diff --git a/README.md b/README.md index 50f77b7..ed2ee4d 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,8 @@

17k+ req/s HTTP/2 + WAF + OpenTelemetry Single binary Zero dependencies Zig 0.16+ @@ -22,7 +24,13 @@ ./load_balancer -b localhost:9001 -b localhost:9002 -p 8080 ``` -That's it. Health checks, automatic failover, connection pooling. Done. +That's it. Health checks, automatic failover, connection pooling, **WAF protection**, **distributed tracing**. Done. + +## What's New + +- **Web Application Firewall (WAF)** — Rate limiting, burst detection, request validation +- **OpenTelemetry Tracing** — Full request lifecycle visibility in Jaeger +- **Lock-free Performance** — TigerBeetle-inspired atomic operations, zero locks ## Use Cases @@ -59,28 +67,28 @@ kill -USR2 $(pgrep load_balancer) -### Local microservices gateway +### Rate-limited API Gateway -One port, multiple backends. Round-robin or weighted distribution. +Protect your APIs from abuse with per-IP rate limiting and burst detection. ```bash ./load_balancer \ - -b auth-service:8001 \ - -b api-service:8002 \ - -b cache-service:8003 \ - -s round_robin + -b api-service:8001 \ + --waf-config waf.json \ + --otel-endpoint localhost:4318 ``` -### Debug HTTP/TLS issues +### Debug with Distributed Tracing -See exactly what's on the wire. Hex dumps, TLS cipher info, the works. +See every request in Jaeger with WAF decisions, backend latency, and more. ```bash -./load_balancer --trace --tls-trace \ - -b httpbin.org:443 -p 8080 +./load_balancer -b backend:8001 \ + --otel-endpoint localhost:4318 +# Open http://localhost:16686 for Jaeger UI ``` @@ -123,8 +131,13 @@ zig build -Doptimize=ReleaseFast ./zig-out/bin/backend1 & ./zig-out/bin/backend2 & -# Run -./zig-out/bin/load_balancer -p 8080 -b 127.0.0.1:9001 -b 127.0.0.1:9002 +# Run with WAF and tracing +./zig-out/bin/load_balancer \ + -p 8080 \ + -b 127.0.0.1:9001 \ + -b 127.0.0.1:9002 \ + --waf-config waf.json \ + --otel-endpoint localhost:4318 # Test curl http://localhost:8080 @@ -138,6 +151,7 @@ curl http://localhost:8080 | HAProxy needs a PhD | One binary, zero setup | | Envoy downloads half the internet | No dependencies beyond Zig stdlib | | Node/Go proxies have GC pauses | Memory-safe Zig, no garbage collector | +| WAF costs $$$$ | Built-in, lock-free, high-performance | | "Just use Kubernetes" | This is 4MB, not a lifestyle | **The numbers:** 17,000+ req/s with ~10% overhead vs direct backend access. @@ -157,15 +171,22 @@ curl http://localhost:8080 │ │ │ │ │ │ │ • pool │ │ • pool │ │ • pool │ │ • health │ │ • health │ │ • health │ + │ • WAF │ │ • WAF │ │ • WAF │ └────┬─────┘ └────┬─────┘ └────┬─────┘ │ │ │ └────────────────────┼────────────────────┘ │ SO_REUSEPORT ▼ Port 8080 + │ + ▼ + ┌─────────────────────────────┐ + │ Jaeger / OTLP │ + │ (Distributed Tracing) │ + └─────────────────────────────┘ ``` -Each worker is **fully isolated**: own connection pool, own health state, no locks. A crash in one worker doesn't affect the others. +Each worker is **fully isolated**: own connection pool, own health state, **shared WAF state** (mmap), no locks. A crash in one worker doesn't affect the others. ### Health Checking @@ -173,20 +194,190 @@ Each worker is **fully isolated**: own connection pool, own health state, no loc **Active** — Background probes every 5s catch problems before users hit them. +--- + +## Web Application Firewall (WAF) + +Built-in, lock-free WAF with TigerBeetle-inspired design: + +### Features + +| Feature | Description | +|---------|-------------| +| **Rate Limiting** | Token bucket per IP+path, atomic CAS operations | +| **Burst Detection** | Anomaly detection using EMA (detects sudden traffic spikes) | +| **Request Validation** | URI length, body size, JSON depth limits | +| **Slowloris Protection** | Per-IP connection tracking | +| **Shadow Mode** | Test rules without blocking (log_only) | +| **Hot Reload** | Config changes apply without restart | + +### WAF Request Flow + +``` +Request → WAF Check → Backend + │ + ├─ 1. Validate request (URI, body size) + ├─ 2. Check rate limits (token bucket) + ├─ 3. Check burst detection (EMA anomaly) + └─ 4. Allow / Block / Log +``` + +### WAF Configuration + +Create `waf.json`: + +```json +{ + "enabled": true, + "shadow_mode": false, + "burst_detection_enabled": true, + "burst_threshold": 10, + "rate_limits": [ + { + "name": "login_bruteforce", + "path": "/api/auth/login", + "method": "POST", + "limit": { "requests": 10, "period_sec": 60 }, + "burst": 3, + "by": "ip", + "action": "block" + }, + { + "name": "api_general", + "path": "/api/*", + "limit": { "requests": 100, "period_sec": 60 }, + "burst": 20, + "by": "ip", + "action": "block" + } + ], + "slowloris": { + "max_conns_per_ip": 50 + }, + "request_limits": { + "max_uri_length": 2048, + "max_body_size": 1048576, + "max_json_depth": 20, + "endpoints": [ + { "path": "/api/upload", "max_body_size": 10485760 } + ] + }, + "trusted_proxies": ["10.0.0.0/8", "172.16.0.0/12"], + "logging": { + "log_blocked": true, + "log_allowed": false, + "log_near_limit": true, + "near_limit_threshold": 0.8 + } +} +``` + +Run with WAF: + +```bash +./load_balancer -b backend:8001 --waf-config waf.json +``` + +### WAF Statistics + +Every 10 seconds, the WAF logs statistics: + +``` +[+10000ms] info(waf_stats): WAF Stats: total=1523 allowed=1498 blocked=25 logged=0 block_rate=1% | by_reason: rate_limit=20 slowloris=0 body=3 json=2 +``` + +### Burst Detection + +Detects sudden traffic spikes using Exponential Moving Average (EMA): + +- **Window**: 60 seconds +- **EMA**: `baseline = old * 0.875 + current * 0.125` +- **Trigger**: `current_rate > baseline * threshold` + +Example: An IP normally sends 20 req/min. Suddenly sends 300 req/min → **blocked**. + +--- + +## OpenTelemetry Integration + +Full distributed tracing with Jaeger support. + +### Setup + +1. Start Jaeger: +```bash +docker run -d --name jaeger \ + -p 16686:16686 \ + -p 4318:4318 \ + jaegertracing/all-in-one:latest +``` + +2. Run load balancer with tracing: +```bash +./load_balancer -b backend:8001 --otel-endpoint localhost:4318 +``` + +3. Open Jaeger UI: http://localhost:16686 + +### What's Traced + +Every request gets a span with: + +| Attribute | Description | +|-----------|-------------| +| `http.method` | GET, POST, etc. | +| `http.url` | Request URI | +| `http.status_code` | Response status | +| `waf.decision` | allow / block / log_only | +| `waf.client_ip` | Client IP address | +| `waf.reason` | Why blocked (rate_limit, burst, etc.) | +| `waf.rule` | Which rule triggered | +| `backend.host` | Backend server | +| `backend.latency_ms` | Backend response time | + +### Example Trace + +``` +proxy_request [12.3ms] +├─ http.method: POST +├─ http.url: /api/auth/login +├─ waf.decision: block +├─ waf.client_ip: 192.168.1.100 +├─ waf.reason: rate limit exceeded +└─ waf.rule: login_bruteforce +``` + +### Batching + +Spans are batched for efficiency: +- **Max queue**: 2048 spans +- **Batch size**: 512 spans +- **Export interval**: 5 seconds + +--- + ## CLI Reference ``` --p, --port N Listen port (default: 8080) --b, --backend H:P Backend server (repeat for multiple) --w, --workers N Worker count (default: CPU cores) --s, --strategy S round_robin | weighted | random --c, --config FILE JSON config file (hot-reloaded) --l, --loglevel LVL err | warn | info | debug --k, --insecure Skip TLS verification (dev only!) --t, --trace Dump raw HTTP payloads ---tls-trace Show TLS handshake details ---mode mp|sp Multi-process or single-process ---help You know what this does +-p, --port N Listen port (default: 8080) +-b, --backend H:P Backend server (repeat for multiple) +-w, --workers N Worker count (default: CPU cores) +-s, --strategy S round_robin | weighted | random +-c, --config FILE JSON config file (hot-reloaded) +-l, --loglevel LVL err | warn | info | debug +-k, --insecure Skip TLS verification (dev only!) +-t, --trace Dump raw HTTP payloads +--tls-trace Show TLS handshake details +--mode mp|sp Multi-process or single-process + +WAF Options: +--waf-config FILE WAF configuration JSON file +--waf-shadow Enable shadow mode (log only, don't block) + +Observability: +--otel-endpoint H:P OpenTelemetry OTLP endpoint (e.g., localhost:4318) + +--help You know what this does ``` ## Config File @@ -218,6 +409,8 @@ Changes detected via kqueue (macOS) / inotify (Linux). Zero-downtime reload. | **Prometheus metrics** | `GET /metrics` | | **Connection pooling** | Per-backend, per-worker pools | | **Crash isolation** | Workers are separate processes | +| **WAF** | Rate limiting, burst detection, request validation | +| **Distributed tracing** | OpenTelemetry + Jaeger integration | ## HTTP/2 Support @@ -249,7 +442,7 @@ info(tls_trace): ALPN Protocol: http2 ```bash zig build # Debug zig build -Doptimize=ReleaseFast # Production -zig build test # 242 tests +zig build test # 370+ tests ``` Requires Zig 0.16.0+ @@ -263,6 +456,13 @@ See [ARCHITECTURE.md](ARCHITECTURE.md) for the nerdy stuff: - Binary hot reload via file descriptor passing - SIMD HTTP parsing +See [docs/WAF_ARCHITECTURE.md](docs/WAF_ARCHITECTURE.md) for WAF internals: +- Lock-free token bucket implementation +- Burst detection with EMA algorithm +- Shared memory structures (4MB WafState) +- OpenTelemetry span propagation +- Request flow diagrams with function names + ## License MIT — do whatever you want. diff --git a/__pycache__/h2_backend.cpython-314.pyc b/__pycache__/h2_backend.cpython-314.pyc new file mode 100644 index 0000000..f813f1a Binary files /dev/null and b/__pycache__/h2_backend.cpython-314.pyc differ diff --git a/build.zig b/build.zig index 60bfb80..45e1a2a 100644 --- a/build.zig +++ b/build.zig @@ -18,6 +18,12 @@ pub fn build(b: *std.Build) void { .root_source_file = b.path("vendor/tls/src/root.zig"), }); + // OpenTelemetry SDK module (zig-o11y/opentelemetry-sdk - distributed tracing) + const otel_module = b.dependency("opentelemetry", .{ + .target = target, + .optimize = optimize, + }).module("sdk"); + // Backend 1 const backend1_mod = b.createModule(.{ .root_source_file = b.path("tests/fixtures/backend1.zig"), @@ -72,6 +78,19 @@ pub fn build(b: *std.Build) void { }); const build_test_backend_echo = b.addInstallArtifact(test_backend_echo, .{}); + // Mock OTLP collector (receives traces for integration tests) + const mock_otlp_collector_mod = b.createModule(.{ + .root_source_file = b.path("tests/fixtures/mock_otlp_collector.zig"), + .target = target, + .optimize = optimize, + }); + mock_otlp_collector_mod.addImport("zzz", zzz_module); + const mock_otlp_collector = b.addExecutable(.{ + .name = "mock_otlp_collector", + .root_module = mock_otlp_collector_mod, + }); + const build_mock_otlp_collector = b.addInstallArtifact(mock_otlp_collector, .{}); + // Sanitizer option for debugging const sanitize_thread = b.option(bool, "sanitize-thread", "Enable Thread Sanitizer") orelse false; @@ -85,6 +104,7 @@ pub fn build(b: *std.Build) void { }); load_balancer_mod.addImport("zzz", zzz_module); load_balancer_mod.addImport("tls", tls_module); + load_balancer_mod.addImport("opentelemetry", otel_module); const load_balancer = b.addExecutable(.{ .name = "load_balancer", .root_module = load_balancer_mod, @@ -122,6 +142,7 @@ pub fn build(b: *std.Build) void { }); const run_integration_exe = b.addRunArtifact(integration_exe); run_integration_exe.step.dependOn(&build_test_backend_echo.step); + run_integration_exe.step.dependOn(&build_mock_otlp_collector.step); run_integration_exe.step.dependOn(&build_load_balancer.step); const integration_test_step = b.step("test-integration", "Run integration tests"); @@ -133,6 +154,7 @@ pub fn build(b: *std.Build) void { build_all.dependOn(&build_backend2.step); build_all.dependOn(&build_backend_proxy.step); build_all.dependOn(&build_test_backend_echo.step); + build_all.dependOn(&build_mock_otlp_collector.step); build_all.dependOn(&build_load_balancer.step); const test_step = b.step("test", "Run unit tests"); diff --git a/build.zig.zon b/build.zig.zon index 01e67f4..dcd7b4c 100644 --- a/build.zig.zon +++ b/build.zig.zon @@ -7,6 +7,9 @@ .zzz = .{ .path = "vendor/zzz.io", }, + .opentelemetry = .{ + .path = "vendor/otel", + }, }, .paths = .{ diff --git a/docker-compose.yml b/docker-compose.yml index 846ed6a..bd2aaea 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,6 +6,18 @@ networks: - subnet: 172.28.0.0/16 services: + jaeger: + image: jaegertracing/all-in-one:latest + ports: + - "4317:4317" # OTLP gRPC + - "4318:4318" # OTLP HTTP + - "16686:16686" # Web UI + networks: + lb_net: + ipv4_address: 172.28.0.10 + environment: + - COLLECTOR_OTLP_ENABLED=true + backend1: build: . command: ["./zig-out/bin/backend1"] @@ -49,6 +61,8 @@ services: "172.28.0.2:9001", "--backend", "172.28.0.3:9002", + "--otel-endpoint", + "172.28.0.10:4317", ] ports: - "8080:8080" @@ -56,6 +70,7 @@ services: lb_net: ipv4_address: 172.28.0.4 depends_on: + - jaeger - backend1 - backend2 security_opt: diff --git a/docs/WAF_ARCHITECTURE.md b/docs/WAF_ARCHITECTURE.md new file mode 100644 index 0000000..3bd911a --- /dev/null +++ b/docs/WAF_ARCHITECTURE.md @@ -0,0 +1,702 @@ +# WAF Architecture Documentation + +High-performance Web Application Firewall for the zzz load balancer with OpenTelemetry observability. + +## Table of Contents +- [Overview](#overview) +- [Architecture](#architecture) +- [Request Flow](#request-flow) +- [Components](#components) +- [OpenTelemetry Integration](#opentelemetry-integration) +- [Configuration](#configuration) +- [Data Structures](#data-structures) + +--- + +## Overview + +The WAF provides: +- **Lock-free rate limiting** using atomic CAS operations +- **Request validation** (URI length, body size, JSON depth) +- **Burst detection** (anomaly detection for sudden traffic spikes) +- **Slowloris protection** (connection tracking) +- **Shadow mode** for safe rule testing +- **Hot-reload configuration** +- **Full OpenTelemetry tracing** + +Design Philosophy: **TigerBeetle-style** +- Fixed-size structures with compile-time bounds +- Cache-line alignment to prevent false sharing +- Zero allocation on hot path +- Atomic operations only (no locks) + +--- + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ Load Balancer │ +│ │ +│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────────┐ │ +│ │ main.zig │ │ handler.zig │ │ telemetry/mod.zig │ │ +│ │ │ │ │ │ │ │ +│ │ • CLI args │───▶│ • Request entry │───▶│ • OTLP exporter │ │ +│ │ • WAF init │ │ • WAF check │ │ • Batching processor │ │ +│ │ • Stats thread │ │ • OTEL spans │ │ • Span attributes │ │ +│ └─────────────────┘ └────────┬────────┘ └─────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌──────────────────────────────────────────────────────────────────────┐ │ +│ │ WAF Module (src/waf/) │ │ +│ │ │ │ +│ │ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ ┌────────────┐ │ │ +│ │ │ mod.zig │ │ engine.zig │ │ config.zig │ │ events.zig │ │ │ +│ │ │ │ │ │ │ │ │ │ │ │ +│ │ │ Public API │ │ Orchestrator │ │ JSON parser │ │ Structured │ │ │ +│ │ │ Re-exports │ │ Main check() │ │ Hot-reload │ │ logging │ │ │ +│ │ └─────────────┘ └──────┬───────┘ └─────────────┘ └────────────┘ │ │ +│ │ │ │ │ +│ │ ┌────────────────┼────────────────┐ │ │ +│ │ ▼ ▼ ▼ │ │ +│ │ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │ │ +│ │ │rate_limiter │ │ validator.zig│ │ state.zig │ │ │ +│ │ │ .zig │ │ │ │ │ │ │ +│ │ │ │ │ URI/body/JSON│ │ Shared mem │ │ │ +│ │ │ Token bucket│ │ validation │ │ structures │ │ │ +│ │ │ Atomic CAS │ │ Streaming │ │ 64K buckets │ │ │ +│ │ └─────────────┘ └──────────────┘ │ Burst track │ │ │ +│ │ │ Metrics │ │ │ +│ │ └─────────────┘ │ │ +│ └──────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────┐ + │ Jaeger / OTLP │ + │ (Distributed Tracing) │ + └─────────────────────────────────┘ +``` + +--- + +## Request Flow + +``` + ┌──────────────────────┐ + │ Incoming Request │ + └──────────┬───────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────────────────────────┐ +│ handler.zig:handle() │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 1. telemetry.startServerSpan("proxy_request") │ │ +│ │ ├─ setStringAttribute("http.method", method) │ │ +│ │ └─ setStringAttribute("http.url", uri) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 2. main.getWafEngine() -> WafEngine │ │ +│ │ └─ Returns engine with (WafState*, WafConfig*) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ 3. Build waf.Request │ │ +│ │ ├─ convertHttpMethod(ctx.request.method) │ │ +│ │ ├─ Extract body length from ctx.request.body │ │ +│ │ └─ Request.init() or Request.withContentLength() │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +└─────────────────────────────────────────┼──────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────────────────────────┐ +│ engine.zig:WafEngine.check(&request) │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 1: Fast path check │ │ +│ │ if (!self.waf_config.enabled) return CheckResult.allow() │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 2: getClientIp(request) │ │ +│ │ ├─ Check if source_ip is trusted proxy │ │ +│ │ ├─ If trusted, parse X-Forwarded-For header │ │ +│ │ └─ Return real client IP (u32) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 3: validateRequest(request) │ │ +│ │ ├─ Check URI length <= max_uri_length │ │ +│ │ ├─ Check body size <= getMaxBodySize(path) │ │ +│ │ └─ Return CheckResult.block(.invalid_request) if failed │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 4: checkRateLimit(request, client_ip) │ │ +│ │ ├─ findMatchingRule(request) -> RateLimitRule │ │ +│ │ ├─ Build Key from (client_ip, hashPath(path)) │ │ +│ │ ├─ RateLimiter.check(key, rule) -> DecisionResult │ │ +│ │ │ └─ Bucket.tryConsume() with atomic CAS │ │ +│ │ └─ Return CheckResult.block(.rate_limit) if exhausted │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 5: checkBurst(client_ip) [if burst_detection_enabled] │ │ +│ │ ├─ getCurrentTimeSec() │ │ +│ │ ├─ WafState.checkBurst(ip_hash, time, threshold) │ │ +│ │ │ ├─ BurstTracker.findOrCreate(ip_hash) │ │ +│ │ │ └─ BurstEntry.recordAndCheck(time, threshold) │ │ +│ │ │ ├─ Update EMA baseline (0.875 * old + 0.125 * current) │ │ +│ │ │ └─ Return true if current > baseline * threshold │ │ +│ │ └─ Return CheckResult.block(.burst) if burst detected │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 6: recordDecision(&result) │ │ +│ │ ├─ metrics.recordAllowed() / recordBlocked(reason) │ │ +│ │ └─ Atomic increment of counters │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Step 7: applyMode(result) │ │ +│ │ ├─ If shadow_mode: result.toShadowMode() (block -> log_only) │ │ +│ │ └─ Return final CheckResult │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ +└─────────────────────────────────────────┬──────────────────────────────────┘ + │ + ▼ +┌────────────────────────────────────────────────────────────────────────────┐ +│ Back to handler.zig │ +│ │ +│ ┌─────────────────────────────────────────────────────────────────────┐ │ +│ │ Add WAF attributes to OTEL span: │ │ +│ │ ├─ span.setStringAttribute("waf.decision", decision) │ │ +│ │ ├─ span.setStringAttribute("waf.client_ip", formatIpv4(ip)) │ │ +│ │ ├─ span.setStringAttribute("waf.reason", reason.description()) │ │ +│ │ └─ span.setStringAttribute("waf.rule", rule_name) │ │ +│ └─────────────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ┌───────────────┴───────────────┐ │ +│ ▼ ▼ │ +│ ┌─────────────────┐ ┌─────────────────┐ │ +│ │ result.isBlocked() │ Proceed to │ │ +│ │ Return 429/403 │ │ Backend │ │ +│ │ + WAF headers │ │ Proxy request │ │ +│ └─────────────────┘ └─────────────────┘ │ +│ │ +└────────────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Components + +### 1. `mod.zig` - Public API Module + +Re-exports all public types for clean external interface. + +```zig +// Usage +const waf = @import("waf"); +var engine = waf.WafEngine.init(&state, &config); +const result = engine.check(&request); +``` + +**Exported Types:** +| Type | Source | Purpose | +|------|--------|---------| +| `WafState` | state.zig | Shared memory container | +| `WafEngine` | engine.zig | Main orchestrator | +| `WafConfig` | config.zig | JSON configuration | +| `Request` | engine.zig | Request representation | +| `CheckResult` | engine.zig | Decision output | +| `Decision` | state.zig | allow/block/log_only | +| `Reason` | state.zig | Why blocked | +| `RateLimiter` | rate_limiter.zig | Token bucket | +| `RequestValidator` | validator.zig | Size/format checks | +| `EventLogger` | events.zig | Structured logging | + +--- + +### 2. `engine.zig` - WAF Engine + +The orchestrator that coordinates all security checks. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ WafEngine │ +├─────────────────────────────────────────────────────────────────┤ +│ Fields: │ +│ waf_state: *WafState (mmap'd shared memory) │ +│ waf_config: *const WafConfig │ +│ limiter: RateLimiter │ +├─────────────────────────────────────────────────────────────────┤ +│ Methods: │ +│ init(state, config) -> WafEngine │ +│ check(request) -> CheckResult ◀── Main entry point │ +│ getClientIp(request) -> u32 │ +│ validateRequest(request) -> CheckResult │ +│ checkRateLimit(request, ip) -> CheckResult │ +│ checkBurst(ip) -> CheckResult │ +│ applyMode(result) -> CheckResult │ +│ recordDecision(result) -> void │ +└─────────────────────────────────────────────────────────────────┘ +``` + +**Key Functions:** + +| Function | Lines | Purpose | +|----------|-------|---------| +| `check()` | 192-227 | Main entry, orchestrates all checks | +| `getClientIp()` | 234-255 | Extract real IP from X-Forwarded-For | +| `validateRequest()` | 258-275 | URI/body size validation | +| `checkRateLimit()` | 278-314 | Token bucket rate limiting | +| `checkBurst()` | 317-326 | Anomaly detection | +| `applyMode()` | 329-334 | Shadow mode transformation | + +--- + +### 3. `state.zig` - Shared Memory Structures + +Lock-free data structures for multi-process sharing. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ WafState (~4MB) │ +│ 64-byte aligned │ +├─────────────────────────────────────────────────────────────────┤ +│ magic: u64 │ Corruption detection "WAFSTV10" │ +├─────────────────────────────────────────────────────────────────┤ +│ buckets: [65536]Bucket │ Token bucket table │ +│ └─ Each Bucket: 64 bytes │ └─ key_hash: u64 │ +│ (cache-line aligned) │ └─ packed_state: u64 (atomic) │ +│ │ └─ tokens: u32 (scaled x1000) │ +│ │ └─ timestamp: u32 │ +├─────────────────────────────────────────────────────────────────┤ +│ conn_tracker: ConnTracker │ Per-IP connection counting │ +│ └─ [16384]ConnEntry │ └─ Slowloris detection │ +├─────────────────────────────────────────────────────────────────┤ +│ burst_tracker: BurstTracker│ Anomaly detection │ +│ └─ [8192]BurstEntry │ └─ baseline_rate (EMA) │ +│ │ └─ current_count │ +│ │ └─ last_window │ +├─────────────────────────────────────────────────────────────────┤ +│ metrics: WafMetrics │ Atomic counters │ +│ └─ requests_allowed │ │ +│ └─ requests_blocked │ │ +│ └─ requests_logged │ │ +│ └─ blocked_by_reason[10] │ │ +├─────────────────────────────────────────────────────────────────┤ +│ config_epoch: u64 │ Hot-reload detection │ +└─────────────────────────────────────────────────────────────────┘ +``` + +**Bucket CAS Operation:** + +``` +┌──────────────────────────────────────────────────────────────────┐ +│ Bucket.tryConsume(current_time, rate, max, cost) │ +│ │ +│ 1. Load packed_state atomically │ +│ ┌────────────────────────────────────────┐ │ +│ │ packed_state (u64) │ │ +│ │ [ tokens (u32) | timestamp (u32) ] │ │ +│ └────────────────────────────────────────┘ │ +│ │ +│ 2. Calculate token refill based on elapsed time │ +│ elapsed = current_time - old_timestamp │ +│ new_tokens = min(old_tokens + elapsed * rate, max) │ +│ │ +│ 3. Attempt consumption │ +│ if new_tokens >= cost: │ +│ new_tokens -= cost │ +│ │ +│ 4. CAS (Compare-And-Swap) atomically │ +│ if cmpxchgWeak(old_packed, new_packed): │ +│ return SUCCESS │ +│ else: │ +│ retry (up to MAX_CAS_ATTEMPTS) │ +│ │ +│ 5. Fail-open: If CAS exhausted, allow request │ +└──────────────────────────────────────────────────────────────────┘ +``` + +--- + +### 4. `rate_limiter.zig` - Token Bucket + +Lock-free rate limiting with atomic operations. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ RateLimiter │ +├─────────────────────────────────────────────────────────────────┤ +│ state: *WafState (pointer to shared memory) │ +├─────────────────────────────────────────────────────────────────┤ +│ check(key, rule) -> DecisionResult │ +│ ├─ findOrCreateBucket(key.hash()) │ +│ ├─ bucket.tryConsume(time, rate, capacity, cost) │ +│ └─ Return allow/block with remaining tokens │ +├─────────────────────────────────────────────────────────────────┤ +│ findOrCreateBucket(hash) -> ?*Bucket │ +│ └─ Open addressing with linear probing │ +│ └─ Probe limit: 16 attempts │ +├─────────────────────────────────────────────────────────────────┤ +│ Key Structure: │ +│ ip: u32 (IPv4 address) │ +│ path_hash: u32 (FNV-1a hash of path) │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +### 5. `validator.zig` - Request Validation + +Zero-allocation request validation with streaming JSON support. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ RequestValidator │ +├─────────────────────────────────────────────────────────────────┤ +│ config: *const ValidatorConfig │ +├─────────────────────────────────────────────────────────────────┤ +│ validateRequest(uri, content_length, headers) │ +│ ├─ Check URI length │ +│ ├─ Check body size │ +│ └─ Check query parameter count │ +├─────────────────────────────────────────────────────────────────┤ +│ validateJsonStream(chunk, state) -> ValidationResult │ +│ └─ Streaming JSON validation (constant memory) │ +│ └─ Tracks nesting depth and key count │ +├─────────────────────────────────────────────────────────────────┤ +│ ValidatorConfig: │ +│ max_uri_length: 2048 │ +│ max_query_params: 50 │ +│ max_body_size: 1MB │ +│ max_json_depth: 20 │ +│ max_json_keys: 1000 │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +### 6. `config.zig` - Configuration + +JSON configuration parsing with validation. + +```json +{ + "enabled": true, + "shadow_mode": false, + "burst_detection_enabled": true, + "burst_threshold": 10, + "rate_limits": [ + { + "name": "login_bruteforce", + "path": "/api/auth/login", + "method": "POST", + "limit": { "requests": 10, "period_sec": 60 }, + "burst": 3, + "by": "ip", + "action": "block" + } + ], + "slowloris": { + "max_conns_per_ip": 50 + }, + "request_limits": { + "max_uri_length": 2048, + "max_body_size": 1048576, + "max_json_depth": 20 + }, + "trusted_proxies": ["10.0.0.0/8"], + "logging": { + "log_blocked": true, + "log_allowed": false + } +} +``` + +--- + +### 7. `events.zig` - Structured Logging + +JSON Lines output for machine parsing. + +```json +{"timestamp":1703635200,"event_type":"blocked","client_ip":"192.168.1.1","method":"POST","path":"/api/login","rule_name":"login_bruteforce","reason":"rate_limit"} +``` + +--- + +## OpenTelemetry Integration + +### Initialization Flow + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ main.zig │ +│ │ +│ telemetry.init(allocator, "localhost:4318") │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐│ +│ │ telemetry/mod.zig:init() ││ +│ │ ││ +│ │ 1. Create ConfigOptions (OTLP endpoint, HTTP protobuf) ││ +│ │ 2. Create OTLPExporter with service name ││ +│ │ 3. Create RandomIDGenerator with seeded PRNG ││ +│ │ 4. Create TracerProvider ││ +│ │ 5. Create BatchingProcessor ││ +│ │ └─ max_queue_size: 2048 ││ +│ │ └─ scheduled_delay_millis: 5000 ││ +│ │ └─ max_export_batch_size: 512 ││ +│ │ 6. Add processor to provider ││ +│ │ 7. Get tracer ("zzz-load-balancer", "0.1.0") ││ +│ │ 8. Store in global_state ││ +│ └─────────────────────────────────────────────────────────────┘│ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Request Tracing Flow + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ handler.zig:handle() │ +│ │ +│ var span = telemetry.startServerSpan("proxy_request"); │ +│ defer span.end(); │ +│ │ │ +│ ▼ │ +│ ┌─────────────────────────────────────────────────────────────┐│ +│ │ engine.checkWithSpan(&request, &span, telemetry) ││ +│ │ ││ +│ │ Creates child spans for each WAF step: ││ +│ │ ││ +│ │ proxy_request (Server) ││ +│ │ └── waf.check (Internal) ││ +│ │ ├── waf.validate_request (Internal) ││ +│ │ │ └─ waf.step: "validate_request" ││ +│ │ │ └─ waf.passed: true/false ││ +│ │ ├── waf.rate_limit (Internal) ││ +│ │ │ └─ waf.step: "rate_limit" ││ +│ │ │ └─ waf.passed: true/false ││ +│ │ │ └─ waf.tokens_remaining: 42 ││ +│ │ │ └─ waf.rule: "api_rate_limit" ││ +│ │ └── waf.burst_detection (Internal) ││ +│ │ └─ waf.step: "burst_detection" ││ +│ │ └─ waf.passed: true/false ││ +│ └─────────────────────────────────────────────────────────────┘│ +│ │ +│ // Summary attributes on parent span │ +│ span.setStringAttribute("waf.decision", "allow"); │ +│ span.setStringAttribute("waf.client_ip", "192.168.1.1"); │ +│ │ +│ span.end(); // Automatically queued for batch export │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ BatchingProcessor (background thread) │ +│ │ +│ Every 5 seconds OR when 512 spans accumulated: │ +│ └─ OTLPExporter.export(batch) │ +│ └─ HTTP POST to http://localhost:4318/v1/traces │ +│ └─ Protobuf-encoded spans │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────┐ +│ Jaeger UI │ +│ │ +│ Trace: proxy_request [12.5ms] │ +│ ├─ waf.check [0.8ms] │ +│ │ ├─ waf.validate_request [0.05ms] │ +│ │ ├─ waf.rate_limit [0.5ms] │ +│ │ └─ waf.burst_detection [0.2ms] │ +│ └─ backend_request [11.2ms] │ +│ │ +│ Span Attributes: │ +│ ├─ http.method: GET │ +│ ├─ http.url: /api/users │ +│ ├─ waf.decision: block │ +│ │ ├─ waf.client_ip: 192.168.1.100 │ +│ │ ├─ waf.reason: rate_limit │ +│ │ └─ waf.rule: api_rate_limit │ +│ └─ Duration: 1.2ms │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Span Attributes Reference + +**Parent Span (proxy_request):** + +| Attribute | Type | Description | +|-----------|------|-------------| +| `http.method` | string | HTTP method (GET, POST, etc.) | +| `http.url` | string | Request URI | +| `http.status_code` | int | Response status code | +| `waf.decision` | string | "allow", "block", or "log_only" | +| `waf.client_ip` | string | Client IP address | +| `waf.reason` | string | Reason description if blocked | +| `waf.rule` | string | Rule name if blocked | +| `backend.host` | string | Backend server address | +| `backend.port` | int | Backend server port | + +**WAF Child Spans (waf.check, waf.validate_request, waf.rate_limit, waf.burst_detection):** + +| Attribute | Type | Description | +|-----------|------|-------------| +| `waf.step` | string | Current step name | +| `waf.passed` | bool | Whether this step passed | +| `waf.tokens_remaining` | int | Remaining rate limit tokens (rate_limit only) | +| `waf.rule` | string | Matched rule name (rate_limit only) | +| `waf.reason` | string | Block reason if failed | +| `waf.blocked_by` | string | Which step blocked (on waf.check span) | +| `waf.result` | string | Final result "allow" (on waf.check span) | +| `waf.enabled` | string | "false" if WAF disabled (fast path) | + +--- + +## Burst Detection (Anomaly Detection) + +Detects sudden traffic spikes using Exponential Moving Average (EMA). + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ BurstEntry (12 bytes per IP) │ +├─────────────────────────────────────────────────────────────────┤ +│ ip_hash: u32 │ FNV-1a hash of client IP │ +│ baseline_rate: u16 │ EMA of requests/window (scaled x16) │ +│ current_count: u16 │ Requests in current window │ +│ last_window: u32 │ Timestamp of current window start │ +└─────────────────────────────────────────────────────────────────┘ + +Algorithm: +┌─────────────────────────────────────────────────────────────────┐ +│ Window = 60 seconds │ +│ │ +│ On each request: │ +│ if (current_time - last_window >= 60): │ +│ # New window - update baseline with EMA │ +│ new_baseline = old_baseline * 0.875 + current_count * 0.125│ +│ current_count = 1 │ +│ last_window = current_time │ +│ return false # No burst on window transition │ +│ else: │ +│ # Same window - increment and check │ +│ current_count += 1 │ +│ if baseline < MIN_BASELINE: │ +│ return false # Not enough history │ +│ if current_count * 16 > baseline * threshold: │ +│ return true # BURST DETECTED │ +│ return false │ +└─────────────────────────────────────────────────────────────────┘ + +Example: + Baseline: 20 req/min (established over time) + Threshold: 10x + + Window 1: 18 requests -> baseline updates to ~20 + Window 2: 22 requests -> baseline updates to ~20 + Window 3: 250 requests + -> 250 * 16 = 4000 > 320 * 10 = 3200 + -> BURST DETECTED at request ~200 +``` + +--- + +## WAF Stats Thread + +Background thread logs statistics every 10 seconds. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ main.zig:wafStatsLoop() │ +│ │ +│ while (true): │ +│ std.posix.nanosleep(10, 0) # Sleep 10 seconds │ +│ stats = global_waf_state.metrics.snapshot() │ +│ │ +│ log.info("WAF Stats: total={} allowed={} blocked={} ...") │ +└─────────────────────────────────────────────────────────────────┘ + +Output: +[+10000ms] info(waf_stats): WAF Stats: total=1523 allowed=1498 blocked=25 logged=0 block_rate=1% | by_reason: rate_limit=20 slowloris=0 body=3 json=2 +``` + +--- + +## Memory Layout + +``` +Total WafState size: ~4.2MB + +┌──────────────────────────────────────────────────────────────────┐ +│ Offset │ Size │ Field │ +├───────────┼───────────┼──────────────────────────────────────────┤ +│ 0x000000 │ 8 bytes │ magic ("WAFSTV10") │ +│ 0x000040 │ 4,194,304 │ buckets[65536] (64 bytes each) │ +│ 0x400040 │ 131,072 │ conn_tracker[16384] (8 bytes each) │ +│ 0x420040 │ 98,304 │ burst_tracker[8192] (12 bytes each) │ +│ 0x438040 │ 128 │ metrics (atomic counters) │ +│ 0x4380C0 │ 8 │ config_epoch │ +└──────────────────────────────────────────────────────────────────┘ + +Cache-line alignment (64 bytes) at: + - buckets array start + - conn_tracker start + - burst_tracker start + - metrics start + - config_epoch +``` + +--- + +## File Reference + +| File | Lines | Purpose | +|------|-------|---------| +| `src/waf/mod.zig` | ~190 | Public API, re-exports | +| `src/waf/engine.zig` | ~400 | Main orchestrator | +| `src/waf/state.zig` | ~1240 | Shared memory structures | +| `src/waf/rate_limiter.zig` | ~350 | Token bucket implementation | +| `src/waf/validator.zig` | ~350 | Request validation | +| `src/waf/config.zig` | ~600 | JSON config parsing | +| `src/waf/events.zig` | ~300 | Structured logging | +| `src/telemetry/mod.zig` | ~220 | OpenTelemetry integration | +| `src/proxy/handler.zig` | ~500 | Request handler with WAF | +| `main.zig` | ~400 | Entry point, WAF init | + +--- + +## Test Coverage + +129 unit tests covering: +- Bucket operations and CAS +- Rate limiter logic +- Burst detection (EMA, spike detection) +- Request validation +- JSON streaming validation +- Config parsing +- Event formatting +- Integration tests + +Run tests: +```bash +zig test src/waf/mod.zig +``` diff --git a/docs/plans/2025-12-26-opentelemetry-tracing-design.md b/docs/plans/2025-12-26-opentelemetry-tracing-design.md new file mode 100644 index 0000000..fe4c09a --- /dev/null +++ b/docs/plans/2025-12-26-opentelemetry-tracing-design.md @@ -0,0 +1,257 @@ +# OpenTelemetry Tracing Design + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Add distributed tracing to the load balancer for debugging request lifecycle, HTTP/2 streams, and GOAWAY retry issues. + +**Architecture:** Traces exported via OTLP/HTTP to Jaeger. Spans created per-request with child spans for pool acquire, TLS handshake, H2 connection, and response forwarding. + +**Tech Stack:** zig-o11y/opentelemetry-sdk, Jaeger + +--- + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Load Balancer │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │ +│ │ Tracer │ │ Spans │ │ OTLP Exporter │ │ +│ │ (global) │──│ (per-req) │──│ HTTP/JSON → Jaeger │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ + │ + ▼ OTLP/HTTP (port 4318) + ┌─────────────────────┐ + │ Jaeger │ + │ localhost:16686 │ + └─────────────────────┘ +``` + +## Span Hierarchy (Implemented) + +### HTTP/2 (HTTPS) Path - New Connection (8 spans) +``` +proxy_request (root, SERVER span) +├─ http.method, http.url, http.status_code +├─ backend.id, backend.host +│ +├── backend_selection (INTERNAL) +│ ├─ lb.backend_count +│ ├─ lb.healthy_count +│ └─ lb.strategy +│ +└── backend_request_h2 (CLIENT) + ├─ http.method, http.url + ├─ http.status_code, http.response_content_length + ├─ h2.retry_count + │ + ├── dns_resolution (INTERNAL) ← only on new connection + │ └─ dns.hostname + │ + ├── tcp_connect (INTERNAL) ← only on new connection + │ ├─ net.peer.name + │ └─ net.peer.port + │ + ├── tls_handshake (INTERNAL) ← only on new connection + │ ├─ tls.server_name + │ ├─ tls.cipher + │ └─ tls.alpn + │ + ├── h2_handshake (INTERNAL) ← only on new connection + │ └─ HTTP/2 preface + SETTINGS exchange + │ + └── response_streaming (INTERNAL) + ├─ http.body.type: "h2_buffered" + ├─ http.body.length + ├─ http.body.bytes_written + └─ http.body.had_error +``` + +### HTTP/2 (HTTPS) Path - Reused Connection (4 spans) +``` +proxy_request (root, SERVER span) +├── backend_selection (INTERNAL) +└── backend_request_h2 (CLIENT) + └── response_streaming (INTERNAL) +``` +Connection spans (dns, tcp, tls, h2_handshake) only appear when creating fresh connections. +H2 multiplexing reuses existing connections for subsequent requests. + +### HTTP/1.1 Path - New Connection +``` +proxy_request (root, SERVER span) +├─ http.method, http.url, http.status_code +│ +├── backend_selection (INTERNAL) +│ ├─ lb.backend_count, lb.healthy_count, lb.strategy +│ +├── backend_connection (INTERNAL) +│ │ +│ ├── dns_resolution (INTERNAL) ← only on new connection +│ │ └─ dns.hostname +│ │ +│ ├── tcp_connect (INTERNAL) ← only on new connection +│ │ ├─ net.peer.name +│ │ └─ net.peer.port +│ │ +│ └── tls_handshake (INTERNAL, if HTTPS) ← only on new connection +│ ├─ tls.server_name, tls.cipher, tls.alpn +│ +├── backend_request (CLIENT) +│ └─ http.method, http.url +│ +└── response_streaming (INTERNAL) + ├─ http.body.type: "content_length" | "chunked" | "until_close" + ├─ http.body.expected_length + ├─ http.body.bytes_transferred + └─ http.body.had_error +``` + +## Files to Create + +### src/telemetry/mod.zig +Public API for tracing: +- `init(endpoint: []const u8)` - Initialize tracer with Jaeger endpoint +- `shutdown()` - Flush and close exporter +- `startSpan(name, parent)` - Create new span +- Global tracer instance + +### src/telemetry/exporter.zig +OTLP/HTTP exporter: +- Batch spans (100 or 5 seconds) +- Non-blocking export in background +- Flush on shutdown + +### src/telemetry/span.zig +Span wrapper: +- setAttribute(key, value) +- addEvent(name) +- setStatus(error) +- end() + +## Files to Modify + +### build.zig +```zig +const otel = b.dependency("opentelemetry-sdk", .{ + .target = target, + .optimize = optimize, +}); +load_balancer_mod.addImport("opentelemetry", otel.module("opentelemetry")); +``` + +### build.zig.zon +```zig +.@"opentelemetry-sdk" = .{ + .url = "git+https://github.com/zig-o11y/opentelemetry-sdk", +}, +``` + +### main.zig +- Add `--otel-endpoint` CLI flag +- Call `telemetry.init(endpoint)` at startup +- Call `telemetry.shutdown()` on exit + +### src/proxy/handler.zig +- Create root span in `streamingProxy()` +- Add `trace_context` field to `ProxyState` +- End span in `streamingProxy_finalize()` + +### src/proxy/connection.zig +- Create `pool_acquire` child span in `acquireConnection()` + +### src/http/http2/connection.zig +- Create `h2_connect` span in `connect()` +- Create `h2_request` span in `request()` +- Add `goaway_received` event when GOAWAY detected + +### src/http/http2/pool.zig +- Add `retry_triggered` event on retry + +### src/http/tls.zig +- Create `tls_handshake` span during TLS setup + +## Context Passing + +```zig +pub const ProxyState = struct { + // ... existing fields ... + trace_span: ?*Span = null, // Root span for this request +}; +``` + +Child spans access parent via `proxy_state.trace_span`. + +## CLI Usage + +```bash +# With tracing +./load_balancer --port 8080 --backend 127.0.0.1:9000 \ + --otel-endpoint http://localhost:4318 + +# Without tracing (default) +./load_balancer --port 8080 --backend 127.0.0.1:9000 +``` + +## Testing + +```bash +# Start Jaeger +docker run -d --name jaeger \ + -p 16686:16686 -p 4318:4318 \ + jaegertracing/all-in-one:latest + +# Run LB with tracing +./zig-out/bin/load_balancer --otel-endpoint http://localhost:4318 ... + +# View traces +open http://localhost:16686 +``` + +## Verification Checklist + +- [x] Normal request shows complete span hierarchy +- [x] Backend selection shows lb.strategy, backend_count, healthy_count +- [x] DNS resolution span with hostname +- [x] TCP connect span with peer name/port +- [x] TLS handshake span with cipher and ALPN protocol +- [x] H2 handshake span for HTTP/2 connections +- [x] Response streaming span with body type and bytes transferred +- [x] Errors mark spans with error status +- [x] GOAWAY shows retry events in h2 span + +## Implementation Status + +**Completed:** +- Vendored OpenTelemetry SDK from zig-o11y +- OTLP/HTTP exporter to Jaeger (port 4318) +- `--otel-endpoint` CLI flag +- Full span hierarchy for HTTP/1.1 and HTTP/2 paths +- Connection phase spans (DNS, TCP, TLS, H2 handshake) +- Response streaming spans +- Proper parent-child span relationships +- Span attributes following OpenTelemetry semantic conventions + +## Performance Characteristics + +**Non-blocking Architecture:** +- Uses `BatchingProcessor` with background export thread +- Request path: quick span clone + queue append (~microseconds) +- Background thread: batched HTTP export every 5s or 512 spans +- Request handling never blocks on OTLP HTTP export + +**SDK Fixes Applied:** +- Added `Span.clone()` for deep copying (attributes, events, links) +- `BatchingProcessor.onEnd()` deep clones spans to take ownership +- Prevents use-after-free when original span is cleaned up +- Mutex unlocked during HTTP export for concurrency + +**Configuration:** +```zig +BatchingProcessor.init(allocator, exporter, .{ + .max_queue_size = 2048, // Max spans in queue + .scheduled_delay_millis = 5000, // Export every 5 seconds + .max_export_batch_size = 512, // Or when 512 spans accumulated +}); +``` diff --git a/docs/plans/2025-12-26-waf-design.md b/docs/plans/2025-12-26-waf-design.md new file mode 100644 index 0000000..76130c5 --- /dev/null +++ b/docs/plans/2025-12-26-waf-design.md @@ -0,0 +1,502 @@ +# ZZZ WAF Design + +**Date:** 2025-12-26 +**Status:** Approved +**Scope:** Integrated WAF for zzz load balancer + +## Overview + +A high-performance Web Application Firewall integrated into the zzz load balancer, focusing on API protection and DDoS/abuse mitigation. Designed with TigerBeetle-style patterns for zero-allocation hot paths and lock-free shared state. + +### Goals + +- **API Protection:** Rate limiting, authentication abuse prevention, JSON depth attacks +- **DDoS Mitigation:** Slowloris, request flooding, volume-based attacks +- **Production-Ready:** Shadow mode, hot-reload config, full observability +- **TigerBeetle-Style:** Fixed-size structures, lock-free atomics, bounded everything + +### Non-Goals (Future Work) + +- OWASP pattern matching (SQLi, XSS) - requires regex engine +- Multi-node state sync (Redis) - can layer on later +- Bot detection / CAPTCHA integration + +--- + +## Architecture + +``` +Request → Router → [WAF Layer] → Proxy Handler → Backend + ↓ + ┌────────┴────────┐ + │ WAF Engine │ + ├─────────────────┤ + │ • RateLimiter │ ← Token bucket per IP/path + │ • SlowlorisGuard│ ← Connection timing + │ • RequestValidator│ ← Size limits, JSON depth + │ • DecisionLogger│ ← Events + traces + └────────┬────────┘ + ↓ + ┌────────┴────────┐ + │ Shared State │ (mmap'd region) + ├─────────────────┤ + │ • BucketTable │ ← Fixed-size hash table + │ • ConnTracker │ ← Per-IP connection counts + │ • Metrics │ ← Atomic counters + └─────────────────┘ +``` + +### Design Principles + +- **Zero allocation on hot path** - All structures pre-allocated at startup +- **Lock-free reads** - Atomic operations only, no mutexes during request handling +- **Fail-open option** - If WAF state is corrupted, configurable to allow or deny +- **Comptime rule compilation** - Rule matching logic generated at compile time where possible + +--- + +## Shared State Design (TigerBeetle-Style) + +### Main State Structure + +```zig +/// Cache-line aligned for atomic access +pub const WafState = extern struct { + // Magic + version for corruption detection + magic: u64 = 0xWAF_STATE_V1, + + // Token bucket table - fixed size, open addressing + buckets: [MAX_BUCKETS]Bucket align(64), + + // Connection tracking for slowloris + conn_tracker: ConnTracker align(64), + + // Global metrics - atomic counters + metrics: WafMetrics align(64), + + // Configuration epoch - for hot-reload detection + config_epoch: u64 align(64), + + comptime { + std.debug.assert(@sizeOf(WafState) == WAF_STATE_SIZE); + } +}; +``` + +### Token Bucket Entry + +```zig +pub const Bucket = extern struct { + // Key: hash of (IP, path_pattern) + key_hash: u64, + + // Token bucket state (packed for atomic CAS) + tokens: u32, // Current tokens (scaled by 1000) + last_update: u32, // Timestamp in seconds + + // Stats for this bucket + total_requests: u64, + total_blocked: u64, + + comptime { + std.debug.assert(@sizeOf(Bucket) == 64); // One cache line + } +}; +``` + +### Constants + +```zig +const MAX_BUCKETS = 65536; // 64K entries = 4MB state +const MAX_TOKENS = 10000; // Scaled for precision +const BUCKET_PROBE_LIMIT = 16; // Open addressing probe limit +const MAX_CAS_ATTEMPTS = 8; // Retry limit for atomic ops +const MAX_TRACKED_IPS = 16384; // For connection counting +``` + +--- + +## Rate Limiter (Token Bucket) + +Lock-free token bucket with atomic compare-and-swap. O(1) per request. + +```zig +pub const RateLimiter = struct { + state: *WafState, + + pub fn check(self: *RateLimiter, key: Key, rule: *const Rule) Decision { + const bucket_idx = self.findBucket(key); + const bucket = &self.state.buckets[bucket_idx]; + + const now_sec: u32 = @truncate(std.time.timestamp()); + + var attempts: u32 = 0; + while (attempts < MAX_CAS_ATTEMPTS) : (attempts += 1) { + const old = @atomicLoad(u64, &bucket.packed, .acquire); + const old_tokens = unpackTokens(old); + const old_time = unpackTime(old); + + // Refill tokens based on elapsed time + const elapsed = now_sec -% old_time; + const refill = @min(elapsed * rule.tokens_per_sec, rule.burst_capacity); + const available = @min(old_tokens + refill, rule.burst_capacity); + + if (available < rule.cost_per_request) { + return .{ .action = .block, .reason = .rate_limit_exceeded }; + } + + const new_tokens = available - rule.cost_per_request; + const new = packState(new_tokens, now_sec); + + if (@cmpxchgWeak(u64, &bucket.packed, old, new, .release, .monotonic)) |_| { + continue; + } + + return .{ .action = .allow }; + } + + // Fail-open under extreme contention + @atomicAdd(&self.state.metrics.cas_exhausted, 1, .monotonic); + return .{ .action = .allow, .reason = .cas_exhausted }; + } +}; +``` + +### Key Properties + +- No mutex, no blocking +- Timestamps wrap-safe using wrapping subtraction +- Bounded CAS retries (fail-open under extreme contention) +- Token precision: scaled by 1000 for sub-integer rates + +--- + +## Slowloris & Connection Abuse Detection + +Per-connection state tracking with timeout enforcement. + +```zig +pub const SlowlorisGuard = struct { + pub const ConnState = struct { + first_byte_time: u64, + headers_complete_time: u64, + bytes_received: u32, + last_activity: u64, + }; + + pub const Config = struct { + header_timeout_ms: u32 = 5_000, + body_timeout_ms: u32 = 30_000, + min_bytes_per_sec: u32 = 100, + max_conns_per_ip: u16 = 100, + }; + + pub fn onDataReceived(self: *SlowlorisGuard, conn: *ConnState, bytes: u32) Decision { + const now = std.time.milliTimestamp(); + conn.bytes_received += bytes; + conn.last_activity = now; + + // Header timeout check + if (conn.headers_complete_time == 0) { + if (now - conn.first_byte_time > self.config.header_timeout_ms) { + return .{ .action = .block, .reason = .header_timeout }; + } + } + + // Transfer rate check (after initial burst window) + const elapsed_sec = (now - conn.first_byte_time) / 1000; + if (elapsed_sec > 2) { + const rate = conn.bytes_received / elapsed_sec; + if (rate < self.config.min_bytes_per_sec) { + return .{ .action = .block, .reason = .slow_transfer }; + } + } + + return .{ .action = .allow }; + } +}; +``` + +### Connection Tracking (Shared Memory) + +```zig +pub const ConnTracker = extern struct { + entries: [MAX_TRACKED_IPS]ConnEntry align(64), +}; + +pub const ConnEntry = extern struct { + ip_hash: u32, + conn_count: u16, // Atomic increment/decrement + _padding: u16, +}; +``` + +--- + +## API Protection (Request Validation) + +Streaming validation with constant memory usage. + +```zig +pub const RequestValidator = struct { + pub const Config = struct { + max_uri_length: u16 = 2048, + max_query_params: u8 = 50, + max_header_value_length: u16 = 8192, + max_cookie_size: u16 = 4096, + max_body_size: u32 = 1_048_576, + max_json_depth: u8 = 20, + max_json_keys: u16 = 1000, + endpoint_overrides: []const EndpointConfig, + }; + + /// Fast pre-body validation + pub fn validateHeaders(self: *RequestValidator, req: *const Request) Decision { + if ((req.uri orelse "").len > self.config.max_uri_length) { + return .{ .action = .block, .reason = .uri_too_long }; + } + + if (req.getHeader("content-length")) |cl| { + const len = std.fmt.parseInt(u32, cl, 10) catch 0; + if (len > self.config.max_body_size) { + return .{ .action = .block, .reason = .body_too_large }; + } + } + + if (req.getHeader("cookie")) |cookie| { + if (cookie.len > self.config.max_cookie_size) { + return .{ .action = .block, .reason = .cookie_too_large }; + } + } + + return .{ .action = .allow }; + } + + /// Streaming JSON validation (constant memory) + pub fn validateJsonStream(self: *RequestValidator, chunk: []const u8, state: *JsonState) Decision { + for (chunk) |byte| { + switch (byte) { + '{', '[' => { + state.depth += 1; + if (state.depth > self.config.max_json_depth) { + return .{ .action = .block, .reason = .json_too_deep }; + } + }, + '}', ']' => state.depth -|= 1, + ':' => { + state.key_count += 1; + if (state.key_count > self.config.max_json_keys) { + return .{ .action = .block, .reason = .json_too_many_keys }; + } + }, + else => {}, + } + } + return .{ .action = .allow }; + } +}; +``` + +--- + +## Configuration + +### Example `waf.json` + +```json +{ + "enabled": true, + "shadow_mode": false, + + "rate_limits": [ + { + "name": "login_bruteforce", + "path": "/api/auth/login", + "method": "POST", + "limit": { "requests": 10, "period_sec": 60 }, + "burst": 3, + "by": "ip", + "action": "block" + }, + { + "name": "api_global", + "path": "/api/*", + "limit": { "requests": 1000, "period_sec": 60 }, + "burst": 100, + "by": "ip", + "action": "block" + } + ], + + "slowloris": { + "header_timeout_ms": 5000, + "body_timeout_ms": 30000, + "min_bytes_per_sec": 100, + "max_conns_per_ip": 50 + }, + + "request_limits": { + "max_uri_length": 2048, + "max_body_size": 1048576, + "max_json_depth": 20, + "endpoints": [ + { "path": "/api/upload", "max_body_size": 10485760 } + ] + }, + + "trusted_proxies": ["10.0.0.0/8", "172.16.0.0/12"], + + "logging": { + "log_blocked": true, + "log_allowed": false, + "log_near_limit": true, + "near_limit_threshold": 0.8 + } +} +``` + +### Hot-Reload Behavior + +- Config file watched using existing `config_watcher.zig` pattern +- Atomic epoch increment signals workers to re-read +- Rate limit buckets NOT cleared on reload (existing limits preserved) +- Only rule definitions change + +### CLI Integration + +``` +./load_balancer --backend 127.0.0.1:8080 --waf waf.json --waf-shadow +``` + +--- + +## Observability + +### Metrics (Prometheus format via `/metrics`) + +```zig +pub const WafMetrics = extern struct { + requests_allowed: u64 align(64), + requests_blocked: u64 align(64), + requests_logged: u64 align(64), + + blocked_rate_limit: u64, + blocked_slowloris: u64, + blocked_body_too_large: u64, + blocked_json_depth: u64, + + bucket_table_usage: u64, + cas_exhausted: u64, + config_reloads: u64, +}; +``` + +### OpenTelemetry Spans + +Child span of `proxy_request`: + +```zig +fn createWafSpan(parent: Span, decision: Decision, rule: ?*const Rule) Span { + var span = parent.child("waf_check"); + span.setAttribute("waf.decision", @tagName(decision.action)); + span.setAttribute("waf.shadow_mode", config.shadow_mode); + + if (decision.action != .allow) { + span.setAttribute("waf.reason", @tagName(decision.reason)); + if (rule) |r| span.setAttribute("waf.rule", r.name); + } + + return span; +} +``` + +### Structured Event Log + +JSON events on interesting events (blocks, near-limit warnings): + +```zig +pub const WafEvent = struct { + timestamp: i64, + event_type: enum { blocked, near_limit, config_reload }, + client_ip: []const u8, + method: []const u8, + path: []const u8, + rule_name: ?[]const u8, + reason: ?[]const u8, + tokens_remaining: ?u32, +}; +``` + +--- + +## Integration Points + +### 1. Router Layer (main.zig) + +```zig +var router = try Router.init(allocator, &.{ + Route.init("/metrics").get({}, metrics.metricsHandler).layer(), + Route.init("/").all(waf_ctx, wafMiddleware).layer(), + Route.init("/%r").all(handler_ctx, generateHandler(strategy)).layer(), +}, .{}); +``` + +### 2. Shared Memory Region + +Extend existing `shared_region.zig`: + +```zig +pub const SharedRegion = struct { + health_state: *SharedHealthState, + waf_state: *WafState, + backend_config: *BackendConfig, +}; +``` + +### 3. CLI Arguments + +``` +--waf Path to WAF config JSON +--waf-shadow Force shadow mode (log only) +--waf-disabled Disable WAF entirely +``` + +--- + +## File Structure + +``` +src/ +├── waf/ +│ ├── mod.zig # Public API +│ ├── engine.zig # Main WAF engine +│ ├── rate_limiter.zig # Token bucket implementation +│ ├── slowloris.zig # Connection abuse detection +│ ├── validator.zig # Request validation +│ ├── state.zig # Shared memory structures +│ ├── config.zig # JSON config parsing +│ └── events.zig # Structured logging +``` + +--- + +## Testing Strategy + +1. **Unit tests** - Each module in isolation (rate_limiter, validator, etc.) +2. **Integration tests** - Add `tests/suites/waf.zig` using existing harness +3. **Load tests** - Verify zero-allocation claims under traffic +4. **Fuzz tests** - JSON parser, config parser edge cases + +--- + +## Implementation Order + +1. `src/waf/state.zig` - Shared memory structures +2. `src/waf/rate_limiter.zig` - Token bucket core +3. `src/waf/config.zig` - JSON parsing +4. `src/waf/engine.zig` - Main orchestration +5. `main.zig` integration - CLI + middleware +6. `src/waf/slowloris.zig` - Connection tracking +7. `src/waf/validator.zig` - Request validation +8. `src/waf/events.zig` - Observability +9. Tests + documentation diff --git a/integration_test b/integration_test new file mode 100755 index 0000000..7a38eae Binary files /dev/null and b/integration_test differ diff --git a/main.zig b/main.zig index f563bd2..19ce07e 100644 --- a/main.zig +++ b/main.zig @@ -16,6 +16,9 @@ const log = std.log.scoped(.lb); const zzz = @import("zzz"); const http = zzz.HTTP; +const telemetry = @import("src/telemetry/mod.zig"); +const waf = @import("src/waf/mod.zig"); + const Io = std.Io; const Server = http.Server; const Router = http.Router; @@ -36,6 +39,75 @@ const SharedHealthState = shared_region.SharedHealthState; /// Runtime log level (can be changed via --loglevel) var runtime_log_level: std.log.Level = .info; +/// Global WAF state (shared across all workers via mmap) +/// This is pointer-stable because it's placed in mmap'd shared memory before fork +var global_waf_state: ?*waf.WafState = null; + +/// Global WAF config (loaded from JSON, immutable after init) +var global_waf_config: waf.WafConfig = .{}; + +/// WAF allocator (for config parsing) +var waf_allocator: ?std.mem.Allocator = null; + +/// Get the global WAF engine for request checking +/// Returns null if WAF is disabled or not configured +pub fn getWafEngine() ?waf.WafEngine { + if (!global_waf_config.enabled) return null; + if (global_waf_state) |state| { + return waf.WafEngine.init(state, &global_waf_config); + } + return null; +} + +/// Check if WAF is enabled and configured +pub fn isWafEnabled() bool { + return global_waf_config.enabled and global_waf_state != null; +} + +/// Get WAF config (for read-only access) +pub fn getWafConfig() *const waf.WafConfig { + return &global_waf_config; +} + +/// WAF stats reporter interval (seconds) +const WAF_STATS_INTERVAL_SEC: u64 = 10; + +/// Start WAF stats reporter thread +fn startWafStatsThread() !std.Thread { + return std.Thread.spawn(.{}, wafStatsLoop, .{}); +} + +/// WAF stats reporter loop - runs in background thread +fn wafStatsLoop() void { + const waf_log = std.log.scoped(.waf_stats); + + while (true) { + std.posix.nanosleep(WAF_STATS_INTERVAL_SEC, 0); + + if (global_waf_state) |state| { + const stats = state.metrics.snapshot(); + const total = stats.totalRequests(); + + if (total > 0) { + waf_log.info( + "WAF Stats: total={d} allowed={d} blocked={d} logged={d} block_rate={d}% | by_reason: rate_limit={d} slowloris={d} body={d} json={d}", + .{ + total, + stats.requests_allowed, + stats.requests_blocked, + stats.requests_logged, + stats.blockRatePercent(), + stats.blocked_rate_limit, + stats.blocked_slowloris, + stats.blocked_body_too_large, + stats.blocked_json_depth, + }, + ); + } + } + } +} + pub const std_options: std.Options = .{ .log_level = .debug, // Compile-time max level (allows all) .logFn = runtimeLogFn, // Custom log function respects runtime level @@ -120,6 +192,11 @@ const Config = struct { insecure_tls: bool = false, // Skip TLS verification (for testing only) trace: bool = false, // Enable hex/ASCII payload tracing tls_trace: bool = false, // Enable detailed TLS handshake tracing + otel_endpoint: ?[]const u8 = null, // OTLP endpoint for OpenTelemetry tracing + // WAF configuration + waf_config_path: ?[]const u8 = null, // Path to WAF JSON config + waf_shadow_mode: bool = false, // Force shadow mode (log only, don't block) + waf_disabled: bool = false, // Disable WAF entirely }; // ============================================================================ @@ -146,7 +223,11 @@ fn printUsage() void { \\ -k, --insecure Skip TLS certificate verification (testing only) \\ -t, --trace Dump raw request/response payloads (hex + ASCII) \\ --tls-trace Show detailed TLS handshake info (cipher, version, CA) + \\ --otel-endpoint OTLP endpoint for OpenTelemetry tracing (e.g. localhost:4318) \\ --upgrade-fd Inherit socket fd for binary hot reload (internal) + \\ --waf Path to WAF configuration JSON file + \\ --waf-shadow Force WAF shadow mode (log only, don't block) + \\ --waf-disabled Disable WAF entirely \\ --help Show this help \\ \\HOT RELOAD: @@ -157,6 +238,7 @@ fn printUsage() void { \\ load_balancer --mode mp --port 8080 --backend 127.0.0.1:9001 \\ load_balancer -m sp -p 8080 -b 127.0.0.1:9001 -b 127.0.0.1:9002 \\ load_balancer -m mp -c backends.json # Hot reload on file change + \\ load_balancer -m sp --waf waf.json # Enable WAF with config \\ , .{}); } @@ -175,6 +257,10 @@ fn parseArgs(allocator: std.mem.Allocator) !Config { var insecure_tls: bool = false; var trace: bool = false; var tls_trace: bool = false; + var otel_endpoint: ?[]const u8 = null; + var waf_config_path: ?[]const u8 = null; + var waf_shadow_mode: bool = false; + var waf_disabled: bool = false; var backend_list: std.ArrayListUnmanaged(BackendDef) = .empty; errdefer backend_list.deinit(allocator); @@ -273,11 +359,25 @@ fn parseArgs(allocator: std.mem.Allocator) !Config { trace = true; } else if (std.mem.eql(u8, arg, "--tls-trace")) { tls_trace = true; + } else if (std.mem.eql(u8, arg, "--otel-endpoint")) { + if (i + 1 < args.len) { + otel_endpoint = try allocator.dupe(u8, args[i + 1]); + i += 1; + } } else if (std.mem.eql(u8, arg, "--upgrade-fd")) { if (i + 1 < args.len) { upgrade_fd = try std.fmt.parseInt(posix.fd_t, args[i + 1], 10); i += 1; } + } else if (std.mem.eql(u8, arg, "--waf")) { + if (i + 1 < args.len) { + waf_config_path = try allocator.dupe(u8, args[i + 1]); + i += 1; + } + } else if (std.mem.eql(u8, arg, "--waf-shadow")) { + waf_shadow_mode = true; + } else if (std.mem.eql(u8, arg, "--waf-disabled")) { + waf_disabled = true; } } @@ -296,6 +396,22 @@ fn parseArgs(allocator: std.mem.Allocator) !Config { std.debug.print("TLS-TRACE: Detailed TLS handshake info enabled\n", .{}); } + // Notify if OpenTelemetry tracing is enabled + if (otel_endpoint) |endpoint| { + std.debug.print("OTEL: OpenTelemetry tracing enabled, endpoint: {s}\n", .{endpoint}); + } + + // Notify about WAF configuration + if (waf_disabled) { + std.debug.print("WAF: Disabled via --waf-disabled\n", .{}); + } else if (waf_config_path) |path| { + if (waf_shadow_mode) { + std.debug.print("WAF: Shadow mode enabled (log only), config: {s}\n", .{path}); + } else { + std.debug.print("WAF: Enabled, config: {s}\n", .{path}); + } + } + // Use default mode if not specified const final_mode = mode orelse RunMode.default(); @@ -338,6 +454,10 @@ fn parseArgs(allocator: std.mem.Allocator) !Config { .insecure_tls = insecure_tls, .trace = trace, .tls_trace = tls_trace, + .otel_endpoint = otel_endpoint, + .waf_config_path = waf_config_path, + .waf_shadow_mode = waf_shadow_mode, + .waf_disabled = waf_disabled, .lbConfig = .{ .worker_count = worker_count, .port = port, @@ -354,6 +474,16 @@ fn freeConfig(allocator: std.mem.Allocator, config: Config) void { allocator.free(backend.host); } allocator.free(config.lbConfig.backends); + + // Free the otel_endpoint if it was allocated + if (config.otel_endpoint) |endpoint| { + allocator.free(endpoint); + } + + // Free the WAF config path if it was allocated + if (config.waf_config_path) |path| { + allocator.free(path); + } } // ============================================================================ @@ -381,6 +511,42 @@ pub fn main() !void { // Set runtime TLS trace mode config_mod.setTlsTraceEnabled(config.tls_trace); + // Initialize OpenTelemetry tracing if endpoint is provided + if (config.otel_endpoint) |endpoint| { + telemetry.init(allocator, endpoint) catch |err| { + log.err("Failed to initialize telemetry: {s}", .{@errorName(err)}); + }; + } + defer telemetry.deinit(); + + // Initialize WAF if configured + if (!config.waf_disabled) { + if (config.waf_config_path) |path| { + waf_allocator = allocator; + global_waf_config = waf.WafConfig.loadFromFile(allocator, path) catch |err| { + log.err("Failed to load WAF config from '{s}': {s}", .{ path, @errorName(err) }); + return err; + }; + // Apply shadow mode override from CLI + if (config.waf_shadow_mode) { + global_waf_config.shadow_mode = true; + } + log.info("WAF config loaded: enabled={}, shadow_mode={}, rules={d}", .{ + global_waf_config.enabled, + global_waf_config.shadow_mode, + global_waf_config.rate_limits.len, + }); + } + } else { + // WAF explicitly disabled + global_waf_config.enabled = false; + } + defer { + if (waf_allocator != null) { + global_waf_config.deinit(); + } + } + // Validate configuration try config.lbConfig.validate(); @@ -416,6 +582,21 @@ fn runMultiProcess(allocator: std.mem.Allocator, config: Config) !void { const region = try shared_allocator.init(); defer shared_allocator.deinit(); + // Initialize WAF state on heap (too large for stack: ~4MB) + // For multi-process, this should ideally be in shared memory (mmap) + // but for now we allocate on heap (each worker gets a copy after fork) + const waf_state_ptr = try allocator.create(waf.WafState); + waf_state_ptr.* = waf.WafState.init(); + global_waf_state = waf_state_ptr; + defer if (!global_waf_config.enabled) allocator.destroy(waf_state_ptr); + if (global_waf_config.enabled) { + log.info("WAF state initialized ({d} bytes)", .{@sizeOf(waf.WafState)}); + // Start WAF stats reporter thread + _ = startWafStatsThread() catch |err| { + log.warn("Failed to start WAF stats thread: {s}", .{@errorName(err)}); + }; + } + // Initialize backends in shared region initSharedBackends(region, mutable_lb_config.backends); @@ -669,6 +850,9 @@ fn workerMain( defer mp_server.deinit(); try mp_server.serve(io, &router, &socket); + + // Clean up H2 connection pool on shutdown + h2_pool_new.deinit(io); } // ============================================================================ @@ -725,6 +909,19 @@ fn runSingleProcess(parent_allocator: std.mem.Allocator, config: Config) !void { const lb_config = config.lbConfig; + // Initialize WAF state on heap (too large for stack: ~4MB) + const waf_state_ptr = try allocator.create(waf.WafState); + waf_state_ptr.* = waf.WafState.init(); + global_waf_state = waf_state_ptr; + defer allocator.destroy(waf_state_ptr); + if (global_waf_config.enabled) { + log.info("WAF state initialized ({d} bytes)", .{@sizeOf(waf.WafState)}); + // Start WAF stats reporter thread + _ = startWafStatsThread() catch |err| { + log.warn("Failed to start WAF stats thread: {s}", .{@errorName(err)}); + }; + } + for (lb_config.backends, 0..) |b, idx| { log.info(" Backend {d}: {s}:{d}", .{ idx + 1, b.host, b.port }); } @@ -817,6 +1014,9 @@ fn runSingleProcess(parent_allocator: std.mem.Allocator, config: Config) !void { defer sp_server.deinit(); try sp_server.serve(io, &router, &socket); + + // Clean up H2 connection pool on shutdown + h2_pool_sp.deinit(io); } fn setupSignalHandlers() void { diff --git a/src/http/http2/pool.zig b/src/http/http2/pool.zig index 5ee77f7..9170ee8 100644 --- a/src/http/http2/pool.zig +++ b/src/http/http2/pool.zig @@ -19,8 +19,9 @@ const UltraSock = ultra_sock_mod.UltraSock; const Protocol = ultra_sock_mod.Protocol; const TlsOptions = ultra_sock_mod.TlsOptions; const BackendServer = @import("../../core/types.zig").BackendServer; +const telemetry = @import("../../telemetry/mod.zig"); -const MAX_CONNECTIONS_PER_BACKEND: usize = 16; +const MAX_CONNECTIONS_PER_BACKEND: usize = 32; const MAX_BACKENDS: usize = 64; /// Idle timeout for connections (30 seconds in nanoseconds) @@ -77,7 +78,7 @@ pub const H2ConnectionPool = struct { /// 3. If found and healthy, return it (stays available for more requests) /// 4. If not found, find empty slot and create fresh /// 5. Unlock mutex on return - pub fn getOrCreate(self: *Self, backend_idx: u32, io: Io) !*H2Connection { + pub fn getOrCreate(self: *Self, backend_idx: u32, io: Io, trace_span: ?*telemetry.Span) !*H2Connection { std.debug.assert(backend_idx < self.backends.len); // Lock per-backend mutex to prevent concurrent creation race @@ -109,7 +110,7 @@ pub const H2ConnectionPool = struct { // Phase 2: Create new connection in empty slot for (&self.slot_state[backend_idx], &self.slots[backend_idx], 0..) |*state, *slot, i| { if (state.* == .empty) { - const conn = try self.createFreshConnection(backend_idx, io); + const conn = try self.createFreshConnection(backend_idx, io, trace_span); conn.ref_count.store(1, .release); // First user slot.* = conn; state.* = .available; @@ -118,8 +119,8 @@ pub const H2ConnectionPool = struct { } } - // All slots full - handler will retry on TooManyStreams - log.warn("Connection pool exhausted: backend={d}", .{backend_idx}); + // All slots full - this is an error condition (causes 503s under load) + log.err("H2 pool exhausted: backend={d} (32 conns x 8 streams = 256 max concurrent)", .{backend_idx}); return error.PoolExhausted; } @@ -175,7 +176,7 @@ pub const H2ConnectionPool = struct { } /// Create fresh connection with TLS and HTTP/2 handshake - fn createFreshConnection(self: *Self, backend_idx: u32, io: Io) !*H2Connection { + fn createFreshConnection(self: *Self, backend_idx: u32, io: Io, trace_span: ?*telemetry.Span) !*H2Connection { std.debug.assert(backend_idx < self.backends.len); const backend = &self.backends[backend_idx]; @@ -186,7 +187,9 @@ pub const H2ConnectionPool = struct { // Create socket from backend server const protocol: Protocol = if (backend.isHttps()) .https else .http; const tls_options = TlsOptions.fromRuntimeWithHttp2(); - const sock = UltraSock.initWithTls(protocol, backend.getHost(), backend.port, tls_options); + var sock = UltraSock.initWithTls(protocol, backend.getHost(), backend.port, tls_options); + // Set trace span for detailed connection phase tracing (DNS, TCP, TLS) + sock.trace_span = trace_span; // Initialize H2Connection with per-connection buffers conn.* = try H2Connection.init(sock, backend_idx, self.allocator); @@ -205,7 +208,13 @@ pub const H2ConnectionPool = struct { conn.sock.enableKeepalive() catch {}; // Perform HTTP/2 handshake (send preface + SETTINGS) + var h2_handshake_span = if (trace_span) |parent| + telemetry.startChildSpan(parent, "h2_handshake", .Internal) + else + telemetry.Span{ .inner = null, .tracer = null, .allocator = undefined }; + defer h2_handshake_span.end(); try conn.connect(io); + h2_handshake_span.setOk(); log.debug("Fresh H2 connection established: backend={d}", .{backend_idx}); return conn; diff --git a/src/http/ultra_sock.zig b/src/http/ultra_sock.zig index dd9a862..6a5248d 100644 --- a/src/http/ultra_sock.zig +++ b/src/http/ultra_sock.zig @@ -10,6 +10,7 @@ const tls_mod = @import("tls.zig"); const tls = tls_mod.tls_lib; const config_mod = @import("../core/config.zig"); +const telemetry = @import("../telemetry/mod.zig"); // Re-export TlsOptions for external use pub const TlsOptions = tls_mod.TlsOptions; @@ -46,6 +47,9 @@ pub const UltraSock = struct { // Cached read timeout for restoration after temporary changes read_timeout_ms: u32 = 1000, // Default 1 second + // Optional parent span for tracing connection phases + trace_span: ?*telemetry.Span = null, + /// Get the current read timeout in milliseconds pub fn getReadTimeout(self: *const UltraSock) u32 { return self.read_timeout_ms; @@ -97,10 +101,19 @@ pub const UltraSock = struct { // Resolve address - first try as IP, then DNS resolution const addr = Io.net.IpAddress.parse(self.host, self.port) catch blk: { // Not a raw IP, try DNS resolution using getaddrinfo + // DNS resolution span + var dns_span = if (self.trace_span) |parent| + telemetry.startChildSpan(parent, "dns_resolution", .Internal) + else + telemetry.Span{ .inner = null, .tracer = null, .allocator = undefined }; + defer dns_span.end(); + dns_span.setStringAttribute("dns.hostname", self.host); + if (trace_enabled) { tls_log.debug(" Resolving DNS for {s}...", .{self.host}); } const resolved = resolveDns(self.host, self.port) catch { + dns_span.setBoolAttribute("dns.success", false); if (trace_enabled) { tls_log.err("!!! DNS resolution failed for {s}", .{self.host}); } else { @@ -108,22 +121,33 @@ pub const UltraSock = struct { } return error.InvalidAddress; }; + dns_span.setBoolAttribute("dns.success", true); if (trace_enabled) { tls_log.debug(" DNS resolved {s}", .{self.host}); } break :blk resolved; }; - // Connect using std.Io + // TCP connect span + var tcp_span = if (self.trace_span) |parent| + telemetry.startChildSpan(parent, "tcp_connect", .Internal) + else + telemetry.Span{ .inner = null, .tracer = null, .allocator = undefined }; + defer tcp_span.end(); + tcp_span.setStringAttribute("net.peer.name", self.host); + tcp_span.setIntAttribute("net.peer.port", @intCast(self.port)); + if (trace_enabled) { tls_log.debug(" TCP connecting to {s}:{}...", .{ self.host, self.port }); } self.stream = addr.connect(io, .{ .mode = .stream }) catch { + tcp_span.setBoolAttribute("tcp.success", false); if (trace_enabled) { tls_log.err("!!! TCP connect failed to {s}:{}", .{ self.host, self.port }); } return error.ConnectionFailed; }; + tcp_span.setBoolAttribute("tcp.success", true); errdefer self.closeStream(); if (trace_enabled) { @@ -162,10 +186,19 @@ pub const UltraSock = struct { } const addr = Io.net.IpAddress.parse(self.host, self.port) catch blk: { + // DNS resolution span + var dns_span = if (self.trace_span) |parent| + telemetry.startChildSpan(parent, "dns_resolution", .Internal) + else + telemetry.Span{ .inner = null, .tracer = null, .allocator = undefined }; + defer dns_span.end(); + dns_span.setStringAttribute("dns.hostname", self.host); + if (trace_enabled) { tls_log.debug(" Resolving DNS for {s}...", .{self.host}); } const resolved = resolveDns(self.host, self.port) catch { + dns_span.setBoolAttribute("dns.success", false); if (trace_enabled) { tls_log.err("!!! DNS resolution failed for {s}", .{self.host}); } else { @@ -173,21 +206,33 @@ pub const UltraSock = struct { } return error.InvalidAddress; }; + dns_span.setBoolAttribute("dns.success", true); if (trace_enabled) { tls_log.debug(" DNS resolved {s}", .{self.host}); } break :blk resolved; }; + // TCP connect span + var tcp_span = if (self.trace_span) |parent| + telemetry.startChildSpan(parent, "tcp_connect", .Internal) + else + telemetry.Span{ .inner = null, .tracer = null, .allocator = undefined }; + defer tcp_span.end(); + tcp_span.setStringAttribute("net.peer.name", self.host); + tcp_span.setIntAttribute("net.peer.port", @intCast(self.port)); + if (trace_enabled) { tls_log.debug(" TCP connecting to {s}:{}...", .{ self.host, self.port }); } self.stream = addr.connect(io, .{ .mode = .stream }) catch { + tcp_span.setBoolAttribute("tcp.success", false); if (trace_enabled) { tls_log.err("!!! TCP connect failed to {s}:{}", .{ self.host, self.port }); } return error.ConnectionFailed; }; + tcp_span.setBoolAttribute("tcp.success", true); errdefer self.closeStream(); if (trace_enabled) { @@ -301,6 +346,14 @@ pub const UltraSock = struct { ) !void { const stream = self.stream orelse return error.SocketNotInitialized; + // TLS handshake span + var tls_span = if (self.trace_span) |parent| + telemetry.startChildSpan(parent, "tls_handshake", .Internal) + else + telemetry.Span{ .inner = null, .tracer = null, .allocator = undefined }; + defer tls_span.end(); + tls_span.setStringAttribute("tls.server_name", self.host); + try tls_mod.ensureCaBundleLoaded(io, self.tls_options.ca); self.stream_reader = stream.reader(io, input_buf); @@ -342,6 +395,8 @@ pub const UltraSock = struct { &self.stream_writer.interface, client_opts, ) catch |err| { + tls_span.setBoolAttribute("tls.success", false); + tls_span.setStringAttribute("tls.error", @errorName(err)); if (trace_enabled) { tls_log.err("!!! TLS handshake FAILED: {} for {s}:{}", .{ err, self.host, self.port }); } else { @@ -349,16 +404,20 @@ pub const UltraSock = struct { } return error.TlsHandshakeFailed; }; + tls_span.setBoolAttribute("tls.success", true); // Convert ALPN to copy-safe enum (TigerStyle: no dangling pointers after copy) if (self.tls_diagnostic.negotiated_alpn) |alpn| { self.negotiated_protocol = if (std.mem.eql(u8, alpn, "h2")) .http2 else .http1_1; + tls_span.setStringAttribute("tls.alpn", alpn); } else { self.negotiated_protocol = .http1_1; + tls_span.setStringAttribute("tls.alpn", "http/1.1"); } // Log TLS connection details const cipher_name = @tagName(self.tls_conn.?.cipher); + tls_span.setStringAttribute("tls.cipher", cipher_name); if (trace_enabled) { // Detailed TLS trace logging from Diagnostic struct diff --git a/src/main.zig b/src/main.zig new file mode 100644 index 0000000..624eef5 --- /dev/null +++ b/src/main.zig @@ -0,0 +1,27 @@ +const std = @import("std"); +const zzz_fix = @import("zzz_fix"); + +pub fn main() !void { + // Prints to stderr, ignoring potential errors. + std.debug.print("All your {s} are belong to us.\n", .{"codebase"}); + try zzz_fix.bufferedPrint(); +} + +test "simple test" { + const gpa = std.testing.allocator; + var list: std.ArrayList(i32) = .empty; + defer list.deinit(gpa); // Try commenting this out and see if zig detects the memory leak! + try list.append(gpa, 42); + try std.testing.expectEqual(@as(i32, 42), list.pop()); +} + +test "fuzz example" { + const Context = struct { + fn testOne(context: @This(), input: []const u8) anyerror!void { + _ = context; + // Try passing `--fuzz` to `zig build test` and see if it manages to fail this test case! + try std.testing.expect(!std.mem.eql(u8, "canyoufindme", input)); + } + }; + try std.testing.fuzz(Context{}, Context.testOne, .{}); +} diff --git a/src/proxy/connection.zig b/src/proxy/connection.zig index b643275..1642aa3 100644 --- a/src/proxy/connection.zig +++ b/src/proxy/connection.zig @@ -20,6 +20,7 @@ const ultra_sock_mod = @import("../http/ultra_sock.zig"); const UltraSock = ultra_sock_mod.UltraSock; const metrics = @import("../metrics/mod.zig"); const WorkerState = @import("../lb/worker.zig").WorkerState; +const telemetry = @import("../telemetry/mod.zig"); // Re-export ProxyState for convenience pub const ProxyState = @import("handler.zig").ProxyState; @@ -41,6 +42,7 @@ pub fn acquireConnection( backend_idx: u32, state: *WorkerState, req_id: u32, + trace_span: ?*telemetry.Span, ) ProxyError!ProxyState { // Prevent bitmap overflow in circuit breaker health tracking. std.debug.assert(backend_idx < MAX_BACKENDS); @@ -73,6 +75,7 @@ pub fn acquireConnection( .body_had_error = false, .client_write_error = false, .backend_wants_close = false, + .trace_span = null, // Set by handler after acquisition. }; proxy_state.tls_conn_ptr = proxy_state.sock.getTlsConnection(); return proxy_state; @@ -85,6 +88,8 @@ pub fn acquireConnection( // Create fresh connection var sock = UltraSock.fromBackendServerWithHttp2(backend); + // Set trace span for detailed connection phase tracing (DNS, TCP, TLS) + sock.trace_span = trace_span; sock.connect(ctx.io) catch { sock.close_blocking(); return ProxyError.BackendUnavailable; @@ -119,6 +124,7 @@ pub fn acquireConnection( .body_had_error = false, .client_write_error = false, .backend_wants_close = false, + .trace_span = null, // Set by handler after acquisition. }; proxy_state.tls_conn_ptr = proxy_state.sock.getTlsConnection(); return proxy_state; diff --git a/src/proxy/handler.zig b/src/proxy/handler.zig index 3421675..539bb2f 100644 --- a/src/proxy/handler.zig +++ b/src/proxy/handler.zig @@ -17,6 +17,10 @@ const log = std.log.scoped(.mp); const zzz = @import("zzz"); const http = zzz.HTTP; +const telemetry = @import("../telemetry/mod.zig"); +const main = @import("../../main.zig"); +const waf = @import("../waf/mod.zig"); + const config = @import("../core/config.zig"); const types = @import("../core/types.zig"); const ultra_sock_mod = @import("../http/ultra_sock.zig"); @@ -53,12 +57,30 @@ pub const ProxyError = error{ ReadFailed, Timeout, EmptyResponse, + PoolExhausted, // Local resource issue, not a backend failure InvalidResponse, /// HTTP/2 GOAWAY exhausted retries - NOT a backend health failure /// Server gracefully closed connection, just need fresh connection GoawayRetriesExhausted, }; +// ============================================================================ +// IP Address Formatting +// ============================================================================ + +/// Format an IPv4 address (u32 in host byte order) as a string +/// Returns a static buffer - only valid until next call +var ip_format_buf: [16]u8 = undefined; +fn formatIpv4(ip: u32) []const u8 { + const len = std.fmt.bufPrint(&ip_format_buf, "{d}.{d}.{d}.{d}", .{ + @as(u8, @truncate(ip >> 24)), + @as(u8, @truncate(ip >> 16)), + @as(u8, @truncate(ip >> 8)), + @as(u8, @truncate(ip)), + }) catch return "0.0.0.0"; + return ip_format_buf[0..len.len]; +} + // ============================================================================ // Connection State (TigerStyle: explicit struct for clarity) // ============================================================================ @@ -77,6 +99,8 @@ pub const ProxyState = struct { body_had_error: bool, client_write_error: bool, backend_wants_close: bool, + /// Parent span for tracing (optional, for creating child spans) + trace_span: ?*telemetry.Span, /// TigerStyle: assertion for valid state (called at multiple points). pub fn assertValid(self: *const ProxyState) void { @@ -118,6 +142,94 @@ pub fn generateHandler( } } + // Start a trace span for this request (before WAF so blocked requests are traced) + const method = ctx.request.method orelse .GET; + const uri = ctx.request.uri orelse "/"; + var span = telemetry.startServerSpan("proxy_request"); + defer span.end(); + span.setStringAttribute("http.method", @tagName(method)); + span.setStringAttribute("http.url", uri); + + // WAF check - before any backend processing + if (main.getWafEngine()) |engine_val| { + // Make a mutable copy since check() requires *WafEngine + var engine = engine_val; + + // Extract source IP from connection (use 0 if not available) + // In production, you'd extract this from the socket address + const source_ip: u32 = 0; // TODO: Extract from ctx.connection when available + + // Build WAF request with body size for content-length validation + const waf_method = convertHttpMethod(ctx.request.method orelse .GET); + const body_len: ?usize = if (ctx.request.body) |b| b.len else null; + const waf_request = if (body_len) |len| + waf.Request.withContentLength( + waf_method, + ctx.request.uri orelse "/", + source_ip, + len, + ) + else + waf.Request.init( + waf_method, + ctx.request.uri orelse "/", + source_ip, + ); + + // Check WAF rules (with tracing - creates child spans for each step) + const waf_result = engine.checkWithSpan(&waf_request, &span, telemetry); + + // Add WAF summary attributes to parent span + const client_ip_str = formatIpv4(source_ip); + span.setStringAttribute("waf.decision", if (waf_result.isBlocked()) "block" else if (waf_result.shouldLog()) "log_only" else "allow"); + span.setStringAttribute("waf.client_ip", client_ip_str); + if (waf_result.reason != .none) { + span.setStringAttribute("waf.reason", waf_result.reason.description()); + } + if (waf_result.rule_name) |rule| { + span.setStringAttribute("waf.rule", rule); + } + + if (waf_result.isBlocked()) { + // Request blocked by WAF + log.warn("[W{d}] WAF BLOCKED: ip={s} uri={s} reason={s} rule={s}", .{ + state.worker_id, + client_ip_str, + ctx.request.uri orelse "/", + waf_result.reason.description(), + waf_result.rule_name orelse "N/A", + }); + + // Return appropriate status based on block reason + const status: http.Status = switch (waf_result.reason) { + .rate_limit => .@"Too Many Requests", + .body_too_large => .@"Content Too Large", + else => .Forbidden, + }; + + span.setIntAttribute("http.status_code", @intFromEnum(status)); + span.setError("WAF blocked"); + + return ctx.response.apply(.{ + .status = status, + .mime = http.Mime.JSON, + .body = "{\"error\":\"blocked by WAF\"}", + }); + } + + // Log if shadow mode decision was made + if (waf_result.shouldLog()) { + log.info("[W{d}] WAF SHADOW: ip={s} uri={s} reason={s} rule={s}", .{ + state.worker_id, + client_ip_str, + ctx.request.uri orelse "/", + waf_result.reason.description(), + waf_result.rule_name orelse "N/A", + }); + span.addEvent("waf_shadow_block"); + } + } + // Use dynamic backend count (from shared region if available) const backend_count = state.getBackendCount(); std.debug.assert(backend_count <= MAX_BACKENDS); @@ -129,6 +241,8 @@ pub fn generateHandler( }); if (backend_count == 0) { + span.setError("No backends configured"); + span.setIntAttribute("http.status_code", 503); return ctx.response.apply(.{ .status = .@"Service Unavailable", .mime = http.Mime.TEXT, @@ -136,7 +250,17 @@ pub fn generateHandler( }); } + // Backend selection with tracing + var selection_span = telemetry.startChildSpan(&span, "backend_selection", .Internal); + selection_span.setIntAttribute("lb.backend_count", @intCast(backend_count)); + selection_span.setIntAttribute("lb.healthy_count", @intCast(state.circuit_breaker.countHealthy())); + selection_span.setStringAttribute("lb.strategy", @tagName(strategy)); + const backend_idx = state.selectBackend(strategy) orelse { + selection_span.setError("No healthy backends"); + selection_span.end(); + span.setError("No backends available"); + span.setIntAttribute("http.status_code", 503); log.warn("[W{d}] selectBackend returned null", .{state.worker_id}); return ctx.response.apply(.{ .status = .@"Service Unavailable", @@ -145,12 +269,26 @@ pub fn generateHandler( }); }; + selection_span.setIntAttribute("lb.selected_backend", @intCast(backend_idx)); + selection_span.setOk(); + selection_span.end(); + // Prevent out-of-bounds access to backends array and bitmap. std.debug.assert(backend_idx < backend_count); std.debug.assert(backend_idx < MAX_BACKENDS); log.debug("[W{d}] Selected backend {d}", .{ state.worker_id, backend_idx }); - return proxyWithFailover(ctx, @intCast(backend_idx), state); + const result = proxyWithFailover(ctx, @intCast(backend_idx), state, &span); + + // Update parent span with final status + if (result) |response| { + span.setOk(); + return response; + } else |err| { + span.setError(@errorName(err)); + span.setIntAttribute("http.status_code", 503); + return err; + } } }.handle; } @@ -160,6 +298,7 @@ inline fn proxyWithFailover( ctx: *const http.Context, primary_idx: u32, state: *WorkerState, + trace_span: *telemetry.Span, ) !http.Respond { // Prevent out-of-bounds access to backends array and bitmap. std.debug.assert(primary_idx < MAX_BACKENDS); @@ -168,27 +307,31 @@ inline fn proxyWithFailover( // Try to get backend from shared region (hot reload) or fall back to local if (state.getSharedBackend(primary_idx)) |shared_backend| { - if (streamingProxyShared(ctx, shared_backend, primary_idx, state)) |response| { + if (streamingProxyShared(ctx, shared_backend, primary_idx, state, trace_span)) |response| { state.recordSuccess(primary_idx); + trace_span.setIntAttribute("http.status_code", 200); return response; } else |err| { - // GOAWAY exhaustion is NOT a backend failure - just connection-level flow control - if (err != ProxyError.GoawayRetriesExhausted) { + // GOAWAY/PoolExhausted are NOT backend failures - just connection-level issues + if (err != ProxyError.GoawayRetriesExhausted and err != ProxyError.PoolExhausted) { state.recordFailure(primary_idx); } + trace_span.addEvent("primary_backend_failed"); log.warn("[W{d}] Backend {d} failed: {s}", .{ state.worker_id, primary_idx + 1, @errorName(err) }); } } else { // Fall back to local backends list const backends = state.backends; - if (streamingProxy(ctx, &backends.items[primary_idx], primary_idx, state)) |response| { + if (streamingProxy(ctx, &backends.items[primary_idx], primary_idx, state, trace_span)) |response| { state.recordSuccess(primary_idx); + trace_span.setIntAttribute("http.status_code", 200); return response; } else |err| { - // GOAWAY exhaustion is NOT a backend failure - just connection-level flow control - if (err != ProxyError.GoawayRetriesExhausted) { + // GOAWAY/PoolExhausted are NOT backend failures - just connection-level issues + if (err != ProxyError.GoawayRetriesExhausted and err != ProxyError.PoolExhausted) { state.recordFailure(primary_idx); } + trace_span.addEvent("primary_backend_failed"); log.warn("[W{d}] Backend {d} failed: {s}", .{ state.worker_id, primary_idx + 1, @errorName(err) }); } } @@ -200,16 +343,19 @@ inline fn proxyWithFailover( log.debug("[W{d}] Failing over to backend {d}", .{ state.worker_id, failover_idx + 1 }); metrics.global_metrics.recordFailover(); + trace_span.addEvent("failover_started"); + trace_span.setIntAttribute("lb.failover_backend", @intCast(failover_idx)); const failover_u32: u32 = @intCast(failover_idx); if (state.getSharedBackend(failover_idx)) |shared_backend| { - if (streamingProxyShared(ctx, shared_backend, failover_u32, state)) |response| { + if (streamingProxyShared(ctx, shared_backend, failover_u32, state, trace_span)) |response| { state.recordSuccess(failover_idx); + trace_span.setIntAttribute("http.status_code", 200); return response; } else |failover_err| { - // GOAWAY exhaustion is NOT a backend failure - if (failover_err != ProxyError.GoawayRetriesExhausted) { + // GOAWAY/PoolExhausted are NOT backend failures + if (failover_err != ProxyError.GoawayRetriesExhausted and failover_err != ProxyError.PoolExhausted) { state.recordFailure(failover_idx); } const err_name = @errorName(failover_err); @@ -222,12 +368,13 @@ inline fn proxyWithFailover( } else { const backends = state.backends; const backend = &backends.items[failover_idx]; - if (streamingProxy(ctx, backend, failover_u32, state)) |response| { + if (streamingProxy(ctx, backend, failover_u32, state, trace_span)) |response| { state.recordSuccess(failover_idx); + trace_span.setIntAttribute("http.status_code", 200); return response; } else |failover_err| { - // GOAWAY exhaustion is NOT a backend failure - if (failover_err != ProxyError.GoawayRetriesExhausted) { + // GOAWAY/PoolExhausted are NOT backend failures + if (failover_err != ProxyError.GoawayRetriesExhausted and failover_err != ProxyError.PoolExhausted) { state.recordFailure(failover_idx); } const err_name = @errorName(failover_err); @@ -240,6 +387,10 @@ inline fn proxyWithFailover( } } + // All backends exhausted - this IS an error (user gets 503) + log.err("[W{d}] All backends exhausted, returning 503", .{state.worker_id}); + trace_span.setError("All backends exhausted"); + return ctx.response.apply(.{ .status = .@"Service Unavailable", .mime = http.Mime.TEXT, @@ -272,6 +423,7 @@ inline fn streamingProxy( backend: *const types.BackendServer, backend_idx: u32, state: *WorkerState, + trace_span: *telemetry.Span, ) ProxyError!http.Respond { // Prevent bitmap overflow in circuit breaker health tracking. std.debug.assert(backend_idx < MAX_BACKENDS); @@ -302,11 +454,17 @@ inline fn streamingProxy( .body_had_error = false, .client_write_error = false, .backend_wants_close = false, + .trace_span = trace_span, }; return streamingProxyHttp2(ctx, backend, &proxy_state, state, backend_idx, start_ns, req_id); } - // Phase 1: Acquire connection (HTTP/1.1 path only now). + // Phase 1: Acquire connection (HTTP/1.1 path only now) with tracing. + var conn_span = telemetry.startChildSpan(trace_span, "backend_connection", .Client); + conn_span.setStringAttribute("backend.host", backend.getHost()); + conn_span.setIntAttribute("backend.port", @intCast(backend.port)); + conn_span.setBoolAttribute("backend.tls", backend.isHttps()); + var proxy_state = proxy_connection.acquireConnection( types.BackendServer, ctx, @@ -314,14 +472,22 @@ inline fn streamingProxy( backend_idx, state, req_id, + &conn_span, ) catch |err| { + conn_span.setError(@errorName(err)); + conn_span.end(); return err; }; + conn_span.setBoolAttribute("connection.from_pool", proxy_state.from_pool); + conn_span.setOk(); + conn_span.end(); + // Fix pointers after struct copy - TLS connection and stream reader/writer // have internal pointers that become dangling when UltraSock is copied by value. proxy_state.sock.fixAllPointersAfterCopy(ctx.io); proxy_state.tls_conn_ptr = proxy_state.sock.getTlsConnection(); + proxy_state.trace_span = trace_span; // TigerStyle: validate state after acquisition. proxy_state.assertValid(); @@ -331,6 +497,11 @@ inline fn streamingProxy( return streamingProxyHttp2(ctx, backend, &proxy_state, state, backend_idx, start_ns, req_id); } + // Phase 2-5: Backend request/response with tracing + var request_span = telemetry.startChildSpan(trace_span, "backend_request", .Client); + request_span.setStringAttribute("http.method", @tagName(ctx.request.method orelse .GET)); + request_span.setStringAttribute("http.url", ctx.request.uri orelse "/"); + // Phase 2: Send request (HTTP/1.1 path). proxy_request.sendRequest( types.BackendServer, @@ -339,10 +510,14 @@ inline fn streamingProxy( &proxy_state, req_id, ) catch |err| { + request_span.setError(@errorName(err)); + request_span.end(); proxy_state.sock.close_blocking(); return err; }; + request_span.addEvent("request_sent"); + // Phase 3: Read and parse headers. // Safe undefined: buffer fully written by backend read before parsing. var header_buffer: [MAX_HEADER_BYTES]u8 = undefined; @@ -356,10 +531,15 @@ inline fn streamingProxy( &header_end, req_id, ) catch |err| { + request_span.setError(@errorName(err)); + request_span.end(); proxy_state.sock.close_blocking(); return err; }; + request_span.addEvent("headers_received"); + request_span.setIntAttribute("http.status_code", @intCast(proxy_state.status_code)); + // HTTP response must have headers, validate parse succeeded before forwarding. std.debug.assert(header_end > 0); std.debug.assert(header_end <= header_len); @@ -375,13 +555,34 @@ inline fn streamingProxy( msg_len, req_id, ) catch |err| { + request_span.setError(@errorName(err)); + request_span.end(); proxy_state.sock.close_blocking(); return err; }; - // Phase 5: Stream body. + // Phase 5: Stream body with tracing. + var body_span = telemetry.startChildSpan(&request_span, "response_streaming", .Internal); + body_span.setStringAttribute("http.body.type", @tagName(msg_len.type)); + if (msg_len.type == .content_length) { + body_span.setIntAttribute("http.body.expected_length", @intCast(msg_len.length)); + } + proxy_io.streamBody(ctx, &proxy_state, header_end, header_len, msg_len, req_id); + body_span.setIntAttribute("http.body.bytes_transferred", @intCast(proxy_state.bytes_from_backend)); + body_span.setBoolAttribute("http.body.had_error", proxy_state.body_had_error); + if (proxy_state.body_had_error) { + body_span.setError("body_transfer_error"); + } else { + body_span.setOk(); + } + body_span.end(); + + request_span.setIntAttribute("http.response_content_length", @intCast(proxy_state.bytes_from_backend)); + request_span.setOk(); + request_span.end(); + // Phase 6: Finalize and return connection. return streamingProxy_finalize(ctx, &proxy_state, state, backend_idx, start_ns, req_id); } @@ -392,6 +593,7 @@ inline fn streamingProxyShared( backend: *const shared_region.SharedBackend, backend_idx: u32, state: *WorkerState, + trace_span: *telemetry.Span, ) ProxyError!http.Respond { // Prevent bitmap overflow in circuit breaker health tracking. std.debug.assert(backend_idx < MAX_BACKENDS); @@ -424,11 +626,17 @@ inline fn streamingProxyShared( .body_had_error = false, .client_write_error = false, .backend_wants_close = false, + .trace_span = trace_span, }; return streamingProxyHttp2(ctx, backend, &proxy_state, state, backend_idx, start_ns, req_id); } - // Phase 1: Acquire connection (HTTP/1.1 path only now). + // Phase 1: Acquire connection (HTTP/1.1 path only now) with tracing. + var conn_span = telemetry.startChildSpan(trace_span, "backend_connection", .Client); + conn_span.setStringAttribute("backend.host", backend.getHost()); + conn_span.setIntAttribute("backend.port", @intCast(backend.port)); + conn_span.setBoolAttribute("backend.tls", backend.isHttps()); + var proxy_state = proxy_connection.acquireConnection( shared_region.SharedBackend, ctx, @@ -436,14 +644,22 @@ inline fn streamingProxyShared( backend_idx, state, req_id, + &conn_span, ) catch |err| { + conn_span.setError(@errorName(err)); + conn_span.end(); return err; }; + conn_span.setBoolAttribute("connection.from_pool", proxy_state.from_pool); + conn_span.setOk(); + conn_span.end(); + // Fix pointers after struct copy - TLS connection and stream reader/writer // have internal pointers that become dangling when UltraSock is copied by value. proxy_state.sock.fixAllPointersAfterCopy(ctx.io); proxy_state.tls_conn_ptr = proxy_state.sock.getTlsConnection(); + proxy_state.trace_span = trace_span; proxy_state.assertValid(); @@ -452,6 +668,11 @@ inline fn streamingProxyShared( return streamingProxyHttp2(ctx, backend, &proxy_state, state, backend_idx, start_ns, req_id); } + // Phase 2-5: Backend request/response with tracing + var request_span = telemetry.startChildSpan(trace_span, "backend_request", .Client); + request_span.setStringAttribute("http.method", @tagName(ctx.request.method orelse .GET)); + request_span.setStringAttribute("http.url", ctx.request.uri orelse "/"); + // Phase 2: Send request (HTTP/1.1 path). proxy_request.sendRequest( shared_region.SharedBackend, @@ -460,10 +681,14 @@ inline fn streamingProxyShared( &proxy_state, req_id, ) catch |err| { + request_span.setError(@errorName(err)); + request_span.end(); proxy_state.sock.close_blocking(); return err; }; + request_span.addEvent("request_sent"); + // Phase 3-6: Same as regular streamingProxy (backend-agnostic) var header_buffer: [MAX_HEADER_BYTES]u8 = undefined; var header_len: u32 = 0; @@ -476,10 +701,15 @@ inline fn streamingProxyShared( &header_end, req_id, ) catch |err| { + request_span.setError(@errorName(err)); + request_span.end(); proxy_state.sock.close_blocking(); return err; }; + request_span.addEvent("headers_received"); + request_span.setIntAttribute("http.status_code", @intCast(proxy_state.status_code)); + std.debug.assert(header_end > 0); std.debug.assert(header_end <= header_len); @@ -493,12 +723,34 @@ inline fn streamingProxyShared( msg_len, req_id, ) catch |err| { + request_span.setError(@errorName(err)); + request_span.end(); proxy_state.sock.close_blocking(); return err; }; + // Stream body with tracing + var body_span = telemetry.startChildSpan(&request_span, "response_streaming", .Internal); + body_span.setStringAttribute("http.body.type", @tagName(msg_len.type)); + if (msg_len.type == .content_length) { + body_span.setIntAttribute("http.body.expected_length", @intCast(msg_len.length)); + } + proxy_io.streamBody(ctx, &proxy_state, header_end, header_len, msg_len, req_id); + body_span.setIntAttribute("http.body.bytes_transferred", @intCast(proxy_state.bytes_from_backend)); + body_span.setBoolAttribute("http.body.had_error", proxy_state.body_had_error); + if (proxy_state.body_had_error) { + body_span.setError("body_transfer_error"); + } else { + body_span.setOk(); + } + body_span.end(); + + request_span.setIntAttribute("http.response_content_length", @intCast(proxy_state.bytes_from_backend)); + request_span.setOk(); + request_span.end(); + return streamingProxy_finalize(ctx, &proxy_state, state, backend_idx, start_ns, req_id); } @@ -518,9 +770,16 @@ fn forwardH2Response( backend_idx: u32, start_ns: ?std.time.Instant, req_id: u32, + parent_span: *telemetry.Span, ) ProxyError!http.Respond { const body = response.getBody(); + // Create response streaming span + var body_span = telemetry.startChildSpan(parent_span, "response_streaming", .Internal); + defer body_span.end(); + body_span.setStringAttribute("http.body.type", "h2_buffered"); + body_span.setIntAttribute("http.body.length", @intCast(body.len)); + // Update proxy state with response info proxy_state.status_code = response.status; proxy_state.bytes_from_backend = @intCast(body.len); @@ -573,6 +832,10 @@ fn forwardH2Response( metrics.global_metrics.recordRequest(duration_ms, proxy_state.status_code); + // Update span with result + body_span.setIntAttribute("http.body.bytes_written", @intCast(proxy_state.bytes_to_client)); + body_span.setBoolAttribute("http.body.had_error", proxy_state.client_write_error); + // Return response type if (proxy_state.client_write_error) { log.debug("[REQ {d}] => .close (client write error)", .{req_id}); @@ -597,8 +860,23 @@ fn streamingProxyHttp2( ) ProxyError!http.Respond { log.debug("[REQ {d}] HTTP/2 request to backend {d}", .{ req_id, backend_idx }); + // Start HTTP/2 request span if we have a trace context + var h2_span = if (proxy_state.trace_span) |trace_span| + telemetry.startChildSpan(trace_span, "backend_request_h2", .Client) + else + telemetry.Span{ .inner = null, .tracer = null, .allocator = undefined }; + defer h2_span.end(); + + h2_span.setStringAttribute("http.method", @tagName(ctx.request.method orelse .GET)); + h2_span.setStringAttribute("http.url", ctx.request.uri orelse "/"); + h2_span.setStringAttribute("http.flavor", "2.0"); + h2_span.setStringAttribute("backend.host", backend.getFullHost()); + // Get pool (must exist) - const pool = state.h2_pool orelse return ProxyError.ConnectionFailed; + const pool = state.h2_pool orelse { + h2_span.setError("No H2 pool"); + return ProxyError.ConnectionFailed; + }; // Retry loop for TooManyStreams (connection full, not broken) const h2_conn = @import("../http/http2/connection.zig"); @@ -617,11 +895,19 @@ fn streamingProxyHttp2( last_was_goaway = false; // Get or create connection (pool handles everything: TLS, handshake, retry) - const conn = pool.getOrCreate(backend_idx, ctx.io) catch |err| { - log.warn("[REQ {d}] H2 pool getOrCreate failed: {}", .{ req_id, err }); + const conn = pool.getOrCreate(backend_idx, ctx.io, &h2_span) catch |err| { + h2_span.setError("Pool getOrCreate failed"); + // PoolExhausted is a local resource issue, not a backend failure + // Pool already logs at error level, so just return the error + if (err == error.PoolExhausted) { + return ProxyError.PoolExhausted; + } + log.err("[REQ {d}] H2 pool getOrCreate failed: {}", .{ req_id, err }); return ProxyError.ConnectionFailed; }; + h2_span.addEvent("connection_acquired"); + // Make request (connection handles: send, reader spawn, await) response = conn.request( @tagName(ctx.request.method orelse .GET), @@ -633,6 +919,7 @@ fn streamingProxyHttp2( if (err == error.TooManyStreams) { // Connection full (all 8 slots busy) - release as healthy and retry log.debug("[REQ {d}] Connection full, retrying...", .{req_id}); + h2_span.addEvent("retry_too_many_streams"); pool.release(conn, true, ctx.io); continue; } @@ -640,12 +927,14 @@ fn streamingProxyHttp2( // GOAWAY received - get fresh connection and retry // NOT a failure - just graceful connection shutdown log.debug("[REQ {d}] GOAWAY, retrying on fresh connection", .{req_id}); + h2_span.addEvent("retry_goaway"); pool.release(conn, false, ctx.io); // Destroy this conn, but NOT a backend failure last_was_goaway = true; continue; } // Other errors - connection is broken log.warn("[REQ {d}] H2 request failed: {}", .{ req_id, err }); + h2_span.setError(@errorName(err)); pool.release(conn, false, ctx.io); return ProxyError.SendFailed; }; @@ -659,6 +948,7 @@ fn streamingProxyHttp2( // GOAWAY exhausted retries - NOT a backend health failure // Server is healthy, just aggressively closing connections under load log.debug("[REQ {d}] GOAWAY exhausted retries (not a failure)", .{req_id}); + h2_span.setError("GOAWAY retries exhausted"); return ProxyError.GoawayRetriesExhausted; } log.warn("[REQ {d}] H2 request failed after {d} retries", .{ req_id, attempts }); @@ -671,8 +961,14 @@ fn streamingProxyHttp2( proxy_state.status_code = response.status; proxy_state.bytes_from_backend = @intCast(response.body.items.len); + // Update span with response info + h2_span.setIntAttribute("http.status_code", @intCast(response.status)); + h2_span.setIntAttribute("http.response_content_length", @intCast(response.body.items.len)); + h2_span.setIntAttribute("h2.retry_count", @intCast(attempts)); + h2_span.setOk(); + // Forward response to client (connection released via defer above) - return forwardH2Response(ctx, proxy_state, &response, backend_idx, start_ns, req_id); + return forwardH2Response(ctx, proxy_state, &response, backend_idx, start_ns, req_id, &h2_span); } // ============================================================================ @@ -769,3 +1065,22 @@ fn streamingProxy_finalize( log.debug("[REQ {d}] => .responded", .{req_id}); return .responded; } + +// ============================================================================ +// WAF Helper Functions +// ============================================================================ + +/// Convert zzz HTTP method to WAF HTTP method +fn convertHttpMethod(method: http.Method) waf.HttpMethod { + return switch (method) { + .GET => .GET, + .POST => .POST, + .PUT => .PUT, + .DELETE => .DELETE, + .PATCH => .PATCH, + .HEAD => .HEAD, + .OPTIONS => .OPTIONS, + .TRACE => .TRACE, + .CONNECT => .CONNECT, + }; +} diff --git a/src/root.zig b/src/root.zig new file mode 100644 index 0000000..94c7cd0 --- /dev/null +++ b/src/root.zig @@ -0,0 +1,23 @@ +//! By convention, root.zig is the root source file when making a library. +const std = @import("std"); + +pub fn bufferedPrint() !void { + // Stdout is for the actual output of your application, for example if you + // are implementing gzip, then only the compressed bytes should be sent to + // stdout, not any debugging messages. + var stdout_buffer: [1024]u8 = undefined; + var stdout_writer = std.fs.File.stdout().writer(&stdout_buffer); + const stdout = &stdout_writer.interface; + + try stdout.print("Run `zig build test` to run the tests.\n", .{}); + + try stdout.flush(); // Don't forget to flush! +} + +pub fn add(a: i32, b: i32) i32 { + return a + b; +} + +test "basic add functionality" { + try std.testing.expect(add(3, 7) == 10); +} diff --git a/src/telemetry/mod.zig b/src/telemetry/mod.zig new file mode 100644 index 0000000..1c8f733 --- /dev/null +++ b/src/telemetry/mod.zig @@ -0,0 +1,313 @@ +//! OpenTelemetry Tracing for Load Balancer +//! +//! Provides distributed tracing for request lifecycle using the zig-o11y SDK. +//! +//! Usage: +//! // Initialize at startup with OTLP endpoint +//! try telemetry.init(allocator, "localhost:4318"); +//! defer telemetry.deinit(); +//! +//! // Create root span for request +//! var root_span = telemetry.startServerSpan("proxy_request"); +//! defer root_span.end(); +//! root_span.setStringAttribute("http.method", "GET"); +//! +//! // Create child spans for sub-operations +//! var child = telemetry.startChildSpan(&root_span, "backend_connection", .Client); +//! defer child.end(); +//! child.setStringAttribute("backend.host", "127.0.0.1"); + +const std = @import("std"); +const otel = @import("opentelemetry"); + +const log = std.log.scoped(.telemetry); + +/// Global telemetry state +var global_state: ?*TelemetryState = null; + +const TelemetryState = struct { + allocator: std.mem.Allocator, + config: *otel.otlp.ConfigOptions, + exporter: *otel.trace.OTLPExporter, + processor: *otel.trace.BatchingProcessor, // Background thread, non-blocking + provider: *otel.trace.TracerProvider, + tracer: *otel.api.trace.TracerImpl, + prng: *std.Random.DefaultPrng, // Keep PRNG alive on heap +}; + +/// Initialize the telemetry system with an OTLP endpoint. +/// Endpoint should be in "host:port" format (e.g., "localhost:4318"). +pub fn init(allocator: std.mem.Allocator, endpoint: []const u8) !void { + // TigerBeetle: validate inputs + std.debug.assert(endpoint.len > 0); + std.debug.assert(endpoint.len < 256); // Reasonable endpoint length + + if (global_state != null) { + log.warn("Telemetry already initialized", .{}); + return; + } + + const state = try allocator.create(TelemetryState); + errdefer allocator.destroy(state); + + // Create OTLP config + const config = try otel.otlp.ConfigOptions.init(allocator); + errdefer config.deinit(); + config.endpoint = endpoint; + config.scheme = .http; // Jaeger OTLP uses HTTP by default + config.protocol = .http_protobuf; // OTLP uses protobuf over HTTP + + // Create OTLP exporter with service name + const exporter = try otel.trace.OTLPExporter.initWithServiceName(allocator, config, "zzz-load-balancer"); + errdefer exporter.deinit(); + + // Create random ID generator with heap-allocated PRNG for persistent state + const nanos: i128 = otel.compat.nanoTimestamp(); + const seed: u64 = @intFromPtr(state) ^ @as(u64, @truncate(@intFromPtr(&allocator))) ^ @as(u64, @truncate(@as(u128, @bitCast(nanos)))); + const prng = try allocator.create(std.Random.DefaultPrng); + errdefer allocator.destroy(prng); + prng.* = std.Random.DefaultPrng.init(seed); + const random_gen = otel.trace.RandomIDGenerator.init(prng.random()); + const id_gen = otel.trace.IDGenerator{ .Random = random_gen }; + + // Create tracer provider + const provider = try otel.trace.TracerProvider.init(allocator, id_gen); + errdefer provider.shutdown(); + + // Create batching processor - exports spans in background thread (non-blocking) + // Config: batch up to 512 spans, export every 5 seconds or when batch full + const processor = try otel.trace.BatchingProcessor.init(allocator, exporter.asSpanExporter(), .{ + .max_queue_size = 2048, + .scheduled_delay_millis = 5000, // Export every 5 seconds + .max_export_batch_size = 512, // Or when 512 spans accumulated + }); + errdefer { + processor.asSpanProcessor().shutdown() catch {}; + processor.deinit(); + } + + // Add the processor to the provider + try provider.addSpanProcessor(processor.asSpanProcessor()); + state.processor = processor; + + // Get a tracer for the load balancer + const tracer = try provider.getTracer(.{ + .name = "zzz-load-balancer", + .version = "0.1.0", + }); + + // Set remaining state fields (processor already set above) + state.allocator = allocator; + state.config = config; + state.exporter = exporter; + state.provider = provider; + state.tracer = tracer; + state.prng = prng; + + global_state = state; + + log.info("Telemetry initialized with endpoint: {s}", .{endpoint}); +} + +/// Shutdown the telemetry system +/// Flushes pending spans and waits for background thread to complete. +pub fn deinit() void { + const state = global_state orelse return; + const allocator = state.allocator; + + // Shutdown provider (which shuts down processors and exports pending spans) + state.provider.shutdown(); + + // Shutdown and cleanup batching processor (waits for background thread) + state.processor.asSpanProcessor().shutdown() catch {}; + state.processor.deinit(); + + // Clean up config and exporter + state.exporter.deinit(); + state.config.deinit(); + + // Clean up PRNG + allocator.destroy(state.prng); + + allocator.destroy(state); + global_state = null; + + log.info("Telemetry shutdown complete", .{}); +} + +/// Check if telemetry is enabled +pub fn isEnabled() bool { + return global_state != null; +} + +/// Span kind for creating spans +pub const SpanKind = enum { + Server, + Client, + Internal, +}; + +/// Span wrapper for easier use with parent-child relationships +/// Optimized to avoid Context serialization allocations +pub const Span = struct { + inner: ?otel.api.trace.Span, + tracer: ?*otel.api.trace.TracerImpl, + allocator: std.mem.Allocator, + + const Self = @This(); + + /// Set a string attribute on the span + pub fn setStringAttribute(self: *Self, key: []const u8, value: []const u8) void { + if (self.inner) |*span| { + span.setAttribute(key, .{ .string = value }) catch {}; + } + } + + /// Set an integer attribute on the span + pub fn setIntAttribute(self: *Self, key: []const u8, value: i64) void { + if (self.inner) |*span| { + span.setAttribute(key, .{ .int = value }) catch {}; + } + } + + /// Set a boolean attribute on the span + pub fn setBoolAttribute(self: *Self, key: []const u8, value: bool) void { + if (self.inner) |*span| { + span.setAttribute(key, .{ .bool = value }) catch {}; + } + } + + /// Add an event to the span + pub fn addEvent(self: *Self, name: []const u8) void { + if (self.inner) |*span| { + span.addEvent(name, null, null) catch {}; + } + } + + /// Set the span status to error with a message + pub fn setError(self: *Self, message: []const u8) void { + if (self.inner) |*span| { + span.setStatus(.{ .code = .Error, .description = message }); + } + } + + /// Set the span status to OK + pub fn setOk(self: *Self) void { + if (self.inner) |*span| { + span.setStatus(.{ .code = .Ok, .description = "" }); + } + } + + /// Get this span's SpanContext directly (no allocation) + pub fn getSpanContext(self: *Self) ?otel.api.trace.SpanContext { + if (self.inner) |*span| { + return span.getContext(); + } + return null; + } + + /// End the span + pub fn end(self: *Self) void { + if (self.inner) |*span| { + if (self.tracer) |tracer| { + tracer.endSpan(span); + } + span.deinit(); + } + self.inner = null; + } +}; + +/// Start a new server span (for incoming requests) +pub fn startServerSpan(name: []const u8) Span { + const state = global_state orelse return Span{ .inner = null, .tracer = null, .allocator = undefined }; + + const span = state.tracer.startSpan(state.allocator, name, .{ + .kind = .Server, + }) catch |err| { + log.debug("Failed to start span: {}", .{err}); + return Span{ .inner = null, .tracer = null, .allocator = state.allocator }; + }; + + return Span{ + .inner = span, + .tracer = state.tracer, + .allocator = state.allocator, + }; +} + +/// Start a new client span (for outgoing requests to backends) +pub fn startClientSpan(name: []const u8) Span { + const state = global_state orelse return Span{ .inner = null, .tracer = null, .allocator = undefined }; + + const span = state.tracer.startSpan(state.allocator, name, .{ + .kind = .Client, + }) catch |err| { + log.debug("Failed to start span: {}", .{err}); + return Span{ .inner = null, .tracer = null, .allocator = state.allocator }; + }; + + return Span{ + .inner = span, + .tracer = state.tracer, + .allocator = state.allocator, + }; +} + +/// Start a new internal span +pub fn startInternalSpan(name: []const u8) Span { + const state = global_state orelse return Span{ .inner = null, .tracer = null, .allocator = undefined }; + + const span = state.tracer.startSpan(state.allocator, name, .{ + .kind = .Internal, + }) catch |err| { + log.debug("Failed to start span: {}", .{err}); + return Span{ .inner = null, .tracer = null, .allocator = state.allocator }; + }; + + return Span{ + .inner = span, + .tracer = state.tracer, + .allocator = state.allocator, + }; +} + +/// Start a child span with a parent span (fast path - no Context allocation) +pub fn startChildSpan(parent: *Span, name: []const u8, kind: SpanKind) Span { + // TigerBeetle: validate inputs + std.debug.assert(name.len > 0); + std.debug.assert(name.len < 128); // Reasonable span name length + + const state = global_state orelse return Span{ .inner = null, .tracer = null, .allocator = undefined }; + + // Get parent SpanContext directly (no allocation!) + const parent_span_ctx = parent.getSpanContext() orelse { + // If no parent context, create a standalone span + return switch (kind) { + .Server => startServerSpan(name), + .Client => startClientSpan(name), + .Internal => startInternalSpan(name), + }; + }; + + const otel_kind: otel.api.trace.SpanKind = switch (kind) { + .Server => .Server, + .Client => .Client, + .Internal => .Internal, + }; + + // Use parent_span_context (fast path) instead of parent_context (slow path) + const span = state.tracer.startSpan(state.allocator, name, .{ + .kind = otel_kind, + .parent_span_context = parent_span_ctx, + }) catch |err| { + log.debug("Failed to start child span: {}", .{err}); + return Span{ .inner = null, .tracer = null, .allocator = state.allocator }; + }; + + return Span{ + .inner = span, + .tracer = state.tracer, + .allocator = state.allocator, + }; +} diff --git a/src/test_load_balancer.zig b/src/test_load_balancer.zig index e876d92..bc67c4f 100644 --- a/src/test_load_balancer.zig +++ b/src/test_load_balancer.zig @@ -39,6 +39,9 @@ pub const simd_parse = @import("internal/simd_parse.zig"); // Core module tests pub const config = @import("core/config.zig"); +// WAF module tests +pub const waf_state = @import("waf/state.zig"); + // Config module tests pub const config_watcher = @import("config/config_watcher.zig"); @@ -69,6 +72,7 @@ comptime { _ = http2_client; _ = simd_parse; _ = config; + _ = waf_state; _ = config_watcher; _ = component_integration_test; _ = proxy_io; diff --git a/src/waf/config.zig b/src/waf/config.zig new file mode 100644 index 0000000..9368db4 --- /dev/null +++ b/src/waf/config.zig @@ -0,0 +1,1201 @@ +/// WAF Configuration - JSON Parsing and Validation +/// +/// Parses waf.json configuration for the Web Application Firewall. +/// Supports hot-reload through config epoch tracking. +/// +/// Design Philosophy (TigerBeetle-inspired): +/// - Fixed-size arrays with explicit compile-time bounds +/// - No unbounded allocations - all limits are known at compile time +/// - Validation before application - invalid configs are rejected early +/// - Hot-reload support via config epoch +/// +/// Example waf.json: +/// ```json +/// { +/// "enabled": true, +/// "shadow_mode": false, +/// "rate_limits": [ +/// { +/// "name": "login_bruteforce", +/// "path": "/api/auth/login", +/// "method": "POST", +/// "limit": { "requests": 10, "period_sec": 60 }, +/// "burst": 3, +/// "by": "ip", +/// "action": "block" +/// } +/// ], +/// "slowloris": { +/// "header_timeout_ms": 5000, +/// "body_timeout_ms": 30000, +/// "min_bytes_per_sec": 100, +/// "max_conns_per_ip": 50 +/// }, +/// "request_limits": { +/// "max_uri_length": 2048, +/// "max_body_size": 1048576, +/// "max_json_depth": 20, +/// "endpoints": [ +/// { "path": "/api/upload", "max_body_size": 10485760 } +/// ] +/// }, +/// "trusted_proxies": ["10.0.0.0/8", "172.16.0.0/12"], +/// "logging": { +/// "log_blocked": true, +/// "log_allowed": false, +/// "log_near_limit": true, +/// "near_limit_threshold": 0.8 +/// } +/// } +/// ``` +const std = @import("std"); +const Allocator = std.mem.Allocator; + +// ============================================================================= +// Constants (TigerBeetle-style: fixed sizes, explicit bounds) +// ============================================================================= + +/// Maximum rate limit rules in config +pub const MAX_RATE_LIMIT_RULES: usize = 64; + +/// Maximum endpoint-specific overrides +pub const MAX_ENDPOINT_OVERRIDES: usize = 32; + +/// Maximum trusted proxy CIDR ranges +pub const MAX_TRUSTED_PROXIES: usize = 16; + +/// Maximum path length for rules +pub const MAX_PATH_LENGTH: usize = 256; + +/// Maximum name length for rules +pub const MAX_NAME_LENGTH: usize = 64; + +/// Maximum header name length for rate limiting by header +pub const MAX_HEADER_NAME_LENGTH: usize = 64; + +/// Maximum config file size +pub const MAX_CONFIG_SIZE: usize = 64 * 1024; + +// ============================================================================= +// Enums +// ============================================================================= + +/// HTTP methods for rate limiting +pub const HttpMethod = enum { + GET, + POST, + PUT, + DELETE, + PATCH, + HEAD, + OPTIONS, + TRACE, + CONNECT, + + /// Parse HTTP method from string (case-insensitive) + pub fn parse(str: []const u8) ?HttpMethod { + const method_map = std.StaticStringMap(HttpMethod).initComptime(.{ + .{ "GET", .GET }, + .{ "POST", .POST }, + .{ "PUT", .PUT }, + .{ "DELETE", .DELETE }, + .{ "PATCH", .PATCH }, + .{ "HEAD", .HEAD }, + .{ "OPTIONS", .OPTIONS }, + .{ "TRACE", .TRACE }, + .{ "CONNECT", .CONNECT }, + // Lowercase variants + .{ "get", .GET }, + .{ "post", .POST }, + .{ "put", .PUT }, + .{ "delete", .DELETE }, + .{ "patch", .PATCH }, + .{ "head", .HEAD }, + .{ "options", .OPTIONS }, + .{ "trace", .TRACE }, + .{ "connect", .CONNECT }, + }); + return method_map.get(str); + } + + /// Convert to string representation + pub fn toString(self: HttpMethod) []const u8 { + return switch (self) { + .GET => "GET", + .POST => "POST", + .PUT => "PUT", + .DELETE => "DELETE", + .PATCH => "PATCH", + .HEAD => "HEAD", + .OPTIONS => "OPTIONS", + .TRACE => "TRACE", + .CONNECT => "CONNECT", + }; + } +}; + +/// What to rate limit by +pub const RateLimitBy = enum { + /// Rate limit by client IP address + ip, + /// Rate limit by specific header value (e.g., API key) + header, + /// Rate limit by request path + path, + + /// Parse from string + pub fn parse(str: []const u8) ?RateLimitBy { + const by_map = std.StaticStringMap(RateLimitBy).initComptime(.{ + .{ "ip", .ip }, + .{ "header", .header }, + .{ "path", .path }, + }); + return by_map.get(str); + } +}; + +/// Action to take when rule matches +pub const Action = enum { + /// Block the request immediately + block, + /// Log but allow the request (shadow mode) + log, + /// Slow down response (tarpit attackers) + tarpit, + + /// Parse from string + pub fn parse(str: []const u8) ?Action { + const action_map = std.StaticStringMap(Action).initComptime(.{ + .{ "block", .block }, + .{ "log", .log }, + .{ "tarpit", .tarpit }, + }); + return action_map.get(str); + } +}; + +// ============================================================================= +// CIDR Range (for trusted proxies) +// ============================================================================= + +/// IPv4 CIDR range for trusted proxy detection +pub const CidrRange = struct { + /// Network address (host byte order) + network: u32, + /// Netmask (host byte order, e.g., 0xFFFFFF00 for /24) + mask: u32, + + /// Parse CIDR notation (e.g., "10.0.0.0/8", "192.168.1.0/24") + pub fn parse(cidr: []const u8) !CidrRange { + // Find the slash separator + const slash_pos = std.mem.indexOf(u8, cidr, "/") orelse return error.InvalidCidr; + + const ip_part = cidr[0..slash_pos]; + const prefix_part = cidr[slash_pos + 1 ..]; + + // Parse IP address + const ip = try parseIpv4(ip_part); + + // Parse prefix length (0-32 valid for IPv4) + const prefix_len = std.fmt.parseInt(u8, prefix_part, 10) catch return error.InvalidCidrPrefix; + if (prefix_len > 32) return error.InvalidCidrPrefix; + + // Calculate mask from prefix length + const mask: u32 = if (prefix_len == 0) + 0 + else if (prefix_len == 32) + 0xFFFFFFFF + else + @as(u32, 0xFFFFFFFF) << @intCast(32 - @as(u6, @intCast(prefix_len))); + + // Validate that IP is actually the network address + if ((ip & mask) != ip) { + return error.IpNotNetworkAddress; + } + + return .{ + .network = ip, + .mask = mask, + }; + } + + /// Check if an IP address falls within this CIDR range + pub fn contains(self: CidrRange, ip: u32) bool { + return (ip & self.mask) == self.network; + } + + /// Format as CIDR string (for debugging) + pub fn format( + self: CidrRange, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, + ) !void { + _ = fmt; + _ = options; + const prefix_len = @popCount(self.mask); + try writer.print("{d}.{d}.{d}.{d}/{d}", .{ + @as(u8, @truncate(self.network >> 24)), + @as(u8, @truncate(self.network >> 16)), + @as(u8, @truncate(self.network >> 8)), + @as(u8, @truncate(self.network)), + prefix_len, + }); + } +}; + +/// Parse IPv4 address string to u32 (host byte order) +fn parseIpv4(ip_str: []const u8) !u32 { + var octets: [4]u8 = undefined; + var octet_idx: usize = 0; + var current: u32 = 0; + + for (ip_str) |c| { + if (c == '.') { + if (octet_idx >= 3) return error.InvalidIpAddress; + if (current > 255) return error.InvalidIpAddress; + octets[octet_idx] = @intCast(current); + octet_idx += 1; + current = 0; + } else if (c >= '0' and c <= '9') { + current = current * 10 + (c - '0'); + if (current > 255) return error.InvalidIpAddress; + } else { + return error.InvalidIpAddress; + } + } + + if (octet_idx != 3) return error.InvalidIpAddress; + octets[3] = @intCast(current); + + return (@as(u32, octets[0]) << 24) | + (@as(u32, octets[1]) << 16) | + (@as(u32, octets[2]) << 8) | + @as(u32, octets[3]); +} + +/// Convert u32 IP to bytes (for display) +pub fn ipToBytes(ip: u32) [4]u8 { + return .{ + @truncate(ip >> 24), + @truncate(ip >> 16), + @truncate(ip >> 8), + @truncate(ip), + }; +} + +// ============================================================================= +// Sub-Configurations +// ============================================================================= + +/// Slowloris attack prevention configuration +pub const SlowlorisConfig = struct { + /// Timeout for receiving all headers (milliseconds) + header_timeout_ms: u32 = 5000, + /// Timeout for receiving request body (milliseconds) + body_timeout_ms: u32 = 30000, + /// Minimum bytes per second for body transfer + min_bytes_per_sec: u32 = 100, + /// Maximum concurrent connections per IP + max_conns_per_ip: u16 = 50, +}; + +/// Endpoint-specific body size override +pub const EndpointOverride = struct { + /// Path pattern (supports * wildcard at end) + path: []const u8, + /// Maximum body size for this endpoint + max_body_size: u32, +}; + +/// Request size and depth limits configuration +pub const RequestLimitsConfig = struct { + /// Maximum URI length in bytes + max_uri_length: u32 = 2048, + /// Maximum request body size in bytes (default 1MB) + max_body_size: u32 = 1048576, + /// Maximum JSON nesting depth + max_json_depth: u8 = 20, + /// Endpoint-specific overrides (e.g., larger limit for upload endpoints) + endpoints: []const EndpointOverride = &.{}, +}; + +/// Logging configuration +pub const LoggingConfig = struct { + /// Log blocked requests + log_blocked: bool = true, + /// Log allowed requests (verbose) + log_allowed: bool = false, + /// Log requests approaching rate limit + log_near_limit: bool = true, + /// Threshold for "near limit" (0.0-1.0, e.g., 0.8 = 80% of limit) + near_limit_threshold: f32 = 0.8, +}; + +// ============================================================================= +// Rate Limit Rule +// ============================================================================= + +/// Rate limit rule configuration +pub const RateLimitRule = struct { + /// Human-readable name for this rule + name: []const u8, + /// Path pattern to match (supports * wildcard at end) + path: []const u8, + /// HTTP method to match (null = all methods) + method: ?HttpMethod = null, + /// Requests allowed per period + requests: u32, + /// Period in seconds + period_sec: u32, + /// Burst allowance (extra requests allowed in short burst) + burst: u32, + /// What to rate limit by + by: RateLimitBy = .ip, + /// Header name when by=header + header_name: ?[]const u8 = null, + /// Action to take when limit exceeded + action: Action = .block, + + /// Check if this rule matches a given path and method + pub fn matches(self: *const RateLimitRule, request_path: []const u8, request_method: ?HttpMethod) bool { + // Check method match (null rule method means match all) + if (self.method) |rule_method| { + if (request_method) |req_method| { + if (rule_method != req_method) return false; + } else { + return false; + } + } + + // Check path match with wildcard support + return pathMatches(self.path, request_path); + } + + /// Calculate tokens per second for rate limiter + /// Scaled by 1000 for sub-token precision + pub fn tokensPerSec(self: *const RateLimitRule) u32 { + if (self.period_sec == 0) return 0; + return (self.requests * 1000) / self.period_sec; + } + + /// Calculate burst capacity for rate limiter + /// Scaled by 1000 for sub-token precision + pub fn burstCapacity(self: *const RateLimitRule) u32 { + return (self.requests + self.burst) * 1000; + } +}; + +/// Check if a path pattern matches a request path +/// Supports * wildcard at end of pattern +fn pathMatches(pattern: []const u8, path: []const u8) bool { + // Empty pattern matches nothing + if (pattern.len == 0) return false; + + // Check for wildcard at end + if (pattern[pattern.len - 1] == '*') { + const prefix = pattern[0 .. pattern.len - 1]; + return std.mem.startsWith(u8, path, prefix); + } + + // Exact match + return std.mem.eql(u8, pattern, path); +} + +// ============================================================================= +// Main WAF Configuration +// ============================================================================= + +/// Main WAF configuration structure +/// Parsed from waf.json with all settings for the firewall +pub const WafConfig = struct { + /// Master enable/disable switch + enabled: bool = true, + /// Shadow mode: log but don't block (for testing rules) + shadow_mode: bool = false, + /// Rate limiting rules + rate_limits: []const RateLimitRule = &.{}, + /// Slowloris attack prevention + slowloris: SlowlorisConfig = .{}, + /// Request size limits + request_limits: RequestLimitsConfig = .{}, + /// Trusted proxy CIDR ranges (for X-Forwarded-For) + trusted_proxies: []const CidrRange = &.{}, + /// Logging configuration + logging: LoggingConfig = .{}, + /// Burst detection: detect sudden velocity spikes + burst_detection_enabled: bool = true, + /// Burst threshold: current rate must exceed baseline * threshold to trigger + burst_threshold: u32 = 10, + /// Config epoch for hot-reload detection + epoch: u64 = 0, + + /// Internal: allocator used for parsing (needed for deinit) + _allocator: ?Allocator = null, + + // ========================================================================= + // Parsing + // ========================================================================= + + /// Parse WAF config from JSON string + pub fn parse(allocator: Allocator, json: []const u8) !WafConfig { + // Parse JSON + const parsed = std.json.parseFromSlice( + JsonWafConfig, + allocator, + json, + .{ .ignore_unknown_fields = true, .allocate = .alloc_always }, + ) catch { + return error.InvalidJson; + }; + defer parsed.deinit(); + + return try fromJson(allocator, parsed.value); + } + + /// Convert from JSON representation to WafConfig + fn fromJson(allocator: Allocator, json: JsonWafConfig) !WafConfig { + var config = WafConfig{ + .enabled = json.enabled, + .shadow_mode = json.shadow_mode, + ._allocator = allocator, + }; + errdefer config.deinit(); + + // Parse rate limits + if (json.rate_limits.len > MAX_RATE_LIMIT_RULES) { + return error.TooManyRateLimitRules; + } + if (json.rate_limits.len > 0) { + const rate_limits = try allocator.alloc(RateLimitRule, json.rate_limits.len); + // Initialize to empty for safe cleanup during errdefer + for (rate_limits) |*rl| { + rl.* = .{ + .name = &.{}, + .path = &.{}, + .requests = 0, + .period_sec = 1, + .burst = 0, + }; + } + config.rate_limits = rate_limits; + + for (json.rate_limits, 0..) |jrl, i| { + rate_limits[i] = try parseRateLimitRule(allocator, jrl); + } + } + + // Parse trusted proxies + if (json.trusted_proxies.len > MAX_TRUSTED_PROXIES) { + return error.TooManyTrustedProxies; + } + if (json.trusted_proxies.len > 0) { + const proxies = try allocator.alloc(CidrRange, json.trusted_proxies.len); + config.trusted_proxies = proxies; + + for (json.trusted_proxies, 0..) |cidr_str, i| { + proxies[i] = CidrRange.parse(cidr_str) catch { + return error.InvalidTrustedProxy; + }; + } + } + + // Parse endpoint overrides + if (json.request_limits) |jrl| { + config.request_limits = .{ + .max_uri_length = jrl.max_uri_length, + .max_body_size = jrl.max_body_size, + .max_json_depth = jrl.max_json_depth, + .endpoints = &.{}, + }; + + if (jrl.endpoints.len > 0) { + const endpoints = try allocator.alloc(EndpointOverride, jrl.endpoints.len); + // Initialize to empty for safe cleanup + for (endpoints) |*ep| { + ep.* = .{ .path = &.{}, .max_body_size = 0 }; + } + config.request_limits.endpoints = endpoints; + + for (jrl.endpoints, 0..) |je, i| { + const path = try allocator.dupe(u8, je.path); + endpoints[i] = .{ + .path = path, + .max_body_size = je.max_body_size, + }; + } + } + } + + // Parse slowloris config + if (json.slowloris) |js| { + config.slowloris = .{ + .header_timeout_ms = js.header_timeout_ms, + .body_timeout_ms = js.body_timeout_ms, + .min_bytes_per_sec = js.min_bytes_per_sec, + .max_conns_per_ip = js.max_conns_per_ip, + }; + } + + // Parse logging config + if (json.logging) |jl| { + config.logging = .{ + .log_blocked = jl.log_blocked, + .log_allowed = jl.log_allowed, + .log_near_limit = jl.log_near_limit, + .near_limit_threshold = jl.near_limit_threshold, + }; + } + + // Validate before returning + try config.validate(); + + return config; + } + + /// Free all allocated memory + pub fn deinit(self: *WafConfig) void { + const allocator = self._allocator orelse return; + + // Free rate limit rules - only free strings that were actually allocated + for (self.rate_limits) |rule| { + if (rule.name.len > 0) allocator.free(rule.name); + if (rule.path.len > 0) allocator.free(rule.path); + if (rule.header_name) |h| { + if (h.len > 0) allocator.free(h); + } + } + // Free the array itself if it was allocated + if (self.rate_limits.len > 0) { + allocator.free(self.rate_limits); + } + + // Free endpoint overrides - only free paths that were actually allocated + for (self.request_limits.endpoints) |ep| { + if (ep.path.len > 0) allocator.free(ep.path); + } + // Free the array itself if it was allocated + if (self.request_limits.endpoints.len > 0) { + allocator.free(self.request_limits.endpoints); + } + + // Free trusted proxies array if it was allocated + if (self.trusted_proxies.len > 0) { + allocator.free(self.trusted_proxies); + } + + self._allocator = null; + } + + // ========================================================================= + // File Loading + // ========================================================================= + + /// Load WAF config from file + pub fn loadFromFile(allocator: Allocator, path: []const u8) !WafConfig { + const file = std.fs.cwd().openFile(path, .{}) catch |err| { + return switch (err) { + error.FileNotFound => error.ConfigFileNotFound, + else => error.ConfigFileOpenFailed, + }; + }; + defer file.close(); + + // Read file content + var content = try allocator.alloc(u8, MAX_CONFIG_SIZE); + defer allocator.free(content); + + var total_read: usize = 0; + while (total_read < MAX_CONFIG_SIZE) { + const bytes_read = file.read(content[total_read..]) catch |err| { + if (err == error.EndOfStream) break; + return error.ConfigFileReadFailed; + }; + if (bytes_read == 0) break; + total_read += bytes_read; + } + + if (total_read == MAX_CONFIG_SIZE) { + return error.ConfigFileTooLarge; + } + + return try parse(allocator, content[0..total_read]); + } + + // ========================================================================= + // Validation + // ========================================================================= + + /// Validate configuration consistency + pub fn validate(self: *const WafConfig) !void { + // Validate rate limit rules + for (self.rate_limits) |rule| { + if (rule.name.len == 0) return error.EmptyRuleName; + if (rule.name.len > MAX_NAME_LENGTH) return error.RuleNameTooLong; + if (rule.path.len == 0) return error.EmptyRulePath; + if (rule.path.len > MAX_PATH_LENGTH) return error.RulePathTooLong; + if (rule.requests == 0) return error.ZeroRequests; + if (rule.period_sec == 0) return error.ZeroPeriod; + if (rule.by == .header and rule.header_name == null) { + return error.MissingHeaderName; + } + } + + // Validate slowloris config + if (self.slowloris.header_timeout_ms == 0) return error.ZeroHeaderTimeout; + if (self.slowloris.body_timeout_ms == 0) return error.ZeroBodyTimeout; + + // Validate request limits + if (self.request_limits.max_uri_length == 0) return error.ZeroMaxUriLength; + if (self.request_limits.max_body_size == 0) return error.ZeroMaxBodySize; + + // Validate endpoint overrides + for (self.request_limits.endpoints) |ep| { + if (ep.path.len == 0) return error.EmptyEndpointPath; + if (ep.path.len > MAX_PATH_LENGTH) return error.EndpointPathTooLong; + if (ep.max_body_size == 0) return error.ZeroEndpointBodySize; + } + + // Validate logging threshold + if (self.logging.near_limit_threshold <= 0.0 or self.logging.near_limit_threshold > 1.0) { + return error.InvalidNearLimitThreshold; + } + } + + // ========================================================================= + // Lookups + // ========================================================================= + + /// Find the first matching rate limit rule for a request + pub fn findRateLimitRule(self: *const WafConfig, path: []const u8, method: ?HttpMethod) ?*const RateLimitRule { + for (self.rate_limits) |*rule| { + if (rule.matches(path, method)) { + return rule; + } + } + return null; + } + + /// Get the effective max body size for a path + pub fn getMaxBodySize(self: *const WafConfig, path: []const u8) u32 { + // Check endpoint-specific overrides first + for (self.request_limits.endpoints) |ep| { + if (pathMatches(ep.path, path)) { + return ep.max_body_size; + } + } + return self.request_limits.max_body_size; + } + + /// Check if an IP is from a trusted proxy + pub fn isTrustedProxy(self: *const WafConfig, ip: u32) bool { + for (self.trusted_proxies) |cidr| { + if (cidr.contains(ip)) { + return true; + } + } + return false; + } +}; + +// ============================================================================= +// JSON Schema Types (for std.json parsing) +// ============================================================================= + +const JsonRateLimitRule = struct { + name: []const u8, + path: []const u8, + method: ?[]const u8 = null, + limit: struct { + requests: u32, + period_sec: u32, + }, + burst: u32 = 0, + by: []const u8 = "ip", + header_name: ?[]const u8 = null, + action: []const u8 = "block", +}; + +const JsonSlowlorisConfig = struct { + header_timeout_ms: u32 = 5000, + body_timeout_ms: u32 = 30000, + min_bytes_per_sec: u32 = 100, + max_conns_per_ip: u16 = 50, +}; + +const JsonEndpointOverride = struct { + path: []const u8, + max_body_size: u32, +}; + +const JsonRequestLimitsConfig = struct { + max_uri_length: u32 = 2048, + max_body_size: u32 = 1048576, + max_json_depth: u8 = 20, + endpoints: []const JsonEndpointOverride = &.{}, +}; + +const JsonLoggingConfig = struct { + log_blocked: bool = true, + log_allowed: bool = false, + log_near_limit: bool = true, + near_limit_threshold: f32 = 0.8, +}; + +const JsonWafConfig = struct { + enabled: bool = true, + shadow_mode: bool = false, + rate_limits: []const JsonRateLimitRule = &.{}, + slowloris: ?JsonSlowlorisConfig = null, + request_limits: ?JsonRequestLimitsConfig = null, + trusted_proxies: []const []const u8 = &.{}, + logging: ?JsonLoggingConfig = null, +}; + +/// Parse a JSON rate limit rule to RateLimitRule +fn parseRateLimitRule(allocator: Allocator, jrl: JsonRateLimitRule) !RateLimitRule { + // Validate lengths + if (jrl.name.len > MAX_NAME_LENGTH) return error.RuleNameTooLong; + if (jrl.path.len > MAX_PATH_LENGTH) return error.RulePathTooLong; + + // Parse method + const method: ?HttpMethod = if (jrl.method) |m| + HttpMethod.parse(m) orelse return error.InvalidHttpMethod + else + null; + + // Parse "by" field + const by = RateLimitBy.parse(jrl.by) orelse return error.InvalidRateLimitBy; + + // Parse action + const action = Action.parse(jrl.action) orelse return error.InvalidAction; + + // Copy strings + const name = try allocator.dupe(u8, jrl.name); + errdefer allocator.free(name); + + const path = try allocator.dupe(u8, jrl.path); + errdefer allocator.free(path); + + const header_name: ?[]const u8 = if (jrl.header_name) |h| blk: { + if (h.len > MAX_HEADER_NAME_LENGTH) return error.HeaderNameTooLong; + break :blk try allocator.dupe(u8, h); + } else null; + + return .{ + .name = name, + .path = path, + .method = method, + .requests = jrl.limit.requests, + .period_sec = jrl.limit.period_sec, + .burst = jrl.burst, + .by = by, + .header_name = header_name, + .action = action, + }; +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "CidrRange: parse valid CIDR" { + const cidr = try CidrRange.parse("10.0.0.0/8"); + try std.testing.expectEqual(@as(u32, 0x0A000000), cidr.network); + try std.testing.expectEqual(@as(u32, 0xFF000000), cidr.mask); +} + +test "CidrRange: parse /24 network" { + const cidr = try CidrRange.parse("192.168.1.0/24"); + try std.testing.expectEqual(@as(u32, 0xC0A80100), cidr.network); + try std.testing.expectEqual(@as(u32, 0xFFFFFF00), cidr.mask); +} + +test "CidrRange: parse /32 (single host)" { + const cidr = try CidrRange.parse("192.168.1.100/32"); + try std.testing.expectEqual(@as(u32, 0xC0A80164), cidr.network); + try std.testing.expectEqual(@as(u32, 0xFFFFFFFF), cidr.mask); +} + +test "CidrRange: contains" { + const cidr = try CidrRange.parse("10.0.0.0/8"); + + // Should contain + try std.testing.expect(cidr.contains(0x0A000001)); // 10.0.0.1 + try std.testing.expect(cidr.contains(0x0AFFFFFF)); // 10.255.255.255 + + // Should not contain + try std.testing.expect(!cidr.contains(0x0B000000)); // 11.0.0.0 + try std.testing.expect(!cidr.contains(0xC0A80101)); // 192.168.1.1 +} + +test "CidrRange: invalid - not network address" { + const result = CidrRange.parse("10.0.0.1/8"); + try std.testing.expectError(error.IpNotNetworkAddress, result); +} + +test "CidrRange: invalid - no slash" { + const result = CidrRange.parse("10.0.0.0"); + try std.testing.expectError(error.InvalidCidr, result); +} + +test "parseIpv4: valid addresses" { + try std.testing.expectEqual(@as(u32, 0x7F000001), try parseIpv4("127.0.0.1")); + try std.testing.expectEqual(@as(u32, 0xC0A80101), try parseIpv4("192.168.1.1")); + try std.testing.expectEqual(@as(u32, 0x00000000), try parseIpv4("0.0.0.0")); + try std.testing.expectEqual(@as(u32, 0xFFFFFFFF), try parseIpv4("255.255.255.255")); +} + +test "parseIpv4: invalid addresses" { + try std.testing.expectError(error.InvalidIpAddress, parseIpv4("256.0.0.0")); + try std.testing.expectError(error.InvalidIpAddress, parseIpv4("10.0.0")); + try std.testing.expectError(error.InvalidIpAddress, parseIpv4("10.0.0.0.0")); + try std.testing.expectError(error.InvalidIpAddress, parseIpv4("abc.def.ghi.jkl")); +} + +test "HttpMethod: parse" { + try std.testing.expectEqual(HttpMethod.GET, HttpMethod.parse("GET").?); + try std.testing.expectEqual(HttpMethod.POST, HttpMethod.parse("post").?); + try std.testing.expect(HttpMethod.parse("INVALID") == null); +} + +test "RateLimitBy: parse" { + try std.testing.expectEqual(RateLimitBy.ip, RateLimitBy.parse("ip").?); + try std.testing.expectEqual(RateLimitBy.header, RateLimitBy.parse("header").?); + try std.testing.expectEqual(RateLimitBy.path, RateLimitBy.parse("path").?); + try std.testing.expect(RateLimitBy.parse("invalid") == null); +} + +test "Action: parse" { + try std.testing.expectEqual(Action.block, Action.parse("block").?); + try std.testing.expectEqual(Action.log, Action.parse("log").?); + try std.testing.expectEqual(Action.tarpit, Action.parse("tarpit").?); + try std.testing.expect(Action.parse("invalid") == null); +} + +test "pathMatches: exact match" { + try std.testing.expect(pathMatches("/api/users", "/api/users")); + try std.testing.expect(!pathMatches("/api/users", "/api/posts")); +} + +test "pathMatches: wildcard" { + try std.testing.expect(pathMatches("/api/*", "/api/users")); + try std.testing.expect(pathMatches("/api/*", "/api/users/123")); + try std.testing.expect(!pathMatches("/api/*", "/other/path")); +} + +test "RateLimitRule: matches" { + const rule = RateLimitRule{ + .name = "test", + .path = "/api/*", + .method = .POST, + .requests = 100, + .period_sec = 60, + .burst = 10, + }; + + try std.testing.expect(rule.matches("/api/users", .POST)); + try std.testing.expect(!rule.matches("/api/users", .GET)); + try std.testing.expect(!rule.matches("/other", .POST)); +} + +test "RateLimitRule: tokensPerSec and burstCapacity" { + const rule = RateLimitRule{ + .name = "test", + .path = "/api/*", + .requests = 60, // 60 requests per minute = 1 per second + .period_sec = 60, + .burst = 10, + }; + + try std.testing.expectEqual(@as(u32, 1000), rule.tokensPerSec()); // 1 * 1000 + try std.testing.expectEqual(@as(u32, 70000), rule.burstCapacity()); // (60 + 10) * 1000 +} + +test "WafConfig: parse minimal config" { + const json = + \\{"enabled": true, "shadow_mode": false} + ; + + var config = try WafConfig.parse(std.testing.allocator, json); + defer config.deinit(); + + try std.testing.expect(config.enabled); + try std.testing.expect(!config.shadow_mode); + try std.testing.expectEqual(@as(usize, 0), config.rate_limits.len); +} + +test "WafConfig: parse with rate limits" { + const json = + \\{ + \\ "enabled": true, + \\ "shadow_mode": false, + \\ "rate_limits": [ + \\ { + \\ "name": "login_bruteforce", + \\ "path": "/api/auth/login", + \\ "method": "POST", + \\ "limit": { "requests": 10, "period_sec": 60 }, + \\ "burst": 3, + \\ "by": "ip", + \\ "action": "block" + \\ } + \\ ] + \\} + ; + + var config = try WafConfig.parse(std.testing.allocator, json); + defer config.deinit(); + + try std.testing.expectEqual(@as(usize, 1), config.rate_limits.len); + try std.testing.expectEqualStrings("login_bruteforce", config.rate_limits[0].name); + try std.testing.expectEqualStrings("/api/auth/login", config.rate_limits[0].path); + try std.testing.expectEqual(HttpMethod.POST, config.rate_limits[0].method.?); + try std.testing.expectEqual(@as(u32, 10), config.rate_limits[0].requests); + try std.testing.expectEqual(@as(u32, 60), config.rate_limits[0].period_sec); + try std.testing.expectEqual(@as(u32, 3), config.rate_limits[0].burst); + try std.testing.expectEqual(RateLimitBy.ip, config.rate_limits[0].by); + try std.testing.expectEqual(Action.block, config.rate_limits[0].action); +} + +test "WafConfig: parse with trusted proxies" { + const json = + \\{ + \\ "trusted_proxies": ["10.0.0.0/8", "172.16.0.0/12"] + \\} + ; + + var config = try WafConfig.parse(std.testing.allocator, json); + defer config.deinit(); + + try std.testing.expectEqual(@as(usize, 2), config.trusted_proxies.len); + + // Test 10.0.0.0/8 + try std.testing.expectEqual(@as(u32, 0x0A000000), config.trusted_proxies[0].network); + try std.testing.expect(config.isTrustedProxy(0x0A010203)); // 10.1.2.3 + + // Test 172.16.0.0/12 + try std.testing.expectEqual(@as(u32, 0xAC100000), config.trusted_proxies[1].network); + try std.testing.expect(config.isTrustedProxy(0xAC1F0001)); // 172.31.0.1 +} + +test "WafConfig: parse with slowloris config" { + const json = + \\{ + \\ "slowloris": { + \\ "header_timeout_ms": 3000, + \\ "body_timeout_ms": 20000, + \\ "min_bytes_per_sec": 50, + \\ "max_conns_per_ip": 25 + \\ } + \\} + ; + + var config = try WafConfig.parse(std.testing.allocator, json); + defer config.deinit(); + + try std.testing.expectEqual(@as(u32, 3000), config.slowloris.header_timeout_ms); + try std.testing.expectEqual(@as(u32, 20000), config.slowloris.body_timeout_ms); + try std.testing.expectEqual(@as(u32, 50), config.slowloris.min_bytes_per_sec); + try std.testing.expectEqual(@as(u16, 25), config.slowloris.max_conns_per_ip); +} + +test "WafConfig: parse with request limits and endpoint overrides" { + const json = + \\{ + \\ "request_limits": { + \\ "max_uri_length": 4096, + \\ "max_body_size": 2097152, + \\ "max_json_depth": 10, + \\ "endpoints": [ + \\ { "path": "/api/upload", "max_body_size": 10485760 } + \\ ] + \\ } + \\} + ; + + var config = try WafConfig.parse(std.testing.allocator, json); + defer config.deinit(); + + try std.testing.expectEqual(@as(u32, 4096), config.request_limits.max_uri_length); + try std.testing.expectEqual(@as(u32, 2097152), config.request_limits.max_body_size); + try std.testing.expectEqual(@as(u8, 10), config.request_limits.max_json_depth); + try std.testing.expectEqual(@as(usize, 1), config.request_limits.endpoints.len); + + // Test getMaxBodySize + try std.testing.expectEqual(@as(u32, 10485760), config.getMaxBodySize("/api/upload")); + try std.testing.expectEqual(@as(u32, 2097152), config.getMaxBodySize("/api/other")); +} + +test "WafConfig: parse with logging config" { + const json = + \\{ + \\ "logging": { + \\ "log_blocked": true, + \\ "log_allowed": true, + \\ "log_near_limit": false, + \\ "near_limit_threshold": 0.9 + \\ } + \\} + ; + + var config = try WafConfig.parse(std.testing.allocator, json); + defer config.deinit(); + + try std.testing.expect(config.logging.log_blocked); + try std.testing.expect(config.logging.log_allowed); + try std.testing.expect(!config.logging.log_near_limit); + try std.testing.expectApproxEqAbs(@as(f32, 0.9), config.logging.near_limit_threshold, 0.001); +} + +test "WafConfig: parse full example config" { + const json = + \\{ + \\ "enabled": true, + \\ "shadow_mode": false, + \\ "rate_limits": [ + \\ { + \\ "name": "login_bruteforce", + \\ "path": "/api/auth/login", + \\ "method": "POST", + \\ "limit": { "requests": 10, "period_sec": 60 }, + \\ "burst": 3, + \\ "by": "ip", + \\ "action": "block" + \\ }, + \\ { + \\ "name": "api_global", + \\ "path": "/api/*", + \\ "limit": { "requests": 1000, "period_sec": 60 }, + \\ "burst": 100, + \\ "by": "ip", + \\ "action": "block" + \\ } + \\ ], + \\ "slowloris": { + \\ "header_timeout_ms": 5000, + \\ "body_timeout_ms": 30000, + \\ "min_bytes_per_sec": 100, + \\ "max_conns_per_ip": 50 + \\ }, + \\ "request_limits": { + \\ "max_uri_length": 2048, + \\ "max_body_size": 1048576, + \\ "max_json_depth": 20, + \\ "endpoints": [ + \\ { "path": "/api/upload", "max_body_size": 10485760 } + \\ ] + \\ }, + \\ "trusted_proxies": ["10.0.0.0/8", "172.16.0.0/12"], + \\ "logging": { + \\ "log_blocked": true, + \\ "log_allowed": false, + \\ "log_near_limit": true, + \\ "near_limit_threshold": 0.8 + \\ } + \\} + ; + + var config = try WafConfig.parse(std.testing.allocator, json); + defer config.deinit(); + + // Verify everything parsed correctly + try std.testing.expect(config.enabled); + try std.testing.expect(!config.shadow_mode); + try std.testing.expectEqual(@as(usize, 2), config.rate_limits.len); + try std.testing.expectEqual(@as(usize, 2), config.trusted_proxies.len); + try std.testing.expectEqual(@as(usize, 1), config.request_limits.endpoints.len); + + // Test findRateLimitRule - should match login rule (more specific) + const login_rule = config.findRateLimitRule("/api/auth/login", .POST); + try std.testing.expect(login_rule != null); + try std.testing.expectEqualStrings("login_bruteforce", login_rule.?.name); + + // Test findRateLimitRule - should match global API rule + const api_rule = config.findRateLimitRule("/api/users", .GET); + try std.testing.expect(api_rule != null); + try std.testing.expectEqualStrings("api_global", api_rule.?.name); + + // Test findRateLimitRule - no match + const no_rule = config.findRateLimitRule("/static/file.js", .GET); + try std.testing.expect(no_rule == null); +} + +test "WafConfig: validation catches zero requests" { + const json = + \\{ + \\ "rate_limits": [ + \\ { + \\ "name": "test", + \\ "path": "/api/*", + \\ "limit": { "requests": 0, "period_sec": 60 } + \\ } + \\ ] + \\} + ; + + const result = WafConfig.parse(std.testing.allocator, json); + try std.testing.expectError(error.ZeroRequests, result); +} + +test "WafConfig: validation catches zero period" { + const json = + \\{ + \\ "rate_limits": [ + \\ { + \\ "name": "test", + \\ "path": "/api/*", + \\ "limit": { "requests": 10, "period_sec": 0 } + \\ } + \\ ] + \\} + ; + + const result = WafConfig.parse(std.testing.allocator, json); + try std.testing.expectError(error.ZeroPeriod, result); +} + +test "WafConfig: validation catches invalid near_limit_threshold" { + const json = + \\{ + \\ "logging": { + \\ "near_limit_threshold": 1.5 + \\ } + \\} + ; + + const result = WafConfig.parse(std.testing.allocator, json); + try std.testing.expectError(error.InvalidNearLimitThreshold, result); +} + +test "WafConfig: invalid JSON returns error" { + const json = \\{invalid json} + ; + + const result = WafConfig.parse(std.testing.allocator, json); + try std.testing.expectError(error.InvalidJson, result); +} + +test "WafConfig: invalid trusted proxy returns error" { + const json = + \\{ + \\ "trusted_proxies": ["invalid-cidr"] + \\} + ; + + const result = WafConfig.parse(std.testing.allocator, json); + try std.testing.expectError(error.InvalidTrustedProxy, result); +} + +test "WafConfig: defaults are sensible" { + var config = WafConfig{}; + + // All defaults should pass validation + try config.validate(); + + try std.testing.expect(config.enabled); + try std.testing.expect(!config.shadow_mode); + try std.testing.expectEqual(@as(u32, 5000), config.slowloris.header_timeout_ms); + try std.testing.expectEqual(@as(u32, 2048), config.request_limits.max_uri_length); + try std.testing.expect(config.logging.log_blocked); +} diff --git a/src/waf/engine.zig b/src/waf/engine.zig new file mode 100644 index 0000000..dd9808c --- /dev/null +++ b/src/waf/engine.zig @@ -0,0 +1,717 @@ +/// WAF Engine - Request Evaluation Orchestrator +/// +/// The heart of the Web Application Firewall. Orchestrates rate limiting, +/// request validation, and decision-making into a unified request evaluation flow. +/// +/// Design Philosophy (TigerBeetle-inspired): +/// - Single entry point for all WAF decisions +/// - Zero allocation on hot path +/// - Deterministic, predictable behavior +/// - Shadow mode for safe rule testing +/// - Composable with existing load balancer infrastructure +/// +/// Request Flow: +/// 1. WAF enabled check (fast path for disabled WAF) +/// 2. Client IP extraction (handles trusted proxies) +/// 3. Rate limit evaluation +/// 4. Request validation (URI length, body size) +/// 5. Metrics recording +/// 6. Decision return +/// +/// Shadow Mode: +/// When enabled, all blocking decisions are converted to log_only decisions. +/// This allows testing rules in production without affecting traffic. +const std = @import("std"); + +const state = @import("state.zig"); +pub const WafState = state.WafState; +pub const Decision = state.Decision; +pub const Reason = state.Reason; + +const rate_limiter = @import("rate_limiter.zig"); +pub const RateLimiter = rate_limiter.RateLimiter; +pub const Key = rate_limiter.Key; +pub const Rule = rate_limiter.Rule; +pub const hashPath = rate_limiter.hashPath; + +const config = @import("config.zig"); +pub const WafConfig = config.WafConfig; +pub const HttpMethod = config.HttpMethod; +pub const RateLimitRule = config.RateLimitRule; +pub const ipToBytes = config.ipToBytes; + + +// ============================================================================= +// Request - Incoming HTTP Request Representation +// ============================================================================= + +/// Represents an incoming HTTP request for WAF evaluation +/// Designed to be lightweight and stack-allocated +pub const Request = struct { + /// HTTP method of the request + method: HttpMethod, + /// Request URI (path + query string) + uri: []const u8, + /// Content-Length header value, if present + content_length: ?usize = null, + /// Direct connection IP (network byte order, u32) + source_ip: u32, + /// X-Forwarded-For header value, if present (for trusted proxy handling) + x_forwarded_for: ?[]const u8 = null, + + /// Create a request from basic parameters + pub fn init(method: HttpMethod, uri: []const u8, source_ip: u32) Request { + return .{ + .method = method, + .uri = uri, + .source_ip = source_ip, + }; + } + + /// Create a request with content length + pub fn withContentLength(method: HttpMethod, uri: []const u8, source_ip: u32, content_length: usize) Request { + return .{ + .method = method, + .uri = uri, + .source_ip = source_ip, + .content_length = content_length, + }; + } + + /// Extract just the path portion of the URI (before query string) + pub fn getPath(self: *const Request) []const u8 { + if (std.mem.indexOf(u8, self.uri, "?")) |query_start| { + return self.uri[0..query_start]; + } + return self.uri; + } +}; + +// ============================================================================= +// CheckResult - WAF Decision Output +// ============================================================================= + +/// Result of WAF request evaluation +/// Contains the decision, reason, and additional context for logging/headers +pub const CheckResult = struct { + /// The WAF decision (allow, block, or log_only) + decision: Decision, + /// Reason for blocking (if blocked or logged) + reason: Reason = .none, + /// Name of the rule that triggered (if any) + rule_name: ?[]const u8 = null, + /// Remaining tokens after rate limit check (for rate limit headers) + tokens_remaining: ?u32 = null, + + /// Create an allow result + pub fn allow() CheckResult { + return .{ + .decision = .allow, + .reason = .none, + }; + } + + /// Create a block result with reason and optional rule name + pub fn block(reason: Reason, rule_name: ?[]const u8) CheckResult { + return .{ + .decision = .block, + .reason = reason, + .rule_name = rule_name, + }; + } + + /// Create a log_only result (shadow mode) + pub fn logOnly(reason: Reason, rule_name: ?[]const u8) CheckResult { + return .{ + .decision = .log_only, + .reason = reason, + .rule_name = rule_name, + }; + } + + /// Check if request should proceed + pub inline fn isAllowed(self: CheckResult) bool { + return self.decision == .allow or self.decision == .log_only; + } + + /// Check if request was blocked + pub inline fn isBlocked(self: CheckResult) bool { + return self.decision == .block; + } + + /// Check if this result should be logged + pub inline fn shouldLog(self: CheckResult) bool { + return self.decision.shouldLog(); + } + + /// Convert block to log_only (for shadow mode) + pub fn toShadowMode(self: CheckResult) CheckResult { + if (self.decision == .block) { + return .{ + .decision = .log_only, + .reason = self.reason, + .rule_name = self.rule_name, + .tokens_remaining = self.tokens_remaining, + }; + } + return self; + } +}; + +// ============================================================================= +// WafEngine - Main WAF Orchestrator +// ============================================================================= + +/// Main WAF engine that orchestrates all security checks +/// Thread-safe through atomic operations in underlying state +pub const WafEngine = struct { + /// Pointer to shared WAF state (mmap'd region) + waf_state: *WafState, + /// Pointer to current configuration + waf_config: *const WafConfig, + /// Rate limiter instance + limiter: RateLimiter, + + /// Initialize the WAF engine with shared state and configuration + pub fn init(waf_state: *WafState, waf_config: *const WafConfig) WafEngine { + return .{ + .waf_state = waf_state, + .waf_config = waf_config, + .limiter = RateLimiter.init(waf_state), + }; + } + + /// Main entry point - evaluate a request through all WAF checks + /// + /// This is the hot path. Designed for minimal latency: + /// 1. Fast-path if WAF disabled + /// 2. Extract real client IP + /// 3. Validate request format + /// 4. Check rate limits + /// 5. Record metrics + /// 6. Return decision (converted to log_only if shadow mode) + pub fn check(self: *WafEngine, request: *const Request) CheckResult { + // Fast path: WAF disabled + if (!self.waf_config.enabled) { + return CheckResult.allow(); + } + + // Extract real client IP (handles trusted proxies) + const client_ip = self.getClientIp(request); + + // Validate request (URI length, body size) + const validation_result = self.validateRequest(request); + if (validation_result.decision != .allow) { + self.recordDecision(&validation_result); + return self.applyMode(validation_result); + } + + // Check rate limits + const rate_limit_result = self.checkRateLimit(request, client_ip); + if (rate_limit_result.decision != .allow) { + self.recordDecision(&rate_limit_result); + return self.applyMode(rate_limit_result); + } + + // Check for burst behavior (sudden velocity spike) + if (self.waf_config.burst_detection_enabled) { + const burst_result = self.checkBurst(client_ip); + if (burst_result.decision != .allow) { + self.recordDecision(&burst_result); + return self.applyMode(burst_result); + } + } + + // All checks passed + self.recordDecision(&rate_limit_result); + return self.applyMode(rate_limit_result); + } + + /// Main entry point with OpenTelemetry tracing + /// + /// Creates child spans for each WAF check step, providing visibility + /// into the WAF decision process in Jaeger/OTLP. + /// + /// The telemetry_mod parameter uses duck typing - any module that provides + /// startChildSpan(span, name, kind) -> Span works. + /// + /// Span hierarchy: + /// proxy_request (parent) + /// └── waf.check + /// ├── waf.validate_request + /// ├── waf.rate_limit + /// └── waf.burst_detection + pub fn checkWithSpan( + self: *WafEngine, + request: *const Request, + parent_span: anytype, + comptime telemetry_mod: type, + ) CheckResult { + // Create WAF check span as child of parent + var waf_span = telemetry_mod.startChildSpan(parent_span, "waf.check", .Internal); + defer waf_span.end(); + + // Fast path: WAF disabled + if (!self.waf_config.enabled) { + waf_span.setStringAttribute("waf.enabled", "false"); + return CheckResult.allow(); + } + + // Extract real client IP (handles trusted proxies) + const client_ip = self.getClientIp(request); + + // Step 1: Validate request (URI length, body size) + { + var validate_span = telemetry_mod.startChildSpan(&waf_span, "waf.validate_request", .Internal); + defer validate_span.end(); + + const validation_result = self.validateRequest(request); + validate_span.setStringAttribute("waf.step", "validate_request"); + validate_span.setBoolAttribute("waf.passed", validation_result.decision == .allow); + + if (validation_result.decision != .allow) { + validate_span.setStringAttribute("waf.reason", validation_result.reason.description()); + waf_span.setStringAttribute("waf.blocked_by", "validate_request"); + self.recordDecision(&validation_result); + return self.applyMode(validation_result); + } + } + + // Step 2: Check rate limits + const rate_limit_result = blk: { + var rate_span = telemetry_mod.startChildSpan(&waf_span, "waf.rate_limit", .Internal); + defer rate_span.end(); + + const result = self.checkRateLimit(request, client_ip); + rate_span.setStringAttribute("waf.step", "rate_limit"); + rate_span.setBoolAttribute("waf.passed", result.decision == .allow); + + if (result.tokens_remaining) |remaining| { + rate_span.setIntAttribute("waf.tokens_remaining", @intCast(remaining)); + } + if (result.rule_name) |rule| { + rate_span.setStringAttribute("waf.rule", rule); + } + + if (result.decision != .allow) { + rate_span.setStringAttribute("waf.reason", result.reason.description()); + waf_span.setStringAttribute("waf.blocked_by", "rate_limit"); + self.recordDecision(&result); + break :blk self.applyMode(result); + } + break :blk result; + }; + + // Early return if rate limited + if (rate_limit_result.decision != .allow) { + return rate_limit_result; + } + + // Step 3: Check for burst behavior (sudden velocity spike) + if (self.waf_config.burst_detection_enabled) { + var burst_span = telemetry_mod.startChildSpan(&waf_span, "waf.burst_detection", .Internal); + defer burst_span.end(); + + const burst_result = self.checkBurst(client_ip); + burst_span.setStringAttribute("waf.step", "burst_detection"); + burst_span.setBoolAttribute("waf.passed", burst_result.decision == .allow); + + if (burst_result.decision != .allow) { + burst_span.setStringAttribute("waf.reason", burst_result.reason.description()); + waf_span.setStringAttribute("waf.blocked_by", "burst_detection"); + self.recordDecision(&burst_result); + return self.applyMode(burst_result); + } + } + + // All checks passed + waf_span.setStringAttribute("waf.result", "allow"); + self.recordDecision(&rate_limit_result); + return self.applyMode(rate_limit_result); + } + + /// Extract the real client IP, handling trusted proxies + /// + /// If the direct connection is from a trusted proxy and X-Forwarded-For + /// is present, parse and return the first (client) IP from the chain. + /// Otherwise, return the direct connection IP. + pub fn getClientIp(self: *WafEngine, request: *const Request) u32 { + // Check if source IP is from a trusted proxy + if (!self.waf_config.isTrustedProxy(request.source_ip)) { + return request.source_ip; + } + + // Source is trusted proxy - try to parse X-Forwarded-For + const xff = request.x_forwarded_for orelse return request.source_ip; + if (xff.len == 0) return request.source_ip; + + // X-Forwarded-For format: "client, proxy1, proxy2" + // We want the leftmost IP (original client) + const first_ip = blk: { + if (std.mem.indexOf(u8, xff, ",")) |comma_pos| { + break :blk std.mem.trim(u8, xff[0..comma_pos], " "); + } + break :blk std.mem.trim(u8, xff, " "); + }; + + // Parse the IP address + return parseIpv4(first_ip) catch request.source_ip; + } + + /// Validate request against size and format limits + fn validateRequest(self: *WafEngine, request: *const Request) CheckResult { + const limits = &self.waf_config.request_limits; + + // Check URI length + if (request.uri.len > limits.max_uri_length) { + return CheckResult.block(.invalid_request, null); + } + + // Check body size (if Content-Length present) + if (request.content_length) |content_len| { + const max_body = self.waf_config.getMaxBodySize(request.getPath()); + if (content_len > max_body) { + return CheckResult.block(.body_too_large, null); + } + } + + return CheckResult.allow(); + } + + /// Check rate limits for the request + fn checkRateLimit(self: *WafEngine, request: *const Request, client_ip: u32) CheckResult { + // Find matching rate limit rule + const rule = self.waf_config.findRateLimitRule( + request.getPath(), + request.method, + ) orelse { + // No matching rule - allow by default + return CheckResult.allow(); + }; + + // Create rate limit key (IP + path hash) + const path_hash = hashPath(rule.path); + const key = Key{ + .ip = client_ip, + .path_hash = path_hash, + }; + + // Convert config rule to rate limiter rule + const limiter_rule = Rule{ + .tokens_per_sec = rule.tokensPerSec(), + .burst_capacity = rule.burstCapacity(), + .cost_per_request = 1000, // 1 token per request (scaled by 1000) + }; + + // Check rate limit + const decision = self.limiter.check(key, &limiter_rule); + + if (decision.action == .block) { + var result = CheckResult.block(.rate_limit, rule.name); + result.tokens_remaining = decision.remaining_tokens; + return result; + } + + var result = CheckResult.allow(); + result.tokens_remaining = decision.remaining_tokens; + return result; + } + + /// Check for burst behavior (sudden velocity spike) + fn checkBurst(self: *WafEngine, client_ip: u32) CheckResult { + const current_time = rate_limiter.getCurrentTimeSec(); + const threshold = self.waf_config.burst_threshold; + + if (self.waf_state.checkBurst(client_ip, current_time, threshold)) { + return CheckResult.block(.burst, null); + } + + return CheckResult.allow(); + } + + /// Apply shadow mode transformation if enabled + fn applyMode(self: *WafEngine, result: CheckResult) CheckResult { + if (self.waf_config.shadow_mode) { + return result.toShadowMode(); + } + return result; + } + + /// Record decision in metrics + fn recordDecision(self: *WafEngine, result: *const CheckResult) void { + switch (result.decision) { + .allow => self.waf_state.metrics.recordAllowed(), + .block => self.waf_state.metrics.recordBlocked(result.reason), + .log_only => self.waf_state.metrics.recordLogged(), + } + } + + /// Get current metrics snapshot + pub fn getMetrics(self: *WafEngine) state.MetricsSnapshot { + return self.waf_state.metrics.snapshot(); + } + + /// Check if config epoch has changed (for hot-reload detection) + pub fn configEpochChanged(self: *WafEngine, last_epoch: u64) bool { + return self.waf_state.getConfigEpoch() != last_epoch; + } + + /// Get current config epoch + pub fn getConfigEpoch(self: *WafEngine) u64 { + return self.waf_state.getConfigEpoch(); + } +}; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Parse IPv4 address string to u32 (host byte order) +fn parseIpv4(ip_str: []const u8) !u32 { + var octets: [4]u8 = undefined; + var octet_idx: usize = 0; + var current: u32 = 0; + + for (ip_str) |c| { + if (c == '.') { + if (octet_idx >= 3) return error.InvalidIpAddress; + if (current > 255) return error.InvalidIpAddress; + octets[octet_idx] = @intCast(current); + octet_idx += 1; + current = 0; + } else if (c >= '0' and c <= '9') { + current = current * 10 + (c - '0'); + if (current > 255) return error.InvalidIpAddress; + } else { + return error.InvalidIpAddress; + } + } + + if (octet_idx != 3) return error.InvalidIpAddress; + octets[3] = @intCast(current); + + return (@as(u32, octets[0]) << 24) | + (@as(u32, octets[1]) << 16) | + (@as(u32, octets[2]) << 8) | + @as(u32, octets[3]); +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "Request: init and getPath" { + const req = Request.init(.GET, "/api/users?page=1", 0xC0A80101); + + try std.testing.expectEqual(HttpMethod.GET, req.method); + try std.testing.expectEqualStrings("/api/users?page=1", req.uri); + try std.testing.expectEqualStrings("/api/users", req.getPath()); + try std.testing.expectEqual(@as(u32, 0xC0A80101), req.source_ip); +} + +test "Request: withContentLength" { + const req = Request.withContentLength(.POST, "/api/upload", 0x0A000001, 1024); + + try std.testing.expectEqual(HttpMethod.POST, req.method); + try std.testing.expectEqual(@as(?usize, 1024), req.content_length); +} + +test "Request: getPath without query string" { + const req = Request.init(.GET, "/api/users", 0xC0A80101); + try std.testing.expectEqualStrings("/api/users", req.getPath()); +} + +test "CheckResult: allow" { + const result = CheckResult.allow(); + + try std.testing.expect(result.isAllowed()); + try std.testing.expect(!result.isBlocked()); + try std.testing.expect(!result.shouldLog()); + try std.testing.expectEqual(Decision.allow, result.decision); + try std.testing.expectEqual(Reason.none, result.reason); +} + +test "CheckResult: block" { + const result = CheckResult.block(.rate_limit, "login_bruteforce"); + + try std.testing.expect(!result.isAllowed()); + try std.testing.expect(result.isBlocked()); + try std.testing.expect(result.shouldLog()); + try std.testing.expectEqual(Decision.block, result.decision); + try std.testing.expectEqual(Reason.rate_limit, result.reason); + try std.testing.expectEqualStrings("login_bruteforce", result.rule_name.?); +} + +test "CheckResult: logOnly" { + const result = CheckResult.logOnly(.body_too_large, null); + + try std.testing.expect(result.isAllowed()); // log_only still allows + try std.testing.expect(!result.isBlocked()); + try std.testing.expect(result.shouldLog()); + try std.testing.expectEqual(Decision.log_only, result.decision); + try std.testing.expectEqual(Reason.body_too_large, result.reason); +} + +test "CheckResult: toShadowMode converts block to log_only" { + const blocked = CheckResult.block(.rate_limit, "test_rule"); + const shadowed = blocked.toShadowMode(); + + try std.testing.expectEqual(Decision.log_only, shadowed.decision); + try std.testing.expectEqual(Reason.rate_limit, shadowed.reason); + try std.testing.expectEqualStrings("test_rule", shadowed.rule_name.?); +} + +test "CheckResult: toShadowMode preserves allow" { + const allowed = CheckResult.allow(); + const shadowed = allowed.toShadowMode(); + + try std.testing.expectEqual(Decision.allow, shadowed.decision); +} + +test "parseIpv4: valid addresses" { + try std.testing.expectEqual(@as(u32, 0x7F000001), try parseIpv4("127.0.0.1")); + try std.testing.expectEqual(@as(u32, 0xC0A80101), try parseIpv4("192.168.1.1")); + try std.testing.expectEqual(@as(u32, 0x0A000001), try parseIpv4("10.0.0.1")); +} + +test "parseIpv4: invalid addresses" { + try std.testing.expectError(error.InvalidIpAddress, parseIpv4("256.0.0.0")); + try std.testing.expectError(error.InvalidIpAddress, parseIpv4("10.0.0")); + try std.testing.expectError(error.InvalidIpAddress, parseIpv4("10.0.0.0.0")); +} + +test "WafEngine: init" { + var waf_state = WafState.init(); + const waf_config = WafConfig{}; + const engine = WafEngine.init(&waf_state, &waf_config); + + try std.testing.expect(engine.waf_state == &waf_state); + try std.testing.expect(engine.waf_config == &waf_config); +} + +test "WafEngine: check with disabled WAF" { + var waf_state = WafState.init(); + var waf_config = WafConfig{}; + waf_config.enabled = false; + + var engine = WafEngine.init(&waf_state, &waf_config); + const req = Request.init(.GET, "/api/users", 0xC0A80101); + const result = engine.check(&req); + + try std.testing.expect(result.isAllowed()); + try std.testing.expectEqual(Decision.allow, result.decision); +} + +test "WafEngine: check allows valid request" { + var waf_state = WafState.init(); + const waf_config = WafConfig{}; // No rate limit rules + + var engine = WafEngine.init(&waf_state, &waf_config); + const req = Request.init(.GET, "/api/users", 0xC0A80101); + const result = engine.check(&req); + + try std.testing.expect(result.isAllowed()); +} + +test "WafEngine: check blocks oversized URI" { + var waf_state = WafState.init(); + var waf_config = WafConfig{}; + waf_config.request_limits.max_uri_length = 10; + + var engine = WafEngine.init(&waf_state, &waf_config); + const req = Request.init(.GET, "/this/is/a/very/long/uri/that/exceeds/the/limit", 0xC0A80101); + const result = engine.check(&req); + + try std.testing.expect(result.isBlocked()); + try std.testing.expectEqual(Reason.invalid_request, result.reason); +} + +test "WafEngine: check blocks oversized body" { + var waf_state = WafState.init(); + var waf_config = WafConfig{}; + waf_config.request_limits.max_body_size = 100; + + var engine = WafEngine.init(&waf_state, &waf_config); + const req = Request.withContentLength(.POST, "/api/upload", 0xC0A80101, 1000); + const result = engine.check(&req); + + try std.testing.expect(result.isBlocked()); + try std.testing.expectEqual(Reason.body_too_large, result.reason); +} + +test "WafEngine: shadow mode converts block to log_only" { + var waf_state = WafState.init(); + var waf_config = WafConfig{}; + waf_config.shadow_mode = true; + waf_config.request_limits.max_uri_length = 10; + + var engine = WafEngine.init(&waf_state, &waf_config); + const req = Request.init(.GET, "/this/is/a/very/long/uri", 0xC0A80101); + const result = engine.check(&req); + + // In shadow mode, should be log_only instead of block + try std.testing.expect(result.isAllowed()); + try std.testing.expectEqual(Decision.log_only, result.decision); + try std.testing.expectEqual(Reason.invalid_request, result.reason); +} + +test "WafEngine: getClientIp returns source IP when not trusted" { + var waf_state = WafState.init(); + const waf_config = WafConfig{}; // No trusted proxies + + var engine = WafEngine.init(&waf_state, &waf_config); + var req = Request.init(.GET, "/api/users", 0xC0A80101); + req.x_forwarded_for = "10.0.0.1"; + + const client_ip = engine.getClientIp(&req); + try std.testing.expectEqual(@as(u32, 0xC0A80101), client_ip); +} + +test "WafEngine: getClientIp returns source IP when no XFF" { + var waf_state = WafState.init(); + const waf_config = WafConfig{}; + + var engine = WafEngine.init(&waf_state, &waf_config); + const req = Request.init(.GET, "/api/users", 0xC0A80101); + + const client_ip = engine.getClientIp(&req); + try std.testing.expectEqual(@as(u32, 0xC0A80101), client_ip); +} + +test "WafEngine: metrics recording" { + var waf_state = WafState.init(); + var waf_config = WafConfig{}; + waf_config.request_limits.max_uri_length = 10; + + var engine = WafEngine.init(&waf_state, &waf_config); + + // Make a valid request + const valid_req = Request.init(.GET, "/api", 0xC0A80101); + _ = engine.check(&valid_req); + + // Make an invalid request + const invalid_req = Request.init(.GET, "/very/long/uri/path", 0xC0A80101); + _ = engine.check(&invalid_req); + + // Check metrics + const metrics = engine.getMetrics(); + try std.testing.expectEqual(@as(u64, 1), metrics.requests_allowed); + try std.testing.expectEqual(@as(u64, 1), metrics.requests_blocked); +} + +test "WafEngine: config epoch" { + var waf_state = WafState.init(); + const waf_config = WafConfig{}; + + var engine = WafEngine.init(&waf_state, &waf_config); + + const epoch1 = engine.getConfigEpoch(); + try std.testing.expectEqual(@as(u64, 0), epoch1); + try std.testing.expect(!engine.configEpochChanged(0)); + + // Increment epoch + _ = waf_state.incrementConfigEpoch(); + + try std.testing.expect(engine.configEpochChanged(0)); + try std.testing.expectEqual(@as(u64, 1), engine.getConfigEpoch()); +} diff --git a/src/waf/events.zig b/src/waf/events.zig new file mode 100644 index 0000000..5a27ad0 --- /dev/null +++ b/src/waf/events.zig @@ -0,0 +1,800 @@ +/// WAF Structured Event Logging +/// +/// Provides structured JSON logging for WAF events including blocked requests, +/// rate limit warnings, and configuration changes. +/// +/// Design Philosophy (TigerBeetle-inspired): +/// - Zero allocation on hot path - fixed-size buffers only +/// - Structured JSON output for machine parsing +/// - Human-readable formatting for debugging +/// - Configurable log levels via LoggingConfig +/// +/// Output Format: JSON Lines (one event per line) +/// ```json +/// {"timestamp":1703635200,"event_type":"blocked","client_ip":"192.168.1.1","method":"POST","path":"/api/login","rule_name":"login_bruteforce","reason":"rate_limit"} +/// ``` +const std = @import("std"); + +const config = @import("config.zig"); +pub const LoggingConfig = config.LoggingConfig; +pub const HttpMethod = config.HttpMethod; + +const state = @import("state.zig"); +pub const Reason = state.Reason; + +// ============================================================================= +// Constants +// ============================================================================= + +/// Maximum length of formatted IP string "255.255.255.255" = 15 chars +pub const MAX_IP_STRING_LEN: usize = 15; + +/// Maximum length of a single log line (generous for JSON overhead) +pub const MAX_LOG_LINE_LEN: usize = 2048; + +/// Maximum path length to include in logs (truncated if longer) +pub const MAX_LOG_PATH_LEN: usize = 256; + +/// Maximum rule name length to include in logs +pub const MAX_LOG_RULE_NAME_LEN: usize = 64; + +// ============================================================================= +// EventType - Classification of WAF Events +// ============================================================================= + +/// Type of WAF event for structured logging +pub const EventType = enum { + /// Request was blocked by WAF + blocked, + /// Request was allowed (verbose logging) + allowed, + /// Request is approaching rate limit threshold + near_limit, + /// WAF configuration was reloaded + config_reload, + + /// Convert to JSON string representation + pub fn toJsonString(self: EventType) []const u8 { + return switch (self) { + .blocked => "blocked", + .allowed => "allowed", + .near_limit => "near_limit", + .config_reload => "config_reload", + }; + } +}; + +// ============================================================================= +// WafEvent - Structured Event Data +// ============================================================================= + +/// Represents a single WAF event for logging +/// All string fields are slices into external memory (no allocation) +pub const WafEvent = struct { + /// Unix timestamp (seconds since epoch) + timestamp: i64, + /// Type of event + event_type: EventType, + /// Client IP address as string (e.g., "192.168.1.1") + client_ip: []const u8, + /// HTTP method (e.g., "GET", "POST") + method: []const u8, + /// Request path (may be truncated) + path: []const u8, + /// Name of the rule that triggered (if applicable) + rule_name: ?[]const u8 = null, + /// Human-readable reason for the decision + reason: ?[]const u8 = null, + /// Remaining rate limit tokens (if applicable) + tokens_remaining: ?u32 = null, + /// New config epoch (for config_reload events) + config_epoch: ?u64 = null, + + /// Format event into a fixed-size buffer as JSON + /// Returns a slice of the buffer containing the formatted output + pub fn format(self: *const WafEvent, buffer: []u8) []const u8 { + var pos: usize = 0; + + // Start JSON object + if (pos + 1 > buffer.len) return buffer[0..0]; + buffer[pos] = '{'; + pos += 1; + + // timestamp + const ts_prefix = "\"timestamp\":"; + if (pos + ts_prefix.len > buffer.len) return buffer[0..pos]; + @memcpy(buffer[pos..][0..ts_prefix.len], ts_prefix); + pos += ts_prefix.len; + pos = appendInt(buffer, pos, self.timestamp); + + // event_type + pos = appendJsonField(buffer, pos, "event_type", self.event_type.toJsonString()); + + // client_ip + pos = appendJsonField(buffer, pos, "client_ip", self.client_ip); + + // method + pos = appendJsonField(buffer, pos, "method", self.method); + + // path (escaped) + pos = appendJsonFieldEscaped(buffer, pos, "path", self.path); + + // rule_name (optional) + if (self.rule_name) |name| { + pos = appendJsonFieldEscaped(buffer, pos, "rule_name", name); + } + + // reason (optional) + if (self.reason) |r| { + pos = appendJsonFieldEscaped(buffer, pos, "reason", r); + } + + // tokens_remaining (optional) + if (self.tokens_remaining) |tokens| { + const tok_prefix = ",\"tokens_remaining\":"; + if (pos + tok_prefix.len <= buffer.len) { + @memcpy(buffer[pos..][0..tok_prefix.len], tok_prefix); + pos += tok_prefix.len; + pos = appendInt(buffer, pos, tokens); + } + } + + // config_epoch (optional) + if (self.config_epoch) |epoch| { + const epoch_prefix = ",\"config_epoch\":"; + if (pos + epoch_prefix.len <= buffer.len) { + @memcpy(buffer[pos..][0..epoch_prefix.len], epoch_prefix); + pos += epoch_prefix.len; + pos = appendInt(buffer, pos, epoch); + } + } + + // Close JSON object and add newline + if (pos + 2 <= buffer.len) { + buffer[pos] = '}'; + pos += 1; + buffer[pos] = '\n'; + pos += 1; + } + + return buffer[0..pos]; + } + + /// Create a blocked event + pub fn blocked( + client_ip: []const u8, + method: []const u8, + path: []const u8, + rule_name: ?[]const u8, + reason: Reason, + tokens_remaining: ?u32, + ) WafEvent { + return .{ + .timestamp = getCurrentTimestamp(), + .event_type = .blocked, + .client_ip = client_ip, + .method = method, + .path = truncatePath(path), + .rule_name = rule_name, + .reason = reason.description(), + .tokens_remaining = tokens_remaining, + }; + } + + /// Create an allowed event + pub fn allowed( + client_ip: []const u8, + method: []const u8, + path: []const u8, + tokens_remaining: ?u32, + ) WafEvent { + return .{ + .timestamp = getCurrentTimestamp(), + .event_type = .allowed, + .client_ip = client_ip, + .method = method, + .path = truncatePath(path), + .tokens_remaining = tokens_remaining, + }; + } + + /// Create a near-limit warning event + pub fn nearLimit( + client_ip: []const u8, + method: []const u8, + path: []const u8, + rule_name: ?[]const u8, + tokens_remaining: u32, + ) WafEvent { + return .{ + .timestamp = getCurrentTimestamp(), + .event_type = .near_limit, + .client_ip = client_ip, + .method = method, + .path = truncatePath(path), + .rule_name = rule_name, + .tokens_remaining = tokens_remaining, + }; + } + + /// Create a config reload event + pub fn configReload(epoch: u64) WafEvent { + return .{ + .timestamp = getCurrentTimestamp(), + .event_type = .config_reload, + .client_ip = "-", + .method = "-", + .path = "-", + .config_epoch = epoch, + }; + } +}; + +// ============================================================================= +// EventLogger - Central Logging Coordinator +// ============================================================================= + +/// Central event logger for WAF +/// Coordinates logging based on configuration settings +pub const EventLogger = struct { + /// Pointer to logging configuration + log_config: *const LoggingConfig, + /// Optional file handle for persistent logging + file: ?std.fs.File, + /// Buffer for IP formatting + ip_buffer: [MAX_IP_STRING_LEN + 1]u8 = undefined, + /// Buffer for log line formatting + line_buffer: [MAX_LOG_LINE_LEN]u8 = undefined, + + /// Initialize an event logger + pub fn init(logging_config: *const LoggingConfig) EventLogger { + return .{ + .log_config = logging_config, + .file = null, + }; + } + + /// Initialize with a file for persistent logging + pub fn initWithFile(logging_config: *const LoggingConfig, file: std.fs.File) EventLogger { + return .{ + .log_config = logging_config, + .file = file, + }; + } + + /// Log a blocked request event + /// Only logs if log_blocked is enabled in config + pub fn logBlocked( + self: *EventLogger, + client_ip: u32, + method: HttpMethod, + path: []const u8, + rule_name: ?[]const u8, + reason: Reason, + tokens_remaining: ?u32, + ) void { + if (!self.log_config.log_blocked) return; + + const ip_str = self.formatIpInternal(client_ip); + const event = WafEvent.blocked( + ip_str, + method.toString(), + path, + rule_name, + reason, + tokens_remaining, + ); + + self.emitEvent(&event); + } + + /// Log an allowed request event + /// Only logs if log_allowed is enabled in config + pub fn logAllowed( + self: *EventLogger, + client_ip: u32, + method: HttpMethod, + path: []const u8, + tokens_remaining: ?u32, + ) void { + if (!self.log_config.log_allowed) return; + + const ip_str = self.formatIpInternal(client_ip); + const event = WafEvent.allowed( + ip_str, + method.toString(), + path, + tokens_remaining, + ); + + self.emitEvent(&event); + } + + /// Log a near-limit warning event + /// Only logs if log_near_limit is enabled and tokens are below threshold + pub fn logNearLimit( + self: *EventLogger, + client_ip: u32, + method: HttpMethod, + path: []const u8, + rule_name: ?[]const u8, + tokens_remaining: u32, + burst_capacity: u32, + ) void { + if (!self.log_config.log_near_limit) return; + + // Check if we're near the limit threshold + if (burst_capacity == 0) return; + + const usage_ratio: f32 = 1.0 - (@as(f32, @floatFromInt(tokens_remaining)) / + @as(f32, @floatFromInt(burst_capacity))); + + if (usage_ratio < self.log_config.near_limit_threshold) return; + + const ip_str = self.formatIpInternal(client_ip); + const event = WafEvent.nearLimit( + ip_str, + method.toString(), + path, + rule_name, + tokens_remaining, + ); + + self.emitEvent(&event); + } + + /// Log a configuration reload event + /// Always logs when called (config reloads are important) + pub fn logConfigReload(self: *EventLogger, epoch: u64) void { + const event = WafEvent.configReload(epoch); + self.emitEvent(&event); + } + + /// Internal: emit an event to configured outputs + fn emitEvent(self: *EventLogger, event: *const WafEvent) void { + const formatted = event.format(&self.line_buffer); + + // Write to file if configured + if (self.file) |file| { + _ = file.write(formatted) catch {}; + } + + // Also write to stderr for visibility (development/debugging) + const stderr = std.fs.File{ .handle = std.posix.STDERR_FILENO }; + _ = stderr.write(formatted) catch {}; + } + + /// Internal: format IP to internal buffer + fn formatIpInternal(self: *EventLogger, ip: u32) []const u8 { + return formatIpToBuffer(ip, &self.ip_buffer); + } +}; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Format IPv4 address (u32) to string representation +/// Returns a fixed-size array with the formatted IP +pub fn formatIp(ip: u32) [MAX_IP_STRING_LEN]u8 { + var buffer: [MAX_IP_STRING_LEN]u8 = [_]u8{0} ** MAX_IP_STRING_LEN; + _ = formatIpToBuffer(ip, &buffer); + return buffer; +} + +/// Format IPv4 address to a provided buffer +/// Returns slice of the buffer containing the formatted string +pub fn formatIpToBuffer(ip: u32, buffer: []u8) []const u8 { + if (buffer.len < MAX_IP_STRING_LEN) return buffer[0..0]; + + const formatted = std.fmt.bufPrint(buffer, "{d}.{d}.{d}.{d}", .{ + @as(u8, @truncate(ip >> 24)), + @as(u8, @truncate(ip >> 16)), + @as(u8, @truncate(ip >> 8)), + @as(u8, @truncate(ip)), + }) catch return buffer[0..0]; + + return formatted; +} + +/// Get current Unix timestamp in seconds +pub fn getCurrentTimestamp() i64 { + const ts = std.posix.clock_gettime(.REALTIME) catch { + return 0; + }; + return ts.sec; +} + +/// Truncate path to maximum log length +fn truncatePath(path: []const u8) []const u8 { + if (path.len <= MAX_LOG_PATH_LEN) return path; + return path[0..MAX_LOG_PATH_LEN]; +} + +/// Append an integer to buffer, return new position +fn appendInt(buffer: []u8, pos: usize, value: anytype) usize { + var int_buf: [32]u8 = undefined; + const formatted = std.fmt.bufPrint(&int_buf, "{d}", .{value}) catch return pos; + if (pos + formatted.len > buffer.len) return pos; + @memcpy(buffer[pos..][0..formatted.len], formatted); + return pos + formatted.len; +} + +/// Append a JSON string field (with comma prefix): ,"key":"value" +fn appendJsonField(buffer: []u8, pos: usize, key: []const u8, value: []const u8) usize { + // Calculate required space: ,"{key}":"{value}" + const overhead = 6; // ,"":"" + const required = overhead + key.len + value.len; + if (pos + required > buffer.len) return pos; + + var p = pos; + + buffer[p] = ','; + p += 1; + buffer[p] = '"'; + p += 1; + @memcpy(buffer[p..][0..key.len], key); + p += key.len; + buffer[p] = '"'; + p += 1; + buffer[p] = ':'; + p += 1; + buffer[p] = '"'; + p += 1; + @memcpy(buffer[p..][0..value.len], value); + p += value.len; + buffer[p] = '"'; + p += 1; + + return p; +} + +/// Append a JSON string field with escaping for special characters +fn appendJsonFieldEscaped(buffer: []u8, pos: usize, key: []const u8, value: []const u8) usize { + // First, try simple path if no escaping needed + var needs_escape = false; + for (value) |c| { + if (c == '"' or c == '\\' or c < 0x20) { + needs_escape = true; + break; + } + } + + if (!needs_escape) { + return appendJsonField(buffer, pos, key, value); + } + + // Need to escape - build escaped value + var p = pos; + + // ,"{key}":" + const prefix_overhead = 5; // ,":" + if (p + prefix_overhead + key.len > buffer.len) return pos; + + buffer[p] = ','; + p += 1; + buffer[p] = '"'; + p += 1; + @memcpy(buffer[p..][0..key.len], key); + p += key.len; + buffer[p] = '"'; + p += 1; + buffer[p] = ':'; + p += 1; + buffer[p] = '"'; + p += 1; + + // Write escaped value + for (value) |c| { + switch (c) { + '"' => { + if (p + 2 > buffer.len) break; + buffer[p] = '\\'; + p += 1; + buffer[p] = '"'; + p += 1; + }, + '\\' => { + if (p + 2 > buffer.len) break; + buffer[p] = '\\'; + p += 1; + buffer[p] = '\\'; + p += 1; + }, + '\n' => { + if (p + 2 > buffer.len) break; + buffer[p] = '\\'; + p += 1; + buffer[p] = 'n'; + p += 1; + }, + '\r' => { + if (p + 2 > buffer.len) break; + buffer[p] = '\\'; + p += 1; + buffer[p] = 'r'; + p += 1; + }, + '\t' => { + if (p + 2 > buffer.len) break; + buffer[p] = '\\'; + p += 1; + buffer[p] = 't'; + p += 1; + }, + else => { + if (c < 0x20) { + // Control characters - escape as \u00XX + if (p + 6 > buffer.len) break; + const hex = std.fmt.bufPrint(buffer[p..][0..6], "\\u00{x:0>2}", .{c}) catch break; + p += hex.len; + } else { + if (p + 1 > buffer.len) break; + buffer[p] = c; + p += 1; + } + }, + } + } + + // Close quote + if (p + 1 <= buffer.len) { + buffer[p] = '"'; + p += 1; + } + + return p; +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "EventType: toJsonString" { + try std.testing.expectEqualStrings("blocked", EventType.blocked.toJsonString()); + try std.testing.expectEqualStrings("allowed", EventType.allowed.toJsonString()); + try std.testing.expectEqualStrings("near_limit", EventType.near_limit.toJsonString()); + try std.testing.expectEqualStrings("config_reload", EventType.config_reload.toJsonString()); +} + +test "formatIp: basic formatting" { + const ip1: u32 = 0xC0A80101; // 192.168.1.1 + var buffer1: [MAX_IP_STRING_LEN]u8 = undefined; + const result1 = formatIpToBuffer(ip1, &buffer1); + try std.testing.expectEqualStrings("192.168.1.1", result1); + + const ip2: u32 = 0x7F000001; // 127.0.0.1 + var buffer2: [MAX_IP_STRING_LEN]u8 = undefined; + const result2 = formatIpToBuffer(ip2, &buffer2); + try std.testing.expectEqualStrings("127.0.0.1", result2); + + const ip3: u32 = 0x0A000001; // 10.0.0.1 + var buffer3: [MAX_IP_STRING_LEN]u8 = undefined; + const result3 = formatIpToBuffer(ip3, &buffer3); + try std.testing.expectEqualStrings("10.0.0.1", result3); +} + +test "formatIp: edge cases" { + const ip_zero: u32 = 0x00000000; // 0.0.0.0 + var buffer1: [MAX_IP_STRING_LEN]u8 = undefined; + const result1 = formatIpToBuffer(ip_zero, &buffer1); + try std.testing.expectEqualStrings("0.0.0.0", result1); + + const ip_max: u32 = 0xFFFFFFFF; // 255.255.255.255 + var buffer2: [MAX_IP_STRING_LEN]u8 = undefined; + const result2 = formatIpToBuffer(ip_max, &buffer2); + try std.testing.expectEqualStrings("255.255.255.255", result2); +} + +test "WafEvent: format JSON" { + const event = WafEvent{ + .timestamp = 1703635200, + .event_type = .blocked, + .client_ip = "192.168.1.1", + .method = "POST", + .path = "/api/login", + .rule_name = "login_bruteforce", + .reason = "rate limit exceeded", + .tokens_remaining = 0, + }; + + var buffer: [MAX_LOG_LINE_LEN]u8 = undefined; + const output = event.format(&buffer); + + // Verify it's valid JSON structure + try std.testing.expect(std.mem.startsWith(u8, output, "{")); + try std.testing.expect(std.mem.endsWith(u8, output, "}\n")); + + // Verify key fields are present + try std.testing.expect(std.mem.indexOf(u8, output, "\"timestamp\":1703635200") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "\"event_type\":\"blocked\"") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "\"client_ip\":\"192.168.1.1\"") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "\"method\":\"POST\"") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "\"path\":\"/api/login\"") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "\"rule_name\":\"login_bruteforce\"") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "\"reason\":\"rate limit exceeded\"") != null); +} + +test "WafEvent: blocked constructor" { + const event = WafEvent.blocked( + "10.0.0.1", + "GET", + "/api/resource", + "rate_limit_rule", + .rate_limit, + 100, + ); + + try std.testing.expectEqual(EventType.blocked, event.event_type); + try std.testing.expectEqualStrings("10.0.0.1", event.client_ip); + try std.testing.expectEqualStrings("GET", event.method); + try std.testing.expectEqualStrings("/api/resource", event.path); + try std.testing.expectEqualStrings("rate_limit_rule", event.rule_name.?); + try std.testing.expectEqualStrings("rate limit exceeded", event.reason.?); + try std.testing.expectEqual(@as(?u32, 100), event.tokens_remaining); +} + +test "WafEvent: allowed constructor" { + const event = WafEvent.allowed( + "192.168.0.1", + "POST", + "/api/upload", + 5000, + ); + + try std.testing.expectEqual(EventType.allowed, event.event_type); + try std.testing.expectEqualStrings("192.168.0.1", event.client_ip); + try std.testing.expectEqualStrings("POST", event.method); + try std.testing.expectEqualStrings("/api/upload", event.path); + try std.testing.expect(event.rule_name == null); + try std.testing.expect(event.reason == null); + try std.testing.expectEqual(@as(?u32, 5000), event.tokens_remaining); +} + +test "WafEvent: nearLimit constructor" { + const event = WafEvent.nearLimit( + "172.16.0.1", + "GET", + "/api/users", + "api_rate_limit", + 500, + ); + + try std.testing.expectEqual(EventType.near_limit, event.event_type); + try std.testing.expectEqualStrings("172.16.0.1", event.client_ip); + try std.testing.expectEqualStrings("api_rate_limit", event.rule_name.?); + try std.testing.expectEqual(@as(?u32, 500), event.tokens_remaining); +} + +test "WafEvent: configReload constructor" { + const event = WafEvent.configReload(42); + + try std.testing.expectEqual(EventType.config_reload, event.event_type); + try std.testing.expectEqualStrings("-", event.client_ip); + try std.testing.expectEqualStrings("-", event.method); + try std.testing.expectEqualStrings("-", event.path); + try std.testing.expectEqual(@as(?u64, 42), event.config_epoch); +} + +test "WafEvent: format handles special characters" { + const event = WafEvent{ + .timestamp = 1703635200, + .event_type = .blocked, + .client_ip = "10.0.0.1", + .method = "GET", + .path = "/api/test?q=\"hello\"&b=\\world", + .rule_name = null, + .reason = "test\nreason", + }; + + var buffer: [MAX_LOG_LINE_LEN]u8 = undefined; + const output = event.format(&buffer); + + // Verify special characters are escaped + try std.testing.expect(std.mem.indexOf(u8, output, "\\\"hello\\\"") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "\\\\world") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "test\\nreason") != null); +} + +test "EventLogger: init" { + const logging_config = LoggingConfig{}; + const logger = EventLogger.init(&logging_config); + + try std.testing.expect(logger.log_config == &logging_config); + try std.testing.expect(logger.file == null); +} + +test "EventLogger: respects log_blocked config" { + var logging_config = LoggingConfig{}; + logging_config.log_blocked = false; + + var logger = EventLogger.init(&logging_config); + + // This should not log anything since log_blocked is false + // (no crash is success, actual output verification would require more setup) + logger.logBlocked( + 0xC0A80101, + .POST, + "/api/login", + "test_rule", + .rate_limit, + 0, + ); +} + +test "EventLogger: respects log_allowed config" { + var logging_config = LoggingConfig{}; + logging_config.log_allowed = false; + + var logger = EventLogger.init(&logging_config); + + // This should not log anything since log_allowed is false + logger.logAllowed( + 0xC0A80101, + .GET, + "/api/users", + 5000, + ); +} + +test "truncatePath: short paths unchanged" { + const short_path = "/api/users"; + try std.testing.expectEqualStrings(short_path, truncatePath(short_path)); +} + +test "truncatePath: long paths truncated" { + var long_path: [MAX_LOG_PATH_LEN + 100]u8 = undefined; + for (&long_path) |*c| { + c.* = 'a'; + } + const truncated = truncatePath(&long_path); + try std.testing.expectEqual(MAX_LOG_PATH_LEN, truncated.len); +} + +test "appendJsonField: basic field" { + var buffer: [100]u8 = undefined; + const new_pos = appendJsonField(&buffer, 0, "key", "value"); + + try std.testing.expectEqualStrings(",\"key\":\"value\"", buffer[0..new_pos]); +} + +test "appendJsonFieldEscaped: no escaping needed" { + var buffer: [100]u8 = undefined; + const new_pos = appendJsonFieldEscaped(&buffer, 0, "key", "simple"); + + try std.testing.expectEqualStrings(",\"key\":\"simple\"", buffer[0..new_pos]); +} + +test "appendJsonFieldEscaped: escapes special chars" { + var buffer: [100]u8 = undefined; + const new_pos = appendJsonFieldEscaped(&buffer, 0, "key", "hello\"world"); + + try std.testing.expectEqualStrings(",\"key\":\"hello\\\"world\"", buffer[0..new_pos]); +} + +test "WafEvent: optional fields omitted when null" { + const event = WafEvent{ + .timestamp = 1703635200, + .event_type = .allowed, + .client_ip = "10.0.0.1", + .method = "GET", + .path = "/health", + .rule_name = null, + .reason = null, + .tokens_remaining = null, + }; + + var buffer: [MAX_LOG_LINE_LEN]u8 = undefined; + const output = event.format(&buffer); + + // Optional fields should not be present + try std.testing.expect(std.mem.indexOf(u8, output, "rule_name") == null); + try std.testing.expect(std.mem.indexOf(u8, output, "reason") == null); + try std.testing.expect(std.mem.indexOf(u8, output, "tokens_remaining") == null); +} + +test "getCurrentTimestamp: returns reasonable value" { + const ts = getCurrentTimestamp(); + // Should be after year 2020 (timestamp > 1577836800) + try std.testing.expect(ts > 1577836800); +} diff --git a/src/waf/mod.zig b/src/waf/mod.zig new file mode 100644 index 0000000..1cb777f --- /dev/null +++ b/src/waf/mod.zig @@ -0,0 +1,190 @@ +//! Web Application Firewall (WAF) Module +//! +//! High-performance WAF for the zzz load balancer with: +//! - Lock-free rate limiting (token bucket) +//! - Request validation (URI, body, JSON depth) +//! - Slowloris detection +//! - Shadow mode for safe rollout +//! - Hot-reload configuration +//! +//! ## Quick Start +//! ```zig +//! const waf = @import("waf"); +//! +//! // Load configuration +//! var config = try waf.WafConfig.loadFromFile(allocator, "waf.json"); +//! defer config.deinit(); +//! +//! // Initialize shared state +//! var state = waf.WafState.init(); +//! +//! // Create engine +//! var engine = waf.WafEngine.init(&state, &config); +//! +//! // Check requests +//! const result = engine.check(&request); +//! if (result.isBlocked()) { +//! // Return 403/429 +//! } +//! ``` + +// ============================================================================= +// Re-export core types from state.zig +// ============================================================================= + +pub const WafState = @import("state.zig").WafState; +pub const Decision = @import("state.zig").Decision; +pub const Reason = @import("state.zig").Reason; +pub const Bucket = @import("state.zig").Bucket; +pub const ConnEntry = @import("state.zig").ConnEntry; +pub const ConnTracker = @import("state.zig").ConnTracker; +pub const BurstEntry = @import("state.zig").BurstEntry; +pub const BurstTracker = @import("state.zig").BurstTracker; +pub const WafMetrics = @import("state.zig").WafMetrics; +pub const MetricsSnapshot = @import("state.zig").MetricsSnapshot; + +// State constants +pub const MAX_BUCKETS = @import("state.zig").MAX_BUCKETS; +pub const MAX_TOKENS = @import("state.zig").MAX_TOKENS; +pub const BUCKET_PROBE_LIMIT = @import("state.zig").BUCKET_PROBE_LIMIT; +pub const MAX_CAS_ATTEMPTS = @import("state.zig").MAX_CAS_ATTEMPTS; +pub const MAX_TRACKED_IPS = @import("state.zig").MAX_TRACKED_IPS; +pub const WAF_STATE_MAGIC = @import("state.zig").WAF_STATE_MAGIC; +pub const WAF_STATE_SIZE = @import("state.zig").WAF_STATE_SIZE; +pub const MAX_BURST_TRACKED = @import("state.zig").MAX_BURST_TRACKED; +pub const BURST_WINDOW_SEC = @import("state.zig").BURST_WINDOW_SEC; +pub const BURST_THRESHOLD_MULTIPLIER = @import("state.zig").BURST_THRESHOLD_MULTIPLIER; + +// State helper functions +pub const packState = @import("state.zig").packState; +pub const unpackTokens = @import("state.zig").unpackTokens; +pub const unpackTime = @import("state.zig").unpackTime; +pub const computeKeyHash = @import("state.zig").computeKeyHash; +pub const computeIpHash = @import("state.zig").computeIpHash; + +// ============================================================================= +// Re-export configuration types from config.zig +// ============================================================================= + +pub const WafConfig = @import("config.zig").WafConfig; +pub const RateLimitRule = @import("config.zig").RateLimitRule; +pub const SlowlorisConfig = @import("config.zig").SlowlorisConfig; +pub const RequestLimitsConfig = @import("config.zig").RequestLimitsConfig; +pub const LoggingConfig = @import("config.zig").LoggingConfig; +pub const EndpointOverride = @import("config.zig").EndpointOverride; +pub const CidrRange = @import("config.zig").CidrRange; +pub const HttpMethod = @import("config.zig").HttpMethod; +pub const RateLimitBy = @import("config.zig").RateLimitBy; +pub const Action = @import("config.zig").Action; + +// Config constants +pub const MAX_RATE_LIMIT_RULES = @import("config.zig").MAX_RATE_LIMIT_RULES; +pub const MAX_ENDPOINT_OVERRIDES = @import("config.zig").MAX_ENDPOINT_OVERRIDES; +pub const MAX_TRUSTED_PROXIES = @import("config.zig").MAX_TRUSTED_PROXIES; +pub const MAX_PATH_LENGTH = @import("config.zig").MAX_PATH_LENGTH; +pub const MAX_NAME_LENGTH = @import("config.zig").MAX_NAME_LENGTH; +pub const MAX_HEADER_NAME_LENGTH = @import("config.zig").MAX_HEADER_NAME_LENGTH; +pub const MAX_CONFIG_SIZE = @import("config.zig").MAX_CONFIG_SIZE; + +// Config helper functions +pub const ipToBytes = @import("config.zig").ipToBytes; + +// ============================================================================= +// Re-export engine types from engine.zig +// ============================================================================= + +pub const WafEngine = @import("engine.zig").WafEngine; +pub const Request = @import("engine.zig").Request; +pub const CheckResult = @import("engine.zig").CheckResult; + +// ============================================================================= +// Re-export rate limiter types from rate_limiter.zig +// ============================================================================= + +pub const RateLimiter = @import("rate_limiter.zig").RateLimiter; +pub const Key = @import("rate_limiter.zig").Key; +pub const Rule = @import("rate_limiter.zig").Rule; +pub const DecisionResult = @import("rate_limiter.zig").DecisionResult; +pub const BucketStats = @import("rate_limiter.zig").BucketStats; + +// Rate limiter helper functions +pub const hashPath = @import("rate_limiter.zig").hashPath; +pub const getCurrentTimeSec = @import("rate_limiter.zig").getCurrentTimeSec; + +// ============================================================================= +// Re-export validator types from validator.zig +// ============================================================================= + +pub const RequestValidator = @import("validator.zig").RequestValidator; +pub const ValidatorConfig = @import("validator.zig").ValidatorConfig; +pub const ValidationResult = @import("validator.zig").ValidationResult; +pub const JsonState = @import("validator.zig").JsonState; + +// ============================================================================= +// Re-export events types from events.zig +// ============================================================================= + +pub const EventLogger = @import("events.zig").EventLogger; +pub const WafEvent = @import("events.zig").WafEvent; +pub const EventType = @import("events.zig").EventType; + +// Events helper functions +pub const formatIp = @import("events.zig").formatIp; +pub const getCurrentTimestamp = @import("events.zig").getCurrentTimestamp; + +// ============================================================================= +// Tests - ensure all imports are valid +// ============================================================================= + +test { + // Import all submodules to ensure they compile + _ = @import("state.zig"); + _ = @import("config.zig"); + _ = @import("engine.zig"); + _ = @import("rate_limiter.zig"); + _ = @import("validator.zig"); + _ = @import("events.zig"); +} + +test "mod: WafState and WafConfig integration" { + const std = @import("std"); + + // Test that the exported types work together + var waf_state = WafState.init(); + try std.testing.expect(waf_state.validate()); + + const waf_config = WafConfig{}; + try waf_config.validate(); + + var engine = WafEngine.init(&waf_state, &waf_config); + const request = Request.init(.GET, "/api/users", 0xC0A80101); + const result = engine.check(&request); + + try std.testing.expect(result.isAllowed()); +} + +test "mod: RateLimiter integration" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(192, 168, 1, 1, hashPath("/api/test")); + const rule = Rule.simple(10, 10); + + const decision = limiter.check(key, &rule); + const std = @import("std"); + try std.testing.expect(decision.isAllowed()); +} + +test "mod: RequestValidator integration" { + const std = @import("std"); + + const config = ValidatorConfig{}; + const validator = RequestValidator.init(&config); + + const result = validator.validateRequest("/api/users", null, null); + try std.testing.expect(result.isValid()); + + var json_state = JsonState{}; + const json_result = validator.validateJsonStream("{\"key\": \"value\"}", &json_state); + try std.testing.expect(json_result.isValid()); +} diff --git a/src/waf/rate_limiter.zig b/src/waf/rate_limiter.zig new file mode 100644 index 0000000..6d5ec74 --- /dev/null +++ b/src/waf/rate_limiter.zig @@ -0,0 +1,684 @@ +/// Lock-free Token Bucket Rate Limiter +/// +/// High-performance rate limiting using atomic CAS operations on shared memory. +/// Designed for multi-process environments with zero allocation on the hot path. +/// +/// Design Philosophy (TigerBeetle-inspired): +/// - Lock-free: Uses atomic CAS operations, no mutexes +/// - Fixed-size: All structures have bounded, compile-time known sizes +/// - Fail-open: Under extreme contention, allows requests rather than blocking +/// - Token precision: Scaled by 1000 for sub-token granularity +/// - Time-safe: Uses wrapping arithmetic for timestamp handling +/// +/// Token Bucket Algorithm: +/// - Tokens refill at a configurable rate per second +/// - Burst capacity limits maximum accumulated tokens +/// - Each request consumes tokens; blocked if insufficient +/// - Atomic CAS ensures correctness under concurrent access +const std = @import("std"); + +const state = @import("state.zig"); +pub const WafState = state.WafState; +pub const Bucket = state.Bucket; +pub const Decision = state.Decision; +pub const Reason = state.Reason; +pub const MAX_BUCKETS = state.MAX_BUCKETS; +pub const MAX_CAS_ATTEMPTS = state.MAX_CAS_ATTEMPTS; +pub const BUCKET_PROBE_LIMIT = state.BUCKET_PROBE_LIMIT; +pub const packState = state.packState; +pub const unpackTokens = state.unpackTokens; +pub const unpackTime = state.unpackTime; + +// ============================================================================= +// Key - What to Rate Limit By +// ============================================================================= + +/// Represents the identity for rate limiting +/// Combines IP address and path pattern hash for fine-grained control +pub const Key = struct { + /// IPv4 address as u32 (network byte order) + ip: u32, + /// Hash of the path pattern being rate limited + path_hash: u32, + + /// Compute a 64-bit hash combining IP and path for bucket lookup + /// Uses FNV-1a for speed and good distribution + pub fn hash(self: Key) u64 { + return computeKeyHash(self.ip, self.path_hash); + } + + /// Create a key from raw IPv4 bytes + pub fn fromIpBytes(ip_bytes: [4]u8, path_hash: u32) Key { + return .{ + .ip = std.mem.readInt(u32, &ip_bytes, .big), + .path_hash = path_hash, + }; + } + + /// Create a key from IPv4 octets + pub fn fromOctets(a: u8, b: u8, c: u8, d: u8, path_hash: u32) Key { + return .{ + .ip = (@as(u32, a) << 24) | (@as(u32, b) << 16) | (@as(u32, c) << 8) | @as(u32, d), + .path_hash = path_hash, + }; + } +}; + +// ============================================================================= +// Rule - Rate Limit Configuration +// ============================================================================= + +/// Rate limit rule configuration +/// Token values are scaled by 1000 for sub-token precision +pub const Rule = struct { + /// Tokens added per second (scaled by 1000) + /// Example: 1500 = 1.5 requests per second + tokens_per_sec: u32, + + /// Maximum tokens the bucket can hold (scaled by 1000) + /// This is the burst capacity - allows temporary spikes + burst_capacity: u32, + + /// Cost per request (scaled by 1000) + /// Default: 1000 = 1 token per request + cost_per_request: u32 = 1000, + + /// Create a simple rule: X requests per second with Y burst + pub fn simple(requests_per_sec: u32, burst: u32) Rule { + return .{ + .tokens_per_sec = requests_per_sec * 1000, + .burst_capacity = burst * 1000, + .cost_per_request = 1000, + }; + } + + /// Create a rule with fractional rates (milli-precision) + /// Example: milliRate(1500, 5000, 1000) = 1.5 req/sec, 5 burst, 1 cost + pub fn milliRate(tokens_per_sec: u32, burst_capacity: u32, cost: u32) Rule { + return .{ + .tokens_per_sec = tokens_per_sec, + .burst_capacity = burst_capacity, + .cost_per_request = cost, + }; + } + + comptime { + // Ensure rule fits in a cache line (important for hot path) + std.debug.assert(@sizeOf(Rule) <= 64); + } +}; + +// ============================================================================= +// DecisionResult - Detailed Rate Limit Response +// ============================================================================= + +/// Result of a rate limit check with detailed information +pub const DecisionResult = struct { + /// The WAF decision (allow/block) + action: Decision, + /// Reason for the decision + reason: Reason = .none, + /// Remaining tokens after this request (if allowed), scaled by 1000 + remaining_tokens: u32 = 0, + + /// Check if request should be blocked + pub inline fn isBlocked(self: DecisionResult) bool { + return self.action == .block; + } + + /// Check if request was allowed + pub inline fn isAllowed(self: DecisionResult) bool { + return self.action == .allow; + } +}; + +// ============================================================================= +// RateLimiter - Main Rate Limiting Engine +// ============================================================================= + +/// Lock-free token bucket rate limiter +/// Uses shared memory for multi-process visibility +pub const RateLimiter = struct { + /// Pointer to shared WAF state (in mmap'd region) + state: *WafState, + + /// Initialize a rate limiter with shared state + pub fn init(waf_state: *WafState) RateLimiter { + return .{ .state = waf_state }; + } + + /// Check if a request should be rate limited + /// This is the hot path - optimized for minimal latency + /// + /// Returns: DecisionResult with action, reason, and remaining tokens + /// + /// Behavior: + /// - Finds or creates bucket for the key + /// - Attempts atomic token consumption via CAS + /// - Fails open under extreme contention (CAS exhausted) + pub fn check(self: *RateLimiter, key: Key, rule: *const Rule) DecisionResult { + const key_hash = key.hash(); + const now_sec = getCurrentTimeSec(); + + // Find or create the bucket for this key + const bucket = self.state.findOrCreateBucket(key_hash, now_sec) orelse { + // Probe limit exceeded - table too full + // Fail open: allow request but log the condition + self.state.metrics.recordCasExhausted(); + return .{ + .action = .allow, + .reason = .none, + .remaining_tokens = 0, + }; + }; + + // Increment total requests counter (relaxed ordering - just metrics) + _ = bucket.getTotalRequestsPtr().fetchAdd(1, .monotonic); + + // Attempt to consume tokens using the bucket's atomic tryConsume + if (bucket.tryConsume(rule.cost_per_request, rule.tokens_per_sec, rule.burst_capacity, now_sec)) { + // Success - request allowed + const current_state = bucket.getPackedStatePtrConst().load(.acquire); + return .{ + .action = .allow, + .reason = .none, + .remaining_tokens = unpackTokens(current_state), + }; + } + + // Rate limited - increment blocked counter + _ = bucket.getTotalBlockedPtr().fetchAdd(1, .monotonic); + + // Get remaining tokens for the response + const current_state = bucket.getPackedStatePtrConst().load(.acquire); + + return .{ + .action = .block, + .reason = .rate_limit, + .remaining_tokens = unpackTokens(current_state), + }; + } + + /// Check rate limit with explicit CAS loop (alternative implementation) + /// Provides more control over the retry behavior + pub fn checkWithRetry(self: *RateLimiter, key: Key, rule: *const Rule) DecisionResult { + const key_hash = key.hash(); + const now_sec = getCurrentTimeSec(); + + const bucket = self.state.findOrCreateBucket(key_hash, now_sec) orelse { + self.state.metrics.recordCasExhausted(); + return .{ .action = .allow, .reason = .none }; + }; + + _ = bucket.getTotalRequestsPtr().fetchAdd(1, .monotonic); + + const state_ptr = bucket.getPackedStatePtr(); + var attempts: u32 = 0; + + while (attempts < MAX_CAS_ATTEMPTS) : (attempts += 1) { + const old = state_ptr.load(.acquire); + const old_tokens = unpackTokens(old); + const old_time = unpackTime(old); + + // Refill tokens based on elapsed time (wrap-safe) + const elapsed = now_sec -% old_time; + const refill = @min( + @as(u64, elapsed) * @as(u64, rule.tokens_per_sec), + @as(u64, rule.burst_capacity), + ); + const available: u32 = @intCast(@min( + @as(u64, old_tokens) + refill, + @as(u64, rule.burst_capacity), + )); + + // Check if we have enough tokens + if (available < rule.cost_per_request) { + _ = bucket.getTotalBlockedPtr().fetchAdd(1, .monotonic); + return .{ + .action = .block, + .reason = .rate_limit, + .remaining_tokens = available, + }; + } + + // Calculate new state + const new_tokens = available - rule.cost_per_request; + const new = packState(new_tokens, now_sec); + + // Attempt atomic update + if (state_ptr.cmpxchgWeak(old, new, .acq_rel, .acquire) == null) { + return .{ + .action = .allow, + .reason = .none, + .remaining_tokens = new_tokens, + }; + } + // CAS failed, retry with fresh state + } + + // CAS exhausted - fail open for availability + self.state.metrics.recordCasExhausted(); + return .{ + .action = .allow, + .reason = .none, + .remaining_tokens = 0, + }; + } + + /// Find the bucket index for a key using open addressing + /// Returns the index if found, or creates a new bucket + pub fn findBucket(self: *RateLimiter, key: Key) ?usize { + const key_hash = key.hash(); + const start_idx = @as(usize, @truncate(key_hash)) % MAX_BUCKETS; + var idx = start_idx; + var probe_count: u32 = 0; + + while (probe_count < BUCKET_PROBE_LIMIT) : (probe_count += 1) { + const bucket = &self.state.buckets[idx]; + + if (bucket.key_hash == key_hash) { + return idx; + } + + if (bucket.key_hash == 0) { + return null; // Not found, slot is empty + } + + // Linear probing + idx = (idx + 1) % MAX_BUCKETS; + } + + return null; // Probe limit exceeded + } + + /// Get remaining tokens for a key (for metrics/headers) + /// Returns null if the key doesn't have an active bucket + pub fn getRemainingTokens(self: *RateLimiter, key: Key) ?u32 { + const key_hash = key.hash(); + const bucket = self.state.findBucket(key_hash) orelse return null; + + const current_state = bucket.getPackedStatePtrConst().load(.acquire); + return unpackTokens(current_state); + } + + /// Get bucket statistics for a key + pub fn getBucketStats(self: *RateLimiter, key: Key) ?BucketStats { + const key_hash = key.hash(); + const bucket = self.state.findBucket(key_hash) orelse return null; + + const current_state = bucket.getPackedStatePtrConst().load(.acquire); + + // Read total_requests and total_blocked atomically + // Use pointer cast for const atomic access + const total_requests_ptr: *const std.atomic.Value(u64) = @ptrCast(&bucket.total_requests); + const total_blocked_ptr: *const std.atomic.Value(u64) = @ptrCast(&bucket.total_blocked); + + return .{ + .tokens = unpackTokens(current_state), + .last_update = unpackTime(current_state), + .total_requests = total_requests_ptr.load(.monotonic), + .total_blocked = total_blocked_ptr.load(.monotonic), + }; + } + + /// Reset a bucket to full tokens (useful for testing or manual intervention) + pub fn resetBucket(self: *RateLimiter, key: Key, max_tokens: u32) bool { + const key_hash = key.hash(); + const now_sec = getCurrentTimeSec(); + const bucket = self.state.findOrCreateBucket(key_hash, now_sec) orelse return false; + + const new_state = packState(max_tokens, now_sec); + bucket.getPackedStatePtr().store(new_state, .release); + return true; + } +}; + +/// Statistics for a rate limit bucket +pub const BucketStats = struct { + /// Current token count (scaled by 1000) + tokens: u32, + /// Last update timestamp (seconds since epoch) + last_update: u32, + /// Total requests to this bucket + total_requests: u64, + /// Total requests blocked + total_blocked: u64, + + /// Calculate block rate as percentage + pub fn blockRatePercent(self: BucketStats) u64 { + if (self.total_requests == 0) return 0; + return (self.total_blocked * 100) / self.total_requests; + } +}; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Compute a 64-bit hash from IP and path hash +/// Uses FNV-1a variant for good distribution +pub fn computeKeyHash(ip: u32, path_hash: u32) u64 { + var hash_val: u64 = 0xcbf29ce484222325; // FNV offset basis + + // Mix in IP bytes + hash_val ^= @as(u64, ip >> 24); + hash_val *%= 0x100000001b3; // FNV prime + hash_val ^= @as(u64, (ip >> 16) & 0xFF); + hash_val *%= 0x100000001b3; + hash_val ^= @as(u64, (ip >> 8) & 0xFF); + hash_val *%= 0x100000001b3; + hash_val ^= @as(u64, ip & 0xFF); + hash_val *%= 0x100000001b3; + + // Separator to avoid collisions + hash_val ^= 0xff; + hash_val *%= 0x100000001b3; + + // Mix in path hash bytes + hash_val ^= @as(u64, path_hash >> 24); + hash_val *%= 0x100000001b3; + hash_val ^= @as(u64, (path_hash >> 16) & 0xFF); + hash_val *%= 0x100000001b3; + hash_val ^= @as(u64, (path_hash >> 8) & 0xFF); + hash_val *%= 0x100000001b3; + hash_val ^= @as(u64, path_hash & 0xFF); + hash_val *%= 0x100000001b3; + + // Ensure non-zero (0 is reserved for empty slots) + return if (hash_val == 0) 1 else hash_val; +} + +/// Get current timestamp in seconds (for rate limiting) +/// Returns seconds since epoch, wrapped to u32 +pub inline fn getCurrentTimeSec() u32 { + const ts = std.posix.clock_gettime(.REALTIME) catch { + // Fallback: return 0 if clock is unavailable (shouldn't happen in practice) + return 0; + }; + // ts.sec is i64 (seconds since epoch), wrap to u32 + return @truncate(@as(u64, @intCast(ts.sec))); +} + +/// Hash a path string to u32 (for Key.path_hash) +pub fn hashPath(path: []const u8) u32 { + var hash_val: u32 = 0x811c9dc5; // FNV-1a 32-bit offset basis + + for (path) |b| { + hash_val ^= b; + hash_val *%= 0x01000193; // FNV-1a 32-bit prime + } + + return hash_val; +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "Key: hash consistency" { + const key1 = Key{ .ip = 0xC0A80101, .path_hash = 0x12345678 }; + const key2 = Key{ .ip = 0xC0A80101, .path_hash = 0x12345678 }; + const key3 = Key{ .ip = 0xC0A80102, .path_hash = 0x12345678 }; + + // Same inputs produce same hash + try std.testing.expectEqual(key1.hash(), key2.hash()); + + // Different inputs produce different hash + try std.testing.expect(key1.hash() != key3.hash()); + + // Hash is never zero + try std.testing.expect(key1.hash() != 0); +} + +test "Key: fromOctets" { + const key = Key.fromOctets(192, 168, 1, 1, 0x12345678); + try std.testing.expectEqual(@as(u32, 0xC0A80101), key.ip); + try std.testing.expectEqual(@as(u32, 0x12345678), key.path_hash); +} + +test "Key: fromIpBytes" { + const bytes = [4]u8{ 192, 168, 1, 1 }; + const key = Key.fromIpBytes(bytes, 0xDEADBEEF); + try std.testing.expectEqual(@as(u32, 0xC0A80101), key.ip); + try std.testing.expectEqual(@as(u32, 0xDEADBEEF), key.path_hash); +} + +test "Rule: simple creation" { + const rule = Rule.simple(10, 20); + try std.testing.expectEqual(@as(u32, 10000), rule.tokens_per_sec); + try std.testing.expectEqual(@as(u32, 20000), rule.burst_capacity); + try std.testing.expectEqual(@as(u32, 1000), rule.cost_per_request); +} + +test "Rule: milliRate creation" { + const rule = Rule.milliRate(1500, 5000, 500); + try std.testing.expectEqual(@as(u32, 1500), rule.tokens_per_sec); + try std.testing.expectEqual(@as(u32, 5000), rule.burst_capacity); + try std.testing.expectEqual(@as(u32, 500), rule.cost_per_request); +} + +test "RateLimiter: init" { + var waf_state = WafState.init(); + const limiter = RateLimiter.init(&waf_state); + try std.testing.expect(limiter.state == &waf_state); +} + +test "RateLimiter: check allows within limit" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(192, 168, 1, 1, hashPath("/api/users")); + const rule = Rule.simple(10, 10); // 10 req/sec, 10 burst + + // First request should be allowed (bucket starts full) + const result = limiter.check(key, &rule); + try std.testing.expect(result.isAllowed()); + try std.testing.expectEqual(Decision.allow, result.action); + try std.testing.expectEqual(Reason.none, result.reason); +} + +test "RateLimiter: check blocks when exhausted" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(10, 0, 0, 1, hashPath("/api/heavy")); + const rule = Rule.simple(1, 2); // 1 req/sec, 2 burst + + // Exhaust the bucket (starts with burst_capacity tokens) + _ = limiter.check(key, &rule); // Uses 1 of 2 + _ = limiter.check(key, &rule); // Uses 1 of 1 + + // Third request should be blocked + const result = limiter.check(key, &rule); + try std.testing.expect(result.isBlocked()); + try std.testing.expectEqual(Decision.block, result.action); + try std.testing.expectEqual(Reason.rate_limit, result.reason); +} + +test "RateLimiter: different keys have different buckets" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key1 = Key.fromOctets(192, 168, 1, 1, hashPath("/api/a")); + const key2 = Key.fromOctets(192, 168, 1, 2, hashPath("/api/a")); + const rule = Rule.simple(1, 1); + + // Exhaust key1's bucket + _ = limiter.check(key1, &rule); + const result1 = limiter.check(key1, &rule); + try std.testing.expect(result1.isBlocked()); + + // Key2 should still be allowed (different bucket) + const result2 = limiter.check(key2, &rule); + try std.testing.expect(result2.isAllowed()); +} + +test "RateLimiter: getRemainingTokens" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(172, 16, 0, 1, hashPath("/health")); + const rule = Rule.simple(10, 5); // 5 burst = 5000 tokens + + // Non-existent key returns null + try std.testing.expect(limiter.getRemainingTokens(key) == null); + + // After first request, bucket exists + _ = limiter.check(key, &rule); + + // Should have tokens (5000 - 1000 = 4000) + const remaining = limiter.getRemainingTokens(key); + try std.testing.expect(remaining != null); + try std.testing.expectEqual(@as(u32, 4000), remaining.?); +} + +test "RateLimiter: getBucketStats" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(10, 10, 10, 10, hashPath("/stats")); + const rule = Rule.simple(1, 2); + + // Non-existent returns null + try std.testing.expect(limiter.getBucketStats(key) == null); + + // Make some requests + _ = limiter.check(key, &rule); // Allowed + _ = limiter.check(key, &rule); // Allowed + _ = limiter.check(key, &rule); // Blocked + + const stats = limiter.getBucketStats(key); + try std.testing.expect(stats != null); + try std.testing.expectEqual(@as(u64, 3), stats.?.total_requests); + try std.testing.expectEqual(@as(u64, 1), stats.?.total_blocked); +} + +test "RateLimiter: resetBucket" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(1, 2, 3, 4, hashPath("/reset")); + const rule = Rule.simple(1, 3); + + // Exhaust bucket + _ = limiter.check(key, &rule); + _ = limiter.check(key, &rule); + _ = limiter.check(key, &rule); + + // Should be blocked + const blocked_result = limiter.check(key, &rule); + try std.testing.expect(blocked_result.isBlocked()); + + // Reset to full + const reset_success = limiter.resetBucket(key, 3000); + try std.testing.expect(reset_success); + + // Should be allowed again + const allowed_result = limiter.check(key, &rule); + try std.testing.expect(allowed_result.isAllowed()); +} + +test "RateLimiter: findBucket returns index" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(8, 8, 8, 8, hashPath("/find")); + const rule = Rule.simple(10, 10); + + // Before any request, bucket doesn't exist + try std.testing.expect(limiter.findBucket(key) == null); + + // Create bucket via check + _ = limiter.check(key, &rule); + + // Now it should exist + const idx = limiter.findBucket(key); + try std.testing.expect(idx != null); + try std.testing.expect(idx.? < MAX_BUCKETS); +} + +test "computeKeyHash: distribution" { + // Test that different inputs produce different hashes + var hashes: [100]u64 = undefined; + for (0..100) |i| { + hashes[i] = computeKeyHash(@intCast(i), 0x12345678); + } + + // Check for uniqueness + for (0..100) |i| { + for (i + 1..100) |j| { + try std.testing.expect(hashes[i] != hashes[j]); + } + } +} + +test "computeKeyHash: non-zero guarantee" { + // Hash should never be zero + for (0..1000) |i| { + const hash_val = computeKeyHash(@intCast(i), @intCast(i)); + try std.testing.expect(hash_val != 0); + } +} + +test "hashPath: consistency" { + const hash1 = hashPath("/api/users"); + const hash2 = hashPath("/api/users"); + const hash3 = hashPath("/api/posts"); + + try std.testing.expectEqual(hash1, hash2); + try std.testing.expect(hash1 != hash3); +} + +test "DecisionResult: helper methods" { + const allowed = DecisionResult{ .action = .allow, .reason = .none }; + const blocked = DecisionResult{ .action = .block, .reason = .rate_limit }; + + try std.testing.expect(allowed.isAllowed()); + try std.testing.expect(!allowed.isBlocked()); + try std.testing.expect(blocked.isBlocked()); + try std.testing.expect(!blocked.isAllowed()); +} + +test "BucketStats: blockRatePercent" { + const stats_empty = BucketStats{ + .tokens = 0, + .last_update = 0, + .total_requests = 0, + .total_blocked = 0, + }; + try std.testing.expectEqual(@as(u64, 0), stats_empty.blockRatePercent()); + + const stats_half = BucketStats{ + .tokens = 0, + .last_update = 0, + .total_requests = 100, + .total_blocked = 50, + }; + try std.testing.expectEqual(@as(u64, 50), stats_half.blockRatePercent()); +} + +test "RateLimiter: checkWithRetry allows within limit" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(192, 168, 2, 1, hashPath("/retry")); + const rule = Rule.simple(10, 10); + + const result = limiter.checkWithRetry(key, &rule); + try std.testing.expect(result.isAllowed()); +} + +test "RateLimiter: checkWithRetry blocks when exhausted" { + var waf_state = WafState.init(); + var limiter = RateLimiter.init(&waf_state); + + const key = Key.fromOctets(10, 0, 0, 2, hashPath("/retry-heavy")); + const rule = Rule.simple(1, 2); + + _ = limiter.checkWithRetry(key, &rule); + _ = limiter.checkWithRetry(key, &rule); + + const result = limiter.checkWithRetry(key, &rule); + try std.testing.expect(result.isBlocked()); + try std.testing.expectEqual(Reason.rate_limit, result.reason); +} diff --git a/src/waf/state.zig b/src/waf/state.zig new file mode 100644 index 0000000..ed26015 --- /dev/null +++ b/src/waf/state.zig @@ -0,0 +1,1243 @@ +/// WAF Shared Memory State +/// +/// Provides cache-line aligned, lock-free data structures for the Web Application +/// Firewall. Designed for multi-process shared memory with atomic operations only. +/// +/// Design Philosophy (TigerBeetle-inspired): +/// - Fixed-size structures with explicit bounds +/// - Cache-line alignment to prevent false sharing +/// - Packed fields for atomic CAS operations +/// - Comptime assertions for size/alignment guarantees +/// - Zero allocation on hot path +/// +/// Memory Layout: +/// - WafState: Main container (~4MB total) +/// - Token bucket table: 64K entries for rate limiting +/// - Connection tracker: 16K entries for slowloris detection +/// - Metrics: Atomic counters for observability +/// - Config epoch: Hot-reload detection +const std = @import("std"); + +// ============================================================================= +// Constants (Single Source of Truth) +// ============================================================================= + +/// Maximum buckets in token bucket table (64K entries = ~4MB state) +pub const MAX_BUCKETS: usize = 65536; + +/// Maximum tokens in a bucket (scaled by 1000 for sub-token precision) +pub const MAX_TOKENS: u32 = 10000; + +/// Open addressing probe limit before giving up +pub const BUCKET_PROBE_LIMIT: u32 = 16; + +/// Maximum CAS retry attempts before declaring exhaustion +pub const MAX_CAS_ATTEMPTS: u32 = 8; + +/// Maximum tracked IPs for connection counting (slowloris prevention) +pub const MAX_TRACKED_IPS: usize = 16384; + +/// Magic number for corruption detection (ASCII: "WAFSTV1\0" + version nibble) +pub const WAF_STATE_MAGIC: u64 = 0x5741465354563130; // "WAFSTV10" + +/// Cache line size for alignment +const CACHE_LINE: usize = 64; + +// ============================================================================= +// Decision and Reason Enums +// ============================================================================= + +/// WAF decision for a request +pub const Decision = enum(u8) { + /// Allow the request to proceed + allow = 0, + /// Block the request (return error response) + block = 1, + /// Log but allow (shadow mode / detection only) + log_only = 2, + + /// Check if request should be blocked + pub inline fn isBlocked(self: Decision) bool { + return self == .block; + } + + /// Check if request should be logged + pub inline fn shouldLog(self: Decision) bool { + return self != .allow; + } +}; + +/// Reason for WAF decision +pub const Reason = enum(u8) { + /// No violation detected + none = 0, + /// Rate limit exceeded + rate_limit = 1, + /// Slowloris attack detected (too many connections) + slowloris = 2, + /// Request body too large + body_too_large = 3, + /// JSON nesting depth exceeded + json_depth = 4, + /// SQL injection pattern detected + sql_injection = 5, + /// XSS pattern detected + xss = 6, + /// Path traversal attempt + path_traversal = 7, + /// Invalid request format + invalid_request = 8, + /// Request velocity burst detected (sudden spike from IP) + burst = 9, + + /// Get human-readable description + pub fn description(self: Reason) []const u8 { + return switch (self) { + .none => "no violation", + .rate_limit => "rate limit exceeded", + .slowloris => "too many connections from IP", + .body_too_large => "request body too large", + .json_depth => "JSON nesting depth exceeded", + .sql_injection => "SQL injection pattern detected", + .xss => "XSS pattern detected", + .path_traversal => "path traversal attempt", + .invalid_request => "invalid request format", + .burst => "request velocity burst detected", + }; + } +}; + +// ============================================================================= +// Token Bucket (Rate Limiting) +// ============================================================================= + +/// Token bucket entry for rate limiting +/// Exactly one cache line (64 bytes) for optimal memory access patterns. +/// +/// The `packed_state` field combines tokens and timestamp for atomic CAS: +/// - High 32 bits: last_update (seconds since epoch, wraps every ~136 years) +/// - Low 32 bits: tokens (scaled by 1000 for precision) +pub const Bucket = extern struct { + /// Hash of (IP, path_pattern) - 0 means empty slot + key_hash: u64 = 0, + + /// Packed token bucket state for atomic CAS operations + /// High 32 bits: last_update timestamp, Low 32 bits: tokens + /// Use getPackedState/setPackedState for atomic access + packed_state: u64 = 0, + + /// Total requests seen for this bucket + total_requests: u64 = 0, + + /// Total requests blocked for this bucket + total_blocked: u64 = 0, + + /// Reserved for future use (alignment padding) + _reserved: [32]u8 = undefined, + + /// Get pointer to packed_state for atomic operations + pub inline fn getPackedStatePtr(self: *Bucket) *std.atomic.Value(u64) { + return @ptrCast(&self.packed_state); + } + + /// Get pointer to packed_state for atomic operations (const version) + pub inline fn getPackedStatePtrConst(self: *const Bucket) *const std.atomic.Value(u64) { + return @ptrCast(&self.packed_state); + } + + /// Get pointer to total_requests for atomic operations + pub inline fn getTotalRequestsPtr(self: *Bucket) *std.atomic.Value(u64) { + return @ptrCast(&self.total_requests); + } + + /// Get pointer to total_blocked for atomic operations + pub inline fn getTotalBlockedPtr(self: *Bucket) *std.atomic.Value(u64) { + return @ptrCast(&self.total_blocked); + } + + /// Pack tokens and timestamp into a single u64 for atomic CAS + pub inline fn packState(tokens: u32, timestamp: u32) u64 { + return (@as(u64, timestamp) << 32) | @as(u64, tokens); + } + + /// Extract tokens from packed state + pub inline fn unpackTokens(state: u64) u32 { + return @truncate(state); + } + + /// Extract timestamp from packed state + pub inline fn unpackTime(state: u64) u32 { + return @truncate(state >> 32); + } + + /// Check if this bucket slot is empty + pub inline fn isEmpty(self: *const Bucket) bool { + return self.key_hash == 0; + } + + /// Atomically try to consume tokens using CAS + /// Returns true if tokens were consumed, false if rate limited + pub fn tryConsume( + self: *Bucket, + tokens_to_consume: u32, + refill_rate: u32, + max_tokens: u32, + current_time: u32, + ) bool { + const state_ptr = self.getPackedStatePtr(); + var attempts: u32 = 0; + while (attempts < MAX_CAS_ATTEMPTS) : (attempts += 1) { + const old_state = state_ptr.load(.acquire); + const old_tokens = Bucket.unpackTokens(old_state); + const old_time = Bucket.unpackTime(old_state); + + // Calculate token refill based on elapsed time (wrap-safe) + const elapsed = current_time -% old_time; + + const refilled = @min( + @as(u64, old_tokens) + @as(u64, elapsed) * @as(u64, refill_rate), + @as(u64, max_tokens), + ); + + // Check if we have enough tokens + if (refilled < tokens_to_consume) { + return false; // Rate limited + } + + // Calculate new state + const new_tokens: u32 = @intCast(refilled - tokens_to_consume); + const new_state = Bucket.packState(new_tokens, current_time); + + // Attempt atomic update + if (state_ptr.cmpxchgWeak( + old_state, + new_state, + .acq_rel, + .acquire, + ) == null) { + // Success + return true; + } + // CAS failed, retry with new state + } + + // CAS exhausted - treat as rate limited for safety + return false; + } + + /// Get pointer to key_hash for atomic operations + pub inline fn getKeyHashPtr(self: *Bucket) *std.atomic.Value(u64) { + return @ptrCast(&self.key_hash); + } + + /// Initialize bucket with a key and full tokens + pub fn init(self: *Bucket, key: u64, max_tokens: u32, current_time: u32) void { + self.getKeyHashPtr().store(key, .release); + self.getPackedStatePtr().store(Bucket.packState(max_tokens, current_time), .release); + self.getTotalRequestsPtr().store(0, .release); + self.getTotalBlockedPtr().store(0, .release); + } + + comptime { + // Ensure bucket is exactly one cache line + std.debug.assert(@sizeOf(Bucket) == CACHE_LINE); + // Ensure proper alignment for atomics + std.debug.assert(@alignOf(Bucket) >= @alignOf(u64)); + } +}; + +// ============================================================================= +// Connection Tracker (Slowloris Prevention) +// ============================================================================= + +/// Connection entry for tracking per-IP connection counts +pub const ConnEntry = extern struct { + /// Hash of IP address - 0 means empty slot + ip_hash: u32 = 0, + + /// Current connection count (use getConnCountPtr for atomic access) + conn_count: u16 = 0, + + /// Padding for alignment + _padding: u16 = 0, + + /// Get pointer to conn_count for atomic operations + pub inline fn getConnCountPtr(self: *ConnEntry) *std.atomic.Value(u16) { + return @ptrCast(&self.conn_count); + } + + /// Get pointer to conn_count for atomic operations (const version) + pub inline fn getConnCountPtrConst(self: *const ConnEntry) *const std.atomic.Value(u16) { + return @ptrCast(&self.conn_count); + } + + /// Check if entry is empty + pub inline fn isEmpty(self: *const ConnEntry) bool { + return self.ip_hash == 0; + } + + /// Atomically increment connection count + /// Returns the new count + pub inline fn incrementConn(self: *ConnEntry) u16 { + return self.getConnCountPtr().fetchAdd(1, .acq_rel) + 1; + } + + /// Atomically decrement connection count using CAS loop + /// Returns the new count, saturates at 0 + pub inline fn decrementConn(self: *ConnEntry) u16 { + const ptr = self.getConnCountPtr(); + while (true) { + const old = ptr.load(.acquire); + if (old == 0) return 0; + if (ptr.cmpxchgWeak(old, old - 1, .acq_rel, .acquire) == null) { + return old - 1; + } + } + } + + /// Get current connection count + pub inline fn getConnCount(self: *const ConnEntry) u16 { + return self.getConnCountPtrConst().load(.acquire); + } + + comptime { + // Ensure entry is 8 bytes (fits nicely in cache) + std.debug.assert(@sizeOf(ConnEntry) == 8); + } +}; + +/// Connection tracker for slowloris prevention +/// Fixed-size hash table with open addressing +pub const ConnTracker = extern struct { + entries: [MAX_TRACKED_IPS]ConnEntry align(CACHE_LINE) = [_]ConnEntry{.{}} ** MAX_TRACKED_IPS, + + /// Find or create entry for an IP hash + /// Returns null if table is full (probe limit exceeded) + pub fn findOrCreate(self: *ConnTracker, ip_hash: u32) ?*ConnEntry { + if (ip_hash == 0) return null; // 0 is reserved for empty + + const start_idx = @as(usize, ip_hash) % MAX_TRACKED_IPS; + var idx = start_idx; + var probe_count: u32 = 0; + + while (probe_count < BUCKET_PROBE_LIMIT) : (probe_count += 1) { + const entry = &self.entries[idx]; + + // Found existing entry + if (entry.ip_hash == ip_hash) { + return entry; + } + + // Found empty slot - try to claim it atomically + if (entry.ip_hash == 0) { + // Use atomic store with release ordering to ensure visibility + const ip_hash_ptr: *std.atomic.Value(u32) = @ptrCast(&entry.ip_hash); + ip_hash_ptr.store(ip_hash, .release); + return entry; + } + + // Linear probing + idx = (idx + 1) % MAX_TRACKED_IPS; + } + + return null; // Table full or probe limit reached + } + + /// Find existing entry for an IP hash + /// Returns null if not found + pub fn find(self: *const ConnTracker, ip_hash: u32) ?*const ConnEntry { + if (ip_hash == 0) return null; + + const start_idx = @as(usize, ip_hash) % MAX_TRACKED_IPS; + var idx = start_idx; + var probe_count: u32 = 0; + + while (probe_count < BUCKET_PROBE_LIMIT) : (probe_count += 1) { + const entry = &self.entries[idx]; + + if (entry.ip_hash == ip_hash) { + return entry; + } + + if (entry.ip_hash == 0) { + return null; // Empty slot means not found + } + + idx = (idx + 1) % MAX_TRACKED_IPS; + } + + return null; + } + + comptime { + // Verify size is what we expect + std.debug.assert(@sizeOf(ConnTracker) == MAX_TRACKED_IPS * @sizeOf(ConnEntry)); + } +}; + +// ============================================================================= +// Burst Detector (Anomaly Detection) +// ============================================================================= + +/// Maximum tracked IPs for burst detection +pub const MAX_BURST_TRACKED: usize = 8192; + +/// Burst detection window in seconds +pub const BURST_WINDOW_SEC: u32 = 10; + +/// Default burst threshold multiplier (current rate > baseline * threshold = burst) +pub const BURST_THRESHOLD_MULTIPLIER: u32 = 10; + +/// Minimum baseline before burst detection activates (avoid false positives on first requests) +pub const BURST_MIN_BASELINE: u16 = 5; + +/// Entry for tracking per-IP request velocity +/// Uses exponential moving average (EMA) to establish baseline +pub const BurstEntry = extern struct { + /// Hash of IP address - 0 means empty slot + ip_hash: u32 = 0, + + /// Baseline request rate (EMA, requests per window, scaled by 16 for precision) + /// Stored as fixed-point: actual_rate = baseline_rate / 16 + baseline_rate: u16 = 0, + + /// Request count in current window + current_count: u16 = 0, + + /// Last window timestamp (seconds, wrapping) + last_window: u32 = 0, + + /// Get atomic pointer to current_count + pub inline fn getCurrentCountPtr(self: *BurstEntry) *std.atomic.Value(u16) { + return @ptrCast(&self.current_count); + } + + /// Get atomic pointer to baseline_rate + pub inline fn getBaselinePtr(self: *BurstEntry) *std.atomic.Value(u16) { + return @ptrCast(&self.baseline_rate); + } + + /// Get atomic pointer to last_window + pub inline fn getLastWindowPtr(self: *BurstEntry) *std.atomic.Value(u32) { + return @ptrCast(&self.last_window); + } + + /// Record a request and check for burst + /// Returns true if this is a burst (anomaly detected) + pub fn recordAndCheck(self: *BurstEntry, current_time: u32, threshold_mult: u32) bool { + const last = self.getLastWindowPtr().load(.acquire); + const time_diff = current_time -% last; + + // Check if we're in a new window + if (time_diff >= BURST_WINDOW_SEC) { + // Window expired - update baseline and reset count + const old_count = self.getCurrentCountPtr().swap(1, .acq_rel); + const old_baseline = self.getBaselinePtr().load(.acquire); + + // Calculate new baseline using EMA: new = old * 0.875 + current * 0.125 + // Using fixed-point: multiply count by 16, then blend + const current_scaled: u32 = @as(u32, old_count) * 16; + const new_baseline: u16 = @intCast( + (@as(u32, old_baseline) * 7 + current_scaled) / 8, + ); + + self.getBaselinePtr().store(new_baseline, .release); + self.getLastWindowPtr().store(current_time, .release); + + return false; // New window, no burst yet + } + + // Same window - increment count and check for burst + const new_count = self.getCurrentCountPtr().fetchAdd(1, .acq_rel) + 1; + const baseline = self.getBaselinePtr().load(.acquire); + + // Skip burst detection if baseline not established + if (baseline < BURST_MIN_BASELINE * 16) { + return false; + } + + // Check if current rate exceeds baseline * threshold + // baseline is scaled by 16, so: current * 16 > baseline * threshold + const current_scaled: u32 = @as(u32, new_count) * 16; + const threshold: u32 = @as(u32, baseline) * threshold_mult; + + return current_scaled > threshold; + } + + comptime { + // Ensure entry is 12 bytes + std.debug.assert(@sizeOf(BurstEntry) == 12); + } +}; + +/// Burst detector hash table +pub const BurstTracker = extern struct { + entries: [MAX_BURST_TRACKED]BurstEntry = [_]BurstEntry{.{}} ** MAX_BURST_TRACKED, + + /// Find or create entry for an IP + /// Returns null if table is full + pub fn findOrCreate(self: *BurstTracker, ip_hash: u32, current_time: u32) ?*BurstEntry { + if (ip_hash == 0) return null; + + const start_idx = ip_hash % MAX_BURST_TRACKED; + var idx = start_idx; + var probe_count: u32 = 0; + + while (probe_count < BUCKET_PROBE_LIMIT) : (probe_count += 1) { + const entry = &self.entries[idx]; + + // Found existing entry + if (entry.ip_hash == ip_hash) { + return entry; + } + + // Found empty slot - try to claim it + if (entry.ip_hash == 0) { + const ptr: *std.atomic.Value(u32) = @ptrCast(&entry.ip_hash); + if (ptr.cmpxchgStrong(0, ip_hash, .acq_rel, .acquire) == null) { + // Successfully claimed slot + entry.getLastWindowPtr().store(current_time, .release); + return entry; + } + // Someone else claimed it, check if it's ours + if (entry.ip_hash == ip_hash) { + return entry; + } + } + + idx = (idx + 1) % MAX_BURST_TRACKED; + } + + return null; // Table too full + } + + /// Check if an IP is bursting (for read-only check) + pub fn isBursting(self: *BurstTracker, ip_hash: u32, current_time: u32, threshold_mult: u32) bool { + if (ip_hash == 0) return false; + + const start_idx = ip_hash % MAX_BURST_TRACKED; + var idx = start_idx; + var probe_count: u32 = 0; + + while (probe_count < BUCKET_PROBE_LIMIT) : (probe_count += 1) { + const entry = &self.entries[idx]; + + if (entry.ip_hash == ip_hash) { + return entry.recordAndCheck(current_time, threshold_mult); + } + + if (entry.ip_hash == 0) { + return false; // Not found + } + + idx = (idx + 1) % MAX_BURST_TRACKED; + } + + return false; + } + + comptime { + std.debug.assert(@sizeOf(BurstTracker) == MAX_BURST_TRACKED * @sizeOf(BurstEntry)); + } +}; + +// ============================================================================= +// WAF Metrics +// ============================================================================= + +/// Single cache line of metrics (64 bytes = 8 x u64) +const MetricsCacheLine = extern struct { + values: [8]u64 = [_]u64{0} ** 8, +}; + +/// WAF metrics with cache-line aligned atomic counters +/// Uses cache-line sized blocks to prevent false sharing +pub const WafMetrics = extern struct { + // Primary counters (hot path) - each in its own cache line block + // Line 0: requests_allowed (index 0) + line0: MetricsCacheLine align(CACHE_LINE) = .{}, + // Line 1: requests_blocked (index 0) + line1: MetricsCacheLine = .{}, + // Line 2: requests_logged (index 0) + line2: MetricsCacheLine = .{}, + + // Block reason breakdown - grouped in one cache line + // blocked_rate_limit (0), blocked_slowloris (1), blocked_body_too_large (2), blocked_json_depth (3) + line3: MetricsCacheLine = .{}, + + // Operational metrics - grouped in one cache line + // bucket_table_usage (0), cas_exhausted (1), config_reloads (2) + line4: MetricsCacheLine = .{}, + + // Helper to get atomic pointer + inline fn atomicPtr(ptr: *u64) *std.atomic.Value(u64) { + return @ptrCast(ptr); + } + + inline fn atomicPtrConst(ptr: *const u64) *const std.atomic.Value(u64) { + return @ptrCast(ptr); + } + + /// Increment allowed counter + pub inline fn recordAllowed(self: *WafMetrics) void { + _ = atomicPtr(&self.line0.values[0]).fetchAdd(1, .monotonic); + } + + /// Increment blocked counter with reason breakdown + pub inline fn recordBlocked(self: *WafMetrics, reason: Reason) void { + _ = atomicPtr(&self.line1.values[0]).fetchAdd(1, .monotonic); + switch (reason) { + .rate_limit => _ = atomicPtr(&self.line3.values[0]).fetchAdd(1, .monotonic), + .slowloris => _ = atomicPtr(&self.line3.values[1]).fetchAdd(1, .monotonic), + .body_too_large => _ = atomicPtr(&self.line3.values[2]).fetchAdd(1, .monotonic), + .json_depth => _ = atomicPtr(&self.line3.values[3]).fetchAdd(1, .monotonic), + else => {}, + } + } + + /// Increment logged counter + pub inline fn recordLogged(self: *WafMetrics) void { + _ = atomicPtr(&self.line2.values[0]).fetchAdd(1, .monotonic); + } + + /// Record a CAS exhaustion event + pub inline fn recordCasExhausted(self: *WafMetrics) void { + _ = atomicPtr(&self.line4.values[1]).fetchAdd(1, .monotonic); + } + + /// Record a config reload + pub inline fn recordConfigReload(self: *WafMetrics) void { + _ = atomicPtr(&self.line4.values[2]).fetchAdd(1, .monotonic); + } + + /// Update bucket table usage count + pub inline fn updateBucketUsage(self: *WafMetrics, count: u64) void { + atomicPtr(&self.line4.values[0]).store(count, .monotonic); + } + + /// Get snapshot of all metrics (for reporting) + pub fn snapshot(self: *const WafMetrics) MetricsSnapshot { + return .{ + .requests_allowed = atomicPtrConst(&self.line0.values[0]).load(.monotonic), + .requests_blocked = atomicPtrConst(&self.line1.values[0]).load(.monotonic), + .requests_logged = atomicPtrConst(&self.line2.values[0]).load(.monotonic), + .blocked_rate_limit = atomicPtrConst(&self.line3.values[0]).load(.monotonic), + .blocked_slowloris = atomicPtrConst(&self.line3.values[1]).load(.monotonic), + .blocked_body_too_large = atomicPtrConst(&self.line3.values[2]).load(.monotonic), + .blocked_json_depth = atomicPtrConst(&self.line3.values[3]).load(.monotonic), + .bucket_table_usage = atomicPtrConst(&self.line4.values[0]).load(.monotonic), + .cas_exhausted = atomicPtrConst(&self.line4.values[1]).load(.monotonic), + .config_reloads = atomicPtrConst(&self.line4.values[2]).load(.monotonic), + }; + } + + comptime { + // Verify each line is cache-line sized + std.debug.assert(@sizeOf(MetricsCacheLine) == CACHE_LINE); + // Verify total size + std.debug.assert(@sizeOf(WafMetrics) == 5 * CACHE_LINE); + } +}; + +/// Non-atomic snapshot of metrics for reporting +pub const MetricsSnapshot = struct { + requests_allowed: u64, + requests_blocked: u64, + requests_logged: u64, + blocked_rate_limit: u64, + blocked_slowloris: u64, + blocked_body_too_large: u64, + blocked_json_depth: u64, + bucket_table_usage: u64, + cas_exhausted: u64, + config_reloads: u64, + + /// Calculate total requests processed + pub fn totalRequests(self: MetricsSnapshot) u64 { + return self.requests_allowed + self.requests_blocked + self.requests_logged; + } + + /// Calculate block rate as percentage (scaled by 100) + pub fn blockRatePercent(self: MetricsSnapshot) u64 { + const total = self.totalRequests(); + if (total == 0) return 0; + return (self.requests_blocked * 100) / total; + } +}; + +// ============================================================================= +// Main WAF State Structure +// ============================================================================= + +/// Calculate sizes for comptime assertions +const BUCKET_TABLE_SIZE: usize = MAX_BUCKETS * @sizeOf(Bucket); +const CONN_TRACKER_SIZE: usize = @sizeOf(ConnTracker); +const METRICS_SIZE: usize = @sizeOf(WafMetrics); +const HEADER_SIZE: usize = CACHE_LINE; // magic + config_epoch + +/// Total WAF state size (for shared memory allocation) +pub const WAF_STATE_SIZE: usize = blk: { + // Calculate with proper alignment padding + var size: usize = 0; + + // Header (magic + padding to cache line) + size += CACHE_LINE; + + // Bucket table (already cache-line aligned entries) + size += BUCKET_TABLE_SIZE; + + // ConnTracker (cache-line aligned) + size = std.mem.alignForward(usize, size, CACHE_LINE); + size += CONN_TRACKER_SIZE; + + // Metrics (cache-line aligned) + size = std.mem.alignForward(usize, size, CACHE_LINE); + size += METRICS_SIZE; + + // Config epoch (cache-line aligned) + size = std.mem.alignForward(usize, size, CACHE_LINE); + size += CACHE_LINE; + + break :blk size; +}; + +/// Main WAF shared state structure +/// All fields are cache-line aligned to prevent false sharing across CPU cores. +pub const WafState = extern struct { + /// Magic number for corruption detection + magic: u64 align(CACHE_LINE) = WAF_STATE_MAGIC, + + /// Version number for compatibility + version: u32 = 1, + + /// Reserved padding to fill first cache line + _header_padding: [52]u8 = undefined, + + /// Token bucket table for rate limiting (fixed-size, open addressing) + buckets: [MAX_BUCKETS]Bucket align(CACHE_LINE) = [_]Bucket{.{}} ** MAX_BUCKETS, + + /// Connection tracker for slowloris detection + conn_tracker: ConnTracker align(CACHE_LINE) = .{}, + + /// Burst detector for anomaly detection + burst_tracker: BurstTracker align(CACHE_LINE) = .{}, + + /// Global metrics with atomic counters + metrics: WafMetrics align(CACHE_LINE) = .{}, + + /// Configuration epoch for hot-reload detection + /// Increment this when WAF config changes; workers can detect stale config + config_epoch: u64 align(CACHE_LINE) = 0, + + /// Padding to ensure config_epoch has its own cache line + _epoch_padding: [56]u8 = undefined, + + // ========================================================================= + // Initialization and Validation + // ========================================================================= + + /// Get atomic pointer to config_epoch + inline fn getConfigEpochPtr(self: *WafState) *std.atomic.Value(u64) { + return @ptrCast(&self.config_epoch); + } + + /// Get atomic pointer to config_epoch (const version) + inline fn getConfigEpochPtrConst(self: *const WafState) *const std.atomic.Value(u64) { + return @ptrCast(&self.config_epoch); + } + + /// Initialize WAF state with magic number and zeroed fields + pub fn init() WafState { + return .{}; + } + + /// Validate the WAF state structure (check magic for corruption) + pub fn validate(self: *const WafState) bool { + return self.magic == WAF_STATE_MAGIC and self.version == 1; + } + + /// Get current config epoch + pub inline fn getConfigEpoch(self: *const WafState) u64 { + return self.getConfigEpochPtrConst().load(.acquire); + } + + /// Increment config epoch (call after hot-reloading config) + pub inline fn incrementConfigEpoch(self: *WafState) u64 { + const new_epoch = self.getConfigEpochPtr().fetchAdd(1, .acq_rel) + 1; + self.metrics.recordConfigReload(); + return new_epoch; + } + + // ========================================================================= + // Bucket Table Operations + // ========================================================================= + + /// Find or create a bucket for the given key hash + /// Uses open addressing with linear probing + /// Returns null if probe limit exceeded (table too full) + pub fn findOrCreateBucket(self: *WafState, key_hash: u64, current_time: u32) ?*Bucket { + if (key_hash == 0) return null; // 0 is reserved + + const start_idx = @as(usize, @truncate(key_hash)) % MAX_BUCKETS; + var idx = start_idx; + var probe_count: u32 = 0; + + while (probe_count < BUCKET_PROBE_LIMIT) : (probe_count += 1) { + const bucket = &self.buckets[idx]; + + // Found existing bucket + if (bucket.key_hash == key_hash) { + return bucket; + } + + // Found empty slot + if (bucket.key_hash == 0) { + // Initialize with full tokens + bucket.init(key_hash, MAX_TOKENS, current_time); + return bucket; + } + + // Linear probing + idx = (idx + 1) % MAX_BUCKETS; + } + + return null; // Probe limit exceeded + } + + /// Find an existing bucket (does not create) + pub fn findBucket(self: *const WafState, key_hash: u64) ?*const Bucket { + if (key_hash == 0) return null; + + const start_idx = @as(usize, @truncate(key_hash)) % MAX_BUCKETS; + var idx = start_idx; + var probe_count: u32 = 0; + + while (probe_count < BUCKET_PROBE_LIMIT) : (probe_count += 1) { + const bucket = &self.buckets[idx]; + + if (bucket.key_hash == key_hash) { + return bucket; + } + + if (bucket.key_hash == 0) { + return null; + } + + idx = (idx + 1) % MAX_BUCKETS; + } + + return null; + } + + /// Count non-empty buckets (for metrics) + pub fn countBuckets(self: *const WafState) u64 { + var count: u64 = 0; + for (&self.buckets) |*bucket| { + if (!bucket.isEmpty()) count += 1; + } + return count; + } + + // ========================================================================= + // Burst Detection Operations + // ========================================================================= + + /// Check if an IP is exhibiting burst behavior (sudden velocity spike) + /// Returns true if current request rate is significantly above baseline + pub fn checkBurst(self: *WafState, ip_hash: u32, current_time: u32, threshold_mult: u32) bool { + if (self.burst_tracker.findOrCreate(ip_hash, current_time)) |entry| { + return entry.recordAndCheck(current_time, threshold_mult); + } + return false; // Table full, fail open + } + + // ========================================================================= + // Comptime Assertions + // ========================================================================= + + comptime { + // Verify magic is at offset 0 and cache-line aligned + std.debug.assert(@offsetOf(WafState, "magic") == 0); + + // Verify all major sections are cache-line aligned + std.debug.assert(@offsetOf(WafState, "buckets") % CACHE_LINE == 0); + std.debug.assert(@offsetOf(WafState, "conn_tracker") % CACHE_LINE == 0); + std.debug.assert(@offsetOf(WafState, "burst_tracker") % CACHE_LINE == 0); + std.debug.assert(@offsetOf(WafState, "metrics") % CACHE_LINE == 0); + std.debug.assert(@offsetOf(WafState, "config_epoch") % CACHE_LINE == 0); + + // Verify struct alignment + std.debug.assert(@alignOf(WafState) >= CACHE_LINE); + } +}; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Pack tokens and timestamp into a single u64 for atomic CAS +/// Exported for use by external code +pub inline fn packState(tokens: u32, timestamp: u32) u64 { + return Bucket.packState(tokens, timestamp); +} + +/// Extract tokens from packed state +pub inline fn unpackTokens(state: u64) u32 { + return Bucket.unpackTokens(state); +} + +/// Extract timestamp from packed state +pub inline fn unpackTime(state: u64) u32 { + return Bucket.unpackTime(state); +} + +/// Compute hash for rate limiting key (IP + path pattern) +/// Uses FNV-1a for speed and good distribution +pub fn computeKeyHash(ip_bytes: []const u8, path: []const u8) u64 { + var hash: u64 = 0xcbf29ce484222325; // FNV offset basis + + for (ip_bytes) |b| { + hash ^= b; + hash *%= 0x100000001b3; // FNV prime + } + + // Separator to avoid collisions between IP and path + hash ^= 0xff; + hash *%= 0x100000001b3; + + for (path) |b| { + hash ^= b; + hash *%= 0x100000001b3; + } + + // Ensure non-zero (0 is reserved for empty slots) + return if (hash == 0) 1 else hash; +} + +/// Compute hash for IP address (connection tracking) +pub fn computeIpHash(ip_bytes: []const u8) u32 { + var hash: u32 = 0x811c9dc5; // FNV-1a 32-bit offset basis + + for (ip_bytes) |b| { + hash ^= b; + hash *%= 0x01000193; // FNV-1a 32-bit prime + } + + // Ensure non-zero + return if (hash == 0) 1 else hash; +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "Bucket: size and alignment" { + try std.testing.expectEqual(@as(usize, 64), @sizeOf(Bucket)); + try std.testing.expect(@alignOf(Bucket) >= 8); +} + +test "Bucket: pack/unpack state" { + const tokens: u32 = 5000; + const timestamp: u32 = 1703548800; // 2023-12-26 00:00:00 UTC + + const pack_val = packState(tokens, timestamp); + try std.testing.expectEqual(tokens, unpackTokens(pack_val)); + try std.testing.expectEqual(timestamp, unpackTime(pack_val)); +} + +test "Bucket: tryConsume" { + var bucket = Bucket{}; + bucket.init(0x12345678, MAX_TOKENS, 1000); + + // Should succeed - we have full tokens + try std.testing.expect(bucket.tryConsume(1000, 100, MAX_TOKENS, 1000)); + + // Check remaining tokens + const bucket_state = bucket.getPackedStatePtrConst().load(.acquire); + try std.testing.expectEqual(@as(u32, 9000), unpackTokens(bucket_state)); + + // Consume more + try std.testing.expect(bucket.tryConsume(9000, 100, MAX_TOKENS, 1000)); + + // Should fail - no tokens left + try std.testing.expect(!bucket.tryConsume(1000, 100, MAX_TOKENS, 1000)); + + // Wait for refill (time advances by 10 seconds, refill rate = 100/sec) + try std.testing.expect(bucket.tryConsume(500, 100, MAX_TOKENS, 1010)); +} + +test "ConnEntry: increment/decrement" { + var entry = ConnEntry{}; + entry.ip_hash = 0x12345678; + + try std.testing.expectEqual(@as(u16, 0), entry.getConnCount()); + + // Increment + try std.testing.expectEqual(@as(u16, 1), entry.incrementConn()); + try std.testing.expectEqual(@as(u16, 2), entry.incrementConn()); + try std.testing.expectEqual(@as(u16, 2), entry.getConnCount()); + + // Decrement + try std.testing.expectEqual(@as(u16, 1), entry.decrementConn()); + try std.testing.expectEqual(@as(u16, 0), entry.decrementConn()); + + // Underflow protection + try std.testing.expectEqual(@as(u16, 0), entry.decrementConn()); +} + +test "ConnTracker: find and create" { + var tracker = ConnTracker{}; + + // Find non-existent + try std.testing.expect(tracker.find(0x12345678) == null); + + // Create + const entry = tracker.findOrCreate(0x12345678); + try std.testing.expect(entry != null); + try std.testing.expectEqual(@as(u32, 0x12345678), entry.?.ip_hash); + + // Find existing + const found = tracker.find(0x12345678); + try std.testing.expect(found != null); + try std.testing.expectEqual(@as(u32, 0x12345678), found.?.ip_hash); + + // Find same entry again + const entry2 = tracker.findOrCreate(0x12345678); + try std.testing.expect(entry2 != null); + try std.testing.expectEqual(entry.?, entry2.?); +} + +test "WafMetrics: record and snapshot" { + var metrics = WafMetrics{}; + + metrics.recordAllowed(); + metrics.recordAllowed(); + metrics.recordBlocked(.rate_limit); + metrics.recordLogged(); + + const snap = metrics.snapshot(); + try std.testing.expectEqual(@as(u64, 2), snap.requests_allowed); + try std.testing.expectEqual(@as(u64, 1), snap.requests_blocked); + try std.testing.expectEqual(@as(u64, 1), snap.requests_logged); + try std.testing.expectEqual(@as(u64, 1), snap.blocked_rate_limit); + try std.testing.expectEqual(@as(u64, 4), snap.totalRequests()); + try std.testing.expectEqual(@as(u64, 25), snap.blockRatePercent()); +} + +test "WafState: init and validate" { + const state = WafState.init(); + try std.testing.expect(state.validate()); + try std.testing.expectEqual(WAF_STATE_MAGIC, state.magic); + try std.testing.expectEqual(@as(u32, 1), state.version); +} + +test "WafState: config epoch" { + var state = WafState.init(); + + try std.testing.expectEqual(@as(u64, 0), state.getConfigEpoch()); + + const epoch1 = state.incrementConfigEpoch(); + try std.testing.expectEqual(@as(u64, 1), epoch1); + try std.testing.expectEqual(@as(u64, 1), state.getConfigEpoch()); + + const epoch2 = state.incrementConfigEpoch(); + try std.testing.expectEqual(@as(u64, 2), epoch2); +} + +test "WafState: bucket operations" { + var state = WafState.init(); + const current_time: u32 = 1000; + + // Find or create bucket + const bucket = state.findOrCreateBucket(0xDEADBEEF, current_time); + try std.testing.expect(bucket != null); + try std.testing.expectEqual(@as(u64, 0xDEADBEEF), bucket.?.key_hash); + + // Find existing bucket + const found = state.findBucket(0xDEADBEEF); + try std.testing.expect(found != null); + try std.testing.expectEqual(@as(u64, 0xDEADBEEF), found.?.key_hash); + + // Find non-existent + try std.testing.expect(state.findBucket(0xCAFEBABE) == null); +} + +test "Decision: enum operations" { + const allow = Decision.allow; + const block = Decision.block; + const log_only = Decision.log_only; + + try std.testing.expect(!allow.isBlocked()); + try std.testing.expect(block.isBlocked()); + try std.testing.expect(!log_only.isBlocked()); + + try std.testing.expect(!allow.shouldLog()); + try std.testing.expect(block.shouldLog()); + try std.testing.expect(log_only.shouldLog()); +} + +test "Reason: descriptions" { + try std.testing.expectEqualStrings("rate limit exceeded", Reason.rate_limit.description()); + try std.testing.expectEqualStrings("too many connections from IP", Reason.slowloris.description()); +} + +test "computeKeyHash: basic" { + const ip = [_]u8{ 192, 168, 1, 1 }; + const path = "/api/users"; + + const hash1 = computeKeyHash(&ip, path); + const hash2 = computeKeyHash(&ip, path); + + // Same input should produce same hash + try std.testing.expectEqual(hash1, hash2); + + // Different path should produce different hash + const hash3 = computeKeyHash(&ip, "/api/posts"); + try std.testing.expect(hash1 != hash3); + + // Hash should never be 0 + try std.testing.expect(hash1 != 0); +} + +test "computeIpHash: basic" { + const ip1 = [_]u8{ 192, 168, 1, 1 }; + const ip2 = [_]u8{ 192, 168, 1, 2 }; + + const hash1 = computeIpHash(&ip1); + const hash2 = computeIpHash(&ip2); + + try std.testing.expect(hash1 != hash2); + try std.testing.expect(hash1 != 0); + try std.testing.expect(hash2 != 0); +} + +test "alignment: all structures properly aligned" { + // Bucket must be cache-line sized + try std.testing.expectEqual(@as(usize, CACHE_LINE), @sizeOf(Bucket)); + + // ConnEntry should be 8 bytes + try std.testing.expectEqual(@as(usize, 8), @sizeOf(ConnEntry)); + + // WafState sections must be cache-line aligned + try std.testing.expect(@offsetOf(WafState, "buckets") % CACHE_LINE == 0); + try std.testing.expect(@offsetOf(WafState, "conn_tracker") % CACHE_LINE == 0); + try std.testing.expect(@offsetOf(WafState, "metrics") % CACHE_LINE == 0); + try std.testing.expect(@offsetOf(WafState, "config_epoch") % CACHE_LINE == 0); +} + +test "BurstEntry: recordAndCheck detects velocity spike" { + var entry = BurstEntry{ + .ip_hash = 0x12345678, + .baseline_rate = 0, + .current_count = 0, + .last_window = 0, + }; + + const threshold: u32 = 3; // Current rate must be > baseline * 3 to trigger + var time: u32 = 1000; + + // Establish baseline over several windows with 20 requests each + // This builds up baseline above BURST_MIN_BASELINE (5 * 16 = 80) + for (0..5) |_| { + for (0..20) |_| { + _ = entry.recordAndCheck(time, threshold); + } + time += BURST_WINDOW_SEC; + } + + // Baseline should now be ~20 requests/window (scaled by 16 = ~320) + // which is above BURST_MIN_BASELINE * 16 = 80 + + // Now simulate a burst: 200 requests in one window + // This should trigger because 200 * 16 = 3200 > 320 * 3 = 960 + var burst_detected = false; + for (0..200) |_| { + if (entry.recordAndCheck(time, threshold)) { + burst_detected = true; + break; + } + } + + try std.testing.expect(burst_detected); +} + +test "BurstEntry: no burst for steady traffic" { + var entry = BurstEntry{ + .ip_hash = 0x12345678, + .baseline_rate = 0, + .current_count = 0, + .last_window = 0, + }; + + const threshold: u32 = 10; + const base_time: u32 = 1000; + + // First window - establish baseline (50 requests) + for (0..50) |_| { + _ = entry.recordAndCheck(base_time, threshold); + } + + // Move to next window + const window2_time = base_time + BURST_WINDOW_SEC; + _ = entry.recordAndCheck(window2_time, threshold); + + // Maintain similar rate - no burst + var burst_detected = false; + for (0..60) |_| { + if (entry.recordAndCheck(window2_time, threshold)) { + burst_detected = true; + } + } + + // Should NOT detect burst (60 is not >> 50 * 10) + try std.testing.expect(!burst_detected); +} + +test "BurstTracker: findOrCreate" { + var tracker = BurstTracker{}; + const current_time: u32 = 1000; + + // Find or create entry + const entry1 = tracker.findOrCreate(0x12345678, current_time); + try std.testing.expect(entry1 != null); + try std.testing.expectEqual(@as(u32, 0x12345678), entry1.?.ip_hash); + + // Same hash should return same entry + const entry2 = tracker.findOrCreate(0x12345678, current_time); + try std.testing.expectEqual(entry1, entry2); + + // Different hash should return different entry + const entry3 = tracker.findOrCreate(0xDEADBEEF, current_time); + try std.testing.expect(entry3 != null); + try std.testing.expect(entry1 != entry3); +} + +test "WafState: checkBurst integration" { + var waf_state = WafState.init(); + + const ip_hash: u32 = 0xCAFEBABE; + const threshold: u32 = 3; + var time: u32 = 1000; + + // Establish baseline over several windows with 30 requests each + for (0..5) |_| { + for (0..30) |_| { + // Should not detect burst while establishing baseline + const result = waf_state.checkBurst(ip_hash, time, threshold); + _ = result; + } + time += BURST_WINDOW_SEC; + } + + // Now burst: 300 requests in one window + // Baseline ~30 req/window * 16 = 480, threshold 480 * 3 = 1440 + // 300 * 16 = 4800 > 1440, should trigger + var burst_detected = false; + for (0..300) |_| { + if (waf_state.checkBurst(ip_hash, time, threshold)) { + burst_detected = true; + break; + } + } + try std.testing.expect(burst_detected); +} diff --git a/src/waf/validator.zig b/src/waf/validator.zig new file mode 100644 index 0000000..98e614e --- /dev/null +++ b/src/waf/validator.zig @@ -0,0 +1,572 @@ +/// Request Validation - Size and Structure Limits +/// +/// High-performance request validation for the Web Application Firewall. +/// Validates URI length, body size, cookie size, and JSON structure. +/// +/// Design Philosophy (TigerBeetle-inspired): +/// - Zero allocation on hot path +/// - Streaming JSON validation (constant memory) +/// - Early rejection of invalid requests +/// - Pre-body validation for fast fail +/// +/// Validation Flow: +/// 1. validateRequest() - Check URI, Content-Length, headers before body +/// 2. validateJsonStream() - Incremental JSON validation during body receipt +/// +/// Memory Characteristics: +/// - ValidatorConfig: ~20 bytes (inline configuration) +/// - JsonState: 4 bytes (streaming state) +/// - No heap allocation during validation +const std = @import("std"); + +const state = @import("state.zig"); +pub const Reason = state.Reason; + +// ============================================================================= +// Validation Result +// ============================================================================= + +/// Result of a validation check +pub const ValidationResult = struct { + /// Whether the request passed validation + valid: bool, + /// Reason for rejection (null if valid) + reason: ?Reason, + + /// Create a passing result + pub inline fn pass() ValidationResult { + return .{ .valid = true, .reason = null }; + } + + /// Create a failing result with reason + pub inline fn fail(reason: Reason) ValidationResult { + return .{ .valid = false, .reason = reason }; + } + + /// Check if validation passed + pub inline fn isValid(self: ValidationResult) bool { + return self.valid; + } + + /// Check if validation failed + pub inline fn isInvalid(self: ValidationResult) bool { + return !self.valid; + } +}; + +// ============================================================================= +// Validator Configuration +// ============================================================================= + +/// Configuration for request validation limits +/// Mirrors RequestLimitsConfig from config.zig with additional fields +/// for comprehensive validation +pub const ValidatorConfig = struct { + /// Maximum URI length in bytes (default 2KB) + max_uri_length: u32 = 2048, + + /// Maximum number of query parameters + max_query_params: u16 = 50, + + /// Maximum header value length (default 8KB) + max_header_value_length: u32 = 8192, + + /// Maximum cookie size in bytes (default 4KB) + max_cookie_size: u32 = 4096, + + /// Maximum request body size in bytes (default 1MB) + max_body_size: u32 = 1_048_576, + + /// Maximum JSON nesting depth + max_json_depth: u8 = 20, + + /// Maximum JSON keys (protects against hash collision attacks) + max_json_keys: u16 = 1000, + + /// Create configuration from RequestLimitsConfig + pub fn fromRequestLimits(limits: anytype) ValidatorConfig { + return .{ + .max_uri_length = limits.max_uri_length, + .max_body_size = limits.max_body_size, + .max_json_depth = limits.max_json_depth, + }; + } + + comptime { + // Ensure config fits in a cache line + std.debug.assert(@sizeOf(ValidatorConfig) <= 64); + } +}; + +// ============================================================================= +// JSON Streaming State +// ============================================================================= + +/// State for streaming JSON validation +/// Tracks nesting depth and key count with constant memory usage +/// Properly handles string escapes to avoid false positives +pub const JsonState = struct { + /// Current nesting depth (objects and arrays) + depth: u8 = 0, + + /// Total keys seen (colons outside strings) + key_count: u16 = 0, + + /// Currently inside a string literal + in_string: bool = false, + + /// Next character is escaped (preceded by backslash) + escape_next: bool = false, + + /// Reset state for new request + pub inline fn reset(self: *JsonState) void { + self.* = .{}; + } + + /// Check if parsing is complete (all brackets closed) + pub inline fn isComplete(self: *const JsonState) bool { + return self.depth == 0 and !self.in_string; + } + + comptime { + // Ensure state is minimal (6 bytes with alignment) + std.debug.assert(@sizeOf(JsonState) <= 8); + } +}; + +// ============================================================================= +// Request Validator +// ============================================================================= + +/// Validates HTTP requests against configured limits +/// Zero allocation, suitable for hot path +pub const RequestValidator = struct { + /// Configuration reference (immutable during request processing) + config: *const ValidatorConfig, + + /// Initialize validator with configuration + pub fn init(config: *const ValidatorConfig) RequestValidator { + return .{ .config = config }; + } + + // ========================================================================= + // Pre-Body Validation + // ========================================================================= + + /// Validate request before reading body + /// Fast path for rejecting obviously invalid requests + /// + /// Checks: + /// - URI length + /// - Content-Length vs max body size + /// - Cookie size + /// + /// Parameters: + /// - uri: Request URI (path + query string) + /// - content_length: Value of Content-Length header (null if not present) + /// - headers: Iterator or slice of header name-value pairs + /// + /// Returns: ValidationResult with pass/fail and reason + pub fn validateRequest( + self: *const RequestValidator, + uri: ?[]const u8, + content_length: ?u32, + headers: anytype, + ) ValidationResult { + // Check URI length + if (uri) |u| { + if (u.len > self.config.max_uri_length) { + return ValidationResult.fail(.invalid_request); + } + + // Count query parameters if present + if (std.mem.indexOf(u8, u, "?")) |query_start| { + const query = u[query_start + 1 ..]; + const param_count = countQueryParams(query); + if (param_count > self.config.max_query_params) { + return ValidationResult.fail(.invalid_request); + } + } + } + + // Check Content-Length against max body size + if (content_length) |len| { + if (len > self.config.max_body_size) { + return ValidationResult.fail(.body_too_large); + } + } + + // Check headers for cookie size + return self.validateHeaders(headers); + } + + /// Validate specific headers + /// Separate method for when you only have headers to check + fn validateHeaders(self: *const RequestValidator, headers: anytype) ValidationResult { + const HeadersType = @TypeOf(headers); + const type_info = @typeInfo(HeadersType); + + // Handle different header representations + switch (type_info) { + .pointer => |ptr| { + // Slice of header pairs + if (ptr.size == .Slice) { + for (headers) |header| { + const result = self.checkHeader(header[0], header[1]); + if (result.isInvalid()) return result; + } + } + }, + .@"struct" => { + // Iterator-like struct with next() method + if (@hasDecl(HeadersType, "next")) { + var iter = headers; + while (iter.next()) |header| { + const result = self.checkHeader(header.name, header.value); + if (result.isInvalid()) return result; + } + } + }, + .null => { + // No headers to check + }, + else => { + // Unsupported type - skip header validation + }, + } + + return ValidationResult.pass(); + } + + /// Check a single header against limits + fn checkHeader(self: *const RequestValidator, name: []const u8, value: []const u8) ValidationResult { + // Check header value length + if (value.len > self.config.max_header_value_length) { + return ValidationResult.fail(.invalid_request); + } + + // Check cookie size specifically + if (std.ascii.eqlIgnoreCase(name, "cookie")) { + if (value.len > self.config.max_cookie_size) { + return ValidationResult.fail(.invalid_request); + } + } + + return ValidationResult.pass(); + } + + // ========================================================================= + // Streaming JSON Validation + // ========================================================================= + + /// Validate a chunk of JSON data + /// Call repeatedly as body chunks arrive + /// Maintains state between calls for streaming validation + /// + /// Features: + /// - Constant memory (O(1) space) + /// - Proper string escape handling + /// - Depth limiting (prevents stack exhaustion attacks) + /// - Key counting (prevents hash collision attacks) + /// + /// Parameters: + /// - chunk: Bytes of JSON data + /// - json_state: Mutable state tracking depth and keys + /// + /// Returns: ValidationResult - pass to continue, fail to reject + pub fn validateJsonStream( + self: *const RequestValidator, + chunk: []const u8, + json_state: *JsonState, + ) ValidationResult { + for (chunk) |byte| { + // Handle escape sequences in strings + if (json_state.escape_next) { + json_state.escape_next = false; + continue; + } + + // Handle string state + if (json_state.in_string) { + switch (byte) { + '\\' => json_state.escape_next = true, + '"' => json_state.in_string = false, + else => {}, + } + continue; + } + + // Not in string - check structural characters + switch (byte) { + '"' => json_state.in_string = true, + + '{', '[' => { + // Use saturating add to prevent overflow + json_state.depth +|= 1; + + if (json_state.depth > self.config.max_json_depth) { + return ValidationResult.fail(.json_depth); + } + }, + + '}', ']' => { + // Use saturating subtract to prevent underflow + json_state.depth -|= 1; + }, + + ':' => { + // Colon indicates a key-value pair in object + json_state.key_count +|= 1; + + if (json_state.key_count > self.config.max_json_keys) { + return ValidationResult.fail(.json_depth); + } + }, + + else => {}, + } + } + + return ValidationResult.pass(); + } + + /// Validate complete JSON body (non-streaming) + /// Convenience method for when entire body is available + pub fn validateJsonBody(self: *const RequestValidator, body: []const u8) ValidationResult { + var json_state = JsonState{}; + return self.validateJsonStream(body, &json_state); + } +}; + +// ============================================================================= +// Helper Functions +// ============================================================================= + +/// Count query parameters in a query string +/// Parameters are separated by & and contain key=value or just key +fn countQueryParams(query: []const u8) u16 { + if (query.len == 0) return 0; + + var count: u16 = 1; // At least one param if query exists + for (query) |c| { + if (c == '&') { + count +|= 1; + } + } + return count; +} + +// ============================================================================= +// Tests +// ============================================================================= + +test "ValidationResult: pass and fail" { + const pass_result = ValidationResult.pass(); + try std.testing.expect(pass_result.isValid()); + try std.testing.expect(!pass_result.isInvalid()); + try std.testing.expect(pass_result.reason == null); + + const fail_result = ValidationResult.fail(.body_too_large); + try std.testing.expect(!fail_result.isValid()); + try std.testing.expect(fail_result.isInvalid()); + try std.testing.expectEqual(Reason.body_too_large, fail_result.reason.?); +} + +test "ValidatorConfig: default values" { + const config = ValidatorConfig{}; + try std.testing.expectEqual(@as(u32, 2048), config.max_uri_length); + try std.testing.expectEqual(@as(u32, 4096), config.max_cookie_size); + try std.testing.expectEqual(@as(u32, 1_048_576), config.max_body_size); + try std.testing.expectEqual(@as(u8, 20), config.max_json_depth); + try std.testing.expectEqual(@as(u16, 1000), config.max_json_keys); +} + +test "JsonState: size and reset" { + var json_state = JsonState{ .depth = 5, .key_count = 100, .in_string = true }; + try std.testing.expect(@sizeOf(JsonState) <= 8); + + json_state.reset(); + try std.testing.expectEqual(@as(u8, 0), json_state.depth); + try std.testing.expectEqual(@as(u16, 0), json_state.key_count); + try std.testing.expect(!json_state.in_string); +} + +test "JsonState: isComplete" { + var json_state = JsonState{}; + try std.testing.expect(json_state.isComplete()); + + json_state.depth = 1; + try std.testing.expect(!json_state.isComplete()); + + json_state.depth = 0; + json_state.in_string = true; + try std.testing.expect(!json_state.isComplete()); +} + +test "RequestValidator: init" { + const config = ValidatorConfig{}; + const validator = RequestValidator.init(&config); + try std.testing.expectEqual(&config, validator.config); +} + +test "RequestValidator: validateRequest passes valid URI" { + const config = ValidatorConfig{ .max_uri_length = 100 }; + const validator = RequestValidator.init(&config); + + const result = validator.validateRequest("/api/users", null, null); + try std.testing.expect(result.isValid()); +} + +test "RequestValidator: validateRequest rejects long URI" { + const config = ValidatorConfig{ .max_uri_length = 10 }; + const validator = RequestValidator.init(&config); + + const result = validator.validateRequest("/api/users/very/long/path", null, null); + try std.testing.expect(result.isInvalid()); + try std.testing.expectEqual(Reason.invalid_request, result.reason.?); +} + +test "RequestValidator: validateRequest rejects large body" { + const config = ValidatorConfig{ .max_body_size = 1000 }; + const validator = RequestValidator.init(&config); + + const result = validator.validateRequest("/api/upload", 5000, null); + try std.testing.expect(result.isInvalid()); + try std.testing.expectEqual(Reason.body_too_large, result.reason.?); +} + +test "RequestValidator: validateRequest allows valid body size" { + const config = ValidatorConfig{ .max_body_size = 10000 }; + const validator = RequestValidator.init(&config); + + const result = validator.validateRequest("/api/upload", 5000, null); + try std.testing.expect(result.isValid()); +} + +test "RequestValidator: validateJsonStream passes simple JSON" { + const config = ValidatorConfig{}; + const validator = RequestValidator.init(&config); + + var json_state = JsonState{}; + const json = "{\"name\": \"test\", \"value\": 123}"; + + const result = validator.validateJsonStream(json, &json_state); + try std.testing.expect(result.isValid()); + try std.testing.expectEqual(@as(u8, 0), json_state.depth); // All closed + try std.testing.expectEqual(@as(u16, 2), json_state.key_count); // Two keys +} + +test "RequestValidator: validateJsonStream rejects deep nesting" { + const config = ValidatorConfig{ .max_json_depth = 3 }; + const validator = RequestValidator.init(&config); + + var json_state = JsonState{}; + const json = "{{{{"; // 4 levels of nesting + + const result = validator.validateJsonStream(json, &json_state); + try std.testing.expect(result.isInvalid()); + try std.testing.expectEqual(Reason.json_depth, result.reason.?); +} + +test "RequestValidator: validateJsonStream rejects too many keys" { + const config = ValidatorConfig{ .max_json_keys = 3 }; + const validator = RequestValidator.init(&config); + + var json_state = JsonState{}; + const json = "{\"a\":1,\"b\":2,\"c\":3,\"d\":4}"; // 4 keys + + const result = validator.validateJsonStream(json, &json_state); + try std.testing.expect(result.isInvalid()); + try std.testing.expectEqual(Reason.json_depth, result.reason.?); +} + +test "RequestValidator: validateJsonStream handles strings correctly" { + const config = ValidatorConfig{ .max_json_depth = 2 }; + const validator = RequestValidator.init(&config); + + var json_state = JsonState{}; + // Braces inside strings should be ignored + const json = "{\"data\": \"{{{{{{{\"}"; + + const result = validator.validateJsonStream(json, &json_state); + try std.testing.expect(result.isValid()); + try std.testing.expectEqual(@as(u8, 0), json_state.depth); +} + +test "RequestValidator: validateJsonStream handles escapes" { + const config = ValidatorConfig{}; + const validator = RequestValidator.init(&config); + + var json_state = JsonState{}; + // Escaped quote should not end string + const json = "{\"data\": \"test\\\"quote\"}"; + + const result = validator.validateJsonStream(json, &json_state); + try std.testing.expect(result.isValid()); + try std.testing.expect(!json_state.in_string); +} + +test "RequestValidator: validateJsonStream chunked" { + const config = ValidatorConfig{}; + const validator = RequestValidator.init(&config); + + var json_state = JsonState{}; + + // Send JSON in chunks + var result = validator.validateJsonStream("{\"na", &json_state); + try std.testing.expect(result.isValid()); + try std.testing.expectEqual(@as(u8, 1), json_state.depth); + + result = validator.validateJsonStream("me\":\"te", &json_state); + try std.testing.expect(result.isValid()); + try std.testing.expect(json_state.in_string); + + result = validator.validateJsonStream("st\"}", &json_state); + try std.testing.expect(result.isValid()); + try std.testing.expectEqual(@as(u8, 0), json_state.depth); + try std.testing.expect(!json_state.in_string); +} + +test "RequestValidator: validateJsonBody convenience" { + const config = ValidatorConfig{}; + const validator = RequestValidator.init(&config); + + const result = validator.validateJsonBody("{\"key\": [1, 2, 3]}"); + try std.testing.expect(result.isValid()); +} + +test "RequestValidator: validateJsonStream handles nested arrays" { + const config = ValidatorConfig{ .max_json_depth = 5 }; + const validator = RequestValidator.init(&config); + + var json_state = JsonState{}; + const json = "{\"arr\": [[1, 2], [3, 4]]}"; + + const result = validator.validateJsonStream(json, &json_state); + try std.testing.expect(result.isValid()); + try std.testing.expectEqual(@as(u8, 0), json_state.depth); +} + +test "countQueryParams: basic" { + try std.testing.expectEqual(@as(u16, 0), countQueryParams("")); + try std.testing.expectEqual(@as(u16, 1), countQueryParams("a=1")); + try std.testing.expectEqual(@as(u16, 2), countQueryParams("a=1&b=2")); + try std.testing.expectEqual(@as(u16, 3), countQueryParams("a=1&b=2&c=3")); +} + +test "RequestValidator: validateRequest rejects too many query params" { + const config = ValidatorConfig{ .max_query_params = 2 }; + const validator = RequestValidator.init(&config); + + const result = validator.validateRequest("/api?a=1&b=2&c=3", null, null); + try std.testing.expect(result.isInvalid()); + try std.testing.expectEqual(Reason.invalid_request, result.reason.?); +} + +test "RequestValidator: validateRequest allows valid query params" { + const config = ValidatorConfig{ .max_query_params = 5 }; + const validator = RequestValidator.init(&config); + + const result = validator.validateRequest("/api?a=1&b=2", null, null); + try std.testing.expect(result.isValid()); +} diff --git a/tests/__pycache__/h2_backend.cpython-314.pyc b/tests/__pycache__/h2_backend.cpython-314.pyc index b54c7a3..99dae5f 100644 Binary files a/tests/__pycache__/h2_backend.cpython-314.pyc and b/tests/__pycache__/h2_backend.cpython-314.pyc differ diff --git a/tests/fixtures/mock_otlp_collector.zig b/tests/fixtures/mock_otlp_collector.zig new file mode 100644 index 0000000..303e21f --- /dev/null +++ b/tests/fixtures/mock_otlp_collector.zig @@ -0,0 +1,226 @@ +/// Mock OTLP Collector - Receives and stores traces for integration tests +/// +/// Provides: +/// - POST /v1/traces - Receive OTLP traces (protobuf) +/// - GET /traces - Retrieve stored traces as JSON +/// - DELETE /traces - Clear stored traces +/// +/// This allows integration tests to verify that telemetry is correctly exported. +const std = @import("std"); +const log = std.log.scoped(.mock_otlp); + +const zzz = @import("zzz"); +const http = zzz.HTTP; + +const Io = std.Io; +const Server = http.Server; +const Router = http.Router; +const Context = http.Context; +const Route = http.Route; +const Respond = http.Respond; + +/// Stored trace data for test verification +const StoredTrace = struct { + sequence: usize, + body_size: usize, + /// Raw protobuf body (first 1KB for debugging) + body_preview: []const u8, +}; + +var stored_traces: std.ArrayListUnmanaged(StoredTrace) = .empty; +var traces_mutex: std.Thread.Mutex = .{}; +var trace_allocator: std.mem.Allocator = undefined; +var trace_sequence: usize = 0; + +/// Handle incoming OTLP traces +fn handleOtlpTraces(ctx: *const Context, _: void) !Respond { + const body = ctx.request.body orelse ""; + + log.info("Received OTLP trace: {d} bytes", .{body.len}); + + // Store trace info + traces_mutex.lock(); + defer traces_mutex.unlock(); + + // Store preview of body (up to 1KB) + const preview_len = @min(body.len, 1024); + const preview = trace_allocator.dupe(u8, body[0..preview_len]) catch { + return ctx.response.apply(.{ + .status = .@"Internal Server Error", + .mime = http.Mime.JSON, + .body = "{\"error\":\"allocation_failed\"}", + }); + }; + + trace_sequence += 1; + stored_traces.append(trace_allocator, .{ + .sequence = trace_sequence, + .body_size = body.len, + .body_preview = preview, + }) catch { + trace_allocator.free(preview); + return ctx.response.apply(.{ + .status = .@"Internal Server Error", + .mime = http.Mime.JSON, + .body = "{\"error\":\"storage_failed\"}", + }); + }; + + log.info("Stored trace #{d}", .{stored_traces.items.len}); + + return ctx.response.apply(.{ + .status = .OK, + .mime = http.Mime.JSON, + .body = "{}", + }); +} + +/// Retrieve stored traces for test verification +fn handleGetTraces(ctx: *const Context, _: void) !Respond { + const allocator = ctx.allocator; + + traces_mutex.lock(); + defer traces_mutex.unlock(); + + var json: std.ArrayListUnmanaged(u8) = .empty; + errdefer json.deinit(allocator); + + try json.appendSlice(allocator, "{\"trace_count\":"); + + var count_buf: [16]u8 = undefined; + const count_str = try std.fmt.bufPrint(&count_buf, "{d}", .{stored_traces.items.len}); + try json.appendSlice(allocator, count_str); + + try json.appendSlice(allocator, ",\"traces\":["); + + for (stored_traces.items, 0..) |trace, i| { + if (i > 0) try json.appendSlice(allocator, ","); + + try json.appendSlice(allocator, "{\"sequence\":"); + var seq_buf: [32]u8 = undefined; + const seq_str = try std.fmt.bufPrint(&seq_buf, "{d}", .{trace.sequence}); + try json.appendSlice(allocator, seq_str); + + try json.appendSlice(allocator, ",\"body_size\":"); + var size_buf: [16]u8 = undefined; + const size_str = try std.fmt.bufPrint(&size_buf, "{d}", .{trace.body_size}); + try json.appendSlice(allocator, size_str); + + // Add hex preview of first few bytes for debugging + try json.appendSlice(allocator, ",\"body_preview_hex\":\""); + const hex_len = @min(trace.body_preview.len, 64); + for (trace.body_preview[0..hex_len]) |byte| { + var hex_buf: [2]u8 = undefined; + _ = std.fmt.bufPrint(&hex_buf, "{x:0>2}", .{byte}) catch continue; + try json.appendSlice(allocator, &hex_buf); + } + try json.appendSlice(allocator, "\"}"); + } + + try json.appendSlice(allocator, "]}"); + + const response_body = try json.toOwnedSlice(allocator); + + return ctx.response.apply(.{ + .status = .OK, + .mime = http.Mime.JSON, + .body = response_body, + }); +} + +/// Clear stored traces +fn handleClearTraces(ctx: *const Context, _: void) !Respond { + traces_mutex.lock(); + defer traces_mutex.unlock(); + + for (stored_traces.items) |trace| { + trace_allocator.free(trace.body_preview); + } + stored_traces.clearRetainingCapacity(); + trace_sequence = 0; + + log.info("Cleared all stored traces", .{}); + + return ctx.response.apply(.{ + .status = .OK, + .mime = http.Mime.JSON, + .body = "{\"cleared\":true}", + }); +} + +var server: Server = undefined; + +fn shutdown(_: std.c.SIG) callconv(.c) void { + server.stop(); +} + +pub fn main() !void { + const args = try std.process.argsAlloc(std.heap.page_allocator); + defer std.process.argsFree(std.heap.page_allocator, args); + + var port: u16 = 14318; // Default OTLP test port + + // Parse command line arguments + var i: usize = 1; + while (i < args.len) : (i += 1) { + if (std.mem.eql(u8, args[i], "--port") or std.mem.eql(u8, args[i], "-p")) { + if (i + 1 < args.len) { + port = try std.fmt.parseInt(u16, args[i + 1], 10); + i += 1; + } + } + } + + const host: []const u8 = "127.0.0.1"; + + var gpa: std.heap.DebugAllocator(.{}) = .init; + const allocator = gpa.allocator(); + defer _ = gpa.deinit(); + + // Set up trace storage allocator + trace_allocator = allocator; + + // Clean up stored traces on exit + defer { + for (stored_traces.items) |trace| { + trace_allocator.free(trace.body_preview); + } + stored_traces.deinit(trace_allocator); + } + + std.posix.sigaction(std.posix.SIG.TERM, &.{ + .handler = .{ .handler = shutdown }, + .mask = std.posix.sigemptyset(), + .flags = 0, + }, null); + + var threaded: Io.Threaded = .init(allocator); + defer threaded.deinit(); + const io = threaded.io(); + + var router = try Router.init(allocator, &.{ + // OTLP trace endpoint + Route.init("/v1/traces").post({}, handleOtlpTraces).layer(), + // Test verification endpoints + Route.init("/traces").get({}, handleGetTraces).delete({}, handleClearTraces).layer(), + }, .{}); + defer router.deinit(allocator); + + const addr = try Io.net.IpAddress.parse(host, port); + var socket = try addr.listen(io, .{ + .kernel_backlog = 1024, + .reuse_address = true, + }); + defer socket.deinit(io); + + log.info("Mock OTLP Collector listening on {s}:{d}", .{ host, port }); + + server = try Server.init(allocator, .{ + .socket_buffer_bytes = 1024 * 64, // Larger buffer for protobuf data + .keepalive_count_max = null, + .connection_count_max = 128, + }); + defer server.deinit(); + + try server.serve(io, &router, &socket); +} diff --git a/tests/integration_test.zig b/tests/integration_test.zig index afa2e73..6cc285e 100644 --- a/tests/integration_test.zig +++ b/tests/integration_test.zig @@ -12,6 +12,8 @@ const headers = @import("suites/headers.zig"); const body = @import("suites/body.zig"); const load_balancing = @import("suites/load_balancing.zig"); const http2 = @import("suites/http2.zig"); +const otel = @import("suites/otel.zig"); +const waf = @import("suites/waf.zig"); pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; @@ -28,6 +30,9 @@ pub fn main() !void { body.suite, load_balancing.suite, http2.suite, + otel.suite, + waf.suite, + waf.shadow_suite, }; var suite_failures: usize = 0; diff --git a/tests/process_manager.zig b/tests/process_manager.zig index 8247a1b..0327c5a 100644 --- a/tests/process_manager.zig +++ b/tests/process_manager.zig @@ -8,6 +8,7 @@ const posix = std.posix; const test_utils = @import("test_utils.zig"); pub const H2_BACKEND_PORT: u16 = 9443; +pub const OTLP_COLLECTOR_PORT: u16 = 14318; pub const Process = struct { child: std.process.Child, @@ -227,4 +228,168 @@ pub const ProcessManager = struct { p.deinit(); } } + + /// Start mock OTLP collector for telemetry tests + pub fn startOtlpCollector(self: *ProcessManager) !void { + var port_buf: [8]u8 = undefined; + const port_str = try std.fmt.bufPrint(&port_buf, "{d}", .{OTLP_COLLECTOR_PORT}); + + var child = std.process.Child.init( + &.{ "./zig-out/bin/mock_otlp_collector", "--port", port_str }, + self.allocator, + ); + child.stdin_behavior = .Ignore; + child.stdout_behavior = .Ignore; + child.stderr_behavior = .Ignore; + + try child.spawn(); + errdefer { + _ = child.kill() catch {}; + _ = child.wait() catch {}; + } + + try self.processes.append(self.allocator, .{ + .child = child, + .name = try self.allocator.dupe(u8, "mock_otlp_collector"), + .allocator = self.allocator, + }); + + // Wait for port to be ready + try test_utils.waitForPort(OTLP_COLLECTOR_PORT, 10000); + } + + /// Start load balancer with OTLP telemetry enabled + pub fn startLoadBalancerWithOtel(self: *ProcessManager, backend_ports: []const u16) !void { + var args: std.ArrayList([]const u8) = .empty; + defer args.deinit(self.allocator); + + // Track strings we allocate so we can free them + var allocated_strings: std.ArrayList([]const u8) = .empty; + defer { + for (allocated_strings.items) |s| self.allocator.free(s); + allocated_strings.deinit(self.allocator); + } + + try args.append(self.allocator, "./zig-out/bin/load_balancer"); + try args.append(self.allocator, "--port"); + + var lb_port_buf: [8]u8 = undefined; + const lb_port_str = try std.fmt.bufPrint(&lb_port_buf, "{d}", .{test_utils.LB_PORT}); + const lb_port_dup = try self.allocator.dupe(u8, lb_port_str); + try allocated_strings.append(self.allocator, lb_port_dup); + try args.append(self.allocator, lb_port_dup); + + // Use single-process mode for easier testing + try args.append(self.allocator, "--mode"); + try args.append(self.allocator, "sp"); + + // Add OTLP endpoint + try args.append(self.allocator, "--otel-endpoint"); + var otel_buf: [32]u8 = undefined; + const otel_str = try std.fmt.bufPrint(&otel_buf, "127.0.0.1:{d}", .{OTLP_COLLECTOR_PORT}); + const otel_dup = try self.allocator.dupe(u8, otel_str); + try allocated_strings.append(self.allocator, otel_dup); + try args.append(self.allocator, otel_dup); + + for (backend_ports) |port| { + try args.append(self.allocator, "--backend"); + var buf: [32]u8 = undefined; + const backend_str = try std.fmt.bufPrint(&buf, "127.0.0.1:{d}", .{port}); + const backend_dup = try self.allocator.dupe(u8, backend_str); + try allocated_strings.append(self.allocator, backend_dup); + try args.append(self.allocator, backend_dup); + } + + var child = std.process.Child.init(args.items, self.allocator); + child.stdin_behavior = .Ignore; + child.stdout_behavior = .Ignore; + child.stderr_behavior = .Ignore; + + try child.spawn(); + errdefer { + _ = child.kill() catch {}; + _ = child.wait() catch {}; + } + + try self.processes.append(self.allocator, .{ + .child = child, + .name = try self.allocator.dupe(u8, "load_balancer_otel"), + .allocator = self.allocator, + }); + + // Wait for LB port + try test_utils.waitForPort(test_utils.LB_PORT, 10000); + + // Wait for health checks (backends need to be marked healthy) + posix.nanosleep(2, 0); + } + + /// Start load balancer with WAF enabled + pub fn startLoadBalancerWithWaf(self: *ProcessManager, backend_ports: []const u16, waf_config_path: []const u8) !void { + try self.startLoadBalancerWithWafOnPort(backend_ports, waf_config_path, test_utils.LB_PORT); + } + + /// Start load balancer with WAF enabled on a specific port + pub fn startLoadBalancerWithWafOnPort(self: *ProcessManager, backend_ports: []const u16, waf_config_path: []const u8, port: u16) !void { + var args: std.ArrayList([]const u8) = .empty; + defer args.deinit(self.allocator); + + // Track strings we allocate so we can free them + var allocated_strings: std.ArrayList([]const u8) = .empty; + defer { + for (allocated_strings.items) |s| self.allocator.free(s); + allocated_strings.deinit(self.allocator); + } + + try args.append(self.allocator, "./zig-out/bin/load_balancer"); + try args.append(self.allocator, "--port"); + + var lb_port_buf: [8]u8 = undefined; + const lb_port_str = try std.fmt.bufPrint(&lb_port_buf, "{d}", .{port}); + const lb_port_dup = try self.allocator.dupe(u8, lb_port_str); + try allocated_strings.append(self.allocator, lb_port_dup); + try args.append(self.allocator, lb_port_dup); + + // Use single-process mode for easier testing + try args.append(self.allocator, "--mode"); + try args.append(self.allocator, "sp"); + + // Add WAF config path + try args.append(self.allocator, "--waf"); + const waf_path_dup = try self.allocator.dupe(u8, waf_config_path); + try allocated_strings.append(self.allocator, waf_path_dup); + try args.append(self.allocator, waf_path_dup); + + for (backend_ports) |backend_port| { + try args.append(self.allocator, "--backend"); + var buf: [32]u8 = undefined; + const backend_str = try std.fmt.bufPrint(&buf, "127.0.0.1:{d}", .{backend_port}); + const backend_dup = try self.allocator.dupe(u8, backend_str); + try allocated_strings.append(self.allocator, backend_dup); + try args.append(self.allocator, backend_dup); + } + + var child = std.process.Child.init(args.items, self.allocator); + child.stdin_behavior = .Ignore; + child.stdout_behavior = .Ignore; + child.stderr_behavior = .Ignore; + + try child.spawn(); + errdefer { + _ = child.kill() catch {}; + _ = child.wait() catch {}; + } + + try self.processes.append(self.allocator, .{ + .child = child, + .name = try self.allocator.dupe(u8, "load_balancer_waf"), + .allocator = self.allocator, + }); + + // Wait for LB port + try test_utils.waitForPort(port, 10000); + + // Wait for health checks (backends need to be marked healthy) + posix.nanosleep(2, 0); + } }; diff --git a/tests/suites/otel.zig b/tests/suites/otel.zig new file mode 100644 index 0000000..2773872 --- /dev/null +++ b/tests/suites/otel.zig @@ -0,0 +1,138 @@ +//! OpenTelemetry integration tests. +//! +//! Tests that the load balancer correctly exports traces to an OTLP collector. +//! Uses a mock OTLP collector to receive and verify trace data. + +const std = @import("std"); +const harness = @import("../harness.zig"); +const utils = @import("../test_utils.zig"); +const ProcessManager = @import("../process_manager.zig").ProcessManager; + +var pm: ProcessManager = undefined; + +fn beforeAll(allocator: std.mem.Allocator) !void { + pm = ProcessManager.init(allocator); + + // Start mock OTLP collector first + try pm.startOtlpCollector(); + + // Start backend + try pm.startBackend(utils.BACKEND1_PORT, "backend1"); + + // Start load balancer with OTLP endpoint + try pm.startLoadBalancerWithOtel(&.{utils.BACKEND1_PORT}); +} + +fn afterAll(_: std.mem.Allocator) !void { + pm.deinit(); +} + +fn testTracesExported(allocator: std.mem.Allocator) !void { + // Clear any existing traces + try utils.clearOtlpTraces(allocator); + + // Make a request through the load balancer + const response = try utils.httpRequest(allocator, "GET", utils.LB_PORT, "/test/trace", null, null); + defer allocator.free(response); + + // Verify request succeeded + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 200), status); + + // Wait for traces to be exported (batching processor has delay) + // The BatchingProcessor exports every 5 seconds or when batch is full + try utils.waitForTraces(allocator, 1, 10000); + + // Verify we received at least one trace + const count = try utils.getOtlpTraceCount(allocator); + try std.testing.expect(count >= 1); +} + +fn testMultipleRequestsGenerateTraces(allocator: std.mem.Allocator) !void { + // Clear existing traces + try utils.clearOtlpTraces(allocator); + + // Make multiple requests + const num_requests: usize = 5; + for (0..num_requests) |_| { + const response = try utils.httpRequest(allocator, "GET", utils.LB_PORT, "/api/multi", null, null); + defer allocator.free(response); + + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 200), status); + } + + // Wait for at least one trace export + // Note: BatchingProcessor batches spans, so multiple requests may result + // in a single export containing all spans + try utils.waitForTraces(allocator, 1, 10000); + + // Verify we got at least one trace export + const count = try utils.getOtlpTraceCount(allocator); + try std.testing.expect(count >= 1); +} + +fn testPostRequestGeneratesTrace(allocator: std.mem.Allocator) !void { + // Clear existing traces + try utils.clearOtlpTraces(allocator); + + // Make a POST request with body + const body = "{\"test\":\"data\"}"; + const headers = &[_][2][]const u8{.{ "Content-Type", "application/json" }}; + + const response = try utils.httpRequest(allocator, "POST", utils.LB_PORT, "/api/post", headers, body); + defer allocator.free(response); + + // Verify request succeeded + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 200), status); + + // Wait for trace + try utils.waitForTraces(allocator, 1, 10000); + + const count = try utils.getOtlpTraceCount(allocator); + try std.testing.expect(count >= 1); +} + +fn testTracesHaveData(allocator: std.mem.Allocator) !void { + // Clear existing traces + try utils.clearOtlpTraces(allocator); + + // Make a request + const response = try utils.httpRequest(allocator, "GET", utils.LB_PORT, "/verify/data", null, null); + defer allocator.free(response); + + // Wait for trace + try utils.waitForTraces(allocator, 1, 10000); + + // Get the raw traces response + const traces_response = try utils.httpRequest(allocator, "GET", utils.OTLP_PORT, "/traces", null, null); + defer allocator.free(traces_response); + + const traces_body = try utils.extractJsonBody(traces_response); + + // Verify the response contains trace data + // The body_size field should be > 0 indicating actual protobuf data was received + const parsed = try std.json.parseFromSlice(std.json.Value, allocator, traces_body, .{}); + defer parsed.deinit(); + + const traces_array = parsed.value.object.get("traces") orelse return error.NoTraces; + try std.testing.expect(traces_array.array.items.len >= 1); + + // Check first trace has non-zero body size + const first_trace = traces_array.array.items[0]; + const body_size = first_trace.object.get("body_size") orelse return error.NoBodySize; + try std.testing.expect(body_size.integer > 0); +} + +pub const suite = harness.Suite{ + .name = "OpenTelemetry Tracing", + .before_all = beforeAll, + .after_all = afterAll, + .tests = &.{ + harness.it("exports traces to OTLP collector", testTracesExported), + harness.it("generates traces for multiple requests", testMultipleRequestsGenerateTraces), + harness.it("generates traces for POST requests", testPostRequestGeneratesTrace), + harness.it("trace data contains valid protobuf", testTracesHaveData), + }, +}; diff --git a/tests/suites/waf.zig b/tests/suites/waf.zig new file mode 100644 index 0000000..95e8619 --- /dev/null +++ b/tests/suites/waf.zig @@ -0,0 +1,327 @@ +//! WAF (Web Application Firewall) integration tests. +//! +//! Tests rate limiting, body size limits, URI length limits, and shadow mode. +//! Creates a temporary WAF config file for each test suite run. + +const std = @import("std"); +const harness = @import("../harness.zig"); +const utils = @import("../test_utils.zig"); +const ProcessManager = @import("../process_manager.zig").ProcessManager; +const posix = std.posix; + +var pm: ProcessManager = undefined; +var waf_config_path: []const u8 = undefined; +var allocator: std.mem.Allocator = undefined; + +/// WAF test config - tight limits for testing +const WAF_CONFIG = + \\{ + \\ "enabled": true, + \\ "shadow_mode": false, + \\ "rate_limits": [ + \\ { + \\ "name": "test_limit", + \\ "path": "/api/*", + \\ "limit": { "requests": 5, "period_sec": 60 }, + \\ "burst": 0, + \\ "by": "ip", + \\ "action": "block" + \\ } + \\ ], + \\ "request_limits": { + \\ "max_uri_length": 100, + \\ "max_body_size": 1024 + \\ } + \\} +; + +/// Shadow mode config - logs but doesn't block +const WAF_SHADOW_CONFIG = + \\{ + \\ "enabled": true, + \\ "shadow_mode": true, + \\ "rate_limits": [ + \\ { + \\ "name": "test_limit", + \\ "path": "/api/*", + \\ "limit": { "requests": 2, "period_sec": 60 }, + \\ "burst": 0, + \\ "by": "ip", + \\ "action": "block" + \\ } + \\ ], + \\ "request_limits": { + \\ "max_uri_length": 50, + \\ "max_body_size": 512 + \\ } + \\} +; + +fn beforeAll(alloc: std.mem.Allocator) !void { + allocator = alloc; + pm = ProcessManager.init(alloc); + + // Create temporary WAF config file + waf_config_path = try createTempWafConfig(alloc, WAF_CONFIG); + + // Start backend on dedicated WAF port + try pm.startBackend(utils.WAF_BACKEND_PORT, "waf_backend"); + + // Start load balancer with WAF enabled on dedicated port + try pm.startLoadBalancerWithWafOnPort(&.{utils.WAF_BACKEND_PORT}, waf_config_path, utils.WAF_LB_PORT); +} + +fn afterAll(_: std.mem.Allocator) !void { + pm.deinit(); + + // Clean up temporary WAF config file + deleteTempWafConfig(waf_config_path); + allocator.free(waf_config_path); +} + +/// Create a temporary WAF config file +fn createTempWafConfig(alloc: std.mem.Allocator, config: []const u8) ![]const u8 { + // Use /tmp for temporary files with a unique name based on timestamp + const now = std.time.Instant.now() catch return error.TimerUnavailable; + // Use seconds + nanoseconds for unique filename + const ts_sec: i64 = now.timestamp.sec; + const ts_nsec: i64 = now.timestamp.nsec; + const path = try std.fmt.allocPrint(alloc, "/tmp/waf_test_{d}_{d}.json", .{ ts_sec, ts_nsec }); + errdefer alloc.free(path); + + const file = try std.fs.createFileAbsolute(path, .{}); + defer file.close(); + + try file.writeAll(config); + + return path; +} + +/// Delete the temporary WAF config file +fn deleteTempWafConfig(path: []const u8) void { + std.fs.deleteFileAbsolute(path) catch {}; +} + +// ============================================================================= +// Rate Limiting Tests +// ============================================================================= + +fn testRateLimitingBlocks(alloc: std.mem.Allocator) !void { + // The rate limit is 5 requests per 60 seconds with no burst + // We need to make 6 requests to trigger the block + // Note: we use a unique path suffix to avoid interference from other tests + + const path = "/api/rate-test-1"; + + // Make 5 requests - all should succeed + for (0..5) |_| { + const response = try utils.httpRequest(alloc, "GET", utils.WAF_LB_PORT, path, null, null); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + // First 5 should succeed (200) or we might get 429 from previous test runs + if (status != 200 and status != 429) { + return error.UnexpectedStatus; + } + } + + // The 6th request should be rate limited (429) + const response = try utils.httpRequest(alloc, "GET", utils.WAF_LB_PORT, path, null, null); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 429), status); +} + +fn testNonApiPathNotRateLimited(alloc: std.mem.Allocator) !void { + // The rate limit only applies to /api/* paths + // Requests to other paths should not be rate limited + + const path = "/other/path"; + + // Make many requests - all should succeed + for (0..10) |_| { + const response = try utils.httpRequest(alloc, "GET", utils.WAF_LB_PORT, path, null, null); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 200), status); + } +} + +// ============================================================================= +// Request Size Limit Tests +// ============================================================================= + +fn testBodySizeLimitBlocks(alloc: std.mem.Allocator) !void { + // The max_body_size is 1024 bytes + // Sending a larger body should be blocked with 413 + + // Create a body larger than 1024 bytes + const large_body = try alloc.alloc(u8, 2000); + defer alloc.free(large_body); + @memset(large_body, 'X'); + + const headers = &[_][2][]const u8{.{ "Content-Type", "application/octet-stream" }}; + + const response = try utils.httpRequest(alloc, "POST", utils.WAF_LB_PORT, "/upload", headers, large_body); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 413), status); +} + +fn testSmallBodyAllowed(alloc: std.mem.Allocator) !void { + // A body smaller than 1024 bytes should be allowed + + const small_body = "This is a small body that should be allowed"; + const headers = &[_][2][]const u8{.{ "Content-Type", "text/plain" }}; + + const response = try utils.httpRequest(alloc, "POST", utils.WAF_LB_PORT, "/data", headers, small_body); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 200), status); +} + +// ============================================================================= +// URI Length Limit Tests +// ============================================================================= + +fn testUriLengthLimitBlocks(alloc: std.mem.Allocator) !void { + // The max_uri_length is 100 bytes + // Sending a longer URI should be blocked with 403 + + // Create a URI longer than 100 bytes + var long_uri_buf: [200]u8 = undefined; + @memset(&long_uri_buf, 'a'); + long_uri_buf[0] = '/'; + const long_uri = long_uri_buf[0..150]; + + const response = try utils.httpRequest(alloc, "GET", utils.WAF_LB_PORT, long_uri, null, null); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 403), status); +} + +fn testShortUriAllowed(alloc: std.mem.Allocator) !void { + // A URI shorter than 100 bytes should be allowed + + const short_uri = "/short/path"; + + const response = try utils.httpRequest(alloc, "GET", utils.WAF_LB_PORT, short_uri, null, null); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + try std.testing.expectEqual(@as(u16, 200), status); +} + +pub const suite = harness.Suite{ + .name = "WAF (Web Application Firewall)", + .before_all = beforeAll, + .after_all = afterAll, + .tests = &.{ + harness.it("rate limiting blocks after limit exceeded", testRateLimitingBlocks), + harness.it("non-API paths not rate limited", testNonApiPathNotRateLimited), + harness.it("blocks requests with body exceeding size limit", testBodySizeLimitBlocks), + harness.it("allows requests with small body", testSmallBodyAllowed), + harness.it("blocks requests with URI exceeding length limit", testUriLengthLimitBlocks), + harness.it("allows requests with short URI", testShortUriAllowed), + }, +}; + +// ============================================================================= +// Shadow Mode Test Suite +// ============================================================================= + +var pm_shadow: ProcessManager = undefined; +var waf_shadow_config_path: []const u8 = undefined; + +fn beforeAllShadow(alloc: std.mem.Allocator) !void { + allocator = alloc; + pm_shadow = ProcessManager.init(alloc); + + // Create temporary WAF config file with shadow mode + waf_shadow_config_path = try createTempWafConfig(alloc, WAF_SHADOW_CONFIG); + + // Start backend on dedicated WAF shadow port to avoid conflicts + try pm_shadow.startBackend(utils.WAF_SHADOW_BACKEND_PORT, "waf_shadow_backend"); + + // Start load balancer with WAF shadow mode on dedicated port + try pm_shadow.startLoadBalancerWithWafOnPort(&.{utils.WAF_SHADOW_BACKEND_PORT}, waf_shadow_config_path, utils.WAF_SHADOW_LB_PORT); +} + +fn afterAllShadow(_: std.mem.Allocator) !void { + pm_shadow.deinit(); + + // Clean up temporary WAF config file + deleteTempWafConfig(waf_shadow_config_path); + allocator.free(waf_shadow_config_path); +} + +fn testShadowModeDoesNotBlock(alloc: std.mem.Allocator) !void { + // In shadow mode, the rate limit is 2 requests per 60 seconds + // But shadow mode should NOT block, only log + // All requests should succeed + + const path = "/api/shadow-test"; + + // Make more requests than the limit + for (0..5) |_| { + const response = try utils.httpRequest(alloc, "GET", utils.WAF_SHADOW_LB_PORT, path, null, null); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + // In shadow mode, all requests should succeed + try std.testing.expectEqual(@as(u16, 200), status); + } +} + +fn testShadowModeAllowsLargeBody(alloc: std.mem.Allocator) !void { + // In shadow mode, max_body_size is 512 bytes + // But shadow mode should NOT block, only log + + // Create a body larger than 512 bytes + const large_body = try alloc.alloc(u8, 800); + defer alloc.free(large_body); + @memset(large_body, 'Y'); + + const headers = &[_][2][]const u8{.{ "Content-Type", "application/octet-stream" }}; + + const response = try utils.httpRequest(alloc, "POST", utils.WAF_SHADOW_LB_PORT, "/upload", headers, large_body); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + // In shadow mode, should still succeed + try std.testing.expectEqual(@as(u16, 200), status); +} + +fn testShadowModeAllowsLongUri(alloc: std.mem.Allocator) !void { + // In shadow mode, max_uri_length is 50 bytes + // But shadow mode should NOT block, only log + + // Create a URI longer than 50 bytes but shorter than typical limits + var uri_buf: [80]u8 = undefined; + @memset(&uri_buf, 'z'); + uri_buf[0] = '/'; + const long_uri = uri_buf[0..70]; + + const response = try utils.httpRequest(alloc, "GET", utils.WAF_SHADOW_LB_PORT, long_uri, null, null); + defer alloc.free(response); + + const status = try utils.getResponseStatusCode(response); + // In shadow mode, should still succeed + try std.testing.expectEqual(@as(u16, 200), status); +} + +pub const shadow_suite = harness.Suite{ + .name = "WAF Shadow Mode", + .before_all = beforeAllShadow, + .after_all = afterAllShadow, + .tests = &.{ + harness.it("shadow mode does not block rate-limited requests", testShadowModeDoesNotBlock), + harness.it("shadow mode allows large body requests", testShadowModeAllowsLargeBody), + harness.it("shadow mode allows long URI requests", testShadowModeAllowsLongUri), + }, +}; diff --git a/tests/test_utils.zig b/tests/test_utils.zig index a50ad27..7daca23 100644 --- a/tests/test_utils.zig +++ b/tests/test_utils.zig @@ -14,6 +14,11 @@ pub const BACKEND2_PORT: u16 = 19002; pub const BACKEND3_PORT: u16 = 19003; pub const LB_PORT: u16 = 18080; pub const LB_H2_PORT: u16 = 18081; // Load balancer port for HTTP/2 tests +pub const OTLP_PORT: u16 = 14318; // Mock OTLP collector port +pub const WAF_LB_PORT: u16 = 18082; // Load balancer port for WAF tests +pub const WAF_SHADOW_LB_PORT: u16 = 18083; // Load balancer port for WAF shadow mode tests +pub const WAF_BACKEND_PORT: u16 = 19004; // Backend port for WAF tests +pub const WAF_SHADOW_BACKEND_PORT: u16 = 19005; // Backend port for WAF shadow mode tests /// Wait for a port to accept connections pub fn waitForPort(port: u16, timeout_ms: u64) !void { @@ -238,3 +243,43 @@ pub fn getResponseHeaderValue(response: []const u8, header_name: []const u8) ![] } return error.HeaderNotFound; } + +/// Get trace count from mock OTLP collector +pub fn getOtlpTraceCount(allocator: std.mem.Allocator) !i64 { + const response = try httpRequest(allocator, "GET", OTLP_PORT, "/traces", null, null); + defer allocator.free(response); + + const body = try extractJsonBody(response); + return try getJsonInt(allocator, body, "trace_count"); +} + +/// Clear traces in mock OTLP collector +pub fn clearOtlpTraces(allocator: std.mem.Allocator) !void { + const response = try httpRequest(allocator, "DELETE", OTLP_PORT, "/traces", null, null); + defer allocator.free(response); + + const status = try getResponseStatusCode(response); + if (status != 200) { + return error.ClearFailed; + } +} + +/// Wait for traces to be received by collector +pub fn waitForTraces(allocator: std.mem.Allocator, min_count: i64, timeout_ms: u64) !void { + const start = std.time.Instant.now() catch return error.TimerUnavailable; + const timeout_ns = timeout_ms * std.time.ns_per_ms; + + while (true) { + const count = getOtlpTraceCount(allocator) catch 0; + if (count >= min_count) { + return; + } + + const now = std.time.Instant.now() catch return error.TimerUnavailable; + if (now.since(start) >= timeout_ns) { + return error.TraceTimeout; + } + + posix.nanosleep(0, 100 * std.time.ns_per_ms); + } +} diff --git a/vendor/otel b/vendor/otel new file mode 160000 index 0000000..b03dc06 --- /dev/null +++ b/vendor/otel @@ -0,0 +1 @@ +Subproject commit b03dc06be8f10373d793694b5766413cae9ab4e0 diff --git a/vendor/otel-proto b/vendor/otel-proto new file mode 160000 index 0000000..a4cb1ce --- /dev/null +++ b/vendor/otel-proto @@ -0,0 +1 @@ +Subproject commit a4cb1cece1ba93f0258de091ba826c2bbe025c7b diff --git a/vendor/protobuf b/vendor/protobuf new file mode 160000 index 0000000..2828be0 --- /dev/null +++ b/vendor/protobuf @@ -0,0 +1 @@ +Subproject commit 2828be045c5f3e55c6f3f239c2ec40bc480a26ca diff --git a/waf_test.json b/waf_test.json new file mode 100644 index 0000000..4fc3afc --- /dev/null +++ b/waf_test.json @@ -0,0 +1,18 @@ +{ + "enabled": true, + "shadow_mode": false, + "rate_limits": [ + { + "name": "test_limit", + "path": "/*", + "limit": { "requests": 200000, "period_sec": 60 }, + "burst": 0, + "by": "ip", + "action": "block" + } + ], + "request_limits": { + "max_uri_length": 100, + "max_body_size": 1024 + } +}