From 5a99c73b3ea5a284514d2383514dfad08247db0d Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Wed, 6 May 2026 22:04:57 +0000 Subject: [PATCH 01/12] Improve state persistence and slim CI smoke test --- README.md | 2 +- ci/docker-compose.yml | 4 +- ci/test.go | 441 +++-------------------------------- ci_behavior_test.go | 201 ++++++++++++++++ internal/state/state.go | 111 +++++++-- internal/state/state_test.go | 74 ++++++ main.go | 130 +++++++++-- main_test.go | 36 ++- 8 files changed, 540 insertions(+), 459 deletions(-) create mode 100644 ci_behavior_test.go diff --git a/README.md b/README.md index 41e10ec..f69c68c 100644 --- a/README.md +++ b/README.md @@ -137,7 +137,7 @@ services: | `enableStatsPage` | `string` | `"false"` | Allows `exemptIps` to access `/captcha-protect/stats` to monitor the rate limiter. | | `logLevel` | `string` | `"INFO"` | Log level for the middleware. Options: `ERROR`, `WARNING`, `INFO`, or `DEBUG`. | | `persistentStateFile` | `string` | `""` | File path to persist rate limiter state across Traefik restarts. In Docker, mount this file from the host. | -| `enableStateReconciliation` | `string` | `"false"` | When `"true"`, reads and merges disk state before each save to prevent multiple instances from overwriting data. Adds extra I/O overhead. Only enable for multi-instance deployments sharing state. **Performance warning**: Not recommended for sites with >1M unique visitors due to reconciliation overhead (5-8s per cycle at scale). | +| `enableStateReconciliation` | `string` | `"false"` | When `"true"`, polls the shared state file for changes and merges newer disk state into memory, then reconciles again before dirty snapshots are saved. Enable for multi-instance deployments sharing state. | ### Circuit Breaker (failover if a captcha provider is unavailable) diff --git a/ci/docker-compose.yml b/ci/docker-compose.yml index 92a9623..3559bb4 100644 --- a/ci/docker-compose.yml +++ b/ci/docker-compose.yml @@ -20,7 +20,7 @@ services: traefik.http.middlewares.captcha-protect.plugin.captcha-protect.logLevel: "DEBUG" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectParameters: "${PROTECT_PARAMETERS:-false}" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.goodBots: "" - traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "true" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "false" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.mode: "regex" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectRoutes: "^/" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.excludeRoutes: "\\/oai\\/request,\\/node\\/\\d+\\/(book-)?manifest" @@ -54,7 +54,7 @@ services: traefik.http.middlewares.captcha-protect.plugin.captcha-protect.logLevel: "DEBUG" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectParameters: "${PROTECT_PARAMETERS:-false}" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.goodBots: "" - traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "true" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "false" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.mode: "regex" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectRoutes: "^/" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.excludeRoutes: "\\/oai\\/request,\\/node\\/\\d+\\/(book-)?manifest" diff --git a/ci/test.go b/ci/test.go index f1aed01..858bb66 100755 --- a/ci/test.go +++ b/ci/test.go @@ -1,271 +1,99 @@ package main import ( - "encoding/json" "fmt" - "io" "log/slog" - "math/rand" - "net" "net/http" "os" "os/exec" "strings" - "sync" "time" - - cp "github.com/libops/captcha-protect" - "github.com/libops/captcha-protect/internal/helper" -) - -var ( - rateLimit = 5 - exemptIps []*net.IPNet ) -const numIPs = 100 -const parallelism = 10 +const rateLimit = 5 func main() { - log := slog.New(slog.NewTextHandler(os.Stdout, nil)) - googleCIDRs, err := helper.FetchGoogleCrawlerIPs(log, http.DefaultClient, helper.GoogleCrawlerIPRangeURLs) - if err != nil { - slog.Error("unable to fetch google crawler ips", "err", err) - os.Exit(1) - } - - _ips := []string{ - "127.0.0.0/8", - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "fc00::/8", - } - _ips = append(_ips, googleCIDRs...) - for _, ip := range _ips { - parsedIp := parseCIDR(ip) - exemptIps = append(exemptIps, parsedIp) - } - - fmt.Printf("Checking rate limit %d\n", rateLimit) - - fmt.Printf("Generating %d IPs\n", numIPs) - ips := generateUniquePublicIPs(numIPs) + _ = os.Remove("./tmp/state.json") fmt.Println("Bringing traefik/nginx online") runCommand("docker", "compose", "up", "-d") waitForService("http://localhost") waitForService("http://localhost/app2") - waitForGoogleExemptionReady(googleCIDRs) - fmt.Printf("Making sure %d attempt(s) pass\n", rateLimit) - runParallelChecks(ips, rateLimit, "http://localhost") - statePath := "./tmp/state.json" - runCommand("jq", ".", statePath) - - fmt.Printf("Making sure attempt #%d causes a redirect to the challenge page\n", rateLimit+1) - ensureRedirect(ips, "http://localhost") - testExcludeRouteRegexBypass(ips) - - fmt.Println("\nTesting state sharing between nginx instances...") - time.Sleep(cp.StateSaveInterval + cp.StateSaveJitter + (5 * time.Second)) - - testStateSharing(ips) - testGoogleBotGetsThrough(googleCIDRs) - - runCommand("docker", "container", "stats", "--no-stream") - - // now restart the containers and make sure the previous state reloaded - runCommand("docker", "compose", "down") - runCommand("docker", "compose", "up", "-d") - waitForService("http://localhost") - time.Sleep(10 * time.Second) - checkStateReload() - - runCommand("rm", "-f", statePath) + fmt.Println("Testing Traefik plugin smoke path...") + assertProtectedRoute("107.198.130.166", "http://localhost", "http://localhost/challenge?destination=%2F") + assertNoRedirect("107.198.130.166", "http://localhost/node/123/manifest") + assertNoRedirect("107.198.130.166", "http://localhost/oai/request?foo=bar") + assertProtectedRoute("108.198.130.167", "http://localhost/app2", "http://localhost/challenge?destination=%2Fapp2") -} - -func generateUniquePublicIPs(n int) []string { - ipSet := make(map[string]struct{}) - var ips []string - config := cp.CreateConfig() - bc := &cp.CaptchaProtect{} - bc.SetExemptIps(exemptIps) - err := bc.SetIpv4Mask(16) - if err != nil { - slog.Error("unable to set ipv4 mask") - os.Exit(1) - } - - err = bc.SetIpv6Mask(64) - if err != nil { - slog.Error("unable to set ipv6 mask") - os.Exit(1) - } - - for len(ips) < n { - ip := randomPublicIP(config) - ip, ipRange := bc.ParseIp(ip) - if _, exists := ipSet[ipRange]; !exists { - ipSet[ipRange] = struct{}{} - ips = append(ips, ip) - } - } - - return ips -} - -func randomPublicIP(config *cp.Config) string { - for { - ip := fmt.Sprintf("%d.%d.%d.%d", - rand.Intn(255)+1, - rand.Intn(255), - rand.Intn(255), - rand.Intn(254)+1, - ) - - if !helper.IsIpExcluded(ip, exemptIps) && !helper.IsIpGoodBot(ip, config.GoodBots) { - return ip - } - } + _ = os.Remove("./tmp/state.json") + fmt.Println("✓ Traefik plugin smoke test passed") } func waitForService(url string) { - for { + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { resp, err := http.Get(url) if err == nil && resp.StatusCode < 500 { resp.Body.Close() - time.Sleep(5 * time.Second) // Give it time to stabilize return } + if resp != nil { + resp.Body.Close() + } fmt.Println("waiting for traefik/nginx to come online...") time.Sleep(1 * time.Second) } -} - -func runParallelChecks(ips []string, rateLimit int, url string) { - var wg sync.WaitGroup - sem := make(chan struct{}, parallelism) - - for range rateLimit { - for _, ip := range ips { - wg.Add(1) - sem <- struct{}{} - go func(ip string) { - defer wg.Done() - defer func() { <-sem }() - - fmt.Printf("Checking %s\n", ip) - output := httpRequest(ip, url) - if output != "" { - slog.Error("Unexpected output", "ip", ip, "output", output) - os.Exit(1) - - } - }(ip) - } - } - - wg.Wait() -} - -func ensureRedirect(ips []string, url string) { - expectedURL := url + "/challenge?destination=%2F" - if url != "http://localhost" { - // For /app2, the destination should be the app2 path - expectedURL = "http://localhost/challenge?destination=%2Fapp2" - } - - for _, ip := range ips { - fmt.Printf("Checking %s\n", ip) - output := httpRequest(ip, url) - if output != expectedURL { - slog.Error("Unexpected output", "ip", ip, "output", output, "expected", expectedURL) - os.Exit(1) - } - - fmt.Printf("Got a redirect! %s\n", output) - } + slog.Error("Timed out waiting for service", "url", url) + os.Exit(1) } -func testExcludeRouteRegexBypass(ips []string) { - fmt.Println("\nTesting regex excludeRoutes bypass...") - - testIP := ips[0] - tests := []struct { - url string - name string - }{ - { - url: "http://localhost/node/123/manifest", - name: "/node/123/manifest", - }, - { - url: "http://localhost/oai/request?foo=bar", - name: "/oai/request?foo=bar", - }, +func assertProtectedRoute(ip, url, expectedURL string) { + for i := 0; i < rateLimit; i++ { + assertNoRedirect(ip, url) } - for _, tt := range tests { - fmt.Printf("Checking excluded route %s with IP %s\n", tt.name, testIP) - output := httpRequest(testIP, tt.url) - if output != "" { - slog.Error("Excluded route was unexpectedly challenged", "ip", testIP, "route", tt.name, "output", output) - os.Exit(1) - } + output := httpRequest(ip, url) + if output != expectedURL { + slog.Error("Expected protected route to redirect", "ip", ip, "url", url, "output", output, "expected", expectedURL) + os.Exit(1) } - - fmt.Println("✓ regex excludeRoutes bypass works for excluded paths") } -func testStateSharing(ips []string) { - // Use first IP to test state sharing - testIP := ips[0] - - fmt.Printf("Testing with IP: %s\n", testIP) - - // The IP should already be at rate limit from previous tests on localhost/ - // Now verify it's also rate limited on localhost/app2 (shared state) - fmt.Println("Verifying IP is rate limited on /app2 (state should be shared)...") - output := httpRequest(testIP, "http://localhost/app2") - expectedURL := "http://localhost/challenge?destination=%2Fapp2" - - if output != expectedURL { - slog.Error("State NOT shared between instances!", "ip", testIP, "output", output, "expected", expectedURL) +func assertNoRedirect(ip, url string) { + output := httpRequest(ip, url) + if output != "" { + slog.Error("Unexpected redirect", "ip", ip, "url", url, "output", output) os.Exit(1) } - - fmt.Println("✓ State is correctly shared between nginx instances!") } func httpRequest(ip, url string) string { client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Capture the redirect URL and stop following it if len(via) > 0 { return http.ErrUseLastResponse } return nil }, + Timeout: 10 * time.Second, } - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { slog.Error("Failed to create request", "err", err) os.Exit(1) } req.Header.Set("X-Forwarded-For", ip) + resp, err := client.Do(req) if err != nil { slog.Error("Request failed", "err", err) os.Exit(1) - } defer resp.Body.Close() - // Get redirect URL from response location, err := resp.Location() if err != nil { if err == http.ErrNoLocation { @@ -273,224 +101,23 @@ func httpRequest(ip, url string) string { } slog.Error("Failed to get redirect URL", "err", err) os.Exit(1) - } return strings.TrimSpace(location.String()) } -// runCommand runs a shell command. func runCommand(name string, args ...string) { - runCommandWithEnv(nil, name, args...) -} - -func runCommandWithEnv(env map[string]string, name string, args ...string) { cmd := exec.Command(name, args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - cmd.Env = append(cmd.Env, fmt.Sprintf("RATE_LIMIT=%d", rateLimit)) - cmd.Env = append(cmd.Env, fmt.Sprintf("PATH=%s", os.Getenv("PATH"))) + cmd.Env = append(os.Environ(), fmt.Sprintf("RATE_LIMIT=%d", rateLimit)) - tt := os.Getenv("TRAEFIK_TAG") - if tt != "" { - cmd.Env = append(cmd.Env, fmt.Sprintf("TRAEFIK_TAG=%s", tt)) - } - for k, v := range env { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + if traefikTag := os.Getenv("TRAEFIK_TAG"); traefikTag != "" { + cmd.Env = append(cmd.Env, fmt.Sprintf("TRAEFIK_TAG=%s", traefikTag)) } + if err := cmd.Run(); err != nil { slog.Error("Command failed", "err", err) os.Exit(1) } } - -func checkStateReload() { - resp, err := http.Get("http://localhost/captcha-protect/stats") - if err != nil { - slog.Error("Failed to make GET request", "err", err) - os.Exit(1) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - slog.Error("Failed to read response body", "err", err) - os.Exit(1) - - } - var jsonResponse map[string]interface{} - err = json.Unmarshal(body, &jsonResponse) - if err != nil { - slog.Error("Failed to unmarshal JSON", "err", err) - os.Exit(1) - - } - bots, exists := jsonResponse["bots"] - if !exists { - slog.Error("Key 'bots' not found in JSON response") - os.Exit(1) - } - botsMap, ok := bots.(map[string]interface{}) - if !ok { - slog.Error("'bots' is not an array") - os.Exit(1) - } - - if len(botsMap) < numIPs { - slog.Error("Unexpected number of bots", "expected", numIPs, "received", len(botsMap)) - os.Exit(1) - } - - slog.Info("State reloaded successfully!") -} - -func parseCIDR(cidr string) *net.IPNet { - _, block, err := net.ParseCIDR(cidr) - if err != nil { - slog.Error("Failed to parse CIDR", "cidr", cidr, "err", err) - } - return block -} - -func getIPFromCIDR(cidr string) (string, error) { - ip, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - return "", err - } - - // For IPv4, increment the IP to get a usable host address - if ip.To4() != nil { - // Clone the IP to avoid modifying the original - newIP := make(net.IP, len(ip)) - copy(newIP, ip) - - for i := len(newIP) - 1; i >= 0; i-- { - newIP[i]++ - if newIP[i] > 0 { - break - } - } - - // If the new IP is the broadcast address, we can't use it. - // This is a simplistic check, and might not cover all cases for small subnets. - // A more robust solution might be needed for very small CIDR ranges. - if !ipnet.Contains(newIP) { - // This can happen for /31 or /32. For now, we just return the network address. - return ip.String(), nil - } - // make sure we don't have a broadcast address - last_ip := make(net.IP, len(ipnet.IP)) - copy(last_ip, ipnet.IP) - for i := 0; i < len(ipnet.Mask); i++ { - last_ip[i] |= ^ipnet.Mask[i] - } - if newIP.Equal(last_ip) { - return ip.String(), nil - } - - return newIP.String(), nil - } - - // For IPv6, we can usually just use the network address. - return ip.String(), nil -} - -func testGoogleBotGetsThrough(googleCIDRs []string) { - fmt.Println("\nTesting GoogleBot exemption...") - - if len(googleCIDRs) == 0 { - slog.Warn("No Google CIDRs found, skipping test") - return - } - - // Pick a Google IP - googleIP, err := getIPFromCIDR(googleCIDRs[len(googleCIDRs)-1]) - if err != nil { - slog.Error("Failed to get an IP from google CIDR", "err", err) - os.Exit(1) - } - - var output string // Declare output once here - - fmt.Printf("Checking GoogleBot IP %s without params - should always pass (making %d requests)\n", googleIP, rateLimit+1) - for i := 0; i < rateLimit+1; i++ { - output = httpRequest(googleIP, "http://localhost") // Assign value to the already declared 'output' - if output != "" { - slog.Error(fmt.Sprintf("GoogleBot with no params was challenged on request #%d", i+1), "ip", googleIP, "output", output) - os.Exit(1) - } - } - fmt.Printf("✓ GoogleBot with no params passed %d requests successfully\n", rateLimit+1) - - // now restart with PROTECT_PARAMETERS=true and test again with params - fmt.Println("\nRestarting traefik with PROTECT_PARAMETERS=true") - runCommand("docker", "compose", "down") - runCommandWithEnv(map[string]string{"PROTECT_PARAMETERS": "true"}, "docker", "compose", "up", "-d") - waitForService("http://localhost") - waitForService("http://localhost/app2") - - // Prime the rate limiter for the GoogleBot IP with parameters - fmt.Printf("Priming rate limiter for GoogleBot IP %s with params (%d requests)\n", googleIP, rateLimit) - for i := range rateLimit { - output = httpRequest(googleIP, "http://localhost/?foo=bar") // Assign value - if output != "" { - slog.Error(fmt.Sprintf("GoogleBot with params was challenged prematurely on request #%d", i+1), "ip", googleIP, "output", output) - os.Exit(1) - } - } - fmt.Printf("✓ Rate limiter primed for GoogleBot IP %s\n", googleIP) - - fmt.Printf("Checking GoogleBot IP %s with params (request #%d) - should be challenged\n", googleIP, rateLimit+1) - output = httpRequest(googleIP, "http://localhost/?foo=bar") // Assign value - expectedURL := "http://localhost/challenge?destination=%2F%3Ffoo%3Dbar" - if output != expectedURL { - slog.Error("GoogleBot with params was not challenged", "ip", googleIP, "output", output, "expected", expectedURL) - os.Exit(1) - } - fmt.Println("✓ GoogleBot with params was challenged") - - // set things back to normal for other tests - runCommand("docker", "compose", "down") -} - -func waitForGoogleExemptionReady(googleCIDRs []string) { - googleIP, err := firstUsableIPv4FromCIDRs(googleCIDRs) - if err != nil { - slog.Warn("Unable to select Google IP for readiness check; skipping warmup", "err", err) - return - } - - deadline := time.Now().Add(90 * time.Second) - for time.Now().Before(deadline) { - ready := true - for i := 0; i < rateLimit+1; i++ { - if output := httpRequest(googleIP, "http://localhost"); output != "" { - ready = false - break - } - } - if ready { - fmt.Printf("Google exemption is active for %s\n", googleIP) - return - } - time.Sleep(500 * time.Millisecond) - } - - slog.Error("Timed out waiting for Google crawler IP exemption to become active", "googleIP", googleIP) - os.Exit(1) -} - -func firstUsableIPv4FromCIDRs(cidrs []string) (string, error) { - for _, cidr := range cidrs { - ip, err := getIPFromCIDR(cidr) - if err != nil { - continue - } - parsed := net.ParseIP(ip) - if parsed != nil && parsed.To4() != nil { - return ip, nil - } - } - - return "", fmt.Errorf("no usable IPv4 found in CIDR list") -} diff --git a/ci_behavior_test.go b/ci_behavior_test.go new file mode 100644 index 0000000..6a6f689 --- /dev/null +++ b/ci_behavior_test.go @@ -0,0 +1,201 @@ +package captcha_protect + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "testing/synctest" + "time" + + "github.com/libops/captcha-protect/internal/helper" + lru "github.com/patrickmn/go-cache" +) + +const ciRateLimit = uint(5) + +func TestCILabelEquivalentMiddlewareBehavior(t *testing.T) { + bc := newCILabelEquivalentMiddleware(t, nil) + ip := "107.198.130.166" + + for i := uint(0); i < ciRateLimit; i++ { + assertNoRedirect(t, bc, ip, "/") + } + assertRedirect(t, bc, ip, "/", "/challenge?destination=%2F") + + for _, route := range []string{ + "/node/123/manifest", + "/node/123/book-manifest", + "/oai/request?foo=bar", + } { + assertNoRedirect(t, bc, ip, route) + } + + appIP := "108.198.130.167" + for i := uint(0); i < ciRateLimit; i++ { + assertNoRedirect(t, bc, appIP, "/app2") + } + assertRedirect(t, bc, appIP, "/app2", "/challenge?destination=%2Fapp2") +} + +func TestCILabelEquivalentGooglebotParameterBehavior(t *testing.T) { + googleIP := "203.0.113.10" + + bypass := newCILabelEquivalentMiddleware(t, nil) + bypass.googlebotIPs = helper.NewGooglebotIPs() + bypass.googlebotIPs.Update([]string{"203.0.113.0/24"}, discardLogger()) + bypass.config.EnableGooglebotIPCheck = "true" + + for i := uint(0); i < ciRateLimit+1; i++ { + assertNoRedirect(t, bypass, googleIP, "/") + } + + protectedParams := newCILabelEquivalentMiddleware(t, func(config *Config) { + config.ProtectParameters = "true" + }) + protectedParams.googlebotIPs = helper.NewGooglebotIPs() + protectedParams.googlebotIPs.Update([]string{"203.0.113.0/24"}, discardLogger()) + protectedParams.config.EnableGooglebotIPCheck = "true" + + for i := uint(0); i < ciRateLimit; i++ { + assertNoRedirect(t, protectedParams, googleIP, "/?foo=bar") + } + assertRedirect(t, protectedParams, googleIP, "/?foo=bar", "/challenge?destination=%2F%3Ffoo%3Dbar") +} + +func TestPersistentStateSharingWithSynctest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + stateFile := filepath.Join(t.TempDir(), "state.json") + writer := newStateOnlyCaptchaProtect(stateFile, 2) + reader := newStateOnlyCaptchaProtect(stateFile, 2) + + ctx, cancel := context.WithCancel(t.Context()) + done := make(chan struct{}, 2) + go func() { + writer.saveState(ctx) + done <- struct{}{} + }() + go func() { + reader.saveState(ctx) + done <- struct{}{} + }() + + for i := uint(0); i < writer.config.RateLimit+1; i++ { + writer.registerRequest("107.198.0.0") + } + + time.Sleep(StateSaveInterval + StateSaveJitter + 3*time.Second) + synctest.Wait() + + v, ok := reader.rateCache.Get("107.198.0.0") + if !ok { + t.Fatal("expected reader instance to reconcile writer state") + } + if got, want := v.(uint), writer.config.RateLimit+1; got != want { + t.Fatalf("reconciled rate = %d, want %d", got, want) + } + + cancel() + synctest.Wait() + <-done + <-done + }) +} + +func newCILabelEquivalentMiddleware(t *testing.T, mutate func(*Config)) *CaptchaProtect { + t.Helper() + + config := ciLabelEquivalentConfig() + if mutate != nil { + mutate(config) + } + + next := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusNoContent) + }) + bc, err := NewCaptchaProtect(t.Context(), next, config, "captcha-protect") + if err != nil { + t.Fatalf("NewCaptchaProtect failed: %v", err) + } + + return bc +} + +func ciLabelEquivalentConfig() *Config { + config := CreateConfig() + config.RateLimit = ciRateLimit + config.Window = 120 + config.CaptchaProvider = "poj" + config.SiteKey = "test-site-key" + config.SecretKey = "test-secret-key" + config.EnableStatsPage = "true" + config.IPForwardedHeader = "X-Forwarded-For" + config.LogLevel = "ERROR" + config.ProtectParameters = "false" + config.GoodBots = []string{} + config.EnableGooglebotIPCheck = "false" + config.Mode = "regex" + config.ProtectRoutes = []string{"^/"} + config.ExcludeRoutes = []string{ + `\/oai\/request`, + `\/node\/\d+\/(book-)?manifest`, + } + return config +} + +func newStateOnlyCaptchaProtect(stateFile string, rateLimit uint) *CaptchaProtect { + config := ciLabelEquivalentConfig() + config.PersistentStateFile = stateFile + config.EnableStateReconciliation = "true" + config.RateLimit = rateLimit + + return &CaptchaProtect{ + config: config, + log: discardLogger(), + rateCache: lru.New(time.Hour, lru.NoExpiration), + botCache: lru.New(time.Hour, lru.NoExpiration), + verifiedCache: lru.New(time.Hour, lru.NoExpiration), + } +} + +func assertNoRedirect(t *testing.T, handler http.Handler, ip, target string) { + t.Helper() + + status, location := serveCITestRequest(handler, ip, target) + if location != "" { + t.Fatalf("%s returned redirect %q", target, location) + } + if status != http.StatusNoContent { + t.Fatalf("%s status = %d, want %d", target, status, http.StatusNoContent) + } +} + +func assertRedirect(t *testing.T, handler http.Handler, ip, target, expectedLocation string) { + t.Helper() + + status, location := serveCITestRequest(handler, ip, target) + if status != http.StatusFound { + t.Fatalf("%s status = %d, want %d", target, status, http.StatusFound) + } + if location != expectedLocation { + t.Fatalf("%s redirect = %q, want %q", target, location, expectedLocation) + } +} + +func serveCITestRequest(handler http.Handler, ip, target string) (int, string) { + req := httptest.NewRequest(http.MethodGet, target, nil) + req.Host = "localhost" + req.Header.Set("X-Forwarded-For", ip) + rw := httptest.NewRecorder() + + handler.ServeHTTP(rw, req) + + return rw.Code, rw.Header().Get("Location") +} + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} diff --git a/internal/state/state.go b/internal/state/state.go index 291e5fa..d3d55a4 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "path/filepath" "reflect" "time" @@ -24,6 +25,18 @@ type State struct { Memory map[string]uintptr `json:"memory"` } +type SaveMetrics struct { + LockMs int64 + ReadMs int64 + ReconcileMs int64 + MarshalMs int64 + WriteMs int64 + TotalMs int64 + RateEntries int + BotEntries int + VerifiedEntries int +} + func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { state := State{ Memory: make(map[string]uintptr, 3), @@ -67,26 +80,35 @@ func SaveStateToFile( rateCache, botCache, verifiedCache *lru.Cache, log *slog.Logger, ) (lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs int64, err error) { + metrics, err := SaveStateToFileWithMetrics(filePath, reconcile, rateCache, botCache, verifiedCache, log) + return metrics.LockMs, metrics.ReadMs, metrics.ReconcileMs, metrics.MarshalMs, metrics.WriteMs, metrics.TotalMs, err +} + +func SaveStateToFileWithMetrics( + filePath string, + reconcile bool, + rateCache, botCache, verifiedCache *lru.Cache, + log *slog.Logger, +) (SaveMetrics, error) { startTime := time.Now() + metrics := SaveMetrics{} lock, err := NewFileLock(filePath + ".lock") if err != nil { - return 0, 0, 0, 0, 0, 0, fmt.Errorf("failed to create lock: %w", err) + return metrics, fmt.Errorf("failed to create lock: %w", err) } defer lock.Close() if err := lock.Lock(); err != nil { - return 0, 0, 0, 0, 0, 0, fmt.Errorf("failed to acquire lock: %w", err) + return metrics, fmt.Errorf("failed to acquire lock: %w", err) } - lockDuration := time.Since(startTime) - - var readDuration, reconcileDuration, marshalDuration, writeDuration time.Duration + metrics.LockMs = time.Since(startTime).Milliseconds() // Reconcile with existing file state if enabled if reconcile { readStart := time.Now() fileContent, readErr := os.ReadFile(filePath) - readDuration = time.Since(readStart) + metrics.ReadMs = time.Since(readStart).Milliseconds() if readErr == nil && len(fileContent) > 0 { reconcileStart := time.Now() @@ -95,7 +117,7 @@ func SaveStateToFile( log.Debug("Reconciling state before save", "fileBytes", len(fileContent)) ReconcileState(fileState, rateCache, botCache, verifiedCache) } - reconcileDuration = time.Since(reconcileStart) + metrics.ReconcileMs = time.Since(reconcileStart).Milliseconds() } } @@ -103,29 +125,50 @@ func SaveStateToFile( marshalStart := time.Now() currentState := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) jsonData, err := json.Marshal(currentState) - marshalDuration = time.Since(marshalStart) + metrics.MarshalMs = time.Since(marshalStart).Milliseconds() + metrics.RateEntries = len(currentState.Rate) + metrics.BotEntries = len(currentState.Bots) + metrics.VerifiedEntries = len(currentState.Verified) if err != nil { - return lockDuration.Milliseconds(), readDuration.Milliseconds(), - reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), - 0, 0, err + return metrics, err } // Write to disk writeStart := time.Now() - err = os.WriteFile(filePath, jsonData, 0644) - writeDuration = time.Since(writeStart) + err = atomicWriteFile(filePath, jsonData, 0644) + metrics.WriteMs = time.Since(writeStart).Milliseconds() if err != nil { - return lockDuration.Milliseconds(), readDuration.Milliseconds(), - reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), - writeDuration.Milliseconds(), 0, err + return metrics, err } - totalDuration := time.Since(startTime) - return lockDuration.Milliseconds(), readDuration.Milliseconds(), - reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), - writeDuration.Milliseconds(), totalDuration.Milliseconds(), nil + metrics.TotalMs = time.Since(startTime).Milliseconds() + return metrics, nil +} + +func atomicWriteFile(filePath string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(filePath) + tmp, err := os.CreateTemp(dir, filepath.Base(filePath)+".tmp-*") + if err != nil { + return err + } + tmpName := tmp.Name() + defer os.Remove(tmpName) + + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Chmod(perm); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + + return os.Rename(tmpName, filePath) } // LoadStateFromFile loads state from a file with locking. @@ -160,6 +203,34 @@ func LoadStateFromFile( return nil } +func ReconcileStateFromFile( + filePath string, + rateCache, botCache, verifiedCache *lru.Cache, +) error { + lock, err := NewFileLock(filePath + ".lock") + if err != nil { + return err + } + defer lock.Close() + + if err := lock.Lock(); err != nil { + return err + } + + fileContent, err := os.ReadFile(filePath) + if err != nil || len(fileContent) == 0 { + return err + } + + var fileState State + if err := json.Unmarshal(fileContent, &fileState); err != nil { + return err + } + + ReconcileState(fileState, rateCache, botCache, verifiedCache) + return nil +} + func calculateDuration(expiration int64, now int64) time.Duration { if expiration == 0 { return lru.NoExpiration diff --git a/internal/state/state_test.go b/internal/state/state_test.go index b1639af..1db9521 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "log/slog" "os" + "path/filepath" "testing" "testing/synctest" "time" @@ -381,6 +382,43 @@ func TestSaveStateToFile(t *testing.T) { } }) + t.Run("Save with metrics", func(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + botCache.Set("1.2.3.4", false, lru.DefaultExpiration) + verifiedCache.Set("5.6.7.8", true, lru.DefaultExpiration) + + metrics, err := SaveStateToFileWithMetrics( + tmpFile, + false, + rateCache, + botCache, + verifiedCache, + testLogger(), + ) + if err != nil { + t.Fatalf("SaveStateToFileWithMetrics failed: %v", err) + } + + if metrics.RateEntries != 1 { + t.Errorf("Expected 1 rate entry, got %d", metrics.RateEntries) + } + if metrics.BotEntries != 1 { + t.Errorf("Expected 1 bot entry, got %d", metrics.BotEntries) + } + if metrics.VerifiedEntries != 1 { + t.Errorf("Expected 1 verified entry, got %d", metrics.VerifiedEntries) + } + if matches, err := filepath.Glob(tmpFile + ".tmp-*"); err != nil || len(matches) != 0 { + t.Fatalf("Expected no leftover temp files, matches=%v err=%v", matches, err) + } + }) + t.Run("File write error", func(t *testing.T) { // Use invalid path to trigger error invalidPath := "/invalid/directory/that/does/not/exist/state.json" @@ -560,6 +598,42 @@ func TestLoadStateFromFile(t *testing.T) { }) } +func TestReconcileStateFromFile(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + now := time.Now().UnixNano() + futureExpiration := now + int64(1*time.Hour) + + fileState := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, + }, + Bots: map[string]CacheEntry{}, + Verified: map[string]CacheEntry{}, + Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + } + data, _ := json.Marshal(fileState) + if err := os.WriteFile(tmpFile, data, 0644); err != nil { + t.Fatalf("Failed to write test state: %v", err) + } + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) + + if err := ReconcileStateFromFile(tmpFile, rateCache, botCache, verifiedCache); err != nil { + t.Fatalf("ReconcileStateFromFile failed: %v", err) + } + + if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 10 { + t.Error("Expected file state to be reconciled into memory") + } + if v, ok := rateCache.Get("10.0.0.0"); !ok || v.(uint) != 5 { + t.Error("Expected existing memory state to be retained") + } +} + func testLogger() *slog.Logger { return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelError, // Only show errors during tests diff --git a/main.go b/main.go index e3191dc..264ecd8 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/libops/captcha-protect/internal/helper" @@ -98,6 +99,9 @@ type CaptchaProtect struct { ipv6Mask net.IPMask protectRoutesRegex []*regexp.Regexp excludeRoutesRegex []*regexp.Regexp + stateDirty atomic.Uint64 + stateSavedDirty atomic.Uint64 + stateFileModTime time.Time // Circuit breaker fields mu sync.RWMutex @@ -669,6 +673,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. if success { bc.verifiedCache.Set(ip, true, exp) + bc.markStateDirty() destination := normalizeDestination(req.FormValue("destination")) http.Redirect(rw, req, destination, http.StatusFound) @@ -880,13 +885,21 @@ func (bc *CaptchaProtect) trippedRateLimit(ip string) bool { func (bc *CaptchaProtect) registerRequest(ip string) { err := bc.rateCache.Add(ip, uint(1), lru.DefaultExpiration) if err == nil { + bc.markStateDirty() + return + } + + v, ok := bc.rateCache.Get(ip) + if ok && v.(uint) > bc.config.RateLimit { return } _, err = bc.rateCache.IncrementUint(ip, uint(1)) if err != nil { bc.log.Error("unable to set rate cache", "ip", ip, "err", err) + return } + bc.markStateDirty() } func (bc *CaptchaProtect) getClientIP(req *http.Request) (string, string) { @@ -992,6 +1005,7 @@ func (bc *CaptchaProtect) isGoodBot(req *http.Request, clientIP string) bool { v = helper.IsIpGoodBot(clientIP, bc.config.GoodBots) } bc.botCache.Set(clientIP, v, lru.DefaultExpiration) + bc.markStateDirty() return v } @@ -1019,7 +1033,7 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { bc.log.Debug("State save configured", "baseInterval", StateSaveInterval, "jitter", jitter, "actualInterval", interval) - ticker := time.NewTicker(interval) + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0644) @@ -1029,26 +1043,46 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { } // we made sure the file is writable, we can continue in our loop file.Close() + bc.refreshStateFileModTime() + + lastSave := time.Time{} for { select { case <-ticker.C: - bc.log.Debug("Periodic state save triggered") - bc.saveStateNow() + if bc.config.EnableStateReconciliation == "true" { + bc.reconcileStateFromFileIfChanged() + } + if !bc.hasUnsavedState() { + continue + } + if !lastSave.IsZero() && time.Since(lastSave) < interval { + continue + } + bc.log.Debug("Dirty state save triggered", "dirtyChanges", bc.unsavedStateChanges()) + if bc.saveStateNow() { + lastSave = time.Now() + } case <-ctx.Done(): - bc.log.Debug("Context cancelled, running saveState before shutdown") - bc.saveStateNow() + if bc.config.EnableStateReconciliation == "true" { + bc.reconcileStateFromFileIfChanged() + } + if bc.hasUnsavedState() { + bc.log.Debug("Context cancelled, running saveState before shutdown") + bc.saveStateNow() + } return } } } // saveStateNow performs an immediate state save using the state package. -func (bc *CaptchaProtect) saveStateNow() { +func (bc *CaptchaProtect) saveStateNow() bool { reconcile := bc.config.EnableStateReconciliation == "true" + dirtyAtStart := bc.stateDirty.Load() - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := state.SaveStateToFile( + metrics, err := state.SaveStateToFileWithMetrics( bc.config.PersistentStateFile, reconcile, bc.rateCache, @@ -1059,22 +1093,23 @@ func (bc *CaptchaProtect) saveStateNow() { if err != nil { bc.log.Error("failed to save state", "err", err) - return + return false } + bc.stateSavedDirty.Store(dirtyAtStart) + bc.refreshStateFileModTime() - // Get current state for logging (already marshaled in SaveStateToFile, but we need counts) - currentState := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) bc.log.Debug("State saved successfully", - "rateEntries", len(currentState.Rate), - "botEntries", len(currentState.Bots), - "verifiedEntries", len(currentState.Verified), - "lockMs", lockMs, - "readMs", readMs, - "reconcileMs", reconcileMs, - "marshalMs", marshalMs, - "writeMs", writeMs, - "totalMs", totalMs, + "rateEntries", metrics.RateEntries, + "botEntries", metrics.BotEntries, + "verifiedEntries", metrics.VerifiedEntries, + "lockMs", metrics.LockMs, + "readMs", metrics.ReadMs, + "reconcileMs", metrics.ReconcileMs, + "marshalMs", metrics.MarshalMs, + "writeMs", metrics.WriteMs, + "totalMs", metrics.TotalMs, ) + return true } func (bc *CaptchaProtect) loadState() { @@ -1091,6 +1126,63 @@ func (bc *CaptchaProtect) loadState() { } bc.log.Info("Loaded previous state") + bc.refreshStateFileModTime() +} + +func (bc *CaptchaProtect) markStateDirty() { + if bc.config.PersistentStateFile == "" { + return + } + bc.stateDirty.Add(1) +} + +func (bc *CaptchaProtect) hasUnsavedState() bool { + return bc.stateDirty.Load() != bc.stateSavedDirty.Load() +} + +func (bc *CaptchaProtect) unsavedStateChanges() uint64 { + dirty := bc.stateDirty.Load() + saved := bc.stateSavedDirty.Load() + if dirty < saved { + return 0 + } + return dirty - saved +} + +func (bc *CaptchaProtect) refreshStateFileModTime() { + info, err := os.Stat(bc.config.PersistentStateFile) + if err != nil { + return + } + bc.stateFileModTime = info.ModTime() +} + +func (bc *CaptchaProtect) reconcileStateFromFileIfChanged() { + info, err := os.Stat(bc.config.PersistentStateFile) + if err != nil { + return + } + modTime := info.ModTime() + if !bc.stateFileModTime.IsZero() && !modTime.After(bc.stateFileModTime) { + return + } + + err = state.ReconcileStateFromFile( + bc.config.PersistentStateFile, + bc.rateCache, + bc.botCache, + bc.verifiedCache, + ) + if err != nil { + bc.log.Warn("failed to reconcile state file", "err", err) + return + } + + if info, err := os.Stat(bc.config.PersistentStateFile); err == nil { + modTime = info.ModTime() + } + bc.stateFileModTime = modTime + bc.log.Debug("Reconciled newer state file") } func (bc *CaptchaProtect) ChallengeOnPage() bool { diff --git a/main_test.go b/main_test.go index d3a83e0..c760302 100644 --- a/main_test.go +++ b/main_test.go @@ -974,15 +974,6 @@ func TestStatePersistence(t *testing.T) { config.SiteKey = "test" config.SecretKey = "test" config.ProtectRoutes = []string{"/"} - config.PersistentStateFile = tmpFile - - // Don't pass a context to avoid starting background goroutines - bc1, _ := NewCaptchaProtect(context.Background(), nil, config, "test") - - // Add some state - bc1.rateCache.Set("192.168.0.0", uint(10), 1*time.Hour) - bc1.verifiedCache.Set("1.2.3.4", true, 1*time.Hour) - bc1.botCache.Set("5.6.7.8", false, 1*time.Hour) // Manually save state by writing the file directly // This tests the state format without relying on the background goroutine @@ -1015,6 +1006,8 @@ func TestStatePersistence(t *testing.T) { // Create new instance - should load state bc2, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc2.config.PersistentStateFile = tmpFile + bc2.loadState() // Check rate cache val, found := bc2.rateCache.Get("192.168.0.0") @@ -1035,6 +1028,27 @@ func TestStatePersistence(t *testing.T) { } } +func TestRegisterRequestCapsPersistedRateCounter(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + + bc := newStateOnlyCaptchaProtect(tmpFile, 2) + + for i := 0; i < 5; i++ { + bc.registerRequest("192.168.0.0") + } + + v, ok := bc.rateCache.Get("192.168.0.0") + if !ok { + t.Fatal("Expected rate cache entry") + } + if got, want := v.(uint), bc.config.RateLimit+1; got != want { + t.Fatalf("Expected rate counter to cap at %d, got %d", want, got) + } + if got, want := bc.stateDirty.Load(), uint64(bc.config.RateLimit+1); got != want { + t.Fatalf("Expected dirty counter to track only effective mutations, got %d want %d", got, want) + } +} + func TestVerifyChallengePage(t *testing.T) { tests := []struct { name string @@ -1228,13 +1242,14 @@ func TestLoadStateInvalidJSON(t *testing.T) { config.SiteKey = "test" config.SecretKey = "test" config.ProtectRoutes = []string{"/"} - config.PersistentStateFile = tmpFile // Should not panic, just log error bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") if err != nil { t.Errorf("Should not fail on invalid state JSON: %v", err) } + bc.config.PersistentStateFile = tmpFile + bc.loadState() // Caches should be empty if bc.rateCache.ItemCount() != 0 { @@ -1243,6 +1258,7 @@ func TestLoadStateInvalidJSON(t *testing.T) { // Clean up the file before temp dir cleanup _ = os.Remove(tmpFile) + _ = os.Remove(tmpFile + ".lock") } func TestParseHttpMethodsInvalid(t *testing.T) { From 4772cb264d61030b2779af379768a1f838edf847 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Wed, 6 May 2026 22:18:36 +0000 Subject: [PATCH 02/12] Harden state persistence and redact secrets --- .github/workflows/lint-test.yml | 2 +- .github/workflows/stress-test.yaml | 2 +- README.md | 2 +- ci/test.go | 12 +++++----- go.mod | 2 ++ internal/state/lock.go | 15 ++++++++----- internal/state/lock_test.go | 6 ++++- internal/state/state.go | 10 ++++----- internal/state/state_test.go | 3 +++ main.go | 36 +++++++++++++++++++++++++----- main_test.go | 16 +++++++++++++ 11 files changed, 81 insertions(+), 25 deletions(-) diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 14a5ebc..4f08669 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: - go-version: ">=1.25.0" + go-version: "1.26.2" - name: golangci-lint uses: golangci/golangci-lint-action@1e7e51e771db61008b38414a730f564565cf7c20 # v9 diff --git a/.github/workflows/stress-test.yaml b/.github/workflows/stress-test.yaml index b9b4b0b..e92b825 100644 --- a/.github/workflows/stress-test.yaml +++ b/.github/workflows/stress-test.yaml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: - go-version: ">=1.25.0" + go-version: "1.26.2" - name: Run stress tests id: stress_test diff --git a/README.md b/README.md index f69c68c..a0eb310 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ services: | `window` | `int` | `86400` | Duration (in seconds) for monitoring requests per subnet. | | `ipv4subnetMask` | `int` | `16` | CIDR subnet mask to group IPv4 addresses for rate limiting. | | `ipv6subnetMask` | `int` | `64` | CIDR subnet mask to group IPv6 addresses for rate limiting. | -| `ipForwardedHeader` | `string` | `""` | Header to check for the original client IP if Traefik is behind a load balancer. | +| `ipForwardedHeader` | `string` | `""` | Header to check for the original client IP if Traefik is behind a load balancer. Only set this when Traefik receives the header from a trusted proxy/load balancer; otherwise clients can spoof their IP. | | `ipDepth` | `int` | `0` | How deep past the last non-exempt IP to fetch the real IP from `ipForwardedHeader`. Default 0 returns the last IP in the forward header | | `goodBots` | `[]string` (encouraged) | *see below* | List of second-level domains for bots that are never challenged or rate-limited. | | `enableGooglebotIPCheck`| `string`. | `"false"` | Treat IPs coming from googlebot's known IP ranges as good bots | diff --git a/ci/test.go b/ci/test.go index 858bb66..7658c69 100755 --- a/ci/test.go +++ b/ci/test.go @@ -33,13 +33,13 @@ func main() { func waitForService(url string) { deadline := time.Now().Add(90 * time.Second) for time.Now().Before(deadline) { - resp, err := http.Get(url) + resp, err := http.Get(url) // #nosec G107 -- CI smoke test only calls fixed localhost URLs. if err == nil && resp.StatusCode < 500 { - resp.Body.Close() + _ = resp.Body.Close() return } if resp != nil { - resp.Body.Close() + _ = resp.Body.Close() } fmt.Println("waiting for traefik/nginx to come online...") time.Sleep(1 * time.Second) @@ -92,7 +92,9 @@ func httpRequest(ip, url string) string { slog.Error("Request failed", "err", err) os.Exit(1) } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() location, err := resp.Location() if err != nil { @@ -107,7 +109,7 @@ func httpRequest(ip, url string) string { } func runCommand(name string, args ...string) { - cmd := exec.Command(name, args...) + cmd := exec.Command(name, args...) // #nosec G204 -- CI smoke test invokes fixed docker compose commands. cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Env = append(os.Environ(), fmt.Sprintf("RATE_LIMIT=%d", rateLimit)) diff --git a/go.mod b/go.mod index 4f45b71..fd36af6 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,6 @@ module github.com/libops/captcha-protect go 1.25.0 +toolchain go1.26.2 + require github.com/patrickmn/go-cache v2.1.0+incompatible diff --git a/internal/state/lock.go b/internal/state/lock.go index 18137bf..842e0e4 100644 --- a/internal/state/lock.go +++ b/internal/state/lock.go @@ -33,17 +33,20 @@ func (fl *FileLock) Lock() error { for { // Try to create lock file exclusively - f, err := os.OpenFile(fl.lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) + f, err := os.OpenFile(fl.lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) if err == nil { // Successfully created lock file - _, err = f.WriteString(strconv.Itoa(fl.pid)) - f.Close() - // Check for write error - if err != nil { + _, writeErr := f.WriteString(strconv.Itoa(fl.pid)) + closeErr := f.Close() + if writeErr != nil { // We got the lock but failed to write. // Best effort to clean up, then return the error. _ = os.Remove(fl.lockPath) - return fmt.Errorf("failed to write pid to lock file: %v", err) + return fmt.Errorf("failed to write pid to lock file: %v", writeErr) + } + if closeErr != nil { + _ = os.Remove(fl.lockPath) + return fmt.Errorf("failed to close lock file: %v", closeErr) } // We hold the lock return nil diff --git a/internal/state/lock_test.go b/internal/state/lock_test.go index 61e6a5d..b6a47e4 100644 --- a/internal/state/lock_test.go +++ b/internal/state/lock_test.go @@ -27,9 +27,13 @@ func TestFileLock_LockUnlock(t *testing.T) { t.Fatalf("Lock() error = %v", err) } - if _, err := os.Stat(lockPath); err != nil { + info, err := os.Stat(lockPath) + if err != nil { t.Fatalf("lock file was not created: %v", err) } + if mode := info.Mode().Perm(); mode != 0600 { + t.Fatalf("lock file mode = %v, want 0600", mode) + } content, err := os.ReadFile(lockPath) if err != nil { diff --git a/internal/state/state.go b/internal/state/state.go index d3d55a4..c109b89 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -107,7 +107,7 @@ func SaveStateToFileWithMetrics( // Reconcile with existing file state if enabled if reconcile { readStart := time.Now() - fileContent, readErr := os.ReadFile(filePath) + fileContent, readErr := os.ReadFile(filePath) // #nosec G304 -- persistent state path is trusted middleware configuration. metrics.ReadMs = time.Since(readStart).Milliseconds() if readErr == nil && len(fileContent) > 0 { @@ -134,9 +134,9 @@ func SaveStateToFileWithMetrics( return metrics, err } - // Write to disk + // Write to disk. State can contain client IPs, so keep snapshots private. writeStart := time.Now() - err = atomicWriteFile(filePath, jsonData, 0644) + err = atomicWriteFile(filePath, jsonData, 0600) metrics.WriteMs = time.Since(writeStart).Milliseconds() if err != nil { @@ -186,7 +186,7 @@ func LoadStateFromFile( return err } - fileContent, err := os.ReadFile(filePath) + fileContent, err := os.ReadFile(filePath) // #nosec G304 -- persistent state path is trusted middleware configuration. if err != nil || len(fileContent) == 0 { return err } @@ -217,7 +217,7 @@ func ReconcileStateFromFile( return err } - fileContent, err := os.ReadFile(filePath) + fileContent, err := os.ReadFile(filePath) // #nosec G304 -- persistent state path is trusted middleware configuration. if err != nil || len(fileContent) == 0 { return err } diff --git a/internal/state/state_test.go b/internal/state/state_test.go index 1db9521..1899a16 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -286,6 +286,9 @@ func TestSaveStateToFile(t *testing.T) { if fileInfo.Size() == 0 { t.Error("State file is empty") } + if mode := fileInfo.Mode().Perm(); mode != 0600 { + t.Fatalf("State file mode = %v, want 0600", mode) + } // Load and verify the saved data savedData, err := os.ReadFile(tmpFile) diff --git a/main.go b/main.go index 264ecd8..b754400 100644 --- a/main.go +++ b/main.go @@ -2,11 +2,12 @@ package captcha_protect import ( "context" + crand "crypto/rand" "encoding/json" "fmt" htemplate "html/template" "log/slog" - "math/rand" + "math/big" "net" "net/http" "net/url" @@ -162,6 +163,14 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return NewCaptchaProtect(ctx, next, config, name) } +func redactedConfig(config *Config) Config { + c := *config + if c.SecretKey != "" { + c.SecretKey = "[REDACTED]" + } + return c +} + func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, name string) (*CaptchaProtect, error) { log := plog.New(config.LogLevel) @@ -177,7 +186,7 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n } expiration := time.Duration(config.Window) * time.Second - log.Debug("Captcha config", "config", config) + log.Debug("Captcha config", "config", redactedConfig(config)) if len(config.ProtectRoutes) == 0 && config.Mode != "suffix" { return nil, fmt.Errorf("you must protect at least one route with the protectRoutes config value. / will cover your entire site") @@ -1028,7 +1037,7 @@ func (c *Config) ParseHttpMethods(log *slog.Logger) { func (bc *CaptchaProtect) saveState(ctx context.Context) { // Add random jitter to prevent multiple instances from trying to save simultaneously - jitter := time.Duration(rand.Intn(int(StateSaveJitter.Milliseconds()))) * time.Millisecond + jitter := stateSaveJitter() interval := StateSaveInterval + jitter bc.log.Debug("State save configured", "baseInterval", StateSaveInterval, "jitter", jitter, "actualInterval", interval) @@ -1036,13 +1045,16 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() - file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0644) + file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { bc.log.Error("unable to save state, could not open or create file", "stateFile", bc.config.PersistentStateFile, "err", err) return } // we made sure the file is writable, we can continue in our loop - file.Close() + if err := file.Close(); err != nil { + bc.log.Error("unable to save state, could not close state file", "stateFile", bc.config.PersistentStateFile, "err", err) + return + } bc.refreshStateFileModTime() lastSave := time.Time{} @@ -1077,6 +1089,20 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { } } +func stateSaveJitter() time.Duration { + maxJitter := big.NewInt(StateSaveJitter.Milliseconds()) + if maxJitter.Sign() <= 0 { + return 0 + } + + jitter, err := crand.Int(crand.Reader, maxJitter) + if err != nil { + return 0 + } + + return time.Duration(jitter.Int64()) * time.Millisecond +} + // saveStateNow performs an immediate state save using the state package. func (bc *CaptchaProtect) saveStateNow() bool { reconcile := bc.config.EnableStateReconciliation == "true" diff --git a/main_test.go b/main_test.go index c760302..5cdd888 100644 --- a/main_test.go +++ b/main_test.go @@ -706,6 +706,22 @@ func TestNewCaptchaProtectValidation(t *testing.T) { } } +func TestRedactedConfigDoesNotExposeSecretKey(t *testing.T) { + config := CreateConfig() + config.SecretKey = "super-secret" + + redacted := redactedConfig(config) + if redacted.SecretKey == config.SecretKey { + t.Fatal("expected secret key to be redacted") + } + if redacted.SecretKey != "[REDACTED]" { + t.Fatalf("unexpected redacted secret value %q", redacted.SecretKey) + } + if config.SecretKey != "super-secret" { + t.Fatal("redactedConfig mutated the original config") + } +} + func TestRateLimiting(t *testing.T) { config := CreateConfig() config.SiteKey = "test" From 9ab4e7048a1516a5efa4d24846be71b1d90cef0a Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Wed, 6 May 2026 23:11:05 +0000 Subject: [PATCH 03/12] Speed up persisted state snapshots --- README.md | 4 +- ci/parse-stress-results/main.go | 14 +- ci_behavior_test.go | 2 + internal/state/lock.go | 34 ++-- internal/state/lock_test.go | 54 ++++++ internal/state/state.go | 272 +++++++++++++++++++++++++--- internal/state/state_stress_test.go | 190 +++++++------------ internal/state/state_test.go | 149 ++++++++++++++- main_test.go | 77 ++++++++ 9 files changed, 619 insertions(+), 177 deletions(-) diff --git a/README.md b/README.md index a0eb310..e65920c 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ services: | `challengeStatusCode` | `int` | `200` | HTTP Response status code to return when serving a challenge | | `enableStatsPage` | `string` | `"false"` | Allows `exemptIps` to access `/captcha-protect/stats` to monitor the rate limiter. | | `logLevel` | `string` | `"INFO"` | Log level for the middleware. Options: `ERROR`, `WARNING`, `INFO`, or `DEBUG`. | -| `persistentStateFile` | `string` | `""` | File path to persist rate limiter state across Traefik restarts. In Docker, mount this file from the host. | +| `persistentStateFile` | `string` | `""` | File path to persist rate limiter and verified challenge state across Traefik restarts. Derived bot lookup cache entries are not persisted. In Docker, mount this file from the host. | | `enableStateReconciliation` | `string` | `"false"` | When `"true"`, polls the shared state file for changes and merges newer disk state into memory, then reconciles again before dirty snapshots are saved. Enable for multi-instance deployments sharing state. | ### Circuit Breaker (failover if a captcha provider is unavailable) @@ -272,7 +272,7 @@ If you have use a computer within the `exemptIps`, and access to the command lin curl -s https://example.com/captcha-protect/stats | jq -r '.rate | to_entries | sort_by(.value) | .[] | "\(.key): \(.value)"' | tail -25 ``` -This JSON state data is also found in the `state.json` file that you should have configured in your `docker-compose.yml` using the `persistentStateFile` setting and volume definition. NOTE: this file should only be changed by `captcha-protect` and not manually. +The rate limiter and verified challenge portions of this JSON state data are also found in the `state.json` file that you should have configured in your `docker-compose.yml` using the `persistentStateFile` setting and volume definition. NOTE: this file should only be changed by `captcha-protect` and not manually. ## Troubleshooting diff --git a/ci/parse-stress-results/main.go b/ci/parse-stress-results/main.go index 2d591f1..74e0c0e 100644 --- a/ci/parse-stress-results/main.go +++ b/ci/parse-stress-results/main.go @@ -32,17 +32,17 @@ func main() { scanner := bufio.NewScanner(os.Stdin) // Patterns to extract data - sizePattern := regexp.MustCompile(`size: ([\d.]+) MB`) + sizePattern := regexp.MustCompile(`size: ([\d.]+) (KB|MB)`) timePattern := regexp.MustCompile(`took (\d+)ms`) results := make(map[string]*TestResult) currentTest := "" // Initialize known tests - results["Small"] = &TestResult{Name: "Small", Entries: "16 rate / 65K bots / 256 verified", Threshold: 500} - results["Medium"] = &TestResult{Name: "Medium", Entries: "256 rate / 262K bots / 65K verified", Threshold: 1000} - results["Large"] = &TestResult{Name: "Large", Entries: "1K rate / 1M bots / 262K verified", Threshold: 3000} - results["XLarge"] = &TestResult{Name: "XLarge", Entries: "4K rate / 4.2M bots / 1M verified", Threshold: 10000} + results["Small"] = &TestResult{Name: "Small", Entries: "16 rate / derived bots skipped / 256 verified", Threshold: 500} + results["Medium"] = &TestResult{Name: "Medium", Entries: "256 rate / derived bots skipped / 65K verified", Threshold: 1000} + results["Large"] = &TestResult{Name: "Large", Entries: "1K rate / derived bots skipped / 262K verified", Threshold: 3000} + results["XLarge"] = &TestResult{Name: "XLarge", Entries: "4K rate / derived bots skipped / 1M verified", Threshold: 10000} for scanner.Scan() { line := scanner.Text() @@ -65,9 +65,9 @@ func main() { // Extract size from Marshal test if event.Output != "" && strings.Contains(event.Output, "Marshal took") && strings.Contains(event.Output, "size:") { - if matches := sizePattern.FindStringSubmatch(event.Output); len(matches) > 1 { + if matches := sizePattern.FindStringSubmatch(event.Output); len(matches) > 2 { if currentTest != "" && results[currentTest] != nil { - results[currentTest].Size = matches[1] + " MB" + results[currentTest].Size = matches[1] + " " + matches[2] } } } diff --git a/ci_behavior_test.go b/ci_behavior_test.go index 6a6f689..7b702ce 100644 --- a/ci_behavior_test.go +++ b/ci_behavior_test.go @@ -89,6 +89,8 @@ func TestPersistentStateSharingWithSynctest(t *testing.T) { time.Sleep(StateSaveInterval + StateSaveJitter + 3*time.Second) synctest.Wait() + reader.stateFileModTime = time.Time{} + reader.reconcileStateFromFileIfChanged() v, ok := reader.rateCache.Get("107.198.0.0") if !ok { diff --git a/internal/state/lock.go b/internal/state/lock.go index 842e0e4..6df25c2 100644 --- a/internal/state/lock.go +++ b/internal/state/lock.go @@ -15,6 +15,11 @@ type FileLock struct { pid int } +type lockPIDFile interface { + WriteString(string) (int, error) + Close() error +} + // NewFileLock creates a new file lock for the given path. // It uses a separate .lock file to coordinate access. func NewFileLock(path string) (*FileLock, error) { @@ -36,17 +41,8 @@ func (fl *FileLock) Lock() error { f, err := os.OpenFile(fl.lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) if err == nil { // Successfully created lock file - _, writeErr := f.WriteString(strconv.Itoa(fl.pid)) - closeErr := f.Close() - if writeErr != nil { - // We got the lock but failed to write. - // Best effort to clean up, then return the error. - _ = os.Remove(fl.lockPath) - return fmt.Errorf("failed to write pid to lock file: %v", writeErr) - } - if closeErr != nil { - _ = os.Remove(fl.lockPath) - return fmt.Errorf("failed to close lock file: %v", closeErr) + if err := writeLockPID(f, fl.lockPath, fl.pid); err != nil { + return err } // We hold the lock return nil @@ -82,6 +78,22 @@ func (fl *FileLock) Lock() error { } } +func writeLockPID(file lockPIDFile, lockPath string, pid int) error { + _, writeErr := file.WriteString(strconv.Itoa(pid)) + closeErr := file.Close() + if writeErr != nil { + // We got the lock but failed to write. + // Best effort to clean up, then return the error. + _ = os.Remove(lockPath) + return fmt.Errorf("failed to write pid to lock file: %v", writeErr) + } + if closeErr != nil { + _ = os.Remove(lockPath) + return fmt.Errorf("failed to close lock file: %v", closeErr) + } + return nil +} + // Unlock releases the exclusive lock by removing the lock file // This is now safer and checks the PID. func (fl *FileLock) Unlock() error { diff --git a/internal/state/lock_test.go b/internal/state/lock_test.go index b6a47e4..75071c8 100644 --- a/internal/state/lock_test.go +++ b/internal/state/lock_test.go @@ -2,6 +2,7 @@ package state import ( + "errors" "fmt" "os" "path/filepath" @@ -12,6 +13,19 @@ import ( "time" ) +type fakeLockPIDFile struct { + writeErr error + closeErr error +} + +func (f fakeLockPIDFile) WriteString(string) (int, error) { + return 0, f.writeErr +} + +func (f fakeLockPIDFile) Close() error { + return f.closeErr +} + // TestFileLock_LockUnlock tests the basic Lock and Unlock functionality. func TestFileLock_LockUnlock(t *testing.T) { t.Parallel() @@ -53,6 +67,46 @@ func TestFileLock_LockUnlock(t *testing.T) { } } +func TestWriteLockPIDErrorsCleanUpLockFile(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + file fakeLockPIDFile + wantErr string + }{ + { + name: "write error", + file: fakeLockPIDFile{writeErr: errors.New("write failed")}, + wantErr: "failed to write pid to lock file: write failed", + }, + { + name: "close error", + file: fakeLockPIDFile{closeErr: errors.New("close failed")}, + wantErr: "failed to close lock file: close failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + lockPath := filepath.Join(t.TempDir(), "test.lock") + if err := os.WriteFile(lockPath, []byte("partial"), 0600); err != nil { + t.Fatalf("failed to create lock file: %v", err) + } + + err := writeLockPID(tt.file, lockPath, os.Getpid()) + if err == nil || err.Error() != tt.wantErr { + t.Fatalf("writeLockPID error = %v, want %q", err, tt.wantErr) + } + if _, statErr := os.Stat(lockPath); !os.IsNotExist(statErr) { + t.Fatalf("expected failed writeLockPID to remove lock file, stat err = %v", statErr) + } + }) + } +} + // TestFileLock_Close tests the Close functionality, including idempotency. func TestFileLock_Close(t *testing.T) { t.Parallel() diff --git a/internal/state/state.go b/internal/state/state.go index c109b89..647f6f6 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -1,12 +1,14 @@ package state import ( + "bufio" "encoding/json" "fmt" "log/slog" "os" "path/filepath" "reflect" + "strconv" "time" lru "github.com/patrickmn/go-cache" @@ -25,6 +27,30 @@ type State struct { Memory map[string]uintptr `json:"memory"` } +type persistentEntry[T any] struct { + Value T `json:"value"` + Expiration int64 `json:"expiration"` +} + +type persistentState struct { + Rate map[string]persistentEntry[uint] `json:"rate"` + Bots map[string]persistentEntry[bool] `json:"bots"` + Verified map[string]persistentEntry[bool] `json:"verified"` + Memory map[string]uintptr `json:"memory"` +} + +type ignoredJSON struct{} + +func (ignoredJSON) UnmarshalJSON(_ []byte) error { + return nil +} + +type reconcileStateFile struct { + Rate map[string]persistentEntry[uint] `json:"rate"` + Bots ignoredJSON `json:"bots"` // Bot cache is derived and too large to merge on every state save. + Verified map[string]persistentEntry[bool] `json:"verified"` +} + type SaveMetrics struct { LockMs int64 ReadMs int64 @@ -112,32 +138,27 @@ func SaveStateToFileWithMetrics( if readErr == nil && len(fileContent) > 0 { reconcileStart := time.Now() - var fileState State + var fileState reconcileStateFile if unmarshalErr := json.Unmarshal(fileContent, &fileState); unmarshalErr == nil { log.Debug("Reconciling state before save", "fileBytes", len(fileContent)) - ReconcileState(fileState, rateCache, botCache, verifiedCache) + reconcilePersistentFileState(fileState, rateCache, verifiedCache) } metrics.ReconcileMs = time.Since(reconcileStart).Milliseconds() } } - // Marshal current state + // Bot cache entries are derived from DNS/IP checks and can dwarf the + // state that needs cross-instance sharing. Persist only rate limiter and + // verified-user state; existing files with bot entries still load. + rateItems := rateCache.Items() + verifiedItems := verifiedCache.Items() + metrics.RateEntries = len(rateItems) + metrics.BotEntries = 0 + metrics.VerifiedEntries = len(verifiedItems) + marshalStart := time.Now() - currentState := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) - jsonData, err := json.Marshal(currentState) + err = atomicWriteStateFile(filePath, rateItems, nil, verifiedItems, 0600) metrics.MarshalMs = time.Since(marshalStart).Milliseconds() - metrics.RateEntries = len(currentState.Rate) - metrics.BotEntries = len(currentState.Bots) - metrics.VerifiedEntries = len(currentState.Verified) - - if err != nil { - return metrics, err - } - - // Write to disk. State can contain client IPs, so keep snapshots private. - writeStart := time.Now() - err = atomicWriteFile(filePath, jsonData, 0600) - metrics.WriteMs = time.Since(writeStart).Milliseconds() if err != nil { return metrics, err @@ -147,7 +168,11 @@ func SaveStateToFileWithMetrics( return metrics, nil } -func atomicWriteFile(filePath string, data []byte, perm os.FileMode) error { +func atomicWriteStateFile( + filePath string, + rateItems, botItems, verifiedItems map[string]lru.Item, + perm os.FileMode, +) error { dir := filepath.Dir(filePath) tmp, err := os.CreateTemp(dir, filepath.Base(filePath)+".tmp-*") if err != nil { @@ -156,7 +181,12 @@ func atomicWriteFile(filePath string, data []byte, perm os.FileMode) error { tmpName := tmp.Name() defer os.Remove(tmpName) - if _, err := tmp.Write(data); err != nil { + writer := bufio.NewWriterSize(tmp, 1024*1024) + if err := writeStateJSON(writer, rateItems, botItems, verifiedItems); err != nil { + _ = tmp.Close() + return err + } + if err := writer.Flush(); err != nil { _ = tmp.Close() return err } @@ -171,6 +201,125 @@ func atomicWriteFile(filePath string, data []byte, perm os.FileMode) error { return os.Rename(tmpName, filePath) } +func writeStateJSON( + writer *bufio.Writer, + rateItems, botItems, verifiedItems map[string]lru.Item, +) error { + if err := writeString(writer, `{"rate":`); err != nil { + return err + } + rateMemory, err := writeCacheEntryMap[uint](writer, rateItems, writeUint) + if err != nil { + return err + } + if err := writeString(writer, `,"bots":`); err != nil { + return err + } + botMemory, err := writeCacheEntryMap[bool](writer, botItems, writeBool) + if err != nil { + return err + } + if err := writeString(writer, `,"verified":`); err != nil { + return err + } + verifiedMemory, err := writeCacheEntryMap[bool](writer, verifiedItems, writeBool) + if err != nil { + return err + } + + if err := writeString(writer, `,"memory":{"rate":`); err != nil { + return err + } + if err := writeString(writer, strconv.FormatUint(uint64(rateMemory), 10)); err != nil { + return err + } + if err := writeString(writer, `,"bot":`); err != nil { + return err + } + if err := writeString(writer, strconv.FormatUint(uint64(botMemory), 10)); err != nil { + return err + } + if err := writeString(writer, `,"verified":`); err != nil { + return err + } + if err := writeString(writer, strconv.FormatUint(uint64(verifiedMemory), 10)); err != nil { + return err + } + return writeString(writer, `}}`) +} + +func writeCacheEntryMap[T any]( + writer *bufio.Writer, + items map[string]lru.Item, + writeValue func(*bufio.Writer, T) error, +) (uintptr, error) { + if err := writer.WriteByte('{'); err != nil { + return 0, err + } + + memoryUsage := reflect.TypeOf(map[string]CacheEntry{}).Size() + first := true + quotedKey := make([]byte, 0, 64) + for key, item := range items { + value, ok := item.Object.(T) + if !ok { + return memoryUsage, fmt.Errorf("unexpected cache value type for %q", key) + } + + if !first { + if err := writer.WriteByte(','); err != nil { + return memoryUsage, err + } + } + first = false + + quotedKey = strconv.AppendQuote(quotedKey[:0], key) + if _, err := writer.Write(quotedKey); err != nil { + return memoryUsage, err + } + if err := writeString(writer, `:{"value":`); err != nil { + return memoryUsage, err + } + if err := writeValue(writer, value); err != nil { + return memoryUsage, err + } + if err := writeString(writer, `,"expiration":`); err != nil { + return memoryUsage, err + } + if err := writeString(writer, strconv.FormatInt(item.Expiration, 10)); err != nil { + return memoryUsage, err + } + if err := writer.WriteByte('}'); err != nil { + return memoryUsage, err + } + + memoryUsage += reflect.TypeOf(key).Size() + memoryUsage += reflect.TypeOf(item).Size() + memoryUsage += uintptr(len(key)) + } + + if err := writer.WriteByte('}'); err != nil { + return memoryUsage, err + } + return memoryUsage, nil +} + +func writeUint(writer *bufio.Writer, value uint) error { + return writeString(writer, strconv.FormatUint(uint64(value), 10)) +} + +func writeBool(writer *bufio.Writer, value bool) error { + if value { + return writeString(writer, "true") + } + return writeString(writer, "false") +} + +func writeString(writer *bufio.Writer, value string) error { + _, err := writer.WriteString(value) + return err +} + // LoadStateFromFile loads state from a file with locking. func LoadStateFromFile( filePath string, @@ -191,14 +340,14 @@ func LoadStateFromFile( return err } - var loadedState State + var loadedState persistentState err = json.Unmarshal(fileContent, &loadedState) if err != nil { return err } // Use SetState which properly handles expiration times - SetState(loadedState, rateCache, botCache, verifiedCache) + setPersistentState(loadedState, rateCache, botCache, verifiedCache) return nil } @@ -222,12 +371,12 @@ func ReconcileStateFromFile( return err } - var fileState State + var fileState reconcileStateFile if err := json.Unmarshal(fileContent, &fileState); err != nil { return err } - ReconcileState(fileState, rateCache, botCache, verifiedCache) + reconcilePersistentFileState(fileState, rateCache, verifiedCache) return nil } @@ -296,6 +445,83 @@ func loadCacheEntries[T any]( } } +func setPersistentState(state persistentState, rateCache, botCache, verifiedCache *lru.Cache) { + loadPersistentEntries(state.Rate, rateCache) + loadPersistentEntries(state.Bots, botCache) + loadPersistentEntries(state.Verified, verifiedCache) +} + +func loadPersistentEntries[T any](entries map[string]persistentEntry[T], cache *lru.Cache) { + now := time.Now().UnixNano() + for key, entry := range entries { + if entry.Expiration > 0 && entry.Expiration <= now { + continue + } + cache.Set(key, entry.Value, calculateDuration(entry.Expiration, now)) + } +} + +func reconcilePersistentFileState(state reconcileStateFile, rateCache, verifiedCache *lru.Cache) { + rateItems := rateCache.Items() + verifiedItems := verifiedCache.Items() + + reconcilePersistentRateCache(state.Rate, rateItems, rateCache) + reconcilePersistentCacheEntries(state.Verified, verifiedItems, verifiedCache) +} + +func reconcilePersistentCacheEntries[T any]( + fileEntries map[string]persistentEntry[T], + memItems map[string]lru.Item, + cache *lru.Cache, +) { + now := time.Now().UnixNano() + for key, fileEntry := range fileEntries { + if fileEntry.Expiration > 0 && fileEntry.Expiration <= now { + continue + } + + duration := calculateDuration(fileEntry.Expiration, now) + memItem, exists := memItems[key] + if !exists { + cache.Set(key, fileEntry.Value, duration) + continue + } + + if fileEntry.Expiration > memItem.Expiration { + cache.Set(key, fileEntry.Value, duration) + } + } +} + +func reconcilePersistentRateCache( + fileEntries map[string]persistentEntry[uint], + memItems map[string]lru.Item, + cache *lru.Cache, +) { + now := time.Now().UnixNano() + for key, fileEntry := range fileEntries { + if fileEntry.Expiration > 0 && fileEntry.Expiration <= now { + continue + } + + memItem, exists := memItems[key] + if !exists { + cache.Set(key, fileEntry.Value, calculateDuration(fileEntry.Expiration, now)) + continue + } + + memValue, ok := memItem.Object.(uint) + if !ok { + cache.Set(key, fileEntry.Value, calculateDuration(fileEntry.Expiration, now)) + continue + } + + combinedValue := maxUint(fileEntry.Value, memValue) + laterExpiration := max(fileEntry.Expiration, memItem.Expiration) + cache.Set(key, combinedValue, calculateDuration(laterExpiration, now)) + } +} + // reconcileCacheEntries implements "later expiration wins" // This is correct for bool flags (Verified, Bots). func reconcileCacheEntries[T any]( diff --git a/internal/state/state_stress_test.go b/internal/state/state_stress_test.go index 7c43a8e..6ad4345 100644 --- a/internal/state/state_stress_test.go +++ b/internal/state/state_stress_test.go @@ -1,11 +1,12 @@ package state import ( + "bufio" + "bytes" "encoding/json" "fmt" "log/slog" "math" - "os" "testing" "time" @@ -15,13 +16,13 @@ import ( // This file contains stress tests for state persistence operations at various scales. // // Performance Findings (Apple M2 Pro): -// Small (16 rate / 65K bots / 256 verified → 3.87 MB JSON): +// Small (16 rate / derived bots skipped / 256 verified): // - SaveStateToFile with reconciliation: ~84ms -// Medium (256 rate / 262K bots / 65K verified → 19.31 MB JSON): +// Medium (256 rate / derived bots skipped / 65K verified): // - SaveStateToFile with reconciliation: ~410ms -// Large (1,024 rate / 1M bots / 262K verified → 77.61 MB JSON): +// Large (1,024 rate / derived bots skipped / 262K verified): // - SaveStateToFile with reconciliation: ~1.8s -// XLarge (4,096 rate / 4.2M bots / 1M verified → 312.68 MB JSON): +// XLarge (4,096 rate / derived bots skipped / 1M verified): // - SaveStateToFile with reconciliation: ~8.7s (approaching 10s save window limit) // // Recommendation: Do not enable enableStateReconciliation for sites with >1M unique visitors. @@ -118,6 +119,22 @@ func populateCaches(level StressLevel, rateCache, botCache, verifiedCache *lru.C } } +func populatePersistentCaches(level StressLevel, rateCache, verifiedCache *lru.Cache) { + expiration := 24 * time.Hour + + for i := 0; i < level.RateEntries; i++ { + subnet := generateIPv4Subnet(i) + rate := uint(1 + (i % 100)) + rateCache.Set(subnet, rate, expiration) + } + + startOffset := 0x10000000 // Start from 16.0.0.0 + for i := 0; i < level.VerifiedEntries; i++ { + ip := generateIPv4Address(startOffset + i) + verifiedCache.Set(ip, true, expiration) + } +} + // BenchmarkStateOperations benchmarks marshal/unmarshal/reconcile at different scales func BenchmarkStateOperations(b *testing.B) { levels := getStressLevels() @@ -239,56 +256,38 @@ func TestStateOperationsWithinThreshold(t *testing.T) { levels := getStressLevels() - // Define thresholds for each operation (in milliseconds) + // Define thresholds for persistence operations (in milliseconds) // These are generous limits to avoid flaky tests on slower CI machines type thresholds struct { - GetStateMs int64 - MarshalMs int64 - UnmarshalMs int64 - SetStateMs int64 - ReconcileMs int64 + MarshalMs int64 + SaveWithReconcileMs int64 } levelThresholds := map[string]thresholds{ "Small": { - GetStateMs: 100, - MarshalMs: 200, - UnmarshalMs: 200, - SetStateMs: 200, - ReconcileMs: 200, + MarshalMs: 200, + SaveWithReconcileMs: 500, }, "Medium": { - GetStateMs: 200, - MarshalMs: 500, - UnmarshalMs: 500, - SetStateMs: 500, - ReconcileMs: 500, + MarshalMs: 500, + SaveWithReconcileMs: 1000, }, "Large": { - GetStateMs: 500, - MarshalMs: 2000, - UnmarshalMs: 2000, - SetStateMs: 2000, - ReconcileMs: 2000, + MarshalMs: 2000, + SaveWithReconcileMs: 3000, }, "XLarge": { - GetStateMs: 2000, - MarshalMs: 5000, - UnmarshalMs: 5000, - SetStateMs: 3000, - ReconcileMs: 3000, + MarshalMs: 5000, + SaveWithReconcileMs: 10000, }, "XXLarge": { - GetStateMs: 5000, - MarshalMs: 15000, - UnmarshalMs: 15000, - SetStateMs: 10000, - ReconcileMs: 10000, + MarshalMs: 15000, + SaveWithReconcileMs: 30000, }, } @@ -300,41 +299,16 @@ func TestStateOperationsWithinThreshold(t *testing.T) { botCache := lru.New(24*time.Hour, lru.NoExpiration) verifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - t.Logf("Populating caches (rate=%d, bots=%d, verified=%d)...", - level.RateEntries, level.BotEntries, level.VerifiedEntries) - populateCaches(level, rateCache, botCache, verifiedCache) + t.Logf("Populating persistent caches (rate=%d, derived bots skipped, verified=%d)...", + level.RateEntries, level.VerifiedEntries) + populatePersistentCaches(level, rateCache, verifiedCache) thresh := levelThresholds[level.Name] - // Test GetState - t.Run("GetState", func(t *testing.T) { - start := time.Now() - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) - elapsed := time.Since(start).Milliseconds() - - t.Logf("GetState took %dms (threshold: %dms)", elapsed, thresh.GetStateMs) - if elapsed > thresh.GetStateMs { - slog.Error(fmt.Sprintf("GetState took %dms, exceeds threshold of %dms", elapsed, thresh.GetStateMs)) - } - - // Verify counts - if len(state.Rate) != level.RateEntries { - t.Errorf("Expected %d rate entries, got %d", level.RateEntries, len(state.Rate)) - } - if len(state.Bots) != level.BotEntries { - t.Errorf("Expected %d bot entries, got %d", level.BotEntries, len(state.Bots)) - } - if len(state.Verified) != level.VerifiedEntries { - t.Errorf("Expected %d verified entries, got %d", level.VerifiedEntries, len(state.Verified)) - } - }) - - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) - // Test Marshal t.Run("Marshal", func(t *testing.T) { start := time.Now() - jsonData, err := json.Marshal(state) + jsonData, err := marshalPersistentSnapshotForStress(rateCache, verifiedCache) elapsed := time.Since(start).Milliseconds() if err != nil { @@ -356,76 +330,22 @@ func TestStateOperationsWithinThreshold(t *testing.T) { } }) - // Test Unmarshal - jsonData, _ := json.Marshal(state) - t.Run("Unmarshal", func(t *testing.T) { - start := time.Now() - var loadedState State - err := json.Unmarshal(jsonData, &loadedState) - elapsed := time.Since(start).Milliseconds() - - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - t.Logf("Unmarshal took %dms (threshold: %dms)", elapsed, thresh.UnmarshalMs) - if elapsed > thresh.UnmarshalMs { - slog.Error(fmt.Sprintf("Unmarshal took %dms, exceeds threshold of %dms", elapsed, thresh.UnmarshalMs)) - } - }) - - // Test SetState - t.Run("SetState", func(t *testing.T) { - newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) - newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - - start := time.Now() - SetState(state, newRateCache, newBotCache, newVerifiedCache) - elapsed := time.Since(start).Milliseconds() - - t.Logf("SetState took %dms (threshold: %dms)", elapsed, thresh.SetStateMs) - if elapsed > thresh.SetStateMs { - slog.Error(fmt.Sprintf("SetState took %dms, exceeds threshold of %dms", elapsed, thresh.SetStateMs)) - } - - // Verify data was loaded - if newRateCache.ItemCount() != level.RateEntries { - t.Errorf("Expected %d rate entries after SetState, got %d", - level.RateEntries, newRateCache.ItemCount()) - } - }) - - // Test ReconcileState - t.Run("ReconcileState", func(t *testing.T) { - newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) - newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - - // Pre-populate with 50% overlapping data - for i := 0; i < level.RateEntries/2; i++ { - subnet := generateIPv4Subnet(i) - newRateCache.Set(subnet, uint(50), 24*time.Hour) - } - - start := time.Now() - ReconcileState(state, newRateCache, newBotCache, newVerifiedCache) - elapsed := time.Since(start).Milliseconds() - - t.Logf("ReconcileState took %dms (threshold: %dms)", elapsed, thresh.ReconcileMs) - if elapsed > thresh.ReconcileMs { - slog.Error(fmt.Sprintf("ReconcileState took %dms, exceeds threshold of %dms", elapsed, thresh.ReconcileMs)) - } - }) - // Test full SaveStateToFile with reconciliation t.Run("SaveStateToFile", func(t *testing.T) { tmpFile := t.TempDir() + "/state.json" logger := testLogger() // Pre-create a state file to enable reconciliation - initialData, _ := json.Marshal(state) - _ = os.WriteFile(tmpFile, initialData, 0644) + if _, _, _, _, _, _, err := SaveStateToFile( + tmpFile, + false, + rateCache, + botCache, + verifiedCache, + logger, + ); err != nil { + t.Fatalf("Failed to write initial state: %v", err) + } start := time.Now() lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( @@ -498,3 +418,15 @@ func TestGenerateUniqueIPs(t *testing.T) { } }) } + +func marshalPersistentSnapshotForStress(rateCache, verifiedCache *lru.Cache) ([]byte, error) { + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + if err := writeStateJSON(writer, rateCache.Items(), nil, verifiedCache.Items()); err != nil { + return nil, err + } + if err := writer.Flush(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/internal/state/state_test.go b/internal/state/state_test.go index 1899a16..ee88d15 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -1,6 +1,8 @@ package state import ( + "bufio" + "bytes" "encoding/json" "log/slog" "os" @@ -304,8 +306,8 @@ func TestSaveStateToFile(t *testing.T) { if len(savedState.Rate) != 1 { t.Errorf("Expected 1 rate entry, got %d", len(savedState.Rate)) } - if len(savedState.Bots) != 1 { - t.Errorf("Expected 1 bot entry, got %d", len(savedState.Bots)) + if len(savedState.Bots) != 0 { + t.Errorf("Expected derived bot cache entries to be skipped, got %d", len(savedState.Bots)) } if len(savedState.Verified) != 1 { t.Errorf("Expected 1 verified entry, got %d", len(savedState.Verified)) @@ -411,8 +413,8 @@ func TestSaveStateToFile(t *testing.T) { if metrics.RateEntries != 1 { t.Errorf("Expected 1 rate entry, got %d", metrics.RateEntries) } - if metrics.BotEntries != 1 { - t.Errorf("Expected 1 bot entry, got %d", metrics.BotEntries) + if metrics.BotEntries != 0 { + t.Errorf("Expected derived bot cache entries to be skipped, got %d", metrics.BotEntries) } if metrics.VerifiedEntries != 1 { t.Errorf("Expected 1 verified entry, got %d", metrics.VerifiedEntries) @@ -605,15 +607,22 @@ func TestReconcileStateFromFile(t *testing.T) { tmpFile := t.TempDir() + "/state.json" now := time.Now().UnixNano() futureExpiration := now + int64(1*time.Hour) + laterExpiration := now + int64(2*time.Hour) fileState := State{ Rate: map[string]CacheEntry{ "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, }, - Bots: map[string]CacheEntry{}, + Bots: map[string]CacheEntry{ + "1.2.3.4": {Value: true, Expiration: futureExpiration}, + }, Verified: map[string]CacheEntry{}, Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, } + fileState.Verified = map[string]CacheEntry{ + "5.6.7.8": {Value: true, Expiration: futureExpiration}, + "9.9.9.9": {Value: true, Expiration: futureExpiration}, + } data, _ := json.Marshal(fileState) if err := os.WriteFile(tmpFile, data, 0644); err != nil { t.Fatalf("Failed to write test state: %v", err) @@ -624,6 +633,7 @@ func TestReconcileStateFromFile(t *testing.T) { verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) + verifiedCache.Set("9.9.9.9", false, time.Duration(laterExpiration-now)) if err := ReconcileStateFromFile(tmpFile, rateCache, botCache, verifiedCache); err != nil { t.Fatalf("ReconcileStateFromFile failed: %v", err) @@ -635,6 +645,135 @@ func TestReconcileStateFromFile(t *testing.T) { if v, ok := rateCache.Get("10.0.0.0"); !ok || v.(uint) != 5 { t.Error("Expected existing memory state to be retained") } + if botCache.ItemCount() != 0 { + t.Errorf("Expected derived bot cache entries to be skipped during reconciliation, got %d", botCache.ItemCount()) + } + if v, ok := verifiedCache.Get("5.6.7.8"); !ok || v.(bool) != true { + t.Error("Expected file verified state to be reconciled into memory") + } + if v, ok := verifiedCache.Get("9.9.9.9"); !ok || v.(bool) != false { + t.Error("Expected newer memory verified state to be retained") + } +} + +func TestSaveStateToFileWithMetricsWriteError(t *testing.T) { + statePath := t.TempDir() + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + + metrics, err := SaveStateToFileWithMetrics( + statePath, + false, + rateCache, + botCache, + verifiedCache, + testLogger(), + ) + if err == nil { + t.Fatal("expected write error when state path is a directory") + } + if metrics.RateEntries != 1 { + t.Fatalf("expected metrics to include marshaled rate entry, got %d", metrics.RateEntries) + } +} + +func TestAtomicWriteStateFileCreateTempError(t *testing.T) { + missingDir := filepath.Join(t.TempDir(), "missing") + err := atomicWriteStateFile(filepath.Join(missingDir, "state.json"), nil, nil, nil, 0600) + if err == nil { + t.Fatal("expected atomicWriteStateFile to fail when temp directory is missing") + } +} + +func TestWriteStateJSON(t *testing.T) { + rateCache := lru.New(1*time.Hour, 1*time.Minute) + botCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + botCache.Set("1.2.3.4", true, lru.DefaultExpiration) + botCache.Set("5.6.7.8", false, lru.DefaultExpiration) + verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) + + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + if err := writeStateJSON(writer, rateCache.Items(), botCache.Items(), verifiedCache.Items()); err != nil { + t.Fatalf("writeStateJSON failed: %v", err) + } + if err := writer.Flush(); err != nil { + t.Fatalf("Flush failed: %v", err) + } + + var saved State + if err := json.Unmarshal(buf.Bytes(), &saved); err != nil { + t.Fatalf("state JSON did not unmarshal: %v", err) + } + if len(saved.Rate) != 1 || len(saved.Bots) != 2 || len(saved.Verified) != 1 { + t.Fatalf("unexpected saved counts: rate=%d bots=%d verified=%d", len(saved.Rate), len(saved.Bots), len(saved.Verified)) + } + if saved.Bots["1.2.3.4"].Value != true { + t.Fatal("expected true bot value to be written") + } + if saved.Bots["5.6.7.8"].Value != false { + t.Fatal("expected false bot value to be written") + } +} + +func TestWriteStateJSONUnexpectedType(t *testing.T) { + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + err := writeStateJSON( + writer, + map[string]lru.Item{"bad": {Object: "not-a-uint", Expiration: time.Now().Add(time.Hour).UnixNano()}}, + nil, + nil, + ) + if err == nil { + t.Fatal("expected writeStateJSON to reject unexpected cache value type") + } +} + +func TestReconcileStateFromFileEmptyAndInvalidFiles(t *testing.T) { + newCaches := func() (*lru.Cache, *lru.Cache, *lru.Cache) { + return lru.New(1*time.Hour, 1*time.Minute), + lru.New(1*time.Hour, 1*time.Minute), + lru.New(1*time.Hour, 1*time.Minute) + } + + t.Run("missing file", func(t *testing.T) { + rateCache, botCache, verifiedCache := newCaches() + err := ReconcileStateFromFile(filepath.Join(t.TempDir(), "missing.json"), rateCache, botCache, verifiedCache) + if err == nil { + t.Fatal("expected missing state file to return read error") + } + }) + + t.Run("empty file", func(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + if err := os.WriteFile(tmpFile, nil, 0600); err != nil { + t.Fatalf("failed to write empty state file: %v", err) + } + + rateCache, botCache, verifiedCache := newCaches() + if err := ReconcileStateFromFile(tmpFile, rateCache, botCache, verifiedCache); err != nil { + t.Fatalf("empty state file should be a no-op: %v", err) + } + }) + + t.Run("invalid JSON", func(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + if err := os.WriteFile(tmpFile, []byte("{invalid json"), 0600); err != nil { + t.Fatalf("failed to write invalid state file: %v", err) + } + + rateCache, botCache, verifiedCache := newCaches() + if err := ReconcileStateFromFile(tmpFile, rateCache, botCache, verifiedCache); err == nil { + t.Fatal("expected invalid state file to return unmarshal error") + } + }) } func testLogger() *slog.Logger { diff --git a/main_test.go b/main_test.go index 5cdd888..1d00dc4 100644 --- a/main_test.go +++ b/main_test.go @@ -1065,6 +1065,83 @@ func TestRegisterRequestCapsPersistedRateCounter(t *testing.T) { } } +func TestSaveStateFlushesDirtyStateOnCanceledContext(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + bc := newStateOnlyCaptchaProtect(tmpFile, 2) + + bc.rateCache.Set("192.168.0.0", uint(1), time.Hour) + bc.markStateDirty() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + bc.saveState(ctx) + + if bc.hasUnsavedState() { + t.Fatal("expected canceled saveState to flush dirty state before returning") + } + + data, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("expected state file to be written: %v", err) + } + + var saved struct { + Rate map[string]json.RawMessage `json:"rate"` + } + if err := json.Unmarshal(data, &saved); err != nil { + t.Fatalf("state file did not contain valid JSON: %v", err) + } + if _, ok := saved.Rate["192.168.0.0"]; !ok { + t.Fatal("expected dirty rate entry to be persisted") + } +} + +func TestSaveStateNowReturnsFalseAndKeepsDirtyStateOnWriteError(t *testing.T) { + statePath := t.TempDir() + bc := newStateOnlyCaptchaProtect(statePath, 2) + bc.rateCache.Set("192.168.0.0", uint(1), time.Hour) + bc.markStateDirty() + + if bc.saveStateNow() { + t.Fatal("expected saveStateNow to fail when persistent state path is a directory") + } + if !bc.hasUnsavedState() { + t.Fatal("expected failed save to keep state dirty") + } +} + +func TestStateBookkeepingErrorBranches(t *testing.T) { + missingFile := filepath.Join(t.TempDir(), "missing", "state.json") + bc := newStateOnlyCaptchaProtect(missingFile, 2) + + bc.stateDirty.Store(1) + bc.stateSavedDirty.Store(2) + if got := bc.unsavedStateChanges(); got != 0 { + t.Fatalf("unsavedStateChanges with saved counter ahead = %d, want 0", got) + } + + bc.refreshStateFileModTime() + if !bc.stateFileModTime.IsZero() { + t.Fatal("refreshStateFileModTime should ignore missing state file") + } + + bc.reconcileStateFromFileIfChanged() + if !bc.stateFileModTime.IsZero() { + t.Fatal("reconcileStateFromFileIfChanged should ignore missing state file") + } + + invalidFile := filepath.Join(t.TempDir(), "state.json") + if err := os.WriteFile(invalidFile, []byte("{invalid json"), 0600); err != nil { + t.Fatalf("failed to write invalid state file: %v", err) + } + bc.config.PersistentStateFile = invalidFile + bc.reconcileStateFromFileIfChanged() + if !bc.stateFileModTime.IsZero() { + t.Fatal("failed reconciliation should not advance state file mod time") + } +} + func TestVerifyChallengePage(t *testing.T) { tests := []struct { name string From 627a1ecde329a3764b68e34b0879da2528dadb21 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Wed, 6 May 2026 23:41:02 +0000 Subject: [PATCH 04/12] Wait for captcha middleware in CI smoke test --- ci/test.go | 58 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/ci/test.go b/ci/test.go index 7658c69..6d3071a 100755 --- a/ci/test.go +++ b/ci/test.go @@ -16,9 +16,10 @@ func main() { _ = os.Remove("./tmp/state.json") fmt.Println("Bringing traefik/nginx online") - runCommand("docker", "compose", "up", "-d") + runCommand("docker", "compose", "up", "-d", "--force-recreate") waitForService("http://localhost") waitForService("http://localhost/app2") + waitForProtectedRoute("http://localhost", "http://localhost/challenge?destination=%2F") fmt.Println("Testing Traefik plugin smoke path...") assertProtectedRoute("107.198.130.166", "http://localhost", "http://localhost/challenge?destination=%2F") @@ -49,12 +50,44 @@ func waitForService(url string) { os.Exit(1) } +func waitForProtectedRoute(url, expectedURL string) { + deadline := time.Now().Add(90 * time.Second) + attempt := 0 + for time.Now().Before(deadline) { + readinessIP := fmt.Sprintf("109.%d.130.168", attempt%250) + if routeIsProtected(readinessIP, url, expectedURL) { + return + } + attempt++ + fmt.Println("waiting for captcha-protect middleware to become active...") + time.Sleep(1 * time.Second) + } + + slog.Error("Timed out waiting for captcha-protect middleware", "url", url, "expected", expectedURL) + os.Exit(1) +} + +func routeIsProtected(ip, url, expectedURL string) bool { + for i := 0; i < rateLimit; i++ { + if output, err := httpRequest(ip, url); err != nil || output != "" { + return false + } + } + + output, err := httpRequest(ip, url) + return err == nil && output == expectedURL +} + func assertProtectedRoute(ip, url, expectedURL string) { for i := 0; i < rateLimit; i++ { assertNoRedirect(ip, url) } - output := httpRequest(ip, url) + output, err := httpRequest(ip, url) + if err != nil { + slog.Error("Request failed", "ip", ip, "url", url, "err", err) + os.Exit(1) + } if output != expectedURL { slog.Error("Expected protected route to redirect", "ip", ip, "url", url, "output", output, "expected", expectedURL) os.Exit(1) @@ -62,14 +95,18 @@ func assertProtectedRoute(ip, url, expectedURL string) { } func assertNoRedirect(ip, url string) { - output := httpRequest(ip, url) + output, err := httpRequest(ip, url) + if err != nil { + slog.Error("Request failed", "ip", ip, "url", url, "err", err) + os.Exit(1) + } if output != "" { slog.Error("Unexpected redirect", "ip", ip, "url", url, "output", output) os.Exit(1) } } -func httpRequest(ip, url string) string { +func httpRequest(ip, url string) (string, error) { client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { if len(via) > 0 { @@ -82,15 +119,13 @@ func httpRequest(ip, url string) string { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - slog.Error("Failed to create request", "err", err) - os.Exit(1) + return "", err } req.Header.Set("X-Forwarded-For", ip) resp, err := client.Do(req) if err != nil { - slog.Error("Request failed", "err", err) - os.Exit(1) + return "", err } defer func() { _ = resp.Body.Close() @@ -99,13 +134,12 @@ func httpRequest(ip, url string) string { location, err := resp.Location() if err != nil { if err == http.ErrNoLocation { - return "" + return "", nil } - slog.Error("Failed to get redirect URL", "err", err) - os.Exit(1) + return "", err } - return strings.TrimSpace(location.String()) + return strings.TrimSpace(location.String()), nil } func runCommand(name string, args ...string) { From c4c382e0d2d53f4637052afd2948d8abe1671a7f Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Wed, 6 May 2026 23:47:16 +0000 Subject: [PATCH 05/12] Restore fast CI smoke failure signal --- ci/test.go | 31 +------------------------------ go.mod | 2 -- 2 files changed, 1 insertion(+), 32 deletions(-) diff --git a/ci/test.go b/ci/test.go index 6d3071a..f30d78b 100755 --- a/ci/test.go +++ b/ci/test.go @@ -16,10 +16,9 @@ func main() { _ = os.Remove("./tmp/state.json") fmt.Println("Bringing traefik/nginx online") - runCommand("docker", "compose", "up", "-d", "--force-recreate") + runCommand("docker", "compose", "up", "-d") waitForService("http://localhost") waitForService("http://localhost/app2") - waitForProtectedRoute("http://localhost", "http://localhost/challenge?destination=%2F") fmt.Println("Testing Traefik plugin smoke path...") assertProtectedRoute("107.198.130.166", "http://localhost", "http://localhost/challenge?destination=%2F") @@ -50,34 +49,6 @@ func waitForService(url string) { os.Exit(1) } -func waitForProtectedRoute(url, expectedURL string) { - deadline := time.Now().Add(90 * time.Second) - attempt := 0 - for time.Now().Before(deadline) { - readinessIP := fmt.Sprintf("109.%d.130.168", attempt%250) - if routeIsProtected(readinessIP, url, expectedURL) { - return - } - attempt++ - fmt.Println("waiting for captcha-protect middleware to become active...") - time.Sleep(1 * time.Second) - } - - slog.Error("Timed out waiting for captcha-protect middleware", "url", url, "expected", expectedURL) - os.Exit(1) -} - -func routeIsProtected(ip, url, expectedURL string) bool { - for i := 0; i < rateLimit; i++ { - if output, err := httpRequest(ip, url); err != nil || output != "" { - return false - } - } - - output, err := httpRequest(ip, url) - return err == nil && output == expectedURL -} - func assertProtectedRoute(ip, url, expectedURL string) { for i := 0; i < rateLimit; i++ { assertNoRedirect(ip, url) diff --git a/go.mod b/go.mod index fd36af6..4f45b71 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,4 @@ module github.com/libops/captcha-protect go 1.25.0 -toolchain go1.26.2 - require github.com/patrickmn/go-cache v2.1.0+incompatible From 26801dbbfabe8777ef7971dec76f5549752c6a1e Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Wed, 6 May 2026 23:56:06 +0000 Subject: [PATCH 06/12] Restore Yaegi compatibility for state persistence --- internal/state/state.go | 49 +++++++++++++++++++++++++++-------------- main.go | 37 +++++++++++++++++++++++-------- main_test.go | 8 ++++--- 3 files changed, 65 insertions(+), 29 deletions(-) diff --git a/internal/state/state.go b/internal/state/state.go index 647f6f6..d54e05c 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -27,16 +27,21 @@ type State struct { Memory map[string]uintptr `json:"memory"` } -type persistentEntry[T any] struct { - Value T `json:"value"` +type persistentRateEntry struct { + Value uint `json:"value"` + Expiration int64 `json:"expiration"` +} + +type persistentBoolEntry struct { + Value bool `json:"value"` Expiration int64 `json:"expiration"` } type persistentState struct { - Rate map[string]persistentEntry[uint] `json:"rate"` - Bots map[string]persistentEntry[bool] `json:"bots"` - Verified map[string]persistentEntry[bool] `json:"verified"` - Memory map[string]uintptr `json:"memory"` + Rate map[string]persistentRateEntry `json:"rate"` + Bots map[string]persistentBoolEntry `json:"bots"` + Verified map[string]persistentBoolEntry `json:"verified"` + Memory map[string]uintptr `json:"memory"` } type ignoredJSON struct{} @@ -46,9 +51,9 @@ func (ignoredJSON) UnmarshalJSON(_ []byte) error { } type reconcileStateFile struct { - Rate map[string]persistentEntry[uint] `json:"rate"` - Bots ignoredJSON `json:"bots"` // Bot cache is derived and too large to merge on every state save. - Verified map[string]persistentEntry[bool] `json:"verified"` + Rate map[string]persistentRateEntry `json:"rate"` + Bots ignoredJSON `json:"bots"` // Bot cache is derived and too large to merge on every state save. + Verified map[string]persistentBoolEntry `json:"verified"` } type SaveMetrics struct { @@ -446,12 +451,22 @@ func loadCacheEntries[T any]( } func setPersistentState(state persistentState, rateCache, botCache, verifiedCache *lru.Cache) { - loadPersistentEntries(state.Rate, rateCache) - loadPersistentEntries(state.Bots, botCache) - loadPersistentEntries(state.Verified, verifiedCache) + loadPersistentRateEntries(state.Rate, rateCache) + loadPersistentBoolEntries(state.Bots, botCache) + loadPersistentBoolEntries(state.Verified, verifiedCache) +} + +func loadPersistentRateEntries(entries map[string]persistentRateEntry, cache *lru.Cache) { + now := time.Now().UnixNano() + for key, entry := range entries { + if entry.Expiration > 0 && entry.Expiration <= now { + continue + } + cache.Set(key, entry.Value, calculateDuration(entry.Expiration, now)) + } } -func loadPersistentEntries[T any](entries map[string]persistentEntry[T], cache *lru.Cache) { +func loadPersistentBoolEntries(entries map[string]persistentBoolEntry, cache *lru.Cache) { now := time.Now().UnixNano() for key, entry := range entries { if entry.Expiration > 0 && entry.Expiration <= now { @@ -466,11 +481,11 @@ func reconcilePersistentFileState(state reconcileStateFile, rateCache, verifiedC verifiedItems := verifiedCache.Items() reconcilePersistentRateCache(state.Rate, rateItems, rateCache) - reconcilePersistentCacheEntries(state.Verified, verifiedItems, verifiedCache) + reconcilePersistentBoolCacheEntries(state.Verified, verifiedItems, verifiedCache) } -func reconcilePersistentCacheEntries[T any]( - fileEntries map[string]persistentEntry[T], +func reconcilePersistentBoolCacheEntries( + fileEntries map[string]persistentBoolEntry, memItems map[string]lru.Item, cache *lru.Cache, ) { @@ -494,7 +509,7 @@ func reconcilePersistentCacheEntries[T any]( } func reconcilePersistentRateCache( - fileEntries map[string]persistentEntry[uint], + fileEntries map[string]persistentRateEntry, memItems map[string]lru.Item, cache *lru.Cache, ) { diff --git a/main.go b/main.go index b754400..4695339 100644 --- a/main.go +++ b/main.go @@ -17,7 +17,6 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" "github.com/libops/captcha-protect/internal/helper" @@ -100,8 +99,9 @@ type CaptchaProtect struct { ipv6Mask net.IPMask protectRoutesRegex []*regexp.Regexp excludeRoutesRegex []*regexp.Regexp - stateDirty atomic.Uint64 - stateSavedDirty atomic.Uint64 + stateMu sync.Mutex + stateDirty uint64 + stateSavedDirty uint64 stateFileModTime time.Time // Circuit breaker fields @@ -1106,7 +1106,7 @@ func stateSaveJitter() time.Duration { // saveStateNow performs an immediate state save using the state package. func (bc *CaptchaProtect) saveStateNow() bool { reconcile := bc.config.EnableStateReconciliation == "true" - dirtyAtStart := bc.stateDirty.Load() + dirtyAtStart := bc.currentStateDirty() metrics, err := state.SaveStateToFileWithMetrics( bc.config.PersistentStateFile, @@ -1121,7 +1121,7 @@ func (bc *CaptchaProtect) saveStateNow() bool { bc.log.Error("failed to save state", "err", err) return false } - bc.stateSavedDirty.Store(dirtyAtStart) + bc.markStateSaved(dirtyAtStart) bc.refreshStateFileModTime() bc.log.Debug("State saved successfully", @@ -1159,22 +1159,41 @@ func (bc *CaptchaProtect) markStateDirty() { if bc.config.PersistentStateFile == "" { return } - bc.stateDirty.Add(1) + bc.stateMu.Lock() + bc.stateDirty++ + bc.stateMu.Unlock() } func (bc *CaptchaProtect) hasUnsavedState() bool { - return bc.stateDirty.Load() != bc.stateSavedDirty.Load() + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + return bc.stateDirty != bc.stateSavedDirty } func (bc *CaptchaProtect) unsavedStateChanges() uint64 { - dirty := bc.stateDirty.Load() - saved := bc.stateSavedDirty.Load() + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + + dirty := bc.stateDirty + saved := bc.stateSavedDirty if dirty < saved { return 0 } return dirty - saved } +func (bc *CaptchaProtect) currentStateDirty() uint64 { + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + return bc.stateDirty +} + +func (bc *CaptchaProtect) markStateSaved(dirty uint64) { + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + bc.stateSavedDirty = dirty +} + func (bc *CaptchaProtect) refreshStateFileModTime() { info, err := os.Stat(bc.config.PersistentStateFile) if err != nil { diff --git a/main_test.go b/main_test.go index 1d00dc4..3c35efd 100644 --- a/main_test.go +++ b/main_test.go @@ -1060,7 +1060,7 @@ func TestRegisterRequestCapsPersistedRateCounter(t *testing.T) { if got, want := v.(uint), bc.config.RateLimit+1; got != want { t.Fatalf("Expected rate counter to cap at %d, got %d", want, got) } - if got, want := bc.stateDirty.Load(), uint64(bc.config.RateLimit+1); got != want { + if got, want := bc.currentStateDirty(), uint64(bc.config.RateLimit+1); got != want { t.Fatalf("Expected dirty counter to track only effective mutations, got %d want %d", got, want) } } @@ -1115,8 +1115,10 @@ func TestStateBookkeepingErrorBranches(t *testing.T) { missingFile := filepath.Join(t.TempDir(), "missing", "state.json") bc := newStateOnlyCaptchaProtect(missingFile, 2) - bc.stateDirty.Store(1) - bc.stateSavedDirty.Store(2) + bc.stateMu.Lock() + bc.stateDirty = 1 + bc.stateSavedDirty = 2 + bc.stateMu.Unlock() if got := bc.unsavedStateChanges(); got != 0 { t.Fatalf("unsavedStateChanges with saved counter ahead = %d, want 0", got) } From 4be6e44626e4f66e55964a8a87b268edf5f9a78e Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Thu, 7 May 2026 00:19:21 +0000 Subject: [PATCH 07/12] Address state persistence review findings --- ci_behavior_test.go | 8 +-- internal/state/state.go | 101 +++++++++++++++++++--------- internal/state/state_stress_test.go | 24 +++---- internal/state/state_test.go | 44 ++++++------ main.go | 39 +++++++++-- main_test.go | 16 ++--- 6 files changed, 149 insertions(+), 83 deletions(-) diff --git a/ci_behavior_test.go b/ci_behavior_test.go index 7b702ce..c615d5c 100644 --- a/ci_behavior_test.go +++ b/ci_behavior_test.go @@ -73,15 +73,11 @@ func TestPersistentStateSharingWithSynctest(t *testing.T) { reader := newStateOnlyCaptchaProtect(stateFile, 2) ctx, cancel := context.WithCancel(t.Context()) - done := make(chan struct{}, 2) + done := make(chan struct{}, 1) go func() { writer.saveState(ctx) done <- struct{}{} }() - go func() { - reader.saveState(ctx) - done <- struct{}{} - }() for i := uint(0); i < writer.config.RateLimit+1; i++ { writer.registerRequest("107.198.0.0") @@ -89,7 +85,6 @@ func TestPersistentStateSharingWithSynctest(t *testing.T) { time.Sleep(StateSaveInterval + StateSaveJitter + 3*time.Second) synctest.Wait() - reader.stateFileModTime = time.Time{} reader.reconcileStateFromFileIfChanged() v, ok := reader.rateCache.Get("107.198.0.0") @@ -103,7 +98,6 @@ func TestPersistentStateSharingWithSynctest(t *testing.T) { cancel() synctest.Wait() <-done - <-done }) } diff --git a/internal/state/state.go b/internal/state/state.go index d54e05c..5dc02a4 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -10,6 +10,7 @@ import ( "reflect" "strconv" "time" + "unicode/utf8" lru "github.com/patrickmn/go-cache" ) @@ -20,6 +21,7 @@ type CacheEntry struct { Expiration int64 `json:"expiration"` // Unix timestamp in nanoseconds, 0 means no expiration } +// State contains the persisted cache snapshot and approximate cache memory usage. type State struct { Rate map[string]CacheEntry `json:"rate"` Bots map[string]CacheEntry `json:"bots"` @@ -27,6 +29,8 @@ type State struct { Memory map[string]uintptr `json:"memory"` } +// Keep concrete entry types: Traefik's Yaegi interpreter cannot reliably handle +// generic persistentEntry[T] map fields. type persistentRateEntry struct { Value uint `json:"value"` Expiration int64 `json:"expiration"` @@ -56,6 +60,7 @@ type reconcileStateFile struct { Verified map[string]persistentBoolEntry `json:"verified"` } +// SaveMetrics reports timing and entry counts for a state save. type SaveMetrics struct { LockMs int64 ReadMs int64 @@ -68,6 +73,7 @@ type SaveMetrics struct { VerifiedEntries int } +// GetState converts cache items into a serializable state snapshot. func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { state := State{ Memory: make(map[string]uintptr, 3), @@ -102,19 +108,8 @@ func ReconcileState(fileState State, rateCache, botCache, verifiedCache *lru.Cac reconcileCacheEntries(fileState.Verified, verifiedItems, verifiedCache, convertBoolValue) } -// SaveStateToFile saves state to a file with locking and optional reconciliation. -// When reconcile is true, it reads and merges existing file state before saving. -// Returns timing metrics for debugging. -func SaveStateToFile( - filePath string, - reconcile bool, - rateCache, botCache, verifiedCache *lru.Cache, - log *slog.Logger, -) (lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs int64, err error) { - metrics, err := SaveStateToFileWithMetrics(filePath, reconcile, rateCache, botCache, verifiedCache, log) - return metrics.LockMs, metrics.ReadMs, metrics.ReconcileMs, metrics.MarshalMs, metrics.WriteMs, metrics.TotalMs, err -} - +// SaveStateToFileWithMetrics saves rate and verified state to a file with locking. +// Bot cache entries are loaded from legacy files but are not persisted because they are derived state. func SaveStateToFileWithMetrics( filePath string, reconcile bool, @@ -161,10 +156,7 @@ func SaveStateToFileWithMetrics( metrics.BotEntries = 0 metrics.VerifiedEntries = len(verifiedItems) - marshalStart := time.Now() - err = atomicWriteStateFile(filePath, rateItems, nil, verifiedItems, 0600) - metrics.MarshalMs = time.Since(marshalStart).Milliseconds() - + metrics.MarshalMs, metrics.WriteMs, err = atomicWriteStateFile(filePath, rateItems, nil, verifiedItems, 0600) if err != nil { return metrics, err } @@ -177,33 +169,40 @@ func atomicWriteStateFile( filePath string, rateItems, botItems, verifiedItems map[string]lru.Item, perm os.FileMode, -) error { +) (marshalMs, writeMs int64, err error) { dir := filepath.Dir(filePath) tmp, err := os.CreateTemp(dir, filepath.Base(filePath)+".tmp-*") if err != nil { - return err + return 0, 0, err } tmpName := tmp.Name() defer os.Remove(tmpName) writer := bufio.NewWriterSize(tmp, 1024*1024) + marshalStart := time.Now() if err := writeStateJSON(writer, rateItems, botItems, verifiedItems); err != nil { _ = tmp.Close() - return err + return 0, 0, err } if err := writer.Flush(); err != nil { _ = tmp.Close() - return err + return 0, 0, err } + marshalMs = time.Since(marshalStart).Milliseconds() + + writeStart := time.Now() if err := tmp.Chmod(perm); err != nil { _ = tmp.Close() - return err + return marshalMs, 0, err } if err := tmp.Close(); err != nil { - return err + return marshalMs, 0, err } - return os.Rename(tmpName, filePath) + if err := os.Rename(tmpName, filePath); err != nil { + return marshalMs, 0, err + } + return marshalMs, time.Since(writeStart).Milliseconds(), nil } func writeStateJSON( @@ -262,9 +261,8 @@ func writeCacheEntryMap[T any]( return 0, err } - memoryUsage := reflect.TypeOf(map[string]CacheEntry{}).Size() + memoryUsage := cacheEntryMapSize first := true - quotedKey := make([]byte, 0, 64) for key, item := range items { value, ok := item.Object.(T) if !ok { @@ -278,8 +276,7 @@ func writeCacheEntryMap[T any]( } first = false - quotedKey = strconv.AppendQuote(quotedKey[:0], key) - if _, err := writer.Write(quotedKey); err != nil { + if err := writeJSONString(writer, key); err != nil { return memoryUsage, err } if err := writeString(writer, `:{"value":`); err != nil { @@ -298,8 +295,8 @@ func writeCacheEntryMap[T any]( return memoryUsage, err } - memoryUsage += reflect.TypeOf(key).Size() - memoryUsage += reflect.TypeOf(item).Size() + memoryUsage += stringSize + memoryUsage += lruItemSize memoryUsage += uintptr(len(key)) } @@ -309,6 +306,12 @@ func writeCacheEntryMap[T any]( return memoryUsage, nil } +var ( + cacheEntryMapSize = reflect.TypeOf(map[string]CacheEntry{}).Size() + stringSize = reflect.TypeOf("").Size() + lruItemSize = reflect.TypeOf(lru.Item{}).Size() +) + func writeUint(writer *bufio.Writer, value uint) error { return writeString(writer, strconv.FormatUint(uint64(value), 10)) } @@ -325,6 +328,43 @@ func writeString(writer *bufio.Writer, value string) error { return err } +func writeJSONString(writer *bufio.Writer, value string) error { + if isPlainJSONString(value) { + if err := writer.WriteByte('"'); err != nil { + return err + } + if err := writeString(writer, value); err != nil { + return err + } + return writer.WriteByte('"') + } + + encoded, err := json.Marshal(value) + if err != nil { + return err + } + _, err = writer.Write(encoded) + return err +} + +func isPlainJSONString(value string) bool { + asciiOnly := true + for i := 0; i < len(value); i++ { + switch value[i] { + case '\\', '"': + return false + default: + if value[i] < 0x20 { + return false + } + if value[i] >= utf8.RuneSelf { + asciiOnly = false + } + } + } + return asciiOnly || utf8.ValidString(value) +} + // LoadStateFromFile loads state from a file with locking. func LoadStateFromFile( filePath string, @@ -357,6 +397,7 @@ func LoadStateFromFile( return nil } +// ReconcileStateFromFile merges newer persisted state into the provided caches. func ReconcileStateFromFile( filePath string, rateCache, botCache, verifiedCache *lru.Cache, diff --git a/internal/state/state_stress_test.go b/internal/state/state_stress_test.go index 6ad4345..ffba4a1 100644 --- a/internal/state/state_stress_test.go +++ b/internal/state/state_stress_test.go @@ -217,8 +217,8 @@ func BenchmarkStateOperations(b *testing.B) { } }) - // Benchmark full SaveStateToFile cycle (with reconciliation) - b.Run(fmt.Sprintf("SaveStateToFile/%s", level.Name), func(b *testing.B) { + // Benchmark full SaveStateToFileWithMetrics cycle (with reconciliation) + b.Run(fmt.Sprintf("SaveStateToFileWithMetrics/%s", level.Name), func(b *testing.B) { tmpDir := b.TempDir() tmpFile := tmpDir + "/state.json" logger := testLogger() @@ -226,7 +226,7 @@ func BenchmarkStateOperations(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, true, // enable reconciliation rateCache, @@ -241,7 +241,7 @@ func BenchmarkStateOperations(b *testing.B) { // Report timing breakdown (only once to avoid noise) if i == 0 { b.Logf("Timing breakdown: lock=%dms read=%dms reconcile=%dms marshal=%dms write=%dms total=%dms", - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs) + metrics.LockMs, metrics.ReadMs, metrics.ReconcileMs, metrics.MarshalMs, metrics.WriteMs, metrics.TotalMs) } } }) @@ -330,13 +330,13 @@ func TestStateOperationsWithinThreshold(t *testing.T) { } }) - // Test full SaveStateToFile with reconciliation - t.Run("SaveStateToFile", func(t *testing.T) { + // Test full SaveStateToFileWithMetrics with reconciliation + t.Run("SaveStateToFileWithMetrics", func(t *testing.T) { tmpFile := t.TempDir() + "/state.json" logger := testLogger() // Pre-create a state file to enable reconciliation - if _, _, _, _, _, _, err := SaveStateToFile( + if _, err := SaveStateToFileWithMetrics( tmpFile, false, rateCache, @@ -348,7 +348,7 @@ func TestStateOperationsWithinThreshold(t *testing.T) { } start := time.Now() - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, true, // enable reconciliation rateCache, @@ -364,7 +364,7 @@ func TestStateOperationsWithinThreshold(t *testing.T) { t.Logf("SaveStateToFile took %dms (threshold: %dms)", elapsed, thresh.SaveWithReconcileMs) t.Logf(" Breakdown: lock=%dms read=%dms reconcile=%dms marshal=%dms write=%dms total=%dms", - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs) + metrics.LockMs, metrics.ReadMs, metrics.ReconcileMs, metrics.MarshalMs, metrics.WriteMs, metrics.TotalMs) if elapsed > thresh.SaveWithReconcileMs { slog.Error(fmt.Sprintf("SaveStateToFile took %dms, exceeds threshold of %dms", @@ -372,10 +372,10 @@ func TestStateOperationsWithinThreshold(t *testing.T) { } // Verify math adds up (approximately, allowing for measurement overhead) - measuredTotal := lockMs + readMs + reconcileMs + marshalMs + writeMs - if totalMs > 0 && math.Abs(float64(measuredTotal-totalMs)) > float64(totalMs)*0.2 { + measuredTotal := metrics.LockMs + metrics.ReadMs + metrics.ReconcileMs + metrics.MarshalMs + metrics.WriteMs + if metrics.TotalMs > 0 && math.Abs(float64(measuredTotal-metrics.TotalMs)) > float64(metrics.TotalMs)*0.2 { t.Logf("Warning: timing components (%dms) don't add up to total (%dms)", - measuredTotal, totalMs) + measuredTotal, metrics.TotalMs) } }) }) diff --git a/internal/state/state_test.go b/internal/state/state_test.go index ee88d15..35df3c1 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -237,7 +237,7 @@ func TestReconcileState(t *testing.T) { } } -func TestSaveStateToFile(t *testing.T) { +func TestSaveStateToFileWithMetrics(t *testing.T) { t.Run("Basic save without reconciliation", func(t *testing.T) { // Create temp file tmpFile := t.TempDir() + "/state.json" @@ -252,7 +252,7 @@ func TestSaveStateToFile(t *testing.T) { verifiedCache.Set("5.6.7.8", true, lru.DefaultExpiration) // Save without reconciliation - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, false, // no reconciliation rateCache, @@ -266,18 +266,18 @@ func TestSaveStateToFile(t *testing.T) { } // Verify timing metrics - if lockMs < 0 || readMs < 0 || reconcileMs < 0 || marshalMs < 0 || writeMs < 0 || totalMs < 0 { + if metrics.LockMs < 0 || metrics.ReadMs < 0 || metrics.ReconcileMs < 0 || metrics.MarshalMs < 0 || metrics.WriteMs < 0 || metrics.TotalMs < 0 { t.Error("Expected all timing metrics to be non-negative") } // Verify reconcileMs is 0 when reconciliation is disabled - if reconcileMs != 0 { - t.Errorf("Expected reconcileMs to be 0 when reconciliation disabled, got %d", reconcileMs) + if metrics.ReconcileMs != 0 { + t.Errorf("Expected reconcileMs to be 0 when reconciliation disabled, got %d", metrics.ReconcileMs) } // Verify readMs is 0 when reconciliation is disabled - if readMs != 0 { - t.Errorf("Expected readMs to be 0 when reconciliation disabled, got %d", readMs) + if metrics.ReadMs != 0 { + t.Errorf("Expected readMs to be 0 when reconciliation disabled, got %d", metrics.ReadMs) } // Verify file was created and contains data @@ -341,7 +341,7 @@ func TestSaveStateToFile(t *testing.T) { rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) // Save with reconciliation enabled - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, true, // enable reconciliation rateCache, @@ -355,22 +355,22 @@ func TestSaveStateToFile(t *testing.T) { } // Verify timing metrics (all should be non-negative) - if lockMs < 0 { + if metrics.LockMs < 0 { t.Error("Expected non-negative lockMs") } - if readMs < 0 { + if metrics.ReadMs < 0 { t.Error("Expected non-negative readMs when reconciliation is enabled") } - if reconcileMs < 0 { + if metrics.ReconcileMs < 0 { t.Error("Expected non-negative reconcileMs when reconciliation is enabled") } - if marshalMs < 0 { + if metrics.MarshalMs < 0 { t.Error("Expected non-negative marshalMs") } - if writeMs < 0 { + if metrics.WriteMs < 0 { t.Error("Expected non-negative writeMs") } - if totalMs < 0 { + if metrics.TotalMs < 0 { t.Error("Expected non-negative totalMs") } @@ -432,7 +432,7 @@ func TestSaveStateToFile(t *testing.T) { botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - _, _, _, _, _, _, err := SaveStateToFile( + _, err := SaveStateToFileWithMetrics( invalidPath, false, rateCache, @@ -682,7 +682,7 @@ func TestSaveStateToFileWithMetricsWriteError(t *testing.T) { func TestAtomicWriteStateFileCreateTempError(t *testing.T) { missingDir := filepath.Join(t.TempDir(), "missing") - err := atomicWriteStateFile(filepath.Join(missingDir, "state.json"), nil, nil, nil, 0600) + _, _, err := atomicWriteStateFile(filepath.Join(missingDir, "state.json"), nil, nil, nil, 0600) if err == nil { t.Fatal("expected atomicWriteStateFile to fail when temp directory is missing") } @@ -694,6 +694,7 @@ func TestWriteStateJSON(t *testing.T) { verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + rateCache.Set("bad\a-key", uint(2), lru.DefaultExpiration) botCache.Set("1.2.3.4", true, lru.DefaultExpiration) botCache.Set("5.6.7.8", false, lru.DefaultExpiration) verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) @@ -711,9 +712,12 @@ func TestWriteStateJSON(t *testing.T) { if err := json.Unmarshal(buf.Bytes(), &saved); err != nil { t.Fatalf("state JSON did not unmarshal: %v", err) } - if len(saved.Rate) != 1 || len(saved.Bots) != 2 || len(saved.Verified) != 1 { + if len(saved.Rate) != 2 || len(saved.Bots) != 2 || len(saved.Verified) != 1 { t.Fatalf("unexpected saved counts: rate=%d bots=%d verified=%d", len(saved.Rate), len(saved.Bots), len(saved.Verified)) } + if saved.Rate["bad\a-key"].Value != float64(2) { + t.Fatal("expected JSON-escaped rate key to round-trip") + } if saved.Bots["1.2.3.4"].Value != true { t.Fatal("expected true bot value to be written") } @@ -974,7 +978,7 @@ func TestSaveAndLoadStateWithExpiration_Synctest(t *testing.T) { verifiedCache1.Set("9.9.9.9", true, lru.NoExpiration) // Save state - _, _, _, _, _, _, err := SaveStateToFile( + _, err := SaveStateToFileWithMetrics( tmpFile, false, rateCache1, @@ -1045,7 +1049,7 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { initialCache := lru.New(1*time.Hour, lru.NoExpiration) initialCache.Set("192.168.0.0", uint(100), 5*time.Second) - _, _, _, _, _, _, err := SaveStateToFile( + _, err := SaveStateToFileWithMetrics( tmpFile, false, initialCache, @@ -1067,7 +1071,7 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { newCache.Set("192.168.0.0", uint(200), 8*time.Second) // expires at start+10s // Save with reconciliation enabled - _, _, _, _, _, _, err = SaveStateToFile( + _, err = SaveStateToFileWithMetrics( tmpFile, true, // reconcile newCache, diff --git a/main.go b/main.go index 4695339..a1a9bf9 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( crand "crypto/rand" "encoding/json" "fmt" + "hash/fnv" htemplate "html/template" "log/slog" "math/big" @@ -1014,7 +1015,6 @@ func (bc *CaptchaProtect) isGoodBot(req *http.Request, clientIP string) bool { v = helper.IsIpGoodBot(clientIP, bc.config.GoodBots) } bc.botCache.Set(clientIP, v, lru.DefaultExpiration) - bc.markStateDirty() return v } @@ -1097,9 +1097,23 @@ func stateSaveJitter() time.Duration { jitter, err := crand.Int(crand.Reader, maxJitter) if err != nil { + return fallbackStateSaveJitter(maxJitter.Int64()) + } + + return time.Duration(jitter.Int64()) * time.Millisecond +} + +func fallbackStateSaveJitter(maxMillis int64) time.Duration { + if maxMillis <= 0 { return 0 } + hostname, _ := os.Hostname() + hash := fnv.New64a() + _, _ = fmt.Fprintf(hash, "%s:%d:%d", hostname, os.Getpid(), time.Now().UnixNano()) + + jitter := new(big.Int).SetUint64(hash.Sum64()) + jitter.Mod(jitter, big.NewInt(maxMillis)) return time.Duration(jitter.Int64()) * time.Millisecond } @@ -1139,8 +1153,12 @@ func (bc *CaptchaProtect) saveStateNow() bool { } func (bc *CaptchaProtect) loadState() { + bc.loadStateFrom(bc.config.PersistentStateFile) +} + +func (bc *CaptchaProtect) loadStateFrom(filePath string) { err := state.LoadStateFromFile( - bc.config.PersistentStateFile, + filePath, bc.rateCache, bc.botCache, bc.verifiedCache, @@ -1152,7 +1170,7 @@ func (bc *CaptchaProtect) loadState() { } bc.log.Info("Loaded previous state") - bc.refreshStateFileModTime() + bc.refreshStateFileModTimeFrom(filePath) } func (bc *CaptchaProtect) markStateDirty() { @@ -1195,10 +1213,16 @@ func (bc *CaptchaProtect) markStateSaved(dirty uint64) { } func (bc *CaptchaProtect) refreshStateFileModTime() { - info, err := os.Stat(bc.config.PersistentStateFile) + bc.refreshStateFileModTimeFrom(bc.config.PersistentStateFile) +} + +func (bc *CaptchaProtect) refreshStateFileModTimeFrom(filePath string) { + info, err := os.Stat(filePath) if err != nil { return } + bc.stateMu.Lock() + defer bc.stateMu.Unlock() bc.stateFileModTime = info.ModTime() } @@ -1208,7 +1232,10 @@ func (bc *CaptchaProtect) reconcileStateFromFileIfChanged() { return } modTime := info.ModTime() - if !bc.stateFileModTime.IsZero() && !modTime.After(bc.stateFileModTime) { + bc.stateMu.Lock() + lastModTime := bc.stateFileModTime + bc.stateMu.Unlock() + if !lastModTime.IsZero() && !modTime.After(lastModTime) { return } @@ -1226,7 +1253,9 @@ func (bc *CaptchaProtect) reconcileStateFromFileIfChanged() { if info, err := os.Stat(bc.config.PersistentStateFile); err == nil { modTime = info.ModTime() } + bc.stateMu.Lock() bc.stateFileModTime = modTime + bc.stateMu.Unlock() bc.log.Debug("Reconciled newer state file") } diff --git a/main_test.go b/main_test.go index 3c35efd..4947f9b 100644 --- a/main_test.go +++ b/main_test.go @@ -1022,8 +1022,7 @@ func TestStatePersistence(t *testing.T) { // Create new instance - should load state bc2, _ := NewCaptchaProtect(context.Background(), nil, config, "test") - bc2.config.PersistentStateFile = tmpFile - bc2.loadState() + bc2.loadStateFrom(tmpFile) // Check rate cache val, found := bc2.rateCache.Get("192.168.0.0") @@ -1044,7 +1043,7 @@ func TestStatePersistence(t *testing.T) { } } -func TestRegisterRequestCapsPersistedRateCounter(t *testing.T) { +func TestRegisterRequestStopsIncrementingAfterRateLimitTrips(t *testing.T) { tmpFile := filepath.Join(t.TempDir(), "state.json") bc := newStateOnlyCaptchaProtect(tmpFile, 2) @@ -1058,7 +1057,7 @@ func TestRegisterRequestCapsPersistedRateCounter(t *testing.T) { t.Fatal("Expected rate cache entry") } if got, want := v.(uint), bc.config.RateLimit+1; got != want { - t.Fatalf("Expected rate counter to cap at %d, got %d", want, got) + t.Fatalf("Expected sequential requests to stop incrementing at %d, got %d", want, got) } if got, want := bc.currentStateDirty(), uint64(bc.config.RateLimit+1); got != want { t.Fatalf("Expected dirty counter to track only effective mutations, got %d want %d", got, want) @@ -1137,9 +1136,9 @@ func TestStateBookkeepingErrorBranches(t *testing.T) { if err := os.WriteFile(invalidFile, []byte("{invalid json"), 0600); err != nil { t.Fatalf("failed to write invalid state file: %v", err) } - bc.config.PersistentStateFile = invalidFile - bc.reconcileStateFromFileIfChanged() - if !bc.stateFileModTime.IsZero() { + invalidBC := newStateOnlyCaptchaProtect(invalidFile, 2) + invalidBC.reconcileStateFromFileIfChanged() + if !invalidBC.stateFileModTime.IsZero() { t.Fatal("failed reconciliation should not advance state file mod time") } } @@ -1343,8 +1342,7 @@ func TestLoadStateInvalidJSON(t *testing.T) { if err != nil { t.Errorf("Should not fail on invalid state JSON: %v", err) } - bc.config.PersistentStateFile = tmpFile - bc.loadState() + bc.loadStateFrom(tmpFile) // Caches should be empty if bc.rateCache.ItemCount() != 0 { From bd1cf1fca4814128c91263b2c46f45b083869494 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Thu, 7 May 2026 11:40:09 +0000 Subject: [PATCH 08/12] Sync lock files and remove bot state persistence --- ci/parse-stress-results/main.go | 8 +- internal/state/lock.go | 32 +++-- internal/state/lock_test.go | 10 ++ internal/state/state.go | 65 +++------- internal/state/state_stress_test.go | 73 +++-------- internal/state/state_test.go | 187 +++++----------------------- main.go | 6 +- main_test.go | 11 -- 8 files changed, 101 insertions(+), 291 deletions(-) diff --git a/ci/parse-stress-results/main.go b/ci/parse-stress-results/main.go index 74e0c0e..9e8e09c 100644 --- a/ci/parse-stress-results/main.go +++ b/ci/parse-stress-results/main.go @@ -39,10 +39,10 @@ func main() { currentTest := "" // Initialize known tests - results["Small"] = &TestResult{Name: "Small", Entries: "16 rate / derived bots skipped / 256 verified", Threshold: 500} - results["Medium"] = &TestResult{Name: "Medium", Entries: "256 rate / derived bots skipped / 65K verified", Threshold: 1000} - results["Large"] = &TestResult{Name: "Large", Entries: "1K rate / derived bots skipped / 262K verified", Threshold: 3000} - results["XLarge"] = &TestResult{Name: "XLarge", Entries: "4K rate / derived bots skipped / 1M verified", Threshold: 10000} + results["Small"] = &TestResult{Name: "Small", Entries: "16 rate / 256 verified", Threshold: 500} + results["Medium"] = &TestResult{Name: "Medium", Entries: "256 rate / 65K verified", Threshold: 1000} + results["Large"] = &TestResult{Name: "Large", Entries: "1K rate / 262K verified", Threshold: 3000} + results["XLarge"] = &TestResult{Name: "XLarge", Entries: "4K rate / 1M verified", Threshold: 10000} for scanner.Scan() { line := scanner.Text() diff --git a/internal/state/lock.go b/internal/state/lock.go index 6df25c2..b8cda88 100644 --- a/internal/state/lock.go +++ b/internal/state/lock.go @@ -2,6 +2,7 @@ package state import ( "fmt" + "io" "os" "strconv" "strings" @@ -16,8 +17,9 @@ type FileLock struct { } type lockPIDFile interface { - WriteString(string) (int, error) - Close() error + io.StringWriter + io.Closer + Sync() error } // NewFileLock creates a new file lock for the given path. @@ -78,18 +80,22 @@ func (fl *FileLock) Lock() error { } } -func writeLockPID(file lockPIDFile, lockPath string, pid int) error { - _, writeErr := file.WriteString(strconv.Itoa(pid)) - closeErr := file.Close() - if writeErr != nil { - // We got the lock but failed to write. - // Best effort to clean up, then return the error. - _ = os.Remove(lockPath) - return fmt.Errorf("failed to write pid to lock file: %v", writeErr) +func writeLockPID(file lockPIDFile, lockPath string, pid int) (err error) { + defer func() { + closeErr := file.Close() + if err == nil && closeErr != nil { + err = fmt.Errorf("failed to close lock file: %w", closeErr) + } + if err != nil { + _ = os.Remove(lockPath) + } + }() + + if _, err := file.WriteString(strconv.Itoa(pid)); err != nil { + return fmt.Errorf("failed to write pid to lock file: %w", err) } - if closeErr != nil { - _ = os.Remove(lockPath) - return fmt.Errorf("failed to close lock file: %v", closeErr) + if err := file.Sync(); err != nil { + return fmt.Errorf("failed to sync lock file: %w", err) } return nil } diff --git a/internal/state/lock_test.go b/internal/state/lock_test.go index 75071c8..8cbb97b 100644 --- a/internal/state/lock_test.go +++ b/internal/state/lock_test.go @@ -15,6 +15,7 @@ import ( type fakeLockPIDFile struct { writeErr error + syncErr error closeErr error } @@ -26,6 +27,10 @@ func (f fakeLockPIDFile) Close() error { return f.closeErr } +func (f fakeLockPIDFile) Sync() error { + return f.syncErr +} + // TestFileLock_LockUnlock tests the basic Lock and Unlock functionality. func TestFileLock_LockUnlock(t *testing.T) { t.Parallel() @@ -85,6 +90,11 @@ func TestWriteLockPIDErrorsCleanUpLockFile(t *testing.T) { file: fakeLockPIDFile{closeErr: errors.New("close failed")}, wantErr: "failed to close lock file: close failed", }, + { + name: "sync error", + file: fakeLockPIDFile{syncErr: errors.New("sync failed")}, + wantErr: "failed to sync lock file: sync failed", + }, } for _, tt := range tests { diff --git a/internal/state/state.go b/internal/state/state.go index 5dc02a4..d395778 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -24,7 +24,6 @@ type CacheEntry struct { // State contains the persisted cache snapshot and approximate cache memory usage. type State struct { Rate map[string]CacheEntry `json:"rate"` - Bots map[string]CacheEntry `json:"bots"` Verified map[string]CacheEntry `json:"verified"` Memory map[string]uintptr `json:"memory"` } @@ -43,20 +42,12 @@ type persistentBoolEntry struct { type persistentState struct { Rate map[string]persistentRateEntry `json:"rate"` - Bots map[string]persistentBoolEntry `json:"bots"` Verified map[string]persistentBoolEntry `json:"verified"` Memory map[string]uintptr `json:"memory"` } -type ignoredJSON struct{} - -func (ignoredJSON) UnmarshalJSON(_ []byte) error { - return nil -} - type reconcileStateFile struct { Rate map[string]persistentRateEntry `json:"rate"` - Bots ignoredJSON `json:"bots"` // Bot cache is derived and too large to merge on every state save. Verified map[string]persistentBoolEntry `json:"verified"` } @@ -69,18 +60,16 @@ type SaveMetrics struct { WriteMs int64 TotalMs int64 RateEntries int - BotEntries int VerifiedEntries int } // GetState converts cache items into a serializable state snapshot. -func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { +func GetState(rateCache, verifiedCache map[string]lru.Item) State { state := State{ - Memory: make(map[string]uintptr, 3), + Memory: make(map[string]uintptr, 2), } state.Rate, state.Memory["rate"] = getCacheEntries[uint](rateCache) - state.Bots, state.Memory["bot"] = getCacheEntries[bool](botCache) state.Verified, state.Memory["verified"] = getCacheEntries[bool](verifiedCache) return state @@ -88,32 +77,28 @@ func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { // SetState loads state data into the provided caches, preserving expiration times. // If an entry has already expired (expiration < now), it will be skipped. -func SetState(state State, rateCache, botCache, verifiedCache *lru.Cache) { +func SetState(state State, rateCache, verifiedCache *lru.Cache) { loadCacheEntries(state.Rate, rateCache, convertRateValue) - loadCacheEntries(state.Bots, botCache, convertBoolValue) loadCacheEntries(state.Verified, verifiedCache, convertBoolValue) } // ReconcileState merges file-based state with in-memory state. -func ReconcileState(fileState State, rateCache, botCache, verifiedCache *lru.Cache) { +func ReconcileState(fileState State, rateCache, verifiedCache *lru.Cache) { rateItems := rateCache.Items() - botItems := botCache.Items() verifiedItems := verifiedCache.Items() // Use "max value wins" for rate cache reconcileRateCache(fileState.Rate, rateItems, rateCache, convertRateValue) - // Use "later expiration wins" for bot and verified caches - reconcileCacheEntries(fileState.Bots, botItems, botCache, convertBoolValue) + // Use "later expiration wins" for verified caches reconcileCacheEntries(fileState.Verified, verifiedItems, verifiedCache, convertBoolValue) } // SaveStateToFileWithMetrics saves rate and verified state to a file with locking. -// Bot cache entries are loaded from legacy files but are not persisted because they are derived state. func SaveStateToFileWithMetrics( filePath string, reconcile bool, - rateCache, botCache, verifiedCache *lru.Cache, + rateCache, verifiedCache *lru.Cache, log *slog.Logger, ) (SaveMetrics, error) { startTime := time.Now() @@ -147,16 +132,12 @@ func SaveStateToFileWithMetrics( } } - // Bot cache entries are derived from DNS/IP checks and can dwarf the - // state that needs cross-instance sharing. Persist only rate limiter and - // verified-user state; existing files with bot entries still load. rateItems := rateCache.Items() verifiedItems := verifiedCache.Items() metrics.RateEntries = len(rateItems) - metrics.BotEntries = 0 metrics.VerifiedEntries = len(verifiedItems) - metrics.MarshalMs, metrics.WriteMs, err = atomicWriteStateFile(filePath, rateItems, nil, verifiedItems, 0600) + metrics.MarshalMs, metrics.WriteMs, err = atomicWriteStateFile(filePath, rateItems, verifiedItems, 0600) if err != nil { return metrics, err } @@ -167,7 +148,7 @@ func SaveStateToFileWithMetrics( func atomicWriteStateFile( filePath string, - rateItems, botItems, verifiedItems map[string]lru.Item, + rateItems, verifiedItems map[string]lru.Item, perm os.FileMode, ) (marshalMs, writeMs int64, err error) { dir := filepath.Dir(filePath) @@ -180,7 +161,7 @@ func atomicWriteStateFile( writer := bufio.NewWriterSize(tmp, 1024*1024) marshalStart := time.Now() - if err := writeStateJSON(writer, rateItems, botItems, verifiedItems); err != nil { + if err := writeStateJSON(writer, rateItems, verifiedItems); err != nil { _ = tmp.Close() return 0, 0, err } @@ -207,7 +188,7 @@ func atomicWriteStateFile( func writeStateJSON( writer *bufio.Writer, - rateItems, botItems, verifiedItems map[string]lru.Item, + rateItems, verifiedItems map[string]lru.Item, ) error { if err := writeString(writer, `{"rate":`); err != nil { return err @@ -216,13 +197,6 @@ func writeStateJSON( if err != nil { return err } - if err := writeString(writer, `,"bots":`); err != nil { - return err - } - botMemory, err := writeCacheEntryMap[bool](writer, botItems, writeBool) - if err != nil { - return err - } if err := writeString(writer, `,"verified":`); err != nil { return err } @@ -237,12 +211,6 @@ func writeStateJSON( if err := writeString(writer, strconv.FormatUint(uint64(rateMemory), 10)); err != nil { return err } - if err := writeString(writer, `,"bot":`); err != nil { - return err - } - if err := writeString(writer, strconv.FormatUint(uint64(botMemory), 10)); err != nil { - return err - } if err := writeString(writer, `,"verified":`); err != nil { return err } @@ -368,7 +336,7 @@ func isPlainJSONString(value string) bool { // LoadStateFromFile loads state from a file with locking. func LoadStateFromFile( filePath string, - rateCache, botCache, verifiedCache *lru.Cache, + rateCache, verifiedCache *lru.Cache, ) error { lock, err := NewFileLock(filePath + ".lock") if err != nil { @@ -391,8 +359,7 @@ func LoadStateFromFile( return err } - // Use SetState which properly handles expiration times - setPersistentState(loadedState, rateCache, botCache, verifiedCache) + setPersistentState(loadedState, rateCache, verifiedCache) return nil } @@ -400,7 +367,7 @@ func LoadStateFromFile( // ReconcileStateFromFile merges newer persisted state into the provided caches. func ReconcileStateFromFile( filePath string, - rateCache, botCache, verifiedCache *lru.Cache, + rateCache, verifiedCache *lru.Cache, ) error { lock, err := NewFileLock(filePath + ".lock") if err != nil { @@ -491,9 +458,8 @@ func loadCacheEntries[T any]( } } -func setPersistentState(state persistentState, rateCache, botCache, verifiedCache *lru.Cache) { +func setPersistentState(state persistentState, rateCache, verifiedCache *lru.Cache) { loadPersistentRateEntries(state.Rate, rateCache) - loadPersistentBoolEntries(state.Bots, botCache) loadPersistentBoolEntries(state.Verified, verifiedCache) } @@ -578,8 +544,7 @@ func reconcilePersistentRateCache( } } -// reconcileCacheEntries implements "later expiration wins" -// This is correct for bool flags (Verified, Bots). +// reconcileCacheEntries implements "later expiration wins" for bool flags. func reconcileCacheEntries[T any]( fileEntries map[string]CacheEntry, memItems map[string]lru.Item, diff --git a/internal/state/state_stress_test.go b/internal/state/state_stress_test.go index ffba4a1..732d2c5 100644 --- a/internal/state/state_stress_test.go +++ b/internal/state/state_stress_test.go @@ -16,13 +16,13 @@ import ( // This file contains stress tests for state persistence operations at various scales. // // Performance Findings (Apple M2 Pro): -// Small (16 rate / derived bots skipped / 256 verified): +// Small (16 rate / 256 verified): // - SaveStateToFile with reconciliation: ~84ms -// Medium (256 rate / derived bots skipped / 65K verified): +// Medium (256 rate / 65K verified): // - SaveStateToFile with reconciliation: ~410ms -// Large (1,024 rate / derived bots skipped / 262K verified): +// Large (1,024 rate / 262K verified): // - SaveStateToFile with reconciliation: ~1.8s -// XLarge (4,096 rate / derived bots skipped / 1M verified): +// XLarge (4,096 rate / 1M verified): // - SaveStateToFile with reconciliation: ~8.7s (approaching 10s save window limit) // // Recommendation: Do not enable enableStateReconciliation for sites with >1M unique visitors. @@ -36,7 +36,6 @@ import ( type StressLevel struct { Name string RateEntries int - BotEntries int VerifiedEntries int } @@ -46,26 +45,22 @@ func getStressLevels() []StressLevel { return []StressLevel{ { Name: "Small", - RateEntries: 1 << 4, // 2^4 = 16 - BotEntries: 1 << 16, // 2^16 = 65,536 - VerifiedEntries: 1 << 8, // 2^8 = 256 + RateEntries: 1 << 4, // 2^4 = 16 + VerifiedEntries: 1 << 8, // 2^8 = 256 }, { Name: "Medium", RateEntries: 1 << 8, // 2^8 = 256 - BotEntries: 1 << 18, // 2^18 = 262,144 (capped from 2^32) VerifiedEntries: 1 << 16, // 2^16 = 65,536 }, { Name: "Large", RateEntries: 1 << 10, // 2^10 = 1,024 (capped from 2^16) - BotEntries: 1 << 20, // 2^20 = 1,048,576 (capped from 2^64) VerifiedEntries: 1 << 18, // 2^18 = 262,144 (capped from 2^32) }, { Name: "XLarge", RateEntries: 1 << 12, // 2^12 = 4,096 - BotEntries: 1 << 22, // 2^22 = 4,194,304 VerifiedEntries: 1 << 20, // 2^20 = 1,048,576 }, } @@ -80,7 +75,7 @@ func generateIPv4Subnet(index int) string { return fmt.Sprintf("%d.%d.0.0", a, b) } -// generateIPv4Address generates a unique IPv4 address for bot/verified caches +// generateIPv4Address generates a unique IPv4 address for verified caches // Uses the pattern: A.B.C.D where all octets are derived from the index func generateIPv4Address(index int) string { a := (index >> 24) & 0xFF @@ -91,7 +86,7 @@ func generateIPv4Address(index int) string { } // populateCaches fills caches with test data based on the stress level -func populateCaches(level StressLevel, rateCache, botCache, verifiedCache *lru.Cache) { +func populateCaches(level StressLevel, rateCache, verifiedCache *lru.Cache) { expiration := 24 * time.Hour // Populate rate cache with subnet entries @@ -102,32 +97,7 @@ func populateCaches(level StressLevel, rateCache, botCache, verifiedCache *lru.C rateCache.Set(subnet, rate, expiration) } - // Populate bot cache with IP addresses - for i := 0; i < level.BotEntries; i++ { - ip := generateIPv4Address(i) - // Alternate between verified and unverified bots - isBot := i%2 == 0 - botCache.Set(ip, isBot, expiration) - } - // Populate verified cache with IP addresses - // Use different starting index to avoid overlap with bot cache - startOffset := 0x10000000 // Start from 16.0.0.0 - for i := 0; i < level.VerifiedEntries; i++ { - ip := generateIPv4Address(startOffset + i) - verifiedCache.Set(ip, true, expiration) - } -} - -func populatePersistentCaches(level StressLevel, rateCache, verifiedCache *lru.Cache) { - expiration := 24 * time.Hour - - for i := 0; i < level.RateEntries; i++ { - subnet := generateIPv4Subnet(i) - rate := uint(1 + (i % 100)) - rateCache.Set(subnet, rate, expiration) - } - startOffset := 0x10000000 // Start from 16.0.0.0 for i := 0; i < level.VerifiedEntries; i++ { ip := generateIPv4Address(startOffset + i) @@ -142,24 +112,23 @@ func BenchmarkStateOperations(b *testing.B) { for _, level := range levels { // Create caches and populate with test data rateCache := lru.New(24*time.Hour, lru.NoExpiration) - botCache := lru.New(24*time.Hour, lru.NoExpiration) verifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - b.Logf("Populating caches for %s level (rate=%d, bots=%d, verified=%d)...", - level.Name, level.RateEntries, level.BotEntries, level.VerifiedEntries) - populateCaches(level, rateCache, botCache, verifiedCache) + b.Logf("Populating caches for %s level (rate=%d, verified=%d)...", + level.Name, level.RateEntries, level.VerifiedEntries) + populateCaches(level, rateCache, verifiedCache) // Benchmark GetState (extract to struct) b.Run(fmt.Sprintf("GetState/%s", level.Name), func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - _ = GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + _ = GetState(rateCache.Items(), verifiedCache.Items()) } }) // Benchmark Marshal - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + state := GetState(rateCache.Items(), verifiedCache.Items()) b.Run(fmt.Sprintf("Marshal/%s", level.Name), func(b *testing.B) { b.ReportAllocs() var jsonData []byte @@ -188,11 +157,10 @@ func BenchmarkStateOperations(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) b.StartTimer() - SetState(state, newRateCache, newBotCache, newVerifiedCache) + SetState(state, newRateCache, newVerifiedCache) } }) @@ -204,7 +172,6 @@ func BenchmarkStateOperations(b *testing.B) { b.StopTimer() // Create fresh caches with some overlapping data newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) // Pre-populate with 50% of entries for j := 0; j < level.RateEntries/2; j++ { @@ -213,7 +180,7 @@ func BenchmarkStateOperations(b *testing.B) { } b.StartTimer() - ReconcileState(state, newRateCache, newBotCache, newVerifiedCache) + ReconcileState(state, newRateCache, newVerifiedCache) } }) @@ -230,7 +197,6 @@ func BenchmarkStateOperations(b *testing.B) { tmpFile, true, // enable reconciliation rateCache, - botCache, verifiedCache, logger, ) @@ -296,12 +262,11 @@ func TestStateOperationsWithinThreshold(t *testing.T) { t.Run(level.Name, func(t *testing.T) { // Create and populate caches rateCache := lru.New(24*time.Hour, lru.NoExpiration) - botCache := lru.New(24*time.Hour, lru.NoExpiration) verifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - t.Logf("Populating persistent caches (rate=%d, derived bots skipped, verified=%d)...", + t.Logf("Populating persistent caches (rate=%d, verified=%d)...", level.RateEntries, level.VerifiedEntries) - populatePersistentCaches(level, rateCache, verifiedCache) + populateCaches(level, rateCache, verifiedCache) thresh := levelThresholds[level.Name] @@ -340,7 +305,6 @@ func TestStateOperationsWithinThreshold(t *testing.T) { tmpFile, false, rateCache, - botCache, verifiedCache, logger, ); err != nil { @@ -352,7 +316,6 @@ func TestStateOperationsWithinThreshold(t *testing.T) { tmpFile, true, // enable reconciliation rateCache, - botCache, verifiedCache, logger, ) @@ -422,7 +385,7 @@ func TestGenerateUniqueIPs(t *testing.T) { func marshalPersistentSnapshotForStress(rateCache, verifiedCache *lru.Cache) ([]byte, error) { var buf bytes.Buffer writer := bufio.NewWriter(&buf) - if err := writeStateJSON(writer, rateCache.Items(), nil, verifiedCache.Items()); err != nil { + if err := writeStateJSON(writer, rateCache.Items(), verifiedCache.Items()); err != nil { return nil, err } if err := writer.Flush(); err != nil { diff --git a/internal/state/state_test.go b/internal/state/state_test.go index 35df3c1..e9e0a4b 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -17,20 +17,16 @@ import ( func TestGetState(t *testing.T) { // Create test caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) // Add test data rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) - botCache.Set("1.2.3.4", true, lru.DefaultExpiration) - botCache.Set("5.6.7.8", false, lru.DefaultExpiration) - verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) // Get state - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + state := GetState(rateCache.Items(), verifiedCache.Items()) // Verify rate cache data if len(state.Rate) != 2 { @@ -47,21 +43,6 @@ func TestGetState(t *testing.T) { t.Error("Expected non-zero expiration for 192.168.0.0") } - // Verify bot cache data - if len(state.Bots) != 2 { - t.Errorf("Expected 2 bot entries, got %d", len(state.Bots)) - } - if state.Bots["1.2.3.4"].Value != true { - t.Error("Expected bot 1.2.3.4 to be true") - } - if state.Bots["5.6.7.8"].Value != false { - t.Error("Expected bot 5.6.7.8 to be false") - } - // Verify expiration timestamps are set - if state.Bots["1.2.3.4"].Expiration == 0 { - t.Error("Expected non-zero expiration for bot 1.2.3.4") - } - // Verify verified cache data if len(state.Verified) != 1 { t.Errorf("Expected 1 verified entry, got %d", len(state.Verified)) @@ -75,15 +56,12 @@ func TestGetState(t *testing.T) { } // Verify memory tracking exists - if len(state.Memory) != 3 { - t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + if len(state.Memory) != 2 { + t.Errorf("Expected 2 memory entries, got %d", len(state.Memory)) } if state.Memory["rate"] == 0 { t.Error("Expected non-zero memory for rate cache") } - if state.Memory["bot"] == 0 { - t.Error("Expected non-zero memory for bot cache") - } if state.Memory["verified"] == 0 { t.Error("Expected non-zero memory for verified cache") } @@ -92,22 +70,18 @@ func TestGetState(t *testing.T) { func TestGetStateEmpty(t *testing.T) { // Create empty caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + state := GetState(rateCache.Items(), verifiedCache.Items()) if len(state.Rate) != 0 { t.Errorf("Expected 0 rate entries, got %d", len(state.Rate)) } - if len(state.Bots) != 0 { - t.Errorf("Expected 0 bot entries, got %d", len(state.Bots)) - } if len(state.Verified) != 0 { t.Errorf("Expected 0 verified entries, got %d", len(state.Verified)) } - if len(state.Memory) != 3 { - t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + if len(state.Memory) != 2 { + t.Errorf("Expected 2 memory entries, got %d", len(state.Memory)) } } @@ -122,10 +96,6 @@ func TestSetState(t *testing.T) { "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, "10.0.0.0": {Value: uint(5), Expiration: pastExpiration}, // expired }, - Bots: map[string]CacheEntry{ - "1.2.3.4": {Value: true, Expiration: futureExpiration}, - "5.6.7.8": {Value: false, Expiration: pastExpiration}, // expired - }, Verified: map[string]CacheEntry{ "9.9.9.9": {Value: true, Expiration: futureExpiration}, "8.8.8.8": {Value: true, Expiration: pastExpiration}, // expired @@ -135,11 +105,10 @@ func TestSetState(t *testing.T) { // Create caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) // Set state - SetState(state, rateCache, botCache, verifiedCache) + SetState(state, rateCache, verifiedCache) // Verify only non-expired entries were loaded if rateCache.ItemCount() != 1 { @@ -152,13 +121,6 @@ func TestSetState(t *testing.T) { t.Error("Expected expired entry 10.0.0.0 to be filtered out") } - if botCache.ItemCount() != 1 { - t.Errorf("Expected 1 bot entry (expired filtered out), got %d", botCache.ItemCount()) - } - if v, ok := botCache.Get("1.2.3.4"); !ok || v.(bool) != true { - t.Error("Expected bot 1.2.3.4 to be true") - } - if verifiedCache.ItemCount() != 2 { t.Errorf("Expected 2 verified entries (1 expired filtered out), got %d", verifiedCache.ItemCount()) } @@ -193,7 +155,6 @@ func TestReconcileState(t *testing.T) { // Create memory caches with some overlapping data rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), time.Duration(oldExpiration-now)) // older, should be replaced @@ -203,7 +164,7 @@ func TestReconcileState(t *testing.T) { verifiedCache.Set("2.2.2.2", true, time.Duration(newExpiration-now)) // newer, should be kept // Reconcile - ReconcileState(fileState, rateCache, botCache, verifiedCache) + ReconcileState(fileState, rateCache, verifiedCache) // Verify reconciliation results // 192.168.0.0 should be updated to file's value (newer expiration) @@ -244,11 +205,9 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { // Create caches with test data rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) - botCache.Set("1.2.3.4", false, lru.DefaultExpiration) verifiedCache.Set("5.6.7.8", true, lru.DefaultExpiration) // Save without reconciliation @@ -256,7 +215,6 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { tmpFile, false, // no reconciliation rateCache, - botCache, verifiedCache, testLogger(), ) @@ -306,9 +264,6 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { if len(savedState.Rate) != 1 { t.Errorf("Expected 1 rate entry, got %d", len(savedState.Rate)) } - if len(savedState.Bots) != 0 { - t.Errorf("Expected derived bot cache entries to be skipped, got %d", len(savedState.Bots)) - } if len(savedState.Verified) != 1 { t.Errorf("Expected 1 verified entry, got %d", len(savedState.Verified)) } @@ -324,9 +279,8 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { Rate: map[string]CacheEntry{ "10.0.0.0": {Value: uint(5), Expiration: futureExpiration}, }, - Bots: map[string]CacheEntry{}, Verified: map[string]CacheEntry{}, - Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, } initialData, _ := json.Marshal(initialState) if err := os.WriteFile(tmpFile, initialData, 0644); err != nil { @@ -335,7 +289,6 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { // Create caches with different data rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) @@ -345,7 +298,6 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { tmpFile, true, // enable reconciliation rateCache, - botCache, verifiedCache, testLogger(), ) @@ -391,18 +343,15 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { tmpFile := t.TempDir() + "/state.json" rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) - botCache.Set("1.2.3.4", false, lru.DefaultExpiration) verifiedCache.Set("5.6.7.8", true, lru.DefaultExpiration) metrics, err := SaveStateToFileWithMetrics( tmpFile, false, rateCache, - botCache, verifiedCache, testLogger(), ) @@ -413,9 +362,6 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { if metrics.RateEntries != 1 { t.Errorf("Expected 1 rate entry, got %d", metrics.RateEntries) } - if metrics.BotEntries != 0 { - t.Errorf("Expected derived bot cache entries to be skipped, got %d", metrics.BotEntries) - } if metrics.VerifiedEntries != 1 { t.Errorf("Expected 1 verified entry, got %d", metrics.VerifiedEntries) } @@ -429,14 +375,12 @@ func TestSaveStateToFileWithMetrics(t *testing.T) { invalidPath := "/invalid/directory/that/does/not/exist/state.json" rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) _, err := SaveStateToFileWithMetrics( invalidPath, false, rateCache, - botCache, verifiedCache, testLogger(), ) @@ -459,13 +403,10 @@ func TestLoadStateFromFile(t *testing.T) { "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, "10.0.0.0": {Value: uint(5), Expiration: futureExpiration}, }, - Bots: map[string]CacheEntry{ - "1.2.3.4": {Value: true, Expiration: futureExpiration}, - }, Verified: map[string]CacheEntry{ "5.6.7.8": {Value: true, Expiration: futureExpiration}, }, - Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, } data, _ := json.Marshal(testState) @@ -475,10 +416,9 @@ func TestLoadStateFromFile(t *testing.T) { // Load into empty caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err != nil { t.Fatalf("LoadStateFromFile failed: %v", err) } @@ -487,9 +427,6 @@ func TestLoadStateFromFile(t *testing.T) { if rateCache.ItemCount() != 2 { t.Errorf("Expected 2 rate entries, got %d", rateCache.ItemCount()) } - if botCache.ItemCount() != 1 { - t.Errorf("Expected 1 bot entry, got %d", botCache.ItemCount()) - } if verifiedCache.ItemCount() != 1 { t.Errorf("Expected 1 verified entry, got %d", verifiedCache.ItemCount()) } @@ -498,9 +435,6 @@ func TestLoadStateFromFile(t *testing.T) { if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 10 { t.Error("Expected rate 10 for 192.168.0.0") } - if v, ok := botCache.Get("1.2.3.4"); !ok || v.(bool) != true { - t.Error("Expected bot 1.2.3.4 to be true") - } if v, ok := verifiedCache.Get("5.6.7.8"); !ok || v.(bool) != true { t.Error("Expected 5.6.7.8 to be verified") } @@ -516,9 +450,8 @@ func TestLoadStateFromFile(t *testing.T) { Rate: map[string]CacheEntry{ "192.168.0.0": {Value: uint(10), Expiration: pastExpiration}, // expired }, - Bots: map[string]CacheEntry{}, Verified: map[string]CacheEntry{}, - Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, } data, _ := json.Marshal(testState) @@ -528,10 +461,9 @@ func TestLoadStateFromFile(t *testing.T) { } // Load into empty caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err = LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err = LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err != nil { t.Fatalf("LoadStateFromFile failed: %v", err) } @@ -546,10 +478,9 @@ func TestLoadStateFromFile(t *testing.T) { nonExistentFile := t.TempDir() + "/does-not-exist.json" rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err := LoadStateFromFile(nonExistentFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(nonExistentFile, rateCache, verifiedCache) if err == nil { t.Error("Expected error for non-existent file, got nil") } @@ -564,10 +495,9 @@ func TestLoadStateFromFile(t *testing.T) { } rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err == nil { t.Error("Expected error for invalid JSON, got nil") } @@ -587,11 +517,10 @@ func TestLoadStateFromFile(t *testing.T) { } rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) // Empty file returns nil (no state to load, which is fine) - err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err != nil { t.Errorf("Unexpected error for empty file: %v", err) } @@ -613,11 +542,8 @@ func TestReconcileStateFromFile(t *testing.T) { Rate: map[string]CacheEntry{ "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, }, - Bots: map[string]CacheEntry{ - "1.2.3.4": {Value: true, Expiration: futureExpiration}, - }, Verified: map[string]CacheEntry{}, - Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, } fileState.Verified = map[string]CacheEntry{ "5.6.7.8": {Value: true, Expiration: futureExpiration}, @@ -629,13 +555,12 @@ func TestReconcileStateFromFile(t *testing.T) { } rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) verifiedCache.Set("9.9.9.9", false, time.Duration(laterExpiration-now)) - if err := ReconcileStateFromFile(tmpFile, rateCache, botCache, verifiedCache); err != nil { + if err := ReconcileStateFromFile(tmpFile, rateCache, verifiedCache); err != nil { t.Fatalf("ReconcileStateFromFile failed: %v", err) } @@ -645,9 +570,6 @@ func TestReconcileStateFromFile(t *testing.T) { if v, ok := rateCache.Get("10.0.0.0"); !ok || v.(uint) != 5 { t.Error("Expected existing memory state to be retained") } - if botCache.ItemCount() != 0 { - t.Errorf("Expected derived bot cache entries to be skipped during reconciliation, got %d", botCache.ItemCount()) - } if v, ok := verifiedCache.Get("5.6.7.8"); !ok || v.(bool) != true { t.Error("Expected file verified state to be reconciled into memory") } @@ -660,7 +582,6 @@ func TestSaveStateToFileWithMetricsWriteError(t *testing.T) { statePath := t.TempDir() rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) @@ -668,7 +589,6 @@ func TestSaveStateToFileWithMetricsWriteError(t *testing.T) { statePath, false, rateCache, - botCache, verifiedCache, testLogger(), ) @@ -682,7 +602,7 @@ func TestSaveStateToFileWithMetricsWriteError(t *testing.T) { func TestAtomicWriteStateFileCreateTempError(t *testing.T) { missingDir := filepath.Join(t.TempDir(), "missing") - _, _, err := atomicWriteStateFile(filepath.Join(missingDir, "state.json"), nil, nil, nil, 0600) + _, _, err := atomicWriteStateFile(filepath.Join(missingDir, "state.json"), nil, nil, 0600) if err == nil { t.Fatal("expected atomicWriteStateFile to fail when temp directory is missing") } @@ -690,18 +610,15 @@ func TestAtomicWriteStateFileCreateTempError(t *testing.T) { func TestWriteStateJSON(t *testing.T) { rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) rateCache.Set("bad\a-key", uint(2), lru.DefaultExpiration) - botCache.Set("1.2.3.4", true, lru.DefaultExpiration) - botCache.Set("5.6.7.8", false, lru.DefaultExpiration) verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) var buf bytes.Buffer writer := bufio.NewWriter(&buf) - if err := writeStateJSON(writer, rateCache.Items(), botCache.Items(), verifiedCache.Items()); err != nil { + if err := writeStateJSON(writer, rateCache.Items(), verifiedCache.Items()); err != nil { t.Fatalf("writeStateJSON failed: %v", err) } if err := writer.Flush(); err != nil { @@ -712,18 +629,12 @@ func TestWriteStateJSON(t *testing.T) { if err := json.Unmarshal(buf.Bytes(), &saved); err != nil { t.Fatalf("state JSON did not unmarshal: %v", err) } - if len(saved.Rate) != 2 || len(saved.Bots) != 2 || len(saved.Verified) != 1 { - t.Fatalf("unexpected saved counts: rate=%d bots=%d verified=%d", len(saved.Rate), len(saved.Bots), len(saved.Verified)) + if len(saved.Rate) != 2 || len(saved.Verified) != 1 { + t.Fatalf("unexpected saved counts: rate=%d verified=%d", len(saved.Rate), len(saved.Verified)) } if saved.Rate["bad\a-key"].Value != float64(2) { t.Fatal("expected JSON-escaped rate key to round-trip") } - if saved.Bots["1.2.3.4"].Value != true { - t.Fatal("expected true bot value to be written") - } - if saved.Bots["5.6.7.8"].Value != false { - t.Fatal("expected false bot value to be written") - } } func TestWriteStateJSONUnexpectedType(t *testing.T) { @@ -733,7 +644,6 @@ func TestWriteStateJSONUnexpectedType(t *testing.T) { writer, map[string]lru.Item{"bad": {Object: "not-a-uint", Expiration: time.Now().Add(time.Hour).UnixNano()}}, nil, - nil, ) if err == nil { t.Fatal("expected writeStateJSON to reject unexpected cache value type") @@ -741,15 +651,14 @@ func TestWriteStateJSONUnexpectedType(t *testing.T) { } func TestReconcileStateFromFileEmptyAndInvalidFiles(t *testing.T) { - newCaches := func() (*lru.Cache, *lru.Cache, *lru.Cache) { + newCaches := func() (*lru.Cache, *lru.Cache) { return lru.New(1*time.Hour, 1*time.Minute), - lru.New(1*time.Hour, 1*time.Minute), lru.New(1*time.Hour, 1*time.Minute) } t.Run("missing file", func(t *testing.T) { - rateCache, botCache, verifiedCache := newCaches() - err := ReconcileStateFromFile(filepath.Join(t.TempDir(), "missing.json"), rateCache, botCache, verifiedCache) + rateCache, verifiedCache := newCaches() + err := ReconcileStateFromFile(filepath.Join(t.TempDir(), "missing.json"), rateCache, verifiedCache) if err == nil { t.Fatal("expected missing state file to return read error") } @@ -761,8 +670,8 @@ func TestReconcileStateFromFileEmptyAndInvalidFiles(t *testing.T) { t.Fatalf("failed to write empty state file: %v", err) } - rateCache, botCache, verifiedCache := newCaches() - if err := ReconcileStateFromFile(tmpFile, rateCache, botCache, verifiedCache); err != nil { + rateCache, verifiedCache := newCaches() + if err := ReconcileStateFromFile(tmpFile, rateCache, verifiedCache); err != nil { t.Fatalf("empty state file should be a no-op: %v", err) } }) @@ -773,8 +682,8 @@ func TestReconcileStateFromFileEmptyAndInvalidFiles(t *testing.T) { t.Fatalf("failed to write invalid state file: %v", err) } - rateCache, botCache, verifiedCache := newCaches() - if err := ReconcileStateFromFile(tmpFile, rateCache, botCache, verifiedCache); err == nil { + rateCache, verifiedCache := newCaches() + if err := ReconcileStateFromFile(tmpFile, rateCache, verifiedCache); err == nil { t.Fatal("expected invalid state file to return unmarshal error") } }) @@ -804,12 +713,6 @@ func TestSetStateWithExpiration_Synctest(t *testing.T) { Expiration: start.Add(10 * time.Second).UnixNano(), // expires in 10s }, }, - Bots: map[string]CacheEntry{ - "1.2.3.4": { - Value: true, - Expiration: start.Add(3 * time.Second).UnixNano(), // expires in 3s - }, - }, Verified: map[string]CacheEntry{ "9.9.9.9": { Value: true, @@ -820,32 +723,23 @@ func TestSetStateWithExpiration_Synctest(t *testing.T) { // Create empty caches (no cleanup interval to avoid background goroutines) rateCache := lru.New(1*time.Hour, lru.NoExpiration) - botCache := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache := lru.New(1*time.Hour, lru.NoExpiration) // Load state - SetState(state, rateCache, botCache, verifiedCache) + SetState(state, rateCache, verifiedCache) // Verify all entries are loaded if rateCache.ItemCount() != 2 { t.Errorf("Expected 2 rate entries, got %d", rateCache.ItemCount()) } - if botCache.ItemCount() != 1 { - t.Errorf("Expected 1 bot entry, got %d", botCache.ItemCount()) - } if verifiedCache.ItemCount() != 1 { t.Errorf("Expected 1 verified entry, got %d", verifiedCache.ItemCount()) } - // Advance time by 4 seconds (bot entry should expire, rate entries still valid) + // Advance time by 4 seconds (rate entries still valid) time.Sleep(4 * time.Second) synctest.Wait() - // Bot cache should be empty (expired at 3s) - if _, found := botCache.Get("1.2.3.4"); found { - t.Error("Bot entry should have expired after 3 seconds") - } - // Rate entries should still be present if _, found := rateCache.Get("192.168.0.0"); !found { t.Error("Rate entry 192.168.0.0 should not expire until 5 seconds") @@ -910,7 +804,6 @@ func TestReconcileStateWithExpiration_Synctest(t *testing.T) { // Create memory caches with overlapping data (no cleanup interval to avoid background goroutines) rateCache := lru.New(1*time.Hour, lru.NoExpiration) - botCache := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache := lru.New(1*time.Hour, lru.NoExpiration) // Memory entry with older expiration (should be replaced) @@ -919,7 +812,7 @@ func TestReconcileStateWithExpiration_Synctest(t *testing.T) { rateCache.Set("10.0.0.0", uint(5), 10*time.Second) // Reconcile - ReconcileState(fileState, rateCache, botCache, verifiedCache) + ReconcileState(fileState, rateCache, verifiedCache) // 192.168.0.0 should have file's value (newer expiration) if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 15 { @@ -969,12 +862,10 @@ func TestSaveAndLoadStateWithExpiration_Synctest(t *testing.T) { // Create caches with entries expiring at different times (no cleanup interval to avoid background goroutines) rateCache1 := lru.New(1*time.Hour, lru.NoExpiration) - botCache1 := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache1 := lru.New(1*time.Hour, lru.NoExpiration) rateCache1.Set("192.168.0.0", uint(10), 5*time.Second) rateCache1.Set("10.0.0.0", uint(5), 10*time.Second) - botCache1.Set("1.2.3.4", true, 3*time.Second) verifiedCache1.Set("9.9.9.9", true, lru.NoExpiration) // Save state @@ -982,7 +873,6 @@ func TestSaveAndLoadStateWithExpiration_Synctest(t *testing.T) { tmpFile, false, rateCache1, - botCache1, verifiedCache1, testLogger(), ) @@ -990,25 +880,19 @@ func TestSaveAndLoadStateWithExpiration_Synctest(t *testing.T) { t.Fatalf("SaveStateToFile failed: %v", err) } - // Advance time by 4 seconds (bot expires, rates still valid) + // Advance time by 4 seconds (rates still valid) time.Sleep(4 * time.Second) synctest.Wait() // Load into new caches (no cleanup interval to avoid background goroutines) rateCache2 := lru.New(1*time.Hour, lru.NoExpiration) - botCache2 := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache2 := lru.New(1*time.Hour, lru.NoExpiration) - err = LoadStateFromFile(tmpFile, rateCache2, botCache2, verifiedCache2) + err = LoadStateFromFile(tmpFile, rateCache2, verifiedCache2) if err != nil { t.Fatalf("LoadStateFromFile failed: %v", err) } - // Bot entry should be filtered out (expired 1 second ago) - if botCache2.ItemCount() != 0 { - t.Errorf("Expected 0 bot entries (expired), got %d", botCache2.ItemCount()) - } - // First rate entry should be loaded (expires at 5s, we're at 4s) if _, found := rateCache2.Get("192.168.0.0"); !found { t.Error("Rate entry 192.168.0.0 should be loaded (not yet expired)") @@ -1054,7 +938,6 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { false, initialCache, lru.New(1*time.Hour, lru.NoExpiration), - lru.New(1*time.Hour, lru.NoExpiration), testLogger(), ) if err != nil { @@ -1076,7 +959,6 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { true, // reconcile newCache, lru.New(1*time.Hour, lru.NoExpiration), - lru.New(1*time.Hour, lru.NoExpiration), testLogger(), ) if err != nil { @@ -1089,7 +971,6 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { tmpFile, loadedCache, lru.New(1*time.Hour, lru.NoExpiration), - lru.New(1*time.Hour, lru.NoExpiration), ) if err != nil { t.Fatalf("Load failed: %v", err) diff --git a/main.go b/main.go index a1a9bf9..f3299db 100644 --- a/main.go +++ b/main.go @@ -738,7 +738,7 @@ func (bc *CaptchaProtect) serveStatsPage(rw http.ResponseWriter, ip string) { return } - state := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) + state := state.GetState(bc.rateCache.Items(), bc.verifiedCache.Items()) jsonData, err := json.Marshal(state) if err != nil { bc.log.Error("failed to marshal JSON", "err", err) @@ -1126,7 +1126,6 @@ func (bc *CaptchaProtect) saveStateNow() bool { bc.config.PersistentStateFile, reconcile, bc.rateCache, - bc.botCache, bc.verifiedCache, bc.log, ) @@ -1140,7 +1139,6 @@ func (bc *CaptchaProtect) saveStateNow() bool { bc.log.Debug("State saved successfully", "rateEntries", metrics.RateEntries, - "botEntries", metrics.BotEntries, "verifiedEntries", metrics.VerifiedEntries, "lockMs", metrics.LockMs, "readMs", metrics.ReadMs, @@ -1160,7 +1158,6 @@ func (bc *CaptchaProtect) loadStateFrom(filePath string) { err := state.LoadStateFromFile( filePath, bc.rateCache, - bc.botCache, bc.verifiedCache, ) @@ -1242,7 +1239,6 @@ func (bc *CaptchaProtect) reconcileStateFromFileIfChanged() { err = state.ReconcileStateFromFile( bc.config.PersistentStateFile, bc.rateCache, - bc.botCache, bc.verifiedCache, ) if err != nil { diff --git a/main_test.go b/main_test.go index 4947f9b..cb8253d 100644 --- a/main_test.go +++ b/main_test.go @@ -1008,12 +1008,6 @@ func TestStatePersistence(t *testing.T) { "expiration": float64(futureExpiration), }, }, - "bots": map[string]map[string]interface{}{ - "5.6.7.8": { - "value": false, - "expiration": float64(futureExpiration), - }, - }, }) err := os.WriteFile(tmpFile, jsonData, 0644) if err != nil { @@ -1036,11 +1030,6 @@ func TestStatePersistence(t *testing.T) { t.Error("Verified cache state not persisted correctly") } - // Check bot cache - botVal, found := bc2.botCache.Get("5.6.7.8") - if !found || botVal.(bool) != false { - t.Error("Bot cache state not persisted correctly") - } } func TestRegisterRequestStopsIncrementingAfterRateLimitTrips(t *testing.T) { From d66918b45a5152c22c99b77047a63d08ae544cf6 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Thu, 7 May 2026 07:56:30 -0400 Subject: [PATCH 09/12] Update captcha-protect plugin version to v1.12.3 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e65920c..f751c39 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ services: --providers.docker=true --providers.docker.network=default --experimental.plugins.captcha-protect.modulename=github.com/libops/captcha-protect - --experimental.plugins.captcha-protect.version=v1.12.2 + --experimental.plugins.captcha-protect.version=v1.12.3 volumes: - /var/run/docker.sock:/var/run/docker.sock:z - /CHANGEME/TO/A/HOST/PATH/FOR/STATE/FILE:/tmp/state.json:rw From b0e02cce7c41021a0efda0d878d5a6811615fa81 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Thu, 7 May 2026 12:12:12 +0000 Subject: [PATCH 10/12] Test disabled state persistence path --- main_test.go | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/main_test.go b/main_test.go index cb8253d..af40881 100644 --- a/main_test.go +++ b/main_test.go @@ -1053,6 +1053,28 @@ func TestRegisterRequestStopsIncrementingAfterRateLimitTrips(t *testing.T) { } } +func TestStatePersistenceDisabledWithoutStateFile(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + + bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Fatalf("NewCaptchaProtect failed: %v", err) + } + + bc.registerRequest("192.168.0.0") + bc.markStateDirty() + + if bc.currentStateDirty() != 0 { + t.Fatal("state dirty counter should remain disabled without a persistent state file") + } + if bc.hasUnsavedState() { + t.Fatal("state should not become unsaved without a persistent state file") + } +} + func TestSaveStateFlushesDirtyStateOnCanceledContext(t *testing.T) { tmpFile := filepath.Join(t.TempDir(), "state.json") bc := newStateOnlyCaptchaProtect(tmpFile, 2) From afe2db581da9788638b53add10339edda57b1da0 Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Thu, 7 May 2026 12:16:20 +0000 Subject: [PATCH 11/12] Slow local state persistence cadence --- ci_behavior_test.go | 2 +- main.go | 18 ++++++++++++++---- main_test.go | 12 ++++++++++++ 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/ci_behavior_test.go b/ci_behavior_test.go index c615d5c..17ac51e 100644 --- a/ci_behavior_test.go +++ b/ci_behavior_test.go @@ -83,7 +83,7 @@ func TestPersistentStateSharingWithSynctest(t *testing.T) { writer.registerRequest("107.198.0.0") } - time.Sleep(StateSaveInterval + StateSaveJitter + 3*time.Second) + time.Sleep(stateSaveInterval(writer.config) + StateSaveJitter + 3*time.Second) synctest.Wait() reader.reconcileStateFromFileIfChanged() diff --git a/main.go b/main.go index f3299db..cbba0b1 100644 --- a/main.go +++ b/main.go @@ -28,8 +28,10 @@ import ( ) const ( - // StateSaveInterval is how often the persistent state file is written to disk - StateSaveInterval = 10 * time.Second + // StateSaveInterval is how often local persistent state is written to disk. + StateSaveInterval = 60 * time.Second + // StateReconciliationSaveInterval is the faster save cadence used when multiple instances share state. + StateReconciliationSaveInterval = 10 * time.Second // StateSaveJitter is the maximum random jitter added to save interval to prevent thundering herd StateSaveJitter = 2 * time.Second @@ -1038,9 +1040,10 @@ func (c *Config) ParseHttpMethods(log *slog.Logger) { func (bc *CaptchaProtect) saveState(ctx context.Context) { // Add random jitter to prevent multiple instances from trying to save simultaneously jitter := stateSaveJitter() - interval := StateSaveInterval + jitter + baseInterval := stateSaveInterval(bc.config) + interval := baseInterval + jitter - bc.log.Debug("State save configured", "baseInterval", StateSaveInterval, "jitter", jitter, "actualInterval", interval) + bc.log.Debug("State save configured", "baseInterval", baseInterval, "jitter", jitter, "actualInterval", interval) ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() @@ -1089,6 +1092,13 @@ func (bc *CaptchaProtect) saveState(ctx context.Context) { } } +func stateSaveInterval(config *Config) time.Duration { + if config.EnableStateReconciliation == "true" { + return StateReconciliationSaveInterval + } + return StateSaveInterval +} + func stateSaveJitter() time.Duration { maxJitter := big.NewInt(StateSaveJitter.Milliseconds()) if maxJitter.Sign() <= 0 { diff --git a/main_test.go b/main_test.go index af40881..5121b10 100644 --- a/main_test.go +++ b/main_test.go @@ -1075,6 +1075,18 @@ func TestStatePersistenceDisabledWithoutStateFile(t *testing.T) { } } +func TestStateSaveIntervalUsesFastCadenceOnlyWithReconciliation(t *testing.T) { + config := CreateConfig() + if got := stateSaveInterval(config); got != 60*time.Second { + t.Fatalf("default state save interval = %s, want 60s", got) + } + + config.EnableStateReconciliation = "true" + if got := stateSaveInterval(config); got != 10*time.Second { + t.Fatalf("reconciliation state save interval = %s, want 10s", got) + } +} + func TestSaveStateFlushesDirtyStateOnCanceledContext(t *testing.T) { tmpFile := filepath.Join(t.TempDir(), "state.json") bc := newStateOnlyCaptchaProtect(tmpFile, 2) From 3e94f9c418d290f577834ca7f594399c0906462d Mon Sep 17 00:00:00 2001 From: Joe Corall Date: Thu, 7 May 2026 12:17:29 +0000 Subject: [PATCH 12/12] Document state persistence cadence --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f751c39..931cf36 100644 --- a/README.md +++ b/README.md @@ -136,8 +136,8 @@ services: | `challengeStatusCode` | `int` | `200` | HTTP Response status code to return when serving a challenge | | `enableStatsPage` | `string` | `"false"` | Allows `exemptIps` to access `/captcha-protect/stats` to monitor the rate limiter. | | `logLevel` | `string` | `"INFO"` | Log level for the middleware. Options: `ERROR`, `WARNING`, `INFO`, or `DEBUG`. | -| `persistentStateFile` | `string` | `""` | File path to persist rate limiter and verified challenge state across Traefik restarts. Derived bot lookup cache entries are not persisted. In Docker, mount this file from the host. | -| `enableStateReconciliation` | `string` | `"false"` | When `"true"`, polls the shared state file for changes and merges newer disk state into memory, then reconciles again before dirty snapshots are saved. Enable for multi-instance deployments sharing state. | +| `persistentStateFile` | `string` | `""` | File path to persist rate limiter and verified challenge state across Traefik restarts. When unset, no state load/save goroutine is started. Dirty local state is saved about every 60s plus 0-2s jitter. Derived bot lookup cache entries are not persisted. In Docker, mount this file from the host. | +| `enableStateReconciliation` | `string` | `"false"` | When `"true"`, polls the shared state file for changes and merges newer disk state into memory, then reconciles again before dirty snapshots are saved. Enable for multi-instance deployments sharing state. Dirty shared state is saved about every 10s plus 0-2s jitter. | ### Circuit Breaker (failover if a captcha provider is unavailable) @@ -272,7 +272,7 @@ If you have use a computer within the `exemptIps`, and access to the command lin curl -s https://example.com/captcha-protect/stats | jq -r '.rate | to_entries | sort_by(.value) | .[] | "\(.key): \(.value)"' | tail -25 ``` -The rate limiter and verified challenge portions of this JSON state data are also found in the `state.json` file that you should have configured in your `docker-compose.yml` using the `persistentStateFile` setting and volume definition. NOTE: this file should only be changed by `captcha-protect` and not manually. +The rate limiter and verified challenge portions of this JSON state data are also found in the `state.json` file that you should have configured in your `docker-compose.yml` using the `persistentStateFile` setting and volume definition. When `enableStateReconciliation` is `"false"`, dirty state is saved roughly every 60 seconds plus 0-2 seconds of jitter. When `enableStateReconciliation` is `"true"` for multi-instance shared state, dirty state is saved roughly every 10 seconds plus 0-2 seconds of jitter. If `persistentStateFile` is unset, state persistence is disabled. NOTE: this file should only be changed by `captcha-protect` and not manually. ## Troubleshooting