diff --git a/.github/workflows/lint-test.yml b/.github/workflows/lint-test.yml index 14a5ebc..4f08669 100644 --- a/.github/workflows/lint-test.yml +++ b/.github/workflows/lint-test.yml @@ -13,7 +13,7 @@ jobs: - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: - go-version: ">=1.25.0" + go-version: "1.26.2" - name: golangci-lint uses: golangci/golangci-lint-action@1e7e51e771db61008b38414a730f564565cf7c20 # v9 diff --git a/.github/workflows/stress-test.yaml b/.github/workflows/stress-test.yaml index b9b4b0b..e92b825 100644 --- a/.github/workflows/stress-test.yaml +++ b/.github/workflows/stress-test.yaml @@ -14,7 +14,7 @@ jobs: - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: - go-version: ">=1.25.0" + go-version: "1.26.2" - name: Run stress tests id: stress_test diff --git a/README.md b/README.md index 41e10ec..931cf36 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,7 @@ services: --providers.docker=true --providers.docker.network=default --experimental.plugins.captcha-protect.modulename=github.com/libops/captcha-protect - --experimental.plugins.captcha-protect.version=v1.12.2 + --experimental.plugins.captcha-protect.version=v1.12.3 volumes: - /var/run/docker.sock:/var/run/docker.sock:z - /CHANGEME/TO/A/HOST/PATH/FOR/STATE/FILE:/tmp/state.json:rw @@ -122,7 +122,7 @@ services: | `window` | `int` | `86400` | Duration (in seconds) for monitoring requests per subnet. | | `ipv4subnetMask` | `int` | `16` | CIDR subnet mask to group IPv4 addresses for rate limiting. | | `ipv6subnetMask` | `int` | `64` | CIDR subnet mask to group IPv6 addresses for rate limiting. | -| `ipForwardedHeader` | `string` | `""` | Header to check for the original client IP if Traefik is behind a load balancer. | +| `ipForwardedHeader` | `string` | `""` | Header to check for the original client IP if Traefik is behind a load balancer. Only set this when Traefik receives the header from a trusted proxy/load balancer; otherwise clients can spoof their IP. | | `ipDepth` | `int` | `0` | How deep past the last non-exempt IP to fetch the real IP from `ipForwardedHeader`. Default 0 returns the last IP in the forward header | | `goodBots` | `[]string` (encouraged) | *see below* | List of second-level domains for bots that are never challenged or rate-limited. | | `enableGooglebotIPCheck`| `string`. | `"false"` | Treat IPs coming from googlebot's known IP ranges as good bots | @@ -136,8 +136,8 @@ services: | `challengeStatusCode` | `int` | `200` | HTTP Response status code to return when serving a challenge | | `enableStatsPage` | `string` | `"false"` | Allows `exemptIps` to access `/captcha-protect/stats` to monitor the rate limiter. | | `logLevel` | `string` | `"INFO"` | Log level for the middleware. Options: `ERROR`, `WARNING`, `INFO`, or `DEBUG`. | -| `persistentStateFile` | `string` | `""` | File path to persist rate limiter state across Traefik restarts. In Docker, mount this file from the host. | -| `enableStateReconciliation` | `string` | `"false"` | When `"true"`, reads and merges disk state before each save to prevent multiple instances from overwriting data. Adds extra I/O overhead. Only enable for multi-instance deployments sharing state. **Performance warning**: Not recommended for sites with >1M unique visitors due to reconciliation overhead (5-8s per cycle at scale). | +| `persistentStateFile` | `string` | `""` | File path to persist rate limiter and verified challenge state across Traefik restarts. When unset, no state load/save goroutine is started. Dirty local state is saved about every 60s plus 0-2s jitter. Derived bot lookup cache entries are not persisted. In Docker, mount this file from the host. | +| `enableStateReconciliation` | `string` | `"false"` | When `"true"`, polls the shared state file for changes and merges newer disk state into memory, then reconciles again before dirty snapshots are saved. Enable for multi-instance deployments sharing state. Dirty shared state is saved about every 10s plus 0-2s jitter. | ### Circuit Breaker (failover if a captcha provider is unavailable) @@ -272,7 +272,7 @@ If you have use a computer within the `exemptIps`, and access to the command lin curl -s https://example.com/captcha-protect/stats | jq -r '.rate | to_entries | sort_by(.value) | .[] | "\(.key): \(.value)"' | tail -25 ``` -This JSON state data is also found in the `state.json` file that you should have configured in your `docker-compose.yml` using the `persistentStateFile` setting and volume definition. NOTE: this file should only be changed by `captcha-protect` and not manually. +The rate limiter and verified challenge portions of this JSON state data are also found in the `state.json` file that you should have configured in your `docker-compose.yml` using the `persistentStateFile` setting and volume definition. When `enableStateReconciliation` is `"false"`, dirty state is saved roughly every 60 seconds plus 0-2 seconds of jitter. When `enableStateReconciliation` is `"true"` for multi-instance shared state, dirty state is saved roughly every 10 seconds plus 0-2 seconds of jitter. If `persistentStateFile` is unset, state persistence is disabled. NOTE: this file should only be changed by `captcha-protect` and not manually. ## Troubleshooting diff --git a/ci/docker-compose.yml b/ci/docker-compose.yml index 92a9623..3559bb4 100644 --- a/ci/docker-compose.yml +++ b/ci/docker-compose.yml @@ -20,7 +20,7 @@ services: traefik.http.middlewares.captcha-protect.plugin.captcha-protect.logLevel: "DEBUG" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectParameters: "${PROTECT_PARAMETERS:-false}" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.goodBots: "" - traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "true" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "false" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.mode: "regex" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectRoutes: "^/" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.excludeRoutes: "\\/oai\\/request,\\/node\\/\\d+\\/(book-)?manifest" @@ -54,7 +54,7 @@ services: traefik.http.middlewares.captcha-protect.plugin.captcha-protect.logLevel: "DEBUG" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectParameters: "${PROTECT_PARAMETERS:-false}" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.goodBots: "" - traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "true" + traefik.http.middlewares.captcha-protect.plugin.captcha-protect.enableGooglebotIPCheck: "false" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.mode: "regex" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.protectRoutes: "^/" traefik.http.middlewares.captcha-protect.plugin.captcha-protect.excludeRoutes: "\\/oai\\/request,\\/node\\/\\d+\\/(book-)?manifest" diff --git a/ci/parse-stress-results/main.go b/ci/parse-stress-results/main.go index 2d591f1..9e8e09c 100644 --- a/ci/parse-stress-results/main.go +++ b/ci/parse-stress-results/main.go @@ -32,17 +32,17 @@ func main() { scanner := bufio.NewScanner(os.Stdin) // Patterns to extract data - sizePattern := regexp.MustCompile(`size: ([\d.]+) MB`) + sizePattern := regexp.MustCompile(`size: ([\d.]+) (KB|MB)`) timePattern := regexp.MustCompile(`took (\d+)ms`) results := make(map[string]*TestResult) currentTest := "" // Initialize known tests - results["Small"] = &TestResult{Name: "Small", Entries: "16 rate / 65K bots / 256 verified", Threshold: 500} - results["Medium"] = &TestResult{Name: "Medium", Entries: "256 rate / 262K bots / 65K verified", Threshold: 1000} - results["Large"] = &TestResult{Name: "Large", Entries: "1K rate / 1M bots / 262K verified", Threshold: 3000} - results["XLarge"] = &TestResult{Name: "XLarge", Entries: "4K rate / 4.2M bots / 1M verified", Threshold: 10000} + results["Small"] = &TestResult{Name: "Small", Entries: "16 rate / 256 verified", Threshold: 500} + results["Medium"] = &TestResult{Name: "Medium", Entries: "256 rate / 65K verified", Threshold: 1000} + results["Large"] = &TestResult{Name: "Large", Entries: "1K rate / 262K verified", Threshold: 3000} + results["XLarge"] = &TestResult{Name: "XLarge", Entries: "4K rate / 1M verified", Threshold: 10000} for scanner.Scan() { line := scanner.Text() @@ -65,9 +65,9 @@ func main() { // Extract size from Marshal test if event.Output != "" && strings.Contains(event.Output, "Marshal took") && strings.Contains(event.Output, "size:") { - if matches := sizePattern.FindStringSubmatch(event.Output); len(matches) > 1 { + if matches := sizePattern.FindStringSubmatch(event.Output); len(matches) > 2 { if currentTest != "" && results[currentTest] != nil { - results[currentTest].Size = matches[1] + " MB" + results[currentTest].Size = matches[1] + " " + matches[2] } } } diff --git a/ci/test.go b/ci/test.go index f1aed01..f30d78b 100755 --- a/ci/test.go +++ b/ci/test.go @@ -1,496 +1,130 @@ package main import ( - "encoding/json" "fmt" - "io" "log/slog" - "math/rand" - "net" "net/http" "os" "os/exec" "strings" - "sync" "time" - - cp "github.com/libops/captcha-protect" - "github.com/libops/captcha-protect/internal/helper" -) - -var ( - rateLimit = 5 - exemptIps []*net.IPNet ) -const numIPs = 100 -const parallelism = 10 +const rateLimit = 5 func main() { - log := slog.New(slog.NewTextHandler(os.Stdout, nil)) - googleCIDRs, err := helper.FetchGoogleCrawlerIPs(log, http.DefaultClient, helper.GoogleCrawlerIPRangeURLs) - if err != nil { - slog.Error("unable to fetch google crawler ips", "err", err) - os.Exit(1) - } - - _ips := []string{ - "127.0.0.0/8", - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - "fc00::/8", - } - _ips = append(_ips, googleCIDRs...) - for _, ip := range _ips { - parsedIp := parseCIDR(ip) - exemptIps = append(exemptIps, parsedIp) - } - - fmt.Printf("Checking rate limit %d\n", rateLimit) - - fmt.Printf("Generating %d IPs\n", numIPs) - ips := generateUniquePublicIPs(numIPs) + _ = os.Remove("./tmp/state.json") fmt.Println("Bringing traefik/nginx online") runCommand("docker", "compose", "up", "-d") waitForService("http://localhost") waitForService("http://localhost/app2") - waitForGoogleExemptionReady(googleCIDRs) - fmt.Printf("Making sure %d attempt(s) pass\n", rateLimit) - runParallelChecks(ips, rateLimit, "http://localhost") - statePath := "./tmp/state.json" - runCommand("jq", ".", statePath) - - fmt.Printf("Making sure attempt #%d causes a redirect to the challenge page\n", rateLimit+1) - ensureRedirect(ips, "http://localhost") - testExcludeRouteRegexBypass(ips) - - fmt.Println("\nTesting state sharing between nginx instances...") - time.Sleep(cp.StateSaveInterval + cp.StateSaveJitter + (5 * time.Second)) - - testStateSharing(ips) - testGoogleBotGetsThrough(googleCIDRs) - - runCommand("docker", "container", "stats", "--no-stream") - - // now restart the containers and make sure the previous state reloaded - runCommand("docker", "compose", "down") - runCommand("docker", "compose", "up", "-d") - waitForService("http://localhost") - time.Sleep(10 * time.Second) - checkStateReload() - - runCommand("rm", "-f", statePath) + fmt.Println("Testing Traefik plugin smoke path...") + assertProtectedRoute("107.198.130.166", "http://localhost", "http://localhost/challenge?destination=%2F") + assertNoRedirect("107.198.130.166", "http://localhost/node/123/manifest") + assertNoRedirect("107.198.130.166", "http://localhost/oai/request?foo=bar") + assertProtectedRoute("108.198.130.167", "http://localhost/app2", "http://localhost/challenge?destination=%2Fapp2") -} - -func generateUniquePublicIPs(n int) []string { - ipSet := make(map[string]struct{}) - var ips []string - config := cp.CreateConfig() - bc := &cp.CaptchaProtect{} - bc.SetExemptIps(exemptIps) - err := bc.SetIpv4Mask(16) - if err != nil { - slog.Error("unable to set ipv4 mask") - os.Exit(1) - } - - err = bc.SetIpv6Mask(64) - if err != nil { - slog.Error("unable to set ipv6 mask") - os.Exit(1) - } - - for len(ips) < n { - ip := randomPublicIP(config) - ip, ipRange := bc.ParseIp(ip) - if _, exists := ipSet[ipRange]; !exists { - ipSet[ipRange] = struct{}{} - ips = append(ips, ip) - } - } - - return ips -} - -func randomPublicIP(config *cp.Config) string { - for { - ip := fmt.Sprintf("%d.%d.%d.%d", - rand.Intn(255)+1, - rand.Intn(255), - rand.Intn(255), - rand.Intn(254)+1, - ) - - if !helper.IsIpExcluded(ip, exemptIps) && !helper.IsIpGoodBot(ip, config.GoodBots) { - return ip - } - } + _ = os.Remove("./tmp/state.json") + fmt.Println("✓ Traefik plugin smoke test passed") } func waitForService(url string) { - for { - resp, err := http.Get(url) + deadline := time.Now().Add(90 * time.Second) + for time.Now().Before(deadline) { + resp, err := http.Get(url) // #nosec G107 -- CI smoke test only calls fixed localhost URLs. if err == nil && resp.StatusCode < 500 { - resp.Body.Close() - time.Sleep(5 * time.Second) // Give it time to stabilize + _ = resp.Body.Close() return } + if resp != nil { + _ = resp.Body.Close() + } fmt.Println("waiting for traefik/nginx to come online...") time.Sleep(1 * time.Second) } -} - -func runParallelChecks(ips []string, rateLimit int, url string) { - var wg sync.WaitGroup - sem := make(chan struct{}, parallelism) - - for range rateLimit { - for _, ip := range ips { - wg.Add(1) - sem <- struct{}{} - go func(ip string) { - defer wg.Done() - defer func() { <-sem }() - fmt.Printf("Checking %s\n", ip) - output := httpRequest(ip, url) - if output != "" { - slog.Error("Unexpected output", "ip", ip, "output", output) - os.Exit(1) - - } - }(ip) - } - } - - wg.Wait() + slog.Error("Timed out waiting for service", "url", url) + os.Exit(1) } -func ensureRedirect(ips []string, url string) { - expectedURL := url + "/challenge?destination=%2F" - if url != "http://localhost" { - // For /app2, the destination should be the app2 path - expectedURL = "http://localhost/challenge?destination=%2Fapp2" +func assertProtectedRoute(ip, url, expectedURL string) { + for i := 0; i < rateLimit; i++ { + assertNoRedirect(ip, url) } - for _, ip := range ips { - fmt.Printf("Checking %s\n", ip) - output := httpRequest(ip, url) - - if output != expectedURL { - slog.Error("Unexpected output", "ip", ip, "output", output, "expected", expectedURL) - os.Exit(1) - } - - fmt.Printf("Got a redirect! %s\n", output) - } -} - -func testExcludeRouteRegexBypass(ips []string) { - fmt.Println("\nTesting regex excludeRoutes bypass...") - - testIP := ips[0] - tests := []struct { - url string - name string - }{ - { - url: "http://localhost/node/123/manifest", - name: "/node/123/manifest", - }, - { - url: "http://localhost/oai/request?foo=bar", - name: "/oai/request?foo=bar", - }, + output, err := httpRequest(ip, url) + if err != nil { + slog.Error("Request failed", "ip", ip, "url", url, "err", err) + os.Exit(1) } - - for _, tt := range tests { - fmt.Printf("Checking excluded route %s with IP %s\n", tt.name, testIP) - output := httpRequest(testIP, tt.url) - if output != "" { - slog.Error("Excluded route was unexpectedly challenged", "ip", testIP, "route", tt.name, "output", output) - os.Exit(1) - } + if output != expectedURL { + slog.Error("Expected protected route to redirect", "ip", ip, "url", url, "output", output, "expected", expectedURL) + os.Exit(1) } - - fmt.Println("✓ regex excludeRoutes bypass works for excluded paths") } -func testStateSharing(ips []string) { - // Use first IP to test state sharing - testIP := ips[0] - - fmt.Printf("Testing with IP: %s\n", testIP) - - // The IP should already be at rate limit from previous tests on localhost/ - // Now verify it's also rate limited on localhost/app2 (shared state) - fmt.Println("Verifying IP is rate limited on /app2 (state should be shared)...") - output := httpRequest(testIP, "http://localhost/app2") - expectedURL := "http://localhost/challenge?destination=%2Fapp2" - - if output != expectedURL { - slog.Error("State NOT shared between instances!", "ip", testIP, "output", output, "expected", expectedURL) +func assertNoRedirect(ip, url string) { + output, err := httpRequest(ip, url) + if err != nil { + slog.Error("Request failed", "ip", ip, "url", url, "err", err) + os.Exit(1) + } + if output != "" { + slog.Error("Unexpected redirect", "ip", ip, "url", url, "output", output) os.Exit(1) } - - fmt.Println("✓ State is correctly shared between nginx instances!") } -func httpRequest(ip, url string) string { +func httpRequest(ip, url string) (string, error) { client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { - // Capture the redirect URL and stop following it if len(via) > 0 { return http.ErrUseLastResponse } return nil }, + Timeout: 10 * time.Second, } - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - slog.Error("Failed to create request", "err", err) - os.Exit(1) + return "", err } req.Header.Set("X-Forwarded-For", ip) + resp, err := client.Do(req) if err != nil { - slog.Error("Request failed", "err", err) - os.Exit(1) - + return "", err } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - // Get redirect URL from response location, err := resp.Location() if err != nil { if err == http.ErrNoLocation { - return "" + return "", nil } - slog.Error("Failed to get redirect URL", "err", err) - os.Exit(1) - + return "", err } - return strings.TrimSpace(location.String()) + return strings.TrimSpace(location.String()), nil } -// runCommand runs a shell command. func runCommand(name string, args ...string) { - runCommandWithEnv(nil, name, args...) -} - -func runCommandWithEnv(env map[string]string, name string, args ...string) { - cmd := exec.Command(name, args...) + cmd := exec.Command(name, args...) // #nosec G204 -- CI smoke test invokes fixed docker compose commands. cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - cmd.Env = append(cmd.Env, fmt.Sprintf("RATE_LIMIT=%d", rateLimit)) - cmd.Env = append(cmd.Env, fmt.Sprintf("PATH=%s", os.Getenv("PATH"))) + cmd.Env = append(os.Environ(), fmt.Sprintf("RATE_LIMIT=%d", rateLimit)) - tt := os.Getenv("TRAEFIK_TAG") - if tt != "" { - cmd.Env = append(cmd.Env, fmt.Sprintf("TRAEFIK_TAG=%s", tt)) - } - for k, v := range env { - cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + if traefikTag := os.Getenv("TRAEFIK_TAG"); traefikTag != "" { + cmd.Env = append(cmd.Env, fmt.Sprintf("TRAEFIK_TAG=%s", traefikTag)) } + if err := cmd.Run(); err != nil { slog.Error("Command failed", "err", err) os.Exit(1) } } - -func checkStateReload() { - resp, err := http.Get("http://localhost/captcha-protect/stats") - if err != nil { - slog.Error("Failed to make GET request", "err", err) - os.Exit(1) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - slog.Error("Failed to read response body", "err", err) - os.Exit(1) - - } - var jsonResponse map[string]interface{} - err = json.Unmarshal(body, &jsonResponse) - if err != nil { - slog.Error("Failed to unmarshal JSON", "err", err) - os.Exit(1) - - } - bots, exists := jsonResponse["bots"] - if !exists { - slog.Error("Key 'bots' not found in JSON response") - os.Exit(1) - } - botsMap, ok := bots.(map[string]interface{}) - if !ok { - slog.Error("'bots' is not an array") - os.Exit(1) - } - - if len(botsMap) < numIPs { - slog.Error("Unexpected number of bots", "expected", numIPs, "received", len(botsMap)) - os.Exit(1) - } - - slog.Info("State reloaded successfully!") -} - -func parseCIDR(cidr string) *net.IPNet { - _, block, err := net.ParseCIDR(cidr) - if err != nil { - slog.Error("Failed to parse CIDR", "cidr", cidr, "err", err) - } - return block -} - -func getIPFromCIDR(cidr string) (string, error) { - ip, ipnet, err := net.ParseCIDR(cidr) - if err != nil { - return "", err - } - - // For IPv4, increment the IP to get a usable host address - if ip.To4() != nil { - // Clone the IP to avoid modifying the original - newIP := make(net.IP, len(ip)) - copy(newIP, ip) - - for i := len(newIP) - 1; i >= 0; i-- { - newIP[i]++ - if newIP[i] > 0 { - break - } - } - - // If the new IP is the broadcast address, we can't use it. - // This is a simplistic check, and might not cover all cases for small subnets. - // A more robust solution might be needed for very small CIDR ranges. - if !ipnet.Contains(newIP) { - // This can happen for /31 or /32. For now, we just return the network address. - return ip.String(), nil - } - // make sure we don't have a broadcast address - last_ip := make(net.IP, len(ipnet.IP)) - copy(last_ip, ipnet.IP) - for i := 0; i < len(ipnet.Mask); i++ { - last_ip[i] |= ^ipnet.Mask[i] - } - if newIP.Equal(last_ip) { - return ip.String(), nil - } - - return newIP.String(), nil - } - - // For IPv6, we can usually just use the network address. - return ip.String(), nil -} - -func testGoogleBotGetsThrough(googleCIDRs []string) { - fmt.Println("\nTesting GoogleBot exemption...") - - if len(googleCIDRs) == 0 { - slog.Warn("No Google CIDRs found, skipping test") - return - } - - // Pick a Google IP - googleIP, err := getIPFromCIDR(googleCIDRs[len(googleCIDRs)-1]) - if err != nil { - slog.Error("Failed to get an IP from google CIDR", "err", err) - os.Exit(1) - } - - var output string // Declare output once here - - fmt.Printf("Checking GoogleBot IP %s without params - should always pass (making %d requests)\n", googleIP, rateLimit+1) - for i := 0; i < rateLimit+1; i++ { - output = httpRequest(googleIP, "http://localhost") // Assign value to the already declared 'output' - if output != "" { - slog.Error(fmt.Sprintf("GoogleBot with no params was challenged on request #%d", i+1), "ip", googleIP, "output", output) - os.Exit(1) - } - } - fmt.Printf("✓ GoogleBot with no params passed %d requests successfully\n", rateLimit+1) - - // now restart with PROTECT_PARAMETERS=true and test again with params - fmt.Println("\nRestarting traefik with PROTECT_PARAMETERS=true") - runCommand("docker", "compose", "down") - runCommandWithEnv(map[string]string{"PROTECT_PARAMETERS": "true"}, "docker", "compose", "up", "-d") - waitForService("http://localhost") - waitForService("http://localhost/app2") - - // Prime the rate limiter for the GoogleBot IP with parameters - fmt.Printf("Priming rate limiter for GoogleBot IP %s with params (%d requests)\n", googleIP, rateLimit) - for i := range rateLimit { - output = httpRequest(googleIP, "http://localhost/?foo=bar") // Assign value - if output != "" { - slog.Error(fmt.Sprintf("GoogleBot with params was challenged prematurely on request #%d", i+1), "ip", googleIP, "output", output) - os.Exit(1) - } - } - fmt.Printf("✓ Rate limiter primed for GoogleBot IP %s\n", googleIP) - - fmt.Printf("Checking GoogleBot IP %s with params (request #%d) - should be challenged\n", googleIP, rateLimit+1) - output = httpRequest(googleIP, "http://localhost/?foo=bar") // Assign value - expectedURL := "http://localhost/challenge?destination=%2F%3Ffoo%3Dbar" - if output != expectedURL { - slog.Error("GoogleBot with params was not challenged", "ip", googleIP, "output", output, "expected", expectedURL) - os.Exit(1) - } - fmt.Println("✓ GoogleBot with params was challenged") - - // set things back to normal for other tests - runCommand("docker", "compose", "down") -} - -func waitForGoogleExemptionReady(googleCIDRs []string) { - googleIP, err := firstUsableIPv4FromCIDRs(googleCIDRs) - if err != nil { - slog.Warn("Unable to select Google IP for readiness check; skipping warmup", "err", err) - return - } - - deadline := time.Now().Add(90 * time.Second) - for time.Now().Before(deadline) { - ready := true - for i := 0; i < rateLimit+1; i++ { - if output := httpRequest(googleIP, "http://localhost"); output != "" { - ready = false - break - } - } - if ready { - fmt.Printf("Google exemption is active for %s\n", googleIP) - return - } - time.Sleep(500 * time.Millisecond) - } - - slog.Error("Timed out waiting for Google crawler IP exemption to become active", "googleIP", googleIP) - os.Exit(1) -} - -func firstUsableIPv4FromCIDRs(cidrs []string) (string, error) { - for _, cidr := range cidrs { - ip, err := getIPFromCIDR(cidr) - if err != nil { - continue - } - parsed := net.ParseIP(ip) - if parsed != nil && parsed.To4() != nil { - return ip, nil - } - } - - return "", fmt.Errorf("no usable IPv4 found in CIDR list") -} diff --git a/ci_behavior_test.go b/ci_behavior_test.go new file mode 100644 index 0000000..17ac51e --- /dev/null +++ b/ci_behavior_test.go @@ -0,0 +1,197 @@ +package captcha_protect + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "testing/synctest" + "time" + + "github.com/libops/captcha-protect/internal/helper" + lru "github.com/patrickmn/go-cache" +) + +const ciRateLimit = uint(5) + +func TestCILabelEquivalentMiddlewareBehavior(t *testing.T) { + bc := newCILabelEquivalentMiddleware(t, nil) + ip := "107.198.130.166" + + for i := uint(0); i < ciRateLimit; i++ { + assertNoRedirect(t, bc, ip, "/") + } + assertRedirect(t, bc, ip, "/", "/challenge?destination=%2F") + + for _, route := range []string{ + "/node/123/manifest", + "/node/123/book-manifest", + "/oai/request?foo=bar", + } { + assertNoRedirect(t, bc, ip, route) + } + + appIP := "108.198.130.167" + for i := uint(0); i < ciRateLimit; i++ { + assertNoRedirect(t, bc, appIP, "/app2") + } + assertRedirect(t, bc, appIP, "/app2", "/challenge?destination=%2Fapp2") +} + +func TestCILabelEquivalentGooglebotParameterBehavior(t *testing.T) { + googleIP := "203.0.113.10" + + bypass := newCILabelEquivalentMiddleware(t, nil) + bypass.googlebotIPs = helper.NewGooglebotIPs() + bypass.googlebotIPs.Update([]string{"203.0.113.0/24"}, discardLogger()) + bypass.config.EnableGooglebotIPCheck = "true" + + for i := uint(0); i < ciRateLimit+1; i++ { + assertNoRedirect(t, bypass, googleIP, "/") + } + + protectedParams := newCILabelEquivalentMiddleware(t, func(config *Config) { + config.ProtectParameters = "true" + }) + protectedParams.googlebotIPs = helper.NewGooglebotIPs() + protectedParams.googlebotIPs.Update([]string{"203.0.113.0/24"}, discardLogger()) + protectedParams.config.EnableGooglebotIPCheck = "true" + + for i := uint(0); i < ciRateLimit; i++ { + assertNoRedirect(t, protectedParams, googleIP, "/?foo=bar") + } + assertRedirect(t, protectedParams, googleIP, "/?foo=bar", "/challenge?destination=%2F%3Ffoo%3Dbar") +} + +func TestPersistentStateSharingWithSynctest(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + stateFile := filepath.Join(t.TempDir(), "state.json") + writer := newStateOnlyCaptchaProtect(stateFile, 2) + reader := newStateOnlyCaptchaProtect(stateFile, 2) + + ctx, cancel := context.WithCancel(t.Context()) + done := make(chan struct{}, 1) + go func() { + writer.saveState(ctx) + done <- struct{}{} + }() + + for i := uint(0); i < writer.config.RateLimit+1; i++ { + writer.registerRequest("107.198.0.0") + } + + time.Sleep(stateSaveInterval(writer.config) + StateSaveJitter + 3*time.Second) + synctest.Wait() + reader.reconcileStateFromFileIfChanged() + + v, ok := reader.rateCache.Get("107.198.0.0") + if !ok { + t.Fatal("expected reader instance to reconcile writer state") + } + if got, want := v.(uint), writer.config.RateLimit+1; got != want { + t.Fatalf("reconciled rate = %d, want %d", got, want) + } + + cancel() + synctest.Wait() + <-done + }) +} + +func newCILabelEquivalentMiddleware(t *testing.T, mutate func(*Config)) *CaptchaProtect { + t.Helper() + + config := ciLabelEquivalentConfig() + if mutate != nil { + mutate(config) + } + + next := http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.WriteHeader(http.StatusNoContent) + }) + bc, err := NewCaptchaProtect(t.Context(), next, config, "captcha-protect") + if err != nil { + t.Fatalf("NewCaptchaProtect failed: %v", err) + } + + return bc +} + +func ciLabelEquivalentConfig() *Config { + config := CreateConfig() + config.RateLimit = ciRateLimit + config.Window = 120 + config.CaptchaProvider = "poj" + config.SiteKey = "test-site-key" + config.SecretKey = "test-secret-key" + config.EnableStatsPage = "true" + config.IPForwardedHeader = "X-Forwarded-For" + config.LogLevel = "ERROR" + config.ProtectParameters = "false" + config.GoodBots = []string{} + config.EnableGooglebotIPCheck = "false" + config.Mode = "regex" + config.ProtectRoutes = []string{"^/"} + config.ExcludeRoutes = []string{ + `\/oai\/request`, + `\/node\/\d+\/(book-)?manifest`, + } + return config +} + +func newStateOnlyCaptchaProtect(stateFile string, rateLimit uint) *CaptchaProtect { + config := ciLabelEquivalentConfig() + config.PersistentStateFile = stateFile + config.EnableStateReconciliation = "true" + config.RateLimit = rateLimit + + return &CaptchaProtect{ + config: config, + log: discardLogger(), + rateCache: lru.New(time.Hour, lru.NoExpiration), + botCache: lru.New(time.Hour, lru.NoExpiration), + verifiedCache: lru.New(time.Hour, lru.NoExpiration), + } +} + +func assertNoRedirect(t *testing.T, handler http.Handler, ip, target string) { + t.Helper() + + status, location := serveCITestRequest(handler, ip, target) + if location != "" { + t.Fatalf("%s returned redirect %q", target, location) + } + if status != http.StatusNoContent { + t.Fatalf("%s status = %d, want %d", target, status, http.StatusNoContent) + } +} + +func assertRedirect(t *testing.T, handler http.Handler, ip, target, expectedLocation string) { + t.Helper() + + status, location := serveCITestRequest(handler, ip, target) + if status != http.StatusFound { + t.Fatalf("%s status = %d, want %d", target, status, http.StatusFound) + } + if location != expectedLocation { + t.Fatalf("%s redirect = %q, want %q", target, location, expectedLocation) + } +} + +func serveCITestRequest(handler http.Handler, ip, target string) (int, string) { + req := httptest.NewRequest(http.MethodGet, target, nil) + req.Host = "localhost" + req.Header.Set("X-Forwarded-For", ip) + rw := httptest.NewRecorder() + + handler.ServeHTTP(rw, req) + + return rw.Code, rw.Header().Get("Location") +} + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} diff --git a/internal/state/lock.go b/internal/state/lock.go index 18137bf..b8cda88 100644 --- a/internal/state/lock.go +++ b/internal/state/lock.go @@ -2,6 +2,7 @@ package state import ( "fmt" + "io" "os" "strconv" "strings" @@ -15,6 +16,12 @@ type FileLock struct { pid int } +type lockPIDFile interface { + io.StringWriter + io.Closer + Sync() error +} + // NewFileLock creates a new file lock for the given path. // It uses a separate .lock file to coordinate access. func NewFileLock(path string) (*FileLock, error) { @@ -33,17 +40,11 @@ func (fl *FileLock) Lock() error { for { // Try to create lock file exclusively - f, err := os.OpenFile(fl.lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0644) + f, err := os.OpenFile(fl.lockPath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) if err == nil { // Successfully created lock file - _, err = f.WriteString(strconv.Itoa(fl.pid)) - f.Close() - // Check for write error - if err != nil { - // We got the lock but failed to write. - // Best effort to clean up, then return the error. - _ = os.Remove(fl.lockPath) - return fmt.Errorf("failed to write pid to lock file: %v", err) + if err := writeLockPID(f, fl.lockPath, fl.pid); err != nil { + return err } // We hold the lock return nil @@ -79,6 +80,26 @@ func (fl *FileLock) Lock() error { } } +func writeLockPID(file lockPIDFile, lockPath string, pid int) (err error) { + defer func() { + closeErr := file.Close() + if err == nil && closeErr != nil { + err = fmt.Errorf("failed to close lock file: %w", closeErr) + } + if err != nil { + _ = os.Remove(lockPath) + } + }() + + if _, err := file.WriteString(strconv.Itoa(pid)); err != nil { + return fmt.Errorf("failed to write pid to lock file: %w", err) + } + if err := file.Sync(); err != nil { + return fmt.Errorf("failed to sync lock file: %w", err) + } + return nil +} + // Unlock releases the exclusive lock by removing the lock file // This is now safer and checks the PID. func (fl *FileLock) Unlock() error { diff --git a/internal/state/lock_test.go b/internal/state/lock_test.go index 61e6a5d..8cbb97b 100644 --- a/internal/state/lock_test.go +++ b/internal/state/lock_test.go @@ -2,6 +2,7 @@ package state import ( + "errors" "fmt" "os" "path/filepath" @@ -12,6 +13,24 @@ import ( "time" ) +type fakeLockPIDFile struct { + writeErr error + syncErr error + closeErr error +} + +func (f fakeLockPIDFile) WriteString(string) (int, error) { + return 0, f.writeErr +} + +func (f fakeLockPIDFile) Close() error { + return f.closeErr +} + +func (f fakeLockPIDFile) Sync() error { + return f.syncErr +} + // TestFileLock_LockUnlock tests the basic Lock and Unlock functionality. func TestFileLock_LockUnlock(t *testing.T) { t.Parallel() @@ -27,9 +46,13 @@ func TestFileLock_LockUnlock(t *testing.T) { t.Fatalf("Lock() error = %v", err) } - if _, err := os.Stat(lockPath); err != nil { + info, err := os.Stat(lockPath) + if err != nil { t.Fatalf("lock file was not created: %v", err) } + if mode := info.Mode().Perm(); mode != 0600 { + t.Fatalf("lock file mode = %v, want 0600", mode) + } content, err := os.ReadFile(lockPath) if err != nil { @@ -49,6 +72,51 @@ func TestFileLock_LockUnlock(t *testing.T) { } } +func TestWriteLockPIDErrorsCleanUpLockFile(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + file fakeLockPIDFile + wantErr string + }{ + { + name: "write error", + file: fakeLockPIDFile{writeErr: errors.New("write failed")}, + wantErr: "failed to write pid to lock file: write failed", + }, + { + name: "close error", + file: fakeLockPIDFile{closeErr: errors.New("close failed")}, + wantErr: "failed to close lock file: close failed", + }, + { + name: "sync error", + file: fakeLockPIDFile{syncErr: errors.New("sync failed")}, + wantErr: "failed to sync lock file: sync failed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + lockPath := filepath.Join(t.TempDir(), "test.lock") + if err := os.WriteFile(lockPath, []byte("partial"), 0600); err != nil { + t.Fatalf("failed to create lock file: %v", err) + } + + err := writeLockPID(tt.file, lockPath, os.Getpid()) + if err == nil || err.Error() != tt.wantErr { + t.Fatalf("writeLockPID error = %v, want %q", err, tt.wantErr) + } + if _, statErr := os.Stat(lockPath); !os.IsNotExist(statErr) { + t.Fatalf("expected failed writeLockPID to remove lock file, stat err = %v", statErr) + } + }) + } +} + // TestFileLock_Close tests the Close functionality, including idempotency. func TestFileLock_Close(t *testing.T) { t.Parallel() diff --git a/internal/state/state.go b/internal/state/state.go index 291e5fa..d395778 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -1,12 +1,16 @@ package state import ( + "bufio" "encoding/json" "fmt" "log/slog" "os" + "path/filepath" "reflect" + "strconv" "time" + "unicode/utf8" lru "github.com/patrickmn/go-cache" ) @@ -17,20 +21,55 @@ type CacheEntry struct { Expiration int64 `json:"expiration"` // Unix timestamp in nanoseconds, 0 means no expiration } +// State contains the persisted cache snapshot and approximate cache memory usage. type State struct { Rate map[string]CacheEntry `json:"rate"` - Bots map[string]CacheEntry `json:"bots"` Verified map[string]CacheEntry `json:"verified"` Memory map[string]uintptr `json:"memory"` } -func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { +// Keep concrete entry types: Traefik's Yaegi interpreter cannot reliably handle +// generic persistentEntry[T] map fields. +type persistentRateEntry struct { + Value uint `json:"value"` + Expiration int64 `json:"expiration"` +} + +type persistentBoolEntry struct { + Value bool `json:"value"` + Expiration int64 `json:"expiration"` +} + +type persistentState struct { + Rate map[string]persistentRateEntry `json:"rate"` + Verified map[string]persistentBoolEntry `json:"verified"` + Memory map[string]uintptr `json:"memory"` +} + +type reconcileStateFile struct { + Rate map[string]persistentRateEntry `json:"rate"` + Verified map[string]persistentBoolEntry `json:"verified"` +} + +// SaveMetrics reports timing and entry counts for a state save. +type SaveMetrics struct { + LockMs int64 + ReadMs int64 + ReconcileMs int64 + MarshalMs int64 + WriteMs int64 + TotalMs int64 + RateEntries int + VerifiedEntries int +} + +// GetState converts cache items into a serializable state snapshot. +func GetState(rateCache, verifiedCache map[string]lru.Item) State { state := State{ - Memory: make(map[string]uintptr, 3), + Memory: make(map[string]uintptr, 2), } state.Rate, state.Memory["rate"] = getCacheEntries[uint](rateCache) - state.Bots, state.Memory["bot"] = getCacheEntries[bool](botCache) state.Verified, state.Memory["verified"] = getCacheEntries[bool](verifiedCache) return state @@ -38,100 +77,266 @@ func GetState(rateCache, botCache, verifiedCache map[string]lru.Item) State { // SetState loads state data into the provided caches, preserving expiration times. // If an entry has already expired (expiration < now), it will be skipped. -func SetState(state State, rateCache, botCache, verifiedCache *lru.Cache) { +func SetState(state State, rateCache, verifiedCache *lru.Cache) { loadCacheEntries(state.Rate, rateCache, convertRateValue) - loadCacheEntries(state.Bots, botCache, convertBoolValue) loadCacheEntries(state.Verified, verifiedCache, convertBoolValue) } // ReconcileState merges file-based state with in-memory state. -func ReconcileState(fileState State, rateCache, botCache, verifiedCache *lru.Cache) { +func ReconcileState(fileState State, rateCache, verifiedCache *lru.Cache) { rateItems := rateCache.Items() - botItems := botCache.Items() verifiedItems := verifiedCache.Items() // Use "max value wins" for rate cache reconcileRateCache(fileState.Rate, rateItems, rateCache, convertRateValue) - // Use "later expiration wins" for bot and verified caches - reconcileCacheEntries(fileState.Bots, botItems, botCache, convertBoolValue) + // Use "later expiration wins" for verified caches reconcileCacheEntries(fileState.Verified, verifiedItems, verifiedCache, convertBoolValue) } -// SaveStateToFile saves state to a file with locking and optional reconciliation. -// When reconcile is true, it reads and merges existing file state before saving. -// Returns timing metrics for debugging. -func SaveStateToFile( +// SaveStateToFileWithMetrics saves rate and verified state to a file with locking. +func SaveStateToFileWithMetrics( filePath string, reconcile bool, - rateCache, botCache, verifiedCache *lru.Cache, + rateCache, verifiedCache *lru.Cache, log *slog.Logger, -) (lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs int64, err error) { +) (SaveMetrics, error) { startTime := time.Now() + metrics := SaveMetrics{} lock, err := NewFileLock(filePath + ".lock") if err != nil { - return 0, 0, 0, 0, 0, 0, fmt.Errorf("failed to create lock: %w", err) + return metrics, fmt.Errorf("failed to create lock: %w", err) } defer lock.Close() if err := lock.Lock(); err != nil { - return 0, 0, 0, 0, 0, 0, fmt.Errorf("failed to acquire lock: %w", err) + return metrics, fmt.Errorf("failed to acquire lock: %w", err) } - lockDuration := time.Since(startTime) - - var readDuration, reconcileDuration, marshalDuration, writeDuration time.Duration + metrics.LockMs = time.Since(startTime).Milliseconds() // Reconcile with existing file state if enabled if reconcile { readStart := time.Now() - fileContent, readErr := os.ReadFile(filePath) - readDuration = time.Since(readStart) + fileContent, readErr := os.ReadFile(filePath) // #nosec G304 -- persistent state path is trusted middleware configuration. + metrics.ReadMs = time.Since(readStart).Milliseconds() if readErr == nil && len(fileContent) > 0 { reconcileStart := time.Now() - var fileState State + var fileState reconcileStateFile if unmarshalErr := json.Unmarshal(fileContent, &fileState); unmarshalErr == nil { log.Debug("Reconciling state before save", "fileBytes", len(fileContent)) - ReconcileState(fileState, rateCache, botCache, verifiedCache) + reconcilePersistentFileState(fileState, rateCache, verifiedCache) } - reconcileDuration = time.Since(reconcileStart) + metrics.ReconcileMs = time.Since(reconcileStart).Milliseconds() } } - // Marshal current state - marshalStart := time.Now() - currentState := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) - jsonData, err := json.Marshal(currentState) - marshalDuration = time.Since(marshalStart) + rateItems := rateCache.Items() + verifiedItems := verifiedCache.Items() + metrics.RateEntries = len(rateItems) + metrics.VerifiedEntries = len(verifiedItems) + metrics.MarshalMs, metrics.WriteMs, err = atomicWriteStateFile(filePath, rateItems, verifiedItems, 0600) if err != nil { - return lockDuration.Milliseconds(), readDuration.Milliseconds(), - reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), - 0, 0, err + return metrics, err } - // Write to disk + metrics.TotalMs = time.Since(startTime).Milliseconds() + return metrics, nil +} + +func atomicWriteStateFile( + filePath string, + rateItems, verifiedItems map[string]lru.Item, + perm os.FileMode, +) (marshalMs, writeMs int64, err error) { + dir := filepath.Dir(filePath) + tmp, err := os.CreateTemp(dir, filepath.Base(filePath)+".tmp-*") + if err != nil { + return 0, 0, err + } + tmpName := tmp.Name() + defer os.Remove(tmpName) + + writer := bufio.NewWriterSize(tmp, 1024*1024) + marshalStart := time.Now() + if err := writeStateJSON(writer, rateItems, verifiedItems); err != nil { + _ = tmp.Close() + return 0, 0, err + } + if err := writer.Flush(); err != nil { + _ = tmp.Close() + return 0, 0, err + } + marshalMs = time.Since(marshalStart).Milliseconds() + writeStart := time.Now() - err = os.WriteFile(filePath, jsonData, 0644) - writeDuration = time.Since(writeStart) + if err := tmp.Chmod(perm); err != nil { + _ = tmp.Close() + return marshalMs, 0, err + } + if err := tmp.Close(); err != nil { + return marshalMs, 0, err + } + + if err := os.Rename(tmpName, filePath); err != nil { + return marshalMs, 0, err + } + return marshalMs, time.Since(writeStart).Milliseconds(), nil +} + +func writeStateJSON( + writer *bufio.Writer, + rateItems, verifiedItems map[string]lru.Item, +) error { + if err := writeString(writer, `{"rate":`); err != nil { + return err + } + rateMemory, err := writeCacheEntryMap[uint](writer, rateItems, writeUint) + if err != nil { + return err + } + if err := writeString(writer, `,"verified":`); err != nil { + return err + } + verifiedMemory, err := writeCacheEntryMap[bool](writer, verifiedItems, writeBool) + if err != nil { + return err + } + + if err := writeString(writer, `,"memory":{"rate":`); err != nil { + return err + } + if err := writeString(writer, strconv.FormatUint(uint64(rateMemory), 10)); err != nil { + return err + } + if err := writeString(writer, `,"verified":`); err != nil { + return err + } + if err := writeString(writer, strconv.FormatUint(uint64(verifiedMemory), 10)); err != nil { + return err + } + return writeString(writer, `}}`) +} + +func writeCacheEntryMap[T any]( + writer *bufio.Writer, + items map[string]lru.Item, + writeValue func(*bufio.Writer, T) error, +) (uintptr, error) { + if err := writer.WriteByte('{'); err != nil { + return 0, err + } + + memoryUsage := cacheEntryMapSize + first := true + for key, item := range items { + value, ok := item.Object.(T) + if !ok { + return memoryUsage, fmt.Errorf("unexpected cache value type for %q", key) + } + + if !first { + if err := writer.WriteByte(','); err != nil { + return memoryUsage, err + } + } + first = false + if err := writeJSONString(writer, key); err != nil { + return memoryUsage, err + } + if err := writeString(writer, `:{"value":`); err != nil { + return memoryUsage, err + } + if err := writeValue(writer, value); err != nil { + return memoryUsage, err + } + if err := writeString(writer, `,"expiration":`); err != nil { + return memoryUsage, err + } + if err := writeString(writer, strconv.FormatInt(item.Expiration, 10)); err != nil { + return memoryUsage, err + } + if err := writer.WriteByte('}'); err != nil { + return memoryUsage, err + } + + memoryUsage += stringSize + memoryUsage += lruItemSize + memoryUsage += uintptr(len(key)) + } + + if err := writer.WriteByte('}'); err != nil { + return memoryUsage, err + } + return memoryUsage, nil +} + +var ( + cacheEntryMapSize = reflect.TypeOf(map[string]CacheEntry{}).Size() + stringSize = reflect.TypeOf("").Size() + lruItemSize = reflect.TypeOf(lru.Item{}).Size() +) + +func writeUint(writer *bufio.Writer, value uint) error { + return writeString(writer, strconv.FormatUint(uint64(value), 10)) +} + +func writeBool(writer *bufio.Writer, value bool) error { + if value { + return writeString(writer, "true") + } + return writeString(writer, "false") +} + +func writeString(writer *bufio.Writer, value string) error { + _, err := writer.WriteString(value) + return err +} + +func writeJSONString(writer *bufio.Writer, value string) error { + if isPlainJSONString(value) { + if err := writer.WriteByte('"'); err != nil { + return err + } + if err := writeString(writer, value); err != nil { + return err + } + return writer.WriteByte('"') + } + + encoded, err := json.Marshal(value) if err != nil { - return lockDuration.Milliseconds(), readDuration.Milliseconds(), - reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), - writeDuration.Milliseconds(), 0, err + return err } + _, err = writer.Write(encoded) + return err +} - totalDuration := time.Since(startTime) - return lockDuration.Milliseconds(), readDuration.Milliseconds(), - reconcileDuration.Milliseconds(), marshalDuration.Milliseconds(), - writeDuration.Milliseconds(), totalDuration.Milliseconds(), nil +func isPlainJSONString(value string) bool { + asciiOnly := true + for i := 0; i < len(value); i++ { + switch value[i] { + case '\\', '"': + return false + default: + if value[i] < 0x20 { + return false + } + if value[i] >= utf8.RuneSelf { + asciiOnly = false + } + } + } + return asciiOnly || utf8.ValidString(value) } // LoadStateFromFile loads state from a file with locking. func LoadStateFromFile( filePath string, - rateCache, botCache, verifiedCache *lru.Cache, + rateCache, verifiedCache *lru.Cache, ) error { lock, err := NewFileLock(filePath + ".lock") if err != nil { @@ -143,23 +348,51 @@ func LoadStateFromFile( return err } - fileContent, err := os.ReadFile(filePath) + fileContent, err := os.ReadFile(filePath) // #nosec G304 -- persistent state path is trusted middleware configuration. if err != nil || len(fileContent) == 0 { return err } - var loadedState State + var loadedState persistentState err = json.Unmarshal(fileContent, &loadedState) if err != nil { return err } - // Use SetState which properly handles expiration times - SetState(loadedState, rateCache, botCache, verifiedCache) + setPersistentState(loadedState, rateCache, verifiedCache) return nil } +// ReconcileStateFromFile merges newer persisted state into the provided caches. +func ReconcileStateFromFile( + filePath string, + rateCache, verifiedCache *lru.Cache, +) error { + lock, err := NewFileLock(filePath + ".lock") + if err != nil { + return err + } + defer lock.Close() + + if err := lock.Lock(); err != nil { + return err + } + + fileContent, err := os.ReadFile(filePath) // #nosec G304 -- persistent state path is trusted middleware configuration. + if err != nil || len(fileContent) == 0 { + return err + } + + var fileState reconcileStateFile + if err := json.Unmarshal(fileContent, &fileState); err != nil { + return err + } + + reconcilePersistentFileState(fileState, rateCache, verifiedCache) + return nil +} + func calculateDuration(expiration int64, now int64) time.Duration { if expiration == 0 { return lru.NoExpiration @@ -225,8 +458,93 @@ func loadCacheEntries[T any]( } } -// reconcileCacheEntries implements "later expiration wins" -// This is correct for bool flags (Verified, Bots). +func setPersistentState(state persistentState, rateCache, verifiedCache *lru.Cache) { + loadPersistentRateEntries(state.Rate, rateCache) + loadPersistentBoolEntries(state.Verified, verifiedCache) +} + +func loadPersistentRateEntries(entries map[string]persistentRateEntry, cache *lru.Cache) { + now := time.Now().UnixNano() + for key, entry := range entries { + if entry.Expiration > 0 && entry.Expiration <= now { + continue + } + cache.Set(key, entry.Value, calculateDuration(entry.Expiration, now)) + } +} + +func loadPersistentBoolEntries(entries map[string]persistentBoolEntry, cache *lru.Cache) { + now := time.Now().UnixNano() + for key, entry := range entries { + if entry.Expiration > 0 && entry.Expiration <= now { + continue + } + cache.Set(key, entry.Value, calculateDuration(entry.Expiration, now)) + } +} + +func reconcilePersistentFileState(state reconcileStateFile, rateCache, verifiedCache *lru.Cache) { + rateItems := rateCache.Items() + verifiedItems := verifiedCache.Items() + + reconcilePersistentRateCache(state.Rate, rateItems, rateCache) + reconcilePersistentBoolCacheEntries(state.Verified, verifiedItems, verifiedCache) +} + +func reconcilePersistentBoolCacheEntries( + fileEntries map[string]persistentBoolEntry, + memItems map[string]lru.Item, + cache *lru.Cache, +) { + now := time.Now().UnixNano() + for key, fileEntry := range fileEntries { + if fileEntry.Expiration > 0 && fileEntry.Expiration <= now { + continue + } + + duration := calculateDuration(fileEntry.Expiration, now) + memItem, exists := memItems[key] + if !exists { + cache.Set(key, fileEntry.Value, duration) + continue + } + + if fileEntry.Expiration > memItem.Expiration { + cache.Set(key, fileEntry.Value, duration) + } + } +} + +func reconcilePersistentRateCache( + fileEntries map[string]persistentRateEntry, + memItems map[string]lru.Item, + cache *lru.Cache, +) { + now := time.Now().UnixNano() + for key, fileEntry := range fileEntries { + if fileEntry.Expiration > 0 && fileEntry.Expiration <= now { + continue + } + + memItem, exists := memItems[key] + if !exists { + cache.Set(key, fileEntry.Value, calculateDuration(fileEntry.Expiration, now)) + continue + } + + memValue, ok := memItem.Object.(uint) + if !ok { + cache.Set(key, fileEntry.Value, calculateDuration(fileEntry.Expiration, now)) + continue + } + + combinedValue := maxUint(fileEntry.Value, memValue) + laterExpiration := max(fileEntry.Expiration, memItem.Expiration) + cache.Set(key, combinedValue, calculateDuration(laterExpiration, now)) + } +} + +// reconcileCacheEntries implements "later expiration wins" for bool flags. func reconcileCacheEntries[T any]( fileEntries map[string]CacheEntry, memItems map[string]lru.Item, diff --git a/internal/state/state_stress_test.go b/internal/state/state_stress_test.go index 7c43a8e..732d2c5 100644 --- a/internal/state/state_stress_test.go +++ b/internal/state/state_stress_test.go @@ -1,11 +1,12 @@ package state import ( + "bufio" + "bytes" "encoding/json" "fmt" "log/slog" "math" - "os" "testing" "time" @@ -15,13 +16,13 @@ import ( // This file contains stress tests for state persistence operations at various scales. // // Performance Findings (Apple M2 Pro): -// Small (16 rate / 65K bots / 256 verified → 3.87 MB JSON): +// Small (16 rate / 256 verified): // - SaveStateToFile with reconciliation: ~84ms -// Medium (256 rate / 262K bots / 65K verified → 19.31 MB JSON): +// Medium (256 rate / 65K verified): // - SaveStateToFile with reconciliation: ~410ms -// Large (1,024 rate / 1M bots / 262K verified → 77.61 MB JSON): +// Large (1,024 rate / 262K verified): // - SaveStateToFile with reconciliation: ~1.8s -// XLarge (4,096 rate / 4.2M bots / 1M verified → 312.68 MB JSON): +// XLarge (4,096 rate / 1M verified): // - SaveStateToFile with reconciliation: ~8.7s (approaching 10s save window limit) // // Recommendation: Do not enable enableStateReconciliation for sites with >1M unique visitors. @@ -35,7 +36,6 @@ import ( type StressLevel struct { Name string RateEntries int - BotEntries int VerifiedEntries int } @@ -45,26 +45,22 @@ func getStressLevels() []StressLevel { return []StressLevel{ { Name: "Small", - RateEntries: 1 << 4, // 2^4 = 16 - BotEntries: 1 << 16, // 2^16 = 65,536 - VerifiedEntries: 1 << 8, // 2^8 = 256 + RateEntries: 1 << 4, // 2^4 = 16 + VerifiedEntries: 1 << 8, // 2^8 = 256 }, { Name: "Medium", RateEntries: 1 << 8, // 2^8 = 256 - BotEntries: 1 << 18, // 2^18 = 262,144 (capped from 2^32) VerifiedEntries: 1 << 16, // 2^16 = 65,536 }, { Name: "Large", RateEntries: 1 << 10, // 2^10 = 1,024 (capped from 2^16) - BotEntries: 1 << 20, // 2^20 = 1,048,576 (capped from 2^64) VerifiedEntries: 1 << 18, // 2^18 = 262,144 (capped from 2^32) }, { Name: "XLarge", RateEntries: 1 << 12, // 2^12 = 4,096 - BotEntries: 1 << 22, // 2^22 = 4,194,304 VerifiedEntries: 1 << 20, // 2^20 = 1,048,576 }, } @@ -79,7 +75,7 @@ func generateIPv4Subnet(index int) string { return fmt.Sprintf("%d.%d.0.0", a, b) } -// generateIPv4Address generates a unique IPv4 address for bot/verified caches +// generateIPv4Address generates a unique IPv4 address for verified caches // Uses the pattern: A.B.C.D where all octets are derived from the index func generateIPv4Address(index int) string { a := (index >> 24) & 0xFF @@ -90,7 +86,7 @@ func generateIPv4Address(index int) string { } // populateCaches fills caches with test data based on the stress level -func populateCaches(level StressLevel, rateCache, botCache, verifiedCache *lru.Cache) { +func populateCaches(level StressLevel, rateCache, verifiedCache *lru.Cache) { expiration := 24 * time.Hour // Populate rate cache with subnet entries @@ -101,16 +97,7 @@ func populateCaches(level StressLevel, rateCache, botCache, verifiedCache *lru.C rateCache.Set(subnet, rate, expiration) } - // Populate bot cache with IP addresses - for i := 0; i < level.BotEntries; i++ { - ip := generateIPv4Address(i) - // Alternate between verified and unverified bots - isBot := i%2 == 0 - botCache.Set(ip, isBot, expiration) - } - // Populate verified cache with IP addresses - // Use different starting index to avoid overlap with bot cache startOffset := 0x10000000 // Start from 16.0.0.0 for i := 0; i < level.VerifiedEntries; i++ { ip := generateIPv4Address(startOffset + i) @@ -125,24 +112,23 @@ func BenchmarkStateOperations(b *testing.B) { for _, level := range levels { // Create caches and populate with test data rateCache := lru.New(24*time.Hour, lru.NoExpiration) - botCache := lru.New(24*time.Hour, lru.NoExpiration) verifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - b.Logf("Populating caches for %s level (rate=%d, bots=%d, verified=%d)...", - level.Name, level.RateEntries, level.BotEntries, level.VerifiedEntries) - populateCaches(level, rateCache, botCache, verifiedCache) + b.Logf("Populating caches for %s level (rate=%d, verified=%d)...", + level.Name, level.RateEntries, level.VerifiedEntries) + populateCaches(level, rateCache, verifiedCache) // Benchmark GetState (extract to struct) b.Run(fmt.Sprintf("GetState/%s", level.Name), func(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - _ = GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + _ = GetState(rateCache.Items(), verifiedCache.Items()) } }) // Benchmark Marshal - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + state := GetState(rateCache.Items(), verifiedCache.Items()) b.Run(fmt.Sprintf("Marshal/%s", level.Name), func(b *testing.B) { b.ReportAllocs() var jsonData []byte @@ -171,11 +157,10 @@ func BenchmarkStateOperations(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) b.StartTimer() - SetState(state, newRateCache, newBotCache, newVerifiedCache) + SetState(state, newRateCache, newVerifiedCache) } }) @@ -187,7 +172,6 @@ func BenchmarkStateOperations(b *testing.B) { b.StopTimer() // Create fresh caches with some overlapping data newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) // Pre-populate with 50% of entries for j := 0; j < level.RateEntries/2; j++ { @@ -196,12 +180,12 @@ func BenchmarkStateOperations(b *testing.B) { } b.StartTimer() - ReconcileState(state, newRateCache, newBotCache, newVerifiedCache) + ReconcileState(state, newRateCache, newVerifiedCache) } }) - // Benchmark full SaveStateToFile cycle (with reconciliation) - b.Run(fmt.Sprintf("SaveStateToFile/%s", level.Name), func(b *testing.B) { + // Benchmark full SaveStateToFileWithMetrics cycle (with reconciliation) + b.Run(fmt.Sprintf("SaveStateToFileWithMetrics/%s", level.Name), func(b *testing.B) { tmpDir := b.TempDir() tmpFile := tmpDir + "/state.json" logger := testLogger() @@ -209,11 +193,10 @@ func BenchmarkStateOperations(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, true, // enable reconciliation rateCache, - botCache, verifiedCache, logger, ) @@ -224,7 +207,7 @@ func BenchmarkStateOperations(b *testing.B) { // Report timing breakdown (only once to avoid noise) if i == 0 { b.Logf("Timing breakdown: lock=%dms read=%dms reconcile=%dms marshal=%dms write=%dms total=%dms", - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs) + metrics.LockMs, metrics.ReadMs, metrics.ReconcileMs, metrics.MarshalMs, metrics.WriteMs, metrics.TotalMs) } } }) @@ -239,56 +222,38 @@ func TestStateOperationsWithinThreshold(t *testing.T) { levels := getStressLevels() - // Define thresholds for each operation (in milliseconds) + // Define thresholds for persistence operations (in milliseconds) // These are generous limits to avoid flaky tests on slower CI machines type thresholds struct { - GetStateMs int64 - MarshalMs int64 - UnmarshalMs int64 - SetStateMs int64 - ReconcileMs int64 + MarshalMs int64 + SaveWithReconcileMs int64 } levelThresholds := map[string]thresholds{ "Small": { - GetStateMs: 100, - MarshalMs: 200, - UnmarshalMs: 200, - SetStateMs: 200, - ReconcileMs: 200, + MarshalMs: 200, + SaveWithReconcileMs: 500, }, "Medium": { - GetStateMs: 200, - MarshalMs: 500, - UnmarshalMs: 500, - SetStateMs: 500, - ReconcileMs: 500, + MarshalMs: 500, + SaveWithReconcileMs: 1000, }, "Large": { - GetStateMs: 500, - MarshalMs: 2000, - UnmarshalMs: 2000, - SetStateMs: 2000, - ReconcileMs: 2000, + MarshalMs: 2000, + SaveWithReconcileMs: 3000, }, "XLarge": { - GetStateMs: 2000, - MarshalMs: 5000, - UnmarshalMs: 5000, - SetStateMs: 3000, - ReconcileMs: 3000, + MarshalMs: 5000, + SaveWithReconcileMs: 10000, }, "XXLarge": { - GetStateMs: 5000, - MarshalMs: 15000, - UnmarshalMs: 15000, - SetStateMs: 10000, - ReconcileMs: 10000, + MarshalMs: 15000, + SaveWithReconcileMs: 30000, }, } @@ -297,44 +262,18 @@ func TestStateOperationsWithinThreshold(t *testing.T) { t.Run(level.Name, func(t *testing.T) { // Create and populate caches rateCache := lru.New(24*time.Hour, lru.NoExpiration) - botCache := lru.New(24*time.Hour, lru.NoExpiration) verifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - t.Logf("Populating caches (rate=%d, bots=%d, verified=%d)...", - level.RateEntries, level.BotEntries, level.VerifiedEntries) - populateCaches(level, rateCache, botCache, verifiedCache) + t.Logf("Populating persistent caches (rate=%d, verified=%d)...", + level.RateEntries, level.VerifiedEntries) + populateCaches(level, rateCache, verifiedCache) thresh := levelThresholds[level.Name] - // Test GetState - t.Run("GetState", func(t *testing.T) { - start := time.Now() - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) - elapsed := time.Since(start).Milliseconds() - - t.Logf("GetState took %dms (threshold: %dms)", elapsed, thresh.GetStateMs) - if elapsed > thresh.GetStateMs { - slog.Error(fmt.Sprintf("GetState took %dms, exceeds threshold of %dms", elapsed, thresh.GetStateMs)) - } - - // Verify counts - if len(state.Rate) != level.RateEntries { - t.Errorf("Expected %d rate entries, got %d", level.RateEntries, len(state.Rate)) - } - if len(state.Bots) != level.BotEntries { - t.Errorf("Expected %d bot entries, got %d", level.BotEntries, len(state.Bots)) - } - if len(state.Verified) != level.VerifiedEntries { - t.Errorf("Expected %d verified entries, got %d", level.VerifiedEntries, len(state.Verified)) - } - }) - - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) - // Test Marshal t.Run("Marshal", func(t *testing.T) { start := time.Now() - jsonData, err := json.Marshal(state) + jsonData, err := marshalPersistentSnapshotForStress(rateCache, verifiedCache) elapsed := time.Since(start).Milliseconds() if err != nil { @@ -356,83 +295,27 @@ func TestStateOperationsWithinThreshold(t *testing.T) { } }) - // Test Unmarshal - jsonData, _ := json.Marshal(state) - t.Run("Unmarshal", func(t *testing.T) { - start := time.Now() - var loadedState State - err := json.Unmarshal(jsonData, &loadedState) - elapsed := time.Since(start).Milliseconds() - - if err != nil { - t.Fatalf("Unmarshal failed: %v", err) - } - - t.Logf("Unmarshal took %dms (threshold: %dms)", elapsed, thresh.UnmarshalMs) - if elapsed > thresh.UnmarshalMs { - slog.Error(fmt.Sprintf("Unmarshal took %dms, exceeds threshold of %dms", elapsed, thresh.UnmarshalMs)) - } - }) - - // Test SetState - t.Run("SetState", func(t *testing.T) { - newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) - newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - - start := time.Now() - SetState(state, newRateCache, newBotCache, newVerifiedCache) - elapsed := time.Since(start).Milliseconds() - - t.Logf("SetState took %dms (threshold: %dms)", elapsed, thresh.SetStateMs) - if elapsed > thresh.SetStateMs { - slog.Error(fmt.Sprintf("SetState took %dms, exceeds threshold of %dms", elapsed, thresh.SetStateMs)) - } - - // Verify data was loaded - if newRateCache.ItemCount() != level.RateEntries { - t.Errorf("Expected %d rate entries after SetState, got %d", - level.RateEntries, newRateCache.ItemCount()) - } - }) - - // Test ReconcileState - t.Run("ReconcileState", func(t *testing.T) { - newRateCache := lru.New(24*time.Hour, lru.NoExpiration) - newBotCache := lru.New(24*time.Hour, lru.NoExpiration) - newVerifiedCache := lru.New(24*time.Hour, lru.NoExpiration) - - // Pre-populate with 50% overlapping data - for i := 0; i < level.RateEntries/2; i++ { - subnet := generateIPv4Subnet(i) - newRateCache.Set(subnet, uint(50), 24*time.Hour) - } - - start := time.Now() - ReconcileState(state, newRateCache, newBotCache, newVerifiedCache) - elapsed := time.Since(start).Milliseconds() - - t.Logf("ReconcileState took %dms (threshold: %dms)", elapsed, thresh.ReconcileMs) - if elapsed > thresh.ReconcileMs { - slog.Error(fmt.Sprintf("ReconcileState took %dms, exceeds threshold of %dms", elapsed, thresh.ReconcileMs)) - } - }) - - // Test full SaveStateToFile with reconciliation - t.Run("SaveStateToFile", func(t *testing.T) { + // Test full SaveStateToFileWithMetrics with reconciliation + t.Run("SaveStateToFileWithMetrics", func(t *testing.T) { tmpFile := t.TempDir() + "/state.json" logger := testLogger() // Pre-create a state file to enable reconciliation - initialData, _ := json.Marshal(state) - _ = os.WriteFile(tmpFile, initialData, 0644) + if _, err := SaveStateToFileWithMetrics( + tmpFile, + false, + rateCache, + verifiedCache, + logger, + ); err != nil { + t.Fatalf("Failed to write initial state: %v", err) + } start := time.Now() - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, true, // enable reconciliation rateCache, - botCache, verifiedCache, logger, ) @@ -444,7 +327,7 @@ func TestStateOperationsWithinThreshold(t *testing.T) { t.Logf("SaveStateToFile took %dms (threshold: %dms)", elapsed, thresh.SaveWithReconcileMs) t.Logf(" Breakdown: lock=%dms read=%dms reconcile=%dms marshal=%dms write=%dms total=%dms", - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs) + metrics.LockMs, metrics.ReadMs, metrics.ReconcileMs, metrics.MarshalMs, metrics.WriteMs, metrics.TotalMs) if elapsed > thresh.SaveWithReconcileMs { slog.Error(fmt.Sprintf("SaveStateToFile took %dms, exceeds threshold of %dms", @@ -452,10 +335,10 @@ func TestStateOperationsWithinThreshold(t *testing.T) { } // Verify math adds up (approximately, allowing for measurement overhead) - measuredTotal := lockMs + readMs + reconcileMs + marshalMs + writeMs - if totalMs > 0 && math.Abs(float64(measuredTotal-totalMs)) > float64(totalMs)*0.2 { + measuredTotal := metrics.LockMs + metrics.ReadMs + metrics.ReconcileMs + metrics.MarshalMs + metrics.WriteMs + if metrics.TotalMs > 0 && math.Abs(float64(measuredTotal-metrics.TotalMs)) > float64(metrics.TotalMs)*0.2 { t.Logf("Warning: timing components (%dms) don't add up to total (%dms)", - measuredTotal, totalMs) + measuredTotal, metrics.TotalMs) } }) }) @@ -498,3 +381,15 @@ func TestGenerateUniqueIPs(t *testing.T) { } }) } + +func marshalPersistentSnapshotForStress(rateCache, verifiedCache *lru.Cache) ([]byte, error) { + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + if err := writeStateJSON(writer, rateCache.Items(), verifiedCache.Items()); err != nil { + return nil, err + } + if err := writer.Flush(); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/internal/state/state_test.go b/internal/state/state_test.go index b1639af..e9e0a4b 100644 --- a/internal/state/state_test.go +++ b/internal/state/state_test.go @@ -1,9 +1,12 @@ package state import ( + "bufio" + "bytes" "encoding/json" "log/slog" "os" + "path/filepath" "testing" "testing/synctest" "time" @@ -14,20 +17,16 @@ import ( func TestGetState(t *testing.T) { // Create test caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) // Add test data rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) - botCache.Set("1.2.3.4", true, lru.DefaultExpiration) - botCache.Set("5.6.7.8", false, lru.DefaultExpiration) - verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) // Get state - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + state := GetState(rateCache.Items(), verifiedCache.Items()) // Verify rate cache data if len(state.Rate) != 2 { @@ -44,21 +43,6 @@ func TestGetState(t *testing.T) { t.Error("Expected non-zero expiration for 192.168.0.0") } - // Verify bot cache data - if len(state.Bots) != 2 { - t.Errorf("Expected 2 bot entries, got %d", len(state.Bots)) - } - if state.Bots["1.2.3.4"].Value != true { - t.Error("Expected bot 1.2.3.4 to be true") - } - if state.Bots["5.6.7.8"].Value != false { - t.Error("Expected bot 5.6.7.8 to be false") - } - // Verify expiration timestamps are set - if state.Bots["1.2.3.4"].Expiration == 0 { - t.Error("Expected non-zero expiration for bot 1.2.3.4") - } - // Verify verified cache data if len(state.Verified) != 1 { t.Errorf("Expected 1 verified entry, got %d", len(state.Verified)) @@ -72,15 +56,12 @@ func TestGetState(t *testing.T) { } // Verify memory tracking exists - if len(state.Memory) != 3 { - t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + if len(state.Memory) != 2 { + t.Errorf("Expected 2 memory entries, got %d", len(state.Memory)) } if state.Memory["rate"] == 0 { t.Error("Expected non-zero memory for rate cache") } - if state.Memory["bot"] == 0 { - t.Error("Expected non-zero memory for bot cache") - } if state.Memory["verified"] == 0 { t.Error("Expected non-zero memory for verified cache") } @@ -89,22 +70,18 @@ func TestGetState(t *testing.T) { func TestGetStateEmpty(t *testing.T) { // Create empty caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - state := GetState(rateCache.Items(), botCache.Items(), verifiedCache.Items()) + state := GetState(rateCache.Items(), verifiedCache.Items()) if len(state.Rate) != 0 { t.Errorf("Expected 0 rate entries, got %d", len(state.Rate)) } - if len(state.Bots) != 0 { - t.Errorf("Expected 0 bot entries, got %d", len(state.Bots)) - } if len(state.Verified) != 0 { t.Errorf("Expected 0 verified entries, got %d", len(state.Verified)) } - if len(state.Memory) != 3 { - t.Errorf("Expected 3 memory entries, got %d", len(state.Memory)) + if len(state.Memory) != 2 { + t.Errorf("Expected 2 memory entries, got %d", len(state.Memory)) } } @@ -119,10 +96,6 @@ func TestSetState(t *testing.T) { "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, "10.0.0.0": {Value: uint(5), Expiration: pastExpiration}, // expired }, - Bots: map[string]CacheEntry{ - "1.2.3.4": {Value: true, Expiration: futureExpiration}, - "5.6.7.8": {Value: false, Expiration: pastExpiration}, // expired - }, Verified: map[string]CacheEntry{ "9.9.9.9": {Value: true, Expiration: futureExpiration}, "8.8.8.8": {Value: true, Expiration: pastExpiration}, // expired @@ -132,11 +105,10 @@ func TestSetState(t *testing.T) { // Create caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) // Set state - SetState(state, rateCache, botCache, verifiedCache) + SetState(state, rateCache, verifiedCache) // Verify only non-expired entries were loaded if rateCache.ItemCount() != 1 { @@ -149,13 +121,6 @@ func TestSetState(t *testing.T) { t.Error("Expected expired entry 10.0.0.0 to be filtered out") } - if botCache.ItemCount() != 1 { - t.Errorf("Expected 1 bot entry (expired filtered out), got %d", botCache.ItemCount()) - } - if v, ok := botCache.Get("1.2.3.4"); !ok || v.(bool) != true { - t.Error("Expected bot 1.2.3.4 to be true") - } - if verifiedCache.ItemCount() != 2 { t.Errorf("Expected 2 verified entries (1 expired filtered out), got %d", verifiedCache.ItemCount()) } @@ -190,7 +155,6 @@ func TestReconcileState(t *testing.T) { // Create memory caches with some overlapping data rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), time.Duration(oldExpiration-now)) // older, should be replaced @@ -200,7 +164,7 @@ func TestReconcileState(t *testing.T) { verifiedCache.Set("2.2.2.2", true, time.Duration(newExpiration-now)) // newer, should be kept // Reconcile - ReconcileState(fileState, rateCache, botCache, verifiedCache) + ReconcileState(fileState, rateCache, verifiedCache) // Verify reconciliation results // 192.168.0.0 should be updated to file's value (newer expiration) @@ -234,26 +198,23 @@ func TestReconcileState(t *testing.T) { } } -func TestSaveStateToFile(t *testing.T) { +func TestSaveStateToFileWithMetrics(t *testing.T) { t.Run("Basic save without reconciliation", func(t *testing.T) { // Create temp file tmpFile := t.TempDir() + "/state.json" // Create caches with test data rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) - botCache.Set("1.2.3.4", false, lru.DefaultExpiration) verifiedCache.Set("5.6.7.8", true, lru.DefaultExpiration) // Save without reconciliation - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, false, // no reconciliation rateCache, - botCache, verifiedCache, testLogger(), ) @@ -263,18 +224,18 @@ func TestSaveStateToFile(t *testing.T) { } // Verify timing metrics - if lockMs < 0 || readMs < 0 || reconcileMs < 0 || marshalMs < 0 || writeMs < 0 || totalMs < 0 { + if metrics.LockMs < 0 || metrics.ReadMs < 0 || metrics.ReconcileMs < 0 || metrics.MarshalMs < 0 || metrics.WriteMs < 0 || metrics.TotalMs < 0 { t.Error("Expected all timing metrics to be non-negative") } // Verify reconcileMs is 0 when reconciliation is disabled - if reconcileMs != 0 { - t.Errorf("Expected reconcileMs to be 0 when reconciliation disabled, got %d", reconcileMs) + if metrics.ReconcileMs != 0 { + t.Errorf("Expected reconcileMs to be 0 when reconciliation disabled, got %d", metrics.ReconcileMs) } // Verify readMs is 0 when reconciliation is disabled - if readMs != 0 { - t.Errorf("Expected readMs to be 0 when reconciliation disabled, got %d", readMs) + if metrics.ReadMs != 0 { + t.Errorf("Expected readMs to be 0 when reconciliation disabled, got %d", metrics.ReadMs) } // Verify file was created and contains data @@ -285,6 +246,9 @@ func TestSaveStateToFile(t *testing.T) { if fileInfo.Size() == 0 { t.Error("State file is empty") } + if mode := fileInfo.Mode().Perm(); mode != 0600 { + t.Fatalf("State file mode = %v, want 0600", mode) + } // Load and verify the saved data savedData, err := os.ReadFile(tmpFile) @@ -300,9 +264,6 @@ func TestSaveStateToFile(t *testing.T) { if len(savedState.Rate) != 1 { t.Errorf("Expected 1 rate entry, got %d", len(savedState.Rate)) } - if len(savedState.Bots) != 1 { - t.Errorf("Expected 1 bot entry, got %d", len(savedState.Bots)) - } if len(savedState.Verified) != 1 { t.Errorf("Expected 1 verified entry, got %d", len(savedState.Verified)) } @@ -318,9 +279,8 @@ func TestSaveStateToFile(t *testing.T) { Rate: map[string]CacheEntry{ "10.0.0.0": {Value: uint(5), Expiration: futureExpiration}, }, - Bots: map[string]CacheEntry{}, Verified: map[string]CacheEntry{}, - Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, } initialData, _ := json.Marshal(initialState) if err := os.WriteFile(tmpFile, initialData, 0644); err != nil { @@ -329,17 +289,15 @@ func TestSaveStateToFile(t *testing.T) { // Create caches with different data rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) // Save with reconciliation enabled - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := SaveStateToFile( + metrics, err := SaveStateToFileWithMetrics( tmpFile, true, // enable reconciliation rateCache, - botCache, verifiedCache, testLogger(), ) @@ -349,22 +307,22 @@ func TestSaveStateToFile(t *testing.T) { } // Verify timing metrics (all should be non-negative) - if lockMs < 0 { + if metrics.LockMs < 0 { t.Error("Expected non-negative lockMs") } - if readMs < 0 { + if metrics.ReadMs < 0 { t.Error("Expected non-negative readMs when reconciliation is enabled") } - if reconcileMs < 0 { + if metrics.ReconcileMs < 0 { t.Error("Expected non-negative reconcileMs when reconciliation is enabled") } - if marshalMs < 0 { + if metrics.MarshalMs < 0 { t.Error("Expected non-negative marshalMs") } - if writeMs < 0 { + if metrics.WriteMs < 0 { t.Error("Expected non-negative writeMs") } - if totalMs < 0 { + if metrics.TotalMs < 0 { t.Error("Expected non-negative totalMs") } @@ -381,19 +339,48 @@ func TestSaveStateToFile(t *testing.T) { } }) + t.Run("Save with metrics", func(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + verifiedCache.Set("5.6.7.8", true, lru.DefaultExpiration) + + metrics, err := SaveStateToFileWithMetrics( + tmpFile, + false, + rateCache, + verifiedCache, + testLogger(), + ) + if err != nil { + t.Fatalf("SaveStateToFileWithMetrics failed: %v", err) + } + + if metrics.RateEntries != 1 { + t.Errorf("Expected 1 rate entry, got %d", metrics.RateEntries) + } + if metrics.VerifiedEntries != 1 { + t.Errorf("Expected 1 verified entry, got %d", metrics.VerifiedEntries) + } + if matches, err := filepath.Glob(tmpFile + ".tmp-*"); err != nil || len(matches) != 0 { + t.Fatalf("Expected no leftover temp files, matches=%v err=%v", matches, err) + } + }) + t.Run("File write error", func(t *testing.T) { // Use invalid path to trigger error invalidPath := "/invalid/directory/that/does/not/exist/state.json" rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - _, _, _, _, _, _, err := SaveStateToFile( + _, err := SaveStateToFileWithMetrics( invalidPath, false, rateCache, - botCache, verifiedCache, testLogger(), ) @@ -416,13 +403,10 @@ func TestLoadStateFromFile(t *testing.T) { "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, "10.0.0.0": {Value: uint(5), Expiration: futureExpiration}, }, - Bots: map[string]CacheEntry{ - "1.2.3.4": {Value: true, Expiration: futureExpiration}, - }, Verified: map[string]CacheEntry{ "5.6.7.8": {Value: true, Expiration: futureExpiration}, }, - Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, } data, _ := json.Marshal(testState) @@ -432,10 +416,9 @@ func TestLoadStateFromFile(t *testing.T) { // Load into empty caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err != nil { t.Fatalf("LoadStateFromFile failed: %v", err) } @@ -444,9 +427,6 @@ func TestLoadStateFromFile(t *testing.T) { if rateCache.ItemCount() != 2 { t.Errorf("Expected 2 rate entries, got %d", rateCache.ItemCount()) } - if botCache.ItemCount() != 1 { - t.Errorf("Expected 1 bot entry, got %d", botCache.ItemCount()) - } if verifiedCache.ItemCount() != 1 { t.Errorf("Expected 1 verified entry, got %d", verifiedCache.ItemCount()) } @@ -455,9 +435,6 @@ func TestLoadStateFromFile(t *testing.T) { if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 10 { t.Error("Expected rate 10 for 192.168.0.0") } - if v, ok := botCache.Get("1.2.3.4"); !ok || v.(bool) != true { - t.Error("Expected bot 1.2.3.4 to be true") - } if v, ok := verifiedCache.Get("5.6.7.8"); !ok || v.(bool) != true { t.Error("Expected 5.6.7.8 to be verified") } @@ -473,9 +450,8 @@ func TestLoadStateFromFile(t *testing.T) { Rate: map[string]CacheEntry{ "192.168.0.0": {Value: uint(10), Expiration: pastExpiration}, // expired }, - Bots: map[string]CacheEntry{}, Verified: map[string]CacheEntry{}, - Memory: map[string]uintptr{"rate": 8, "bot": 8, "verified": 8}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, } data, _ := json.Marshal(testState) @@ -485,10 +461,9 @@ func TestLoadStateFromFile(t *testing.T) { } // Load into empty caches rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err = LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err = LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err != nil { t.Fatalf("LoadStateFromFile failed: %v", err) } @@ -503,10 +478,9 @@ func TestLoadStateFromFile(t *testing.T) { nonExistentFile := t.TempDir() + "/does-not-exist.json" rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err := LoadStateFromFile(nonExistentFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(nonExistentFile, rateCache, verifiedCache) if err == nil { t.Error("Expected error for non-existent file, got nil") } @@ -521,10 +495,9 @@ func TestLoadStateFromFile(t *testing.T) { } rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) - err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err == nil { t.Error("Expected error for invalid JSON, got nil") } @@ -544,11 +517,10 @@ func TestLoadStateFromFile(t *testing.T) { } rateCache := lru.New(1*time.Hour, 1*time.Minute) - botCache := lru.New(1*time.Hour, 1*time.Minute) verifiedCache := lru.New(1*time.Hour, 1*time.Minute) // Empty file returns nil (no state to load, which is fine) - err := LoadStateFromFile(tmpFile, rateCache, botCache, verifiedCache) + err := LoadStateFromFile(tmpFile, rateCache, verifiedCache) if err != nil { t.Errorf("Unexpected error for empty file: %v", err) } @@ -560,6 +532,163 @@ func TestLoadStateFromFile(t *testing.T) { }) } +func TestReconcileStateFromFile(t *testing.T) { + tmpFile := t.TempDir() + "/state.json" + now := time.Now().UnixNano() + futureExpiration := now + int64(1*time.Hour) + laterExpiration := now + int64(2*time.Hour) + + fileState := State{ + Rate: map[string]CacheEntry{ + "192.168.0.0": {Value: uint(10), Expiration: futureExpiration}, + }, + Verified: map[string]CacheEntry{}, + Memory: map[string]uintptr{"rate": 8, "verified": 8}, + } + fileState.Verified = map[string]CacheEntry{ + "5.6.7.8": {Value: true, Expiration: futureExpiration}, + "9.9.9.9": {Value: true, Expiration: futureExpiration}, + } + data, _ := json.Marshal(fileState) + if err := os.WriteFile(tmpFile, data, 0644); err != nil { + t.Fatalf("Failed to write test state: %v", err) + } + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("10.0.0.0", uint(5), lru.DefaultExpiration) + verifiedCache.Set("9.9.9.9", false, time.Duration(laterExpiration-now)) + + if err := ReconcileStateFromFile(tmpFile, rateCache, verifiedCache); err != nil { + t.Fatalf("ReconcileStateFromFile failed: %v", err) + } + + if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 10 { + t.Error("Expected file state to be reconciled into memory") + } + if v, ok := rateCache.Get("10.0.0.0"); !ok || v.(uint) != 5 { + t.Error("Expected existing memory state to be retained") + } + if v, ok := verifiedCache.Get("5.6.7.8"); !ok || v.(bool) != true { + t.Error("Expected file verified state to be reconciled into memory") + } + if v, ok := verifiedCache.Get("9.9.9.9"); !ok || v.(bool) != false { + t.Error("Expected newer memory verified state to be retained") + } +} + +func TestSaveStateToFileWithMetricsWriteError(t *testing.T) { + statePath := t.TempDir() + + rateCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + + metrics, err := SaveStateToFileWithMetrics( + statePath, + false, + rateCache, + verifiedCache, + testLogger(), + ) + if err == nil { + t.Fatal("expected write error when state path is a directory") + } + if metrics.RateEntries != 1 { + t.Fatalf("expected metrics to include marshaled rate entry, got %d", metrics.RateEntries) + } +} + +func TestAtomicWriteStateFileCreateTempError(t *testing.T) { + missingDir := filepath.Join(t.TempDir(), "missing") + _, _, err := atomicWriteStateFile(filepath.Join(missingDir, "state.json"), nil, nil, 0600) + if err == nil { + t.Fatal("expected atomicWriteStateFile to fail when temp directory is missing") + } +} + +func TestWriteStateJSON(t *testing.T) { + rateCache := lru.New(1*time.Hour, 1*time.Minute) + verifiedCache := lru.New(1*time.Hour, 1*time.Minute) + + rateCache.Set("192.168.0.0", uint(10), lru.DefaultExpiration) + rateCache.Set("bad\a-key", uint(2), lru.DefaultExpiration) + verifiedCache.Set("9.9.9.9", true, lru.DefaultExpiration) + + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + if err := writeStateJSON(writer, rateCache.Items(), verifiedCache.Items()); err != nil { + t.Fatalf("writeStateJSON failed: %v", err) + } + if err := writer.Flush(); err != nil { + t.Fatalf("Flush failed: %v", err) + } + + var saved State + if err := json.Unmarshal(buf.Bytes(), &saved); err != nil { + t.Fatalf("state JSON did not unmarshal: %v", err) + } + if len(saved.Rate) != 2 || len(saved.Verified) != 1 { + t.Fatalf("unexpected saved counts: rate=%d verified=%d", len(saved.Rate), len(saved.Verified)) + } + if saved.Rate["bad\a-key"].Value != float64(2) { + t.Fatal("expected JSON-escaped rate key to round-trip") + } +} + +func TestWriteStateJSONUnexpectedType(t *testing.T) { + var buf bytes.Buffer + writer := bufio.NewWriter(&buf) + err := writeStateJSON( + writer, + map[string]lru.Item{"bad": {Object: "not-a-uint", Expiration: time.Now().Add(time.Hour).UnixNano()}}, + nil, + ) + if err == nil { + t.Fatal("expected writeStateJSON to reject unexpected cache value type") + } +} + +func TestReconcileStateFromFileEmptyAndInvalidFiles(t *testing.T) { + newCaches := func() (*lru.Cache, *lru.Cache) { + return lru.New(1*time.Hour, 1*time.Minute), + lru.New(1*time.Hour, 1*time.Minute) + } + + t.Run("missing file", func(t *testing.T) { + rateCache, verifiedCache := newCaches() + err := ReconcileStateFromFile(filepath.Join(t.TempDir(), "missing.json"), rateCache, verifiedCache) + if err == nil { + t.Fatal("expected missing state file to return read error") + } + }) + + t.Run("empty file", func(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + if err := os.WriteFile(tmpFile, nil, 0600); err != nil { + t.Fatalf("failed to write empty state file: %v", err) + } + + rateCache, verifiedCache := newCaches() + if err := ReconcileStateFromFile(tmpFile, rateCache, verifiedCache); err != nil { + t.Fatalf("empty state file should be a no-op: %v", err) + } + }) + + t.Run("invalid JSON", func(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + if err := os.WriteFile(tmpFile, []byte("{invalid json"), 0600); err != nil { + t.Fatalf("failed to write invalid state file: %v", err) + } + + rateCache, verifiedCache := newCaches() + if err := ReconcileStateFromFile(tmpFile, rateCache, verifiedCache); err == nil { + t.Fatal("expected invalid state file to return unmarshal error") + } + }) +} + func testLogger() *slog.Logger { return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelError, // Only show errors during tests @@ -584,12 +713,6 @@ func TestSetStateWithExpiration_Synctest(t *testing.T) { Expiration: start.Add(10 * time.Second).UnixNano(), // expires in 10s }, }, - Bots: map[string]CacheEntry{ - "1.2.3.4": { - Value: true, - Expiration: start.Add(3 * time.Second).UnixNano(), // expires in 3s - }, - }, Verified: map[string]CacheEntry{ "9.9.9.9": { Value: true, @@ -600,32 +723,23 @@ func TestSetStateWithExpiration_Synctest(t *testing.T) { // Create empty caches (no cleanup interval to avoid background goroutines) rateCache := lru.New(1*time.Hour, lru.NoExpiration) - botCache := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache := lru.New(1*time.Hour, lru.NoExpiration) // Load state - SetState(state, rateCache, botCache, verifiedCache) + SetState(state, rateCache, verifiedCache) // Verify all entries are loaded if rateCache.ItemCount() != 2 { t.Errorf("Expected 2 rate entries, got %d", rateCache.ItemCount()) } - if botCache.ItemCount() != 1 { - t.Errorf("Expected 1 bot entry, got %d", botCache.ItemCount()) - } if verifiedCache.ItemCount() != 1 { t.Errorf("Expected 1 verified entry, got %d", verifiedCache.ItemCount()) } - // Advance time by 4 seconds (bot entry should expire, rate entries still valid) + // Advance time by 4 seconds (rate entries still valid) time.Sleep(4 * time.Second) synctest.Wait() - // Bot cache should be empty (expired at 3s) - if _, found := botCache.Get("1.2.3.4"); found { - t.Error("Bot entry should have expired after 3 seconds") - } - // Rate entries should still be present if _, found := rateCache.Get("192.168.0.0"); !found { t.Error("Rate entry 192.168.0.0 should not expire until 5 seconds") @@ -690,7 +804,6 @@ func TestReconcileStateWithExpiration_Synctest(t *testing.T) { // Create memory caches with overlapping data (no cleanup interval to avoid background goroutines) rateCache := lru.New(1*time.Hour, lru.NoExpiration) - botCache := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache := lru.New(1*time.Hour, lru.NoExpiration) // Memory entry with older expiration (should be replaced) @@ -699,7 +812,7 @@ func TestReconcileStateWithExpiration_Synctest(t *testing.T) { rateCache.Set("10.0.0.0", uint(5), 10*time.Second) // Reconcile - ReconcileState(fileState, rateCache, botCache, verifiedCache) + ReconcileState(fileState, rateCache, verifiedCache) // 192.168.0.0 should have file's value (newer expiration) if v, ok := rateCache.Get("192.168.0.0"); !ok || v.(uint) != 15 { @@ -749,20 +862,17 @@ func TestSaveAndLoadStateWithExpiration_Synctest(t *testing.T) { // Create caches with entries expiring at different times (no cleanup interval to avoid background goroutines) rateCache1 := lru.New(1*time.Hour, lru.NoExpiration) - botCache1 := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache1 := lru.New(1*time.Hour, lru.NoExpiration) rateCache1.Set("192.168.0.0", uint(10), 5*time.Second) rateCache1.Set("10.0.0.0", uint(5), 10*time.Second) - botCache1.Set("1.2.3.4", true, 3*time.Second) verifiedCache1.Set("9.9.9.9", true, lru.NoExpiration) // Save state - _, _, _, _, _, _, err := SaveStateToFile( + _, err := SaveStateToFileWithMetrics( tmpFile, false, rateCache1, - botCache1, verifiedCache1, testLogger(), ) @@ -770,25 +880,19 @@ func TestSaveAndLoadStateWithExpiration_Synctest(t *testing.T) { t.Fatalf("SaveStateToFile failed: %v", err) } - // Advance time by 4 seconds (bot expires, rates still valid) + // Advance time by 4 seconds (rates still valid) time.Sleep(4 * time.Second) synctest.Wait() // Load into new caches (no cleanup interval to avoid background goroutines) rateCache2 := lru.New(1*time.Hour, lru.NoExpiration) - botCache2 := lru.New(1*time.Hour, lru.NoExpiration) verifiedCache2 := lru.New(1*time.Hour, lru.NoExpiration) - err = LoadStateFromFile(tmpFile, rateCache2, botCache2, verifiedCache2) + err = LoadStateFromFile(tmpFile, rateCache2, verifiedCache2) if err != nil { t.Fatalf("LoadStateFromFile failed: %v", err) } - // Bot entry should be filtered out (expired 1 second ago) - if botCache2.ItemCount() != 0 { - t.Errorf("Expected 0 bot entries (expired), got %d", botCache2.ItemCount()) - } - // First rate entry should be loaded (expires at 5s, we're at 4s) if _, found := rateCache2.Get("192.168.0.0"); !found { t.Error("Rate entry 192.168.0.0 should be loaded (not yet expired)") @@ -829,12 +933,11 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { initialCache := lru.New(1*time.Hour, lru.NoExpiration) initialCache.Set("192.168.0.0", uint(100), 5*time.Second) - _, _, _, _, _, _, err := SaveStateToFile( + _, err := SaveStateToFileWithMetrics( tmpFile, false, initialCache, lru.New(1*time.Hour, lru.NoExpiration), - lru.New(1*time.Hour, lru.NoExpiration), testLogger(), ) if err != nil { @@ -851,12 +954,11 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { newCache.Set("192.168.0.0", uint(200), 8*time.Second) // expires at start+10s // Save with reconciliation enabled - _, _, _, _, _, _, err = SaveStateToFile( + _, err = SaveStateToFileWithMetrics( tmpFile, true, // reconcile newCache, lru.New(1*time.Hour, lru.NoExpiration), - lru.New(1*time.Hour, lru.NoExpiration), testLogger(), ) if err != nil { @@ -869,7 +971,6 @@ func TestReconcilePreservesNewerData_Synctest(t *testing.T) { tmpFile, loadedCache, lru.New(1*time.Hour, lru.NoExpiration), - lru.New(1*time.Hour, lru.NoExpiration), ) if err != nil { t.Fatalf("Load failed: %v", err) diff --git a/main.go b/main.go index e3191dc..cbba0b1 100644 --- a/main.go +++ b/main.go @@ -2,11 +2,13 @@ package captcha_protect import ( "context" + crand "crypto/rand" "encoding/json" "fmt" + "hash/fnv" htemplate "html/template" "log/slog" - "math/rand" + "math/big" "net" "net/http" "net/url" @@ -26,8 +28,10 @@ import ( ) const ( - // StateSaveInterval is how often the persistent state file is written to disk - StateSaveInterval = 10 * time.Second + // StateSaveInterval is how often local persistent state is written to disk. + StateSaveInterval = 60 * time.Second + // StateReconciliationSaveInterval is the faster save cadence used when multiple instances share state. + StateReconciliationSaveInterval = 10 * time.Second // StateSaveJitter is the maximum random jitter added to save interval to prevent thundering herd StateSaveJitter = 2 * time.Second @@ -98,6 +102,10 @@ type CaptchaProtect struct { ipv6Mask net.IPMask protectRoutesRegex []*regexp.Regexp excludeRoutesRegex []*regexp.Regexp + stateMu sync.Mutex + stateDirty uint64 + stateSavedDirty uint64 + stateFileModTime time.Time // Circuit breaker fields mu sync.RWMutex @@ -158,6 +166,14 @@ func New(ctx context.Context, next http.Handler, config *Config, name string) (h return NewCaptchaProtect(ctx, next, config, name) } +func redactedConfig(config *Config) Config { + c := *config + if c.SecretKey != "" { + c.SecretKey = "[REDACTED]" + } + return c +} + func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, name string) (*CaptchaProtect, error) { log := plog.New(config.LogLevel) @@ -173,7 +189,7 @@ func NewCaptchaProtect(ctx context.Context, next http.Handler, config *Config, n } expiration := time.Duration(config.Window) * time.Second - log.Debug("Captcha config", "config", config) + log.Debug("Captcha config", "config", redactedConfig(config)) if len(config.ProtectRoutes) == 0 && config.Mode != "suffix" { return nil, fmt.Errorf("you must protect at least one route with the protectRoutes config value. / will cover your entire site") @@ -669,6 +685,7 @@ func (bc *CaptchaProtect) verifyChallengePage(rw http.ResponseWriter, req *http. if success { bc.verifiedCache.Set(ip, true, exp) + bc.markStateDirty() destination := normalizeDestination(req.FormValue("destination")) http.Redirect(rw, req, destination, http.StatusFound) @@ -723,7 +740,7 @@ func (bc *CaptchaProtect) serveStatsPage(rw http.ResponseWriter, ip string) { return } - state := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) + state := state.GetState(bc.rateCache.Items(), bc.verifiedCache.Items()) jsonData, err := json.Marshal(state) if err != nil { bc.log.Error("failed to marshal JSON", "err", err) @@ -880,13 +897,21 @@ func (bc *CaptchaProtect) trippedRateLimit(ip string) bool { func (bc *CaptchaProtect) registerRequest(ip string) { err := bc.rateCache.Add(ip, uint(1), lru.DefaultExpiration) if err == nil { + bc.markStateDirty() + return + } + + v, ok := bc.rateCache.Get(ip) + if ok && v.(uint) > bc.config.RateLimit { return } _, err = bc.rateCache.IncrementUint(ip, uint(1)) if err != nil { bc.log.Error("unable to set rate cache", "ip", ip, "err", err) + return } + bc.markStateDirty() } func (bc *CaptchaProtect) getClientIP(req *http.Request) (string, string) { @@ -1014,74 +1039,135 @@ func (c *Config) ParseHttpMethods(log *slog.Logger) { func (bc *CaptchaProtect) saveState(ctx context.Context) { // Add random jitter to prevent multiple instances from trying to save simultaneously - jitter := time.Duration(rand.Intn(int(StateSaveJitter.Milliseconds()))) * time.Millisecond - interval := StateSaveInterval + jitter + jitter := stateSaveJitter() + baseInterval := stateSaveInterval(bc.config) + interval := baseInterval + jitter - bc.log.Debug("State save configured", "baseInterval", StateSaveInterval, "jitter", jitter, "actualInterval", interval) + bc.log.Debug("State save configured", "baseInterval", baseInterval, "jitter", jitter, "actualInterval", interval) - ticker := time.NewTicker(interval) + ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() - file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0644) + file, err := os.OpenFile(bc.config.PersistentStateFile, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { bc.log.Error("unable to save state, could not open or create file", "stateFile", bc.config.PersistentStateFile, "err", err) return } // we made sure the file is writable, we can continue in our loop - file.Close() + if err := file.Close(); err != nil { + bc.log.Error("unable to save state, could not close state file", "stateFile", bc.config.PersistentStateFile, "err", err) + return + } + bc.refreshStateFileModTime() + + lastSave := time.Time{} for { select { case <-ticker.C: - bc.log.Debug("Periodic state save triggered") - bc.saveStateNow() + if bc.config.EnableStateReconciliation == "true" { + bc.reconcileStateFromFileIfChanged() + } + if !bc.hasUnsavedState() { + continue + } + if !lastSave.IsZero() && time.Since(lastSave) < interval { + continue + } + bc.log.Debug("Dirty state save triggered", "dirtyChanges", bc.unsavedStateChanges()) + if bc.saveStateNow() { + lastSave = time.Now() + } case <-ctx.Done(): - bc.log.Debug("Context cancelled, running saveState before shutdown") - bc.saveStateNow() + if bc.config.EnableStateReconciliation == "true" { + bc.reconcileStateFromFileIfChanged() + } + if bc.hasUnsavedState() { + bc.log.Debug("Context cancelled, running saveState before shutdown") + bc.saveStateNow() + } return } } } +func stateSaveInterval(config *Config) time.Duration { + if config.EnableStateReconciliation == "true" { + return StateReconciliationSaveInterval + } + return StateSaveInterval +} + +func stateSaveJitter() time.Duration { + maxJitter := big.NewInt(StateSaveJitter.Milliseconds()) + if maxJitter.Sign() <= 0 { + return 0 + } + + jitter, err := crand.Int(crand.Reader, maxJitter) + if err != nil { + return fallbackStateSaveJitter(maxJitter.Int64()) + } + + return time.Duration(jitter.Int64()) * time.Millisecond +} + +func fallbackStateSaveJitter(maxMillis int64) time.Duration { + if maxMillis <= 0 { + return 0 + } + + hostname, _ := os.Hostname() + hash := fnv.New64a() + _, _ = fmt.Fprintf(hash, "%s:%d:%d", hostname, os.Getpid(), time.Now().UnixNano()) + + jitter := new(big.Int).SetUint64(hash.Sum64()) + jitter.Mod(jitter, big.NewInt(maxMillis)) + return time.Duration(jitter.Int64()) * time.Millisecond +} + // saveStateNow performs an immediate state save using the state package. -func (bc *CaptchaProtect) saveStateNow() { +func (bc *CaptchaProtect) saveStateNow() bool { reconcile := bc.config.EnableStateReconciliation == "true" + dirtyAtStart := bc.currentStateDirty() - lockMs, readMs, reconcileMs, marshalMs, writeMs, totalMs, err := state.SaveStateToFile( + metrics, err := state.SaveStateToFileWithMetrics( bc.config.PersistentStateFile, reconcile, bc.rateCache, - bc.botCache, bc.verifiedCache, bc.log, ) if err != nil { bc.log.Error("failed to save state", "err", err) - return + return false } + bc.markStateSaved(dirtyAtStart) + bc.refreshStateFileModTime() - // Get current state for logging (already marshaled in SaveStateToFile, but we need counts) - currentState := state.GetState(bc.rateCache.Items(), bc.botCache.Items(), bc.verifiedCache.Items()) bc.log.Debug("State saved successfully", - "rateEntries", len(currentState.Rate), - "botEntries", len(currentState.Bots), - "verifiedEntries", len(currentState.Verified), - "lockMs", lockMs, - "readMs", readMs, - "reconcileMs", reconcileMs, - "marshalMs", marshalMs, - "writeMs", writeMs, - "totalMs", totalMs, + "rateEntries", metrics.RateEntries, + "verifiedEntries", metrics.VerifiedEntries, + "lockMs", metrics.LockMs, + "readMs", metrics.ReadMs, + "reconcileMs", metrics.ReconcileMs, + "marshalMs", metrics.MarshalMs, + "writeMs", metrics.WriteMs, + "totalMs", metrics.TotalMs, ) + return true } func (bc *CaptchaProtect) loadState() { + bc.loadStateFrom(bc.config.PersistentStateFile) +} + +func (bc *CaptchaProtect) loadStateFrom(filePath string) { err := state.LoadStateFromFile( - bc.config.PersistentStateFile, + filePath, bc.rateCache, - bc.botCache, bc.verifiedCache, ) @@ -1091,6 +1177,92 @@ func (bc *CaptchaProtect) loadState() { } bc.log.Info("Loaded previous state") + bc.refreshStateFileModTimeFrom(filePath) +} + +func (bc *CaptchaProtect) markStateDirty() { + if bc.config.PersistentStateFile == "" { + return + } + bc.stateMu.Lock() + bc.stateDirty++ + bc.stateMu.Unlock() +} + +func (bc *CaptchaProtect) hasUnsavedState() bool { + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + return bc.stateDirty != bc.stateSavedDirty +} + +func (bc *CaptchaProtect) unsavedStateChanges() uint64 { + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + + dirty := bc.stateDirty + saved := bc.stateSavedDirty + if dirty < saved { + return 0 + } + return dirty - saved +} + +func (bc *CaptchaProtect) currentStateDirty() uint64 { + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + return bc.stateDirty +} + +func (bc *CaptchaProtect) markStateSaved(dirty uint64) { + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + bc.stateSavedDirty = dirty +} + +func (bc *CaptchaProtect) refreshStateFileModTime() { + bc.refreshStateFileModTimeFrom(bc.config.PersistentStateFile) +} + +func (bc *CaptchaProtect) refreshStateFileModTimeFrom(filePath string) { + info, err := os.Stat(filePath) + if err != nil { + return + } + bc.stateMu.Lock() + defer bc.stateMu.Unlock() + bc.stateFileModTime = info.ModTime() +} + +func (bc *CaptchaProtect) reconcileStateFromFileIfChanged() { + info, err := os.Stat(bc.config.PersistentStateFile) + if err != nil { + return + } + modTime := info.ModTime() + bc.stateMu.Lock() + lastModTime := bc.stateFileModTime + bc.stateMu.Unlock() + if !lastModTime.IsZero() && !modTime.After(lastModTime) { + return + } + + err = state.ReconcileStateFromFile( + bc.config.PersistentStateFile, + bc.rateCache, + bc.verifiedCache, + ) + if err != nil { + bc.log.Warn("failed to reconcile state file", "err", err) + return + } + + if info, err := os.Stat(bc.config.PersistentStateFile); err == nil { + modTime = info.ModTime() + } + bc.stateMu.Lock() + bc.stateFileModTime = modTime + bc.stateMu.Unlock() + bc.log.Debug("Reconciled newer state file") } func (bc *CaptchaProtect) ChallengeOnPage() bool { diff --git a/main_test.go b/main_test.go index d3a83e0..5121b10 100644 --- a/main_test.go +++ b/main_test.go @@ -706,6 +706,22 @@ func TestNewCaptchaProtectValidation(t *testing.T) { } } +func TestRedactedConfigDoesNotExposeSecretKey(t *testing.T) { + config := CreateConfig() + config.SecretKey = "super-secret" + + redacted := redactedConfig(config) + if redacted.SecretKey == config.SecretKey { + t.Fatal("expected secret key to be redacted") + } + if redacted.SecretKey != "[REDACTED]" { + t.Fatalf("unexpected redacted secret value %q", redacted.SecretKey) + } + if config.SecretKey != "super-secret" { + t.Fatal("redactedConfig mutated the original config") + } +} + func TestRateLimiting(t *testing.T) { config := CreateConfig() config.SiteKey = "test" @@ -974,15 +990,6 @@ func TestStatePersistence(t *testing.T) { config.SiteKey = "test" config.SecretKey = "test" config.ProtectRoutes = []string{"/"} - config.PersistentStateFile = tmpFile - - // Don't pass a context to avoid starting background goroutines - bc1, _ := NewCaptchaProtect(context.Background(), nil, config, "test") - - // Add some state - bc1.rateCache.Set("192.168.0.0", uint(10), 1*time.Hour) - bc1.verifiedCache.Set("1.2.3.4", true, 1*time.Hour) - bc1.botCache.Set("5.6.7.8", false, 1*time.Hour) // Manually save state by writing the file directly // This tests the state format without relying on the background goroutine @@ -1001,12 +1008,6 @@ func TestStatePersistence(t *testing.T) { "expiration": float64(futureExpiration), }, }, - "bots": map[string]map[string]interface{}{ - "5.6.7.8": { - "value": false, - "expiration": float64(futureExpiration), - }, - }, }) err := os.WriteFile(tmpFile, jsonData, 0644) if err != nil { @@ -1015,6 +1016,7 @@ func TestStatePersistence(t *testing.T) { // Create new instance - should load state bc2, _ := NewCaptchaProtect(context.Background(), nil, config, "test") + bc2.loadStateFrom(tmpFile) // Check rate cache val, found := bc2.rateCache.Get("192.168.0.0") @@ -1028,10 +1030,139 @@ func TestStatePersistence(t *testing.T) { t.Error("Verified cache state not persisted correctly") } - // Check bot cache - botVal, found := bc2.botCache.Get("5.6.7.8") - if !found || botVal.(bool) != false { - t.Error("Bot cache state not persisted correctly") +} + +func TestRegisterRequestStopsIncrementingAfterRateLimitTrips(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + + bc := newStateOnlyCaptchaProtect(tmpFile, 2) + + for i := 0; i < 5; i++ { + bc.registerRequest("192.168.0.0") + } + + v, ok := bc.rateCache.Get("192.168.0.0") + if !ok { + t.Fatal("Expected rate cache entry") + } + if got, want := v.(uint), bc.config.RateLimit+1; got != want { + t.Fatalf("Expected sequential requests to stop incrementing at %d, got %d", want, got) + } + if got, want := bc.currentStateDirty(), uint64(bc.config.RateLimit+1); got != want { + t.Fatalf("Expected dirty counter to track only effective mutations, got %d want %d", got, want) + } +} + +func TestStatePersistenceDisabledWithoutStateFile(t *testing.T) { + config := CreateConfig() + config.SiteKey = "test" + config.SecretKey = "test" + config.ProtectRoutes = []string{"/"} + + bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") + if err != nil { + t.Fatalf("NewCaptchaProtect failed: %v", err) + } + + bc.registerRequest("192.168.0.0") + bc.markStateDirty() + + if bc.currentStateDirty() != 0 { + t.Fatal("state dirty counter should remain disabled without a persistent state file") + } + if bc.hasUnsavedState() { + t.Fatal("state should not become unsaved without a persistent state file") + } +} + +func TestStateSaveIntervalUsesFastCadenceOnlyWithReconciliation(t *testing.T) { + config := CreateConfig() + if got := stateSaveInterval(config); got != 60*time.Second { + t.Fatalf("default state save interval = %s, want 60s", got) + } + + config.EnableStateReconciliation = "true" + if got := stateSaveInterval(config); got != 10*time.Second { + t.Fatalf("reconciliation state save interval = %s, want 10s", got) + } +} + +func TestSaveStateFlushesDirtyStateOnCanceledContext(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "state.json") + bc := newStateOnlyCaptchaProtect(tmpFile, 2) + + bc.rateCache.Set("192.168.0.0", uint(1), time.Hour) + bc.markStateDirty() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + bc.saveState(ctx) + + if bc.hasUnsavedState() { + t.Fatal("expected canceled saveState to flush dirty state before returning") + } + + data, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("expected state file to be written: %v", err) + } + + var saved struct { + Rate map[string]json.RawMessage `json:"rate"` + } + if err := json.Unmarshal(data, &saved); err != nil { + t.Fatalf("state file did not contain valid JSON: %v", err) + } + if _, ok := saved.Rate["192.168.0.0"]; !ok { + t.Fatal("expected dirty rate entry to be persisted") + } +} + +func TestSaveStateNowReturnsFalseAndKeepsDirtyStateOnWriteError(t *testing.T) { + statePath := t.TempDir() + bc := newStateOnlyCaptchaProtect(statePath, 2) + bc.rateCache.Set("192.168.0.0", uint(1), time.Hour) + bc.markStateDirty() + + if bc.saveStateNow() { + t.Fatal("expected saveStateNow to fail when persistent state path is a directory") + } + if !bc.hasUnsavedState() { + t.Fatal("expected failed save to keep state dirty") + } +} + +func TestStateBookkeepingErrorBranches(t *testing.T) { + missingFile := filepath.Join(t.TempDir(), "missing", "state.json") + bc := newStateOnlyCaptchaProtect(missingFile, 2) + + bc.stateMu.Lock() + bc.stateDirty = 1 + bc.stateSavedDirty = 2 + bc.stateMu.Unlock() + if got := bc.unsavedStateChanges(); got != 0 { + t.Fatalf("unsavedStateChanges with saved counter ahead = %d, want 0", got) + } + + bc.refreshStateFileModTime() + if !bc.stateFileModTime.IsZero() { + t.Fatal("refreshStateFileModTime should ignore missing state file") + } + + bc.reconcileStateFromFileIfChanged() + if !bc.stateFileModTime.IsZero() { + t.Fatal("reconcileStateFromFileIfChanged should ignore missing state file") + } + + invalidFile := filepath.Join(t.TempDir(), "state.json") + if err := os.WriteFile(invalidFile, []byte("{invalid json"), 0600); err != nil { + t.Fatalf("failed to write invalid state file: %v", err) + } + invalidBC := newStateOnlyCaptchaProtect(invalidFile, 2) + invalidBC.reconcileStateFromFileIfChanged() + if !invalidBC.stateFileModTime.IsZero() { + t.Fatal("failed reconciliation should not advance state file mod time") } } @@ -1228,13 +1359,13 @@ func TestLoadStateInvalidJSON(t *testing.T) { config.SiteKey = "test" config.SecretKey = "test" config.ProtectRoutes = []string{"/"} - config.PersistentStateFile = tmpFile // Should not panic, just log error bc, err := NewCaptchaProtect(context.Background(), nil, config, "test") if err != nil { t.Errorf("Should not fail on invalid state JSON: %v", err) } + bc.loadStateFrom(tmpFile) // Caches should be empty if bc.rateCache.ItemCount() != 0 { @@ -1243,6 +1374,7 @@ func TestLoadStateInvalidJSON(t *testing.T) { // Clean up the file before temp dir cleanup _ = os.Remove(tmpFile) + _ = os.Remove(tmpFile + ".lock") } func TestParseHttpMethodsInvalid(t *testing.T) {