From 96dedc61fe6dce8531b38f8f83252ab572af9f8f Mon Sep 17 00:00:00 2001 From: vansour Date: Sun, 17 May 2026 14:15:32 +0800 Subject: [PATCH 1/4] chore: tighten clippy lint policy --- Cargo.toml | 32 ++ clippy.toml | 1 + crates/rginx-agent/Cargo.toml | 3 + crates/rginx-agent/src/agent_core.rs | 265 ++++++------ crates/rginx-agent/src/audit.rs | 30 +- crates/rginx-agent/src/auth.rs | 96 ++--- crates/rginx-agent/src/auth/keyring.rs | 94 ++--- crates/rginx-agent/src/circuit_breaker.rs | 228 +++++----- crates/rginx-agent/src/config_history.rs | 147 +++---- crates/rginx-agent/src/config_history/diff.rs | 8 +- crates/rginx-agent/src/config_validator.rs | 140 ++++--- .../rginx-agent/src/control_center/model.rs | 59 +-- .../rginx-agent/src/control_center/query.rs | 109 +++-- .../rginx-agent/src/control_center/rollout.rs | 23 +- .../rginx-agent/src/control_center/store.rs | 193 ++++----- .../src/control_center/trait_adapter.rs | 15 +- crates/rginx-agent/src/error.rs | 28 +- crates/rginx-agent/src/events.rs | 59 +-- crates/rginx-agent/src/gradual_rollout.rs | 196 ++++----- .../rginx-agent/src/gradual_rollout/status.rs | 61 ++- crates/rginx-agent/src/lib.rs | 15 +- crates/rginx-agent/src/metrics.rs | 19 +- crates/rginx-agent/src/model.rs | 60 +-- crates/rginx-agent/src/outbound/auth.rs | 22 +- crates/rginx-agent/src/outbound/client.rs | 82 ++-- crates/rginx-agent/src/outbound/command.rs | 52 +-- crates/rginx-agent/src/outbound/model.rs | 40 +- crates/rginx-agent/src/outbound/runner.rs | 336 +++++++-------- crates/rginx-agent/src/outbound/state.rs | 88 ++-- crates/rginx-agent/src/outbound/stream.rs | 99 +++-- crates/rginx-agent/src/outbound/timing.rs | 2 +- crates/rginx-agent/src/rate_limit.rs | 97 +++-- crates/rginx-agent/src/registry.rs | 144 +++---- crates/rginx-agent/src/registry/tests.rs | 6 +- crates/rginx-agent/src/server/breaker.rs | 4 +- crates/rginx-agent/src/server/config.rs | 68 +-- crates/rginx-agent/src/server/control.rs | 120 +++--- crates/rginx-agent/src/server/maintenance.rs | 4 +- crates/rginx-agent/src/server/mod.rs | 55 ++- crates/rginx-agent/src/server/registry.rs | 126 +++--- crates/rginx-agent/src/server/request.rs | 36 +- .../rginx-agent/src/server/request/query.rs | 121 +++--- crates/rginx-agent/src/server/request/read.rs | 16 +- crates/rginx-agent/src/server/response.rs | 7 +- crates/rginx-agent/src/server/rollout.rs | 31 +- crates/rginx-agent/src/server/write.rs | 12 +- .../rginx-agent/src/server/write/routing.rs | 4 +- crates/rginx-agent/src/system.rs | 29 +- crates/rginx-agent/src/tests.rs | 22 +- crates/rginx-agent/src/tests/outbound.rs | 221 +++++----- crates/rginx-agent/src/tests/outbound_auth.rs | 10 +- .../rginx-agent/src/tests/outbound_stream.rs | 135 +++--- crates/rginx-agent/src/tests/support.rs | 202 ++++----- .../src/tests/support/executors.rs | 18 +- crates/rginx-agent/src/tls.rs | 23 +- crates/rginx-agent/src/websocket.rs | 93 +++-- crates/rginx-app/Cargo.toml | 3 + crates/rginx-app/src/admin_cli/mod.rs | 24 +- crates/rginx-app/src/admin_cli/status.rs | 8 +- crates/rginx-app/src/admin_cli/traffic.rs | 8 +- crates/rginx-app/src/check/acme.rs | 16 +- crates/rginx-app/src/check/control.rs | 4 +- crates/rginx-app/src/check/routes.rs | 64 ++- crates/rginx-app/src/check/summary.rs | 72 ++-- crates/rginx-app/src/check/tls.rs | 40 +- crates/rginx-app/src/cli.rs | 58 ++- crates/rginx-app/src/main.rs | 15 +- crates/rginx-app/src/pid_file.rs | 12 +- crates/rginx-app/tests/access_log.rs | 112 ++--- crates/rginx-app/tests/active_health.rs | 31 +- crates/rginx-app/tests/admin.rs | 59 +-- crates/rginx-app/tests/backup.rs | 13 +- crates/rginx-app/tests/cache.rs | 9 + crates/rginx-app/tests/cache/benchmarks.rs | 6 +- crates/rginx-app/tests/check.rs | 35 +- crates/rginx-app/tests/check/helpers.rs | 4 +- crates/rginx-app/tests/compression.rs | 42 +- crates/rginx-app/tests/dns_refresh.rs | 13 +- crates/rginx-app/tests/downstream_mtls.rs | 59 ++- .../tests/downstream_mtls/verifier.rs | 7 +- crates/rginx-app/tests/failover.rs | 13 +- crates/rginx-app/tests/grpc_http3.rs | 49 ++- crates/rginx-app/tests/grpc_http3/helpers.rs | 2 +- crates/rginx-app/tests/grpc_proxy.rs | 106 ++--- crates/rginx-app/tests/grpc_proxy/basic.rs | 3 +- .../tests/grpc_proxy/basic/grpc_web.rs | 8 +- .../tests/grpc_proxy/basic/routing.rs | 2 +- .../tests/grpc_proxy/helpers/body.rs | 66 +-- .../tests/grpc_proxy/helpers/grpc_web.rs | 18 +- .../tests/grpc_proxy/helpers/server.rs | 12 +- .../rginx-app/tests/grpc_proxy/helpers/tls.rs | 47 ++- .../tests/grpc_proxy/helpers/upstream.rs | 4 +- .../rginx-app/tests/grpc_proxy/lifecycle.rs | 8 +- crates/rginx-app/tests/grpc_proxy/timeout.rs | 3 +- .../tests/grpc_proxy/timeout/grpc_web.rs | 4 +- .../tests/grpc_proxy/timeout/validation.rs | 4 +- crates/rginx-app/tests/hardening.rs | 37 +- crates/rginx-app/tests/http2.rs | 112 ++--- crates/rginx-app/tests/http3.rs | 49 ++- .../rginx-app/tests/http3/helpers/client.rs | 22 +- .../tests/http3/helpers/client/tls.rs | 7 +- .../rginx-app/tests/http3/helpers/fixtures.rs | 56 +-- crates/rginx-app/tests/ip_hash.rs | 13 +- crates/rginx-app/tests/least_conn.rs | 13 +- crates/rginx-app/tests/multi_listener.rs | 112 ++--- crates/rginx-app/tests/nginx_alignment.rs | 13 +- crates/rginx-app/tests/nginx_diff.rs | 13 +- crates/rginx-app/tests/ocsp.rs | 27 +- crates/rginx-app/tests/ocsp/helpers.rs | 40 +- crates/rginx-app/tests/phase1.rs | 47 ++- crates/rginx-app/tests/policy.rs | 28 +- crates/rginx-app/tests/proxy_protocol.rs | 15 +- crates/rginx-app/tests/reload.rs | 170 ++++---- crates/rginx-app/tests/reload/cache.rs | 10 +- .../rginx-app/tests/reload/cache_streaming.rs | 10 +- crates/rginx-app/tests/reload/cli.rs | 6 +- .../rginx-app/tests/reload/reload_boundary.rs | 10 +- crates/rginx-app/tests/reload/reload_flow.rs | 26 +- crates/rginx-app/tests/reload/restart_flow.rs | 10 +- .../rginx-app/tests/reload/streaming_flow.rs | 4 +- .../rginx-app/tests/static_file_streaming.rs | 44 +- crates/rginx-app/tests/streaming_download.rs | 20 +- crates/rginx-app/tests/support/cache/mod.rs | 26 +- .../rginx-app/tests/support/cache/response.rs | 19 +- .../rginx-app/tests/support/cache/upstream.rs | 15 +- crates/rginx-app/tests/support/harness.rs | 201 +++++---- crates/rginx-app/tests/support/http.rs | 30 +- crates/rginx-app/tests/support/mod.rs | 86 ++-- crates/rginx-app/tests/support/nginx.rs | 137 +++--- crates/rginx-app/tests/support/response.rs | 12 +- crates/rginx-app/tests/support/tls.rs | 7 +- crates/rginx-app/tests/tls_policy.rs | 153 +++---- crates/rginx-app/tests/upgrade.rs | 25 +- crates/rginx-app/tests/upstream_http2.rs | 19 +- crates/rginx-app/tests/upstream_http3.rs | 199 ++++----- crates/rginx-app/tests/upstream_mtls.rs | 123 +++--- .../rginx-app/tests/upstream_server_name.rs | 57 +-- crates/rginx-app/tests/vhost.rs | 13 +- .../rginx-app/tests/weighted_round_robin.rs | 13 +- crates/rginx-app/tests/workers.rs | 28 +- crates/rginx-config/Cargo.toml | 3 + crates/rginx-config/src/compile/acme.rs | 2 +- crates/rginx-config/src/compile/cache.rs | 9 +- crates/rginx-config/src/compile/mod.rs | 26 +- crates/rginx-config/src/compile/route.rs | 6 +- crates/rginx-config/src/compile/server.rs | 12 +- .../rginx-config/src/compile/server/fields.rs | 18 +- .../src/compile/server/listener.rs | 23 +- .../compile/server/listener/vhost_binding.rs | 6 +- crates/rginx-config/src/compile/server/tls.rs | 6 +- .../src/compile/server/tls/identity.rs | 8 +- crates/rginx-config/src/compile/tests.rs | 37 +- crates/rginx-config/src/compile/tests/acme.rs | 8 +- .../rginx-config/src/compile/tests/cache.rs | 2 +- .../rginx-config/src/compile/tests/route.rs | 4 +- .../src/compile/tests/server_settings.rs | 2 +- .../rginx-config/src/compile/tests/vhosts.rs | 6 +- crates/rginx-config/src/compile/upstream.rs | 12 +- .../rginx-config/src/compile/upstream/tls.rs | 16 +- crates/rginx-config/src/compile/vhost.rs | 4 +- crates/rginx-config/src/lib.rs | 9 + crates/rginx-config/src/listen.rs | 12 +- crates/rginx-config/src/load.rs | 14 +- crates/rginx-config/src/load/env_expand.rs | 33 +- crates/rginx-config/src/load/layout.rs | 9 +- .../src/load/layout/array_rules.rs | 16 +- .../rginx-config/src/load/layout/scanner.rs | 77 ++-- crates/rginx-config/src/load/tests.rs | 122 +++--- crates/rginx-config/src/managed/mod.rs | 12 +- crates/rginx-config/src/managed/normalize.rs | 2 +- crates/rginx-config/src/managed/paths.rs | 2 +- crates/rginx-config/src/managed/types.rs | 71 ++-- crates/rginx-config/src/model.rs | 20 +- crates/rginx-config/src/model/acme.rs | 12 +- crates/rginx-config/src/model/agent.rs | 26 +- crates/rginx-config/src/model/cache.rs | 76 ++-- .../rginx-config/src/model/control_plane.rs | 18 +- crates/rginx-config/src/model/listener.rs | 54 +-- crates/rginx-config/src/model/route.rs | 88 ++-- crates/rginx-config/src/model/runtime.rs | 4 +- crates/rginx-config/src/model/server.rs | 76 ++-- crates/rginx-config/src/model/tls.rs | 96 ++--- crates/rginx-config/src/model/upstream.rs | 114 ++--- crates/rginx-config/src/model/vhost.rs | 14 +- crates/rginx-config/src/validate.rs | 17 +- crates/rginx-config/src/validate/cache.rs | 7 +- .../src/validate/cache/predicate.rs | 12 +- crates/rginx-config/src/validate/route.rs | 6 +- .../src/validate/route/handler.rs | 4 +- crates/rginx-config/src/validate/server.rs | 8 +- .../src/validate/server/listener.rs | 6 +- .../src/validate/server/listener/base.rs | 42 +- .../src/validate/server/listener/listeners.rs | 14 +- crates/rginx-config/src/validate/tests.rs | 24 +- .../src/validate/tests/control_plane.rs | 2 +- .../rginx-config/src/validate/tests/vhosts.rs | 4 +- crates/rginx-config/src/validate/upstream.rs | 14 +- .../src/validate/upstream/basics.rs | 2 +- .../rginx-config/src/validate/upstream/dns.rs | 2 +- .../src/validate/upstream/health.rs | 2 +- .../src/validate/upstream/protocol.rs | 7 +- .../rginx-config/src/validate/upstream/tls.rs | 2 +- .../src/validate/upstream/tuning.rs | 2 +- crates/rginx-core/Cargo.toml | 3 + crates/rginx-core/src/config.rs | 6 +- crates/rginx-core/src/config/access_log.rs | 137 +++--- .../src/config/access_log/variables.rs | 50 +-- crates/rginx-core/src/config/acme.rs | 13 +- crates/rginx-core/src/config/agent.rs | 18 +- crates/rginx-core/src/config/cache.rs | 86 ++-- .../src/config/cache/key_template.rs | 139 ++++--- .../rginx-core/src/config/cache/predicate.rs | 2 +- crates/rginx-core/src/config/control_plane.rs | 11 +- crates/rginx-core/src/config/listener.rs | 48 +-- crates/rginx-core/src/config/route.rs | 182 ++++---- .../src/config/route/proxy_header.rs | 77 ++-- .../src/config/route/regex_matcher.rs | 24 +- crates/rginx-core/src/config/server.rs | 27 +- crates/rginx-core/src/config/server_name.rs | 64 +-- .../rginx-core/src/config/snapshot/linear.rs | 43 +- .../rginx-core/src/config/snapshot/lookup.rs | 125 +++--- crates/rginx-core/src/config/snapshot/mod.rs | 154 +++---- .../src/config/snapshot/route_selection.rs | 27 +- .../src/config/snapshot/vhost_selection.rs | 47 ++- crates/rginx-core/src/config/tests/core.rs | 2 +- crates/rginx-core/src/config/tests/mod.rs | 11 +- crates/rginx-core/src/config/tls.rs | 40 +- crates/rginx-core/src/config/upstream.rs | 6 +- .../src/config/upstream/selection.rs | 159 +++---- .../rginx-core/src/config/upstream/types.rs | 98 ++--- crates/rginx-core/src/config/virtual_host.rs | 12 +- crates/rginx-core/src/error.rs | 12 +- crates/rginx-core/src/lib.rs | 9 + crates/rginx-http/Cargo.toml | 3 + crates/rginx-http/src/cache/entry.rs | 21 +- crates/rginx-http/src/cache/entry/metadata.rs | 100 ++--- crates/rginx-http/src/cache/entry/response.rs | 55 +-- .../src/cache/entry/response/body.rs | 16 +- .../rginx-http/src/cache/entry/signature.rs | 7 +- crates/rginx-http/src/cache/entry/temp.rs | 5 +- crates/rginx-http/src/cache/entry/write.rs | 2 +- crates/rginx-http/src/cache/fill.rs | 4 +- crates/rginx-http/src/cache/fill/external.rs | 57 +-- crates/rginx-http/src/cache/fill/local.rs | 187 ++++----- .../rginx-http/src/cache/fill/persistence.rs | 6 +- crates/rginx-http/src/cache/fill/shared.rs | 107 ++--- .../src/cache/fill/shared/access.rs | 6 +- crates/rginx-http/src/cache/index.rs | 185 ++++----- crates/rginx-http/src/cache/io.rs | 20 +- crates/rginx-http/src/cache/load.rs | 70 ++-- crates/rginx-http/src/cache/lookup.rs | 109 ++--- crates/rginx-http/src/cache/manager.rs | 30 +- .../rginx-http/src/cache/manager/bootstrap.rs | 6 +- .../rginx-http/src/cache/manager/control.rs | 90 ++-- .../src/cache/manager/lookup_support.rs | 46 ++- .../rginx-http/src/cache/manager/response.rs | 22 +- crates/rginx-http/src/cache/mod.rs | 143 +++---- crates/rginx-http/src/cache/policy.rs | 33 +- crates/rginx-http/src/cache/request.rs | 45 +- crates/rginx-http/src/cache/request/render.rs | 7 +- crates/rginx-http/src/cache/runtime.rs | 14 +- .../rginx-http/src/cache/runtime/context.rs | 72 ++-- .../rginx-http/src/cache/runtime/fill_lock.rs | 225 +++++----- .../rginx-http/src/cache/runtime/support.rs | 23 +- crates/rginx-http/src/cache/runtime/zone.rs | 390 +++++++++--------- .../src/cache/runtime/zone/compare.rs | 2 +- crates/rginx-http/src/cache/shared.rs | 29 +- .../rginx-http/src/cache/shared/bootstrap.rs | 7 +- crates/rginx-http/src/cache/shared/delta.rs | 6 +- .../cache/shared/index_file/codec/binary.rs | 18 +- .../cache/shared/index_file/codec/cursor.rs | 178 ++++---- .../cache/shared/index_file/codec/legacy.rs | 26 +- .../src/cache/shared/index_file/codec/mod.rs | 8 +- .../cache/shared/index_file/memory_backend.rs | 239 ++++++----- .../index_file/memory_backend/changes.rs | 2 +- .../memory_backend/document/codec.rs | 19 +- .../memory_backend/document/codec/cursor.rs | 178 ++++---- .../index_file/memory_backend/document/mod.rs | 6 +- .../memory_backend/document/model.rs | 36 +- .../shared/index_file/memory_backend/locks.rs | 12 +- .../src/cache/shared/index_file/mod.rs | 42 +- .../src/cache/shared/index_file/store.rs | 179 ++++---- crates/rginx-http/src/cache/shared/memory.rs | 195 ++++----- .../src/cache/shared/memory/config.rs | 135 +++--- crates/rginx-http/src/cache/state.rs | 152 ++++--- crates/rginx-http/src/cache/store.rs | 12 +- crates/rginx-http/src/cache/store/helpers.rs | 2 +- .../cache/store/maintenance/index_state.rs | 8 +- .../src/cache/store/maintenance/mod.rs | 15 +- .../cache/store/maintenance/store_update.rs | 8 +- crates/rginx-http/src/cache/store/range.rs | 14 +- .../rginx-http/src/cache/store/revalidate.rs | 11 +- .../rginx-http/src/cache/store/streaming.rs | 79 ++-- .../src/cache/store/streaming/body.rs | 218 +++++----- .../src/cache/store/streaming/finalize.rs | 7 +- crates/rginx-http/src/cache/tests/lookup.rs | 13 +- .../rginx-http/src/cache/tests/lookup/keys.rs | 4 +- crates/rginx-http/src/cache/tests/mod.rs | 24 +- .../rginx-http/src/cache/tests/storage_p1.rs | 4 +- .../rginx-http/src/cache/tests/storage_p2.rs | 3 +- .../tests/storage_p2/cross_process_fill.rs | 7 +- .../cache/tests/storage_p2/shared_index.rs | 10 +- .../storage_p2/shared_index/shared_memory.rs | 4 +- .../shared_index/sync_regressions.rs | 3 +- .../rginx-http/src/cache/tests/storage_p4.rs | 3 +- .../src/cache/tests/storage_p4/lifecycle.rs | 7 +- .../src/cache/tests/storage_p4/streaming.rs | 4 +- .../src/cache/tests/storage_p4/termination.rs | 164 ++++---- crates/rginx-http/src/client_ip.rs | 23 +- .../src/compression/accept_encoding.rs | 16 +- crates/rginx-http/src/compression/mod.rs | 14 +- crates/rginx-http/src/compression/options.rs | 6 +- crates/rginx-http/src/compression/tests.rs | 8 +- crates/rginx-http/src/handler/access_log.rs | 67 ++- .../rginx-http/src/handler/dispatch/date.rs | 24 +- .../rginx-http/src/handler/dispatch/file.rs | 27 +- .../src/handler/dispatch/file/headers.rs | 15 +- .../src/handler/dispatch/file/range.rs | 6 +- .../src/handler/dispatch/file/stream.rs | 23 +- crates/rginx-http/src/handler/dispatch/mod.rs | 32 +- .../rginx-http/src/handler/dispatch/phases.rs | 138 +++---- .../src/handler/dispatch/phases/error_page.rs | 6 +- .../src/handler/dispatch/response.rs | 6 +- .../rginx-http/src/handler/dispatch/select.rs | 2 +- crates/rginx-http/src/handler/grpc/error.rs | 6 +- .../rginx-http/src/handler/grpc/grpc_web.rs | 87 ++-- .../rginx-http/src/handler/grpc/metadata.rs | 24 +- .../src/handler/grpc/observability.rs | 102 +++-- crates/rginx-http/src/handler/mod.rs | 25 +- crates/rginx-http/src/handler/tests.rs | 10 +- .../src/handler/tests/routing/handle.rs | 5 +- .../src/handler/tests/routing/handle/file.rs | 8 +- .../tests/routing/handle/file/presentation.rs | 10 +- crates/rginx-http/src/lib.rs | 25 +- crates/rginx-http/src/pki/certificate.rs | 59 ++- .../src/pki/certificate/extensions.rs | 20 +- .../src/pki/certificate/identity.rs | 2 +- .../rginx-http/src/pki/certificate/inspect.rs | 21 +- crates/rginx-http/src/pki/certificate/name.rs | 6 +- .../rginx-http/src/proxy/clients/factory.rs | 6 +- .../src/proxy/clients/http3/connect.rs | 104 +++-- .../src/proxy/clients/http3/endpoint_cache.rs | 27 +- .../rginx-http/src/proxy/clients/http3/mod.rs | 44 +- .../src/proxy/clients/http3/request.rs | 2 +- .../src/proxy/clients/http3/response_body.rs | 24 +- .../src/proxy/clients/http3/session.rs | 28 +- .../src/proxy/clients/http3/tests.rs | 15 +- .../src/proxy/clients/http_client.rs | 50 ++- crates/rginx-http/src/proxy/clients/mod.rs | 172 ++++---- .../rginx-http/src/proxy/clients/profile.rs | 24 +- .../src/proxy/clients/tls/verifier.rs | 46 +-- crates/rginx-http/src/proxy/common.rs | 21 +- crates/rginx-http/src/proxy/common/uri.rs | 2 +- crates/rginx-http/src/proxy/error_mapping.rs | 2 +- .../rginx-http/src/proxy/forward/attempt.rs | 15 +- .../src/proxy/forward/attempt/background.rs | 20 +- .../src/proxy/forward/attempt/cache_lookup.rs | 6 +- .../src/proxy/forward/attempt/logging.rs | 6 +- .../src/proxy/forward/attempt/primary.rs | 18 +- crates/rginx-http/src/proxy/forward/cache.rs | 122 +++--- crates/rginx-http/src/proxy/forward/error.rs | 11 +- .../src/proxy/forward/error/tests.rs | 6 +- .../rginx-http/src/proxy/forward/failure.rs | 17 +- crates/rginx-http/src/proxy/forward/grpc.rs | 12 +- crates/rginx-http/src/proxy/forward/mod.rs | 12 +- .../rginx-http/src/proxy/forward/response.rs | 21 +- crates/rginx-http/src/proxy/forward/setup.rs | 31 +- .../rginx-http/src/proxy/forward/streaming.rs | 6 +- .../rginx-http/src/proxy/forward/success.rs | 20 +- crates/rginx-http/src/proxy/forward/types.rs | 10 +- .../src/proxy/grpc_web/body/request.rs | 16 +- .../src/proxy/grpc_web/body/response.rs | 8 +- .../src/proxy/grpc_web/body/text_decode.rs | 14 +- .../src/proxy/grpc_web/body/text_encode.rs | 8 +- crates/rginx-http/src/proxy/grpc_web/codec.rs | 16 +- crates/rginx-http/src/proxy/grpc_web/mod.rs | 6 +- crates/rginx-http/src/proxy/health.rs | 17 +- .../src/proxy/health/active_probe.rs | 5 +- .../src/proxy/health/grpc_health_codec.rs | 16 +- .../src/proxy/health/registry/guards.rs | 26 +- .../src/proxy/health/registry/mod.rs | 72 ++-- .../src/proxy/health/registry/policy.rs | 2 +- .../src/proxy/health/registry/selection.rs | 211 +++++----- .../src/proxy/health/registry/snapshot.rs | 25 +- .../src/proxy/health/registry/state.rs | 284 ++++++------- crates/rginx-http/src/proxy/health/request.rs | 7 +- crates/rginx-http/src/proxy/mod.rs | 45 +- .../rginx-http/src/proxy/request_body/mod.rs | 12 +- .../src/proxy/request_body/model.rs | 25 +- .../src/proxy/request_body/prepare.rs | 86 ++-- .../src/proxy/request_body/replay.rs | 12 +- .../src/proxy/request_body/streaming.rs | 15 +- crates/rginx-http/src/proxy/resolver.rs | 57 +-- .../rginx-http/src/proxy/resolver/endpoint.rs | 5 +- .../rginx-http/src/proxy/resolver/runtime.rs | 150 +++---- crates/rginx-http/src/proxy/tests/mod.rs | 58 ++- .../src/proxy/tests/request_headers.rs | 4 +- crates/rginx-http/src/proxy/tests/support.rs | 2 +- crates/rginx-http/src/proxy/upgrade.rs | 2 +- crates/rginx-http/src/rate_limit/local.rs | 61 +-- crates/rginx-http/src/rate_limit/mod.rs | 83 ++-- .../src/rate_limit/shared/bucket.rs | 51 ++- .../src/rate_limit/shared/document.rs | 9 +- .../rginx-http/src/rate_limit/shared/lock.rs | 12 +- .../rginx-http/src/rate_limit/shared/mod.rs | 70 ++-- crates/rginx-http/src/request_target.rs | 10 +- crates/rginx-http/src/router.rs | 64 +-- crates/rginx-http/src/router/tests.rs | 4 +- crates/rginx-http/src/server/connection.rs | 12 +- .../src/server/http3/accept_loop.rs | 2 +- crates/rginx-http/src/server/http3/body.rs | 12 +- .../rginx-http/src/server/http3/connection.rs | 7 +- .../rginx-http/src/server/http3/endpoint.rs | 2 +- .../rginx-http/src/server/http3/host_key.rs | 2 +- crates/rginx-http/src/server/http3/mod.rs | 22 +- crates/rginx-http/src/server/http3/request.rs | 2 +- .../rginx-http/src/server/http3/response.rs | 2 +- .../rginx-http/src/server/proxy_protocol.rs | 2 +- crates/rginx-http/src/server/tests.rs | 66 +-- crates/rginx-http/src/state/agent.rs | 149 +++---- crates/rginx-http/src/state/cache.rs | 97 ++--- .../src/state/connections/guards.rs | 11 +- .../state/connections/operations/acquire.rs | 86 ++-- .../state/connections/operations/record.rs | 224 +++++----- crates/rginx-http/src/state/counters/grpc.rs | 10 +- crates/rginx-http/src/state/counters/http.rs | 31 +- crates/rginx-http/src/state/counters/mod.rs | 9 +- .../rginx-http/src/state/counters/rolling.rs | 130 +++--- .../rginx-http/src/state/counters/traffic.rs | 120 +++--- .../src/state/counters/upstreams.rs | 41 +- .../rginx-http/src/state/counters/versions.rs | 8 +- crates/rginx-http/src/state/helpers/mod.rs | 16 +- crates/rginx-http/src/state/lifecycle/acme.rs | 73 ++-- crates/rginx-http/src/state/lifecycle/mtls.rs | 94 +++-- .../rginx-http/src/state/lifecycle/reload.rs | 268 ++++++------ .../rginx-http/src/state/lifecycle/status.rs | 144 ++++--- .../state/lifecycle/status/status_helpers.rs | 9 +- .../rginx-http/src/state/lifecycle/tasks.rs | 49 +-- .../src/state/lifecycle/topology.rs | 2 +- crates/rginx-http/src/state/mod.rs | 94 ++--- crates/rginx-http/src/state/snapshot_bus.rs | 165 ++++---- .../src/state/snapshot_bus/delta.rs | 32 +- .../src/state/snapshot_bus/identity.rs | 31 +- crates/rginx-http/src/state/snapshots/acme.rs | 16 +- .../rginx-http/src/state/snapshots/active.rs | 6 +- .../rginx-http/src/state/snapshots/apply.rs | 20 +- .../rginx-http/src/state/snapshots/delta.rs | 66 +-- crates/rginx-http/src/state/snapshots/http.rs | 32 +- .../rginx-http/src/state/snapshots/reload.rs | 8 +- .../rginx-http/src/state/snapshots/runtime.rs | 166 ++++---- crates/rginx-http/src/state/snapshots/tls.rs | 134 +++--- .../rginx-http/src/state/snapshots/traffic.rs | 52 +-- .../src/state/snapshots/upstreams.rs | 70 ++-- crates/rginx-http/src/state/structure.rs | 70 ++-- crates/rginx-http/src/state/tests.rs | 20 +- crates/rginx-http/src/state/tests/status.rs | 4 +- .../src/state/tls_runtime/bindings.rs | 10 +- .../src/state/tls_runtime/certificates.rs | 6 +- .../src/state/tls_runtime/listeners.rs | 2 +- .../rginx-http/src/state/tls_runtime/mod.rs | 10 +- .../rginx-http/src/state/tls_runtime/ocsp.rs | 24 +- .../src/state/tls_runtime/reload_boundary.rs | 2 + .../src/state/tls_runtime/upstreams.rs | 2 +- crates/rginx-http/src/state/traffic/record.rs | 78 ++-- crates/rginx-http/src/state/traffic/refs.rs | 4 +- .../rginx-http/src/state/traffic/snapshot.rs | 106 ++--- .../rginx-http/src/state/traffic/versions.rs | 111 ++--- .../rginx-http/src/state/upstreams/record.rs | 138 +++---- .../src/state/upstreams/snapshot.rs | 209 +++++----- crates/rginx-http/src/timeout/body.rs | 8 +- .../src/timeout/body/grpc_deadline.rs | 13 +- crates/rginx-http/src/timeout/body/idle.rs | 13 +- .../rginx-http/src/timeout/body/max_bytes.rs | 41 +- crates/rginx-http/src/timeout/io.rs | 24 +- crates/rginx-http/src/timeout/tests.rs | 94 ++--- crates/rginx-http/src/timeout/timers.rs | 10 +- crates/rginx-http/src/tls/acceptor.rs | 2 +- crates/rginx-http/src/tls/certificates.rs | 10 +- .../rginx-http/src/tls/certificates/tests.rs | 2 +- crates/rginx-http/src/tls/client_auth.rs | 16 +- crates/rginx-http/src/tls/mod.rs | 4 +- crates/rginx-http/src/tls/ocsp/der_helpers.rs | 7 +- crates/rginx-http/src/tls/ocsp/discover.rs | 5 +- crates/rginx-http/src/tls/ocsp/mod.rs | 20 +- crates/rginx-http/src/tls/ocsp/nonce.rs | 2 +- crates/rginx-http/src/tls/ocsp/request.rs | 6 +- crates/rginx-http/src/tls/ocsp/signer.rs | 8 +- crates/rginx-http/src/tls/ocsp/tests.rs | 10 +- .../src/tls/ocsp/tests/discovery.rs | 31 +- .../rginx-http/src/tls/ocsp/tests/support.rs | 4 +- .../src/tls/ocsp/tests/support/certs.rs | 2 +- .../src/tls/ocsp/tests/support/response.rs | 166 ++++---- crates/rginx-http/src/tls/ocsp/time.rs | 2 +- crates/rginx-http/src/tls/ocsp/validate.rs | 11 +- crates/rginx-http/src/tls/provider.rs | 11 +- crates/rginx-http/src/tls/session.rs | 41 +- crates/rginx-http/src/tls/sni.rs | 27 +- crates/rginx-http/src/tls/tests.rs | 8 +- crates/rginx-http/src/transition.rs | 17 +- crates/rginx-observability/Cargo.toml | 3 + crates/rginx-observability/src/lib.rs | 9 + crates/rginx-runtime/Cargo.toml | 3 + crates/rginx-runtime/src/acme/challenge.rs | 8 +- crates/rginx-runtime/src/acme/mod.rs | 12 +- crates/rginx-runtime/src/acme/scheduler.rs | 39 +- crates/rginx-runtime/src/acme/storage.rs | 2 +- crates/rginx-runtime/src/acme/types.rs | 17 +- crates/rginx-runtime/src/admin/mod.rs | 6 +- crates/rginx-runtime/src/admin/model.rs | 70 ++-- crates/rginx-runtime/src/admin/socket.rs | 31 +- crates/rginx-runtime/src/agent.rs | 5 +- crates/rginx-runtime/src/agent/tests.rs | 85 ++-- crates/rginx-runtime/src/apply.rs | 8 +- .../rginx-runtime/src/apply/staged_change.rs | 14 +- .../src/bootstrap/listeners/activate.rs | 4 +- .../src/bootstrap/listeners/bind_udp.rs | 2 +- .../src/bootstrap/listeners/group.rs | 25 +- .../src/bootstrap/listeners/join.rs | 4 +- .../src/bootstrap/listeners/mod.rs | 4 +- .../src/bootstrap/listeners/prepare.rs | 105 +++-- .../src/bootstrap/listeners/reconcile.rs | 2 +- crates/rginx-runtime/src/bootstrap/mod.rs | 10 +- .../rginx-runtime/src/bootstrap/shutdown.rs | 9 +- crates/rginx-runtime/src/health.rs | 24 +- crates/rginx-runtime/src/lib.rs | 9 + crates/rginx-runtime/src/ocsp/mod.rs | 18 +- crates/rginx-runtime/src/restart.rs | 7 +- crates/rginx-runtime/src/state.rs | 9 +- crates/rginx-sdk/Cargo.toml | 4 +- crates/rginx-sdk/src/client.rs | 373 ++++++++--------- crates/rginx-sdk/src/config.rs | 53 +-- crates/rginx-sdk/src/error.rs | 29 +- crates/rginx-sdk/src/lib.rs | 9 + crates/rginx-sdk/src/models.rs | 80 ++-- crates/rginx-sdk/src/websocket.rs | 73 ++-- 535 files changed, 12002 insertions(+), 11002 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index b3ff1391..2af68949 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,38 @@ documentation = "https://github.com/vansour/rginx#readme" readme = "README.md" rust-version = "1.94.1" +[workspace.lints.clippy] +pedantic = "warn" + +# Restriction lints: correctness and reliability. +arithmetic_side_effects = "warn" +as_conversions = "warn" +assertions_on_result_states = "warn" +clone_on_ref_ptr = "warn" +error_impl_error = "warn" +indexing_slicing = "warn" +integer_division = "warn" +integer_division_remainder_used = "warn" +let_underscore_must_use = "warn" +map_err_ignore = "warn" +mem_forget = "warn" +missing_assert_message = "warn" +multiple_unsafe_ops_per_block = "warn" +return_and_then = "warn" +string_slice = "warn" +undocumented_unsafe_blocks = "warn" +unwrap_in_result = "warn" + +# Restriction lints: stricter production hygiene. +create_dir = "warn" +expect_used = "warn" +float_arithmetic = "warn" +panic = "warn" +print_stderr = "warn" +print_stdout = "warn" +unreachable = "warn" +unwrap_used = "warn" + [workspace.dependencies] anyhow = "1.0" base64 = "0.22" diff --git a/clippy.toml b/clippy.toml index aac355c4..d43569f7 100644 --- a/clippy.toml +++ b/clippy.toml @@ -2,3 +2,4 @@ # policy surface. msrv = "1.94.1" too-many-arguments-threshold = 8 +absolute-paths-max-segments = 4 diff --git a/crates/rginx-agent/Cargo.toml b/crates/rginx-agent/Cargo.toml index 07e12bd1..c6fd6647 100644 --- a/crates/rginx-agent/Cargo.toml +++ b/crates/rginx-agent/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true rust-version.workspace = true description = "Remote edge control plane primitives for rginx." +[lints] +workspace = true + [dependencies] rginx-config = { path = "../rginx-config" } rginx-http = { path = "../rginx-http" } diff --git a/crates/rginx-agent/src/agent_core.rs b/crates/rginx-agent/src/agent_core.rs index dd18d2f7..ec5c069e 100644 --- a/crates/rginx-agent/src/agent_core.rs +++ b/crates/rginx-agent/src/agent_core.rs @@ -20,106 +20,12 @@ const RELOAD_COMPLETION_TIMEOUT: Duration = Duration::from_secs(30); #[derive(Clone)] pub struct AgentCore { - state: SharedState, - reload_executor: Arc, config_apply_executor: Arc, + reload_executor: Arc, + state: SharedState, } impl AgentCore { - pub fn new(state: SharedState, reload_executor: Arc) -> Self { - Self { - state, - reload_executor, - config_apply_executor: Arc::new(UnsupportedConfigApplyExecutor), - } - } - - pub fn with_config_apply_executor( - mut self, - config_apply_executor: Arc, - ) -> Self { - self.config_apply_executor = config_apply_executor; - self - } - - pub fn shared_state(&self) -> &SharedState { - &self.state - } - - pub async fn status(&self) -> Result { - Ok(NodeStatusView::from(self.state.status_snapshot().await)) - } - - pub async fn snapshot(&self, window_secs: Option) -> Result { - Ok(NodeSnapshotView { - snapshot_version: self.state.current_snapshot_version(), - status: self.state.status_snapshot().await, - counters: self.state.counters_snapshot(), - traffic: self.state.traffic_stats_snapshot_with_window(window_secs), - peer_health: self.state.peer_health_snapshot().await, - upstreams: self.state.upstream_stats_snapshot_with_window(window_secs), - cache: self.state.cache_stats_snapshot().await, - }) - } - - pub async fn delta_since( - &self, - since_version: u64, - window_secs: Option, - ) -> Result { - let delta = self.state.snapshot_delta_since( - since_version, - Some(&SnapshotModule::all()), - window_secs, - ); - Ok(NodeDeltaView::from(delta)) - } - - pub async fn wait_for_snapshot_change( - &self, - since_version: u64, - timeout: Option, - ) -> Result { - let snapshot_version = self.state.wait_for_snapshot_change(since_version, timeout).await; - Ok(NodeWaitView { snapshot_version }) - } - - pub async fn traffic(&self, window_secs: Option) -> Result { - Ok(NodeTrafficView::from(self.state.traffic_stats_snapshot_with_window(window_secs))) - } - - pub async fn upstreams(&self, window_secs: Option) -> Result { - Ok(NodeUpstreamsView { - peer_health: self.state.peer_health_snapshot().await, - upstreams: self.state.upstream_stats_snapshot_with_window(window_secs), - }) - } - - pub async fn cache(&self) -> Result { - Ok(NodeCacheView::from(self.state.cache_stats_snapshot().await)) - } - - pub async fn system(&self) -> Result { - let config = self.state.current_config().await; - let cache_zone_paths = - config.cache_zones.values().map(|zone| zone.path.clone()).collect::>(); - tokio::task::spawn_blocking(move || collect_system_view(&cache_zone_paths)) - .await - .map_err(|error| Error::Server(error.to_string()))? - } - - pub async fn revision(&self) -> Result { - Ok(NodeRevisionView::from(self.state.revision_status_snapshot().await)) - } - - pub async fn reload(&self) -> Result { - let initial_status = self.state.status_snapshot().await.reload; - let fallback_revision = self.state.current_revision().await; - self.reload_executor.execute().await?; - self.wait_for_reload_attempt(initial_status.attempts_total).await?; - Ok(self.reload_action_status(fallback_revision).await) - } - pub async fn action_status(&self, accepted_revision: u64) -> NodeActionStatusView { NodeActionStatusView { accepted_revision, @@ -129,11 +35,6 @@ impl AgentCore { } } - pub async fn wrap_result(&self, result: T) -> NodeControlResultView { - let current_revision = self.state.current_revision().await; - NodeControlResultView { status: self.action_status(current_revision).await, result } - } - pub async fn apply_config( &self, request: ManagedResourceMutation, @@ -145,23 +46,35 @@ impl AgentCore { }) } - pub async fn purge_cache( + pub async fn cache(&self) -> Result { + Ok(NodeCacheView::from(self.state.cache_stats_snapshot().await)) + } + + pub async fn clear_cache_invalidations( &self, - command: CachePurgeCommand, - ) -> Result> { - let result = match command.target { - CachePurgeTarget::Zone => self.state.purge_cache_zone(&command.zone_name).await, - CachePurgeTarget::Key(key) => { - self.state.purge_cache_key(&command.zone_name, &key).await - } - CachePurgeTarget::Prefix(prefix) => { - self.state.purge_cache_prefix(&command.zone_name, &prefix).await - } - } - .map_err(Error::InvalidRequest)?; + command: CacheClearInvalidationsCommand, + ) -> Result> { + let result = self + .state + .clear_cache_invalidations(&command.zone_name) + .await + .map_err(Error::InvalidRequest)?; Ok(self.wrap_result(result).await) } + pub async fn delta_since( + &self, + since_version: u64, + window_secs: Option, + ) -> Result { + let delta = self.state.snapshot_delta_since( + since_version, + Some(&SnapshotModule::all()), + window_secs, + ); + Ok(NodeDeltaView::from(delta)) + } + pub async fn invalidate_cache( &self, command: CacheInvalidateCommand, @@ -184,24 +97,37 @@ impl AgentCore { Ok(self.wrap_result(result).await) } - pub async fn clear_cache_invalidations( + pub fn new(state: SharedState, reload_executor: Arc) -> Self { + Self { + state, + reload_executor, + config_apply_executor: Arc::new(UnsupportedConfigApplyExecutor), + } + } + + pub async fn purge_cache( &self, - command: CacheClearInvalidationsCommand, - ) -> Result> { - let result = self - .state - .clear_cache_invalidations(&command.zone_name) - .await - .map_err(Error::InvalidRequest)?; + command: CachePurgeCommand, + ) -> Result> { + let result = match command.target { + CachePurgeTarget::Zone => self.state.purge_cache_zone(&command.zone_name).await, + CachePurgeTarget::Key(key) => { + self.state.purge_cache_key(&command.zone_name, &key).await + } + CachePurgeTarget::Prefix(prefix) => { + self.state.purge_cache_prefix(&command.zone_name, &prefix).await + } + } + .map_err(Error::InvalidRequest)?; Ok(self.wrap_result(result).await) } - pub async fn set_desired_revision( - &self, - desired_revision: u64, - ) -> Result { - self.state.set_desired_revision(desired_revision); - Ok(self.action_status(desired_revision).await) + pub async fn reload(&self) -> Result { + let initial_status = self.state.status_snapshot().await.reload; + let fallback_revision = self.state.current_revision().await; + self.reload_executor.execute().await?; + self.wait_for_reload_attempt(initial_status.attempts_total).await?; + Ok(self.reload_action_status(fallback_revision).await) } async fn reload_action_status(&self, fallback_revision: u64) -> NodeActionStatusView { @@ -219,6 +145,59 @@ impl AgentCore { } } + pub async fn revision(&self) -> Result { + Ok(NodeRevisionView::from(self.state.revision_status_snapshot().await)) + } + + pub async fn set_desired_revision( + &self, + desired_revision: u64, + ) -> Result { + self.state.set_desired_revision(desired_revision); + Ok(self.action_status(desired_revision).await) + } + + #[must_use] + pub fn shared_state(&self) -> &SharedState { + &self.state + } + + pub async fn snapshot(&self, window_secs: Option) -> Result { + Ok(NodeSnapshotView { + snapshot_version: self.state.current_snapshot_version(), + status: self.state.status_snapshot().await, + counters: self.state.counters_snapshot(), + traffic: self.state.traffic_stats_snapshot_with_window(window_secs), + peer_health: self.state.peer_health_snapshot().await, + upstreams: self.state.upstream_stats_snapshot_with_window(window_secs), + cache: self.state.cache_stats_snapshot().await, + }) + } + + pub async fn status(&self) -> Result { + Ok(NodeStatusView::from(self.state.status_snapshot().await)) + } + + pub async fn system(&self) -> Result { + let config = self.state.current_config().await; + let cache_zone_paths = + config.cache_zones.values().map(|zone| zone.path.clone()).collect::>(); + tokio::task::spawn_blocking(move || collect_system_view(&cache_zone_paths)) + .await + .map_err(|error| Error::Server(error.to_string()))? + } + + pub async fn traffic(&self, window_secs: Option) -> Result { + Ok(NodeTrafficView::from(self.state.traffic_stats_snapshot_with_window(window_secs))) + } + + pub async fn upstreams(&self, window_secs: Option) -> Result { + Ok(NodeUpstreamsView { + peer_health: self.state.peer_health_snapshot().await, + upstreams: self.state.upstream_stats_snapshot_with_window(window_secs), + }) + } + async fn wait_for_reload_attempt(&self, attempts_before: u64) -> Result<()> { let started = Instant::now(); let mut observed_version = self.state.current_snapshot_version(); @@ -244,33 +223,55 @@ impl AgentCore { observed_version = changed_version; } } + + pub async fn wait_for_snapshot_change( + &self, + since_version: u64, + timeout: Option, + ) -> Result { + let snapshot_version = self.state.wait_for_snapshot_change(since_version, timeout).await; + Ok(NodeWaitView { snapshot_version }) + } + + pub fn with_config_apply_executor( + mut self, + config_apply_executor: Arc, + ) -> Self { + self.config_apply_executor = config_apply_executor; + self + } + + pub async fn wrap_result(&self, result: T) -> NodeControlResultView { + let current_revision = self.state.current_revision().await; + NodeControlResultView { status: self.action_status(current_revision).await, result } + } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct CachePurgeCommand { - pub zone_name: String, pub target: CachePurgeTarget, + pub zone_name: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum CachePurgeTarget { - Zone, Key(String), Prefix(String), + Zone, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct CacheInvalidateCommand { - pub zone_name: String, pub target: CacheInvalidateTarget, + pub zone_name: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum CacheInvalidateTarget { - Zone, Key(String), Prefix(String), Tag(String), + Zone, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/crates/rginx-agent/src/audit.rs b/crates/rginx-agent/src/audit.rs index 8a869f9e..8d4d4de9 100644 --- a/crates/rginx-agent/src/audit.rs +++ b/crates/rginx-agent/src/audit.rs @@ -12,36 +12,36 @@ pub(crate) struct AuditContext<'a> { pub(crate) method: &'a Method, pub(crate) path: &'a str, pub(crate) peer_addr: SocketAddr, - pub(crate) resource: Option, pub(crate) requirement: AuthorizationRequirement, + pub(crate) resource: Option, } #[derive(Debug, Serialize)] pub struct AuditLog { - pub timestamp: u64, - pub event: &'static str, - pub outcome: AuditOutcome, - pub request_id: Option, - // Authentication info pub actor_id: Option, pub auth_method: Option, - pub scopes: Vec, + pub duration_ms: Option, + pub error: Option, + pub event: &'static str, // Request info pub method: String, + pub outcome: AuditOutcome, pub path: String, pub peer_addr: String, - pub user_agent: Option, + pub request_id: Option, + + pub requirement: String, // Resource info pub resource: Option, - pub requirement: String, + pub scopes: Vec, // Response info pub status: Option, - pub duration_ms: Option, - pub error: Option, + pub timestamp: u64, + pub user_agent: Option, } #[derive(Debug, Serialize)] @@ -108,7 +108,7 @@ pub(crate) fn log_deny( event: "control_plane_audit", outcome: AuditOutcome::Deny, request_id: None, - actor_id: actor_id.map(|s| s.to_string()), + actor_id: actor_id.map(std::string::ToString::to_string), auth_method: if actor_id.is_some() { Some("api_key".to_string()) } else { None }, scopes: scopes.to_vec(), method: context.method.to_string(), @@ -131,9 +131,7 @@ pub(crate) fn log_deny( path = context.path, peer_addr = %context.peer_addr, resource = %context - .resource - .map(|resource| resource.label().to_string()) - .unwrap_or_else(|| "-".to_string()), + .resource.map_or_else(|| "-".to_string(), |resource| resource.label().to_string()), requirement = %context.requirement.label(), error = %error, "control plane request denied" @@ -199,7 +197,7 @@ fn write_audit_log(log: &AuditLog) { let _ = std::fs::OpenOptions::new().create(true).append(true).open(&audit_path).and_then( |mut f| { use std::io::Write; - writeln!(f, "{}", json) + writeln!(f, "{json}") }, ); } diff --git a/crates/rginx-agent/src/auth.rs b/crates/rginx-agent/src/auth.rs index 25838706..fd515e6a 100644 --- a/crates/rginx-agent/src/auth.rs +++ b/crates/rginx-agent/src/auth.rs @@ -9,11 +9,11 @@ pub(crate) use keyring::ApiKeyStore; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ActionScope { + CacheWrite, + ConfigWrite, MetricsRead, RuntimeRead, - CacheWrite, RuntimeReload, - ConfigWrite, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -36,15 +36,15 @@ pub enum ApiKeyStatus { #[derive(Debug, Clone)] pub(crate) struct ApiKeyRecord { - pub(crate) id: String, - pub(crate) secret: String, - pub(crate) scopes: Vec, - #[allow(dead_code)] + pub(crate) allowed_ips: Vec, + #[expect(dead_code, reason = "key creation timestamps are retained for keyring metadata")] pub(crate) created_at: u64, pub(crate) expires_at: Option, + pub(crate) id: String, pub(crate) last_used_at: Option, + pub(crate) scopes: Vec, + pub(crate) secret: String, pub(crate) status: ApiKeyStatus, - pub(crate) allowed_ips: Vec, } pub(crate) struct ControlPlaneIdentity<'a> { @@ -54,11 +54,14 @@ pub(crate) struct ControlPlaneIdentity<'a> { /// Authentication method used for a request #[derive(Debug, Clone)] -#[allow(private_interfaces)] +#[expect( + private_interfaces, + reason = "public auth model is re-exported while sensitive record internals stay crate-private" +)] pub enum AuthMethod { ApiKey(ApiKeyRecord), - ClientCertificate(crate::tls::ClientCertIdentity), Both { api_key: ApiKeyRecord, client_cert: crate::tls::ClientCertIdentity }, + ClientCertificate(crate::tls::ClientCertIdentity), } impl AuthMethod { @@ -70,6 +73,26 @@ impl AuthMethod { } } + #[expect(dead_code, reason = "reserved for auth audit and metrics labeling")] + pub(crate) fn auth_method_label(&self) -> &'static str { + match self { + AuthMethod::ApiKey(_) => "api_key", + AuthMethod::ClientCertificate(_) => "client_cert", + AuthMethod::Both { .. } => "both", + } + } + + pub(crate) fn authorizes(&self, requirement: AuthorizationRequirement) -> AuthDecision { + match self { + AuthMethod::ApiKey(record) => record.authorizes(requirement), + AuthMethod::ClientCertificate(_) => { + // Client certificates have full access + AuthDecision::Allow + } + AuthMethod::Both { api_key, .. } => api_key.authorizes(requirement), + } + } + pub(crate) fn scope_labels(&self) -> Vec { match self { AuthMethod::ApiKey(record) => { @@ -90,29 +113,18 @@ impl AuthMethod { } } } +} - pub(crate) fn authorizes(&self, requirement: AuthorizationRequirement) -> AuthDecision { - match self { - AuthMethod::ApiKey(record) => record.authorizes(requirement), - AuthMethod::ClientCertificate(_) => { - // Client certificates have full access - AuthDecision::Allow - } - AuthMethod::Both { api_key, .. } => api_key.authorizes(requirement), - } - } - - #[allow(dead_code)] - pub(crate) fn auth_method_label(&self) -> &'static str { +impl ActionScope { + pub(crate) fn label(self) -> &'static str { match self { - AuthMethod::ApiKey(_) => "api_key", - AuthMethod::ClientCertificate(_) => "client_cert", - AuthMethod::Both { .. } => "both", + Self::MetricsRead => "metrics.read", + Self::RuntimeRead => "runtime.read", + Self::CacheWrite => "cache.write", + Self::RuntimeReload => "runtime.reload", + Self::ConfigWrite => "config.write", } } -} - -impl ActionScope { pub(crate) fn parse(value: &str) -> Result { match value.trim() { "metrics.read" => Ok(Self::MetricsRead), @@ -123,16 +135,6 @@ impl ActionScope { other => Err(Error::Server(format!("unknown control plane api key scope `{other}`"))), } } - - pub(crate) fn label(self) -> &'static str { - match self { - Self::MetricsRead => "metrics.read", - Self::RuntimeRead => "runtime.read", - Self::CacheWrite => "cache.write", - Self::RuntimeReload => "runtime.reload", - Self::ConfigWrite => "config.write", - } - } } impl AuthorizationRequirement { @@ -145,14 +147,6 @@ impl AuthorizationRequirement { } impl ApiKeyRecord { - #[allow(dead_code)] - pub(crate) fn identity(&self) -> ControlPlaneIdentity<'_> { - ControlPlaneIdentity { - actor_id: &self.id, - scope_labels: self.scopes.iter().map(|scope| scope.label().to_string()).collect(), - } - } - pub(crate) fn authorizes(&self, requirement: AuthorizationRequirement) -> AuthDecision { match requirement { AuthorizationRequirement::AnyRead => { @@ -173,6 +167,16 @@ impl ApiKeyRecord { } } } + #[expect( + dead_code, + reason = "reserved for audit logging paths that need borrowed identity data" + )] + pub(crate) fn identity(&self) -> ControlPlaneIdentity<'_> { + ControlPlaneIdentity { + actor_id: &self.id, + scope_labels: self.scopes.iter().map(|scope| scope.label().to_string()).collect(), + } + } } pub(crate) fn api_key_from_headers(headers: &HeaderMap) -> Option<&str> { diff --git a/crates/rginx-agent/src/auth/keyring.rs b/crates/rginx-agent/src/auth/keyring.rs index f03947fd..58051aa8 100644 --- a/crates/rginx-agent/src/auth/keyring.rs +++ b/crates/rginx-agent/src/auth/keyring.rs @@ -18,35 +18,6 @@ pub struct ApiKeyStore { } impl ApiKeyStore { - pub(crate) fn load(path: &Path) -> Result { - let contents = std::fs::read_to_string(path)?; - let document: ApiKeysDocument = serde_json::from_str(&contents)?; - let mut by_id = BTreeMap::new(); - let mut by_secret = BTreeMap::new(); - - for entry in document.keys { - let record = ApiKeyRecord::from_entry(entry)?; - if by_id.insert(record.id.clone(), record.clone()).is_some() { - return Err(Error::Server(format!( - "duplicate control plane api key id `{}` in {}", - record.id, - path.display() - ))); - } - if by_secret.insert(secret_hash(&record.secret), record.id.clone()).is_some() { - return Err(Error::Server(format!( - "duplicate control plane api key secret in {}", - path.display() - ))); - } - } - - Ok(Self { - by_id: Arc::new(RwLock::new(by_id)), - by_secret: Arc::new(RwLock::new(by_secret)), - }) - } - pub(crate) async fn find_by_secret(&self, secret: &str) -> Option { let secret_hash = secret_hash(secret); let by_secret = self.by_secret.read().await; @@ -72,34 +43,59 @@ impl ApiKeyStore { Some(record.clone()) } - pub(crate) async fn update_last_used(&self, key_id: &str) { - let mut by_id = self.by_id.write().await; - if let Some(record) = by_id.get_mut(key_id) { - record.last_used_at = Some(current_timestamp_ms()); - } - } - - #[allow(dead_code)] + #[expect(dead_code, reason = "reserved for key-management API handlers")] pub(crate) async fn list_keys(&self) -> Vec { let by_id = self.by_id.read().await; by_id.values().cloned().collect() } - #[allow(dead_code)] + pub(crate) fn load(path: &Path) -> Result { + let contents = std::fs::read_to_string(path)?; + let document: ApiKeysDocument = serde_json::from_str(&contents)?; + let mut by_id = BTreeMap::new(); + let mut by_secret = BTreeMap::new(); + + for entry in document.keys { + let record = ApiKeyRecord::from_entry(entry)?; + if by_id.insert(record.id.clone(), record.clone()).is_some() { + return Err(Error::Server(format!( + "duplicate control plane api key id `{}` in {}", + record.id, + path.display() + ))); + } + if by_secret.insert(secret_hash(&record.secret), record.id.clone()).is_some() { + return Err(Error::Server(format!( + "duplicate control plane api key secret in {}", + path.display() + ))); + } + } + + Ok(Self { + by_id: Arc::new(RwLock::new(by_id)), + by_secret: Arc::new(RwLock::new(by_secret)), + }) + } + + #[expect(dead_code, reason = "reserved for key-management API handlers")] pub(crate) async fn revoke_key(&self, key_id: &str) -> Result<()> { let mut by_id = self.by_id.write().await; let record = by_id .get_mut(key_id) - .ok_or_else(|| Error::InvalidRequest(format!("api key {} not found", key_id)))?; + .ok_or_else(|| Error::InvalidRequest(format!("api key {key_id} not found")))?; record.status = ApiKeyStatus::Revoked; tracing::info!(key_id = %key_id, "api key revoked"); Ok(()) } -} -fn current_timestamp_ms() -> u64 { - SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64 + pub(crate) async fn update_last_used(&self, key_id: &str) { + let mut by_id = self.by_id.write().await; + if let Some(record) = by_id.get_mut(key_id) { + record.last_used_at = Some(current_timestamp_ms()); + } + } } #[derive(Debug, Deserialize)] @@ -110,16 +106,16 @@ struct ApiKeysDocument { #[derive(Debug, Deserialize)] struct ApiKeyEntry { - id: String, - secret: String, #[serde(default)] - scopes: Vec, + allowed_ips: Vec, #[serde(default)] created_at: Option, #[serde(default)] expires_at: Option, + id: String, #[serde(default)] - allowed_ips: Vec, + scopes: Vec, + secret: String, } impl ApiKeyRecord { @@ -147,7 +143,7 @@ impl ApiKeyRecord { .into_iter() .map(|cidr| cidr.parse()) .collect::, _>>() - .map_err(|e| Error::Server(format!("invalid CIDR in allowed_ips: {}", e)))?; + .map_err(|e| Error::Server(format!("invalid CIDR in allowed_ips: {e}")))?; Ok(Self { id, @@ -162,6 +158,10 @@ impl ApiKeyRecord { } } +fn current_timestamp_ms() -> u64 { + SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_millis() as u64 +} + fn secret_hash(secret: &str) -> [u8; 32] { Sha256::digest(secret.as_bytes()).into() } diff --git a/crates/rginx-agent/src/circuit_breaker.rs b/crates/rginx-agent/src/circuit_breaker.rs index d465c70c..15c01b22 100644 --- a/crates/rginx-agent/src/circuit_breaker.rs +++ b/crates/rginx-agent/src/circuit_breaker.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; @@ -6,16 +8,16 @@ use tokio::sync::RwLock; #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] pub enum CircuitState { Closed, - Open, HalfOpen, + Open, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CircuitBreakerConfig { pub failure_threshold: u32, + pub half_open_max_requests: u32, pub success_threshold: u32, pub timeout_secs: u64, - pub half_open_max_requests: u32, } impl Default for CircuitBreakerConfig { @@ -31,38 +33,49 @@ impl Default for CircuitBreakerConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CircuitBreakerStats { - pub state: CircuitState, pub failure_count: u32, - pub success_count: u32, - pub total_requests: u64, + pub half_open_requests: u32, pub last_failure_time: Option, pub last_state_change: u64, - pub half_open_requests: u32, + pub state: CircuitState, + pub success_count: u32, + pub total_requests: u64, } pub struct CircuitBreaker { config: CircuitBreakerConfig, - state: Arc>, failure_count: Arc>, - success_count: Arc>, - total_requests: Arc>, + half_open_requests: Arc>, last_failure_time: Arc>>, last_state_change: Arc>, - half_open_requests: Arc>, + state: Arc>, + success_count: Arc>, + total_requests: Arc>, } impl CircuitBreaker { - pub fn new(config: CircuitBreakerConfig) -> Self { - let now = current_timestamp(); - Self { - config, - state: Arc::new(RwLock::new(CircuitState::Closed)), - failure_count: Arc::new(RwLock::new(0)), - success_count: Arc::new(RwLock::new(0)), - total_requests: Arc::new(RwLock::new(0)), - last_failure_time: Arc::new(RwLock::new(None)), - last_state_change: Arc::new(RwLock::new(now)), - half_open_requests: Arc::new(RwLock::new(0)), + async fn allow_request(&self) -> bool { + let state = *self.state.read().await; + + match state { + CircuitState::Closed => true, + CircuitState::Open => { + if self.should_attempt_reset().await { + self.transition_to_half_open().await; + true + } else { + false + } + } + CircuitState::HalfOpen => { + let mut half_open_requests = self.half_open_requests.write().await; + if *half_open_requests < self.config.half_open_max_requests { + *half_open_requests = half_open_requests.saturating_add(1); + true + } else { + false + } + } } } @@ -74,7 +87,10 @@ impl CircuitBreaker { return Err(CircuitBreakerError::CircuitOpen); } - *self.total_requests.write().await += 1; + { + let mut total_requests = self.total_requests.write().await; + *total_requests = total_requests.saturating_add(1); + } match f.await { Ok(result) => { @@ -88,37 +104,55 @@ impl CircuitBreaker { } } - async fn allow_request(&self) -> bool { + pub async fn get_state(&self) -> CircuitState { + *self.state.read().await + } + + pub async fn get_stats(&self) -> CircuitBreakerStats { + CircuitBreakerStats { + state: *self.state.read().await, + failure_count: *self.failure_count.read().await, + success_count: *self.success_count.read().await, + total_requests: *self.total_requests.read().await, + last_failure_time: *self.last_failure_time.read().await, + last_state_change: *self.last_state_change.read().await, + half_open_requests: *self.half_open_requests.read().await, + } + } + + #[must_use] + pub fn new(config: CircuitBreakerConfig) -> Self { + let now = current_timestamp(); + Self { + config, + state: Arc::new(RwLock::new(CircuitState::Closed)), + failure_count: Arc::new(RwLock::new(0)), + success_count: Arc::new(RwLock::new(0)), + total_requests: Arc::new(RwLock::new(0)), + last_failure_time: Arc::new(RwLock::new(None)), + last_state_change: Arc::new(RwLock::new(now)), + half_open_requests: Arc::new(RwLock::new(0)), + } + } + + async fn on_failure(&self) { let state = *self.state.read().await; + *self.last_failure_time.write().await = Some(current_timestamp()); match state { - CircuitState::Closed => true, - CircuitState::Open => { - if self.should_attempt_reset().await { - self.transition_to_half_open().await; - true - } else { - false + CircuitState::Closed => { + let mut failure_count = self.failure_count.write().await; + *failure_count = failure_count.saturating_add(1); + + if *failure_count >= self.config.failure_threshold { + drop(failure_count); + self.transition_to_open().await; } } CircuitState::HalfOpen => { - let mut half_open_requests = self.half_open_requests.write().await; - if *half_open_requests < self.config.half_open_max_requests { - *half_open_requests += 1; - true - } else { - false - } + self.transition_to_open().await; } - } - } - - async fn should_attempt_reset(&self) -> bool { - if let Some(last_failure) = *self.last_failure_time.read().await { - let now = current_timestamp(); - now - last_failure >= self.config.timeout_secs - } else { - false + CircuitState::Open => {} } } @@ -132,7 +166,7 @@ impl CircuitBreaker { CircuitState::HalfOpen => { let should_close = { let mut success_count = self.success_count.write().await; - *success_count += 1; + *success_count = success_count.saturating_add(1); *success_count >= self.config.success_threshold }; @@ -144,32 +178,26 @@ impl CircuitBreaker { } } - async fn on_failure(&self) { - let state = *self.state.read().await; - *self.last_failure_time.write().await = Some(current_timestamp()); - - match state { - CircuitState::Closed => { - let mut failure_count = self.failure_count.write().await; - *failure_count += 1; + pub async fn reset(&self) { + self.transition_to_closed().await; + } - if *failure_count >= self.config.failure_threshold { - drop(failure_count); - self.transition_to_open().await; - } - } - CircuitState::HalfOpen => { - self.transition_to_open().await; - } - CircuitState::Open => {} + async fn should_attempt_reset(&self) -> bool { + if let Some(last_failure) = *self.last_failure_time.read().await { + let now = current_timestamp(); + now.saturating_sub(last_failure) >= self.config.timeout_secs + } else { + false } } - async fn transition_to_open(&self) { - *self.state.write().await = CircuitState::Open; + async fn transition_to_closed(&self) { + *self.state.write().await = CircuitState::Closed; *self.last_state_change.write().await = current_timestamp(); + *self.failure_count.write().await = 0; + *self.success_count.write().await = 0; *self.half_open_requests.write().await = 0; - tracing::warn!("Circuit breaker transitioned to OPEN state"); + tracing::info!("Circuit breaker transitioned to CLOSED state"); } async fn transition_to_half_open(&self) { @@ -180,33 +208,11 @@ impl CircuitBreaker { tracing::info!("Circuit breaker transitioned to HALF_OPEN state"); } - async fn transition_to_closed(&self) { - *self.state.write().await = CircuitState::Closed; + async fn transition_to_open(&self) { + *self.state.write().await = CircuitState::Open; *self.last_state_change.write().await = current_timestamp(); - *self.failure_count.write().await = 0; - *self.success_count.write().await = 0; *self.half_open_requests.write().await = 0; - tracing::info!("Circuit breaker transitioned to CLOSED state"); - } - - pub async fn get_state(&self) -> CircuitState { - *self.state.read().await - } - - pub async fn get_stats(&self) -> CircuitBreakerStats { - CircuitBreakerStats { - state: *self.state.read().await, - failure_count: *self.failure_count.read().await, - success_count: *self.success_count.read().await, - total_requests: *self.total_requests.read().await, - last_failure_time: *self.last_failure_time.read().await, - last_state_change: *self.last_state_change.read().await, - half_open_requests: *self.half_open_requests.read().await, - } - } - - pub async fn reset(&self) { - self.transition_to_closed().await; + tracing::warn!("Circuit breaker transitioned to OPEN state"); } } @@ -222,8 +228,20 @@ pub struct CircuitBreakerRegistry { } impl CircuitBreakerRegistry { - pub fn new(default_config: CircuitBreakerConfig) -> Self { - Self { breakers: Arc::new(RwLock::new(HashMap::new())), default_config } + pub async fn get(&self, name: &str) -> Option> { + let breakers = self.breakers.read().await; + breakers.get(name).map(Arc::clone) + } + + pub async fn get_all_stats(&self) -> HashMap { + let breakers = self.breakers.read().await; + let mut stats = HashMap::new(); + + for (name, breaker) in breakers.iter() { + stats.insert(name.clone(), breaker.get_stats().await); + } + + stats } pub async fn get_or_create(&self, name: &str) -> Arc { @@ -240,25 +258,14 @@ impl CircuitBreakerRegistry { .clone() } - pub async fn get(&self, name: &str) -> Option> { - let breakers = self.breakers.read().await; - breakers.get(name).map(Arc::clone) - } - pub async fn list(&self) -> Vec { let breakers = self.breakers.read().await; breakers.keys().cloned().collect() } - pub async fn get_all_stats(&self) -> HashMap { - let breakers = self.breakers.read().await; - let mut stats = HashMap::new(); - - for (name, breaker) in breakers.iter() { - stats.insert(name.clone(), breaker.get_stats().await); - } - - stats + #[must_use] + pub fn new(default_config: CircuitBreakerConfig) -> Self { + Self { breakers: Arc::new(RwLock::new(HashMap::new())), default_config } } pub async fn reset(&self, name: &str) -> Result<(), String> { @@ -267,7 +274,7 @@ impl CircuitBreakerRegistry { breaker.reset().await; Ok(()) } else { - Err(format!("Circuit breaker {} not found", name)) + Err(format!("Circuit breaker {name} not found")) } } } @@ -281,6 +288,3 @@ impl Default for CircuitBreakerRegistry { fn current_timestamp() -> u64 { std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/config_history.rs b/crates/rginx-agent/src/config_history.rs index 46587409..4a8f329e 100644 --- a/crates/rginx-agent/src/config_history.rs +++ b/crates/rginx-agent/src/config_history.rs @@ -1,3 +1,7 @@ +mod diff; + +#[cfg(test)] +mod tests; use std::collections::BTreeMap; use std::path::PathBuf; use std::sync::Arc; @@ -8,38 +12,36 @@ use tokio::sync::RwLock; use crate::error::{Error, Result}; use crate::registry::current_timestamp_ms; -mod diff; - use diff::{calculate_diff, calculate_hash}; /// Configuration revision record #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigRevision { - pub revision: u64, pub applied_at: u64, pub applied_by: String, - pub status: ConfigApplyStatus, pub config_snapshot: ConfigSnapshot, pub diff_from_previous: Option, pub metadata: ConfigMetadata, + pub revision: u64, + pub status: ConfigApplyStatus, } /// Configuration apply status #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ConfigApplyStatus { - Success, Failed, RolledBack, + Success, } /// Configuration snapshot #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigSnapshot { - pub hash: String, - pub size_bytes: usize, #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, + pub hash: String, + pub size_bytes: usize, } /// Configuration diff between two versions @@ -52,12 +54,12 @@ pub struct ConfigDiff { /// A single configuration change #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigChange { - pub op: ChangeOperation, - pub path: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub old_value: Option, #[serde(skip_serializing_if = "Option::is_none")] pub new_value: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub old_value: Option, + pub op: ChangeOperation, + pub path: String, } /// Change operation type @@ -73,8 +75,8 @@ pub enum ChangeOperation { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DiffSummary { pub additions: usize, - pub removals: usize, pub modifications: usize, + pub removals: usize, } /// Configuration metadata @@ -82,22 +84,61 @@ pub struct DiffSummary { pub struct ConfigMetadata { #[serde(skip_serializing_if = "Option::is_none")] pub reason: Option, - #[serde(default)] - pub tags: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub rollback_from: Option, + #[serde(default)] + pub tags: Vec, } /// Configuration history storage pub struct ConfigHistory { - storage_path: PathBuf, - revisions: Arc>>, max_revisions: usize, + revisions: Arc>>, + storage_path: PathBuf, } impl ConfigHistory { - pub fn new(storage_path: PathBuf, max_revisions: usize) -> Self { - Self { storage_path, revisions: Arc::new(RwLock::new(BTreeMap::new())), max_revisions } + /// Get total revision count + pub async fn count(&self) -> usize { + let revisions = self.revisions.read().await; + revisions.len() + } + + /// Calculate diff between two revisions + pub async fn diff(&self, from: u64, to: u64) -> Result { + let revisions = self.revisions.read().await; + + let from_config = revisions + .get(&from) + .ok_or_else(|| Error::InvalidRequest(format!("revision {from} not found")))?; + let to_config = revisions + .get(&to) + .ok_or_else(|| Error::InvalidRequest(format!("revision {to} not found")))?; + + let from_content = from_config + .config_snapshot + .content + .as_ref() + .ok_or_else(|| Error::InvalidRequest(format!("revision {from} has no content")))?; + let to_content = to_config + .config_snapshot + .content + .as_ref() + .ok_or_else(|| Error::InvalidRequest(format!("revision {to} has no content")))?; + + Ok(calculate_diff(from_content, to_content)) + } + + /// Get a specific revision + pub async fn get(&self, revision: u64) -> Option { + let revisions = self.revisions.read().await; + revisions.get(&revision).cloned() + } + + /// List revisions with pagination + pub async fn list(&self, limit: usize, offset: usize) -> Vec { + let revisions = self.revisions.read().await; + revisions.values().rev().skip(offset).take(limit).cloned().collect() } /// Load history from disk @@ -109,7 +150,7 @@ impl ConfigHistory { let content = tokio::fs::read_to_string(&history_file).await.map_err(Error::Io)?; let revisions: Vec = serde_json::from_str(&content) - .map_err(|e| Error::InvalidRequest(format!("failed to parse history: {}", e)))?; + .map_err(|e| Error::InvalidRequest(format!("failed to parse history: {e}")))?; let mut map = self.revisions.write().await; for revision in revisions { @@ -121,20 +162,9 @@ impl ConfigHistory { Ok(()) } - /// Save history to disk - pub async fn save(&self) -> Result<()> { - let revisions = self.revisions.read().await; - let list: Vec<_> = revisions.values().cloned().collect(); - - let content = serde_json::to_string_pretty(&list) - .map_err(|e| Error::Server(format!("failed to serialize history: {}", e)))?; - - tokio::fs::create_dir_all(&self.storage_path).await.map_err(Error::Io)?; - - let history_file = self.storage_path.join("config_history.json"); - tokio::fs::write(&history_file, content).await.map_err(Error::Io)?; - - Ok(()) + #[must_use] + pub fn new(storage_path: PathBuf, max_revisions: usize) -> Self { + Self { storage_path, revisions: Arc::new(RwLock::new(BTreeMap::new())), max_revisions } } /// Record a new configuration revision @@ -147,7 +177,7 @@ impl ConfigHistory { ) -> Result<()> { let config_hash = calculate_hash(&config); let config_json = serde_json::to_string(&config) - .map_err(|e| Error::Server(format!("failed to serialize config: {}", e)))?; + .map_err(|e| Error::Server(format!("failed to serialize config: {e}")))?; let config_snapshot = ConfigSnapshot { hash: config_hash, @@ -183,7 +213,7 @@ impl ConfigHistory { // Clean up old revisions while revisions.len() > self.max_revisions { - if let Some(oldest) = revisions.keys().next().cloned() { + if let Some(oldest) = revisions.keys().next().copied() { revisions.remove(&oldest); tracing::debug!(revision = oldest, "removed old config revision"); } @@ -196,48 +226,19 @@ impl ConfigHistory { Ok(()) } - /// Get a specific revision - pub async fn get(&self, revision: u64) -> Option { - let revisions = self.revisions.read().await; - revisions.get(&revision).cloned() - } - - /// List revisions with pagination - pub async fn list(&self, limit: usize, offset: usize) -> Vec { - let revisions = self.revisions.read().await; - revisions.values().rev().skip(offset).take(limit).cloned().collect() - } - - /// Get total revision count - pub async fn count(&self) -> usize { + /// Save history to disk + pub async fn save(&self) -> Result<()> { let revisions = self.revisions.read().await; - revisions.len() - } + let list: Vec<_> = revisions.values().cloned().collect(); - /// Calculate diff between two revisions - pub async fn diff(&self, from: u64, to: u64) -> Result { - let revisions = self.revisions.read().await; + let content = serde_json::to_string_pretty(&list) + .map_err(|e| Error::Server(format!("failed to serialize history: {e}")))?; - let from_config = revisions - .get(&from) - .ok_or_else(|| Error::InvalidRequest(format!("revision {} not found", from)))?; - let to_config = revisions - .get(&to) - .ok_or_else(|| Error::InvalidRequest(format!("revision {} not found", to)))?; + tokio::fs::create_dir_all(&self.storage_path).await.map_err(Error::Io)?; - let from_content = - from_config.config_snapshot.content.as_ref().ok_or_else(|| { - Error::InvalidRequest(format!("revision {} has no content", from)) - })?; - let to_content = to_config - .config_snapshot - .content - .as_ref() - .ok_or_else(|| Error::InvalidRequest(format!("revision {} has no content", to)))?; + let history_file = self.storage_path.join("config_history.json"); + tokio::fs::write(&history_file, content).await.map_err(Error::Io)?; - Ok(calculate_diff(from_content, to_content)) + Ok(()) } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/config_history/diff.rs b/crates/rginx-agent/src/config_history/diff.rs index c200306c..1fef60f2 100644 --- a/crates/rginx-agent/src/config_history/diff.rs +++ b/crates/rginx-agent/src/config_history/diff.rs @@ -16,7 +16,7 @@ pub(super) fn calculate_diff(old: &serde_json::Value, new: &serde_json::Value) - diff_values("", old, new, &mut changes, &mut additions, &mut removals, &mut modifications); - ConfigDiff { changes, summary: DiffSummary { additions, removals, modifications } } + ConfigDiff { changes, summary: DiffSummary { additions, modifications, removals } } } fn diff_values( @@ -35,7 +35,7 @@ fn diff_values( diff_object_values(path, old_map, new_map, changes, additions, removals, modifications); } _ if old != new => { - *modifications += 1; + *modifications = modifications.saturating_add(1); changes.push(ConfigChange { op: ChangeOperation::Replace, path: path.to_string(), @@ -71,7 +71,7 @@ fn diff_object_values( ); } } else { - *removals += 1; + *removals = removals.saturating_add(1); changes.push(ConfigChange { op: ChangeOperation::Remove, path: next_path, @@ -85,7 +85,7 @@ fn diff_object_values( if !old_map.contains_key(key) { let next_path = if path.is_empty() { format!("/{key}") } else { format!("{path}/{key}") }; - *additions += 1; + *additions = additions.saturating_add(1); changes.push(ConfigChange { op: ChangeOperation::Add, path: next_path, diff --git a/crates/rginx-agent/src/config_validator.rs b/crates/rginx-agent/src/config_validator.rs index abf4f0bd..4ea3a2dc 100644 --- a/crates/rginx-agent/src/config_validator.rs +++ b/crates/rginx-agent/src/config_validator.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use serde::{Deserialize, Serialize}; use crate::error::{Error, Result}; @@ -6,19 +8,19 @@ use crate::metrics; /// Configuration validation result #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ValidationResult { - pub valid: bool, pub issues: Vec, + pub valid: bool, pub warnings: Vec, } /// A validation issue or warning #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ValidationIssue { - pub severity: IssueSeverity, pub category: String, pub message: String, #[serde(skip_serializing_if = "Option::is_none")] pub path: Option, + pub severity: IssueSeverity, } /// Issue severity level @@ -26,14 +28,49 @@ pub struct ValidationIssue { #[serde(rename_all = "lowercase")] pub enum IssueSeverity { Error, - Warning, Info, + Warning, } /// Configuration validator for dry-run validation pub struct ConfigValidator; impl ConfigValidator { + /// Assess the impact of applying this configuration + pub async fn assess_impact( + &self, + old_config: &serde_json::Value, + new_config: &serde_json::Value, + ) -> ImpactAssessment { + let mut requires_reload = false; + let mut affects_traffic = false; + let mut breaking_changes = Vec::new(); + + // Simple impact assessment + if old_config != new_config { + requires_reload = true; + + // Check if upstreams changed + if old_config.get("upstreams") != new_config.get("upstreams") { + affects_traffic = true; + breaking_changes.push("upstream configuration changed".to_string()); + } + + // Check if routes changed + if old_config.get("routes") != new_config.get("routes") { + affects_traffic = true; + breaking_changes.push("route configuration changed".to_string()); + } + } + + ImpactAssessment { + requires_reload, + affects_traffic, + breaking_changes, + estimated_downtime_ms: if affects_traffic { Some(100) } else { None }, + } + } + #[must_use] pub fn new() -> Self { Self } @@ -77,21 +114,24 @@ impl ConfigValidator { let valid = issues.is_empty(); metrics::record_config_validation(valid); - Ok(ValidationResult { valid, issues, warnings }) + Ok(ValidationResult { issues, valid, warnings }) } - fn validate_syntax(&self, config: &serde_json::Value) -> Result<()> { - // Basic syntax validation - check if it's a valid JSON object - if !config.is_object() { - return Err(Error::InvalidRequest("configuration must be a JSON object".to_string())); - } - - // Check for required top-level fields - let obj = config.as_object().unwrap(); - - // Validate that we have at least some configuration - if obj.is_empty() { - return Err(Error::InvalidRequest("configuration cannot be empty".to_string())); + async fn validate_resources(&self, config: &serde_json::Value) -> Result<()> { + // Validate that referenced resources exist + if let Some(obj) = config.as_object() { + // Check TLS certificates if present + if let Some(tls) = obj.get("tls") + && let Some(tls_obj) = tls.as_object() + && let Some(cert_path) = tls_obj.get("cert_path") + && let Some(path_str) = cert_path.as_str() + && !path_str.is_empty() + && !std::path::Path::new(path_str).exists() + { + return Err(Error::InvalidRequest(format!( + "certificate file not found: {path_str}" + ))); + } } Ok(()) @@ -124,8 +164,8 @@ impl ConfigValidator { warnings.push(ValidationIssue { severity: IssueSeverity::Warning, category: "semantics".to_string(), - message: format!("upstream '{}' has no peers", name), - path: Some(format!("/upstreams/{}/peers", name)), + message: format!("upstream '{name}' has no peers"), + path: Some(format!("/upstreams/{name}/peers")), }); } } @@ -135,60 +175,21 @@ impl ConfigValidator { Ok(warnings) } - async fn validate_resources(&self, config: &serde_json::Value) -> Result<()> { - // Validate that referenced resources exist - if let Some(obj) = config.as_object() { - // Check TLS certificates if present - if let Some(tls) = obj.get("tls") - && let Some(tls_obj) = tls.as_object() - && let Some(cert_path) = tls_obj.get("cert_path") - && let Some(path_str) = cert_path.as_str() - && !path_str.is_empty() - && !std::path::Path::new(path_str).exists() - { - return Err(Error::InvalidRequest(format!( - "certificate file not found: {}", - path_str - ))); - } + fn validate_syntax(&self, config: &serde_json::Value) -> Result<()> { + // Basic syntax validation - check if it's a valid JSON object + if !config.is_object() { + return Err(Error::InvalidRequest("configuration must be a JSON object".to_string())); } - Ok(()) - } - - /// Assess the impact of applying this configuration - pub async fn assess_impact( - &self, - old_config: &serde_json::Value, - new_config: &serde_json::Value, - ) -> ImpactAssessment { - let mut requires_reload = false; - let mut affects_traffic = false; - let mut breaking_changes = Vec::new(); - - // Simple impact assessment - if old_config != new_config { - requires_reload = true; - - // Check if upstreams changed - if old_config.get("upstreams") != new_config.get("upstreams") { - affects_traffic = true; - breaking_changes.push("upstream configuration changed".to_string()); - } + // Check for required top-level fields + let obj = config.as_object().unwrap(); - // Check if routes changed - if old_config.get("routes") != new_config.get("routes") { - affects_traffic = true; - breaking_changes.push("route configuration changed".to_string()); - } + // Validate that we have at least some configuration + if obj.is_empty() { + return Err(Error::InvalidRequest("configuration cannot be empty".to_string())); } - ImpactAssessment { - requires_reload, - affects_traffic, - breaking_changes, - estimated_downtime_ms: if affects_traffic { Some(100) } else { None }, - } + Ok(()) } } @@ -201,12 +202,9 @@ impl Default for ConfigValidator { /// Impact assessment for configuration changes #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ImpactAssessment { - pub requires_reload: bool, pub affects_traffic: bool, pub breaking_changes: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub estimated_downtime_ms: Option, + pub requires_reload: bool, } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/control_center/model.rs b/crates/rginx-agent/src/control_center/model.rs index e0a6c086..a8931ec9 100644 --- a/crates/rginx-agent/src/control_center/model.rs +++ b/crates/rginx-agent/src/control_center/model.rs @@ -6,22 +6,23 @@ use crate::{AgentCommand, AgentCommandResult, AgentCommandStatus, AgentCommandTy #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct ControlCenterNode { - pub node_id: String, - pub version: String, - pub region: Option, - pub pop: Option, - pub labels: BTreeMap, pub capabilities: Vec, + pub converged: bool, + pub current_revision: u64, + pub desired_revision: u64, pub health: ControlCenterNodeHealth, - pub registered_at_unix_ms: u64, + pub labels: BTreeMap, pub last_heartbeat_at_unix_ms: u64, + pub node_id: String, + pub pop: Option, + pub region: Option, + pub registered_at_unix_ms: u64, pub snapshot_version: u64, - pub current_revision: u64, - pub desired_revision: u64, - pub converged: bool, + pub version: String, } impl ControlCenterNode { + #[must_use] pub fn matches(&self, filter: &ControlCenterNodeFilter) -> bool { if let Some(region) = &filter.region && self.region.as_ref() != Some(region) @@ -45,37 +46,37 @@ impl ControlCenterNode { #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum ControlCenterNodeHealth { - Healthy, Degraded, + Healthy, Offline, } #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct NodeSession { - pub session_id: String, + pub last_seen_at_unix_ms: u64, pub node_id: String, + pub session_id: String, pub started_at_unix_ms: u64, - pub last_seen_at_unix_ms: u64, } #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct HeartbeatRecord { + pub converged: bool, + pub current_revision: u64, + pub desired_revision: u64, pub node_id: String, pub recorded_at_unix_ms: u64, pub snapshot_version: u64, - pub current_revision: u64, - pub desired_revision: u64, - pub converged: bool, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum ControlCenterCommandState { - Queued, Delivered, - Succeeded, - Failed, Expired, + Failed, + Queued, + Succeeded, } impl From for ControlCenterCommandState { @@ -92,16 +93,17 @@ pub struct ControlCenterCommandCreate { #[serde(rename = "type")] pub command_type: AgentCommandType, #[serde(default)] - pub revision: Option, - #[serde(default)] pub expires_at_unix_ms: Option, #[serde(default)] pub payload: serde_json::Value, #[serde(default)] + pub revision: Option, + #[serde(default)] pub signature: Option, } impl ControlCenterCommandCreate { + #[must_use] pub fn reload() -> Self { Self { command_type: AgentCommandType::Reload, @@ -112,6 +114,7 @@ impl ControlCenterCommandCreate { } } + #[must_use] pub fn set_desired_revision(revision: u64) -> Self { Self { command_type: AgentCommandType::SetDesiredRevision, @@ -126,23 +129,23 @@ impl ControlCenterCommandCreate { #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct ControlCenterCommandRecord { pub command: AgentCommand, - pub state: ControlCenterCommandState, + pub completed_at_unix_ms: Option, pub created_at_unix_ms: u64, pub delivered_at_unix_ms: Option, - pub completed_at_unix_ms: Option, pub result: Option, + pub state: ControlCenterCommandState, } #[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)] pub struct ControlCenterNodeFilter { - #[serde(default)] - pub region: Option, - #[serde(default)] - pub pop: Option, #[serde(default)] pub health: Option, #[serde(default)] pub labels: BTreeMap, + #[serde(default)] + pub pop: Option, + #[serde(default)] + pub region: Option, } #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] @@ -154,9 +157,9 @@ pub struct ControlCenterEventCreate { #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct ControlCenterEvent { + pub created_at_unix_ms: u64, + pub event_type: String, pub id: String, pub node_id: String, - pub event_type: String, pub payload: serde_json::Value, - pub created_at_unix_ms: u64, } diff --git a/crates/rginx-agent/src/control_center/query.rs b/crates/rginx-agent/src/control_center/query.rs index 624296e9..42b626fe 100644 --- a/crates/rginx-agent/src/control_center/query.rs +++ b/crates/rginx-agent/src/control_center/query.rs @@ -14,33 +14,42 @@ use super::rollout::{ use super::store::{ControlCenterStore, ensure_node, unix_ms}; impl ControlCenterStore { - pub async fn get_command(&self, command_id: &str) -> Option { - self.state.read().await.commands.get(command_id).cloned() + pub async fn create_rollout( + &self, + request: ControlCenterRolloutCreate, + ) -> Result { + let mut state = self.state.write().await; + let target_node_ids = selected_targets(&state.nodes, &request.selector) + .into_iter() + .map(|target| target.node_id) + .collect(); + let rollout = ControlCenterRollout { + id: Uuid::now_v7().to_string(), + name: request.name, + selector: request.selector, + target_node_ids, + created_at_unix_ms: unix_ms(), + }; + state.rollouts.insert(rollout.id.clone(), rollout.clone()); + Ok(rollout) } - pub async fn recent_results(&self, node_id: &str, limit: usize) -> Vec { - let state = self.state.read().await; - let mut results = state - .commands - .values() - .filter_map(|record| record.result.clone()) - .filter(|result| result.node_id == node_id) - .collect::>(); - results.sort_by_key(|result| Reverse(result.finished_at_unix_ms)); - results.truncate(limit); - results + pub async fn events_for_node(&self, node_id: &str) -> Vec { + self.state + .read() + .await + .events + .iter() + .filter(|event| event.node_id == node_id) + .cloned() + .collect() + } + pub async fn get_command(&self, command_id: &str) -> Option { + self.state.read().await.commands.get(command_id).cloned() } - pub async fn sessions_for_node(&self, node_id: &str) -> Vec { - let state = self.state.read().await; - let mut sessions = state - .node_sessions - .values() - .filter(|session| session.node_id == node_id) - .cloned() - .collect::>(); - sessions.sort_by_key(|session| session.started_at_unix_ms); - sessions + pub async fn get_rollout(&self, rollout_id: &str) -> Option { + self.state.read().await.rollouts.get(rollout_id).cloned() } pub async fn heartbeats_for_node(&self, node_id: &str) -> Vec { @@ -66,6 +75,19 @@ impl ControlCenterStore { Ok(()) } + pub async fn recent_results(&self, node_id: &str, limit: usize) -> Vec { + let state = self.state.read().await; + let mut results = state + .commands + .values() + .filter_map(|record| record.result.clone()) + .filter(|result| result.node_id == node_id) + .collect::>(); + results.sort_by_key(|result| Reverse(result.finished_at_unix_ms)); + results.truncate(limit); + results + } + pub async fn select_rollout_targets( &self, selector: RolloutTargetSelector, @@ -74,39 +96,16 @@ impl ControlCenterStore { selected_targets(&state.nodes, &selector) } - pub async fn create_rollout( - &self, - request: ControlCenterRolloutCreate, - ) -> Result { - let mut state = self.state.write().await; - let target_node_ids = selected_targets(&state.nodes, &request.selector) - .into_iter() - .map(|target| target.node_id) - .collect(); - let rollout = ControlCenterRollout { - id: Uuid::now_v7().to_string(), - name: request.name, - selector: request.selector, - target_node_ids, - created_at_unix_ms: unix_ms(), - }; - state.rollouts.insert(rollout.id.clone(), rollout.clone()); - Ok(rollout) - } - - pub async fn get_rollout(&self, rollout_id: &str) -> Option { - self.state.read().await.rollouts.get(rollout_id).cloned() - } - - pub async fn events_for_node(&self, node_id: &str) -> Vec { - self.state - .read() - .await - .events - .iter() - .filter(|event| event.node_id == node_id) + pub async fn sessions_for_node(&self, node_id: &str) -> Vec { + let state = self.state.read().await; + let mut sessions = state + .node_sessions + .values() + .filter(|session| session.node_id == node_id) .cloned() - .collect() + .collect::>(); + sessions.sort_by_key(|session| session.started_at_unix_ms); + sessions } } diff --git a/crates/rginx-agent/src/control_center/rollout.rs b/crates/rginx-agent/src/control_center/rollout.rs index 003221ed..c3cf0ae7 100644 --- a/crates/rginx-agent/src/control_center/rollout.rs +++ b/crates/rginx-agent/src/control_center/rollout.rs @@ -7,20 +7,21 @@ use super::model::{ControlCenterNode, ControlCenterNodeHealth}; #[derive(Debug, Clone, Default, PartialEq, Eq, Deserialize, Serialize)] pub struct RolloutTargetSelector { #[serde(default)] - pub region: Option, - #[serde(default)] - pub pop: Option, + pub desired_revision: Option, #[serde(default)] pub labels: BTreeMap, #[serde(default)] - pub require_healthy: bool, + pub pop: Option, + #[serde(default)] + pub region: Option, #[serde(default)] pub require_converged: bool, #[serde(default)] - pub desired_revision: Option, + pub require_healthy: bool, } impl RolloutTargetSelector { + #[must_use] pub fn matches(&self, node: &ControlCenterNode) -> bool { if let Some(region) = &self.region && node.region.as_ref() != Some(region) @@ -49,14 +50,14 @@ impl RolloutTargetSelector { #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct RolloutTarget { - pub node_id: String, - pub region: Option, - pub pop: Option, - pub labels: BTreeMap, + pub converged: bool, pub current_revision: u64, pub desired_revision: u64, - pub converged: bool, pub health: ControlCenterNodeHealth, + pub labels: BTreeMap, + pub node_id: String, + pub pop: Option, + pub region: Option, } impl From<&ControlCenterNode> for RolloutTarget { @@ -82,9 +83,9 @@ pub struct ControlCenterRolloutCreate { #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct ControlCenterRollout { + pub created_at_unix_ms: u64, pub id: String, pub name: String, pub selector: RolloutTargetSelector, pub target_node_ids: Vec, - pub created_at_unix_ms: u64, } diff --git a/crates/rginx-agent/src/control_center/store.rs b/crates/rginx-agent/src/control_center/store.rs index 0d795241..7142af4d 100644 --- a/crates/rginx-agent/src/control_center/store.rs +++ b/crates/rginx-agent/src/control_center/store.rs @@ -23,108 +23,23 @@ const DEFAULT_SIGNED_COMMAND_TTL_MS: u64 = 5 * 60 * 1000; #[derive(Clone, Default)] pub struct ControlCenterStore { - pub(super) state: Arc>, command_signing_key: Option>, + pub(super) state: Arc>, } #[derive(Default)] pub(super) struct ControlCenterState { - pub(super) nodes: BTreeMap, - pub(super) node_sessions: BTreeMap, pub(super) commands: BTreeMap, - pub(super) node_command_order: BTreeMap>, pub(super) desired_revisions: BTreeMap, - pub(super) heartbeats: BTreeMap>, pub(super) events: VecDeque, + pub(super) heartbeats: BTreeMap>, + pub(super) node_command_order: BTreeMap>, + pub(super) node_sessions: BTreeMap, + pub(super) nodes: BTreeMap, pub(super) rollouts: BTreeMap, } impl ControlCenterStore { - pub fn new() -> Self { - Self::default() - } - - pub fn with_command_signing_key(mut self, key: impl Into) -> Self { - let key = key.into().trim().to_string(); - if !key.is_empty() { - self.command_signing_key = Some(Arc::from(key)); - } - self - } - - pub async fn register(&self, request: AgentRegisterRequest) -> Result { - let now = unix_ms(); - let mut state = self.state.write().await; - let previous = state.nodes.get(&request.node_id).cloned(); - let desired = previous.as_ref().map(|node| node.desired_revision).unwrap_or(0); - let node = ControlCenterNode { - node_id: request.node_id.clone(), - version: request.version, - region: request.region, - pop: request.pop, - labels: request.labels, - capabilities: request.capabilities, - health: ControlCenterNodeHealth::Healthy, - registered_at_unix_ms: previous.as_ref().map_or(now, |node| node.registered_at_unix_ms), - last_heartbeat_at_unix_ms: now, - snapshot_version: previous.as_ref().map_or(0, |node| node.snapshot_version), - current_revision: previous.as_ref().map_or(0, |node| node.current_revision), - desired_revision: desired, - converged: previous.as_ref().is_none_or(|node| node.converged), - }; - let session = NodeSession { - session_id: Uuid::now_v7().to_string(), - node_id: node.node_id.clone(), - started_at_unix_ms: now, - last_seen_at_unix_ms: now, - }; - state.node_sessions.insert(session.session_id.clone(), session); - state.nodes.insert(node.node_id.clone(), node.clone()); - Ok(node) - } - - pub async fn heartbeat(&self, request: AgentHeartbeatRequest) -> Result { - let now = unix_ms(); - let mut state = self.state.write().await; - ensure_node(&state, &request.node_id)?; - let desired = *state - .desired_revisions - .entry(request.node_id.clone()) - .or_insert(request.desired_revision); - let updated_node = { - let node = state.nodes.get_mut(&request.node_id).expect("node was checked"); - node.last_heartbeat_at_unix_ms = now; - node.snapshot_version = request.snapshot_version; - node.current_revision = request.current_revision; - node.desired_revision = desired; - node.converged = request.converged - && request.desired_revision == desired - && request.current_revision == desired; - node.health = ControlCenterNodeHealth::Healthy; - node.clone() - }; - - let record = HeartbeatRecord { - node_id: request.node_id.clone(), - recorded_at_unix_ms: now, - snapshot_version: request.snapshot_version, - current_revision: request.current_revision, - desired_revision: desired, - converged: updated_node.converged, - }; - let records = state.heartbeats.entry(request.node_id.clone()).or_default(); - records.push_back(record); - while records.len() > HEARTBEAT_LIMIT_PER_NODE { - records.pop_front(); - } - for session in - state.node_sessions.values_mut().filter(|item| item.node_id == request.node_id) - { - session.last_seen_at_unix_ms = now; - } - Ok(updated_node) - } - pub async fn create_command( &self, node_id: &str, @@ -181,6 +96,62 @@ impl ControlCenterStore { Ok(record) } + pub async fn get_node(&self, node_id: &str) -> Option { + self.state.read().await.nodes.get(node_id).cloned() + } + + pub async fn heartbeat(&self, request: AgentHeartbeatRequest) -> Result { + let now = unix_ms(); + let mut state = self.state.write().await; + ensure_node(&state, &request.node_id)?; + let desired = *state + .desired_revisions + .entry(request.node_id.clone()) + .or_insert(request.desired_revision); + let updated_node = { + let node = state.nodes.get_mut(&request.node_id).expect("node was checked"); + node.last_heartbeat_at_unix_ms = now; + node.snapshot_version = request.snapshot_version; + node.current_revision = request.current_revision; + node.desired_revision = desired; + node.converged = request.converged + && request.desired_revision == desired + && request.current_revision == desired; + node.health = ControlCenterNodeHealth::Healthy; + node.clone() + }; + + let record = HeartbeatRecord { + node_id: request.node_id.clone(), + recorded_at_unix_ms: now, + snapshot_version: request.snapshot_version, + current_revision: request.current_revision, + desired_revision: desired, + converged: updated_node.converged, + }; + let records = state.heartbeats.entry(request.node_id.clone()).or_default(); + records.push_back(record); + while records.len() > HEARTBEAT_LIMIT_PER_NODE { + records.pop_front(); + } + for session in + state.node_sessions.values_mut().filter(|item| item.node_id == request.node_id) + { + session.last_seen_at_unix_ms = now; + } + Ok(updated_node) + } + + pub async fn list_nodes(&self, filter: ControlCenterNodeFilter) -> Vec { + let state = self.state.read().await; + state.nodes.values().filter(|node| node.matches(&filter)).cloned().collect() + } + + #[must_use] + pub fn new() -> Self { + Self::default() + } + pub async fn poll_commands( &self, node_id: &str, @@ -193,7 +164,7 @@ impl ControlCenterStore { let start = cursor .as_ref() .and_then(|id| order.iter().position(|candidate| candidate == id)) - .map_or(0, |index| index + 1); + .map_or(0, |index| index.saturating_add(1)); let mut commands = Vec::new(); for command_id in order.iter().skip(start) { let Some(record) = state.commands.get_mut(command_id) else { @@ -267,13 +238,43 @@ impl ControlCenterStore { Ok(event) } - pub async fn list_nodes(&self, filter: ControlCenterNodeFilter) -> Vec { - let state = self.state.read().await; - state.nodes.values().filter(|node| node.matches(&filter)).cloned().collect() + pub async fn register(&self, request: AgentRegisterRequest) -> Result { + let now = unix_ms(); + let mut state = self.state.write().await; + let previous = state.nodes.get(&request.node_id).cloned(); + let desired = previous.as_ref().map_or(0, |node| node.desired_revision); + let node = ControlCenterNode { + node_id: request.node_id.clone(), + version: request.version, + region: request.region, + pop: request.pop, + labels: request.labels, + capabilities: request.capabilities, + health: ControlCenterNodeHealth::Healthy, + registered_at_unix_ms: previous.as_ref().map_or(now, |node| node.registered_at_unix_ms), + last_heartbeat_at_unix_ms: now, + snapshot_version: previous.as_ref().map_or(0, |node| node.snapshot_version), + current_revision: previous.as_ref().map_or(0, |node| node.current_revision), + desired_revision: desired, + converged: previous.as_ref().is_none_or(|node| node.converged), + }; + let session = NodeSession { + session_id: Uuid::now_v7().to_string(), + node_id: node.node_id.clone(), + started_at_unix_ms: now, + last_seen_at_unix_ms: now, + }; + state.node_sessions.insert(session.session_id.clone(), session); + state.nodes.insert(node.node_id.clone(), node.clone()); + Ok(node) } - pub async fn get_node(&self, node_id: &str) -> Option { - self.state.read().await.nodes.get(node_id).cloned() + pub fn with_command_signing_key(mut self, key: impl Into) -> Self { + let key = key.into().trim().to_string(); + if !key.is_empty() { + self.command_signing_key = Some(Arc::from(key)); + } + self } } diff --git a/crates/rginx-agent/src/control_center/trait_adapter.rs b/crates/rginx-agent/src/control_center/trait_adapter.rs index 4eb6223d..2e813e29 100644 --- a/crates/rginx-agent/src/control_center/trait_adapter.rs +++ b/crates/rginx-agent/src/control_center/trait_adapter.rs @@ -10,14 +10,6 @@ use crate::{ use super::store::ControlCenterStore; impl OutboundControlPlaneClient for ControlCenterStore { - fn register( - &self, - request: AgentRegisterRequest, - ) -> Pin> + Send + 'static>> { - let store = self.clone(); - Box::pin(async move { ControlCenterStore::register(&store, request).await.map(|_| ()) }) - } - fn heartbeat( &self, request: AgentHeartbeatRequest, @@ -43,4 +35,11 @@ impl OutboundControlPlaneClient for ControlCenterStore { let store = self.clone(); Box::pin(async move { ControlCenterStore::post_result(&store, result).await.map(|_| ()) }) } + fn register( + &self, + request: AgentRegisterRequest, + ) -> Pin> + Send + 'static>> { + let store = self.clone(); + Box::pin(async move { ControlCenterStore::register(&store, request).await.map(|_| ()) }) + } } diff --git a/crates/rginx-agent/src/error.rs b/crates/rginx-agent/src/error.rs index 17439961..73bf37d8 100644 --- a/crates/rginx-agent/src/error.rs +++ b/crates/rginx-agent/src/error.rs @@ -5,30 +5,30 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum Error { - #[error("io error: {0}")] - Io(#[from] io::Error), + #[error("address parse error: {0}")] + AddrParse(#[from] std::net::AddrParseError), + #[error(transparent)] + Core(#[from] CoreError), + #[error("forbidden control plane request: {0}")] + Forbidden(String), #[error("http error: {0}")] Http(#[from] http::Error), #[error("hyper error: {0}")] Hyper(#[from] hyper::Error), - #[error("serde error: {0}")] - Serde(#[from] serde_json::Error), - #[error("pem error: {0}")] - Pem(#[from] pem::PemError), - #[error("address parse error: {0}")] - AddrParse(#[from] std::net::AddrParseError), #[error("invalid control plane request: {0}")] InvalidRequest(String), - #[error("unauthorized control plane request: {0}")] - Unauthorized(String), - #[error("forbidden control plane request: {0}")] - Forbidden(String), + #[error("io error: {0}")] + Io(#[from] io::Error), #[error("not found: {0}")] NotFound(String), + #[error("pem error: {0}")] + Pem(#[from] pem::PemError), + #[error("serde error: {0}")] + Serde(#[from] serde_json::Error), #[error("control plane server error: {0}")] Server(String), - #[error(transparent)] - Core(#[from] CoreError), + #[error("unauthorized control plane request: {0}")] + Unauthorized(String), } pub type Result = std::result::Result; diff --git a/crates/rginx-agent/src/events.rs b/crates/rginx-agent/src/events.rs index 58eb7b4e..c8ed2ff1 100644 --- a/crates/rginx-agent/src/events.rs +++ b/crates/rginx-agent/src/events.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::collections::HashMap; use std::sync::Arc; @@ -12,28 +14,22 @@ use crate::registry::NodeStatus; #[derive(Debug, Clone, Serialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ControlPlaneEvent { - ConfigUpdateAvailable { + CacheInvalidated { node_id: String, - revision: u64, - config_hash: String, + zone_name: String, + invalidation_type: String, timestamp: u64, }, - ReloadRequired { + CertificateExpiring { node_id: String, - reason: String, + domain: String, + days_left: u32, timestamp: u64, }, - ReloadCompleted { + ConfigUpdateAvailable { node_id: String, revision: u64, - success: bool, - duration_ms: u64, - timestamp: u64, - }, - CertificateExpiring { - node_id: String, - domain: String, - days_left: u32, + config_hash: String, timestamp: u64, }, HealthCheckFailed { @@ -49,15 +45,22 @@ pub enum ControlPlaneEvent { new_status: NodeStatus, timestamp: u64, }, - CacheInvalidated { + ReloadCompleted { node_id: String, - zone_name: String, - invalidation_type: String, + revision: u64, + success: bool, + duration_ms: u64, + timestamp: u64, + }, + ReloadRequired { + node_id: String, + reason: String, timestamp: u64, }, } impl ControlPlaneEvent { + #[must_use] pub fn event_type(&self) -> String { match self { Self::ConfigUpdateAvailable { .. } => "config_update_available".to_string(), @@ -70,6 +73,7 @@ impl ControlPlaneEvent { } } + #[must_use] pub fn node_id(&self) -> Option { match self { Self::ConfigUpdateAvailable { node_id, .. } @@ -82,6 +86,7 @@ impl ControlPlaneEvent { } } + #[must_use] pub fn timestamp(&self) -> u64 { match self { Self::ConfigUpdateAvailable { timestamp, .. } @@ -104,6 +109,7 @@ pub struct EventFilter { } impl EventFilter { + #[must_use] pub fn matches(&self, event: &ControlPlaneEvent) -> bool { if !self.event_types.is_empty() && !self.event_types.contains(&event.event_type()) { return false; @@ -136,6 +142,7 @@ pub struct EventBus { } impl EventBus { + #[must_use] pub fn new(capacity: usize) -> Self { let (sender, _) = broadcast::channel(capacity); Self { sender, subscribers: Arc::new(RwLock::new(HashMap::new())) } @@ -176,14 +183,8 @@ impl EventBus { tracing::info!(sub_id = %subscription_id, "event subscription created"); } - /// Unsubscribe from events - pub async fn unsubscribe(&self, subscription_id: &str) { - let mut subscribers = self.subscribers.write().await; - subscribers.remove(subscription_id); - tracing::info!(sub_id = %subscription_id, "event subscription removed"); - } - /// Get a broadcast receiver for channel-based subscriptions + #[must_use] pub fn subscribe_channel(&self) -> broadcast::Receiver { self.sender.subscribe() } @@ -193,7 +194,11 @@ impl EventBus { let subscribers = self.subscribers.read().await; subscribers.len() } -} -#[cfg(test)] -mod tests; + /// Unsubscribe from events + pub async fn unsubscribe(&self, subscription_id: &str) { + let mut subscribers = self.subscribers.write().await; + subscribers.remove(subscription_id); + tracing::info!(sub_id = %subscription_id, "event subscription removed"); + } +} diff --git a/crates/rginx-agent/src/gradual_rollout.rs b/crates/rginx-agent/src/gradual_rollout.rs index ee8ce31c..4bd43798 100644 --- a/crates/rginx-agent/src/gradual_rollout.rs +++ b/crates/rginx-agent/src/gradual_rollout.rs @@ -1,97 +1,130 @@ +mod status; + +#[cfg(test)] +mod tests; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; -mod status; - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum RolloutStrategy { - Canary, BlueGreen, + Canary, Progressive, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum RolloutPhase { - Pending, + Completed, + Failed, InProgress, Paused, - Completed, + Pending, RolledBack, - Failed, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RolloutStage { - pub stage_id: u32, - pub target_percentage: u32, - pub target_nodes: Vec, + pub completed_at: Option, pub duration_secs: u64, pub health_check_interval_secs: u64, - pub success_threshold: f64, + pub stage_id: u32, pub started_at: Option, - pub completed_at: Option, pub status: RolloutPhase, + pub success_threshold: f64, + pub target_nodes: Vec, + pub target_percentage: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RolloutPlan { - pub rollout_id: String, - pub strategy: RolloutStrategy, - pub config_revision: u64, - pub stages: Vec, pub auto_rollback_on_failure: bool, + pub config_revision: u64, pub created_at: u64, pub created_by: String, pub current_stage: u32, pub phase: RolloutPhase, + pub rollout_id: String, + pub stages: Vec, + pub strategy: RolloutStrategy, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RolloutStatus { - pub rollout_id: String, - pub phase: RolloutPhase, + pub completed_at: Option, pub current_stage: u32, - pub total_stages: u32, - pub nodes_updated: u32, + pub error_message: Option, pub nodes_total: u32, - pub success_rate: f64, + pub nodes_updated: u32, + pub phase: RolloutPhase, + pub rollout_id: String, pub started_at: Option, - pub completed_at: Option, - pub error_message: Option, + pub success_rate: f64, + pub total_stages: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeRolloutState { + pub applied_at: u64, + pub config_revision: u64, + pub error_count: u32, + pub health_status: HealthStatus, pub node_id: String, pub rollout_id: String, pub stage_id: u32, - pub config_revision: u64, - pub applied_at: u64, - pub health_status: HealthStatus, - pub error_count: u32, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum HealthStatus { - Healthy, Degraded, + Healthy, Unhealthy, Unknown, } pub struct GradualRolloutManager { - rollouts: Arc>>, node_states: Arc>>, + rollouts: Arc>>, } impl GradualRolloutManager { - pub fn new() -> Self { - Self { - rollouts: Arc::new(RwLock::new(HashMap::new())), - node_states: Arc::new(RwLock::new(HashMap::new())), + pub async fn advance_stage(&self, rollout_id: &str) -> Result<(), String> { + let mut rollouts = self.rollouts.write().await; + let rollout = rollouts + .get_mut(rollout_id) + .ok_or_else(|| format!("Rollout {rollout_id} not found"))?; + + if rollout.phase != RolloutPhase::InProgress { + return Err(format!( + "Rollout {} is not in progress, current: {:?}", + rollout_id, rollout.phase + )); + } + + let current_stage_idx = rollout.current_stage as usize; + if current_stage_idx >= rollout.stages.len() { + return Err("No more stages to advance".to_string()); + } + + if let Some(current_stage) = rollout.stages.get_mut(current_stage_idx) { + current_stage.status = RolloutPhase::Completed; + current_stage.completed_at = Some(current_timestamp()); } + + let next_stage_idx = current_stage_idx.saturating_add(1); + if next_stage_idx >= rollout.stages.len() { + rollout.phase = RolloutPhase::Completed; + return Ok(()); + } + + rollout.current_stage = next_stage_idx as u32; + if let Some(next_stage) = rollout.stages.get_mut(next_stage_idx) { + next_stage.status = RolloutPhase::InProgress; + next_stage.started_at = Some(current_timestamp()); + } + + Ok(()) } pub async fn create_rollout(&self, plan: RolloutPlan) -> Result { @@ -99,7 +132,7 @@ impl GradualRolloutManager { return Err("Rollout plan must have at least one stage".to_string()); } - let mut total_percentage = 0; + let mut total_percentage = 0u32; for stage in &plan.stages { if stage.target_percentage == 0 || stage.target_percentage > 100 { return Err(format!( @@ -107,11 +140,11 @@ impl GradualRolloutManager { stage.target_percentage, stage.stage_id )); } - total_percentage += stage.target_percentage; + total_percentage = total_percentage.saturating_add(stage.target_percentage); } if total_percentage != 100 { - return Err(format!("Total percentage must equal 100, got {}", total_percentage)); + return Err(format!("Total percentage must equal 100, got {total_percentage}")); } let rollout_id = plan.rollout_id.clone(); @@ -121,6 +154,11 @@ impl GradualRolloutManager { Ok(rollout_id) } + pub async fn get_node_state(&self, node_id: &str) -> Option { + let node_states = self.node_states.read().await; + node_states.get(node_id).cloned() + } + pub async fn get_rollout(&self, rollout_id: &str) -> Option { let rollouts = self.rollouts.read().await; rollouts.get(rollout_id).cloned() @@ -131,35 +169,19 @@ impl GradualRolloutManager { rollouts.values().cloned().collect() } - pub async fn start_rollout(&self, rollout_id: &str) -> Result<(), String> { - let mut rollouts = self.rollouts.write().await; - let rollout = rollouts - .get_mut(rollout_id) - .ok_or_else(|| format!("Rollout {} not found", rollout_id))?; - - if rollout.phase != RolloutPhase::Pending { - return Err(format!( - "Rollout {} is not in pending state, current: {:?}", - rollout_id, rollout.phase - )); - } - - rollout.phase = RolloutPhase::InProgress; - rollout.current_stage = 0; - - if let Some(first_stage) = rollout.stages.first_mut() { - first_stage.status = RolloutPhase::InProgress; - first_stage.started_at = Some(current_timestamp()); + #[must_use] + pub fn new() -> Self { + Self { + rollouts: Arc::new(RwLock::new(HashMap::new())), + node_states: Arc::new(RwLock::new(HashMap::new())), } - - Ok(()) } pub async fn pause_rollout(&self, rollout_id: &str) -> Result<(), String> { let mut rollouts = self.rollouts.write().await; let rollout = rollouts .get_mut(rollout_id) - .ok_or_else(|| format!("Rollout {} not found", rollout_id))?; + .ok_or_else(|| format!("Rollout {rollout_id} not found"))?; if rollout.phase != RolloutPhase::InProgress { return Err(format!( @@ -176,7 +198,7 @@ impl GradualRolloutManager { let mut rollouts = self.rollouts.write().await; let rollout = rollouts .get_mut(rollout_id) - .ok_or_else(|| format!("Rollout {} not found", rollout_id))?; + .ok_or_else(|| format!("Rollout {rollout_id} not found"))?; if rollout.phase != RolloutPhase::Paused { return Err(format!( @@ -189,63 +211,49 @@ impl GradualRolloutManager { Ok(()) } - pub async fn advance_stage(&self, rollout_id: &str) -> Result<(), String> { + pub async fn rollback(&self, rollout_id: &str, reason: &str) -> Result<(), String> { let mut rollouts = self.rollouts.write().await; let rollout = rollouts .get_mut(rollout_id) - .ok_or_else(|| format!("Rollout {} not found", rollout_id))?; + .ok_or_else(|| format!("Rollout {rollout_id} not found"))?; - if rollout.phase != RolloutPhase::InProgress { + if rollout.phase == RolloutPhase::Completed || rollout.phase == RolloutPhase::RolledBack { return Err(format!( - "Rollout {} is not in progress, current: {:?}", + "Cannot rollback rollout {} in state {:?}", rollout_id, rollout.phase )); } - let current_stage_idx = rollout.current_stage as usize; - if current_stage_idx >= rollout.stages.len() { - return Err("No more stages to advance".to_string()); - } - - if let Some(current_stage) = rollout.stages.get_mut(current_stage_idx) { - current_stage.status = RolloutPhase::Completed; - current_stage.completed_at = Some(current_timestamp()); - } - - let next_stage_idx = current_stage_idx + 1; - if next_stage_idx >= rollout.stages.len() { - rollout.phase = RolloutPhase::Completed; - return Ok(()); - } + rollout.phase = RolloutPhase::RolledBack; - rollout.current_stage = next_stage_idx as u32; - if let Some(next_stage) = rollout.stages.get_mut(next_stage_idx) { - next_stage.status = RolloutPhase::InProgress; - next_stage.started_at = Some(current_timestamp()); - } + let mut node_states = self.node_states.write().await; + node_states.retain(|_, state| state.rollout_id != rollout_id); + tracing::info!("Rolled back rollout {}: {}", rollout_id, reason); Ok(()) } - pub async fn rollback(&self, rollout_id: &str, reason: &str) -> Result<(), String> { + pub async fn start_rollout(&self, rollout_id: &str) -> Result<(), String> { let mut rollouts = self.rollouts.write().await; let rollout = rollouts .get_mut(rollout_id) - .ok_or_else(|| format!("Rollout {} not found", rollout_id))?; + .ok_or_else(|| format!("Rollout {rollout_id} not found"))?; - if rollout.phase == RolloutPhase::Completed || rollout.phase == RolloutPhase::RolledBack { + if rollout.phase != RolloutPhase::Pending { return Err(format!( - "Cannot rollback rollout {} in state {:?}", + "Rollout {} is not in pending state, current: {:?}", rollout_id, rollout.phase )); } - rollout.phase = RolloutPhase::RolledBack; + rollout.phase = RolloutPhase::InProgress; + rollout.current_stage = 0; - let mut node_states = self.node_states.write().await; - node_states.retain(|_, state| state.rollout_id != rollout_id); + if let Some(first_stage) = rollout.stages.first_mut() { + first_stage.status = RolloutPhase::InProgress; + first_stage.started_at = Some(current_timestamp()); + } - tracing::info!("Rolled back rollout {}: {}", rollout_id, reason); Ok(()) } @@ -254,11 +262,6 @@ impl GradualRolloutManager { node_states.insert(state.node_id.clone(), state); Ok(()) } - - pub async fn get_node_state(&self, node_id: &str) -> Option { - let node_states = self.node_states.read().await; - node_states.get(node_id).cloned() - } } impl Default for GradualRolloutManager { @@ -270,6 +273,3 @@ impl Default for GradualRolloutManager { fn current_timestamp() -> u64 { std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/gradual_rollout/status.rs b/crates/rginx-agent/src/gradual_rollout/status.rs index 037207d4..095aeccb 100644 --- a/crates/rginx-agent/src/gradual_rollout/status.rs +++ b/crates/rginx-agent/src/gradual_rollout/status.rs @@ -1,6 +1,34 @@ -use super::*; +use super::{GradualRolloutManager, HealthStatus, RolloutPhase, RolloutStatus}; impl GradualRolloutManager { + pub async fn check_stage_health(&self, rollout_id: &str) -> Result { + let rollouts = self.rollouts.read().await; + let rollout = + rollouts.get(rollout_id).ok_or_else(|| format!("Rollout {rollout_id} not found"))?; + + let current_stage_idx = rollout.current_stage as usize; + let current_stage = rollout + .stages + .get(current_stage_idx) + .ok_or_else(|| "Invalid current stage".to_string())?; + + let node_states = self.node_states.read().await; + let stage_nodes = node_states + .values() + .filter(|state| { + state.rollout_id == rollout_id && state.stage_id == current_stage.stage_id + }) + .collect::>(); + + if stage_nodes.is_empty() { + return Ok(true); + } + + let healthy_count = + stage_nodes.iter().filter(|state| state.health_status == HealthStatus::Healthy).count(); + let success_rate = healthy_count as f64 / stage_nodes.len() as f64; + Ok(success_rate >= current_stage.success_threshold) + } pub async fn get_rollout_status(&self, rollout_id: &str) -> Option { let rollouts = self.rollouts.read().await; let rollout = rollouts.get(rollout_id)?; @@ -16,7 +44,7 @@ impl GradualRolloutManager { .filter(|state| state.health_status == HealthStatus::Healthy) .count(); let success_rate = - if nodes_updated > 0 { healthy_nodes as f64 / nodes_updated as f64 } else { 0.0 }; + if nodes_updated > 0 { healthy_nodes as f64 / f64::from(nodes_updated) } else { 0.0 }; let started_at = rollout.stages.first().and_then(|stage| stage.started_at); let completed_at = (rollout.phase == RolloutPhase::Completed) .then(|| rollout.stages.last().and_then(|stage| stage.completed_at)) @@ -35,33 +63,4 @@ impl GradualRolloutManager { error_message: None, }) } - - pub async fn check_stage_health(&self, rollout_id: &str) -> Result { - let rollouts = self.rollouts.read().await; - let rollout = - rollouts.get(rollout_id).ok_or_else(|| format!("Rollout {rollout_id} not found"))?; - - let current_stage_idx = rollout.current_stage as usize; - let current_stage = rollout - .stages - .get(current_stage_idx) - .ok_or_else(|| "Invalid current stage".to_string())?; - - let node_states = self.node_states.read().await; - let stage_nodes = node_states - .values() - .filter(|state| { - state.rollout_id == rollout_id && state.stage_id == current_stage.stage_id - }) - .collect::>(); - - if stage_nodes.is_empty() { - return Ok(true); - } - - let healthy_count = - stage_nodes.iter().filter(|state| state.health_status == HealthStatus::Healthy).count(); - let success_rate = healthy_count as f64 / stage_nodes.len() as f64; - Ok(success_rate >= current_stage.success_threshold) - } } diff --git a/crates/rginx-agent/src/lib.rs b/crates/rginx-agent/src/lib.rs index 92e8a2c5..bb819e69 100644 --- a/crates/rginx-agent/src/lib.rs +++ b/crates/rginx-agent/src/lib.rs @@ -4,6 +4,15 @@ //! it listens on a node-local control port and accepts inbound management //! requests. New control-plane communication should target the outbound agent //! model described in `docs/AGENT_OUTBOUND_CONTROL_PLANE_PLAN.md`. +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] pub mod agent_core; pub mod api; @@ -26,6 +35,9 @@ mod system; mod tls; mod websocket; +#[cfg(test)] +mod tests; + pub use agent_core::{ AgentCore, CacheClearInvalidationsCommand, CacheInvalidateCommand, CacheInvalidateTarget, CachePurgeCommand, CachePurgeTarget, @@ -72,6 +84,3 @@ pub use server::control::{ }; pub use server::{run, run_with_context, run_with_listener}; pub use tls::ClientCertIdentity; - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/metrics.rs b/crates/rginx-agent/src/metrics.rs index 451487ab..8b57b42c 100644 --- a/crates/rginx-agent/src/metrics.rs +++ b/crates/rginx-agent/src/metrics.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use lazy_static::lazy_static; use prometheus::{ CounterVec, Encoder, Gauge, HistogramVec, Registry, TextEncoder, register_counter_vec, @@ -66,15 +68,15 @@ lazy_static! { } pub struct MetricsCollector { - #[allow(dead_code)] + #[expect( + dead_code, + reason = "collector owns a registry handle for future custom registration" + )] registry: Arc, } impl MetricsCollector { - pub fn new() -> Self { - Self { registry: Arc::new(Registry::new()) } - } - + #[must_use] pub fn gather(&self) -> String { let encoder = TextEncoder::new(); let metric_families = prometheus::gather(); @@ -82,6 +84,10 @@ impl MetricsCollector { encoder.encode(&metric_families, &mut buffer).unwrap(); String::from_utf8(buffer).unwrap() } + #[must_use] + pub fn new() -> Self { + Self { registry: Arc::new(Registry::new()) } + } } impl Default for MetricsCollector { @@ -139,6 +145,3 @@ pub fn record_config_rollback(success: bool) { let status = if success { "success" } else { "failure" }; CONFIG_ROLLBACKS_TOTAL.with_label_values(&[status]).inc(); } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/model.rs b/crates/rginx-agent/src/model.rs index 9db8a519..f97a2b53 100644 --- a/crates/rginx-agent/src/model.rs +++ b/crates/rginx-agent/src/model.rs @@ -7,31 +7,31 @@ use serde::Serialize; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum NodeObservabilityView { - Status, - Snapshot, - Delta, - Wait, - Traffic, - Upstreams, Cache, + Delta, Revision, + Snapshot, + Status, System, + Traffic, + Upstreams, + Wait, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum NodeControlAction { - Reload, - PurgeCache, - InvalidateCache, - ClearCacheInvalidations, ApplyConfig, + ClearCacheInvalidations, + InvalidateCache, PublishDesiredRevision, + PurgeCache, + Reload, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ControlPlaneResource { - Observability(NodeObservabilityView), Control(NodeControlAction), + Observability(NodeObservabilityView), Registry, } @@ -46,13 +46,13 @@ pub struct NodeStatusView(pub RuntimeStatusSnapshot); #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct NodeSnapshotView { + pub cache: CacheStatsSnapshot, + pub counters: HttpCountersSnapshot, + pub peer_health: Vec, pub snapshot_version: u64, pub status: RuntimeStatusSnapshot, - pub counters: HttpCountersSnapshot, pub traffic: TrafficStatsSnapshot, - pub peer_health: Vec, pub upstreams: Vec, - pub cache: CacheStatsSnapshot, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] @@ -81,72 +81,72 @@ pub struct NodeRevisionView(pub RevisionStatusSnapshot); #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct NodeSystemView { pub collected_at_unix_ms: u64, + pub filesystems: Vec, pub hostname: String, pub kernel_release: String, pub kernel_version: String, - pub uptime_secs: u64, pub load: SystemLoadView, pub memory: SystemMemoryView, - pub filesystems: Vec, pub networks: Vec, + pub uptime_secs: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct NodeActionStatusView { pub accepted_revision: u64, + pub last_apply_result: Option, + pub last_reload_result: Option, #[serde(flatten)] pub revision: RevisionStatusSnapshot, - pub last_reload_result: Option, - pub last_apply_result: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct NodeControlResultView { + pub result: T, #[serde(flatten)] pub status: NodeActionStatusView, - pub result: T, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct ConfigApplyResultView { - pub operation: String, pub kind: String, - pub resource_id: String, + pub managed_path: String, + pub operation: String, #[serde(skip_serializing_if = "Option::is_none")] pub owner: Option, + pub resource_id: String, #[serde(skip_serializing_if = "Option::is_none")] pub tenant: Option, - pub managed_path: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct SystemLoadView { + pub last_pid: u64, + pub loadavg_15m: String, pub loadavg_1m: String, pub loadavg_5m: String, - pub loadavg_15m: String, pub running_tasks: u64, pub total_tasks: u64, - pub last_pid: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct SystemMemoryView { - pub total_bytes: u64, pub available_bytes: u64, - pub used_bytes: u64, - pub free_bytes: u64, - pub cached_bytes: u64, pub buffers_bytes: u64, - pub swap_total_bytes: u64, + pub cached_bytes: u64, + pub free_bytes: u64, pub swap_free_bytes: u64, + pub swap_total_bytes: u64, pub swap_used_bytes: u64, + pub total_bytes: u64, + pub used_bytes: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct SystemFilesystemView { + pub available_bytes: u64, pub path: String, pub total_bytes: u64, - pub available_bytes: u64, pub used_bytes: u64, } diff --git a/crates/rginx-agent/src/outbound/auth.rs b/crates/rginx-agent/src/outbound/auth.rs index 8f903903..5745b2c5 100644 --- a/crates/rginx-agent/src/outbound/auth.rs +++ b/crates/rginx-agent/src/outbound/auth.rs @@ -19,10 +19,10 @@ pub const SIGNATURE_HEADER: &str = "x-rginx-signature"; #[derive(Clone, PartialEq, Eq)] pub struct AuthenticatedRequestHeaders { pub authorization: String, - pub timestamp: String, - pub nonce: String, pub body_sha256: String, + pub nonce: String, pub signature: String, + pub timestamp: String, } #[derive(Clone)] @@ -48,6 +48,10 @@ impl OutboundRequestSigner { self.sign_with_nonce(method.as_str(), path_and_query, body, unix_ms(), Uuid::now_v7()) } + pub fn sign_command(&self, command: &AgentCommand) -> Result { + sign_agent_command(&self.token, command) + } + pub(crate) fn sign_with_nonce( &self, method: &str, @@ -69,16 +73,12 @@ impl OutboundRequestSigner { signature: hmac_sha256_hex(self.token.as_bytes(), material.as_bytes())?, }) } - - pub fn sign_command(&self, command: &AgentCommand) -> Result { - sign_agent_command(&self.token, command) - } } pub struct OutboundAuthVerifier { - token: String, max_clock_skew: Duration, seen_nonces: HashMap, + token: String, } impl OutboundAuthVerifier { @@ -90,6 +90,10 @@ impl OutboundAuthVerifier { Ok(Self { token, max_clock_skew, seen_nonces: HashMap::new() }) } + fn prune_seen_nonces(&mut self, now_unix_ms: u64, skew_ms: u64) { + self.seen_nonces.retain(|_, timestamp| now_unix_ms.saturating_sub(*timestamp) <= skew_ms); + } + pub fn verify( &mut self, method: &str, @@ -131,10 +135,6 @@ impl OutboundAuthVerifier { self.seen_nonces.insert(headers.nonce.clone(), timestamp); Ok(()) } - - fn prune_seen_nonces(&mut self, now_unix_ms: u64, skew_ms: u64) { - self.seen_nonces.retain(|_, timestamp| now_unix_ms.saturating_sub(*timestamp) <= skew_ms); - } } pub fn sign_agent_command(token: &str, command: &AgentCommand) -> Result { diff --git a/crates/rginx-agent/src/outbound/client.rs b/crates/rginx-agent/src/outbound/client.rs index e788e962..49f062c6 100644 --- a/crates/rginx-agent/src/outbound/client.rs +++ b/crates/rginx-agent/src/outbound/client.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::future::Future; use std::pin::Pin; use std::time::Duration; @@ -27,7 +29,6 @@ const NODE_ID_HEADER: HeaderName = HeaderName::from_static("x-rginx-node-id"); pub type OutboundClientFuture = Pin> + Send + 'static>>; pub trait OutboundControlPlaneClient: Send + Sync { - fn register(&self, request: AgentRegisterRequest) -> OutboundClientFuture<()>; fn heartbeat(&self, request: AgentHeartbeatRequest) -> OutboundClientFuture<()>; fn poll_commands( &self, @@ -36,14 +37,15 @@ pub trait OutboundControlPlaneClient: Send + Sync { timeout: Duration, ) -> OutboundClientFuture; fn post_result(&self, result: AgentCommandResult) -> OutboundClientFuture<()>; + fn register(&self, request: AgentRegisterRequest) -> OutboundClientFuture<()>; } #[derive(Clone)] pub struct HttpOutboundControlPlaneClient { + client: Client, Full>, endpoint: Uri, - signer: OutboundRequestSigner, request_timeout: Duration, - client: Client, Full>, + signer: OutboundRequestSigner, } impl HttpOutboundControlPlaneClient { @@ -51,21 +53,9 @@ impl HttpOutboundControlPlaneClient { Self::with_request_timeout(endpoint, token, Duration::from_secs(30)) } - pub fn with_request_timeout(endpoint: Uri, token: String, request_timeout: Duration) -> Self { - let connector = HttpsConnectorBuilder::new() - .with_native_roots() - .expect("native TLS roots should load") - .https_or_http() - .enable_http1() - .enable_http2() - .build(); - Self { - endpoint, - signer: OutboundRequestSigner::new(token) - .expect("outbound agent token should be validated before client construction"), - request_timeout, - client: Client::builder(TokioExecutor::new()).build(connector), - } + fn path_for(&self, path_and_query: &str) -> String { + let base_path = self.endpoint.path().trim_end_matches('/'); + format!("{base_path}{path_and_query}") } async fn send_json( @@ -129,29 +119,25 @@ impl HttpOutboundControlPlaneClient { .map_err(|error| Error::InvalidRequest(error.to_string())) } - fn path_for(&self, path_and_query: &str) -> String { - let base_path = self.endpoint.path().trim_end_matches('/'); - format!("{base_path}{path_and_query}") + pub fn with_request_timeout(endpoint: Uri, token: String, request_timeout: Duration) -> Self { + let connector = HttpsConnectorBuilder::new() + .with_native_roots() + .expect("native TLS roots should load") + .https_or_http() + .enable_http1() + .enable_http2() + .build(); + Self { + endpoint, + signer: OutboundRequestSigner::new(token) + .expect("outbound agent token should be validated before client construction"), + request_timeout, + client: Client::builder(TokioExecutor::new()).build(connector), + } } } impl OutboundControlPlaneClient for HttpOutboundControlPlaneClient { - fn register(&self, request: AgentRegisterRequest) -> OutboundClientFuture<()> { - let client = self.clone(); - Box::pin(async move { - client - .send_json::<_, serde_json::Value>( - Method::POST, - &request.node_id, - "/v1/agents/register".to_string(), - Some(&request), - client.request_timeout, - ) - .await?; - Ok(()) - }) - } - fn heartbeat(&self, request: AgentHeartbeatRequest) -> OutboundClientFuture<()> { let client = self.clone(); Box::pin(async move { @@ -184,7 +170,9 @@ impl OutboundControlPlaneClient for HttpOutboundControlPlaneClient { &node_id, path, None, - timeout + Duration::from_secs(5), + timeout + .checked_add(Duration::from_secs(5)) + .expect("poll timeout grace period remains representable"), ) .await? .unwrap_or_else(AgentPollResponse::empty)) @@ -207,6 +195,21 @@ impl OutboundControlPlaneClient for HttpOutboundControlPlaneClient { Ok(()) }) } + fn register(&self, request: AgentRegisterRequest) -> OutboundClientFuture<()> { + let client = self.clone(); + Box::pin(async move { + client + .send_json::<_, serde_json::Value>( + Method::POST, + &request.node_id, + "/v1/agents/register".to_string(), + Some(&request), + client.request_timeout, + ) + .await?; + Ok(()) + }) + } } fn header_value(value: &str) -> Result { @@ -241,6 +244,3 @@ fn command_result_path(node_id: &str, command_id: &str) -> String { encode_path_segment(command_id) ) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/outbound/command.rs b/crates/rginx-agent/src/outbound/command.rs index dc9e74cf..ffde67e0 100644 --- a/crates/rginx-agent/src/outbound/command.rs +++ b/crates/rginx-agent/src/outbound/command.rs @@ -10,6 +10,32 @@ use crate::error::{Error, Result}; use super::model::{AgentCommand, AgentCommandType}; +#[derive(Deserialize)] +struct SnapshotPayload { + #[serde(default)] + window_secs: Option, +} + +#[derive(Deserialize)] +struct DesiredRevisionPayload { + desired_revision: Option, +} + +#[derive(Deserialize)] +struct CachePurgePayload { + key: Option, + prefix: Option, + zone_name: String, +} + +#[derive(Deserialize)] +struct CacheInvalidatePayload { + key: Option, + prefix: Option, + tag: Option, + zone_name: String, +} + pub(super) async fn execute( core: &AgentCore, node_id: &str, @@ -76,32 +102,6 @@ fn validate_command( Ok(()) } -#[derive(Deserialize)] -struct SnapshotPayload { - #[serde(default)] - window_secs: Option, -} - -#[derive(Deserialize)] -struct DesiredRevisionPayload { - desired_revision: Option, -} - -#[derive(Deserialize)] -struct CachePurgePayload { - zone_name: String, - key: Option, - prefix: Option, -} - -#[derive(Deserialize)] -struct CacheInvalidatePayload { - zone_name: String, - key: Option, - prefix: Option, - tag: Option, -} - fn desired_revision(command: &AgentCommand) -> Result { let payload = payload::(&command.payload)?; payload diff --git a/crates/rginx-agent/src/outbound/model.rs b/crates/rginx-agent/src/outbound/model.rs index 7baf3695..b903a415 100644 --- a/crates/rginx-agent/src/outbound/model.rs +++ b/crates/rginx-agent/src/outbound/model.rs @@ -4,21 +4,21 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct AgentRegisterRequest { + pub capabilities: Vec, + pub labels: BTreeMap, pub node_id: String, - pub version: String, - pub region: Option, pub pop: Option, - pub labels: BTreeMap, - pub capabilities: Vec, + pub region: Option, + pub version: String, } #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct AgentHeartbeatRequest { - pub node_id: String, - pub snapshot_version: u64, + pub converged: bool, pub current_revision: u64, pub desired_revision: u64, - pub converged: bool, + pub node_id: String, + pub snapshot_version: u64, } #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] @@ -30,6 +30,7 @@ pub struct AgentPollResponse { } impl AgentPollResponse { + #[must_use] pub fn empty() -> Self { Self { commands: Vec::new(), next_cursor: None } } @@ -37,32 +38,33 @@ impl AgentPollResponse { #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct AgentCommand { - pub id: String, #[serde(rename = "type")] pub command_type: AgentCommandType, - pub target_node_id: String, - #[serde(default)] - pub revision: Option, #[serde(default)] pub expires_at_unix_ms: Option, + pub id: String, #[serde(default)] pub payload: serde_json::Value, #[serde(default)] + pub revision: Option, + #[serde(default)] pub signature: Option, + pub target_node_id: String, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "snake_case")] pub enum AgentCommandType { - Reload, ApplyConfig, - SetDesiredRevision, - CachePurge, CacheInvalidate, + CachePurge, CollectSnapshot, + Reload, + SetDesiredRevision, } impl AgentCommandType { + #[must_use] pub fn as_str(self) -> &'static str { match self { Self::Reload => "reload", @@ -78,17 +80,17 @@ impl AgentCommandType { #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AgentCommandStatus { - Succeeded, Failed, + Succeeded, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentCommandResult { pub command_id: String, - pub node_id: String, - pub status: AgentCommandStatus, - pub started_at_unix_ms: u64, + pub error: Option, pub finished_at_unix_ms: u64, + pub node_id: String, pub result: serde_json::Value, - pub error: Option, + pub started_at_unix_ms: u64, + pub status: AgentCommandStatus, } diff --git a/crates/rginx-agent/src/outbound/runner.rs b/crates/rginx-agent/src/outbound/runner.rs index 2a5ba750..39b7db0d 100644 --- a/crates/rginx-agent/src/outbound/runner.rs +++ b/crates/rginx-agent/src/outbound/runner.rs @@ -23,35 +23,100 @@ use super::stream::{AgentStreamHello, OutboundStreamClient}; use super::timing::{wait_or_shutdown, wait_while_locally_disabled}; pub struct OutboundAgent { - settings: AgentSettings, - core: AgentCore, client: Arc, - stream_client: Option>, - registered: bool, + command_signing_key: Option>, + core: AgentCore, + local_disable: Option>, next_heartbeat_due: Option, + pending_results: VecDeque, + registered: bool, + settings: AgentSettings, state: AgentPersistentState, state_store: Option, - pending_results: VecDeque, - command_signing_key: Option>, - local_disable: Option>, + stream_client: Option>, } impl OutboundAgent { - pub fn with_default_client(settings: AgentSettings, token: String, core: AgentCore) -> Self { - let client = Arc::new(HttpOutboundControlPlaneClient::with_request_timeout( - settings.endpoint.clone(), - token.clone(), - settings.request_timeout, - )); - let stream_client = - Arc::new(super::stream::WebSocketOutboundStreamClient::with_connect_timeout( - settings.endpoint.clone(), - token.clone(), - settings.connect_timeout, - )); - Self::new(settings, core, client) - .with_stream_client(stream_client) - .with_command_signing_key(token) + async fn command_response( + &self, + ) -> Result<(AgentPollResponse, Option>)> { + if let Some(stream_client) = &self.stream_client { + let hello = AgentStreamHello { + node_id: self.settings.node_id.clone(), + version: env!("CARGO_PKG_VERSION").to_string(), + cursor: self.state.command_cursor.clone(), + capabilities: register_request(&self.settings).capabilities, + }; + match stream_client.receive_commands(hello, self.settings.poll_timeout).await { + Ok(batch) => return Ok((batch.into(), Some(stream_client.clone()))), + Err(error) => tracing::warn!( + %error, + "outbound agent stream failed; falling back to long polling" + ), + } + } + Ok(( + self.client + .poll_commands( + self.settings.node_id.clone(), + self.state.command_cursor.clone(), + self.settings.poll_timeout, + ) + .await?, + None, + )) + } + + async fn command_result(&mut self, command: AgentCommand) -> Result { + if let Some(result) = self.state.cached_result(&command.id) { + return Ok(result); + } + + self.state.in_flight_command = Some(AgentInFlightCommand::received(command.clone())); + self.persist_state()?; + if let Some(in_flight) = self.state.in_flight_command.as_mut() { + in_flight.mark(AgentCommandExecutionState::Executing); + } + self.persist_state()?; + + let started_at_unix_ms = unix_ms(); + let outcome = execute( + &self.core, + &self.settings.node_id, + &command, + self.command_signing_key.as_deref(), + ) + .await; + let finished_at_unix_ms = unix_ms(); + + let result = match outcome { + Ok(result) => AgentCommandResult { + command_id: command.id, + node_id: self.settings.node_id.clone(), + status: AgentCommandStatus::Succeeded, + started_at_unix_ms, + finished_at_unix_ms, + result, + error: None, + }, + Err(error) => AgentCommandResult { + command_id: command.id, + node_id: self.settings.node_id.clone(), + status: AgentCommandStatus::Failed, + started_at_unix_ms, + finished_at_unix_ms, + result: serde_json::Value::Null, + error: Some(error.to_string()), + }, + }; + self.state.in_flight_command = None; + self.state.remember_result(result.clone(), self.recent_results_limit()); + self.persist_state()?; + Ok(result) + } + + fn heartbeat_due(&self) -> bool { + self.next_heartbeat_due.is_none_or(|due| Instant::now() >= due) } pub fn new( @@ -75,37 +140,55 @@ impl OutboundAgent { } } - pub fn with_command_signing_key(mut self, key: impl Into) -> Self { - let key = key.into().trim().to_string(); - if !key.is_empty() { - self.command_signing_key = Some(Arc::from(key)); + fn persist_state(&self) -> Result<()> { + if let Some(store) = &self.state_store { + store.save(&self.state)?; } - self + publish_agent_runtime(&self.core, &self.state); + Ok(()) } - pub fn with_stream_client(mut self, client: Arc) -> Self { - self.stream_client = Some(client); - self + async fn post_pending_results(&mut self) -> Result<()> { + while let Some(result) = self.pending_results.front().cloned() { + self.client.post_result(result).await?; + self.pending_results.pop_front(); + } + Ok(()) } - pub fn with_state_path(self, path: impl Into) -> Result { - self.with_state_store(AgentStateStore::new(path)) - } + async fn process_poll_response( + &mut self, + response: AgentPollResponse, + stream_client: Option>, + ) -> Result { + let commands_received = response.commands.len(); + let mut results_posted = 0usize; + let mut last_command_id = None; - pub fn with_local_disable(mut self, disabled: watch::Receiver) -> Self { - self.local_disable = Some(disabled); - self - } + for command in response.commands { + let command_id = command.id.clone(); + let result = self.command_result(command).await?; + if let Some(stream_client) = &stream_client { + stream_client.post_result(result).await?; + } else { + self.client.post_result(result).await?; + } + self.state.command_cursor = Some(command_id.clone()); + self.persist_state()?; + last_command_id = Some(command_id); + results_posted = results_posted.saturating_add(1); + } - pub fn with_state_store(mut self, store: AgentStateStore) -> Result { - self.state = store.load_or_default(&self.settings.node_id)?; - if let Some(result) = self.state.recover_in_flight_as_failure() { - self.state.remember_result(result.clone(), store.recent_results_limit()); - self.pending_results.push_back(result); + if let Some(next_cursor) = response.next_cursor.or(last_command_id) { + self.state.command_cursor = Some(next_cursor); + self.persist_state()?; } - store.save(&self.state)?; - self.state_store = Some(store); - Ok(self) + + Ok(OutboundAgentCycleOutcome { commands_received, results_posted }) + } + + fn recent_results_limit(&self) -> usize { + self.state_store.as_ref().map_or(128, AgentStateStore::recent_results_limit) } pub async fn run(mut self, mut shutdown: watch::Receiver) -> Result<()> { @@ -154,7 +237,11 @@ impl OutboundAgent { self.client.heartbeat(heartbeat_request(&self.settings, &self.core).await?).await?; self.state.last_heartbeat_success_unix_ms = Some(unix_ms()); self.state.connection_state = AgentConnectionState::Connected; - self.next_heartbeat_due = Some(Instant::now() + self.settings.heartbeat_interval); + self.next_heartbeat_due = Some( + Instant::now() + .checked_add(self.settings.heartbeat_interval) + .expect("heartbeat deadline remains representable"), + ); self.persist_state()?; } @@ -164,136 +251,53 @@ impl OutboundAgent { self.process_poll_response(response, stream_client).await } - async fn command_response( - &self, - ) -> Result<(AgentPollResponse, Option>)> { - if let Some(stream_client) = &self.stream_client { - let hello = AgentStreamHello { - node_id: self.settings.node_id.clone(), - version: env!("CARGO_PKG_VERSION").to_string(), - cursor: self.state.command_cursor.clone(), - capabilities: register_request(&self.settings).capabilities, - }; - match stream_client.receive_commands(hello, self.settings.poll_timeout).await { - Ok(batch) => return Ok((batch.into(), Some(stream_client.clone()))), - Err(error) => tracing::warn!( - %error, - "outbound agent stream failed; falling back to long polling" - ), - } - } - Ok(( - self.client - .poll_commands( - self.settings.node_id.clone(), - self.state.command_cursor.clone(), - self.settings.poll_timeout, - ) - .await?, - None, - )) - } - - async fn process_poll_response( - &mut self, - response: AgentPollResponse, - stream_client: Option>, - ) -> Result { - let commands_received = response.commands.len(); - let mut results_posted = 0usize; - let mut last_command_id = None; - - for command in response.commands { - let command_id = command.id.clone(); - let result = self.command_result(command).await?; - if let Some(stream_client) = &stream_client { - stream_client.post_result(result).await?; - } else { - self.client.post_result(result).await?; - } - self.state.command_cursor = Some(command_id.clone()); - self.persist_state()?; - last_command_id = Some(command_id); - results_posted += 1; - } - - if let Some(next_cursor) = response.next_cursor.or(last_command_id) { - self.state.command_cursor = Some(next_cursor); - self.persist_state()?; + pub fn with_command_signing_key(mut self, key: impl Into) -> Self { + let key = key.into().trim().to_string(); + if !key.is_empty() { + self.command_signing_key = Some(Arc::from(key)); } - - Ok(OutboundAgentCycleOutcome { commands_received, results_posted }) + self } - async fn post_pending_results(&mut self) -> Result<()> { - while let Some(result) = self.pending_results.front().cloned() { - self.client.post_result(result).await?; - self.pending_results.pop_front(); - } - Ok(()) + pub fn with_default_client(settings: AgentSettings, token: String, core: AgentCore) -> Self { + let client = Arc::new(HttpOutboundControlPlaneClient::with_request_timeout( + settings.endpoint.clone(), + token.clone(), + settings.request_timeout, + )); + let stream_client = + Arc::new(super::stream::WebSocketOutboundStreamClient::with_connect_timeout( + settings.endpoint.clone(), + token.clone(), + settings.connect_timeout, + )); + Self::new(settings, core, client) + .with_stream_client(stream_client) + .with_command_signing_key(token) } - async fn command_result(&mut self, command: AgentCommand) -> Result { - if let Some(result) = self.state.cached_result(&command.id) { - return Ok(result); - } - - self.state.in_flight_command = Some(AgentInFlightCommand::received(command.clone())); - self.persist_state()?; - if let Some(in_flight) = self.state.in_flight_command.as_mut() { - in_flight.mark(AgentCommandExecutionState::Executing); - } - self.persist_state()?; - - let started_at_unix_ms = unix_ms(); - let outcome = execute( - &self.core, - &self.settings.node_id, - &command, - self.command_signing_key.as_deref(), - ) - .await; - let finished_at_unix_ms = unix_ms(); - - let result = match outcome { - Ok(result) => AgentCommandResult { - command_id: command.id, - node_id: self.settings.node_id.clone(), - status: AgentCommandStatus::Succeeded, - started_at_unix_ms, - finished_at_unix_ms, - result, - error: None, - }, - Err(error) => AgentCommandResult { - command_id: command.id, - node_id: self.settings.node_id.clone(), - status: AgentCommandStatus::Failed, - started_at_unix_ms, - finished_at_unix_ms, - result: serde_json::Value::Null, - error: Some(error.to_string()), - }, - }; - self.state.in_flight_command = None; - self.state.remember_result(result.clone(), self.recent_results_limit()); - self.persist_state()?; - Ok(result) + pub fn with_local_disable(mut self, disabled: watch::Receiver) -> Self { + self.local_disable = Some(disabled); + self } - fn heartbeat_due(&self) -> bool { - self.next_heartbeat_due.is_none_or(|due| Instant::now() >= due) + pub fn with_state_path(self, path: impl Into) -> Result { + self.with_state_store(AgentStateStore::new(path)) } - fn recent_results_limit(&self) -> usize { - self.state_store.as_ref().map(AgentStateStore::recent_results_limit).unwrap_or(128) + pub fn with_state_store(mut self, store: AgentStateStore) -> Result { + self.state = store.load_or_default(&self.settings.node_id)?; + if let Some(result) = self.state.recover_in_flight_as_failure() { + self.state.remember_result(result.clone(), store.recent_results_limit()); + self.pending_results.push_back(result); + } + store.save(&self.state)?; + self.state_store = Some(store); + Ok(self) } - fn persist_state(&self) -> Result<()> { - if let Some(store) = &self.state_store { - store.save(&self.state)?; - } - publish_agent_runtime(&self.core, &self.state); - Ok(()) + pub fn with_stream_client(mut self, client: Arc) -> Self { + self.stream_client = Some(client); + self } } diff --git a/crates/rginx-agent/src/outbound/state.rs b/crates/rginx-agent/src/outbound/state.rs index bbfad1dc..ff071f8e 100644 --- a/crates/rginx-agent/src/outbound/state.rs +++ b/crates/rginx-agent/src/outbound/state.rs @@ -14,15 +14,16 @@ const DEFAULT_RECENT_RESULTS_LIMIT: usize = 128; #[serde(rename_all = "snake_case")] #[derive(Default)] pub enum AgentConnectionState { - #[default] - Starting, - Registering, Connected, Degraded, OfflineRetrying, + Registering, + #[default] + Starting, } impl AgentConnectionState { + #[must_use] pub fn as_str(self) -> &'static str { match self { Self::Starting => "starting", @@ -37,24 +38,29 @@ impl AgentConnectionState { #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum AgentCommandExecutionState { - Received, Accepted, Executing, - Succeeded, Failed, + Received, Rejected, + Succeeded, TimedOut, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentInFlightCommand { pub command: AgentCommand, - pub state: AgentCommandExecutionState, pub received_at_unix_ms: u64, + pub state: AgentCommandExecutionState, pub updated_at_unix_ms: u64, } impl AgentInFlightCommand { + pub fn mark(&mut self, state: AgentCommandExecutionState) { + self.state = state; + self.updated_at_unix_ms = unix_ms(); + } + #[must_use] pub fn received(command: AgentCommand) -> Self { let now = unix_ms(); Self { @@ -64,31 +70,31 @@ impl AgentInFlightCommand { updated_at_unix_ms: now, } } - - pub fn mark(&mut self, state: AgentCommandExecutionState) { - self.state = state; - self.updated_at_unix_ms = unix_ms(); - } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentPersistentState { - pub version: u32, - pub node_id: String, - pub connection_state: AgentConnectionState, #[serde(default)] pub command_cursor: Option, + pub connection_state: AgentConnectionState, #[serde(default)] pub in_flight_command: Option, #[serde(default)] - pub recent_results: Vec, + pub last_heartbeat_success_unix_ms: Option, #[serde(default)] pub last_register_success_unix_ms: Option, + pub node_id: String, #[serde(default)] - pub last_heartbeat_success_unix_ms: Option, + pub recent_results: Vec, + pub version: u32, } impl AgentPersistentState { + #[must_use] + pub fn cached_result(&self, command_id: &str) -> Option { + self.recent_results.iter().find(|result| result.command_id == command_id).cloned() + } + pub fn default_for_node(node_id: impl Into) -> Self { Self { version: STATE_VERSION, @@ -102,20 +108,6 @@ impl AgentPersistentState { } } - pub fn cached_result(&self, command_id: &str) -> Option { - self.recent_results.iter().find(|result| result.command_id == command_id).cloned() - } - - pub fn remember_result(&mut self, result: AgentCommandResult, limit: usize) { - self.recent_results.retain(|stored| stored.command_id != result.command_id); - self.recent_results.push(result); - let limit = limit.max(1); - if self.recent_results.len() > limit { - let drain_count = self.recent_results.len() - limit; - self.recent_results.drain(0..drain_count); - } - } - pub fn recover_in_flight_as_failure(&mut self) -> Option { let in_flight = self.in_flight_command.take()?; if self.cached_result(&in_flight.command.id).is_some() { @@ -135,6 +127,16 @@ impl AgentPersistentState { )), }) } + + pub fn remember_result(&mut self, result: AgentCommandResult, limit: usize) { + self.recent_results.retain(|stored| stored.command_id != result.command_id); + self.recent_results.push(result); + let limit = limit.max(1); + if self.recent_results.len() > limit { + let drain_count = self.recent_results.len().saturating_sub(limit); + self.recent_results.drain(0..drain_count); + } + } } #[derive(Debug, Clone)] @@ -144,18 +146,6 @@ pub struct AgentStateStore { } impl AgentStateStore { - pub fn new(path: impl Into) -> Self { - Self { path: path.into(), recent_results_limit: DEFAULT_RECENT_RESULTS_LIMIT } - } - - pub fn path(&self) -> &Path { - &self.path - } - - pub fn recent_results_limit(&self) -> usize { - self.recent_results_limit - } - pub fn load_or_default(&self, node_id: &str) -> Result { let bytes = match std::fs::read(&self.path) { Ok(bytes) => bytes, @@ -178,6 +168,20 @@ impl AgentStateStore { } } + pub fn new(path: impl Into) -> Self { + Self { path: path.into(), recent_results_limit: DEFAULT_RECENT_RESULTS_LIMIT } + } + + #[must_use] + pub fn path(&self) -> &Path { + &self.path + } + + #[must_use] + pub fn recent_results_limit(&self) -> usize { + self.recent_results_limit + } + pub fn save(&self, state: &AgentPersistentState) -> Result<()> { let parent = self.path.parent().unwrap_or_else(|| Path::new(".")); std::fs::create_dir_all(parent)?; diff --git a/crates/rginx-agent/src/outbound/stream.rs b/crates/rginx-agent/src/outbound/stream.rs index 5b2e4a44..72137e9c 100644 --- a/crates/rginx-agent/src/outbound/stream.rs +++ b/crates/rginx-agent/src/outbound/stream.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::time::Duration; use futures_util::{SinkExt, StreamExt}; @@ -19,21 +21,21 @@ const NODE_ID_HEADER: &str = "x-rginx-node-id"; const STREAM_PATH: &str = "/v1/agents/stream"; pub trait OutboundStreamClient: Send + Sync { + fn post_result(&self, result: AgentCommandResult) -> OutboundClientFuture<()>; fn receive_commands( &self, hello: AgentStreamHello, timeout: Duration, ) -> OutboundClientFuture; - fn post_result(&self, result: AgentCommandResult) -> OutboundClientFuture<()>; } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentStreamHello { - pub node_id: String, - pub version: String, + pub capabilities: Vec, #[serde(default)] pub cursor: Option, - pub capabilities: Vec, + pub node_id: String, + pub version: String, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -43,6 +45,7 @@ pub struct AgentStreamCommandBatch { } impl AgentStreamCommandBatch { + #[must_use] pub fn empty() -> Self { Self { commands: Vec::new(), next_cursor: None } } @@ -57,36 +60,31 @@ impl From for AgentPollResponse { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum AgentStreamMessage { - Hello(AgentStreamHello), - Registered { - #[serde(default)] - node_id: Option, - }, - Heartbeat { - #[serde(default)] - payload: serde_json::Value, - }, Command { command: AgentCommand, #[serde(default)] next_cursor: Option, }, + CommandResult { + result: AgentCommandResult, + }, Commands { commands: Vec, #[serde(default)] next_cursor: Option, }, - CommandResult { - result: AgentCommandResult, + Error { + message: String, }, - SnapshotDelta { + Event { #[serde(default)] payload: serde_json::Value, }, - Event { + Heartbeat { #[serde(default)] payload: serde_json::Value, }, + Hello(AgentStreamHello), Ping { #[serde(default)] id: Option, @@ -95,32 +93,24 @@ pub enum AgentStreamMessage { #[serde(default)] id: Option, }, - Error { - message: String, + Registered { + #[serde(default)] + node_id: Option, + }, + SnapshotDelta { + #[serde(default)] + payload: serde_json::Value, }, } #[derive(Clone)] pub struct WebSocketOutboundStreamClient { + connect_timeout: Duration, endpoint: Uri, signer: OutboundRequestSigner, - connect_timeout: Duration, } impl WebSocketOutboundStreamClient { - pub fn new(endpoint: Uri, token: String) -> Self { - Self::with_connect_timeout(endpoint, token, Duration::from_secs(30)) - } - - pub fn with_connect_timeout(endpoint: Uri, token: String, connect_timeout: Duration) -> Self { - Self { - endpoint, - signer: OutboundRequestSigner::new(token) - .expect("outbound agent token should be validated before stream construction"), - connect_timeout, - } - } - async fn connect( &self, node_id: &str, @@ -147,6 +137,15 @@ impl WebSocketOutboundStreamClient { Ok(stream) } + pub fn new(endpoint: Uri, token: String) -> Self { + Self::with_connect_timeout(endpoint, token, Duration::from_secs(30)) + } + + fn path_for(&self, path_and_query: &str) -> String { + let base_path = self.endpoint.path().trim_end_matches('/'); + format!("{base_path}{path_and_query}") + } + fn stream_uri_for_path(&self, path_and_query: &str) -> Result { let scheme = match self.endpoint.scheme_str() { Some("http") => "ws", @@ -161,13 +160,26 @@ impl WebSocketOutboundStreamClient { .map_err(|error| Error::InvalidRequest(error.to_string())) } - fn path_for(&self, path_and_query: &str) -> String { - let base_path = self.endpoint.path().trim_end_matches('/'); - format!("{base_path}{path_and_query}") + pub fn with_connect_timeout(endpoint: Uri, token: String, connect_timeout: Duration) -> Self { + Self { + endpoint, + signer: OutboundRequestSigner::new(token) + .expect("outbound agent token should be validated before stream construction"), + connect_timeout, + } } } impl OutboundStreamClient for WebSocketOutboundStreamClient { + fn post_result(&self, result: AgentCommandResult) -> OutboundClientFuture<()> { + let client = self.clone(); + Box::pin(async move { + let mut stream = client.connect(&result.node_id, None).await?; + send_json_message(&mut stream, &AgentStreamMessage::CommandResult { result }).await?; + let _ = stream.close(None).await; + Ok(()) + }) + } fn receive_commands( &self, hello: AgentStreamHello, @@ -213,7 +225,7 @@ impl OutboundStreamClient for WebSocketOutboundStreamClient { Message::Ping(data) => { stream.send(Message::Pong(data)).await.map_err(|error| { Error::Server(format!("outbound agent stream failed: {error}")) - })? + })?; } Message::Close(_) => { return Err(Error::Server("outbound agent stream closed".into())); @@ -223,16 +235,6 @@ impl OutboundStreamClient for WebSocketOutboundStreamClient { } }) } - - fn post_result(&self, result: AgentCommandResult) -> OutboundClientFuture<()> { - let client = self.clone(); - Box::pin(async move { - let mut stream = client.connect(&result.node_id, None).await?; - send_json_message(&mut stream, &AgentStreamMessage::CommandResult { result }).await?; - let _ = stream.close(None).await; - Ok(()) - }) - } } fn stream_path(cursor: Option<&str>) -> String { @@ -255,6 +257,3 @@ async fn send_json_message( .await .map_err(|error| Error::Server(format!("outbound agent stream failed: {error}"))) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/outbound/timing.rs b/crates/rginx-agent/src/outbound/timing.rs index 589cad38..353d2ccb 100644 --- a/crates/rginx-agent/src/outbound/timing.rs +++ b/crates/rginx-agent/src/outbound/timing.rs @@ -7,7 +7,7 @@ pub(super) async fn wait_or_shutdown( shutdown: &mut watch::Receiver, ) -> bool { tokio::select! { - _ = tokio::time::sleep(duration) => false, + () = tokio::time::sleep(duration) => false, changed = shutdown.changed() => changed.is_err() || *shutdown.borrow(), } } diff --git a/crates/rginx-agent/src/rate_limit.rs b/crates/rginx-agent/src/rate_limit.rs index 50bb540d..e29b23f5 100644 --- a/crates/rginx-agent/src/rate_limit.rs +++ b/crates/rginx-agent/src/rate_limit.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -26,72 +28,58 @@ impl Default for RateLimitConfig { #[derive(Debug, Clone, Copy)] pub struct RateLimit { - pub requests_per_second: u32, pub burst: u32, + pub requests_per_second: u32, } // Token bucket implementation pub struct TokenBucket { capacity: u32, - tokens: f64, - refill_rate: f64, // tokens per second last_refill: Instant, + refill_rate: f64, + tokens: f64, } impl TokenBucket { - pub fn new(capacity: u32, refill_rate: f64) -> Self { - Self { capacity, tokens: capacity as f64, refill_rate, last_refill: Instant::now() } - } - - pub fn try_acquire(&mut self, tokens: u32) -> bool { + pub fn available_tokens(&mut self) -> u32 { self.refill(); - - if self.tokens >= tokens as f64 { - self.tokens -= tokens as f64; - true - } else { - false - } + self.tokens as u32 + } + #[must_use] + pub fn new(capacity: u32, refill_rate: f64) -> Self { + Self { capacity, tokens: f64::from(capacity), refill_rate, last_refill: Instant::now() } } fn refill(&mut self) { let now = Instant::now(); let elapsed = now.duration_since(self.last_refill).as_secs_f64(); let new_tokens = elapsed * self.refill_rate; - self.tokens = (self.tokens + new_tokens).min(self.capacity as f64); + self.tokens = (self.tokens + new_tokens).min(f64::from(self.capacity)); self.last_refill = now; } - pub fn available_tokens(&mut self) -> u32 { + pub fn try_acquire(&mut self, tokens: u32) -> bool { self.refill(); - self.tokens as u32 + + if self.tokens >= f64::from(tokens) { + self.tokens -= f64::from(tokens); + true + } else { + false + } } } // Rate limiter pub struct RateLimiter { - config: RateLimitConfig, - global_bucket: Arc>>, api_key_buckets: Arc>>, + config: RateLimitConfig, endpoint_buckets: Arc>>, + global_bucket: Arc>>, ip_buckets: Arc>>, } impl RateLimiter { - pub fn new(config: RateLimitConfig) -> Self { - let global_bucket = config - .global - .map(|limit| TokenBucket::new(limit.burst, limit.requests_per_second as f64)); - - Self { - config, - global_bucket: Arc::new(RwLock::new(global_bucket)), - api_key_buckets: Arc::new(RwLock::new(HashMap::new())), - endpoint_buckets: Arc::new(RwLock::new(HashMap::new())), - ip_buckets: Arc::new(RwLock::new(HashMap::new())), - } - } - pub async fn check_rate_limit( &self, api_key_id: Option<&str>, @@ -113,12 +101,12 @@ impl RateLimiter { && let Some(limit) = &self.config.per_api_key { let mut buckets = self.api_key_buckets.write().await; - let bucket = buckets - .entry(key_id.to_string()) - .or_insert_with(|| TokenBucket::new(limit.burst, limit.requests_per_second as f64)); + let bucket = buckets.entry(key_id.to_string()).or_insert_with(|| { + TokenBucket::new(limit.burst, f64::from(limit.requests_per_second)) + }); if !bucket.try_acquire(1) { return Ok(RateLimitDecision::Reject { - reason: format!("api key {} rate limit exceeded", key_id), + reason: format!("api key {key_id} rate limit exceeded"), retry_after_secs: 1, }); } @@ -127,12 +115,12 @@ impl RateLimiter { // 3. Check endpoint rate limit if let Some(limit) = self.config.per_endpoint.get(endpoint) { let mut buckets = self.endpoint_buckets.write().await; - let bucket = buckets - .entry(endpoint.to_string()) - .or_insert_with(|| TokenBucket::new(limit.burst, limit.requests_per_second as f64)); + let bucket = buckets.entry(endpoint.to_string()).or_insert_with(|| { + TokenBucket::new(limit.burst, f64::from(limit.requests_per_second)) + }); if !bucket.try_acquire(1) { return Ok(RateLimitDecision::Reject { - reason: format!("endpoint {} rate limit exceeded", endpoint), + reason: format!("endpoint {endpoint} rate limit exceeded"), retry_after_secs: 1, }); } @@ -141,12 +129,12 @@ impl RateLimiter { // 4. Check IP rate limit if let Some(limit) = &self.config.per_ip { let mut buckets = self.ip_buckets.write().await; - let bucket = buckets - .entry(client_ip.to_string()) - .or_insert_with(|| TokenBucket::new(limit.burst, limit.requests_per_second as f64)); + let bucket = buckets.entry(client_ip.to_string()).or_insert_with(|| { + TokenBucket::new(limit.burst, f64::from(limit.requests_per_second)) + }); if !bucket.try_acquire(1) { return Ok(RateLimitDecision::Reject { - reason: format!("ip {} rate limit exceeded", client_ip), + reason: format!("ip {client_ip} rate limit exceeded"), retry_after_secs: 1, }); } @@ -171,6 +159,20 @@ impl RateLimiter { let mut ip_buckets = self.ip_buckets.write().await; ip_buckets.retain(|_, bucket| now.duration_since(bucket.last_refill) < max_age); } + #[must_use] + pub fn new(config: RateLimitConfig) -> Self { + let global_bucket = config + .global + .map(|limit| TokenBucket::new(limit.burst, f64::from(limit.requests_per_second))); + + Self { + config, + global_bucket: Arc::new(RwLock::new(global_bucket)), + api_key_buckets: Arc::new(RwLock::new(HashMap::new())), + endpoint_buckets: Arc::new(RwLock::new(HashMap::new())), + ip_buckets: Arc::new(RwLock::new(HashMap::new())), + } + } } #[derive(Debug, Clone)] @@ -178,6 +180,3 @@ pub enum RateLimitDecision { Allow, Reject { reason: String, retry_after_secs: u64 }, } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/registry.rs b/crates/rginx-agent/src/registry.rs index 0ddf2aa9..b0f3f6ef 100644 --- a/crates/rginx-agent/src/registry.rs +++ b/crates/rginx-agent/src/registry.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; @@ -11,46 +13,46 @@ use crate::metrics; /// Node registration information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeRegistration { - pub node_id: String, - pub region: Option, - pub pop: Option, pub capabilities: Vec, pub control_plane_addr: String, pub labels: HashMap, #[serde(default)] pub metadata: HashMap, + pub node_id: String, + pub pop: Option, + pub region: Option, } /// Node information including registration and runtime state #[derive(Debug, Clone, Serialize)] pub struct NodeInfo { - pub registration: NodeRegistration, - pub status: NodeStatus, pub health: NodeHealth, - pub registered_at: u64, - pub last_heartbeat_at: u64, pub heartbeat_interval_secs: u64, + pub last_heartbeat_at: u64, + pub registered_at: u64, + pub registration: NodeRegistration, + pub status: NodeStatus, } /// Node status #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum NodeStatus { + Draining, Healthy, - Unhealthy, Offline, - Draining, + Unhealthy, } /// Node health metrics #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeHealth { + pub active_connections: u64, + pub disk_usage_percent: f64, + pub load_avg_15m: f64, pub load_avg_1m: f64, pub load_avg_5m: f64, - pub load_avg_15m: f64, pub memory_usage_percent: f64, - pub disk_usage_percent: f64, - pub active_connections: u64, pub requests_per_second: f64, } @@ -70,16 +72,68 @@ impl Default for NodeHealth { /// Node registry for managing edge nodes pub struct NodeRegistry { - nodes: Arc>>, heartbeat_timeout: Duration, + nodes: Arc>>, } impl NodeRegistry { + /// Check for heartbeat timeouts and mark nodes as offline + pub async fn check_heartbeat_timeouts(&self) { + let now = current_timestamp_ms(); + let timeout_ms = self.heartbeat_timeout.as_millis() as u64; + + let mut nodes = self.nodes.write().await; + for (node_id, node) in nodes.iter_mut() { + let elapsed = now.saturating_sub(node.last_heartbeat_at); + if elapsed > timeout_ms && node.status != NodeStatus::Offline { + node.status = NodeStatus::Offline; + tracing::warn!( + node_id = %node_id, + elapsed_secs = elapsed / 1000, + "node marked offline due to heartbeat timeout" + ); + } + } + } + + /// Get a specific node by ID + pub async fn get_node(&self, node_id: &str) -> Option { + let nodes = self.nodes.read().await; + nodes.get(node_id).cloned() + } + + /// Update node heartbeat + pub async fn heartbeat(&self, node_id: &str, health: NodeHealth) -> Result { + let mut nodes = self.nodes.write().await; + let node = nodes + .get_mut(node_id) + .ok_or_else(|| Error::InvalidRequest(format!("node `{node_id}` not registered")))?; + + node.last_heartbeat_at = current_timestamp_ms(); + node.health = health; + node.status = NodeStatus::Healthy; + + Ok(node.clone()) + } + + /// List all nodes matching the filter + pub async fn list_nodes(&self, filter: NodeFilter) -> Vec { + let nodes = self.nodes.read().await; + nodes.values().filter(|node| filter.matches(node)).cloned().collect() + } + /// Create a new node registry + #[must_use] pub fn new(heartbeat_timeout: Duration) -> Self { Self { nodes: Arc::new(RwLock::new(HashMap::new())), heartbeat_timeout } } + /// Get the number of registered nodes + pub async fn node_count(&self) -> usize { + let nodes = self.nodes.read().await; + nodes.len() + } + /// Register a new node pub async fn register(&self, registration: NodeRegistration) -> Result { let now = current_timestamp_ms(); @@ -109,26 +163,12 @@ impl NodeRegistry { Ok(node_info) } - /// Update node heartbeat - pub async fn heartbeat(&self, node_id: &str, health: NodeHealth) -> Result { - let mut nodes = self.nodes.write().await; - let node = nodes - .get_mut(node_id) - .ok_or_else(|| Error::InvalidRequest(format!("node `{}` not registered", node_id)))?; - - node.last_heartbeat_at = current_timestamp_ms(); - node.health = health; - node.status = NodeStatus::Healthy; - - Ok(node.clone()) - } - /// Unregister a node pub async fn unregister(&self, node_id: &str) -> Result<()> { let mut nodes = self.nodes.write().await; nodes .remove(node_id) - .ok_or_else(|| Error::InvalidRequest(format!("node `{}` not registered", node_id)))?; + .ok_or_else(|| Error::InvalidRequest(format!("node `{node_id}` not registered")))?; let node_count = nodes.len() as f64; drop(nodes); @@ -137,56 +177,20 @@ impl NodeRegistry { tracing::info!(node_id = %node_id, "node unregistered"); Ok(()) } - - /// List all nodes matching the filter - pub async fn list_nodes(&self, filter: NodeFilter) -> Vec { - let nodes = self.nodes.read().await; - nodes.values().filter(|node| filter.matches(node)).cloned().collect() - } - - /// Get a specific node by ID - pub async fn get_node(&self, node_id: &str) -> Option { - let nodes = self.nodes.read().await; - nodes.get(node_id).cloned() - } - - /// Check for heartbeat timeouts and mark nodes as offline - pub async fn check_heartbeat_timeouts(&self) { - let now = current_timestamp_ms(); - let timeout_ms = self.heartbeat_timeout.as_millis() as u64; - - let mut nodes = self.nodes.write().await; - for (node_id, node) in nodes.iter_mut() { - let elapsed = now.saturating_sub(node.last_heartbeat_at); - if elapsed > timeout_ms && node.status != NodeStatus::Offline { - node.status = NodeStatus::Offline; - tracing::warn!( - node_id = %node_id, - elapsed_secs = elapsed / 1000, - "node marked offline due to heartbeat timeout" - ); - } - } - } - - /// Get the number of registered nodes - pub async fn node_count(&self) -> usize { - let nodes = self.nodes.read().await; - nodes.len() - } } /// Filter for querying nodes #[derive(Debug, Clone, Default)] pub struct NodeFilter { - pub region: Option, + pub labels: HashMap, pub pop: Option, + pub region: Option, pub status: Option, - pub labels: HashMap, } impl NodeFilter { /// Check if a node matches this filter + #[must_use] pub fn matches(&self, node: &NodeInfo) -> bool { if let Some(region) = &self.region && node.registration.region.as_ref() != Some(region) @@ -217,7 +221,8 @@ impl NodeFilter { } impl NodeInfo { - /// Add missing node_id field for serialization + /// Add missing `node_id` field for serialization + #[must_use] pub fn node_id(&self) -> &str { &self.registration.node_id } @@ -226,6 +231,3 @@ impl NodeInfo { pub(crate) fn current_timestamp_ms() -> u64 { std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64 } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/registry/tests.rs b/crates/rginx-agent/src/registry/tests.rs index 42094598..3d617396 100644 --- a/crates/rginx-agent/src/registry/tests.rs +++ b/crates/rginx-agent/src/registry/tests.rs @@ -2,7 +2,7 @@ use super::*; #[tokio::test] async fn test_node_registration() { - let registry = NodeRegistry::new(Duration::from_secs(60)); + let registry = NodeRegistry::new(Duration::from_mins(1)); let registration = NodeRegistration { node_id: "test-node-1".to_string(), @@ -21,7 +21,7 @@ async fn test_node_registration() { #[tokio::test] async fn test_heartbeat() { - let registry = NodeRegistry::new(Duration::from_secs(60)); + let registry = NodeRegistry::new(Duration::from_mins(1)); let registration = NodeRegistration { node_id: "test-node-1".to_string(), @@ -51,7 +51,7 @@ async fn test_heartbeat() { #[tokio::test] async fn test_node_filter() { - let registry = NodeRegistry::new(Duration::from_secs(60)); + let registry = NodeRegistry::new(Duration::from_mins(1)); let registration1 = NodeRegistration { node_id: "node-1".to_string(), diff --git a/crates/rginx-agent/src/server/breaker.rs b/crates/rginx-agent/src/server/breaker.rs index 0f8e2588..8d55c781 100644 --- a/crates/rginx-agent/src/server/breaker.rs +++ b/crates/rginx-agent/src/server/breaker.rs @@ -24,7 +24,7 @@ pub async fn handle_get_circuit_breaker_stats( registry: Arc, ) -> Result>, String> { let breaker = - registry.get(name).await.ok_or_else(|| format!("Circuit breaker {} not found", name))?; + registry.get(name).await.ok_or_else(|| format!("Circuit breaker {name} not found"))?; let stats = breaker.get_stats().await; let response = serde_json::to_string(&stats).unwrap(); @@ -54,7 +54,7 @@ pub async fn handle_reset_circuit_breaker( name: &str, registry: Arc, ) -> Result>, String> { - registry.reset(name).await.map_err(|e| format!("Failed to reset circuit breaker: {}", e))?; + registry.reset(name).await.map_err(|e| format!("Failed to reset circuit breaker: {e}"))?; let response = json!({ "name": name, diff --git a/crates/rginx-agent/src/server/config.rs b/crates/rginx-agent/src/server/config.rs index ddd14218..19900ba3 100644 --- a/crates/rginx-agent/src/server/config.rs +++ b/crates/rginx-agent/src/server/config.rs @@ -9,6 +9,38 @@ use crate::config_validator::ConfigValidator; use crate::error::{Error, Result}; use crate::server::response::json_response; +// Request/Response types + +#[derive(Debug, Deserialize)] +struct ValidateRequest { + config: serde_json::Value, +} + +#[derive(Debug, Serialize)] +struct ConfigHistoryListResponse { + revisions: Vec, + total: usize, +} + +#[derive(Debug, Serialize)] +struct ConfigRevisionSummary { + applied_at: u64, + applied_by: String, + config_hash: String, + #[serde(skip_serializing_if = "Option::is_none")] + diff_summary: Option, + metadata: ConfigMetadata, + revision: u64, + status: crate::config_history::ConfigApplyStatus, +} + +#[derive(Debug, Serialize)] +struct ConfigDiffResponse { + diff: crate::config_history::ConfigDiff, + from_revision: u64, + to_revision: u64, +} + /// Handle config history list pub(super) async fn handle_config_history_list( request: Request, @@ -48,7 +80,7 @@ pub(super) async fn handle_config_history_get( let config_revision = history .get(revision) .await - .ok_or_else(|| Error::InvalidRequest(format!("revision {} not found", revision)))?; + .ok_or_else(|| Error::InvalidRequest(format!("revision {revision} not found")))?; json_response(config_revision) } @@ -76,7 +108,7 @@ pub(super) async fn handle_config_validate( ) -> Result>> { let body = request.into_body().collect().await?.to_bytes(); let validate_req: ValidateRequest = serde_json::from_slice(&body) - .map_err(|e| Error::InvalidRequest(format!("invalid validation payload: {}", e)))?; + .map_err(|e| Error::InvalidRequest(format!("invalid validation payload: {e}")))?; let result = validator.validate_dry_run(&validate_req.config).await?; @@ -145,35 +177,3 @@ fn parse_diff_query(query: &str) -> Result<(u64, u64)> { Ok((from, to)) } - -// Request/Response types - -#[derive(Debug, Deserialize)] -struct ValidateRequest { - config: serde_json::Value, -} - -#[derive(Debug, Serialize)] -struct ConfigHistoryListResponse { - revisions: Vec, - total: usize, -} - -#[derive(Debug, Serialize)] -struct ConfigRevisionSummary { - revision: u64, - applied_at: u64, - applied_by: String, - status: crate::config_history::ConfigApplyStatus, - config_hash: String, - #[serde(skip_serializing_if = "Option::is_none")] - diff_summary: Option, - metadata: ConfigMetadata, -} - -#[derive(Debug, Serialize)] -struct ConfigDiffResponse { - from_revision: u64, - to_revision: u64, - diff: crate::config_history::ConfigDiff, -} diff --git a/crates/rginx-agent/src/server/control.rs b/crates/rginx-agent/src/server/control.rs index 65a9a431..e023e65e 100644 --- a/crates/rginx-agent/src/server/control.rs +++ b/crates/rginx-agent/src/server/control.rs @@ -37,15 +37,55 @@ pub struct ConfigApplyOutcome { #[derive(Clone)] pub struct ControlPlaneContext { agent_core: AgentCore, - node_registry: Arc, - event_bus: Arc, + circuit_breaker_registry: Arc, config_history: Arc, config_validator: Arc, + event_bus: Arc, + node_registry: Arc, rollout_manager: Arc, - circuit_breaker_registry: Arc, } impl ControlPlaneContext { + pub async fn action_status(&self, accepted_revision: u64) -> NodeActionStatusView { + self.agent_core.action_status(accepted_revision).await + } + + #[must_use] + pub fn agent_core(&self) -> &AgentCore { + &self.agent_core + } + + #[must_use] + pub fn circuit_breaker_registry(&self) -> &Arc { + &self.circuit_breaker_registry + } + + #[must_use] + pub fn config_history(&self) -> &Arc { + &self.config_history + } + + #[must_use] + pub fn config_validator(&self) -> &Arc { + &self.config_validator + } + + #[must_use] + pub fn event_bus(&self) -> &Arc { + &self.event_bus + } + + pub async fn execute_config_apply( + &self, + request: ManagedResourceMutation, + ) -> Result> { + self.agent_core.apply_config(request).await + } + + pub async fn execute_reload(&self) -> Result { + self.agent_core.reload().await + } + pub fn new(state: SharedState, reload_executor: Arc) -> Self { let temp_dir = std::env::temp_dir().join("rginx-config-history"); Self { @@ -61,6 +101,21 @@ impl ControlPlaneContext { } } + #[must_use] + pub fn node_registry(&self) -> &Arc { + &self.node_registry + } + + #[must_use] + pub fn rollout_manager(&self) -> &Arc { + &self.rollout_manager + } + + #[must_use] + pub fn shared_state(&self) -> &SharedState { + self.agent_core.shared_state() + } + pub fn with_config_apply_executor( mut self, config_apply_executor: Arc, @@ -69,71 +124,27 @@ impl ControlPlaneContext { self } - pub fn with_node_registry(mut self, node_registry: Arc) -> Self { - self.node_registry = node_registry; + #[must_use] + pub fn with_config_history(mut self, config_history: Arc) -> Self { + self.config_history = config_history; self } + #[must_use] pub fn with_event_bus(mut self, event_bus: Arc) -> Self { self.event_bus = event_bus; self } - pub fn with_config_history(mut self, config_history: Arc) -> Self { - self.config_history = config_history; + #[must_use] + pub fn with_node_registry(mut self, node_registry: Arc) -> Self { + self.node_registry = node_registry; self } - pub fn agent_core(&self) -> &AgentCore { - &self.agent_core - } - - pub fn shared_state(&self) -> &SharedState { - self.agent_core.shared_state() - } - - pub fn node_registry(&self) -> &Arc { - &self.node_registry - } - - pub fn event_bus(&self) -> &Arc { - &self.event_bus - } - - pub fn config_history(&self) -> &Arc { - &self.config_history - } - - pub fn config_validator(&self) -> &Arc { - &self.config_validator - } - - pub fn rollout_manager(&self) -> &Arc { - &self.rollout_manager - } - - pub fn circuit_breaker_registry(&self) -> &Arc { - &self.circuit_breaker_registry - } - - pub async fn execute_reload(&self) -> Result { - self.agent_core.reload().await - } - - pub async fn action_status(&self, accepted_revision: u64) -> NodeActionStatusView { - self.agent_core.action_status(accepted_revision).await - } - pub async fn wrap_result(&self, result: T) -> NodeControlResultView { self.agent_core.wrap_result(result).await } - - pub async fn execute_config_apply( - &self, - request: ManagedResourceMutation, - ) -> Result> { - self.agent_core.apply_config(request).await - } } #[cfg(unix)] @@ -143,6 +154,7 @@ pub struct ProcessSignalReloadExecutor { #[cfg(unix)] impl ProcessSignalReloadExecutor { + #[must_use] pub fn current_process() -> Self { Self { pid: std::process::id() as libc::pid_t } } diff --git a/crates/rginx-agent/src/server/maintenance.rs b/crates/rginx-agent/src/server/maintenance.rs index c2624a31..fe72e824 100644 --- a/crates/rginx-agent/src/server/maintenance.rs +++ b/crates/rginx-agent/src/server/maintenance.rs @@ -11,11 +11,11 @@ pub(super) fn spawn_rate_limiter_cleanup( mut shutdown: watch::Receiver, ) { tokio::spawn(async move { - let mut interval = tokio::time::interval(Duration::from_secs(300)); + let mut interval = tokio::time::interval(Duration::from_mins(5)); loop { tokio::select! { _ = interval.tick() => { - rate_limiter.cleanup_stale_buckets(Duration::from_secs(600)).await; + rate_limiter.cleanup_stale_buckets(Duration::from_mins(10)).await; } _ = shutdown.changed() => { if *shutdown.borrow() { diff --git a/crates/rginx-agent/src/server/mod.rs b/crates/rginx-agent/src/server/mod.rs index 5f52f68d..490d934b 100644 --- a/crates/rginx-agent/src/server/mod.rs +++ b/crates/rginx-agent/src/server/mod.rs @@ -1,3 +1,13 @@ +pub(crate) mod breaker; +pub(crate) mod config; +pub mod control; +mod maintenance; +pub(crate) mod registry; +mod request; +mod response; +pub(crate) mod rollout; +mod write; + use std::convert::Infallible; use std::net::SocketAddr; use std::time::Duration; @@ -18,16 +28,6 @@ use crate::error::Result; use crate::rate_limit::{RateLimitConfig, RateLimiter}; use crate::tls::load_tls_server_config; -pub(crate) mod breaker; -pub(crate) mod config; -pub mod control; -mod maintenance; -pub(crate) mod registry; -mod request; -mod response; -pub(crate) mod rollout; -mod write; - const MAX_CONCURRENT_CONNECTIONS: usize = 1024; const TLS_HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); @@ -101,16 +101,13 @@ pub async fn run_with_listener( tracing::warn!(%peer_addr, "control plane denied connection by cidr allowlist"); continue; } - let slot = match connection_slots.clone().try_acquire_owned() { - Ok(slot) => slot, - Err(_) => { - tracing::warn!( - %peer_addr, - limit = MAX_CONCURRENT_CONNECTIONS, - "control plane overloaded: rejecting connection" - ); - continue; - } + let slot = if let Ok(slot) = connection_slots.clone().try_acquire_owned() { slot } else { + tracing::warn!( + %peer_addr, + limit = MAX_CONCURRENT_CONNECTIONS, + "control plane overloaded: rejecting connection" + ); + continue; }; let context = context.clone(); let tls_acceptor = tls_acceptor.clone(); @@ -175,17 +172,13 @@ async fn handle_connection( } } accepted = handshake.as_mut() => { - match accepted { - Ok(result) => break result?, - Err(_) => { - tracing::warn!( - %peer_addr, - timeout_secs = TLS_HANDSHAKE_TIMEOUT.as_secs(), - "control plane tls handshake timed out" - ); - return Ok(()); - } - } + if let Ok(result) = accepted { break result? } + tracing::warn!( + %peer_addr, + timeout_secs = TLS_HANDSHAKE_TIMEOUT.as_secs(), + "control plane tls handshake timed out" + ); + return Ok(()); } } }; diff --git a/crates/rginx-agent/src/server/registry.rs b/crates/rginx-agent/src/server/registry.rs index 4b0e3edc..1182df1c 100644 --- a/crates/rginx-agent/src/server/registry.rs +++ b/crates/rginx-agent/src/server/registry.rs @@ -8,6 +8,65 @@ use crate::error::{Error, Result}; use crate::registry::{NodeFilter, NodeHealth, NodeRegistration, NodeRegistry, NodeStatus}; use crate::server::response::json_response; +// Request/Response types + +#[derive(Debug, Deserialize)] +struct HeartbeatRequest { + health: NodeHealth, +} + +#[derive(Debug, Serialize)] +struct RegisterResponse { + heartbeat_interval_secs: u64, + node_id: String, + registered_at: u64, +} + +#[derive(Debug, Serialize)] +struct HeartbeatResponse { + next_heartbeat_in_secs: u64, + status: NodeStatus, +} + +#[derive(Debug, Serialize)] +struct UnregisterResponse { + unregistered_at: u64, +} + +#[derive(Debug, Serialize)] +struct ListNodesResponse { + nodes: Vec, + total: usize, +} + +#[derive(Debug, Serialize)] +struct NodeSummary { + capabilities: Vec, + health: NodeHealth, + labels: std::collections::HashMap, + last_heartbeat_at: u64, + node_id: String, + pop: Option, + region: Option, + registered_at: u64, + status: NodeStatus, +} + +#[derive(Debug, Serialize)] +struct NodeDetailResponse { + capabilities: Vec, + control_plane_addr: String, + health: NodeHealth, + heartbeat_interval_secs: u64, + labels: std::collections::HashMap, + last_heartbeat_at: u64, + node_id: String, + pop: Option, + region: Option, + registered_at: u64, + status: NodeStatus, +} + /// Register a new node pub(super) async fn handle_register( request: Request, @@ -15,7 +74,7 @@ pub(super) async fn handle_register( ) -> Result>> { let body = request.into_body().collect().await?.to_bytes(); let registration: NodeRegistration = serde_json::from_slice(&body) - .map_err(|e| Error::InvalidRequest(format!("invalid registration payload: {}", e)))?; + .map_err(|e| Error::InvalidRequest(format!("invalid registration payload: {e}")))?; let node_info = registry.register(registration).await?; @@ -36,7 +95,7 @@ pub(super) async fn handle_heartbeat( ) -> Result>> { let body = request.into_body().collect().await?.to_bytes(); let heartbeat_req: HeartbeatRequest = serde_json::from_slice(&body) - .map_err(|e| Error::InvalidRequest(format!("invalid heartbeat payload: {}", e)))?; + .map_err(|e| Error::InvalidRequest(format!("invalid heartbeat payload: {e}")))?; let node_info = registry.heartbeat(&node_id, heartbeat_req.health).await?; @@ -101,7 +160,7 @@ pub(super) async fn handle_get_node( let node = registry .get_node(&node_id) .await - .ok_or_else(|| Error::InvalidRequest(format!("node `{}` not found", node_id)))?; + .ok_or_else(|| Error::InvalidRequest(format!("node `{node_id}` not found")))?; let response = NodeDetailResponse { node_id: node.registration.node_id.clone(), @@ -159,65 +218,6 @@ fn parse_node_status(s: &str) -> Result { "unhealthy" => Ok(NodeStatus::Unhealthy), "offline" => Ok(NodeStatus::Offline), "draining" => Ok(NodeStatus::Draining), - _ => Err(Error::InvalidRequest(format!("invalid node status: {}", s))), + _ => Err(Error::InvalidRequest(format!("invalid node status: {s}"))), } } - -// Request/Response types - -#[derive(Debug, Deserialize)] -struct HeartbeatRequest { - health: NodeHealth, -} - -#[derive(Debug, Serialize)] -struct RegisterResponse { - node_id: String, - registered_at: u64, - heartbeat_interval_secs: u64, -} - -#[derive(Debug, Serialize)] -struct HeartbeatResponse { - status: NodeStatus, - next_heartbeat_in_secs: u64, -} - -#[derive(Debug, Serialize)] -struct UnregisterResponse { - unregistered_at: u64, -} - -#[derive(Debug, Serialize)] -struct ListNodesResponse { - nodes: Vec, - total: usize, -} - -#[derive(Debug, Serialize)] -struct NodeSummary { - node_id: String, - region: Option, - pop: Option, - status: NodeStatus, - registered_at: u64, - last_heartbeat_at: u64, - health: NodeHealth, - capabilities: Vec, - labels: std::collections::HashMap, -} - -#[derive(Debug, Serialize)] -struct NodeDetailResponse { - node_id: String, - region: Option, - pop: Option, - status: NodeStatus, - health: NodeHealth, - capabilities: Vec, - labels: std::collections::HashMap, - registered_at: u64, - last_heartbeat_at: u64, - heartbeat_interval_secs: u64, - control_plane_addr: String, -} diff --git a/crates/rginx-agent/src/server/request.rs b/crates/rginx-agent/src/server/request.rs index d97d5a44..7a2244fd 100644 --- a/crates/rginx-agent/src/server/request.rs +++ b/crates/rginx-agent/src/server/request.rs @@ -1,3 +1,9 @@ +mod query; +mod read; +mod resource; +#[cfg(test)] +mod tests; + use std::net::SocketAddr; use bytes::Bytes; @@ -18,12 +24,6 @@ use crate::server::response::error_response; use crate::server::write; use crate::tls::ClientCertIdentity; -mod query; -mod read; -mod resource; -#[cfg(test)] -mod tests; - use self::read::route_get_request; use self::resource::request_resource; @@ -40,8 +40,7 @@ pub(super) async fn handle_request( let path = request.uri().path().to_string(); let resource = request_resource(&method, &path); let requirement = resource - .map(ControlPlaneResource::authorization_requirement) - .unwrap_or(AuthorizationRequirement::AnyRead); + .map_or(AuthorizationRequirement::AnyRead, ControlPlaneResource::authorization_requirement); let audit = AuditContext { method: &method, path: &path, peer_addr, resource, requirement }; let auth_method = @@ -100,17 +99,16 @@ pub(super) async fn handle_request( return response; } - let resource = match resource { - Some(resource) => resource, - None => { - let error = Error::InvalidRequest(format!("unknown control plane path `{path}`")); - log_deny(&audit, Some(&actor_id), &scope_labels, &error); - let response = error_response(error, peer_addr); - let duration = start_time.elapsed().as_secs_f64(); - metrics::record_request(method.as_ref(), response.status().as_u16(), Some(&actor_id)); - metrics::record_request_duration(method.as_ref(), response.status().as_u16(), duration); - return response; - } + let resource = if let Some(resource) = resource { + resource + } else { + let error = Error::InvalidRequest(format!("unknown control plane path `{path}`")); + log_deny(&audit, Some(&actor_id), &scope_labels, &error); + let response = error_response(error, peer_addr); + let duration = start_time.elapsed().as_secs_f64(); + metrics::record_request(method.as_ref(), response.status().as_u16(), Some(&actor_id)); + metrics::record_request_duration(method.as_ref(), response.status().as_u16(), duration); + return response; }; if let Err(error) = authorize_authenticated_request(&auth_method, resource) { diff --git a/crates/rginx-agent/src/server/request/query.rs b/crates/rginx-agent/src/server/request/query.rs index 0ebe4c5d..cf71641a 100644 --- a/crates/rginx-agent/src/server/request/query.rs +++ b/crates/rginx-agent/src/server/request/query.rs @@ -5,6 +5,52 @@ use crate::error::{Error, Result}; const DEFAULT_RECENT_WINDOW_SECS: u64 = 60; const EXTENDED_RECENT_WINDOW_SECS: u64 = 300; +#[derive(Default)] +struct QueryParams { + entries: Vec<(String, String)>, +} + +impl QueryParams { + fn get_required_since_version(&self) -> Result<&str> { + match (self.get_unique("since_version")?, self.get_unique("since")?) { + (Some(_), Some(_)) => Err(Error::InvalidRequest( + "use either `since_version` or `since`, not both".to_string(), + )), + (Some(value), None) | (None, Some(value)) => Ok(value), + (None, None) => Err(Error::InvalidRequest( + "missing required query parameter `since_version`".to_string(), + )), + } + } + + fn get_unique(&self, key: &str) -> Result> { + let mut values = self + .entries + .iter() + .filter(|(entry_key, _)| entry_key == key) + .map(|(_, value)| value.as_str()); + let first = values.next(); + if values.next().is_some() { + return Err(Error::InvalidRequest(format!("duplicate query parameter `{key}`"))); + } + Ok(first) + } + + fn parse(query: Option<&str>) -> Self { + let mut params = Self::default(); + let Some(query) = query else { + return params; + }; + for pair in query.split('&').filter(|segment| !segment.is_empty()) { + let mut parts = pair.splitn(2, '='); + let key = decode_query_component(parts.next().unwrap_or_default()); + let value = decode_query_component(parts.next().unwrap_or_default()); + params.entries.push((key, value)); + } + params + } +} + pub(super) fn parse_delta_query(query: Option<&str>) -> Result<(u64, Option)> { let params = QueryParams::parse(query); let since_version = params @@ -55,52 +101,6 @@ fn parse_optional_window_secs(value: Option<&str>) -> Result> { } } -#[derive(Default)] -struct QueryParams { - entries: Vec<(String, String)>, -} - -impl QueryParams { - fn parse(query: Option<&str>) -> Self { - let mut params = Self::default(); - let Some(query) = query else { - return params; - }; - for pair in query.split('&').filter(|segment| !segment.is_empty()) { - let mut parts = pair.splitn(2, '='); - let key = decode_query_component(parts.next().unwrap_or_default()); - let value = decode_query_component(parts.next().unwrap_or_default()); - params.entries.push((key, value)); - } - params - } - - fn get_unique(&self, key: &str) -> Result> { - let mut values = self - .entries - .iter() - .filter(|(entry_key, _)| entry_key == key) - .map(|(_, value)| value.as_str()); - let first = values.next(); - if values.next().is_some() { - return Err(Error::InvalidRequest(format!("duplicate query parameter `{key}`"))); - } - Ok(first) - } - - fn get_required_since_version(&self) -> Result<&str> { - match (self.get_unique("since_version")?, self.get_unique("since")?) { - (Some(_), Some(_)) => Err(Error::InvalidRequest( - "use either `since_version` or `since`, not both".to_string(), - )), - (Some(value), None) | (None, Some(value)) => Ok(value), - (None, None) => Err(Error::InvalidRequest( - "missing required query parameter `since_version`".to_string(), - )), - } - } -} - pub(super) fn decode_query_component(component: &str) -> String { let bytes = component.as_bytes(); if !bytes.iter().any(|byte| matches!(byte, b'%' | b'+')) { @@ -111,21 +111,22 @@ pub(super) fn decode_query_component(component: &str) -> String { let mut index = 0usize; while index < bytes.len() { match bytes[index] { - b'%' if index + 2 < bytes.len() => { - match (decode_hex(bytes[index + 1]), decode_hex(bytes[index + 2])) { - (Some(high), Some(low)) => { - decoded.push((high << 4) | low); - index += 3; - } - _ => { - decoded.push(bytes[index]); - index += 1; - } + b'%' if index.saturating_add(2) < bytes.len() => { + let high_index = index.saturating_add(1); + let low_index = index.saturating_add(2); + if let (Some(high), Some(low)) = + (decode_hex(bytes[high_index]), decode_hex(bytes[low_index])) + { + decoded.push((high << 4) | low); + index = index.saturating_add(3); + } else { + decoded.push(bytes[index]); + index = index.saturating_add(1); } } byte => { decoded.push(byte); - index += 1; + index = index.saturating_add(1); } } } @@ -135,9 +136,9 @@ pub(super) fn decode_query_component(component: &str) -> String { fn decode_hex(byte: u8) -> Option { match byte { - b'0'..=b'9' => Some(byte - b'0'), - b'a'..=b'f' => Some(byte - b'a' + 10), - b'A'..=b'F' => Some(byte - b'A' + 10), + b'0'..=b'9' => Some(byte.saturating_sub(b'0')), + b'a'..=b'f' => Some(byte.saturating_sub(b'a').saturating_add(10)), + b'A'..=b'F' => Some(byte.saturating_sub(b'A').saturating_add(10)), _ => None, } } diff --git a/crates/rginx-agent/src/server/request/read.rs b/crates/rginx-agent/src/server/request/read.rs index 59776518..f1493045 100644 --- a/crates/rginx-agent/src/server/request/read.rs +++ b/crates/rginx-agent/src/server/request/read.rs @@ -147,13 +147,13 @@ fn handle_metrics_request() -> Result>> { let mut buffer = Vec::new(); encoder .encode(&metric_families, &mut buffer) - .map_err(|e| Error::Server(format!("failed to encode metrics: {}", e)))?; + .map_err(|e| Error::Server(format!("failed to encode metrics: {e}")))?; Response::builder() .status(200) .header("Content-Type", encoder.format_type()) .body(Full::new(Bytes::from(buffer))) - .map_err(|e| Error::Server(format!("failed to build metrics response: {}", e))) + .map_err(|e| Error::Server(format!("failed to build metrics response: {e}"))) } /// Handle /health endpoint - basic health check @@ -177,12 +177,10 @@ async fn handle_readiness_check(context: &ControlPlaneContext) -> Result Result Response, ) -> Result>, String> { let plan: RolloutPlan = - serde_json::from_slice(&body_bytes).map_err(|e| format!("Invalid rollout plan: {}", e))?; + serde_json::from_slice(&body_bytes).map_err(|e| format!("Invalid rollout plan: {e}"))?; - let rollout_id = manager - .create_rollout(plan) - .await - .map_err(|e| format!("Failed to create rollout: {}", e))?; + let rollout_id = + manager.create_rollout(plan).await.map_err(|e| format!("Failed to create rollout: {e}"))?; let response = json!({ "rollout_id": rollout_id, @@ -36,7 +34,7 @@ pub async fn handle_get_rollout( let rollout = manager .get_rollout(rollout_id) .await - .ok_or_else(|| format!("Rollout {} not found", rollout_id))?; + .ok_or_else(|| format!("Rollout {rollout_id} not found"))?; let response = serde_json::to_string(&rollout).unwrap(); @@ -65,10 +63,7 @@ pub async fn handle_start_rollout( rollout_id: &str, manager: Arc, ) -> Result>, String> { - manager - .start_rollout(rollout_id) - .await - .map_err(|e| format!("Failed to start rollout: {}", e))?; + manager.start_rollout(rollout_id).await.map_err(|e| format!("Failed to start rollout: {e}"))?; let response = json!({ "rollout_id": rollout_id, @@ -86,10 +81,7 @@ pub async fn handle_pause_rollout( rollout_id: &str, manager: Arc, ) -> Result>, String> { - manager - .pause_rollout(rollout_id) - .await - .map_err(|e| format!("Failed to pause rollout: {}", e))?; + manager.pause_rollout(rollout_id).await.map_err(|e| format!("Failed to pause rollout: {e}"))?; let response = json!({ "rollout_id": rollout_id, @@ -110,7 +102,7 @@ pub async fn handle_resume_rollout( manager .resume_rollout(rollout_id) .await - .map_err(|e| format!("Failed to resume rollout: {}", e))?; + .map_err(|e| format!("Failed to resume rollout: {e}"))?; let response = json!({ "rollout_id": rollout_id, @@ -128,10 +120,7 @@ pub async fn handle_advance_stage( rollout_id: &str, manager: Arc, ) -> Result>, String> { - manager - .advance_stage(rollout_id) - .await - .map_err(|e| format!("Failed to advance stage: {}", e))?; + manager.advance_stage(rollout_id).await.map_err(|e| format!("Failed to advance stage: {e}"))?; let response = json!({ "rollout_id": rollout_id, @@ -150,7 +139,7 @@ pub async fn handle_rollback( manager: Arc, reason: &str, ) -> Result>, String> { - manager.rollback(rollout_id, reason).await.map_err(|e| format!("Failed to rollback: {}", e))?; + manager.rollback(rollout_id, reason).await.map_err(|e| format!("Failed to rollback: {e}"))?; let response = json!({ "rollout_id": rollout_id, @@ -172,7 +161,7 @@ pub async fn handle_get_rollout_status( let status = manager .get_rollout_status(rollout_id) .await - .ok_or_else(|| format!("Rollout {} not found", rollout_id))?; + .ok_or_else(|| format!("Rollout {rollout_id} not found"))?; let response = serde_json::to_string(&status).unwrap(); diff --git a/crates/rginx-agent/src/server/write.rs b/crates/rginx-agent/src/server/write.rs index 45ba2dc7..a518759d 100644 --- a/crates/rginx-agent/src/server/write.rs +++ b/crates/rginx-agent/src/server/write.rs @@ -1,3 +1,5 @@ +mod routing; + use bytes::Bytes; use http::{Request, Response}; use http_body_util::{BodyExt, Full}; @@ -15,26 +17,24 @@ use crate::server::response::json_response; const MAX_CONTROL_PLANE_BODY_BYTES: usize = 1024 * 1024; -mod routing; - #[derive(Debug, Deserialize)] struct CachePurgeRequest { - zone_name: String, #[serde(default)] key: Option, #[serde(default)] prefix: Option, + zone_name: String, } #[derive(Debug, Deserialize)] struct CacheInvalidateRequest { - zone_name: String, #[serde(default)] key: Option, #[serde(default)] prefix: Option, #[serde(default)] tag: Option, + zone_name: String, } #[derive(Debug, Deserialize)] @@ -98,7 +98,7 @@ pub(super) async fn handle_post( unreachable!("selector validation should reject multiple variants") } }; - let command = CachePurgeCommand { zone_name, target }; + let command = CachePurgeCommand { target, zone_name }; json_response(context.agent_core().purge_cache(command).await?) } "/v1/cache/invalidate" => { @@ -120,7 +120,7 @@ pub(super) async fn handle_post( (None, None, None) => CacheInvalidateTarget::Zone, _ => unreachable!("selector validation should reject multiple variants"), }; - let command = CacheInvalidateCommand { zone_name, target }; + let command = CacheInvalidateCommand { target, zone_name }; json_response(context.agent_core().invalidate_cache(command).await?) } "/v1/cache/clear-invalidations" => { diff --git a/crates/rginx-agent/src/server/write/routing.rs b/crates/rginx-agent/src/server/write/routing.rs index a874ac25..955463c9 100644 --- a/crates/rginx-agent/src/server/write/routing.rs +++ b/crates/rginx-agent/src/server/write/routing.rs @@ -1,4 +1,6 @@ -use super::*; +use super::{ + BodyExt, Bytes, ControlPlaneContext, Error, Full, Incoming, Request, Response, Result, +}; pub(super) async fn route_rollout_post_request( request: Request, diff --git a/crates/rginx-agent/src/system.rs b/crates/rginx-agent/src/system.rs index f4d0b236..3d3dfdd2 100644 --- a/crates/rginx-agent/src/system.rs +++ b/crates/rginx-agent/src/system.rs @@ -11,6 +11,11 @@ use crate::model::{ SystemNetworkInterfaceView, }; +struct StatvfsSnapshot { + available_bytes: u64, + total_bytes: u64, +} + pub(crate) fn collect_system_view(cache_zone_paths: &[PathBuf]) -> Result { Ok(NodeSystemView { collected_at_unix_ms: unix_now_ms(), @@ -26,10 +31,7 @@ pub(crate) fn collect_system_view(cache_zone_paths: &[PathBuf]) -> Result u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map(|duration| duration.as_millis() as u64) - .unwrap_or(0) + SystemTime::now().duration_since(UNIX_EPOCH).map_or(0, |duration| duration.as_millis() as u64) } fn read_hostname() -> Result { @@ -85,7 +87,7 @@ fn read_loadavg() -> Result { .parse::() .map_err(|error| Error::Server(format!("invalid last pid: {error}")))?; - Ok(SystemLoadView { loadavg_1m, loadavg_5m, loadavg_15m, running_tasks, total_tasks, last_pid }) + Ok(SystemLoadView { last_pid, loadavg_15m, loadavg_1m, loadavg_5m, running_tasks, total_tasks }) } fn read_memory() -> Result { @@ -95,7 +97,7 @@ fn read_memory() -> Result { .filter_map(|line| { let (name, value) = line.split_once(':')?; let kb = value.split_whitespace().next()?.parse::().ok()?; - Some((name.trim().to_string(), kb * 1024)) + Some((name.trim().to_string(), kb.saturating_mul(1024))) }) .collect::>(); @@ -110,15 +112,15 @@ fn read_memory() -> Result { let swap_used_bytes = swap_total_bytes.saturating_sub(swap_free_bytes); Ok(SystemMemoryView { - total_bytes, available_bytes, - used_bytes, - free_bytes, - cached_bytes, buffers_bytes, - swap_total_bytes, + cached_bytes, + free_bytes, swap_free_bytes, + swap_total_bytes, swap_used_bytes, + total_bytes, + used_bytes, }) } @@ -198,11 +200,6 @@ where Ok(value.to_string_lossy().into_owned()) } -struct StatvfsSnapshot { - total_bytes: u64, - available_bytes: u64, -} - fn statvfs(path: &Path) -> Result { let path_bytes = std::ffi::CString::new(path.as_os_str().as_encoded_bytes().to_vec()).map_err(|error| { diff --git a/crates/rginx-agent/src/tests.rs b/crates/rginx-agent/src/tests.rs index adac70ac..baf43167 100644 --- a/crates/rginx-agent/src/tests.rs +++ b/crates/rginx-agent/src/tests.rs @@ -1,3 +1,14 @@ +mod agent_core; +mod control_center; +mod outbound; +mod outbound_auth; +mod outbound_persistence; +mod outbound_runtime; +mod outbound_stream; +mod read_api; +mod support; +mod write_api; + use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -25,17 +36,6 @@ use crate::{ ControlPlaneResource, NodeControlAction, NodeObservabilityView, }; -mod agent_core; -mod control_center; -mod outbound; -mod outbound_auth; -mod outbound_persistence; -mod outbound_runtime; -mod outbound_stream; -mod read_api; -mod support; -mod write_api; - use self::support::*; #[test] diff --git a/crates/rginx-agent/src/tests/outbound.rs b/crates/rginx-agent/src/tests/outbound.rs index c82c2408..85f5d2d5 100644 --- a/crates/rginx-agent/src/tests/outbound.rs +++ b/crates/rginx-agent/src/tests/outbound.rs @@ -7,6 +7,116 @@ use rginx_core::{AgentAuthSettings, AgentSettings}; use super::*; +#[derive(Clone, Default)] +pub(super) struct MockControlCenter { + state: Arc>, +} + +#[derive(Default)] +struct MockControlCenterState { + calls: Vec, + poll_errors: VecDeque, + polls: VecDeque, + register_errors: VecDeque, + result_errors: VecDeque, + results: Vec, +} + +impl MockControlCenter { + pub(super) fn calls(&self) -> Vec { + self.state.lock().unwrap().calls.clone() + } + + pub(super) fn fail_next_poll(&self, error: &str) { + self.state.lock().unwrap().poll_errors.push_back(error.to_string()); + } + + pub(super) fn fail_next_register(&self, error: &str) { + self.state.lock().unwrap().register_errors.push_back(error.to_string()); + } + + pub(super) fn fail_next_result(&self, error: &str) { + self.state.lock().unwrap().result_errors.push_back(error.to_string()); + } + + pub(super) fn new() -> Self { + Self::default() + } + + pub(super) fn push_poll(&self, response: crate::AgentPollResponse) { + self.state.lock().unwrap().polls.push_back(response); + } + + pub(super) fn results(&self) -> Vec { + self.state.lock().unwrap().results.clone() + } +} + +impl crate::OutboundControlPlaneClient for MockControlCenter { + fn heartbeat( + &self, + request: crate::AgentHeartbeatRequest, + ) -> Pin> + Send + 'static>> { + let state = self.state.clone(); + Box::pin(async move { + state.lock().unwrap().calls.push(format!("heartbeat:{}", request.node_id)); + Ok(()) + }) + } + + fn poll_commands( + &self, + node_id: String, + cursor: Option, + _timeout: std::time::Duration, + ) -> Pin> + Send + 'static>> + { + let state = self.state.clone(); + Box::pin(async move { + let mut state = state.lock().unwrap(); + state.calls.push(format!("poll:{node_id}:{cursor:?}")); + if let Some(error) = state.poll_errors.pop_front() { + return Err(crate::Error::Server(error)); + } + Ok(state.polls.pop_front().unwrap_or_else(crate::AgentPollResponse::empty)) + }) + } + + fn post_result( + &self, + result: crate::AgentCommandResult, + ) -> Pin> + Send + 'static>> { + let state = self.state.clone(); + Box::pin(async move { + let mut state = state.lock().unwrap(); + state.calls.push(format!( + "result:{}:{}", + result.command_id, + serde_json::to_value(result.status).unwrap().as_str().unwrap() + )); + if let Some(error) = state.result_errors.pop_front() { + return Err(crate::Error::Server(error)); + } + state.results.push(result); + Ok(()) + }) + } + fn register( + &self, + request: crate::AgentRegisterRequest, + ) -> Pin> + Send + 'static>> { + let state = self.state.clone(); + Box::pin(async move { + let mut state = state.lock().unwrap(); + state.calls.push(format!("register:{}", request.node_id)); + if let Some(error) = state.register_errors.pop_front() { + return Err(crate::Error::Server(error)); + } + Ok(()) + }) + } +} + #[tokio::test] async fn outbound_agent_cycle_registers_heartbeats_polls_and_posts_results() { let tempdir = tempfile::tempdir().expect("cache temp dir should exist"); @@ -173,117 +283,6 @@ async fn outbound_agent_can_recover_after_control_center_error() { ); } -#[derive(Clone, Default)] -pub(super) struct MockControlCenter { - state: Arc>, -} - -#[derive(Default)] -struct MockControlCenterState { - calls: Vec, - polls: VecDeque, - register_errors: VecDeque, - poll_errors: VecDeque, - result_errors: VecDeque, - results: Vec, -} - -impl MockControlCenter { - pub(super) fn new() -> Self { - Self::default() - } - - pub(super) fn push_poll(&self, response: crate::AgentPollResponse) { - self.state.lock().unwrap().polls.push_back(response); - } - - pub(super) fn fail_next_register(&self, error: &str) { - self.state.lock().unwrap().register_errors.push_back(error.to_string()); - } - - pub(super) fn fail_next_poll(&self, error: &str) { - self.state.lock().unwrap().poll_errors.push_back(error.to_string()); - } - - pub(super) fn fail_next_result(&self, error: &str) { - self.state.lock().unwrap().result_errors.push_back(error.to_string()); - } - - pub(super) fn calls(&self) -> Vec { - self.state.lock().unwrap().calls.clone() - } - - pub(super) fn results(&self) -> Vec { - self.state.lock().unwrap().results.clone() - } -} - -impl crate::OutboundControlPlaneClient for MockControlCenter { - fn register( - &self, - request: crate::AgentRegisterRequest, - ) -> Pin> + Send + 'static>> { - let state = self.state.clone(); - Box::pin(async move { - let mut state = state.lock().unwrap(); - state.calls.push(format!("register:{}", request.node_id)); - if let Some(error) = state.register_errors.pop_front() { - return Err(crate::Error::Server(error)); - } - Ok(()) - }) - } - - fn heartbeat( - &self, - request: crate::AgentHeartbeatRequest, - ) -> Pin> + Send + 'static>> { - let state = self.state.clone(); - Box::pin(async move { - state.lock().unwrap().calls.push(format!("heartbeat:{}", request.node_id)); - Ok(()) - }) - } - - fn poll_commands( - &self, - node_id: String, - cursor: Option, - _timeout: std::time::Duration, - ) -> Pin> + Send + 'static>> - { - let state = self.state.clone(); - Box::pin(async move { - let mut state = state.lock().unwrap(); - state.calls.push(format!("poll:{node_id}:{cursor:?}")); - if let Some(error) = state.poll_errors.pop_front() { - return Err(crate::Error::Server(error)); - } - Ok(state.polls.pop_front().unwrap_or_else(crate::AgentPollResponse::empty)) - }) - } - - fn post_result( - &self, - result: crate::AgentCommandResult, - ) -> Pin> + Send + 'static>> { - let state = self.state.clone(); - Box::pin(async move { - let mut state = state.lock().unwrap(); - state.calls.push(format!( - "result:{}:{}", - result.command_id, - serde_json::to_value(result.status).unwrap().as_str().unwrap() - )); - if let Some(error) = state.result_errors.pop_front() { - return Err(crate::Error::Server(error)); - } - state.results.push(result); - Ok(()) - }) - } -} - pub(super) fn settings() -> AgentSettings { settings_with_state_path("/tmp/rginx-agent-state.json") } diff --git a/crates/rginx-agent/src/tests/outbound_auth.rs b/crates/rginx-agent/src/tests/outbound_auth.rs index a95abf4f..66b4cacf 100644 --- a/crates/rginx-agent/src/tests/outbound_auth.rs +++ b/crates/rginx-agent/src/tests/outbound_auth.rs @@ -7,7 +7,7 @@ use uuid::Uuid; fn outbound_request_auth_rejects_missing_wrong_expired_and_replayed_tokens() { assert!(crate::OutboundRequestSigner::new(" ").err().unwrap().to_string().contains("missing")); assert!( - crate::OutboundAuthVerifier::new(" ", Duration::from_secs(60)) + crate::OutboundAuthVerifier::new(" ", Duration::from_mins(1)) .err() .unwrap() .to_string() @@ -15,7 +15,7 @@ fn outbound_request_auth_rejects_missing_wrong_expired_and_replayed_tokens() { ); let signer = crate::OutboundRequestSigner::new("secret").unwrap(); - let mut verifier = crate::OutboundAuthVerifier::new("secret", Duration::from_secs(60)).unwrap(); + let mut verifier = crate::OutboundAuthVerifier::new("secret", Duration::from_mins(1)).unwrap(); let body = br#"{"node_id":"edge-sfo-1"}"#; let signed = signer.sign(&Method::POST, "/v1/agents/register", body).unwrap(); let request_time = signed.timestamp.parse::().unwrap(); @@ -31,7 +31,7 @@ fn outbound_request_auth_rejects_missing_wrong_expired_and_replayed_tokens() { .contains("replay") ); - let mut verifier = crate::OutboundAuthVerifier::new("other", Duration::from_secs(60)).unwrap(); + let mut verifier = crate::OutboundAuthVerifier::new("other", Duration::from_mins(1)).unwrap(); assert!( verifier .verify("POST", "/v1/agents/register", body, &signed, request_time) @@ -53,7 +53,7 @@ fn outbound_request_auth_rejects_missing_wrong_expired_and_replayed_tokens() { #[test] fn outbound_request_auth_rejects_body_or_signature_tampering_without_burning_nonce() { let signer = crate::OutboundRequestSigner::new("secret").unwrap(); - let mut verifier = crate::OutboundAuthVerifier::new("secret", Duration::from_secs(60)).unwrap(); + let mut verifier = crate::OutboundAuthVerifier::new("secret", Duration::from_mins(1)).unwrap(); let body = br#"{"node_id":"edge-sfo-1"}"#; let signed = signer.sign(&Method::POST, "/v1/agents/register", body).unwrap(); let request_time = signed.timestamp.parse::().unwrap(); @@ -75,7 +75,7 @@ fn outbound_request_auth_rejects_body_or_signature_tampering_without_burning_non .verify("POST", "/v1/agents/register", body, &signed, request_time) .expect("valid request should still authenticate after rejected tampering"); - let mut verifier = crate::OutboundAuthVerifier::new("secret", Duration::from_secs(60)).unwrap(); + let mut verifier = crate::OutboundAuthVerifier::new("secret", Duration::from_mins(1)).unwrap(); let mut tampered = signed; tampered.signature = "bad-signature".to_string(); assert!( diff --git a/crates/rginx-agent/src/tests/outbound_stream.rs b/crates/rginx-agent/src/tests/outbound_stream.rs index 863f4b15..9695dfb9 100644 --- a/crates/rginx-agent/src/tests/outbound_stream.rs +++ b/crates/rginx-agent/src/tests/outbound_stream.rs @@ -5,6 +5,73 @@ use std::pin::Pin; use super::outbound::{MockControlCenter, command, settings}; use super::*; +#[derive(Clone, Default)] +struct MockStream { + state: Arc>, +} + +#[derive(Default)] +struct MockStreamState { + batches: VecDeque>, + hellos: Vec, + results: Vec, +} + +impl MockStream { + fn fail_next(&self, error: &str) { + self.state.lock().unwrap().batches.push_back(Err(error.to_string())); + } + + fn hellos(&self) -> Vec { + self.state.lock().unwrap().hellos.clone() + } + + fn new() -> Self { + Self::default() + } + + fn push_batch(&self, batch: crate::AgentStreamCommandBatch) { + self.state.lock().unwrap().batches.push_back(Ok(batch)); + } + + fn results(&self) -> Vec { + self.state.lock().unwrap().results.clone() + } +} + +impl crate::OutboundStreamClient for MockStream { + fn post_result( + &self, + result: crate::AgentCommandResult, + ) -> Pin> + Send + 'static>> { + let state = self.state.clone(); + Box::pin(async move { + state.lock().unwrap().results.push(result); + Ok(()) + }) + } + fn receive_commands( + &self, + hello: crate::AgentStreamHello, + _timeout: std::time::Duration, + ) -> Pin> + Send + 'static>> + { + let state = self.state.clone(); + Box::pin(async move { + let mut state = state.lock().unwrap(); + state.hellos.push(hello); + match state + .batches + .pop_front() + .unwrap_or_else(|| Ok(crate::AgentStreamCommandBatch::empty())) + { + Ok(batch) => Ok(batch), + Err(error) => Err(crate::Error::Server(error)), + } + }) + } +} + #[tokio::test] async fn outbound_agent_uses_stream_commands_without_polling_when_stream_is_online() { let state = @@ -77,71 +144,3 @@ async fn outbound_agent_fallback_polling_continues_from_stream_cursor() { assert!(control.calls().contains(&"poll:edge-sfo-1:Some(\"cursor-from-stream\")".to_string())); assert_eq!(stream.hellos()[1].cursor.as_deref(), Some("cursor-from-stream")); } - -#[derive(Clone, Default)] -struct MockStream { - state: Arc>, -} - -#[derive(Default)] -struct MockStreamState { - hellos: Vec, - batches: VecDeque>, - results: Vec, -} - -impl MockStream { - fn new() -> Self { - Self::default() - } - - fn push_batch(&self, batch: crate::AgentStreamCommandBatch) { - self.state.lock().unwrap().batches.push_back(Ok(batch)); - } - - fn fail_next(&self, error: &str) { - self.state.lock().unwrap().batches.push_back(Err(error.to_string())); - } - - fn hellos(&self) -> Vec { - self.state.lock().unwrap().hellos.clone() - } - - fn results(&self) -> Vec { - self.state.lock().unwrap().results.clone() - } -} - -impl crate::OutboundStreamClient for MockStream { - fn receive_commands( - &self, - hello: crate::AgentStreamHello, - _timeout: std::time::Duration, - ) -> Pin> + Send + 'static>> - { - let state = self.state.clone(); - Box::pin(async move { - let mut state = state.lock().unwrap(); - state.hellos.push(hello); - match state - .batches - .pop_front() - .unwrap_or_else(|| Ok(crate::AgentStreamCommandBatch::empty())) - { - Ok(batch) => Ok(batch), - Err(error) => Err(crate::Error::Server(error)), - } - }) - } - - fn post_result( - &self, - result: crate::AgentCommandResult, - ) -> Pin> + Send + 'static>> { - let state = self.state.clone(); - Box::pin(async move { - state.lock().unwrap().results.push(result); - Ok(()) - }) - } -} diff --git a/crates/rginx-agent/src/tests/support.rs b/crates/rginx-agent/src/tests/support.rs index 8f0be0ca..00bcfd45 100644 --- a/crates/rginx-agent/src/tests/support.rs +++ b/crates/rginx-agent/src/tests/support.rs @@ -1,96 +1,35 @@ +mod executors; + use super::*; -mod executors; +use hyper_util::client::legacy::connect::Connect; pub(super) use self::executors::{TestConfigApplyExecutor, TestReloadExecutor}; const CONTROL_PLANE_RETRY_ATTEMPTS: usize = 100; const CONTROL_PLANE_RETRY_DELAY: Duration = Duration::from_millis(50); -pub(super) fn snapshot() -> ConfigSnapshot { - ConfigSnapshot { - control_plane: None, - agent: None, - acme: None, - managed_certificates: Vec::new(), - cache_zones: HashMap::new(), - runtime: RuntimeSettings { - shutdown_timeout: Duration::from_secs(1), - worker_threads: None, - accept_workers: 1, - }, - listeners: vec![Listener { - id: "default".to_string(), - name: "default".to_string(), - server: Server { - listen_addr: "127.0.0.1:8080".parse().unwrap(), - server_header: rginx_core::default_server_header(), - default_certificate: None, - trusted_proxies: Vec::new(), - client_ip_header: None, - keep_alive: true, - max_headers: None, - max_request_body_bytes: None, - max_connections: None, - header_read_timeout: None, - request_body_read_timeout: None, - response_write_timeout: None, - http1: rginx_core::Http1Settings::default(), - access_log_format: None, - tls: None, - }, - default_server: true, - reuse_port_enabled: false, - tls_termination_enabled: false, - proxy_protocol_enabled: false, - http3: None, - }], - default_vhost: VirtualHost { - id: "server".to_string(), - listener_ids: Vec::new(), - default_listener_ids: Vec::new(), - server_names: Vec::new(), - routes: Vec::new(), - tls: None, - }, - vhosts: Vec::new(), - upstreams: HashMap::new(), - lookup: Default::default(), - } -} - -pub(super) fn snapshot_with_cache_zone(path: std::path::PathBuf) -> ConfigSnapshot { - let mut config = snapshot(); - config.cache_zones.insert( - "default".to_string(), - Arc::new(CacheZone { - name: "default".to_string(), - path, - max_size_bytes: Some(1024 * 1024), - inactive: Duration::from_secs(60), - default_ttl: Duration::from_secs(60), - max_entry_bytes: 1024, - path_levels: vec![2], - loader_batch_entries: 100, - loader_sleep: Duration::ZERO, - manager_batch_entries: 100, - manager_sleep: Duration::ZERO, - inactive_cleanup_interval: Duration::from_secs(60), - shared_index: true, - }), - ); - config -} - pub(super) struct ControlPlaneFixture { _tempdir: tempfile::TempDir, + ca_der: CertificateDer<'static>, pub(super) cert_path: std::path::PathBuf, pub(super) key_path: std::path::PathBuf, pub(super) keyring_path: std::path::PathBuf, - ca_der: CertificateDer<'static>, } impl ControlPlaneFixture { + pub(super) fn client(&self) -> Client, Full> { + let mut roots = RootCertStore::empty(); + roots.add(self.ca_der.clone()).expect("ca should add"); + let connector = HttpsConnectorBuilder::new() + .with_tls_config( + rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth(), + ) + .https_only() + .enable_http1() + .build(); + Client::builder(TokioExecutor::new()).build(connector) + } pub(super) fn new() -> std::result::Result> { let tempdir = tempfile::tempdir()?; let mut params = CertificateParams::new(vec!["localhost".to_string()])?; @@ -151,30 +90,24 @@ impl ControlPlaneFixture { ca_der: cert.der().clone(), }) } - - pub(super) fn client(&self) -> Client, Full> { - let mut roots = RootCertStore::empty(); - roots.add(self.ca_der.clone()).expect("ca should add"); - let connector = HttpsConnectorBuilder::new() - .with_tls_config( - rustls::ClientConfig::builder().with_root_certificates(roots).with_no_client_auth(), - ) - .https_only() - .enable_http1() - .build(); - Client::builder(TokioExecutor::new()).build(connector) - } } pub(super) struct RunningControlPlane { _fixture: ControlPlaneFixture, - pub(super) listen_addr: std::net::SocketAddr, pub(super) client: Client, Full>, - shutdown_tx: watch::Sender, + pub(super) listen_addr: std::net::SocketAddr, pub(super) server_task: tokio::task::JoinHandle>, + shutdown_tx: watch::Sender, } impl RunningControlPlane { + pub(super) async fn shutdown(self) { + let _ = self.shutdown_tx.send(true); + self.server_task + .await + .expect("server task should join") + .expect("server should stop cleanly"); + } pub(super) async fn start() -> Self { let state = rginx_http::SharedState::from_config(snapshot()).expect("shared state should build"); @@ -236,23 +169,90 @@ impl RunningControlPlane { Self { _fixture: fixture, listen_addr, client, shutdown_tx, server_task } } +} - pub(super) async fn shutdown(self) { - let _ = self.shutdown_tx.send(true); - self.server_task - .await - .expect("server task should join") - .expect("server should stop cleanly"); +pub(super) fn snapshot() -> ConfigSnapshot { + ConfigSnapshot { + control_plane: None, + agent: None, + acme: None, + managed_certificates: Vec::new(), + cache_zones: HashMap::new(), + runtime: RuntimeSettings { + shutdown_timeout: Duration::from_secs(1), + worker_threads: None, + accept_workers: 1, + }, + listeners: vec![Listener { + id: "default".to_string(), + name: "default".to_string(), + server: Server { + listen_addr: "127.0.0.1:8080".parse().unwrap(), + server_header: rginx_core::default_server_header(), + default_certificate: None, + trusted_proxies: Vec::new(), + client_ip_header: None, + keep_alive: true, + max_headers: None, + max_request_body_bytes: None, + max_connections: None, + header_read_timeout: None, + request_body_read_timeout: None, + response_write_timeout: None, + http1: rginx_core::Http1Settings::default(), + access_log_format: None, + tls: None, + }, + default_server: true, + reuse_port_enabled: false, + tls_termination_enabled: false, + proxy_protocol_enabled: false, + http3: None, + }], + default_vhost: VirtualHost { + id: "server".to_string(), + listener_ids: Vec::new(), + default_listener_ids: Vec::new(), + server_names: Vec::new(), + routes: Vec::new(), + tls: None, + }, + vhosts: Vec::new(), + upstreams: HashMap::new(), + lookup: Default::default(), } } +pub(super) fn snapshot_with_cache_zone(path: std::path::PathBuf) -> ConfigSnapshot { + let mut config = snapshot(); + config.cache_zones.insert( + "default".to_string(), + Arc::new(CacheZone { + name: "default".to_string(), + path, + max_size_bytes: Some(1024 * 1024), + inactive: Duration::from_mins(1), + default_ttl: Duration::from_mins(1), + max_entry_bytes: 1024, + path_levels: vec![2], + loader_batch_entries: 100, + loader_sleep: Duration::ZERO, + manager_batch_entries: 100, + manager_sleep: Duration::ZERO, + inactive_cleanup_interval: Duration::from_mins(1), + shared_index: true, + }), + ); + config +} + pub(super) async fn retry_get_json( client: &Client>, uri: String, api_key: Option<&str>, ) -> hyper::Response where - C: hyper_util::client::legacy::connect::Connect + Clone + Send + Sync + 'static, + C: Connect + Clone + Send + Sync + 'static, { let mut request = Request::builder() .method(Method::GET) @@ -278,7 +278,7 @@ pub(super) async fn retry_get_json_with_server( server_task: &mut tokio::task::JoinHandle>, ) -> hyper::Response where - C: hyper_util::client::legacy::connect::Connect + Clone + Send + Sync + 'static, + C: Connect + Clone + Send + Sync + 'static, { let mut request = Request::builder() .method(Method::GET) @@ -324,7 +324,7 @@ pub(super) async fn retry_json_request_with_server( server_task: &mut tokio::task::JoinHandle>, ) -> hyper::Response where - C: hyper_util::client::legacy::connect::Connect + Clone + Send + Sync + 'static, + C: Connect + Clone + Send + Sync + 'static, { let body = json_body .map(|value| Bytes::from(serde_json::to_vec(&value).expect("json body should encode"))) diff --git a/crates/rginx-agent/src/tests/support/executors.rs b/crates/rginx-agent/src/tests/support/executors.rs index 40301a05..d1c9c8de 100644 --- a/crates/rginx-agent/src/tests/support/executors.rs +++ b/crates/rginx-agent/src/tests/support/executors.rs @@ -5,18 +5,17 @@ pub(in crate::tests) struct TestReloadExecutor { } enum ReloadMode { - Success { state: Box }, Failure, + Success { state: Box }, } impl TestReloadExecutor { - pub(in crate::tests) fn success(state: rginx_http::SharedState) -> Self { - Self { mode: ReloadMode::Success { state: Box::new(state) } } - } - pub(in crate::tests) fn failing() -> Self { Self { mode: ReloadMode::Failure } } + pub(in crate::tests) fn success(state: rginx_http::SharedState) -> Self { + Self { mode: ReloadMode::Success { state: Box::new(state) } } + } } impl crate::ReloadExecutor for TestReloadExecutor { @@ -46,18 +45,17 @@ pub(in crate::tests) struct TestConfigApplyExecutor { } enum ConfigApplyMode { - Success { state: Box }, Failure, + Success { state: Box }, } impl TestConfigApplyExecutor { - pub(in crate::tests) fn success(state: rginx_http::SharedState) -> Self { - Self { mode: ConfigApplyMode::Success { state: Box::new(state) } } - } - pub(in crate::tests) fn failing() -> Self { Self { mode: ConfigApplyMode::Failure } } + pub(in crate::tests) fn success(state: rginx_http::SharedState) -> Self { + Self { mode: ConfigApplyMode::Success { state: Box::new(state) } } + } } impl crate::ConfigApplyExecutor for TestConfigApplyExecutor { diff --git a/crates/rginx-agent/src/tls.rs b/crates/rginx-agent/src/tls.rs index af09c22d..3c4119d2 100644 --- a/crates/rginx-agent/src/tls.rs +++ b/crates/rginx-agent/src/tls.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::path::Path; use std::sync::Arc; @@ -13,6 +15,15 @@ use tokio_rustls::server::TlsStream; use crate::error::{Error, Result}; +/// Client certificate identity extracted from the peer certificate +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ClientCertIdentity { + pub common_name: String, + pub organization: Option, + pub organizational_unit: Option, + pub serial_number: String, +} + pub(crate) fn load_tls_server_config( settings: &ControlPlaneTlsSettings, ) -> Result> { @@ -112,15 +123,6 @@ fn map_pem_error(path: &Path, item: &str, error: PemError) -> Error { } } -/// Client certificate identity extracted from the peer certificate -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ClientCertIdentity { - pub common_name: String, - pub organization: Option, - pub organizational_unit: Option, - pub serial_number: String, -} - /// Extract client identity from TLS stream pub fn extract_client_identity(tls_stream: &TlsStream) -> Option { let (_io, server_conn) = tls_stream.get_ref(); @@ -158,6 +160,3 @@ fn parse_certificate(cert_der: &CertificateDer) -> Option { serial_number, }) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-agent/src/websocket.rs b/crates/rginx-agent/src/websocket.rs index bafad2e8..5a039193 100644 --- a/crates/rginx-agent/src/websocket.rs +++ b/crates/rginx-agent/src/websocket.rs @@ -14,22 +14,60 @@ use crate::server::control::ControlPlaneContext; /// WebSocket request from client #[derive(Debug, Deserialize)] pub struct WebSocketRequest { - pub request_id: String, pub action: String, #[serde(default)] pub filter: Option, + pub request_id: String, } /// WebSocket response to client #[derive(Debug, Serialize)] pub struct WebSocketResponse { - pub request_id: String, pub action: String, pub data: serde_json::Value, + pub request_id: String, +} + +impl<'de> serde::Deserialize<'de> for EventFilter { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + struct EventFilterHelper { + #[serde(default)] + event_types: Vec, + #[serde(default)] + node_ids: Vec, + #[serde(default)] + regions: Vec, + } + + let helper = EventFilterHelper::deserialize(deserializer)?; + Ok(EventFilter { + event_types: helper.event_types, + node_ids: helper.node_ids, + regions: helper.regions, + }) + } +} + +impl serde::Serialize for EventFilter { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut state = serializer.serialize_struct("EventFilter", 3)?; + state.serialize_field("event_types", &self.event_types)?; + state.serialize_field("node_ids", &self.node_ids)?; + state.serialize_field("regions", &self.regions)?; + state.end() + } } /// Handle WebSocket upgrade and connection -#[allow(dead_code)] +#[expect(dead_code, reason = "websocket control-plane entrypoint is not wired into routing yet")] pub async fn handle_websocket_connection( stream: TcpStream, peer_addr: SocketAddr, @@ -37,7 +75,7 @@ pub async fn handle_websocket_connection( ) -> Result<()> { let ws_stream = accept_async(stream) .await - .map_err(|e| Error::Server(format!("websocket handshake failed: {}", e)))?; + .map_err(|e| Error::Server(format!("websocket handshake failed: {e}")))?; tracing::info!(%peer_addr, "websocket connection established"); metrics::increment_websocket_connections(); @@ -92,14 +130,13 @@ pub async fn handle_websocket_connection( Ok(()) } -#[allow(dead_code)] async fn handle_websocket_message( text: &str, context: &ControlPlaneContext, tx: &tokio::sync::mpsc::Sender, ) -> Result<()> { let request: WebSocketRequest = serde_json::from_str(text) - .map_err(|e| Error::InvalidRequest(format!("invalid json: {}", e)))?; + .map_err(|e| Error::InvalidRequest(format!("invalid json: {e}")))?; match request.action.as_str() { "subscribe" => { @@ -113,7 +150,7 @@ async fn handle_websocket_message( }; tx.send(Message::Text(serde_json::to_string(&response)?.into())) .await - .map_err(|e| Error::Server(format!("failed to send response: {}", e)))?; + .map_err(|e| Error::Server(format!("failed to send response: {e}")))?; } "unsubscribe" => { context.event_bus().unsubscribe(&request.request_id).await; @@ -125,7 +162,7 @@ async fn handle_websocket_message( }; tx.send(Message::Text(serde_json::to_string(&response)?.into())) .await - .map_err(|e| Error::Server(format!("failed to send response: {}", e)))?; + .map_err(|e| Error::Server(format!("failed to send response: {e}")))?; } "ping" => { let response = WebSocketResponse { @@ -135,7 +172,7 @@ async fn handle_websocket_message( }; tx.send(Message::Text(serde_json::to_string(&response)?.into())) .await - .map_err(|e| Error::Server(format!("failed to send response: {}", e)))?; + .map_err(|e| Error::Server(format!("failed to send response: {e}")))?; } _ => { return Err(Error::InvalidRequest(format!("unknown action: {}", request.action))); @@ -144,41 +181,3 @@ async fn handle_websocket_message( Ok(()) } - -impl<'de> serde::Deserialize<'de> for EventFilter { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - #[derive(Deserialize)] - struct EventFilterHelper { - #[serde(default)] - event_types: Vec, - #[serde(default)] - node_ids: Vec, - #[serde(default)] - regions: Vec, - } - - let helper = EventFilterHelper::deserialize(deserializer)?; - Ok(EventFilter { - event_types: helper.event_types, - node_ids: helper.node_ids, - regions: helper.regions, - }) - } -} - -impl serde::Serialize for EventFilter { - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut state = serializer.serialize_struct("EventFilter", 3)?; - state.serialize_field("event_types", &self.event_types)?; - state.serialize_field("node_ids", &self.node_ids)?; - state.serialize_field("regions", &self.regions)?; - state.end() - } -} diff --git a/crates/rginx-app/Cargo.toml b/crates/rginx-app/Cargo.toml index ac27cb04..41e95580 100644 --- a/crates/rginx-app/Cargo.toml +++ b/crates/rginx-app/Cargo.toml @@ -13,6 +13,9 @@ description = "Rust edge reverse proxy for small and medium deployments." keywords = ["reverse-proxy", "edge-proxy", "http", "tls", "rust"] categories = ["network-programming", "web-programming::http-server"] +[lints] +workspace = true + [dependencies] anyhow.workspace = true clap.workspace = true diff --git a/crates/rginx-app/src/admin_cli/mod.rs b/crates/rginx-app/src/admin_cli/mod.rs index abf54359..0c7c74da 100644 --- a/crates/rginx-app/src/admin_cli/mod.rs +++ b/crates/rginx-app/src/admin_cli/mod.rs @@ -1,3 +1,15 @@ +mod agent; +mod cache; +mod counters; +mod peers; +mod render; +mod revision; +mod snapshot; +mod socket; +mod status; +mod traffic; +mod upstreams; + pub(super) use std::io::{BufReader, Read, Write}; pub(super) use std::os::unix::net::UnixStream; pub(super) use std::path::Path; @@ -13,18 +25,6 @@ pub(super) use crate::cli::{ InvalidateCacheArgs, PurgeCacheArgs, SnapshotArgs, SnapshotModuleArg, WaitArgs, WindowArgs, }; -mod agent; -mod cache; -mod counters; -mod peers; -mod render; -mod revision; -mod snapshot; -mod socket; -mod status; -mod traffic; -mod upstreams; - pub(crate) fn run_admin_command(config_path: &Path, command: &Command) -> anyhow::Result { match command { Command::Snapshot(args) => { diff --git a/crates/rginx-app/src/admin_cli/status.rs b/crates/rginx-app/src/admin_cli/status.rs index eebf6253..a1b581a4 100644 --- a/crates/rginx-app/src/admin_cli/status.rs +++ b/crates/rginx-app/src/admin_cli/status.rs @@ -1,13 +1,13 @@ -use super::cache::print_status_cache; -use super::socket::{query_admin_socket, unexpected_admin_response}; -use super::*; - mod acme; mod listeners; mod runtime; mod tls; mod upstream_tls; +use super::cache::print_status_cache; +use super::socket::{query_admin_socket, unexpected_admin_response}; +use super::*; + pub(super) fn print_admin_status(config_path: &Path) -> anyhow::Result<()> { match query_admin_socket(config_path, AdminRequest::GetStatus)? { AdminResponse::Status(status) => { diff --git a/crates/rginx-app/src/admin_cli/traffic.rs b/crates/rginx-app/src/admin_cli/traffic.rs index 4073ed88..f3cb8055 100644 --- a/crates/rginx-app/src/admin_cli/traffic.rs +++ b/crates/rginx-app/src/admin_cli/traffic.rs @@ -1,12 +1,12 @@ -use super::socket::{query_admin_socket, unexpected_admin_response}; -use super::*; -use crate::admin_cli::render::print_section; - mod listeners; mod render; mod routes; mod vhosts; +use super::socket::{query_admin_socket, unexpected_admin_response}; +use super::*; +use crate::admin_cli::render::print_section; + pub(super) fn print_admin_traffic(config_path: &Path, args: &WindowArgs) -> anyhow::Result<()> { match query_admin_socket( config_path, diff --git a/crates/rginx-app/src/check/acme.rs b/crates/rginx-app/src/check/acme.rs index b0ac8d7c..7c3b1d6d 100644 --- a/crates/rginx-app/src/check/acme.rs +++ b/crates/rginx-app/src/check/acme.rs @@ -4,22 +4,22 @@ use std::path::PathBuf; use super::tls::TlsCheckDetails; pub(super) struct AcmeCheckDetails { - pub(super) enabled: bool, pub(super) directory_url: Option, - pub(super) state_dir: Option, - pub(super) renew_before_days: Option, - pub(super) poll_interval_secs: Option, + pub(super) enabled: bool, pub(super) managed_certificates: Vec, + pub(super) poll_interval_secs: Option, + pub(super) renew_before_days: Option, + pub(super) state_dir: Option, } pub(super) struct AcmeManagedCertificateCheck { - pub(super) scope: String, - pub(super) domains: Vec, - pub(super) managed: bool, - pub(super) challenge_type: String, pub(super) cert_path: PathBuf, + pub(super) challenge_type: String, + pub(super) domains: Vec, pub(super) key_path: PathBuf, + pub(super) managed: bool, pub(super) next_renewal_unix_ms: Option, + pub(super) scope: String, } pub(super) fn acme_check_details( diff --git a/crates/rginx-app/src/check/control.rs b/crates/rginx-app/src/check/control.rs index 082ee185..3ebbf4e8 100644 --- a/crates/rginx-app/src/check/control.rs +++ b/crates/rginx-app/src/check/control.rs @@ -2,13 +2,13 @@ use std::net::SocketAddr; use std::path::PathBuf; pub(super) struct ControlPlaneCheckDetails { - pub(super) mode: &'static str, pub(super) agent_endpoint: Option, pub(super) agent_node_id: Option, pub(super) agent_state_path: Option, pub(super) legacy_listen: Option, - pub(super) opens_extra_node_port: bool, pub(super) migration_hint: Option<&'static str>, + pub(super) mode: &'static str, + pub(super) opens_extra_node_port: bool, } pub(super) fn control_plane_check_details( diff --git a/crates/rginx-app/src/check/routes.rs b/crates/rginx-app/src/check/routes.rs index b395c983..4c5ea949 100644 --- a/crates/rginx-app/src/check/routes.rs +++ b/crates/rginx-app/src/check/routes.rs @@ -1,18 +1,18 @@ #[derive(Default)] pub(super) struct RouteTransportCheckDetails { + pub(super) cache_enabled_routes: usize, + pub(super) compression_auto_routes: usize, + pub(super) compression_force_routes: usize, + pub(super) compression_off_routes: usize, + pub(super) custom_compression_content_types_routes: usize, + pub(super) custom_compression_min_bytes_routes: usize, pub(super) request_buffering_auto_routes: usize, - pub(super) request_buffering_on_routes: usize, pub(super) request_buffering_off_routes: usize, + pub(super) request_buffering_on_routes: usize, pub(super) response_buffering_auto_routes: usize, - pub(super) response_buffering_on_routes: usize, pub(super) response_buffering_off_routes: usize, - pub(super) compression_auto_routes: usize, - pub(super) compression_off_routes: usize, - pub(super) compression_force_routes: usize, - pub(super) custom_compression_min_bytes_routes: usize, - pub(super) custom_compression_content_types_routes: usize, + pub(super) response_buffering_on_routes: usize, pub(super) streaming_response_idle_timeout_routes: usize, - pub(super) cache_enabled_routes: usize, } pub(super) fn route_transport_check_details( @@ -22,34 +22,52 @@ pub(super) fn route_transport_check_details( for route in all_routes(config) { match route.request_buffering { - rginx_core::RouteBufferingPolicy::Auto => details.request_buffering_auto_routes += 1, - rginx_core::RouteBufferingPolicy::On => details.request_buffering_on_routes += 1, - rginx_core::RouteBufferingPolicy::Off => details.request_buffering_off_routes += 1, + rginx_core::RouteBufferingPolicy::Auto => { + increment(&mut details.request_buffering_auto_routes); + } + rginx_core::RouteBufferingPolicy::On => { + increment(&mut details.request_buffering_on_routes); + } + rginx_core::RouteBufferingPolicy::Off => { + increment(&mut details.request_buffering_off_routes); + } } match route.response_buffering { - rginx_core::RouteBufferingPolicy::Auto => details.response_buffering_auto_routes += 1, - rginx_core::RouteBufferingPolicy::On => details.response_buffering_on_routes += 1, - rginx_core::RouteBufferingPolicy::Off => details.response_buffering_off_routes += 1, + rginx_core::RouteBufferingPolicy::Auto => { + increment(&mut details.response_buffering_auto_routes); + } + rginx_core::RouteBufferingPolicy::On => { + increment(&mut details.response_buffering_on_routes); + } + rginx_core::RouteBufferingPolicy::Off => { + increment(&mut details.response_buffering_off_routes); + } } match route.compression { - rginx_core::RouteCompressionPolicy::Auto => details.compression_auto_routes += 1, - rginx_core::RouteCompressionPolicy::Off => details.compression_off_routes += 1, - rginx_core::RouteCompressionPolicy::Force => details.compression_force_routes += 1, + rginx_core::RouteCompressionPolicy::Auto => { + increment(&mut details.compression_auto_routes); + } + rginx_core::RouteCompressionPolicy::Off => { + increment(&mut details.compression_off_routes); + } + rginx_core::RouteCompressionPolicy::Force => { + increment(&mut details.compression_force_routes); + } } if route.compression_min_bytes.is_some() { - details.custom_compression_min_bytes_routes += 1; + increment(&mut details.custom_compression_min_bytes_routes); } if !route.compression_content_types.is_empty() { - details.custom_compression_content_types_routes += 1; + increment(&mut details.custom_compression_content_types_routes); } if route.streaming_response_idle_timeout.is_some() { - details.streaming_response_idle_timeout_routes += 1; + increment(&mut details.streaming_response_idle_timeout_routes); } if route.cache.is_some() { - details.cache_enabled_routes += 1; + increment(&mut details.cache_enabled_routes); } } @@ -63,3 +81,7 @@ fn all_routes( .chain(config.vhosts.iter()) .flat_map(|vhost| vhost.routes.iter()) } + +fn increment(value: &mut usize) { + *value = value.saturating_add(1); +} diff --git a/crates/rginx-app/src/check/summary.rs b/crates/rginx-app/src/check/summary.rs index d3ab1e4e..15534801 100644 --- a/crates/rginx-app/src/check/summary.rs +++ b/crates/rginx-app/src/check/summary.rs @@ -6,67 +6,67 @@ use super::routes::{RouteTransportCheckDetails, route_transport_check_details}; use super::tls::{TlsCheckDetails, tls_check_details}; pub(crate) struct CheckSummary { - pub(super) listener_model: &'static str, - pub(super) listener_count: usize, + pub(super) accept_workers: usize, + pub(super) acme: AcmeCheckDetails, + pub(super) cache_enabled_route_count: usize, + pub(super) cache_zone_count: usize, + pub(super) cache_zones: Vec, + pub(super) control_plane: ControlPlaneCheckDetails, + pub(super) http3_early_data_enabled_listeners: usize, + pub(super) http3_enabled: bool, pub(super) listener_binding_count: usize, + pub(super) listener_count: usize, + pub(super) listener_model: &'static str, pub(super) listeners: Vec, + pub(super) route_transport: RouteTransportCheckDetails, + pub(super) tls: TlsCheckDetails, pub(super) tls_enabled: bool, - pub(super) http3_enabled: bool, - pub(super) http3_early_data_enabled_listeners: usize, - pub(super) total_vhost_count: usize, pub(super) total_route_count: usize, + pub(super) total_vhost_count: usize, pub(super) upstream_count: usize, - pub(super) cache_zone_count: usize, - pub(super) cache_enabled_route_count: usize, - pub(super) cache_zones: Vec, pub(super) worker_threads: Option, - pub(super) accept_workers: usize, - pub(super) route_transport: RouteTransportCheckDetails, - pub(super) control_plane: ControlPlaneCheckDetails, - pub(super) acme: AcmeCheckDetails, - pub(super) tls: TlsCheckDetails, } pub(super) struct CheckCacheZoneSummary { - pub(super) name: String, - pub(super) path: PathBuf, - pub(super) max_size_bytes: Option, - pub(super) inactive_secs: u64, pub(super) default_ttl_secs: u64, + pub(super) inactive_secs: u64, pub(super) max_entry_bytes: usize, + pub(super) max_size_bytes: Option, + pub(super) name: String, + pub(super) path: PathBuf, } pub(super) struct CheckListenerSummary { - pub(super) id: String, - pub(super) name: String, - pub(super) listen_addr: std::net::SocketAddr, + pub(super) access_log_format_configured: bool, pub(super) binding_count: usize, - pub(super) http3_enabled: bool, - pub(super) tls_enabled: bool, - pub(super) proxy_protocol_enabled: bool, + pub(super) bindings: Vec, pub(super) default_certificate: Option, + pub(super) http3_enabled: bool, + pub(super) id: String, pub(super) keep_alive: bool, + pub(super) listen_addr: std::net::SocketAddr, pub(super) max_connections: Option, - pub(super) access_log_format_configured: bool, - pub(super) bindings: Vec, + pub(super) name: String, + pub(super) proxy_protocol_enabled: bool, + pub(super) tls_enabled: bool, } pub(super) struct CheckListenerBindingSummary { - pub(super) binding_name: String, - pub(super) transport: String, - pub(super) listen_addr: std::net::SocketAddr, - pub(super) protocols: Vec, - pub(super) worker_count: usize, - pub(super) reuse_port_enabled: Option, pub(super) advertise_alt_svc: Option, pub(super) alt_svc_max_age_secs: Option, - pub(super) http3_max_concurrent_streams: Option, - pub(super) http3_stream_buffer_size: Option, + pub(super) binding_name: String, pub(super) http3_active_connection_id_limit: Option, - pub(super) http3_retry: Option, - pub(super) http3_host_key_path: Option, - pub(super) http3_gso: Option, pub(super) http3_early_data_enabled: Option, + pub(super) http3_gso: Option, + pub(super) http3_host_key_path: Option, + pub(super) http3_max_concurrent_streams: Option, + pub(super) http3_retry: Option, + pub(super) http3_stream_buffer_size: Option, + pub(super) listen_addr: std::net::SocketAddr, + pub(super) protocols: Vec, + pub(super) reuse_port_enabled: Option, + pub(super) transport: String, + pub(super) worker_count: usize, } pub(crate) fn build_check_summary(config: &rginx_config::ConfigSnapshot) -> CheckSummary { diff --git a/crates/rginx-app/src/check/tls.rs b/crates/rginx-app/src/check/tls.rs index 3178c2c7..b7fb9e72 100644 --- a/crates/rginx-app/src/check/tls.rs +++ b/crates/rginx-app/src/check/tls.rs @@ -1,36 +1,36 @@ use std::collections::BTreeSet; pub(super) struct TlsCheckDetails { - pub(super) listener_tls_profiles: usize, - pub(super) vhost_tls_overrides: usize, - pub(super) sni_name_count: usize, pub(super) certificate_bundle_count: usize, + pub(super) certificates: Vec, + pub(super) default_certificate_bindings: Vec, pub(super) default_certificates: Vec, pub(super) expiring_certificates: Vec, - pub(super) reloadable_fields: Vec, - pub(super) restart_required_fields: Vec, + pub(super) listener_tls_profiles: usize, pub(super) listeners: Vec, - pub(super) certificates: Vec, pub(super) ocsp: Vec, - pub(super) vhost_bindings: Vec, + pub(super) reloadable_fields: Vec, + pub(super) restart_required_fields: Vec, pub(super) sni_bindings: Vec, pub(super) sni_conflicts: Vec, - pub(super) default_certificate_bindings: Vec, + pub(super) sni_name_count: usize, + pub(super) vhost_bindings: Vec, + pub(super) vhost_tls_overrides: usize, } pub(super) struct TlsSniBindingCheck { - pub(super) listener_name: String, - pub(super) server_name: String, + pub(super) default_selected: bool, pub(super) fingerprints: Vec, + pub(super) listener_name: String, pub(super) scopes: Vec, - pub(super) default_selected: bool, + pub(super) server_name: String, } pub(super) struct TlsDefaultCertificateBindingCheck { - pub(super) listener_name: String, - pub(super) server_name: String, pub(super) fingerprints: Vec, + pub(super) listener_name: String, pub(super) scopes: Vec, + pub(super) server_name: String, } pub(super) fn tls_check_details(config: &rginx_config::ConfigSnapshot) -> TlsCheckDetails { @@ -55,13 +55,15 @@ pub(super) fn tls_check_details(config: &rginx_config::ConfigSnapshot) -> TlsChe .listeners .iter() .filter_map(|listener| listener.server.tls.as_ref()) - .map(|tls| 1 + tls.additional_certificates.len()) + .map(|tls| tls.additional_certificates.len().saturating_add(1)) .sum::() - + std::iter::once(&config.default_vhost) - .chain(config.vhosts.iter()) - .filter_map(|vhost| vhost.tls.as_ref()) - .map(|tls| 1 + tls.additional_certificates.len()) - .sum::(); + .saturating_add( + std::iter::once(&config.default_vhost) + .chain(config.vhosts.iter()) + .filter_map(|vhost| vhost.tls.as_ref()) + .map(|tls| tls.additional_certificates.len().saturating_add(1)) + .sum::(), + ); let default_certificates = config .listeners .iter() diff --git a/crates/rginx-app/src/cli.rs b/crates/rginx-app/src/cli.rs index ea7f6310..0ea1e2d2 100644 --- a/crates/rginx-app/src/cli.rs +++ b/crates/rginx-app/src/cli.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::env; use std::path::{Path, PathBuf}; @@ -10,38 +12,37 @@ use clap::{ArgAction, ArgGroup, Args, Parser, Subcommand, ValueEnum}; about = "A Rust edge reverse proxy for small and medium deployments" )] pub struct Cli { + #[command(subcommand)] + pub command: Option, #[arg(short, long, global = true, default_value_os_t = default_config_path())] pub config: PathBuf, - #[arg(short = 't', action = ArgAction::SetTrue, conflicts_with = "signal")] - pub test_config: bool, - #[arg(short = 's', value_enum, conflicts_with = "test_config")] pub signal: Option, - #[command(subcommand)] - pub command: Option, + #[arg(short = 't', action = ArgAction::SetTrue, conflicts_with = "signal")] + pub test_config: bool, } #[derive(Debug, Clone, Subcommand)] pub enum Command { Acme(AcmeArgs), Agent(AgentArgs), + Cache, Check, - Snapshot(SnapshotArgs), - SnapshotVersion, + ClearCacheInvalidations(CacheZoneArgs), + Counters, Delta(DeltaArgs), - Wait(WaitArgs), + InvalidateCache(InvalidateCacheArgs), + Peers, + PurgeCache(PurgeCacheArgs), SetDesiredRevision(DesiredRevisionArgs), + Snapshot(SnapshotArgs), + SnapshotVersion, Status, - Cache, - PurgeCache(PurgeCacheArgs), - InvalidateCache(InvalidateCacheArgs), - ClearCacheInvalidations(CacheZoneArgs), - Counters, Traffic(WindowArgs), - Peers, Upstreams(WindowArgs), + Wait(WaitArgs), } #[derive(Debug, Clone, Args)] @@ -63,9 +64,9 @@ pub enum AcmeCommand { #[derive(Debug, Clone, Copy, PartialEq, Eq, Subcommand)] pub enum AgentAdminCommand { - Status, Disable, Enable, + Status, } #[derive(Debug, Clone, Args)] @@ -107,12 +108,12 @@ pub struct SnapshotArgs { #[derive(Debug, Clone, Args)] pub struct DeltaArgs { - #[arg(long)] - pub since_version: u64, - #[arg(long, value_enum)] pub include: Vec, + #[arg(long)] + pub since_version: u64, + #[arg(long, value_parser = parse_recent_window_secs)] pub window_secs: Option, } @@ -120,22 +121,18 @@ pub struct DeltaArgs { #[derive(Debug, Clone, Args)] #[command(group = ArgGroup::new("selector").args(["key", "prefix"]).multiple(false))] pub struct PurgeCacheArgs { - #[arg(long)] - pub zone: String, - #[arg(long)] pub key: Option, #[arg(long)] pub prefix: Option, + #[arg(long)] + pub zone: String, } #[derive(Debug, Clone, Args)] #[command(group = ArgGroup::new("selector").args(["key", "prefix", "tag"]).multiple(false))] pub struct InvalidateCacheArgs { - #[arg(long)] - pub zone: String, - #[arg(long)] pub key: Option, @@ -144,6 +141,8 @@ pub struct InvalidateCacheArgs { #[arg(long)] pub tag: Option, + #[arg(long)] + pub zone: String, } #[derive(Debug, Clone, Args)] @@ -154,21 +153,21 @@ pub struct CacheZoneArgs { #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] pub enum SnapshotModuleArg { - Status, + Cache, Counters, - Traffic, #[value(name = "peer-health")] PeerHealth, + Status, + Traffic, Upstreams, - Cache, } #[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] pub enum SignalCommand { + Quit, Reload, Restart, Stop, - Quit, } impl SignalCommand { @@ -234,6 +233,3 @@ pub fn pid_path_for_config(config_path: &Path) -> PathBuf { config_path.with_extension("pid") } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-app/src/main.rs b/crates/rginx-app/src/main.rs index e414df85..9c223a17 100644 --- a/crates/rginx-app/src/main.rs +++ b/crates/rginx-app/src/main.rs @@ -1,3 +1,12 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] mod admin_cli; mod check; mod cli; @@ -5,9 +14,6 @@ mod pid_file; mod runtime; mod signal; -#[cfg(not(target_os = "linux"))] -compile_error!("rginx supports Linux only"); - use anyhow::{Context, anyhow}; use clap::Parser; @@ -17,6 +23,9 @@ use crate::pid_file::PidFileGuard; use crate::runtime::build_runtime; use crate::signal::send_signal_from_pid_file; +#[cfg(not(target_os = "linux"))] +compile_error!("rginx supports Linux only"); + fn main() -> anyhow::Result<()> { rginx_http::install_default_crypto_provider(); let cli = Cli::parse(); diff --git a/crates/rginx-app/src/pid_file.rs b/crates/rginx-app/src/pid_file.rs index 0e84966a..60e5852d 100644 --- a/crates/rginx-app/src/pid_file.rs +++ b/crates/rginx-app/src/pid_file.rs @@ -45,12 +45,6 @@ impl PidFileGuard { } } -pub(crate) fn read_pid_file_record(path: &Path) -> anyhow::Result { - let contents = fs::read_to_string(path) - .with_context(|| format!("failed to read pid file {}", path.display()))?; - parse_pid_file_record(&contents, path) -} - impl Drop for PidFileGuard { fn drop(&mut self) { let expected_pid = std::process::id().to_string(); @@ -95,6 +89,12 @@ impl Drop for PidFileGuard { } } +pub(crate) fn read_pid_file_record(path: &Path) -> anyhow::Result { + let contents = fs::read_to_string(path) + .with_context(|| format!("failed to read pid file {}", path.display()))?; + parse_pid_file_record(&contents, path) +} + fn parse_pid_file_record(contents: &str, path: &Path) -> anyhow::Result { let mut lines = contents.lines(); let raw_pid = lines diff --git a/crates/rginx-app/tests/access_log.rs b/crates/rginx-app/tests/access_log.rs index efd1242a..38d9bc0a 100644 --- a/crates/rginx-app/tests/access_log.rs +++ b/crates/rginx-app/tests/access_log.rs @@ -1,3 +1,14 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::path::Path; @@ -9,13 +20,60 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, SignatureScheme, StreamOwned}; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/GfUhio45GzZi\nz/o67ngbdqW77Bk54aQ/L9kthS47SkqHVaC8+/AYXfDddl0Q+EHQleGrCqjPrRLA\nZAm0Qz0YZEdWcFX4F4T/gvVc7gfobPmkwoVoJwk/A7giuymN9+kpDbXsr2X29uMr\nfW0krA4ZpD/+JzLMXXfG5S9mazrs81wGqFlByGV32z3Qiecvd8JL4nnyz2mMTSqV\nOvtuh285dpgHFugOUGHhhXs9b/CPAGiBTBxuIm+2tGdjtmVGL8L9tlRMYnmQ7u9K\nUCAyHYovUPJjZazTqfkpQ2CkC4iOZ7Ay25gn0fbGBEysyOBZIPw39FEW63jGembu\nOVF6nO/DAgMBAAECggEAKLC7v80TVHiFX4veQZ8WRu7AAmAWzPrNMMEc8rLZcblz\nXhau956DdITILTevQFZEGUhYuUU3RaUaCYojgNUSVLfBctfPjlhfstItMYDjgSt3\nCox6wH8TWm4NzqNgiUCgzmODeaatROUz4MY/r5/NDsuo7pJlIBvEzb5uFdY+QUZ/\nR5gHRiD2Q3wCODe8zQRfTZGo7jCimAuWTLurWZl6ax/4TjWbXCD6DTuUo81cW3vy\nne6tEetHcABRO7uDoBYXk12pCgqFZzjLMnKJjQM+OYnSj6DoWjOu1drT5YyRLGDj\nfzN8V0aKRkOYoZ5QZOua8pByOyQElJnM16vkPtHgPQKBgQD6SOUNWEghvYIGM/lx\nc22/zjvDjeaGC3qSmlpQYN5MGuDoszeDBZ+rMTmHqJ9FcHYkLQnUI7ZkHhRGt/wQ\n/w3CroJjPBgKk+ipy2cBHSI+z+U20xjYzE8hxArWbXG1G4rDt5AIz68IQPsfkVND\nktkDABDaU+KwBPx8fjeeqtRQxQKBgQDDdxdLB1XcfZMX0KEP5RfA8ar1nW41TUAl\nTCOLaXIQbHZ0BeW7USE9mK8OKnVALZGJ+rpxvYFPZ5MWxchpb/cuIwXjLoN6uZVb\nfx4Hho+2iCfhcEKzs8XZW48duKIfhx13BiILLf/YaHAWFs9UfVcQog4Qx03guyMr\n7k9bFuy25wKBgQDpE48zAT6TJS775dTrAQp4b28aan/93pyz/8gRSFRb3UALlDIi\n8s7BluKzYaWI/fUXNVYM14EX9Sb+wIGdtlezL94+2Yyt9RXbYY8361Cj2+jiSG3A\nH2ulzzIkg+E7Pj3Yi443lmiysAjsWeKHcC5l697F4w6cytfye3wCZ6W23QKBgQC0\n9tX+5aytdSkwnDvxXlVOka+ItBcri/i+Ty59TMOIxxInuqoFcUhIIcq4X8CsCUQ8\nLYBd+2fznt3D8JrqWvnKoiw6N38MqTLJQfgIWaFGCep6QhfPDbo30RfAGYcnj01N\nO8Va+lxq+84B9V5AR8bKpG5HRG4qiLc4XerkV2YSswKBgDt9eerSBZyLVwfku25Y\nfrh+nEjUZy81LdlpJmu/bfa2FfItzBqDZPskkJJW9ON82z/ejGFbsU48RF7PJUMr\nGimE33QeTDToGozHCq0QOd0SMfsVkOQR+EROdmY52UIYAYgQUfI1FQ9lLsw10wlQ\nD11SHTL7b9pefBWfW73I7ttV\n-----END PRIVATE KEY-----\n"; +#[derive(Debug)] +struct InsecureServerCertVerifier { + supported_schemes: Vec, +} + +impl InsecureServerCertVerifier { + fn new() -> Self { + Self { + supported_schemes: rustls::crypto::aws_lc_rs::default_provider() + .signature_verification_algorithms + .supported_schemes(), + } + } +} + +impl ServerCertVerifier for InsecureServerCertVerifier { + fn supported_verify_schemes(&self) -> Vec { + self.supported_schemes.clone() + } + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } +} + #[test] fn custom_access_log_format_emits_expected_line() { let listen_addr = reserve_loopback_addr(); @@ -256,53 +314,3 @@ fn send_https_request(listen_addr: SocketAddr, request: &str) -> Result, -} - -impl InsecureServerCertVerifier { - fn new() -> Self { - Self { - supported_schemes: rustls::crypto::aws_lc_rs::default_provider() - .signature_verification_algorithms - .supported_schemes(), - } - } -} - -impl ServerCertVerifier for InsecureServerCertVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } -} diff --git a/crates/rginx-app/tests/active_health.rs b/crates/rginx-app/tests/active_health.rs index 36e5f7b7..fb797d7f 100644 --- a/crates/rginx-app/tests/active_health.rs +++ b/crates/rginx-app/tests/active_health.rs @@ -1,3 +1,14 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::sync::Arc; @@ -5,10 +16,14 @@ use std::sync::atomic::{AtomicBool, Ordering}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, read_http_head, reserve_loopback_addr}; +#[derive(Debug)] +struct ParsedResponse { + body: Vec, + status: u16, +} + #[test] fn active_health_checks_mark_peer_unhealthy_and_recover_after_successive_probes() { let health_ok = Arc::new(AtomicBool::new(false)); @@ -94,7 +109,8 @@ fn wait_for_proxy_response( expectation: &str, server: &mut ServerHarness, ) { - let deadline = std::time::Instant::now() + timeout; + let deadline = + std::time::Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while std::time::Instant::now() < deadline { @@ -161,12 +177,6 @@ fn spawn_health_upstream(health_ok: Option>, body: &'static str) listen_addr } -#[derive(Debug)] -struct ParsedResponse { - status: u16, - body: Vec, -} - fn send_http_request( listen_addr: SocketAddr, method: &str, @@ -207,7 +217,8 @@ fn parse_http_response(bytes: &[u8]) -> Result { .parse::() .map_err(|error| format!("invalid status code: {error}"))?; - Ok(ParsedResponse { status, body: bytes[head_end + 4..].to_vec() }) + let body_start = head_end.saturating_add(4); + Ok(ParsedResponse { status, body: bytes[body_start..].to_vec() }) } fn request_path(head: &str) -> &str { diff --git a/crates/rginx-app/tests/admin.rs b/crates/rginx-app/tests/admin.rs index 16f93fb1..4cefc84d 100644 --- a/crates/rginx-app/tests/admin.rs +++ b/crates/rginx-app/tests/admin.rs @@ -1,28 +1,36 @@ #![cfg(unix)] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + +#[path = "admin/commands.rs"] +mod commands; +#[path = "admin/delta_wait.rs"] +mod delta_wait; +#[path = "admin/snapshot.rs"] +mod snapshot; -#[allow(unused_imports)] use std::io::{BufReader, Read, Write}; -#[allow(unused_imports)] use std::net::TcpListener; -#[allow(unused_imports)] use std::os::unix::net::UnixStream; -#[allow(unused_imports)] use std::path::Path; -#[allow(unused_imports)] use std::process::Command; -#[allow(unused_imports)] use std::time::{Duration, Instant}; -#[allow(unused_imports)] use rcgen::{ BasicConstraints, CertificateParams, CertificateRevocationList, CertificateRevocationListParams, DnType, IsCa, Issuer, KeyIdMethod, KeyPair, KeyUsagePurpose, RevocationReason, RevokedCertParams, SerialNumber, date_time_ymd, }; -mod support; - -#[allow(unused_imports)] use rginx_runtime::admin::{ AdminRequest, AdminResponse, RevisionSnapshot, admin_socket_path_for_config, }; @@ -30,15 +38,20 @@ use support::{ READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr, reserve_loopback_addr_pair, }; -#[path = "admin/commands.rs"] -mod commands; -#[path = "admin/delta_wait.rs"] -mod delta_wait; -#[path = "admin/snapshot.rs"] -mod snapshot; +struct TestCertifiedKey { + cert: rcgen::Certificate, + params: CertificateParams, + signing_key: KeyPair, +} + +impl TestCertifiedKey { + fn issuer(&self) -> Issuer<'_, &KeyPair> { + Issuer::from_params(&self.params, &self.signing_key) + } +} fn wait_for_admin_socket(path: &Path, timeout: Duration) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -237,18 +250,6 @@ fn render_output(output: &std::process::Output) -> String { ) } -struct TestCertifiedKey { - cert: rcgen::Certificate, - signing_key: KeyPair, - params: CertificateParams, -} - -impl TestCertifiedKey { - fn issuer(&self) -> Issuer<'_, &KeyPair> { - Issuer::from_params(&self.params, &self.signing_key) - } -} - fn generate_cert(hostname: &str) -> TestCertifiedKey { let params = CertificateParams::new(vec![hostname.to_string()]) .expect("self-signed certificate should generate"); diff --git a/crates/rginx-app/tests/backup.rs b/crates/rginx-app/tests/backup.rs index 7788f525..c7c148f3 100644 --- a/crates/rginx-app/tests/backup.rs +++ b/crates/rginx-app/tests/backup.rs @@ -1,10 +1,19 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/cache.rs b/crates/rginx-app/tests/cache.rs index 36e0ed96..e09890f3 100644 --- a/crates/rginx-app/tests/cache.rs +++ b/crates/rginx-app/tests/cache.rs @@ -1,4 +1,13 @@ #![cfg(unix)] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] mod support; diff --git a/crates/rginx-app/tests/cache/benchmarks.rs b/crates/rginx-app/tests/cache/benchmarks.rs index abf58ab8..67133dbc 100644 --- a/crates/rginx-app/tests/cache/benchmarks.rs +++ b/crates/rginx-app/tests/cache/benchmarks.rs @@ -171,10 +171,10 @@ fn benchmark_shared_fill_path( .expect("first upstream fill should start"); let second_started = Instant::now(); let second = thread::spawn(move || send_get_request(listen_addr, &second_path)); - thread::sleep(barrier_delay / 5); + thread::sleep(barrier_delay.checked_div(5).expect("benchmark divisor is nonzero")); assert_eq!( upstream_hits.load(Ordering::Relaxed), - index + 1, + index.saturating_add(1), "shared-fill benchmark should only open one upstream request per cold key" ); open_gate(&release_gate); @@ -244,7 +244,7 @@ fn spawn_keyed_blocked_fill_response_server(body: &'static str) -> KeyedBlockedF let mut counts = path_counts.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); let count = counts.entry(path.clone()).or_insert(0); - *count += 1; + *count = count.saturating_add(1); *count == 1 }; hits.fetch_add(1, Ordering::Relaxed); diff --git a/crates/rginx-app/tests/check.rs b/crates/rginx-app/tests/check.rs index 72c82c04..60ab7d40 100644 --- a/crates/rginx-app/tests/check.rs +++ b/crates/rginx-app/tests/check.rs @@ -1,16 +1,12 @@ -use std::env; -use std::fs; -use std::net::{SocketAddr, TcpListener}; -use std::path::{Path, PathBuf}; -use std::process::{Command, Output}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::time::{SystemTime, UNIX_EPOCH}; - -use rcgen::{ - BasicConstraints, CertificateParams, CertifiedKey, DnType, ExtendedKeyUsagePurpose, IsCa, - KeyPair, -}; - +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] #[path = "check/acme.rs"] mod acme; #[path = "check/basic.rs"] @@ -24,4 +20,17 @@ mod summary; #[path = "check/tls.rs"] mod tls; +use std::env; +use std::fs; +use std::net::{SocketAddr, TcpListener}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::{SystemTime, UNIX_EPOCH}; + +use rcgen::{ + BasicConstraints, CertificateParams, CertifiedKey, DnType, ExtendedKeyUsagePurpose, IsCa, + KeyPair, +}; + pub(crate) use helpers::*; diff --git a/crates/rginx-app/tests/check/helpers.rs b/crates/rginx-app/tests/check/helpers.rs index 013e2dba..c4077e32 100644 --- a/crates/rginx-app/tests/check/helpers.rs +++ b/crates/rginx-app/tests/check/helpers.rs @@ -1,5 +1,7 @@ use super::*; +pub(crate) type TestCertifiedKey = CertifiedKey; + pub(crate) fn run_rginx(args: impl IntoIterator>) -> Output { let mut command = Command::new(binary_path()); for arg in args { @@ -54,8 +56,6 @@ pub(crate) fn expected_reloadable_tls_updates_line() -> String { format!("Reloadable TLS updates: {}", rginx_http::tls_reloadable_fields().join(", ")) } -pub(crate) type TestCertifiedKey = CertifiedKey; - pub(crate) fn generate_cert(hostname: &str) -> TestCertifiedKey { rcgen::generate_simple_self_signed(vec![hostname.to_string()]) .expect("self-signed certificate should generate") diff --git a/crates/rginx-app/tests/compression.rs b/crates/rginx-app/tests/compression.rs index 7105b959..1198585b 100644 --- a/crates/rginx-app/tests/compression.rs +++ b/crates/rginx-app/tests/compression.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::collections::HashMap; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; @@ -6,10 +17,21 @@ use std::time::Duration; use brotli::Decompressor; use flate2::read::GzDecoder; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; +#[derive(Debug)] +struct ParsedResponse { + body: Vec, + headers: HashMap, + status: u16, +} + +impl ParsedResponse { + fn header(&self, name: &str) -> Option<&str> { + self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) + } +} + #[test] fn gzips_large_return_text_responses_when_client_accepts_gzip() { let listen_addr = reserve_loopback_addr(); @@ -162,19 +184,6 @@ fn custom_compression_content_types_can_disable_default_text_compression() { server.shutdown_and_wait(Duration::from_secs(5)); } -#[derive(Debug)] -struct ParsedResponse { - status: u16, - headers: HashMap, - body: Vec, -} - -impl ParsedResponse { - fn header(&self, name: &str) -> Option<&str> { - self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) - } -} - fn send_http_request(listen_addr: SocketAddr, request: &str) -> Result { let mut stream = TcpStream::connect_timeout(&listen_addr, Duration::from_millis(200)) .map_err(|error| format!("failed to connect to {listen_addr}: {error}"))?; @@ -219,7 +228,8 @@ fn parse_http_response(bytes: &[u8]) -> Result { headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string()); } - Ok(ParsedResponse { status, headers, body: bytes[head_end + 4..].to_vec() }) + let body_start = head_end.saturating_add(4); + Ok(ParsedResponse { status, headers, body: bytes[body_start..].to_vec() }) } fn decode_gzip(bytes: &[u8]) -> Vec { diff --git a/crates/rginx-app/tests/dns_refresh.rs b/crates/rginx-app/tests/dns_refresh.rs index 3cbc6464..3b19db8e 100644 --- a/crates/rginx-app/tests/dns_refresh.rs +++ b/crates/rginx-app/tests/dns_refresh.rs @@ -1,10 +1,19 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/downstream_mtls.rs b/crates/rginx-app/tests/downstream_mtls.rs index 92cbdb5e..10473b94 100644 --- a/crates/rginx-app/tests/downstream_mtls.rs +++ b/crates/rginx-app/tests/downstream_mtls.rs @@ -1,55 +1,55 @@ -#[allow(unused_imports)] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + +#[path = "downstream_mtls/enforcement.rs"] +mod enforcement; +#[path = "downstream_mtls/observability.rs"] +mod observability; +#[path = "downstream_mtls/validation.rs"] +mod validation; + +#[path = "downstream_mtls/verifier.rs"] +mod verifier; + use std::env; -#[allow(unused_imports)] use std::io::{Read, Write}; -#[allow(unused_imports)] use std::net::SocketAddr; -#[allow(unused_imports)] use std::os::unix::net::UnixStream; -#[allow(unused_imports)] use std::path::{Path, PathBuf}; -#[allow(unused_imports)] use std::sync::Arc; -#[allow(unused_imports)] use std::sync::atomic::{AtomicU64, Ordering}; -#[allow(unused_imports)] use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -#[allow(unused_imports)] use rcgen::{ BasicConstraints, CertificateParams, CertificateRevocationList, CertificateRevocationListParams, DnType, ExtendedKeyUsagePurpose, IsCa, Issuer, KeyIdMethod, KeyPair, KeyUsagePurpose, RevocationReason, RevokedCertParams, SerialNumber, date_time_ymd, }; -#[allow(unused_imports)] use rginx_runtime::admin::{AdminRequest, AdminResponse, admin_socket_path_for_config}; -#[allow(unused_imports)] use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; -#[allow(unused_imports)] use rustls::pki_types::pem::PemObject; -#[allow(unused_imports)] use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; -#[allow(unused_imports)] use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, SignatureScheme, StreamOwned}; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; -#[path = "downstream_mtls/enforcement.rs"] -mod enforcement; -#[path = "downstream_mtls/observability.rs"] -mod observability; -#[path = "downstream_mtls/validation.rs"] -mod validation; +use verifier::*; struct TlsFixture { _dir: PathBuf, ca_cert_pem: String, - server_cert_pem: String, - server_key_pem: String, client_cert_path: PathBuf, client_key_path: PathBuf, + server_cert_pem: String, + server_key_pem: String, } impl TlsFixture { @@ -82,8 +82,8 @@ impl TlsFixture { struct TestCertifiedKey { cert: rcgen::Certificate, - signing_key: KeyPair, params: CertificateParams, + signing_key: KeyPair, } impl TestCertifiedKey { @@ -185,7 +185,7 @@ fn wait_for_https_text_response( expected_body: &str, timeout: Duration, ) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -351,7 +351,7 @@ fn temp_dir(prefix: &str) -> PathBuf { } fn wait_for_admin_socket(path: &Path, timeout: Duration) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -389,8 +389,3 @@ fn query_admin_socket(path: &Path, request: AdminRequest) -> Result Vec { + self.supported_schemes.clone() + } fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, @@ -44,8 +47,4 @@ impl ServerCertVerifier for InsecureServerCertVerifier { ) -> Result { Ok(HandshakeSignatureValid::assertion()) } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } } diff --git a/crates/rginx-app/tests/failover.rs b/crates/rginx-app/tests/failover.rs index dc1b26a7..e1e948ff 100644 --- a/crates/rginx-app/tests/failover.rs +++ b/crates/rginx-app/tests/failover.rs @@ -1,3 +1,14 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::sync::Arc; @@ -5,8 +16,6 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/grpc_http3.rs b/crates/rginx-app/tests/grpc_http3.rs index 94774f14..d02fbf65 100644 --- a/crates/rginx-app/tests/grpc_http3.rs +++ b/crates/rginx-app/tests/grpc_http3.rs @@ -1,3 +1,25 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + +#[path = "grpc_http3/config.rs"] +mod config; +#[path = "grpc_http3/health.rs"] +mod health; +#[path = "grpc_http3/helpers.rs"] +mod helpers; +#[path = "grpc_http3/proxy.rs"] +mod proxy; +#[path = "grpc_http3/timeout.rs"] +mod timeout; + use std::env; use std::fs; use std::net::SocketAddr; @@ -20,32 +42,19 @@ use rustls::{ClientConfig, RootCertStore}; use tokio::sync::oneshot; use tokio::task::JoinHandle; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; +use config::*; +use helpers::*; + const GRPC_METHOD_PATH: &str = "/grpc.health.v1.Health/Check"; const GRPC_REQUEST_FRAME: &[u8] = b"\x00\x00\x00\x00\x02hi"; const GRPC_RESPONSE_FRAME: &[u8] = b"\x00\x00\x00\x00\x02ok"; -#[path = "grpc_http3/config.rs"] -mod config; -#[path = "grpc_http3/health.rs"] -mod health; -#[path = "grpc_http3/helpers.rs"] -mod helpers; -#[path = "grpc_http3/proxy.rs"] -mod proxy; -#[path = "grpc_http3/timeout.rs"] -mod timeout; - -use config::*; -use helpers::*; - #[derive(Debug, Clone, PartialEq, Eq)] struct ObservedGrpcRequest { - path: String, content_type: Option, + path: String, } #[derive(Debug)] @@ -73,16 +82,16 @@ impl Drop for TempDirGuard { #[derive(Debug)] struct H3Response { - status: StatusCode, - headers: std::collections::HashMap, body: Bytes, + headers: std::collections::HashMap, + status: StatusCode, trailers: Option, } #[derive(Clone, Copy)] enum UpstreamMode { - Immediate, DelayHeaders(Duration), + Immediate, } async fn h3_request( diff --git a/crates/rginx-app/tests/grpc_http3/helpers.rs b/crates/rginx-app/tests/grpc_http3/helpers.rs index c09af764..bb9cd9ca 100644 --- a/crates/rginx-app/tests/grpc_http3/helpers.rs +++ b/crates/rginx-app/tests/grpc_http3/helpers.rs @@ -46,7 +46,7 @@ pub(super) fn decode_grpc_web_response(body: &[u8]) -> (Vec, HeaderMap) { while cursor.len() >= 5 { let flags = cursor[0]; let len = u32::from_be_bytes([cursor[1], cursor[2], cursor[3], cursor[4]]) as usize; - let frame_len = 5 + len; + let frame_len = len.saturating_add(5); let frame = &cursor[..frame_len]; let payload = &frame[5..]; if flags & 0x80 == 0 { diff --git a/crates/rginx-app/tests/grpc_proxy.rs b/crates/rginx-app/tests/grpc_proxy.rs index 362cfd26..9cabd174 100644 --- a/crates/rginx-app/tests/grpc_proxy.rs +++ b/crates/rginx-app/tests/grpc_proxy.rs @@ -1,71 +1,78 @@ -#[allow(unused_imports)] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + +#[path = "grpc_proxy/basic.rs"] +mod basic; +#[path = "grpc_proxy/lifecycle.rs"] +mod lifecycle; +#[path = "grpc_proxy/timeout.rs"] +mod timeout; + +#[path = "grpc_proxy/helpers/body.rs"] +mod helpers_body; +#[path = "grpc_proxy/helpers/config.rs"] +mod helpers_config; +#[path = "grpc_proxy/helpers/grpc_web.rs"] +mod helpers_grpc_web; +#[path = "grpc_proxy/helpers/server.rs"] +mod helpers_server; +#[path = "grpc_proxy/helpers/tls.rs"] +mod helpers_tls; +#[path = "grpc_proxy/helpers/upstream.rs"] +mod helpers_upstream; + use std::convert::Infallible; -#[allow(unused_imports)] use std::fs; -#[allow(unused_imports)] use std::future::Future; -#[allow(unused_imports)] use std::net::SocketAddr; -#[allow(unused_imports)] use std::path::{Path, PathBuf}; -#[allow(unused_imports)] use std::pin::Pin; -#[allow(unused_imports)] use std::sync::atomic::{AtomicU8, AtomicU64, Ordering}; -#[allow(unused_imports)] use std::sync::{Arc, Mutex}; -#[allow(unused_imports)] use std::task::{Context, Poll}; -#[allow(unused_imports)] use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -#[allow(unused_imports)] use base64::Engine as _; -#[allow(unused_imports)] use base64::engine::general_purpose::STANDARD; -#[allow(unused_imports)] use bytes::{Bytes, BytesMut}; -#[allow(unused_imports)] use http_body_util::{BodyExt, Empty, Full}; -#[allow(unused_imports)] use hyper::body::{Body, Frame, Incoming, SizeHint}; -#[allow(unused_imports)] use hyper::http::HeaderMap; -#[allow(unused_imports)] use hyper::http::header::{CONTENT_TYPE, HeaderName, HeaderValue, TE}; -#[allow(unused_imports)] use hyper::server::conn::http2; -#[allow(unused_imports)] use hyper::service::service_fn; -#[allow(unused_imports)] use hyper::{Request, Response, StatusCode, Version}; -#[allow(unused_imports)] use hyper_rustls::HttpsConnectorBuilder; -#[allow(unused_imports)] use hyper_util::client::legacy::Client; -#[allow(unused_imports)] +use hyper_util::client::legacy::connect::HttpConnector; use hyper_util::rt::{TokioExecutor, TokioIo}; -#[allow(unused_imports)] use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; -#[allow(unused_imports)] use rustls::pki_types::pem::PemObject; -#[allow(unused_imports)] use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime}; -#[allow(unused_imports)] use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme}; -#[allow(unused_imports)] use tokio::sync::oneshot; -#[allow(unused_imports)] use tokio::task::JoinHandle; -#[allow(unused_imports)] use tokio_rustls::TlsAcceptor; -mod support; - pub(crate) use support::{ READY_ROUTE_CONFIG, ServerHarness, apply_tls_placeholders, reserve_loopback_addr, }; +pub(crate) use helpers_body::*; +pub(crate) use helpers_config::*; +pub(crate) use helpers_grpc_web::*; +pub(crate) use helpers_server::*; +pub(crate) use helpers_tls::*; +pub(crate) use helpers_upstream::*; + const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/GfUhio45GzZi\nz/o67ngbdqW77Bk54aQ/L9kthS47SkqHVaC8+/AYXfDddl0Q+EHQleGrCqjPrRLA\nZAm0Qz0YZEdWcFX4F4T/gvVc7gfobPmkwoVoJwk/A7giuymN9+kpDbXsr2X29uMr\nfW0krA4ZpD/+JzLMXXfG5S9mazrs81wGqFlByGV32z3Qiecvd8JL4nnyz2mMTSqV\nOvtuh285dpgHFugOUGHhhXs9b/CPAGiBTBxuIm+2tGdjtmVGL8L9tlRMYnmQ7u9K\nUCAyHYovUPJjZazTqfkpQ2CkC4iOZ7Ay25gn0fbGBEysyOBZIPw39FEW63jGembu\nOVF6nO/DAgMBAAECggEAKLC7v80TVHiFX4veQZ8WRu7AAmAWzPrNMMEc8rLZcblz\nXhau956DdITILTevQFZEGUhYuUU3RaUaCYojgNUSVLfBctfPjlhfstItMYDjgSt3\nCox6wH8TWm4NzqNgiUCgzmODeaatROUz4MY/r5/NDsuo7pJlIBvEzb5uFdY+QUZ/\nR5gHRiD2Q3wCODe8zQRfTZGo7jCimAuWTLurWZl6ax/4TjWbXCD6DTuUo81cW3vy\nne6tEetHcABRO7uDoBYXk12pCgqFZzjLMnKJjQM+OYnSj6DoWjOu1drT5YyRLGDj\nfzN8V0aKRkOYoZ5QZOua8pByOyQElJnM16vkPtHgPQKBgQD6SOUNWEghvYIGM/lx\nc22/zjvDjeaGC3qSmlpQYN5MGuDoszeDBZ+rMTmHqJ9FcHYkLQnUI7ZkHhRGt/wQ\n/w3CroJjPBgKk+ipy2cBHSI+z+U20xjYzE8hxArWbXG1G4rDt5AIz68IQPsfkVND\nktkDABDaU+KwBPx8fjeeqtRQxQKBgQDDdxdLB1XcfZMX0KEP5RfA8ar1nW41TUAl\nTCOLaXIQbHZ0BeW7USE9mK8OKnVALZGJ+rpxvYFPZ5MWxchpb/cuIwXjLoN6uZVb\nfx4Hho+2iCfhcEKzs8XZW48duKIfhx13BiILLf/YaHAWFs9UfVcQog4Qx03guyMr\n7k9bFuy25wKBgQDpE48zAT6TJS775dTrAQp4b28aan/93pyz/8gRSFRb3UALlDIi\n8s7BluKzYaWI/fUXNVYM14EX9Sb+wIGdtlezL94+2Yyt9RXbYY8361Cj2+jiSG3A\nH2ulzzIkg+E7Pj3Yi443lmiysAjsWeKHcC5l697F4w6cytfye3wCZ6W23QKBgQC0\n9tX+5aytdSkwnDvxXlVOka+ItBcri/i+Ty59TMOIxxInuqoFcUhIIcq4X8CsCUQ8\nLYBd+2fznt3D8JrqWvnKoiw6N38MqTLJQfgIWaFGCep6QhfPDbo30RfAGYcnj01N\nO8Va+lxq+84B9V5AR8bKpG5HRG4qiLc4XerkV2YSswKBgDt9eerSBZyLVwfku25Y\nfrh+nEjUZy81LdlpJmu/bfa2FfItzBqDZPskkJJW9ON82z/ejGFbsU48RF7PJUMr\nGimE33QeTDToGozHCq0QOd0SMfsVkOQR+EROdmY52UIYAYgQUfI1FQ9lLsw10wlQ\nD11SHTL7b9pefBWfW73I7ttV\n-----END PRIVATE KEY-----\n"; const GRPC_METHOD_PATH: &str = "/grpc.health.v1.Health/Check"; @@ -75,40 +82,13 @@ const GRPC_RESPONSE_FRAME: &[u8] = b"\x00\x00\x00\x00\x02ok"; #[derive(Debug)] struct ObservedRequest { - method: String, - version: Version, - path: String, alpn_protocol: Option, + body: Bytes, content_type: Option, grpc_timeout: Option, + method: String, + path: String, te: Option, - body: Bytes, trailers: Option, + version: Version, } - -#[path = "grpc_proxy/basic.rs"] -mod basic; -#[path = "grpc_proxy/lifecycle.rs"] -mod lifecycle; -#[path = "grpc_proxy/timeout.rs"] -mod timeout; - -#[path = "grpc_proxy/helpers/body.rs"] -mod helpers_body; -#[path = "grpc_proxy/helpers/config.rs"] -mod helpers_config; -#[path = "grpc_proxy/helpers/grpc_web.rs"] -mod helpers_grpc_web; -#[path = "grpc_proxy/helpers/server.rs"] -mod helpers_server; -#[path = "grpc_proxy/helpers/tls.rs"] -mod helpers_tls; -#[path = "grpc_proxy/helpers/upstream.rs"] -mod helpers_upstream; - -pub(crate) use helpers_body::*; -pub(crate) use helpers_config::*; -pub(crate) use helpers_grpc_web::*; -pub(crate) use helpers_server::*; -pub(crate) use helpers_tls::*; -pub(crate) use helpers_upstream::*; diff --git a/crates/rginx-app/tests/grpc_proxy/basic.rs b/crates/rginx-app/tests/grpc_proxy/basic.rs index 7c0a54d3..3ce4da54 100644 --- a/crates/rginx-app/tests/grpc_proxy/basic.rs +++ b/crates/rginx-app/tests/grpc_proxy/basic.rs @@ -1,8 +1,7 @@ -use super::*; - #[path = "basic/grpc.rs"] mod grpc; #[path = "basic/grpc_web.rs"] mod grpc_web; #[path = "basic/routing.rs"] mod routing; +use super::*; diff --git a/crates/rginx-app/tests/grpc_proxy/basic/grpc_web.rs b/crates/rginx-app/tests/grpc_proxy/basic/grpc_web.rs index d7768600..05194fcd 100644 --- a/crates/rginx-app/tests/grpc_proxy/basic/grpc_web.rs +++ b/crates/rginx-app/tests/grpc_proxy/basic/grpc_web.rs @@ -12,7 +12,7 @@ async fn proxies_basic_grpc_web_binary_requests_to_http2_grpc_upstreams() { ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request = Request::builder() @@ -70,7 +70,7 @@ async fn proxies_basic_grpc_web_text_requests_to_http2_grpc_upstreams() { let mut server = TestServer::spawn(listen_addr, plain_proxy_config(listen_addr, upstream_addr)); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let encoded_request = format!("{}\r\n", encode_grpc_web_text_payload(GRPC_REQUEST_FRAME)); @@ -134,7 +134,7 @@ async fn proxies_grpc_web_binary_trailer_frames_to_http2_request_trailers() { let mut server = TestServer::spawn(listen_addr, plain_proxy_config(listen_addr, upstream_addr)); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request_body = grpc_web_request_with_trailers(); @@ -196,7 +196,7 @@ async fn proxies_grpc_web_text_trailer_frames_to_http2_request_trailers() { let mut server = TestServer::spawn(listen_addr, plain_proxy_config(listen_addr, upstream_addr)); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request_body = diff --git a/crates/rginx-app/tests/grpc_proxy/basic/routing.rs b/crates/rginx-app/tests/grpc_proxy/basic/routing.rs index 38000fb6..d7a04ce1 100644 --- a/crates/rginx-app/tests/grpc_proxy/basic/routing.rs +++ b/crates/rginx-app/tests/grpc_proxy/basic/routing.rs @@ -12,7 +12,7 @@ async fn routes_grpc_requests_by_service_and_method() { ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request = Request::builder() diff --git a/crates/rginx-app/tests/grpc_proxy/helpers/body.rs b/crates/rginx-app/tests/grpc_proxy/helpers/body.rs index 2ef6ce78..84e95b9e 100644 --- a/crates/rginx-app/tests/grpc_proxy/helpers/body.rs +++ b/crates/rginx-app/tests/grpc_proxy/helpers/body.rs @@ -14,6 +14,10 @@ impl Body for GrpcResponseBody { type Data = Bytes; type Error = Infallible; + fn is_end_stream(&self) -> bool { + self.state >= 2 + } + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -36,10 +40,6 @@ impl Body for GrpcResponseBody { } } - fn is_end_stream(&self) -> bool { - self.state >= 2 - } - fn size_hint(&self) -> SizeHint { let mut hint = SizeHint::new(); hint.set_exact(GRPC_RESPONSE_FRAME.len() as u64); @@ -49,14 +49,14 @@ impl Body for GrpcResponseBody { #[derive(Clone, Copy)] pub(crate) enum UpstreamResponseMode { - Immediate, - DelayHeaders(Duration), DelayBody(Duration), + DelayHeaders(Duration), + Immediate, } pub(crate) struct DelayedGrpcResponseBody { - state: u8, delay: Pin>, + state: u8, } impl DelayedGrpcResponseBody { @@ -66,16 +66,20 @@ impl DelayedGrpcResponseBody { } pub(crate) enum EitherGrpcResponseBody { - Immediate(GrpcResponseBody), - Delayed(DelayedGrpcResponseBody), Cancellable(CancellableGrpcResponseBody), + Delayed(DelayedGrpcResponseBody), Full(Full), + Immediate(GrpcResponseBody), } impl Body for DelayedGrpcResponseBody { type Data = Bytes; type Error = Infallible; + fn is_end_stream(&self) -> bool { + self.state >= 2 + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -101,10 +105,6 @@ impl Body for DelayedGrpcResponseBody { } } - fn is_end_stream(&self) -> bool { - self.state >= 2 - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } @@ -114,6 +114,15 @@ impl Body for EitherGrpcResponseBody { type Data = Bytes; type Error = Infallible; + fn is_end_stream(&self) -> bool { + match self { + Self::Immediate(body) => body.is_end_stream(), + Self::Delayed(body) => body.is_end_stream(), + Self::Cancellable(body) => body.is_end_stream(), + Self::Full(body) => body.is_end_stream(), + } + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -126,15 +135,6 @@ impl Body for EitherGrpcResponseBody { } } - fn is_end_stream(&self) -> bool { - match self { - Self::Immediate(body) => body.is_end_stream(), - Self::Delayed(body) => body.is_end_stream(), - Self::Cancellable(body) => body.is_end_stream(), - Self::Full(body) => body.is_end_stream(), - } - } - fn size_hint(&self) -> SizeHint { match self { Self::Immediate(body) => body.size_hint(), @@ -146,9 +146,9 @@ impl Body for EitherGrpcResponseBody { } pub(crate) struct CancellableGrpcResponseBody { - state: u8, - delay: Pin>, cancelled_tx: Option>>>>, + delay: Pin>, + state: u8, } impl CancellableGrpcResponseBody { @@ -165,7 +165,7 @@ impl Drop for CancellableGrpcResponseBody { fn drop(&mut self) { if let Some(cancelled_tx) = self.cancelled_tx.take() && let Some(sender) = - cancelled_tx.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).take() + cancelled_tx.lock().unwrap_or_else(std::sync::PoisonError::into_inner).take() { let _ = sender.send(()); } @@ -176,6 +176,10 @@ impl Body for CancellableGrpcResponseBody { type Data = Bytes; type Error = Infallible; + fn is_end_stream(&self) -> bool { + self.state >= 2 + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -201,10 +205,6 @@ impl Body for CancellableGrpcResponseBody { } } - fn is_end_stream(&self) -> bool { - self.state >= 2 - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } @@ -224,6 +224,10 @@ impl Body for GrpcRequestBody { type Data = Bytes; type Error = Infallible; + fn is_end_stream(&self) -> bool { + self.state >= 2 + } + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -246,10 +250,6 @@ impl Body for GrpcRequestBody { } } - fn is_end_stream(&self) -> bool { - self.state >= 2 - } - fn size_hint(&self) -> SizeHint { let mut hint = SizeHint::new(); hint.set_exact(GRPC_REQUEST_FRAME.len() as u64); diff --git a/crates/rginx-app/tests/grpc_proxy/helpers/grpc_web.rs b/crates/rginx-app/tests/grpc_proxy/helpers/grpc_web.rs index e5562af6..8056c9a5 100644 --- a/crates/rginx-app/tests/grpc_proxy/helpers/grpc_web.rs +++ b/crates/rginx-app/tests/grpc_proxy/helpers/grpc_web.rs @@ -11,19 +11,21 @@ pub(crate) fn decode_grpc_web_response(bytes: &[u8]) -> (Vec, HeaderMap) "grpc-web frame should include a 5-byte header" ); let flags = bytes[offset]; + let len_start = offset.saturating_add(1); let len = u32::from_be_bytes([ - bytes[offset + 1], - bytes[offset + 2], - bytes[offset + 3], - bytes[offset + 4], + bytes[len_start], + bytes[offset.saturating_add(2)], + bytes[offset.saturating_add(3)], + bytes[offset.saturating_add(4)], ]) as usize; - offset += 5; + offset = offset.saturating_add(5); assert!( bytes.len().saturating_sub(offset) >= len, "grpc-web frame payload should be fully present" ); - let payload = &bytes[offset..offset + len]; - offset += len; + let payload_end = offset.saturating_add(len); + let payload = &bytes[offset..payload_end]; + offset = payload_end; if (flags & 0x80) != 0 { for line in payload.split(|byte| *byte == b'\n') { @@ -67,7 +69,7 @@ pub(crate) fn grpc_web_request_with_trailers() -> Bytes { pub(crate) fn grpc_web_trailer_frame() -> Bytes { let block = b"x-client-trailer: sent\r\nx-request-checksum: abc123\r\n"; - let mut frame = Vec::with_capacity(5 + block.len()); + let mut frame = Vec::with_capacity(block.len().saturating_add(5)); frame.push(0x80); frame.extend_from_slice(&(block.len() as u32).to_be_bytes()); frame.extend_from_slice(block); diff --git a/crates/rginx-app/tests/grpc_proxy/helpers/server.rs b/crates/rginx-app/tests/grpc_proxy/helpers/server.rs index 00881ba9..21029a50 100644 --- a/crates/rginx-app/tests/grpc_proxy/helpers/server.rs +++ b/crates/rginx-app/tests/grpc_proxy/helpers/server.rs @@ -5,6 +5,9 @@ pub(crate) struct TestServer { } impl TestServer { + pub(crate) fn shutdown_and_wait(&mut self, timeout: Duration) { + self.inner.shutdown_and_wait(timeout); + } pub(crate) fn spawn(listen_addr: SocketAddr, config: String) -> Self { let _ = listen_addr; Self { @@ -24,13 +27,8 @@ impl TestServer { pub(crate) fn wait_for_https_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { self.inner.wait_for_https_ready(listen_addr, timeout); } - - pub(crate) fn shutdown_and_wait(&mut self, timeout: Duration) { - self.inner.shutdown_and_wait(timeout); - } } -pub(crate) fn https_h2_connector() --> hyper_rustls::HttpsConnector { +pub(crate) fn https_h2_connector() -> hyper_rustls::HttpsConnector { HttpsConnectorBuilder::new() .with_tls_config( ClientConfig::builder() @@ -44,7 +42,7 @@ pub(crate) fn https_h2_connector() } pub(crate) async fn wait_for_log_contains(server: &TestServer, timeout: Duration, needle: &str) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_logs = String::new(); while Instant::now() < deadline { diff --git a/crates/rginx-app/tests/grpc_proxy/helpers/tls.rs b/crates/rginx-app/tests/grpc_proxy/helpers/tls.rs index d182eafe..f9415fdf 100644 --- a/crates/rginx-app/tests/grpc_proxy/helpers/tls.rs +++ b/crates/rginx-app/tests/grpc_proxy/helpers/tls.rs @@ -1,26 +1,5 @@ use super::*; -pub(crate) fn load_certs(path: &Path) -> Vec> { - CertificateDer::pem_file_iter(path) - .expect("certificate file should open") - .collect::, _>>() - .expect("certificate PEM should parse") -} - -pub(crate) fn load_private_key(path: &Path) -> PrivateKeyDer<'static> { - PrivateKeyDer::from_pem_file(path).expect("private key PEM should parse") -} - -pub(crate) fn temp_dir(prefix: &str) -> PathBuf { - static NEXT_ID: AtomicU64 = AtomicU64::new(1); - let unique = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("system time should be after unix epoch") - .as_nanos(); - let id = NEXT_ID.fetch_add(1, Ordering::Relaxed); - std::env::temp_dir().join(format!("{prefix}-{unique}-{id}")) -} - #[derive(Debug)] pub(crate) struct InsecureServerCertVerifier { supported_schemes: Vec, @@ -37,6 +16,9 @@ impl InsecureServerCertVerifier { } impl ServerCertVerifier for InsecureServerCertVerifier { + fn supported_verify_schemes(&self) -> Vec { + self.supported_schemes.clone() + } fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, @@ -65,8 +47,25 @@ impl ServerCertVerifier for InsecureServerCertVerifier { ) -> Result { Ok(HandshakeSignatureValid::assertion()) } +} - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } +pub(crate) fn load_certs(path: &Path) -> Vec> { + CertificateDer::pem_file_iter(path) + .expect("certificate file should open") + .collect::, _>>() + .expect("certificate PEM should parse") +} + +pub(crate) fn load_private_key(path: &Path) -> PrivateKeyDer<'static> { + PrivateKeyDer::from_pem_file(path).expect("private key PEM should parse") +} + +pub(crate) fn temp_dir(prefix: &str) -> PathBuf { + static NEXT_ID: AtomicU64 = AtomicU64::new(1); + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after unix epoch") + .as_nanos(); + let id = NEXT_ID.fetch_add(1, Ordering::Relaxed); + std::env::temp_dir().join(format!("{prefix}-{unique}-{id}")) } diff --git a/crates/rginx-app/tests/grpc_proxy/helpers/upstream.rs b/crates/rginx-app/tests/grpc_proxy/helpers/upstream.rs index d5375d35..8f456a9a 100644 --- a/crates/rginx-app/tests/grpc_proxy/helpers/upstream.rs +++ b/crates/rginx-app/tests/grpc_proxy/helpers/upstream.rs @@ -25,7 +25,7 @@ pub(crate) async fn spawn_h2c_grpc_upstream() let (body_bytes, trailers) = read_body_and_trailers(body).await; if let Some(sender) = - observed_tx.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).take() + observed_tx.lock().unwrap_or_else(std::sync::PoisonError::into_inner).take() { let _ = sender.send(ObservedRequest { method: parts.method.as_str().to_string(), @@ -267,7 +267,7 @@ pub(crate) async fn spawn_grpc_upstream_with_mode( let (body_bytes, trailers) = read_body_and_trailers(body).await; if let Some(sender) = - observed_tx.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).take() + observed_tx.lock().unwrap_or_else(std::sync::PoisonError::into_inner).take() { let _ = sender.send(ObservedRequest { method: parts.method.as_str().to_string(), diff --git a/crates/rginx-app/tests/grpc_proxy/lifecycle.rs b/crates/rginx-app/tests/grpc_proxy/lifecycle.rs index 9884f0e1..d44ecd6b 100644 --- a/crates/rginx-app/tests/grpc_proxy/lifecycle.rs +++ b/crates/rginx-app/tests/grpc_proxy/lifecycle.rs @@ -59,7 +59,7 @@ async fn returns_grpc_web_status_for_unavailable_upstreams() { TestServer::spawn(listen_addr, plain_proxy_config(listen_addr, unavailable_addr)); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request = Request::builder() @@ -113,7 +113,7 @@ async fn active_grpc_health_checks_gate_proxy_requests_until_peer_recovers() { ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); // Wait for the health check to fail and peer to enter cooldown. @@ -253,7 +253,7 @@ async fn grpc_web_cancellation_closes_upstream_stream_and_emits_access_log_statu ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request = Request::builder() @@ -319,7 +319,7 @@ async fn grpc_web_text_cancellation_closes_upstream_stream_and_emits_access_log_ ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let encoded_request = format!("{}\r\n", encode_grpc_web_text_payload(GRPC_REQUEST_FRAME)); diff --git a/crates/rginx-app/tests/grpc_proxy/timeout.rs b/crates/rginx-app/tests/grpc_proxy/timeout.rs index 7367faa3..9cca4e55 100644 --- a/crates/rginx-app/tests/grpc_proxy/timeout.rs +++ b/crates/rginx-app/tests/grpc_proxy/timeout.rs @@ -1,8 +1,7 @@ -use super::*; - #[path = "timeout/grpc.rs"] mod grpc; #[path = "timeout/grpc_web.rs"] mod grpc_web; #[path = "timeout/validation.rs"] mod validation; +use super::*; diff --git a/crates/rginx-app/tests/grpc_proxy/timeout/grpc_web.rs b/crates/rginx-app/tests/grpc_proxy/timeout/grpc_web.rs index 16de9c06..9b5fdca3 100644 --- a/crates/rginx-app/tests/grpc_proxy/timeout/grpc_web.rs +++ b/crates/rginx-app/tests/grpc_proxy/timeout/grpc_web.rs @@ -12,7 +12,7 @@ async fn respects_grpc_timeout_across_grpc_web_response_body_streams() { ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request = Request::builder() @@ -82,7 +82,7 @@ async fn respects_grpc_timeout_across_grpc_web_text_response_body_streams_and_re ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let encoded_request = format!("{}\r\n", encode_grpc_web_text_payload(GRPC_REQUEST_FRAME)); diff --git a/crates/rginx-app/tests/grpc_proxy/timeout/validation.rs b/crates/rginx-app/tests/grpc_proxy/timeout/validation.rs index 29d0191e..ce3ada8d 100644 --- a/crates/rginx-app/tests/grpc_proxy/timeout/validation.rs +++ b/crates/rginx-app/tests/grpc_proxy/timeout/validation.rs @@ -9,7 +9,7 @@ async fn rejects_invalid_grpc_timeout_for_grpc_web_requests() { let mut server = TestServer::spawn(listen_addr, plain_proxy_config(listen_addr, upstream_addr)); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request = Request::builder() @@ -80,7 +80,7 @@ async fn rejects_invalid_grpc_web_text_request_body_before_contacting_upstream() ); server.wait_for_http_ready(listen_addr, Duration::from_secs(5)); - let connector = hyper_util::client::legacy::connect::HttpConnector::new(); + let connector = HttpConnector::new(); let client: Client<_, Full> = Client::builder(TokioExecutor::new()).build(connector); let request = Request::builder() diff --git a/crates/rginx-app/tests/hardening.rs b/crates/rginx-app/tests/hardening.rs index 30a6c0b9..7f72dbb1 100644 --- a/crates/rginx-app/tests/hardening.rs +++ b/crates/rginx-app/tests/hardening.rs @@ -1,4 +1,20 @@ #![cfg(unix)] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + +#[path = "hardening/connection.rs"] +mod connection; +#[path = "hardening/limits.rs"] +mod limits; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; @@ -6,19 +22,12 @@ use std::sync::{Mutex, OnceLock}; use std::thread; use std::time::{Duration, Instant}; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; -#[path = "hardening/connection.rs"] -mod connection; -#[path = "hardening/limits.rs"] -mod limits; - #[derive(Debug)] struct ParsedResponse { - status: u16, body: Vec, + status: u16, } struct TestServer { @@ -26,6 +35,9 @@ struct TestServer { } impl TestServer { + fn shutdown_and_wait(&mut self, timeout: Duration) { + self.inner.terminate_and_wait(timeout); + } fn spawn(listen_addr: SocketAddr, config: String) -> Self { let _ = listen_addr; Self { inner: ServerHarness::spawn("rginx-hardening-test", |_| config) } @@ -51,9 +63,6 @@ impl TestServer { fn wait_for_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { self.inner.wait_for_http_ready(listen_addr, timeout); } - fn shutdown_and_wait(&mut self, timeout: Duration) { - self.inner.terminate_and_wait(timeout); - } } fn return_config(listen_addr: SocketAddr, server_extra: Option<&str>, body: &str) -> String { @@ -145,11 +154,13 @@ fn parse_http_response(bytes: &[u8]) -> Result { } } - Ok(ParsedResponse { status, body: bytes[head_end + 4..].to_vec() }) + let body_start = head_end.saturating_add(4); + Ok(ParsedResponse { status, body: bytes[body_start..].to_vec() }) } fn read_http_response_once(stream: &mut TcpStream) -> Result { - let deadline = Instant::now() + Duration::from_secs(1); + let deadline = + Instant::now().checked_add(Duration::from_secs(1)).expect("test deadline remains valid"); let mut response = Vec::new(); while Instant::now() < deadline { diff --git a/crates/rginx-app/tests/http2.rs b/crates/rginx-app/tests/http2.rs index 69ee8e89..d4b4a957 100644 --- a/crates/rginx-app/tests/http2.rs +++ b/crates/rginx-app/tests/http2.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::io::{ErrorKind, Write}; use std::net::{SocketAddr, TcpListener}; use std::sync::Arc; @@ -14,8 +25,6 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{ClientConfig, DigitallySignedStruct, SignatureScheme}; -mod support; - use support::{ READY_ROUTE_CONFIG, ServerHarness, apply_tls_placeholders, read_http_head, reserve_loopback_addr, @@ -26,6 +35,55 @@ const STREAMING_CHUNK_DEADLINE: Duration = Duration::from_millis(1000); const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/GfUhio45GzZi\nz/o67ngbdqW77Bk54aQ/L9kthS47SkqHVaC8+/AYXfDddl0Q+EHQleGrCqjPrRLA\nZAm0Qz0YZEdWcFX4F4T/gvVc7gfobPmkwoVoJwk/A7giuymN9+kpDbXsr2X29uMr\nfW0krA4ZpD/+JzLMXXfG5S9mazrs81wGqFlByGV32z3Qiecvd8JL4nnyz2mMTSqV\nOvtuh285dpgHFugOUGHhhXs9b/CPAGiBTBxuIm+2tGdjtmVGL8L9tlRMYnmQ7u9K\nUCAyHYovUPJjZazTqfkpQ2CkC4iOZ7Ay25gn0fbGBEysyOBZIPw39FEW63jGembu\nOVF6nO/DAgMBAAECggEAKLC7v80TVHiFX4veQZ8WRu7AAmAWzPrNMMEc8rLZcblz\nXhau956DdITILTevQFZEGUhYuUU3RaUaCYojgNUSVLfBctfPjlhfstItMYDjgSt3\nCox6wH8TWm4NzqNgiUCgzmODeaatROUz4MY/r5/NDsuo7pJlIBvEzb5uFdY+QUZ/\nR5gHRiD2Q3wCODe8zQRfTZGo7jCimAuWTLurWZl6ax/4TjWbXCD6DTuUo81cW3vy\nne6tEetHcABRO7uDoBYXk12pCgqFZzjLMnKJjQM+OYnSj6DoWjOu1drT5YyRLGDj\nfzN8V0aKRkOYoZ5QZOua8pByOyQElJnM16vkPtHgPQKBgQD6SOUNWEghvYIGM/lx\nc22/zjvDjeaGC3qSmlpQYN5MGuDoszeDBZ+rMTmHqJ9FcHYkLQnUI7ZkHhRGt/wQ\n/w3CroJjPBgKk+ipy2cBHSI+z+U20xjYzE8hxArWbXG1G4rDt5AIz68IQPsfkVND\nktkDABDaU+KwBPx8fjeeqtRQxQKBgQDDdxdLB1XcfZMX0KEP5RfA8ar1nW41TUAl\nTCOLaXIQbHZ0BeW7USE9mK8OKnVALZGJ+rpxvYFPZ5MWxchpb/cuIwXjLoN6uZVb\nfx4Hho+2iCfhcEKzs8XZW48duKIfhx13BiILLf/YaHAWFs9UfVcQog4Qx03guyMr\n7k9bFuy25wKBgQDpE48zAT6TJS775dTrAQp4b28aan/93pyz/8gRSFRb3UALlDIi\n8s7BluKzYaWI/fUXNVYM14EX9Sb+wIGdtlezL94+2Yyt9RXbYY8361Cj2+jiSG3A\nH2ulzzIkg+E7Pj3Yi443lmiysAjsWeKHcC5l697F4w6cytfye3wCZ6W23QKBgQC0\n9tX+5aytdSkwnDvxXlVOka+ItBcri/i+Ty59TMOIxxInuqoFcUhIIcq4X8CsCUQ8\nLYBd+2fznt3D8JrqWvnKoiw6N38MqTLJQfgIWaFGCep6QhfPDbo30RfAGYcnj01N\nO8Va+lxq+84B9V5AR8bKpG5HRG4qiLc4XerkV2YSswKBgDt9eerSBZyLVwfku25Y\nfrh+nEjUZy81LdlpJmu/bfa2FfItzBqDZPskkJJW9ON82z/ejGFbsU48RF7PJUMr\nGimE33QeTDToGozHCq0QOd0SMfsVkOQR+EROdmY52UIYAYgQUfI1FQ9lLsw10wlQ\nD11SHTL7b9pefBWfW73I7ttV\n-----END PRIVATE KEY-----\n"; +#[derive(Debug)] +struct InsecureServerCertVerifier { + supported_schemes: Vec, +} + +impl InsecureServerCertVerifier { + fn new() -> Self { + Self { + supported_schemes: rustls::crypto::aws_lc_rs::default_provider() + .signature_verification_algorithms + .supported_schemes(), + } + } +} + +impl ServerCertVerifier for InsecureServerCertVerifier { + fn supported_verify_schemes(&self) -> Vec { + self.supported_schemes.clone() + } + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } +} + #[tokio::test(flavor = "multi_thread")] async fn serves_http2_over_tls_and_proxies_to_http11_upstreams() { let upstream_listener = @@ -281,53 +339,3 @@ fn write_chunk(stream: &mut std::net::TcpStream, chunk: &[u8]) { stream.write_all(b"\r\n").expect("chunk terminator should write"); stream.flush().expect("chunk should flush"); } - -#[derive(Debug)] -struct InsecureServerCertVerifier { - supported_schemes: Vec, -} - -impl InsecureServerCertVerifier { - fn new() -> Self { - Self { - supported_schemes: rustls::crypto::aws_lc_rs::default_provider() - .signature_verification_algorithms - .supported_schemes(), - } - } -} - -impl ServerCertVerifier for InsecureServerCertVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } -} diff --git a/crates/rginx-app/tests/http3.rs b/crates/rginx-app/tests/http3.rs index f1b8855a..4c41e5d0 100644 --- a/crates/rginx-app/tests/http3.rs +++ b/crates/rginx-app/tests/http3.rs @@ -1,3 +1,31 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + +#[path = "http3/basic.rs"] +mod basic; +#[path = "http3/early_data.rs"] +mod early_data; +#[path = "http3/helpers/client.rs"] +mod helpers_client; +#[path = "http3/helpers/config.rs"] +mod helpers_config; +#[path = "http3/helpers/fixtures.rs"] +mod helpers_fixtures; +#[path = "http3/mtls.rs"] +mod mtls; +#[path = "http3/policy.rs"] +mod policy; +#[path = "http3/retry.rs"] +mod retry; + use std::fs; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener}; @@ -14,8 +42,10 @@ use flate2::read::GzDecoder; use h3::client; use http_body_util::Empty; use hyper::http::{Request, StatusCode}; +use hyper_rustls::HttpsConnector; use hyper_rustls::HttpsConnectorBuilder; use hyper_util::client::legacy::Client; +use hyper_util::client::legacy::connect::HttpConnector; use hyper_util::rt::TokioExecutor; use quinn::crypto::rustls::QuicClientConfig; use rcgen::{ @@ -24,27 +54,8 @@ use rcgen::{ use rustls::pki_types::{CertificateDer, pem::PemObject}; use rustls::{ClientConfig, RootCertStore}; -mod support; - pub(crate) use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; -#[path = "http3/basic.rs"] -mod basic; -#[path = "http3/early_data.rs"] -mod early_data; -#[path = "http3/helpers/client.rs"] -mod helpers_client; -#[path = "http3/helpers/config.rs"] -mod helpers_config; -#[path = "http3/helpers/fixtures.rs"] -mod helpers_fixtures; -#[path = "http3/mtls.rs"] -mod mtls; -#[path = "http3/policy.rs"] -mod policy; -#[path = "http3/retry.rs"] -mod retry; - pub(crate) use helpers_client::*; pub(crate) use helpers_config::*; pub(crate) use helpers_fixtures::*; diff --git a/crates/rginx-app/tests/http3/helpers/client.rs b/crates/rginx-app/tests/http3/helpers/client.rs index b7ad9f61..5c0888c5 100644 --- a/crates/rginx-app/tests/http3/helpers/client.rs +++ b/crates/rginx-app/tests/http3/helpers/client.rs @@ -1,9 +1,14 @@ +#[path = "client/tls.rs"] +mod tls; + use super::*; +pub(crate) use tls::*; + pub(crate) struct Http3Response { - pub(crate) status: StatusCode, - pub(crate) headers: std::collections::HashMap, pub(crate) body: Vec, + pub(crate) headers: std::collections::HashMap, + pub(crate) status: StatusCode, } pub(crate) async fn http3_get( @@ -189,7 +194,7 @@ pub(crate) fn http3_client_endpoint( Ok(endpoint) } -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments, reason = "test helper mirrors HTTP request fields explicitly")] pub(crate) async fn http3_request_with_endpoint( endpoint: &quinn::Endpoint, listen_addr: SocketAddr, @@ -296,7 +301,8 @@ pub(crate) async fn wait_for_http3_0rtt_request( path: &str, timeout: Duration, ) -> Result<(Http3Response, bool), String> { - let deadline = std::time::Instant::now() + timeout; + let deadline = + std::time::Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while std::time::Instant::now() < deadline { @@ -340,7 +346,8 @@ pub(crate) async fn wait_for_http3_0rtt_request_status( expected_status: StatusCode, timeout: Duration, ) -> Result<(Http3Response, bool), String> { - let deadline = std::time::Instant::now() + timeout; + let deadline = + std::time::Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while std::time::Instant::now() < deadline { @@ -376,8 +383,3 @@ pub(crate) async fn wait_for_http3_0rtt_request_status( listen_addr.port() )) } - -#[path = "client/tls.rs"] -mod tls; - -pub(crate) use tls::*; diff --git a/crates/rginx-app/tests/http3/helpers/client/tls.rs b/crates/rginx-app/tests/http3/helpers/client/tls.rs index acc0c01b..10d83caa 100644 --- a/crates/rginx-app/tests/http3/helpers/client/tls.rs +++ b/crates/rginx-app/tests/http3/helpers/client/tls.rs @@ -1,11 +1,6 @@ use super::*; -pub(crate) fn https_client( - cert_pem: &str, -) -> Client< - hyper_rustls::HttpsConnector, - Empty, -> { +pub(crate) fn https_client(cert_pem: &str) -> Client, Empty> { let roots = root_store_from_pem(cert_pem).expect("root store should build"); let client_config = ClientConfig::builder_with_provider(Arc::new( rustls::crypto::aws_lc_rs::default_provider(), diff --git a/crates/rginx-app/tests/http3/helpers/fixtures.rs b/crates/rginx-app/tests/http3/helpers/fixtures.rs index 2b83cd71..23ce83b2 100644 --- a/crates/rginx-app/tests/http3/helpers/fixtures.rs +++ b/crates/rginx-app/tests/http3/helpers/fixtures.rs @@ -1,17 +1,12 @@ use super::*; -pub(crate) fn generate_cert(hostname: &str) -> rcgen::CertifiedKey { - rcgen::generate_simple_self_signed(vec![hostname.to_string()]) - .expect("self-signed certificate should generate") -} - pub(crate) struct Http3MtlsFixture { pub(crate) _dir: PathBuf, pub(crate) ca_cert_pem: String, - pub(crate) server_cert_pem: String, - pub(crate) server_key_pem: String, pub(crate) client_cert_path: PathBuf, pub(crate) client_key_path: PathBuf, + pub(crate) server_cert_pem: String, + pub(crate) server_key_pem: String, } impl Http3MtlsFixture { @@ -45,20 +40,10 @@ impl Http3MtlsFixture { } } -pub(crate) fn temp_dir(prefix: &str) -> PathBuf { - static NEXT_ID: AtomicU64 = AtomicU64::new(1); - let unique = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("system time should be after unix epoch") - .as_nanos(); - let id = NEXT_ID.fetch_add(1, Ordering::Relaxed); - PathBuf::from(format!("{}/{}-{}-{}", std::env::temp_dir().display(), prefix, unique, id)) -} - pub(crate) struct TestCertifiedKey { pub(crate) cert: rcgen::Certificate, - pub(crate) signing_key: KeyPair, pub(crate) params: CertificateParams, + pub(crate) signing_key: KeyPair, } impl TestCertifiedKey { @@ -67,6 +52,27 @@ impl TestCertifiedKey { } } +impl Http3Response { + pub(crate) fn header(&self, name: &str) -> Option<&str> { + self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) + } +} + +pub(crate) fn generate_cert(hostname: &str) -> rcgen::CertifiedKey { + rcgen::generate_simple_self_signed(vec![hostname.to_string()]) + .expect("self-signed certificate should generate") +} + +pub(crate) fn temp_dir(prefix: &str) -> PathBuf { + static NEXT_ID: AtomicU64 = AtomicU64::new(1); + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after unix epoch") + .as_nanos(); + let id = NEXT_ID.fetch_add(1, Ordering::Relaxed); + PathBuf::from(format!("{}/{}-{}-{}", std::env::temp_dir().display(), prefix, unique, id)) +} + pub(crate) fn generate_ca_cert(common_name: &str) -> TestCertifiedKey { let mut params = CertificateParams::default(); params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); @@ -102,14 +108,9 @@ pub(crate) fn body_text(response: &Http3Response) -> String { String::from_utf8(response.body.clone()).expect("response body should be valid UTF-8") } -impl Http3Response { - pub(crate) fn header(&self, name: &str) -> Option<&str> { - self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) - } -} - pub(crate) fn wait_for_admin_socket(path: &Path, timeout: Duration) { - let deadline = std::time::Instant::now() + timeout; + let deadline = + std::time::Instant::now().checked_add(timeout).expect("test deadline remains valid"); while std::time::Instant::now() < deadline { if path.exists() && UnixStream::connect(path).is_ok() { return; @@ -130,7 +131,8 @@ pub(crate) async fn wait_for_http3_text_response( cert_pem: &str, timeout: Duration, ) { - let deadline = std::time::Instant::now() + timeout; + let deadline = + std::time::Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while std::time::Instant::now() < deadline { @@ -206,7 +208,7 @@ pub(crate) fn read_http_head_from_stream(stream: &mut std::net::TcpStream) -> St buffer.extend_from_slice(&chunk[..read]); if let Some(head_end) = buffer.windows(4).position(|window| window == b"\r\n\r\n") { - return String::from_utf8(buffer[..head_end + 4].to_vec()) + return String::from_utf8(buffer[..head_end.saturating_add(4)].to_vec()) .expect("HTTP head should be valid UTF-8"); } } diff --git a/crates/rginx-app/tests/ip_hash.rs b/crates/rginx-app/tests/ip_hash.rs index 9df191c9..727322e0 100644 --- a/crates/rginx-app/tests/ip_hash.rs +++ b/crates/rginx-app/tests/ip_hash.rs @@ -1,11 +1,20 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::collections::HashSet; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/least_conn.rs b/crates/rginx-app/tests/least_conn.rs index 9bc5389e..d9bfc911 100644 --- a/crates/rginx-app/tests/least_conn.rs +++ b/crates/rginx-app/tests/least_conn.rs @@ -1,11 +1,20 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::sync::mpsc; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/multi_listener.rs b/crates/rginx-app/tests/multi_listener.rs index a688b928..4b91eb37 100644 --- a/crates/rginx-app/tests/multi_listener.rs +++ b/crates/rginx-app/tests/multi_listener.rs @@ -1,4 +1,15 @@ #![cfg(unix)] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; @@ -9,13 +20,60 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, SignatureScheme, StreamOwned}; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, read_http_head, reserve_loopback_addr}; const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/GfUhio45GzZi\nz/o67ngbdqW77Bk54aQ/L9kthS47SkqHVaC8+/AYXfDddl0Q+EHQleGrCqjPrRLA\nZAm0Qz0YZEdWcFX4F4T/gvVc7gfobPmkwoVoJwk/A7giuymN9+kpDbXsr2X29uMr\nfW0krA4ZpD/+JzLMXXfG5S9mazrs81wGqFlByGV32z3Qiecvd8JL4nnyz2mMTSqV\nOvtuh285dpgHFugOUGHhhXs9b/CPAGiBTBxuIm+2tGdjtmVGL8L9tlRMYnmQ7u9K\nUCAyHYovUPJjZazTqfkpQ2CkC4iOZ7Ay25gn0fbGBEysyOBZIPw39FEW63jGembu\nOVF6nO/DAgMBAAECggEAKLC7v80TVHiFX4veQZ8WRu7AAmAWzPrNMMEc8rLZcblz\nXhau956DdITILTevQFZEGUhYuUU3RaUaCYojgNUSVLfBctfPjlhfstItMYDjgSt3\nCox6wH8TWm4NzqNgiUCgzmODeaatROUz4MY/r5/NDsuo7pJlIBvEzb5uFdY+QUZ/\nR5gHRiD2Q3wCODe8zQRfTZGo7jCimAuWTLurWZl6ax/4TjWbXCD6DTuUo81cW3vy\nne6tEetHcABRO7uDoBYXk12pCgqFZzjLMnKJjQM+OYnSj6DoWjOu1drT5YyRLGDj\nfzN8V0aKRkOYoZ5QZOua8pByOyQElJnM16vkPtHgPQKBgQD6SOUNWEghvYIGM/lx\nc22/zjvDjeaGC3qSmlpQYN5MGuDoszeDBZ+rMTmHqJ9FcHYkLQnUI7ZkHhRGt/wQ\n/w3CroJjPBgKk+ipy2cBHSI+z+U20xjYzE8hxArWbXG1G4rDt5AIz68IQPsfkVND\nktkDABDaU+KwBPx8fjeeqtRQxQKBgQDDdxdLB1XcfZMX0KEP5RfA8ar1nW41TUAl\nTCOLaXIQbHZ0BeW7USE9mK8OKnVALZGJ+rpxvYFPZ5MWxchpb/cuIwXjLoN6uZVb\nfx4Hho+2iCfhcEKzs8XZW48duKIfhx13BiILLf/YaHAWFs9UfVcQog4Qx03guyMr\n7k9bFuy25wKBgQDpE48zAT6TJS775dTrAQp4b28aan/93pyz/8gRSFRb3UALlDIi\n8s7BluKzYaWI/fUXNVYM14EX9Sb+wIGdtlezL94+2Yyt9RXbYY8361Cj2+jiSG3A\nH2ulzzIkg+E7Pj3Yi443lmiysAjsWeKHcC5l697F4w6cytfye3wCZ6W23QKBgQC0\n9tX+5aytdSkwnDvxXlVOka+ItBcri/i+Ty59TMOIxxInuqoFcUhIIcq4X8CsCUQ8\nLYBd+2fznt3D8JrqWvnKoiw6N38MqTLJQfgIWaFGCep6QhfPDbo30RfAGYcnj01N\nO8Va+lxq+84B9V5AR8bKpG5HRG4qiLc4XerkV2YSswKBgDt9eerSBZyLVwfku25Y\nfrh+nEjUZy81LdlpJmu/bfa2FfItzBqDZPskkJJW9ON82z/ejGFbsU48RF7PJUMr\nGimE33QeTDToGozHCq0QOd0SMfsVkOQR+EROdmY52UIYAYgQUfI1FQ9lLsw10wlQ\nD11SHTL7b9pefBWfW73I7ttV\n-----END PRIVATE KEY-----\n"; +#[derive(Debug)] +struct InsecureServerCertVerifier { + supported_schemes: Vec, +} + +impl InsecureServerCertVerifier { + fn new() -> Self { + Self { + supported_schemes: rustls::crypto::aws_lc_rs::default_provider() + .signature_verification_algorithms + .supported_schemes(), + } + } +} + +impl ServerCertVerifier for InsecureServerCertVerifier { + fn supported_verify_schemes(&self) -> Vec { + self.supported_schemes.clone() + } + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } +} + #[test] fn serves_the_same_routes_on_explicit_http_and_https_listeners() { let http_addr = reserve_loopback_addr(); @@ -291,53 +349,3 @@ fn send_https_request( .map_err(|error| format!("failed to read HTTPS response: {error}"))?; Ok(response) } - -#[derive(Debug)] -struct InsecureServerCertVerifier { - supported_schemes: Vec, -} - -impl InsecureServerCertVerifier { - fn new() -> Self { - Self { - supported_schemes: rustls::crypto::aws_lc_rs::default_provider() - .signature_verification_algorithms - .supported_schemes(), - } - } -} - -impl ServerCertVerifier for InsecureServerCertVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } -} diff --git a/crates/rginx-app/tests/nginx_alignment.rs b/crates/rginx-app/tests/nginx_alignment.rs index 324a8244..acf1946c 100644 --- a/crates/rginx-app/tests/nginx_alignment.rs +++ b/crates/rginx-app/tests/nginx_alignment.rs @@ -1,10 +1,19 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::thread; use std::time::Duration; -mod support; - use support::{ HttpChunkRead, READY_ROUTE_CONFIG, ServerHarness, connect_http_client, read_http_chunk, read_http_head_and_pending, reserve_loopback_addr, spawn_scripted_chunked_response_server, diff --git a/crates/rginx-app/tests/nginx_diff.rs b/crates/rginx-app/tests/nginx_diff.rs index ebded8eb..aa210cb8 100644 --- a/crates/rginx-app/tests/nginx_diff.rs +++ b/crates/rginx-app/tests/nginx_diff.rs @@ -1,8 +1,17 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::net::SocketAddr; use std::time::Duration; -mod support; - use support::{NginxHarness, READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/ocsp.rs b/crates/rginx-app/tests/ocsp.rs index c4d7a15a..bbe4af01 100644 --- a/crates/rginx-app/tests/ocsp.rs +++ b/crates/rginx-app/tests/ocsp.rs @@ -1,4 +1,22 @@ #![cfg(unix)] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + +#[path = "ocsp/cache.rs"] +mod cache; +#[path = "ocsp/helpers.rs"] +mod helpers; +#[path = "ocsp/refresh.rs"] +mod refresh; use std::env; use std::fs; @@ -28,17 +46,8 @@ use rcgen::{ }; use sha1::Digest; -mod support; - pub(crate) use support::{ READY_ROUTE_CONFIG, ServerHarness, read_http_head, reserve_loopback_addr, }; -#[path = "ocsp/cache.rs"] -mod cache; -#[path = "ocsp/helpers.rs"] -mod helpers; -#[path = "ocsp/refresh.rs"] -mod refresh; - pub(crate) use helpers::*; diff --git a/crates/rginx-app/tests/ocsp/helpers.rs b/crates/rginx-app/tests/ocsp/helpers.rs index c4f018fe..a2a1f037 100644 --- a/crates/rginx-app/tests/ocsp/helpers.rs +++ b/crates/rginx-app/tests/ocsp/helpers.rs @@ -1,5 +1,22 @@ use super::*; +pub(crate) struct TestCertifiedKey { + pub(crate) cert: rcgen::Certificate, + pub(crate) params: CertificateParams, + pub(crate) signing_key: KeyPair, +} + +impl TestCertifiedKey { + pub(crate) fn issuer(&self) -> Issuer<'_, &KeyPair> { + Issuer::from_params(&self.params, &self.signing_key) + } +} + +pub(crate) enum TimeOffset { + After(Duration), + Before(Duration), +} + pub(crate) fn spawn_ocsp_responder( requests: Arc, body: Arc>>, @@ -31,7 +48,7 @@ pub(crate) fn wait_for_command_output( timeout: Duration, predicate: impl Fn(&str) -> bool, ) -> String { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_stdout = String::new(); let mut last_stderr = String::new(); @@ -87,18 +104,6 @@ pub(crate) fn dynamic_ocsp_config( ) } -pub(crate) struct TestCertifiedKey { - pub(crate) cert: rcgen::Certificate, - pub(crate) signing_key: KeyPair, - pub(crate) params: CertificateParams, -} - -impl TestCertifiedKey { - pub(crate) fn issuer(&self) -> Issuer<'_, &KeyPair> { - Issuer::from_params(&self.params, &self.signing_key) - } -} - pub(crate) fn generate_ca_cert(common_name: &str) -> TestCertifiedKey { let mut params = CertificateParams::new(vec![common_name.to_string()]).expect("CA params should build"); @@ -159,7 +164,7 @@ pub(crate) fn der_length(length: usize) -> Vec { return vec![length as u8]; } let bytes = length.to_be_bytes().into_iter().skip_while(|byte| *byte == 0).collect::>(); - let mut encoded = Vec::with_capacity(bytes.len() + 1); + let mut encoded = Vec::with_capacity(bytes.len().saturating_add(1)); encoded.push(0x80 | (bytes.len() as u8)); encoded.extend(bytes); encoded @@ -233,17 +238,12 @@ pub(crate) fn extract_ocsp_cert_id_from_request(request_der: &[u8]) -> RasnCertI .expect("OCSP request should contain a CertId") } -pub(crate) enum TimeOffset { - Before(Duration), - After(Duration), -} - pub(crate) fn ocsp_time_with_offset(base: SystemTime, offset: TimeOffset) -> GeneralizedTime { let time = match offset { TimeOffset::Before(duration) => { base.checked_sub(duration).expect("time offset should stay after unix epoch") } - TimeOffset::After(duration) => base + duration, + TimeOffset::After(duration) => base.checked_add(duration).expect("test time remains valid"), }; generalized_time_from_system_time(time) } diff --git a/crates/rginx-app/tests/phase1.rs b/crates/rginx-app/tests/phase1.rs index 8855a8a7..863e4f2a 100644 --- a/crates/rginx-app/tests/phase1.rs +++ b/crates/rginx-app/tests/phase1.rs @@ -1,3 +1,14 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::collections::HashMap; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; @@ -5,10 +16,21 @@ use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, read_http_head, reserve_loopback_addr}; +#[derive(Debug)] +struct ParsedResponse { + body: Vec, + headers: HashMap, + status: u16, +} + +impl ParsedResponse { + fn header(&self, name: &str) -> Option<&str> { + self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) + } +} + #[test] fn return_responses_generate_and_preserve_request_id_headers() { let listen_addr = reserve_loopback_addr(); @@ -173,19 +195,6 @@ fn proxy_request_buffering_off_streams_chunked_uploads_and_enforces_limits() { ); } -#[derive(Debug)] -struct ParsedResponse { - status: u16, - headers: HashMap, - body: Vec, -} - -impl ParsedResponse { - fn header(&self, name: &str) -> Option<&str> { - self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) - } -} - fn send_http_request(listen_addr: SocketAddr, request: &str) -> Result { send_http_request_with_timeouts( listen_addr, @@ -231,9 +240,10 @@ fn read_http_head_and_body(stream: &mut TcpStream) -> (String, Vec) { buffer.extend_from_slice(&chunk[..read]); if let Some(head_end) = buffer.windows(4).position(|window| window == b"\r\n\r\n") { - let head = String::from_utf8(buffer[..head_end + 4].to_vec()) + let body_start = head_end.saturating_add(4); + let head = String::from_utf8(buffer[..body_start].to_vec()) .expect("HTTP head should be valid UTF-8"); - let body = buffer[head_end + 4..].to_vec(); + let body = buffer[body_start..].to_vec(); return (head, body); } } @@ -262,7 +272,8 @@ fn parse_http_response(bytes: &[u8]) -> Result { headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string()); } - Ok(ParsedResponse { status, headers, body: bytes[head_end + 4..].to_vec() }) + let body_start = head_end.saturating_add(4); + Ok(ParsedResponse { status, headers, body: bytes[body_start..].to_vec() }) } fn assert_generated_request_id(value: Option<&str>) { diff --git a/crates/rginx-app/tests/policy.rs b/crates/rginx-app/tests/policy.rs index 2309e698..af1548e6 100644 --- a/crates/rginx-app/tests/policy.rs +++ b/crates/rginx-app/tests/policy.rs @@ -1,11 +1,26 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; +#[derive(Debug)] +struct ParsedResponse { + body: Vec, + status: u16, +} + #[test] fn route_access_control_allows_and_denies_requests_end_to_end() { let listen_addr = reserve_loopback_addr(); @@ -53,12 +68,6 @@ fn route_rate_limit_rejects_requests_after_capacity_is_exhausted() { server.shutdown_and_wait(Duration::from_secs(5)); } -#[derive(Debug)] -struct ParsedResponse { - status: u16, - body: Vec, -} - fn send_http_request( listen_addr: SocketAddr, method: &str, @@ -99,7 +108,8 @@ fn parse_http_response(bytes: &[u8]) -> Result { .parse::() .map_err(|error| format!("invalid status code: {error}"))?; - Ok(ParsedResponse { status, body: bytes[head_end + 4..].to_vec() }) + let body_start = head_end.saturating_add(4); + Ok(ParsedResponse { status, body: bytes[body_start..].to_vec() }) } fn acl_config(listen_addr: SocketAddr) -> String { diff --git a/crates/rginx-app/tests/proxy_protocol.rs b/crates/rginx-app/tests/proxy_protocol.rs index f5666636..77051a9f 100644 --- a/crates/rginx-app/tests/proxy_protocol.rs +++ b/crates/rginx-app/tests/proxy_protocol.rs @@ -1,11 +1,19 @@ #![cfg(unix)] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] @@ -56,7 +64,8 @@ fn untrusted_transport_peer_ignores_proxy_protocol_source() { } fn wait_for_proxy_protocol_ready(listen_addr: SocketAddr, timeout: Duration) { - let deadline = std::time::Instant::now() + timeout; + let deadline = + std::time::Instant::now().checked_add(timeout).expect("test deadline remains valid"); while std::time::Instant::now() < deadline { if send_proxy_protocol_request( listen_addr, diff --git a/crates/rginx-app/tests/reload.rs b/crates/rginx-app/tests/reload.rs index 6fab7148..d1b9fe2b 100644 --- a/crates/rginx-app/tests/reload.rs +++ b/crates/rginx-app/tests/reload.rs @@ -1,29 +1,15 @@ #![cfg(unix)] - -#[allow(unused_imports)] -use std::fs; -#[allow(unused_imports)] -use std::io::{Read, Write}; -#[allow(unused_imports)] -use std::net::{SocketAddr, TcpStream}; -#[allow(unused_imports)] -use std::path::{Path, PathBuf}; -#[allow(unused_imports)] -use std::process::{Command, Output}; -#[allow(unused_imports)] -use std::sync::mpsc; -#[allow(unused_imports)] -use std::sync::{Mutex, OnceLock}; -#[allow(unused_imports)] -use std::time::{Duration, Instant}; - -#[allow(unused_imports)] -use rcgen::{CertifiedKey, KeyPair}; +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] mod support; -use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; - #[path = "reload/cache.rs"] mod cache; #[path = "reload/cache_streaming.rs"] @@ -37,11 +23,64 @@ mod restart_flow; #[path = "reload/streaming_flow.rs"] mod streaming_flow; +#[path = "reload/cli.rs"] +mod cli; + +use std::fs; +use std::io::{Read, Write}; +use std::net::{SocketAddr, TcpStream}; +use std::path::{Path, PathBuf}; +use std::process::{Command, Output}; +use std::sync::mpsc; +use std::sync::{Mutex, OnceLock}; +use std::time::{Duration, Instant}; + +use rcgen::{CertifiedKey, KeyPair}; + +use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; + +use cli::*; + struct TestServer { inner: ServerHarness, } impl TestServer { + fn kill_and_wait(&mut self, timeout: Duration) { + self.inner.kill_and_wait(timeout); + } + + fn pid_path(&self) -> PathBuf { + self.inner.config_path().with_extension("pid") + } + + fn run_cli_command<'a>(&self, args: impl IntoIterator) -> Output { + let mut command = Command::new(binary_path()); + command.arg("--config").arg(self.inner.config_path()); + for arg in args { + command.arg(arg); + } + command.output().expect("rginx command should run") + } + + fn send_cli_signal(&self, signal: &str) -> Output { + Command::new(binary_path()) + .arg("--config") + .arg(self.inner.config_path()) + .arg("-s") + .arg(signal) + .output() + .expect("rginx signal command should run") + } + + fn send_signal(&self, signal: i32) { + self.inner.send_signal(signal); + } + + fn shutdown_and_wait(&mut self, timeout: Duration) { + self.inner.terminate_and_wait(timeout); + } + fn spawn(listen_addr: SocketAddr, body: &str) -> Self { Self::spawn_with_config("rginx-test", return_config(listen_addr, body)) } @@ -54,12 +93,8 @@ impl TestServer { Self { inner: ServerHarness::spawn(prefix, setup) } } - fn write_return_config(&self, listen_addr: SocketAddr, body: &str) { - write_return_config(self.inner.config_path(), listen_addr, body); - } - - fn write_config(&self, config: String) { - fs::write(self.inner.config_path(), config).expect("config file should be written"); + fn temp_dir(&self) -> &Path { + self.inner.temp_dir() } fn wait_for_body(&mut self, listen_addr: SocketAddr, expected: &str, timeout: Duration) { @@ -73,51 +108,12 @@ impl TestServer { ); } - fn wait_for_http_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { - self.inner.wait_for_http_ready(listen_addr, timeout); - } - - fn send_signal(&self, signal: i32) { - self.inner.send_signal(signal); - } - - fn shutdown_and_wait(&mut self, timeout: Duration) { - self.inner.terminate_and_wait(timeout); - } - - fn kill_and_wait(&mut self, timeout: Duration) { - self.inner.kill_and_wait(timeout); - } - fn wait_for_exit(&mut self, timeout: Duration) -> std::process::ExitStatus { self.inner.wait_for_exit(timeout) } - fn temp_dir(&self) -> &Path { - self.inner.temp_dir() - } - - fn pid_path(&self) -> PathBuf { - self.inner.config_path().with_extension("pid") - } - - fn send_cli_signal(&self, signal: &str) -> Output { - Command::new(binary_path()) - .arg("--config") - .arg(self.inner.config_path()) - .arg("-s") - .arg(signal) - .output() - .expect("rginx signal command should run") - } - - fn run_cli_command<'a>(&self, args: impl IntoIterator) -> Output { - let mut command = Command::new(binary_path()); - command.arg("--config").arg(self.inner.config_path()); - for arg in args { - command.arg(arg); - } - command.output().expect("rginx command should run") + fn wait_for_http_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { + self.inner.wait_for_http_ready(listen_addr, timeout); } fn wait_for_status_output( @@ -127,6 +123,14 @@ impl TestServer { ) -> String { wait_for_status_output(self.inner.config_path(), predicate, timeout) } + + fn write_config(&self, config: String) { + fs::write(self.inner.config_path(), config).expect("config file should be written"); + } + + fn write_return_config(&self, listen_addr: SocketAddr, body: &str) { + write_return_config(self.inner.config_path(), listen_addr, body); + } } fn write_return_config(path: &Path, listen_addr: SocketAddr, body: &str) { @@ -149,8 +153,7 @@ fn return_config_with_runtime(listen_addr: SocketAddr, body: &str, runtime_extra fn return_route_fragment(body: &str) -> String { format!( - "LocationConfig(\n matcher: Exact(\"/\"),\n handler: Return(\n status: 200,\n location: \"\",\n body: Some({:?}),\n ),\n),\n", - body + "LocationConfig(\n matcher: Exact(\"/\"),\n handler: Return(\n status: 200,\n location: \"\",\n body: Some({body:?}),\n ),\n),\n" ) } @@ -168,10 +171,7 @@ fn explicit_listeners_config(listeners: &[(&str, SocketAddr)], body: &str) -> St .join(",\n"); format!( - "Config(\n runtime: RuntimeConfig(\n shutdown_timeout_secs: 2,\n ),\n listeners: [\n{listeners}\n ],\n server: ServerConfig(\n ),\n upstreams: [],\n locations: [\n{ready_route} LocationConfig(\n matcher: Exact(\"/\"),\n handler: Return(\n status: 200,\n location: \"\",\n body: Some({body:?}),\n ),\n ),\n ],\n)\n", - listeners = listeners, - ready_route = READY_ROUTE_CONFIG, - body = body, + "Config(\n runtime: RuntimeConfig(\n shutdown_timeout_secs: 2,\n ),\n listeners: [\n{listeners}\n ],\n server: ServerConfig(\n ),\n upstreams: [],\n locations: [\n{READY_ROUTE_CONFIG} LocationConfig(\n matcher: Exact(\"/\"),\n handler: Return(\n status: 200,\n location: \"\",\n body: Some({body:?}),\n ),\n ),\n ],\n)\n", ) } @@ -279,14 +279,14 @@ fn spawn_delayed_response_server( } fn assert_unreachable(listen_addr: SocketAddr, timeout: Duration) { - let deadline = std::time::Instant::now() + timeout; + let deadline = + std::time::Instant::now().checked_add(timeout).expect("test deadline remains valid"); while std::time::Instant::now() < deadline { match fetch_text_response(listen_addr, "/") { Ok((status, body)) => { panic!( - "expected {} to stay unreachable, got status={} body={:?}", - listen_addr, status, body + "expected {listen_addr} to stay unreachable, got status={status} body={body:?}" ); } Err(_) => { @@ -297,7 +297,7 @@ fn assert_unreachable(listen_addr: SocketAddr, timeout: Duration) { } fn wait_for_body(listen_addr: SocketAddr, expected: &str, timeout: Duration) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -312,8 +312,7 @@ fn wait_for_body(listen_addr: SocketAddr, expected: &str, timeout: Duration) { } panic!( - "timed out waiting for expected response on {}; expected body {:?}; last error: {}", - listen_addr, expected, last_error + "timed out waiting for expected response on {listen_addr}; expected body {expected:?}; last error: {last_error}" ); } @@ -329,7 +328,7 @@ fn read_pid_file(path: &Path) -> i32 { } fn wait_for_pid_change(path: &Path, old_pid: i32, timeout: Duration) -> i32 { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); while Instant::now() < deadline { if path.exists() { @@ -345,7 +344,7 @@ fn wait_for_pid_change(path: &Path, old_pid: i32, timeout: Duration) -> i32 { } fn wait_for_process_exit(pid: i32, timeout: Duration) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); while Instant::now() < deadline { let result = unsafe { libc::kill(pid, 0) }; @@ -361,8 +360,3 @@ fn wait_for_process_exit(pid: i32, timeout: Duration) { panic!("timed out waiting for pid {pid} to exit"); } - -#[path = "reload/cli.rs"] -mod cli; - -use cli::*; diff --git a/crates/rginx-app/tests/reload/cache.rs b/crates/rginx-app/tests/reload/cache.rs index e4796bb9..713b6dff 100644 --- a/crates/rginx-app/tests/reload/cache.rs +++ b/crates/rginx-app/tests/reload/cache.rs @@ -10,7 +10,7 @@ use super::*; #[test] fn reload_preserves_cache_entries_when_zone_path_is_reused() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let (upstream_addr, upstream_hits) = spawn_counting_response_server("reload cache ok\n"); let mut server = TestServer::spawn_with_setup("rginx-reload-cache", |temp_dir| { @@ -56,7 +56,7 @@ fn reload_preserves_cache_entries_when_zone_path_is_reused() { #[test] #[ignore = "cache stress suite; run via scripts/run-cache-stress.sh"] fn reload_keeps_hot_cache_hits_available_under_concurrent_traffic() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let (upstream_addr, upstream_hits) = spawn_counting_response_server("reload cache stress\n"); let mut server = TestServer::spawn_with_setup("rginx-reload-cache-stress", |temp_dir| { @@ -91,7 +91,7 @@ fn reload_keeps_hot_cache_hits_available_under_concurrent_traffic() { if response.status != 200 || x_cache.as_deref() != Some("HIT") { failures .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .push(format!( "worker {worker_id} observed status={} x-cache={x_cache:?}", response.status @@ -102,7 +102,7 @@ fn reload_keeps_hot_cache_hits_available_under_concurrent_traffic() { Err(error) => { failures .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .push(format!("worker {worker_id} request failed: {error}")); break; } @@ -133,7 +133,7 @@ fn reload_keeps_hot_cache_hits_available_under_concurrent_traffic() { worker.join().expect("cache stress worker should join cleanly"); } - let failures = failures.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let failures = failures.lock().unwrap_or_else(std::sync::PoisonError::into_inner); assert!(failures.is_empty(), "concurrent reload cache traffic should stay hot: {failures:?}"); assert_eq!(upstream_hits.load(Ordering::Relaxed), 1); diff --git a/crates/rginx-app/tests/reload/cache_streaming.rs b/crates/rginx-app/tests/reload/cache_streaming.rs index d9b486b4..c96bb6ed 100644 --- a/crates/rginx-app/tests/reload/cache_streaming.rs +++ b/crates/rginx-app/tests/reload/cache_streaming.rs @@ -9,7 +9,7 @@ use super::*; #[test] fn reload_preserves_streaming_cache_entries_across_reloads() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let (upstream_addr, upstream_hits) = spawn_chunked_counting_response_server([ b"reload ".as_slice(), @@ -67,7 +67,7 @@ fn reload_preserves_streaming_cache_entries_across_reloads() { #[test] #[ignore = "cache stress suite; run via scripts/run-cache-stress.sh"] fn reload_keeps_hot_streaming_cache_hits_available_under_concurrent_traffic() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let (upstream_addr, upstream_hits) = spawn_chunked_counting_response_server([ b"reload ".as_slice(), @@ -116,7 +116,7 @@ fn reload_keeps_hot_streaming_cache_hits_available_under_concurrent_traffic() { { failures .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .push(format!( "worker {worker_id} observed status={} x-cache={x_cache:?} body={body:?}", response.status @@ -127,7 +127,7 @@ fn reload_keeps_hot_streaming_cache_hits_available_under_concurrent_traffic() { Err(error) => { failures .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .push(format!("worker {worker_id} request failed: {error}")); break; } @@ -158,7 +158,7 @@ fn reload_keeps_hot_streaming_cache_hits_available_under_concurrent_traffic() { worker.join().expect("cache stress worker should join cleanly"); } - let failures = failures.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let failures = failures.lock().unwrap_or_else(std::sync::PoisonError::into_inner); assert!( failures.is_empty(), "concurrent reload streaming cache traffic should stay hot: {failures:?}" diff --git a/crates/rginx-app/tests/reload/cli.rs b/crates/rginx-app/tests/reload/cli.rs index af4a33cd..524954af 100644 --- a/crates/rginx-app/tests/reload/cli.rs +++ b/crates/rginx-app/tests/reload/cli.rs @@ -1,5 +1,7 @@ use super::*; +pub(super) type TestCertifiedKey = CertifiedKey; + pub(super) fn binary_path() -> std::path::PathBuf { std::env::var_os("CARGO_BIN_EXE_rginx") .map(std::path::PathBuf::from) @@ -15,8 +17,6 @@ pub(super) fn render_output(output: &Output) -> String { ) } -pub(super) type TestCertifiedKey = CertifiedKey; - pub(super) fn generate_cert(hostname: &str) -> TestCertifiedKey { rcgen::generate_simple_self_signed(vec![hostname.to_string()]) .expect("self-signed certificate should generate") @@ -53,7 +53,7 @@ pub(super) fn wait_for_status_output( predicate: impl Fn(&str) -> bool, timeout: Duration, ) -> String { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_output = String::new(); while Instant::now() < deadline { diff --git a/crates/rginx-app/tests/reload/reload_boundary.rs b/crates/rginx-app/tests/reload/reload_boundary.rs index 39c8b048..0422c04c 100644 --- a/crates/rginx-app/tests/reload/reload_boundary.rs +++ b/crates/rginx-app/tests/reload/reload_boundary.rs @@ -2,7 +2,7 @@ use super::*; #[test] fn sighup_rejects_listen_address_changes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let initial_addr = reserve_loopback_addr(); let rejected_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(initial_addr, "stable config\n"); @@ -24,7 +24,7 @@ fn sighup_rejects_listen_address_changes() { #[test] fn sighup_rejects_accept_worker_changes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "stable workers\n"); @@ -47,7 +47,7 @@ fn sighup_rejects_accept_worker_changes() { #[test] fn sighup_rejects_runtime_worker_thread_changes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "stable runtime\n"); @@ -70,7 +70,7 @@ fn sighup_rejects_runtime_worker_thread_changes() { #[test] fn sighup_status_reports_restart_required_fields_for_startup_boundary_changes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "stable runtime\n"); @@ -115,7 +115,7 @@ fn sighup_status_reports_restart_required_fields_for_startup_boundary_changes() #[test] fn sighup_reload_picks_up_updated_included_fragments() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn_with_setup("rginx-reload-include-test", |temp_dir| { fs::write(temp_dir.join("routes.ron"), return_route_fragment("before include reload\n")) diff --git a/crates/rginx-app/tests/reload/reload_flow.rs b/crates/rginx-app/tests/reload/reload_flow.rs index d963aea7..d704970a 100644 --- a/crates/rginx-app/tests/reload/reload_flow.rs +++ b/crates/rginx-app/tests/reload/reload_flow.rs @@ -10,7 +10,7 @@ use rustls::{ClientConfig, RootCertStore}; #[test] fn sighup_reload_applies_updated_routes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "before reload\n"); @@ -25,7 +25,7 @@ fn sighup_reload_applies_updated_routes() { #[test] fn nginx_style_reload_command_applies_updated_routes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "before reload\n"); @@ -44,7 +44,7 @@ fn nginx_style_reload_command_applies_updated_routes() { #[test] fn nginx_style_quit_command_stops_the_server() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "before quit\n"); @@ -59,7 +59,7 @@ fn nginx_style_quit_command_stops_the_server() { #[test] fn sighup_reload_adds_explicit_listener_without_restart() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let http_addr = reserve_loopback_addr(); let admin_addr = reserve_loopback_addr(); let mut server = TestServer::spawn_with_config( @@ -83,7 +83,7 @@ fn sighup_reload_adds_explicit_listener_without_restart() { #[test] fn sighup_reload_removes_explicit_listener_without_restart() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let http_addr = reserve_loopback_addr(); let admin_addr = reserve_loopback_addr(); let mut server = TestServer::spawn_with_config( @@ -104,7 +104,7 @@ fn sighup_reload_removes_explicit_listener_without_restart() { #[test] fn removed_listener_drains_in_flight_request_before_going_unreachable() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let http_addr = reserve_loopback_addr(); let drain_addr = reserve_loopback_addr(); let (ready_tx, ready_rx) = mpsc::channel(); @@ -144,7 +144,7 @@ fn removed_listener_drains_in_flight_request_before_going_unreachable() { #[test] fn removed_http3_listener_drains_in_flight_request_before_going_unreachable() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let cert = generate_cert("localhost"); let http_addr = reserve_loopback_addr(); let h3_addr = reserve_loopback_addr(); @@ -229,7 +229,7 @@ fn wait_for_http3_body( expected: &str, timeout: Duration, ) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -247,8 +247,7 @@ fn wait_for_http3_body( } panic!( - "timed out waiting for expected http3 response on {}; expected body {:?}; last error: {}", - listen_addr, expected, last_error + "timed out waiting for expected http3 response on {listen_addr}; expected body {expected:?}; last error: {last_error}" ); } @@ -259,7 +258,7 @@ fn assert_http3_unreachable( cert_pem: &str, timeout: Duration, ) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); while Instant::now() < deadline { let result = tokio::runtime::Builder::new_current_thread() @@ -268,10 +267,7 @@ fn assert_http3_unreachable( .expect("tokio runtime should build") .block_on(async { http3_get_body(listen_addr, server_name, path, cert_pem).await }); if let Ok(body) = result { - panic!( - "expected http3 listener {} to stay unreachable, got body {:?}", - listen_addr, body - ); + panic!("expected http3 listener {listen_addr} to stay unreachable, got body {body:?}"); } std::thread::sleep(Duration::from_millis(50)); } diff --git a/crates/rginx-app/tests/reload/restart_flow.rs b/crates/rginx-app/tests/reload/restart_flow.rs index 58d42e8d..0c272e55 100644 --- a/crates/rginx-app/tests/reload/restart_flow.rs +++ b/crates/rginx-app/tests/reload/restart_flow.rs @@ -10,7 +10,7 @@ use rustls::{ClientConfig, RootCertStore}; #[test] fn nginx_style_restart_command_applies_listen_address_changes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let initial_addr = reserve_loopback_addr(); let restarted_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(initial_addr, "before restart\n"); @@ -40,7 +40,7 @@ fn nginx_style_restart_command_applies_listen_address_changes() { #[test] fn nginx_style_restart_command_applies_runtime_worker_changes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "runtime restart\n"); @@ -81,7 +81,7 @@ fn nginx_style_restart_command_applies_runtime_worker_changes() { #[test] fn nginx_style_restart_command_keeps_old_process_running_when_replacement_fails() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let mut server = TestServer::spawn(listen_addr, "stable runtime\n"); @@ -107,7 +107,7 @@ fn nginx_style_restart_command_keeps_old_process_running_when_replacement_fails( #[test] fn sighup_status_reports_tls_certificate_changes_after_rotation() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let listen_addr = reserve_loopback_addr(); let initial_cert = generate_cert("localhost"); let rotated_cert = generate_cert("localhost"); @@ -156,7 +156,7 @@ fn sighup_status_reports_tls_certificate_changes_after_rotation() { #[test] fn nginx_style_restart_command_keeps_http3_listener_available() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); tokio::runtime::Builder::new_current_thread() .enable_all() .build() diff --git a/crates/rginx-app/tests/reload/streaming_flow.rs b/crates/rginx-app/tests/reload/streaming_flow.rs index dbdd0552..3e64207e 100644 --- a/crates/rginx-app/tests/reload/streaming_flow.rs +++ b/crates/rginx-app/tests/reload/streaming_flow.rs @@ -2,7 +2,7 @@ use super::*; #[test] fn sighup_reload_drains_inflight_streaming_response_before_switching_routes() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let (upstream_addr, upstream_task) = super::support::spawn_scripted_chunked_response_server( "GET / HTTP/1.1\r\n", b"before reload\n", @@ -53,7 +53,7 @@ fn sighup_reload_drains_inflight_streaming_response_before_switching_routes() { #[test] fn nginx_style_restart_command_drains_inflight_streaming_response_before_old_process_exits() { - let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = test_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); let (upstream_addr, upstream_task) = super::support::spawn_scripted_chunked_response_server( "GET / HTTP/1.1\r\n", b"before restart\n", diff --git a/crates/rginx-app/tests/static_file_streaming.rs b/crates/rginx-app/tests/static_file_streaming.rs index e87ae5a7..aecfc4e4 100644 --- a/crates/rginx-app/tests/static_file_streaming.rs +++ b/crates/rginx-app/tests/static_file_streaming.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::collections::HashMap; use std::fs; use std::io::{Read, Write}; @@ -7,13 +18,24 @@ use std::sync::{Arc, Barrier, Mutex, OnceLock}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; const LARGE_FILE_PATH: &str = "/download.bin"; const LARGE_FILE_LEN: usize = 512 * 1024; +#[derive(Debug)] +struct ParsedResponse { + body: Vec, + headers: HashMap, + status: u16, +} + +impl ParsedResponse { + fn header(&self, name: &str) -> Option<&str> { + self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) + } +} + #[test] fn static_file_large_download_and_head_work_on_real_server() { let _guard = test_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); @@ -161,19 +183,6 @@ fn static_file_large_downloads_can_run_concurrently_on_real_server() { server.shutdown_and_wait(Duration::from_secs(5)); } -#[derive(Debug)] -struct ParsedResponse { - status: u16, - headers: HashMap, - body: Vec, -} - -impl ParsedResponse { - fn header(&self, name: &str) -> Option<&str> { - self.headers.get(&name.to_ascii_lowercase()).map(String::as_str) - } -} - fn send_http_request(listen_addr: SocketAddr, request: &str) -> Result { let mut stream = TcpStream::connect_timeout(&listen_addr, Duration::from_millis(200)) .map_err(|error| format!("failed to connect to {listen_addr}: {error}"))?; @@ -218,7 +227,8 @@ fn parse_http_response(bytes: &[u8]) -> Result { headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string()); } - Ok(ParsedResponse { status, headers, body: bytes[head_end + 4..].to_vec() }) + let body_start = head_end.saturating_add(4); + Ok(ParsedResponse { status, headers, body: bytes[body_start..].to_vec() }) } fn create_static_file_fixture(temp_dir: &Path, file_name: &str, bytes: &[u8]) { @@ -236,7 +246,7 @@ fn static_file_config(listen_addr: SocketAddr) -> String { } fn test_file_bytes(len: usize) -> Vec { - (0..len).map(|index| b'a' + (index % 26) as u8).collect() + (0..len).map(|index| b'a'.saturating_add((index.rem_euclid(26)) as u8)).collect() } fn test_lock() -> &'static Mutex<()> { diff --git a/crates/rginx-app/tests/streaming_download.rs b/crates/rginx-app/tests/streaming_download.rs index 70ab1cea..1ba2f4f3 100644 --- a/crates/rginx-app/tests/streaming_download.rs +++ b/crates/rginx-app/tests/streaming_download.rs @@ -1,9 +1,18 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::SocketAddr; use std::time::{Duration, Instant}; -mod support; - use support::{ HttpChunkRead, READY_ROUTE_CONFIG, ServerHarness, connect_http_client, read_http_chunk, read_http_head_and_pending, reserve_loopback_addr, spawn_scripted_chunked_response_server, @@ -276,14 +285,15 @@ fn decode_response_body_text(response: &str) -> Option { if chunk_len == 0 { return String::from_utf8(decoded).ok(); } - if tail.len() < chunk_len + 2 { + let chunk_with_terminator = chunk_len.saturating_add(2); + if tail.len() < chunk_with_terminator { return None; } - if &tail.as_bytes()[chunk_len..chunk_len + 2] != b"\r\n" { + if &tail.as_bytes()[chunk_len..chunk_with_terminator] != b"\r\n" { return None; } decoded.extend_from_slice(&tail.as_bytes()[..chunk_len]); - rest = &tail[chunk_len + 2..]; + rest = &tail[chunk_with_terminator..]; } } diff --git a/crates/rginx-app/tests/support/cache/mod.rs b/crates/rginx-app/tests/support/cache/mod.rs index 03022f96..0db38e8b 100644 --- a/crates/rginx-app/tests/support/cache/mod.rs +++ b/crates/rginx-app/tests/support/cache/mod.rs @@ -1,16 +1,16 @@ -use std::collections::HashMap; -use std::sync::{Mutex, OnceLock}; -use std::time::Duration; - pub mod config; pub mod response; pub mod upstream; +use std::collections::HashMap; +use std::sync::{Mutex, OnceLock}; +use std::time::Duration; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParsedResponse { - pub status: u16, - pub headers: HashMap, pub body: Vec, + pub headers: HashMap, + pub status: u16, } impl ParsedResponse { @@ -21,9 +21,9 @@ impl ParsedResponse { #[derive(Clone, Copy)] pub enum StatusLine { - Ok, - NotModified, InternalServerError, + NotModified, + Ok, } impl StatusLine { @@ -37,12 +37,12 @@ impl StatusLine { } pub struct MetricsSample { - pub samples: Vec, + pub avg_ms: f64, + pub max_ms: f64, pub min_ms: f64, pub p50_ms: f64, pub p95_ms: f64, - pub max_ms: f64, - pub avg_ms: f64, + pub samples: Vec, } impl MetricsSample { @@ -52,12 +52,12 @@ impl MetricsSample { let as_ms = |duration: Duration| duration.as_secs_f64() * 1000.0; let len = samples.len(); let avg_ms = - samples.iter().map(|sample| sample.as_secs_f64()).sum::() * 1000.0 / len as f64; + samples.iter().map(std::time::Duration::as_secs_f64).sum::() * 1000.0 / len as f64; Self { min_ms: as_ms(samples[0]), p50_ms: percentile_ms(&samples, 0.50), p95_ms: percentile_ms(&samples, 0.95), - max_ms: as_ms(samples[len - 1]), + max_ms: as_ms(samples[len.saturating_sub(1)]), avg_ms, samples, } diff --git a/crates/rginx-app/tests/support/cache/response.rs b/crates/rginx-app/tests/support/cache/response.rs index 1d49f1f0..1f2b192b 100644 --- a/crates/rginx-app/tests/support/cache/response.rs +++ b/crates/rginx-app/tests/support/cache/response.rs @@ -44,7 +44,7 @@ pub fn wait_for_response( timeout: Duration, predicate: impl Fn(&ParsedResponse) -> bool, ) -> Result { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -73,7 +73,8 @@ fn parse_response(response: &[u8]) -> Result { .windows(4) .position(|window| window == b"\r\n\r\n") .ok_or_else(|| format!("malformed response: {:?}", String::from_utf8_lossy(response)))?; - let head = std::str::from_utf8(&response[..head_end + 4]) + let body_start = head_end.saturating_add(4); + let head = std::str::from_utf8(&response[..body_start]) .map_err(|error| format!("response head should be utf-8: {error}"))?; let mut lines = head.lines(); let status = lines @@ -92,12 +93,12 @@ fn parse_response(response: &[u8]) -> Result { } } - let mut body = response[head_end + 4..].to_vec(); + let mut body = response[body_start..].to_vec(); if headers.get("transfer-encoding").is_some_and(|value| value.eq_ignore_ascii_case("chunked")) { body = decode_chunked_body(&body)?; } - Ok(ParsedResponse { status, headers, body }) + Ok(ParsedResponse { body, headers, status }) } fn decode_chunked_body(body: &[u8]) -> Result, String> { @@ -115,17 +116,19 @@ fn decode_chunked_body(body: &[u8]) -> Result, String> { usize::from_str_radix(line.trim(), 16) .map_err(|error| format!("chunk length should be valid hex: {error}")) })?; - remaining = &remaining[line_end + 2..]; + let payload_start = line_end.saturating_add(2); + remaining = &remaining[payload_start..]; if chunk_len == 0 { return Ok(decoded); } - if remaining.len() < chunk_len + 2 { + let chunk_with_terminator = chunk_len.saturating_add(2); + if remaining.len() < chunk_with_terminator { return Err("chunk payload shorter than advertised".to_string()); } - if &remaining[chunk_len..chunk_len + 2] != b"\r\n" { + if &remaining[chunk_len..chunk_with_terminator] != b"\r\n" { return Err("missing chunk payload terminator".to_string()); } decoded.extend_from_slice(&remaining[..chunk_len]); - remaining = &remaining[chunk_len + 2..]; + remaining = &remaining[chunk_with_terminator..]; } } diff --git a/crates/rginx-app/tests/support/cache/upstream.rs b/crates/rginx-app/tests/support/cache/upstream.rs index 85b8604f..8f94afa5 100644 --- a/crates/rginx-app/tests/support/cache/upstream.rs +++ b/crates/rginx-app/tests/support/cache/upstream.rs @@ -42,7 +42,7 @@ where let attempt = hits.fetch_add(1, Ordering::Relaxed); requests .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .push(request.clone()); let response = handler(attempt, &request); let _ = stream.write_all(&response); @@ -125,10 +125,10 @@ pub fn spawn_blocked_fill_response_server(body: &'static str) -> BlockedFillServ hits.fetch_add(1, Ordering::Relaxed); let _ = started_tx.send(()); let (open, notify) = &*gate; - let mut released = open.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut released = open.lock().unwrap_or_else(std::sync::PoisonError::into_inner); while !*released { released = - notify.wait(released).unwrap_or_else(|poisoned| poisoned.into_inner()); + notify.wait(released).unwrap_or_else(std::sync::PoisonError::into_inner); } drop(released); @@ -145,14 +145,14 @@ pub fn spawn_blocked_fill_response_server(body: &'static str) -> BlockedFillServ pub fn open_gate(gate: &FillGate) { let (open, notify) = &**gate; - let mut released = open.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut released = open.lock().unwrap_or_else(std::sync::PoisonError::into_inner); *released = true; notify.notify_all(); } pub fn reset_gate(gate: &FillGate) { let (open, _) = &**gate; - let mut released = open.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut released = open.lock().unwrap_or_else(std::sync::PoisonError::into_inner); *released = false; } @@ -177,14 +177,14 @@ pub fn spawn_range_response_server(seen_ranges: Arc>>) -> Sock if let Some(range) = &range { seen_ranges .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .push(range.clone()); } let payload = b"abcdefghijklmnopqrstuvwxyz"; let response = match range.and_then(|range| parse_range_header(&range)) { Some((start, end)) if start < payload.len() => { - let end = end.min(payload.len() - 1); + let end = end.min(payload.len().saturating_sub(1)); let body = &payload[start..=end]; format!( "HTTP/1.1 206 Partial Content\r\ncontent-length: {}\r\ncontent-range: bytes {}-{}/{}\r\ncache-control: max-age=60\r\nconnection: close\r\n\r\n{}", @@ -243,7 +243,6 @@ pub fn header_only_response(status: StatusLine, extra_headers: &[(&str, &str)]) response.into_bytes() } -#[allow(dead_code)] pub fn _shared_fill_streaming_probe( listen_addr: SocketAddr, ) -> Result<(String, HttpChunkRead, String, HttpChunkRead), String> { diff --git a/crates/rginx-app/tests/support/harness.rs b/crates/rginx-app/tests/support/harness.rs index 27ea44b4..9ff38c5e 100644 --- a/crates/rginx-app/tests/support/harness.rs +++ b/crates/rginx-app/tests/support/harness.rs @@ -3,29 +3,57 @@ use super::*; pub struct ServerHarness { child: Child, config_path: PathBuf, - temp_dir: PathBuf, - stdout_path: PathBuf, stderr_path: PathBuf, + stdout_path: PathBuf, + temp_dir: PathBuf, } impl ServerHarness { - pub fn spawn(prefix: &str, build_config: impl FnOnce(&Path) -> String) -> Self { - Self::spawn_inner(prefix, move |temp_dir| build_config(temp_dir)) + pub fn assert_running(&mut self) { + if let Some(status) = self.child.try_wait().expect("child status should be readable") { + panic!("rginx exited unexpectedly with status {status}\n{}", self.combined_output()); + } } - pub fn spawn_with_tls( - prefix: &str, - cert_pem: &str, - key_pem: &str, - build_config: impl FnOnce(&Path, &Path, &Path) -> String, - ) -> Self { - Self::spawn_inner(prefix, move |temp_dir| { - let cert_path = temp_dir.join("server.crt"); - let key_path = temp_dir.join("server.key"); - fs::write(&cert_path, cert_pem).expect("test cert should be written"); - fs::write(&key_path, key_pem).expect("test key should be written"); - build_config(temp_dir, &cert_path, &key_path) - }) + pub fn combined_output(&self) -> String { + let stdout = read_optional_log(&self.stdout_path); + let stderr = read_optional_log(&self.stderr_path); + format!("stdout:\n{stdout}\nstderr:\n{stderr}") + } + + pub fn config_path(&self) -> &Path { + &self.config_path + } + + pub fn kill_and_wait(&mut self, timeout: Duration) { + self.child.kill().expect("rginx should accept a kill signal"); + let status = self.wait_for_exit(timeout); + assert!( + !status.success() || status.code() == Some(0), + "rginx should exit after the test, got {status}\n{}", + self.combined_output() + ); + } + + #[cfg(unix)] + pub fn send_signal(&self, signal: i32) { + let result = unsafe { libc::kill(self.child.id() as i32, signal) }; + assert!( + result == 0, + "failed to send signal {} to pid {}: {}\n{}", + signal, + self.child.id(), + std::io::Error::last_os_error(), + self.combined_output() + ); + } + + pub fn shutdown_and_wait(&mut self, timeout: Duration) { + self.kill_and_wait(timeout); + } + + pub fn spawn(prefix: &str, build_config: impl FnOnce(&Path) -> String) -> Self { + Self::spawn_inner(prefix, move |temp_dir| build_config(temp_dir)) } fn spawn_inner(prefix: &str, build_config: impl FnOnce(&Path) -> String) -> Self { @@ -46,34 +74,62 @@ impl ServerHarness { .spawn() .expect("rginx should start"); - Self { child, config_path, temp_dir, stdout_path, stderr_path } + Self { child, config_path, stderr_path, stdout_path, temp_dir } } - pub fn config_path(&self) -> &Path { - &self.config_path + pub fn spawn_with_tls( + prefix: &str, + cert_pem: &str, + key_pem: &str, + build_config: impl FnOnce(&Path, &Path, &Path) -> String, + ) -> Self { + Self::spawn_inner(prefix, move |temp_dir| { + let cert_path = temp_dir.join("server.crt"); + let key_path = temp_dir.join("server.key"); + fs::write(&cert_path, cert_pem).expect("test cert should be written"); + fs::write(&key_path, key_pem).expect("test key should be written"); + build_config(temp_dir, &cert_path, &key_path) + }) } pub fn temp_dir(&self) -> &Path { &self.temp_dir } - pub fn wait_for_http_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { - self.wait_for_http_text_response( - listen_addr, - &listen_addr.to_string(), - READY_PATH, - 200, - READY_BODY, - timeout, + #[cfg(unix)] + pub fn terminate_and_wait(&mut self, timeout: Duration) { + self.send_signal(libc::SIGTERM); + let status = self.wait_for_exit(timeout); + assert!( + status.success(), + "rginx should exit successfully, got {status}\n{}", + self.combined_output() ); } - pub fn wait_for_https_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { - self.wait_for_https_text_response( + pub fn wait_for_exit(&mut self, timeout: Duration) -> ExitStatus { + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); + + loop { + if let Some(status) = self.child.try_wait().expect("child status should be readable") { + return status; + } + + if Instant::now() >= deadline { + let _ = self.child.kill(); + let _ = self.child.wait(); + panic!("timed out waiting for rginx to exit\n{}", self.combined_output()); + } + + thread::sleep(Duration::from_millis(50)); + } + } + + pub fn wait_for_http_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { + self.wait_for_http_text_response( listen_addr, &listen_addr.to_string(), READY_PATH, - DEFAULT_TLS_SERVER_NAME, 200, READY_BODY, timeout, @@ -98,6 +154,18 @@ impl ServerHarness { ); } + pub fn wait_for_https_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { + self.wait_for_https_text_response( + listen_addr, + &listen_addr.to_string(), + READY_PATH, + DEFAULT_TLS_SERVER_NAME, + 200, + READY_BODY, + timeout, + ); + } + pub fn wait_for_https_text_response( &mut self, listen_addr: SocketAddr, @@ -125,7 +193,7 @@ impl ServerHarness { expected_status: u16, expected_body: &str, ) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -154,75 +222,6 @@ impl ServerHarness { self.combined_output() ); } - - pub fn shutdown_and_wait(&mut self, timeout: Duration) { - self.kill_and_wait(timeout); - } - - pub fn kill_and_wait(&mut self, timeout: Duration) { - self.child.kill().expect("rginx should accept a kill signal"); - let status = self.wait_for_exit(timeout); - assert!( - !status.success() || status.code() == Some(0), - "rginx should exit after the test, got {status}\n{}", - self.combined_output() - ); - } - - #[cfg(unix)] - pub fn send_signal(&self, signal: i32) { - let result = unsafe { libc::kill(self.child.id() as i32, signal) }; - if result != 0 { - panic!( - "failed to send signal {} to pid {}: {}\n{}", - signal, - self.child.id(), - std::io::Error::last_os_error(), - self.combined_output() - ); - } - } - - #[cfg(unix)] - pub fn terminate_and_wait(&mut self, timeout: Duration) { - self.send_signal(libc::SIGTERM); - let status = self.wait_for_exit(timeout); - assert!( - status.success(), - "rginx should exit successfully, got {status}\n{}", - self.combined_output() - ); - } - - pub fn wait_for_exit(&mut self, timeout: Duration) -> ExitStatus { - let deadline = Instant::now() + timeout; - - loop { - if let Some(status) = self.child.try_wait().expect("child status should be readable") { - return status; - } - - if Instant::now() >= deadline { - let _ = self.child.kill(); - let _ = self.child.wait(); - panic!("timed out waiting for rginx to exit\n{}", self.combined_output()); - } - - thread::sleep(Duration::from_millis(50)); - } - } - - pub fn assert_running(&mut self) { - if let Some(status) = self.child.try_wait().expect("child status should be readable") { - panic!("rginx exited unexpectedly with status {status}\n{}", self.combined_output()); - } - } - - pub fn combined_output(&self) -> String { - let stdout = read_optional_log(&self.stdout_path); - let stderr = read_optional_log(&self.stderr_path); - format!("stdout:\n{stdout}\nstderr:\n{stderr}") - } } impl Drop for ServerHarness { diff --git a/crates/rginx-app/tests/support/http.rs b/crates/rginx-app/tests/support/http.rs index 0775b059..26396236 100644 --- a/crates/rginx-app/tests/support/http.rs +++ b/crates/rginx-app/tests/support/http.rs @@ -1,5 +1,12 @@ use super::*; +#[derive(Debug, PartialEq, Eq)] +pub enum HttpChunkRead { + Chunk(Vec), + End, + TimedOut, +} + pub fn reserve_loopback_addr() -> SocketAddr { let listener = TcpListener::bind(("127.0.0.1", 0)).expect("ephemeral loopback listener should bind"); @@ -30,7 +37,8 @@ pub fn read_http_head(stream: &mut TcpStream) -> String { buffer.extend_from_slice(&chunk[..read]); if let Some(head_end) = buffer.windows(4).position(|window| window == b"\r\n\r\n") { - return String::from_utf8(buffer[..head_end + 4].to_vec()) + let body_start = head_end.saturating_add(4); + return String::from_utf8(buffer[..body_start].to_vec()) .expect("HTTP head should be valid UTF-8"); } } @@ -56,22 +64,16 @@ pub fn read_http_head_and_pending(stream: &mut TcpStream) -> (String, Vec) { buffer.extend_from_slice(&chunk[..read]); if let Some(head_end) = buffer.windows(4).position(|window| window == b"\r\n\r\n") { + let body_start = head_end.saturating_add(4); return ( - String::from_utf8(buffer[..head_end + 4].to_vec()) + String::from_utf8(buffer[..body_start].to_vec()) .expect("HTTP head should be valid UTF-8"), - buffer[head_end + 4..].to_vec(), + buffer[body_start..].to_vec(), ); } } } -#[derive(Debug, PartialEq, Eq)] -pub enum HttpChunkRead { - Chunk(Vec), - End, - TimedOut, -} - pub fn read_http_chunk(stream: &mut TcpStream, pending: &mut Vec) -> HttpChunkRead { let mut scratch = [0u8; 256]; @@ -98,9 +100,11 @@ pub fn read_http_chunk(stream: &mut TcpStream, pending: &mut Vec) -> HttpChu String::from_utf8(pending[..line_end].to_vec()).expect("chunk header should be utf-8"); let chunk_len = usize::from_str_radix(line.trim(), 16).expect("chunk length should be valid hex"); - pending.drain(..line_end + 2); + let header_end = line_end.saturating_add(2); + pending.drain(..header_end); - while pending.len() < chunk_len + 2 { + let chunk_with_terminator = chunk_len.saturating_add(2); + while pending.len() < chunk_with_terminator { match stream.read(&mut scratch) { Ok(0) => return HttpChunkRead::End, Ok(read) => pending.extend_from_slice(&scratch[..read]), @@ -122,7 +126,7 @@ pub fn read_http_chunk(stream: &mut TcpStream, pending: &mut Vec) -> HttpChu } let chunk = pending[..chunk_len].to_vec(); - pending.drain(..chunk_len + 2); + pending.drain(..chunk_with_terminator); HttpChunkRead::Chunk(chunk) } diff --git a/crates/rginx-app/tests/support/mod.rs b/crates/rginx-app/tests/support/mod.rs index a7445ea8..c2e21008 100644 --- a/crates/rginx-app/tests/support/mod.rs +++ b/crates/rginx-app/tests/support/mod.rs @@ -1,9 +1,13 @@ -#![allow(dead_code)] +#![expect(dead_code, reason = "shared integration-test helpers are used by different test targets")] +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] -#[ctor::ctor(unsafe)] -fn install_test_crypto_provider() { - rginx_http::install_default_crypto_provider(); -} +pub mod cache; +mod fs_paths; +mod harness; +mod http; +mod nginx; +mod response; +mod tls; use std::env; use std::fs; @@ -21,31 +25,63 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{ClientConfig, ClientConnection, DigitallySignedStruct, SignatureScheme, StreamOwned}; +use fs_paths::{binary_path, read_optional_log, temp_dir}; +use response::{fetch_http_text_response, fetch_https_text_response}; +use tls::InsecureServerCertVerifier; + pub const READY_ROUTE_CONFIG: &str = " LocationConfig(\n matcher: Exact(\"/-/ready\"),\n handler: Return(\n status: 200,\n location: \"\",\n body: Some(\"ready\\n\"),\n ),\n ),\n"; const READY_PATH: &str = "/-/ready"; const READY_BODY: &str = "ready\n"; const DEFAULT_TLS_SERVER_NAME: &str = "localhost"; -pub mod cache; -mod fs_paths; -mod harness; -mod http; -mod nginx; -mod response; -mod tls; +pub type ServerHarness = harness::ServerHarness; +pub type HttpChunkRead = http::HttpChunkRead; +pub type NginxHarness = nginx::NginxHarness; -#[allow(unused_imports)] -pub use harness::ServerHarness; -#[allow(unused_imports)] -pub use http::{ - HttpChunkRead, apply_tls_placeholders, connect_http_client, read_http_chunk, read_http_head, - read_http_head_and_pending, reserve_loopback_addr, reserve_loopback_addr_pair, - spawn_scripted_chunked_response_server, -}; -#[allow(unused_imports)] -pub use nginx::NginxHarness; +#[ctor::ctor(unsafe)] +fn install_test_crypto_provider() { + rginx_http::install_default_crypto_provider(); +} -use fs_paths::{binary_path, read_optional_log, temp_dir}; -use response::{fetch_http_text_response, fetch_https_text_response}; -use tls::InsecureServerCertVerifier; +pub fn reserve_loopback_addr() -> SocketAddr { + http::reserve_loopback_addr() +} + +pub fn reserve_loopback_addr_pair() -> (SocketAddr, SocketAddr) { + http::reserve_loopback_addr_pair() +} + +pub fn read_http_head(stream: &mut TcpStream) -> String { + http::read_http_head(stream) +} + +pub fn connect_http_client(listen_addr: SocketAddr, read_timeout: Duration) -> TcpStream { + http::connect_http_client(listen_addr, read_timeout) +} + +pub fn read_http_head_and_pending(stream: &mut TcpStream) -> (String, Vec) { + http::read_http_head_and_pending(stream) +} + +pub fn read_http_chunk(stream: &mut TcpStream, pending: &mut Vec) -> HttpChunkRead { + http::read_http_chunk(stream, pending) +} + +pub fn spawn_scripted_chunked_response_server( + expected_request_line: &'static str, + first_chunk: &'static [u8], + pause_after_first_chunk: Duration, + second_chunk: Option<&'static [u8]>, +) -> (SocketAddr, JoinHandle<()>) { + http::spawn_scripted_chunked_response_server( + expected_request_line, + first_chunk, + pause_after_first_chunk, + second_chunk, + ) +} + +pub fn apply_tls_placeholders(config: String, cert_path: &Path, key_path: &Path) -> String { + http::apply_tls_placeholders(config, cert_path, key_path) +} diff --git a/crates/rginx-app/tests/support/nginx.rs b/crates/rginx-app/tests/support/nginx.rs index 1cd01125..3304bdc1 100644 --- a/crates/rginx-app/tests/support/nginx.rs +++ b/crates/rginx-app/tests/support/nginx.rs @@ -3,12 +3,49 @@ use super::*; pub struct NginxHarness { child: Child, config_path: PathBuf, - temp_dir: PathBuf, - stdout_path: PathBuf, stderr_path: PathBuf, + stdout_path: PathBuf, + temp_dir: PathBuf, } impl NginxHarness { + pub fn assert_running(&mut self) { + if let Some(status) = self.child.try_wait().expect("nginx child status should read") { + panic!("nginx exited unexpectedly with status {status}\n{}", self.combined_output()); + } + } + + pub fn combined_output(&self) -> String { + let stdout = read_optional_log(&self.stdout_path); + let stderr = read_optional_log(&self.stderr_path); + format!("stdout:\n{stdout}\nstderr:\n{stderr}") + } + + pub fn config_path(&self) -> &Path { + &self.config_path + } + + pub fn is_available() -> bool { + nginx_binary_path().is_some() + } + + #[cfg(unix)] + pub fn send_signal(&self, signal: i32) { + let result = unsafe { libc::kill(self.child.id() as i32, signal) }; + assert!( + result == 0, + "failed to send signal {} to nginx pid {}: {}\n{}", + signal, + self.child.id(), + std::io::Error::last_os_error(), + self.combined_output() + ); + } + + pub fn shutdown_and_wait(&mut self, timeout: Duration) { + self.terminate_and_wait(timeout); + } + pub fn spawn(prefix: &str, build_config: impl FnOnce(&Path) -> String) -> Self { let temp_dir = temp_dir(prefix); fs::create_dir_all(&temp_dir).expect("temp nginx test dir should be created"); @@ -33,19 +70,40 @@ impl NginxHarness { .spawn() .expect("nginx should start"); - Self { child, config_path, temp_dir, stdout_path, stderr_path } + Self { child, config_path, stderr_path, stdout_path, temp_dir } } - pub fn is_available() -> bool { - nginx_binary_path().is_some() + pub fn temp_dir(&self) -> &Path { + &self.temp_dir } - pub fn config_path(&self) -> &Path { - &self.config_path + #[cfg(unix)] + pub fn terminate_and_wait(&mut self, timeout: Duration) { + self.send_signal(libc::SIGTERM); + let status = self.wait_for_exit(timeout); + assert!( + status.success(), + "nginx should exit successfully, got {status}\n{}", + self.combined_output() + ); } - pub fn temp_dir(&self) -> &Path { - &self.temp_dir + pub fn wait_for_exit(&mut self, timeout: Duration) -> ExitStatus { + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); + + loop { + if let Some(status) = self.child.try_wait().expect("nginx child status should read") { + return status; + } + + if Instant::now() >= deadline { + let _ = self.child.kill(); + let _ = self.child.wait(); + panic!("timed out waiting for nginx to exit\n{}", self.combined_output()); + } + + thread::sleep(Duration::from_millis(50)); + } } pub fn wait_for_http_ready(&mut self, listen_addr: SocketAddr, timeout: Duration) { @@ -85,7 +143,7 @@ impl NginxHarness { expected_status: u16, expected_body: &str, ) { - let deadline = Instant::now() + timeout; + let deadline = Instant::now().checked_add(timeout).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -114,65 +172,6 @@ impl NginxHarness { self.combined_output() ); } - - pub fn shutdown_and_wait(&mut self, timeout: Duration) { - self.terminate_and_wait(timeout); - } - - #[cfg(unix)] - pub fn send_signal(&self, signal: i32) { - let result = unsafe { libc::kill(self.child.id() as i32, signal) }; - if result != 0 { - panic!( - "failed to send signal {} to nginx pid {}: {}\n{}", - signal, - self.child.id(), - std::io::Error::last_os_error(), - self.combined_output() - ); - } - } - - #[cfg(unix)] - pub fn terminate_and_wait(&mut self, timeout: Duration) { - self.send_signal(libc::SIGTERM); - let status = self.wait_for_exit(timeout); - assert!( - status.success(), - "nginx should exit successfully, got {status}\n{}", - self.combined_output() - ); - } - - pub fn wait_for_exit(&mut self, timeout: Duration) -> ExitStatus { - let deadline = Instant::now() + timeout; - - loop { - if let Some(status) = self.child.try_wait().expect("nginx child status should read") { - return status; - } - - if Instant::now() >= deadline { - let _ = self.child.kill(); - let _ = self.child.wait(); - panic!("timed out waiting for nginx to exit\n{}", self.combined_output()); - } - - thread::sleep(Duration::from_millis(50)); - } - } - - pub fn assert_running(&mut self) { - if let Some(status) = self.child.try_wait().expect("nginx child status should read") { - panic!("nginx exited unexpectedly with status {status}\n{}", self.combined_output()); - } - } - - pub fn combined_output(&self) -> String { - let stdout = read_optional_log(&self.stdout_path); - let stderr = read_optional_log(&self.stderr_path); - format!("stdout:\n{stdout}\nstderr:\n{stderr}") - } } impl Drop for NginxHarness { diff --git a/crates/rginx-app/tests/support/response.rs b/crates/rginx-app/tests/support/response.rs index 43194a56..501d3ac3 100644 --- a/crates/rginx-app/tests/support/response.rs +++ b/crates/rginx-app/tests/support/response.rs @@ -60,7 +60,7 @@ fn read_text_response_from_stream(stream: &mut impl Read) -> Result<(u16, String } buffer.extend_from_slice(&chunk[..read]); if let Some(position) = buffer.windows(4).position(|window| window == b"\r\n\r\n") { - break position + 4; + break position.saturating_add(4); } }; @@ -150,12 +150,13 @@ fn decode_chunked_body(stream: &mut impl Read, mut buffer: Vec) -> Result) -> Result) -> Result Vec { + self.supported_schemes.clone() + } fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, @@ -44,8 +47,4 @@ impl ServerCertVerifier for InsecureServerCertVerifier { ) -> Result { Ok(HandshakeSignatureValid::assertion()) } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } } diff --git a/crates/rginx-app/tests/tls_policy.rs b/crates/rginx-app/tests/tls_policy.rs index c038df54..e671a147 100644 --- a/crates/rginx-app/tests/tls_policy.rs +++ b/crates/rginx-app/tests/tls_policy.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; use std::sync::Arc; @@ -11,13 +22,79 @@ use rustls::{ StreamOwned, }; -mod support; - use support::{ServerHarness, apply_tls_placeholders, reserve_loopback_addr}; const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/GfUhio45GzZi\nz/o67ngbdqW77Bk54aQ/L9kthS47SkqHVaC8+/AYXfDddl0Q+EHQleGrCqjPrRLA\nZAm0Qz0YZEdWcFX4F4T/gvVc7gfobPmkwoVoJwk/A7giuymN9+kpDbXsr2X29uMr\nfW0krA4ZpD/+JzLMXXfG5S9mazrs81wGqFlByGV32z3Qiecvd8JL4nnyz2mMTSqV\nOvtuh285dpgHFugOUGHhhXs9b/CPAGiBTBxuIm+2tGdjtmVGL8L9tlRMYnmQ7u9K\nUCAyHYovUPJjZazTqfkpQ2CkC4iOZ7Ay25gn0fbGBEysyOBZIPw39FEW63jGembu\nOVF6nO/DAgMBAAECggEAKLC7v80TVHiFX4veQZ8WRu7AAmAWzPrNMMEc8rLZcblz\nXhau956DdITILTevQFZEGUhYuUU3RaUaCYojgNUSVLfBctfPjlhfstItMYDjgSt3\nCox6wH8TWm4NzqNgiUCgzmODeaatROUz4MY/r5/NDsuo7pJlIBvEzb5uFdY+QUZ/\nR5gHRiD2Q3wCODe8zQRfTZGo7jCimAuWTLurWZl6ax/4TjWbXCD6DTuUo81cW3vy\nne6tEetHcABRO7uDoBYXk12pCgqFZzjLMnKJjQM+OYnSj6DoWjOu1drT5YyRLGDj\nfzN8V0aKRkOYoZ5QZOua8pByOyQElJnM16vkPtHgPQKBgQD6SOUNWEghvYIGM/lx\nc22/zjvDjeaGC3qSmlpQYN5MGuDoszeDBZ+rMTmHqJ9FcHYkLQnUI7ZkHhRGt/wQ\n/w3CroJjPBgKk+ipy2cBHSI+z+U20xjYzE8hxArWbXG1G4rDt5AIz68IQPsfkVND\nktkDABDaU+KwBPx8fjeeqtRQxQKBgQDDdxdLB1XcfZMX0KEP5RfA8ar1nW41TUAl\nTCOLaXIQbHZ0BeW7USE9mK8OKnVALZGJ+rpxvYFPZ5MWxchpb/cuIwXjLoN6uZVb\nfx4Hho+2iCfhcEKzs8XZW48duKIfhx13BiILLf/YaHAWFs9UfVcQog4Qx03guyMr\n7k9bFuy25wKBgQDpE48zAT6TJS775dTrAQp4b28aan/93pyz/8gRSFRb3UALlDIi\n8s7BluKzYaWI/fUXNVYM14EX9Sb+wIGdtlezL94+2Yyt9RXbYY8361Cj2+jiSG3A\nH2ulzzIkg+E7Pj3Yi443lmiysAjsWeKHcC5l697F4w6cytfye3wCZ6W23QKBgQC0\n9tX+5aytdSkwnDvxXlVOka+ItBcri/i+Ty59TMOIxxInuqoFcUhIIcq4X8CsCUQ8\nLYBd+2fznt3D8JrqWvnKoiw6N38MqTLJQfgIWaFGCep6QhfPDbo30RfAGYcnj01N\nO8Va+lxq+84B9V5AR8bKpG5HRG4qiLc4XerkV2YSswKBgDt9eerSBZyLVwfku25Y\nfrh+nEjUZy81LdlpJmu/bfa2FfItzBqDZPskkJJW9ON82z/ejGFbsU48RF7PJUMr\nGimE33QeTDToGozHCq0QOd0SMfsVkOQR+EROdmY52UIYAYgQUfI1FQ9lLsw10wlQ\nD11SHTL7b9pefBWfW73I7ttV\n-----END PRIVATE KEY-----\n"; +struct TlsPolicyResponse { + alpn_protocol: Option, + body: String, + protocol_version: Option, + status: u16, +} + +struct TlsPolicyCase<'a> { + alpn_protocols: &'a [&'a [u8]], + enable_sni: bool, + expected_alpn: Option<&'a str>, + expected_body: &'a str, + expected_protocol_version: Option, + expected_status: u16, + host: &'a str, + path: &'a str, + server_name: &'a str, +} + +#[derive(Debug)] +struct InsecureServerCertVerifier { + supported_schemes: Vec, +} + +impl InsecureServerCertVerifier { + fn new() -> Self { + Self { + supported_schemes: rustls::crypto::aws_lc_rs::default_provider() + .signature_verification_algorithms + .supported_schemes(), + } + } +} + +impl ServerCertVerifier for InsecureServerCertVerifier { + fn supported_verify_schemes(&self) -> Vec { + self.supported_schemes.clone() + } + fn verify_server_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _server_name: &ServerName<'_>, + _ocsp_response: &[u8], + _now: UnixTime, + ) -> Result { + Ok(ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(HandshakeSignatureValid::assertion()) + } +} + #[test] fn tls12_only_listener_negotiates_tls12() { let listen_addr = reserve_loopback_addr(); @@ -115,31 +192,13 @@ fn default_certificate_supports_sniless_clients_with_multiple_vhost_certs() { server.shutdown_and_wait(Duration::from_secs(5)); } -struct TlsPolicyResponse { - status: u16, - body: String, - protocol_version: Option, - alpn_protocol: Option, -} - -struct TlsPolicyCase<'a> { - host: &'a str, - path: &'a str, - server_name: &'a str, - enable_sni: bool, - alpn_protocols: &'a [&'a [u8]], - expected_status: u16, - expected_body: &'a str, - expected_protocol_version: Option, - expected_alpn: Option<&'a str>, -} - fn wait_for_https_policy_response( server: &mut ServerHarness, listen_addr: SocketAddr, case: TlsPolicyCase<'_>, ) { - let deadline = Instant::now() + Duration::from_secs(5); + let deadline = + Instant::now().checked_add(Duration::from_secs(5)).expect("test deadline remains valid"); let mut last_error = String::new(); while Instant::now() < deadline { @@ -263,53 +322,3 @@ fn sniless_default_certificate_config(listen_addr: SocketAddr) -> String { listen_addr.to_string(), ) } - -#[derive(Debug)] -struct InsecureServerCertVerifier { - supported_schemes: Vec, -} - -impl InsecureServerCertVerifier { - fn new() -> Self { - Self { - supported_schemes: rustls::crypto::aws_lc_rs::default_provider() - .signature_verification_algorithms - .supported_schemes(), - } - } -} - -impl ServerCertVerifier for InsecureServerCertVerifier { - fn verify_server_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _server_name: &ServerName<'_>, - _ocsp_response: &[u8], - _now: UnixTime, - ) -> Result { - Ok(ServerCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } -} diff --git a/crates/rginx-app/tests/upgrade.rs b/crates/rginx-app/tests/upgrade.rs index 86cc2ec5..527dc480 100644 --- a/crates/rginx-app/tests/upgrade.rs +++ b/crates/rginx-app/tests/upgrade.rs @@ -1,12 +1,14 @@ -use std::io::{self, Read, Write}; -use std::net::{SocketAddr, TcpListener, TcpStream}; -use std::thread; -use std::time::{Duration, Instant}; - +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] mod support; -use support::{READY_ROUTE_CONFIG, ServerHarness, read_http_head, reserve_loopback_addr}; - #[path = "upgrade/max_connections.rs"] mod max_connections; #[path = "upgrade/proxy_flow.rs"] @@ -14,6 +16,13 @@ mod proxy_flow; #[path = "upgrade/reload_flow.rs"] mod reload_flow; +use std::io::{self, Read, Write}; +use std::net::{SocketAddr, TcpListener, TcpStream}; +use std::thread; +use std::time::{Duration, Instant}; + +use support::{READY_ROUTE_CONFIG, ServerHarness, read_http_head, reserve_loopback_addr}; + fn upgrade_proxy_config(listen_addr: SocketAddr, upstream_addr: SocketAddr) -> String { upgrade_proxy_config_with_server_extra(listen_addr, upstream_addr, None) } @@ -66,7 +75,7 @@ fn try_read_http_head(stream: &mut TcpStream) -> io::Result { buffer.extend_from_slice(&chunk[..read]); if let Some(head_end) = buffer.windows(4).position(|window| window == b"\r\n\r\n") { - return String::from_utf8(buffer[..head_end + 4].to_vec()) + return String::from_utf8(buffer[..head_end.saturating_add(4)].to_vec()) .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error)); } } diff --git a/crates/rginx-app/tests/upstream_http2.rs b/crates/rginx-app/tests/upstream_http2.rs index cebe3a53..11c1b866 100644 --- a/crates/rginx-app/tests/upstream_http2.rs +++ b/crates/rginx-app/tests/upstream_http2.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::convert::Infallible; use std::env; use std::fs; @@ -22,8 +33,6 @@ use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio_rustls::TlsAcceptor; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; @@ -31,11 +40,11 @@ const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkq #[derive(Debug, Clone, PartialEq, Eq)] struct ObservedRequest { - version: Version, - path: String, + alpn_protocol: Option, authority: Option, host: Option, - alpn_protocol: Option, + path: String, + version: Version, } #[tokio::test(flavor = "multi_thread")] diff --git a/crates/rginx-app/tests/upstream_http3.rs b/crates/rginx-app/tests/upstream_http3.rs index d42438b1..fa552852 100644 --- a/crates/rginx-app/tests/upstream_http3.rs +++ b/crates/rginx-app/tests/upstream_http3.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::env; use std::fmt; use std::fs; @@ -26,16 +37,105 @@ use rustls::{ use tokio::sync::oneshot; use tokio::task::JoinHandle; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[derive(Debug, Clone, PartialEq, Eq)] struct ObservedHttp3Request { - sni: Option, + path: String, peer_certificates_present: bool, protocol_version: Option, - path: String, + sni: Option, +} + +struct CapturingResolver { + certified_key: Arc, + observed_sni: Arc>>, +} + +impl fmt::Debug for CapturingResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CapturingResolver").finish_non_exhaustive() + } +} + +impl ResolvesServerCert for CapturingResolver { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + if let Some(server_name) = client_hello.server_name() { + *self.observed_sni.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = + Some(server_name.to_string()); + } + Some(self.certified_key.clone()) + } +} + +#[derive(Debug)] +struct RequireTrustedClientCertVerifier { + client_cert_seen: Arc, + inner: Arc, +} + +impl RequireTrustedClientCertVerifier { + fn new(ca_path: &Path, client_cert_seen: Arc) -> Self { + let mut roots = RootCertStore::empty(); + for cert in load_certs(ca_path) { + roots.add(cert).expect("client CA certificate should load into root store"); + } + let inner = WebPkiClientVerifier::builder(roots.into()) + .build() + .expect("client verifier should build"); + Self { inner, client_cert_seen } + } +} + +impl ClientCertVerifier for RequireTrustedClientCertVerifier { + fn root_hint_subjects(&self) -> &[DistinguishedName] { + self.inner.root_hint_subjects() + } + + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } + + fn verify_client_cert( + &self, + end_entity: &CertificateDer<'_>, + intermediates: &[CertificateDer<'_>], + now: rustls::pki_types::UnixTime, + ) -> Result { + let verified = self.inner.verify_client_cert(end_entity, intermediates, now)?; + self.client_cert_seen.store(true, Ordering::Relaxed); + Ok(verified) + } + + fn verify_tls12_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( + &self, + message: &[u8], + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { + self.inner.verify_tls13_signature(message, cert, dss) + } +} + +struct TestCertifiedKey { + cert: rcgen::Certificate, + params: CertificateParams, + signing_key: KeyPair, +} + +impl TestCertifiedKey { + fn issuer(&self) -> Issuer<'_, &KeyPair> { + Issuer::from_params(&self.params, &self.signing_key) + } } #[tokio::test(flavor = "multi_thread")] @@ -269,97 +369,6 @@ fn temp_dir(prefix: &str) -> PathBuf { env::temp_dir().join(format!("{prefix}-{unique}-{id}")) } -struct CapturingResolver { - certified_key: Arc, - observed_sni: Arc>>, -} - -impl fmt::Debug for CapturingResolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("CapturingResolver").finish_non_exhaustive() - } -} - -impl ResolvesServerCert for CapturingResolver { - fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { - if let Some(server_name) = client_hello.server_name() { - *self.observed_sni.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = - Some(server_name.to_string()); - } - Some(self.certified_key.clone()) - } -} - -#[derive(Debug)] -struct RequireTrustedClientCertVerifier { - inner: Arc, - client_cert_seen: Arc, -} - -impl RequireTrustedClientCertVerifier { - fn new(ca_path: &Path, client_cert_seen: Arc) -> Self { - let mut roots = RootCertStore::empty(); - for cert in load_certs(ca_path) { - roots.add(cert).expect("client CA certificate should load into root store"); - } - let inner = WebPkiClientVerifier::builder(roots.into()) - .build() - .expect("client verifier should build"); - Self { inner, client_cert_seen } - } -} - -impl ClientCertVerifier for RequireTrustedClientCertVerifier { - fn root_hint_subjects(&self) -> &[DistinguishedName] { - self.inner.root_hint_subjects() - } - - fn verify_client_cert( - &self, - end_entity: &CertificateDer<'_>, - intermediates: &[CertificateDer<'_>], - now: rustls::pki_types::UnixTime, - ) -> Result { - let verified = self.inner.verify_client_cert(end_entity, intermediates, now)?; - self.client_cert_seen.store(true, Ordering::Relaxed); - Ok(verified) - } - - fn verify_tls12_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - self.inner.verify_tls12_signature(message, cert, dss) - } - - fn verify_tls13_signature( - &self, - message: &[u8], - cert: &CertificateDer<'_>, - dss: &DigitallySignedStruct, - ) -> Result { - self.inner.verify_tls13_signature(message, cert, dss) - } - - fn supported_verify_schemes(&self) -> Vec { - self.inner.supported_verify_schemes() - } -} - -struct TestCertifiedKey { - cert: rcgen::Certificate, - signing_key: KeyPair, - params: CertificateParams, -} - -impl TestCertifiedKey { - fn issuer(&self) -> Issuer<'_, &KeyPair> { - Issuer::from_params(&self.params, &self.signing_key) - } -} - fn generate_ca_cert(common_name: &str) -> TestCertifiedKey { let mut params = CertificateParams::new(vec![common_name.to_string()]).expect("CA params should build"); diff --git a/crates/rginx-app/tests/upstream_mtls.rs b/crates/rginx-app/tests/upstream_mtls.rs index d71e5b7e..2b3092a0 100644 --- a/crates/rginx-app/tests/upstream_mtls.rs +++ b/crates/rginx-app/tests/upstream_mtls.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::convert::Infallible; use std::env; use std::fs; @@ -23,8 +34,6 @@ use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio_rustls::TlsAcceptor; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4EFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2znyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngqp7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gpqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+yfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6JrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5x23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59CiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; @@ -32,9 +41,63 @@ const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkq #[derive(Debug, Clone, PartialEq, Eq)] struct ObservedMtlsRequest { + host: Option, peer_certificates_present: bool, protocol_version: Option, - host: Option, +} + +#[derive(Debug)] +struct RequireAnyClientCertVerifier { + root_hints: Vec, + supported_schemes: Vec, +} + +impl RequireAnyClientCertVerifier { + fn new() -> Self { + Self { + root_hints: Vec::new(), + supported_schemes: rustls::crypto::aws_lc_rs::default_provider() + .signature_verification_algorithms + .supported_schemes(), + } + } +} + +impl ClientCertVerifier for RequireAnyClientCertVerifier { + fn root_hint_subjects(&self) -> &[DistinguishedName] { + &self.root_hints + } + + fn supported_verify_schemes(&self) -> Vec { + self.supported_schemes.clone() + } + + fn verify_client_cert( + &self, + _end_entity: &CertificateDer<'_>, + _intermediates: &[CertificateDer<'_>], + _now: rustls::pki_types::UnixTime, + ) -> Result { + Ok(ClientCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &CertificateDer<'_>, + _dss: &DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } } #[tokio::test(flavor = "multi_thread")] @@ -192,57 +255,3 @@ fn temp_dir(prefix: &str) -> PathBuf { let id = NEXT_ID.fetch_add(1, Ordering::Relaxed); env::temp_dir().join(format!("{prefix}-{unique}-{id}")) } - -#[derive(Debug)] -struct RequireAnyClientCertVerifier { - root_hints: Vec, - supported_schemes: Vec, -} - -impl RequireAnyClientCertVerifier { - fn new() -> Self { - Self { - root_hints: Vec::new(), - supported_schemes: rustls::crypto::aws_lc_rs::default_provider() - .signature_verification_algorithms - .supported_schemes(), - } - } -} - -impl ClientCertVerifier for RequireAnyClientCertVerifier { - fn root_hint_subjects(&self) -> &[DistinguishedName] { - &self.root_hints - } - - fn verify_client_cert( - &self, - _end_entity: &CertificateDer<'_>, - _intermediates: &[CertificateDer<'_>], - _now: rustls::pki_types::UnixTime, - ) -> Result { - Ok(ClientCertVerified::assertion()) - } - - fn verify_tls12_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn verify_tls13_signature( - &self, - _message: &[u8], - _cert: &CertificateDer<'_>, - _dss: &DigitallySignedStruct, - ) -> Result { - Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) - } - - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } -} diff --git a/crates/rginx-app/tests/upstream_server_name.rs b/crates/rginx-app/tests/upstream_server_name.rs index 1e8ee785..70a0c406 100644 --- a/crates/rginx-app/tests/upstream_server_name.rs +++ b/crates/rginx-app/tests/upstream_server_name.rs @@ -1,3 +1,14 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] +mod support; + use std::convert::Infallible; use std::env; use std::fmt; @@ -24,13 +35,33 @@ use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio_rustls::TlsAcceptor; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; const TEST_SERVER_KEY_PEM: &str = "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC/GfUhio45GzZi\nz/o67ngbdqW77Bk54aQ/L9kthS47SkqHVaC8+/AYXfDddl0Q+EHQleGrCqjPrRLA\nZAm0Qz0YZEdWcFX4F4T/gvVc7gfobPmkwoVoJwk/A7giuymN9+kpDbXsr2X29uMr\nfW0krA4ZpD/+JzLMXXfG5S9mazrs81wGqFlByGV32z3Qiecvd8JL4nnyz2mMTSqV\nOvtuh285dpgHFugOUGHhhXs9b/CPAGiBTBxuIm+2tGdjtmVGL8L9tlRMYnmQ7u9K\nUCAyHYovUPJjZazTqfkpQ2CkC4iOZ7Ay25gn0fbGBEysyOBZIPw39FEW63jGembu\nOVF6nO/DAgMBAAECggEAKLC7v80TVHiFX4veQZ8WRu7AAmAWzPrNMMEc8rLZcblz\nXhau956DdITILTevQFZEGUhYuUU3RaUaCYojgNUSVLfBctfPjlhfstItMYDjgSt3\nCox6wH8TWm4NzqNgiUCgzmODeaatROUz4MY/r5/NDsuo7pJlIBvEzb5uFdY+QUZ/\nR5gHRiD2Q3wCODe8zQRfTZGo7jCimAuWTLurWZl6ax/4TjWbXCD6DTuUo81cW3vy\nne6tEetHcABRO7uDoBYXk12pCgqFZzjLMnKJjQM+OYnSj6DoWjOu1drT5YyRLGDj\nfzN8V0aKRkOYoZ5QZOua8pByOyQElJnM16vkPtHgPQKBgQD6SOUNWEghvYIGM/lx\nc22/zjvDjeaGC3qSmlpQYN5MGuDoszeDBZ+rMTmHqJ9FcHYkLQnUI7ZkHhRGt/wQ\n/w3CroJjPBgKk+ipy2cBHSI+z+U20xjYzE8hxArWbXG1G4rDt5AIz68IQPsfkVND\nktkDABDaU+KwBPx8fjeeqtRQxQKBgQDDdxdLB1XcfZMX0KEP5RfA8ar1nW41TUAl\nTCOLaXIQbHZ0BeW7USE9mK8OKnVALZGJ+rpxvYFPZ5MWxchpb/cuIwXjLoN6uZVb\nfx4Hho+2iCfhcEKzs8XZW48duKIfhx13BiILLf/YaHAWFs9UfVcQog4Qx03guyMr\n7k9bFuy25wKBgQDpE48zAT6TJS775dTrAQp4b28aan/93pyz/8gRSFRb3UALlDIi\n8s7BluKzYaWI/fUXNVYM14EX9Sb+wIGdtlezL94+2Yyt9RXbYY8361Cj2+jiSG3A\nH2ulzzIkg+E7Pj3Yi443lmiysAjsWeKHcC5l697F4w6cytfye3wCZ6W23QKBgQC0\n9tX+5aytdSkwnDvxXlVOka+ItBcri/i+Ty59TMOIxxInuqoFcUhIIcq4X8CsCUQ8\nLYBd+2fznt3D8JrqWvnKoiw6N38MqTLJQfgIWaFGCep6QhfPDbo30RfAGYcnj01N\nO8Va+lxq+84B9V5AR8bKpG5HRG4qiLc4XerkV2YSswKBgDt9eerSBZyLVwfku25Y\nfrh+nEjUZy81LdlpJmu/bfa2FfItzBqDZPskkJJW9ON82z/ejGFbsU48RF7PJUMr\nGimE33QeTDToGozHCq0QOd0SMfsVkOQR+EROdmY52UIYAYgQUfI1FQ9lLsw10wlQ\nD11SHTL7b9pefBWfW73I7ttV\n-----END PRIVATE KEY-----\n"; +struct CapturingResolver { + certified_key: Arc, + observed_sni: Arc>>>>, +} + +impl fmt::Debug for CapturingResolver { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("CapturingResolver").finish_non_exhaustive() + } +} + +impl ResolvesServerCert for CapturingResolver { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + if let Some(sender) = + self.observed_sni.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).take() + { + let _ = sender.send(client_hello.server_name().map(str::to_string)); + } + Some(self.certified_key.clone()) + } +} + #[tokio::test(flavor = "multi_thread")] async fn upstream_server_name_override_sets_sni() { let (upstream_addr, observed_rx, upstream_task, upstream_temp_dir) = @@ -188,25 +219,3 @@ fn temp_dir(prefix: &str) -> PathBuf { let id = NEXT_ID.fetch_add(1, Ordering::Relaxed); env::temp_dir().join(format!("{prefix}-{unique}-{id}")) } - -struct CapturingResolver { - certified_key: Arc, - observed_sni: Arc>>>>, -} - -impl fmt::Debug for CapturingResolver { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("CapturingResolver").finish_non_exhaustive() - } -} - -impl ResolvesServerCert for CapturingResolver { - fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { - if let Some(sender) = - self.observed_sni.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).take() - { - let _ = sender.send(client_hello.server_name().map(str::to_string)); - } - Some(self.certified_key.clone()) - } -} diff --git a/crates/rginx-app/tests/vhost.rs b/crates/rginx-app/tests/vhost.rs index f6f8a0cd..5733bf19 100644 --- a/crates/rginx-app/tests/vhost.rs +++ b/crates/rginx-app/tests/vhost.rs @@ -1,8 +1,17 @@ -use std::net::SocketAddr; -use std::time::Duration; +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] mod support; +use std::net::SocketAddr; +use std::time::Duration; + use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/weighted_round_robin.rs b/crates/rginx-app/tests/weighted_round_robin.rs index a49ef0d4..3bd1d565 100644 --- a/crates/rginx-app/tests/weighted_round_robin.rs +++ b/crates/rginx-app/tests/weighted_round_robin.rs @@ -1,10 +1,19 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; #[test] diff --git a/crates/rginx-app/tests/workers.rs b/crates/rginx-app/tests/workers.rs index 1ee837b5..8fe722d5 100644 --- a/crates/rginx-app/tests/workers.rs +++ b/crates/rginx-app/tests/workers.rs @@ -1,13 +1,28 @@ +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] + +mod support; + use std::io::{Read, Write}; use std::net::{SocketAddr, TcpStream}; use std::sync::{Arc, Barrier}; use std::thread; use std::time::Duration; -mod support; - use support::{READY_ROUTE_CONFIG, ServerHarness, reserve_loopback_addr}; +#[derive(Debug)] +struct ParsedResponse { + body: Vec, + status: u16, +} + #[test] fn serves_requests_with_configured_runtime_and_accept_workers() { let listen_addr = reserve_loopback_addr(); @@ -39,12 +54,6 @@ fn serves_requests_with_configured_runtime_and_accept_workers() { server.shutdown_and_wait(Duration::from_secs(5)); } -#[derive(Debug)] -struct ParsedResponse { - status: u16, - body: Vec, -} - fn send_http_request(listen_addr: SocketAddr, request: &str) -> Result { let mut stream = TcpStream::connect_timeout(&listen_addr, Duration::from_millis(500)) .map_err(|error| format!("failed to connect to {listen_addr}: {error}"))?; @@ -87,7 +96,8 @@ fn parse_http_response(bytes: &[u8]) -> Result { }; } - Ok(ParsedResponse { status, body: bytes[head_end + 4..].to_vec() }) + let body_start = head_end.saturating_add(4); + Ok(ParsedResponse { status, body: bytes[body_start..].to_vec() }) } fn return_config(listen_addr: SocketAddr) -> String { diff --git a/crates/rginx-config/Cargo.toml b/crates/rginx-config/Cargo.toml index 0e57b687..7838a552 100644 --- a/crates/rginx-config/Cargo.toml +++ b/crates/rginx-config/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true rust-version.workspace = true description = "Configuration loading, validation, and compilation for rginx." +[lints] +workspace = true + [dependencies] http.workspace = true ipnet.workspace = true diff --git a/crates/rginx-config/src/compile/acme.rs b/crates/rginx-config/src/compile/acme.rs index 49b76992..65ce38e2 100644 --- a/crates/rginx-config/src/compile/acme.rs +++ b/crates/rginx-config/src/compile/acme.rs @@ -17,7 +17,7 @@ pub(super) fn compile_global_acme( let contacts = acme.contacts.into_iter().map(|contact| contact.trim().to_string()).collect(); let state_dir = super::resolve_path(base_dir, acme.state_dir); let renew_before = Duration::from_secs( - acme.renew_before_days.unwrap_or(DEFAULT_ACME_RENEW_BEFORE_DAYS) * 86_400, + acme.renew_before_days.unwrap_or(DEFAULT_ACME_RENEW_BEFORE_DAYS).saturating_mul(86_400), ); let poll_interval = Duration::from_secs(acme.poll_interval_secs.unwrap_or(DEFAULT_ACME_POLL_INTERVAL_SECS)); diff --git a/crates/rginx-config/src/compile/cache.rs b/crates/rginx-config/src/compile/cache.rs index 8e1ead0f..517c9fb6 100644 --- a/crates/rginx-config/src/compile/cache.rs +++ b/crates/rginx-config/src/compile/cache.rs @@ -44,9 +44,9 @@ pub(super) fn compile_cache_zones( let default_max_entry_bytes = usize_from_u64(DEFAULT_CACHE_MAX_ENTRY_BYTES)?; let max_entry_bytes = match zone.max_entry_bytes { Some(value) => usize_from_u64(value)?, - None => max_size_bytes - .map(|max_size| default_max_entry_bytes.min(max_size)) - .unwrap_or(default_max_entry_bytes), + None => max_size_bytes.map_or(default_max_entry_bytes, |max_size| { + default_max_entry_bytes.min(max_size) + }), }; if let Some(max_size_bytes) = max_size_bytes && max_entry_bytes > max_size_bytes @@ -181,8 +181,7 @@ pub(super) fn compile_route_cache( ), range_requests: cache .range_requests - .map(compile_range_request_policy) - .unwrap_or(CacheRangeRequestPolicy::Bypass), + .map_or(CacheRangeRequestPolicy::Bypass, compile_range_request_policy), slice_size_bytes: cache.slice_size_bytes, convert_head: cache.convert_head.unwrap_or(DEFAULT_CACHE_CONVERT_HEAD), }) diff --git a/crates/rginx-config/src/compile/mod.rs b/crates/rginx-config/src/compile/mod.rs index 2edcbf6b..c74f2017 100644 --- a/crates/rginx-config/src/compile/mod.rs +++ b/crates/rginx-config/src/compile/mod.rs @@ -1,12 +1,3 @@ -use std::path::Path; -use std::{collections::HashSet, path::PathBuf}; - -use crate::model::Config; -use crate::model::VirtualHostConfig; -use rginx_core::{ConfigSnapshot, Result, VirtualHost}; - -use crate::validate::validate; - mod acme; mod agent; mod cache; @@ -18,6 +9,19 @@ mod server; mod upstream; mod vhost; +#[cfg(test)] +mod tests; +use std::path::Path; +use std::{collections::HashSet, path::PathBuf}; + +use crate::model::Config; +use crate::model::VirtualHostConfig; +use rginx_core::{ConfigSnapshot, Result, VirtualHost}; + +use crate::validate::validate; + +use path::resolve_path; + const DEFAULT_UPSTREAM_REQUEST_TIMEOUT_SECS: u64 = 30; const DEFAULT_UPSTREAM_CONNECT_TIMEOUT_SECS: u64 = DEFAULT_UPSTREAM_REQUEST_TIMEOUT_SECS; const DEFAULT_UPSTREAM_WRITE_TIMEOUT_SECS: u64 = DEFAULT_UPSTREAM_REQUEST_TIMEOUT_SECS; @@ -39,8 +43,6 @@ const DEFAULT_UPSTREAM_DNS_REFRESH_BEFORE_EXPIRY_SECS: u64 = 10; const DEFAULT_VHOST_ID: &str = "server"; const DEFAULT_GRPC_HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check"; -use path::resolve_path; - #[derive(Debug, Clone, Copy, Default)] pub struct CompileOptions { pub allow_missing_managed_tls_identity: bool, @@ -180,5 +182,3 @@ fn collect_managed_identity_pairs( }) .collect() } -#[cfg(test)] -mod tests; diff --git a/crates/rginx-config/src/compile/route.rs b/crates/rginx-config/src/compile/route.rs index c70acacf..01040188 100644 --- a/crates/rginx-config/src/compile/route.rs +++ b/crates/rginx-config/src/compile/route.rs @@ -1,3 +1,6 @@ +mod action; +mod file; + use std::collections::HashMap; use std::path::Path; use std::sync::Arc; @@ -16,9 +19,6 @@ use crate::model::{ RouteBufferingPolicyConfig, RouteCompressionPolicyConfig, TryFileStepConfig, }; -mod action; -mod file; - use action::compile_route_action; use file::file_route_prefix_from_config; diff --git a/crates/rginx-config/src/compile/server.rs b/crates/rginx-config/src/compile/server.rs index 92df372f..8c534dc6 100644 --- a/crates/rginx-config/src/compile/server.rs +++ b/crates/rginx-config/src/compile/server.rs @@ -1,3 +1,9 @@ +mod fields; +mod http3; +mod listener; +mod listener_managed_identity; +mod tls; + use std::collections::HashSet; use std::path::{Path, PathBuf}; @@ -5,12 +11,6 @@ use rginx_core::{Listener, Result, VirtualHostTls}; use crate::model::{ListenerConfig, ServerConfig, VirtualHostConfig, VirtualHostTlsConfig}; -mod fields; -mod http3; -mod listener; -mod listener_managed_identity; -mod tls; - #[cfg(test)] pub(super) const DEFAULT_HTTP3_MAX_CONCURRENT_STREAMS: usize = http3::DEFAULT_HTTP3_MAX_CONCURRENT_STREAMS; diff --git a/crates/rginx-config/src/compile/server/fields.rs b/crates/rginx-config/src/compile/server/fields.rs index 8f92f226..a3bb44f3 100644 --- a/crates/rginx-config/src/compile/server/fields.rs +++ b/crates/rginx-config/src/compile/server/fields.rs @@ -18,22 +18,22 @@ pub(super) struct CompiledServerFields { } pub(super) struct ServerFieldConfig { - pub(super) listen: String, - pub(super) server_header: Option, - pub(super) default_certificate: Option, - pub(super) trusted_proxies: Vec, + pub(super) access_log_format: Option, + pub(super) allow_missing_tls_identity: bool, pub(super) client_ip_header: Option, + pub(super) default_certificate: Option, + pub(super) header_read_timeout_secs: Option, + pub(super) http1: Option, pub(super) keep_alive: Option, + pub(super) listen: String, + pub(super) max_connections: Option, pub(super) max_headers: Option, pub(super) max_request_body_bytes: Option, - pub(super) max_connections: Option, - pub(super) header_read_timeout_secs: Option, pub(super) request_body_read_timeout_secs: Option, pub(super) response_write_timeout_secs: Option, - pub(super) access_log_format: Option, + pub(super) server_header: Option, pub(super) tls: Option, - pub(super) http1: Option, - pub(super) allow_missing_tls_identity: bool, + pub(super) trusted_proxies: Vec, } pub(super) fn compile_server_fields( diff --git a/crates/rginx-config/src/compile/server/listener.rs b/crates/rginx-config/src/compile/server/listener.rs index 10a4d2cf..b659b7d2 100644 --- a/crates/rginx-config/src/compile/server/listener.rs +++ b/crates/rginx-config/src/compile/server/listener.rs @@ -1,3 +1,6 @@ +mod ids; +mod vhost_binding; + use std::collections::{BTreeMap, HashSet, btree_map::Entry}; use std::net::SocketAddr; use std::path::{Path, PathBuf}; @@ -10,8 +13,6 @@ use super::CompiledServer; use super::fields::{ServerFieldConfig, compile_server_fields}; use super::http3::compile_http3; use super::listener_managed_identity::tls_identity_is_managed; -mod ids; -mod vhost_binding; use self::ids::{explicit_listener_id, vhost_listener_id, vhost_listener_name}; use self::vhost_binding::{VhostListenerBinding, validate_vhost_listener_merge}; @@ -51,22 +52,22 @@ pub(super) fn compile_legacy_server( .is_some_and(|tls| tls_identity_is_managed(tls, base_dir, managed_identity_pairs)); let compiled = compile_server_fields( ServerFieldConfig { - listen, - server_header, - default_certificate, - trusted_proxies, + access_log_format, + allow_missing_tls_identity, client_ip_header, + default_certificate, + header_read_timeout_secs, + http1, keep_alive, + listen, + max_connections, max_headers, max_request_body_bytes, - max_connections, - header_read_timeout_secs, request_body_read_timeout_secs, response_write_timeout_secs, - access_log_format, + server_header, tls, - http1, - allow_missing_tls_identity, + trusted_proxies, }, base_dir, )?; diff --git a/crates/rginx-config/src/compile/server/listener/vhost_binding.rs b/crates/rginx-config/src/compile/server/listener/vhost_binding.rs index dfee24ab..00ec5694 100644 --- a/crates/rginx-config/src/compile/server/listener/vhost_binding.rs +++ b/crates/rginx-config/src/compile/server/listener/vhost_binding.rs @@ -5,12 +5,12 @@ use rginx_core::{Error, Result}; use crate::model::Http3Config; pub(super) struct VhostListenerBinding { - pub(super) ssl: bool, - pub(super) http3: Option, - pub(super) proxy_protocol: bool, pub(super) default_certificate: Option, pub(super) default_server: bool, + pub(super) http3: Option, + pub(super) proxy_protocol: bool, pub(super) reuse_port: bool, + pub(super) ssl: bool, } pub(super) fn validate_vhost_listener_merge( diff --git a/crates/rginx-config/src/compile/server/tls.rs b/crates/rginx-config/src/compile/server/tls.rs index d0b7f70a..6a58cf4f 100644 --- a/crates/rginx-config/src/compile/server/tls.rs +++ b/crates/rginx-config/src/compile/server/tls.rs @@ -1,12 +1,12 @@ +mod identity; +mod policy; + use std::path::Path; use rginx_core::{Result, ServerTls, VirtualHostTls}; use crate::model::{ServerTlsConfig, VirtualHostTlsConfig}; -mod identity; -mod policy; - use identity::{compile_certificate_material, compile_client_auth_policy}; use policy::{ compile_alpn_protocols, compile_session_cache_size, compile_session_ticket_count, diff --git a/crates/rginx-config/src/compile/server/tls/identity.rs b/crates/rginx-config/src/compile/server/tls/identity.rs index 82202e2c..73186bfd 100644 --- a/crates/rginx-config/src/compile/server/tls/identity.rs +++ b/crates/rginx-config/src/compile/server/tls/identity.rs @@ -11,11 +11,11 @@ use crate::model::{ }; pub(super) struct CompiledCertificateMaterial { + pub(super) additional_certificates: Vec, pub(super) cert_path: PathBuf, pub(super) key_path: PathBuf, - pub(super) additional_certificates: Vec, - pub(super) ocsp_staple_path: Option, pub(super) ocsp: OcspConfig, + pub(super) ocsp_staple_path: Option, } pub(super) fn compile_certificate_material( @@ -56,11 +56,11 @@ pub(super) fn compile_certificate_material( .collect::>>()?; Ok(CompiledCertificateMaterial { + additional_certificates, cert_path, key_path, - additional_certificates, - ocsp_staple_path, ocsp, + ocsp_staple_path, }) } diff --git a/crates/rginx-config/src/compile/tests.rs b/crates/rginx-config/src/compile/tests.rs index 6e40d817..c8529734 100644 --- a/crates/rginx-config/src/compile/tests.rs +++ b/crates/rginx-config/src/compile/tests.rs @@ -1,3 +1,21 @@ +mod acme; +mod agent; +mod cache; +mod cache_p1; +mod cache_p2; +mod cache_p3; +mod control_plane; +mod http3; +mod listeners; +mod route; +mod server_settings; +mod server_tls; +mod upstream_defaults; +mod upstream_fallbacks; +mod upstream_server_name; +mod upstream_tls; +mod upstream_transport; +mod vhosts; use std::fs; use std::time::Duration; @@ -140,22 +158,3 @@ fn base_config() -> Config { servers: Vec::new(), } } - -mod acme; -mod agent; -mod cache; -mod cache_p1; -mod cache_p2; -mod cache_p3; -mod control_plane; -mod http3; -mod listeners; -mod route; -mod server_settings; -mod server_tls; -mod upstream_defaults; -mod upstream_fallbacks; -mod upstream_server_name; -mod upstream_tls; -mod upstream_transport; -mod vhosts; diff --git a/crates/rginx-config/src/compile/tests/acme.rs b/crates/rginx-config/src/compile/tests/acme.rs index ee3b1b62..0f32946f 100644 --- a/crates/rginx-config/src/compile/tests/acme.rs +++ b/crates/rginx-config/src/compile/tests/acme.rs @@ -187,8 +187,8 @@ fn compile_resolves_global_acme_settings_relative_to_base() { vec!["mailto:ops@example.com".to_string(), "mailto:security@example.com".to_string(),] ); assert_eq!(acme.state_dir, base_dir.path().join("state/acme")); - assert_eq!(acme.renew_before, Duration::from_secs(21 * 86_400)); - assert_eq!(acme.poll_interval, Duration::from_secs(600)); + assert_eq!(acme.renew_before, Duration::from_hours(504)); + assert_eq!(acme.poll_interval, Duration::from_mins(10)); } #[test] @@ -212,8 +212,8 @@ fn compile_emits_managed_certificate_specs_for_acme_vhosts() { compile_with_base(config, base_dir.path()).expect("managed ACME config should compile"); let acme = snapshot.acme.as_ref().expect("compiled snapshot should include ACME settings"); - assert_eq!(acme.renew_before, Duration::from_secs(30 * 86_400)); - assert_eq!(acme.poll_interval, Duration::from_secs(3600)); + assert_eq!(acme.renew_before, Duration::from_hours(720)); + assert_eq!(acme.poll_interval, Duration::from_hours(1)); assert_eq!(snapshot.managed_certificates.len(), 1); let spec = &snapshot.managed_certificates[0]; diff --git a/crates/rginx-config/src/compile/tests/cache.rs b/crates/rginx-config/src/compile/tests/cache.rs index f5241c8e..dbf0536a 100644 --- a/crates/rginx-config/src/compile/tests/cache.rs +++ b/crates/rginx-config/src/compile/tests/cache.rs @@ -142,7 +142,7 @@ fn compile_attaches_cache_zones_and_route_policy() { let zone = snapshot.cache_zones.get("default").expect("cache zone should compile"); assert_eq!(zone.path, base_dir.path().join("cache/default")); assert_eq!(zone.max_size_bytes, Some(1024 * 1024)); - assert_eq!(zone.inactive, Duration::from_secs(120)); + assert_eq!(zone.inactive, Duration::from_mins(2)); assert_eq!(zone.default_ttl, Duration::from_secs(30)); assert_eq!(zone.max_entry_bytes, 1024); assert!(zone.shared_index); diff --git a/crates/rginx-config/src/compile/tests/route.rs b/crates/rginx-config/src/compile/tests/route.rs index b2e37175..3f0760a2 100644 --- a/crates/rginx-config/src/compile/tests/route.rs +++ b/crates/rginx-config/src/compile/tests/route.rs @@ -1,10 +1,10 @@ -use super::*; - mod file_handler; mod preferred_prefix; mod proxy_rewrite; mod regex; +use super::*; + #[test] fn compile_attaches_route_access_control() { let config = Config { diff --git a/crates/rginx-config/src/compile/tests/server_settings.rs b/crates/rginx-config/src/compile/tests/server_settings.rs index 2f619ecf..a41bdc80 100644 --- a/crates/rginx-config/src/compile/tests/server_settings.rs +++ b/crates/rginx-config/src/compile/tests/server_settings.rs @@ -117,7 +117,7 @@ fn compile_normalizes_trusted_proxy_ips_and_cidrs() { let snapshot = compile(config).expect("trusted proxies should compile"); assert_eq!(default_listener_server(&snapshot).trusted_proxies.len(), 2); assert_eq!( - default_listener_server(&snapshot).client_ip_header.as_ref().map(|name| name.as_str()), + default_listener_server(&snapshot).client_ip_header.as_ref().map(http::HeaderName::as_str), Some("cf-connecting-ip") ); assert!(default_listener_server(&snapshot).is_trusted_proxy("10.1.2.3".parse().unwrap())); diff --git a/crates/rginx-config/src/compile/tests/vhosts.rs b/crates/rginx-config/src/compile/tests/vhosts.rs index 9cb6e428..3672fc3e 100644 --- a/crates/rginx-config/src/compile/tests/vhosts.rs +++ b/crates/rginx-config/src/compile/tests/vhosts.rs @@ -1,7 +1,7 @@ -use super::*; - mod listener_conflicts; +use super::*; + #[test] fn compile_generates_deduplicated_listeners_from_vhost_listen() { let base_dir = temp_base_dir("rginx-vhost-listen-test-"); @@ -72,7 +72,7 @@ fn compile_generates_deduplicated_listeners_from_vhost_listen() { Some("api.example.com") ); let http3 = snapshot.listeners[1].http3.as_ref().expect("http3 should compile"); - assert_eq!(http3.alt_svc_max_age, Duration::from_secs(7200)); + assert_eq!(http3.alt_svc_max_age, Duration::from_hours(2)); assert_eq!(snapshot.total_listener_binding_count(), 3); assert_eq!(snapshot.vhosts.len(), 2); assert!(snapshot.default_vhost.routes.is_empty()); diff --git a/crates/rginx-config/src/compile/upstream.rs b/crates/rginx-config/src/compile/upstream.rs index f5086696..d14e9b7f 100644 --- a/crates/rginx-config/src/compile/upstream.rs +++ b/crates/rginx-config/src/compile/upstream.rs @@ -1,3 +1,8 @@ +mod dns; +mod peer; +mod settings; +mod tls; + use std::collections::HashMap; use std::path::Path; use std::sync::Arc; @@ -7,16 +12,11 @@ use rginx_core::{Error, Result, Upstream, UpstreamSettings}; use crate::model::UpstreamConfig; -mod dns; -mod peer; -mod settings; -mod tls; - pub(super) fn compile_upstreams( raw_upstreams: Vec, base_dir: &Path, ) -> Result>> { - compile_upstreams_with_names(raw_upstreams, base_dir, |name| name.to_string()) + compile_upstreams_with_names(raw_upstreams, base_dir, std::string::ToString::to_string) } pub(super) fn compile_scoped_upstreams( diff --git a/crates/rginx-config/src/compile/upstream/tls.rs b/crates/rginx-config/src/compile/upstream/tls.rs index a82ca476..17500037 100644 --- a/crates/rginx-config/src/compile/upstream/tls.rs +++ b/crates/rginx-config/src/compile/upstream/tls.rs @@ -5,11 +5,11 @@ use rginx_core::{ClientIdentity, Error, Result, TlsVersion, UpstreamTls}; use crate::model::{TlsVersionConfig, UpstreamTlsConfig, UpstreamTlsModeConfig}; pub(super) struct CompiledUpstreamTls { - pub(super) verify_mode: UpstreamTls, - pub(super) tls_versions: Option>, - pub(super) server_verify_depth: Option, - pub(super) server_crl_path: Option, pub(super) client_identity: Option, + pub(super) server_crl_path: Option, + pub(super) server_verify_depth: Option, + pub(super) tls_versions: Option>, + pub(super) verify_mode: UpstreamTls, } pub(super) fn compile_tls( @@ -86,11 +86,11 @@ pub(super) fn compile_tls( }; Ok(CompiledUpstreamTls { - verify_mode, - tls_versions, - server_verify_depth, - server_crl_path, client_identity, + server_crl_path, + server_verify_depth, + tls_versions, + verify_mode, }) } diff --git a/crates/rginx-config/src/compile/vhost.rs b/crates/rginx-config/src/compile/vhost.rs index f37dc56f..22b43caf 100644 --- a/crates/rginx-config/src/compile/vhost.rs +++ b/crates/rginx-config/src/compile/vhost.rs @@ -9,9 +9,9 @@ use rginx_core::{ManagedCertificateSpec, Result, Upstream, VirtualHost}; use crate::model::VirtualHostConfig; pub(super) struct CompiledVirtualHost { - pub(super) vhost: VirtualHost, - pub(super) upstreams: HashMap>, pub(super) managed_certificate: Option, + pub(super) upstreams: HashMap>, + pub(super) vhost: VirtualHost, } pub(super) fn compile_virtual_host( diff --git a/crates/rginx-config/src/lib.rs b/crates/rginx-config/src/lib.rs index af5ba1e9..586a8956 100644 --- a/crates/rginx-config/src/lib.rs +++ b/crates/rginx-config/src/lib.rs @@ -1,3 +1,12 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] pub mod compile; pub mod load; pub mod managed; diff --git a/crates/rginx-config/src/listen.rs b/crates/rginx-config/src/listen.rs index e1d50937..a5cf05b7 100644 --- a/crates/rginx-config/src/listen.rs +++ b/crates/rginx-config/src/listen.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::net::{Ipv4Addr, SocketAddr}; use rginx_core::{Error, Result}; @@ -5,12 +7,12 @@ use rginx_core::{Error, Result}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) struct ParsedVhostListen { pub(crate) addr: SocketAddr, - pub(crate) ssl: bool, + pub(crate) default_server: bool, pub(crate) http2: bool, pub(crate) http3: bool, pub(crate) proxy_protocol: bool, - pub(crate) default_server: bool, pub(crate) reuse_port: bool, + pub(crate) ssl: bool, } pub(crate) fn parse_vhost_listen(owner_label: &str, raw: &str) -> Result { @@ -67,8 +69,7 @@ fn parse_listen_addr(owner_label: &str, value: &str) -> Result { let normalized = value .strip_prefix("*:") - .map(|port| format!("0.0.0.0:{port}")) - .unwrap_or_else(|| value.to_string()); + .map_or_else(|| value.to_string(), |port| format!("0.0.0.0:{port}")); normalized.parse::().map_err(|error| { Error::Config(format!("{owner_label} listen `{value}` is invalid: {error}")) @@ -80,6 +81,3 @@ fn parse_port(owner_label: &str, value: &str) -> Result { Error::Config(format!("{owner_label} listen port `{value}` is invalid: {error}")) }) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-config/src/load.rs b/crates/rginx-config/src/load.rs index 7bf2ebd7..4356a899 100644 --- a/crates/rginx-config/src/load.rs +++ b/crates/rginx-config/src/load.rs @@ -1,10 +1,3 @@ -use std::fs; -use std::path::Path; - -use rginx_core::Result; - -use crate::model::Config; - mod env_expand; mod include; mod layout; @@ -13,6 +6,13 @@ mod preprocess; #[cfg(test)] mod tests; +use std::fs; +use std::path::Path; + +use rginx_core::Result; + +use crate::model::Config; + #[cfg(test)] use env_expand::expand_env_placeholders_in_ron_strings; diff --git a/crates/rginx-config/src/load/env_expand.rs b/crates/rginx-config/src/load/env_expand.rs index 9bc7e773..6b86b486 100644 --- a/crates/rginx-config/src/load/env_expand.rs +++ b/crates/rginx-config/src/load/env_expand.rs @@ -19,14 +19,14 @@ pub(super) fn expand_env_placeholders_in_ron_strings( if ch == '"' { in_string = true; } - index += 1; + advance(&mut index, 1); continue; } if escaped { expanded.push(ch); escaped = false; - index += 1; + advance(&mut index, 1); continue; } @@ -34,36 +34,37 @@ pub(super) fn expand_env_placeholders_in_ron_strings( '\\' => { expanded.push(ch); escaped = true; - index += 1; + advance(&mut index, 1); } '"' => { expanded.push(ch); in_string = false; - index += 1; + advance(&mut index, 1); } - '$' if chars.get(index + 1) == Some(&'$') => { + '$' if chars.get(checked_add(index, 1)) == Some(&'$') => { expanded.push('$'); - index += 2; + advance(&mut index, 2); } - '$' if chars.get(index + 1) == Some(&'{') => { - let end = chars[index + 2..] + '$' if chars.get(checked_add(index, 1)) == Some(&'{') => { + let token_start = checked_add(index, 2); + let end = chars[token_start..] .iter() .position(|candidate| *candidate == '}') - .map(|offset| index + 2 + offset) + .map(|offset| checked_add(token_start, offset)) .ok_or_else(|| { Error::Config(format!( "unterminated environment placeholder in `{}`", source_path.display() )) })?; - let token = chars[index + 2..end].iter().collect::(); + let token = chars[token_start..end].iter().collect::(); let replacement = resolve_env_placeholder(&token, source_path)?; expanded.push_str(&escape_ron_string_fragment(&replacement)); - index = end + 1; + index = checked_add(end, 1); } _ => { expanded.push(ch); - index += 1; + advance(&mut index, 1); } } } @@ -112,3 +113,11 @@ fn escape_ron_string_fragment(value: &str) -> String { } escaped } + +fn advance(value: &mut usize, increment: usize) { + *value = checked_add(*value, increment); +} + +fn checked_add(value: usize, increment: usize) -> usize { + value.checked_add(increment).expect("RON environment expansion index remains representable") +} diff --git a/crates/rginx-config/src/load/layout.rs b/crates/rginx-config/src/load/layout.rs index 1fd794a6..bc80facf 100644 --- a/crates/rginx-config/src/load/layout.rs +++ b/crates/rginx-config/src/load/layout.rs @@ -1,13 +1,14 @@ +mod array_rules; +mod scanner; + use std::path::Path; use rginx_core::Result; -const CANONICAL_ROOT_CONFIG_NAME: &str = "rginx.ron"; -mod array_rules; -mod scanner; - use array_rules::{validate_empty_root_array_field, validate_servers_root_array_field}; +const CANONICAL_ROOT_CONFIG_NAME: &str = "rginx.ron"; + pub(super) fn validate_canonical_root_layout(contents: &str, source_path: &Path) -> Result<()> { if source_path.file_name().and_then(|name| name.to_str()) != Some(CANONICAL_ROOT_CONFIG_NAME) { return Ok(()); diff --git a/crates/rginx-config/src/load/layout/array_rules.rs b/crates/rginx-config/src/load/layout/array_rules.rs index 7ef86d10..0dc8d0d3 100644 --- a/crates/rginx-config/src/load/layout/array_rules.rs +++ b/crates/rginx-config/src/load/layout/array_rules.rs @@ -51,13 +51,13 @@ fn has_disallowed_array_tokens(contents: &str, allow_include_comments: bool) -> while i < bytes.len() { match bytes[i] { b' ' | b'\t' | b'\r' | b'\n' | b',' => { - i += 1; + advance(&mut i, 1); } - b'/' if bytes.get(i + 1) == Some(&b'/') => { + b'/' if bytes.get(checked_add(i, 1)) == Some(&b'/') => { let comment_start = i; - i += 2; + advance(&mut i, 2); while i < bytes.len() && bytes[i] != b'\n' { - i += 1; + advance(&mut i, 1); } let comment = &contents[comment_start..i]; if !allow_include_comments && comment.trim_start().starts_with("// @include ") { @@ -71,6 +71,14 @@ fn has_disallowed_array_tokens(contents: &str, allow_include_comments: bool) -> false } +fn advance(value: &mut usize, increment: usize) { + *value = checked_add(*value, increment); +} + +fn checked_add(value: usize, increment: usize) -> usize { + value.checked_add(increment).expect("RON array scanner index remains representable") +} + fn array_inner(value: &str) -> Option<&str> { let trimmed = value.trim(); let inner = trimmed.strip_prefix('[')?.strip_suffix(']')?; diff --git a/crates/rginx-config/src/load/layout/scanner.rs b/crates/rginx-config/src/load/layout/scanner.rs index 259c43bd..f89d8d94 100644 --- a/crates/rginx-config/src/load/layout/scanner.rs +++ b/crates/rginx-config/src/load/layout/scanner.rs @@ -4,7 +4,7 @@ pub(super) fn find_top_level_field_value<'a>( ) -> Option<&'a str> { let root_start = find_root_config_open(contents)?; let bytes = contents.as_bytes(); - let mut i = root_start + 1; + let mut i = checked_add(root_start, 1); let mut in_string = false; let mut escaped = false; let mut in_comment = false; @@ -19,7 +19,7 @@ pub(super) fn find_top_level_field_value<'a>( if byte == b'\n' { in_comment = false; } - i += 1; + advance(&mut i, 1); continue; } @@ -31,33 +31,33 @@ pub(super) fn find_top_level_field_value<'a>( } else if byte == b'"' { in_string = false; } - i += 1; + advance(&mut i, 1); continue; } - if byte == b'/' && bytes.get(i + 1) == Some(&b'/') { + if byte == b'/' && bytes.get(checked_add(i, 1)) == Some(&b'/') { in_comment = true; - i += 2; + advance(&mut i, 2); continue; } if byte == b'"' { in_string = true; - i += 1; + advance(&mut i, 1); continue; } if paren_depth == 0 && bracket_depth == 0 && brace_depth == 0 && is_identifier_start(byte) { let name_start = i; - i += 1; + advance(&mut i, 1); while i < bytes.len() && is_identifier_continue(bytes[i]) { - i += 1; + advance(&mut i, 1); } let name = &contents[name_start..i]; let mut cursor = skip_ascii_whitespace(contents, i); if bytes.get(cursor) == Some(&b':') { - cursor = skip_ascii_whitespace(contents, cursor + 1); + cursor = skip_ascii_whitespace(contents, checked_add(cursor, 1)); let value_end = scan_top_level_value_end(contents, cursor); if name == field_name { return Some(contents[cursor..value_end].trim()); @@ -68,21 +68,21 @@ pub(super) fn find_top_level_field_value<'a>( } match byte { - b'(' => paren_depth += 1, + b'(' => advance(&mut paren_depth, 1), b')' => { if paren_depth == 0 { break; } - paren_depth -= 1; + paren_depth = paren_depth.saturating_sub(1); } - b'[' => bracket_depth += 1, + b'[' => advance(&mut bracket_depth, 1), b']' => bracket_depth = bracket_depth.saturating_sub(1), - b'{' => brace_depth += 1, + b'{' => advance(&mut brace_depth, 1), b'}' => brace_depth = brace_depth.saturating_sub(1), _ => {} } - i += 1; + advance(&mut i, 1); } None @@ -102,7 +102,7 @@ fn find_root_config_open(contents: &str) -> Option { if byte == b'\n' { in_comment = false; } - i += 1; + advance(&mut i, 1); continue; } @@ -114,34 +114,35 @@ fn find_root_config_open(contents: &str) -> Option { } else if byte == b'"' { in_string = false; } - i += 1; + advance(&mut i, 1); continue; } - if byte == b'/' && bytes.get(i + 1) == Some(&b'/') { + if byte == b'/' && bytes.get(checked_add(i, 1)) == Some(&b'/') { in_comment = true; - i += 2; + advance(&mut i, 2); continue; } if byte == b'"' { in_string = true; - i += 1; + advance(&mut i, 1); continue; } + let config_end = checked_add(i, 6); if byte == b'C' - && bytes.get(i..i + 6) == Some(b"Config") + && bytes.get(i..config_end) == Some(b"Config") && !bytes.get(i.wrapping_sub(1)).is_some_and(|prev| is_identifier_continue(*prev)) - && !bytes.get(i + 6).is_some_and(|next| is_identifier_continue(*next)) + && !bytes.get(config_end).is_some_and(|next| is_identifier_continue(*next)) { - let cursor = skip_ascii_whitespace(contents, i + 6); + let cursor = skip_ascii_whitespace(contents, config_end); if bytes.get(cursor) == Some(&b'(') { return Some(cursor); } } - i += 1; + advance(&mut i, 1); } None @@ -164,7 +165,7 @@ fn scan_top_level_value_end(contents: &str, start: usize) -> usize { if byte == b'\n' { in_comment = false; } - i += 1; + advance(&mut i, 1); continue; } @@ -176,39 +177,39 @@ fn scan_top_level_value_end(contents: &str, start: usize) -> usize { } else if byte == b'"' { in_string = false; } - i += 1; + advance(&mut i, 1); continue; } - if byte == b'/' && bytes.get(i + 1) == Some(&b'/') { + if byte == b'/' && bytes.get(checked_add(i, 1)) == Some(&b'/') { in_comment = true; - i += 2; + advance(&mut i, 2); continue; } if byte == b'"' { in_string = true; - i += 1; + advance(&mut i, 1); continue; } match byte { - b'(' => paren_depth += 1, + b'(' => advance(&mut paren_depth, 1), b')' => { if paren_depth == 0 && bracket_depth == 0 && brace_depth == 0 { return i; } paren_depth = paren_depth.saturating_sub(1); } - b'[' => bracket_depth += 1, + b'[' => advance(&mut bracket_depth, 1), b']' => bracket_depth = bracket_depth.saturating_sub(1), - b'{' => brace_depth += 1, + b'{' => advance(&mut brace_depth, 1), b'}' => brace_depth = brace_depth.saturating_sub(1), b',' if paren_depth == 0 && bracket_depth == 0 && brace_depth == 0 => return i, _ => {} } - i += 1; + advance(&mut i, 1); } i @@ -216,12 +217,20 @@ fn scan_top_level_value_end(contents: &str, start: usize) -> usize { fn skip_ascii_whitespace(contents: &str, mut index: usize) -> usize { let bytes = contents.as_bytes(); - while bytes.get(index).is_some_and(|byte| byte.is_ascii_whitespace()) { - index += 1; + while bytes.get(index).is_some_and(u8::is_ascii_whitespace) { + advance(&mut index, 1); } index } +fn advance(value: &mut usize, increment: usize) { + *value = checked_add(*value, increment); +} + +fn checked_add(value: usize, increment: usize) -> usize { + value.checked_add(increment).expect("RON layout scanner index remains representable") +} + fn is_identifier_start(byte: u8) -> bool { byte.is_ascii_alphabetic() || byte == b'_' } diff --git a/crates/rginx-config/src/load/tests.rs b/crates/rginx-config/src/load/tests.rs index bf21229f..a0aba54b 100644 --- a/crates/rginx-config/src/load/tests.rs +++ b/crates/rginx-config/src/load/tests.rs @@ -7,6 +7,64 @@ use proptest::prelude::*; use super::{expand_env_placeholders_in_ron_strings, load_from_path, load_from_str}; +proptest! { + #![proptest_config(ProptestConfig::with_cases(64))] + + #[test] + fn env_placeholder_expansion_leaves_placeholder_free_sources_unchanged( + source in prop::collection::vec( + any::().prop_filter("source must not contain `$`", |ch| *ch != '$'), + 0..128, + ) + .prop_map(|chars| chars.into_iter().collect::()) + ) { + let expanded = expand_env_placeholders_in_ron_strings(&source, Path::new("inline.ron")) + .expect("placeholder-free source should not fail"); + + prop_assert_eq!(expanded, source); + } + + #[test] + fn load_from_str_round_trips_arbitrary_env_string_values(value in arbitrary_env_value()) { + let _guard = env_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); + let _scoped_env = ScopedEnvVar::set("rginx_test_prop_body", &value); + + let config = load_from_str( + "Config(\n runtime: RuntimeConfig(\n shutdown_timeout_secs: 2,\n ),\n server: ServerConfig(\n listen: \"127.0.0.1:18080\",\n ),\n upstreams: [],\n locations: [\n LocationConfig(\n matcher: Exact(\"/\"),\n handler: Return(\n status: 200,\n location: \"\",\n body: Some(\"${rginx_test_prop_body}\"),\n ),\n ),\n ],\n)\n", + Path::new("inline.ron"), + ) + .expect("config should load with arbitrary env expansion"); + + match &config.locations[0].handler { + crate::model::HandlerConfig::Return { body, .. } => { + prop_assert_eq!(body.as_deref(), Some(value.as_str())); + } + _ => panic!("expected return handler"), + } + } +} + +struct ScopedEnvVar { + key: &'static str, +} + +impl ScopedEnvVar { + fn set(key: &'static str, value: &str) -> Self { + unsafe { + std::env::set_var(key, value); + } + Self { key } + } +} + +impl Drop for ScopedEnvVar { + fn drop(&mut self) { + unsafe { + std::env::remove_var(self.key); + } + } +} + #[test] fn load_from_str_deserializes_cache_zones_and_route_cache_policy() { let config = load_from_str( @@ -31,7 +89,7 @@ fn load_from_str_deserializes_cache_zones_and_route_cache_policy() { #[test] fn load_from_str_expands_environment_placeholders_inside_strings() { - let _guard = env_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = env_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); unsafe { std::env::set_var("rginx_test_listen", "127.0.0.1:19090"); std::env::set_var("rginx_test_body", "hello \"env\"\n"); @@ -59,7 +117,7 @@ fn load_from_str_expands_environment_placeholders_inside_strings() { #[test] fn load_from_str_supports_env_defaults_and_literal_dollar_escape() { - let _guard = env_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = env_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); unsafe { std::env::remove_var("rginx_test_missing"); } @@ -103,7 +161,7 @@ fn load_from_str_supports_legacy_and_structured_upstream_tls_config() { #[test] fn load_from_str_rejects_missing_environment_placeholders() { - let _guard = env_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let _guard = env_lock().lock().unwrap_or_else(std::sync::PoisonError::into_inner); unsafe { std::env::remove_var("rginx_test_required"); } @@ -247,43 +305,6 @@ fn load_from_path_rejects_root_inline_locations_in_canonical_root_config() { let _ = fs::remove_dir_all(temp_dir); } -proptest! { - #![proptest_config(ProptestConfig::with_cases(64))] - - #[test] - fn env_placeholder_expansion_leaves_placeholder_free_sources_unchanged( - source in prop::collection::vec( - any::().prop_filter("source must not contain `$`", |ch| *ch != '$'), - 0..128, - ) - .prop_map(|chars| chars.into_iter().collect::()) - ) { - let expanded = expand_env_placeholders_in_ron_strings(&source, Path::new("inline.ron")) - .expect("placeholder-free source should not fail"); - - prop_assert_eq!(expanded, source); - } - - #[test] - fn load_from_str_round_trips_arbitrary_env_string_values(value in arbitrary_env_value()) { - let _guard = env_lock().lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - let _scoped_env = ScopedEnvVar::set("rginx_test_prop_body", &value); - - let config = load_from_str( - "Config(\n runtime: RuntimeConfig(\n shutdown_timeout_secs: 2,\n ),\n server: ServerConfig(\n listen: \"127.0.0.1:18080\",\n ),\n upstreams: [],\n locations: [\n LocationConfig(\n matcher: Exact(\"/\"),\n handler: Return(\n status: 200,\n location: \"\",\n body: Some(\"${rginx_test_prop_body}\"),\n ),\n ),\n ],\n)\n", - Path::new("inline.ron"), - ) - .expect("config should load with arbitrary env expansion"); - - match &config.locations[0].handler { - crate::model::HandlerConfig::Return { body, .. } => { - prop_assert_eq!(body.as_deref(), Some(value.as_str())); - } - _ => panic!("expected return handler"), - } - } -} - fn temp_dir(prefix: &str) -> PathBuf { let unique = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -306,24 +327,3 @@ fn arbitrary_env_value() -> impl Strategy { ) .prop_map(|chars| chars.into_iter().collect::()) } - -struct ScopedEnvVar { - key: &'static str, -} - -impl ScopedEnvVar { - fn set(key: &'static str, value: &str) -> Self { - unsafe { - std::env::set_var(key, value); - } - Self { key } - } -} - -impl Drop for ScopedEnvVar { - fn drop(&mut self) { - unsafe { - std::env::remove_var(self.key); - } - } -} diff --git a/crates/rginx-config/src/managed/mod.rs b/crates/rginx-config/src/managed/mod.rs index bd02f383..384d7468 100644 --- a/crates/rginx-config/src/managed/mod.rs +++ b/crates/rginx-config/src/managed/mod.rs @@ -1,3 +1,9 @@ +mod normalize; +mod paths; +#[cfg(test)] +mod tests; +mod types; + use std::collections::HashMap; use std::fs; use std::path::Path; @@ -6,12 +12,6 @@ use rginx_core::{Error, Result}; use crate::model::{Config, LocationConfig}; -mod normalize; -mod paths; -#[cfg(test)] -mod tests; -mod types; - pub use self::paths::{managed_resource_path_for_config, managed_root_for_config}; pub use self::types::{ AppliedManagedMutation, ManagedResourceDocument, ManagedResourceKind, ManagedResourceMetadata, diff --git a/crates/rginx-config/src/managed/normalize.rs b/crates/rginx-config/src/managed/normalize.rs index 75015658..e3a2ea8b 100644 --- a/crates/rginx-config/src/managed/normalize.rs +++ b/crates/rginx-config/src/managed/normalize.rs @@ -82,7 +82,7 @@ fn normalize_target( ))); } - Ok(Some(ManagedRouteTarget { server_name, managed_virtual_host_id })) + Ok(Some(ManagedRouteTarget { managed_virtual_host_id, server_name })) } ManagedResourceKind::VirtualHost | ManagedResourceKind::Upstream => { if target.is_some() { diff --git a/crates/rginx-config/src/managed/paths.rs b/crates/rginx-config/src/managed/paths.rs index 4a486ae3..9a448f34 100644 --- a/crates/rginx-config/src/managed/paths.rs +++ b/crates/rginx-config/src/managed/paths.rs @@ -51,7 +51,7 @@ pub fn managed_resource_path_for_config( } pub(super) fn encode_id_for_path(id: &str) -> String { - let mut encoded = String::with_capacity(id.len() * 3); + let mut encoded = String::with_capacity(id.len().saturating_mul(3)); for byte in id.bytes() { let ch = byte as char; if ch.is_ascii_alphanumeric() || matches!(ch, '-' | '_' | '.') { diff --git a/crates/rginx-config/src/managed/types.rs b/crates/rginx-config/src/managed/types.rs index a3614d76..3a1de4b5 100644 --- a/crates/rginx-config/src/managed/types.rs +++ b/crates/rginx-config/src/managed/types.rs @@ -8,9 +8,9 @@ use super::normalize::normalize_document; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ManagedResourceKind { - VirtualHost, - Upstream, Route, + Upstream, + VirtualHost, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -24,10 +24,10 @@ pub struct ManagedResourceMetadata { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ManagedRouteTarget { - #[serde(default)] - pub server_name: Option, #[serde(default)] pub managed_virtual_host_id: Option, + #[serde(default)] + pub server_name: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -35,14 +35,18 @@ pub struct ManagedResourceDocument { pub api_version: String, pub kind: ManagedResourceKind, pub metadata: ManagedResourceMetadata, + pub spec: serde_json::Value, #[serde(default)] pub target: Option, - pub spec: serde_json::Value, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "operation", rename_all = "snake_case")] pub enum ManagedResourceMutation { + Delete { + kind: ManagedResourceKind, + id: String, + }, Upsert { kind: ManagedResourceKind, metadata: ManagedResourceMetadata, @@ -50,27 +54,16 @@ pub enum ManagedResourceMutation { target: Option, spec: serde_json::Value, }, - Delete { - kind: ManagedResourceKind, - id: String, - }, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum AppliedManagedMutation { - Upsert { document: ManagedResourceDocument, previous: Option }, Delete { kind: ManagedResourceKind, removed: ManagedResourceDocument }, + Upsert { document: ManagedResourceDocument, previous: Option }, } impl ManagedResourceKind { - pub fn label(self) -> &'static str { - match self { - Self::VirtualHost => "virtual_host", - Self::Upstream => "upstream", - Self::Route => "route", - } - } - + #[must_use] pub fn directory_name(self) -> &'static str { match self { Self::VirtualHost => "vhosts", @@ -79,6 +72,15 @@ impl ManagedResourceKind { } } + #[must_use] + pub fn label(self) -> &'static str { + match self { + Self::VirtualHost => "virtual_host", + Self::Upstream => "upstream", + Self::Route => "route", + } + } + pub(super) fn ordered() -> [Self; 3] { [Self::VirtualHost, Self::Upstream, Self::Route] } @@ -103,6 +105,7 @@ impl ManagedResourceDocument { } impl ManagedResourceMutation { + #[must_use] pub fn operation_label(&self) -> &'static str { match self { Self::Upsert { .. } => "upsert", @@ -112,13 +115,7 @@ impl ManagedResourceMutation { } impl AppliedManagedMutation { - pub fn operation_label(&self) -> &'static str { - match self { - Self::Upsert { .. } => "upsert", - Self::Delete { .. } => "delete", - } - } - + #[must_use] pub fn kind(&self) -> ManagedResourceKind { match self { Self::Upsert { document, .. } => document.kind, @@ -126,13 +123,18 @@ impl AppliedManagedMutation { } } - pub fn resource_id(&self) -> &str { + pub fn managed_path_for_config(&self, config_path: impl AsRef) -> Result { + super::managed_resource_path_for_config(config_path, self.kind(), self.resource_id()) + } + #[must_use] + pub fn operation_label(&self) -> &'static str { match self { - Self::Upsert { document, .. } => &document.metadata.id, - Self::Delete { removed, .. } => &removed.metadata.id, + Self::Upsert { .. } => "upsert", + Self::Delete { .. } => "delete", } } + #[must_use] pub fn owner(&self) -> Option<&str> { match self { Self::Upsert { document, .. } | Self::Delete { removed: document, .. } => { @@ -141,6 +143,15 @@ impl AppliedManagedMutation { } } + #[must_use] + pub fn resource_id(&self) -> &str { + match self { + Self::Upsert { document, .. } => &document.metadata.id, + Self::Delete { removed, .. } => &removed.metadata.id, + } + } + + #[must_use] pub fn tenant(&self) -> Option<&str> { match self { Self::Upsert { document, .. } | Self::Delete { removed: document, .. } => { @@ -148,8 +159,4 @@ impl AppliedManagedMutation { } } } - - pub fn managed_path_for_config(&self, config_path: impl AsRef) -> Result { - super::managed_resource_path_for_config(config_path, self.kind(), self.resource_id()) - } } diff --git a/crates/rginx-config/src/model.rs b/crates/rginx-config/src/model.rs index 4921558d..459f95c6 100644 --- a/crates/rginx-config/src/model.rs +++ b/crates/rginx-config/src/model.rs @@ -1,5 +1,3 @@ -use serde::Deserialize; - mod acme; mod agent; mod cache; @@ -12,6 +10,8 @@ mod tls; mod upstream; mod vhost; +use serde::Deserialize; + pub use acme::{AcmeChallengeConfig, AcmeConfig, VirtualHostAcmeConfig}; pub use agent::AgentConfig; pub use cache::{ @@ -41,22 +41,22 @@ pub use vhost::VirtualHostConfig; #[derive(Debug, Clone, Deserialize)] pub struct Config { - pub runtime: RuntimeConfig, - #[serde(default)] - pub agent: Option, - #[serde(default)] - pub control_plane: Option, #[serde(default)] pub acme: Option, #[serde(default)] - pub listeners: Vec, + pub agent: Option, #[serde(default)] pub cache_zones: Vec, - pub server: ServerConfig, #[serde(default)] - pub upstreams: Vec, + pub control_plane: Option, + #[serde(default)] + pub listeners: Vec, #[serde(default)] pub locations: Vec, + pub runtime: RuntimeConfig, + pub server: ServerConfig, #[serde(default)] pub servers: Vec, + #[serde(default)] + pub upstreams: Vec, } diff --git a/crates/rginx-config/src/model/acme.rs b/crates/rginx-config/src/model/acme.rs index 87a60dda..bb30d29b 100644 --- a/crates/rginx-config/src/model/acme.rs +++ b/crates/rginx-config/src/model/acme.rs @@ -2,22 +2,22 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] pub struct AcmeConfig { - pub directory_url: String, #[serde(default)] pub contacts: Vec, - pub state_dir: String, - #[serde(default)] - pub renew_before_days: Option, + pub directory_url: String, #[serde(default)] pub poll_interval_secs: Option, + #[serde(default)] + pub renew_before_days: Option, + pub state_dir: String, } #[derive(Debug, Clone, Deserialize)] pub struct VirtualHostAcmeConfig { - #[serde(default)] - pub domains: Vec, #[serde(default)] pub challenge: Option, + #[serde(default)] + pub domains: Vec, } #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Hash)] diff --git a/crates/rginx-config/src/model/agent.rs b/crates/rginx-config/src/model/agent.rs index b886ffc3..97fa6296 100644 --- a/crates/rginx-config/src/model/agent.rs +++ b/crates/rginx-config/src/model/agent.rs @@ -9,32 +9,32 @@ use serde::Deserialize; /// node-side management port. #[derive(Debug, Clone, Deserialize, Default)] pub struct AgentConfig { + #[serde(default)] + pub backoff_initial_ms: Option, + #[serde(default)] + pub backoff_max_secs: Option, + #[serde(default)] + pub connect_timeout_secs: Option, #[serde(default)] pub enabled: Option, #[serde(default)] pub endpoint: Option, #[serde(default)] - pub node_id: Option, + pub heartbeat_interval_secs: Option, #[serde(default)] - pub token_path: Option, + pub labels: BTreeMap, #[serde(default)] - pub state_path: Option, + pub node_id: Option, #[serde(default)] - pub region: Option, + pub poll_timeout_secs: Option, #[serde(default)] pub pop: Option, #[serde(default)] - pub labels: BTreeMap, - #[serde(default)] - pub heartbeat_interval_secs: Option, - #[serde(default)] - pub connect_timeout_secs: Option, + pub region: Option, #[serde(default)] pub request_timeout_secs: Option, #[serde(default)] - pub poll_timeout_secs: Option, - #[serde(default)] - pub backoff_initial_ms: Option, + pub state_path: Option, #[serde(default)] - pub backoff_max_secs: Option, + pub token_path: Option, } diff --git a/crates/rginx-config/src/model/cache.rs b/crates/rginx-config/src/model/cache.rs index 8332a1cd..23653a63 100644 --- a/crates/rginx-config/src/model/cache.rs +++ b/crates/rginx-config/src/model/cache.rs @@ -2,18 +2,12 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] pub struct CacheZoneConfig { - pub name: String, - pub path: String, - #[serde(default)] - pub max_size_bytes: Option, - #[serde(default)] - pub inactive_secs: Option, #[serde(default)] pub default_ttl_secs: Option, #[serde(default)] - pub max_entry_bytes: Option, + pub inactive_cleanup_interval_secs: Option, #[serde(default)] - pub path_levels: Option>, + pub inactive_secs: Option, #[serde(default)] pub loader_batch_entries: Option, #[serde(default)] @@ -23,52 +17,58 @@ pub struct CacheZoneConfig { #[serde(default)] pub manager_sleep_millis: Option, #[serde(default)] - pub inactive_cleanup_interval_secs: Option, + pub max_entry_bytes: Option, + #[serde(default)] + pub max_size_bytes: Option, + pub name: String, + pub path: String, + #[serde(default)] + pub path_levels: Option>, #[serde(default)] pub shared_index: Option, } #[derive(Debug, Clone, Deserialize)] pub struct CacheRouteConfig { - pub zone: String, #[serde(default)] - pub methods: Option>, - #[serde(default)] - pub statuses: Option>, - #[serde(default)] - pub ttl_secs_by_status: Option>, - #[serde(default)] - pub key: Option, + pub background_update: Option, #[serde(default)] pub cache_bypass: Option, #[serde(default)] - pub no_cache: Option, - #[serde(default)] - pub stale_if_error_secs: Option, + pub convert_head: Option, #[serde(default)] pub grace_secs: Option, #[serde(default)] - pub keep_secs: Option, + pub ignore_headers: Option>, #[serde(default)] - pub pass_ttl_secs: Option, + pub keep_secs: Option, #[serde(default)] - pub use_stale: Option>, + pub key: Option, #[serde(default)] - pub background_update: Option, + pub lock_age_secs: Option, #[serde(default)] pub lock_timeout_secs: Option, #[serde(default)] - pub lock_age_secs: Option, + pub methods: Option>, #[serde(default)] pub min_uses: Option, #[serde(default)] - pub ignore_headers: Option>, + pub no_cache: Option, + #[serde(default)] + pub pass_ttl_secs: Option, #[serde(default)] pub range_requests: Option, #[serde(default)] pub slice_size_bytes: Option, #[serde(default)] - pub convert_head: Option, + pub stale_if_error_secs: Option, + #[serde(default)] + pub statuses: Option>, + #[serde(default)] + pub ttl_secs_by_status: Option>, + #[serde(default)] + pub use_stale: Option>, + pub zone: String, } #[derive(Debug, Clone, Deserialize)] @@ -79,16 +79,16 @@ pub struct CacheStatusTtlConfig { #[derive(Debug, Clone, Deserialize)] pub enum CachePredicateConfig { - Any(Vec), All(Vec), - Not(Box), - Method(String), - HeaderExists(String), + Any(Vec), + CookieEquals { name: String, value: String }, + CookieExists(String), HeaderEquals { name: String, value: String }, - QueryExists(String), + HeaderExists(String), + Method(String), + Not(Box), QueryEquals { name: String, value: String }, - CookieExists(String), - CookieEquals { name: String, value: String }, + QueryExists(String), Status(u16), Statuses(Vec), } @@ -96,8 +96,6 @@ pub enum CachePredicateConfig { #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] pub enum CacheUseStaleConditionConfig { Error, - Timeout, - Updating, Http403, Http404, Http429, @@ -105,15 +103,17 @@ pub enum CacheUseStaleConditionConfig { Http502, Http503, Http504, + Timeout, + Updating, } #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] pub enum CacheIgnoreHeaderConfig { - XAccelExpires, - Expires, CacheControl, + Expires, SetCookie, Vary, + XAccelExpires, } #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] diff --git a/crates/rginx-config/src/model/control_plane.rs b/crates/rginx-config/src/model/control_plane.rs index 7fba222a..030b0c4c 100644 --- a/crates/rginx-config/src/model/control_plane.rs +++ b/crates/rginx-config/src/model/control_plane.rs @@ -10,31 +10,31 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize, Default)] pub struct ControlPlaneConfig { #[serde(default)] - pub enabled: Option, + pub allowed_cidrs: Vec, #[serde(default)] - pub listen: Option, + pub api_keys_path: Option, #[serde(default)] - pub tls: Option, + pub enabled: Option, #[serde(default)] - pub allowed_cidrs: Vec, + pub labels: BTreeMap, #[serde(default)] - pub api_keys_path: Option, + pub listen: Option, #[serde(default)] pub node_id: Option, #[serde(default)] - pub region: Option, - #[serde(default)] pub pop: Option, #[serde(default)] - pub labels: BTreeMap, + pub region: Option, + #[serde(default)] + pub tls: Option, } #[derive(Debug, Clone, Deserialize)] pub struct ControlPlaneTlsConfig { pub cert_path: String, - pub key_path: String, #[serde(default)] pub client_ca_path: Option, + pub key_path: String, #[serde(default)] pub require_client_cert: Option, } diff --git a/crates/rginx-config/src/model/listener.rs b/crates/rginx-config/src/model/listener.rs index e0c744aa..f43630d4 100644 --- a/crates/rginx-config/src/model/listener.rs +++ b/crates/rginx-config/src/model/listener.rs @@ -6,18 +6,18 @@ use super::ServerTlsConfig; pub struct Http1Config { #[serde(default)] pub half_close: Option, - // When combined with preserve_header_case, hyper preserves known original - // casing first and falls back to title-cased serialization otherwise. #[serde(default)] - pub title_case_headers: Option, + pub max_buf_size_bytes: Option, + #[serde(default)] + pub pipeline_flush: Option, // When combined with title_case_headers, preserved original casing wins for // headers that have a captured case map, while others still title-case. #[serde(default)] pub preserve_header_case: Option, + // When combined with preserve_header_case, hyper preserves known original + // casing first and falls back to title-cased serialization otherwise. #[serde(default)] - pub max_buf_size_bytes: Option, - #[serde(default)] - pub pipeline_flush: Option, + pub title_case_headers: Option, #[serde(default)] pub writev: Option, } @@ -25,61 +25,61 @@ pub struct Http1Config { #[derive(Debug, Clone, Deserialize, Default, PartialEq, Eq)] pub struct Http3Config { #[serde(default)] - pub listen: Option, + pub active_connection_id_limit: Option, #[serde(default)] pub advertise_alt_svc: Option, #[serde(default)] pub alt_svc_max_age_secs: Option, #[serde(default)] - pub max_concurrent_streams: Option, + pub early_data: Option, #[serde(default)] - pub stream_buffer_size_bytes: Option, + pub gso: Option, #[serde(default)] - pub active_connection_id_limit: Option, + pub host_key_path: Option, #[serde(default)] - pub retry: Option, + pub listen: Option, #[serde(default)] - pub host_key_path: Option, + pub max_concurrent_streams: Option, #[serde(default)] - pub gso: Option, + pub retry: Option, #[serde(default)] - pub early_data: Option, + pub stream_buffer_size_bytes: Option, } #[derive(Debug, Clone, Deserialize)] pub struct ListenerConfig { - pub name: String, - pub listen: String, #[serde(default)] - pub server_header: Option, + pub access_log_format: Option, #[serde(default)] - pub proxy_protocol: Option, + pub client_ip_header: Option, #[serde(default)] pub default_certificate: Option, #[serde(default)] - pub trusted_proxies: Vec, + pub header_read_timeout_secs: Option, #[serde(default)] - pub client_ip_header: Option, + pub http1: Option, + #[serde(default)] + pub http3: Option, #[serde(default)] pub keep_alive: Option, + pub listen: String, + #[serde(default)] + pub max_connections: Option, #[serde(default)] pub max_headers: Option, #[serde(default)] pub max_request_body_bytes: Option, + pub name: String, #[serde(default)] - pub max_connections: Option, - #[serde(default)] - pub header_read_timeout_secs: Option, + pub proxy_protocol: Option, #[serde(default)] pub request_body_read_timeout_secs: Option, #[serde(default)] pub response_write_timeout_secs: Option, #[serde(default)] - pub access_log_format: Option, + pub server_header: Option, #[serde(default)] pub tls: Option, #[serde(default)] - pub http1: Option, - #[serde(default)] - pub http3: Option, + pub trusted_proxies: Vec, } diff --git a/crates/rginx-config/src/model/route.rs b/crates/rginx-config/src/model/route.rs index 49e91ccb..6569f64e 100644 --- a/crates/rginx-config/src/model/route.rs +++ b/crates/rginx-config/src/model/route.rs @@ -8,66 +8,66 @@ use super::CacheRouteConfig; pub enum RouteBufferingPolicyConfig { #[default] Auto, - On, Off, + On, } #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Default)] pub enum RouteCompressionPolicyConfig { - Off, #[default] Auto, Force, + Off, } #[derive(Debug, Clone, Deserialize)] pub struct LocationConfig { - pub matcher: MatcherConfig, - pub handler: HandlerConfig, #[serde(default)] - pub internal: Option, + pub allow_cidrs: Vec, #[serde(default)] - pub rewrite_rules: Vec, + pub allow_early_data: Option, #[serde(default)] - pub try_files: Vec, + pub burst: Option, #[serde(default)] - pub error_pages: Vec, + pub cache: Option, #[serde(default)] - pub grpc_service: Option, + pub compression: Option, #[serde(default)] - pub grpc_method: Option, + pub compression_content_types: Option>, #[serde(default)] - pub allow_cidrs: Vec, + pub compression_min_bytes: Option, #[serde(default)] pub deny_cidrs: Vec, #[serde(default)] - pub requests_per_sec: Option, + pub error_pages: Vec, #[serde(default)] - pub burst: Option, + pub grpc_method: Option, #[serde(default)] - pub allow_early_data: Option, + pub grpc_service: Option, + pub handler: HandlerConfig, #[serde(default)] - pub request_buffering: Option, + pub internal: Option, + pub matcher: MatcherConfig, #[serde(default)] - pub response_buffering: Option, + pub request_buffering: Option, #[serde(default)] - pub compression: Option, + pub requests_per_sec: Option, #[serde(default)] - pub compression_min_bytes: Option, + pub response_buffering: Option, #[serde(default)] - pub compression_content_types: Option>, + pub rewrite_rules: Vec, #[serde(default)] pub streaming_response_idle_timeout_secs: Option, #[serde(default)] - pub cache: Option, + pub try_files: Vec, } #[derive(Debug, Clone, Deserialize)] pub enum MatcherConfig { Exact(String), + Named(String), PreferredPrefix(String), Prefix(String), - Named(String), Regex { pattern: String, #[serde(default)] @@ -77,21 +77,6 @@ pub enum MatcherConfig { #[derive(Debug, Clone, Deserialize)] pub enum HandlerConfig { - Proxy { - upstream: String, - #[serde(default)] - preserve_host: Option, - #[serde(default)] - strip_prefix: Option, - #[serde(default)] - proxy_pass_uri: Option, - #[serde(default)] - proxy_http_version: Option, - #[serde(default)] - proxy_redirect: Option, - #[serde(default)] - proxy_set_headers: HashMap, - }, File { #[serde(default)] root: Option, @@ -122,6 +107,21 @@ pub enum HandlerConfig { #[serde(default)] follow_symlinks: Option, }, + Proxy { + upstream: String, + #[serde(default)] + preserve_host: Option, + #[serde(default)] + strip_prefix: Option, + #[serde(default)] + proxy_pass_uri: Option, + #[serde(default)] + proxy_http_version: Option, + #[serde(default)] + proxy_redirect: Option, + #[serde(default)] + proxy_set_headers: HashMap, + }, Return { status: u16, location: String, @@ -140,9 +140,9 @@ pub struct RewriteRuleConfig { #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] pub enum RewriteStopConfig { + Break, Continue, Last, - Break, } #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq)] @@ -159,8 +159,8 @@ pub enum ProxyRedirectModeConfig { #[derive(Debug, Clone, Deserialize)] pub enum TryFileStepConfig { - Path(String), Named(String), + Path(String), Status(u16), } @@ -172,9 +172,9 @@ pub struct ErrorPageConfig { #[derive(Debug, Clone, Deserialize)] pub enum ErrorPageTargetConfig { - Uri(String), Named(String), Status(u16), + Uri(String), } #[derive(Debug, Clone, Deserialize)] @@ -191,13 +191,13 @@ pub enum ProxyHeaderValueConfig { #[derive(Debug, Clone, Deserialize)] pub enum ProxyHeaderDynamicValueConfig { - Host, - Scheme, ClientIp, - RemoteAddr, - PeerAddr, ForwardedFor, + Host, + PeerAddr, + RemoteAddr, + Remove, RequestHeader(String), + Scheme, Template(String), - Remove, } diff --git a/crates/rginx-config/src/model/runtime.rs b/crates/rginx-config/src/model/runtime.rs index 4adb633f..a22c759a 100644 --- a/crates/rginx-config/src/model/runtime.rs +++ b/crates/rginx-config/src/model/runtime.rs @@ -2,9 +2,9 @@ use serde::Deserialize; #[derive(Debug, Clone, Deserialize)] pub struct RuntimeConfig { + #[serde(default)] + pub accept_workers: Option, pub shutdown_timeout_secs: u64, #[serde(default)] pub worker_threads: Option, - #[serde(default)] - pub accept_workers: Option, } diff --git a/crates/rginx-config/src/model/server.rs b/crates/rginx-config/src/model/server.rs index 530b3911..dc8cf0a4 100644 --- a/crates/rginx-config/src/model/server.rs +++ b/crates/rginx-config/src/model/server.rs @@ -4,24 +4,24 @@ use super::{Http1Config, Http3Config, ServerTlsConfig}; #[derive(Debug, Clone)] pub struct ServerConfig { - pub listen: Option, - pub server_header: Option, - pub proxy_protocol: Option, - pub default_certificate: Option, - pub server_names: Vec, - pub trusted_proxies: Vec, + pub access_log_format: Option, pub client_ip_header: Option, + pub default_certificate: Option, + pub header_read_timeout_secs: Option, + pub http1: Option, + pub http3: Option, pub keep_alive: Option, + pub listen: Option, + pub max_connections: Option, pub max_headers: Option, pub max_request_body_bytes: Option, - pub max_connections: Option, - pub header_read_timeout_secs: Option, + pub proxy_protocol: Option, pub request_body_read_timeout_secs: Option, pub response_write_timeout_secs: Option, - pub access_log_format: Option, + pub server_header: Option, + pub server_names: Vec, pub tls: Option, - pub http1: Option, - pub http3: Option, + pub trusted_proxies: Vec, } impl<'de> Deserialize<'de> for ServerConfig { @@ -33,63 +33,63 @@ impl<'de> Deserialize<'de> for ServerConfig { #[serde(rename = "ServerConfig")] struct ServerConfigDe { #[serde(default)] - listen: MaybeString, - #[serde(default)] - server_header: Option, + access_log_format: Option, #[serde(default)] - proxy_protocol: Option, + client_ip_header: Option, #[serde(default)] default_certificate: Option, #[serde(default)] - server_names: Vec, + header_read_timeout_secs: Option, #[serde(default)] - trusted_proxies: Vec, + http1: Option, #[serde(default)] - client_ip_header: Option, + http3: Option, #[serde(default)] keep_alive: Option, #[serde(default)] + listen: MaybeString, + #[serde(default)] + max_connections: Option, + #[serde(default)] max_headers: Option, #[serde(default)] max_request_body_bytes: Option, #[serde(default)] - max_connections: Option, - #[serde(default)] - header_read_timeout_secs: Option, + proxy_protocol: Option, #[serde(default)] request_body_read_timeout_secs: Option, #[serde(default)] response_write_timeout_secs: Option, #[serde(default)] - access_log_format: Option, + server_header: Option, #[serde(default)] - tls: Option, + server_names: Vec, #[serde(default)] - http1: Option, + tls: Option, #[serde(default)] - http3: Option, + trusted_proxies: Vec, } let server = ServerConfigDe::deserialize(deserializer)?; Ok(ServerConfig { - listen: server.listen.0, - server_header: server.server_header, - proxy_protocol: server.proxy_protocol, - default_certificate: server.default_certificate, - server_names: server.server_names, - trusted_proxies: server.trusted_proxies, + access_log_format: server.access_log_format, client_ip_header: server.client_ip_header, + default_certificate: server.default_certificate, + header_read_timeout_secs: server.header_read_timeout_secs, + http1: server.http1, + http3: server.http3, keep_alive: server.keep_alive, + listen: server.listen.0, + max_connections: server.max_connections, max_headers: server.max_headers, max_request_body_bytes: server.max_request_body_bytes, - max_connections: server.max_connections, - header_read_timeout_secs: server.header_read_timeout_secs, + proxy_protocol: server.proxy_protocol, request_body_read_timeout_secs: server.request_body_read_timeout_secs, response_write_timeout_secs: server.response_write_timeout_secs, - access_log_format: server.access_log_format, + server_header: server.server_header, + server_names: server.server_names, tls: server.tls, - http1: server.http1, - http3: server.http3, + trusted_proxies: server.trusted_proxies, }) } } @@ -105,13 +105,13 @@ impl<'de> Deserialize<'de> for MaybeString { #[derive(Deserialize)] #[serde(untagged)] enum StringOrOption { - String(String), Option(Option), + String(String), } Ok(match StringOrOption::deserialize(deserializer)? { - StringOrOption::String(value) => Self(Some(value)), StringOrOption::Option(value) => Self(value), + StringOrOption::String(value) => Self(Some(value)), }) } } diff --git a/crates/rginx-config/src/model/tls.rs b/crates/rginx-config/src/model/tls.rs index 2ce84752..66795872 100644 --- a/crates/rginx-config/src/model/tls.rs +++ b/crates/rginx-config/src/model/tls.rs @@ -3,42 +3,42 @@ use serde::{Deserialize, Deserializer, de}; #[derive(Debug, Clone, Deserialize)] pub struct ServerTlsConfig { - pub cert_path: String, - pub key_path: String, #[serde(default)] pub additional_certificates: Option>, #[serde(default)] - pub versions: Option>, + pub alpn_protocols: Option>, + pub cert_path: String, #[serde(default)] pub cipher_suites: Option>, #[serde(default)] + pub client_auth: Option, + #[serde(default)] pub key_exchange_groups: Option>, + pub key_path: String, #[serde(default)] - pub alpn_protocols: Option>, + pub ocsp: Option, #[serde(default)] pub ocsp_staple_path: Option, #[serde(default)] - pub ocsp: Option, + pub session_cache_size: Option, #[serde(default)] pub session_resumption: Option, #[serde(default)] - pub session_tickets: Option, - #[serde(default)] - pub session_cache_size: Option, - #[serde(default)] pub session_ticket_count: Option, #[serde(default)] - pub client_auth: Option, + pub session_tickets: Option, + #[serde(default)] + pub versions: Option>, } #[derive(Debug, Clone)] pub struct VirtualHostTlsConfig { + pub acme: Option, + pub additional_certificates: Option>, pub cert_path: String, pub key_path: String, - pub additional_certificates: Option>, - pub ocsp_staple_path: Option, pub ocsp: Option, - pub acme: Option, + pub ocsp_staple_path: Option, } #[derive(Debug, Clone, Deserialize)] @@ -46,9 +46,9 @@ pub struct ServerCertificateBundleConfig { pub cert_path: String, pub key_path: String, #[serde(default)] - pub ocsp_staple_path: Option, - #[serde(default)] pub ocsp: Option, + #[serde(default)] + pub ocsp_staple_path: Option, } #[derive(Debug, Clone, Deserialize)] @@ -80,26 +80,26 @@ pub enum TlsVersionConfig { #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Hash)] pub enum TlsCipherSuiteConfig { - Tls13Aes256GcmSha384, Tls13Aes128GcmSha256, + Tls13Aes256GcmSha384, Tls13Chacha20Poly1305Sha256, - TlsEcdheEcdsaWithAes256GcmSha384, TlsEcdheEcdsaWithAes128GcmSha256, + TlsEcdheEcdsaWithAes256GcmSha384, TlsEcdheEcdsaWithChacha20Poly1305Sha256, - TlsEcdheRsaWithAes256GcmSha384, TlsEcdheRsaWithAes128GcmSha256, + TlsEcdheRsaWithAes256GcmSha384, TlsEcdheRsaWithChacha20Poly1305Sha256, } #[derive(Debug, Clone, Copy, Deserialize, PartialEq, Eq, Hash)] pub enum TlsKeyExchangeGroupConfig { - X25519, + Mlkem1024, + Mlkem768, Secp256r1, + Secp256r1Mlkem768, Secp384r1, + X25519, X25519Mlkem768, - Secp256r1Mlkem768, - Mlkem768, - Mlkem1024, } #[derive(Debug, Clone, Deserialize, PartialEq, Eq)] @@ -110,12 +110,12 @@ pub enum ServerClientAuthModeConfig { #[derive(Debug, Clone, Deserialize)] pub struct ServerClientAuthConfig { - pub mode: ServerClientAuthModeConfig, pub ca_cert_path: String, #[serde(default)] - pub verify_depth: Option, - #[serde(default)] pub crl_path: Option, + pub mode: ServerClientAuthModeConfig, + #[serde(default)] + pub verify_depth: Option, } impl<'de> Deserialize<'de> for VirtualHostTlsConfig { @@ -125,47 +125,47 @@ impl<'de> Deserialize<'de> for VirtualHostTlsConfig { { #[derive(Debug, Clone, Deserialize)] enum VirtualHostTlsConfigDe { - VirtualHostTlsConfig { - cert_path: String, - key_path: String, - #[serde(default)] - additional_certificates: Option>, - #[serde(default)] - ocsp_staple_path: Option, - #[serde(default)] - ocsp: Option, + ServerTlsConfig { #[serde(default)] acme: Option, - }, - ServerTlsConfig { - cert_path: String, - key_path: String, #[serde(default)] additional_certificates: Option>, #[serde(default)] - versions: Option>, + alpn_protocols: Option>, + cert_path: String, #[serde(default)] cipher_suites: Option>, #[serde(default)] + client_auth: Option, + #[serde(default)] key_exchange_groups: Option>, + key_path: String, #[serde(default)] - alpn_protocols: Option>, + ocsp: Option, #[serde(default)] ocsp_staple_path: Option, #[serde(default)] - ocsp: Option, + session_cache_size: Option, #[serde(default)] session_resumption: Option, #[serde(default)] + session_ticket_count: Option, + #[serde(default)] session_tickets: Option, #[serde(default)] - session_cache_size: Option, + versions: Option>, + }, + VirtualHostTlsConfig { #[serde(default)] - session_ticket_count: Option, + acme: Option, #[serde(default)] - client_auth: Option, + additional_certificates: Option>, + cert_path: String, + key_path: String, #[serde(default)] - acme: Option, + ocsp: Option, + #[serde(default)] + ocsp_staple_path: Option, }, } @@ -178,12 +178,12 @@ impl<'de> Deserialize<'de> for VirtualHostTlsConfig { ocsp, acme, } => Ok(Self { + acme, + additional_certificates, cert_path, key_path, - additional_certificates, - ocsp_staple_path, ocsp, - acme, + ocsp_staple_path, }), VirtualHostTlsConfigDe::ServerTlsConfig { cert_path, diff --git a/crates/rginx-config/src/model/upstream.rs b/crates/rginx-config/src/model/upstream.rs index b1174e42..d3b0cf02 100644 --- a/crates/rginx-config/src/model/upstream.rs +++ b/crates/rginx-config/src/model/upstream.rs @@ -4,132 +4,128 @@ use super::TlsVersionConfig; #[derive(Debug, Clone, Deserialize)] pub struct UpstreamConfig { - pub name: String, - pub peers: Vec, - pub tls: Option, + #[serde(default)] + pub connect_timeout_secs: Option, #[serde(default)] pub dns: Option, #[serde(default)] - pub protocol: UpstreamProtocolConfig, + pub health_check_grpc_service: Option, #[serde(default)] - pub load_balance: UpstreamLoadBalanceConfig, + pub health_check_interval_secs: Option, #[serde(default)] - pub server_name: Option, - #[serde(default, alias = "tls_server_name")] - pub server_name_override: Option, + pub health_check_path: Option, #[serde(default)] - pub request_timeout_secs: Option, + pub health_check_timeout_secs: Option, #[serde(default)] - pub connect_timeout_secs: Option, + pub healthy_successes_required: Option, #[serde(default)] - pub read_timeout_secs: Option, + pub http2_keep_alive_interval_secs: Option, #[serde(default)] - pub write_timeout_secs: Option, + pub http2_keep_alive_timeout_secs: Option, + #[serde(default)] + pub http2_keep_alive_while_idle: Option, #[serde(default)] pub idle_timeout_secs: Option, #[serde(default)] + pub load_balance: UpstreamLoadBalanceConfig, + #[serde(default)] + pub max_replayable_request_body_bytes: Option, + pub name: String, + pub peers: Vec, + #[serde(default)] pub pool_idle_timeout_secs: Option, #[serde(default)] pub pool_max_idle_per_host: Option, #[serde(default)] - pub tcp_keepalive_secs: Option, + pub protocol: UpstreamProtocolConfig, #[serde(default)] - pub tcp_nodelay: Option, + pub read_timeout_secs: Option, #[serde(default)] - pub http2_keep_alive_interval_secs: Option, + pub request_timeout_secs: Option, #[serde(default)] - pub http2_keep_alive_timeout_secs: Option, + pub server_name: Option, + #[serde(default, alias = "tls_server_name")] + pub server_name_override: Option, #[serde(default)] - pub http2_keep_alive_while_idle: Option, + pub tcp_keepalive_secs: Option, #[serde(default)] - pub max_replayable_request_body_bytes: Option, + pub tcp_nodelay: Option, + pub tls: Option, #[serde(default)] pub unhealthy_after_failures: Option, #[serde(default)] pub unhealthy_cooldown_secs: Option, #[serde(default)] - pub health_check_path: Option, - #[serde(default)] - pub health_check_grpc_service: Option, - #[serde(default)] - pub health_check_interval_secs: Option, - #[serde(default)] - pub health_check_timeout_secs: Option, - #[serde(default)] - pub healthy_successes_required: Option, + pub write_timeout_secs: Option, } #[derive(Debug, Clone, Deserialize, Default)] pub struct UpstreamDnsConfig { #[serde(default)] - pub resolver_addrs: Vec, + pub max_ttl_secs: Option, #[serde(default)] pub min_ttl_secs: Option, #[serde(default)] - pub max_ttl_secs: Option, - #[serde(default)] pub negative_ttl_secs: Option, #[serde(default)] - pub stale_if_error_secs: Option, + pub prefer_ipv4: Option, + #[serde(default)] + pub prefer_ipv6: Option, #[serde(default)] pub refresh_before_expiry_secs: Option, #[serde(default)] - pub prefer_ipv4: Option, + pub resolver_addrs: Vec, #[serde(default)] - pub prefer_ipv6: Option, + pub stale_if_error_secs: Option, } #[derive(Debug, Clone, Deserialize)] pub struct UpstreamPeerConfig { - pub url: String, - #[serde(default = "default_upstream_peer_weight")] - pub weight: u32, #[serde(default)] pub backup: bool, #[serde(default)] pub max_conns: Option, -} - -const fn default_upstream_peer_weight() -> u32 { - 1 + pub url: String, + #[serde(default = "default_upstream_peer_weight")] + pub weight: u32, } #[derive(Debug, Clone, Deserialize, Default)] pub enum UpstreamTlsModeConfig { - #[default] - NativeRoots, CustomCa { ca_cert_path: String, }, Insecure, + #[default] + NativeRoots, } #[derive(Debug, Clone)] pub struct UpstreamTlsConfig { - pub verify: UpstreamTlsModeConfig, - pub versions: Option>, - pub verify_depth: Option, - pub crl_path: Option, pub client_cert_path: Option, pub client_key_path: Option, + pub crl_path: Option, + pub verify: UpstreamTlsModeConfig, + pub verify_depth: Option, + pub versions: Option>, } #[derive(Debug, Clone, Deserialize, Default)] pub enum UpstreamProtocolConfig { #[default] Auto, + H2c, Http1, Http2, - H2c, Http3, } #[derive(Debug, Clone, Deserialize, Default)] pub enum UpstreamLoadBalanceConfig { - #[default] - RoundRobin, IpHash, LeastConn, + #[default] + RoundRobin, } impl<'de> Deserialize<'de> for UpstreamTlsConfig { @@ -139,24 +135,24 @@ impl<'de> Deserialize<'de> for UpstreamTlsConfig { { #[derive(Debug, Clone, Deserialize)] enum UpstreamTlsConfigDe { - NativeRoots, CustomCa { ca_cert_path: String, }, Insecure, + NativeRoots, UpstreamTlsConfig { - #[serde(default, alias = "verify")] - verify: UpstreamTlsModeConfig, #[serde(default)] - versions: Option>, + client_cert_path: Option, #[serde(default)] - verify_depth: Option, + client_key_path: Option, #[serde(default)] crl_path: Option, + #[serde(default, alias = "verify")] + verify: UpstreamTlsModeConfig, #[serde(default)] - client_cert_path: Option, + verify_depth: Option, #[serde(default)] - client_key_path: Option, + versions: Option>, }, } @@ -193,8 +189,12 @@ impl<'de> Deserialize<'de> for UpstreamTlsConfig { client_cert_path, client_key_path, } => { - Self { verify, versions, verify_depth, crl_path, client_cert_path, client_key_path } + Self { client_cert_path, client_key_path, crl_path, verify, verify_depth, versions } } }) } } + +const fn default_upstream_peer_weight() -> u32 { + 1 +} diff --git a/crates/rginx-config/src/model/vhost.rs b/crates/rginx-config/src/model/vhost.rs index 15fd0345..8e97f8a8 100644 --- a/crates/rginx-config/src/model/vhost.rs +++ b/crates/rginx-config/src/model/vhost.rs @@ -4,18 +4,18 @@ use super::{Http3Config, LocationConfig, UpstreamConfig, VirtualHostTlsConfig}; #[derive(Debug, Clone, Deserialize)] pub struct VirtualHostConfig { + #[serde(default)] + pub http3: Option, #[serde(default, deserialize_with = "deserialize_string_list")] pub listen: Vec, - #[serde(default, alias = "server_name", deserialize_with = "deserialize_string_list")] - pub server_names: Vec, - #[serde(default)] - pub upstreams: Vec, #[serde(default)] pub locations: Vec, + #[serde(default, alias = "server_name", deserialize_with = "deserialize_string_list")] + pub server_names: Vec, #[serde(default)] pub tls: Option, #[serde(default)] - pub http3: Option, + pub upstreams: Vec, } fn deserialize_string_list<'de, D>(deserializer: D) -> Result, D::Error> @@ -25,12 +25,12 @@ where #[derive(Deserialize)] #[serde(untagged)] enum StringOrList { - String(String), List(Vec), + String(String), } Ok(match StringOrList::deserialize(deserializer)? { - StringOrList::String(value) => vec![value], StringOrList::List(values) => values, + StringOrList::String(value) => vec![value], }) } diff --git a/crates/rginx-config/src/validate.rs b/crates/rginx-config/src/validate.rs index 11930304..6255af32 100644 --- a/crates/rginx-config/src/validate.rs +++ b/crates/rginx-config/src/validate.rs @@ -1,9 +1,3 @@ -use std::collections::HashSet; - -use rginx_core::{Error, Result}; - -use crate::model::{Config, LocationConfig, RouteBufferingPolicyConfig}; - mod acme; mod agent; mod cache; @@ -14,6 +8,14 @@ mod server; mod upstream; mod vhost; +#[cfg(test)] +mod tests; +use std::collections::HashSet; + +use rginx_core::{Error, Result}; + +use crate::model::{Config, LocationConfig, RouteBufferingPolicyConfig}; + const DEFAULT_GRPC_HEALTH_CHECK_PATH: &str = "/grpc.health.v1.Health/Check"; pub fn validate(config: &Config) -> Result<()> { @@ -90,6 +92,3 @@ fn validate_request_buffering_limits(config: &Config) -> Result<()> { fn route_uses_forced_request_buffering(location: &LocationConfig) -> bool { matches!(location.request_buffering, Some(RouteBufferingPolicyConfig::On)) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-config/src/validate/cache.rs b/crates/rginx-config/src/validate/cache.rs index 0e570186..ea77a71a 100644 --- a/crates/rginx-config/src/validate/cache.rs +++ b/crates/rginx-config/src/validate/cache.rs @@ -1,11 +1,11 @@ +mod predicate; + use std::collections::HashSet; use rginx_core::{Error, Result}; use crate::model::{CacheRouteConfig, CacheStatusTtlConfig, CacheZoneConfig}; -mod predicate; - use predicate::{PredicateValidationMode, validate_cache_predicate}; pub(super) fn validate_cache_zones(zones: &[CacheZoneConfig]) -> Result> { @@ -65,8 +65,7 @@ pub(super) fn validate_route_cache( } if !cache_zone_names.contains(zone) { return Err(Error::Config(format!( - "{route_scope} references undefined cache zone `{}`", - zone + "{route_scope} references undefined cache zone `{zone}`" ))); } diff --git a/crates/rginx-config/src/validate/cache/predicate.rs b/crates/rginx-config/src/validate/cache/predicate.rs index a9fd35e9..f520db2b 100644 --- a/crates/rginx-config/src/validate/cache/predicate.rs +++ b/crates/rginx-config/src/validate/cache/predicate.rs @@ -3,6 +3,12 @@ use rginx_core::{Error, Result}; use crate::model::CachePredicateConfig; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum PredicateValidationMode { + RequestOnly, + RequestOrResponse, +} + pub(super) fn validate_cache_predicate( route_scope: &str, field: &str, @@ -79,9 +85,3 @@ fn validate_header_name(route_scope: &str, field: &str, name: &str) -> Result<() })?; Ok(()) } - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(super) enum PredicateValidationMode { - RequestOnly, - RequestOrResponse, -} diff --git a/crates/rginx-config/src/validate/route.rs b/crates/rginx-config/src/validate/route.rs index 6a3b7ec9..982dcebf 100644 --- a/crates/rginx-config/src/validate/route.rs +++ b/crates/rginx-config/src/validate/route.rs @@ -1,3 +1,6 @@ +mod handler; +mod matcher; + use std::collections::HashSet; use ipnet::IpNet; @@ -9,9 +12,6 @@ use crate::model::{ RouteCompressionPolicyConfig, }; -mod handler; -mod matcher; - use matcher::{exact_route_key, matcher_label, route_scope}; pub(super) fn validate_locations( diff --git a/crates/rginx-config/src/validate/route/handler.rs b/crates/rginx-config/src/validate/route/handler.rs index 7cddfb77..eb3c7df4 100644 --- a/crates/rginx-config/src/validate/route/handler.rs +++ b/crates/rginx-config/src/validate/route/handler.rs @@ -1,3 +1,5 @@ +mod file; + use std::collections::HashSet; use rginx_core::{Error, ProxyHeaderTemplate, Result}; @@ -6,8 +8,6 @@ use crate::model::{ HandlerConfig, MatcherConfig, ProxyHeaderDynamicValueConfig, ProxyHeaderValueConfig, }; -mod file; - use file::validate_file_handler; pub(super) fn validate_handler( diff --git a/crates/rginx-config/src/validate/server.rs b/crates/rginx-config/src/validate/server.rs index 7fa69a61..b42eb623 100644 --- a/crates/rginx-config/src/validate/server.rs +++ b/crates/rginx-config/src/validate/server.rs @@ -1,13 +1,13 @@ -use rginx_core::Result; - -use crate::model::{Http3Config, ListenerConfig, ServerConfig, ServerTlsConfig, VirtualHostConfig}; - mod http3; mod listener; mod names; mod proxies; mod tls; +use rginx_core::Result; + +use crate::model::{Http3Config, ListenerConfig, ServerConfig, ServerTlsConfig, VirtualHostConfig}; + /// Validates the legacy top-level `server` block. pub(super) fn validate_server(server: &ServerConfig) -> Result<()> { listener::validate_server(server) diff --git a/crates/rginx-config/src/validate/server/listener.rs b/crates/rginx-config/src/validate/server/listener.rs index 6668fff3..ecdc9138 100644 --- a/crates/rginx-config/src/validate/server/listener.rs +++ b/crates/rginx-config/src/validate/server/listener.rs @@ -1,10 +1,10 @@ +mod base; +mod listeners; + use rginx_core::Result; use crate::model::{ListenerConfig, ServerConfig, VirtualHostConfig}; -mod base; -mod listeners; - pub(super) fn validate_server(server: &ServerConfig) -> Result<()> { base::validate_server(server) } diff --git a/crates/rginx-config/src/validate/server/listener/base.rs b/crates/rginx-config/src/validate/server/listener/base.rs index 66b25507..4abd8091 100644 --- a/crates/rginx-config/src/validate/server/listener/base.rs +++ b/crates/rginx-config/src/validate/server/listener/base.rs @@ -11,6 +11,27 @@ use super::super::tls::{ validate_tls_key_exchange_groups, validate_tls_versions, }; +pub(super) struct ListenerLikeRef<'a> { + pub(super) access_log_format: Option<&'a str>, + pub(super) client_ip_header: Option<&'a str>, + pub(super) default_certificate: Option<&'a str>, + pub(super) header_read_timeout_secs: Option, + pub(super) http1: Option<&'a Http1Config>, + pub(super) http3: Option<&'a Http3Config>, + pub(super) listen: Option<&'a str>, + pub(super) max_connections: Option, + pub(super) max_headers: Option, + pub(super) max_request_body_bytes: Option, + pub(super) owner_label: &'a str, + pub(super) proxy_protocol: Option, + pub(super) request_body_read_timeout_secs: Option, + pub(super) require_listen: bool, + pub(super) response_write_timeout_secs: Option, + pub(super) server_header: Option<&'a str>, + pub(super) tls: Option<&'a ServerTlsConfig>, + pub(super) trusted_proxies: &'a [String], +} + pub(super) fn validate_server(server: &ServerConfig) -> Result<()> { validate_listener_like(ListenerLikeRef { owner_label: "server", @@ -36,27 +57,6 @@ pub(super) fn validate_server(server: &ServerConfig) -> Result<()> { Ok(()) } -pub(super) struct ListenerLikeRef<'a> { - pub(super) owner_label: &'a str, - pub(super) listen: Option<&'a str>, - pub(super) server_header: Option<&'a str>, - pub(super) proxy_protocol: Option, - pub(super) default_certificate: Option<&'a str>, - pub(super) trusted_proxies: &'a [String], - pub(super) client_ip_header: Option<&'a str>, - pub(super) max_headers: Option, - pub(super) max_request_body_bytes: Option, - pub(super) max_connections: Option, - pub(super) header_read_timeout_secs: Option, - pub(super) request_body_read_timeout_secs: Option, - pub(super) response_write_timeout_secs: Option, - pub(super) access_log_format: Option<&'a str>, - pub(super) tls: Option<&'a ServerTlsConfig>, - pub(super) http1: Option<&'a Http1Config>, - pub(super) http3: Option<&'a Http3Config>, - pub(super) require_listen: bool, -} - pub(super) fn validate_listener_like(config: ListenerLikeRef<'_>) -> Result<()> { if config.require_listen { let listen = config.listen.unwrap_or_default().trim(); diff --git a/crates/rginx-config/src/validate/server/listener/listeners.rs b/crates/rginx-config/src/validate/server/listener/listeners.rs index efc05dad..c8ff6ed8 100644 --- a/crates/rginx-config/src/validate/server/listener/listeners.rs +++ b/crates/rginx-config/src/validate/server/listener/listeners.rs @@ -7,6 +7,13 @@ use crate::model::{Http3Config, ListenerConfig, ServerConfig, VirtualHostConfig} use super::base::{ListenerLikeRef, legacy_server_listener_fields, validate_listener_like}; +#[derive(Clone)] +struct VhostListenerBinding { + http3: Option, + proxy_protocol: bool, + ssl: bool, +} + pub(super) fn validate_listeners( listeners: &[ListenerConfig], server: &ServerConfig, @@ -128,13 +135,6 @@ fn validate_server_fields_for_vhost_listen(server: &ServerConfig) -> Result<()> Ok(()) } -#[derive(Clone)] -struct VhostListenerBinding { - ssl: bool, - proxy_protocol: bool, - http3: Option, -} - fn validate_vhost_listener_bindings(vhosts: &[VirtualHostConfig]) -> Result<()> { let mut tcp_bindings = HashMap::::new(); diff --git a/crates/rginx-config/src/validate/tests.rs b/crates/rginx-config/src/validate/tests.rs index 2a20f613..bb33d8fd 100644 --- a/crates/rginx-config/src/validate/tests.rs +++ b/crates/rginx-config/src/validate/tests.rs @@ -1,15 +1,3 @@ -use crate::model::{ - AgentConfig, CacheRouteConfig, CacheZoneConfig, Config, ControlPlaneConfig, - ControlPlaneTlsConfig, HandlerConfig, Http1Config, Http3Config, ListenerConfig, LocationConfig, - MatcherConfig, ProxyHeaderDynamicValueConfig, ProxyHeaderValueConfig, - RouteBufferingPolicyConfig, RouteCompressionPolicyConfig, RuntimeConfig, ServerConfig, - ServerTlsConfig, TlsCipherSuiteConfig, TlsKeyExchangeGroupConfig, TlsVersionConfig, - UpstreamConfig, UpstreamLoadBalanceConfig, UpstreamPeerConfig, UpstreamProtocolConfig, - VirtualHostConfig, VirtualHostTlsConfig, -}; - -use super::{DEFAULT_GRPC_HEALTH_CHECK_PATH, validate}; - mod acme; mod agent; mod cache; @@ -26,6 +14,18 @@ mod upstream_health; mod upstream_tls; mod vhosts; +use crate::model::{ + AgentConfig, CacheRouteConfig, CacheZoneConfig, Config, ControlPlaneConfig, + ControlPlaneTlsConfig, HandlerConfig, Http1Config, Http3Config, ListenerConfig, LocationConfig, + MatcherConfig, ProxyHeaderDynamicValueConfig, ProxyHeaderValueConfig, + RouteBufferingPolicyConfig, RouteCompressionPolicyConfig, RuntimeConfig, ServerConfig, + ServerTlsConfig, TlsCipherSuiteConfig, TlsKeyExchangeGroupConfig, TlsVersionConfig, + UpstreamConfig, UpstreamLoadBalanceConfig, UpstreamPeerConfig, UpstreamProtocolConfig, + VirtualHostConfig, VirtualHostTlsConfig, +}; + +use super::{DEFAULT_GRPC_HEALTH_CHECK_PATH, validate}; + fn base_config() -> Config { Config { agent: None, diff --git a/crates/rginx-config/src/validate/tests/control_plane.rs b/crates/rginx-config/src/validate/tests/control_plane.rs index 786bbc2e..a7ee4663 100644 --- a/crates/rginx-config/src/validate/tests/control_plane.rs +++ b/crates/rginx-config/src/validate/tests/control_plane.rs @@ -149,7 +149,7 @@ fn validate_rejects_blank_node_identity_fields_and_labels() { node_id: Some(" ".to_string()), region: None, pop: None, - labels: [("".to_string(), "edge".to_string())].into_iter().collect(), + labels: [(String::new(), "edge".to_string())].into_iter().collect(), }); let error = validate(&config).expect_err("blank node identity fields should be rejected"); diff --git a/crates/rginx-config/src/validate/tests/vhosts.rs b/crates/rginx-config/src/validate/tests/vhosts.rs index e1f13be1..44b518f1 100644 --- a/crates/rginx-config/src/validate/tests/vhosts.rs +++ b/crates/rginx-config/src/validate/tests/vhosts.rs @@ -1,7 +1,7 @@ -use super::*; - mod listen; +use super::*; + #[test] fn validate_rejects_empty_default_server_name() { let mut config = base_config(); diff --git a/crates/rginx-config/src/validate/upstream.rs b/crates/rginx-config/src/validate/upstream.rs index 2361730c..1a87cc56 100644 --- a/crates/rginx-config/src/validate/upstream.rs +++ b/crates/rginx-config/src/validate/upstream.rs @@ -1,3 +1,10 @@ +mod basics; +mod dns; +mod health; +mod protocol; +mod tls; +mod tuning; + use std::collections::HashSet; use rginx_core::{Error, Result}; @@ -6,13 +13,6 @@ use crate::model::{ TlsVersionConfig, UpstreamConfig, UpstreamProtocolConfig, UpstreamTlsModeConfig, }; -mod basics; -mod dns; -mod health; -mod protocol; -mod tls; -mod tuning; - pub(super) fn validate_upstreams(upstreams: &[UpstreamConfig]) -> Result> { let mut upstream_names = HashSet::new(); diff --git a/crates/rginx-config/src/validate/upstream/basics.rs b/crates/rginx-config/src/validate/upstream/basics.rs index 37165c7f..7f88e7e7 100644 --- a/crates/rginx-config/src/validate/upstream/basics.rs +++ b/crates/rginx-config/src/validate/upstream/basics.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Error, HashSet, Result, UpstreamConfig}; pub(super) fn validate_upstream_name_and_peers( upstream: &UpstreamConfig, diff --git a/crates/rginx-config/src/validate/upstream/dns.rs b/crates/rginx-config/src/validate/upstream/dns.rs index 74caa476..24f8b767 100644 --- a/crates/rginx-config/src/validate/upstream/dns.rs +++ b/crates/rginx-config/src/validate/upstream/dns.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Error, Result, UpstreamConfig}; pub(super) fn validate_dns_settings(upstream: &UpstreamConfig) -> Result<()> { let Some(dns) = &upstream.dns else { diff --git a/crates/rginx-config/src/validate/upstream/health.rs b/crates/rginx-config/src/validate/upstream/health.rs index 7e0e1d28..6ea39e04 100644 --- a/crates/rginx-config/src/validate/upstream/health.rs +++ b/crates/rginx-config/src/validate/upstream/health.rs @@ -1,6 +1,6 @@ use http::uri::PathAndQuery; -use super::*; +use super::{Error, Result, UpstreamConfig, UpstreamProtocolConfig}; pub(super) fn validate_active_health_settings(upstream: &UpstreamConfig) -> Result<()> { if let Some(path) = &upstream.health_check_path { diff --git a/crates/rginx-config/src/validate/upstream/protocol.rs b/crates/rginx-config/src/validate/upstream/protocol.rs index 2c36cb46..11bca5c5 100644 --- a/crates/rginx-config/src/validate/upstream/protocol.rs +++ b/crates/rginx-config/src/validate/upstream/protocol.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Error, Result, UpstreamConfig, UpstreamProtocolConfig}; pub(super) fn validate_protocol_requirements(upstream: &UpstreamConfig) -> Result<()> { for peer in &upstream.peers { @@ -10,10 +10,9 @@ pub(super) fn validate_protocol_requirements(upstream: &UpstreamConfig) -> Resul })?; match (&upstream.protocol, uri.scheme_str()) { - (UpstreamProtocolConfig::Http2, Some("https")) - | (UpstreamProtocolConfig::Http3, Some("https")) + (UpstreamProtocolConfig::Http2 | UpstreamProtocolConfig::Http3, Some("https")) | (UpstreamProtocolConfig::H2c, Some("http")) => {} - (UpstreamProtocolConfig::Http2, _) | (UpstreamProtocolConfig::Http3, _) => { + (UpstreamProtocolConfig::Http2 | UpstreamProtocolConfig::Http3, _) => { let protocol = match &upstream.protocol { UpstreamProtocolConfig::Http2 => "Http2", UpstreamProtocolConfig::Http3 => "Http3", diff --git a/crates/rginx-config/src/validate/upstream/tls.rs b/crates/rginx-config/src/validate/upstream/tls.rs index f4b1b57c..2e03d68d 100644 --- a/crates/rginx-config/src/validate/upstream/tls.rs +++ b/crates/rginx-config/src/validate/upstream/tls.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Error, HashSet, Result, TlsVersionConfig, UpstreamConfig, UpstreamTlsModeConfig}; pub(super) fn validate_tls_settings(upstream: &UpstreamConfig) -> Result<()> { let Some(tls) = &upstream.tls else { diff --git a/crates/rginx-config/src/validate/upstream/tuning.rs b/crates/rginx-config/src/validate/upstream/tuning.rs index afce0573..6473b60a 100644 --- a/crates/rginx-config/src/validate/upstream/tuning.rs +++ b/crates/rginx-config/src/validate/upstream/tuning.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Error, Result, UpstreamConfig}; pub(super) fn validate_timeout_and_tuning(upstream: &UpstreamConfig) -> Result<()> { for (field, value) in [ diff --git a/crates/rginx-core/Cargo.toml b/crates/rginx-core/Cargo.toml index 17d9dd81..fd717c10 100644 --- a/crates/rginx-core/Cargo.toml +++ b/crates/rginx-core/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true rust-version.workspace = true description = "Core types and shared runtime models for rginx." +[lints] +workspace = true + [dependencies] http.workspace = true ipnet.workspace = true diff --git a/crates/rginx-core/src/config.rs b/crates/rginx-core/src/config.rs index 86b476ae..610406fc 100644 --- a/crates/rginx-core/src/config.rs +++ b/crates/rginx-core/src/config.rs @@ -12,6 +12,9 @@ mod tls; mod upstream; mod virtual_host; +#[cfg(test)] +mod tests; + pub use access_log::{AccessLogFormat, AccessLogValues}; pub use acme::{AcmeChallengeType, AcmeSettings, ManagedCertificateSpec}; pub use agent::{AgentAuthSettings, AgentSettings}; @@ -50,6 +53,3 @@ pub use upstream::{ UpstreamProtocol, UpstreamSettings, UpstreamTls, }; pub use virtual_host::VirtualHost; - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-core/src/config/access_log.rs b/crates/rginx-core/src/config/access_log.rs index 4200cb75..34460563 100644 --- a/crates/rginx-core/src/config/access_log.rs +++ b/crates/rginx-core/src/config/access_log.rs @@ -1,17 +1,17 @@ +mod helpers; +mod variables; + use std::fmt::Write as _; use crate::{Error, Result}; -mod helpers; -mod variables; - use helpers::{fallback_access_log_option, fallback_access_log_value, is_access_log_variable_char}; use variables::{AccessLogVariable, parse_access_log_variable}; #[derive(Debug, Clone, PartialEq, Eq)] pub struct AccessLogFormat { - template: String, segments: Vec, + template: String, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -22,42 +22,42 @@ enum AccessLogSegment { #[derive(Debug, Clone, Copy)] pub struct AccessLogValues<'a> { - pub request_id: &'a str, - pub remote_addr: &'a str, - pub peer_addr: &'a str, - pub method: &'a str, + pub body_bytes_sent: Option, + pub cache_status: Option<&'a str>, + pub client_ip_source: &'a str, + pub elapsed_ms: u64, + pub grpc_message: Option<&'a str>, + pub grpc_method: Option<&'a str>, + pub grpc_protocol: Option<&'a str>, + pub grpc_service: Option<&'a str>, + pub grpc_status: Option<&'a str>, pub host: &'a str, + pub http_version: &'a str, + pub method: &'a str, pub path: &'a str, + pub peer_addr: &'a str, + pub referer: Option<&'a str>, + pub remote_addr: &'a str, pub request: &'a str, - pub status: u16, - pub body_bytes_sent: Option, - pub elapsed_ms: u64, - pub client_ip_source: &'a str, - pub vhost: &'a str, + pub request_id: &'a str, pub route: &'a str, pub scheme: &'a str, - pub http_version: &'a str, - pub tls_version: Option<&'a str>, + pub status: u16, pub tls_alpn: Option<&'a str>, - pub user_agent: Option<&'a str>, - pub referer: Option<&'a str>, pub tls_client_authenticated: bool, - pub tls_client_subject: Option<&'a str>, - pub tls_client_issuer: Option<&'a str>, - pub tls_client_serial: Option<&'a str>, - pub tls_client_san_dns_names: Option<&'a str>, pub tls_client_chain_length: Option, pub tls_client_chain_subjects: Option<&'a str>, - pub grpc_protocol: Option<&'a str>, - pub grpc_service: Option<&'a str>, - pub grpc_method: Option<&'a str>, - pub grpc_status: Option<&'a str>, - pub grpc_message: Option<&'a str>, - pub cache_status: Option<&'a str>, - pub upstream_name: Option<&'a str>, + pub tls_client_issuer: Option<&'a str>, + pub tls_client_san_dns_names: Option<&'a str>, + pub tls_client_serial: Option<&'a str>, + pub tls_client_subject: Option<&'a str>, + pub tls_version: Option<&'a str>, pub upstream_addr: Option<&'a str>, - pub upstream_status: Option, + pub upstream_name: Option<&'a str>, pub upstream_response_time_ms: Option, + pub upstream_status: Option, + pub user_agent: Option<&'a str>, + pub vhost: &'a str, } impl AccessLogFormat { @@ -70,7 +70,7 @@ impl AccessLogFormat { while index < bytes.len() { if bytes[index] != b'$' { - index += 1; + advance_index(&mut index, 1); continue; } @@ -79,43 +79,44 @@ impl AccessLogFormat { .push(AccessLogSegment::Literal(template[literal_start..index].to_string())); } - if let Some(next) = bytes.get(index + 1) { + if let Some(next) = bytes.get(checked_index(index, 1)) { if *next == b'$' { segments.push(AccessLogSegment::Literal("$".to_string())); - index += 2; + advance_index(&mut index, 2); literal_start = index; continue; } if *next == b'{' { - let Some(relative_end) = template[index + 2..].find('}') else { + let name_start = checked_index(index, 2); + let Some(relative_end) = template[name_start..].find('}') else { return Err(Error::Config( "access_log_format contains an unterminated `${...}` variable" .to_string(), )); }; - let end = index + 2 + relative_end; - let name = &template[index + 2..end]; + let end = checked_index(name_start, relative_end); + let name = &template[name_start..end]; segments.push(AccessLogSegment::Variable(parse_access_log_variable(name)?)); - index = end + 1; + index = checked_index(end, 1); literal_start = index; continue; } } - let mut end = index + 1; + let mut end = checked_index(index, 1); while end < bytes.len() && is_access_log_variable_char(bytes[end]) { - end += 1; + advance_index(&mut end, 1); } - if end == index + 1 { + if end == checked_index(index, 1) { segments.push(AccessLogSegment::Literal("$".to_string())); - index += 1; + advance_index(&mut index, 1); literal_start = index; continue; } - let name = &template[index + 1..end]; + let name = &template[checked_index(index, 1)..end]; segments.push(AccessLogSegment::Variable(parse_access_log_variable(name)?)); index = end; literal_start = end; @@ -125,15 +126,12 @@ impl AccessLogFormat { segments.push(AccessLogSegment::Literal(template[literal_start..].to_string())); } - Ok(Self { template, segments }) - } - - pub fn template(&self) -> &str { - &self.template + Ok(Self { segments, template }) } + #[must_use] pub fn render(&self, values: &AccessLogValues<'_>) -> String { - let mut rendered = String::with_capacity(self.template.len() + 64); + let mut rendered = String::with_capacity(self.template.len().saturating_add(64)); for segment in &self.segments { match segment { @@ -144,7 +142,7 @@ impl AccessLogFormat { AccessLogVariable::PeerAddr => rendered.push_str(values.peer_addr), AccessLogVariable::Method => rendered.push_str(values.method), AccessLogVariable::Host => { - rendered.push_str(fallback_access_log_value(values.host)) + rendered.push_str(fallback_access_log_value(values.host)); } AccessLogVariable::Path => rendered.push_str(values.path), AccessLogVariable::Request => rendered.push_str(values.request), @@ -167,27 +165,27 @@ impl AccessLogFormat { AccessLogVariable::Scheme => rendered.push_str(values.scheme), AccessLogVariable::HttpVersion => rendered.push_str(values.http_version), AccessLogVariable::TlsVersion => { - rendered.push_str(fallback_access_log_option(values.tls_version)) + rendered.push_str(fallback_access_log_option(values.tls_version)); } AccessLogVariable::TlsAlpn => { - rendered.push_str(fallback_access_log_option(values.tls_alpn)) + rendered.push_str(fallback_access_log_option(values.tls_alpn)); } AccessLogVariable::UserAgent => { - rendered.push_str(fallback_access_log_option(values.user_agent)) + rendered.push_str(fallback_access_log_option(values.user_agent)); } AccessLogVariable::Referer => { - rendered.push_str(fallback_access_log_option(values.referer)) + rendered.push_str(fallback_access_log_option(values.referer)); } AccessLogVariable::TlsClientAuthenticated => rendered .push_str(if values.tls_client_authenticated { "true" } else { "false" }), AccessLogVariable::TlsClientSubject => { - rendered.push_str(fallback_access_log_option(values.tls_client_subject)) + rendered.push_str(fallback_access_log_option(values.tls_client_subject)); } AccessLogVariable::TlsClientIssuer => { - rendered.push_str(fallback_access_log_option(values.tls_client_issuer)) + rendered.push_str(fallback_access_log_option(values.tls_client_issuer)); } AccessLogVariable::TlsClientSerial => { - rendered.push_str(fallback_access_log_option(values.tls_client_serial)) + rendered.push_str(fallback_access_log_option(values.tls_client_serial)); } AccessLogVariable::TlsClientSanDnsNames => rendered .push_str(fallback_access_log_option(values.tls_client_san_dns_names)), @@ -201,28 +199,28 @@ impl AccessLogFormat { AccessLogVariable::TlsClientChainSubjects => rendered .push_str(fallback_access_log_option(values.tls_client_chain_subjects)), AccessLogVariable::GrpcProtocol => { - rendered.push_str(fallback_access_log_option(values.grpc_protocol)) + rendered.push_str(fallback_access_log_option(values.grpc_protocol)); } AccessLogVariable::GrpcService => { - rendered.push_str(fallback_access_log_option(values.grpc_service)) + rendered.push_str(fallback_access_log_option(values.grpc_service)); } AccessLogVariable::GrpcMethod => { - rendered.push_str(fallback_access_log_option(values.grpc_method)) + rendered.push_str(fallback_access_log_option(values.grpc_method)); } AccessLogVariable::GrpcStatus => { - rendered.push_str(fallback_access_log_option(values.grpc_status)) + rendered.push_str(fallback_access_log_option(values.grpc_status)); } AccessLogVariable::GrpcMessage => { - rendered.push_str(fallback_access_log_option(values.grpc_message)) + rendered.push_str(fallback_access_log_option(values.grpc_message)); } AccessLogVariable::CacheStatus => { - rendered.push_str(fallback_access_log_option(values.cache_status)) + rendered.push_str(fallback_access_log_option(values.cache_status)); } AccessLogVariable::UpstreamName => { - rendered.push_str(fallback_access_log_option(values.upstream_name)) + rendered.push_str(fallback_access_log_option(values.upstream_name)); } AccessLogVariable::UpstreamAddr => { - rendered.push_str(fallback_access_log_option(values.upstream_addr)) + rendered.push_str(fallback_access_log_option(values.upstream_addr)); } AccessLogVariable::UpstreamStatus => { if let Some(status) = values.upstream_status { @@ -244,4 +242,17 @@ impl AccessLogFormat { rendered } + + #[must_use] + pub fn template(&self) -> &str { + &self.template + } +} + +fn advance_index(index: &mut usize, increment: usize) { + *index = checked_index(*index, increment); +} + +fn checked_index(index: usize, increment: usize) -> usize { + index.checked_add(increment).expect("access log template index remains representable") } diff --git a/crates/rginx-core/src/config/access_log/variables.rs b/crates/rginx-core/src/config/access_log/variables.rs index ffb4d4e6..f0021afd 100644 --- a/crates/rginx-core/src/config/access_log/variables.rs +++ b/crates/rginx-core/src/config/access_log/variables.rs @@ -2,42 +2,42 @@ use crate::{Error, Result}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(super) enum AccessLogVariable { - RequestId, - RemoteAddr, - PeerAddr, - Method, + BodyBytesSent, + CacheStatus, + ClientIpSource, + ElapsedMs, + GrpcMessage, + GrpcMethod, + GrpcProtocol, + GrpcService, + GrpcStatus, Host, + HttpVersion, + Method, Path, + PeerAddr, + Referer, + RemoteAddr, Request, - Status, - BodyBytesSent, - ElapsedMs, - ClientIpSource, - Vhost, + RequestId, Route, Scheme, - HttpVersion, - TlsVersion, + Status, TlsAlpn, - UserAgent, - Referer, TlsClientAuthenticated, - TlsClientSubject, - TlsClientIssuer, - TlsClientSerial, - TlsClientSanDnsNames, TlsClientChainLength, TlsClientChainSubjects, - GrpcProtocol, - GrpcService, - GrpcMethod, - GrpcStatus, - GrpcMessage, - CacheStatus, - UpstreamName, + TlsClientIssuer, + TlsClientSanDnsNames, + TlsClientSerial, + TlsClientSubject, + TlsVersion, UpstreamAddr, - UpstreamStatus, + UpstreamName, UpstreamResponseTimeMs, + UpstreamStatus, + UserAgent, + Vhost, } pub(super) fn parse_access_log_variable(name: &str) -> Result { diff --git a/crates/rginx-core/src/config/acme.rs b/crates/rginx-core/src/config/acme.rs index 291cfd7c..ba61667c 100644 --- a/crates/rginx-core/src/config/acme.rs +++ b/crates/rginx-core/src/config/acme.rs @@ -7,6 +7,7 @@ pub enum AcmeChallengeType { } impl AcmeChallengeType { + #[must_use] pub fn as_str(self) -> &'static str { match self { Self::Http01 => "http-01", @@ -16,18 +17,18 @@ impl AcmeChallengeType { #[derive(Debug, Clone, PartialEq, Eq)] pub struct AcmeSettings { - pub directory_url: String, pub contacts: Vec, - pub state_dir: PathBuf, - pub renew_before: Duration, + pub directory_url: String, pub poll_interval: Duration, + pub renew_before: Duration, + pub state_dir: PathBuf, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ManagedCertificateSpec { - pub scope: String, - pub domains: Vec, pub cert_path: PathBuf, - pub key_path: PathBuf, pub challenge: AcmeChallengeType, + pub domains: Vec, + pub key_path: PathBuf, + pub scope: String, } diff --git a/crates/rginx-core/src/config/agent.rs b/crates/rginx-core/src/config/agent.rs index 282ac51f..66aa0b4c 100644 --- a/crates/rginx-core/src/config/agent.rs +++ b/crates/rginx-core/src/config/agent.rs @@ -11,17 +11,17 @@ pub struct AgentAuthSettings { #[derive(Debug, Clone)] pub struct AgentSettings { + pub auth: AgentAuthSettings, + pub backoff_initial: Duration, + pub backoff_max: Duration, + pub connect_timeout: Duration, pub endpoint: Uri, + pub heartbeat_interval: Duration, + pub labels: BTreeMap, pub node_id: String, - pub auth: AgentAuthSettings, - pub state_path: PathBuf, - pub region: Option, + pub poll_timeout: Duration, pub pop: Option, - pub labels: BTreeMap, - pub heartbeat_interval: Duration, - pub connect_timeout: Duration, + pub region: Option, pub request_timeout: Duration, - pub poll_timeout: Duration, - pub backoff_initial: Duration, - pub backoff_max: Duration, + pub state_path: PathBuf, } diff --git a/crates/rginx-core/src/config/cache.rs b/crates/rginx-core/src/config/cache.rs index 6b068a00..7e7b37ec 100644 --- a/crates/rginx-core/src/config/cache.rs +++ b/crates/rginx-core/src/config/cache.rs @@ -1,52 +1,52 @@ +mod key_template; +mod predicate; + use std::time::Duration; use http::header::HeaderName; use http::{HeaderMap, Method, StatusCode}; -mod key_template; -mod predicate; - pub use key_template::CacheKeyTemplateError; #[derive(Debug, Clone)] pub struct CacheZone { - pub name: String, - pub path: std::path::PathBuf, - pub max_size_bytes: Option, - pub inactive: Duration, pub default_ttl: Duration, - pub max_entry_bytes: usize, - pub path_levels: Vec, + pub inactive: Duration, + pub inactive_cleanup_interval: Duration, pub loader_batch_entries: usize, pub loader_sleep: Duration, pub manager_batch_entries: usize, pub manager_sleep: Duration, - pub inactive_cleanup_interval: Duration, + pub max_entry_bytes: usize, + pub max_size_bytes: Option, + pub name: String, + pub path: std::path::PathBuf, + pub path_levels: Vec, pub shared_index: bool, } #[derive(Debug, Clone)] pub struct RouteCachePolicy { - pub zone: String, - pub methods: Vec, - pub statuses: Vec, - pub ttl_by_status: Vec, - pub key: CacheKeyTemplate, + pub background_update: bool, pub cache_bypass: Option, - pub no_cache: Option, - pub stale_if_error: Option, + pub convert_head: bool, pub grace: Option, + pub ignore_headers: Vec, pub keep: Option, - pub pass_ttl: Option, - pub use_stale: Vec, - pub background_update: bool, - pub lock_timeout: Duration, + pub key: CacheKeyTemplate, pub lock_age: Duration, + pub lock_timeout: Duration, + pub methods: Vec, pub min_uses: u64, - pub ignore_headers: Vec, + pub no_cache: Option, + pub pass_ttl: Option, pub range_requests: CacheRangeRequestPolicy, pub slice_size_bytes: Option, - pub convert_head: bool, + pub stale_if_error: Option, + pub statuses: Vec, + pub ttl_by_status: Vec, + pub use_stale: Vec, + pub zone: String, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -57,24 +57,22 @@ pub struct CacheStatusTtlRule { #[derive(Debug, Clone, PartialEq, Eq)] pub enum CachePredicate { - Any(Vec), All(Vec), - Not(Box), - Method(Method), - HeaderExists(HeaderName), + Any(Vec), + CookieEquals { name: String, value: String }, + CookieExists(String), HeaderEquals { name: HeaderName, value: String }, - QueryExists(String), + HeaderExists(HeaderName), + Method(Method), + Not(Box), QueryEquals { name: String, value: String }, - CookieExists(String), - CookieEquals { name: String, value: String }, + QueryExists(String), Status(Vec), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CacheUseStaleCondition { Error, - Timeout, - Updating, Http403, Http404, Http429, @@ -82,15 +80,17 @@ pub enum CacheUseStaleCondition { Http502, Http503, Http504, + Timeout, + Updating, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CacheIgnoreHeader { - XAccelExpires, - Expires, CacheControl, + Expires, SetCookie, Vary, + XAccelExpires, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -101,15 +101,15 @@ pub enum CacheRangeRequestPolicy { #[derive(Debug, Clone, Copy)] pub struct CachePredicateRequestContext<'a> { + pub headers: &'a HeaderMap, pub method: &'a Method, pub uri: &'a str, - pub headers: &'a HeaderMap, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct CacheKeyTemplate { - raw: String, parts: Vec, + raw: String, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -120,20 +120,20 @@ enum CacheKeyTemplatePart { #[derive(Debug, Clone, PartialEq, Eq)] enum CacheKeyVariable { - Scheme, + Cookie(String), + Header(HeaderName), Host, - Uri, Method, - Header(HeaderName), Query(String), - Cookie(String), + Scheme, + Uri, } #[derive(Debug, Clone, Copy)] pub struct CacheKeyRenderContext<'a> { - pub scheme: &'a str, + pub headers: &'a HeaderMap, pub host: &'a str, - pub uri: &'a str, pub method: &'a str, - pub headers: &'a HeaderMap, + pub scheme: &'a str, + pub uri: &'a str, } diff --git a/crates/rginx-core/src/config/cache/key_template.rs b/crates/rginx-core/src/config/cache/key_template.rs index 2f382ce2..7242efbe 100644 --- a/crates/rginx-core/src/config/cache/key_template.rs +++ b/crates/rginx-core/src/config/cache/key_template.rs @@ -2,9 +2,57 @@ use http::HeaderMap; use thiserror::Error; use super::predicate::{cookie_pairs, query_pairs}; -use super::*; +use super::{ + CacheKeyRenderContext, CacheKeyTemplate, CacheKeyTemplatePart, CacheKeyVariable, HeaderName, +}; impl CacheKeyTemplate { + pub fn append_rendered(&self, rendered: &mut String, context: &CacheKeyRenderContext<'_>) { + for part in &self.parts { + match part { + CacheKeyTemplatePart::Literal(value) => rendered.push_str(value), + CacheKeyTemplatePart::Variable(variable) => match variable { + CacheKeyVariable::Scheme => rendered.push_str(context.scheme), + CacheKeyVariable::Host => rendered.push_str(context.host), + CacheKeyVariable::Uri => rendered.push_str(context.uri), + CacheKeyVariable::Method => rendered.push_str(context.method), + CacheKeyVariable::Header(name) => { + append_joined_header_values(rendered, context.headers, name); + } + CacheKeyVariable::Query(name) => { + if let Some(value) = query_pairs(context.uri) + .find_map(|(key, value)| (key == name).then_some(value)) + { + rendered.push_str(value); + } + } + CacheKeyVariable::Cookie(name) => { + if let Some(value) = cookie_pairs(context.headers) + .find_map(|(key, value)| (key == name).then_some(value)) + { + rendered.push_str(value); + } + } + }, + } + } + } + + #[must_use] + pub fn as_str(&self) -> &str { + &self.raw + } + + #[must_use] + pub fn estimated_rendered_capacity(&self, context: &CacheKeyRenderContext<'_>) -> usize { + self.raw + .len() + .saturating_add(context.scheme.len()) + .saturating_add(context.host.len()) + .saturating_add(context.uri.len()) + .saturating_add(context.method.len()) + } + pub fn parse(raw: impl Into) -> Result { let raw = raw.into(); let mut parts = Vec::new(); @@ -15,12 +63,12 @@ impl CacheKeyTemplate { let remainder = &raw[index..]; if remainder.starts_with("{{") { literal.push('{'); - index += 2; + advance_index(&mut index, 2); continue; } if remainder.starts_with("}}") { literal.push('}'); - index += 2; + advance_index(&mut index, 2); continue; } @@ -31,7 +79,8 @@ impl CacheKeyTemplate { parts.push(CacheKeyTemplatePart::Literal(std::mem::take(&mut literal))); } - let after_start = &raw[index + 1..]; + let variable_start = checked_index(index, 1); + let after_start = &raw[variable_start..]; let Some(end) = after_start.find('}') else { return Err(CacheKeyTemplateError::UnclosedVariable { template: raw.clone(), @@ -42,14 +91,14 @@ impl CacheKeyTemplate { return Err(CacheKeyTemplateError::EmptyVariable { template: raw.clone() }); } parts.push(CacheKeyTemplatePart::Variable(parse_cache_key_variable(variable)?)); - index += end + 2; + index = checked_index(variable_start, checked_index(end, 1)); } '}' => { return Err(CacheKeyTemplateError::UnescapedClose { template: raw.clone() }); } _ => { literal.push(ch); - index += ch.len_utf8(); + advance_index(&mut index, ch.len_utf8()); } } } @@ -58,58 +107,24 @@ impl CacheKeyTemplate { parts.push(CacheKeyTemplatePart::Literal(literal)); } - Ok(Self { raw, parts }) + Ok(Self { parts, raw }) } - pub fn as_str(&self) -> &str { - &self.raw + #[must_use] + pub fn references_method(&self) -> bool { + self.parts + .iter() + .any(|part| matches!(part, CacheKeyTemplatePart::Variable(CacheKeyVariable::Method))) } + #[must_use] pub fn render(&self, context: &CacheKeyRenderContext<'_>) -> String { let mut rendered = String::with_capacity(self.estimated_rendered_capacity(context)); self.append_rendered(&mut rendered, context); rendered } - pub fn append_rendered(&self, rendered: &mut String, context: &CacheKeyRenderContext<'_>) { - for part in &self.parts { - match part { - CacheKeyTemplatePart::Literal(value) => rendered.push_str(value), - CacheKeyTemplatePart::Variable(variable) => match variable { - CacheKeyVariable::Scheme => rendered.push_str(context.scheme), - CacheKeyVariable::Host => rendered.push_str(context.host), - CacheKeyVariable::Uri => rendered.push_str(context.uri), - CacheKeyVariable::Method => rendered.push_str(context.method), - CacheKeyVariable::Header(name) => { - append_joined_header_values(rendered, context.headers, name); - } - CacheKeyVariable::Query(name) => { - if let Some(value) = query_pairs(context.uri) - .find_map(|(key, value)| (key == name).then_some(value)) - { - rendered.push_str(value); - } - } - CacheKeyVariable::Cookie(name) => { - if let Some(value) = cookie_pairs(context.headers) - .find_map(|(key, value)| (key == name).then_some(value)) - { - rendered.push_str(value); - } - } - }, - } - } - } - - pub fn estimated_rendered_capacity(&self, context: &CacheKeyRenderContext<'_>) -> usize { - self.raw.len() - + context.scheme.len() - + context.host.len() - + context.uri.len() - + context.method.len() - } - + #[must_use] pub fn static_rendered(&self) -> Option<&str> { match self.parts.as_slice() { [] => Some(""), @@ -117,32 +132,34 @@ impl CacheKeyTemplate { _ => None, } } +} - pub fn references_method(&self) -> bool { - self.parts - .iter() - .any(|part| matches!(part, CacheKeyTemplatePart::Variable(CacheKeyVariable::Method))) - } +fn advance_index(index: &mut usize, increment: usize) { + *index = checked_index(*index, increment); +} + +fn checked_index(index: usize, increment: usize) -> usize { + index.checked_add(increment).expect("cache key template index remains representable") } #[derive(Debug, Error)] pub enum CacheKeyTemplateError { - #[error("cache key template `{template}` has an unclosed variable")] - UnclosedVariable { template: String }, - #[error("cache key template `{template}` has an unescaped closing brace")] - UnescapedClose { template: String }, #[error("cache key template `{template}` has an empty variable")] EmptyVariable { template: String }, - #[error("cache key template variable `{name}` is not supported")] - UnknownVariable { name: String }, + #[error("cache key template variable `{name}` requires a non-empty value")] + EmptyVariableArgument { name: String }, #[error("cache key template request header `{name}` is invalid: {source}")] InvalidRequestHeader { name: String, #[source] source: http::header::InvalidHeaderName, }, - #[error("cache key template variable `{name}` requires a non-empty value")] - EmptyVariableArgument { name: String }, + #[error("cache key template `{template}` has an unclosed variable")] + UnclosedVariable { template: String }, + #[error("cache key template `{template}` has an unescaped closing brace")] + UnescapedClose { template: String }, + #[error("cache key template variable `{name}` is not supported")] + UnknownVariable { name: String }, } fn parse_cache_key_variable(name: &str) -> Result { diff --git a/crates/rginx-core/src/config/cache/predicate.rs b/crates/rginx-core/src/config/cache/predicate.rs index 6c20523e..006784a4 100644 --- a/crates/rginx-core/src/config/cache/predicate.rs +++ b/crates/rginx-core/src/config/cache/predicate.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{CachePredicate, CachePredicateRequestContext, HeaderMap, StatusCode}; impl CachePredicate { pub fn matches_request(&self, request: &CachePredicateRequestContext<'_>) -> bool { diff --git a/crates/rginx-core/src/config/control_plane.rs b/crates/rginx-core/src/config/control_plane.rs index ab6d3b48..a40a5e8a 100644 --- a/crates/rginx-core/src/config/control_plane.rs +++ b/crates/rginx-core/src/config/control_plane.rs @@ -7,8 +7,8 @@ use ipnet::IpNet; #[derive(Debug, Clone)] pub struct ControlPlaneTlsSettings { pub cert_path: PathBuf, - pub key_path: PathBuf, pub client_ca_path: Option, + pub key_path: PathBuf, pub require_client_cert: bool, } @@ -20,17 +20,18 @@ pub struct ControlPlaneTlsSettings { /// they are required for compatibility. #[derive(Debug, Clone)] pub struct ControlPlaneSettings { - pub listen: SocketAddr, - pub tls: ControlPlaneTlsSettings, pub allowed_cidrs: Vec, pub api_keys_path: PathBuf, + pub labels: BTreeMap, + pub listen: SocketAddr, pub node_id: Option, - pub region: Option, pub pop: Option, - pub labels: BTreeMap, + pub region: Option, + pub tls: ControlPlaneTlsSettings, } impl ControlPlaneSettings { + #[must_use] pub fn allows(&self, ip: IpAddr) -> bool { if self.allowed_cidrs.is_empty() { return true; diff --git a/crates/rginx-core/src/config/listener.rs b/crates/rginx-core/src/config/listener.rs index 00ac035c..c1be6729 100644 --- a/crates/rginx-core/src/config/listener.rs +++ b/crates/rginx-core/src/config/listener.rs @@ -11,6 +11,7 @@ pub enum ListenerTransportKind { } impl ListenerTransportKind { + #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::Tcp => "tcp", @@ -27,6 +28,7 @@ pub enum ListenerApplicationProtocol { } impl ListenerApplicationProtocol { + #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::Http1 => "http1", @@ -38,58 +40,58 @@ impl ListenerApplicationProtocol { #[derive(Debug, Clone)] pub struct ListenerHttp3 { - pub listen_addr: SocketAddr, + pub active_connection_id_limit: u32, pub advertise_alt_svc: bool, pub alt_svc_max_age: Duration, + pub early_data_enabled: bool, + pub gso: bool, + pub host_key_path: Option, + pub listen_addr: SocketAddr, pub max_concurrent_streams: usize, - pub stream_buffer_size: usize, - pub active_connection_id_limit: u32, pub retry: bool, - pub host_key_path: Option, - pub gso: bool, - pub early_data_enabled: bool, + pub stream_buffer_size: usize, } #[derive(Debug, Clone)] pub struct ListenerTransportBinding { - pub name: &'static str, - pub kind: ListenerTransportKind, - pub listen_addr: SocketAddr, - pub protocols: Vec, pub advertise_alt_svc: bool, pub alt_svc_max_age: Option, - pub http3_max_concurrent_streams: Option, - pub http3_stream_buffer_size: Option, pub http3_active_connection_id_limit: Option, - pub http3_retry: Option, - pub http3_host_key_path: Option, - pub http3_gso: Option, pub http3_early_data_enabled: Option, + pub http3_gso: Option, + pub http3_host_key_path: Option, + pub http3_max_concurrent_streams: Option, + pub http3_retry: Option, + pub http3_stream_buffer_size: Option, + pub kind: ListenerTransportKind, + pub listen_addr: SocketAddr, + pub name: &'static str, + pub protocols: Vec, } #[derive(Debug, Clone)] pub struct Listener { + pub default_server: bool, + pub http3: Option, pub id: String, pub name: String, - pub server: Server, - pub default_server: bool, + pub proxy_protocol_enabled: bool, pub reuse_port_enabled: bool, + pub server: Server, pub tls_termination_enabled: bool, - pub proxy_protocol_enabled: bool, - pub http3: Option, } impl Listener { - pub fn tls_enabled(&self) -> bool { - self.tls_termination_enabled + pub fn binding_count(&self) -> usize { + usize::from(self.http3.is_some()).saturating_add(1) } pub fn http3_enabled(&self) -> bool { self.http3.is_some() } - pub fn binding_count(&self) -> usize { - 1 + usize::from(self.http3.is_some()) + pub fn tls_enabled(&self) -> bool { + self.tls_termination_enabled } pub fn transport_bindings(&self) -> Vec { diff --git a/crates/rginx-core/src/config/route.rs b/crates/rginx-core/src/config/route.rs index 234eecb0..40ef1c3b 100644 --- a/crates/rginx-core/src/config/route.rs +++ b/crates/rginx-core/src/config/route.rs @@ -1,3 +1,6 @@ +mod proxy_header; +mod regex_matcher; + use std::net::IpAddr; use std::path::PathBuf; use std::sync::Arc; @@ -9,9 +12,6 @@ use ipnet::IpNet; use super::cache::RouteCachePolicy; use super::upstream::Upstream; -mod proxy_header; -mod regex_matcher; - pub use proxy_header::{ ProxyHeaderRenderContext, ProxyHeaderTemplate, ProxyHeaderTemplateError, ProxyHeaderValue, }; @@ -19,30 +19,30 @@ pub use regex_matcher::{RouteRegexError, RouteRegexMatcher}; #[derive(Debug, Clone)] pub struct Route { - pub id: String, - pub matcher: RouteMatcher, - pub internal: bool, - pub rewrites: Vec, - pub try_files: Vec, + pub access_control: RouteAccessControl, + pub action: RouteAction, + pub allow_early_data: bool, + pub cache: Option, + pub compression: RouteCompressionPolicy, + pub compression_content_types: Vec, + pub compression_min_bytes: Option, pub error_pages: Vec, pub grpc_match: Option, - pub action: RouteAction, - pub access_control: RouteAccessControl, + pub id: String, + pub internal: bool, + pub matcher: RouteMatcher, pub rate_limit: Option, - pub allow_early_data: bool, pub request_buffering: RouteBufferingPolicy, pub response_buffering: RouteBufferingPolicy, - pub compression: RouteCompressionPolicy, - pub compression_min_bytes: Option, - pub compression_content_types: Vec, + pub rewrites: Vec, pub streaming_response_idle_timeout: Option, - pub cache: Option, + pub try_files: Vec, } impl Route { pub fn priority(&self) -> (u8, usize, u8) { let (matcher_rank, matcher_len) = self.matcher.priority(); - let grpc_rank = self.grpc_match.as_ref().map_or(0, |grpc_match| grpc_match.priority()); + let grpc_rank = self.grpc_match.as_ref().map_or(0, GrpcRouteMatch::priority); (matcher_rank, matcher_len, grpc_rank) } } @@ -54,10 +54,7 @@ pub struct RouteAccessControl { } impl RouteAccessControl { - pub fn new(allow_cidrs: Vec, deny_cidrs: Vec) -> Self { - Self { allow_cidrs, deny_cidrs } - } - + #[must_use] pub fn allows(&self, ip: IpAddr) -> bool { if self.deny_cidrs.iter().any(|cidr| cidr.contains(&ip)) { return false; @@ -65,17 +62,22 @@ impl RouteAccessControl { self.allow_cidrs.is_empty() || self.allow_cidrs.iter().any(|cidr| cidr.contains(&ip)) } + #[must_use] + pub fn new(allow_cidrs: Vec, deny_cidrs: Vec) -> Self { + Self { allow_cidrs, deny_cidrs } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RouteRateLimit { - pub requests_per_sec: u32, pub burst: u32, + pub requests_per_sec: u32, } impl RouteRateLimit { + #[must_use] pub fn new(requests_per_sec: u32, burst: u32) -> Self { - Self { requests_per_sec, burst } + Self { burst, requests_per_sec } } } @@ -83,11 +85,12 @@ impl RouteRateLimit { pub enum RouteBufferingPolicy { #[default] Auto, - On, Off, + On, } impl RouteBufferingPolicy { + #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::Auto => "auto", @@ -99,13 +102,14 @@ impl RouteBufferingPolicy { #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum RouteCompressionPolicy { - Off, #[default] Auto, Force, + Off, } impl RouteCompressionPolicy { + #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::Off => "off", @@ -118,32 +122,14 @@ impl RouteCompressionPolicy { #[derive(Debug, Clone, PartialEq, Eq)] pub enum RouteMatcher { Exact(String), + Named(String), PreferredPrefix(String), Prefix(String), Regex(RouteRegexMatcher), - Named(String), } impl RouteMatcher { - pub fn matches(&self, path: &str) -> bool { - match self { - Self::Exact(expected) => path == expected, - Self::PreferredPrefix(prefix) | Self::Prefix(prefix) => prefix_matches(prefix, path), - Self::Regex(regex) => regex.matches(path), - Self::Named(_) => false, - } - } - - pub fn priority(&self) -> (u8, usize) { - match self { - Self::Exact(path) => (4, path.len()), - Self::PreferredPrefix(path) => (3, path.len()), - Self::Regex(_) => (2, 0), - Self::Prefix(path) => (1, path.len()), - Self::Named(_) => (0, 0), - } - } - + #[must_use] pub fn id_fragment(&self) -> String { match self { Self::Exact(path) => format!("exact:{path}"), @@ -159,38 +145,36 @@ impl RouteMatcher { Self::Named(name) => format!("named:{name}"), } } -} - -fn prefix_matches(prefix: &str, path: &str) -> bool { - if prefix == "/" { - return true; + #[must_use] + pub fn matches(&self, path: &str) -> bool { + match self { + Self::Exact(expected) => path == expected, + Self::PreferredPrefix(prefix) | Self::Prefix(prefix) => prefix_matches(prefix, path), + Self::Regex(regex) => regex.matches(path), + Self::Named(_) => false, + } } - if path == prefix { - return true; + #[must_use] + pub fn priority(&self) -> (u8, usize) { + match self { + Self::Exact(path) => (4, path.len()), + Self::PreferredPrefix(path) => (3, path.len()), + Self::Regex(_) => (2, 0), + Self::Prefix(path) => (1, path.len()), + Self::Named(_) => (0, 0), + } } - - path.strip_prefix(prefix).is_some_and(|remainder| { - if prefix.ends_with('/') { true } else { remainder.starts_with('/') } - }) } #[derive(Debug, Clone, PartialEq, Eq)] pub struct GrpcRouteMatch { - pub service: Option, pub method: Option, + pub service: Option, } impl GrpcRouteMatch { - pub fn matches(&self, service: &str, method: &str) -> bool { - self.service.as_deref().is_none_or(|expected| expected == service) - && self.method.as_deref().is_none_or(|expected| expected == method) - } - - pub fn priority(&self) -> u8 { - u8::from(self.service.is_some()) + u8::from(self.method.is_some()) - } - + #[must_use] pub fn id_fragment(&self) -> String { let mut fragments = Vec::new(); if let Some(service) = &self.service { @@ -201,12 +185,22 @@ impl GrpcRouteMatch { } format!("grpc:{}", fragments.join(",")) } + #[must_use] + pub fn matches(&self, service: &str, method: &str) -> bool { + self.service.as_deref().is_none_or(|expected| expected == service) + && self.method.as_deref().is_none_or(|expected| expected == method) + } + + #[must_use] + pub fn priority(&self) -> u8 { + u8::from(self.service.is_some()).saturating_add(u8::from(self.method.is_some())) + } } #[derive(Debug, Clone)] pub enum RouteAction { - Proxy(ProxyTarget), File(FileRoute), + Proxy(ProxyTarget), Return(ReturnAction), } @@ -226,20 +220,20 @@ pub enum RouteRewriteStop { #[derive(Debug, Clone)] pub struct ProxyTarget { - pub upstream_name: String, - pub upstream: Arc, pub preserve_host: bool, - pub uri_mode: ProxyUriMode, - pub request_version: Version, - pub redirect: ProxyRedirectMode, pub proxy_set_headers: Vec<(HeaderName, ProxyHeaderValue)>, + pub redirect: ProxyRedirectMode, + pub request_version: Version, + pub upstream: Arc, + pub upstream_name: String, + pub uri_mode: ProxyUriMode, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum ProxyUriMode { PassOriginal, - StripPrefix(String), ReplacePrefix { matched_prefix: String, replacement: String }, + StripPrefix(String), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -250,31 +244,31 @@ pub enum ProxyRedirectMode { #[derive(Debug, Clone)] pub struct FileRoute { - pub path_strategy: FilePathStrategy, - pub route_prefix: Option, - pub index_files: Vec, - pub etag: bool, - pub default_content_type: String, - pub content_type_overrides: Vec<(String, String)>, - pub cache_control: Option, - pub expires: Option, - pub immutable: bool, pub autoindex: bool, + pub cache_control: Option, pub content_disposition: Option, - pub hide_dotfiles: bool, + pub content_type_overrides: Vec<(String, String)>, + pub default_content_type: String, + pub etag: bool, + pub expires: Option, pub follow_symlinks: bool, + pub hide_dotfiles: bool, + pub immutable: bool, + pub index_files: Vec, + pub path_strategy: FilePathStrategy, + pub route_prefix: Option, } #[derive(Debug, Clone)] pub enum FilePathStrategy { - Root(PathBuf), Alias(PathBuf), + Root(PathBuf), } #[derive(Debug, Clone, PartialEq, Eq)] pub enum TryFileStep { - Path(String), Named(String), + Path(String), Status(StatusCode), } @@ -286,14 +280,28 @@ pub struct RouteErrorPage { #[derive(Debug, Clone, PartialEq, Eq)] pub enum ErrorPageTarget { - Uri(String), Named(String), Status(StatusCode), + Uri(String), } #[derive(Debug, Clone)] pub struct ReturnAction { - pub status: StatusCode, - pub location: String, pub body: Option, + pub location: String, + pub status: StatusCode, +} + +fn prefix_matches(prefix: &str, path: &str) -> bool { + if prefix == "/" { + return true; + } + + if path == prefix { + return true; + } + + path.strip_prefix(prefix).is_some_and(|remainder| { + if prefix.ends_with('/') { true } else { remainder.starts_with('/') } + }) } diff --git a/crates/rginx-core/src/config/route/proxy_header.rs b/crates/rginx-core/src/config/route/proxy_header.rs index 6d3d9bc0..d7645e77 100644 --- a/crates/rginx-core/src/config/route/proxy_header.rs +++ b/crates/rginx-core/src/config/route/proxy_header.rs @@ -10,19 +10,22 @@ use thiserror::Error; /// deliberately. #[derive(Debug, Clone, PartialEq, Eq)] pub enum ProxyHeaderValue { - Static(HeaderValue), - Host, - Scheme, ClientIp, - RemoteAddr, - PeerAddr, ForwardedFor, + Host, + PeerAddr, + RemoteAddr, + Remove, RequestHeader(HeaderName), + Scheme, + Static(HeaderValue), Template(ProxyHeaderTemplate), - Remove, } impl ProxyHeaderValue { + pub fn removes_header(&self) -> bool { + matches!(self, Self::Remove) + } pub fn render( &self, context: &ProxyHeaderRenderContext<'_>, @@ -41,19 +44,20 @@ impl ProxyHeaderValue { Self::Remove => Ok(None), } } - - pub fn removes_header(&self) -> bool { - matches!(self, Self::Remove) - } } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ProxyHeaderTemplate { - raw: String, parts: Vec, + raw: String, } impl ProxyHeaderTemplate { + #[must_use] + pub fn as_str(&self) -> &str { + &self.raw + } + pub fn parse(raw: String) -> Result { let mut parts = Vec::new(); let mut literal = String::new(); @@ -63,12 +67,12 @@ impl ProxyHeaderTemplate { let remainder = &raw[index..]; if remainder.starts_with("{{") { literal.push('{'); - index += 2; + advance_index(&mut index, 2); continue; } if remainder.starts_with("}}") { literal.push('}'); - index += 2; + advance_index(&mut index, 2); continue; } @@ -79,7 +83,8 @@ impl ProxyHeaderTemplate { parts.push(ProxyHeaderTemplatePart::Literal(std::mem::take(&mut literal))); } - let after_start = &raw[index + 1..]; + let variable_start = checked_index(index, 1); + let after_start = &raw[variable_start..]; let Some(end) = after_start.find('}') else { return Err(ProxyHeaderTemplateError::UnclosedVariable { template: raw.clone(), @@ -94,14 +99,14 @@ impl ProxyHeaderTemplate { parts.push(ProxyHeaderTemplatePart::Variable(parse_proxy_header_variable( variable, )?)); - index += end + 2; + index = checked_index(variable_start, checked_index(end, 1)); } '}' => { return Err(ProxyHeaderTemplateError::UnescapedClose { template: raw.clone() }); } _ => { literal.push(ch); - index += ch.len_utf8(); + advance_index(&mut index, ch.len_utf8()); } } } @@ -110,11 +115,7 @@ impl ProxyHeaderTemplate { parts.push(ProxyHeaderTemplatePart::Literal(literal)); } - Ok(Self { raw, parts }) - } - - pub fn as_str(&self) -> &str { - &self.raw + Ok(Self { parts, raw }) } fn render( @@ -136,6 +137,14 @@ impl ProxyHeaderTemplate { } } +fn advance_index(index: &mut usize, increment: usize) { + *index = checked_index(*index, increment); +} + +fn checked_index(index: usize, increment: usize) -> usize { + index.checked_add(increment).expect("proxy header template index remains representable") +} + #[derive(Debug, Clone, PartialEq, Eq)] enum ProxyHeaderTemplatePart { Literal(String), @@ -144,13 +153,13 @@ enum ProxyHeaderTemplatePart { #[derive(Debug, Clone, PartialEq, Eq)] enum ProxyHeaderVariable { - Host, - Scheme, ClientIp, - RemoteAddr, - PeerAddr, ForwardedFor, + Host, + PeerAddr, + RemoteAddr, RequestHeader(HeaderName), + Scheme, } impl ProxyHeaderVariable { @@ -174,13 +183,13 @@ impl ProxyHeaderVariable { #[derive(Debug, Clone, Copy)] pub struct ProxyHeaderRenderContext<'a> { + pub client_ip: IpAddr, + pub forwarded_for: &'a str, pub original_headers: &'a HeaderMap, pub original_host: Option<&'a HeaderValue>, - pub upstream_authority: &'a str, - pub client_ip: IpAddr, pub peer_addr: SocketAddr, - pub forwarded_for: &'a str, pub scheme: &'a str, + pub upstream_authority: &'a str, } impl ProxyHeaderRenderContext<'_> { @@ -191,20 +200,20 @@ impl ProxyHeaderRenderContext<'_> { #[derive(Debug, Error)] pub enum ProxyHeaderTemplateError { - #[error("proxy header template `{template}` has an unclosed variable")] - UnclosedVariable { template: String }, - #[error("proxy header template `{template}` has an unescaped closing brace")] - UnescapedClose { template: String }, #[error("proxy header template `{template}` has an empty variable")] EmptyVariable { template: String }, - #[error("proxy header template variable `{name}` is not supported")] - UnknownVariable { name: String }, #[error("proxy header template request header `{name}` is invalid: {source}")] InvalidRequestHeader { name: String, #[source] source: http::header::InvalidHeaderName, }, + #[error("proxy header template `{template}` has an unclosed variable")] + UnclosedVariable { template: String }, + #[error("proxy header template `{template}` has an unescaped closing brace")] + UnescapedClose { template: String }, + #[error("proxy header template variable `{name}` is not supported")] + UnknownVariable { name: String }, } fn parse_proxy_header_variable( diff --git a/crates/rginx-core/src/config/route/regex_matcher.rs b/crates/rginx-core/src/config/route/regex_matcher.rs index 1a6752e5..25a131aa 100644 --- a/crates/rginx-core/src/config/route/regex_matcher.rs +++ b/crates/rginx-core/src/config/route/regex_matcher.rs @@ -3,14 +3,24 @@ use thiserror::Error; #[derive(Debug, Clone)] pub struct RouteRegexMatcher { - pattern: String, case_insensitive: bool, + pattern: String, regex: Regex, } impl RouteRegexMatcher { pub const SIZE_LIMIT_BYTES: usize = 1 << 20; + #[must_use] + pub fn case_insensitive(&self) -> bool { + self.case_insensitive + } + + #[must_use] + pub fn matches(&self, path: &str) -> bool { + self.regex.is_match(path) + } + pub fn new(pattern: String, case_insensitive: bool) -> Result { let regex = RegexBuilder::new(&pattern) .case_insensitive(case_insensitive) @@ -21,21 +31,15 @@ impl RouteRegexMatcher { source, })?; - Ok(Self { pattern, case_insensitive, regex }) + Ok(Self { case_insensitive, pattern, regex }) } + #[must_use] pub fn pattern(&self) -> &str { &self.pattern } - pub fn case_insensitive(&self) -> bool { - self.case_insensitive - } - - pub fn matches(&self, path: &str) -> bool { - self.regex.is_match(path) - } - + #[must_use] pub fn replace(&self, text: &str, replacement: &str) -> String { self.regex.replace(text, replacement).into_owned() } diff --git a/crates/rginx-core/src/config/server.rs b/crates/rginx-core/src/config/server.rs index c3da55dd..13c1db99 100644 --- a/crates/rginx-core/src/config/server.rs +++ b/crates/rginx-core/src/config/server.rs @@ -6,40 +6,42 @@ use ipnet::IpNet; use super::{AccessLogFormat, ServerTls}; +pub const DEFAULT_SERVER_HEADER: &str = "rginx"; + #[derive(Debug, Clone)] pub struct RuntimeSettings { + pub accept_workers: usize, pub shutdown_timeout: Duration, pub worker_threads: Option, - pub accept_workers: usize, } #[derive(Debug, Clone, Default)] pub struct Http1Settings { pub half_close: bool, - pub title_case_headers: bool, - pub preserve_header_case: bool, pub max_buf_size: Option, pub pipeline_flush: bool, + pub preserve_header_case: bool, + pub title_case_headers: bool, pub writev: Option, } #[derive(Debug, Clone)] pub struct Server { - pub listen_addr: SocketAddr, - pub server_header: HeaderValue, - pub default_certificate: Option, - pub trusted_proxies: Vec, + pub access_log_format: Option, pub client_ip_header: Option, + pub default_certificate: Option, + pub header_read_timeout: Option, + pub http1: Http1Settings, pub keep_alive: bool, + pub listen_addr: SocketAddr, + pub max_connections: Option, pub max_headers: Option, pub max_request_body_bytes: Option, - pub max_connections: Option, - pub header_read_timeout: Option, pub request_body_read_timeout: Option, pub response_write_timeout: Option, - pub http1: Http1Settings, - pub access_log_format: Option, + pub server_header: HeaderValue, pub tls: Option, + pub trusted_proxies: Vec, } impl Server { @@ -48,8 +50,7 @@ impl Server { } } -pub const DEFAULT_SERVER_HEADER: &str = "rginx"; - +#[must_use] pub fn default_server_header() -> HeaderValue { HeaderValue::from_static(DEFAULT_SERVER_HEADER) } diff --git a/crates/rginx-core/src/config/server_name.rs b/crates/rginx-core/src/config/server_name.rs index 61677259..e4589c7b 100644 --- a/crates/rginx-core/src/config/server_name.rs +++ b/crates/rginx-core/src/config/server_name.rs @@ -2,14 +2,15 @@ use regex::Regex; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum ServerNameMatch { - Exact, DotWildcard { suffix_len: usize }, + Exact, LeadingWildcard { suffix_len: usize }, - TrailingWildcard { prefix_len: usize }, Regex, + TrailingWildcard { prefix_len: usize }, } impl ServerNameMatch { + #[must_use] pub fn priority(self) -> (u8, usize) { match self { Self::Exact => (4, 0), @@ -23,11 +24,12 @@ impl ServerNameMatch { #[derive(Debug, Clone)] pub struct CompiledServerNamePattern { - pattern: String, kind: CompiledServerNamePatternKind, + pattern: String, } impl CompiledServerNamePattern { + #[must_use] pub fn compile(pattern: &str) -> Option { let pattern = pattern.trim().to_ascii_lowercase(); if pattern.is_empty() { @@ -35,59 +37,55 @@ impl CompiledServerNamePattern { } let kind = CompiledServerNamePatternKind::compile(&pattern)?; - Some(Self { pattern, kind }) + Some(Self { kind, pattern }) } - pub fn pattern(&self) -> &str { - &self.pattern + pub(crate) fn into_kind(self) -> CompiledServerNamePatternKind { + self.kind } + #[must_use] pub fn is_exact(&self) -> bool { matches!(self.kind, CompiledServerNamePatternKind::Exact(_)) } - pub fn priority(&self) -> (u8, usize) { - self.kind.priority() - } - - pub fn normalize_host(host: &str) -> String { - normalize_host_for_match(host) - } - + #[must_use] pub fn matches(&self, host: &str) -> Option { let hostname = normalize_host_for_match(host); self.matches_normalized_host(&hostname) } + #[must_use] pub fn matches_normalized_host(&self, hostname: &str) -> Option { self.kind.matches_hostname(hostname) } - pub(crate) fn into_kind(self) -> CompiledServerNamePatternKind { - self.kind + #[must_use] + pub fn normalize_host(host: &str) -> String { + normalize_host_for_match(host) + } + + #[must_use] + pub fn pattern(&self) -> &str { + &self.pattern + } + + #[must_use] + pub fn priority(&self) -> (u8, usize) { + self.kind.priority() } } #[derive(Debug, Clone)] pub(crate) enum CompiledServerNamePatternKind { - Exact(String), DotWildcard { suffix: String, matched: ServerNameMatch }, + Exact(String), LeadingWildcard { suffix: String, matched: ServerNameMatch }, - TrailingWildcard { prefix: String, matched: ServerNameMatch }, Regex(Regex), + TrailingWildcard { prefix: String, matched: ServerNameMatch }, } impl CompiledServerNamePatternKind { - pub(crate) fn priority(&self) -> (u8, usize) { - match self { - Self::Exact(_) => ServerNameMatch::Exact.priority(), - Self::DotWildcard { matched, .. } - | Self::LeadingWildcard { matched, .. } - | Self::TrailingWildcard { matched, .. } => matched.priority(), - Self::Regex(_) => ServerNameMatch::Regex.priority(), - } - } - pub(crate) fn compile(pattern: &str) -> Option { if let Some(regex_pattern) = pattern.strip_prefix('~') { let regex_pattern = regex_pattern.trim(); @@ -158,6 +156,15 @@ impl CompiledServerNamePatternKind { Self::Regex(regex) => regex.is_match(hostname).then_some(ServerNameMatch::Regex), } } + pub(crate) fn priority(&self) -> (u8, usize) { + match self { + Self::Exact(_) => ServerNameMatch::Exact.priority(), + Self::DotWildcard { matched, .. } + | Self::LeadingWildcard { matched, .. } + | Self::TrailingWildcard { matched, .. } => matched.priority(), + Self::Regex(_) => ServerNameMatch::Regex.priority(), + } + } } pub fn best_matching_server_name_pattern<'a>( @@ -176,6 +183,7 @@ pub fn best_matching_server_name_pattern<'a>( }) } +#[must_use] pub fn match_server_name(pattern: &str, host: &str) -> Option { CompiledServerNamePattern::compile(pattern)?.matches(host) } diff --git a/crates/rginx-core/src/config/snapshot/linear.rs b/crates/rginx-core/src/config/snapshot/linear.rs index 571c9471..9b49316a 100644 --- a/crates/rginx-core/src/config/snapshot/linear.rs +++ b/crates/rginx-core/src/config/snapshot/linear.rs @@ -2,6 +2,28 @@ use crate::config::{Route, RouteMatcher, VirtualHost}; use super::route_selection::RequestGrpcMatch; +#[derive(Clone, Copy)] +struct RouteSelectionCandidate { + grpc_rank: u8, + index: usize, + match_len: usize, + preferred_prefix: bool, +} + +impl RouteSelectionCandidate { + fn new(index: usize, match_len: usize, preferred_prefix: bool, route: &Route) -> Self { + Self { + index, + match_len, + preferred_prefix, + grpc_rank: route + .grpc_match + .as_ref() + .map_or(0, super::super::route::GrpcRouteMatch::priority), + } + } +} + pub(super) fn select_vhost_with_linear_scan<'a>( vhosts: &'a [VirtualHost], default: &'a VirtualHost, @@ -31,7 +53,7 @@ pub(super) fn select_vhost_with_linear_scan<'a>( match selected { None => selected = Some((priority, vhost)), Some((current_priority, _)) if priority > current_priority => { - selected = Some((priority, vhost)) + selected = Some((priority, vhost)); } Some(_) => {} } @@ -100,25 +122,6 @@ pub(super) fn select_route_with_linear_scan<'a>( prefix.and_then(|candidate| routes.get(candidate.index)) } -#[derive(Clone, Copy)] -struct RouteSelectionCandidate { - index: usize, - match_len: usize, - grpc_rank: u8, - preferred_prefix: bool, -} - -impl RouteSelectionCandidate { - fn new(index: usize, match_len: usize, preferred_prefix: bool, route: &Route) -> Self { - Self { - index, - match_len, - preferred_prefix, - grpc_rank: route.grpc_match.as_ref().map_or(0, |grpc| grpc.priority()), - } - } -} - fn select_more_specific( current: Option, candidate: RouteSelectionCandidate, diff --git a/crates/rginx-core/src/config/snapshot/lookup.rs b/crates/rginx-core/src/config/snapshot/lookup.rs index ae42801b..d0a64af9 100644 --- a/crates/rginx-core/src/config/snapshot/lookup.rs +++ b/crates/rginx-core/src/config/snapshot/lookup.rs @@ -8,73 +8,16 @@ use super::vhost_selection::CompiledVhostSelection; #[derive(Debug, Clone, Default)] pub struct ConfigLookup { built: bool, + default_route_selection: CompiledRouteSelection, listener_index: HashMap, - vhost_index: HashMap>, - route_index: HashMap, usize)>, named_route_index: HashMap<(Option, String), usize>, - default_route_selection: CompiledRouteSelection, + route_index: HashMap, usize)>, + vhost_index: HashMap>, vhost_route_selections: Vec, vhost_selection: CompiledVhostSelection, } impl ConfigLookup { - pub fn is_built(&self) -> bool { - self.built - } - - pub fn listener_index(&self, id: &str) -> Option { - self.built.then_some(())?; - self.listener_index.get(id).copied() - } - - pub fn vhost_index(&self, id: &str) -> Option> { - self.built.then_some(())?; - self.vhost_index.get(id).copied() - } - - pub fn route_index(&self, id: &str) -> Option<(Option, usize)> { - self.built.then_some(())?; - self.route_index.get(id).copied() - } - - pub fn select_vhost_index(&self, host: &str, listener_id: Option<&str>) -> Option { - self.built.then_some(())?; - self.vhost_selection.select(host, listener_id) - } - - pub fn select_route_index( - &self, - snapshot: &ConfigSnapshot, - vhost_index: Option, - path: &str, - grpc_service: Option<&str>, - grpc_method: Option<&str>, - ) -> Option { - self.built.then_some(())?; - let request_grpc = RequestGrpcMatch::new(grpc_service, grpc_method); - match vhost_index { - None => self.default_route_selection.select( - &snapshot.default_vhost.routes, - path, - request_grpc, - ), - Some(idx) => self.vhost_route_selections.get(idx)?.select( - &snapshot.vhosts.get(idx)?.routes, - path, - request_grpc, - ), - } - } - - pub fn select_named_route_index( - &self, - vhost_index: Option, - name: &str, - ) -> Option { - self.built.then_some(())?; - self.named_route_index.get(&(vhost_index, name.to_string())).copied() - } - pub(crate) fn compile(snapshot: &ConfigSnapshot) -> Self { let mut lookup = Self { built: true, ..Self::default() }; @@ -120,4 +63,66 @@ impl ConfigLookup { lookup.vhost_selection = CompiledVhostSelection::compile(&snapshot.vhosts); lookup } + #[must_use] + pub fn is_built(&self) -> bool { + self.built + } + + #[must_use] + pub fn listener_index(&self, id: &str) -> Option { + self.built.then_some(())?; + self.listener_index.get(id).copied() + } + + #[must_use] + pub fn route_index(&self, id: &str) -> Option<(Option, usize)> { + self.built.then_some(())?; + self.route_index.get(id).copied() + } + + #[must_use] + pub fn select_named_route_index( + &self, + vhost_index: Option, + name: &str, + ) -> Option { + self.built.then_some(())?; + self.named_route_index.get(&(vhost_index, name.to_string())).copied() + } + + pub fn select_route_index( + &self, + snapshot: &ConfigSnapshot, + vhost_index: Option, + path: &str, + grpc_service: Option<&str>, + grpc_method: Option<&str>, + ) -> Option { + self.built.then_some(())?; + let request_grpc = RequestGrpcMatch::new(grpc_service, grpc_method); + match vhost_index { + None => self.default_route_selection.select( + &snapshot.default_vhost.routes, + path, + request_grpc, + ), + Some(idx) => self.vhost_route_selections.get(idx)?.select( + &snapshot.vhosts.get(idx)?.routes, + path, + request_grpc, + ), + } + } + + #[must_use] + pub fn select_vhost_index(&self, host: &str, listener_id: Option<&str>) -> Option { + self.built.then_some(())?; + self.vhost_selection.select(host, listener_id) + } + + #[must_use] + pub fn vhost_index(&self, id: &str) -> Option> { + self.built.then_some(())?; + self.vhost_index.get(id).copied() + } } diff --git a/crates/rginx-core/src/config/snapshot/mod.rs b/crates/rginx-core/src/config/snapshot/mod.rs index 3aba1b16..7b4c8443 100644 --- a/crates/rginx-core/src/config/snapshot/mod.rs +++ b/crates/rginx-core/src/config/snapshot/mod.rs @@ -1,3 +1,8 @@ +mod linear; +mod lookup; +mod route_selection; +mod vhost_selection; + use std::collections::HashMap; use std::sync::Arc; @@ -7,54 +12,24 @@ use super::{ }; use crate::config::{AgentSettings, ControlPlaneSettings}; -mod linear; -mod lookup; -mod route_selection; -mod vhost_selection; - pub use lookup::ConfigLookup; #[derive(Debug, Clone)] pub struct ConfigSnapshot { - pub runtime: RuntimeSettings, + pub acme: Option, pub agent: Option, + pub cache_zones: HashMap>, pub control_plane: Option, - pub acme: Option, - pub managed_certificates: Vec, - pub listeners: Vec, pub default_vhost: VirtualHost, - pub vhosts: Vec, - pub cache_zones: HashMap>, - pub upstreams: HashMap>, + pub listeners: Vec, pub lookup: ConfigLookup, + pub managed_certificates: Vec, + pub runtime: RuntimeSettings, + pub upstreams: HashMap>, + pub vhosts: Vec, } impl ConfigSnapshot { - pub fn rebuild_lookup(&mut self) { - self.lookup = ConfigLookup::compile(self); - } - - pub fn total_route_count(&self) -> usize { - self.default_vhost.routes.len() - + self.vhosts.iter().map(|vhost| vhost.routes.len()).sum::() - } - - pub fn total_vhost_count(&self) -> usize { - 1 + self.vhosts.len() - } - - pub fn total_listener_count(&self) -> usize { - self.listeners.len() - } - - pub fn total_listener_binding_count(&self) -> usize { - self.listeners.iter().map(Listener::binding_count).sum() - } - - pub fn tls_enabled(&self) -> bool { - self.listeners.iter().any(Listener::tls_enabled) - } - pub fn http3_enabled(&self) -> bool { self.listeners.iter().any(Listener::http3_enabled) } @@ -67,15 +42,8 @@ impl ConfigSnapshot { self.listeners.iter().find(|listener| listener.id == id) } - pub fn vhost(&self, id: &str) -> Option<&VirtualHost> { - if let Some(compiled) = self.lookup.vhost_index(id) { - return match compiled { - None => Some(&self.default_vhost), - Some(idx) => self.vhosts.get(idx), - }; - } - - self.vhosts.iter().chain(std::iter::once(&self.default_vhost)).find(|vhost| vhost.id == id) + pub fn rebuild_lookup(&mut self) { + self.lookup = ConfigLookup::compile(self); } pub fn route(&self, id: &str) -> Option<(&VirtualHost, &Route)> { @@ -96,16 +64,32 @@ impl ConfigSnapshot { }) } - pub fn select_vhost_for_listener(&self, host: &str, listener_id: Option<&str>) -> &VirtualHost { + pub fn select_named_route_in_vhost<'a>( + &'a self, + vhost: &'a VirtualHost, + name: &str, + ) -> Option<&'a Route> { if self.lookup.is_built() { - return self - .lookup - .select_vhost_index(host, listener_id) - .and_then(|idx| self.vhosts.get(idx)) - .unwrap_or(&self.default_vhost); + if let Some(vhost_index) = self.lookup.vhost_index(&vhost.id) { + let indexed_vhost = match vhost_index { + None => &self.default_vhost, + Some(idx) => self.vhosts.get(idx)?, + }; + if let Some(route_index) = self.lookup.select_named_route_index(vhost_index, name) { + return indexed_vhost.routes.get(route_index); + } + return None; + } + return vhost.routes.iter().find(|route| match &route.matcher { + RouteMatcher::Named(candidate) => candidate == name, + _ => false, + }); } - linear::select_vhost_with_linear_scan(&self.vhosts, &self.default_vhost, host, listener_id) + vhost.routes.iter().find(|route| match &route.matcher { + RouteMatcher::Named(candidate) => candidate == name, + _ => false, + }) } pub fn select_route_in_vhost<'a>( @@ -143,31 +127,49 @@ impl ConfigSnapshot { linear::select_route_with_linear_scan(&vhost.routes, path, grpc_service, grpc_method) } - pub fn select_named_route_in_vhost<'a>( - &'a self, - vhost: &'a VirtualHost, - name: &str, - ) -> Option<&'a Route> { + pub fn select_vhost_for_listener(&self, host: &str, listener_id: Option<&str>) -> &VirtualHost { if self.lookup.is_built() { - if let Some(vhost_index) = self.lookup.vhost_index(&vhost.id) { - let indexed_vhost = match vhost_index { - None => &self.default_vhost, - Some(idx) => self.vhosts.get(idx)?, - }; - if let Some(route_index) = self.lookup.select_named_route_index(vhost_index, name) { - return indexed_vhost.routes.get(route_index); - } - return None; - } - return vhost.routes.iter().find(|route| match &route.matcher { - RouteMatcher::Named(candidate) => candidate == name, - _ => false, - }); + return self + .lookup + .select_vhost_index(host, listener_id) + .and_then(|idx| self.vhosts.get(idx)) + .unwrap_or(&self.default_vhost); } - vhost.routes.iter().find(|route| match &route.matcher { - RouteMatcher::Named(candidate) => candidate == name, - _ => false, - }) + linear::select_vhost_with_linear_scan(&self.vhosts, &self.default_vhost, host, listener_id) + } + + pub fn tls_enabled(&self) -> bool { + self.listeners.iter().any(Listener::tls_enabled) + } + + pub fn total_listener_binding_count(&self) -> usize { + self.listeners.iter().map(Listener::binding_count).sum() + } + + pub fn total_listener_count(&self) -> usize { + self.listeners.len() + } + + pub fn total_route_count(&self) -> usize { + self.default_vhost + .routes + .len() + .saturating_add(self.vhosts.iter().map(|vhost| vhost.routes.len()).sum::()) + } + + pub fn total_vhost_count(&self) -> usize { + self.vhosts.len().saturating_add(1) + } + + pub fn vhost(&self, id: &str) -> Option<&VirtualHost> { + if let Some(compiled) = self.lookup.vhost_index(id) { + return match compiled { + None => Some(&self.default_vhost), + Some(idx) => self.vhosts.get(idx), + }; + } + + self.vhosts.iter().chain(std::iter::once(&self.default_vhost)).find(|vhost| vhost.id == id) } } diff --git a/crates/rginx-core/src/config/snapshot/route_selection.rs b/crates/rginx-core/src/config/snapshot/route_selection.rs index 7c775f80..78532522 100644 --- a/crates/rginx-core/src/config/snapshot/route_selection.rs +++ b/crates/rginx-core/src/config/snapshot/route_selection.rs @@ -63,15 +63,11 @@ impl CompiledRouteSelection { #[derive(Clone, Copy)] pub(super) struct RequestGrpcMatch<'a> { - service: Option<&'a str>, method: Option<&'a str>, + service: Option<&'a str>, } impl<'a> RequestGrpcMatch<'a> { - pub(super) fn new(service: Option<&'a str>, method: Option<&'a str>) -> Self { - Self { service, method } - } - pub(super) fn matches(self, route: &Route) -> bool { route.grpc_match.as_ref().is_none_or(|candidate| { let Some(service) = self.service else { @@ -83,6 +79,17 @@ impl<'a> RequestGrpcMatch<'a> { candidate.matches(service, method) }) } + pub(super) fn new(service: Option<&'a str>, method: Option<&'a str>) -> Self { + Self { method, service } + } +} + +#[derive(Clone, Copy)] +struct RouteSelectionCandidate { + grpc_rank: u8, + index: usize, + match_len: usize, + preferred_prefix: bool, } fn select_route_candidate( @@ -138,7 +145,7 @@ fn path_prefix_candidates(path: &str) -> Vec<&str> { if bytes[idx] != b'/' { continue; } - let inclusive = &path[..idx + 1]; + let inclusive = &path[..idx.saturating_add(1)]; if candidates.last().copied() != Some(inclusive) { candidates.push(inclusive); } @@ -152,14 +159,6 @@ fn path_prefix_candidates(path: &str) -> Vec<&str> { candidates } -#[derive(Clone, Copy)] -struct RouteSelectionCandidate { - index: usize, - match_len: usize, - grpc_rank: u8, - preferred_prefix: bool, -} - fn select_more_specific( current: Option, candidate: RouteSelectionCandidate, diff --git a/crates/rginx-core/src/config/snapshot/vhost_selection.rs b/crates/rginx-core/src/config/snapshot/vhost_selection.rs index 24e5f275..3a131d7f 100644 --- a/crates/rginx-core/src/config/snapshot/vhost_selection.rs +++ b/crates/rginx-core/src/config/snapshot/vhost_selection.rs @@ -8,11 +8,11 @@ use crate::config::server_name::{ #[derive(Debug, Clone, Default)] pub(super) struct CompiledVhostSelection { exact: HashMap>, + listener_defaults: HashMap, + listener_ids: Vec>, rank_two_wildcards: Vec<(usize, Vec)>, - trailing_wildcards: Vec<(usize, Vec)>, regex: Vec, - listener_ids: Vec>, - listener_defaults: HashMap, + trailing_wildcards: Vec<(usize, Vec)>, } impl CompiledVhostSelection { @@ -84,6 +84,16 @@ impl CompiledVhostSelection { compiled } + fn matches_listener(&self, vhost_index: usize, listener_id: Option<&str>) -> bool { + let Some(listener_id) = listener_id else { + return true; + }; + let Some(listener_ids) = self.listener_ids.get(vhost_index) else { + return false; + }; + listener_ids.is_empty() || listener_ids.iter().any(|candidate| candidate == listener_id) + } + pub(super) fn select(&self, host: &str, listener_id: Option<&str>) -> Option { let hostname = normalize_host_for_match(host); @@ -99,7 +109,7 @@ impl CompiledVhostSelection { self.matches_listener(candidate.vhost_index(), listener_id) .then_some(candidate) .and_then(|candidate| candidate.matches(&hostname)) - .map(|_| candidate.vhost_index()) + .map(|()| candidate.vhost_index()) }) { return Some(vhost_index); } @@ -110,7 +120,7 @@ impl CompiledVhostSelection { self.matches_listener(candidate.vhost_index(), listener_id) .then_some(candidate) .and_then(|candidate| candidate.matches(&hostname)) - .map(|_| candidate.vhost_index()) + .map(|()| candidate.vhost_index()) }) { return Some(vhost_index); } @@ -126,16 +136,6 @@ impl CompiledVhostSelection { listener_id.and_then(|id| self.listener_defaults.get(id).copied()) } - - fn matches_listener(&self, vhost_index: usize, listener_id: Option<&str>) -> bool { - let Some(listener_id) = listener_id else { - return true; - }; - let Some(listener_ids) = self.listener_ids.get(vhost_index) else { - return false; - }; - listener_ids.is_empty() || listener_ids.iter().any(|candidate| candidate == listener_id) - } } #[derive(Debug, Clone)] @@ -146,14 +146,6 @@ enum VhostServerNameCandidate { } impl VhostServerNameCandidate { - fn vhost_index(&self) -> usize { - match self { - Self::Dot { vhost_index, .. } - | Self::Leading { vhost_index, .. } - | Self::Trailing { vhost_index, .. } => *vhost_index, - } - } - fn matches(&self, hostname: &str) -> Option<()> { match self { Self::Dot { suffix, .. } => { @@ -176,12 +168,19 @@ impl VhostServerNameCandidate { .map(|_| ()), } } + fn vhost_index(&self) -> usize { + match self { + Self::Dot { vhost_index, .. } + | Self::Leading { vhost_index, .. } + | Self::Trailing { vhost_index, .. } => *vhost_index, + } + } } #[derive(Debug, Clone)] struct VhostRegexCandidate { - vhost_index: usize, regex: regex::Regex, + vhost_index: usize, } fn sort_groups_desc(groups: HashMap>) -> Vec<(usize, Vec)> { diff --git a/crates/rginx-core/src/config/tests/core.rs b/crates/rginx-core/src/config/tests/core.rs index d49d1d23..914c0728 100644 --- a/crates/rginx-core/src/config/tests/core.rs +++ b/crates/rginx-core/src/config/tests/core.rs @@ -207,7 +207,7 @@ fn listener_transport_bindings_include_udp_http3_binding_when_configured() { http3: Some(ListenerHttp3 { listen_addr: "127.0.0.1:443".parse().unwrap(), advertise_alt_svc: true, - alt_svc_max_age: Duration::from_secs(3600), + alt_svc_max_age: Duration::from_hours(1), max_concurrent_streams: 128, stream_buffer_size: 64 * 1024, active_connection_id_limit: 2, diff --git a/crates/rginx-core/src/config/tests/mod.rs b/crates/rginx-core/src/config/tests/mod.rs index 4e21653a..9bf5b758 100644 --- a/crates/rginx-core/src/config/tests/mod.rs +++ b/crates/rginx-core/src/config/tests/mod.rs @@ -1,3 +1,8 @@ +mod access_log; +mod core; +mod proxy_header; +mod route_matcher; +mod snapshot_lookup; use std::collections::HashMap; use std::net::IpAddr; use std::time::Duration; @@ -9,9 +14,3 @@ use super::{ ListenerHttp3, ListenerTransportKind, ReturnAction, Route, RouteAccessControl, RouteAction, RouteMatcher, RuntimeSettings, Server, VirtualHost, default_server_header, match_server_name, }; - -mod access_log; -mod core; -mod proxy_header; -mod route_matcher; -mod snapshot_lookup; diff --git a/crates/rginx-core/src/config/tls.rs b/crates/rginx-core/src/config/tls.rs index cc12ca20..3adcbf2e 100644 --- a/crates/rginx-core/src/config/tls.rs +++ b/crates/rginx-core/src/config/tls.rs @@ -25,8 +25,8 @@ pub struct OcspConfig { pub struct ServerCertificateBundle { pub cert_path: PathBuf, pub key_path: PathBuf, - pub ocsp_staple_path: Option, pub ocsp: OcspConfig, + pub ocsp_staple_path: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -37,26 +37,26 @@ pub enum TlsVersion { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TlsCipherSuite { - Tls13Aes256GcmSha384, Tls13Aes128GcmSha256, + Tls13Aes256GcmSha384, Tls13Chacha20Poly1305Sha256, - TlsEcdheEcdsaWithAes256GcmSha384, TlsEcdheEcdsaWithAes128GcmSha256, + TlsEcdheEcdsaWithAes256GcmSha384, TlsEcdheEcdsaWithChacha20Poly1305Sha256, - TlsEcdheRsaWithAes256GcmSha384, TlsEcdheRsaWithAes128GcmSha256, + TlsEcdheRsaWithAes256GcmSha384, TlsEcdheRsaWithChacha20Poly1305Sha256, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum TlsKeyExchangeGroup { - X25519, + Mlkem1024, + Mlkem768, Secp256r1, + Secp256r1Mlkem768, Secp384r1, + X25519, X25519Mlkem768, - Secp256r1Mlkem768, - Mlkem768, - Mlkem1024, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -67,10 +67,10 @@ pub enum ServerClientAuthMode { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct ServerClientAuthPolicy { - pub mode: ServerClientAuthMode, pub ca_cert_path: PathBuf, - pub verify_depth: Option, pub crl_path: Option, + pub mode: ServerClientAuthMode, + pub verify_depth: Option, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -81,27 +81,27 @@ pub struct ClientIdentity { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct VirtualHostTls { + pub additional_certificates: Vec, pub cert_path: PathBuf, pub key_path: PathBuf, - pub additional_certificates: Vec, - pub ocsp_staple_path: Option, pub ocsp: OcspConfig, + pub ocsp_staple_path: Option, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ServerTls { - pub cert_path: PathBuf, - pub key_path: PathBuf, pub additional_certificates: Vec, - pub versions: Option>, + pub alpn_protocols: Option>, + pub cert_path: PathBuf, pub cipher_suites: Option>, + pub client_auth: Option, pub key_exchange_groups: Option>, - pub alpn_protocols: Option>, - pub ocsp_staple_path: Option, + pub key_path: PathBuf, pub ocsp: OcspConfig, - pub session_resumption: Option, - pub session_tickets: Option, + pub ocsp_staple_path: Option, pub session_cache_size: Option, + pub session_resumption: Option, pub session_ticket_count: Option, - pub client_auth: Option, + pub session_tickets: Option, + pub versions: Option>, } diff --git a/crates/rginx-core/src/config/upstream.rs b/crates/rginx-core/src/config/upstream.rs index 3f73e0bd..837a42d1 100644 --- a/crates/rginx-core/src/config/upstream.rs +++ b/crates/rginx-core/src/config/upstream.rs @@ -1,3 +1,6 @@ +mod selection; +mod types; + use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use std::sync::atomic::AtomicUsize; @@ -5,9 +8,6 @@ use std::time::Duration; use super::tls::{ClientIdentity, TlsVersion}; -mod selection; -mod types; - pub use types::{ ActiveHealthCheck, Upstream, UpstreamDnsPolicy, UpstreamLoadBalance, UpstreamPeer, UpstreamProtocol, UpstreamSettings, UpstreamTls, diff --git a/crates/rginx-core/src/config/upstream/selection.rs b/crates/rginx-core/src/config/upstream/selection.rs index 97ca18f8..41014de4 100644 --- a/crates/rginx-core/src/config/upstream/selection.rs +++ b/crates/rginx-core/src/config/upstream/selection.rs @@ -1,8 +1,46 @@ use std::sync::atomic::Ordering; -use super::*; +use super::{ + AtomicUsize, IpAddr, Upstream, UpstreamLoadBalance, UpstreamPeer, UpstreamSettings, UpstreamTls, +}; impl Upstream { + pub fn backup_next_peers(&self, limit: usize) -> Vec { + self.next_peers_in_pool(limit, true) + } + + pub fn backup_peers_for_client_ip(&self, client_ip: IpAddr, limit: usize) -> Vec { + self.peers_for_client_ip_in_pool(client_ip, limit, true) + } + + pub fn has_primary_peers(&self) -> bool { + self.peers.iter().any(|peer| !peer.backup) + } + + fn ip_hash_peers_in_pool( + &self, + client_ip: IpAddr, + limit: usize, + peer_indices: &[usize], + ) -> Vec { + if peer_indices.is_empty() || limit == 0 { + return Vec::new(); + } + + let total_weight = self.total_weight_for_indices(peer_indices); + if total_weight == 0 { + return peer_indices + .iter() + .take(limit) + .map(|index| self.peers[*index].clone()) + .collect(); + } + + let start = (stable_ip_hash(client_ip) as usize).rem_euclid(total_weight); + self.weighted_peers_from_indices(peer_indices, start, limit) + } + + #[must_use] pub fn new( name: String, peers: Vec, @@ -50,39 +88,6 @@ impl Upstream { if primary.is_empty() { self.next_peers_in_pool(limit, true) } else { primary } } - pub fn primary_next_peers(&self, limit: usize) -> Vec { - self.next_peers_in_pool(limit, false) - } - - pub fn backup_next_peers(&self, limit: usize) -> Vec { - self.next_peers_in_pool(limit, true) - } - - pub fn has_primary_peers(&self) -> bool { - self.peers.iter().any(|peer| !peer.backup) - } - - pub fn peers_for_client_ip(&self, client_ip: IpAddr, limit: usize) -> Vec { - let primary = self.peers_for_client_ip_in_pool(client_ip, limit, false); - if primary.is_empty() { - self.peers_for_client_ip_in_pool(client_ip, limit, true) - } else { - primary - } - } - - pub fn primary_peers_for_client_ip( - &self, - client_ip: IpAddr, - limit: usize, - ) -> Vec { - self.peers_for_client_ip_in_pool(client_ip, limit, false) - } - - pub fn backup_peers_for_client_ip(&self, client_ip: IpAddr, limit: usize) -> Vec { - self.peers_for_client_ip_in_pool(client_ip, limit, true) - } - fn next_peers_in_pool(&self, limit: usize, backup: bool) -> Vec { let peer_indices = self.peer_indices_for_pool(backup); if peer_indices.is_empty() || limit == 0 { @@ -98,10 +103,44 @@ impl Upstream { .collect(); } - let start = self.cursor.fetch_add(1, Ordering::Relaxed) % total_weight; + let start = self.cursor.fetch_add(1, Ordering::Relaxed).rem_euclid(total_weight); self.weighted_peers_from_indices(&peer_indices, start, limit) } + fn peer_indices_for_pool(&self, backup: bool) -> Vec { + self.peers + .iter() + .enumerate() + .filter_map(|(index, peer)| (peer.backup == backup).then_some(index)) + .collect() + } + + fn peer_position_for_weighted_slot( + &self, + peer_indices: &[usize], + slot: usize, + ) -> Option { + let mut offset = 0usize; + + for (position, index) in peer_indices.iter().enumerate() { + offset = offset.saturating_add(self.peers[*index].weight as usize); + if slot < offset { + return Some(position); + } + } + + None + } + + pub fn peers_for_client_ip(&self, client_ip: IpAddr, limit: usize) -> Vec { + let primary = self.peers_for_client_ip_in_pool(client_ip, limit, false); + if primary.is_empty() { + self.peers_for_client_ip_in_pool(client_ip, limit, true) + } else { + primary + } + } + fn peers_for_client_ip_in_pool( &self, client_ip: IpAddr, @@ -122,35 +161,16 @@ impl Upstream { } } - fn ip_hash_peers_in_pool( + pub fn primary_next_peers(&self, limit: usize) -> Vec { + self.next_peers_in_pool(limit, false) + } + + pub fn primary_peers_for_client_ip( &self, client_ip: IpAddr, limit: usize, - peer_indices: &[usize], ) -> Vec { - if peer_indices.is_empty() || limit == 0 { - return Vec::new(); - } - - let total_weight = self.total_weight_for_indices(peer_indices); - if total_weight == 0 { - return peer_indices - .iter() - .take(limit) - .map(|index| self.peers[*index].clone()) - .collect(); - } - - let start = stable_ip_hash(client_ip) as usize % total_weight; - self.weighted_peers_from_indices(peer_indices, start, limit) - } - - fn peer_indices_for_pool(&self, backup: bool) -> Vec { - self.peers - .iter() - .enumerate() - .filter_map(|(index, peer)| (peer.backup == backup).then_some(index)) - .collect() + self.peers_for_client_ip_in_pool(client_ip, limit, false) } fn total_weight_for_indices(&self, peer_indices: &[usize]) -> usize { @@ -173,7 +193,7 @@ impl Upstream { let mut seen = vec![false; peer_indices.len()]; for offset in 0..total_weight { - let slot = (start + offset) % total_weight; + let slot = start.saturating_add(offset).rem_euclid(total_weight); let Some(position) = self.peer_position_for_weighted_slot(peer_indices, slot) else { continue; }; @@ -191,23 +211,6 @@ impl Upstream { selected } - - fn peer_position_for_weighted_slot( - &self, - peer_indices: &[usize], - slot: usize, - ) -> Option { - let mut offset = 0usize; - - for (position, index) in peer_indices.iter().enumerate() { - offset += self.peers[*index].weight as usize; - if slot < offset { - return Some(position); - } - } - - None - } } fn stable_ip_hash(client_ip: IpAddr) -> u64 { diff --git a/crates/rginx-core/src/config/upstream/types.rs b/crates/rginx-core/src/config/upstream/types.rs index a6162575..244759f7 100644 --- a/crates/rginx-core/src/config/upstream/types.rs +++ b/crates/rginx-core/src/config/upstream/types.rs @@ -1,24 +1,25 @@ -use super::*; +use super::{AtomicUsize, ClientIdentity, Duration, PathBuf, SocketAddr, TlsVersion}; #[derive(Debug, Clone)] pub struct ActiveHealthCheck { - pub path: String, pub grpc_service: Option, + pub healthy_successes_required: u32, pub interval: Duration, + pub path: String, pub timeout: Duration, - pub healthy_successes_required: u32, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum UpstreamProtocol { Auto, + H2c, Http1, Http2, - H2c, Http3, } impl UpstreamProtocol { + #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::Auto => "auto", @@ -32,12 +33,13 @@ impl UpstreamProtocol { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum UpstreamLoadBalance { - RoundRobin, IpHash, LeastConn, + RoundRobin, } impl UpstreamLoadBalance { + #[must_use] pub const fn as_str(self) -> &'static str { match self { Self::RoundRobin => "round_robin", @@ -49,14 +51,14 @@ impl UpstreamLoadBalance { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct UpstreamDnsPolicy { - pub resolver_addrs: Vec, - pub min_ttl: Duration, pub max_ttl: Duration, + pub min_ttl: Duration, pub negative_ttl: Duration, - pub stale_if_error: Duration, - pub refresh_before_expiry: Duration, pub prefer_ipv4: bool, pub prefer_ipv6: bool, + pub refresh_before_expiry: Duration, + pub resolver_addrs: Vec, + pub stale_if_error: Duration, } impl Default for UpstreamDnsPolicy { @@ -64,9 +66,9 @@ impl Default for UpstreamDnsPolicy { Self { resolver_addrs: Vec::new(), min_ttl: Duration::from_secs(5), - max_ttl: Duration::from_secs(300), + max_ttl: Duration::from_mins(5), negative_ttl: Duration::from_secs(30), - stale_if_error: Duration::from_secs(60), + stale_if_error: Duration::from_mins(1), refresh_before_expiry: Duration::from_secs(10), prefer_ipv4: false, prefer_ipv6: false, @@ -76,77 +78,77 @@ impl Default for UpstreamDnsPolicy { #[derive(Debug, Clone)] pub struct UpstreamSettings { - pub protocol: UpstreamProtocol, - pub load_balance: UpstreamLoadBalance, - pub dns: UpstreamDnsPolicy, - pub server_name: bool, - pub server_name_override: Option, - pub tls_versions: Option>, - pub server_verify_depth: Option, - pub server_crl_path: Option, + pub active_health_check: Option, pub client_identity: Option, - pub request_timeout: Duration, pub connect_timeout: Duration, - pub write_timeout: Duration, + pub dns: UpstreamDnsPolicy, + pub http2_keep_alive_interval: Option, + pub http2_keep_alive_timeout: Duration, + pub http2_keep_alive_while_idle: bool, pub idle_timeout: Duration, + pub load_balance: UpstreamLoadBalance, + pub max_replayable_request_body_bytes: usize, pub pool_idle_timeout: Option, pub pool_max_idle_per_host: usize, + pub protocol: UpstreamProtocol, + pub request_timeout: Duration, + pub server_crl_path: Option, + pub server_name: bool, + pub server_name_override: Option, + pub server_verify_depth: Option, pub tcp_keepalive: Option, pub tcp_nodelay: bool, - pub http2_keep_alive_interval: Option, - pub http2_keep_alive_timeout: Duration, - pub http2_keep_alive_while_idle: bool, - pub max_replayable_request_body_bytes: usize, + pub tls_versions: Option>, pub unhealthy_after_failures: u32, pub unhealthy_cooldown: Duration, - pub active_health_check: Option, + pub write_timeout: Duration, } #[derive(Debug)] pub struct Upstream { + pub active_health_check: Option, + pub client_identity: Option, + pub connect_timeout: Duration, + pub(super) cursor: AtomicUsize, + pub dns: UpstreamDnsPolicy, + pub http2_keep_alive_interval: Option, + pub http2_keep_alive_timeout: Duration, + pub http2_keep_alive_while_idle: bool, + pub idle_timeout: Duration, + pub load_balance: UpstreamLoadBalance, + pub max_replayable_request_body_bytes: usize, pub name: String, pub peers: Vec, - pub tls: UpstreamTls, + pub pool_idle_timeout: Option, + pub pool_max_idle_per_host: usize, pub protocol: UpstreamProtocol, - pub load_balance: UpstreamLoadBalance, - pub dns: UpstreamDnsPolicy, + pub request_timeout: Duration, + pub server_crl_path: Option, pub server_name: bool, pub server_name_override: Option, - pub tls_versions: Option>, pub server_verify_depth: Option, - pub server_crl_path: Option, - pub client_identity: Option, - pub request_timeout: Duration, - pub connect_timeout: Duration, - pub write_timeout: Duration, - pub idle_timeout: Duration, - pub pool_idle_timeout: Option, - pub pool_max_idle_per_host: usize, pub tcp_keepalive: Option, pub tcp_nodelay: bool, - pub http2_keep_alive_interval: Option, - pub http2_keep_alive_timeout: Duration, - pub http2_keep_alive_while_idle: bool, - pub max_replayable_request_body_bytes: usize, + pub tls: UpstreamTls, + pub tls_versions: Option>, pub unhealthy_after_failures: u32, pub unhealthy_cooldown: Duration, - pub active_health_check: Option, - pub(super) cursor: AtomicUsize, + pub write_timeout: Duration, } #[derive(Debug, Clone)] pub struct UpstreamPeer { - pub url: String, - pub scheme: String, pub authority: String, - pub weight: u32, pub backup: bool, pub max_conns: Option, + pub scheme: String, + pub url: String, + pub weight: u32, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum UpstreamTls { - NativeRoots, CustomCa { ca_cert_path: PathBuf }, Insecure, + NativeRoots, } diff --git a/crates/rginx-core/src/config/virtual_host.rs b/crates/rginx-core/src/config/virtual_host.rs index 2edea366..33d5aca6 100644 --- a/crates/rginx-core/src/config/virtual_host.rs +++ b/crates/rginx-core/src/config/virtual_host.rs @@ -2,21 +2,21 @@ use super::{Route, ServerNameMatch, VirtualHostTls, best_matching_server_name_pa #[derive(Debug, Clone)] pub struct VirtualHost { + pub default_listener_ids: Vec, pub id: String, pub listener_ids: Vec, - pub default_listener_ids: Vec, - pub server_names: Vec, pub routes: Vec, + pub server_names: Vec, pub tls: Option, } impl VirtualHost { - pub fn matches_host(&self, host: &str) -> bool { - self.server_names.is_empty() || self.best_server_name_match(host).is_some() - } - pub fn best_server_name_match(&self, host: &str) -> Option { best_matching_server_name_pattern(self.server_names.iter().map(String::as_str), host) .map(|(_, matched)| matched) } + #[must_use] + pub fn matches_host(&self, host: &str) -> bool { + self.server_names.is_empty() || self.best_server_name_match(host).is_some() + } } diff --git a/crates/rginx-core/src/error.rs b/crates/rginx-core/src/error.rs index 91f1b164..82e478af 100644 --- a/crates/rginx-core/src/error.rs +++ b/crates/rginx-core/src/error.rs @@ -2,16 +2,16 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum Error { - #[error("io error: {0}")] - Io(#[from] std::io::Error), #[error("address parse error: {0}")] AddrParse(#[from] std::net::AddrParseError), - #[error("invalid uri: {0}")] - InvalidUri(#[from] http::uri::InvalidUri), - #[error("invalid status code: {0}")] - InvalidStatusCode(#[from] http::status::InvalidStatusCode), #[error("configuration error: {0}")] Config(String), + #[error("invalid status code: {0}")] + InvalidStatusCode(#[from] http::status::InvalidStatusCode), + #[error("invalid uri: {0}")] + InvalidUri(#[from] http::uri::InvalidUri), + #[error("io error: {0}")] + Io(#[from] std::io::Error), #[error("server error: {0}")] Server(String), } diff --git a/crates/rginx-core/src/lib.rs b/crates/rginx-core/src/lib.rs index 9ec95e6a..496270f5 100644 --- a/crates/rginx-core/src/lib.rs +++ b/crates/rginx-core/src/lib.rs @@ -1,3 +1,12 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] pub mod config; pub mod context; pub mod error; diff --git a/crates/rginx-http/Cargo.toml b/crates/rginx-http/Cargo.toml index 4de94b6a..5e196214 100644 --- a/crates/rginx-http/Cargo.toml +++ b/crates/rginx-http/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true rust-version.workspace = true description = "HTTP server, routing, and proxying for rginx." +[lints] +workspace = true + [dependencies] aws-lc-rs = "1.17.0" base64.workspace = true diff --git a/crates/rginx-http/src/cache/entry.rs b/crates/rginx-http/src/cache/entry.rs index 9d78150c..f5d6d113 100644 --- a/crates/rginx-http/src/cache/entry.rs +++ b/crates/rginx-http/src/cache/entry.rs @@ -1,3 +1,9 @@ +mod metadata; +mod response; +mod signature; +mod temp; +mod write; + use bytes::Bytes; use http::StatusCode; use http::header::{ @@ -17,12 +23,6 @@ use super::{ CachedVaryHeaderValue, PreparedCacheResponseHead, build_conditional_headers, }; -mod metadata; -mod response; -mod signature; -mod temp; -mod write; - pub(in crate::cache) use metadata::{ CacheMetadata, CacheMetadataInput, CachedHeader, cache_metadata, prepare_cached_response_head, read_cache_metadata, @@ -40,9 +40,9 @@ pub(in crate::cache) use write::{ }; pub(super) struct CachePaths { + pub(super) body: PathBuf, pub(super) dir: PathBuf, pub(super) metadata: PathBuf, - pub(super) body: PathBuf, } pub(super) async fn read_cached_response_for_request( @@ -204,7 +204,8 @@ fn cache_paths_with_levels(base: &Path, levels: &[usize], hash: &str) -> CachePa } fn cache_path_segment(hash: &str, offset: usize, level_len: usize) -> String { - hash.get(offset..offset.saturating_add(level_len)).map(str::to_string).unwrap_or_else(|| { - format!("{:0, - #[serde(default)] - pub(in crate::cache) tags: Vec, - pub(in crate::cache) status: u16, + pub(in crate::cache) grace_until_unix_ms: Option, pub(in crate::cache) headers: Vec, - pub(in crate::cache) stored_at_unix_ms: u64, - pub(in crate::cache) expires_at_unix_ms: u64, + #[serde(default)] + pub(in crate::cache) keep_until_unix_ms: Option, + #[serde(default)] + pub(in crate::cache) key: String, #[serde(default)] pub(in crate::cache) kind: CacheIndexEntryKind, #[serde(default)] - pub(in crate::cache) grace_until_unix_ms: Option, + pub(in crate::cache) must_revalidate: bool, #[serde(default)] - pub(in crate::cache) keep_until_unix_ms: Option, + pub(in crate::cache) requires_revalidation: bool, #[serde(default)] pub(in crate::cache) stale_if_error_until_unix_ms: Option, #[serde(default)] pub(in crate::cache) stale_while_revalidate_until_unix_ms: Option, + pub(in crate::cache) status: u16, + pub(in crate::cache) stored_at_unix_ms: u64, #[serde(default)] - pub(in crate::cache) requires_revalidation: bool, + pub(in crate::cache) tags: Vec, #[serde(default)] - pub(in crate::cache) must_revalidate: bool, - pub(in crate::cache) body_size_bytes: usize, + pub(in crate::cache) vary: Vec, } #[derive(Debug, Deserialize)] struct RawCacheMetadata { - #[serde(default)] - key: String, #[serde(default)] base_key: String, + body_size_bytes: usize, + expires_at_unix_ms: u64, #[serde(default)] - vary: Vec, - #[serde(default)] - tags: Vec, - status: u16, + grace_until_unix_ms: Option, headers: Vec, - stored_at_unix_ms: u64, - expires_at_unix_ms: u64, + #[serde(default)] + keep_until_unix_ms: Option, + #[serde(default)] + key: String, #[serde(default)] kind: CacheIndexEntryKind, #[serde(default)] - grace_until_unix_ms: Option, + must_revalidate: bool, #[serde(default)] - keep_until_unix_ms: Option, + requires_revalidation: Option, #[serde(default)] stale_if_error_until_unix_ms: Option, #[serde(default)] stale_while_revalidate_until_unix_ms: Option, + status: u16, + stored_at_unix_ms: u64, #[serde(default)] - requires_revalidation: Option, + tags: Vec, #[serde(default)] - must_revalidate: bool, - body_size_bytes: usize, + vary: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -77,19 +81,33 @@ pub(in crate::cache) struct CachedVaryHeader { #[derive(Debug, Clone)] pub(in crate::cache) struct CacheMetadataInput { - pub(in crate::cache) kind: CacheIndexEntryKind, pub(in crate::cache) base_key: String, - pub(in crate::cache) vary: Vec, - pub(in crate::cache) tags: Vec, - pub(in crate::cache) stored_at_unix_ms: u64, + pub(in crate::cache) body_size_bytes: usize, pub(in crate::cache) expires_at_unix_ms: u64, pub(in crate::cache) grace_until_unix_ms: Option, pub(in crate::cache) keep_until_unix_ms: Option, + pub(in crate::cache) kind: CacheIndexEntryKind, + pub(in crate::cache) must_revalidate: bool, + pub(in crate::cache) requires_revalidation: bool, pub(in crate::cache) stale_if_error_until_unix_ms: Option, pub(in crate::cache) stale_while_revalidate_until_unix_ms: Option, - pub(in crate::cache) requires_revalidation: bool, - pub(in crate::cache) must_revalidate: bool, - pub(in crate::cache) body_size_bytes: usize, + pub(in crate::cache) stored_at_unix_ms: u64, + pub(in crate::cache) tags: Vec, + pub(in crate::cache) vary: Vec, +} + +impl CacheMetadata { + pub(super) fn headers_map(&self) -> std::io::Result { + let mut headers = HeaderMap::new(); + for header in &self.headers { + let name = HeaderName::from_bytes(header.name.as_bytes()) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + let value = HeaderValue::from_bytes(&header.value) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + headers.append(name, value); + } + Ok(headers) + } } pub(in crate::cache) fn cache_metadata( @@ -167,17 +185,3 @@ pub(in crate::cache) async fn read_cache_metadata(path: &Path) -> std::io::Resul body_size_bytes: raw.body_size_bytes, }) } - -impl CacheMetadata { - pub(super) fn headers_map(&self) -> std::io::Result { - let mut headers = HeaderMap::new(); - for header in &self.headers { - let name = HeaderName::from_bytes(header.name.as_bytes()) - .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; - let value = HeaderValue::from_bytes(&header.value) - .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; - headers.append(name, value); - } - Ok(headers) - } -} diff --git a/crates/rginx-http/src/cache/entry/response.rs b/crates/rginx-http/src/cache/entry/response.rs index 9ba5c166..6329080e 100644 --- a/crates/rginx-http/src/cache/entry/response.rs +++ b/crates/rginx-http/src/cache/entry/response.rs @@ -1,26 +1,43 @@ +mod body; + use super::super::store::range::build_downstream_response; -use super::*; +use super::{ + CACHE_STATUS_HEADER, CONNECTION, CONTENT_LENGTH, CONTENT_RANGE, CacheRequest, CachedHeader, + HeaderMap, HeaderName, HeaderValue, HttpResponse, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, + StatusCode, TE, TRAILER, TRANSFER_ENCODING, TryFrom, UPGRADE, +}; use crate::handler::BoxError; use bytes::Bytes; use http::Response; use hyper::body::Body; -mod body; +pub(in crate::cache) use body::CachedFileBody; struct CachedContentRange { - start: u64, end: u64, + start: u64, total: Option, } #[derive(Debug, Clone)] pub(in crate::cache) struct DownstreamRangeTrimPlan { + emit_bytes: usize, headers: HeaderMap, skip_bytes: usize, - emit_bytes: usize, } -pub(in crate::cache) use body::CachedFileBody; +impl DownstreamRangeTrimPlan { + pub(in crate::cache) fn emit_bytes(&self) -> usize { + self.emit_bytes + } + pub(in crate::cache) fn headers(&self) -> &HeaderMap { + &self.headers + } + + pub(in crate::cache) fn skip_bytes(&self) -> usize { + self.skip_bytes + } +} pub(in crate::cache) fn finalize_response_for_request( status: StatusCode, @@ -87,10 +104,12 @@ pub(in crate::cache) fn downstream_range_trim_plan( )); } - let skip_bytes = usize::try_from(request_range.request_start - cached_range.start) - .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; - let emit_bytes = usize::try_from(response_end - request_range.request_start + 1) - .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + let skip_bytes = + usize::try_from(request_range.request_start.saturating_sub(cached_range.start)) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; + let emit_bytes = + usize::try_from(response_end.saturating_sub(request_range.request_start).saturating_add(1)) + .map_err(|error| std::io::Error::new(std::io::ErrorKind::InvalidData, error))?; let mut trimmed_headers = headers.clone(); update_downstream_range_headers( &mut trimmed_headers, @@ -170,7 +189,7 @@ fn parse_cached_content_range(headers: &HeaderMap) -> std::io::Result &HeaderMap { - &self.headers - } - - pub(in crate::cache) fn skip_bytes(&self) -> usize { - self.skip_bytes - } - - pub(in crate::cache) fn emit_bytes(&self) -> usize { - self.emit_bytes - } -} diff --git a/crates/rginx-http/src/cache/entry/response/body.rs b/crates/rginx-http/src/cache/entry/response/body.rs index 9e823fda..d88c735e 100644 --- a/crates/rginx-http/src/cache/entry/response/body.rs +++ b/crates/rginx-http/src/cache/entry/response/body.rs @@ -9,10 +9,10 @@ use tokio::task::JoinHandle; const COMMITTED_CACHE_READ_CHUNK_BYTES: usize = 64 * 1024; pub(in crate::cache) struct CachedFileBody { - rx: mpsc::Receiver, BoxError>>, - size_hint: SizeHint, done: bool, join_handle: Option>, + rx: mpsc::Receiver, BoxError>>, + size_hint: SizeHint, } impl CachedFileBody { @@ -41,6 +41,10 @@ impl Body for CachedFileBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -55,10 +59,6 @@ impl Body for CachedFileBody { } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { self.size_hint.clone() } @@ -85,14 +85,14 @@ async fn stream_cached_file_body( .await; return; } - Ok(read) => filled += read, + Ok(read) => filled = filled.saturating_add(read), Err(error) => { let _ = tx.send(Err(error.into())).await; return; } } } - remaining_bytes -= chunk_len; + remaining_bytes = remaining_bytes.saturating_sub(chunk_len); if tx.send(Ok(Frame::data(Bytes::from(buffer)))).await.is_err() { return; } diff --git a/crates/rginx-http/src/cache/entry/signature.rs b/crates/rginx-http/src/cache/entry/signature.rs index a27e6164..91065b4a 100644 --- a/crates/rginx-http/src/cache/entry/signature.rs +++ b/crates/rginx-http/src/cache/entry/signature.rs @@ -3,7 +3,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use sha2::{Digest, Sha256}; use super::super::vary::canonical_vary_headers; -use super::*; +use super::CachedVaryHeaderValue; pub(in crate::cache) fn cache_key_hash(key: &str) -> String { cache_bytes_hash(key.as_bytes()) @@ -31,13 +31,12 @@ pub(in crate::cache) fn cache_variant_key( pub(in crate::cache) fn unix_time_ms(time: SystemTime) -> u64 { time.duration_since(UNIX_EPOCH) - .map(|duration| duration.as_millis().min(u128::from(u64::MAX)) as u64) - .unwrap_or(0) + .map_or(0, |duration| duration.as_millis().min(u128::from(u64::MAX)) as u64) } fn cache_bytes_hash(bytes: &[u8]) -> String { let digest = Sha256::digest(bytes); - let mut encoded = String::with_capacity(digest.len() * 2); + let mut encoded = String::with_capacity(digest.len().saturating_mul(2)); for byte in digest { use std::fmt::Write as _; let _ = write!(encoded, "{byte:02x}"); diff --git a/crates/rginx-http/src/cache/entry/temp.rs b/crates/rginx-http/src/cache/entry/temp.rs index 3d90f3f6..726644a9 100644 --- a/crates/rginx-http/src/cache/entry/temp.rs +++ b/crates/rginx-http/src/cache/entry/temp.rs @@ -14,8 +14,9 @@ pub(super) fn next_temp_suffix() -> String { } pub(super) fn sibling_temp_path(path: &Path, suffix: &str) -> PathBuf { - let mut file_name = - path.file_name().map_or_else(|| OsString::from("cache-entry"), |name| name.to_os_string()); + let mut file_name = path + .file_name() + .map_or_else(|| OsString::from("cache-entry"), std::ffi::OsStr::to_os_string); file_name.push(suffix); path.with_file_name(file_name) } diff --git a/crates/rginx-http/src/cache/entry/write.rs b/crates/rginx-http/src/cache/entry/write.rs index ad7d9ca4..c20ad713 100644 --- a/crates/rginx-http/src/cache/entry/write.rs +++ b/crates/rginx-http/src/cache/entry/write.rs @@ -3,7 +3,7 @@ use std::path::{Path, PathBuf}; use tokio::fs; use super::temp::{cleanup_failed_write, next_temp_suffix, sibling_temp_path}; -use super::*; +use super::{CacheMetadata, CachePaths}; #[cfg(test)] pub(in crate::cache) async fn write_cache_entry( diff --git a/crates/rginx-http/src/cache/fill.rs b/crates/rginx-http/src/cache/fill.rs index 29586d5e..bca82b9a 100644 --- a/crates/rginx-http/src/cache/fill.rs +++ b/crates/rginx-http/src/cache/fill.rs @@ -1,11 +1,11 @@ -use super::*; - mod common; mod external; mod local; mod persistence; mod shared; +use super::{CacheRequest, InflightFillReadState, RouteCachePolicy}; + pub(in crate::cache) use local::{CacheFillReadState, inflight_fill_body}; pub(in crate::cache) use shared::{ ExternalCacheFillReadState, SharedFillExternalStateHandle, clear_stale_memory_shared_fill_lock, diff --git a/crates/rginx-http/src/cache/fill/external.rs b/crates/rginx-http/src/cache/fill/external.rs index 56044759..9619e324 100644 --- a/crates/rginx-http/src/cache/fill/external.rs +++ b/crates/rginx-http/src/cache/fill/external.rs @@ -16,29 +16,11 @@ use super::shared::{ }; use crate::handler::BoxError; -pub(in crate::cache::fill) fn build_external_fill_response( - state: ExternalCacheFillReadState, - request: &CacheRequest, - policy: &RouteCachePolicy, - read_body: bool, -) -> std::io::Result { - let status = state.status; - let headers = state.headers.clone(); - build_inflight_fill_response_with_body( - status, - &headers, - request, - policy, - read_body, - move || ExternalInFlightFillBody::new(state), - ) -} - struct ExternalInFlightFillBody { - rx: mpsc::Receiver, BoxError>>, - size_hint: hyper::body::SizeHint, done: bool, join_handle: Option>, + rx: mpsc::Receiver, BoxError>>, + size_hint: hyper::body::SizeHint, } impl ExternalInFlightFillBody { @@ -66,6 +48,10 @@ impl hyper::body::Body for ExternalInFlightFillBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -80,15 +66,29 @@ impl hyper::body::Body for ExternalInFlightFillBody { } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> hyper::body::SizeHint { self.size_hint.clone() } } +pub(in crate::cache::fill) fn build_external_fill_response( + state: ExternalCacheFillReadState, + request: &CacheRequest, + policy: &RouteCachePolicy, + read_body: bool, +) -> std::io::Result { + let status = state.status; + let headers = state.headers.clone(); + build_inflight_fill_response_with_body( + status, + &headers, + request, + policy, + read_body, + move || ExternalInFlightFillBody::new(state), + ) +} + async fn stream_external_fill_body( state: ExternalCacheFillReadState, tx: mpsc::Sender, BoxError>>, @@ -128,9 +128,10 @@ async fn stream_external_fill_body( }; if available > offset { - let chunk_len = - usize::try_from((available - offset).min(IN_FLIGHT_FILL_READ_CHUNK_BYTES as u64)) - .expect("bounded read chunk length should fit in usize"); + let chunk_len = usize::try_from( + available.saturating_sub(offset).min(IN_FLIGHT_FILL_READ_CHUNK_BYTES as u64), + ) + .expect("bounded read chunk length should fit in usize"); let mut buffer = vec![0; chunk_len]; let mut filled = 0usize; while filled < chunk_len { @@ -146,7 +147,7 @@ async fn stream_external_fill_body( return; } Ok(0) => sleep(EXTERNAL_FILL_POLL_INTERVAL).await, - Ok(read) => filled += read, + Ok(read) => filled = filled.saturating_add(read), Err(error) => { let _ = tx.send(Err(error.into())).await; return; diff --git a/crates/rginx-http/src/cache/fill/local.rs b/crates/rginx-http/src/cache/fill/local.rs index 803cdbfa..072eb9fa 100644 --- a/crates/rginx-http/src/cache/fill/local.rs +++ b/crates/rginx-http/src/cache/fill/local.rs @@ -18,20 +18,72 @@ use super::shared::SharedFillExternalStateHandle; use crate::handler::{BoxError, HttpBody, boxed_body}; pub(in crate::cache) struct CacheFillReadState { - status: StatusCode, - headers: HeaderMap, - body_tmp_path: PathBuf, body_path: PathBuf, + body_tmp_path: PathBuf, bytes_written: AtomicU64, + error: Mutex>, + external_state: Option, finished: AtomicBool, - upstream_completed: AtomicBool, + headers: HeaderMap, notify: Arc, + status: StatusCode, trailers: Mutex>, - error: Mutex>, - external_state: Option, + upstream_completed: AtomicBool, } impl CacheFillReadState { + fn bytes_written(&self) -> u64 { + self.bytes_written.load(Ordering::Acquire) + } + + fn can_serve(&self) -> bool { + self.error.lock().unwrap_or_else(std::sync::PoisonError::into_inner).as_ref().is_none() + } + + pub(in crate::cache) fn can_share(&self) -> bool { + self.can_serve() && !self.upstream_completed.load(Ordering::Acquire) + } + + fn error_message(&self) -> Option { + self.error.lock().unwrap_or_else(std::sync::PoisonError::into_inner).clone() + } + + pub(in crate::cache) fn fail(&self, error: impl std::fmt::Display) { + let error = error.to_string(); + *self.error.lock().unwrap_or_else(std::sync::PoisonError::into_inner) = Some(error.clone()); + self.notify.notify_waiters(); + if let Some(external_state) = self.external_state.as_ref() + && let Err(external_error) = external_state.fail(&error) + { + tracing::warn!(%external_error, "failed to mark shared fill state failed"); + } + } + + pub(in crate::cache) fn finish(&self, trailers: Option) { + *self.trailers.lock().unwrap_or_else(std::sync::PoisonError::into_inner) = trailers.clone(); + self.finished.store(true, Ordering::Release); + self.notify.notify_waiters(); + if let Some(external_state) = self.external_state.as_ref() + && let Err(error) = external_state.finish(trailers) + { + tracing::warn!(%error, "failed to mark shared fill state complete"); + } + } + + fn is_finished(&self) -> bool { + self.finished.load(Ordering::Acquire) + } + + pub(in crate::cache) fn mark_upstream_complete(&self) { + self.upstream_completed.store(true, Ordering::Release); + self.notify.notify_waiters(); + if let Some(external_state) = self.external_state.as_ref() + && let Err(error) = external_state.mark_upstream_complete() + { + tracing::warn!(%error, "failed to mark shared fill state upstream-complete"); + } + } + pub(in crate::cache) fn new( status: StatusCode, headers: HeaderMap, @@ -76,94 +128,20 @@ impl CacheFillReadState { } } - pub(in crate::cache) fn finish(&self, trailers: Option) { - *self.trailers.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = trailers.clone(); - self.finished.store(true, Ordering::Release); - self.notify.notify_waiters(); - if let Some(external_state) = self.external_state.as_ref() - && let Err(error) = external_state.finish(trailers) - { - tracing::warn!(%error, "failed to mark shared fill state complete"); - } - } - - pub(in crate::cache) fn fail(&self, error: impl std::fmt::Display) { - let error = error.to_string(); - *self.error.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = Some(error.clone()); - self.notify.notify_waiters(); - if let Some(external_state) = self.external_state.as_ref() - && let Err(external_error) = external_state.fail(&error) - { - tracing::warn!(%external_error, "failed to mark shared fill state failed"); - } - } - - pub(in crate::cache) fn can_share(&self) -> bool { - self.can_serve() && !self.upstream_completed.load(Ordering::Acquire) - } - - pub(in crate::cache) fn mark_upstream_complete(&self) { - self.upstream_completed.store(true, Ordering::Release); - self.notify.notify_waiters(); - if let Some(external_state) = self.external_state.as_ref() - && let Err(error) = external_state.mark_upstream_complete() - { - tracing::warn!(%error, "failed to mark shared fill state upstream-complete"); - } - } - - fn can_serve(&self) -> bool { - self.error.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).as_ref().is_none() - } - - fn bytes_written(&self) -> u64 { - self.bytes_written.load(Ordering::Acquire) - } - - fn is_finished(&self) -> bool { - self.finished.load(Ordering::Acquire) - } - - fn trailers(&self) -> Option { - self.trailers.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).clone() - } - - fn error_message(&self) -> Option { - self.error.lock().unwrap_or_else(|poisoned| poisoned.into_inner()).clone() - } - fn size_hint(&self) -> hyper::body::SizeHint { size_hint_from_headers(&self.headers) } -} - -pub(in crate::cache::fill) fn build_local_fill_response( - state: Arc, - request: &CacheRequest, - policy: &RouteCachePolicy, - read_body: bool, -) -> std::io::Result { - let status = state.status; - let headers = state.headers.clone(); - build_inflight_fill_response_with_body( - status, - &headers, - request, - policy, - read_body, - move || InFlightFillBody::new(state), - ) -} -pub(in crate::cache) fn inflight_fill_body(state: Arc) -> HttpBody { - boxed_body(InFlightFillBody::new(state)) + fn trailers(&self) -> Option { + self.trailers.lock().unwrap_or_else(std::sync::PoisonError::into_inner).clone() + } } struct InFlightFillBody { - rx: mpsc::Receiver, BoxError>>, - size_hint: hyper::body::SizeHint, done: bool, join_handle: Option>, + rx: mpsc::Receiver, BoxError>>, + size_hint: hyper::body::SizeHint, } impl InFlightFillBody { @@ -191,6 +169,10 @@ impl hyper::body::Body for InFlightFillBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -205,15 +187,33 @@ impl hyper::body::Body for InFlightFillBody { } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> hyper::body::SizeHint { self.size_hint.clone() } } +pub(in crate::cache::fill) fn build_local_fill_response( + state: Arc, + request: &CacheRequest, + policy: &RouteCachePolicy, + read_body: bool, +) -> std::io::Result { + let status = state.status; + let headers = state.headers.clone(); + build_inflight_fill_response_with_body( + status, + &headers, + request, + policy, + read_body, + move || InFlightFillBody::new(state), + ) +} + +pub(in crate::cache) fn inflight_fill_body(state: Arc) -> HttpBody { + boxed_body(InFlightFillBody::new(state)) +} + async fn stream_inflight_fill_body( state: Arc, tx: mpsc::Sender, BoxError>>, @@ -231,9 +231,10 @@ async fn stream_inflight_fill_body( let notified = state.notify.notified(); let available = state.bytes_written(); if available > offset { - let chunk_len = - usize::try_from((available - offset).min(IN_FLIGHT_FILL_READ_CHUNK_BYTES as u64)) - .expect("bounded read chunk length should fit in usize"); + let chunk_len = usize::try_from( + available.saturating_sub(offset).min(IN_FLIGHT_FILL_READ_CHUNK_BYTES as u64), + ) + .expect("bounded read chunk length should fit in usize"); let mut buffer = vec![0; chunk_len]; let mut filled = 0usize; while filled < chunk_len { @@ -249,7 +250,7 @@ async fn stream_inflight_fill_body( return; } Ok(0) => state.notify.notified().await, - Ok(read) => filled += read, + Ok(read) => filled = filled.saturating_add(read), Err(error) => { let _ = tx.send(Err(error.into())).await; return; diff --git a/crates/rginx-http/src/cache/fill/persistence.rs b/crates/rginx-http/src/cache/fill/persistence.rs index a9b068d2..10184014 100644 --- a/crates/rginx-http/src/cache/fill/persistence.rs +++ b/crates/rginx-http/src/cache/fill/persistence.rs @@ -24,10 +24,8 @@ pub(super) fn next_shared_fill_nonce(now: u64) -> String { fn temp_json_path(path: &Path) -> PathBuf { let counter = SHARED_FILL_STATE_TMP_COUNTER.fetch_add(1, Ordering::Relaxed); - let mut file_name = path - .file_name() - .map(|name| name.to_os_string()) - .unwrap_or_else(|| "shared-fill-state".into()); + let mut file_name = + path.file_name().map_or_else(|| "shared-fill-state".into(), std::ffi::OsStr::to_os_string); file_name.push(format!(".tmp-{}-{counter}", std::process::id())); path.with_file_name(file_name) } diff --git a/crates/rginx-http/src/cache/fill/shared.rs b/crates/rginx-http/src/cache/fill/shared.rs index 46abbc92..fa40437c 100644 --- a/crates/rginx-http/src/cache/fill/shared.rs +++ b/crates/rginx-http/src/cache/fill/shared.rs @@ -1,3 +1,5 @@ +mod access; + use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; use std::time::SystemTime; @@ -11,7 +13,6 @@ use super::super::shared::{ shared_fill_lock_path, shared_fill_state_path, }; use super::persistence::{atomic_write_json, next_shared_fill_nonce}; -mod access; use self::access::ExternalFillStateSource; pub(in crate::cache) use self::access::{ @@ -41,11 +42,11 @@ enum SharedFillStateBackend { #[derive(Clone)] pub(in crate::cache) struct ExternalCacheFillReadState { - pub(super) status: StatusCode, - pub(super) headers: HeaderMap, - pub(super) body_tmp_path: PathBuf, pub(super) body_path: PathBuf, + pub(super) body_tmp_path: PathBuf, + pub(super) headers: HeaderMap, pub(super) source: ExternalFillStateSource, + pub(super) status: StatusCode, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -56,22 +57,22 @@ struct SharedFillLockRecord { #[derive(Debug, Clone, Serialize, Deserialize)] pub(super) struct SharedFillStateRecord { + pub(super) error: Option, + pub(super) finished: bool, pub(super) nonce: String, + pub(super) response: Option, #[serde(default)] pub(super) share_fingerprint: String, - pub(super) response: Option, - pub(super) upstream_completed: bool, - pub(super) finished: bool, pub(super) trailers: Option>, - pub(super) error: Option, + pub(super) upstream_completed: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub(super) struct SharedFillResponseMetadata { - status: u16, - headers: Vec, - body_tmp_path: PathBuf, body_path: PathBuf, + body_tmp_path: PathBuf, + headers: Vec, + status: u16, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -142,26 +143,10 @@ impl SharedFillExternalStateHandle { } } - pub(super) fn publish_response( - &self, - status: StatusCode, - headers: &HeaderMap, - body_tmp_path: &Path, - body_path: &Path, - ) -> std::io::Result<()> { - self.update_state(|state| { - state.response = Some(SharedFillResponseMetadata { - status: status.as_u16(), - headers: shared_headers_from_map(headers), - body_tmp_path: body_tmp_path.to_path_buf(), - body_path: body_path.to_path_buf(), - }); - }) - } - - pub(super) fn mark_upstream_complete(&self) -> std::io::Result<()> { - self.update_state(|state| { - state.upstream_completed = true; + pub(super) fn fail(&self, error: impl std::fmt::Display) -> std::io::Result<()> { + let error = error.to_string(); + self.update_state(move |state| { + state.error = Some(error.clone()); }) } @@ -173,16 +158,9 @@ impl SharedFillExternalStateHandle { }) } - pub(super) fn fail(&self, error: impl std::fmt::Display) -> std::io::Result<()> { - let error = error.to_string(); - self.update_state(move |state| { - state.error = Some(error.clone()); - }) - } - pub(super) fn heartbeat(&self) -> std::io::Result<()> { run_blocking(|| { - let state = self.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let state = self.state.lock().unwrap_or_else(std::sync::PoisonError::into_inner); match &self.backend { SharedFillStateBackend::File { lock_path, .. } => persist_shared_fill_lock_record( lock_path, @@ -201,9 +179,15 @@ impl SharedFillExternalStateHandle { }) } + pub(super) fn mark_upstream_complete(&self) -> std::io::Result<()> { + self.update_state(|state| { + state.upstream_completed = true; + }) + } + fn persist_lock_and_state(&self, now: u64) -> std::io::Result<()> { run_blocking(|| { - let state = self.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let state = self.state.lock().unwrap_or_else(std::sync::PoisonError::into_inner); match &self.backend { SharedFillStateBackend::File { lock_path, state_path } => { persist_shared_fill_state_record(state_path, &state)?; @@ -225,9 +209,38 @@ impl SharedFillExternalStateHandle { }) } + pub(super) fn publish_response( + &self, + status: StatusCode, + headers: &HeaderMap, + body_tmp_path: &Path, + body_path: &Path, + ) -> std::io::Result<()> { + self.update_state(|state| { + state.response = Some(SharedFillResponseMetadata { + status: status.as_u16(), + headers: shared_headers_from_map(headers), + body_tmp_path: body_tmp_path.to_path_buf(), + body_path: body_path.to_path_buf(), + }); + }) + } + + pub(in crate::cache) fn release(&self) { + let state = self.state.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + match &self.backend { + SharedFillStateBackend::File { lock_path, .. } => { + let _ = std::fs::remove_file(lock_path); + } + SharedFillStateBackend::SharedMemory { store, key } => { + let _ = store.release_fill_lock(key, &state.nonce); + } + } + } + fn update_state(&self, update: impl FnOnce(&mut SharedFillStateRecord)) -> std::io::Result<()> { run_blocking(|| { - let mut state = self.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut state = self.state.lock().unwrap_or_else(std::sync::PoisonError::into_inner); update(&mut state); match &self.backend { SharedFillStateBackend::File { lock_path, state_path } => { @@ -249,16 +262,4 @@ impl SharedFillExternalStateHandle { } }) } - - pub(in crate::cache) fn release(&self) { - let state = self.state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - match &self.backend { - SharedFillStateBackend::File { lock_path, .. } => { - let _ = std::fs::remove_file(lock_path); - } - SharedFillStateBackend::SharedMemory { store, key } => { - let _ = store.release_fill_lock(key, &state.nonce); - } - } - } } diff --git a/crates/rginx-http/src/cache/fill/shared/access.rs b/crates/rginx-http/src/cache/fill/shared/access.rs index b103ccd5..a8ca778d 100644 --- a/crates/rginx-http/src/cache/fill/shared/access.rs +++ b/crates/rginx-http/src/cache/fill/shared/access.rs @@ -3,7 +3,11 @@ use std::sync::Arc; use http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; -use super::*; +use super::{ + ExternalCacheFillReadState, SharedFillExternalStateHandle, SharedFillHeader, + SharedFillLockRecord, SharedFillLockStatus, SharedFillStateRecord, SharedIndexStore, + atomic_write_json, shared_fill_lock_path, shared_fill_state_path, +}; #[derive(Clone)] pub(in crate::cache::fill) enum ExternalFillStateSource { diff --git a/crates/rginx-http/src/cache/index.rs b/crates/rginx-http/src/cache/index.rs index 2ecadc66..f7f15c54 100644 --- a/crates/rginx-http/src/cache/index.rs +++ b/crates/rginx-http/src/cache/index.rs @@ -5,6 +5,38 @@ use super::{ }; impl CacheIndex { + pub(super) fn add_invalidation_rule(&mut self, rule: CacheInvalidationRule) { + if self.invalidations.contains(&rule) { + return; + } + self.invalidation_index.record_rule(&rule); + self.invalidations.push(rule); + } + + pub(super) fn clear_invalidation_rules(&mut self) { + self.invalidations.clear(); + self.invalidation_index.clear(); + } + + fn decrement_hash_ref(&mut self, hash: &str) { + let Some(count) = self.hash_ref_counts.get_mut(hash) else { + return; + }; + *count = count.saturating_sub(1); + if *count == 0 { + self.hash_ref_counts.remove(hash); + } + } + + pub(super) fn hash_is_referenced(&self, hash: &str) -> bool { + self.hash_ref_counts.get(hash).copied().unwrap_or_default() > 0 + } + + fn increment_hash_ref(&mut self, hash: &str) { + let count = self.hash_ref_counts.entry(hash.to_string()).or_insert(0); + *count = count.saturating_add(1); + } + pub(super) fn insert_entry( &mut self, key: String, @@ -20,15 +52,14 @@ impl CacheIndex { existing } - pub(super) fn remove_entry(&mut self, key: &str) -> Option { - let removed = self.entries.remove(key)?; - self.decrement_hash_ref(&removed.hash); - self.unschedule_entry_access(key); - Some(removed) + pub(super) fn oldest_scheduled_access_unix_ms(&self) -> Option { + self.access_schedule.first().map(|entry| entry.last_access_unix_ms) } - pub(super) fn hash_is_referenced(&self, hash: &str) -> bool { - self.hash_ref_counts.get(hash).copied().unwrap_or_default() > 0 + pub(super) fn pop_oldest_scheduled_entry(&mut self) -> Option<(String, u64)> { + let scheduled = self.access_schedule.pop_first()?; + self.access_ticket_by_key.remove(&scheduled.key); + Some((scheduled.key, scheduled.last_access_unix_ms)) } #[cfg(test)] @@ -46,33 +77,11 @@ impl CacheIndex { } } - pub(super) fn reschedule_entry_access(&mut self, key: &str, last_access_unix_ms: u64) { - if self.entries.contains_key(key) { - self.schedule_entry_access(key, last_access_unix_ms); - } - } - - pub(super) fn oldest_scheduled_access_unix_ms(&self) -> Option { - self.access_schedule.first().map(|entry| entry.last_access_unix_ms) - } - - pub(super) fn pop_oldest_scheduled_entry(&mut self) -> Option<(String, u64)> { - let scheduled = self.access_schedule.pop_first()?; - self.access_ticket_by_key.remove(&scheduled.key); - Some((scheduled.key, scheduled.last_access_unix_ms)) - } - - pub(super) fn add_invalidation_rule(&mut self, rule: CacheInvalidationRule) { - if self.invalidations.contains(&rule) { - return; - } - self.invalidation_index.record_rule(&rule); - self.invalidations.push(rule); - } - - pub(super) fn clear_invalidation_rules(&mut self) { - self.invalidations.clear(); - self.invalidation_index.clear(); + pub(super) fn remove_entry(&mut self, key: &str) -> Option { + let removed = self.entries.remove(key)?; + self.decrement_hash_ref(&removed.hash); + self.unschedule_entry_access(key); + Some(removed) } pub(super) fn replace_invalidation_rules(&mut self, rules: Vec) { @@ -80,17 +89,9 @@ impl CacheIndex { self.invalidation_index.rebuild(&self.invalidations); } - fn increment_hash_ref(&mut self, hash: &str) { - *self.hash_ref_counts.entry(hash.to_string()).or_insert(0) += 1; - } - - fn decrement_hash_ref(&mut self, hash: &str) { - let Some(count) = self.hash_ref_counts.get_mut(hash) else { - return; - }; - *count = count.saturating_sub(1); - if *count == 0 { - self.hash_ref_counts.remove(hash); + pub(super) fn reschedule_entry_access(&mut self, key: &str, last_access_unix_ms: u64) { + if self.entries.contains_key(key) { + self.schedule_entry_access(key, last_access_unix_ms); } } @@ -127,45 +128,6 @@ impl CacheInvalidationIndex { self.latest_tag_created_at_unix_ms.clear(); } - fn rebuild(&mut self, rules: &[CacheInvalidationRule]) { - self.clear(); - for rule in rules { - self.record_rule(rule); - } - } - - fn record_rule(&mut self, rule: &CacheInvalidationRule) { - match &rule.selector { - CacheInvalidationSelector::All => { - Self::record_latest( - &mut self.latest_all_created_at_unix_ms, - rule.created_at_unix_ms, - ); - } - CacheInvalidationSelector::Exact(key) => { - Self::record_named( - &mut self.latest_exact_created_at_unix_ms, - key, - rule.created_at_unix_ms, - ); - } - CacheInvalidationSelector::Prefix(prefix) => { - Self::record_named( - &mut self.latest_prefix_created_at_unix_ms, - prefix, - rule.created_at_unix_ms, - ); - } - CacheInvalidationSelector::Tag(tag) => { - Self::record_named( - &mut self.latest_tag_created_at_unix_ms, - tag, - rule.created_at_unix_ms, - ); - } - } - } - pub(super) fn latest_matching_created_at_unix_ms( &self, key: &str, @@ -198,6 +160,13 @@ impl CacheInvalidationIndex { latest } + fn rebuild(&mut self, rules: &[CacheInvalidationRule]) { + self.clear(); + for rule in rules { + self.record_rule(rule); + } + } + fn record_latest(slot: &mut Option, created_at_unix_ms: u64) { *slot = Some(slot.map_or(created_at_unix_ms, |current| current.max(created_at_unix_ms))); } @@ -212,15 +181,38 @@ impl CacheInvalidationIndex { .and_modify(|current| *current = (*current).max(created_at_unix_ms)) .or_insert(created_at_unix_ms); } -} -fn utf8_prefixes(value: &str) -> impl Iterator { - value - .char_indices() - .map(|(index, _)| index) - .skip(1) - .chain(std::iter::once(value.len())) - .map(move |end| &value[..end]) + fn record_rule(&mut self, rule: &CacheInvalidationRule) { + match &rule.selector { + CacheInvalidationSelector::All => { + Self::record_latest( + &mut self.latest_all_created_at_unix_ms, + rule.created_at_unix_ms, + ); + } + CacheInvalidationSelector::Exact(key) => { + Self::record_named( + &mut self.latest_exact_created_at_unix_ms, + key, + rule.created_at_unix_ms, + ); + } + CacheInvalidationSelector::Prefix(prefix) => { + Self::record_named( + &mut self.latest_prefix_created_at_unix_ms, + prefix, + rule.created_at_unix_ms, + ); + } + CacheInvalidationSelector::Tag(tag) => { + Self::record_named( + &mut self.latest_tag_created_at_unix_ms, + tag, + rule.created_at_unix_ms, + ); + } + } + } } impl CacheIndexEntry { @@ -246,3 +238,12 @@ impl CacheIndexEntry { && self.must_revalidate == other.must_revalidate } } + +fn utf8_prefixes(value: &str) -> impl Iterator { + value + .char_indices() + .map(|(index, _)| index) + .skip(1) + .chain(std::iter::once(value.len())) + .map(move |end| &value[..end]) +} diff --git a/crates/rginx-http/src/cache/io.rs b/crates/rginx-http/src/cache/io.rs index f8d76c5b..26a2aa2d 100644 --- a/crates/rginx-http/src/cache/io.rs +++ b/crates/rginx-http/src/cache/io.rs @@ -31,6 +31,10 @@ impl CacheIoLockPool { CacheIoReadGuard { _guard: self.stripes[stripe].clone().read_owned().await } } + fn stripe(&self, hash: &str) -> usize { + cache_io_lock_stripe_with_len(hash, self.stripes.len()) + } + async fn write(&self, hash: &str) -> CacheIoWriteGuards { self.write_hashes([hash]).await } @@ -46,15 +50,6 @@ impl CacheIoLockPool { } CacheIoWriteGuards { _guards: guards } } - - fn stripe(&self, hash: &str) -> usize { - cache_io_lock_stripe_with_len(hash, self.stripes.len()) - } -} - -#[cfg(test)] -pub(super) fn cache_io_lock_stripe(hash: &str) -> usize { - cache_io_lock_stripe_with_len(hash, CACHE_IO_LOCK_STRIPES) } impl CacheZoneRuntime { @@ -67,8 +62,13 @@ impl CacheZoneRuntime { } } +#[cfg(test)] +pub(super) fn cache_io_lock_stripe(hash: &str) -> usize { + cache_io_lock_stripe_with_len(hash, CACHE_IO_LOCK_STRIPES) +} + fn cache_io_lock_stripe_with_len(hash: &str, stripe_len: usize) -> usize { let mut hasher = std::hash::DefaultHasher::new(); hash.hash(&mut hasher); - (hasher.finish() as usize) % stripe_len + (hasher.finish() as usize).rem_euclid(stripe_len) } diff --git a/crates/rginx-http/src/cache/load.rs b/crates/rginx-http/src/cache/load.rs index 06bdea96..7b1d7c16 100644 --- a/crates/rginx-http/src/cache/load.rs +++ b/crates/rginx-http/src/cache/load.rs @@ -11,32 +11,32 @@ use super::{CacheIndex, CacheIndexEntry, CacheIndexEntryKind, CachedVaryHeaderVa #[derive(Debug, Deserialize)] struct ScannedCacheMetadata { - #[serde(default)] - key: String, #[serde(default)] base_key: String, + body_size_bytes: usize, + expires_at_unix_ms: u64, #[serde(default)] - vary: Vec, + grace_until_unix_ms: Option, #[serde(default)] - tags: Vec, + keep_until_unix_ms: Option, #[serde(default)] - stored_at_unix_ms: u64, - expires_at_unix_ms: u64, + key: String, #[serde(default)] kind: CacheIndexEntryKind, #[serde(default)] - grace_until_unix_ms: Option, + must_revalidate: bool, #[serde(default)] - keep_until_unix_ms: Option, + requires_revalidation: Option, #[serde(default)] stale_if_error_until_unix_ms: Option, #[serde(default)] stale_while_revalidate_until_unix_ms: Option, #[serde(default)] - requires_revalidation: Option, + stored_at_unix_ms: u64, #[serde(default)] - must_revalidate: bool, - body_size_bytes: usize, + tags: Vec, + #[serde(default)] + vary: Vec, } #[derive(Debug, Deserialize)] @@ -46,6 +46,30 @@ struct ScannedVaryHeader { value: Option, } +#[derive(Default)] +struct LoaderState { + processed_entries: usize, +} + +impl LoaderState { + fn maybe_sleep(&mut self, zone: &CacheZone) { + self.processed_entries = self.processed_entries.saturating_add(1); + if zone.loader_batch_entries == 0 + || zone.loader_sleep.is_zero() + || !self.processed_entries.is_multiple_of(zone.loader_batch_entries) + { + return; + } + if let Ok(handle) = tokio::runtime::Handle::try_current() + && handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread + { + tokio::task::block_in_place(|| thread::sleep(zone.loader_sleep)); + return; + } + thread::sleep(zone.loader_sleep); + } +} + pub(super) fn load_index_from_disk(zone: &CacheZone) -> io::Result { let mut index = CacheIndex::default(); if !zone.path.exists() { @@ -242,30 +266,6 @@ fn remove_cache_files(zone: &CacheZone, hash: &str) { let _ = fs::remove_file(paths.body); } -#[derive(Default)] -struct LoaderState { - processed_entries: usize, -} - -impl LoaderState { - fn maybe_sleep(&mut self, zone: &CacheZone) { - self.processed_entries = self.processed_entries.saturating_add(1); - if zone.loader_batch_entries == 0 - || zone.loader_sleep.is_zero() - || !self.processed_entries.is_multiple_of(zone.loader_batch_entries) - { - return; - } - if let Ok(handle) = tokio::runtime::Handle::try_current() - && handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::MultiThread - { - tokio::task::block_in_place(|| thread::sleep(zone.loader_sleep)); - return; - } - thread::sleep(zone.loader_sleep); - } -} - fn add_variant_key( variants: &mut std::collections::HashMap>, base_key: String, diff --git a/crates/rginx-http/src/cache/lookup.rs b/crates/rginx-http/src/cache/lookup.rs index da83a1ff..6e3a4bdc 100644 --- a/crates/rginx-http/src/cache/lookup.rs +++ b/crates/rginx-http/src/cache/lookup.rs @@ -1,12 +1,66 @@ +mod helpers; + use super::invalidation::entry_is_logically_invalid; use super::runtime::{CacheEntryLifecyclePhase, lifecycle_phase, stale_updating_allowed}; -use super::*; - -mod helpers; +use super::{ + Arc, CacheIndexEntry, CacheManager, CacheRequest, CacheStatus, CacheZoneRuntime, + FillLockDecision, HttpResponse, LookupDecision, LookupWait, PreparedCacheResponseHead, + RouteCachePolicy, load_cached_response_head, read_cached_response_for_request, read_index, + remove_cache_entry_if_matches, +}; use helpers::{fill_share_fingerprint, matching_variant_key}; impl CacheManager { + pub(super) async fn load_lookup_cached_response( + &self, + zone: &Arc, + key: &str, + entry: &CacheIndexEntry, + request: &CacheRequest, + policy: &RouteCachePolicy, + read_cached_body: bool, + ) -> Option<(Arc, HttpResponse)> { + match read_cached_response_for_request(zone, key, entry, request, policy, read_cached_body) + .await + { + Ok(response) => Some(response), + Err(error) => { + tracing::warn!( + zone = %zone.config.name, + key = %key, + key_hash = %entry.hash, + %error, + "failed to read cached response; removing entry" + ); + remove_cache_entry_if_matches(zone, key, entry).await; + None + } + } + } + + pub(super) async fn load_lookup_response_head( + &self, + zone: &Arc, + key: &str, + entry: &CacheIndexEntry, + ) -> Option> { + match load_cached_response_head(zone, key, entry).await { + Ok(response_head) => Some(response_head), + Err(error) => { + tracing::warn!( + zone = %zone.config.name, + key = %key, + key_hash = %entry.hash, + %error, + "failed to load cached response head; removing entry" + ); + remove_cache_entry_if_matches(zone, key, entry).await; + None + } + } + } + pub(super) fn lookup_decision( &self, zone: &Arc, @@ -153,53 +207,4 @@ impl CacheManager { }, } } - - pub(super) async fn load_lookup_response_head( - &self, - zone: &Arc, - key: &str, - entry: &CacheIndexEntry, - ) -> Option> { - match load_cached_response_head(zone, key, entry).await { - Ok(response_head) => Some(response_head), - Err(error) => { - tracing::warn!( - zone = %zone.config.name, - key = %key, - key_hash = %entry.hash, - %error, - "failed to load cached response head; removing entry" - ); - remove_cache_entry_if_matches(zone, key, entry).await; - None - } - } - } - - pub(super) async fn load_lookup_cached_response( - &self, - zone: &Arc, - key: &str, - entry: &CacheIndexEntry, - request: &CacheRequest, - policy: &RouteCachePolicy, - read_cached_body: bool, - ) -> Option<(Arc, HttpResponse)> { - match read_cached_response_for_request(zone, key, entry, request, policy, read_cached_body) - .await - { - Ok(response) => Some(response), - Err(error) => { - tracing::warn!( - zone = %zone.config.name, - key = %key, - key_hash = %entry.hash, - %error, - "failed to read cached response; removing entry" - ); - remove_cache_entry_if_matches(zone, key, entry).await; - None - } - } - } } diff --git a/crates/rginx-http/src/cache/manager.rs b/crates/rginx-http/src/cache/manager.rs index 4303d74f..a094991a 100644 --- a/crates/rginx-http/src/cache/manager.rs +++ b/crates/rginx-http/src/cache/manager.rs @@ -1,10 +1,21 @@ -use super::*; - mod bootstrap; mod control; mod lookup_support; mod response; +use super::{ + Arc, AsyncMutex, AtomicU64, CacheChangeNotifier, CacheFillGuard, CacheIndexEntry, + CacheInvalidationResult, CacheInvalidationSelector, CacheIoLockPool, CacheLookup, CacheManager, + CachePurgeResult, CacheRequest, CacheStatus, CacheStoreContext, CacheStoreError, + CacheZoneRuntime, CacheZoneRuntimeSnapshot, CacheZoneStats, ConfigSnapshot, Error, HashMap, + HttpResponse, LookupDecision, LookupWait, Method, Mutex, PreparedCacheResponseHead, + PurgeSelector, Result, RouteCachePolicy, RwLock, StatusCode, SystemTime, + bootstrap_shared_index, cache_request_bypass, cleanup_inactive_entries_in_zone, + purge_zone_entries, record_zone_shared_entry_access, refresh_not_modified_response, + remove_cache_entry_if_matches, render_cache_key, request_requires_revalidation, store_response, + sync_zone_shared_index_if_needed, unix_time_ms, with_cache_status, +}; + use lookup_support::{ CacheLookupRequest, CacheStoreContextParts, CachedLookupResponse, record_lookup_miss_status, wait_for_cache_fill, @@ -144,12 +155,13 @@ impl CacheManager { cache_status, } => { let cached_response_head = if let Some(entry) = &cached_entry { - match self.load_lookup_response_head(&zone, &key, entry).await { - Some(response_head) => Some(response_head), - None => { - drop(fill_guard); - continue; - } + if let Some(response_head) = + self.load_lookup_response_head(&zone, &key, entry).await + { + Some(response_head) + } else { + drop(fill_guard); + continue; } } else { None @@ -224,6 +236,6 @@ impl CacheManager { ) .await?; let response = self.finalize_lookup_cached_response(zone, key, status, response).await; - Some(CachedLookupResponse { response_head, response }) + Some(CachedLookupResponse { response, response_head }) } } diff --git a/crates/rginx-http/src/cache/manager/bootstrap.rs b/crates/rginx-http/src/cache/manager/bootstrap.rs index 617c442f..f980a218 100644 --- a/crates/rginx-http/src/cache/manager/bootstrap.rs +++ b/crates/rginx-http/src/cache/manager/bootstrap.rs @@ -1,4 +1,8 @@ -use super::*; +use super::{ + Arc, AsyncMutex, AtomicU64, CacheChangeNotifier, CacheIoLockPool, CacheManager, + CacheZoneRuntime, CacheZoneStats, ConfigSnapshot, Error, HashMap, Mutex, Result, RwLock, + bootstrap_shared_index, +}; impl CacheManager { pub(crate) fn from_config_with_notifier( diff --git a/crates/rginx-http/src/cache/manager/control.rs b/crates/rginx-http/src/cache/manager/control.rs index 57a53b00..a23cc9a9 100644 --- a/crates/rginx-http/src/cache/manager/control.rs +++ b/crates/rginx-http/src/cache/manager/control.rs @@ -1,21 +1,12 @@ -use super::*; +use super::{ + CacheInvalidationResult, CacheInvalidationSelector, CacheManager, CachePurgeResult, + CacheZoneRuntimeSnapshot, PurgeSelector, cleanup_inactive_entries_in_zone, purge_zone_entries, + sync_zone_shared_index_if_needed, +}; use crate::cache::invalidation::normalize_cache_tag; use crate::cache::store::{clear_zone_invalidations, invalidate_zone_entries}; impl CacheManager { - pub(crate) fn snapshot(&self) -> Vec { - let mut snapshots = self.zones.values().map(|zone| zone.snapshot()).collect::>(); - snapshots.sort_by(|left, right| left.zone_name.cmp(&right.zone_name)); - snapshots - } - - pub(crate) async fn snapshot_with_shared_sync(&self) -> Vec { - for zone in self.zones.values() { - sync_zone_shared_index_if_needed(zone).await; - } - self.snapshot() - } - pub(crate) async fn cleanup_inactive_entries(&self) { for zone in self.zones.values() { sync_zone_shared_index_if_needed(zone).await; @@ -23,50 +14,52 @@ impl CacheManager { } } - pub(crate) async fn purge_zone( + pub(crate) async fn clear_invalidations( &self, zone_name: &str, - ) -> std::result::Result { + ) -> std::result::Result { let zone = self .zones .get(zone_name) .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - Ok(purge_zone_entries(zone, PurgeSelector::All).await) + Ok(clear_zone_invalidations(zone).await) } - pub(crate) async fn purge_key( + pub(crate) async fn invalidate_key( &self, zone_name: &str, key: &str, - ) -> std::result::Result { + ) -> std::result::Result { let zone = self .zones .get(zone_name) .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - Ok(purge_zone_entries(zone, PurgeSelector::Exact(key.to_string())).await) + Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::Exact(key.to_string())).await) } - pub(crate) async fn purge_prefix( + pub(crate) async fn invalidate_prefix( &self, zone_name: &str, prefix: &str, - ) -> std::result::Result { + ) -> std::result::Result { let zone = self .zones .get(zone_name) .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - Ok(purge_zone_entries(zone, PurgeSelector::Prefix(prefix.to_string())).await) + Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::Prefix(prefix.to_string())) + .await) } - pub(crate) async fn invalidate_zone( + pub(crate) async fn invalidate_tag( &self, zone_name: &str, + tag: &str, ) -> std::result::Result { let zone = self .zones @@ -74,13 +67,15 @@ impl CacheManager { .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::All).await) + let Some(tag) = normalize_cache_tag(tag) else { + return Err("cache invalidation tag must not be empty".to_string()); + }; + Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::Tag(tag)).await) } - pub(crate) async fn invalidate_key( + pub(crate) async fn invalidate_zone( &self, zone_name: &str, - key: &str, ) -> std::result::Result { let zone = self .zones @@ -88,51 +83,60 @@ impl CacheManager { .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::Exact(key.to_string())).await) + Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::All).await) } - pub(crate) async fn invalidate_prefix( + pub(crate) async fn purge_key( &self, zone_name: &str, - prefix: &str, - ) -> std::result::Result { + key: &str, + ) -> std::result::Result { let zone = self .zones .get(zone_name) .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::Prefix(prefix.to_string())) - .await) + Ok(purge_zone_entries(zone, PurgeSelector::Exact(key.to_string())).await) } - pub(crate) async fn invalidate_tag( + pub(crate) async fn purge_prefix( &self, zone_name: &str, - tag: &str, - ) -> std::result::Result { + prefix: &str, + ) -> std::result::Result { let zone = self .zones .get(zone_name) .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - let Some(tag) = normalize_cache_tag(tag) else { - return Err("cache invalidation tag must not be empty".to_string()); - }; - Ok(invalidate_zone_entries(zone, CacheInvalidationSelector::Tag(tag)).await) + Ok(purge_zone_entries(zone, PurgeSelector::Prefix(prefix.to_string())).await) } - pub(crate) async fn clear_invalidations( + pub(crate) async fn purge_zone( &self, zone_name: &str, - ) -> std::result::Result { + ) -> std::result::Result { let zone = self .zones .get(zone_name) .cloned() .ok_or_else(|| format!("unknown cache zone `{zone_name}`"))?; sync_zone_shared_index_if_needed(&zone).await; - Ok(clear_zone_invalidations(zone).await) + Ok(purge_zone_entries(zone, PurgeSelector::All).await) + } + + pub(crate) fn snapshot(&self) -> Vec { + let mut snapshots = self.zones.values().map(|zone| zone.snapshot()).collect::>(); + snapshots.sort_by(|left, right| left.zone_name.cmp(&right.zone_name)); + snapshots + } + + pub(crate) async fn snapshot_with_shared_sync(&self) -> Vec { + for zone in self.zones.values() { + sync_zone_shared_index_if_needed(zone).await; + } + self.snapshot() } } diff --git a/crates/rginx-http/src/cache/manager/lookup_support.rs b/crates/rginx-http/src/cache/manager/lookup_support.rs index b4825b3c..ca763da3 100644 --- a/crates/rginx-http/src/cache/manager/lookup_support.rs +++ b/crates/rginx-http/src/cache/manager/lookup_support.rs @@ -1,26 +1,30 @@ -use super::*; +use super::{ + Arc, CacheFillGuard, CacheIndexEntry, CacheRequest, CacheStatus, CacheStoreContext, + CacheZoneRuntime, HttpResponse, LookupWait, Method, PreparedCacheResponseHead, + RouteCachePolicy, render_cache_key, request_requires_revalidation, +}; pub(super) struct CacheLookupRequest { - pub(super) request: CacheRequest, pub(super) base_key: String, - pub(super) request_forces_revalidation: bool, pub(super) read_cached_body: bool, + pub(super) request: CacheRequest, + pub(super) request_forces_revalidation: bool, } pub(super) struct CachedLookupResponse { - pub(super) response_head: Arc, pub(super) response: HttpResponse, + pub(super) response_head: Arc, } pub(super) struct CacheStoreContextParts<'a> { - pub(super) zone: Arc, - pub(super) key: String, pub(super) base_key: String, - pub(super) policy: &'a RouteCachePolicy, - pub(super) fill_guard: Option, - pub(super) cached_entry: Option, pub(super) cache_status: CacheStatus, + pub(super) cached_entry: Option, + pub(super) fill_guard: Option, + pub(super) key: String, + pub(super) policy: &'a RouteCachePolicy, pub(super) store_response: bool, + pub(super) zone: Arc, } impl CacheLookupRequest { @@ -38,18 +42,7 @@ impl CacheLookupRequest { ); let request_forces_revalidation = request_requires_revalidation(&request.headers); let read_cached_body = request.method != Method::HEAD; - Self { request, base_key, request_forces_revalidation, read_cached_body } - } - - pub(super) fn with_method(&self, method: Method) -> Self { - let request = self.request.with_method(method); - let read_cached_body = request.method != Method::HEAD; - Self { - request, - base_key: self.base_key.clone(), - request_forces_revalidation: self.request_forces_revalidation, - read_cached_body, - } + Self { base_key, read_cached_body, request, request_forces_revalidation } } pub(super) fn should_store_response(&self, policy: &RouteCachePolicy) -> bool { @@ -74,6 +67,17 @@ impl CacheLookupRequest { read_cached_body: self.read_cached_body, } } + + pub(super) fn with_method(&self, method: Method) -> Self { + let request = self.request.with_method(method); + let read_cached_body = request.method != Method::HEAD; + Self { + request, + base_key: self.base_key.clone(), + request_forces_revalidation: self.request_forces_revalidation, + read_cached_body, + } + } } impl CacheStoreContext { diff --git a/crates/rginx-http/src/cache/manager/response.rs b/crates/rginx-http/src/cache/manager/response.rs index 262aa886..fba48ab2 100644 --- a/crates/rginx-http/src/cache/manager/response.rs +++ b/crates/rginx-http/src/cache/manager/response.rs @@ -1,6 +1,18 @@ -use super::*; +use super::{ + Arc, CacheManager, CacheStatus, CacheStoreContext, CacheStoreError, CacheZoneRuntime, + HttpResponse, StatusCode, SystemTime, record_zone_shared_entry_access, + refresh_not_modified_response, store_response, sync_zone_shared_index_if_needed, unix_time_ms, + with_cache_status, +}; impl CacheManager { + pub(crate) async fn complete_not_modified( + &self, + context: CacheStoreContext, + response: HttpResponse, + ) -> std::result::Result { + refresh_not_modified_response(context, response).await + } pub(super) async fn finalize_lookup_cached_response( &self, zone: &Arc, @@ -46,12 +58,4 @@ impl CacheManager { with_cache_status(response, status) } - - pub(crate) async fn complete_not_modified( - &self, - context: CacheStoreContext, - response: HttpResponse, - ) -> std::result::Result { - refresh_not_modified_response(context, response).await - } } diff --git a/crates/rginx-http/src/cache/mod.rs b/crates/rginx-http/src/cache/mod.rs index 7f103204..0d277ff7 100644 --- a/crates/rginx-http/src/cache/mod.rs +++ b/crates/rginx-http/src/cache/mod.rs @@ -1,20 +1,5 @@ //! Route-level HTTP response cache for proxied responses. -use std::collections::{BTreeSet, HashMap}; -use std::path::PathBuf; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, RwLock, Weak}; -use std::time::SystemTime; - -use http::header::{HeaderMap, HeaderValue, IF_MODIFIED_SINCE, IF_NONE_MATCH}; -use http::{Method, Request, StatusCode, Uri}; -use rginx_core::{CacheZone, ConfigSnapshot, Error, Result, RouteCachePolicy}; -use serde::{Deserialize, Serialize}; -use tokio::sync::futures::OwnedNotified; -use tokio::sync::{Mutex as AsyncMutex, Notify}; - -use crate::handler::{HttpBody, HttpResponse}; - mod entry; mod fill; mod index; @@ -31,14 +16,33 @@ mod state; mod store; mod vary; +#[cfg(test)] +#[path = "invalidation/tests.rs"] +mod invalidation_tests; + +#[cfg(test)] +mod tests; + +use std::collections::{BTreeSet, HashMap}; +use std::path::PathBuf; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, RwLock, Weak}; +use std::time::SystemTime; + +use http::header::{HeaderMap, HeaderValue, IF_MODIFIED_SINCE, IF_NONE_MATCH}; +use http::{Method, Request, StatusCode, Uri}; +use rginx_core::{CacheZone, ConfigSnapshot, Error, Result, RouteCachePolicy}; +use serde::{Deserialize, Serialize}; +use tokio::sync::futures::OwnedNotified; +use tokio::sync::{Mutex as AsyncMutex, Notify}; + +use crate::handler::{HttpBody, HttpResponse}; + use entry::{ CacheMetadata, cache_paths_for_zone, load_cached_response_head, read_cached_response_for_request, serve_prepared_cached_response_for_request, unix_time_ms, }; #[cfg(test)] -#[path = "invalidation/tests.rs"] -mod invalidation_tests; -#[cfg(test)] use entry::{ cache_key_hash, cache_metadata, cache_paths, cache_variant_key, prepare_cached_response_head, read_cache_metadata, write_cache_entry, @@ -89,107 +93,107 @@ pub(crate) struct CacheManager { } pub(crate) struct CacheStoreContext { - zone: Arc, - policy: RouteCachePolicy, - request: CacheRequest, + _fill_guard: Option, base_key: String, - key: String, cache_status: CacheStatus, - store_response: bool, - _fill_guard: Option, cached_entry: Option, cached_response_head: Option>, - revalidating: bool, - request_forces_revalidation: bool, + key: String, + policy: RouteCachePolicy, read_cached_body: bool, + request: CacheRequest, + request_forces_revalidation: bool, + revalidating: bool, + store_response: bool, + zone: Arc, } #[derive(Clone)] pub(crate) struct CacheRequest { + headers: HeaderMap, method: Method, uri: Uri, - headers: HeaderMap, } pub(crate) enum CacheLookup { + Bypass(CacheStatus), Hit(HttpResponse), Miss(Box), Updating(HttpResponse, Box), - Bypass(CacheStatus), } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum CacheStatus { - Hit, - Miss, Bypass, Expired, + Hit, + Miss, + Revalidated, Stale, Updating, - Revalidated, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CacheZoneRuntimeSnapshot { - pub zone_name: String, - pub path: PathBuf, - pub max_size_bytes: Option, - pub inactive_secs: u64, + pub active_invalidation_rules: usize, + pub bypass_total: u64, + pub current_size_bytes: usize, pub default_ttl_secs: u64, - pub max_entry_bytes: usize, pub entry_count: usize, - pub current_size_bytes: usize, + pub eviction_total: u64, + pub expired_total: u64, pub hit_total: u64, + pub inactive_cleanup_total: u64, + pub inactive_secs: u64, + pub invalidation_total: u64, + pub max_entry_bytes: usize, + pub max_size_bytes: Option, pub miss_total: u64, - pub bypass_total: u64, - pub expired_total: u64, - pub stale_total: u64, - pub updating_total: u64, - pub revalidated_total: u64, - pub write_success_total: u64, - pub write_error_total: u64, - pub eviction_total: u64, + pub path: PathBuf, pub purge_total: u64, - pub invalidation_total: u64, - pub inactive_cleanup_total: u64, - pub wait_local_total: u64, - pub wait_external_total: u64, - pub stale_serve_updating_total: u64, - pub stale_serve_error_total: u64, - pub stale_serve_timeout_total: u64, - pub stale_serve_status_total: u64, pub revalidate_not_modified_total: u64, - pub active_invalidation_rules: usize, + pub revalidated_total: u64, + pub shared_index_capacity_rejection_total: u64, + pub shared_index_current_size_bytes: u64, pub shared_index_enabled: bool, - pub shared_index_generation: u64, - pub shared_index_shm_capacity_bytes: u64, - pub shared_index_shm_used_bytes: u64, pub shared_index_entry_count: u64, - pub shared_index_current_size_bytes: u64, + pub shared_index_full_reload_total: u64, + pub shared_index_generation: u64, + pub shared_index_lock_contention_total: u64, pub shared_index_operation_ring_capacity: u64, pub shared_index_operation_ring_used: u64, - pub shared_index_lock_contention_total: u64, - pub shared_index_full_reload_total: u64, pub shared_index_rebuild_total: u64, + pub shared_index_shm_capacity_bytes: u64, + pub shared_index_shm_used_bytes: u64, pub shared_index_stale_fill_lock_cleanup_total: u64, - pub shared_index_capacity_rejection_total: u64, + pub stale_serve_error_total: u64, + pub stale_serve_status_total: u64, + pub stale_serve_timeout_total: u64, + pub stale_serve_updating_total: u64, + pub stale_total: u64, + pub updating_total: u64, + pub wait_external_total: u64, + pub wait_local_total: u64, + pub write_error_total: u64, + pub write_success_total: u64, + pub zone_name: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CachePurgeResult { - pub zone_name: String, - pub scope: String, - pub removed_entries: usize, pub removed_bytes: usize, + pub removed_entries: usize, + pub scope: String, + pub zone_name: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct CacheInvalidationResult { - pub zone_name: String, - pub scope: String, - pub affected_entries: usize, - pub affected_bytes: usize, pub active_rules: usize, + pub affected_bytes: usize, + pub affected_entries: usize, + pub scope: String, + pub zone_name: String, } impl CacheStatus { @@ -205,6 +209,3 @@ impl CacheStatus { }) } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/cache/policy.rs b/crates/rginx-http/src/cache/policy.rs index f5e386bb..6f22500a 100644 --- a/crates/rginx-http/src/cache/policy.rs +++ b/crates/rginx-http/src/cache/policy.rs @@ -1,3 +1,5 @@ +mod directives; + use std::time::Duration; use http::StatusCode; @@ -13,8 +15,6 @@ use crate::handler::HttpResponse; use super::CacheStoreContext; use super::request::{cacheable_range_request, response_content_range_matches_request}; -mod directives; - pub(super) use directives::cache_control_contains; use directives::{ cache_control_duration, cache_control_max_age, expires_ttl, pragma_contains, @@ -23,11 +23,11 @@ use directives::{ #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(super) struct ResponseFreshness { - pub(super) ttl: Duration, + pub(super) must_revalidate: bool, + pub(super) requires_revalidation: bool, pub(super) stale_if_error: Option, pub(super) stale_while_revalidate: Option, - pub(super) requires_revalidation: bool, - pub(super) must_revalidate: bool, + pub(super) ttl: Duration, } #[derive(Debug, Clone, Copy)] @@ -37,6 +37,17 @@ pub(super) struct ResponseBodySize { upper: Option, } +impl ResponseBodySize { + pub(super) fn exact(body_size_bytes: usize) -> Self { + let body_size_bytes = body_size_bytes as u64; + Self { exact: Some(body_size_bytes), lower: body_size_bytes, upper: Some(body_size_bytes) } + } + fn from_response(response: &HttpResponse) -> Self { + let size_hint = response.body().size_hint(); + Self { exact: size_hint.exact(), lower: size_hint.lower(), upper: size_hint.upper() } + } +} + pub(super) fn response_is_storable(context: &CacheStoreContext, response: &HttpResponse) -> bool { response_is_storable_with_size( context, @@ -222,15 +233,3 @@ fn parse_content_length(headers: &HeaderMap) -> Option { .and_then(|value| value.to_str().ok()) .and_then(|value| value.parse::().ok()) } - -impl ResponseBodySize { - fn from_response(response: &HttpResponse) -> Self { - let size_hint = response.body().size_hint(); - Self { exact: size_hint.exact(), lower: size_hint.lower(), upper: size_hint.upper() } - } - - pub(super) fn exact(body_size_bytes: usize) -> Self { - let body_size_bytes = body_size_bytes as u64; - Self { exact: Some(body_size_bytes), lower: body_size_bytes, upper: Some(body_size_bytes) } - } -} diff --git a/crates/rginx-http/src/cache/request.rs b/crates/rginx-http/src/cache/request.rs index 7090ccd5..f82c7e2e 100644 --- a/crates/rginx-http/src/cache/request.rs +++ b/crates/rginx-http/src/cache/request.rs @@ -1,3 +1,5 @@ +mod render; + use http::HeaderValue; use http::header::{AUTHORIZATION, CONTENT_RANGE, CONTENT_TYPE, HeaderMap, IF_RANGE, RANGE}; use http::{Method, Uri}; @@ -9,25 +11,32 @@ use super::CacheRequest; use super::policy::cache_control_contains; use super::vary::normalized_accept_encoding; -mod render; - use render::{append_cache_key_suffix, cache_key_suffix_capacity}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(super) struct CacheableRangeRequest { - pub(super) request_start: u64, - pub(super) request_end: u64, - pub(super) cache_start: u64, pub(super) cache_end: u64, + pub(super) cache_start: u64, + pub(super) request_end: u64, + pub(super) request_start: u64, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum ParsedRangeHeader { + Invalid, SingleBounded { start: u64, end: u64 }, UnsupportedMultiRange, UnsupportedOpenEnded, UnsupportedSuffix, - Invalid, +} + +impl CacheableRangeRequest { + pub(super) fn needs_downstream_trimming(self) -> bool { + self.request_start != self.cache_start || self.request_end != self.cache_end + } + pub(super) fn upstream_header_value(self) -> String { + format!("bytes={}-{}", self.cache_start, self.cache_end) + } } pub(super) fn cache_request_bypass(request: &CacheRequest, policy: &RouteCachePolicy) -> bool { @@ -94,16 +103,16 @@ pub(super) fn render_cache_key( range, ); let mut rendered = if let Some(static_key) = policy.key.static_rendered() { - let mut rendered = String::with_capacity(static_key.len() + suffix_capacity); + let mut rendered = String::with_capacity(static_key.len().saturating_add(suffix_capacity)); rendered.push_str(static_key); rendered } else { - let request_uri = uri.path_and_query().map(|value| value.as_str()).unwrap_or("/"); + let request_uri = uri.path_and_query().map_or("/", http::uri::PathAndQuery::as_str); let host = headers .get(http::header::HOST) .and_then(|value| value.to_str().ok()) .filter(|value| !value.trim().is_empty()) - .or_else(|| uri.authority().map(|authority| authority.as_str())) + .or_else(|| uri.authority().map(http::uri::Authority::as_str)) .unwrap_or("-"); let context = CacheKeyRenderContext { scheme, @@ -113,7 +122,7 @@ pub(super) fn render_cache_key( headers, }; let mut rendered = String::with_capacity( - policy.key.estimated_rendered_capacity(&context) + suffix_capacity, + policy.key.estimated_rendered_capacity(&context).saturating_add(suffix_capacity), ); policy.key.append_rendered(&mut rendered, &context); rendered @@ -266,7 +275,11 @@ fn cacheable_range_request_for_policy( let Some(slice_size) = policy.slice_size_bytes else { return Some(request); }; - let slice_start = request.request_start / slice_size * slice_size; + let slice_start = request + .request_start + .checked_div(slice_size) + .expect("validated cache slice size is nonzero") + .saturating_mul(slice_size); let slice_end = slice_start.saturating_add(slice_size.saturating_sub(1)); (request.request_end <= slice_end).then_some(CacheableRangeRequest { cache_start: slice_start, @@ -274,13 +287,3 @@ fn cacheable_range_request_for_policy( ..request }) } - -impl CacheableRangeRequest { - pub(super) fn upstream_header_value(self) -> String { - format!("bytes={}-{}", self.cache_start, self.cache_end) - } - - pub(super) fn needs_downstream_trimming(self) -> bool { - self.request_start != self.cache_start || self.request_end != self.cache_end - } -} diff --git a/crates/rginx-http/src/cache/request/render.rs b/crates/rginx-http/src/cache/request/render.rs index e333b118..cb1dd2ba 100644 --- a/crates/rginx-http/src/cache/request/render.rs +++ b/crates/rginx-http/src/cache/request/render.rs @@ -31,10 +31,11 @@ pub(super) fn cache_key_suffix_capacity( ) -> usize { let mut capacity = 0usize; if add_cache_method_suffix { - capacity = capacity.saturating_add("|cache-method:".len() + cache_method.len()); + capacity = + capacity.saturating_add("|cache-method:".len().saturating_add(cache_method.len())); } if let Some(accept_encoding) = accept_encoding { - capacity = capacity.saturating_add("|ae:".len() + accept_encoding.len()); + capacity = capacity.saturating_add("|ae:".len().saturating_add(accept_encoding.len())); } if let Some(range) = range { capacity = capacity @@ -50,7 +51,7 @@ fn decimal_len(mut value: u64) -> usize { let mut digits = 1usize; while value >= 10 { value /= 10; - digits += 1; + digits = digits.saturating_add(1); } digits } diff --git a/crates/rginx-http/src/cache/runtime.rs b/crates/rginx-http/src/cache/runtime.rs index 489fb704..f418e4de 100644 --- a/crates/rginx-http/src/cache/runtime.rs +++ b/crates/rginx-http/src/cache/runtime.rs @@ -1,10 +1,20 @@ -use super::*; - mod context; mod fill_lock; mod support; mod zone; +use super::{ + Arc, AtomicU64, CACHE_STATUS_HEADER, CacheConditionalHeaders, CacheEntryHotState, + CacheFillGuard, CacheFillLockState, CacheFillReadState, CacheIndex, CacheIndexEntry, + CacheRequest, CacheStaleReason, CacheStatus, CacheStoreContext, CacheZoneRuntime, + CacheZoneRuntimeSnapshot, FillLockDecision, HeaderMap, HttpBody, HttpResponse, + IF_MODIFIED_SINCE, IF_NONE_MATCH, InflightFillReadState, Method, Mutex, Notify, Ordering, + PreparedCacheResponseHead, Request, SharedIndexStore, StatusCode, SystemTime, + cache_paths_for_zone, header_value, read_index, record_zone_shared_entry_access, + remove_zone_index_entry_if_matches, request, serve_prepared_cached_response_for_request, + shared, unix_time_ms, +}; + pub(in crate::cache) use support::PurgeSelector; #[cfg(test)] pub(in crate::cache) use support::stale_reuse_blocked; diff --git a/crates/rginx-http/src/cache/runtime/context.rs b/crates/rginx-http/src/cache/runtime/context.rs index 190f0343..97713d27 100644 --- a/crates/rginx-http/src/cache/runtime/context.rs +++ b/crates/rginx-http/src/cache/runtime/context.rs @@ -1,4 +1,10 @@ -use super::*; +use super::{ + CacheRequest, CacheStaleReason, CacheStatus, CacheStoreContext, HeaderMap, HttpBody, + HttpResponse, IF_MODIFIED_SINCE, IF_NONE_MATCH, Method, Request, StatusCode, SystemTime, + record_zone_shared_entry_access, remove_cache_entry_if_matches, request, + serve_prepared_cached_response_for_request, stale_for_reason_allowed, unix_time_ms, + with_cache_status, +}; impl CacheRequest { pub(crate) fn from_request(request: &Request) -> Self { @@ -9,26 +15,32 @@ impl CacheRequest { } } - pub(crate) fn with_method(&self, method: Method) -> Self { - Self { method, uri: self.uri.clone(), headers: self.headers.clone() } + pub(crate) fn request_uri(&self) -> &str { + self.uri.path_and_query().map_or("/", http::uri::PathAndQuery::as_str) } - pub(crate) fn request_uri(&self) -> &str { - self.uri.path_and_query().map(|value| value.as_str()).unwrap_or("/") + pub(crate) fn with_method(&self, method: Method) -> Self { + Self { method, uri: self.uri.clone(), headers: self.headers.clone() } } } impl CacheStoreContext { - pub(crate) fn cache_status(&self) -> CacheStatus { - self.cache_status - } - - pub(crate) fn prepares_cacheable_upstream_request(&self) -> bool { - self.store_response + pub(crate) fn apply_conditional_request_headers(&self, headers: &mut HeaderMap) { + let Some(conditional_headers) = + self.cached_response_head.as_ref().and_then(|head| head.conditional_headers.as_ref()) + else { + return; + }; + if let Some(value) = conditional_headers.if_none_match.clone() { + headers.insert(IF_NONE_MATCH, value); + } + if let Some(value) = conditional_headers.if_modified_since.clone() { + headers.insert(IF_MODIFIED_SINCE, value); + } } - pub(crate) fn upstream_request_method(&self) -> Method { - request::upstream_cache_request_method(&self.request.method, &self.policy) + pub(crate) fn apply_upstream_request_headers(&self, headers: &mut HeaderMap) { + request::apply_upstream_range_headers(&self.request.method, headers, &self.policy); } pub(crate) fn build_background_request(&self) -> Request { @@ -43,26 +55,8 @@ impl CacheStoreContext { request } - pub(crate) fn apply_upstream_request_headers(&self, headers: &mut HeaderMap) { - request::apply_upstream_range_headers(&self.request.method, headers, &self.policy); - } - - pub(crate) fn apply_conditional_request_headers(&self, headers: &mut HeaderMap) { - let Some(conditional_headers) = - self.cached_response_head.as_ref().and_then(|head| head.conditional_headers.as_ref()) - else { - return; - }; - if let Some(value) = conditional_headers.if_none_match.clone() { - headers.insert(IF_NONE_MATCH, value); - } - if let Some(value) = conditional_headers.if_modified_since.clone() { - headers.insert(IF_MODIFIED_SINCE, value); - } - } - - pub(crate) fn should_refresh_from_not_modified(&self, status: StatusCode) -> bool { - self.cached_entry.is_some() && status == StatusCode::NOT_MODIFIED + pub(crate) fn cache_status(&self) -> CacheStatus { + self.cache_status } pub(crate) fn can_serve_stale(&self, reason: CacheStaleReason) -> bool { @@ -76,6 +70,10 @@ impl CacheStoreContext { stale_for_reason_allowed(&self.policy, entry, now, self.request_forces_revalidation, reason) } + pub(crate) fn prepares_cacheable_upstream_request(&self) -> bool { + self.store_response + } + pub(crate) async fn serve_stale( &self, reason: CacheStaleReason, @@ -125,4 +123,12 @@ impl CacheStoreContext { } } } + + pub(crate) fn should_refresh_from_not_modified(&self, status: StatusCode) -> bool { + self.cached_entry.is_some() && status == StatusCode::NOT_MODIFIED + } + + pub(crate) fn upstream_request_method(&self) -> Method { + request::upstream_cache_request_method(&self.request.method, &self.policy) + } } diff --git a/crates/rginx-http/src/cache/runtime/fill_lock.rs b/crates/rginx-http/src/cache/runtime/fill_lock.rs index 3af77813..05fa3760 100644 --- a/crates/rginx-http/src/cache/runtime/fill_lock.rs +++ b/crates/rginx-http/src/cache/runtime/fill_lock.rs @@ -2,15 +2,58 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::SystemTime; -use super::*; +use super::{ + CacheFillGuard, CacheFillLockState, CacheFillReadState, CacheZoneRuntime, FillLockDecision, + InflightFillReadState, Notify, Ordering, SharedIndexStore, shared, unix_time_ms, +}; enum FileSharedFillLockState { - Missing, Fresh, + Missing, Stale, } impl CacheZoneRuntime { + pub(in crate::cache) fn attach_fill_read_state( + &self, + key: &str, + generation: u64, + state: Arc, + ) -> Option> { + let mut fill_locks = + self.fill_locks.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + let lock = fill_locks.get_mut(key)?; + if lock.generation != generation { + return None; + } + lock.reader_state = Some(state.clone()); + lock.notify.notify_waiters(); + Some(state) + } + + fn file_shared_fill_lock_state( + &self, + lock_path: &Path, + lock_age: std::time::Duration, + ) -> FileSharedFillLockState { + let metadata = match std::fs::metadata(lock_path) { + Ok(metadata) => metadata, + Err(error) if error.kind() == std::io::ErrorKind::NotFound => { + return FileSharedFillLockState::Missing; + } + Err(_) => return FileSharedFillLockState::Fresh, + }; + let Ok(modified) = metadata.modified() else { + return FileSharedFillLockState::Fresh; + }; + if unix_time_ms(SystemTime::now()).saturating_sub(unix_time_ms(modified)) + > lock_age.as_millis() as u64 + { + FileSharedFillLockState::Stale + } else { + FileSharedFillLockState::Fresh + } + } pub(in crate::cache) fn fill_lock_decision( self: &Arc, key: &str, @@ -19,7 +62,7 @@ impl CacheZoneRuntime { share_fingerprint: Option<&str>, ) -> FillLockDecision { let mut fill_locks = - self.fill_locks.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + self.fill_locks.lock().unwrap_or_else(std::sync::PoisonError::into_inner); if let Some(lock) = fill_locks.get(key).cloned() && now.saturating_sub(lock.acquired_at_unix_ms) <= lock_age.as_millis() as u64 { @@ -66,49 +109,45 @@ impl CacheZoneRuntime { (None, None) } } - } else { - match self.try_acquire_file_shared_fill_lock(key, lock_age) { - Some(lock_path) => { - match super::super::fill::create_file_shared_external_fill_handle( - self.config.as_ref(), - key, - &lock_path, - now, - share_fingerprint, - ) { - Ok(state) => (Some(lock_path), Some(state)), - Err(error) => { - tracing::warn!( - zone = %self.config.name, - key = %key, - path = %lock_path.display(), - %error, - "failed to initialize external shared fill state; falling back to local fill coordination" - ); - let _ = std::fs::remove_file(&lock_path); - (None, None) - } - } - } - None => { - if let Some(state) = super::super::fill::load_file_external_fill_state( - self.config.as_ref(), - key, - share_fingerprint, - ) { - return FillLockDecision::Read { - state: InflightFillReadState::External(state), - }; - } - return FillLockDecision::WaitExternal { key: key.to_string() }; + } else if let Some(lock_path) = self.try_acquire_file_shared_fill_lock(key, lock_age) { + match super::super::fill::create_file_shared_external_fill_handle( + self.config.as_ref(), + key, + &lock_path, + now, + share_fingerprint, + ) { + Ok(state) => (Some(lock_path), Some(state)), + Err(error) => { + tracing::warn!( + zone = %self.config.name, + key = %key, + path = %lock_path.display(), + %error, + "failed to initialize external shared fill state; falling back to local fill coordination" + ); + let _ = std::fs::remove_file(&lock_path); + (None, None) } } + } else { + if let Some(state) = super::super::fill::load_file_external_fill_state( + self.config.as_ref(), + key, + share_fingerprint, + ) { + return FillLockDecision::Read { + state: InflightFillReadState::External(state), + }; + } + return FillLockDecision::WaitExternal { key: key.to_string() }; } } else { (None, None) }; let notify = Arc::new(Notify::new()); - let generation = self.fill_lock_generation.fetch_add(1, Ordering::Relaxed) + 1; + let generation = + self.fill_lock_generation.fetch_add(1, Ordering::Relaxed).saturating_add(1); fill_locks.insert( key.to_string(), CacheFillLockState { @@ -129,6 +168,34 @@ impl CacheZoneRuntime { }) } + fn try_acquire_file_shared_fill_lock( + &self, + key: &str, + lock_age: std::time::Duration, + ) -> Option { + let lock_path = shared::shared_fill_lock_path(self.config.as_ref(), key); + loop { + match std::fs::OpenOptions::new().write(true).create_new(true).open(&lock_path) { + Ok(_) => return Some(lock_path), + Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => { + match self.file_shared_fill_lock_state(&lock_path, lock_age) { + FileSharedFillLockState::Missing | FileSharedFillLockState::Stale => { + match std::fs::remove_file(&lock_path) { + Ok(()) => continue, + Err(error) if error.kind() == std::io::ErrorKind::NotFound => { + continue; + } + Err(_) => return None, + } + } + FileSharedFillLockState::Fresh => return None, + } + } + Err(_) => return None, + } + } + } + pub(in crate::cache) async fn wait_for_external_fill_lock( &self, key: &str, @@ -144,7 +211,9 @@ impl CacheZoneRuntime { } let lock_path = shared::shared_fill_lock_path(self.config.as_ref(), key); - let deadline = tokio::time::Instant::now() + lock_timeout; + let deadline = tokio::time::Instant::now() + .checked_add(lock_timeout) + .expect("cache fill lock deadline remains representable"); loop { match self.file_shared_fill_lock_state(&lock_path, lock_age) { FileSharedFillLockState::Missing => return true, @@ -173,7 +242,9 @@ impl CacheZoneRuntime { lock_timeout: std::time::Duration, lock_age: std::time::Duration, ) -> bool { - let deadline = tokio::time::Instant::now() + lock_timeout; + let deadline = tokio::time::Instant::now() + .checked_add(lock_timeout) + .expect("cache fill lock deadline remains representable"); loop { let now = unix_time_ms(SystemTime::now()); match super::super::fill::memory_shared_fill_lock_state(store, key, now, lock_age) { @@ -198,81 +269,13 @@ impl CacheZoneRuntime { tokio::time::sleep(remaining.min(std::time::Duration::from_millis(25))).await; } } - - fn try_acquire_file_shared_fill_lock( - &self, - key: &str, - lock_age: std::time::Duration, - ) -> Option { - let lock_path = shared::shared_fill_lock_path(self.config.as_ref(), key); - loop { - match std::fs::OpenOptions::new().write(true).create_new(true).open(&lock_path) { - Ok(_) => return Some(lock_path), - Err(error) if error.kind() == std::io::ErrorKind::AlreadyExists => { - match self.file_shared_fill_lock_state(&lock_path, lock_age) { - FileSharedFillLockState::Missing | FileSharedFillLockState::Stale => { - match std::fs::remove_file(&lock_path) { - Ok(()) => continue, - Err(error) if error.kind() == std::io::ErrorKind::NotFound => { - continue; - } - Err(_) => return None, - } - } - FileSharedFillLockState::Fresh => return None, - } - } - Err(_) => return None, - } - } - } - - pub(in crate::cache) fn attach_fill_read_state( - &self, - key: &str, - generation: u64, - state: Arc, - ) -> Option> { - let mut fill_locks = - self.fill_locks.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - let lock = fill_locks.get_mut(key)?; - if lock.generation != generation { - return None; - } - lock.reader_state = Some(state.clone()); - lock.notify.notify_waiters(); - Some(state) - } - - fn file_shared_fill_lock_state( - &self, - lock_path: &Path, - lock_age: std::time::Duration, - ) -> FileSharedFillLockState { - let metadata = match std::fs::metadata(lock_path) { - Ok(metadata) => metadata, - Err(error) if error.kind() == std::io::ErrorKind::NotFound => { - return FileSharedFillLockState::Missing; - } - Err(_) => return FileSharedFillLockState::Fresh, - }; - let Ok(modified) = metadata.modified() else { - return FileSharedFillLockState::Fresh; - }; - if unix_time_ms(SystemTime::now()).saturating_sub(unix_time_ms(modified)) - > lock_age.as_millis() as u64 - { - FileSharedFillLockState::Stale - } else { - FileSharedFillLockState::Fresh - } - } } impl Drop for CacheFillGuard { fn drop(&mut self) { if let Some(fill_locks) = self.fill_locks.upgrade() { - let mut fill_locks = fill_locks.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut fill_locks = + fill_locks.lock().unwrap_or_else(std::sync::PoisonError::into_inner); if fill_locks.get(&self.key).is_some_and(|lock| lock.generation == self.generation) { fill_locks.remove(&self.key); } diff --git a/crates/rginx-http/src/cache/runtime/support.rs b/crates/rginx-http/src/cache/runtime/support.rs index 20b11ba0..0b249043 100644 --- a/crates/rginx-http/src/cache/runtime/support.rs +++ b/crates/rginx-http/src/cache/runtime/support.rs @@ -3,14 +3,24 @@ use http::header::{ETAG, HeaderMap, HeaderValue, LAST_MODIFIED}; use rginx_core::RouteCachePolicy; use tokio::fs; -use super::*; +use super::{ + Arc, CacheConditionalHeaders, CacheIndexEntry, CacheStaleReason, CacheZoneRuntime, + cache_paths_for_zone, header_value, read_index, remove_zone_index_entry_if_matches, +}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(in crate::cache) enum CacheEntryLifecyclePhase { + Dead, Fresh, Grace, Keep, - Dead, +} + +#[derive(Debug, Clone)] +pub(in crate::cache) enum PurgeSelector { + All, + Exact(String), + Prefix(String), } pub(in crate::cache) async fn remove_cache_entry_if_matches( @@ -52,7 +62,7 @@ pub(in crate::cache) fn build_conditional_headers( let if_modified_since = header_value(headers, LAST_MODIFIED).and_then(|value| HeaderValue::from_str(&value).ok()); (if_none_match.is_some() || if_modified_since.is_some()) - .then_some(CacheConditionalHeaders { if_none_match, if_modified_since }) + .then_some(CacheConditionalHeaders { if_modified_since, if_none_match }) } pub(in crate::cache) fn stale_if_error_window_open(entry: &CacheIndexEntry, now: u64) -> bool { @@ -172,13 +182,6 @@ pub(in crate::cache) fn stale_for_reason_allowed( } } -#[derive(Debug, Clone)] -pub(in crate::cache) enum PurgeSelector { - All, - Exact(String), - Prefix(String), -} - pub(in crate::cache) async fn remove_cache_files_locked(zone: &rginx_core::CacheZone, hash: &str) { let paths = cache_paths_for_zone(zone, hash); remove_cache_file(&paths.metadata, hash).await; diff --git a/crates/rginx-http/src/cache/runtime/zone.rs b/crates/rginx-http/src/cache/runtime/zone.rs index fac3fe1c..e2c76ce2 100644 --- a/crates/rginx-http/src/cache/runtime/zone.rs +++ b/crates/rginx-http/src/cache/runtime/zone.rs @@ -1,75 +1,31 @@ -use super::*; +mod compare; +use super::{ + Arc, AtomicU64, CacheEntryHotState, CacheIndex, CacheIndexEntry, CacheStaleReason, + CacheZoneRuntime, CacheZoneRuntimeSnapshot, Mutex, Ordering, PreparedCacheResponseHead, + read_index, +}; use crate::cache::shared::SHARED_ACCESS_TOUCH_PUBLISH_INTERVAL_MS; use crate::cache::shared::run_blocking; -mod compare; use compare::same_cached_object_metadata; impl CacheZoneRuntime { - pub(in crate::cache) fn snapshot(&self) -> CacheZoneRuntimeSnapshot { - let (entry_count, current_size_bytes, active_invalidation_rules) = { - let index = read_index(&self.index); - (index.entries.len(), index.current_size_bytes, index.invalidations.len()) - }; - let shared_index_metrics = self - .shared_index_store - .as_ref() - .and_then(|store| run_blocking(|| store.metrics()).ok()) - .unwrap_or_default(); - CacheZoneRuntimeSnapshot { - zone_name: self.config.name.clone(), - path: self.config.path.clone(), - max_size_bytes: self.config.max_size_bytes, - inactive_secs: self.config.inactive.as_secs(), - default_ttl_secs: self.config.default_ttl.as_secs(), - max_entry_bytes: self.config.max_entry_bytes, - entry_count, - current_size_bytes, - hit_total: self.stats.hit_total.load(Ordering::Relaxed), - miss_total: self.stats.miss_total.load(Ordering::Relaxed), - bypass_total: self.stats.bypass_total.load(Ordering::Relaxed), - expired_total: self.stats.expired_total.load(Ordering::Relaxed), - stale_total: self.stats.stale_total.load(Ordering::Relaxed), - updating_total: self.stats.updating_total.load(Ordering::Relaxed), - revalidated_total: self.stats.revalidated_total.load(Ordering::Relaxed), - write_success_total: self.stats.write_success_total.load(Ordering::Relaxed), - write_error_total: self.stats.write_error_total.load(Ordering::Relaxed), - eviction_total: self.stats.eviction_total.load(Ordering::Relaxed), - purge_total: self.stats.purge_total.load(Ordering::Relaxed), - invalidation_total: self.stats.invalidation_total.load(Ordering::Relaxed), - inactive_cleanup_total: self.stats.inactive_cleanup_total.load(Ordering::Relaxed), - wait_local_total: self.stats.wait_local_total.load(Ordering::Relaxed), - wait_external_total: self.stats.wait_external_total.load(Ordering::Relaxed), - stale_serve_updating_total: self - .stats - .stale_serve_updating_total - .load(Ordering::Relaxed), - stale_serve_error_total: self.stats.stale_serve_error_total.load(Ordering::Relaxed), - stale_serve_timeout_total: self.stats.stale_serve_timeout_total.load(Ordering::Relaxed), - stale_serve_status_total: self.stats.stale_serve_status_total.load(Ordering::Relaxed), - revalidate_not_modified_total: self - .stats - .revalidate_not_modified_total - .load(Ordering::Relaxed), - active_invalidation_rules, - shared_index_enabled: self.config.shared_index, - shared_index_generation: self.shared_index_generation.load(Ordering::Relaxed), - shared_index_shm_capacity_bytes: shared_index_metrics.shm_capacity_bytes, - shared_index_shm_used_bytes: shared_index_metrics.shm_used_bytes, - shared_index_entry_count: shared_index_metrics.entry_count, - shared_index_current_size_bytes: shared_index_metrics.current_size_bytes, - shared_index_operation_ring_capacity: shared_index_metrics.operation_ring_capacity, - shared_index_operation_ring_used: shared_index_metrics.operation_ring_used, - shared_index_lock_contention_total: shared_index_metrics.lock_contention_total, - shared_index_full_reload_total: shared_index_metrics.full_reload_total, - shared_index_rebuild_total: shared_index_metrics.rebuild_total, - shared_index_stale_fill_lock_cleanup_total: shared_index_metrics - .stale_fill_lock_cleanup_total, - shared_index_capacity_rejection_total: shared_index_metrics.capacity_rejection_total, - } + #[cfg(test)] + pub(in crate::cache) fn clear_hot_entries(&self) { + self.hot_entries.write().unwrap_or_else(|poisoned| poisoned.into_inner()).clear(); + } + + pub(in crate::cache) fn effective_last_access_unix_ms( + &self, + key: &str, + entry: &CacheIndexEntry, + ) -> u64 { + self.hot_entry(key) + .map(|hot| hot.last_access_unix_ms.load(Ordering::Relaxed)) + .map_or(entry.last_access_unix_ms, |hot| hot.max(entry.last_access_unix_ms)) } fn hot_entry(&self, key: &str) -> Option> { - self.hot_entries.read().unwrap_or_else(|poisoned| poisoned.into_inner()).get(key).cloned() + self.hot_entries.read().unwrap_or_else(std::sync::PoisonError::into_inner).get(key).cloned() } fn hot_entry_for_key(&self, key: &str) -> Arc { @@ -78,7 +34,7 @@ impl CacheZoneRuntime { } let mut hot_entries = - self.hot_entries.write().unwrap_or_else(|poisoned| poisoned.into_inner()); + self.hot_entries.write().unwrap_or_else(std::sync::PoisonError::into_inner); hot_entries .entry(key.to_string()) .or_insert_with(|| { @@ -91,39 +47,12 @@ impl CacheZoneRuntime { .clone() } - pub(in crate::cache) fn record_entry_access(&self, key: &str, now: u64) { - self.hot_entry_for_key(key).last_access_unix_ms.fetch_max(now, Ordering::Relaxed); - } - - pub(in crate::cache) fn should_publish_shared_entry_access(&self, key: &str, now: u64) -> bool { - let hot_entry = self.hot_entry_for_key(key); - loop { - let last_published = hot_entry.shared_touch_published_unix_ms.load(Ordering::Relaxed); - if last_published != 0 - && now.saturating_sub(last_published) < SHARED_ACCESS_TOUCH_PUBLISH_INTERVAL_MS - { - return false; - } - if hot_entry - .shared_touch_published_unix_ms - .compare_exchange(last_published, now, Ordering::Relaxed, Ordering::Relaxed) - .is_ok() - { - return true; - } + pub(in crate::cache) fn notify_changed(&self) { + if let Some(notifier) = &self.change_notifier { + notifier(&self.config.name); } } - pub(in crate::cache) fn effective_last_access_unix_ms( - &self, - key: &str, - entry: &CacheIndexEntry, - ) -> u64 { - self.hot_entry(key) - .map(|hot| hot.last_access_unix_ms.load(Ordering::Relaxed)) - .map_or(entry.last_access_unix_ms, |hot| hot.max(entry.last_access_unix_ms)) - } - pub(in crate::cache) fn prepared_response_head( &self, key: &str, @@ -133,159 +62,234 @@ impl CacheZoneRuntime { let response_head = hot_entry .response_head .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .clone()?; (response_head.hash == expected_hash).then_some(response_head) } - pub(in crate::cache) fn store_prepared_response_head( - &self, - key: &str, - last_access_unix_ms: u64, - response_head: Arc, - ) { - let still_current = read_index(&self.index) - .entries - .get(key) - .is_some_and(|entry| entry.hash == response_head.hash); - if !still_current { - return; - } - - let hot_entry = self.hot_entry_for_key(key); - hot_entry.last_access_unix_ms.fetch_max(last_access_unix_ms, Ordering::Relaxed); - *hot_entry.response_head.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) = - Some(response_head); + pub(in crate::cache) fn record_bypass(&self) { + self.record_counter(&self.stats.bypass_total, 1); } - pub(in crate::cache) fn remove_hot_entry(&self, key: &str) { - self.hot_entries.write().unwrap_or_else(|poisoned| poisoned.into_inner()).remove(key); + fn record_counter(&self, counter: &AtomicU64, value: u64) { + counter.fetch_add(value, Ordering::Relaxed); } - #[cfg(test)] - pub(in crate::cache) fn clear_hot_entries(&self) { - self.hot_entries.write().unwrap_or_else(|poisoned| poisoned.into_inner()).clear(); + pub(in crate::cache) fn record_entry_access(&self, key: &str, now: u64) { + self.hot_entry_for_key(key).last_access_unix_ms.fetch_max(now, Ordering::Relaxed); } - pub(in crate::cache) fn retain_hot_entries_for_reloaded_index( - &self, - previous: &CacheIndex, - next: &CacheIndex, - ) { - self.hot_entries.write().unwrap_or_else(|poisoned| poisoned.into_inner()).retain( - |key, hot_entry| { - let Some(previous_entry) = previous.entries.get(key) else { - return false; - }; - let Some(next_entry) = next.entries.get(key) else { - return false; - }; - if !same_cached_object_metadata(previous_entry, next_entry) { - return false; - } + pub(in crate::cache) fn record_evictions(&self, count: usize) { + if count > 0 { + self.record_counter(&self.stats.eviction_total, count as u64); + } + } - hot_entry - .response_head - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .as_ref() - .is_none_or(|response_head| response_head.hash == next_entry.hash) - }, - ); + pub(in crate::cache) fn record_expired(&self) { + self.record_counter(&self.stats.expired_total, 1); } pub(in crate::cache) fn record_hit(&self) { self.record_counter(&self.stats.hit_total, 1); } - pub(in crate::cache) fn record_miss(&self) { - self.record_counter(&self.stats.miss_total, 1); + pub(in crate::cache) fn record_inactive_cleanup(&self, count: usize) { + if count > 0 { + self.record_counter(&self.stats.inactive_cleanup_total, count as u64); + } } - pub(in crate::cache) fn record_bypass(&self) { - self.record_counter(&self.stats.bypass_total, 1); + pub(in crate::cache) fn record_invalidation(&self, count: usize) { + if count > 0 { + self.record_counter(&self.stats.invalidation_total, count as u64); + } } - pub(in crate::cache) fn record_expired(&self) { - self.record_counter(&self.stats.expired_total, 1); + pub(in crate::cache) fn record_miss(&self) { + self.record_counter(&self.stats.miss_total, 1); } - pub(in crate::cache) fn record_stale(&self) { - self.record_counter(&self.stats.stale_total, 1); + pub(in crate::cache) fn record_purge(&self, count: usize) { + if count > 0 { + self.record_counter(&self.stats.purge_total, count as u64); + } } - pub(in crate::cache) fn record_updating(&self) { - self.record_counter(&self.stats.updating_total, 1); + pub(in crate::cache) fn record_revalidate_not_modified(&self) { + self.record_counter(&self.stats.revalidate_not_modified_total, 1); } pub(in crate::cache) fn record_revalidated(&self) { self.record_counter(&self.stats.revalidated_total, 1); } - pub(in crate::cache) fn record_write_success(&self) { - self.record_counter(&self.stats.write_success_total, 1); - } - - pub(in crate::cache) fn record_write_error(&self) { - self.record_counter(&self.stats.write_error_total, 1); + pub(in crate::cache) fn record_stale(&self) { + self.record_counter(&self.stats.stale_total, 1); } - pub(in crate::cache) fn record_evictions(&self, count: usize) { - if count > 0 { - self.record_counter(&self.stats.eviction_total, count as u64); - } + pub(in crate::cache) fn record_stale_serve_reason(&self, reason: CacheStaleReason) { + let counter = match reason { + CacheStaleReason::Error => &self.stats.stale_serve_error_total, + CacheStaleReason::Timeout => &self.stats.stale_serve_timeout_total, + CacheStaleReason::Status(_) => &self.stats.stale_serve_status_total, + }; + self.record_counter(counter, 1); } - pub(in crate::cache) fn record_purge(&self, count: usize) { - if count > 0 { - self.record_counter(&self.stats.purge_total, count as u64); - } + pub(in crate::cache) fn record_stale_serve_updating(&self) { + self.record_counter(&self.stats.stale_serve_updating_total, 1); } - pub(in crate::cache) fn record_invalidation(&self, count: usize) { - if count > 0 { - self.record_counter(&self.stats.invalidation_total, count as u64); - } + pub(in crate::cache) fn record_updating(&self) { + self.record_counter(&self.stats.updating_total, 1); } - pub(in crate::cache) fn record_inactive_cleanup(&self, count: usize) { - if count > 0 { - self.record_counter(&self.stats.inactive_cleanup_total, count as u64); - } + pub(in crate::cache) fn record_wait_external(&self) { + self.record_counter(&self.stats.wait_external_total, 1); } pub(in crate::cache) fn record_wait_local(&self) { self.record_counter(&self.stats.wait_local_total, 1); } - pub(in crate::cache) fn record_wait_external(&self) { - self.record_counter(&self.stats.wait_external_total, 1); + pub(in crate::cache) fn record_write_error(&self) { + self.record_counter(&self.stats.write_error_total, 1); } - pub(in crate::cache) fn record_stale_serve_updating(&self) { - self.record_counter(&self.stats.stale_serve_updating_total, 1); + pub(in crate::cache) fn record_write_success(&self) { + self.record_counter(&self.stats.write_success_total, 1); } - pub(in crate::cache) fn record_stale_serve_reason(&self, reason: CacheStaleReason) { - let counter = match reason { - CacheStaleReason::Error => &self.stats.stale_serve_error_total, - CacheStaleReason::Timeout => &self.stats.stale_serve_timeout_total, - CacheStaleReason::Status(_) => &self.stats.stale_serve_status_total, - }; - self.record_counter(counter, 1); + pub(in crate::cache) fn remove_hot_entry(&self, key: &str) { + self.hot_entries.write().unwrap_or_else(std::sync::PoisonError::into_inner).remove(key); } - pub(in crate::cache) fn record_revalidate_not_modified(&self) { - self.record_counter(&self.stats.revalidate_not_modified_total, 1); + pub(in crate::cache) fn retain_hot_entries_for_reloaded_index( + &self, + previous: &CacheIndex, + next: &CacheIndex, + ) { + self.hot_entries.write().unwrap_or_else(std::sync::PoisonError::into_inner).retain( + |key, hot_entry| { + let Some(previous_entry) = previous.entries.get(key) else { + return false; + }; + let Some(next_entry) = next.entries.get(key) else { + return false; + }; + if !same_cached_object_metadata(previous_entry, next_entry) { + return false; + } + + hot_entry + .response_head + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .as_ref() + .is_none_or(|response_head| response_head.hash == next_entry.hash) + }, + ); } - fn record_counter(&self, counter: &AtomicU64, value: u64) { - counter.fetch_add(value, Ordering::Relaxed); + pub(in crate::cache) fn should_publish_shared_entry_access(&self, key: &str, now: u64) -> bool { + let hot_entry = self.hot_entry_for_key(key); + loop { + let last_published = hot_entry.shared_touch_published_unix_ms.load(Ordering::Relaxed); + if last_published != 0 + && now.saturating_sub(last_published) < SHARED_ACCESS_TOUCH_PUBLISH_INTERVAL_MS + { + return false; + } + if hot_entry + .shared_touch_published_unix_ms + .compare_exchange(last_published, now, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + return true; + } + } } - pub(in crate::cache) fn notify_changed(&self) { - if let Some(notifier) = &self.change_notifier { - notifier(&self.config.name); + pub(in crate::cache) fn snapshot(&self) -> CacheZoneRuntimeSnapshot { + let (entry_count, current_size_bytes, active_invalidation_rules) = { + let index = read_index(&self.index); + (index.entries.len(), index.current_size_bytes, index.invalidations.len()) + }; + let shared_index_metrics = self + .shared_index_store + .as_ref() + .and_then(|store| run_blocking(|| store.metrics()).ok()) + .unwrap_or_default(); + CacheZoneRuntimeSnapshot { + zone_name: self.config.name.clone(), + path: self.config.path.clone(), + max_size_bytes: self.config.max_size_bytes, + inactive_secs: self.config.inactive.as_secs(), + default_ttl_secs: self.config.default_ttl.as_secs(), + max_entry_bytes: self.config.max_entry_bytes, + entry_count, + current_size_bytes, + hit_total: self.stats.hit_total.load(Ordering::Relaxed), + miss_total: self.stats.miss_total.load(Ordering::Relaxed), + bypass_total: self.stats.bypass_total.load(Ordering::Relaxed), + expired_total: self.stats.expired_total.load(Ordering::Relaxed), + stale_total: self.stats.stale_total.load(Ordering::Relaxed), + updating_total: self.stats.updating_total.load(Ordering::Relaxed), + revalidated_total: self.stats.revalidated_total.load(Ordering::Relaxed), + write_success_total: self.stats.write_success_total.load(Ordering::Relaxed), + write_error_total: self.stats.write_error_total.load(Ordering::Relaxed), + eviction_total: self.stats.eviction_total.load(Ordering::Relaxed), + purge_total: self.stats.purge_total.load(Ordering::Relaxed), + invalidation_total: self.stats.invalidation_total.load(Ordering::Relaxed), + inactive_cleanup_total: self.stats.inactive_cleanup_total.load(Ordering::Relaxed), + wait_local_total: self.stats.wait_local_total.load(Ordering::Relaxed), + wait_external_total: self.stats.wait_external_total.load(Ordering::Relaxed), + stale_serve_updating_total: self + .stats + .stale_serve_updating_total + .load(Ordering::Relaxed), + stale_serve_error_total: self.stats.stale_serve_error_total.load(Ordering::Relaxed), + stale_serve_timeout_total: self.stats.stale_serve_timeout_total.load(Ordering::Relaxed), + stale_serve_status_total: self.stats.stale_serve_status_total.load(Ordering::Relaxed), + revalidate_not_modified_total: self + .stats + .revalidate_not_modified_total + .load(Ordering::Relaxed), + active_invalidation_rules, + shared_index_enabled: self.config.shared_index, + shared_index_generation: self.shared_index_generation.load(Ordering::Relaxed), + shared_index_shm_capacity_bytes: shared_index_metrics.shm_capacity_bytes, + shared_index_shm_used_bytes: shared_index_metrics.shm_used_bytes, + shared_index_entry_count: shared_index_metrics.entry_count, + shared_index_current_size_bytes: shared_index_metrics.current_size_bytes, + shared_index_operation_ring_capacity: shared_index_metrics.operation_ring_capacity, + shared_index_operation_ring_used: shared_index_metrics.operation_ring_used, + shared_index_lock_contention_total: shared_index_metrics.lock_contention_total, + shared_index_full_reload_total: shared_index_metrics.full_reload_total, + shared_index_rebuild_total: shared_index_metrics.rebuild_total, + shared_index_stale_fill_lock_cleanup_total: shared_index_metrics + .stale_fill_lock_cleanup_total, + shared_index_capacity_rejection_total: shared_index_metrics.capacity_rejection_total, } } + + pub(in crate::cache) fn store_prepared_response_head( + &self, + key: &str, + last_access_unix_ms: u64, + response_head: Arc, + ) { + let still_current = read_index(&self.index) + .entries + .get(key) + .is_some_and(|entry| entry.hash == response_head.hash); + if !still_current { + return; + } + + let hot_entry = self.hot_entry_for_key(key); + hot_entry.last_access_unix_ms.fetch_max(last_access_unix_ms, Ordering::Relaxed); + *hot_entry.response_head.lock().unwrap_or_else(std::sync::PoisonError::into_inner) = + Some(response_head); + } } diff --git a/crates/rginx-http/src/cache/runtime/zone/compare.rs b/crates/rginx-http/src/cache/runtime/zone/compare.rs index a712e978..929d5f3d 100644 --- a/crates/rginx-http/src/cache/runtime/zone/compare.rs +++ b/crates/rginx-http/src/cache/runtime/zone/compare.rs @@ -1,4 +1,4 @@ -use super::*; +use super::CacheIndexEntry; pub(super) fn same_cached_object_metadata(left: &CacheIndexEntry, right: &CacheIndexEntry) -> bool { left.kind == right.kind diff --git a/crates/rginx-http/src/cache/shared.rs b/crates/rginx-http/src/cache/shared.rs index 8e16df42..0009334a 100644 --- a/crates/rginx-http/src/cache/shared.rs +++ b/crates/rginx-http/src/cache/shared.rs @@ -1,3 +1,11 @@ +mod bootstrap; +mod delta; +#[cfg(target_os = "linux")] +pub(crate) mod memory; +#[cfg(not(target_os = "linux"))] +pub(crate) mod memory {} +mod index_file; + use std::io; use std::sync::Arc; use std::sync::atomic::Ordering; @@ -7,14 +15,6 @@ use super::load::load_index_from_disk; use super::store::lock_index; use super::{CacheIndex, CacheIndexEntry, CacheInvalidationRule, CacheZoneRuntime}; -mod bootstrap; -mod delta; -#[cfg(target_os = "linux")] -pub(crate) mod memory; -#[cfg(not(target_os = "linux"))] -pub(crate) mod memory {} -mod index_file; - pub(in crate::cache) use bootstrap::bootstrap_shared_index; use delta::{apply_shared_index_delta, reload_zone_shared_index}; pub(super) use index_file::SharedIndexStore; @@ -35,18 +35,19 @@ pub(in crate::cache) use index_file::{ run_blocking, shared_fill_lock_path, shared_fill_state_path, }; -type SharedIndexBootstrap = (CacheIndex, Option>, u64, u64, u64); pub(super) const SHARED_ACCESS_TOUCH_PUBLISH_INTERVAL_MS: u64 = 1_000; +type SharedIndexBootstrap = (CacheIndex, Option>, u64, u64, u64); + #[derive(Clone)] pub(super) enum SharedIndexOperation { - UpsertEntry { key: String, entry: CacheIndexEntry }, - RemoveEntry { key: String }, - TouchEntry { key: String, last_access_unix_ms: u64 }, - SetAdmissionCount { key: String, uses: u64 }, - RemoveAdmissionCount { key: String }, AddInvalidation { rule: CacheInvalidationRule }, ClearInvalidations, + RemoveAdmissionCount { key: String }, + RemoveEntry { key: String }, + SetAdmissionCount { key: String, uses: u64 }, + TouchEntry { key: String, last_access_unix_ms: u64 }, + UpsertEntry { key: String, entry: CacheIndexEntry }, } pub(super) async fn sync_zone_shared_index_if_needed(zone: &Arc) { diff --git a/crates/rginx-http/src/cache/shared/bootstrap.rs b/crates/rginx-http/src/cache/shared/bootstrap.rs index 0268420a..65de62f5 100644 --- a/crates/rginx-http/src/cache/shared/bootstrap.rs +++ b/crates/rginx-http/src/cache/shared/bootstrap.rs @@ -1,4 +1,9 @@ -use super::*; +use super::{ + Arc, CacheIndex, SharedIndexBootstrap, SharedIndexStore, delete_legacy_shared_index_file, + index_file, io, legacy_shared_index_path, load_index_from_disk, + load_legacy_shared_index_from_disk, load_shared_index_from_disk, recreate_shared_index_on_disk, + run_blocking, shared_index_store, +}; pub(in crate::cache) fn bootstrap_shared_index( zone: &rginx_core::CacheZone, diff --git a/crates/rginx-http/src/cache/shared/delta.rs b/crates/rginx-http/src/cache/shared/delta.rs index c4cbf631..e513f86e 100644 --- a/crates/rginx-http/src/cache/shared/delta.rs +++ b/crates/rginx-http/src/cache/shared/delta.rs @@ -1,4 +1,8 @@ -use super::*; +use super::{ + Arc, CacheIndex, CacheInvalidationRule, CacheZoneRuntime, Ordering, SharedIndexOperation, + SharedIndexStore, invalidation_rule_matches_entry, load_shared_index_from_disk, lock_index, + run_blocking, +}; pub(super) fn reload_zone_shared_index(zone: &Arc, store: &SharedIndexStore) { let loaded = match run_blocking(|| load_shared_index_from_disk(store)) { diff --git a/crates/rginx-http/src/cache/shared/index_file/codec/binary.rs b/crates/rginx-http/src/cache/shared/index_file/codec/binary.rs index d579e9ee..80c9324f 100644 --- a/crates/rginx-http/src/cache/shared/index_file/codec/binary.rs +++ b/crates/rginx-http/src/cache/shared/index_file/codec/binary.rs @@ -131,21 +131,21 @@ pub(in crate::cache::shared::index_file) fn deserialize_entry_record( cursor.finish()?; Ok(CacheIndexEntry { - kind, - hash, base_key, - stored_at_unix_ms, - vary, - tags, body_size_bytes, expires_at_unix_ms, grace_until_unix_ms, + hash, keep_until_unix_ms, + kind, + last_access_unix_ms, + must_revalidate, + requires_revalidation, stale_if_error_until_unix_ms, stale_while_revalidate_until_unix_ms, - requires_revalidation, - must_revalidate, - last_access_unix_ms, + stored_at_unix_ms, + tags, + vary, }) } @@ -201,5 +201,5 @@ pub(in crate::cache::shared::index_file) fn deserialize_invalidation_rule( }; let created_at_unix_ms = cursor.read_u64()?; cursor.finish()?; - Ok(CacheInvalidationRule { selector, created_at_unix_ms }) + Ok(CacheInvalidationRule { created_at_unix_ms, selector }) } diff --git a/crates/rginx-http/src/cache/shared/index_file/codec/cursor.rs b/crates/rginx-http/src/cache/shared/index_file/codec/cursor.rs index e7de0687..bf6368af 100644 --- a/crates/rginx-http/src/cache/shared/index_file/codec/cursor.rs +++ b/crates/rginx-http/src/cache/shared/index_file/codec/cursor.rs @@ -2,6 +2,87 @@ use std::io; use super::invalid_data_error; +pub(super) struct BinaryCursor<'a> { + bytes: &'a [u8], + offset: usize, +} + +impl<'a> BinaryCursor<'a> { + pub(super) fn finish(&self) -> io::Result<()> { + if self.offset == self.bytes.len() { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "binary shared index record contained trailing bytes", + )) + } + } + pub(super) fn new(bytes: &'a [u8]) -> Self { + Self { bytes, offset: 0 } + } + + pub(super) fn read_bool(&mut self) -> io::Result { + match self.read_u8()? { + 0 => Ok(false), + 1 => Ok(true), + other => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("invalid boolean discriminant `{other}`"), + )), + } + } + + pub(super) fn read_bytes(&mut self) -> io::Result> { + let len = usize::try_from(self.read_u32()?) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "byte length exceeds usize"))?; + Ok(self.read_exact(len)?.to_vec()) + } + + pub(super) fn read_exact(&mut self, len: usize) -> io::Result<&'a [u8]> { + let end = self.offset.checked_add(len).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "binary cursor overflowed") + })?; + let slice = self.bytes.get(self.offset..end).ok_or_else(|| { + io::Error::new(io::ErrorKind::UnexpectedEof, "truncated binary shared index record") + })?; + self.offset = end; + Ok(slice) + } + + pub(super) fn read_optional_string(&mut self) -> io::Result> { + self.read_bool()?.then(|| self.read_string()).transpose() + } + + pub(super) fn read_optional_u64(&mut self) -> io::Result> { + self.read_bool()?.then(|| self.read_u64()).transpose() + } + + pub(super) fn read_string(&mut self) -> io::Result { + String::from_utf8(self.read_bytes()?).map_err(invalid_data_error) + } + + pub(super) fn read_u32(&mut self) -> io::Result { + Ok(u32::from_le_bytes( + self.read_exact(4)? + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u32 bytes"))?, + )) + } + + pub(super) fn read_u64(&mut self) -> io::Result { + Ok(u64::from_le_bytes( + self.read_exact(8)? + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u64 bytes"))?, + )) + } + + pub(super) fn read_u8(&mut self) -> io::Result { + Ok(*self.read_exact(1)?.first().expect("single-byte slice should exist")) + } +} + pub(super) fn write_u8(bytes: &mut Vec, value: u8) { bytes.push(value); } @@ -33,15 +114,12 @@ pub(super) fn write_string(bytes: &mut Vec, value: &str) -> io::Result<()> { } pub(super) fn write_optional_string(bytes: &mut Vec, value: Option<&str>) -> io::Result<()> { - match value { - Some(value) => { - write_bool(bytes, true); - write_string(bytes, value) - } - None => { - write_bool(bytes, false); - Ok(()) - } + if let Some(value) = value { + write_bool(bytes, true); + write_string(bytes, value) + } else { + write_bool(bytes, false); + Ok(()) } } @@ -54,85 +132,3 @@ pub(super) fn write_bytes(bytes: &mut Vec, value: &[u8]) -> io::Result<()> { bytes.extend_from_slice(value); Ok(()) } - -pub(super) struct BinaryCursor<'a> { - bytes: &'a [u8], - offset: usize, -} - -impl<'a> BinaryCursor<'a> { - pub(super) fn new(bytes: &'a [u8]) -> Self { - Self { bytes, offset: 0 } - } - - pub(super) fn read_u8(&mut self) -> io::Result { - Ok(*self.read_exact(1)?.first().expect("single-byte slice should exist")) - } - - pub(super) fn read_u32(&mut self) -> io::Result { - Ok(u32::from_le_bytes( - self.read_exact(4)? - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u32 bytes"))?, - )) - } - - pub(super) fn read_u64(&mut self) -> io::Result { - Ok(u64::from_le_bytes( - self.read_exact(8)? - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u64 bytes"))?, - )) - } - - pub(super) fn read_bool(&mut self) -> io::Result { - match self.read_u8()? { - 0 => Ok(false), - 1 => Ok(true), - other => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("invalid boolean discriminant `{other}`"), - )), - } - } - - pub(super) fn read_optional_u64(&mut self) -> io::Result> { - self.read_bool()?.then(|| self.read_u64()).transpose() - } - - pub(super) fn read_string(&mut self) -> io::Result { - String::from_utf8(self.read_bytes()?).map_err(invalid_data_error) - } - - pub(super) fn read_optional_string(&mut self) -> io::Result> { - self.read_bool()?.then(|| self.read_string()).transpose() - } - - pub(super) fn read_bytes(&mut self) -> io::Result> { - let len = usize::try_from(self.read_u32()?) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "byte length exceeds usize"))?; - Ok(self.read_exact(len)?.to_vec()) - } - - pub(super) fn read_exact(&mut self, len: usize) -> io::Result<&'a [u8]> { - let end = self.offset.checked_add(len).ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "binary cursor overflowed") - })?; - let slice = self.bytes.get(self.offset..end).ok_or_else(|| { - io::Error::new(io::ErrorKind::UnexpectedEof, "truncated binary shared index record") - })?; - self.offset = end; - Ok(slice) - } - - pub(super) fn finish(&self) -> io::Result<()> { - if self.offset == self.bytes.len() { - Ok(()) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "binary shared index record contained trailing bytes", - )) - } - } -} diff --git a/crates/rginx-http/src/cache/shared/index_file/codec/legacy.rs b/crates/rginx-http/src/cache/shared/index_file/codec/legacy.rs index 775fa2bd..87798d92 100644 --- a/crates/rginx-http/src/cache/shared/index_file/codec/legacy.rs +++ b/crates/rginx-http/src/cache/shared/index_file/codec/legacy.rs @@ -11,38 +11,38 @@ pub(super) struct SharedVaryHeader { #[derive(Debug, Deserialize)] pub(super) struct LegacySharedIndexFile { - pub(super) version: u8, - pub(super) generation: u64, - pub(super) entries: Vec, #[serde(default)] pub(super) admission_counts: Vec, + pub(super) entries: Vec, + pub(super) generation: u64, + pub(super) version: u8, } #[derive(Debug, Deserialize)] pub(super) struct LegacySharedIndexEntry { - pub(super) key: String, - pub(super) hash: String, pub(super) base_key: String, - #[serde(default)] - pub(super) stored_at_unix_ms: u64, - pub(super) vary: Vec, - #[serde(default)] - pub(super) tags: Vec, pub(super) body_size_bytes: usize, pub(super) expires_at_unix_ms: u64, #[serde(default)] pub(super) grace_until_unix_ms: Option, + pub(super) hash: String, #[serde(default)] pub(super) keep_until_unix_ms: Option, + pub(super) key: String, + pub(super) last_access_unix_ms: u64, + #[serde(default)] + pub(super) must_revalidate: bool, + #[serde(default)] + pub(super) requires_revalidation: Option, #[serde(default)] pub(super) stale_if_error_until_unix_ms: Option, #[serde(default)] pub(super) stale_while_revalidate_until_unix_ms: Option, #[serde(default)] - pub(super) requires_revalidation: Option, + pub(super) stored_at_unix_ms: u64, #[serde(default)] - pub(super) must_revalidate: bool, - pub(super) last_access_unix_ms: u64, + pub(super) tags: Vec, + pub(super) vary: Vec, } impl LegacySharedIndexEntry { diff --git a/crates/rginx-http/src/cache/shared/index_file/codec/mod.rs b/crates/rginx-http/src/cache/shared/index_file/codec/mod.rs index 8441f560..76c88e92 100644 --- a/crates/rginx-http/src/cache/shared/index_file/codec/mod.rs +++ b/crates/rginx-http/src/cache/shared/index_file/codec/mod.rs @@ -1,13 +1,13 @@ +mod binary; +mod cursor; +mod legacy; + use std::io; use std::path::Path; use super::super::super::{CacheIndex, CacheIndexEntry, CachedVaryHeaderValue}; use super::{LoadedSharedIndex, SHARED_INDEX_SCHEMA_VERSION, invalid_data_error}; -mod binary; -mod cursor; -mod legacy; - pub(super) use binary::{ deserialize_entry_record, deserialize_invalidation_rule, serialize_entry_record, serialize_invalidation_rule, diff --git a/crates/rginx-http/src/cache/shared/index_file/memory_backend.rs b/crates/rginx-http/src/cache/shared/index_file/memory_backend.rs index 619eb414..d6f2cd30 100644 --- a/crates/rginx-http/src/cache/shared/index_file/memory_backend.rs +++ b/crates/rginx-http/src/cache/shared/index_file/memory_backend.rs @@ -1,3 +1,12 @@ +mod changes; +mod config; +mod document; +mod fill_locks; +mod locks; +mod segment; +#[cfg(test)] +mod tests; + use super::codec::serialize_invalidation_rule; use super::{ AppliedSharedIndexOperations, LoadedSharedIndex, LoadedSharedIndexChanges, @@ -10,15 +19,6 @@ use std::io; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; -mod changes; -mod config; -mod document; -mod fill_locks; -mod locks; -mod segment; -#[cfg(test)] -mod tests; - use changes::{ apply_operation_to_document, change_record_from_operation, operations_since, trim_change_ring, }; @@ -31,16 +31,20 @@ use locks::{FileLock, lock}; use segment::{open_or_create_segment, read_document, sync_state_from_header, write_document}; pub(super) struct MemorySharedIndexStore { - path: PathBuf, + capacity_rejection_total: AtomicU64, + full_reload_total: AtomicU64, + lock_contention_total: AtomicU64, lock_path: PathBuf, - segment_config: SharedMemorySegmentConfig, operation_ring_capacity: usize, - lock_contention_total: AtomicU64, - full_reload_total: AtomicU64, - capacity_rejection_total: AtomicU64, + path: PathBuf, + segment_config: SharedMemorySegmentConfig, } impl MemorySharedIndexStore { + fn lock(&self) -> io::Result { + lock(&self.lock_path, &self.lock_contention_total) + } + pub(super) fn new(zone: &rginx_core::CacheZone) -> Self { let capacity_bytes = memory_capacity_bytes(); let operation_ring_capacity = memory_operation_ring_capacity(); @@ -60,6 +64,17 @@ impl MemorySharedIndexStore { } } + fn open_or_create_segment(&self) -> io::Result { + open_or_create_segment(&self.segment_config) + } + + fn read_document( + &self, + segment: &SharedMemorySegment, + ) -> io::Result { + read_document(segment) + } + fn with_document_lock( &self, operation: impl FnOnce(&SharedMemorySegment, SharedMemoryIndexDocument) -> io::Result, @@ -79,21 +94,6 @@ impl MemorySharedIndexStore { operation(&segment) } - fn lock(&self) -> io::Result { - lock(&self.lock_path, &self.lock_contention_total) - } - - fn open_or_create_segment(&self) -> io::Result { - open_or_create_segment(&self.segment_config) - } - - fn read_document( - &self, - segment: &SharedMemorySegment, - ) -> io::Result { - read_document(segment) - } - fn write_document( &self, segment: &SharedMemorySegment, @@ -110,12 +110,49 @@ impl MemorySharedIndexStore { } impl SharedIndexBackend for MemorySharedIndexStore { - fn path(&self) -> &Path { - &self.path + fn apply_operations( + &self, + operations: &[SharedIndexOperation], + ) -> io::Result { + self.with_document_lock(|segment, mut document| { + if !operations.is_empty() { + document.generation = document.generation.saturating_add(1); + for operation in operations { + apply_operation_to_document(&mut document, operation)?; + document.last_change_seq = document.last_change_seq.saturating_add(1); + document.changes.push(change_record_from_operation( + document.last_change_seq, + document.generation, + operation, + )?); + } + trim_change_ring(&mut document, self.operation_ring_capacity); + self.write_document(segment, &document)?; + } + Ok(AppliedSharedIndexOperations { + generation: document.generation, + store_epoch: document.store_epoch, + last_change_seq: document.last_change_seq, + }) + }) } - fn supports_shared_fill_locks(&self) -> bool { - true + fn clear_stale_fill_lock( + &self, + key: &str, + now_unix_ms: u64, + lock_age_ms: u64, + ) -> io::Result { + fill_locks::clear_stale_fill_lock(self, key, now_unix_ms, lock_age_ms) + } + + fn fill_lock_status( + &self, + key: &str, + now_unix_ms: u64, + lock_age_ms: u64, + ) -> io::Result { + fill_locks::fill_lock_status(self, key, now_unix_ms, lock_age_ms) } fn load(&self) -> io::Result { @@ -125,18 +162,6 @@ impl SharedIndexBackend for MemorySharedIndexStore { Ok(loaded) } - fn sync_state(&self) -> io::Result { - match SharedMemorySegment::attach(&self.segment_config) { - Ok(segment) => Ok(sync_state_from_header(segment.header())), - Err(error) - if matches!(error.kind(), io::ErrorKind::NotFound | io::ErrorKind::InvalidData) => - { - self.with_segment_lock(|segment| Ok(sync_state_from_header(segment.header()))) - } - Err(error) => Err(error), - } - } - fn load_changes_since(&self, after_seq: u64) -> io::Result { self.with_document_lock(|_segment, document| { let operations = operations_since(&document, after_seq)?; @@ -149,6 +174,37 @@ impl SharedIndexBackend for MemorySharedIndexStore { }) } + fn load_fill_lock(&self, key: &str) -> io::Result> { + fill_locks::load_fill_lock(self, key) + } + + fn metrics(&self) -> io::Result { + self.with_document_lock(|segment, document| { + let bytes = encode_document(&document)?; + let header = segment.header(); + let used_payload_bytes = bytes.len().saturating_add(PAYLOAD_LEN_BYTES); + let used_bytes = u64::try_from(used_payload_bytes) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "document is too large"))? + .saturating_add(u64::from(header.header_len)); + Ok(SharedIndexMetrics { + shm_capacity_bytes: header.capacity_bytes, + shm_used_bytes: used_bytes, + entry_count: header.entry_count, + current_size_bytes: header.current_size_bytes, + operation_ring_capacity: header.operation_ring_capacity, + operation_ring_used: document.changes.len() as u64, + lock_contention_total: self.lock_contention_total.load(Ordering::Relaxed), + full_reload_total: self.full_reload_total.load(Ordering::Relaxed), + rebuild_total: document.rebuild_total, + stale_fill_lock_cleanup_total: document.stale_fill_lock_cleanup_total, + capacity_rejection_total: self.capacity_rejection_total.load(Ordering::Relaxed), + }) + }) + } + fn path(&self) -> &Path { + &self.path + } + fn recreate( &self, index: &CacheIndex, @@ -157,7 +213,7 @@ impl SharedIndexBackend for MemorySharedIndexStore { let _lock = self.lock()?; let segment = self.open_or_create_segment()?; let previous_rebuild_total = - self.read_document(&segment).map(|document| document.rebuild_total).unwrap_or(0); + self.read_document(&segment).map_or(0, |document| document.rebuild_total); let mut document = SharedMemoryIndexDocument::empty(segment.header().store_epoch); document.generation = generation; document.rebuild_total = previous_rebuild_total.saturating_add(1); @@ -178,31 +234,24 @@ impl SharedIndexBackend for MemorySharedIndexStore { Ok(applied) } - fn apply_operations( - &self, - operations: &[SharedIndexOperation], - ) -> io::Result { - self.with_document_lock(|segment, mut document| { - if !operations.is_empty() { - document.generation = document.generation.saturating_add(1); - for operation in operations { - apply_operation_to_document(&mut document, operation)?; - document.last_change_seq = document.last_change_seq.saturating_add(1); - document.changes.push(change_record_from_operation( - document.last_change_seq, - document.generation, - operation, - )?); - } - trim_change_ring(&mut document, self.operation_ring_capacity); - self.write_document(segment, &document)?; + fn release_fill_lock(&self, key: &str, nonce: &str) -> io::Result<()> { + fill_locks::release_fill_lock(self, key, nonce) + } + + fn supports_shared_fill_locks(&self) -> bool { + true + } + + fn sync_state(&self) -> io::Result { + match SharedMemorySegment::attach(&self.segment_config) { + Ok(segment) => Ok(sync_state_from_header(segment.header())), + Err(error) + if matches!(error.kind(), io::ErrorKind::NotFound | io::ErrorKind::InvalidData) => + { + self.with_segment_lock(|segment| Ok(sync_state_from_header(segment.header()))) } - Ok(AppliedSharedIndexOperations { - generation: document.generation, - store_epoch: document.store_epoch, - last_change_seq: document.last_change_seq, - }) - }) + Err(error) => Err(error), + } } fn try_acquire_fill_lock( @@ -216,10 +265,6 @@ impl SharedIndexBackend for MemorySharedIndexStore { fill_locks::try_acquire_fill_lock(self, key, now_unix_ms, lock_age_ms, nonce, state_json) } - fn load_fill_lock(&self, key: &str) -> io::Result> { - fill_locks::load_fill_lock(self, key) - } - fn update_fill_lock( &self, key: &str, @@ -229,52 +274,6 @@ impl SharedIndexBackend for MemorySharedIndexStore { ) -> io::Result<()> { fill_locks::update_fill_lock(self, key, nonce, updated_at_unix_ms, state_json) } - - fn release_fill_lock(&self, key: &str, nonce: &str) -> io::Result<()> { - fill_locks::release_fill_lock(self, key, nonce) - } - - fn fill_lock_status( - &self, - key: &str, - now_unix_ms: u64, - lock_age_ms: u64, - ) -> io::Result { - fill_locks::fill_lock_status(self, key, now_unix_ms, lock_age_ms) - } - - fn clear_stale_fill_lock( - &self, - key: &str, - now_unix_ms: u64, - lock_age_ms: u64, - ) -> io::Result { - fill_locks::clear_stale_fill_lock(self, key, now_unix_ms, lock_age_ms) - } - - fn metrics(&self) -> io::Result { - self.with_document_lock(|segment, document| { - let bytes = encode_document(&document)?; - let header = segment.header(); - let used_payload_bytes = bytes.len().saturating_add(PAYLOAD_LEN_BYTES); - let used_bytes = u64::try_from(used_payload_bytes) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "document is too large"))? - .saturating_add(header.header_len as u64); - Ok(SharedIndexMetrics { - shm_capacity_bytes: header.capacity_bytes, - shm_used_bytes: used_bytes, - entry_count: header.entry_count, - current_size_bytes: header.current_size_bytes, - operation_ring_capacity: header.operation_ring_capacity, - operation_ring_used: document.changes.len() as u64, - lock_contention_total: self.lock_contention_total.load(Ordering::Relaxed), - full_reload_total: self.full_reload_total.load(Ordering::Relaxed), - rebuild_total: document.rebuild_total, - stale_fill_lock_cleanup_total: document.stale_fill_lock_cleanup_total, - capacity_rejection_total: self.capacity_rejection_total.load(Ordering::Relaxed), - }) - }) - } } #[cfg(all(test, target_os = "linux"))] diff --git a/crates/rginx-http/src/cache/shared/index_file/memory_backend/changes.rs b/crates/rginx-http/src/cache/shared/index_file/memory_backend/changes.rs index 6b01a7ee..0fd2b1f1 100644 --- a/crates/rginx-http/src/cache/shared/index_file/memory_backend/changes.rs +++ b/crates/rginx-http/src/cache/shared/index_file/memory_backend/changes.rs @@ -199,7 +199,7 @@ pub(super) fn trim_change_ring( return; } if document.changes.len() > operation_ring_capacity { - let drop_count = document.changes.len() - operation_ring_capacity; + let drop_count = document.changes.len().saturating_sub(operation_ring_capacity); document.changes.drain(0..drop_count); } } diff --git a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec.rs b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec.rs index 07094624..a7ffe5ce 100644 --- a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec.rs +++ b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec.rs @@ -1,8 +1,9 @@ +mod cursor; +mod records; + use std::io; use super::model::SharedMemoryIndexDocument; -mod cursor; -mod records; use self::cursor::{BinaryCursor, write_bytes, write_string, write_u32, write_u64}; use self::records::{ @@ -117,17 +118,17 @@ pub(in crate::cache::shared::index_file::memory_backend) fn decode_document( cursor.finish()?; Ok(SharedMemoryIndexDocument { - version, + admission_counts, + changes, + entries, + fill_locks, generation, - store_epoch, + invalidations, last_change_seq, rebuild_total, stale_fill_lock_cleanup_total, - entries, - admission_counts, - invalidations, - fill_locks, - changes, + store_epoch, + version, }) } diff --git a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec/cursor.rs b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec/cursor.rs index 42378920..b99cf163 100644 --- a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec/cursor.rs +++ b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/codec/cursor.rs @@ -2,6 +2,87 @@ use std::io; use super::super::super::super::invalid_data_error; +pub(super) struct BinaryCursor<'a> { + bytes: &'a [u8], + offset: usize, +} + +impl<'a> BinaryCursor<'a> { + pub(super) fn finish(&self) -> io::Result<()> { + if self.offset == self.bytes.len() { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::InvalidData, + "shared memory document contained trailing bytes", + )) + } + } + pub(super) fn new(bytes: &'a [u8]) -> Self { + Self { bytes, offset: 0 } + } + + pub(super) fn read_bool(&mut self) -> io::Result { + match self.read_u8()? { + 0 => Ok(false), + 1 => Ok(true), + other => Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("invalid boolean discriminant `{other}`"), + )), + } + } + + pub(super) fn read_bytes(&mut self) -> io::Result> { + let len = usize::try_from(self.read_u32()?) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "byte length exceeds usize"))?; + Ok(self.read_exact(len)?.to_vec()) + } + + pub(super) fn read_exact(&mut self, len: usize) -> io::Result<&'a [u8]> { + let end = self.offset.checked_add(len).ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidData, "binary cursor overflowed") + })?; + let slice = self.bytes.get(self.offset..end).ok_or_else(|| { + io::Error::new(io::ErrorKind::UnexpectedEof, "truncated shared memory document") + })?; + self.offset = end; + Ok(slice) + } + + pub(super) fn read_optional_bytes(&mut self) -> io::Result>> { + self.read_bool()?.then(|| self.read_bytes()).transpose() + } + + pub(super) fn read_optional_u64(&mut self) -> io::Result> { + self.read_bool()?.then(|| self.read_u64()).transpose() + } + + pub(super) fn read_string(&mut self) -> io::Result { + String::from_utf8(self.read_bytes()?).map_err(invalid_data_error) + } + + pub(super) fn read_u32(&mut self) -> io::Result { + Ok(u32::from_le_bytes( + self.read_exact(4)? + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u32 bytes"))?, + )) + } + + pub(super) fn read_u64(&mut self) -> io::Result { + Ok(u64::from_le_bytes( + self.read_exact(8)? + .try_into() + .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u64 bytes"))?, + )) + } + + pub(super) fn read_u8(&mut self) -> io::Result { + Ok(*self.read_exact(1)?.first().expect("single-byte slice should exist")) + } +} + pub(super) fn write_u8(bytes: &mut Vec, value: u8) { bytes.push(value); } @@ -39,15 +120,12 @@ pub(super) fn write_bytes(bytes: &mut Vec, value: &[u8]) -> io::Result<()> { } pub(super) fn write_optional_bytes(bytes: &mut Vec, value: Option<&[u8]>) -> io::Result<()> { - match value { - Some(value) => { - write_bool(bytes, true); - write_bytes(bytes, value) - } - None => { - write_bool(bytes, false); - Ok(()) - } + if let Some(value) = value { + write_bool(bytes, true); + write_bytes(bytes, value) + } else { + write_bool(bytes, false); + Ok(()) } } @@ -56,85 +134,3 @@ fn usize_to_u32(value: usize, label: &str) -> io::Result { io::Error::new(io::ErrorKind::InvalidData, format!("{label} exceeds u32 capacity")) }) } - -pub(super) struct BinaryCursor<'a> { - bytes: &'a [u8], - offset: usize, -} - -impl<'a> BinaryCursor<'a> { - pub(super) fn new(bytes: &'a [u8]) -> Self { - Self { bytes, offset: 0 } - } - - pub(super) fn read_u8(&mut self) -> io::Result { - Ok(*self.read_exact(1)?.first().expect("single-byte slice should exist")) - } - - pub(super) fn read_u32(&mut self) -> io::Result { - Ok(u32::from_le_bytes( - self.read_exact(4)? - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u32 bytes"))?, - )) - } - - pub(super) fn read_u64(&mut self) -> io::Result { - Ok(u64::from_le_bytes( - self.read_exact(8)? - .try_into() - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "invalid u64 bytes"))?, - )) - } - - pub(super) fn read_bool(&mut self) -> io::Result { - match self.read_u8()? { - 0 => Ok(false), - 1 => Ok(true), - other => Err(io::Error::new( - io::ErrorKind::InvalidData, - format!("invalid boolean discriminant `{other}`"), - )), - } - } - - pub(super) fn read_optional_u64(&mut self) -> io::Result> { - self.read_bool()?.then(|| self.read_u64()).transpose() - } - - pub(super) fn read_string(&mut self) -> io::Result { - String::from_utf8(self.read_bytes()?).map_err(invalid_data_error) - } - - pub(super) fn read_bytes(&mut self) -> io::Result> { - let len = usize::try_from(self.read_u32()?) - .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "byte length exceeds usize"))?; - Ok(self.read_exact(len)?.to_vec()) - } - - pub(super) fn read_optional_bytes(&mut self) -> io::Result>> { - self.read_bool()?.then(|| self.read_bytes()).transpose() - } - - pub(super) fn read_exact(&mut self, len: usize) -> io::Result<&'a [u8]> { - let end = self.offset.checked_add(len).ok_or_else(|| { - io::Error::new(io::ErrorKind::InvalidData, "binary cursor overflowed") - })?; - let slice = self.bytes.get(self.offset..end).ok_or_else(|| { - io::Error::new(io::ErrorKind::UnexpectedEof, "truncated shared memory document") - })?; - self.offset = end; - Ok(slice) - } - - pub(super) fn finish(&self) -> io::Result<()> { - if self.offset == self.bytes.len() { - Ok(()) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "shared memory document contained trailing bytes", - )) - } - } -} diff --git a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/mod.rs b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/mod.rs index 92a445a9..5517c031 100644 --- a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/mod.rs +++ b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/mod.rs @@ -1,3 +1,6 @@ +mod codec; +mod model; + use std::io; use super::super::LoadedSharedIndex; @@ -6,9 +9,6 @@ use super::super::codec::{ }; use crate::cache::CacheIndex; -mod codec; -mod model; - pub(super) use codec::{ PAYLOAD_LEN_BYTES, decode_document, document_current_size_bytes, encode_document, }; diff --git a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/model.rs b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/model.rs index 068cdaf2..750b1374 100644 --- a/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/model.rs +++ b/crates/rginx-http/src/cache/shared/index_file/memory_backend/document/model.rs @@ -6,47 +6,47 @@ const SHARED_MEMORY_DOCUMENT_VERSION: u32 = 2; #[derive(Debug, Serialize, Deserialize)] pub(in crate::cache::shared::index_file::memory_backend) struct SharedMemoryIndexDocument { - pub(in crate::cache::shared::index_file::memory_backend) version: u32, + pub(in crate::cache::shared::index_file::memory_backend) admission_counts: + BTreeMap, + pub(in crate::cache::shared::index_file::memory_backend) changes: Vec, + pub(in crate::cache::shared::index_file::memory_backend) entries: BTreeMap>, + #[serde(default)] + pub(in crate::cache::shared::index_file::memory_backend) fill_locks: + BTreeMap, pub(in crate::cache::shared::index_file::memory_backend) generation: u64, - pub(in crate::cache::shared::index_file::memory_backend) store_epoch: u64, + pub(in crate::cache::shared::index_file::memory_backend) invalidations: Vec>, pub(in crate::cache::shared::index_file::memory_backend) last_change_seq: u64, #[serde(default)] pub(in crate::cache::shared::index_file::memory_backend) rebuild_total: u64, #[serde(default)] pub(in crate::cache::shared::index_file::memory_backend) stale_fill_lock_cleanup_total: u64, - pub(in crate::cache::shared::index_file::memory_backend) entries: BTreeMap>, - pub(in crate::cache::shared::index_file::memory_backend) admission_counts: - BTreeMap, - pub(in crate::cache::shared::index_file::memory_backend) invalidations: Vec>, - #[serde(default)] - pub(in crate::cache::shared::index_file::memory_backend) fill_locks: - BTreeMap, - pub(in crate::cache::shared::index_file::memory_backend) changes: Vec, + pub(in crate::cache::shared::index_file::memory_backend) store_epoch: u64, + pub(in crate::cache::shared::index_file::memory_backend) version: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub(in crate::cache::shared::index_file::memory_backend) struct SharedMemoryChangeRecord { - pub(in crate::cache::shared::index_file::memory_backend) seq: u64, + pub(in crate::cache::shared::index_file::memory_backend) entry_json: Option>, pub(in crate::cache::shared::index_file::memory_backend) generation: u64, - pub(in crate::cache::shared::index_file::memory_backend) op_kind: u8, pub(in crate::cache::shared::index_file::memory_backend) key: String, - pub(in crate::cache::shared::index_file::memory_backend) entry_json: Option>, - pub(in crate::cache::shared::index_file::memory_backend) uses: Option, #[serde(default)] pub(in crate::cache::shared::index_file::memory_backend) last_access_unix_ms: Option, + pub(in crate::cache::shared::index_file::memory_backend) op_kind: u8, + pub(in crate::cache::shared::index_file::memory_backend) seq: u64, + pub(in crate::cache::shared::index_file::memory_backend) uses: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub(in crate::cache::shared::index_file::memory_backend) struct SharedMemoryFillLockRecord { + pub(in crate::cache::shared::index_file::memory_backend) acquired_at_unix_ms: u64, pub(in crate::cache::shared::index_file::memory_backend) key_hash: String, - pub(in crate::cache::shared::index_file::memory_backend) owner_pid: u32, - pub(in crate::cache::shared::index_file::memory_backend) owner_generation: u64, pub(in crate::cache::shared::index_file::memory_backend) nonce: String, - pub(in crate::cache::shared::index_file::memory_backend) acquired_at_unix_ms: u64, - pub(in crate::cache::shared::index_file::memory_backend) updated_at_unix_ms: u64, + pub(in crate::cache::shared::index_file::memory_backend) owner_generation: u64, + pub(in crate::cache::shared::index_file::memory_backend) owner_pid: u32, #[serde(default)] pub(in crate::cache::shared::index_file::memory_backend) released: bool, pub(in crate::cache::shared::index_file::memory_backend) state_json: Vec, + pub(in crate::cache::shared::index_file::memory_backend) updated_at_unix_ms: u64, } impl SharedMemoryIndexDocument { diff --git a/crates/rginx-http/src/cache/shared/index_file/memory_backend/locks.rs b/crates/rginx-http/src/cache/shared/index_file/memory_backend/locks.rs index b28a7565..afb800f9 100644 --- a/crates/rginx-http/src/cache/shared/index_file/memory_backend/locks.rs +++ b/crates/rginx-http/src/cache/shared/index_file/memory_backend/locks.rs @@ -8,6 +8,12 @@ pub(super) struct FileLock { pub(super) file: File, } +impl Drop for FileLock { + fn drop(&mut self) { + let _ = unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) }; + } +} + pub(super) fn lock(lock_path: &PathBuf, lock_contention_total: &AtomicU64) -> io::Result { if let Some(parent) = lock_path.parent() { std::fs::create_dir_all(parent)?; @@ -31,9 +37,3 @@ pub(super) fn lock(lock_path: &PathBuf, lock_contention_total: &AtomicU64) -> io } Err(io::Error::last_os_error()) } - -impl Drop for FileLock { - fn drop(&mut self) { - let _ = unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) }; - } -} diff --git a/crates/rginx-http/src/cache/shared/index_file/mod.rs b/crates/rginx-http/src/cache/shared/index_file/mod.rs index 92219668..1eb7afcb 100644 --- a/crates/rginx-http/src/cache/shared/index_file/mod.rs +++ b/crates/rginx-http/src/cache/shared/index_file/mod.rs @@ -1,14 +1,7 @@ -use std::io; -use std::path::PathBuf; - -use super::super::CacheIndex; -use super::SharedIndexOperation; - mod codec; #[cfg(target_os = "linux")] mod memory_backend; mod store; -pub(super) const SHARED_INDEX_SCHEMA_VERSION: u8 = 2; #[cfg(not(target_os = "linux"))] mod memory_backend { use std::io; @@ -32,18 +25,17 @@ mod memory_backend { } impl SharedIndexBackend for MemorySharedIndexStore { - fn path(&self) -> &Path { - &self.path - } - - fn load(&self) -> io::Result { + fn apply_operations( + &self, + _operations: &[SharedIndexOperation], + ) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, "shared memory index is only supported on linux", )) } - fn sync_state(&self) -> io::Result { + fn load(&self) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, "shared memory index is only supported on linux", @@ -57,20 +49,20 @@ mod memory_backend { )) } - fn recreate( - &self, - _index: &CacheIndex, - _generation: u64, - ) -> io::Result { + fn metrics(&self) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, "shared memory index is only supported on linux", )) } + fn path(&self) -> &Path { + &self.path + } - fn apply_operations( + fn recreate( &self, - _operations: &[SharedIndexOperation], + _index: &CacheIndex, + _generation: u64, ) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, @@ -82,7 +74,7 @@ mod memory_backend { false } - fn metrics(&self) -> io::Result { + fn sync_state(&self) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, "shared memory index is only supported on linux", @@ -94,6 +86,12 @@ mod memory_backend { mod legacy; mod paths; +use std::io; +use std::path::PathBuf; + +use super::super::CacheIndex; +use super::SharedIndexOperation; + use store::SharedIndexBackend; pub(in crate::cache::shared) use store::{ AppliedSharedIndexOperations, LoadedSharedIndex, LoadedSharedIndexChanges, @@ -102,6 +100,8 @@ pub(in crate::cache::shared) use store::{ pub(in crate::cache) use store::{SharedFillLockAcquire, SharedFillLockStatus}; pub(in crate::cache) use store::{SharedIndexMetrics, SharedIndexStore}; +pub(super) const SHARED_INDEX_SCHEMA_VERSION: u8 = 2; + pub(in crate::cache) fn shared_fill_lock_path(zone: &rginx_core::CacheZone, key: &str) -> PathBuf { paths::shared_fill_lock_path(zone, key) } diff --git a/crates/rginx-http/src/cache/shared/index_file/store.rs b/crates/rginx-http/src/cache/shared/index_file/store.rs index f74f6fc4..b75c6d2f 100644 --- a/crates/rginx-http/src/cache/shared/index_file/store.rs +++ b/crates/rginx-http/src/cache/shared/index_file/store.rs @@ -5,23 +5,23 @@ use super::super::super::shared::index_file::memory_backend; use super::SharedIndexOperation; pub(in crate::cache::shared) struct LoadedSharedIndex { - pub(in crate::cache::shared) index: CacheIndex, pub(in crate::cache::shared) generation: u64, - pub(in crate::cache::shared) store_epoch: u64, + pub(in crate::cache::shared) index: CacheIndex, pub(in crate::cache::shared) last_change_seq: u64, + pub(in crate::cache::shared) store_epoch: u64, } pub(in crate::cache::shared) struct LoadedSharedIndexChanges { pub(in crate::cache::shared) generation: u64, - pub(in crate::cache::shared) store_epoch: u64, pub(in crate::cache::shared) last_change_seq: u64, pub(in crate::cache::shared) operations: Vec, + pub(in crate::cache::shared) store_epoch: u64, } pub(in crate::cache::shared) struct AppliedSharedIndexOperations { pub(in crate::cache::shared) generation: u64, - pub(in crate::cache::shared) store_epoch: u64, pub(in crate::cache::shared) last_change_seq: u64, + pub(in crate::cache::shared) store_epoch: u64, } pub(in crate::cache::shared) struct SharedIndexSyncState { @@ -31,17 +31,17 @@ pub(in crate::cache::shared) struct SharedIndexSyncState { #[derive(Debug, Clone, Default)] pub(in crate::cache) struct SharedIndexMetrics { - pub(in crate::cache) shm_capacity_bytes: u64, - pub(in crate::cache) shm_used_bytes: u64, - pub(in crate::cache) entry_count: u64, + pub(in crate::cache) capacity_rejection_total: u64, pub(in crate::cache) current_size_bytes: u64, + pub(in crate::cache) entry_count: u64, + pub(in crate::cache) full_reload_total: u64, + pub(in crate::cache) lock_contention_total: u64, pub(in crate::cache) operation_ring_capacity: u64, pub(in crate::cache) operation_ring_used: u64, - pub(in crate::cache) lock_contention_total: u64, - pub(in crate::cache) full_reload_total: u64, pub(in crate::cache) rebuild_total: u64, + pub(in crate::cache) shm_capacity_bytes: u64, + pub(in crate::cache) shm_used_bytes: u64, pub(in crate::cache) stale_fill_lock_cleanup_total: u64, - pub(in crate::cache) capacity_rejection_total: u64, } pub(in crate::cache) struct SharedFillLockSnapshot { @@ -55,51 +55,44 @@ pub(in crate::cache) enum SharedFillLockAcquire { } pub(in crate::cache) enum SharedFillLockStatus { - Missing, Fresh, + Missing, Stale, } pub(super) trait SharedIndexBackend: Send + Sync { - fn path(&self) -> &std::path::Path; - fn load(&self) -> io::Result; - fn sync_state(&self) -> io::Result; - fn load_changes_since(&self, after_seq: u64) -> io::Result; - fn recreate( - &self, - index: &CacheIndex, - generation: u64, - ) -> io::Result; fn apply_operations( &self, operations: &[SharedIndexOperation], ) -> io::Result; - fn supports_shared_fill_locks(&self) -> bool { - false - } - - fn metrics(&self) -> io::Result { + fn clear_stale_fill_lock( + &self, + _key: &str, + _now_unix_ms: u64, + _lock_age_ms: u64, + ) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, - "shared index metrics are not supported by this backend", + "shared fill locks are not supported by this backend", )) } - fn try_acquire_fill_lock( + fn fill_lock_status( &self, _key: &str, _now_unix_ms: u64, _lock_age_ms: u64, - _nonce: &str, - _state_json: &[u8], - ) -> io::Result { + ) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, "shared fill locks are not supported by this backend", )) } + fn load(&self) -> io::Result; + fn load_changes_since(&self, after_seq: u64) -> io::Result; + fn load_fill_lock(&self, _key: &str) -> io::Result> { Err(io::Error::new( io::ErrorKind::Unsupported, @@ -107,19 +100,20 @@ pub(super) trait SharedIndexBackend: Send + Sync { )) } - fn update_fill_lock( - &self, - _key: &str, - _nonce: &str, - _updated_at_unix_ms: u64, - _state_json: &[u8], - ) -> io::Result<()> { + fn metrics(&self) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, - "shared fill locks are not supported by this backend", + "shared index metrics are not supported by this backend", )) } + fn path(&self) -> &std::path::Path; + fn recreate( + &self, + index: &CacheIndex, + generation: u64, + ) -> io::Result; + fn release_fill_lock(&self, _key: &str, _nonce: &str) -> io::Result<()> { Err(io::Error::new( io::ErrorKind::Unsupported, @@ -127,24 +121,33 @@ pub(super) trait SharedIndexBackend: Send + Sync { )) } - fn fill_lock_status( + fn supports_shared_fill_locks(&self) -> bool { + false + } + + fn sync_state(&self) -> io::Result; + + fn try_acquire_fill_lock( &self, _key: &str, _now_unix_ms: u64, _lock_age_ms: u64, - ) -> io::Result { + _nonce: &str, + _state_json: &[u8], + ) -> io::Result { Err(io::Error::new( io::ErrorKind::Unsupported, "shared fill locks are not supported by this backend", )) } - fn clear_stale_fill_lock( + fn update_fill_lock( &self, _key: &str, - _now_unix_ms: u64, - _lock_age_ms: u64, - ) -> io::Result { + _nonce: &str, + _updated_at_unix_ms: u64, + _state_json: &[u8], + ) -> io::Result<()> { Err(io::Error::new( io::ErrorKind::Unsupported, "shared fill locks are not supported by this backend", @@ -157,20 +160,37 @@ pub(in crate::cache) struct SharedIndexStore { } impl SharedIndexStore { - pub(in crate::cache) fn for_zone(zone: &rginx_core::CacheZone) -> Self { - Self { backend: Box::new(memory_backend::MemorySharedIndexStore::new(zone)) } + pub(super) fn apply_operations( + &self, + operations: &[SharedIndexOperation], + ) -> io::Result { + self.backend.apply_operations(operations) } - pub(in crate::cache) fn path(&self) -> &std::path::Path { - self.backend.path() + pub(in crate::cache) fn clear_stale_fill_lock( + &self, + key: &str, + now_unix_ms: u64, + lock_age_ms: u64, + ) -> io::Result { + self.backend.clear_stale_fill_lock(key, now_unix_ms, lock_age_ms) } - pub(super) fn load(&self) -> io::Result { - self.backend.load() + pub(in crate::cache) fn fill_lock_status( + &self, + key: &str, + now_unix_ms: u64, + lock_age_ms: u64, + ) -> io::Result { + self.backend.fill_lock_status(key, now_unix_ms, lock_age_ms) } - pub(super) fn sync_state(&self) -> io::Result { - self.backend.sync_state() + pub(in crate::cache) fn for_zone(zone: &rginx_core::CacheZone) -> Self { + Self { backend: Box::new(memory_backend::MemorySharedIndexStore::new(zone)) } + } + + pub(super) fn load(&self) -> io::Result { + self.backend.load() } pub(super) fn load_changes_since( @@ -180,6 +200,21 @@ impl SharedIndexStore { self.backend.load_changes_since(after_seq) } + pub(in crate::cache) fn load_fill_lock( + &self, + key: &str, + ) -> io::Result> { + self.backend.load_fill_lock(key) + } + + pub(in crate::cache) fn metrics(&self) -> io::Result { + self.backend.metrics() + } + + pub(in crate::cache) fn path(&self) -> &std::path::Path { + self.backend.path() + } + pub(super) fn recreate( &self, index: &CacheIndex, @@ -188,19 +223,16 @@ impl SharedIndexStore { self.backend.recreate(index, generation) } - pub(super) fn apply_operations( - &self, - operations: &[SharedIndexOperation], - ) -> io::Result { - self.backend.apply_operations(operations) + pub(in crate::cache) fn release_fill_lock(&self, key: &str, nonce: &str) -> io::Result<()> { + self.backend.release_fill_lock(key, nonce) } pub(in crate::cache) fn supports_shared_fill_locks(&self) -> bool { self.backend.supports_shared_fill_locks() } - pub(in crate::cache) fn metrics(&self) -> io::Result { - self.backend.metrics() + pub(super) fn sync_state(&self) -> io::Result { + self.backend.sync_state() } pub(in crate::cache) fn try_acquire_fill_lock( @@ -214,13 +246,6 @@ impl SharedIndexStore { self.backend.try_acquire_fill_lock(key, now_unix_ms, lock_age_ms, nonce, state_json) } - pub(in crate::cache) fn load_fill_lock( - &self, - key: &str, - ) -> io::Result> { - self.backend.load_fill_lock(key) - } - pub(in crate::cache) fn update_fill_lock( &self, key: &str, @@ -230,26 +255,4 @@ impl SharedIndexStore { ) -> io::Result<()> { self.backend.update_fill_lock(key, nonce, updated_at_unix_ms, state_json) } - - pub(in crate::cache) fn release_fill_lock(&self, key: &str, nonce: &str) -> io::Result<()> { - self.backend.release_fill_lock(key, nonce) - } - - pub(in crate::cache) fn fill_lock_status( - &self, - key: &str, - now_unix_ms: u64, - lock_age_ms: u64, - ) -> io::Result { - self.backend.fill_lock_status(key, now_unix_ms, lock_age_ms) - } - - pub(in crate::cache) fn clear_stale_fill_lock( - &self, - key: &str, - now_unix_ms: u64, - lock_age_ms: u64, - ) -> io::Result { - self.backend.clear_stale_fill_lock(key, now_unix_ms, lock_age_ms) - } } diff --git a/crates/rginx-http/src/cache/shared/memory.rs b/crates/rginx-http/src/cache/shared/memory.rs index 7c89045a..2a5dfae7 100644 --- a/crates/rginx-http/src/cache/shared/memory.rs +++ b/crates/rginx-http/src/cache/shared/memory.rs @@ -1,4 +1,14 @@ -#![allow(dead_code)] +#![expect( + dead_code, + reason = "shared memory primitives are reused across platform and test targets" +)] + +mod config; +mod sys; + +#[cfg(test)] +#[path = "memory/tests.rs"] +mod tests; use std::fs; use std::io; @@ -6,9 +16,6 @@ use std::mem::size_of; use std::os::fd::RawFd; use std::ptr::NonNull; -mod config; -mod sys; - pub(crate) use self::config::{SharedMemorySegmentBacking, SharedMemorySegmentConfig}; use self::sys::{close_fd, fd_size, ftruncate, invalid_header, open_file, shm_open}; @@ -39,34 +46,15 @@ pub(crate) struct SharedMemoryHeader { #[derive(Debug)] pub(crate) struct SharedMemorySegment { + capacity_bytes: usize, fd: RawFd, ptr: NonNull, - capacity_bytes: usize, } unsafe impl Send for SharedMemorySegment {} unsafe impl Sync for SharedMemorySegment {} impl SharedMemorySegment { - pub(crate) fn create_or_reset(config: &SharedMemorySegmentConfig) -> io::Result { - config.validate_capacity()?; - let fd = Self::open_for_create(config)?; - if let Err(error) = ftruncate(fd, config.capacity_bytes) { - let _ = close_fd(fd); - return Err(error); - } - let segment = match Self::map(fd, config.capacity_bytes) { - Ok(segment) => segment, - Err(error) => { - let _ = close_fd(fd); - return Err(error); - } - }; - segment.zero(); - segment.write_header(config.initial_header()?); - Ok(segment) - } - pub(crate) fn attach(config: &SharedMemorySegmentConfig) -> io::Result { config.validate_capacity()?; let fd = Self::open_for_attach(config)?; @@ -101,52 +89,61 @@ impl SharedMemorySegment { Ok(segment) } - pub(crate) fn unlink(config: &SharedMemorySegmentConfig) -> io::Result<()> { - match &config.backing { - SharedMemorySegmentBacking::NamedShm { .. } => { - let name = config.shm_name()?; - let result = unsafe { libc::shm_unlink(name.as_ptr()) }; - if result == 0 { - return Ok(()); - } - let error = io::Error::last_os_error(); - if error.kind() == io::ErrorKind::NotFound { Ok(()) } else { Err(error) } - } - SharedMemorySegmentBacking::File { path } => match fs::remove_file(path) { - Ok(()) => Ok(()), - Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()), - Err(error) => Err(error), - }, + pub(crate) fn create_or_reset(config: &SharedMemorySegmentConfig) -> io::Result { + config.validate_capacity()?; + let fd = Self::open_for_create(config)?; + if let Err(error) = ftruncate(fd, config.capacity_bytes) { + let _ = close_fd(fd); + return Err(error); } + let segment = match Self::map(fd, config.capacity_bytes) { + Ok(segment) => segment, + Err(error) => { + let _ = close_fd(fd); + return Err(error); + } + }; + segment.zero(); + segment.write_header(config.initial_header()?); + Ok(segment) } pub(crate) fn header(&self) -> SharedMemoryHeader { unsafe { self.header_ptr().read() } } - pub(crate) fn payload_capacity(&self) -> usize { - self.capacity_bytes.saturating_sub(size_of::()) + fn header_ptr(&self) -> *mut SharedMemoryHeader { + self.ptr.as_ptr().cast::() } - pub(crate) fn write_payload(&self, offset: usize, bytes: &[u8]) -> io::Result<()> { - self.validate_payload_range(offset, bytes.len())?; - unsafe { - std::ptr::copy_nonoverlapping( - bytes.as_ptr(), - self.payload_ptr().add(offset), - bytes.len(), - ); + fn map(fd: RawFd, capacity_bytes: usize) -> io::Result { + let mapped = unsafe { + libc::mmap( + std::ptr::null_mut(), + capacity_bytes, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED, + fd, + 0, + ) + }; + if mapped == libc::MAP_FAILED { + return Err(io::Error::last_os_error()); } - Ok(()) + let ptr = NonNull::new(mapped.cast::()).ok_or_else(|| { + io::Error::other("mmap returned a null pointer for shared memory segment") + })?; + Ok(Self { capacity_bytes, fd, ptr }) } - pub(crate) fn read_payload(&self, offset: usize, len: usize) -> io::Result> { - self.validate_payload_range(offset, len)?; - let mut bytes = vec![0; len]; - unsafe { - std::ptr::copy_nonoverlapping(self.payload_ptr().add(offset), bytes.as_mut_ptr(), len); + fn open_for_attach(config: &SharedMemorySegmentConfig) -> io::Result { + match &config.backing { + SharedMemorySegmentBacking::NamedShm { .. } => { + let name = config.shm_name()?; + shm_open(&name, libc::O_RDWR, 0o600) + } + SharedMemorySegmentBacking::File { path } => open_file(path, false), } - Ok(bytes) } fn open_for_create(config: &SharedMemorySegmentConfig) -> io::Result { @@ -164,45 +161,39 @@ impl SharedMemorySegment { } } - fn open_for_attach(config: &SharedMemorySegmentConfig) -> io::Result { - match &config.backing { - SharedMemorySegmentBacking::NamedShm { .. } => { - let name = config.shm_name()?; - shm_open(&name, libc::O_RDWR, 0o600) - } - SharedMemorySegmentBacking::File { path } => open_file(path, false), - } + pub(crate) fn payload_capacity(&self) -> usize { + self.capacity_bytes.saturating_sub(size_of::()) } - fn map(fd: RawFd, capacity_bytes: usize) -> io::Result { - let mapped = unsafe { - libc::mmap( - std::ptr::null_mut(), - capacity_bytes, - libc::PROT_READ | libc::PROT_WRITE, - libc::MAP_SHARED, - fd, - 0, - ) - }; - if mapped == libc::MAP_FAILED { - return Err(io::Error::last_os_error()); - } - let ptr = NonNull::new(mapped.cast::()).ok_or_else(|| { - io::Error::other("mmap returned a null pointer for shared memory segment") - })?; - Ok(Self { fd, ptr, capacity_bytes }) + fn payload_ptr(&self) -> *mut u8 { + unsafe { self.ptr.as_ptr().add(size_of::()) } } - fn zero(&self) { + pub(crate) fn read_payload(&self, offset: usize, len: usize) -> io::Result> { + self.validate_payload_range(offset, len)?; + let mut bytes = vec![0; len]; unsafe { - std::ptr::write_bytes(self.ptr.as_ptr(), 0, self.capacity_bytes); + std::ptr::copy_nonoverlapping(self.payload_ptr().add(offset), bytes.as_mut_ptr(), len); } + Ok(bytes) } - pub(crate) fn write_header(&self, header: SharedMemoryHeader) { - unsafe { - self.header_ptr().write(header); + pub(crate) fn unlink(config: &SharedMemorySegmentConfig) -> io::Result<()> { + match &config.backing { + SharedMemorySegmentBacking::NamedShm { .. } => { + let name = config.shm_name()?; + let result = unsafe { libc::shm_unlink(name.as_ptr()) }; + if result == 0 { + return Ok(()); + } + let error = io::Error::last_os_error(); + if error.kind() == io::ErrorKind::NotFound { Ok(()) } else { Err(error) } + } + SharedMemorySegmentBacking::File { path } => match fs::remove_file(path) { + Ok(()) => Ok(()), + Err(error) if error.kind() == io::ErrorKind::NotFound => Ok(()), + Err(error) => Err(error), + }, } } @@ -278,12 +269,28 @@ impl SharedMemorySegment { Ok(()) } - fn header_ptr(&self) -> *mut SharedMemoryHeader { - self.ptr.as_ptr().cast::() + pub(crate) fn write_header(&self, header: SharedMemoryHeader) { + unsafe { + self.header_ptr().write(header); + } } - fn payload_ptr(&self) -> *mut u8 { - unsafe { self.ptr.as_ptr().add(size_of::()) } + pub(crate) fn write_payload(&self, offset: usize, bytes: &[u8]) -> io::Result<()> { + self.validate_payload_range(offset, bytes.len())?; + unsafe { + std::ptr::copy_nonoverlapping( + bytes.as_ptr(), + self.payload_ptr().add(offset), + bytes.len(), + ); + } + Ok(()) + } + + fn zero(&self) { + unsafe { + std::ptr::write_bytes(self.ptr.as_ptr(), 0, self.capacity_bytes); + } } } @@ -293,7 +300,3 @@ impl Drop for SharedMemorySegment { let _ = close_fd(self.fd); } } - -#[cfg(test)] -#[path = "memory/tests.rs"] -mod tests; diff --git a/crates/rginx-http/src/cache/shared/memory/config.rs b/crates/rginx-http/src/cache/shared/memory/config.rs index 9668b811..47018b2b 100644 --- a/crates/rginx-http/src/cache/shared/memory/config.rs +++ b/crates/rginx-http/src/cache/shared/memory/config.rs @@ -4,33 +4,37 @@ use std::mem::size_of; use std::path::PathBuf; use std::time::{SystemTime, UNIX_EPOCH}; -use super::*; +use super::{ + DEFAULT_HASH_BUCKET_COUNT, DEFAULT_OPERATION_RING_CAPACITY, SHM_ABI_VERSION, SHM_MAGIC, + SHM_NAME_PREFIX, SharedMemoryHeader, +}; #[derive(Clone, Debug)] pub(crate) enum SharedMemorySegmentBacking { - NamedShm { name: String }, File { path: PathBuf }, + NamedShm { name: String }, } #[derive(Clone, Debug)] pub(crate) struct SharedMemorySegmentConfig { pub(super) backing: SharedMemorySegmentBacking, - pub(super) zone_name_hash: u64, pub(super) capacity_bytes: usize, + pub(super) flags: u64, pub(super) hash_bucket_count: u64, pub(super) operation_ring_capacity: u64, - pub(super) flags: u64, + pub(super) zone_name_hash: u64, } impl SharedMemorySegmentConfig { - pub(super) fn for_zone(zone_name: &str, capacity_bytes: usize) -> Self { - let zone_name_hash = stable_hash(zone_name.as_bytes()); - Self::new(format!("{SHM_NAME_PREFIX}{zone_name_hash:016x}"), zone_name_hash, capacity_bytes) - } - - pub(crate) fn for_identity(identity: &str, capacity_bytes: usize) -> Self { - let zone_name_hash = stable_hash(identity.as_bytes()); - Self::new(format!("{SHM_NAME_PREFIX}{zone_name_hash:016x}"), zone_name_hash, capacity_bytes) + fn file_backed(path: PathBuf, zone_name_hash: u64, capacity_bytes: usize) -> Self { + Self { + backing: SharedMemorySegmentBacking::File { path }, + zone_name_hash, + capacity_bytes, + hash_bucket_count: DEFAULT_HASH_BUCKET_COUNT, + operation_ring_capacity: DEFAULT_OPERATION_RING_CAPACITY, + flags: 0, + } } pub(crate) fn for_file_identity( @@ -42,32 +46,45 @@ impl SharedMemorySegmentConfig { Self::new_file(path, zone_name_hash, capacity_bytes) } - pub(crate) fn new(name: impl Into, zone_name_hash: u64, capacity_bytes: usize) -> Self { - Self::named_shm(name.into(), zone_name_hash, capacity_bytes) + pub(crate) fn for_identity(identity: &str, capacity_bytes: usize) -> Self { + let zone_name_hash = stable_hash(identity.as_bytes()); + Self::new(format!("{SHM_NAME_PREFIX}{zone_name_hash:016x}"), zone_name_hash, capacity_bytes) } - pub(crate) fn new_file( - path: impl Into, - zone_name_hash: u64, - capacity_bytes: usize, - ) -> Self { - Self::file_backed(path.into(), zone_name_hash, capacity_bytes) + pub(super) fn for_zone(zone_name: &str, capacity_bytes: usize) -> Self { + let zone_name_hash = stable_hash(zone_name.as_bytes()); + Self::new(format!("{SHM_NAME_PREFIX}{zone_name_hash:016x}"), zone_name_hash, capacity_bytes) } - fn named_shm(name: String, zone_name_hash: u64, capacity_bytes: usize) -> Self { - Self { - backing: SharedMemorySegmentBacking::NamedShm { name }, - zone_name_hash, - capacity_bytes, - hash_bucket_count: DEFAULT_HASH_BUCKET_COUNT, - operation_ring_capacity: DEFAULT_OPERATION_RING_CAPACITY, - flags: 0, - } + pub(super) fn initial_header(&self) -> io::Result { + Ok(SharedMemoryHeader { + magic: SHM_MAGIC, + abi_version: SHM_ABI_VERSION, + header_len: u32::try_from(size_of::()).map_err(|_| { + io::Error::new(io::ErrorKind::InvalidInput, "shared memory header is too large") + })?, + zone_name_hash: self.zone_name_hash, + capacity_bytes: u64::try_from(self.capacity_bytes).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "shared memory capacity does not fit u64", + ) + })?, + generation: 1, + store_epoch: current_store_epoch(), + operation_seq: 0, + entry_count: 0, + current_size_bytes: 0, + hash_bucket_count: self.hash_bucket_count, + operation_ring_capacity: self.operation_ring_capacity, + allocator_free_head: 0, + flags: self.flags, + }) } - fn file_backed(path: PathBuf, zone_name_hash: u64, capacity_bytes: usize) -> Self { + fn named_shm(name: String, zone_name_hash: u64, capacity_bytes: usize) -> Self { Self { - backing: SharedMemorySegmentBacking::File { path }, + backing: SharedMemorySegmentBacking::NamedShm { name }, zone_name_hash, capacity_bytes, hash_bucket_count: DEFAULT_HASH_BUCKET_COUNT, @@ -76,19 +93,16 @@ impl SharedMemorySegmentConfig { } } - pub(crate) fn with_hash_bucket_count(mut self, hash_bucket_count: u64) -> Self { - self.hash_bucket_count = hash_bucket_count; - self - } - - pub(crate) fn with_operation_ring_capacity(mut self, operation_ring_capacity: u64) -> Self { - self.operation_ring_capacity = operation_ring_capacity; - self + pub(crate) fn new(name: impl Into, zone_name_hash: u64, capacity_bytes: usize) -> Self { + Self::named_shm(name.into(), zone_name_hash, capacity_bytes) } - pub(crate) fn with_capacity_bytes(mut self, capacity_bytes: usize) -> Self { - self.capacity_bytes = capacity_bytes; - self + pub(crate) fn new_file( + path: impl Into, + zone_name_hash: u64, + capacity_bytes: usize, + ) -> Self { + Self::file_backed(path.into(), zone_name_hash, capacity_bytes) } pub(super) fn shm_name(&self) -> io::Result { @@ -128,30 +142,19 @@ impl SharedMemorySegmentConfig { Ok(()) } - pub(super) fn initial_header(&self) -> io::Result { - Ok(SharedMemoryHeader { - magic: SHM_MAGIC, - abi_version: SHM_ABI_VERSION, - header_len: u32::try_from(size_of::()).map_err(|_| { - io::Error::new(io::ErrorKind::InvalidInput, "shared memory header is too large") - })?, - zone_name_hash: self.zone_name_hash, - capacity_bytes: u64::try_from(self.capacity_bytes).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "shared memory capacity does not fit u64", - ) - })?, - generation: 1, - store_epoch: current_store_epoch(), - operation_seq: 0, - entry_count: 0, - current_size_bytes: 0, - hash_bucket_count: self.hash_bucket_count, - operation_ring_capacity: self.operation_ring_capacity, - allocator_free_head: 0, - flags: self.flags, - }) + pub(crate) fn with_capacity_bytes(mut self, capacity_bytes: usize) -> Self { + self.capacity_bytes = capacity_bytes; + self + } + + pub(crate) fn with_hash_bucket_count(mut self, hash_bucket_count: u64) -> Self { + self.hash_bucket_count = hash_bucket_count; + self + } + + pub(crate) fn with_operation_ring_capacity(mut self, operation_ring_capacity: u64) -> Self { + self.operation_ring_capacity = operation_ring_capacity; + self } } diff --git a/crates/rginx-http/src/cache/state.rs b/crates/rginx-http/src/cache/state.rs index c8685fa4..e72756ff 100644 --- a/crates/rginx-http/src/cache/state.rs +++ b/crates/rginx-http/src/cache/state.rs @@ -1,34 +1,39 @@ -use super::*; +use super::{ + Arc, AsyncMutex, AtomicU64, BTreeSet, CacheChangeNotifier, CacheFillReadState, CacheIoLockPool, + CacheMetadata, CacheStatus, CacheZone, Deserialize, ExternalCacheFillReadState, HashMap, + HeaderMap, HeaderValue, Mutex, Notify, OwnedNotified, PathBuf, RwLock, Serialize, + SharedFillExternalStateHandle, SharedIndexStore, StatusCode, Weak, +}; pub(super) struct CacheZoneRuntime { + pub(super) change_notifier: Option, pub(super) config: Arc, - pub(super) index: RwLock, + pub(super) fill_lock_generation: AtomicU64, + pub(super) fill_locks: Arc>>, pub(super) hot_entries: RwLock>>, + pub(super) index: RwLock, pub(super) io_locks: CacheIoLockPool, - pub(super) shared_index_sync_lock: AsyncMutex<()>, - pub(super) shared_index_store: Option>, - pub(super) fill_locks: Arc>>, - pub(super) fill_lock_generation: AtomicU64, pub(super) last_inactive_cleanup_unix_ms: AtomicU64, + pub(super) shared_index_change_seq: AtomicU64, pub(super) shared_index_generation: AtomicU64, + pub(super) shared_index_store: Option>, pub(super) shared_index_store_epoch: AtomicU64, - pub(super) shared_index_change_seq: AtomicU64, + pub(super) shared_index_sync_lock: AsyncMutex<()>, pub(super) stats: CacheZoneStats, - pub(super) change_notifier: Option, } #[derive(Default, Clone)] pub(super) struct CacheIndex { + pub(super) access_schedule: BTreeSet, + pub(super) access_ticket_by_key: HashMap, + pub(super) admission_counts: HashMap, + pub(super) current_size_bytes: usize, pub(super) entries: HashMap, pub(super) hash_ref_counts: HashMap, - pub(super) variants: HashMap>, - pub(super) admission_counts: HashMap, - pub(super) invalidations: Vec, pub(super) invalidation_index: CacheInvalidationIndex, - pub(super) current_size_bytes: usize, + pub(super) invalidations: Vec, pub(super) maintenance_next_ticket: u64, - pub(super) access_schedule: BTreeSet, - pub(super) access_ticket_by_key: HashMap, + pub(super) variants: HashMap>, } #[derive(Default, Clone)] @@ -39,11 +44,26 @@ pub(super) struct CacheInvalidationIndex { pub(super) latest_tag_created_at_unix_ms: HashMap, } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct CacheAccessScheduleEntry { + pub(super) key: String, pub(super) last_access_unix_ms: u64, pub(super) ticket: u64, - pub(super) key: String, +} + +impl Ord for CacheAccessScheduleEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.last_access_unix_ms + .cmp(&other.last_access_unix_ms) + .then_with(|| self.ticket.cmp(&other.ticket)) + .then_with(|| self.key.cmp(&other.key)) + } +} + +impl PartialOrd for CacheAccessScheduleEntry { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -54,29 +74,29 @@ pub(super) struct CacheAccessScheduleTicket { #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct CacheIndexEntry { - pub(super) kind: CacheIndexEntryKind, - pub(super) hash: String, pub(super) base_key: String, - pub(super) stored_at_unix_ms: u64, - pub(super) vary: Vec, - pub(super) tags: Vec, pub(super) body_size_bytes: usize, pub(super) expires_at_unix_ms: u64, pub(super) grace_until_unix_ms: Option, + pub(super) hash: String, pub(super) keep_until_unix_ms: Option, + pub(super) kind: CacheIndexEntryKind, + pub(super) last_access_unix_ms: u64, + pub(super) must_revalidate: bool, + pub(super) requires_revalidation: bool, pub(super) stale_if_error_until_unix_ms: Option, pub(super) stale_while_revalidate_until_unix_ms: Option, - pub(super) requires_revalidation: bool, - pub(super) must_revalidate: bool, - pub(super) last_access_unix_ms: u64, + pub(super) stored_at_unix_ms: u64, + pub(super) tags: Vec, + pub(super) vary: Vec, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] #[serde(rename_all = "snake_case")] pub(super) enum CacheIndexEntryKind { + HitForPass, #[default] Response, - HitForPass, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -90,41 +110,41 @@ pub(super) enum CacheInvalidationSelector { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub(super) struct CacheInvalidationRule { - pub(super) selector: CacheInvalidationSelector, pub(super) created_at_unix_ms: u64, + pub(super) selector: CacheInvalidationSelector, } #[derive(Default)] pub(super) struct CacheZoneStats { - pub(super) hit_total: AtomicU64, - pub(super) miss_total: AtomicU64, pub(super) bypass_total: AtomicU64, - pub(super) expired_total: AtomicU64, - pub(super) stale_total: AtomicU64, - pub(super) updating_total: AtomicU64, - pub(super) revalidated_total: AtomicU64, - pub(super) write_success_total: AtomicU64, - pub(super) write_error_total: AtomicU64, pub(super) eviction_total: AtomicU64, - pub(super) purge_total: AtomicU64, - pub(super) invalidation_total: AtomicU64, + pub(super) expired_total: AtomicU64, + pub(super) hit_total: AtomicU64, pub(super) inactive_cleanup_total: AtomicU64, - pub(super) wait_local_total: AtomicU64, - pub(super) wait_external_total: AtomicU64, - pub(super) stale_serve_updating_total: AtomicU64, + pub(super) invalidation_total: AtomicU64, + pub(super) miss_total: AtomicU64, + pub(super) purge_total: AtomicU64, + pub(super) revalidate_not_modified_total: AtomicU64, + pub(super) revalidated_total: AtomicU64, pub(super) stale_serve_error_total: AtomicU64, - pub(super) stale_serve_timeout_total: AtomicU64, pub(super) stale_serve_status_total: AtomicU64, - pub(super) revalidate_not_modified_total: AtomicU64, + pub(super) stale_serve_timeout_total: AtomicU64, + pub(super) stale_serve_updating_total: AtomicU64, + pub(super) stale_total: AtomicU64, + pub(super) updating_total: AtomicU64, + pub(super) wait_external_total: AtomicU64, + pub(super) wait_local_total: AtomicU64, + pub(super) write_error_total: AtomicU64, + pub(super) write_success_total: AtomicU64, } pub(super) struct CacheFillGuard { - pub(super) key: String, - pub(super) generation: u64, - pub(super) fill_locks: Weak>>, - pub(super) notify: Arc, pub(super) external_lock_path: Option, pub(super) external_state: Option, + pub(super) fill_locks: Weak>>, + pub(super) generation: u64, + pub(super) key: String, + pub(super) notify: Arc, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -135,41 +155,46 @@ pub(super) struct CachedVaryHeaderValue { #[derive(Clone)] pub(super) struct CacheFillLockState { - pub(super) notify: Arc, pub(super) acquired_at_unix_ms: u64, pub(super) generation: u64, - pub(super) share_fingerprint: String, + pub(super) notify: Arc, pub(super) reader_state: Option>, + pub(super) share_fingerprint: String, } #[derive(Clone)] pub(super) struct CacheConditionalHeaders { - pub(super) if_none_match: Option, pub(super) if_modified_since: Option, + pub(super) if_none_match: Option, } pub(super) struct CacheEntryHotState { pub(super) last_access_unix_ms: AtomicU64, - pub(super) shared_touch_published_unix_ms: AtomicU64, pub(super) response_head: Mutex>>, + pub(super) shared_touch_published_unix_ms: AtomicU64, } pub(super) struct PreparedCacheResponseHead { + pub(super) body_size_bytes: usize, + pub(super) conditional_headers: Option, pub(super) hash: String, + pub(super) headers: HeaderMap, pub(super) metadata: Arc, pub(super) status: StatusCode, - pub(super) headers: HeaderMap, - pub(super) conditional_headers: Option, - pub(super) body_size_bytes: usize, } #[derive(Clone)] pub(super) enum InflightFillReadState { - Local(Arc), External(ExternalCacheFillReadState), + Local(Arc), } pub(super) enum LookupDecision { + BackgroundUpdate { + key: String, + cached_entry: CacheIndexEntry, + fill_guard: CacheFillGuard, + }, Bypass { status: CacheStatus, }, @@ -181,16 +206,6 @@ pub(super) enum LookupDecision { key: String, entry: CacheIndexEntry, }, - Stale { - key: String, - entry: CacheIndexEntry, - status: CacheStatus, - }, - BackgroundUpdate { - key: String, - cached_entry: CacheIndexEntry, - fill_guard: CacheFillGuard, - }, Miss { key: String, base_key: String, @@ -201,26 +216,31 @@ pub(super) enum LookupDecision { ReadWhileFill { state: InflightFillReadState, }, + Stale { + key: String, + entry: CacheIndexEntry, + status: CacheStatus, + }, Wait { strategy: LookupWait, }, } pub(super) enum LookupWait { - Local { waiter: OwnedNotified }, External { key: String }, + Local { waiter: OwnedNotified }, } pub(super) enum FillLockDecision { Acquired(CacheFillGuard), Read { state: InflightFillReadState }, - WaitLocal { waiter: OwnedNotified }, WaitExternal { key: String }, + WaitLocal { waiter: OwnedNotified }, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum CacheStaleReason { Error, - Timeout, Status(StatusCode), + Timeout, } diff --git a/crates/rginx-http/src/cache/store.rs b/crates/rginx-http/src/cache/store.rs index a603154f..38ecdc25 100644 --- a/crates/rginx-http/src/cache/store.rs +++ b/crates/rginx-http/src/cache/store.rs @@ -1,3 +1,9 @@ +mod helpers; +mod maintenance; +pub(in crate::cache) mod range; +mod revalidate; +mod streaming; + use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -19,12 +25,6 @@ use super::{ with_cache_status, }; -mod helpers; -mod maintenance; -pub(in crate::cache) mod range; -mod revalidate; -mod streaming; - use helpers::{ cache_final_key_for_response, cache_metadata_input, freshness_is_cacheable, merge_not_modified_headers, should_remember_hit_for_pass, diff --git a/crates/rginx-http/src/cache/store/helpers.rs b/crates/rginx-http/src/cache/store/helpers.rs index 0d1b65dd..2fcfd01c 100644 --- a/crates/rginx-http/src/cache/store/helpers.rs +++ b/crates/rginx-http/src/cache/store/helpers.rs @@ -7,7 +7,7 @@ use super::super::invalidation::normalize_cache_tag; use super::super::policy::{ResponseFreshness, vary_headers}; use super::super::vary::normalized_request_header_values; use super::super::{CachedVaryHeaderValue, PurgeSelector}; -use super::*; +use super::{CacheMetadataInput, cache_variant_key, duration_to_ms}; pub(super) fn cache_metadata_input( base_key: &str, diff --git a/crates/rginx-http/src/cache/store/maintenance/index_state.rs b/crates/rginx-http/src/cache/store/maintenance/index_state.rs index 042e5d40..c1f1347e 100644 --- a/crates/rginx-http/src/cache/store/maintenance/index_state.rs +++ b/crates/rginx-http/src/cache/store/maintenance/index_state.rs @@ -4,7 +4,7 @@ use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; use super::super::super::shared::{ SharedIndexOperation, apply_zone_shared_index_operations_locked, }; -use super::super::*; +use super::super::{Arc, CacheIndex, CacheIndexEntry, CacheZoneRuntime}; use crate::cache::vary::sorted_vary_dimension_names; pub(in crate::cache) struct CacheAdmissionDecision { @@ -12,8 +12,8 @@ pub(in crate::cache) struct CacheAdmissionDecision { } pub(in crate::cache) struct RemovedIndexEntry { - pub(in crate::cache) hash: String, pub(in crate::cache) delete_files: bool, + pub(in crate::cache) hash: String, } pub(in crate::cache) async fn record_cache_admission_attempt( @@ -172,11 +172,11 @@ pub(in crate::cache) fn inactive_cleanup_candidates( } pub(in crate::cache) fn read_index(index: &RwLock) -> RwLockReadGuard<'_, CacheIndex> { - index.read().unwrap_or_else(|poisoned| poisoned.into_inner()) + index.read().unwrap_or_else(std::sync::PoisonError::into_inner) } pub(in crate::cache) fn lock_index(index: &RwLock) -> RwLockWriteGuard<'_, CacheIndex> { - index.write().unwrap_or_else(|poisoned| poisoned.into_inner()) + index.write().unwrap_or_else(std::sync::PoisonError::into_inner) } pub(super) fn add_variant_key( diff --git a/crates/rginx-http/src/cache/store/maintenance/mod.rs b/crates/rginx-http/src/cache/store/maintenance/mod.rs index cadaef37..151efbe2 100644 --- a/crates/rginx-http/src/cache/store/maintenance/mod.rs +++ b/crates/rginx-http/src/cache/store/maintenance/mod.rs @@ -1,16 +1,19 @@ +mod index_state; +mod store_update; + use super::super::invalidation::{ entry_is_logically_invalid, invalidation_rule_matches_entry, invalidation_scope, }; use super::super::remove_cache_files_if_unreferenced; use super::super::shared::{SharedIndexOperation, apply_zone_shared_index_operations_locked}; -use super::*; +use super::{ + Arc, CacheIndexEntry, CachePurgeResult, CacheZoneRuntime, SystemTime, duration_to_ms, + purge_scope, purge_selector_matches, unix_time_ms, +}; use crate::cache::{ CacheInvalidationResult, CacheInvalidationRule, CacheInvalidationSelector, PurgeSelector, }; -mod index_state; -mod store_update; - use index_state::{add_variant_key, remove_variant_key, variant_keys_with_different_dimensions}; pub(in crate::cache) use index_state::{ eviction_candidates, inactive_cleanup_candidates, lock_index, read_index, @@ -56,7 +59,7 @@ pub(in crate::cache) async fn cleanup_inactive_entries_in_zone(zone: &Arc>(); let mut removed = Vec::with_capacity(matching_keys.len()); - let mut shared_operations = Vec::with_capacity(matching_keys.len() * 2); + let mut shared_operations = Vec::with_capacity(matching_keys.len().saturating_mul(2)); for key in matching_keys { if let Some(entry) = index.remove_entry(&key) { index.current_size_bytes = diff --git a/crates/rginx-http/src/cache/store/maintenance/store_update.rs b/crates/rginx-http/src/cache/store/maintenance/store_update.rs index bb950547..1e4c2e22 100644 --- a/crates/rginx-http/src/cache/store/maintenance/store_update.rs +++ b/crates/rginx-http/src/cache/store/maintenance/store_update.rs @@ -1,4 +1,8 @@ -use super::*; +use super::{ + Arc, CacheIndexEntry, CacheZoneRuntime, SharedIndexOperation, add_variant_key, + apply_zone_shared_index_operations_locked, eviction_candidates, lock_index, remove_variant_key, + variant_keys_with_different_dimensions, +}; pub(in crate::cache) async fn update_index_after_store( zone: &Arc, @@ -77,7 +81,7 @@ pub(in crate::cache) async fn update_index_after_store( removed_hashes.insert(evicted_entry.hash); } removed_keys.push(evicted_key.clone()); - eviction_count += 1; + eviction_count = eviction_count.saturating_add(1); shared_operations.push(SharedIndexOperation::RemoveEntry { key: evicted_key.clone() }); shared_operations.push(SharedIndexOperation::RemoveAdmissionCount { key: evicted_key }); } diff --git a/crates/rginx-http/src/cache/store/range.rs b/crates/rginx-http/src/cache/store/range.rs index 546bc7b3..281b8720 100644 --- a/crates/rginx-http/src/cache/store/range.rs +++ b/crates/rginx-http/src/cache/store/range.rs @@ -4,7 +4,7 @@ use hyper::body::{Body, Frame, SizeHint}; use pin_project_lite::pin_project; use super::super::entry::DownstreamRangeTrimPlan; -use super::*; +use super::HttpResponse; use crate::handler::{BoxError, boxed_body}; pin_project! { @@ -31,6 +31,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -50,7 +54,7 @@ where std::task::Poll::Ready(Some(Ok(frame))) => match frame.into_data() { Ok(mut data) => { if *this.skip_bytes >= data.len() { - *this.skip_bytes -= data.len(); + *this.skip_bytes = (*this.skip_bytes).saturating_sub(data.len()); continue; } if *this.skip_bytes > 0 { @@ -63,7 +67,7 @@ where if data.len() > *this.emit_bytes { data = data.slice(..*this.emit_bytes); } - *this.emit_bytes -= data.len(); + *this.emit_bytes = (*this.emit_bytes).saturating_sub(data.len()); if *this.emit_bytes == 0 { *this.done = true; } @@ -99,10 +103,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { let mut hint = SizeHint::default(); hint.set_exact(self.emit_bytes as u64); diff --git a/crates/rginx-http/src/cache/store/revalidate.rs b/crates/rginx-http/src/cache/store/revalidate.rs index c1a6cf9a..f40d287b 100644 --- a/crates/rginx-http/src/cache/store/revalidate.rs +++ b/crates/rginx-http/src/cache/store/revalidate.rs @@ -1,4 +1,13 @@ -use super::*; +use super::{ + Arc, CacheIndexEntry, CacheStatus, CacheStoreContext, CacheStoreError, HttpResponse, + ResponseBodySize, SystemTime, build_cached_response_for_request, cache_final_key_for_response, + cache_metadata, cache_metadata_input, cache_paths_for_zone, freshness_is_cacheable, + merge_not_modified_headers, prepare_cached_response_head, remember_hit_for_pass, + remove_cache_files_if_unreferenced, remove_cache_files_locked, + remove_zone_index_entry_if_matches, response_freshness, response_is_storable_with_size, + response_no_cache, should_remember_hit_for_pass, unix_time_ms, update_index_after_store, + with_cache_status, write_cache_metadata, +}; pub(in crate::cache) async fn refresh_not_modified_response( context: CacheStoreContext, diff --git a/crates/rginx-http/src/cache/store/streaming.rs b/crates/rginx-http/src/cache/store/streaming.rs index 869a39d8..2ab3119c 100644 --- a/crates/rginx-http/src/cache/store/streaming.rs +++ b/crates/rginx-http/src/cache/store/streaming.rs @@ -1,3 +1,6 @@ +mod body; +mod finalize; + use std::path::PathBuf; use std::time::Duration; @@ -8,65 +11,58 @@ use hyper::body::Body as _; use tokio::fs::{self, File}; use tokio::sync::mpsc; -mod body; -mod finalize; - use super::super::entry::DownstreamRangeTrimPlan; use super::super::fill::{CacheFillReadState, inflight_fill_body}; use super::range::build_downstream_response; -use super::*; +use super::{ + Arc, CacheIndexEntry, CacheStoreContext, CacheZoneRuntime, HttpBody, HttpResponse, SystemTime, + cache_entry_temp_body_path, cache_final_key_for_response, cache_key_hash, cache_metadata, + cache_metadata_input, cache_paths_for_zone, commit_cache_entry_temp_body, duration_to_ms, + freshness_is_cacheable, prepare_cached_response_head, record_cache_admission_attempt, + remember_hit_for_pass, remove_cache_files_if_unreferenced, response_freshness, + should_remember_hit_for_pass, unix_time_ms, update_index_after_store, +}; use body::{StreamingCacheBody, spawn_streaming_origin_fill}; use finalize::spawn_streaming_cache_writer; const STREAMING_CACHE_WRITE_QUEUE_DEPTH: usize = 8; struct StreamingCachePlan { - zone: Arc, + _fill_guard: Option, base_key: String, + body_tmp: PathBuf, + expected_body_bytes: Option, + fill_state: Option>, final_key: String, - vary: Vec, - tags: Vec, - status: StatusCode, - headers: HeaderMap, freshness: super::super::policy::ResponseFreshness, - now: u64, - hash: String, - paths: super::super::entry::CachePaths, - body_tmp: PathBuf, - max_entry_bytes: usize, grace: Option, + hash: String, + headers: HeaderMap, keep: Option, + max_entry_bytes: usize, + now: u64, pass_ttl: Option, - expected_body_bytes: Option, - revalidating: bool, + paths: super::super::entry::CachePaths, replaced_entry: Option<(String, CacheIndexEntry)>, - _fill_guard: Option, - fill_state: Option>, + revalidating: bool, + status: StatusCode, + tags: Vec, + vary: Vec, + zone: Arc, } pub(super) enum StreamingCacheWriteMessage { + Abort, Data(Bytes), Finish { trailers: Option }, - Abort, } pub(super) struct StreamingCacheWriter { - tx: mpsc::Sender, fill_state: Option>, + tx: mpsc::Sender, } impl StreamingCacheWriter { - fn new( - tx: mpsc::Sender, - fill_state: Option>, - ) -> Self { - Self { tx, fill_state } - } - - fn try_send_data(&self, bytes: Bytes) -> bool { - self.tx.try_send(StreamingCacheWriteMessage::Data(bytes)).is_ok() - } - fn abort(&self, reason: &str) { if let Some(fill_state) = self.fill_state.as_ref() { fill_state.fail(reason); @@ -74,25 +70,35 @@ impl StreamingCacheWriter { let _ = self.tx.try_send(StreamingCacheWriteMessage::Abort); } - fn try_finish(&self, trailers: Option) -> bool { - let sent = self.tx.try_send(StreamingCacheWriteMessage::Finish { trailers }).is_ok(); + async fn finish(self, trailers: Option) -> bool { + let sent = self.tx.send(StreamingCacheWriteMessage::Finish { trailers }).await.is_ok(); if sent && let Some(fill_state) = self.fill_state.as_ref() { fill_state.mark_upstream_complete(); } sent } + fn new( + tx: mpsc::Sender, + fill_state: Option>, + ) -> Self { + Self { fill_state, tx } + } async fn send_data(&self, bytes: Bytes) -> bool { self.tx.send(StreamingCacheWriteMessage::Data(bytes)).await.is_ok() } - async fn finish(self, trailers: Option) -> bool { - let sent = self.tx.send(StreamingCacheWriteMessage::Finish { trailers }).await.is_ok(); + fn try_finish(&self, trailers: Option) -> bool { + let sent = self.tx.try_send(StreamingCacheWriteMessage::Finish { trailers }).is_ok(); if sent && let Some(fill_state) = self.fill_state.as_ref() { fill_state.mark_upstream_complete(); } sent } + + fn try_send_data(&self, bytes: Bytes) -> bool { + self.tx.try_send(StreamingCacheWriteMessage::Data(bytes)).is_ok() + } } pub(super) async fn store_streaming_response( @@ -137,8 +143,7 @@ pub(super) async fn store_streaming_response( .cached_entry .as_ref() .filter(|_| context.key == final_key) - .map(|entry| entry.hash.clone()) - .unwrap_or_else(|| cache_key_hash(&final_key)); + .map_or_else(|| cache_key_hash(&final_key), |entry| entry.hash.clone()); let paths = cache_paths_for_zone(context.zone.config.as_ref(), &hash); let max_entry_bytes = context.zone.config.max_entry_bytes; let replaced_entry = context diff --git a/crates/rginx-http/src/cache/store/streaming/body.rs b/crates/rginx-http/src/cache/store/streaming/body.rs index 932dc816..5aeec29d 100644 --- a/crates/rginx-http/src/cache/store/streaming/body.rs +++ b/crates/rginx-http/src/cache/store/streaming/body.rs @@ -5,106 +5,30 @@ use bytes::Bytes; use http_body_util::BodyExt; use hyper::body::{Frame, SizeHint}; -pub(super) fn spawn_streaming_origin_fill( - mut inner: HttpBody, - writer: StreamingCacheWriter, - fill_state: Option>, -) { - let Some(handle) = tokio::runtime::Handle::try_current().ok() else { - if let Some(fill_state) = fill_state.as_ref() { - fill_state.fail("streaming cache fill requires an active Tokio runtime"); - } - return; - }; - - handle.spawn(async move { - drive_streaming_origin_fill(&mut inner, writer, fill_state).await; - }); -} - -async fn drive_streaming_origin_fill( - inner: &mut HttpBody, - writer: StreamingCacheWriter, - fill_state: Option>, -) { - let mut writer = Some(writer); - - while let Some(frame) = inner.frame().await { - let Ok(frame) = frame else { - if let Some(fill_state) = fill_state.as_ref() { - fill_state.fail("upstream body read failed while filling cache"); - } - return; - }; - - let stream_completed = frame.is_trailers() || inner.is_end_stream(); - let trailers = frame.trailers_ref().cloned(); - - if let Some(data) = frame.data_ref() { - if data.is_empty() { - if !stream_completed { - continue; - } - } else { - let Some(cache_writer) = writer.as_ref() else { - return; - }; - if !cache_writer.send_data(data.clone()).await { - if let Some(fill_state) = fill_state.as_ref() { - fill_state.fail("streaming cache writer channel closed before EOF"); - } - return; - } - } - } - - if stream_completed { - let Some(cache_writer) = writer.take() else { - return; - }; - if !cache_writer.finish(trailers).await - && let Some(fill_state) = fill_state.as_ref() - { - fill_state.fail("streaming cache writer channel closed before end-of-stream"); - } - return; - } - } - - let Some(cache_writer) = writer.take() else { - return; - }; - if !cache_writer.finish(None).await - && let Some(fill_state) = fill_state.as_ref() - { - fill_state.fail("streaming cache writer channel closed before end-of-stream"); - } -} - pub(super) struct StreamingCacheBody { - inner: HttpBody, - size_hint: SizeHint, cache: Option, cached_body_bytes: usize, - max_entry_bytes: usize, done: bool, + inner: HttpBody, + max_entry_bytes: usize, + size_hint: SizeHint, } impl StreamingCacheBody { - pub(super) fn new( - inner: HttpBody, - size_hint: SizeHint, - cache: StreamingCacheWriter, - max_entry_bytes: usize, - ) -> Self { - Self { - inner, - size_hint, - cache: Some(cache), - cached_body_bytes: 0, - max_entry_bytes, - done: false, + fn cache_frame_data(&mut self, data: &Bytes) { + if data.is_empty() { + return; } + + let Some(cache) = self.cache.as_ref() else { + return; + }; + let next_body_size = self.cached_body_bytes.saturating_add(data.len()); + if next_body_size > self.max_entry_bytes || !cache.try_send_data(data.clone()) { + self.disable_cache(); + return; + } + self.cached_body_bytes = next_body_size; } fn disable_cache(&mut self) { @@ -128,20 +52,20 @@ impl StreamingCacheBody { } } - fn cache_frame_data(&mut self, data: &Bytes) { - if data.is_empty() { - return; - } - - let Some(cache) = self.cache.as_ref() else { - return; - }; - let next_body_size = self.cached_body_bytes.saturating_add(data.len()); - if next_body_size > self.max_entry_bytes || !cache.try_send_data(data.clone()) { - self.disable_cache(); - return; + pub(super) fn new( + inner: HttpBody, + size_hint: SizeHint, + cache: StreamingCacheWriter, + max_entry_bytes: usize, + ) -> Self { + Self { + inner, + size_hint, + cache: Some(cache), + cached_body_bytes: 0, + max_entry_bytes, + done: false, } - self.cached_body_bytes = next_body_size; } } @@ -168,6 +92,10 @@ impl hyper::body::Body for StreamingCacheBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -204,15 +132,87 @@ impl hyper::body::Body for StreamingCacheBody { } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { self.size_hint.clone() } } +pub(super) fn spawn_streaming_origin_fill( + mut inner: HttpBody, + writer: StreamingCacheWriter, + fill_state: Option>, +) { + let Some(handle) = tokio::runtime::Handle::try_current().ok() else { + if let Some(fill_state) = fill_state.as_ref() { + fill_state.fail("streaming cache fill requires an active Tokio runtime"); + } + return; + }; + + handle.spawn(async move { + drive_streaming_origin_fill(&mut inner, writer, fill_state).await; + }); +} + +async fn drive_streaming_origin_fill( + inner: &mut HttpBody, + writer: StreamingCacheWriter, + fill_state: Option>, +) { + let mut writer = Some(writer); + + while let Some(frame) = inner.frame().await { + let Ok(frame) = frame else { + if let Some(fill_state) = fill_state.as_ref() { + fill_state.fail("upstream body read failed while filling cache"); + } + return; + }; + + let stream_completed = frame.is_trailers() || inner.is_end_stream(); + let trailers = frame.trailers_ref().cloned(); + + if let Some(data) = frame.data_ref() { + if data.is_empty() { + if !stream_completed { + continue; + } + } else { + let Some(cache_writer) = writer.as_ref() else { + return; + }; + if !cache_writer.send_data(data.clone()).await { + if let Some(fill_state) = fill_state.as_ref() { + fill_state.fail("streaming cache writer channel closed before EOF"); + } + return; + } + } + } + + if stream_completed { + let Some(cache_writer) = writer.take() else { + return; + }; + if !cache_writer.finish(trailers).await + && let Some(fill_state) = fill_state.as_ref() + { + fill_state.fail("streaming cache writer channel closed before end-of-stream"); + } + return; + } + } + + let Some(cache_writer) = writer.take() else { + return; + }; + if !cache_writer.finish(None).await + && let Some(fill_state) = fill_state.as_ref() + { + fill_state.fail("streaming cache writer channel closed before end-of-stream"); + } +} + async fn drain_remaining_frames( mut inner: HttpBody, cache: StreamingCacheWriter, diff --git a/crates/rginx-http/src/cache/store/streaming/finalize.rs b/crates/rginx-http/src/cache/store/streaming/finalize.rs index ef20c8e2..9d2858ac 100644 --- a/crates/rginx-http/src/cache/store/streaming/finalize.rs +++ b/crates/rginx-http/src/cache/store/streaming/finalize.rs @@ -1,4 +1,9 @@ -use super::*; +use super::{ + Arc, CacheIndexEntry, StreamingCachePlan, StreamingCacheWriteMessage, StreamingCacheWriter, + cache_key_hash, cache_metadata, cache_metadata_input, commit_cache_entry_temp_body, + duration_to_ms, prepare_cached_response_head, remove_cache_files_if_unreferenced, + update_index_after_store, +}; use tokio::fs::{self, File}; use tokio::io::AsyncWriteExt; use tokio::sync::mpsc; diff --git a/crates/rginx-http/src/cache/tests/lookup.rs b/crates/rginx-http/src/cache/tests/lookup.rs index e541b0e2..ea6acf8a 100644 --- a/crates/rginx-http/src/cache/tests/lookup.rs +++ b/crates/rginx-http/src/cache/tests/lookup.rs @@ -1,3 +1,9 @@ +mod background; +mod keys; +mod keys_head; +mod keys_slice; +mod recovery; +mod truncated_body; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -9,10 +15,3 @@ use tokio::sync::Notify; use crate::handler::full_body; use super::*; - -mod background; -mod keys; -mod keys_head; -mod keys_slice; -mod recovery; -mod truncated_body; diff --git a/crates/rginx-http/src/cache/tests/lookup/keys.rs b/crates/rginx-http/src/cache/tests/lookup/keys.rs index 98628df8..0f391f66 100644 --- a/crates/rginx-http/src/cache/tests/lookup/keys.rs +++ b/crates/rginx-http/src/cache/tests/lookup/keys.rs @@ -1,7 +1,7 @@ -use super::*; - mod range_requests; +use super::*; + #[test] fn cache_key_template_renders_request_parts() { let template = diff --git a/crates/rginx-http/src/cache/tests/mod.rs b/crates/rginx-http/src/cache/tests/mod.rs index 9157c98b..49f20544 100644 --- a/crates/rginx-http/src/cache/tests/mod.rs +++ b/crates/rginx-http/src/cache/tests/mod.rs @@ -1,15 +1,3 @@ -use std::collections::HashMap; -use std::path::PathBuf; -use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Mutex, RwLock}; -use std::time::Duration; - -use http::{Method, StatusCode}; -use rginx_core::{CacheZone, RouteCachePolicy}; - -use super::entry::CacheMetadataInput; -use super::*; - mod lookup; mod notifications; mod policy; @@ -22,6 +10,18 @@ mod storage_p5; mod storage_regressions; mod stress; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::atomic::AtomicU64; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::Duration; + +use http::{Method, StatusCode}; +use rginx_core::{CacheZone, RouteCachePolicy}; + +use super::entry::CacheMetadataInput; +use super::*; + fn test_zone(path: PathBuf, max_entry_bytes: usize) -> Arc { test_zone_with_notifier(path, max_entry_bytes, None) } diff --git a/crates/rginx-http/src/cache/tests/storage_p1.rs b/crates/rginx-http/src/cache/tests/storage_p1.rs index 376c203a..e02d2c80 100644 --- a/crates/rginx-http/src/cache/tests/storage_p1.rs +++ b/crates/rginx-http/src/cache/tests/storage_p1.rs @@ -1,3 +1,5 @@ +mod edge_cases; + use std::collections::HashMap; use std::sync::atomic::AtomicU64; use std::sync::{Arc, Mutex, RwLock}; @@ -13,8 +15,6 @@ use crate::handler::full_body; use super::*; -mod edge_cases; - #[tokio::test] async fn cache_manager_requires_min_uses_before_storing() { let temp = tempfile::tempdir().expect("cache temp dir should exist"); diff --git a/crates/rginx-http/src/cache/tests/storage_p2.rs b/crates/rginx-http/src/cache/tests/storage_p2.rs index 8b7f4480..19cb9a6e 100644 --- a/crates/rginx-http/src/cache/tests/storage_p2.rs +++ b/crates/rginx-http/src/cache/tests/storage_p2.rs @@ -1,4 +1,3 @@ -use super::*; - mod cross_process_fill; mod shared_index; +use super::*; diff --git a/crates/rginx-http/src/cache/tests/storage_p2/cross_process_fill.rs b/crates/rginx-http/src/cache/tests/storage_p2/cross_process_fill.rs index e8b070fc..0a5c6332 100644 --- a/crates/rginx-http/src/cache/tests/storage_p2/cross_process_fill.rs +++ b/crates/rginx-http/src/cache/tests/storage_p2/cross_process_fill.rs @@ -1,3 +1,7 @@ +mod coordination; +mod head_convert; +mod streaming; + use std::time::SystemTime; use bytes::Bytes; @@ -11,9 +15,6 @@ use tokio::time::timeout; use crate::handler::{BoxError, boxed_body, full_body}; use super::*; -mod coordination; -mod head_convert; -mod streaming; #[cfg(target_os = "linux")] fn shared_memory_test_config(path: std::path::PathBuf) -> Arc { diff --git a/crates/rginx-http/src/cache/tests/storage_p2/shared_index.rs b/crates/rginx-http/src/cache/tests/storage_p2/shared_index.rs index 4a6f3b4c..74e8a009 100644 --- a/crates/rginx-http/src/cache/tests/storage_p2/shared_index.rs +++ b/crates/rginx-http/src/cache/tests/storage_p2/shared_index.rs @@ -1,3 +1,8 @@ +mod manager_sync; +mod shared_memory_basic; +mod shared_memory_recovery; +mod sync_regressions; + use std::sync::atomic::Ordering; use std::time::{Duration, SystemTime}; @@ -11,11 +16,6 @@ use crate::handler::full_body; use super::*; -mod manager_sync; -mod shared_memory_basic; -mod shared_memory_recovery; -mod sync_regressions; - fn test_manager_with_max_size(path: std::path::PathBuf, max_size_bytes: usize) -> CacheManager { let config = Arc::new(CacheZone { name: "default".to_string(), diff --git a/crates/rginx-http/src/cache/tests/storage_p2/shared_index/shared_memory.rs b/crates/rginx-http/src/cache/tests/storage_p2/shared_index/shared_memory.rs index 52ffbb8a..f75968a4 100644 --- a/crates/rginx-http/src/cache/tests/storage_p2/shared_index/shared_memory.rs +++ b/crates/rginx-http/src/cache/tests/storage_p2/shared_index/shared_memory.rs @@ -1,4 +1,4 @@ -use super::*; - mod shared_memory_basic; mod shared_memory_recovery; + +use super::*; diff --git a/crates/rginx-http/src/cache/tests/storage_p2/shared_index/sync_regressions.rs b/crates/rginx-http/src/cache/tests/storage_p2/shared_index/sync_regressions.rs index 0c4f90c2..1e7866fd 100644 --- a/crates/rginx-http/src/cache/tests/storage_p2/shared_index/sync_regressions.rs +++ b/crates/rginx-http/src/cache/tests/storage_p2/shared_index/sync_regressions.rs @@ -1,9 +1,10 @@ -use super::*; #[path = "bootstrap_recovery.rs"] mod bootstrap_recovery; #[path = "delta_eviction.rs"] mod delta_eviction; +use super::*; + fn shared_index_test_zone(path: &std::path::Path) -> CacheZone { CacheZone { name: "default".to_string(), diff --git a/crates/rginx-http/src/cache/tests/storage_p4.rs b/crates/rginx-http/src/cache/tests/storage_p4.rs index 27d4856f..3fa5b03d 100644 --- a/crates/rginx-http/src/cache/tests/storage_p4.rs +++ b/crates/rginx-http/src/cache/tests/storage_p4.rs @@ -1,6 +1,5 @@ -use super::*; - mod lifecycle; mod ranges; mod streaming; mod termination; +use super::*; diff --git a/crates/rginx-http/src/cache/tests/storage_p4/lifecycle.rs b/crates/rginx-http/src/cache/tests/storage_p4/lifecycle.rs index 5d1d7152..5578cc24 100644 --- a/crates/rginx-http/src/cache/tests/storage_p4/lifecycle.rs +++ b/crates/rginx-http/src/cache/tests/storage_p4/lifecycle.rs @@ -1,3 +1,6 @@ +mod grace_keep; +mod pass_markers; +mod stale_reason; use std::time::{Duration, SystemTime}; use bytes::Bytes; @@ -12,7 +15,3 @@ use crate::cache::runtime::{ use crate::handler::full_body; use super::*; - -mod grace_keep; -mod pass_markers; -mod stale_reason; diff --git a/crates/rginx-http/src/cache/tests/storage_p4/streaming.rs b/crates/rginx-http/src/cache/tests/storage_p4/streaming.rs index 4c2d0c61..552cdb77 100644 --- a/crates/rginx-http/src/cache/tests/storage_p4/streaming.rs +++ b/crates/rginx-http/src/cache/tests/storage_p4/streaming.rs @@ -1,3 +1,5 @@ +mod committed_hits; + use std::time::Duration; use bytes::Bytes; @@ -12,8 +14,6 @@ use crate::handler::{BoxError, boxed_body, full_body}; use super::*; -mod committed_hits; - #[tokio::test] async fn cache_manager_caches_unknown_size_response_after_stream_completion() { let temp = tempfile::tempdir().expect("cache temp dir should exist"); diff --git a/crates/rginx-http/src/cache/tests/storage_p4/termination.rs b/crates/rginx-http/src/cache/tests/storage_p4/termination.rs index 184d17ce..df1664ba 100644 --- a/crates/rginx-http/src/cache/tests/storage_p4/termination.rs +++ b/crates/rginx-http/src/cache/tests/storage_p4/termination.rs @@ -11,6 +11,88 @@ use crate::handler::{boxed_body, full_body}; use super::*; +#[derive(Default)] +struct LateEndStreamBody { + state: Cell, +} + +impl hyper::body::Body for LateEndStreamBody { + type Data = Bytes; + type Error = crate::handler::BoxError; + + fn is_end_stream(&self) -> bool { + match self.state.get() { + 1 => { + self.state.set(2); + false + } + 2 | 3 => true, + _ => false, + } + } + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.state.get() { + 0 => { + self.state.set(1); + Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b"late"))))) + } + 1 => Poll::Pending, + 2 | 3 => { + self.state.set(3); + Poll::Ready(None) + } + state => panic!("unexpected poll state: {state}"), + } + } +} + +#[derive(Default)] +struct ExactSizeTrailersBody { + state: Cell, +} + +impl hyper::body::Body for ExactSizeTrailersBody { + type Data = Bytes; + type Error = crate::handler::BoxError; + + fn is_end_stream(&self) -> bool { + matches!(self.state.get(), 2 | 3) + } + + fn poll_frame( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + match self.state.get() { + 0 => { + self.state.set(1); + Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b"late"))))) + } + 1 => { + self.state.set(2); + let mut trailers = http::HeaderMap::new(); + trailers.insert("x-end", http::HeaderValue::from_static("done")); + Poll::Ready(Some(Ok(Frame::trailers(trailers)))) + } + 2 | 3 => { + self.state.set(3); + Poll::Ready(None) + } + state => panic!("unexpected poll state: {state}"), + } + } + + fn size_hint(&self) -> hyper::body::SizeHint { + let mut hint = hyper::body::SizeHint::default(); + hint.set_exact(4); + hint + } +} + #[tokio::test] async fn cache_manager_ends_downstream_stream_when_upstream_ends() { let temp = tempfile::tempdir().expect("cache temp dir should exist"); @@ -94,85 +176,3 @@ async fn cache_manager_preserves_trailers_after_exact_size_hint() { let body = response.into_body().collect().await.unwrap().to_bytes(); assert_eq!(body.as_ref(), b"late"); } - -#[derive(Default)] -struct LateEndStreamBody { - state: Cell, -} - -impl hyper::body::Body for LateEndStreamBody { - type Data = Bytes; - type Error = crate::handler::BoxError; - - fn poll_frame( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match self.state.get() { - 0 => { - self.state.set(1); - Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b"late"))))) - } - 1 => Poll::Pending, - 2 | 3 => { - self.state.set(3); - Poll::Ready(None) - } - state => panic!("unexpected poll state: {state}"), - } - } - - fn is_end_stream(&self) -> bool { - match self.state.get() { - 1 => { - self.state.set(2); - false - } - 2 | 3 => true, - _ => false, - } - } -} - -#[derive(Default)] -struct ExactSizeTrailersBody { - state: Cell, -} - -impl hyper::body::Body for ExactSizeTrailersBody { - type Data = Bytes; - type Error = crate::handler::BoxError; - - fn poll_frame( - self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll, Self::Error>>> { - match self.state.get() { - 0 => { - self.state.set(1); - Poll::Ready(Some(Ok(Frame::data(Bytes::from_static(b"late"))))) - } - 1 => { - self.state.set(2); - let mut trailers = http::HeaderMap::new(); - trailers.insert("x-end", http::HeaderValue::from_static("done")); - Poll::Ready(Some(Ok(Frame::trailers(trailers)))) - } - 2 | 3 => { - self.state.set(3); - Poll::Ready(None) - } - state => panic!("unexpected poll state: {state}"), - } - } - - fn is_end_stream(&self) -> bool { - matches!(self.state.get(), 2 | 3) - } - - fn size_hint(&self) -> hyper::body::SizeHint { - let mut hint = hyper::body::SizeHint::default(); - hint.set_exact(4); - hint - } -} diff --git a/crates/rginx-http/src/client_ip.rs b/crates/rginx-http/src/client_ip.rs index ee659475..08c5f2d2 100644 --- a/crates/rginx-http/src/client_ip.rs +++ b/crates/rginx-http/src/client_ip.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::net::{IpAddr, SocketAddr}; use http::HeaderMap; @@ -5,9 +7,9 @@ use rginx_core::Server; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ClientIpSource { + ClientIpHeader, SocketPeer, XForwardedFor, - ClientIpHeader, } impl ClientIpSource { @@ -22,30 +24,30 @@ impl ClientIpSource { #[derive(Debug, Clone, PartialEq, Eq)] pub struct ClientAddress { - pub peer_addr: SocketAddr, pub client_ip: IpAddr, pub forwarded_for: String, + pub peer_addr: SocketAddr, pub source: ClientIpSource, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct TlsClientIdentity { - pub subject: Option, - pub issuer: Option, - pub serial_number: Option, - pub san_dns_names: Vec, pub chain_length: usize, pub chain_subjects: Vec, + pub issuer: Option, + pub san_dns_names: Vec, + pub serial_number: Option, + pub subject: Option, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConnectionPeerAddrs { - pub socket_peer_addr: SocketAddr, + pub early_data: bool, pub proxy_protocol_source_addr: Option, + pub socket_peer_addr: SocketAddr, + pub tls_alpn: Option, pub tls_client_identity: Option, pub tls_version: Option, - pub tls_alpn: Option, - pub early_data: bool, } pub fn resolve_client_address( @@ -152,6 +154,3 @@ fn select_client_ip_with_immediate_peer( .find(|ip| !server.is_trusted_proxy(*ip)) .unwrap_or(chain[0]) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/compression/accept_encoding.rs b/crates/rginx-http/src/compression/accept_encoding.rs index 4a92f351..dba0a8a8 100644 --- a/crates/rginx-http/src/compression/accept_encoding.rs +++ b/crates/rginx-http/src/compression/accept_encoding.rs @@ -23,7 +23,7 @@ impl ContentCoding { pub(super) fn merge_vary_header(headers: &mut HeaderMap, token: &str) { let mut values = Vec::::new(); - for value in headers.get_all(VARY).iter() { + for value in &headers.get_all(VARY) { let Ok(value) = value.to_str() else { continue; }; @@ -58,6 +58,13 @@ pub(super) fn preferred_response_encoding(headers: &HeaderMap) -> Option f32 { + match coding { + ContentCoding::Brotli => self.brotli.or(self.wildcard).unwrap_or(0.0), + ContentCoding::Gzip => self.gzip.or(self.wildcard).unwrap_or(0.0), + } + } + fn record(&mut self, coding: &str, q: f32) { let slot = if coding.eq_ignore_ascii_case("br") { Some(&mut self.brotli) @@ -74,13 +81,6 @@ pub(super) fn preferred_response_encoding(headers: &HeaderMap) -> Option f32 { - match coding { - ContentCoding::Brotli => self.brotli.or(self.wildcard).unwrap_or(0.0), - ContentCoding::Gzip => self.gzip.or(self.wildcard).unwrap_or(0.0), - } - } } let mut accepted = AcceptedEncodings::default(); diff --git a/crates/rginx-http/src/compression/mod.rs b/crates/rginx-http/src/compression/mod.rs index 53aef642..b2152c5d 100644 --- a/crates/rginx-http/src/compression/mod.rs +++ b/crates/rginx-http/src/compression/mod.rs @@ -1,3 +1,10 @@ +mod accept_encoding; +mod content_type; +mod encode; +mod options; +#[cfg(test)] +mod tests; + use bytes::Bytes; use http::header::{ACCEPT_RANGES, CONTENT_ENCODING, CONTENT_LENGTH, HeaderValue}; use http::{HeaderMap, Method, Response, StatusCode}; @@ -7,13 +14,6 @@ use rginx_core::RouteBufferingPolicy; use crate::handler::{HttpBody, HttpResponse, full_body}; -mod accept_encoding; -mod content_type; -mod encode; -mod options; -#[cfg(test)] -mod tests; - use accept_encoding::{ContentCoding, merge_vary_header, preferred_response_encoding}; use content_type::response_is_eligible; use encode::compress_bytes; diff --git a/crates/rginx-http/src/compression/options.rs b/crates/rginx-http/src/compression/options.rs index a2aa49f8..55d7b951 100644 --- a/crates/rginx-http/src/compression/options.rs +++ b/crates/rginx-http/src/compression/options.rs @@ -6,10 +6,10 @@ const MIN_COMPRESSIBLE_RESPONSE_BYTES: usize = 256; #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) struct ResponseCompressionOptions<'a> { - pub(crate) response_buffering: RouteBufferingPolicy, pub(crate) compression: RouteCompressionPolicy, - pub(crate) compression_min_bytes: Option, pub(crate) compression_content_types: Cow<'a, [String]>, + pub(crate) compression_min_bytes: Option, + pub(crate) response_buffering: RouteBufferingPolicy, } impl Default for ResponseCompressionOptions<'_> { @@ -23,7 +23,7 @@ impl Default for ResponseCompressionOptions<'_> { } } -impl<'a> ResponseCompressionOptions<'a> { +impl ResponseCompressionOptions<'_> { pub(crate) fn min_bytes(&self) -> usize { match self.compression { RouteCompressionPolicy::Force => self.compression_min_bytes.unwrap_or(1), diff --git a/crates/rginx-http/src/compression/tests.rs b/crates/rginx-http/src/compression/tests.rs index f9e2bfe8..fec7ce9a 100644 --- a/crates/rginx-http/src/compression/tests.rs +++ b/crates/rginx-http/src/compression/tests.rs @@ -26,6 +26,10 @@ impl hyper::body::Body for CollectErrorBody { type Data = Bytes; type Error = std::io::Error; + fn is_end_stream(&self) -> bool { + false + } + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -33,10 +37,6 @@ impl hyper::body::Body for CollectErrorBody { Poll::Ready(Some(Err(std::io::Error::other("collect failed")))) } - fn is_end_stream(&self) -> bool { - false - } - fn size_hint(&self) -> SizeHint { let mut hint = SizeHint::new(); hint.set_exact(512); diff --git a/crates/rginx-http/src/handler/access_log.rs b/crates/rginx-http/src/handler/access_log.rs index 599510c6..580bdb7d 100644 --- a/crates/rginx-http/src/handler/access_log.rs +++ b/crates/rginx-http/src/handler/access_log.rs @@ -7,56 +7,56 @@ use rginx_core::{AccessLogFormat, AccessLogValues}; #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct UpstreamAccessLog { - pub(crate) upstream_name: String, pub(crate) upstream_addr: String, - pub(crate) upstream_status: u16, + pub(crate) upstream_name: String, pub(crate) upstream_response_time_ms: u64, + pub(crate) upstream_status: u16, } pub(super) struct AccessLogContext<'a> { - pub(crate) request_id: &'a str, - pub(crate) method: &'a str, + pub(crate) body_bytes_sent: Option, + pub(crate) cache_status: Option<&'a str>, + pub(crate) client_address: &'a ClientAddress, + pub(crate) downstream_scheme: &'a str, + pub(crate) elapsed_ms: u64, + pub(crate) grpc: Option<&'a GrpcObservability>, pub(crate) host: &'a str, + pub(crate) method: &'a str, pub(crate) path: &'a str, - pub(crate) request_version: Version, - pub(crate) user_agent: Option<&'a str>, pub(crate) referer: Option<&'a str>, - pub(crate) client_address: &'a ClientAddress, - pub(crate) vhost: &'a str, + pub(crate) request_id: &'a str, + pub(crate) request_version: Version, pub(crate) route: &'a str, pub(crate) status: u16, - pub(crate) elapsed_ms: u64, - pub(crate) downstream_scheme: &'a str, - pub(crate) tls_version: Option<&'a str>, pub(crate) tls_alpn: Option<&'a str>, - pub(crate) body_bytes_sent: Option, pub(crate) tls_client_identity: Option<&'a TlsClientIdentity>, - pub(crate) grpc: Option<&'a GrpcObservability>, - pub(crate) cache_status: Option<&'a str>, + pub(crate) tls_version: Option<&'a str>, pub(crate) upstream: Option<&'a UpstreamAccessLog>, + pub(crate) user_agent: Option<&'a str>, + pub(crate) vhost: &'a str, } #[derive(Debug, Clone)] pub(super) struct OwnedAccessLogContext { - pub(crate) request_id: String, - pub(crate) method: String, + pub(crate) body_bytes_sent: Option, + pub(crate) cache_status: Option, + pub(crate) client_address: ClientAddress, + pub(crate) downstream_scheme: String, + pub(crate) elapsed_ms: u64, pub(crate) host: String, + pub(crate) method: String, pub(crate) path: String, - pub(crate) request_version: Version, - pub(crate) user_agent: Option, pub(crate) referer: Option, - pub(crate) client_address: ClientAddress, - pub(crate) vhost: String, + pub(crate) request_id: String, + pub(crate) request_version: Version, pub(crate) route: String, pub(crate) status: u16, - pub(crate) elapsed_ms: u64, - pub(crate) downstream_scheme: String, - pub(crate) tls_version: Option, pub(crate) tls_alpn: Option, - pub(crate) body_bytes_sent: Option, pub(crate) tls_client_identity: Option, - pub(crate) cache_status: Option, + pub(crate) tls_version: Option, pub(crate) upstream: Option, + pub(crate) user_agent: Option, + pub(crate) vhost: String, } impl OwnedAccessLogContext { @@ -136,27 +136,20 @@ pub(super) fn log_access_event(format: Option<&AccessLogFormat>, context: Access cache_status = context.cache_status.unwrap_or("-"), upstream_name = context .upstream - .map(|upstream| upstream.upstream_name.as_str()) - .unwrap_or("-"), + .map_or("-", |upstream| upstream.upstream_name.as_str()), upstream_addr = context .upstream - .map(|upstream| upstream.upstream_addr.as_str()) - .unwrap_or("-"), + .map_or("-", |upstream| upstream.upstream_addr.as_str()), upstream_status = context - .upstream - .map(|upstream| upstream.upstream_status.to_string()) - .unwrap_or_else(|| "-".to_string()), + .upstream.map_or_else(|| "-".to_string(), |upstream| upstream.upstream_status.to_string()), upstream_response_time_ms = context - .upstream - .map(|upstream| upstream.upstream_response_time_ms.to_string()) - .unwrap_or_else(|| "-".to_string()), + .upstream.map_or_else(|| "-".to_string(), |upstream| upstream.upstream_response_time_ms.to_string()), tls_client_san_dns_names = joined_tls_client_san_dns_names(context.tls_client_identity) .as_deref() .unwrap_or("-"), tls_client_chain_length = context .tls_client_identity - .map(|identity| identity.chain_length) - .unwrap_or(0), + .map_or(0, |identity| identity.chain_length), tls_client_chain_subjects = joined_tls_client_chain_subjects(context.tls_client_identity) .as_deref() .unwrap_or("-"), diff --git a/crates/rginx-http/src/handler/dispatch/date.rs b/crates/rginx-http/src/handler/dispatch/date.rs index 1e6361e5..0564d01c 100644 --- a/crates/rginx-http/src/handler/dispatch/date.rs +++ b/crates/rginx-http/src/handler/dispatch/date.rs @@ -11,19 +11,11 @@ struct CachedHttpDate { value: HeaderValue, } -pub(super) fn current_http_date() -> HeaderValue { - let unix_epoch_seconds = current_unix_epoch_seconds(); - let cache = HTTP_DATE_CACHE.get_or_init(|| Mutex::new(CachedHttpDate::new(unix_epoch_seconds))); - let mut cached = cache.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - if cached.unix_epoch_seconds != unix_epoch_seconds { - *cached = CachedHttpDate::new(unix_epoch_seconds); - } - cached.value.clone() -} - impl CachedHttpDate { fn new(unix_epoch_seconds: u64) -> Self { - let timestamp = UNIX_EPOCH + Duration::from_secs(unix_epoch_seconds); + let timestamp = UNIX_EPOCH + .checked_add(Duration::from_secs(unix_epoch_seconds)) + .expect("HTTP date timestamp remains representable"); let value = httpdate::fmt_http_date(timestamp); let value = HeaderValue::from_str(&value).expect("formatted HTTP date should be a valid header"); @@ -31,6 +23,16 @@ impl CachedHttpDate { } } +pub(super) fn current_http_date() -> HeaderValue { + let unix_epoch_seconds = current_unix_epoch_seconds(); + let cache = HTTP_DATE_CACHE.get_or_init(|| Mutex::new(CachedHttpDate::new(unix_epoch_seconds))); + let mut cached = cache.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + if cached.unix_epoch_seconds != unix_epoch_seconds { + *cached = CachedHttpDate::new(unix_epoch_seconds); + } + cached.value.clone() +} + fn current_unix_epoch_seconds() -> u64 { match SystemTime::now().duration_since(UNIX_EPOCH) { Ok(duration) => duration.as_secs(), diff --git a/crates/rginx-http/src/handler/dispatch/file.rs b/crates/rginx-http/src/handler/dispatch/file.rs index e85201e1..aa8362c2 100644 --- a/crates/rginx-http/src/handler/dispatch/file.rs +++ b/crates/rginx-http/src/handler/dispatch/file.rs @@ -1,3 +1,11 @@ +mod autoindex; +mod headers; +mod path; +mod range; +mod stream; + +#[cfg(test)] +mod tests; use std::fs::Metadata; use std::io::SeekFrom; use std::path::{Component, Path, PathBuf}; @@ -18,14 +26,6 @@ use tokio::task::JoinHandle; use crate::handler::{BoxError, HttpResponse, boxed_body, full_body, text_response}; -const FILE_STREAM_CHUNK_BYTES: usize = 64 * 1024; - -mod autoindex; -mod headers; -mod path; -mod range; -mod stream; - use autoindex::{apply_autoindex_headers, render_directory_listing}; use headers::{FileResponseHeaders, apply_file_headers, build_file_response_headers}; use path::{ @@ -35,6 +35,8 @@ use path::{ use range::{RangeRequest, conditional_not_modified, parse_single_range}; use stream::{FileStreamInstrumentation, streaming_file_body}; +const FILE_STREAM_CHUNK_BYTES: usize = 64 * 1024; + enum FileAccessError { Forbidden(&'static str), NotFound(&'static str), @@ -46,8 +48,8 @@ enum ResolveFileResult { } struct ResolvedFile { - path: PathBuf, metadata: Metadata, + path: PathBuf, } pub(super) async fn serve_file_response( @@ -106,7 +108,7 @@ pub(super) async fn serve_file_response( RangeRequest::Satisfiable { start, end } if *request_method == Method::HEAD => { let mut response = http::Response::builder() .status(StatusCode::PARTIAL_CONTENT) - .header(CONTENT_LENGTH, (end - start + 1).to_string()) + .header(CONTENT_LENGTH, end.saturating_sub(start).saturating_add(1).to_string()) .header(CONTENT_RANGE, format!("bytes {start}-{end}/{content_len}")); response = apply_file_headers( response, @@ -146,7 +148,7 @@ pub(super) async fn serve_file_response( match range_request { RangeRequest::Satisfiable { start, end } => { - let body_len = end - start + 1; + let body_len = end.saturating_sub(start).saturating_add(1); let mut response = http::Response::builder() .status(StatusCode::PARTIAL_CONTENT) .header(CONTENT_LENGTH, body_len.to_string()) @@ -265,6 +267,3 @@ async fn resolve_index_file( } Ok(None) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/handler/dispatch/file/headers.rs b/crates/rginx-http/src/handler/dispatch/file/headers.rs index de6aafec..4a81c75c 100644 --- a/crates/rginx-http/src/handler/dispatch/file/headers.rs +++ b/crates/rginx-http/src/handler/dispatch/file/headers.rs @@ -4,12 +4,12 @@ use super::{ }; pub(super) struct FileResponseHeaders { + pub(super) cache_control: Option, + pub(super) content_disposition: Option, pub(super) content_type: String, pub(super) etag: Option, - pub(super) last_modified: Option, - pub(super) cache_control: Option, pub(super) expires: Option, - pub(super) content_disposition: Option, + pub(super) last_modified: Option, } pub(super) fn build_file_response_headers( @@ -88,8 +88,7 @@ pub(super) fn content_type_for_path(path: &Path, file_route: &rginx_core::FileRo } default_content_type_for_path(path) - .map(str::to_string) - .unwrap_or_else(|| file_route.default_content_type.clone()) + .map_or_else(|| file_route.default_content_type.clone(), str::to_string) } fn build_etag(metadata: &Metadata) -> String { @@ -97,8 +96,7 @@ fn build_etag(metadata: &Metadata) -> String { .modified() .ok() .and_then(|time| time.duration_since(UNIX_EPOCH).ok()) - .map(|duration| duration.as_secs()) - .unwrap_or(0); + .map_or(0, |duration| duration.as_secs()); format!("\"{:x}-{:x}\"", metadata.len(), modified) } @@ -107,8 +105,7 @@ fn build_last_modified(metadata: &Metadata) -> Option { } fn default_content_type_for_path(path: &Path) -> Option<&'static str> { - let extension = - path.extension().and_then(|value| value.to_str()).map(|value| value.to_ascii_lowercase()); + let extension = path.extension().and_then(|value| value.to_str()).map(str::to_ascii_lowercase); match extension.as_deref() { Some("html" | "htm") => Some("text/html; charset=utf-8"), diff --git a/crates/rginx-http/src/handler/dispatch/file/range.rs b/crates/rginx-http/src/handler/dispatch/file/range.rs index 609f588e..2e2ddf36 100644 --- a/crates/rginx-http/src/handler/dispatch/file/range.rs +++ b/crates/rginx-http/src/handler/dispatch/file/range.rs @@ -81,7 +81,7 @@ pub(super) fn parse_single_range( return RangeRequest::Unsatisfiable; } let start = body_len.saturating_sub(suffix_len); - return RangeRequest::Satisfiable { start, end: body_len - 1 }; + return RangeRequest::Satisfiable { start, end: body_len.saturating_sub(1) }; } let Ok(start) = start.parse::() else { @@ -91,12 +91,12 @@ pub(super) fn parse_single_range( return RangeRequest::Unsatisfiable; } let end = if end.is_empty() { - body_len - 1 + body_len.saturating_sub(1) } else { let Ok(end) = end.parse::() else { return RangeRequest::Ignore; }; - end.min(body_len - 1) + end.min(body_len.saturating_sub(1)) }; if start > end { return RangeRequest::Unsatisfiable; diff --git a/crates/rginx-http/src/handler/dispatch/file/stream.rs b/crates/rginx-http/src/handler/dispatch/file/stream.rs index 41e0c71b..af16dd50 100644 --- a/crates/rginx-http/src/handler/dispatch/file/stream.rs +++ b/crates/rginx-http/src/handler/dispatch/file/stream.rs @@ -9,21 +9,20 @@ pub(super) struct FileStreamInstrumentation { } impl FileStreamInstrumentation { - fn record_read(&self, bytes: usize) { - self.bytes_read.fetch_add(bytes, Ordering::Relaxed); - } - #[cfg(test)] pub(super) fn bytes_read(&self) -> usize { self.bytes_read.load(Ordering::Relaxed) } + fn record_read(&self, bytes: usize) { + self.bytes_read.fetch_add(bytes, Ordering::Relaxed); + } } struct FileBody { - rx: mpsc::Receiver, BoxError>>, - size_hint: SizeHint, done: bool, join_handle: Option>, + rx: mpsc::Receiver, BoxError>>, + size_hint: SizeHint, } impl FileBody { @@ -57,6 +56,10 @@ impl hyper::body::Body for FileBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -71,10 +74,6 @@ impl hyper::body::Body for FileBody { } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { self.size_hint.clone() } @@ -117,7 +116,7 @@ async fn stream_file_body( .await; return; } - Ok(read) => filled += read, + Ok(read) => filled = filled.saturating_add(read), Err(error) => { let _ = tx.send(Err(error.into())).await; return; @@ -125,7 +124,7 @@ async fn stream_file_body( } } - remaining_bytes -= chunk_len; + remaining_bytes = remaining_bytes.saturating_sub(chunk_len); instrumentation.record_read(chunk_len); if tx.send(Ok(Frame::data(Bytes::copy_from_slice(&buffer[..chunk_len])))).await.is_err() { return; diff --git a/crates/rginx-http/src/handler/dispatch/mod.rs b/crates/rginx-http/src/handler/dispatch/mod.rs index 8a600105..939ddd5b 100644 --- a/crates/rginx-http/src/handler/dispatch/mod.rs +++ b/crates/rginx-http/src/handler/dispatch/mod.rs @@ -1,3 +1,12 @@ +mod acme; +mod authorize; +mod date; +mod file; +mod phases; +mod response; +mod route; +mod select; + use std::sync::Arc; use std::time::Instant; @@ -13,16 +22,7 @@ use super::grpc::{ GrpcStatsContext, OwnedGrpcRequestMetadata, grpc_request_metadata, wrap_grpc_observability_response, }; -use super::*; - -mod acme; -mod authorize; -mod date; -mod file; -mod phases; -mod response; -mod route; -mod select; +use super::{EarlyDataFlag, HttpBody, HttpResponse, full_body, grpc_error_response, text_response}; #[cfg(test)] pub(super) use authorize::authorize_route; @@ -41,18 +41,18 @@ use response::alt_svc_header_value; const REQUEST_ID_HEADER: &str = "x-request-id"; struct DispatchRequestMetadata { - method: Method, - request_version: Version, - request_headers: HeaderMap, + early_data: bool, + grpc_request: Option, host: String, + method: Method, path: String, - user_agent: Option, referer: Option, + request_headers: HeaderMap, request_id: String, request_id_header: HeaderValue, - grpc_request: Option, + request_version: Version, tls_client_identity: Option, - early_data: bool, + user_agent: Option, } pub async fn handle( diff --git a/crates/rginx-http/src/handler/dispatch/phases.rs b/crates/rginx-http/src/handler/dispatch/phases.rs index 7783e797..df2723b1 100644 --- a/crates/rginx-http/src/handler/dispatch/phases.rs +++ b/crates/rginx-http/src/handler/dispatch/phases.rs @@ -1,3 +1,7 @@ +mod error_page; + +#[cfg(test)] +mod tests; use std::borrow::Cow; use http::header::{CONTENT_LENGTH, LOCATION}; @@ -16,10 +20,6 @@ use super::select::{route_match_context, select_vhost_for_request}; use super::{HttpBody, HttpResponse, grpc_error_response, grpc_request_metadata, text_response}; use crate::handler::grpc::GrpcStatusCode; -const MAX_INTERNAL_REDIRECTS: usize = 16; - -mod error_page; - #[cfg(test)] use error_page::error_page_request_method; use error_page::{ @@ -27,36 +27,28 @@ use error_page::{ content_phase_redirect_target, error_page_target, }; -#[cfg(test)] -pub(in crate::handler) fn internal_redirect_status() -> StatusCode { - error_page::internal_redirect_status() -} - -#[cfg(test)] -pub(in crate::handler) fn internal_redirect_location(target: &str) -> String { - error_page::internal_redirect_location(target) -} +const MAX_INTERNAL_REDIRECTS: usize = 16; #[derive(Clone, Copy)] pub(super) struct ListenerRequestContext<'a> { pub(super) listener_id: &'a str, pub(super) listener_tls_enabled: bool, - pub(super) request_body_read_timeout: Option, pub(super) max_request_body_bytes: Option, + pub(super) request_body_read_timeout: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(super) enum RequestPhase { - SelectServer, - Rewrite, Access, Content, + Rewrite, + SelectServer, } #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct InternalRedirectState { - redirects: usize, max_redirects: usize, + redirects: usize, } impl Default for InternalRedirectState { @@ -74,7 +66,7 @@ impl InternalRedirectState { }); } - self.redirects += 1; + self.redirects = self.redirects.saturating_add(1); Ok(()) } } @@ -85,24 +77,24 @@ pub(super) enum InternalRedirectError { } pub(super) struct PhaseRequestContext<'a> { - pub(super) state: &'a SharedState, pub(super) active: &'a ActiveState, - pub(super) listener_id: &'a str, - pub(super) listener: &'a rginx_core::Listener, pub(super) client_address: &'a ClientAddress, + pub(super) early_data: bool, + pub(super) listener: &'a rginx_core::Listener, + pub(super) listener_id: &'a str, pub(super) request_headers: &'a http::HeaderMap, pub(super) request_id: &'a str, - pub(super) early_data: bool, + pub(super) state: &'a SharedState, } #[derive(Clone)] pub(super) struct PhaseSelection<'a> { - pub(super) vhost_id: &'a str, - pub(super) route_id: &'a str, - pub(super) selected_route_id: Option<&'a str>, + mode: SelectionModeOwned, pub(super) response_compression_options: ResponseCompressionOptions<'static>, pub(super) route: Option<&'a rginx_core::Route>, - mode: SelectionModeOwned, + pub(super) route_id: &'a str, + pub(super) selected_route_id: Option<&'a str>, + pub(super) vhost_id: &'a str, } #[derive(Clone, Copy)] @@ -117,6 +109,49 @@ enum SelectionModeOwned { NamedRoute, } +struct RewriteResult { + stop: rginx_core::RouteRewriteStop, + uri: Uri, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum TryFilesResolution { + ContinueCurrentRoute, + Named(String), + Status(StatusCode), + Uri(String), +} + +pub(super) struct PhaseExecution { + pub(super) response: HttpResponse, + pub(super) response_compression_options: ResponseCompressionOptions<'static>, + pub(super) route_id: String, + pub(super) selected_route_id: Option, + pub(super) vhost_id: String, +} + +impl PhaseExecution { + fn from_response(selection: PhaseSelection<'_>, response: HttpResponse) -> Self { + Self { + vhost_id: selection.vhost_id.to_string(), + route_id: selection.route_id.to_string(), + selected_route_id: selection.selected_route_id.map(str::to_string), + response_compression_options: selection.response_compression_options, + response, + } + } +} + +#[cfg(test)] +pub(in crate::handler) fn internal_redirect_status() -> StatusCode { + error_page::internal_redirect_status() +} + +#[cfg(test)] +pub(in crate::handler) fn internal_redirect_location(target: &str) -> String { + error_page::internal_redirect_location(target) +} + pub(super) async fn execute_request_phases( request: Request, context: PhaseRequestContext<'_>, @@ -247,14 +282,13 @@ pub(super) async fn execute_request_phases( )); phase = RequestPhase::Access; continue; - } else { - current_uri = redirect_uri(¤t_uri, &redirect_target); - request = request - .map(|request| rebuild_request_uri(request, current_uri.clone())); - selection = None; - phase = RequestPhase::SelectServer; - continue; } + current_uri = redirect_uri(¤t_uri, &redirect_target); + request = + request.map(|request| rebuild_request_uri(request, current_uri.clone())); + selection = None; + phase = RequestPhase::SelectServer; + continue; } let internal_request_template = selected.route.as_ref().and_then(|route| { (!route.error_pages.is_empty()).then(|| { @@ -334,39 +368,6 @@ pub(super) async fn execute_request_phases( } } -struct RewriteResult { - uri: Uri, - stop: rginx_core::RouteRewriteStop, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -enum TryFilesResolution { - ContinueCurrentRoute, - Uri(String), - Named(String), - Status(StatusCode), -} - -pub(super) struct PhaseExecution { - pub(super) vhost_id: String, - pub(super) route_id: String, - pub(super) selected_route_id: Option, - pub(super) response_compression_options: ResponseCompressionOptions<'static>, - pub(super) response: HttpResponse, -} - -impl PhaseExecution { - fn from_response(selection: PhaseSelection<'_>, response: HttpResponse) -> Self { - Self { - vhost_id: selection.vhost_id.to_string(), - route_id: selection.route_id.to_string(), - selected_route_id: selection.selected_route_id.map(str::to_string), - response_compression_options: selection.response_compression_options, - response, - } - } -} - fn select_phase_targets<'a>(context: &PhaseRequestContext<'a>, uri: &Uri) -> PhaseSelection<'a> { let config = context.active.config.as_ref(); select_phase_targets_with_mode(context, config, SelectionMode::ExternalPath(uri)) @@ -403,7 +404,7 @@ fn select_phase_targets_with_mode<'a>( } }; - let route_id = selected_route.map(|route| route.id.as_str()).unwrap_or("__unmatched__"); + let route_id = selected_route.map_or("__unmatched__", |route| route.id.as_str()); let selected_route_id = selected_route.map(|route| route.id.as_str()); let response_compression_options = selected_route.map(owned_response_compression_options).unwrap_or_default(); @@ -670,6 +671,3 @@ fn owned_response_compression_options( compression_content_types: Cow::Owned(route.compression_content_types.clone()), } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/handler/dispatch/phases/error_page.rs b/crates/rginx-http/src/handler/dispatch/phases/error_page.rs index 37ee945a..50f1a991 100644 --- a/crates/rginx-http/src/handler/dispatch/phases/error_page.rs +++ b/crates/rginx-http/src/handler/dispatch/phases/error_page.rs @@ -12,17 +12,17 @@ const INTERNAL_REDIRECT_STATUS_CODE: u16 = 399; const INTERNAL_REDIRECT_LOCATION_PREFIX: &str = "internal:"; pub(super) enum ErrorPageRedirect { - Uri(String), Named(String), Status(StatusCode), + Uri(String), } #[derive(Clone)] pub(super) struct InternalRequestTemplate { + extensions: http::Extensions, + headers: HeaderMap, method: Method, version: http::Version, - headers: HeaderMap, - extensions: http::Extensions, } impl InternalRequestTemplate { diff --git a/crates/rginx-http/src/handler/dispatch/response.rs b/crates/rginx-http/src/handler/dispatch/response.rs index daf373ff..cb348372 100644 --- a/crates/rginx-http/src/handler/dispatch/response.rs +++ b/crates/rginx-http/src/handler/dispatch/response.rs @@ -9,10 +9,10 @@ use super::date::current_http_date; use super::full_body; pub(in crate::handler) struct FinalizedDownstreamResponse { - pub(in crate::handler) response: HttpResponse, - pub(in crate::handler) status: StatusCode, pub(in crate::handler) body_bytes_sent: Option, pub(in crate::handler) grpc: Option, + pub(in crate::handler) response: HttpResponse, + pub(in crate::handler) status: StatusCode, } pub(in crate::handler) async fn finalize_downstream_response( @@ -49,7 +49,7 @@ pub(in crate::handler) async fn finalize_downstream_response( let status = response.status(); let body_bytes_sent = response_body_bytes_sent(method.as_str(), &response); - FinalizedDownstreamResponse { response, status, body_bytes_sent, grpc } + FinalizedDownstreamResponse { body_bytes_sent, grpc, response, status } } pub(in crate::handler) fn response_body_bytes_sent( diff --git a/crates/rginx-http/src/handler/dispatch/select.rs b/crates/rginx-http/src/handler/dispatch/select.rs index b3da1900..595c12f9 100644 --- a/crates/rginx-http/src/handler/dispatch/select.rs +++ b/crates/rginx-http/src/handler/dispatch/select.rs @@ -10,7 +10,7 @@ pub(in crate::handler) fn request_host<'a>(headers: &'a HeaderMap, uri: &'a Uri) headers .get(HOST) .and_then(|value| value.to_str().ok()) - .or_else(|| uri.authority().map(|authority| authority.as_str())) + .or_else(|| uri.authority().map(http::uri::Authority::as_str)) .unwrap_or_default() } diff --git a/crates/rginx-http/src/handler/grpc/error.rs b/crates/rginx-http/src/handler/grpc/error.rs index 2e1c8044..44607a10 100644 --- a/crates/rginx-http/src/handler/grpc/error.rs +++ b/crates/rginx-http/src/handler/grpc/error.rs @@ -11,12 +11,12 @@ use super::metadata::grpc_protocol; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum GrpcStatusCode { Cancelled, - InvalidArgument, DeadlineExceeded, + InvalidArgument, PermissionDenied, ResourceExhausted, - Unimplemented, Unavailable, + Unimplemented, } impl GrpcStatusCode { @@ -131,7 +131,7 @@ fn encode_grpc_web_error_body(trailers: &HeaderMap, is_text: bool) -> Bytes { trailer_block.extend_from_slice(b"\r\n"); } - let mut encoded = Vec::with_capacity(5 + trailer_block.len()); + let mut encoded = Vec::with_capacity(trailer_block.len().saturating_add(5)); encoded.push(0x80); encoded.extend_from_slice(&(trailer_block.len() as u32).to_be_bytes()); encoded.extend_from_slice(&trailer_block); diff --git a/crates/rginx-http/src/handler/grpc/grpc_web.rs b/crates/rginx-http/src/handler/grpc/grpc_web.rs index e55f9fff..37f4b11d 100644 --- a/crates/rginx-http/src/handler/grpc/grpc_web.rs +++ b/crates/rginx-http/src/handler/grpc/grpc_web.rs @@ -9,53 +9,14 @@ use super::observability::GrpcObservability; #[derive(Debug, Default)] pub(in crate::handler) struct GrpcWebObservabilityParser { - is_text: bool, - encoded_carryover: BytesMut, buffer: BytesMut, - saw_trailers: bool, disabled: bool, + encoded_carryover: BytesMut, + is_text: bool, + saw_trailers: bool, } impl GrpcWebObservabilityParser { - pub(in crate::handler) fn for_protocol(protocol: &str) -> Option { - match protocol { - "grpc-web" => Some(Self { is_text: false, ..Self::default() }), - "grpc-web-text" => Some(Self { is_text: true, ..Self::default() }), - _ => None, - } - } - - pub(in crate::handler) fn observe_chunk(&mut self, data: &[u8], grpc: &mut GrpcObservability) { - if self.disabled { - return; - } - - let result = if self.is_text { - decode_grpc_web_text_observability_chunk(&mut self.encoded_carryover, data).and_then( - |decoded| { - if let Some(decoded) = decoded { - self.observe_binary_chunk(&decoded, false, grpc) - } else { - Ok(()) - } - }, - ) - } else { - self.observe_binary_chunk(data, false, grpc) - }; - - if let Err(error) = result { - self.disabled = true; - tracing::debug!( - protocol = %grpc.protocol, - service = %grpc.service, - method = %grpc.method, - %error, - "failed to parse grpc-web response trailers for observability" - ); - } - } - pub(in crate::handler) fn finish(&mut self, grpc: &mut GrpcObservability) { if self.disabled { return; @@ -88,6 +49,14 @@ impl GrpcWebObservabilityParser { } } + pub(in crate::handler) fn for_protocol(protocol: &str) -> Option { + match protocol { + "grpc-web" => Some(Self { is_text: false, ..Self::default() }), + "grpc-web-text" => Some(Self { is_text: true, ..Self::default() }), + _ => None, + } + } + fn observe_binary_chunk( &mut self, data: &[u8], @@ -117,6 +86,37 @@ impl GrpcWebObservabilityParser { } } } + + pub(in crate::handler) fn observe_chunk(&mut self, data: &[u8], grpc: &mut GrpcObservability) { + if self.disabled { + return; + } + + let result = if self.is_text { + decode_grpc_web_text_observability_chunk(&mut self.encoded_carryover, data).and_then( + |decoded| { + if let Some(decoded) = decoded { + self.observe_binary_chunk(&decoded, false, grpc) + } else { + Ok(()) + } + }, + ) + } else { + self.observe_binary_chunk(data, false, grpc) + }; + + if let Err(error) = result { + self.disabled = true; + tracing::debug!( + protocol = %grpc.protocol, + service = %grpc.service, + method = %grpc.method, + %error, + "failed to parse grpc-web response trailers for observability" + ); + } + } } enum ParsedGrpcWebObservabilityFrame { @@ -205,7 +205,8 @@ fn decode_grpc_web_text_observability_chunk( } } - let complete_len = carryover.len() / 4 * 4; + let complete_len = + carryover.len().checked_div(4).expect("base64 block size is nonzero").saturating_mul(4); if complete_len == 0 { return Ok(None); } diff --git a/crates/rginx-http/src/handler/grpc/metadata.rs b/crates/rginx-http/src/handler/grpc/metadata.rs index 472bc543..7e3ab42f 100644 --- a/crates/rginx-http/src/handler/grpc/metadata.rs +++ b/crates/rginx-http/src/handler/grpc/metadata.rs @@ -3,26 +3,16 @@ use http::header::CONTENT_TYPE; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(in crate::handler) struct GrpcRequestMetadata<'a> { + pub(crate) method: &'a str, pub(crate) protocol: &'static str, pub(crate) service: &'a str, - pub(crate) method: &'a str, -} - -pub(in crate::handler) fn grpc_request_metadata<'a>( - headers: &HeaderMap, - request_path: &'a str, -) -> Option> { - let protocol = grpc_protocol(headers)?; - let (service, method) = grpc_service_method(request_path)?; - - Some(GrpcRequestMetadata { protocol, service, method }) } #[derive(Debug, Clone, PartialEq, Eq)] pub(in crate::handler) struct OwnedGrpcRequestMetadata { + pub(crate) method: String, pub(crate) protocol: &'static str, pub(crate) service: String, - pub(crate) method: String, } impl OwnedGrpcRequestMetadata { @@ -45,6 +35,16 @@ impl<'a> From> for OwnedGrpcRequestMetadata { } } +pub(in crate::handler) fn grpc_request_metadata<'a>( + headers: &HeaderMap, + request_path: &'a str, +) -> Option> { + let protocol = grpc_protocol(headers)?; + let (service, method) = grpc_service_method(request_path)?; + + Some(GrpcRequestMetadata { method, protocol, service }) +} + pub(super) fn grpc_protocol(headers: &HeaderMap) -> Option<&'static str> { let (mime, _) = split_header_content_type(crate::handler::dispatch::header_value(headers, CONTENT_TYPE)?); diff --git a/crates/rginx-http/src/handler/grpc/observability.rs b/crates/rginx-http/src/handler/grpc/observability.rs index b825dcbc..6d25b9d3 100644 --- a/crates/rginx-http/src/handler/grpc/observability.rs +++ b/crates/rginx-http/src/handler/grpc/observability.rs @@ -16,27 +16,11 @@ use super::metadata::GrpcRequestMetadata; #[derive(Debug, Clone, PartialEq, Eq)] pub(in crate::handler) struct GrpcObservability { + pub(crate) message: Option, + pub(crate) method: String, pub(crate) protocol: String, pub(crate) service: String, - pub(crate) method: String, pub(crate) status: Option, - pub(crate) message: Option, -} - -pub(in crate::handler) fn grpc_observability( - metadata: Option>, - response_headers: &HeaderMap, -) -> Option { - let metadata = metadata?; - let mut grpc = GrpcObservability { - protocol: metadata.protocol.to_string(), - service: metadata.service.to_string(), - method: metadata.method.to_string(), - status: None, - message: None, - }; - grpc.update_from_headers(response_headers); - Some(grpc) } impl GrpcObservability { @@ -56,28 +40,20 @@ impl GrpcObservability { #[derive(Clone)] pub(in crate::handler) struct GrpcStatsContext { - pub state: SharedState, pub listener_id: String, - pub vhost_id: String, pub route_id: Option, + pub state: SharedState, + pub vhost_id: String, } struct GrpcResponseFinalizer { - format: Option, context: OwnedAccessLogContext, - stats: Option, finalized: bool, + format: Option, + stats: Option, } impl GrpcResponseFinalizer { - fn new( - format: Option, - context: OwnedAccessLogContext, - stats: Option, - ) -> Self { - Self { format, context, stats, finalized: false } - } - fn finalize(&mut self, grpc: &GrpcObservability) { if self.finalized { return; @@ -94,34 +70,24 @@ impl GrpcResponseFinalizer { } log_access_event(self.format.as_ref(), self.context.as_borrowed(Some(grpc))); } + fn new( + format: Option, + context: OwnedAccessLogContext, + stats: Option, + ) -> Self { + Self { format, context, stats, finalized: false } + } } struct GrpcAccessLogBody { - inner: HttpBody, finalizer: GrpcResponseFinalizer, grpc: GrpcObservability, grpc_web: Option, + inner: HttpBody, stream_completed: bool, } impl GrpcAccessLogBody { - fn new(inner: HttpBody, finalizer: GrpcResponseFinalizer, grpc: GrpcObservability) -> Self { - let grpc_web = GrpcWebObservabilityParser::for_protocol(&grpc.protocol); - Self { inner, finalizer, grpc, grpc_web, stream_completed: false } - } - - fn observe_frame(&mut self, frame: &Frame) { - if let Some(trailers) = frame.trailers_ref() { - self.grpc.update_from_headers(trailers); - } - - if let Some(data) = frame.data_ref() - && let Some(parser) = self.grpc_web.as_mut() - { - parser.observe_chunk(data, &mut self.grpc); - } - } - fn finish(&mut self) { if let Some(parser) = self.grpc_web.as_mut() { parser.finish(&mut self.grpc); @@ -139,6 +105,22 @@ impl GrpcAccessLogBody { self.grpc.message = Some("downstream cancelled".to_string()); } } + fn new(inner: HttpBody, finalizer: GrpcResponseFinalizer, grpc: GrpcObservability) -> Self { + let grpc_web = GrpcWebObservabilityParser::for_protocol(&grpc.protocol); + Self { inner, finalizer, grpc, grpc_web, stream_completed: false } + } + + fn observe_frame(&mut self, frame: &Frame) { + if let Some(trailers) = frame.trailers_ref() { + self.grpc.update_from_headers(trailers); + } + + if let Some(data) = frame.data_ref() + && let Some(parser) = self.grpc_web.as_mut() + { + parser.observe_chunk(data, &mut self.grpc); + } + } } impl Drop for GrpcAccessLogBody { @@ -152,6 +134,10 @@ impl hyper::body::Body for GrpcAccessLogBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + fn poll_frame( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -180,15 +166,27 @@ impl hyper::body::Body for GrpcAccessLogBody { } } - fn is_end_stream(&self) -> bool { - self.inner.is_end_stream() - } - fn size_hint(&self) -> SizeHint { self.inner.size_hint() } } +pub(in crate::handler) fn grpc_observability( + metadata: Option>, + response_headers: &HeaderMap, +) -> Option { + let metadata = metadata?; + let mut grpc = GrpcObservability { + protocol: metadata.protocol.to_string(), + service: metadata.service.to_string(), + method: metadata.method.to_string(), + status: None, + message: None, + }; + grpc.update_from_headers(response_headers); + Some(grpc) +} + pub(in crate::handler) fn wrap_grpc_observability_response( response: HttpResponse, format: Option, diff --git a/crates/rginx-http/src/handler/mod.rs b/crates/rginx-http/src/handler/mod.rs index d83d593e..11b641a2 100644 --- a/crates/rginx-http/src/handler/mod.rs +++ b/crates/rginx-http/src/handler/mod.rs @@ -1,3 +1,10 @@ +mod access_log; +mod dispatch; +mod grpc; +mod response; + +#[cfg(test)] +mod tests; use std::error::Error as StdError; use bytes::Bytes; @@ -8,6 +15,11 @@ use hyper::{Request, Response}; use crate::client_ip::ConnectionPeerAddrs; +pub(crate) use access_log::UpstreamAccessLog; +pub use dispatch::handle; +pub(crate) use grpc::{GrpcStatusCode, grpc_error_response}; +pub(crate) use response::{full_body, text_response}; + pub(crate) type BoxError = Box; pub(crate) type HttpBody = UnsyncBoxBody; pub(crate) type HttpResponse = Response; @@ -15,16 +27,6 @@ pub(crate) type HttpResponse = Response; #[derive(Clone, Copy, Debug)] pub(crate) struct EarlyDataFlag(pub bool); -mod access_log; -mod dispatch; -mod grpc; -mod response; - -pub(crate) use access_log::UpstreamAccessLog; -pub use dispatch::handle; -pub(crate) use grpc::{GrpcStatusCode, grpc_error_response}; -pub(crate) use response::{full_body, text_response}; - pub(crate) fn boxed_body(body: B) -> HttpBody where B: Body + Send + 'static, @@ -42,6 +44,3 @@ pub(crate) fn attach_connection_metadata( request.extensions_mut().insert(identity); } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/handler/tests.rs b/crates/rginx-http/src/handler/tests.rs index 54131075..845a4c5b 100644 --- a/crates/rginx-http/src/handler/tests.rs +++ b/crates/rginx-http/src/handler/tests.rs @@ -1,3 +1,8 @@ +mod observability; +mod responses; +mod routing; +mod support; + use std::collections::HashMap; use std::time::Duration; @@ -23,9 +28,4 @@ use super::{GrpcStatusCode, attach_connection_metadata, grpc_error_response, tex use crate::client_ip::{ClientAddress, ClientIpSource, ConnectionPeerAddrs, TlsClientIdentity}; use crate::compression::ResponseCompressionOptions; -mod observability; -mod responses; -mod routing; -mod support; - pub(crate) use support::*; diff --git a/crates/rginx-http/src/handler/tests/routing/handle.rs b/crates/rginx-http/src/handler/tests/routing/handle.rs index 8ebb845b..2f996238 100644 --- a/crates/rginx-http/src/handler/tests/routing/handle.rs +++ b/crates/rginx-http/src/handler/tests/routing/handle.rs @@ -1,3 +1,5 @@ +mod file; +mod redirects; use super::super::*; #[tokio::test] @@ -209,6 +211,3 @@ async fn handle_short_circuits_acme_http01_requests() { let body = response.into_body().collect().await.expect("body should collect").to_bytes(); assert_eq!(body.as_ref(), b"demo-key-authorization"); } - -mod file; -mod redirects; diff --git a/crates/rginx-http/src/handler/tests/routing/handle/file.rs b/crates/rginx-http/src/handler/tests/routing/handle/file.rs index f4843b0e..79338bcf 100644 --- a/crates/rginx-http/src/handler/tests/routing/handle/file.rs +++ b/crates/rginx-http/src/handler/tests/routing/handle/file.rs @@ -1,13 +1,13 @@ +mod error_pages; +mod presentation; +mod redirects; + use std::path::Path; use std::sync::Arc; use std::time::Duration; use super::super::super::*; -mod error_pages; -mod presentation; -mod redirects; - fn default_file_route(base: &Path, route_prefix: &str) -> rginx_core::FileRoute { rginx_core::FileRoute { path_strategy: rginx_core::FilePathStrategy::Root(base.to_path_buf()), diff --git a/crates/rginx-http/src/handler/tests/routing/handle/file/presentation.rs b/crates/rginx-http/src/handler/tests/routing/handle/file/presentation.rs index c9b37522..0b8a2fea 100644 --- a/crates/rginx-http/src/handler/tests/routing/handle/file/presentation.rs +++ b/crates/rginx-http/src/handler/tests/routing/handle/file/presentation.rs @@ -1,5 +1,8 @@ use super::*; +#[cfg(unix)] +use std::os::unix::fs::symlink; + #[tokio::test] async fn handle_static_file_response_can_be_compressed() { let temp_dir = tempfile::tempdir().expect("temp dir should create"); @@ -157,11 +160,8 @@ async fn handle_file_route_rejects_symlinks_when_disabled() { let temp_dir = tempfile::tempdir().expect("temp dir should create"); std::fs::create_dir_all(temp_dir.path().join("site")).unwrap(); std::fs::write(temp_dir.path().join("secret.txt"), "secret").unwrap(); - std::os::unix::fs::symlink( - temp_dir.path().join("secret.txt"), - temp_dir.path().join("site/secret-link.txt"), - ) - .unwrap(); + symlink(temp_dir.path().join("secret.txt"), temp_dir.path().join("site/secret-link.txt")) + .unwrap(); let route = default_file_route_action( "server/routes[0]|prefix:/", diff --git a/crates/rginx-http/src/lib.rs b/crates/rginx-http/src/lib.rs index b3b5fc9d..22d655ce 100644 --- a/crates/rginx-http/src/lib.rs +++ b/crates/rginx-http/src/lib.rs @@ -1,3 +1,12 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] mod cache; mod client_ip; mod compression; @@ -13,14 +22,6 @@ mod timeout; mod tls; mod transition; -/// Upper bound for accepted OCSP response payloads. -pub const MAX_OCSP_RESPONSE_BYTES: usize = 128 * 1024; - -/// Installs the default process-wide TLS crypto provider used by HTTP components. -pub fn install_default_crypto_provider() { - tls::install_default_crypto_provider(); -} - pub use client_ip::TlsClientIdentity; pub use proxy::{PeerHealthSnapshot, UpstreamHealthSnapshot}; pub use server::serve; @@ -48,6 +49,14 @@ pub use transition::{ config_transition_boundary, plan_config_transition, validate_config_transition, }; +/// Upper bound for accepted OCSP response payloads. +pub const MAX_OCSP_RESPONSE_BYTES: usize = 128 * 1024; + +/// Installs the default process-wide TLS crypto provider used by HTTP components. +pub fn install_default_crypto_provider() { + tls::install_default_crypto_provider(); +} + #[cfg(test)] #[ctor::ctor(unsafe)] fn install_test_crypto_provider() { diff --git a/crates/rginx-http/src/pki/certificate.rs b/crates/rginx-http/src/pki/certificate.rs index 1040be39..2565fa23 100644 --- a/crates/rginx-http/src/pki/certificate.rs +++ b/crates/rginx-http/src/pki/certificate.rs @@ -1,5 +1,17 @@ +mod decode; +mod extensions; +mod helpers; +mod identity; +mod inspect; +mod name; + +#[cfg(test)] +mod tests; use crate::client_ip::TlsClientIdentity; +pub(crate) use identity::parse_tls_client_identity; +pub(crate) use inspect::inspect_certificate; + const TLS_EXPIRY_WARNING_DAYS: i64 = 30; const OID_ATTR_COMMON_NAME: &str = "2.5.4.3"; @@ -26,45 +38,35 @@ const OID_EKU_EMAIL_PROTECTION: &str = "1.3.6.1.5.5.7.3.4"; const OID_EKU_TIME_STAMPING: &str = "1.3.6.1.5.5.7.3.8"; const OID_EKU_OCSP_SIGNING: &str = "1.3.6.1.5.5.7.3.9"; -mod decode; -mod extensions; -mod helpers; -mod identity; -mod inspect; -mod name; - -pub(crate) use identity::parse_tls_client_identity; -pub(crate) use inspect::inspect_certificate; - #[derive(Debug, Clone)] pub(crate) struct InspectedCertificate { - pub(crate) subject: Option, - pub(crate) issuer: Option, - pub(crate) serial_number: Option, - pub(crate) san_dns_names: Vec, - pub(crate) fingerprint_sha256: Option, - pub(crate) subject_key_identifier: Option, pub(crate) authority_key_identifier: Option, + pub(crate) chain_diagnostics: Vec, + pub(crate) chain_length: usize, + pub(crate) chain_subjects: Vec, + pub(crate) expires_in_days: Option, + pub(crate) extended_key_usage: Vec, + pub(crate) fingerprint_sha256: Option, pub(crate) is_ca: Option, - pub(crate) path_len_constraint: Option, + pub(crate) issuer: Option, pub(crate) key_usage: Option, - pub(crate) extended_key_usage: Vec, - pub(crate) not_before_unix_ms: Option, pub(crate) not_after_unix_ms: Option, - pub(crate) expires_in_days: Option, - pub(crate) chain_length: usize, - pub(crate) chain_subjects: Vec, - pub(crate) chain_diagnostics: Vec, + pub(crate) not_before_unix_ms: Option, + pub(crate) path_len_constraint: Option, + pub(crate) san_dns_names: Vec, + pub(crate) serial_number: Option, + pub(crate) subject: Option, + pub(crate) subject_key_identifier: Option, } #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct ParsedClientIdentity { - pub(crate) subject: Option, - pub(crate) issuer: Option, - pub(crate) serial_number: Option, - pub(crate) san_dns_names: Vec, pub(crate) chain_length: usize, pub(crate) chain_subjects: Vec, + pub(crate) issuer: Option, + pub(crate) san_dns_names: Vec, + pub(crate) serial_number: Option, + pub(crate) subject: Option, } impl From for TlsClientIdentity { @@ -79,6 +81,3 @@ impl From for TlsClientIdentity { } } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/pki/certificate/extensions.rs b/crates/rginx-http/src/pki/certificate/extensions.rs index 13f5246e..5f70987b 100644 --- a/crates/rginx-http/src/pki/certificate/extensions.rs +++ b/crates/rginx-http/src/pki/certificate/extensions.rs @@ -11,6 +11,15 @@ use super::{ OID_EXT_KEY_USAGE, OID_EXT_SUBJECT_ALT_NAME, OID_EXT_SUBJECT_KEY_IDENTIFIER, }; +#[derive(Debug, Clone)] +pub(super) struct ParsedKeyUsage { + pub(super) digital_signature: bool, + pub(super) key_agreement: bool, + pub(super) key_cert_sign: bool, + pub(super) key_encipherment: bool, + names: Vec<&'static str>, +} + pub(super) fn subject_key_identifier(extensions: Option<&[Extension]>) -> Option { let extension = find_extension(extensions, OID_EXT_SUBJECT_KEY_IDENTIFIER)?; let value: SubjectKeyIdentifier = rasn::der::decode(extension.extn_value.as_ref()).ok()?; @@ -44,19 +53,10 @@ pub(super) fn basic_constraints(extensions: Option<&[Extension]>) -> Option, -} - pub(super) fn key_usage(extensions: Option<&[Extension]>) -> Option { let extension = find_extension(extensions, OID_EXT_KEY_USAGE)?; let bits: KeyUsage = rasn::der::decode(extension.extn_value.as_ref()).ok()?; - let bit = |index: usize| bits.get(index).map(|value| *value).unwrap_or(false); + let bit = |index: usize| bits.get(index).is_some_and(|value| *value); let mut names = Vec::new(); if bit(0) { names.push("digitalSignature"); diff --git a/crates/rginx-http/src/pki/certificate/identity.rs b/crates/rginx-http/src/pki/certificate/identity.rs index da82ef1f..a82d4eaa 100644 --- a/crates/rginx-http/src/pki/certificate/identity.rs +++ b/crates/rginx-http/src/pki/certificate/identity.rs @@ -17,7 +17,7 @@ pub(crate) fn parse_tls_client_identity<'a>( }; for (index, der) in der_chain.into_iter().enumerate() { - identity.chain_length += 1; + identity.chain_length = identity.chain_length.saturating_add(1); if let Some(cert) = decode_certificate(der) { let extensions = cert.tbs_certificate.extensions.as_ref().map(|value| value.as_slice()); let subject = name_to_string(&cert.tbs_certificate.subject); diff --git a/crates/rginx-http/src/pki/certificate/inspect.rs b/crates/rginx-http/src/pki/certificate/inspect.rs index 13f2e9e7..df098f8d 100644 --- a/crates/rginx-http/src/pki/certificate/inspect.rs +++ b/crates/rginx-http/src/pki/certificate/inspect.rs @@ -43,7 +43,7 @@ pub(crate) fn inspect_certificate(path: &Path) -> Option { let subject = name_to_string(&cert.tbs_certificate.subject); let issuer = name_to_string(&cert.tbs_certificate.issuer); let expires_in_days = time_to_unix_secs(cert.tbs_certificate.validity.not_after) - .map(|not_after| (not_after - now_secs).div_euclid(86_400)); + .map(|not_after| not_after.saturating_sub(now_secs).div_euclid(86_400)); let basic_constraints = basic_constraints(extensions); let key_usage = key_usage(extensions); let extended_key_usage = extended_key_usage(extensions); @@ -111,32 +111,29 @@ pub(crate) fn inspect_certificate(path: &Path) -> Option { if chain_entries.len() == certs.len() { for index in 0..chain_entries.len().saturating_sub(1) { let issuer = chain_entries[index].issuer.as_deref(); - let next_subject = chain_entries[index + 1].subject.as_deref(); + let next_index = index.saturating_add(1); + let next_subject = chain_entries[next_index].subject.as_deref(); if issuer != next_subject { chain_diagnostics.push(format!( - "chain_link_mismatch cert[{index}]_issuer_to_cert[{}]_subject", - index + 1 + "chain_link_mismatch cert[{index}]_issuer_to_cert[{next_index}]_subject" )); } if let (Some(aki), Some(ski)) = ( chain_entries[index].authority_key_identifier.as_deref(), - chain_entries[index + 1].subject_key_identifier.as_deref(), + chain_entries[next_index].subject_key_identifier.as_deref(), ) && aki != ski { chain_diagnostics - .push(format!("chain_aki_ski_mismatch cert[{index}]_to_cert[{}]", index + 1)); + .push(format!("chain_aki_ski_mismatch cert[{index}]_to_cert[{next_index}]")); } - if let Some(path_len_constraint) = chain_entries[index + 1].path_len_constraint { - let descendant_ca_certs = chain_entries[..index + 1] + if let Some(path_len_constraint) = chain_entries[next_index].path_len_constraint { + let descendant_ca_certs = chain_entries[..next_index] .iter() .filter(|entry| entry.is_ca == Some(true)) .count() as u32; if descendant_ca_certs > path_len_constraint { chain_diagnostics.push(format!( - "cert[{}] path_len_constraint_exceeded descendant_ca_certs={} path_len_constraint={}", - index + 1, - descendant_ca_certs, - path_len_constraint + "cert[{next_index}] path_len_constraint_exceeded descendant_ca_certs={descendant_ca_certs} path_len_constraint={path_len_constraint}" )); } } diff --git a/crates/rginx-http/src/pki/certificate/name.rs b/crates/rginx-http/src/pki/certificate/name.rs index a538ad85..6331b188 100644 --- a/crates/rginx-http/src/pki/certificate/name.rs +++ b/crates/rginx-http/src/pki/certificate/name.rs @@ -77,12 +77,10 @@ fn decode_directory_or_string(bytes: &[u8]) -> String { fn decode_printable_string(bytes: &[u8]) -> String { rasn::der::decode::(bytes) - .map(|value| bytes_to_lossy_string(value.as_bytes())) - .unwrap_or_else(|_| hex_string(bytes)) + .map_or_else(|_| hex_string(bytes), |value| bytes_to_lossy_string(value.as_bytes())) } fn decode_ia5_string(bytes: &[u8]) -> String { rasn::der::decode::(bytes) - .map(|value| value.to_string()) - .unwrap_or_else(|_| hex_string(bytes)) + .map_or_else(|_| hex_string(bytes), |value| value.to_string()) } diff --git a/crates/rginx-http/src/proxy/clients/factory.rs b/crates/rginx-http/src/proxy/clients/factory.rs index eb2fe5ab..f6fe4db7 100644 --- a/crates/rginx-http/src/proxy/clients/factory.rs +++ b/crates/rginx-http/src/proxy/clients/factory.rs @@ -1,6 +1,10 @@ use tokio::sync::Mutex; -use super::*; +use super::{ + Arc, EndpointClientCache, Error, HttpProxyClient, ProxyClient, ServerName, + UpstreamClientProfile, UpstreamProtocol, UpstreamResolver, endpoint_client_cache_capacity, + http3, tls, +}; pub(super) fn build_client_for_profile( profile: &UpstreamClientProfile, diff --git a/crates/rginx-http/src/proxy/clients/http3/connect.rs b/crates/rginx-http/src/proxy/clients/http3/connect.rs index 04be31c1..43954d5e 100644 --- a/crates/rginx-http/src/proxy/clients/http3/connect.rs +++ b/crates/rginx-http/src/proxy/clients/http3/connect.rs @@ -12,6 +12,56 @@ use super::Http3Client; use super::session::{Http3Session, Http3SessionEntry, Http3SessionKey}; impl Http3Client { + #[cfg(test)] + pub(super) async fn cached_session_count(&self) -> usize { + self.sessions + .lock() + .await + .values() + .filter(|entry| matches!(entry, Http3SessionEntry::Ready(_))) + .count() + } + + async fn connect_session( + &self, + key: &Http3SessionKey, + peer_url: &str, + ) -> Result { + let endpoint = self.endpoint_for_remote(key.remote_addr).await?; + let connecting = endpoint.connect(key.remote_addr, &key.server_name).map_err(|error| { + Error::Server(format!( + "failed to start upstream http3 connect to `{peer_url}`: {error}" + )) + })?; + let connection = tokio::time::timeout(self.connect_timeout, connecting) + .await + .map_err(|_| { + Error::Server(format!( + "upstream http3 connect to `{}` timed out after {} ms", + peer_url, + self.connect_timeout.as_millis() + )) + })? + .map_err(|error| { + Error::Server(format!("upstream http3 connect to `{peer_url}` failed: {error}")) + })?; + + let (mut driver, send_request) = + client::new(h3_quinn::Connection::new(connection)).await.map_err(|error| { + Error::Server(format!( + "failed to initialize upstream http3 session for `{peer_url}`: {error}" + )) + })?; + let session = Http3Session::new(send_request); + let driver_closed = session.close_flag(); + let driver_task = tokio::spawn(async move { + let _ = poll_fn(|cx| driver.poll_close(cx)).await; + driver_closed.store(true, Ordering::Release); + }); + session.set_driver_task(driver_task).await; + Ok(session) + } + pub(super) async fn session_for( &self, key: Http3SessionKey, @@ -19,8 +69,8 @@ impl Http3Client { ) -> Result, Error> { loop { enum SessionAction { - Wait(Arc), Connect(Arc), + Wait(Arc), } let action = { @@ -57,58 +107,6 @@ impl Http3Client { } } } - - async fn connect_session( - &self, - key: &Http3SessionKey, - peer_url: &str, - ) -> Result { - let endpoint = self.endpoint_for_remote(key.remote_addr).await?; - let connecting = endpoint.connect(key.remote_addr, &key.server_name).map_err(|error| { - Error::Server(format!( - "failed to start upstream http3 connect to `{}`: {error}", - peer_url - )) - })?; - let connection = tokio::time::timeout(self.connect_timeout, connecting) - .await - .map_err(|_| { - Error::Server(format!( - "upstream http3 connect to `{}` timed out after {} ms", - peer_url, - self.connect_timeout.as_millis() - )) - })? - .map_err(|error| { - Error::Server(format!("upstream http3 connect to `{}` failed: {error}", peer_url)) - })?; - - let (mut driver, send_request) = - client::new(h3_quinn::Connection::new(connection)).await.map_err(|error| { - Error::Server(format!( - "failed to initialize upstream http3 session for `{}`: {error}", - peer_url - )) - })?; - let session = Http3Session::new(send_request); - let driver_closed = session.close_flag(); - let driver_task = tokio::spawn(async move { - let _ = poll_fn(|cx| driver.poll_close(cx)).await; - driver_closed.store(true, Ordering::Release); - }); - session.set_driver_task(driver_task).await; - Ok(session) - } - - #[cfg(test)] - pub(super) async fn cached_session_count(&self) -> usize { - self.sessions - .lock() - .await - .values() - .filter(|entry| matches!(entry, Http3SessionEntry::Ready(_))) - .count() - } } pub(super) fn server_name_for_peer( diff --git a/crates/rginx-http/src/proxy/clients/http3/endpoint_cache.rs b/crates/rginx-http/src/proxy/clients/http3/endpoint_cache.rs index a04fb94d..ad42dd1d 100644 --- a/crates/rginx-http/src/proxy/clients/http3/endpoint_cache.rs +++ b/crates/rginx-http/src/proxy/clients/http3/endpoint_cache.rs @@ -11,6 +11,19 @@ pub(super) struct Http3ClientEndpoints { } impl Http3Client { + #[cfg(test)] + pub(super) async fn cached_endpoint_count(&self) -> usize { + usize::from(self.endpoints.ipv4.lock().await.is_some()) + .saturating_add(usize::from(self.endpoints.ipv6.lock().await.is_some())) + } + + #[cfg(test)] + pub(super) async fn cached_endpoint_local_addr( + &self, + remote_addr: SocketAddr, + ) -> Result { + self.endpoint_for_remote(remote_addr).await?.local_addr().map_err(Error::Io) + } pub(super) async fn endpoint_for_remote( &self, remote_addr: SocketAddr, @@ -34,18 +47,4 @@ impl Http3Client { *endpoint = Some(created); Ok(reusable) } - - #[cfg(test)] - pub(super) async fn cached_endpoint_count(&self) -> usize { - usize::from(self.endpoints.ipv4.lock().await.is_some()) - + usize::from(self.endpoints.ipv6.lock().await.is_some()) - } - - #[cfg(test)] - pub(super) async fn cached_endpoint_local_addr( - &self, - remote_addr: SocketAddr, - ) -> Result { - self.endpoint_for_remote(remote_addr).await?.local_addr().map_err(Error::Io) - } } diff --git a/crates/rginx-http/src/proxy/clients/http3/mod.rs b/crates/rginx-http/src/proxy/clients/http3/mod.rs index 949e0216..58b2e5f0 100644 --- a/crates/rginx-http/src/proxy/clients/http3/mod.rs +++ b/crates/rginx-http/src/proxy/clients/http3/mod.rs @@ -1,21 +1,5 @@ //! Upstream HTTP/3 session reuse, endpoint caching, and body bridging. -use std::collections::HashMap; -use std::sync::Arc; - -use tokio::sync::Mutex; - -use super::*; -use endpoint_cache::Http3ClientEndpoints; -use session::{Http3SessionEntry, Http3SessionKey}; - -#[cfg(test)] -use crate::handler::boxed_body; -#[cfg(test)] -use std::net::SocketAddr; -#[cfg(test)] -use tokio::task::JoinHandle; - mod connect; mod endpoint_cache; mod request; @@ -24,16 +8,35 @@ mod session; #[cfg(test)] mod tests; +use std::collections::HashMap; +use std::sync::Arc; + +use tokio::sync::Mutex; + +use super::{ + Duration, Error, ResolvedUpstreamPeer, UpstreamPeer, UpstreamResolver, + UpstreamResolverRuntimeSnapshot, +}; +use endpoint_cache::Http3ClientEndpoints; +use session::{Http3SessionEntry, Http3SessionKey}; + #[derive(Clone)] pub(crate) struct Http3Client { client_config: quinn::ClientConfig, connect_timeout: Duration, - resolver: Arc, endpoints: Arc, + resolver: Arc, sessions: Arc>>, } impl Http3Client { + pub(super) async fn cached_peer_endpoints( + &self, + peer: &UpstreamPeer, + ) -> Result, Error> { + self.resolver.cached_peer_endpoints(peer).await + } + pub(super) fn new( client_config: quinn::ClientConfig, connect_timeout: Duration, @@ -55,13 +58,6 @@ impl Http3Client { self.resolver.resolve_peer(peer).await } - pub(super) async fn cached_peer_endpoints( - &self, - peer: &UpstreamPeer, - ) -> Result, Error> { - self.resolver.cached_peer_endpoints(peer).await - } - pub(super) async fn resolver_snapshot(&self) -> UpstreamResolverRuntimeSnapshot { self.resolver.snapshot().await } diff --git a/crates/rginx-http/src/proxy/clients/http3/request.rs b/crates/rginx-http/src/proxy/clients/http3/request.rs index 4f8a3f8c..c5caaa24 100644 --- a/crates/rginx-http/src/proxy/clients/http3/request.rs +++ b/crates/rginx-http/src/proxy/clients/http3/request.rs @@ -96,7 +96,7 @@ impl Http3Client { )) })?; - let (parts, _) = response.into_parts(); + let (parts, ()) = response.into_parts(); let size_hint = response_size_hint(&parts.headers); let body = streaming_response_body(request_stream, session, peer.display_url.clone(), size_hint); diff --git a/crates/rginx-http/src/proxy/clients/http3/response_body.rs b/crates/rginx-http/src/proxy/clients/http3/response_body.rs index 1fa2a045..696217cf 100644 --- a/crates/rginx-http/src/proxy/clients/http3/response_body.rs +++ b/crates/rginx-http/src/proxy/clients/http3/response_body.rs @@ -8,11 +8,17 @@ use crate::handler::{BoxError, HttpBody, boxed_body}; use super::session::{H3RequestStream, Http3Session}; +// Fallback matching for clean HTTP/3 shutdowns when only Display-formatted +// errors are available here. These strings correspond to upstream quinn/h3 +// formatting of an application close with the HTTP/3 code H3_NO_ERROR. +const QUINN_APPLICATION_CLOSE_H3_NO_ERROR: &str = "ApplicationClose: H3_NO_ERROR"; +const H3_APPLICATION_CODE_H3_NO_ERROR: &str = "Application { code: H3_NO_ERROR"; + struct StreamingResponseBody { - rx: mpsc::Receiver, BoxError>>, - size_hint: SizeHint, done: bool, join_handle: Option>, + rx: mpsc::Receiver, BoxError>>, + size_hint: SizeHint, } impl StreamingResponseBody { @@ -39,6 +45,10 @@ impl hyper::body::Body for StreamingResponseBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -53,10 +63,6 @@ impl hyper::body::Body for StreamingResponseBody { } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { self.size_hint.clone() } @@ -131,12 +137,6 @@ pub(super) fn response_size_hint(headers: &HeaderMap) -> SizeHint { hint } -// Fallback matching for clean HTTP/3 shutdowns when only Display-formatted -// errors are available here. These strings correspond to upstream quinn/h3 -// formatting of an application close with the HTTP/3 code H3_NO_ERROR. -const QUINN_APPLICATION_CLOSE_H3_NO_ERROR: &str = "ApplicationClose: H3_NO_ERROR"; -const H3_APPLICATION_CODE_H3_NO_ERROR: &str = "Application { code: H3_NO_ERROR"; - fn is_clean_http3_response_shutdown(error: &impl std::fmt::Display) -> bool { let error = error.to_string(); error.contains(QUINN_APPLICATION_CLOSE_H3_NO_ERROR) diff --git a/crates/rginx-http/src/proxy/clients/http3/session.rs b/crates/rginx-http/src/proxy/clients/http3/session.rs index b93229ca..0e383b19 100644 --- a/crates/rginx-http/src/proxy/clients/http3/session.rs +++ b/crates/rginx-http/src/proxy/clients/http3/session.rs @@ -16,18 +16,30 @@ pub(super) struct Http3SessionKey { } pub(super) struct Http3Session { - sender: Mutex, closed: Arc, driver_task: Mutex>>, + sender: Mutex, } #[derive(Clone)] pub(super) enum Http3SessionEntry { - Ready(Arc), Pending(Arc), + Ready(Arc), } impl Http3Session { + pub(super) fn close_flag(&self) -> Arc { + self.closed.clone() + } + + pub(super) fn is_closed(&self) -> bool { + self.closed.load(Ordering::Acquire) + } + + pub(super) fn mark_closed(&self) { + self.closed.store(true, Ordering::Release); + } + pub(super) fn new(sender: H3SendRequest) -> Self { Self { sender: Mutex::new(sender), @@ -43,18 +55,6 @@ impl Http3Session { pub(super) async fn set_driver_task(&self, task: JoinHandle<()>) { *self.driver_task.lock().await = Some(task); } - - pub(super) fn close_flag(&self) -> Arc { - self.closed.clone() - } - - pub(super) fn mark_closed(&self) { - self.closed.store(true, Ordering::Release); - } - - pub(super) fn is_closed(&self) -> bool { - self.closed.load(Ordering::Acquire) - } } impl Drop for Http3Session { diff --git a/crates/rginx-http/src/proxy/clients/http3/tests.rs b/crates/rginx-http/src/proxy/clients/http3/tests.rs index 83068dce..4f65c470 100644 --- a/crates/rginx-http/src/proxy/clients/http3/tests.rs +++ b/crates/rginx-http/src/proxy/clients/http3/tests.rs @@ -1,9 +1,18 @@ -use super::*; -use http::StatusCode; +use super::{Http3Client, UpstreamResolver}; +use crate::handler::{HttpBody, boxed_body}; +use bytes::Bytes; +use http::{Request, Response, StatusCode}; use http_body_util::BodyExt; -use rginx_core::{UpstreamDnsPolicy, UpstreamLoadBalance, UpstreamSettings, UpstreamTls}; +use rginx_core::{ + Upstream, UpstreamDnsPolicy, UpstreamLoadBalance, UpstreamPeer, UpstreamProtocol, + UpstreamSettings, UpstreamTls, +}; +use std::net::SocketAddr; +use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; use std::time::Instant; +use tokio::task::JoinHandle; fn test_resolver() -> Arc { Arc::new(UpstreamResolver::new(UpstreamDnsPolicy::default()).expect("resolver should build")) diff --git a/crates/rginx-http/src/proxy/clients/http_client.rs b/crates/rginx-http/src/proxy/clients/http_client.rs index 4eea1bd3..b8da89e3 100644 --- a/crates/rginx-http/src/proxy/clients/http_client.rs +++ b/crates/rginx-http/src/proxy/clients/http_client.rs @@ -3,11 +3,16 @@ use std::io; use std::net::SocketAddr; use std::task::{Context, Poll}; +use hyper_util::client::legacy::connect::dns::Name; use rustls::ClientConfig; use tokio::sync::Mutex; use tower_service::Service; -use super::*; +use super::{ + Arc, Client, Error, FixedServerNameResolver, HashMap, HttpConnector, HttpsConnectorBuilder, + HyperProxyClient, ResolvedUpstreamPeer, ServerName, TokioExecutor, TokioTimer, + UpstreamClientProfile, UpstreamProtocol, UpstreamResolver, +}; const ENDPOINT_CLIENT_CACHE_MIN_CAPACITY: usize = 16; const ENDPOINT_CLIENT_CACHE_MAX_CAPACITY: usize = 1024; @@ -19,15 +24,15 @@ pub(crate) struct HttpProxyClient { // endpoint client, so effective idle capacity is per live endpoint until LRU // eviction trims stale DNS endpoints from this bounded cache. pub(super) endpoint_clients: Arc>, - pub(super) resolver: Arc, pub(super) profile: UpstreamClientProfile, - pub(super) tls_config: ClientConfig, + pub(super) resolver: Arc, pub(super) server_name_override: Option>, + pub(super) tls_config: ClientConfig, } pub(super) struct EndpointClientCache { - pub(super) entries: HashMap, capacity: usize, + pub(super) entries: HashMap, next_access: u64, } @@ -59,8 +64,16 @@ impl HttpProxyClient { } impl EndpointClientCache { - pub(super) fn new(capacity: usize) -> Self { - Self { entries: HashMap::new(), capacity: capacity.max(1), next_access: 0 } + fn evict_lru(&mut self) { + let Some(socket_addr) = self + .entries + .iter() + .min_by_key(|(_socket_addr, entry)| entry.last_used) + .map(|(socket_addr, _entry)| *socket_addr) + else { + return; + }; + self.entries.remove(&socket_addr); } pub(super) fn get(&mut self, socket_addr: SocketAddr) -> Option { @@ -85,16 +98,8 @@ impl EndpointClientCache { client } - fn evict_lru(&mut self) { - let Some(socket_addr) = self - .entries - .iter() - .min_by_key(|(_socket_addr, entry)| entry.last_used) - .map(|(socket_addr, _entry)| *socket_addr) - else { - return; - }; - self.entries.remove(&socket_addr); + pub(super) fn new(capacity: usize) -> Self { + Self { entries: HashMap::new(), capacity: capacity.max(1), next_access: 0 } } fn next_access(&mut self) -> u64 { @@ -109,18 +114,19 @@ impl FixedEndpointResolver { } } -impl Service for FixedEndpointResolver { - type Response = std::vec::IntoIter; +impl Service for FixedEndpointResolver { type Error = io::Error; type Future = Ready>; - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + type Response = std::vec::IntoIter; - fn call(&mut self, _name: hyper_util::client::legacy::connect::dns::Name) -> Self::Future { + fn call(&mut self, _name: Name) -> Self::Future { ready(Ok(vec![self.socket_addr].into_iter())) } + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } } pub(super) fn build_hyper_client_for_endpoint( diff --git a/crates/rginx-http/src/proxy/clients/mod.rs b/crates/rginx-http/src/proxy/clients/mod.rs index 5c24aa16..d7317362 100644 --- a/crates/rginx-http/src/proxy/clients/mod.rs +++ b/crates/rginx-http/src/proxy/clients/mod.rs @@ -1,12 +1,5 @@ //! Cached upstream proxy client selection for HTTP/1.1, HTTP/2, and HTTP/3. -use super::health::{ - ActivePeerGuard, ActiveProbeStatus, PeerFailureStatus, PeerHealthRegistry, SelectedPeers, - UpstreamHealthSnapshot, -}; -use super::*; -use std::error::Error as StdError; - mod factory; mod http3; mod http_client; @@ -15,12 +8,21 @@ mod profile; mod tests; mod tls; +use super::health::{ + ActivePeerGuard, ActiveProbeStatus, PeerFailureStatus, PeerHealthRegistry, SelectedPeers, + UpstreamHealthSnapshot, +}; +use super::{ + Arc, Client, ConfigSnapshot, Duration, Error, FixedServerNameResolver, HashMap, HashSet, + HttpBody, HttpConnector, HttpsConnector, HttpsConnectorBuilder, Request, ResolvedUpstreamPeer, + Response, ServerName, TokioExecutor, TokioTimer, Upstream, UpstreamPeer, UpstreamProtocol, + UpstreamResolver, UpstreamResolverRuntimeSnapshot, UpstreamTls, +}; +use std::error::Error as StdError; + #[cfg(test)] pub(super) use tls::load_custom_ca_store; -type HyperProxyClient = Client>, HttpBody>; -pub(crate) type HealthChangeNotifier = Arc; - use factory::build_client_for_profile; pub(crate) use http_client::HttpProxyClient; #[cfg(test)] @@ -28,11 +30,14 @@ use http_client::build_hyper_client_for_endpoint; use http_client::{EndpointClientCache, FixedEndpointResolver, endpoint_client_cache_capacity}; use profile::UpstreamClientProfile; +type HyperProxyClient = Client>, HttpBody>; +pub(crate) type HealthChangeNotifier = Arc; + #[derive(Clone)] pub struct ProxyClients { - upstreams: Arc>>, clients: Arc>, health: PeerHealthRegistry, + upstreams: Arc>>, } #[derive(Clone)] @@ -42,6 +47,21 @@ pub(crate) enum ProxyClient { } impl ProxyClients { + #[cfg(test)] + pub(super) fn cached_client_count(&self) -> usize { + self.clients.len() + } + + pub(crate) fn for_upstream(&self, upstream: &Upstream) -> Result { + let profile = UpstreamClientProfile::from_upstream(upstream); + self.clients.get(&profile).cloned().ok_or_else(|| { + Error::Server(format!( + "missing cached proxy client for upstream `{}` with TLS profile {:?}", + upstream.name, profile + )) + }) + } + pub fn from_config(config: &ConfigSnapshot) -> Result { Self::from_config_with_health_notifier(config, None) } @@ -75,61 +95,6 @@ impl ProxyClients { }) } - pub(crate) fn for_upstream(&self, upstream: &Upstream) -> Result { - let profile = UpstreamClientProfile::from_upstream(upstream); - self.clients.get(&profile).cloned().ok_or_else(|| { - Error::Server(format!( - "missing cached proxy client for upstream `{}` with TLS profile {:?}", - upstream.name, profile - )) - }) - } - - pub(super) async fn select_peers( - &self, - upstream: &Upstream, - client_ip: std::net::IpAddr, - limit: usize, - ) -> SelectedPeers { - match self.for_upstream(upstream) { - Ok(client) => self.health.select_peers(&client, upstream, client_ip, limit).await, - Err(_) => SelectedPeers { peers: Vec::new(), skipped_unhealthy: 0 }, - } - } - - pub(super) fn record_peer_success(&self, upstream_name: &str, peer_url: &str) -> bool { - self.health.record_success(upstream_name, peer_url) - } - - pub(super) fn record_peer_failure( - &self, - upstream_name: &str, - peer_url: &str, - ) -> PeerFailureStatus { - self.health.record_failure(upstream_name, peer_url) - } - - pub(super) fn record_active_peer_success( - &self, - upstream_name: &str, - peer_url: &str, - healthy_successes_required: u32, - ) -> ActiveProbeStatus { - self.health.record_active_success(upstream_name, peer_url, healthy_successes_required) - } - - pub(crate) fn record_active_peer_failure(&self, upstream_name: &str, peer_url: &str) -> bool { - self.health.record_active_failure(upstream_name, peer_url) - } - - pub(super) fn track_active_request( - &self, - upstream_name: &str, - peer_url: &str, - ) -> ActivePeerGuard { - self.health.track_active_request(upstream_name, peer_url) - } - pub(crate) async fn peer_health_snapshot(&self) -> Vec { let mut upstreams = self.upstreams.values().cloned().collect::>(); upstreams.sort_by(|left, right| left.name.cmp(&right.name)); @@ -152,23 +117,53 @@ impl ProxyClients { .collect() } - #[cfg(test)] - pub(super) fn cached_client_count(&self) -> usize { - self.clients.len() + pub(crate) fn record_active_peer_failure(&self, upstream_name: &str, peer_url: &str) -> bool { + self.health.record_active_failure(upstream_name, peer_url) } -} -impl ProxyClient { - pub(crate) async fn resolve_peer( + pub(super) fn record_active_peer_success( &self, - peer: &UpstreamPeer, - ) -> Result, Error> { - match self { - Self::Http(client) => client.resolver.resolve_peer(peer).await, - Self::Http3(client) => client.resolve_peer(peer).await, + upstream_name: &str, + peer_url: &str, + healthy_successes_required: u32, + ) -> ActiveProbeStatus { + self.health.record_active_success(upstream_name, peer_url, healthy_successes_required) + } + + pub(super) fn record_peer_failure( + &self, + upstream_name: &str, + peer_url: &str, + ) -> PeerFailureStatus { + self.health.record_failure(upstream_name, peer_url) + } + + pub(super) fn record_peer_success(&self, upstream_name: &str, peer_url: &str) -> bool { + self.health.record_success(upstream_name, peer_url) + } + + pub(super) async fn select_peers( + &self, + upstream: &Upstream, + client_ip: std::net::IpAddr, + limit: usize, + ) -> SelectedPeers { + match self.for_upstream(upstream) { + Ok(client) => self.health.select_peers(&client, upstream, client_ip, limit).await, + Err(_) => SelectedPeers { peers: Vec::new(), skipped_unhealthy: 0 }, } } + pub(super) fn track_active_request( + &self, + upstream_name: &str, + peer_url: &str, + ) -> ActivePeerGuard { + self.health.track_active_request(upstream_name, peer_url) + } +} + +impl ProxyClient { pub(crate) async fn cached_peer_endpoints( &self, peer: &UpstreamPeer, @@ -179,13 +174,6 @@ impl ProxyClient { } } - pub(crate) async fn resolver_snapshot(&self) -> UpstreamResolverRuntimeSnapshot { - match self { - Self::Http(client) => client.resolver.snapshot().await, - Self::Http3(client) => client.resolver_snapshot().await, - } - } - pub async fn request( &self, upstream: &Upstream, @@ -206,6 +194,22 @@ impl ProxyClient { Self::Http3(client) => client.request(upstream, peer, request).await, } } + pub(crate) async fn resolve_peer( + &self, + peer: &UpstreamPeer, + ) -> Result, Error> { + match self { + Self::Http(client) => client.resolver.resolve_peer(peer).await, + Self::Http3(client) => client.resolve_peer(peer).await, + } + } + + pub(crate) async fn resolver_snapshot(&self) -> UpstreamResolverRuntimeSnapshot { + match self { + Self::Http(client) => client.resolver.snapshot().await, + Self::Http3(client) => client.resolver_snapshot().await, + } + } } fn format_error_chain(prefix: &str, error: &(dyn StdError + 'static)) -> String { diff --git a/crates/rginx-http/src/proxy/clients/profile.rs b/crates/rginx-http/src/proxy/clients/profile.rs index 950cd6b7..5aa3b7e6 100644 --- a/crates/rginx-http/src/proxy/clients/profile.rs +++ b/crates/rginx-http/src/proxy/clients/profile.rs @@ -2,27 +2,27 @@ use std::path::PathBuf; use rginx_core::{ClientIdentity, TlsVersion}; -use super::*; +use super::{Duration, Upstream, UpstreamProtocol, UpstreamTls}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(super) struct UpstreamClientProfile { - pub(super) tls: UpstreamTls, - pub(super) dns: rginx_core::UpstreamDnsPolicy, - pub(super) tls_versions: Option>, - pub(super) server_verify_depth: Option, - pub(super) server_crl_path: Option, pub(super) client_identity: Option, - pub(super) protocol: UpstreamProtocol, - pub(super) server_name: bool, - pub(super) server_name_override: Option, pub(super) connect_timeout: Duration, + pub(super) dns: rginx_core::UpstreamDnsPolicy, + pub(super) http2_keep_alive_interval: Option, + pub(super) http2_keep_alive_timeout: Duration, + pub(super) http2_keep_alive_while_idle: bool, pub(super) pool_idle_timeout: Option, pub(super) pool_max_idle_per_host: usize, + pub(super) protocol: UpstreamProtocol, + pub(super) server_crl_path: Option, + pub(super) server_name: bool, + pub(super) server_name_override: Option, + pub(super) server_verify_depth: Option, pub(super) tcp_keepalive: Option, pub(super) tcp_nodelay: bool, - pub(super) http2_keep_alive_interval: Option, - pub(super) http2_keep_alive_timeout: Duration, - pub(super) http2_keep_alive_while_idle: bool, + pub(super) tls: UpstreamTls, + pub(super) tls_versions: Option>, } impl UpstreamClientProfile { diff --git a/crates/rginx-http/src/proxy/clients/tls/verifier.rs b/crates/rginx-http/src/proxy/clients/tls/verifier.rs index 00830554..392a141a 100644 --- a/crates/rginx-http/src/proxy/clients/tls/verifier.rs +++ b/crates/rginx-http/src/proxy/clients/tls/verifier.rs @@ -8,23 +8,6 @@ use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, Server use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{DigitallySignedStruct, RootCertStore, SignatureScheme}; -pub(super) fn build_server_cert_verifier( - roots: RootCertStore, - server_verify_depth: Option, - server_crl_path: Option<&Path>, -) -> Result, Error> { - let builder = if let Some(crl_path) = server_crl_path { - WebPkiServerVerifier::builder(roots.into()) - .with_crls(load_certificate_revocation_lists(crl_path)?) - } else { - WebPkiServerVerifier::builder(roots.into()) - }; - let verifier = builder.build().map_err(|error| { - Error::Server(format!("failed to build upstream certificate verifier: {error}")) - })?; - Ok(Arc::new(DepthLimitedServerCertVerifier::new(verifier, server_verify_depth))) -} - #[derive(Debug)] pub(super) struct InsecureServerCertVerifier { supported_schemes: Vec, @@ -43,6 +26,9 @@ impl DepthLimitedServerCertVerifier { } impl ServerCertVerifier for DepthLimitedServerCertVerifier { + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } fn verify_server_cert( &self, end_entity: &CertificateDer<'_>, @@ -80,10 +66,6 @@ impl ServerCertVerifier for DepthLimitedServerCertVerifier { ) -> Result { self.inner.verify_tls13_signature(message, cert, dss) } - - fn supported_verify_schemes(&self) -> Vec { - self.inner.supported_verify_schemes() - } } impl InsecureServerCertVerifier { @@ -96,6 +78,9 @@ impl InsecureServerCertVerifier { } impl ServerCertVerifier for InsecureServerCertVerifier { + fn supported_verify_schemes(&self) -> Vec { + self.supported_schemes.clone() + } fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, @@ -124,8 +109,21 @@ impl ServerCertVerifier for InsecureServerCertVerifier { ) -> Result { Ok(HandshakeSignatureValid::assertion()) } +} - fn supported_verify_schemes(&self) -> Vec { - self.supported_schemes.clone() - } +pub(super) fn build_server_cert_verifier( + roots: RootCertStore, + server_verify_depth: Option, + server_crl_path: Option<&Path>, +) -> Result, Error> { + let builder = if let Some(crl_path) = server_crl_path { + WebPkiServerVerifier::builder(roots.into()) + .with_crls(load_certificate_revocation_lists(crl_path)?) + } else { + WebPkiServerVerifier::builder(roots.into()) + }; + let verifier = builder.build().map_err(|error| { + Error::Server(format!("failed to build upstream certificate verifier: {error}")) + })?; + Ok(Arc::new(DepthLimitedServerCertVerifier::new(verifier, server_verify_depth))) } diff --git a/crates/rginx-http/src/proxy/common.rs b/crates/rginx-http/src/proxy/common.rs index 2e2aaaab..4dcc0d8f 100644 --- a/crates/rginx-http/src/proxy/common.rs +++ b/crates/rginx-http/src/proxy/common.rs @@ -1,9 +1,19 @@ -use super::*; - mod uri; +use super::{ + CONNECTION, CONTENT_LENGTH, CONTENT_TYPE, ClientAddress, GrpcWebMode, HOST, HeaderMap, + HeaderName, HeaderValue, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, ProxyHeaderRenderContext, + ProxyHeaderValue, ResolvedUpstreamPeer, StatusCode, TE, TRAILER, TRANSFER_ENCODING, UPGRADE, + UpstreamProtocol, Uri, Version, +}; + pub(super) use uri::build_proxy_uri; +enum ProxyHeaderOverride { + Remove(HeaderName), + Set(HeaderName, HeaderValue), +} + pub(super) fn split_content_type(content_type: &str) -> (&str, &str) { let mut parts = content_type.splitn(2, ';'); let mime = parts.next().unwrap_or_default().trim(); @@ -13,7 +23,7 @@ pub(super) fn split_content_type(content_type: &str) -> (&str, &str) { pub(super) fn append_header_map(destination: &mut HeaderMap, source: &HeaderMap) { for name in source.keys() { - for value in source.get_all(name).iter() { + for value in &source.get_all(name) { destination.append(name.clone(), value.clone()); } } @@ -91,11 +101,6 @@ pub(super) fn sanitize_request_headers( Ok(()) } -enum ProxyHeaderOverride { - Set(HeaderName, HeaderValue), - Remove(HeaderName), -} - fn render_proxy_header_overrides( original_headers: &HeaderMap, authority: &str, diff --git a/crates/rginx-http/src/proxy/common/uri.rs b/crates/rginx-http/src/proxy/common/uri.rs index 7353384d..d86dd926 100644 --- a/crates/rginx-http/src/proxy/common/uri.rs +++ b/crates/rginx-http/src/proxy/common/uri.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{ResolvedUpstreamPeer, Uri}; use rginx_core::ProxyUriMode; pub(in crate::proxy) fn build_proxy_uri( diff --git a/crates/rginx-http/src/proxy/error_mapping.rs b/crates/rginx-http/src/proxy/error_mapping.rs index e6fb3577..fd809552 100644 --- a/crates/rginx-http/src/proxy/error_mapping.rs +++ b/crates/rginx-http/src/proxy/error_mapping.rs @@ -1,4 +1,4 @@ -use super::*; +use super::UpstreamTls; pub(crate) fn upstream_tls_verify_label(tls: &UpstreamTls) -> &'static str { match tls { diff --git a/crates/rginx-http/src/proxy/forward/attempt.rs b/crates/rginx-http/src/proxy/forward/attempt.rs index da9cb3ec..ec570d52 100644 --- a/crates/rginx-http/src/proxy/forward/attempt.rs +++ b/crates/rginx-http/src/proxy/forward/attempt.rs @@ -1,10 +1,21 @@ -use super::*; - mod background; mod cache_lookup; mod logging; mod primary; +use super::{ + BodyExt, ClientAddress, DownstreamRequestBodyFailure, DownstreamRequestContext, + DownstreamRequestOptions, DownstreamResponseContext, ForwardCacheBackend, ForwardCacheContext, + ForwardCacheLookup, HOST, HttpBody, HttpResponse, ProxyClients, ProxyTarget, Request, + ResolvedUpstreamPeer, SharedState, UpstreamSuccessContext, bad_gateway, + build_downstream_response, can_retry_peer_request, classify_downstream_request_body_failure, + classify_upstream_tls_failure, downstream_request_body_failure_response, + failed_to_build_request_response, finalize_streaming_request_body, finalize_upstream_success, + grpc_response_deadline, grpc_timeout_message, lookup_forward_cache, prepare_forward_request, + setup, upstream_gateway_timeout_response, upstream_tls_verify_label, + upstream_unavailable_response, wait_for_upstream_stage, +}; + use background::spawn_background_cache_refresh; use cache_lookup::resolve_forward_cache; use logging::{log_downstream_request_body_failure, log_successful_attempt}; diff --git a/crates/rginx-http/src/proxy/forward/attempt/background.rs b/crates/rginx-http/src/proxy/forward/attempt/background.rs index 13b6a699..8d4ef01d 100644 --- a/crates/rginx-http/src/proxy/forward/attempt/background.rs +++ b/crates/rginx-http/src/proxy/forward/attempt/background.rs @@ -1,15 +1,21 @@ -use super::*; +use super::{ + BodyExt, ClientAddress, DownstreamRequestContext, DownstreamRequestOptions, + DownstreamResponseContext, ForwardCacheBackend, HttpResponse, ProxyClients, ProxyTarget, + SharedState, build_downstream_response, can_retry_peer_request, + finalize_streaming_request_body, grpc_response_deadline, prepare_forward_request, setup, + wait_for_upstream_stage, +}; struct BackgroundCacheRefreshTask { - state: SharedState, - target: ProxyTarget, + cache_backend: B, + cache_store: crate::cache::CacheStoreContext, client_address: ClientAddress, - listener_id: String, downstream_proto: String, - request_id: String, + listener_id: String, options: DownstreamRequestOptions, - cache_backend: B, - cache_store: crate::cache::CacheStoreContext, + request_id: String, + state: SharedState, + target: ProxyTarget, } pub(super) fn spawn_background_cache_refresh( diff --git a/crates/rginx-http/src/proxy/forward/attempt/cache_lookup.rs b/crates/rginx-http/src/proxy/forward/attempt/cache_lookup.rs index 22d9899a..a7b0af06 100644 --- a/crates/rginx-http/src/proxy/forward/attempt/cache_lookup.rs +++ b/crates/rginx-http/src/proxy/forward/attempt/cache_lookup.rs @@ -1,4 +1,8 @@ -use super::*; +use super::{ + ClientAddress, DownstreamRequestContext, ForwardCacheBackend, ForwardCacheContext, + ForwardCacheLookup, HttpResponse, ProxyTarget, SharedState, lookup_forward_cache, + spawn_background_cache_refresh, +}; pub(super) async fn resolve_forward_cache( state: &SharedState, diff --git a/crates/rginx-http/src/proxy/forward/attempt/logging.rs b/crates/rginx-http/src/proxy/forward/attempt/logging.rs index bb618476..53fdf2c9 100644 --- a/crates/rginx-http/src/proxy/forward/attempt/logging.rs +++ b/crates/rginx-http/src/proxy/forward/attempt/logging.rs @@ -1,4 +1,6 @@ -use super::*; +use super::{ + DownstreamRequestBodyFailure, DownstreamRequestContext, ProxyTarget, ResolvedUpstreamPeer, +}; pub(super) fn log_successful_attempt( target: &ProxyTarget, @@ -21,7 +23,7 @@ pub(super) fn log_successful_attempt( upstream = %target.upstream_name, peer = %peer.display_url, logical_peer = %peer.logical_peer_url, - attempt = attempt_index + 1, + attempt = attempt_index.saturating_add(1), "upstream failover request succeeded" ); } diff --git a/crates/rginx-http/src/proxy/forward/attempt/primary.rs b/crates/rginx-http/src/proxy/forward/attempt/primary.rs index 7aa34038..784f23a2 100644 --- a/crates/rginx-http/src/proxy/forward/attempt/primary.rs +++ b/crates/rginx-http/src/proxy/forward/attempt/primary.rs @@ -1,4 +1,12 @@ -use super::*; +use super::{ + ClientAddress, DownstreamRequestContext, HOST, HttpBody, HttpResponse, ProxyTarget, Request, + SharedState, UpstreamSuccessContext, bad_gateway, can_retry_peer_request, + classify_downstream_request_body_failure, downstream_request_body_failure_response, + failed_to_build_request_response, finalize_streaming_request_body, finalize_upstream_success, + grpc_response_deadline, grpc_timeout_message, log_downstream_request_body_failure, + log_successful_attempt, prepare_forward_request, resolve_forward_cache, setup, + upstream_gateway_timeout_response, upstream_unavailable_response, wait_for_upstream_stage, +}; pub async fn forward_request( state: SharedState, @@ -144,14 +152,14 @@ pub async fn forward_request( state.record_upstream_failover(&target.upstream_name); let failure = clients.record_peer_failure(&target.upstream_name, &peer.endpoint_key); - let next_peer = &peers[attempt_index + 1]; + let next_peer = &peers[attempt_index.saturating_add(1)]; tracing::warn!( request_id = %downstream.request_id, upstream = %target.upstream_name, failed_peer = %peer.display_url, failed_logical_peer = %peer.logical_peer_url, next_peer = %next_peer.display_url, - attempt = attempt_index + 1, + attempt = attempt_index.saturating_add(1), upstream_sni_enabled = target.upstream.server_name, upstream_server_name = target.upstream.server_name_override.as_deref().unwrap_or("-"), upstream_verify = super::upstream_tls_verify_label(&target.upstream.tls), @@ -218,14 +226,14 @@ pub async fn forward_request( state.record_upstream_failover(&target.upstream_name); let failure = clients.record_peer_failure(&target.upstream_name, &peer.endpoint_key); - let next_peer = &peers[attempt_index + 1]; + let next_peer = &peers[attempt_index.saturating_add(1)]; tracing::warn!( request_id = %downstream.request_id, upstream = %target.upstream_name, failed_peer = %peer.display_url, failed_logical_peer = %peer.logical_peer_url, next_peer = %next_peer.display_url, - attempt = attempt_index + 1, + attempt = attempt_index.saturating_add(1), timeout_ms = upstream_request_timeout.as_millis() as u64, upstream_sni_enabled = target.upstream.server_name, upstream_server_name = target.upstream.server_name_override.as_deref().unwrap_or("-"), diff --git a/crates/rginx-http/src/proxy/forward/cache.rs b/crates/rginx-http/src/proxy/forward/cache.rs index d80492a5..1ec5a0ce 100644 --- a/crates/rginx-http/src/proxy/forward/cache.rs +++ b/crates/rginx-http/src/proxy/forward/cache.rs @@ -1,8 +1,13 @@ use std::future::Future; -use super::*; +use super::{HeaderMap, HttpBody, HttpResponse}; pub(super) trait ForwardCacheBackend { + fn complete_not_modified( + &self, + context: crate::cache::CacheStoreContext, + response: HttpResponse, + ) -> impl Future> + Send; fn lookup( &self, request: crate::cache::CacheRequest, @@ -15,15 +20,17 @@ pub(super) trait ForwardCacheBackend { context: crate::cache::CacheStoreContext, response: HttpResponse, ) -> impl Future + Send; +} +impl ForwardCacheBackend for crate::cache::CacheManager { fn complete_not_modified( &self, context: crate::cache::CacheStoreContext, response: HttpResponse, - ) -> impl Future> + Send; -} - -impl ForwardCacheBackend for crate::cache::CacheManager { + ) -> impl Future> + Send + { + crate::cache::CacheManager::complete_not_modified(self, context, response) + } fn lookup( &self, request: crate::cache::CacheRequest, @@ -40,26 +47,63 @@ impl ForwardCacheBackend for crate::cache::CacheManager { ) -> impl Future + Send { crate::cache::CacheManager::store_response(self, context, response) } - - fn complete_not_modified( - &self, - context: crate::cache::CacheStoreContext, - response: HttpResponse, - ) -> impl Future> + Send - { - crate::cache::CacheManager::complete_not_modified(self, context, response) - } } pub(super) struct ForwardCacheContext { - pub(super) store: Option>, pub(super) status: Option, + pub(super) store: Option>, } pub(super) enum ForwardCacheLookup { Hit(HttpResponse), - Updating(HttpResponse, Box), Proceed(Box), + Updating(HttpResponse, Box), +} + +impl ForwardCacheContext { + pub(super) fn apply_conditional_request_headers(&self, headers: &mut HeaderMap) { + if let Some(store) = + self.store.as_ref().filter(|store| store.prepares_cacheable_upstream_request()) + { + store.apply_conditional_request_headers(headers); + } + } + + pub(super) fn apply_upstream_request_headers(&self, headers: &mut HeaderMap) { + if let Some(store) = + self.store.as_ref().filter(|store| store.prepares_cacheable_upstream_request()) + { + store.apply_upstream_request_headers(headers); + } + } + + pub(super) fn apply_upstream_request_method(&self, request: &mut http::Request) { + if let Some(store) = + self.store.as_ref().filter(|store| store.prepares_cacheable_upstream_request()) + { + *request.method_mut() = store.upstream_request_method(); + } + } + + pub(super) fn mark_response(&self, response: HttpResponse) -> HttpResponse { + if let Some(status) = self.status { + crate::cache::with_cache_status(response, status) + } else { + response + } + } + + pub(super) async fn serve_stale_for_reason( + &self, + reason: crate::cache::CacheStaleReason, + status: crate::cache::CacheStatus, + ) -> Option { + let store = self.store.as_ref()?; + if !store.can_serve_stale(reason) { + return None; + } + store.serve_stale(reason, status).await + } } pub(super) async fn lookup_forward_cache( @@ -99,49 +143,3 @@ pub(super) async fn lookup_forward_cache( } } } - -impl ForwardCacheContext { - pub(super) fn mark_response(&self, response: HttpResponse) -> HttpResponse { - if let Some(status) = self.status { - crate::cache::with_cache_status(response, status) - } else { - response - } - } - - pub(super) fn apply_upstream_request_method(&self, request: &mut http::Request) { - if let Some(store) = - self.store.as_ref().filter(|store| store.prepares_cacheable_upstream_request()) - { - *request.method_mut() = store.upstream_request_method(); - } - } - - pub(super) fn apply_upstream_request_headers(&self, headers: &mut HeaderMap) { - if let Some(store) = - self.store.as_ref().filter(|store| store.prepares_cacheable_upstream_request()) - { - store.apply_upstream_request_headers(headers); - } - } - - pub(super) fn apply_conditional_request_headers(&self, headers: &mut HeaderMap) { - if let Some(store) = - self.store.as_ref().filter(|store| store.prepares_cacheable_upstream_request()) - { - store.apply_conditional_request_headers(headers); - } - } - - pub(super) async fn serve_stale_for_reason( - &self, - reason: crate::cache::CacheStaleReason, - status: crate::cache::CacheStatus, - ) -> Option { - let store = self.store.as_ref()?; - if !store.can_serve_stale(reason) { - return None; - } - store.serve_stale(reason, status).await - } -} diff --git a/crates/rginx-http/src/proxy/forward/error.rs b/crates/rginx-http/src/proxy/forward/error.rs index cb72c5b7..9714f723 100644 --- a/crates/rginx-http/src/proxy/forward/error.rs +++ b/crates/rginx-http/src/proxy/forward/error.rs @@ -1,8 +1,12 @@ +#[cfg(test)] +mod tests; use std::error::Error as StdError; use std::future::Future; use super::super::request_body::request_body_limit_error; -use super::*; +use super::{ + Duration, Error, GrpcStatusCode, HeaderMap, HttpResponse, StatusCode, grpc_error_response, +}; pub(super) fn invalid_downstream_request_body_error(error: &(dyn StdError + 'static)) -> bool { let mut current = Some(error); @@ -49,7 +53,7 @@ pub(super) fn downstream_request_body_limit(error: &(dyn StdError + 'static)) -> fn parse_request_body_limit_message(message: &str) -> Option { let prefix = "request body exceeded configured limit of "; let suffix = " bytes"; - let start = message.find(prefix)? + prefix.len(); + let start = message.find(prefix)?.saturating_add(prefix.len()); let rest = &message[start..]; let end = rest.find(suffix)?; let value = &rest[..end]; @@ -132,6 +136,3 @@ pub(super) fn unsupported_media_type(request_headers: &HeaderMap, message: Strin }, ) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/proxy/forward/error/tests.rs b/crates/rginx-http/src/proxy/forward/error/tests.rs index c79ccf74..e3ae48be 100644 --- a/crates/rginx-http/src/proxy/forward/error/tests.rs +++ b/crates/rginx-http/src/proxy/forward/error/tests.rs @@ -1,6 +1,10 @@ use std::io; -use super::*; +use super::super::failure::{ + DownstreamRequestBodyFailure, classify_downstream_request_body_failure, +}; +use super::Error; +use super::{downstream_request_body_limit, invalid_downstream_request_body_error}; #[test] fn detects_invalid_io_kinds_in_error_chain() { diff --git a/crates/rginx-http/src/proxy/forward/failure.rs b/crates/rginx-http/src/proxy/forward/failure.rs index 48f08f8d..67bd384c 100644 --- a/crates/rginx-http/src/proxy/forward/failure.rs +++ b/crates/rginx-http/src/proxy/forward/failure.rs @@ -1,4 +1,14 @@ -use super::*; +use super::{ + DownstreamRequestContext, ForwardCacheContext, HeaderMap, HttpResponse, ProxyTarget, + ResolvedUpstreamPeer, SharedState, bad_gateway, bad_request, downstream_request_body_limit, + gateway_timeout, invalid_downstream_request_body_error, payload_too_large, + unsupported_media_type, +}; + +pub(super) enum DownstreamRequestBodyFailure { + Invalid, + PayloadTooLarge { max_request_body_bytes: usize }, +} pub(super) fn upstream_bad_gateway_response( state: &SharedState, @@ -34,11 +44,6 @@ pub(super) fn upstream_invalid_request_body_response( ) } -pub(super) enum DownstreamRequestBodyFailure { - PayloadTooLarge { max_request_body_bytes: usize }, - Invalid, -} - pub(super) fn classify_downstream_request_body_failure( error: &(dyn std::error::Error + 'static), ) -> Option { diff --git a/crates/rginx-http/src/proxy/forward/grpc.rs b/crates/rginx-http/src/proxy/forward/grpc.rs index 2f214e79..ef746c6a 100644 --- a/crates/rginx-http/src/proxy/forward/grpc.rs +++ b/crates/rginx-http/src/proxy/forward/grpc.rs @@ -1,5 +1,9 @@ use super::super::grpc_web::{GrpcWebEncoding, GrpcWebMode}; -use super::*; +use super::{ + CONTENT_TYPE, Duration, GRPC_CONTENT_TYPE_PREFIX, GRPC_TIMEOUT_HEADER, + GRPC_WEB_CONTENT_TYPE_PREFIX, GRPC_WEB_TEXT_CONTENT_TYPE_PREFIX, HeaderMap, HeaderValue, + MAX_GRPC_TIMEOUT_DIGITS, TokioInstant, split_content_type, +}; pub(crate) fn detect_grpc_web_mode( headers: &HeaderMap, @@ -69,7 +73,9 @@ pub(super) fn grpc_response_deadline( upstream_request_timeout: Duration, ) -> Option { grpc_protocol_request(request_headers).then(|| super::response::GrpcResponseDeadline { - deadline: TokioInstant::now() + upstream_request_timeout, + deadline: TokioInstant::now() + .checked_add(upstream_request_timeout) + .expect("gRPC response deadline remains representable"), timeout: upstream_request_timeout, timeout_message: super::error::grpc_timeout_message( upstream_name, @@ -104,7 +110,7 @@ pub(crate) fn parse_grpc_timeout(headers: &HeaderMap) -> Result )); } - let (amount, unit) = value.split_at(value.len() - 1); + let (amount, unit) = value.split_at(value.len().saturating_sub(1)); if amount.is_empty() || amount.len() > MAX_GRPC_TIMEOUT_DIGITS || !amount.bytes().all(|byte| byte.is_ascii_digit()) diff --git a/crates/rginx-http/src/proxy/forward/mod.rs b/crates/rginx-http/src/proxy/forward/mod.rs index c7cd4ca7..311ccaf9 100644 --- a/crates/rginx-http/src/proxy/forward/mod.rs +++ b/crates/rginx-http/src/proxy/forward/mod.rs @@ -1,9 +1,3 @@ -use super::request_body::{ - PrepareRequestError, PreparedProxyRequest, StreamingBodyCompletion, can_retry_peer_request, -}; -use super::upgrade::proxy_upgraded_connection; -use super::*; - mod attempt; mod cache; mod error; @@ -15,6 +9,12 @@ mod streaming; mod success; mod types; +use super::request_body::{ + PrepareRequestError, PreparedProxyRequest, StreamingBodyCompletion, can_retry_peer_request, +}; +use super::upgrade::proxy_upgraded_connection; +use super::*; + use cache::{ForwardCacheBackend, ForwardCacheContext, ForwardCacheLookup, lookup_forward_cache}; use error::{ bad_gateway, bad_request, downstream_request_body_limit, gateway_timeout, grpc_timeout_message, diff --git a/crates/rginx-http/src/proxy/forward/response.rs b/crates/rginx-http/src/proxy/forward/response.rs index 1f0be2b2..82133cac 100644 --- a/crates/rginx-http/src/proxy/forward/response.rs +++ b/crates/rginx-http/src/proxy/forward/response.rs @@ -1,8 +1,14 @@ +#[cfg(test)] +mod tests; use super::super::grpc_web::{ GrpcWebMode, GrpcWebResponseBody, GrpcWebTextEncodeBody, extract_grpc_initial_trailers, }; use super::health::{ActivePeerBody, ActivePeerGuard}; -use super::*; +use super::{ + BodyExt, BoxError, Bytes, CONTENT_LENGTH, CONTENT_TYPE, Duration, GrpcDeadlineBody, HeaderMap, + HeaderValue, HttpBody, HttpResponse, IdleTimeoutBody, ProxyTarget, ResolvedUpstreamPeer, + Response, TokioInstant, full_body, is_upgrade_response, sanitize_response_headers, +}; use rginx_core::ProxyRedirectMode; #[derive(Debug, Clone)] @@ -13,14 +19,14 @@ pub(super) struct GrpcResponseDeadline { } pub(super) struct DownstreamResponseContext<'a> { - pub(super) target: &'a ProxyTarget, - pub(super) peer: &'a ResolvedUpstreamPeer, - pub(super) downstream_scheme: &'a str, + pub(super) active_peer: Option, pub(super) downstream_host: Option<&'a HeaderValue>, - pub(super) idle_timeout: Duration, + pub(super) downstream_scheme: &'a str, pub(super) grpc_response_deadline: Option, pub(super) grpc_web_mode: Option<&'a GrpcWebMode>, - pub(super) active_peer: Option, + pub(super) idle_timeout: Duration, + pub(super) peer: &'a ResolvedUpstreamPeer, + pub(super) target: &'a ProxyTarget, pub(super) upstream_access_log: Option, } @@ -162,6 +168,3 @@ where body.boxed_unsync() } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/proxy/forward/setup.rs b/crates/rginx-http/src/proxy/forward/setup.rs index e8637ccd..cafd11af 100644 --- a/crates/rginx-http/src/proxy/forward/setup.rs +++ b/crates/rginx-http/src/proxy/forward/setup.rs @@ -1,24 +1,33 @@ -use super::*; +use super::{ + ClientAddress, DownstreamRequestContext, Duration, GrpcWebMode, HeaderMap, HttpBody, + HttpResponse, MAX_FAILOVER_ATTEMPTS, OnUpgrade, PrepareRequestError, PreparedProxyRequest, + ProxyClients, ProxyTarget, Request, ResolvedUpstreamPeer, SharedState, UpstreamProtocol, + detect_grpc_web_mode, effective_upstream_request_timeout, + invalid_downstream_request_body_error, is_upgrade_request, upstream_bad_gateway_response, + upstream_bad_request_response, upstream_invalid_request_body_response, + upstream_no_healthy_peers_response, upstream_payload_too_large_response, + upstream_unsupported_media_type_response, +}; use crate::proxy::clients::ProxyClient; use crate::proxy::request_body::PrepareProxyRequestOptions; type PreparedResult = Result; struct PreparedForwardPolicy { - response_idle_timeout: Duration, grpc_web_mode: Option, + response_idle_timeout: Duration, upstream_request_timeout: Duration, } pub(super) struct PreparedForwardRequest { - pub(super) request_headers: HeaderMap, - pub(super) response_idle_timeout: Duration, - pub(super) grpc_web_mode: Option, - pub(super) upstream_request_timeout: Duration, pub(super) client: ProxyClient, pub(super) downstream_upgrade: Option, - pub(super) prepared_request: PreparedProxyRequest, + pub(super) grpc_web_mode: Option, pub(super) peers: Vec, + pub(super) prepared_request: PreparedProxyRequest, + pub(super) request_headers: HeaderMap, + pub(super) response_idle_timeout: Duration, + pub(super) upstream_request_timeout: Duration, } pub(super) async fn prepare_forward_request( @@ -67,7 +76,7 @@ pub(super) async fn prepare_forward_request( }) } -#[allow( +#[expect( clippy::result_large_err, reason = "keep synchronous setup helpers consistent with HttpResponse-returning async forwarding code" )] @@ -98,10 +107,10 @@ fn prepare_forward_policy( ) })?; - Ok(PreparedForwardPolicy { response_idle_timeout, grpc_web_mode, upstream_request_timeout }) + Ok(PreparedForwardPolicy { grpc_web_mode, response_idle_timeout, upstream_request_timeout }) } -#[allow( +#[expect( clippy::result_large_err, reason = "keep synchronous setup helpers consistent with HttpResponse-returning async forwarding code" )] @@ -135,7 +144,7 @@ fn prepare_proxy_client( } } -#[allow( +#[expect( clippy::result_large_err, reason = "keep synchronous setup helpers consistent with HttpResponse-returning async forwarding code" )] diff --git a/crates/rginx-http/src/proxy/forward/streaming.rs b/crates/rginx-http/src/proxy/forward/streaming.rs index 9fcf3a36..dd1a5b3f 100644 --- a/crates/rginx-http/src/proxy/forward/streaming.rs +++ b/crates/rginx-http/src/proxy/forward/streaming.rs @@ -1,4 +1,8 @@ -use super::*; +use super::{ + BoxError, DownstreamRequestBodyFailure, DownstreamRequestContext, HeaderMap, HttpResponse, + ProxyTarget, ResolvedUpstreamPeer, SharedState, StreamingBodyCompletion, bad_gateway, + classify_downstream_request_body_failure, downstream_request_body_failure_response, +}; pub(super) async fn finalize_streaming_request_body( body_completion: Option, diff --git a/crates/rginx-http/src/proxy/forward/success.rs b/crates/rginx-http/src/proxy/forward/success.rs index 250cdca3..fe5377cf 100644 --- a/crates/rginx-http/src/proxy/forward/success.rs +++ b/crates/rginx-http/src/proxy/forward/success.rs @@ -2,20 +2,20 @@ use super::*; use http::HeaderValue; pub(super) struct UpstreamSuccessContext<'a, B: ForwardCacheBackend + ?Sized> { - pub(super) state: &'a SharedState, + pub(super) active_peer: super::super::health::ActivePeerGuard, + pub(super) cache_backend: &'a B, + pub(super) cache_status: Option, + pub(super) cache_store: Option>, + pub(super) downstream_host: Option<&'a HeaderValue>, + pub(super) downstream_scheme: &'a str, pub(super) downstream_upgrade: Option, + pub(super) grpc_response_deadline: Option, + pub(super) grpc_web_mode: Option<&'a GrpcWebMode>, pub(super) listener_id: &'a str, - pub(super) downstream_scheme: &'a str, - pub(super) downstream_host: Option<&'a HeaderValue>, - pub(super) target: &'a ProxyTarget, pub(super) peer: &'a ResolvedUpstreamPeer, - pub(super) active_peer: super::super::health::ActivePeerGuard, pub(super) response_idle_timeout: Duration, - pub(super) grpc_response_deadline: Option, - pub(super) grpc_web_mode: Option<&'a GrpcWebMode>, - pub(super) cache_backend: &'a B, - pub(super) cache_store: Option>, - pub(super) cache_status: Option, + pub(super) state: &'a SharedState, + pub(super) target: &'a ProxyTarget, pub(super) upstream_response_time_ms: u64, } diff --git a/crates/rginx-http/src/proxy/forward/types.rs b/crates/rginx-http/src/proxy/forward/types.rs index 21b12a87..df17a272 100644 --- a/crates/rginx-http/src/proxy/forward/types.rs +++ b/crates/rginx-http/src/proxy/forward/types.rs @@ -1,18 +1,18 @@ -use super::*; +use super::{Duration, RouteBufferingPolicy}; #[derive(Debug, Clone)] pub struct DownstreamRequestOptions { - pub request_body_read_timeout: Option, + pub cache: Option, pub max_request_body_bytes: Option, + pub request_body_read_timeout: Option, pub request_buffering: RouteBufferingPolicy, pub streaming_response_idle_timeout: Option, - pub cache: Option, } #[derive(Debug, Clone)] pub struct DownstreamRequestContext<'a> { - pub listener_id: &'a str, pub downstream_proto: &'a str, - pub request_id: &'a str, + pub listener_id: &'a str, pub options: DownstreamRequestOptions, + pub request_id: &'a str, } diff --git a/crates/rginx-http/src/proxy/grpc_web/body/request.rs b/crates/rginx-http/src/proxy/grpc_web/body/request.rs index 1f07119c..ce47fd25 100644 --- a/crates/rginx-http/src/proxy/grpc_web/body/request.rs +++ b/crates/rginx-http/src/proxy/grpc_web/body/request.rs @@ -38,6 +38,14 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.inner_finished + && self.inner_trailers.is_none() + && self.pending_data.is_none() + && self.pending_trailers.is_none() + && self.buffer.is_empty() + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -120,14 +128,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.inner_finished - && self.inner_trailers.is_none() - && self.pending_data.is_none() - && self.pending_trailers.is_none() - && self.buffer.is_empty() - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } diff --git a/crates/rginx-http/src/proxy/grpc_web/body/response.rs b/crates/rginx-http/src/proxy/grpc_web/body/response.rs index 44dfcdf3..99e200de 100644 --- a/crates/rginx-http/src/proxy/grpc_web/body/response.rs +++ b/crates/rginx-http/src/proxy/grpc_web/body/response.rs @@ -25,6 +25,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.inner_finished && self.pending_frame.is_none() + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -76,10 +80,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.inner_finished && self.pending_frame.is_none() - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } diff --git a/crates/rginx-http/src/proxy/grpc_web/body/text_decode.rs b/crates/rginx-http/src/proxy/grpc_web/body/text_decode.rs index dcf9ec66..5b7d57a6 100644 --- a/crates/rginx-http/src/proxy/grpc_web/body/text_decode.rs +++ b/crates/rginx-http/src/proxy/grpc_web/body/text_decode.rs @@ -34,6 +34,13 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.inner_finished + && self.pending_data.is_none() + && self.pending_trailers.is_none() + && self.carryover.is_empty() + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -97,13 +104,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.inner_finished - && self.pending_data.is_none() - && self.pending_trailers.is_none() - && self.carryover.is_empty() - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } diff --git a/crates/rginx-http/src/proxy/grpc_web/body/text_encode.rs b/crates/rginx-http/src/proxy/grpc_web/body/text_encode.rs index fd10a859..5ba41b5f 100644 --- a/crates/rginx-http/src/proxy/grpc_web/body/text_encode.rs +++ b/crates/rginx-http/src/proxy/grpc_web/body/text_encode.rs @@ -28,6 +28,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.inner_finished && self.pending_data.is_none() && self.carryover.is_empty() + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -82,10 +86,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.inner_finished && self.pending_data.is_none() && self.carryover.is_empty() - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } diff --git a/crates/rginx-http/src/proxy/grpc_web/codec.rs b/crates/rginx-http/src/proxy/grpc_web/codec.rs index 7048e9f9..200f265c 100644 --- a/crates/rginx-http/src/proxy/grpc_web/codec.rs +++ b/crates/rginx-http/src/proxy/grpc_web/codec.rs @@ -1,14 +1,14 @@ use super::*; -pub(crate) fn invalid_grpc_web_body(message: &str) -> BoxError { - std::io::Error::new(std::io::ErrorKind::InvalidData, message).into() -} - pub(crate) enum ParsedGrpcWebRequestFrame { Data(Bytes), Trailers(HeaderMap), } +pub(crate) fn invalid_grpc_web_body(message: &str) -> BoxError { + std::io::Error::new(std::io::ErrorKind::InvalidData, message).into() +} + pub(crate) fn decode_grpc_web_request_frame( buffer: &mut BytesMut, inner_finished: bool, @@ -86,7 +86,8 @@ pub(crate) fn decode_grpc_web_text_chunk( } } - let complete_len = carryover.len() / 4 * 4; + let complete_len = + carryover.len().checked_div(4).expect("base64 block size is nonzero").saturating_mul(4); if complete_len == 0 { return Ok(None); } @@ -128,7 +129,8 @@ pub(crate) fn decode_grpc_web_text_final( pub(crate) fn encode_grpc_web_text_chunk(carryover: &mut BytesMut, data: &[u8]) -> Option { carryover.extend_from_slice(data); - let complete_len = carryover.len() / 3 * 3; + let complete_len = + carryover.len().checked_div(3).expect("base64 raw block size is nonzero").saturating_mul(3); if complete_len == 0 { return None; } @@ -167,7 +169,7 @@ pub(crate) fn encode_grpc_web_trailers(trailers: &HeaderMap) -> Bytes { trailer_block.extend_from_slice(b"\r\n"); } - let mut encoded = Vec::with_capacity(5 + trailer_block.len()); + let mut encoded = Vec::with_capacity(trailer_block.len().saturating_add(5)); encoded.push(0x80); encoded.extend_from_slice(&(trailer_block.len() as u32).to_be_bytes()); encoded.extend_from_slice(&trailer_block); diff --git a/crates/rginx-http/src/proxy/grpc_web/mod.rs b/crates/rginx-http/src/proxy/grpc_web/mod.rs index ba99c664..93273b69 100644 --- a/crates/rginx-http/src/proxy/grpc_web/mod.rs +++ b/crates/rginx-http/src/proxy/grpc_web/mod.rs @@ -1,8 +1,8 @@ -use super::*; - mod body; mod codec; +use super::*; + pub(super) use body::{ GrpcWebRequestBody, GrpcWebResponseBody, GrpcWebTextDecodeBody, GrpcWebTextEncodeBody, }; @@ -16,8 +16,8 @@ pub(super) use codec::{ #[derive(Debug, Clone)] pub(crate) struct GrpcWebMode { pub downstream_content_type: HeaderValue, - pub upstream_content_type: HeaderValue, pub encoding: GrpcWebEncoding, + pub upstream_content_type: HeaderValue, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] diff --git a/crates/rginx-http/src/proxy/health.rs b/crates/rginx-http/src/proxy/health.rs index 8695996a..b4549eca 100644 --- a/crates/rginx-http/src/proxy/health.rs +++ b/crates/rginx-http/src/proxy/health.rs @@ -1,17 +1,24 @@ //! Upstream health probing and health-state exports for proxy routing. -use super::*; - mod active_probe; mod grpc_health_codec; mod registry; mod request; +use super::{ + ActiveHealthCheck, Arc, Bytes, CONTENT_LENGTH, CONTENT_TYPE, Error, HOST, HeaderValue, + HttpBody, Method, Request, ResolvedUpstreamPeer, TE, Upstream, UpstreamPeer, UpstreamProtocol, + Uri, build_proxy_uri, classify_upstream_tls_failure, clients, full_body, + remove_redundant_host_header_for_authority_pseudo_header, upstream_request_version, + upstream_tls_verify_label, +}; + pub use active_probe::probe_upstream_peer; -#[allow(unused_imports)] +#[cfg(test)] +pub(crate) use grpc_health_codec::decode_grpc_health_check_response; pub(crate) use grpc_health_codec::{ - GrpcHealthProbeResult, GrpcHealthServingStatus, decode_grpc_health_check_response, - encode_grpc_health_check_request, evaluate_grpc_health_probe_response, + GrpcHealthProbeResult, GrpcHealthServingStatus, encode_grpc_health_check_request, + evaluate_grpc_health_probe_response, }; pub(crate) use registry::{ ActivePeerBody, ActivePeerGuard, ActiveProbeStatus, PeerFailureStatus, PeerHealthRegistry, diff --git a/crates/rginx-http/src/proxy/health/active_probe.rs b/crates/rginx-http/src/proxy/health/active_probe.rs index 7ed509f1..c6cfb18a 100644 --- a/crates/rginx-http/src/proxy/health/active_probe.rs +++ b/crates/rginx-http/src/proxy/health/active_probe.rs @@ -1,5 +1,8 @@ use super::clients::ProxyClients; -use super::*; +use super::{ + Arc, GrpcHealthProbeResult, GrpcHealthServingStatus, Upstream, UpstreamPeer, + evaluate_grpc_health_probe_response, +}; pub async fn probe_upstream_peer( clients: ProxyClients, diff --git a/crates/rginx-http/src/proxy/health/grpc_health_codec.rs b/crates/rginx-http/src/proxy/health/grpc_health_codec.rs index 83ad6daa..45bdc45a 100644 --- a/crates/rginx-http/src/proxy/health/grpc_health_codec.rs +++ b/crates/rginx-http/src/proxy/health/grpc_health_codec.rs @@ -8,11 +8,11 @@ const GRPC_HEALTH_SERVING_STATUS_SERVING: u64 = 1; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum GrpcHealthServingStatus { - Unknown, - Serving, NotServing, - ServiceUnknown, Other(u64), + ServiceUnknown, + Serving, + Unknown, } impl GrpcHealthServingStatus { @@ -33,12 +33,12 @@ impl GrpcHealthServingStatus { #[derive(Debug)] pub(crate) enum GrpcHealthProbeResult { - Serving, NotServing { http_status: StatusCode, grpc_status: Option, serving_status: Option, }, + Serving, } pub(crate) async fn evaluate_grpc_health_probe_response( @@ -122,7 +122,7 @@ pub(crate) fn encode_grpc_health_check_request(service: &str) -> Bytes { payload.extend_from_slice(service.as_bytes()); } - let mut frame = BytesMut::with_capacity(5 + payload.len()); + let mut frame = BytesMut::with_capacity(payload.len().saturating_add(5)); frame.extend_from_slice(&[0]); frame.extend_from_slice(&(payload.len() as u32).to_be_bytes()); frame.extend_from_slice(&payload); @@ -153,7 +153,7 @@ fn decode_grpc_frame_payload(body: &[u8]) -> Result<&[u8], BoxError> { } let len = u32::from_be_bytes([body[1], body[2], body[3], body[4]]) as usize; - let expected_len = 5 + len; + let expected_len = len.saturating_add(5); if body.len() != expected_len { return Err(invalid_grpc_health_probe(format!( "gRPC health response frame length mismatch: expected {expected_len} bytes, got {}", @@ -207,14 +207,14 @@ fn decode_protobuf_varint(payload: &[u8], index: &mut usize) -> Result= 64 { return Err(invalid_grpc_health_probe("protobuf varint is too large")); } diff --git a/crates/rginx-http/src/proxy/health/registry/guards.rs b/crates/rginx-http/src/proxy/health/registry/guards.rs index 5f1e0152..bc0aab8e 100644 --- a/crates/rginx-http/src/proxy/health/registry/guards.rs +++ b/crates/rginx-http/src/proxy/health/registry/guards.rs @@ -7,9 +7,17 @@ use crate::proxy::HealthChangeNotifier; use super::PeerHealth; +pin_project! { + pub(crate) struct ActivePeerBody { + #[pin] + inner: B, + guard: Option, + } +} + pub(crate) struct ActivePeerGuard { - pub(super) peer: Option>, pub(super) notifier: Option, + pub(super) peer: Option>, pub(super) upstream_name: String, } @@ -24,14 +32,6 @@ impl Drop for ActivePeerGuard { } } -pin_project! { - pub(crate) struct ActivePeerBody { - #[pin] - inner: B, - guard: Option, - } -} - impl ActivePeerBody { pub(crate) fn new(inner: B, guard: ActivePeerGuard) -> Self { Self { inner, guard: Some(guard) } @@ -46,6 +46,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.inner.is_end_stream() + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -65,10 +69,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.inner.is_end_stream() - } - fn size_hint(&self) -> SizeHint { self.inner.size_hint() } diff --git a/crates/rginx-http/src/proxy/health/registry/mod.rs b/crates/rginx-http/src/proxy/health/registry/mod.rs index 2d0e6459..b74324cd 100644 --- a/crates/rginx-http/src/proxy/health/registry/mod.rs +++ b/crates/rginx-http/src/proxy/health/registry/mod.rs @@ -1,5 +1,13 @@ //! Per-upstream passive and active health state used by peer selection. +mod guards; +mod policy; +mod selection; +mod snapshot; +mod state; +#[cfg(test)] +mod tests; + use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; @@ -10,21 +18,13 @@ use serde::{Deserialize, Serialize}; use crate::proxy::clients::ProxyClient; use crate::proxy::{HealthChangeNotifier, ResolvedUpstreamPeer, UpstreamResolverRuntimeSnapshot}; -mod guards; -mod policy; -mod selection; -mod snapshot; -mod state; -#[cfg(test)] -mod tests; - pub(crate) use guards::{ActivePeerBody, ActivePeerGuard}; #[derive(Debug, Clone, Copy)] struct PeerHealthPolicy { - unhealthy_after_failures: u32, - cooldown: Duration, active_health_enabled: bool, + cooldown: Duration, + unhealthy_after_failures: u32, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -36,28 +36,28 @@ pub(crate) struct PeerFailureStatus { #[derive(Debug, Default)] struct PassiveHealthState { consecutive_failures: u32, - unhealthy_until: Option, pending_recovery: bool, + unhealthy_until: Option, } #[derive(Debug, Default)] struct ActiveHealthState { - unhealthy: bool, consecutive_successes: u32, + unhealthy: bool, } #[derive(Debug, Default)] struct PeerHealthState { - passive: PassiveHealthState, active: ActiveHealthState, active_requests: u64, + passive: PassiveHealthState, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) struct ActiveProbeStatus { + pub consecutive_successes: u32, pub healthy: bool, pub recovered: bool, - pub consecutive_successes: u32, } #[derive(Debug, Default)] @@ -70,53 +70,53 @@ type UpstreamPeerHealthMap = HashMap; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct PeerHealthSnapshot { - pub peer_url: String, - pub backup: bool, - pub weight: u32, + pub active_consecutive_successes: u32, + pub active_requests: u64, + pub active_unhealthy: bool, pub available: bool, + pub backup: bool, pub passive_consecutive_failures: u32, pub passive_cooldown_remaining_ms: Option, pub passive_pending_recovery: bool, - pub active_unhealthy: bool, - pub active_consecutive_successes: u32, - pub active_requests: u64, + pub peer_url: String, + pub weight: u32, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ResolvedEndpointHealthSnapshot { + pub active_consecutive_successes: u32, + pub active_requests: u64, + pub active_unhealthy: bool, + pub available: bool, + pub backup: bool, + pub dial_addr: String, + pub display_url: String, pub endpoint_key: String, pub logical_peer_url: String, - pub display_url: String, - pub dial_addr: String, - pub server_name: String, - pub backup: bool, - pub weight: u32, - pub available: bool, pub passive_consecutive_failures: u32, pub passive_cooldown_remaining_ms: Option, pub passive_pending_recovery: bool, - pub active_unhealthy: bool, - pub active_consecutive_successes: u32, - pub active_requests: u64, + pub server_name: String, + pub weight: u32, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UpstreamHealthSnapshot { - pub upstream_name: String, - pub unhealthy_after_failures: u32, - pub cooldown_ms: u64, pub active_health_enabled: bool, - pub resolver: UpstreamResolverRuntimeSnapshot, - pub peers: Vec, + pub cooldown_ms: u64, pub endpoints: Vec, + pub peers: Vec, + pub resolver: UpstreamResolverRuntimeSnapshot, + pub unhealthy_after_failures: u32, + pub upstream_name: String, } #[derive(Clone)] pub(crate) struct PeerHealthRegistry { - policies: Arc>, - peers: Arc, endpoint_peers: Arc>, notifier: Option, + peers: Arc, + policies: Arc>, } pub(crate) struct SelectedPeers { diff --git a/crates/rginx-http/src/proxy/health/registry/policy.rs b/crates/rginx-http/src/proxy/health/registry/policy.rs index f25f5641..0d4da6c0 100644 --- a/crates/rginx-http/src/proxy/health/registry/policy.rs +++ b/crates/rginx-http/src/proxy/health/registry/policy.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{PeerHealthPolicy, Upstream}; impl PeerHealthPolicy { pub(super) fn from_upstream(upstream: &Upstream) -> Self { diff --git a/crates/rginx-http/src/proxy/health/registry/selection.rs b/crates/rginx-http/src/proxy/health/registry/selection.rs index 74e6027b..99dbeda0 100644 --- a/crates/rginx-http/src/proxy/health/registry/selection.rs +++ b/crates/rginx-http/src/proxy/health/registry/selection.rs @@ -1,6 +1,94 @@ -use super::*; +use super::{ + PeerHealthRegistry, ProxyClient, ResolvedUpstreamPeer, SelectedPeers, Upstream, + UpstreamLoadBalance, UpstreamPeer, +}; impl PeerHealthRegistry { + fn endpoint_is_selectable( + &self, + upstream_name: &str, + endpoint: &ResolvedUpstreamPeer, + active_requests: u64, + ) -> bool { + if !self.is_available(upstream_name, &endpoint.endpoint_key) { + return false; + } + + endpoint.max_conns.is_none_or(|max_conns| active_requests < max_conns as u64) + } + + async fn select_available_peers( + &self, + client: &ProxyClient, + upstream: &Upstream, + ordered: Vec, + limit: usize, + ) -> SelectedPeers { + if limit == 0 { + return SelectedPeers { peers: Vec::new(), skipped_unhealthy: 0 }; + } + + let mut batches = Vec::new(); + let mut skipped_unhealthy = 0usize; + + for peer in ordered { + let mut endpoints: Vec = + client.resolve_peer(&peer).await.unwrap_or_default(); + for endpoint in &endpoints { + self.ensure_endpoint( + &upstream.name, + &endpoint.endpoint_key, + &endpoint.logical_peer_url, + ); + } + endpoints.sort_by(|left, right| { + self.active_requests(&upstream.name, &left.logical_peer_url) + .cmp(&self.active_requests(&upstream.name, &right.logical_peer_url)) + .then(left.dial_authority.cmp(&right.dial_authority)) + }); + + let mut available = Vec::new(); + for endpoint in endpoints { + self.ensure_endpoint( + &upstream.name, + &endpoint.endpoint_key, + &endpoint.logical_peer_url, + ); + let active_requests = + self.active_requests(&upstream.name, &endpoint.logical_peer_url); + if self.endpoint_is_selectable(&upstream.name, &endpoint, active_requests) { + available.push(endpoint); + } else { + skipped_unhealthy = skipped_unhealthy.saturating_add(1); + } + } + if !available.is_empty() { + batches.push(available); + } + } + + let mut selected = Vec::new(); + let mut depth = 0usize; + while selected.len() < limit { + let mut advanced = false; + for batch in &batches { + if let Some(endpoint) = batch.get(depth) { + selected.push(endpoint.clone()); + advanced = true; + if selected.len() == limit { + break; + } + } + } + if !advanced { + break; + } + depth = depth.saturating_add(1); + } + + SelectedPeers { peers: selected, skipped_unhealthy } + } + pub(crate) async fn select_peers( &self, client: &ProxyClient, @@ -32,7 +120,7 @@ impl PeerHealthRegistry { return primary; } - let remaining = limit - primary.peers.len(); + let remaining = limit.saturating_sub(primary.peers.len()); merge_selected_peers( primary, self.select_peers_in_pool(client, upstream, client_ip, remaining, true).await, @@ -65,30 +153,13 @@ impl PeerHealthRegistry { return primary; } - let remaining = limit - primary.peers.len(); + let remaining = limit.saturating_sub(primary.peers.len()); merge_selected_peers( primary, self.select_peers_by_least_conn_in_pool(client, upstream, remaining, true).await, ) } - async fn select_peers_in_pool( - &self, - client: &ProxyClient, - upstream: &Upstream, - client_ip: std::net::IpAddr, - limit: usize, - backup: bool, - ) -> SelectedPeers { - let ordered = if backup { - upstream.backup_peers_for_client_ip(client_ip, upstream.peers.len()) - } else { - upstream.primary_peers_for_client_ip(client_ip, upstream.peers.len()) - }; - - self.select_available_peers(client, upstream, ordered, limit).await - } - async fn select_peers_by_least_conn_in_pool( &self, client: &ProxyClient, @@ -97,7 +168,7 @@ impl PeerHealthRegistry { backup: bool, ) -> SelectedPeers { let mut available = Vec::new(); - let mut skipped_unhealthy = 0; + let mut skipped_unhealthy = 0usize; for (order, peer) in upstream.peers.iter().cloned().enumerate() { if peer.backup != backup { @@ -117,7 +188,7 @@ impl PeerHealthRegistry { if self.endpoint_is_selectable(&upstream.name, &endpoint, active_requests) { available.push((active_requests, order, endpoint)); } else { - skipped_unhealthy += 1; + skipped_unhealthy = skipped_unhealthy.saturating_add(1); } } } @@ -135,94 +206,27 @@ impl PeerHealthRegistry { } } - async fn select_available_peers( + async fn select_peers_in_pool( &self, client: &ProxyClient, upstream: &Upstream, - ordered: Vec, + client_ip: std::net::IpAddr, limit: usize, + backup: bool, ) -> SelectedPeers { - if limit == 0 { - return SelectedPeers { peers: Vec::new(), skipped_unhealthy: 0 }; - } - - let mut batches = Vec::new(); - let mut skipped_unhealthy = 0; - - for peer in ordered { - let mut endpoints: Vec = - client.resolve_peer(&peer).await.unwrap_or_default(); - for endpoint in &endpoints { - self.ensure_endpoint( - &upstream.name, - &endpoint.endpoint_key, - &endpoint.logical_peer_url, - ); - } - endpoints.sort_by(|left, right| { - self.active_requests(&upstream.name, &left.logical_peer_url) - .cmp(&self.active_requests(&upstream.name, &right.logical_peer_url)) - .then(left.dial_authority.cmp(&right.dial_authority)) - }); - - let mut available = Vec::new(); - for endpoint in endpoints { - self.ensure_endpoint( - &upstream.name, - &endpoint.endpoint_key, - &endpoint.logical_peer_url, - ); - let active_requests = - self.active_requests(&upstream.name, &endpoint.logical_peer_url); - if self.endpoint_is_selectable(&upstream.name, &endpoint, active_requests) { - available.push(endpoint); - } else { - skipped_unhealthy += 1; - } - } - if !available.is_empty() { - batches.push(available); - } - } - - let mut selected = Vec::new(); - let mut depth = 0usize; - while selected.len() < limit { - let mut advanced = false; - for batch in &batches { - if let Some(endpoint) = batch.get(depth) { - selected.push(endpoint.clone()); - advanced = true; - if selected.len() == limit { - break; - } - } - } - if !advanced { - break; - } - depth += 1; - } - - SelectedPeers { peers: selected, skipped_unhealthy } - } - - fn endpoint_is_selectable( - &self, - upstream_name: &str, - endpoint: &ResolvedUpstreamPeer, - active_requests: u64, - ) -> bool { - if !self.is_available(upstream_name, &endpoint.endpoint_key) { - return false; - } + let ordered = if backup { + upstream.backup_peers_for_client_ip(client_ip, upstream.peers.len()) + } else { + upstream.primary_peers_for_client_ip(client_ip, upstream.peers.len()) + }; - endpoint.max_conns.is_none_or(|max_conns| active_requests < max_conns as u64) + self.select_available_peers(client, upstream, ordered, limit).await } } fn merge_selected_peers(mut primary: SelectedPeers, secondary: SelectedPeers) -> SelectedPeers { - primary.skipped_unhealthy += secondary.skipped_unhealthy; + primary.skipped_unhealthy = + primary.skipped_unhealthy.saturating_add(secondary.skipped_unhealthy); primary.peers.extend(secondary.peers); primary } @@ -233,8 +237,9 @@ fn projected_least_conn_load( right_active_requests: u64, right_weight: u32, ) -> std::cmp::Ordering { - let left = u128::from(left_active_requests.saturating_add(1)) * u128::from(right_weight.max(1)); - let right = - u128::from(right_active_requests.saturating_add(1)) * u128::from(left_weight.max(1)); + let left = u128::from(left_active_requests.saturating_add(1)) + .saturating_mul(u128::from(right_weight.max(1))); + let right = u128::from(right_active_requests.saturating_add(1)) + .saturating_mul(u128::from(left_weight.max(1))); left.cmp(&right) } diff --git a/crates/rginx-http/src/proxy/health/registry/snapshot.rs b/crates/rginx-http/src/proxy/health/registry/snapshot.rs index dba7a64d..6d4ee64d 100644 --- a/crates/rginx-http/src/proxy/health/registry/snapshot.rs +++ b/crates/rginx-http/src/proxy/health/registry/snapshot.rs @@ -1,4 +1,8 @@ -use super::*; +use super::{ + Instant, PeerHealth, PeerHealthPolicy, PeerHealthRegistry, PeerHealthSnapshot, + ResolvedEndpointHealthSnapshot, ResolvedUpstreamPeer, Upstream, UpstreamHealthSnapshot, + UpstreamPeer, UpstreamResolverRuntimeSnapshot, +}; use super::state::lock_peer_health; @@ -18,9 +22,10 @@ impl PeerHealthRegistry { let endpoint_snapshots = endpoints .iter() .map(|endpoint| { - self.get_health(&upstream.name, &endpoint.endpoint_key) - .map(|health| health.snapshot_endpoint(endpoint)) - .unwrap_or_else(|| default_endpoint_snapshot(endpoint)) + self.get_health(&upstream.name, &endpoint.endpoint_key).map_or_else( + || default_endpoint_snapshot(endpoint), + |health| health.snapshot_endpoint(endpoint), + ) }) .collect::>(); @@ -33,14 +38,16 @@ impl PeerHealthRegistry { .filter(|endpoint| endpoint.logical_peer_url == peer.url) .cloned() .collect::>(); - if !peer_endpoints.is_empty() { - aggregate_peer_snapshot(peer, &peer_endpoints) - } else { + if peer_endpoints.is_empty() { self.peers .get(&upstream.name) .and_then(|upstream_peers| upstream_peers.get(&peer.url)) - .map(|health| health.snapshot(peer)) - .unwrap_or_else(|| default_peer_snapshot(peer, !peer_is_hostname(peer))) + .map_or_else( + || default_peer_snapshot(peer, !peer_is_hostname(peer)), + |health| health.snapshot(peer), + ) + } else { + aggregate_peer_snapshot(peer, &peer_endpoints) } }) .collect::>(); diff --git a/crates/rginx-http/src/proxy/health/registry/state.rs b/crates/rginx-http/src/proxy/health/registry/state.rs index 7f024d0f..43a56f59 100644 --- a/crates/rginx-http/src/proxy/health/registry/state.rs +++ b/crates/rginx-http/src/proxy/health/registry/state.rs @@ -1,8 +1,35 @@ -use super::*; +use super::{ + ActiveHealthState, ActiveProbeStatus, Arc, ConfigSnapshot, HashMap, HealthChangeNotifier, + Instant, Mutex, PassiveHealthState, PeerFailureStatus, PeerHealth, PeerHealthPolicy, + PeerHealthRegistry, PeerHealthState, +}; use super::guards::ActivePeerGuard; impl PeerHealthRegistry { + pub(crate) fn active_requests(&self, upstream_name: &str, peer_url: &str) -> u64 { + self.get_health(upstream_name, peer_url).map_or(0, |health| health.active_requests()) + } + + pub(super) fn ensure_endpoint( + &self, + upstream_name: &str, + endpoint_key: &str, + logical_peer_url: &str, + ) -> Arc { + let mut endpoints = + self.endpoint_peers.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + endpoints + .entry(upstream_name.to_string()) + .or_default() + .entry(endpoint_key.to_string()) + .or_insert_with(|| { + self.get_logical_peer(upstream_name, logical_peer_url) + .unwrap_or_else(|| Arc::new(PeerHealth::default())) + }) + .clone() + } + pub(crate) fn from_config(config: &ConfigSnapshot) -> Self { Self::from_config_with_notifier(config, None) } @@ -44,59 +71,96 @@ impl PeerHealthRegistry { } } - pub(crate) fn record_success(&self, upstream_name: &str, peer_url: &str) -> bool { - if let Some(health) = self.get_health(upstream_name, peer_url) { - let recovered = health.record_success(); - self.notify_change(upstream_name); - return recovered; - } - - false + pub(super) fn get_endpoint( + &self, + upstream_name: &str, + peer_url: &str, + ) -> Option> { + self.endpoint_peers + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .get(upstream_name) + .and_then(|upstream_peers| upstream_peers.get(peer_url)) + .cloned() } - pub(crate) fn record_failure(&self, upstream_name: &str, peer_url: &str) -> PeerFailureStatus { - let Some(policy) = self.policies.get(upstream_name).copied() else { - return PeerFailureStatus { consecutive_failures: 0, entered_cooldown: false }; - }; + pub(super) fn get_health( + &self, + upstream_name: &str, + peer_url: &str, + ) -> Option> { + self.get_endpoint(upstream_name, peer_url) + .or_else(|| self.get_logical_peer(upstream_name, peer_url)) + } - self.get_health(upstream_name, peer_url) - .map(|health| { - let status = health.record_failure(policy); - self.notify_change(upstream_name); - status - }) - .unwrap_or(PeerFailureStatus { consecutive_failures: 0, entered_cooldown: false }) + pub(super) fn get_logical_peer( + &self, + upstream_name: &str, + peer_url: &str, + ) -> Option> { + self.peers + .get(upstream_name) + .and_then(|upstream_peers| upstream_peers.get(peer_url)) + .cloned() } pub(super) fn is_available(&self, upstream_name: &str, peer_url: &str) -> bool { self.get_health(upstream_name, peer_url).is_none_or(|health| health.is_available()) } + pub(super) fn notify_change(&self, upstream_name: &str) { + if let Some(notifier) = &self.notifier { + notifier(upstream_name); + } + } + + pub(crate) fn record_active_failure(&self, upstream_name: &str, peer_url: &str) -> bool { + self.get_health(upstream_name, peer_url).is_some_and(|health| { + let changed = health.record_active_failure(); + self.notify_change(upstream_name); + changed + }) + } + pub(crate) fn record_active_success( &self, upstream_name: &str, peer_url: &str, healthy_successes_required: u32, ) -> ActiveProbeStatus { - self.get_health(upstream_name, peer_url) - .map(|health| { + self.get_health(upstream_name, peer_url).map_or( + ActiveProbeStatus { healthy: true, recovered: false, consecutive_successes: 0 }, + |health| { let status = health.record_active_success(healthy_successes_required); self.notify_change(upstream_name); status - }) - .unwrap_or(ActiveProbeStatus { - healthy: true, - recovered: false, - consecutive_successes: 0, - }) + }, + ) } - pub(crate) fn record_active_failure(&self, upstream_name: &str, peer_url: &str) -> bool { - self.get_health(upstream_name, peer_url).is_some_and(|health| { - let changed = health.record_active_failure(); + pub(crate) fn record_failure(&self, upstream_name: &str, peer_url: &str) -> PeerFailureStatus { + let Some(policy) = self.policies.get(upstream_name).copied() else { + return PeerFailureStatus { consecutive_failures: 0, entered_cooldown: false }; + }; + + self.get_health(upstream_name, peer_url).map_or( + PeerFailureStatus { consecutive_failures: 0, entered_cooldown: false }, + |health| { + let status = health.record_failure(policy); + self.notify_change(upstream_name); + status + }, + ) + } + + pub(crate) fn record_success(&self, upstream_name: &str, peer_url: &str) -> bool { + if let Some(health) = self.get_health(upstream_name, peer_url) { + let recovered = health.record_success(); self.notify_change(upstream_name); - changed - }) + return recovered; + } + + false } pub(crate) fn track_active_request( @@ -118,71 +182,27 @@ impl PeerHealthRegistry { upstream_name: upstream_name.to_string(), } } +} - pub(crate) fn active_requests(&self, upstream_name: &str, peer_url: &str) -> u64 { - self.get_health(upstream_name, peer_url).map(|health| health.active_requests()).unwrap_or(0) - } - - pub(super) fn get_health( - &self, - upstream_name: &str, - peer_url: &str, - ) -> Option> { - self.get_endpoint(upstream_name, peer_url) - .or_else(|| self.get_logical_peer(upstream_name, peer_url)) - } - - pub(super) fn get_logical_peer( - &self, - upstream_name: &str, - peer_url: &str, - ) -> Option> { - self.peers - .get(upstream_name) - .and_then(|upstream_peers| upstream_peers.get(peer_url)) - .cloned() - } - - pub(super) fn get_endpoint( - &self, - upstream_name: &str, - peer_url: &str, - ) -> Option> { - self.endpoint_peers - .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .get(upstream_name) - .and_then(|upstream_peers| upstream_peers.get(peer_url)) - .cloned() +impl PeerHealth { + pub(super) fn active_requests(&self) -> u64 { + lock_peer_health(&self.state).active_requests } - pub(super) fn ensure_endpoint( - &self, - upstream_name: &str, - endpoint_key: &str, - logical_peer_url: &str, - ) -> Arc { - let mut endpoints = - self.endpoint_peers.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - endpoints - .entry(upstream_name.to_string()) - .or_default() - .entry(endpoint_key.to_string()) - .or_insert_with(|| { - self.get_logical_peer(upstream_name, logical_peer_url) - .unwrap_or_else(|| Arc::new(PeerHealth::default())) - }) - .clone() + pub(super) fn decrement_active_requests(&self) -> bool { + let mut state = lock_peer_health(&self.state); + let was_active = state.active_requests > 0; + state.active_requests = state.active_requests.saturating_sub(1); + was_active && state.active_requests == 0 } - pub(super) fn notify_change(&self, upstream_name: &str) { - if let Some(notifier) = &self.notifier { - notifier(upstream_name); - } + fn increment_active_requests(&self) -> bool { + let mut state = lock_peer_health(&self.state); + let transitioned_from_idle = state.active_requests == 0; + state.active_requests = state.active_requests.saturating_add(1); + transitioned_from_idle } -} -impl PeerHealth { pub(super) fn is_available(&self) -> bool { let state = lock_peer_health(&self.state); let passive_available = @@ -190,11 +210,28 @@ impl PeerHealth { passive_available && !state.active.unhealthy } - fn record_success(&self) -> bool { + fn record_active_failure(&self) -> bool { let mut state = lock_peer_health(&self.state); - let recovered = state.passive.pending_recovery; - state.passive = PassiveHealthState::default(); - recovered + let was_healthy = !state.active.unhealthy; + state.active.unhealthy = true; + state.active.consecutive_successes = 0; + was_healthy + } + + fn record_active_success(&self, healthy_successes_required: u32) -> ActiveProbeStatus { + let mut state = lock_peer_health(&self.state); + if !state.active.unhealthy { + return ActiveProbeStatus { healthy: true, recovered: false, consecutive_successes: 0 }; + } + + state.active.consecutive_successes = state.active.consecutive_successes.saturating_add(1); + let consecutive_successes = state.active.consecutive_successes; + let recovered = consecutive_successes >= healthy_successes_required; + if recovered { + state.active = ActiveHealthState::default(); + } + + ActiveProbeStatus { healthy: recovered, recovered, consecutive_successes } } fn record_failure(&self, policy: PeerHealthPolicy) -> PeerFailureStatus { @@ -207,14 +244,20 @@ impl PeerHealth { } let already_in_cooldown = state.passive.unhealthy_until.is_some_and(|until| until > now); - state.passive.consecutive_failures += 1; + state.passive.consecutive_failures = state.passive.consecutive_failures.saturating_add(1); let entered_cooldown = !already_in_cooldown && state.passive.consecutive_failures >= policy.unhealthy_after_failures; if entered_cooldown { - state.passive.unhealthy_until = Some(now + policy.cooldown); + state.passive.unhealthy_until = Some( + now.checked_add(policy.cooldown) + .expect("peer cooldown deadline remains representable"), + ); state.passive.pending_recovery = true; } else if already_in_cooldown { - state.passive.unhealthy_until = Some(now + policy.cooldown); + state.passive.unhealthy_until = Some( + now.checked_add(policy.cooldown) + .expect("peer cooldown deadline remains representable"), + ); } PeerFailureStatus { @@ -223,51 +266,16 @@ impl PeerHealth { } } - fn record_active_success(&self, healthy_successes_required: u32) -> ActiveProbeStatus { - let mut state = lock_peer_health(&self.state); - if !state.active.unhealthy { - return ActiveProbeStatus { healthy: true, recovered: false, consecutive_successes: 0 }; - } - - state.active.consecutive_successes += 1; - let consecutive_successes = state.active.consecutive_successes; - let recovered = consecutive_successes >= healthy_successes_required; - if recovered { - state.active = ActiveHealthState::default(); - } - - ActiveProbeStatus { healthy: recovered, recovered, consecutive_successes } - } - - fn record_active_failure(&self) -> bool { - let mut state = lock_peer_health(&self.state); - let was_healthy = !state.active.unhealthy; - state.active.unhealthy = true; - state.active.consecutive_successes = 0; - was_healthy - } - - fn increment_active_requests(&self) -> bool { - let mut state = lock_peer_health(&self.state); - let transitioned_from_idle = state.active_requests == 0; - state.active_requests += 1; - transitioned_from_idle - } - - pub(super) fn decrement_active_requests(&self) -> bool { + fn record_success(&self) -> bool { let mut state = lock_peer_health(&self.state); - let was_active = state.active_requests > 0; - state.active_requests = state.active_requests.saturating_sub(1); - was_active && state.active_requests == 0 - } - - pub(super) fn active_requests(&self) -> u64 { - lock_peer_health(&self.state).active_requests + let recovered = state.passive.pending_recovery; + state.passive = PassiveHealthState::default(); + recovered } } pub(super) fn lock_peer_health( state: &Mutex, ) -> std::sync::MutexGuard<'_, PeerHealthState> { - state.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) + state.lock().unwrap_or_else(std::sync::PoisonError::into_inner) } diff --git a/crates/rginx-http/src/proxy/health/request.rs b/crates/rginx-http/src/proxy/health/request.rs index 2f1af14a..e4b8aa95 100644 --- a/crates/rginx-http/src/proxy/health/request.rs +++ b/crates/rginx-http/src/proxy/health/request.rs @@ -1,4 +1,9 @@ -use super::*; +use super::{ + ActiveHealthCheck, Bytes, CONTENT_LENGTH, CONTENT_TYPE, Error, HOST, HeaderValue, HttpBody, + Method, Request, ResolvedUpstreamPeer, TE, Upstream, UpstreamProtocol, Uri, build_proxy_uri, + encode_grpc_health_check_request, full_body, + remove_redundant_host_header_for_authority_pseudo_header, upstream_request_version, +}; use rginx_core::ProxyUriMode; pub(in super::super) fn build_active_health_request( diff --git a/crates/rginx-http/src/proxy/mod.rs b/crates/rginx-http/src/proxy/mod.rs index a1ed60ab..91749951 100644 --- a/crates/rginx-http/src/proxy/mod.rs +++ b/crates/rginx-http/src/proxy/mod.rs @@ -1,3 +1,15 @@ +mod clients; +mod common; +mod error_mapping; +mod forward; +mod grpc_web; +mod health; +mod request_body; +mod resolver; +#[cfg(test)] +mod tests; +mod upgrade; + pub(super) use std::collections::{HashMap, HashSet}; pub(super) use std::sync::Arc; pub(super) use std::time::Duration; @@ -33,34 +45,27 @@ use crate::handler::{ use crate::state::SharedState; use crate::timeout::{GrpcDeadlineBody, IdleTimeoutBody, MaxBytesBody}; -mod clients; -mod common; -mod error_mapping; -mod forward; -mod grpc_web; -mod health; -mod request_body; -mod resolver; -#[cfg(test)] -mod tests; -mod upgrade; - -const MAX_FAILOVER_ATTEMPTS: usize = 2; -const GRPC_CONTENT_TYPE_PREFIX: &str = "application/grpc"; -const GRPC_WEB_CONTENT_TYPE_PREFIX: &str = "application/grpc-web"; -const GRPC_WEB_TEXT_CONTENT_TYPE_PREFIX: &str = "application/grpc-web-text"; -const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; -const MAX_GRPC_TIMEOUT_DIGITS: usize = 8; - pub(crate) use clients::HealthChangeNotifier; pub use clients::ProxyClients; pub use forward::{DownstreamRequestContext, DownstreamRequestOptions, forward_request}; pub use health::probe_upstream_peer; pub use health::{PeerHealthSnapshot, UpstreamHealthSnapshot}; -use self::common::*; +use self::common::{ + append_header_map, build_proxy_uri, is_upgrade_request, is_upgrade_response, + preserved_te_trailers_value, remove_redundant_host_header_for_authority_pseudo_header, + sanitize_request_headers, sanitize_response_headers, split_content_type, + upstream_request_version, +}; pub(crate) use error_mapping::{classify_upstream_tls_failure, upstream_tls_verify_label}; pub(super) use grpc_web::GrpcWebMode; pub(crate) use resolver::{ ResolvedUpstreamPeer, UpstreamResolver, UpstreamResolverRuntimeSnapshot, }; + +const MAX_FAILOVER_ATTEMPTS: usize = 2; +const GRPC_CONTENT_TYPE_PREFIX: &str = "application/grpc"; +const GRPC_WEB_CONTENT_TYPE_PREFIX: &str = "application/grpc-web"; +const GRPC_WEB_TEXT_CONTENT_TYPE_PREFIX: &str = "application/grpc-web-text"; +const GRPC_TIMEOUT_HEADER: &str = "grpc-timeout"; +const MAX_GRPC_TIMEOUT_DIGITS: usize = 8; diff --git a/crates/rginx-http/src/proxy/request_body/mod.rs b/crates/rginx-http/src/proxy/request_body/mod.rs index 303b071f..46ab4406 100644 --- a/crates/rginx-http/src/proxy/request_body/mod.rs +++ b/crates/rginx-http/src/proxy/request_body/mod.rs @@ -1,9 +1,3 @@ -use std::error::Error as StdError; - -use super::grpc_web::{GrpcWebMode, GrpcWebRequestBody, GrpcWebTextDecodeBody}; -use super::*; -use crate::handler::boxed_body; - mod limits; mod model; mod prepare; @@ -12,6 +6,12 @@ mod streaming; #[cfg(test)] mod tests; +use std::error::Error as StdError; + +use super::grpc_web::{GrpcWebMode, GrpcWebRequestBody, GrpcWebTextDecodeBody}; +use super::*; +use crate::handler::boxed_body; + pub(super) use limits::request_body_limit_error; pub(super) use model::{ BuiltUpstreamRequest, PrepareProxyRequestOptions, PrepareRequestError, PreparedProxyRequest, diff --git a/crates/rginx-http/src/proxy/request_body/model.rs b/crates/rginx-http/src/proxy/request_body/model.rs index 6643c50e..106f94e7 100644 --- a/crates/rginx-http/src/proxy/request_body/model.rs +++ b/crates/rginx-http/src/proxy/request_body/model.rs @@ -1,12 +1,12 @@ use super::*; pub(in crate::proxy) struct PreparedProxyRequest { - pub method: Method, - pub uri: Uri, - pub headers: HeaderMap, pub body: PreparedRequestBody, + pub headers: HeaderMap, + pub method: Method, pub observed_request_body_bytes: u64, pub(in crate::proxy) peer_failover_enabled: bool, + pub uri: Uri, pub(in crate::proxy) wait_for_streaming_body: bool, } @@ -16,34 +16,30 @@ pub(in crate::proxy) enum PreparedRequestBody { } pub(in crate::proxy) struct BuiltUpstreamRequest { - pub request: Request, pub body_completion: Option, + pub request: Request, } pub(in crate::proxy) type StreamingBodyCompletion = tokio::sync::oneshot::Receiver>; pub(in crate::proxy) struct PrepareProxyRequestOptions<'a> { - pub upstream_name: &'a str, - pub request_body_read_timeout: Option, - pub write_timeout: Duration, + pub grpc_web_mode: Option<&'a GrpcWebMode>, pub max_replayable_request_body_bytes: usize, pub max_request_body_bytes: Option, + pub request_body_read_timeout: Option, pub request_buffering: RouteBufferingPolicy, - pub grpc_web_mode: Option<&'a GrpcWebMode>, + pub upstream_name: &'a str, + pub write_timeout: Duration, } #[derive(Debug)] pub(in crate::proxy) enum PrepareRequestError { - PayloadTooLarge { max_request_body_bytes: usize }, Other(Box), + PayloadTooLarge { max_request_body_bytes: usize }, } impl PrepareRequestError { - pub(in crate::proxy) fn payload_too_large(max_request_body_bytes: usize) -> Self { - Self::PayloadTooLarge { max_request_body_bytes } - } - pub(in crate::proxy) fn boxed(error: BoxError) -> Self { if let Some(max_request_body_bytes) = request_body_limit_error(error.as_ref()) { Self::payload_too_large(max_request_body_bytes) @@ -51,6 +47,9 @@ impl PrepareRequestError { Self::Other(error) } } + pub(in crate::proxy) fn payload_too_large(max_request_body_bytes: usize) -> Self { + Self::PayloadTooLarge { max_request_body_bytes } + } } impl std::fmt::Display for PrepareRequestError { diff --git a/crates/rginx-http/src/proxy/request_body/prepare.rs b/crates/rginx-http/src/proxy/request_body/prepare.rs index d020486a..935514a3 100644 --- a/crates/rginx-http/src/proxy/request_body/prepare.rs +++ b/crates/rginx-http/src/proxy/request_body/prepare.rs @@ -9,55 +9,17 @@ struct CollectedRequestBody { } struct PrepareRequestBodyConfig<'a> { - upstream_name: &'a str, - method: &'a Method, - headers: &'a HeaderMap, body_timeout: Duration, + grpc_web_mode: Option<&'a GrpcWebMode>, + headers: &'a HeaderMap, max_replayable_request_body_bytes: usize, max_request_body_bytes: Option, + method: &'a Method, request_buffering: RouteBufferingPolicy, - grpc_web_mode: Option<&'a GrpcWebMode>, + upstream_name: &'a str, } impl PreparedProxyRequest { - pub(in crate::proxy) async fn from_request( - request: Request, - options: PrepareProxyRequestOptions<'_>, - ) -> Result { - let (parts, body) = request.into_parts(); - let body_timeout = options.request_body_read_timeout.unwrap_or(options.write_timeout); - let prepared_body = prepare_request_body( - body, - PrepareRequestBodyConfig { - upstream_name: options.upstream_name, - method: &parts.method, - headers: &parts.headers, - body_timeout, - max_replayable_request_body_bytes: options.max_replayable_request_body_bytes, - max_request_body_bytes: options.max_request_body_bytes, - request_buffering: options.request_buffering, - grpc_web_mode: options.grpc_web_mode, - }, - ) - .await?; - - Ok(Self { - method: parts.method, - uri: parts.uri, - headers: parts.headers, - observed_request_body_bytes: prepared_body.observed_bytes(), - body: prepared_body, - peer_failover_enabled: options.request_buffering != RouteBufferingPolicy::Off, - wait_for_streaming_body: options.max_request_body_bytes.is_some(), - }) - } - - pub(in crate::proxy) fn can_failover(&self) -> bool { - self.peer_failover_enabled - && is_idempotent_method(&self.method) - && matches!(self.body, PreparedRequestBody::Replayable { .. }) - } - pub(in crate::proxy) fn build_for_peer( &mut self, peer: &ResolvedUpstreamPeer, @@ -116,7 +78,45 @@ impl PreparedProxyRequest { *request.version_mut() = target.request_version; *request.uri_mut() = uri; *request.headers_mut() = headers; - Ok(BuiltUpstreamRequest { request, body_completion }) + Ok(BuiltUpstreamRequest { body_completion, request }) + } + + pub(in crate::proxy) fn can_failover(&self) -> bool { + self.peer_failover_enabled + && is_idempotent_method(&self.method) + && matches!(self.body, PreparedRequestBody::Replayable { .. }) + } + + pub(in crate::proxy) async fn from_request( + request: Request, + options: PrepareProxyRequestOptions<'_>, + ) -> Result { + let (parts, body) = request.into_parts(); + let body_timeout = options.request_body_read_timeout.unwrap_or(options.write_timeout); + let prepared_body = prepare_request_body( + body, + PrepareRequestBodyConfig { + upstream_name: options.upstream_name, + method: &parts.method, + headers: &parts.headers, + body_timeout, + max_replayable_request_body_bytes: options.max_replayable_request_body_bytes, + max_request_body_bytes: options.max_request_body_bytes, + request_buffering: options.request_buffering, + grpc_web_mode: options.grpc_web_mode, + }, + ) + .await?; + + Ok(Self { + method: parts.method, + uri: parts.uri, + headers: parts.headers, + observed_request_body_bytes: prepared_body.observed_bytes(), + body: prepared_body, + peer_failover_enabled: options.request_buffering != RouteBufferingPolicy::Off, + wait_for_streaming_body: options.max_request_body_bytes.is_some(), + }) } } diff --git a/crates/rginx-http/src/proxy/request_body/replay.rs b/crates/rginx-http/src/proxy/request_body/replay.rs index 8844d83f..d255cd78 100644 --- a/crates/rginx-http/src/proxy/request_body/replay.rs +++ b/crates/rginx-http/src/proxy/request_body/replay.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{BoxError, Bytes, Frame, HeaderMap, Method, PreparedProxyRequest, SizeHint}; pub(super) struct ReplayableRequestBody { body: Option, @@ -15,6 +15,10 @@ impl hyper::body::Body for ReplayableRequestBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.body.as_ref().is_none_or(Bytes::is_empty) && self.trailers.is_none() + } + fn poll_frame( self: std::pin::Pin<&mut Self>, _cx: &mut std::task::Context<'_>, @@ -34,10 +38,6 @@ impl hyper::body::Body for ReplayableRequestBody { std::task::Poll::Ready(None) } - fn is_end_stream(&self) -> bool { - self.body.as_ref().is_none_or(Bytes::is_empty) && self.trailers.is_none() - } - fn size_hint(&self) -> SizeHint { let mut hint = SizeHint::new(); hint.set_exact(self.body.as_ref().map_or(0, |body| body.len() as u64)); @@ -57,5 +57,5 @@ pub(in crate::proxy) fn can_retry_peer_request( peer_count: usize, attempt_index: usize, ) -> bool { - prepared_request.can_failover() && attempt_index + 1 < peer_count + prepared_request.can_failover() && attempt_index.saturating_add(1) < peer_count } diff --git a/crates/rginx-http/src/proxy/request_body/streaming.rs b/crates/rginx-http/src/proxy/request_body/streaming.rs index 6e2aae56..6e5b1465 100644 --- a/crates/rginx-http/src/proxy/request_body/streaming.rs +++ b/crates/rginx-http/src/proxy/request_body/streaming.rs @@ -1,9 +1,12 @@ -use super::*; +use super::{ + BodyExt, BoxError, Bytes, Frame, HttpBody, SizeHint, StdError, StreamingBodyCompletion, + boxed_body, request_body_limit_error, +}; #[derive(Debug)] struct RelayedRequestBody { - receiver: tokio::sync::mpsc::Receiver, BoxError>>, done: bool, + receiver: tokio::sync::mpsc::Receiver, BoxError>>, } impl RelayedRequestBody { @@ -16,6 +19,10 @@ impl hyper::body::Body for RelayedRequestBody { type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -43,10 +50,6 @@ impl hyper::body::Body for RelayedRequestBody { } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } diff --git a/crates/rginx-http/src/proxy/resolver.rs b/crates/rginx-http/src/proxy/resolver.rs index 67370007..bbfb215e 100644 --- a/crates/rginx-http/src/proxy/resolver.rs +++ b/crates/rginx-http/src/proxy/resolver.rs @@ -1,3 +1,6 @@ +mod endpoint; +mod runtime; + use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; @@ -9,77 +12,77 @@ use rginx_core::UpstreamDnsPolicy; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex; -mod endpoint; -mod runtime; - #[derive(Debug, Clone)] pub(crate) struct ResolvedUpstreamPeer { - #[cfg_attr(not(test), allow(dead_code))] - pub url: String, - pub logical_peer_url: String, - pub endpoint_key: String, + pub backup: bool, + pub dial_authority: String, pub display_url: String, + pub endpoint_key: String, + pub logical_peer_url: String, + pub max_conns: Option, pub scheme: String, - pub upstream_authority: String, - pub dial_authority: String, - pub socket_addr: SocketAddr, pub server_name: String, + pub socket_addr: SocketAddr, + pub upstream_authority: String, + #[cfg_attr( + not(test), + expect(dead_code, reason = "canonical peer URL is asserted by resolver tests") + )] + pub url: String, pub weight: u32, - pub backup: bool, - pub max_conns: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct UpstreamResolverCacheEntrySnapshot { - pub hostname: String, pub addresses: Vec, + pub hostname: String, + pub last_error: Option, pub negative: bool, - pub valid_for_ms: Option, pub stale_for_ms: Option, - pub last_error: Option, + pub valid_for_ms: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] pub struct UpstreamResolverRuntimeSnapshot { - pub resolve_requests_total: u64, + pub cache_entries: Vec, pub cache_hits_total: u64, pub cache_misses_total: u64, pub refreshes_total: u64, pub resolve_errors_total: u64, + pub resolve_requests_total: u64, pub stale_answers_total: u64, - pub cache_entries: Vec, } #[derive(Debug, Clone)] pub(super) struct CacheEntry { pub(super) addresses: Vec, - pub(super) valid_until: Instant, - pub(super) stale_until: Instant, - pub(super) negative: bool, pub(super) last_error: Option, + pub(super) negative: bool, + pub(super) stale_until: Instant, + pub(super) valid_until: Instant, } #[derive(Debug, Clone)] pub(super) struct PeerAddressing { + pub(super) authority: String, + pub(super) backup: bool, pub(super) host: String, + pub(super) logical_peer_url: String, + pub(super) max_conns: Option, pub(super) port: u16, pub(super) scheme: String, - pub(super) authority: String, - pub(super) logical_peer_url: String, pub(super) weight: u32, - pub(super) backup: bool, - pub(super) max_conns: Option, } #[derive(Debug, Clone)] pub(crate) struct UpstreamResolver { - pub(super) policy: UpstreamDnsPolicy, - pub(super) resolver: TokioResolver, pub(super) cache: Arc>>, - pub(super) resolve_requests_total: Arc, pub(super) cache_hits_total: Arc, pub(super) cache_misses_total: Arc, + pub(super) policy: UpstreamDnsPolicy, pub(super) refreshes_total: Arc, pub(super) resolve_errors_total: Arc, + pub(super) resolve_requests_total: Arc, + pub(super) resolver: TokioResolver, pub(super) stale_answers_total: Arc, } diff --git a/crates/rginx-http/src/proxy/resolver/endpoint.rs b/crates/rginx-http/src/proxy/resolver/endpoint.rs index 4c068b86..ca1433cf 100644 --- a/crates/rginx-http/src/proxy/resolver/endpoint.rs +++ b/crates/rginx-http/src/proxy/resolver/endpoint.rs @@ -7,7 +7,7 @@ use hickory_resolver::config::{ use hickory_resolver::net::runtime::TokioRuntimeProvider; use rginx_core::{Error, UpstreamDnsPolicy, UpstreamPeer}; -use super::*; +use super::{PeerAddressing, ResolvedUpstreamPeer}; pub(super) fn build_resolver(policy: &UpstreamDnsPolicy) -> Result { let mut options = ResolverOpts::default(); @@ -112,7 +112,8 @@ pub(super) fn refresh_due( valid_until: std::time::Instant, refresh_before_expiry: std::time::Duration, ) -> bool { - now >= valid_until || now + refresh_before_expiry >= valid_until + now >= valid_until + || now.checked_add(refresh_before_expiry).is_none_or(|refresh_at| refresh_at >= valid_until) } pub(super) fn order_addresses( diff --git a/crates/rginx-http/src/proxy/resolver/runtime.rs b/crates/rginx-http/src/proxy/resolver/runtime.rs index 6249a69f..37989726 100644 --- a/crates/rginx-http/src/proxy/resolver/runtime.rs +++ b/crates/rginx-http/src/proxy/resolver/runtime.rs @@ -9,37 +9,12 @@ use super::endpoint::{ build_endpoint, build_resolver, clamp_ttl, duration_to_ms, order_addresses, parse_peer_addressing, refresh_due, }; -use super::*; +use super::{ + Arc, AtomicU64, CacheEntry, HashMap, Mutex, ResolvedUpstreamPeer, UpstreamDnsPolicy, + UpstreamResolver, UpstreamResolverCacheEntrySnapshot, UpstreamResolverRuntimeSnapshot, +}; impl UpstreamResolver { - pub(crate) fn new(policy: UpstreamDnsPolicy) -> Result { - let resolver = build_resolver(&policy)?; - Ok(Self { - policy, - resolver, - cache: Arc::new(Mutex::new(HashMap::new())), - resolve_requests_total: Arc::new(AtomicU64::new(0)), - cache_hits_total: Arc::new(AtomicU64::new(0)), - cache_misses_total: Arc::new(AtomicU64::new(0)), - refreshes_total: Arc::new(AtomicU64::new(0)), - resolve_errors_total: Arc::new(AtomicU64::new(0)), - stale_answers_total: Arc::new(AtomicU64::new(0)), - }) - } - - pub(crate) async fn resolve_peer( - &self, - peer: &UpstreamPeer, - ) -> Result, Error> { - let addressing = parse_peer_addressing(peer)?; - if let Ok(ip) = IpAddr::from_str(&addressing.host) { - return Ok(vec![build_endpoint(&addressing, ip)]); - } - - let addresses = self.resolve_host(&addressing.host).await?; - Ok(addresses.into_iter().map(|ip| build_endpoint(&addressing, ip)).collect()) - } - pub(crate) async fn cached_peer_endpoints( &self, peer: &UpstreamPeer, @@ -61,31 +36,31 @@ impl UpstreamResolver { Ok(entry.addresses.iter().copied().map(|ip| build_endpoint(&addressing, ip)).collect()) } - pub(crate) async fn snapshot(&self) -> UpstreamResolverRuntimeSnapshot { + async fn lookup_host(&self, host: &str) -> Result<(Vec, Duration), Error> { + let lookup = + self.resolver.lookup_ip(host).await.map_err(|error| { + Error::Server(format!("dns lookup failed for `{host}`: {error}")) + })?; let now = Instant::now(); - let cache = self.cache.lock().await; - let mut cache_entries = cache - .iter() - .map(|(hostname, entry)| UpstreamResolverCacheEntrySnapshot { - hostname: hostname.clone(), - addresses: entry.addresses.iter().map(ToString::to_string).collect(), - negative: entry.negative, - valid_for_ms: entry.valid_until.checked_duration_since(now).map(duration_to_ms), - stale_for_ms: entry.stale_until.checked_duration_since(now).map(duration_to_ms), - last_error: entry.last_error.clone(), - }) - .collect::>(); - cache_entries.sort_by(|left, right| left.hostname.cmp(&right.hostname)); - - UpstreamResolverRuntimeSnapshot { - resolve_requests_total: self.resolve_requests_total.load(Ordering::Relaxed), - cache_hits_total: self.cache_hits_total.load(Ordering::Relaxed), - cache_misses_total: self.cache_misses_total.load(Ordering::Relaxed), - refreshes_total: self.refreshes_total.load(Ordering::Relaxed), - resolve_errors_total: self.resolve_errors_total.load(Ordering::Relaxed), - stale_answers_total: self.stale_answers_total.load(Ordering::Relaxed), - cache_entries, - } + let ttl = lookup.valid_until().checked_duration_since(now).unwrap_or(self.policy.max_ttl); + let mut addresses = lookup.iter().collect::>(); + addresses.sort(); + addresses.dedup(); + Ok((addresses, ttl)) + } + pub(crate) fn new(policy: UpstreamDnsPolicy) -> Result { + let resolver = build_resolver(&policy)?; + Ok(Self { + policy, + resolver, + cache: Arc::new(Mutex::new(HashMap::new())), + resolve_requests_total: Arc::new(AtomicU64::new(0)), + cache_hits_total: Arc::new(AtomicU64::new(0)), + cache_misses_total: Arc::new(AtomicU64::new(0)), + refreshes_total: Arc::new(AtomicU64::new(0)), + resolve_errors_total: Arc::new(AtomicU64::new(0)), + stale_answers_total: Arc::new(AtomicU64::new(0)), + }) } async fn resolve_host(&self, host: &str) -> Result, Error> { @@ -108,11 +83,15 @@ impl UpstreamResolver { match self.lookup_host(host).await { Ok((addresses, ttl)) if !addresses.is_empty() => { let ttl = clamp_ttl(ttl, self.policy.min_ttl, self.policy.max_ttl); - let stale_until = now + ttl + self.policy.stale_if_error; + let valid_until = + now.checked_add(ttl).expect("DNS TTL deadline remains representable"); + let stale_until = valid_until + .checked_add(self.policy.stale_if_error) + .expect("DNS stale deadline remains representable"); let addresses = order_addresses(addresses, &self.policy); let entry = CacheEntry { addresses: addresses.clone(), - valid_until: now + ttl, + valid_until, stale_until, negative: false, last_error: None, @@ -141,6 +120,46 @@ impl UpstreamResolver { } } + pub(crate) async fn resolve_peer( + &self, + peer: &UpstreamPeer, + ) -> Result, Error> { + let addressing = parse_peer_addressing(peer)?; + if let Ok(ip) = IpAddr::from_str(&addressing.host) { + return Ok(vec![build_endpoint(&addressing, ip)]); + } + + let addresses = self.resolve_host(&addressing.host).await?; + Ok(addresses.into_iter().map(|ip| build_endpoint(&addressing, ip)).collect()) + } + + pub(crate) async fn snapshot(&self) -> UpstreamResolverRuntimeSnapshot { + let now = Instant::now(); + let cache = self.cache.lock().await; + let mut cache_entries = cache + .iter() + .map(|(hostname, entry)| UpstreamResolverCacheEntrySnapshot { + hostname: hostname.clone(), + addresses: entry.addresses.iter().map(ToString::to_string).collect(), + negative: entry.negative, + valid_for_ms: entry.valid_until.checked_duration_since(now).map(duration_to_ms), + stale_for_ms: entry.stale_until.checked_duration_since(now).map(duration_to_ms), + last_error: entry.last_error.clone(), + }) + .collect::>(); + cache_entries.sort_by(|left, right| left.hostname.cmp(&right.hostname)); + + UpstreamResolverRuntimeSnapshot { + resolve_requests_total: self.resolve_requests_total.load(Ordering::Relaxed), + cache_hits_total: self.cache_hits_total.load(Ordering::Relaxed), + cache_misses_total: self.cache_misses_total.load(Ordering::Relaxed), + refreshes_total: self.refreshes_total.load(Ordering::Relaxed), + resolve_errors_total: self.resolve_errors_total.load(Ordering::Relaxed), + stale_answers_total: self.stale_answers_total.load(Ordering::Relaxed), + cache_entries, + } + } + async fn store_negative_and_fail( &self, host: &str, @@ -152,25 +171,16 @@ impl UpstreamResolver { host.to_string(), CacheEntry { addresses: Vec::new(), - valid_until: now + self.policy.negative_ttl, - stale_until: now + self.policy.negative_ttl, + valid_until: now + .checked_add(self.policy.negative_ttl) + .expect("negative DNS TTL deadline remains representable"), + stale_until: now + .checked_add(self.policy.negative_ttl) + .expect("negative DNS stale deadline remains representable"), negative: true, last_error: Some(message.to_string()), }, ); Err(Error::Server(format!("failed to resolve upstream hostname `{host}`: {message}"))) } - - async fn lookup_host(&self, host: &str) -> Result<(Vec, Duration), Error> { - let lookup = - self.resolver.lookup_ip(host).await.map_err(|error| { - Error::Server(format!("dns lookup failed for `{host}`: {error}")) - })?; - let now = Instant::now(); - let ttl = lookup.valid_until().checked_duration_since(now).unwrap_or(self.policy.max_ttl); - let mut addresses = lookup.iter().collect::>(); - addresses.sort(); - addresses.dedup(); - Ok((addresses, ttl)) - } } diff --git a/crates/rginx-http/src/proxy/tests/mod.rs b/crates/rginx-http/src/proxy/tests/mod.rs index d497445c..26ba4739 100644 --- a/crates/rginx-http/src/proxy/tests/mod.rs +++ b/crates/rginx-http/src/proxy/tests/mod.rs @@ -1,4 +1,11 @@ -#![allow(unused_imports)] +mod cache; +mod cache_stale; +mod client_profiles; +mod grpc; +mod peer_recovery; +mod peer_selection; +mod request_headers; +mod support; use std::collections::{HashMap, VecDeque}; use std::convert::Infallible; @@ -52,17 +59,27 @@ use super::{ use crate::client_ip::{ClientAddress, ClientIpSource}; use tempfile::TempDir; -mod cache; -mod cache_stale; -mod client_profiles; -mod grpc; -mod peer_recovery; -mod peer_selection; -mod request_headers; -mod support; - use support::spawn_range_server; +const TEST_CA_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDXTCCAkWgAwIBAgIJAOIvDiVb18eVMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV\nBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX\naWRnaXRzIFB0eSBMdGQwHhcNMTYwODE0MTY1NjExWhcNMjYwODEyMTY1NjExWjBF\nMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50\nZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\nCgKCAQEArVHWFn52Lbl1l59exduZntVSZyDYpzDND+S2LUcO6fRBWhV/1Kzox+2G\nZptbuMGmfI3iAnb0CFT4uC3kBkQQlXonGATSVyaFTFR+jq/lc0SP+9Bd7SBXieIV\neIXlY1TvlwIvj3Ntw9zX+scTA4SXxH6M0rKv9gTOub2vCMSHeF16X8DQr4XsZuQr\n7Cp7j1I4aqOJyap5JTl5ijmG8cnu0n+8UcRlBzy99dLWJG0AfI3VRJdWpGTNVZ92\naFff3RpK3F/WI2gp3qV1ynRAKuvmncGC3LDvYfcc2dgsc1N6Ffq8GIrkgRob6eBc\nklDHp1d023Lwre+VaVDSo1//Y72UFwIDAQABo1AwTjAdBgNVHQ4EFgQUbNOlA6sN\nXyzJjYqciKeId7g3/ZowHwYDVR0jBBgwFoAUbNOlA6sNXyzJjYqciKeId7g3/Zow\nDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAVVaR5QWLZIRR4Dw6TSBn\nBQiLpBSXN6oAxdDw6n4PtwW6CzydaA+creiK6LfwEsiifUfQe9f+T+TBSpdIYtMv\nZ2H2tjlFX8VrjUFvPrvn5c28CuLI0foBgY8XGSkR2YMYzWw2jPEq3Th/KM5Catn3\nAFm3bGKWMtGPR4v+90chEN0jzaAmJYRrVUh9vea27bOCn31Nse6XXQPmSI6Gyncy\nOAPUsvPClF3IjeL1tmBotWqSGn1cYxLo+Lwjk22A9h6vjcNQRyZF2VLVvtwYrNU3\nmwJ6GCLsLHpwW/yjyvn8iEltnJvByM/eeRnfXV6WDObyiZsE/n6DxIRJodQzFqy9\nGA==\n-----END CERTIFICATE-----\n"; + +struct StatusServerHandle { + listen_addr: SocketAddr, + shutdown: Arc, + thread: Option>, +} + +impl Drop for StatusServerHandle { + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Relaxed); + if let Some(thread) = self.thread.take() { + let _ = thread.join(); + } + } +} + +type TestCertifiedKey = CertifiedKey; + fn upstream_settings(protocol: UpstreamProtocol) -> UpstreamSettings { UpstreamSettings { protocol, @@ -295,28 +312,13 @@ fn grpc_health_response_body(serving_status: u64) -> Bytes { panic!("test serving status should fit in a single-byte protobuf varint"); } - let mut body = BytesMut::with_capacity(5 + payload.len()); + let mut body = BytesMut::with_capacity(payload.len().saturating_add(5)); body.extend_from_slice(&[0]); body.extend_from_slice(&(payload.len() as u32).to_be_bytes()); body.extend_from_slice(&payload); body.freeze() } -struct StatusServerHandle { - listen_addr: SocketAddr, - shutdown: Arc, - thread: Option>, -} - -impl Drop for StatusServerHandle { - fn drop(&mut self) { - self.shutdown.store(true, Ordering::Relaxed); - if let Some(thread) = self.thread.take() { - let _ = thread.join(); - } - } -} - async fn spawn_status_server(statuses: Arc>>) -> StatusServerHandle { let listener = TcpListener::bind(("127.0.0.1", 0)).expect("test status listener should bind"); listener.set_nonblocking(true).expect("status listener should support nonblocking mode"); @@ -375,10 +377,6 @@ async fn spawn_status_server(statuses: Arc>>) -> Stat StatusServerHandle { listen_addr, shutdown, thread: Some(thread) } } -const TEST_CA_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDXTCCAkWgAwIBAgIJAOIvDiVb18eVMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV\nBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX\naWRnaXRzIFB0eSBMdGQwHhcNMTYwODE0MTY1NjExWhcNMjYwODEyMTY1NjExWjBF\nMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50\nZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB\nCgKCAQEArVHWFn52Lbl1l59exduZntVSZyDYpzDND+S2LUcO6fRBWhV/1Kzox+2G\nZptbuMGmfI3iAnb0CFT4uC3kBkQQlXonGATSVyaFTFR+jq/lc0SP+9Bd7SBXieIV\neIXlY1TvlwIvj3Ntw9zX+scTA4SXxH6M0rKv9gTOub2vCMSHeF16X8DQr4XsZuQr\n7Cp7j1I4aqOJyap5JTl5ijmG8cnu0n+8UcRlBzy99dLWJG0AfI3VRJdWpGTNVZ92\naFff3RpK3F/WI2gp3qV1ynRAKuvmncGC3LDvYfcc2dgsc1N6Ffq8GIrkgRob6eBc\nklDHp1d023Lwre+VaVDSo1//Y72UFwIDAQABo1AwTjAdBgNVHQ4EFgQUbNOlA6sN\nXyzJjYqciKeId7g3/ZowHwYDVR0jBBgwFoAUbNOlA6sNXyzJjYqciKeId7g3/Zow\nDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAVVaR5QWLZIRR4Dw6TSBn\nBQiLpBSXN6oAxdDw6n4PtwW6CzydaA+creiK6LfwEsiifUfQe9f+T+TBSpdIYtMv\nZ2H2tjlFX8VrjUFvPrvn5c28CuLI0foBgY8XGSkR2YMYzWw2jPEq3Th/KM5Catn3\nAFm3bGKWMtGPR4v+90chEN0jzaAmJYRrVUh9vea27bOCn31Nse6XXQPmSI6Gyncy\nOAPUsvPClF3IjeL1tmBotWqSGn1cYxLo+Lwjk22A9h6vjcNQRyZF2VLVvtwYrNU3\nmwJ6GCLsLHpwW/yjyvn8iEltnJvByM/eeRnfXV6WDObyiZsE/n6DxIRJodQzFqy9\nGA==\n-----END CERTIFICATE-----\n"; - -type TestCertifiedKey = CertifiedKey; - fn write_test_identity(cert_path: &Path, key_path: &Path) { let identity = generate_test_identity("localhost"); std::fs::write(cert_path, identity.cert.pem()).expect("test cert should be written"); diff --git a/crates/rginx-http/src/proxy/tests/request_headers.rs b/crates/rginx-http/src/proxy/tests/request_headers.rs index 2e8938df..a6fc87d2 100644 --- a/crates/rginx-http/src/proxy/tests/request_headers.rs +++ b/crates/rginx-http/src/proxy/tests/request_headers.rs @@ -1,7 +1,7 @@ -use super::*; - mod uri; +use super::*; + #[test] fn sanitize_request_headers_overwrites_x_forwarded_for_with_sanitized_chain() { let mut headers = HeaderMap::new(); diff --git a/crates/rginx-http/src/proxy/tests/support.rs b/crates/rginx-http/src/proxy/tests/support.rs index 701a1480..c3192245 100644 --- a/crates/rginx-http/src/proxy/tests/support.rs +++ b/crates/rginx-http/src/proxy/tests/support.rs @@ -46,7 +46,7 @@ pub(super) async fn spawn_range_server(seen_ranges: Arc>>) -> let response = match range.and_then(|range| parse_test_range_header(&range)) { Some((start, end)) if start < payload.len() => { - let end = end.min(payload.len() - 1); + let end = end.min(payload.len().saturating_sub(1)); let body = &payload[start..=end]; format!( "HTTP/1.1 206 Partial Content\r\ncontent-length: {}\r\ncontent-range: bytes {}-{}/{}\r\nconnection: close\r\n\r\n{}", diff --git a/crates/rginx-http/src/proxy/upgrade.rs b/crates/rginx-http/src/proxy/upgrade.rs index a6f1f3a4..7a1ebe6f 100644 --- a/crates/rginx-http/src/proxy/upgrade.rs +++ b/crates/rginx-http/src/proxy/upgrade.rs @@ -1,5 +1,5 @@ use super::health::ActivePeerGuard; -use super::*; +use super::{OnUpgrade, TokioIo, copy_bidirectional}; pub(super) async fn proxy_upgraded_connection( downstream_upgrade: OnUpgrade, diff --git a/crates/rginx-http/src/rate_limit/local.rs b/crates/rginx-http/src/rate_limit/local.rs index b91115bd..e51d9545 100644 --- a/crates/rginx-http/src/rate_limit/local.rs +++ b/crates/rginx-http/src/rate_limit/local.rs @@ -14,8 +14,8 @@ const SHARD_CLEANUP_INTERVAL: Duration = Duration::from_secs(30); #[derive(Default)] pub(super) struct LocalRateLimitersInner { - pub(super) shard_mask: usize, pub(super) cleanup_interval: Duration, + pub(super) shard_mask: usize, pub(super) shards: Box<[Shard]>, } @@ -34,15 +34,15 @@ pub(super) struct ShardState { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub(super) struct BucketKey { - pub(super) route: String, pub(super) client_ip: IpAddr, + pub(super) route: String, } #[derive(Debug, Clone)] pub(super) struct TokenBucket { + last_refill: Instant, policy: RouteRateLimit, tokens: f64, - last_refill: Instant, } impl Default for LocalRateLimiters { @@ -78,14 +78,20 @@ impl LocalRateLimiters { pub(super) fn with_config(shard_count: usize, cleanup_interval: Duration) -> Self { let shard_count = shard_count.max(1).next_power_of_two(); let now = Instant::now(); - let next_cleanup_at = now + cleanup_interval; + let next_cleanup_at = now + .checked_add(cleanup_interval) + .expect("rate limit cleanup deadline remains representable"); let shards = (0..shard_count) .map(|_| Shard::new(next_cleanup_at)) .collect::>() .into_boxed_slice(); Self { - inner: LocalRateLimitersInner { shard_mask: shard_count - 1, cleanup_interval, shards }, + inner: LocalRateLimitersInner { + shard_mask: shard_count.saturating_sub(1), + cleanup_interval, + shards, + }, } } } @@ -110,27 +116,15 @@ impl Shard { } impl TokenBucket { - pub(super) fn new(policy: RouteRateLimit, now: Instant) -> Self { - Self { policy, tokens: bucket_capacity(policy), last_refill: now } - } - - pub(super) fn try_acquire(&mut self, policy: RouteRateLimit, now: Instant) -> bool { - self.reconfigure(policy, now); - self.refill(now); - - if self.tokens < 1.0 { - return false; - } - - self.tokens -= 1.0; - true - } - pub(super) fn is_evictable(&mut self, now: Instant) -> bool { self.refill(now); self.tokens >= bucket_capacity(self.policy) } + pub(super) fn new(policy: RouteRateLimit, now: Instant) -> Self { + Self { policy, tokens: bucket_capacity(policy), last_refill: now } + } + fn reconfigure(&mut self, policy: RouteRateLimit, now: Instant) { if self.policy == policy { return; @@ -144,14 +138,26 @@ impl TokenBucket { fn refill(&mut self, now: Instant) { let replenished = self.tokens + now.duration_since(self.last_refill).as_secs_f64() - * self.policy.requests_per_sec as f64; + * f64::from(self.policy.requests_per_sec); self.tokens = replenished.min(bucket_capacity(self.policy)); self.last_refill = now; } + + pub(super) fn try_acquire(&mut self, policy: RouteRateLimit, now: Instant) -> bool { + self.reconfigure(policy, now); + self.refill(now); + + if self.tokens < 1.0 { + return false; + } + + self.tokens -= 1.0; + true + } } pub(super) fn bucket_capacity(policy: RouteRateLimit) -> f64 { - (policy.burst + 1) as f64 + f64::from(policy.burst.saturating_add(1)) } fn maybe_cleanup_buckets(state: &mut ShardState, now: Instant, cleanup_interval: Duration) { @@ -160,16 +166,17 @@ fn maybe_cleanup_buckets(state: &mut ShardState, now: Instant, cleanup_interval: } state.buckets.retain(|_, bucket| !bucket.is_evictable(now)); - state.next_cleanup_at = now + cleanup_interval; + state.next_cleanup_at = now + .checked_add(cleanup_interval) + .expect("rate limit cleanup deadline remains representable"); } fn default_shard_count() -> usize { std::thread::available_parallelism() - .map(|parallelism| parallelism.get().next_power_of_two()) - .unwrap_or(MIN_SHARD_COUNT) + .map_or(MIN_SHARD_COUNT, |parallelism| parallelism.get().next_power_of_two()) .clamp(MIN_SHARD_COUNT, MAX_SHARD_COUNT) } pub(super) fn lock_map(mutex: &Mutex) -> std::sync::MutexGuard<'_, T> { - mutex.lock().unwrap_or_else(|poisoned| poisoned.into_inner()) + mutex.lock().unwrap_or_else(std::sync::PoisonError::into_inner) } diff --git a/crates/rginx-http/src/rate_limit/mod.rs b/crates/rginx-http/src/rate_limit/mod.rs index bc4a3ba8..3f4e4c4a 100644 --- a/crates/rginx-http/src/rate_limit/mod.rs +++ b/crates/rginx-http/src/rate_limit/mod.rs @@ -1,9 +1,3 @@ -use std::net::IpAddr; -use std::path::Path; -use std::sync::Arc; - -use rginx_core::RouteRateLimit; - mod local; #[cfg(target_os = "linux")] mod shared; @@ -12,6 +6,12 @@ mod shared {} #[cfg(test)] mod tests; +use std::net::IpAddr; +use std::path::Path; +use std::sync::Arc; + +use rginx_core::RouteRateLimit; + use local::LocalRateLimiters; #[cfg(target_os = "linux")] @@ -31,26 +31,7 @@ impl Default for RateLimiters { } impl RateLimiters { - pub fn for_runtime(config_path: Option<&Path>) -> Self { - let local = Arc::new(LocalRateLimiters::default()); - #[cfg(target_os = "linux")] - let shared = config_path.and_then(|path| { - SharedRateLimitStore::new(path).map(Arc::new).map_err(|error| { - tracing::warn!( - path = %path.display(), - %error, - "failed to initialize shared rate-limit store; falling back to local limiter state" - ); - }).ok() - }); - - Self { - local, - #[cfg(target_os = "linux")] - shared, - } - } - + #[must_use] pub fn check(&self, route: &str, client_ip: IpAddr, policy: Option<&RouteRateLimit>) -> bool { let Some(policy) = policy.copied() else { return true; @@ -72,6 +53,41 @@ impl RateLimiters { self.local.check(route, client_ip, policy) } + #[cfg(all(test, target_os = "linux"))] + fn check_shared_at( + &self, + route: &str, + client_ip: IpAddr, + policy: RouteRateLimit, + now_unix_ms: u64, + ) -> bool { + self.shared + .as_ref() + .expect("shared store should exist for shared tests") + .check_at(route, client_ip, policy, now_unix_ms) + .expect("shared rate-limit check should succeed in tests") + } + #[must_use] + pub fn for_runtime(config_path: Option<&Path>) -> Self { + let local = Arc::new(LocalRateLimiters::default()); + #[cfg(target_os = "linux")] + let shared = config_path.and_then(|path| { + SharedRateLimitStore::new(path).map(Arc::new).map_err(|error| { + tracing::warn!( + path = %path.display(), + %error, + "failed to initialize shared rate-limit store; falling back to local limiter state" + ); + }).ok() + }); + + Self { + local, + #[cfg(target_os = "linux")] + shared, + } + } + fn local_only() -> Self { Self { local: Arc::new(LocalRateLimiters::default()), @@ -99,19 +115,4 @@ impl RateLimiters { )), } } - - #[cfg(all(test, target_os = "linux"))] - fn check_shared_at( - &self, - route: &str, - client_ip: IpAddr, - policy: RouteRateLimit, - now_unix_ms: u64, - ) -> bool { - self.shared - .as_ref() - .expect("shared store should exist for shared tests") - .check_at(route, client_ip, policy, now_unix_ms) - .expect("shared rate-limit check should succeed in tests") - } } diff --git a/crates/rginx-http/src/rate_limit/shared/bucket.rs b/crates/rginx-http/src/rate_limit/shared/bucket.rs index 681f4a23..77966082 100644 --- a/crates/rginx-http/src/rate_limit/shared/bucket.rs +++ b/crates/rginx-http/src/rate_limit/shared/bucket.rs @@ -1,18 +1,25 @@ use rginx_core::RouteRateLimit; use serde::{Deserialize, Serialize}; +use super::document::SharedRateLimitDocument; + const CLEANUP_INTERVAL_MS: u64 = 30_000; #[derive(Debug, Clone, Serialize, Deserialize)] pub(super) struct SharedTokenBucket { - requests_per_sec: u32, burst: u32, - tokens: f64, last_refill_unix_ms: u64, last_seen_unix_ms: u64, + requests_per_sec: u32, + tokens: f64, } impl SharedTokenBucket { + fn is_evictable(&mut self, now_unix_ms: u64) -> bool { + self.refill(now_unix_ms); + self.tokens >= bucket_capacity(RouteRateLimit::new(self.requests_per_sec, self.burst)) + } + pub(super) fn new(policy: RouteRateLimit, now_unix_ms: u64) -> Self { Self { requests_per_sec: policy.requests_per_sec, @@ -23,24 +30,6 @@ impl SharedTokenBucket { } } - pub(super) fn try_acquire(&mut self, policy: RouteRateLimit, now_unix_ms: u64) -> bool { - self.reconfigure(policy, now_unix_ms); - self.refill(now_unix_ms); - self.last_seen_unix_ms = now_unix_ms; - - if self.tokens < 1.0 { - return false; - } - - self.tokens -= 1.0; - true - } - - fn is_evictable(&mut self, now_unix_ms: u64) -> bool { - self.refill(now_unix_ms); - self.tokens >= bucket_capacity(RouteRateLimit::new(self.requests_per_sec, self.burst)) - } - fn reconfigure(&mut self, policy: RouteRateLimit, now_unix_ms: u64) { if self.requests_per_sec == policy.requests_per_sec && self.burst == policy.burst { return; @@ -55,17 +44,27 @@ impl SharedTokenBucket { fn refill(&mut self, now_unix_ms: u64) { let elapsed_ms = now_unix_ms.saturating_sub(self.last_refill_unix_ms); let replenished = - self.tokens + (elapsed_ms as f64 / 1_000.0) * self.requests_per_sec as f64; + self.tokens + (elapsed_ms as f64 / 1_000.0) * f64::from(self.requests_per_sec); self.tokens = replenished .min(bucket_capacity(RouteRateLimit::new(self.requests_per_sec, self.burst))); self.last_refill_unix_ms = now_unix_ms; } + + pub(super) fn try_acquire(&mut self, policy: RouteRateLimit, now_unix_ms: u64) -> bool { + self.reconfigure(policy, now_unix_ms); + self.refill(now_unix_ms); + self.last_seen_unix_ms = now_unix_ms; + + if self.tokens < 1.0 { + return false; + } + + self.tokens -= 1.0; + true + } } -pub(super) fn maybe_cleanup_document( - document: &mut crate::rate_limit::shared::document::SharedRateLimitDocument, - now_unix_ms: u64, -) { +pub(super) fn maybe_cleanup_document(document: &mut SharedRateLimitDocument, now_unix_ms: u64) { if now_unix_ms < document.next_cleanup_unix_ms { return; } @@ -75,5 +74,5 @@ pub(super) fn maybe_cleanup_document( } pub(super) fn bucket_capacity(policy: RouteRateLimit) -> f64 { - (policy.burst + 1) as f64 + f64::from(policy.burst.saturating_add(1)) } diff --git a/crates/rginx-http/src/rate_limit/shared/document.rs b/crates/rginx-http/src/rate_limit/shared/document.rs index 0d1917de..227f9c2f 100644 --- a/crates/rginx-http/src/rate_limit/shared/document.rs +++ b/crates/rginx-http/src/rate_limit/shared/document.rs @@ -15,11 +15,11 @@ const DEFAULT_SHM_CAPACITY_BYTES: usize = 4 * 1024 * 1024; #[derive(Debug, Serialize, Deserialize)] pub(super) struct SharedRateLimitDocument { - pub(super) version: u32, - #[serde(default)] - pub(super) next_cleanup_unix_ms: u64, #[serde(default)] pub(super) buckets: BTreeMap, + #[serde(default)] + pub(super) next_cleanup_unix_ms: u64, + pub(super) version: u32, } impl Default for SharedRateLimitDocument { @@ -115,8 +115,7 @@ pub(super) fn stable_hash(value: &str) -> u64 { pub(super) fn unix_time_ms(time: SystemTime) -> u64 { time.duration_since(UNIX_EPOCH) - .map(|duration| duration.as_millis().min(u128::from(u64::MAX)) as u64) - .unwrap_or(0) + .map_or(0, |duration| duration.as_millis().min(u128::from(u64::MAX)) as u64) } pub(super) fn invalid_data_error(error: impl std::fmt::Display) -> io::Error { diff --git a/crates/rginx-http/src/rate_limit/shared/lock.rs b/crates/rginx-http/src/rate_limit/shared/lock.rs index ec4dad10..306e9f90 100644 --- a/crates/rginx-http/src/rate_limit/shared/lock.rs +++ b/crates/rginx-http/src/rate_limit/shared/lock.rs @@ -8,6 +8,12 @@ pub(super) struct FileLock { file: File, } +impl Drop for FileLock { + fn drop(&mut self) { + let _ = unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) }; + } +} + pub(super) fn acquire(lock_path: &Path, contention_total: &AtomicU64) -> io::Result { if let Some(parent) = lock_path.parent() { std::fs::create_dir_all(parent)?; @@ -31,9 +37,3 @@ pub(super) fn acquire(lock_path: &Path, contention_total: &AtomicU64) -> io::Res } Err(io::Error::last_os_error()) } - -impl Drop for FileLock { - fn drop(&mut self) { - let _ = unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) }; - } -} diff --git a/crates/rginx-http/src/rate_limit/shared/mod.rs b/crates/rginx-http/src/rate_limit/shared/mod.rs index 810c36a1..6d62434f 100644 --- a/crates/rginx-http/src/rate_limit/shared/mod.rs +++ b/crates/rginx-http/src/rate_limit/shared/mod.rs @@ -1,3 +1,7 @@ +mod bucket; +mod document; +mod lock; + use std::io; use std::net::IpAddr; use std::path::{Path, PathBuf}; @@ -7,10 +11,6 @@ use rginx_core::RouteRateLimit; use crate::cache::shared::memory::{SharedMemorySegment, SharedMemorySegmentConfig}; -mod bucket; -mod document; -mod lock; - use bucket::{SharedTokenBucket, maybe_cleanup_document}; use document::{ SharedRateLimitDocument, bucket_key, read_document, shm_capacity_bytes, stable_hash, @@ -19,38 +19,12 @@ use document::{ use lock::FileLock; pub(super) struct SharedRateLimitStore { - segment_config: SharedMemorySegmentConfig, - lock_path: PathBuf, lock_contention_total: AtomicU64, + lock_path: PathBuf, + segment_config: SharedMemorySegmentConfig, } impl SharedRateLimitStore { - pub(super) fn new(config_path: &Path) -> io::Result { - let identity = format!("rate-limit:{}", config_path.display()); - Self::from_identity(&identity) - } - - #[cfg(test)] - pub(super) fn for_identity(identity: &str) -> io::Result { - Self::from_identity(identity) - } - - fn from_identity(identity: &str) -> io::Result { - let identity_hash = stable_hash(identity); - let temp_dir = std::env::temp_dir(); - let segment_path = temp_dir.join(format!("rginx-rate-limit-{identity_hash:016x}.shm")); - let lock_path = temp_dir.join(format!("rginx-rate-limit-{identity_hash:016x}.lock")); - Ok(Self { - segment_config: SharedMemorySegmentConfig::for_file_identity( - identity, - segment_path, - shm_capacity_bytes(), - ), - lock_path, - lock_contention_total: AtomicU64::new(0), - }) - } - pub(super) fn check( &self, route: &str, @@ -83,15 +57,35 @@ impl SharedRateLimitStore { } #[cfg(test)] - pub(super) fn unlink_for_test(&self) -> io::Result<()> { - let _ = std::fs::remove_file(&self.lock_path); - SharedMemorySegment::unlink(&self.segment_config) + pub(super) fn for_identity(identity: &str) -> io::Result { + Self::from_identity(identity) + } + + fn from_identity(identity: &str) -> io::Result { + let identity_hash = stable_hash(identity); + let temp_dir = std::env::temp_dir(); + let segment_path = temp_dir.join(format!("rginx-rate-limit-{identity_hash:016x}.shm")); + let lock_path = temp_dir.join(format!("rginx-rate-limit-{identity_hash:016x}.lock")); + Ok(Self { + segment_config: SharedMemorySegmentConfig::for_file_identity( + identity, + segment_path, + shm_capacity_bytes(), + ), + lock_path, + lock_contention_total: AtomicU64::new(0), + }) } fn lock(&self) -> io::Result { lock::acquire(&self.lock_path, &self.lock_contention_total) } + pub(super) fn new(config_path: &Path) -> io::Result { + let identity = format!("rate-limit:{}", config_path.display()); + Self::from_identity(&identity) + } + fn open_or_create_segment(&self) -> io::Result { match SharedMemorySegment::attach(&self.segment_config) { Ok(segment) => Ok(segment), @@ -109,4 +103,10 @@ impl SharedRateLimitStore { Err(error) => Err(error), } } + + #[cfg(test)] + pub(super) fn unlink_for_test(&self) -> io::Result<()> { + let _ = std::fs::remove_file(&self.lock_path); + SharedMemorySegment::unlink(&self.segment_config) + } } diff --git a/crates/rginx-http/src/request_target.rs b/crates/rginx-http/src/request_target.rs index e6bdcfa4..0fafcf73 100644 --- a/crates/rginx-http/src/request_target.rs +++ b/crates/rginx-http/src/request_target.rs @@ -1,3 +1,6 @@ +#[cfg(test)] +mod tests; + use http::Uri; pub(crate) struct NormalizedRequestTarget { @@ -5,9 +8,6 @@ pub(crate) struct NormalizedRequestTarget { pub(crate) path_and_query: String, } -#[cfg(test)] -mod tests; - pub(crate) fn normalize_request_target(uri: &Uri) -> NormalizedRequestTarget { let path = normalize_request_path(uri.path()); let path_and_query = match uri.query() { @@ -19,9 +19,7 @@ pub(crate) fn normalize_request_target(uri: &Uri) -> NormalizedRequestTarget { } pub(crate) fn raw_request_target(uri: &Uri) -> String { - uri.path_and_query() - .map(|value| value.as_str().to_string()) - .unwrap_or_else(|| uri.path().to_string()) + uri.path_and_query().map_or_else(|| uri.path().to_string(), |value| value.as_str().to_string()) } pub(crate) fn normalize_request_path(path: &str) -> String { diff --git a/crates/rginx-http/src/router.rs b/crates/rginx-http/src/router.rs index 1a0de9b0..8710911a 100644 --- a/crates/rginx-http/src/router.rs +++ b/crates/rginx-http/src/router.rs @@ -1,24 +1,49 @@ +#[cfg(test)] +mod tests; use rginx_core::{ConfigSnapshot, Route, VirtualHost}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct GrpcRequestMatch<'a> { - pub service: &'a str, pub method: &'a str, + pub service: &'a str, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RouteMatchContext<'a> { - pub path: &'a str, pub grpc: Option>, + pub path: &'a str, } impl<'a> RouteMatchContext<'a> { + #[must_use] pub fn new(path: &'a str) -> Self { Self { path, grpc: None } } + #[must_use] pub fn with_grpc(path: &'a str, service: &'a str, method: &'a str) -> Self { - Self { path, grpc: Some(GrpcRequestMatch { service, method }) } + Self { path, grpc: Some(GrpcRequestMatch { method, service }) } + } +} + +#[derive(Clone, Copy)] +struct Candidate<'a> { + grpc_rank: u8, + index: usize, + match_len: usize, + preferred_prefix: bool, + route: &'a Route, +} + +impl<'a> Candidate<'a> { + fn new(index: usize, route: &'a Route, match_len: usize, preferred_prefix: bool) -> Self { + Self { + index, + route, + match_len, + grpc_rank: route.grpc_match.as_ref().map_or(0, rginx_core::GrpcRouteMatch::priority), + preferred_prefix, + } } } @@ -76,27 +101,6 @@ pub fn select_route_with_context<'a>( prefix.map(|candidate| candidate.route) } -#[derive(Clone, Copy)] -struct Candidate<'a> { - index: usize, - route: &'a Route, - match_len: usize, - grpc_rank: u8, - preferred_prefix: bool, -} - -impl<'a> Candidate<'a> { - fn new(index: usize, route: &'a Route, match_len: usize, preferred_prefix: bool) -> Self { - Self { - index, - route, - match_len, - grpc_rank: route.grpc_match.as_ref().map_or(0, |grpc_match| grpc_match.priority()), - preferred_prefix, - } - } -} - fn select_more_specific<'a>( current: Option>, candidate: Candidate<'a>, @@ -119,6 +123,7 @@ fn select_more_specific<'a>( } /// 根据 Host 选择虚拟主机 +#[must_use] pub fn select_vhost<'a>( vhosts: &'a [VirtualHost], default: &'a VirtualHost, @@ -127,6 +132,7 @@ pub fn select_vhost<'a>( select_vhost_for_listener(vhosts, default, host, None) } +#[must_use] pub fn select_vhost_for_listener<'a>( vhosts: &'a [VirtualHost], default: &'a VirtualHost, @@ -156,7 +162,7 @@ pub fn select_vhost_for_listener<'a>( match selected { None => selected = Some((priority, vhost)), Some((current_priority, _)) if priority > current_priority => { - selected = Some((priority, vhost)) + selected = Some((priority, vhost)); } Some(_) => {} } @@ -174,10 +180,12 @@ pub fn select_vhost_from_snapshot<'a>( } /// 在指定虚拟主机内选择路由 +#[must_use] pub fn select_route_in_vhost<'a>(vhost: &'a VirtualHost, path: &str) -> Option<&'a Route> { select_route(&vhost.routes, path) } +#[must_use] pub fn select_route_in_vhost_with_context<'a>( vhost: &'a VirtualHost, context: &RouteMatchContext<'_>, @@ -198,6 +206,7 @@ pub fn select_route_in_snapshot_vhost_with_context<'a>( ) } +#[must_use] pub fn select_named_route_in_vhost<'a>(vhost: &'a VirtualHost, name: &str) -> Option<&'a Route> { vhost.routes.iter().find(|route| match &route.matcher { rginx_core::RouteMatcher::Named(candidate) => candidate == name, @@ -214,6 +223,7 @@ pub fn select_named_route_in_snapshot_vhost<'a>( } /// 组合:Host + Path 双层匹配 +#[must_use] pub fn select_route_by_host<'a>( default_vhost: &'a VirtualHost, vhosts: &'a [VirtualHost], @@ -224,6 +234,7 @@ pub fn select_route_by_host<'a>( select_route_in_vhost(vhost, path).map(|route| (vhost, route)) } +#[must_use] pub fn select_route_by_host_with_context<'a>( default_vhost: &'a VirtualHost, vhosts: &'a [VirtualHost], @@ -243,6 +254,3 @@ fn route_matches(route: &Route, context: &RouteMatchContext<'_>) -> bool { context.grpc.is_some_and(|grpc| grpc_match.matches(grpc.service, grpc.method)) }) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/router/tests.rs b/crates/rginx-http/src/router/tests.rs index b28f72e0..679c527c 100644 --- a/crates/rginx-http/src/router/tests.rs +++ b/crates/rginx-http/src/router/tests.rs @@ -1,10 +1,10 @@ +mod context; + use http::StatusCode; use rginx_core::{ReturnAction, Route, RouteAccessControl, RouteAction, RouteMatcher, VirtualHost}; use super::{select_route, select_route_by_host, select_vhost}; -mod context; - fn make_route(path: &str, body: &str) -> Route { Route { cache: None, diff --git a/crates/rginx-http/src/server/connection.rs b/crates/rginx-http/src/server/connection.rs index df83a344..26252ee5 100644 --- a/crates/rginx-http/src/server/connection.rs +++ b/crates/rginx-http/src/server/connection.rs @@ -16,15 +16,15 @@ const ALPN_H2: &[u8] = b"h2"; #[derive(Clone, Copy)] pub(super) struct Http1ConnectionOptions { + pub(super) half_close: bool, + pub(super) header_read_timeout: Option, pub(super) keep_alive: bool, + pub(super) max_buf_size: Option, pub(super) max_headers: Option, - pub(super) header_read_timeout: Option, + pub(super) pipeline_flush: bool, + pub(super) preserve_header_case: bool, pub(super) response_write_timeout: Option, - pub(super) half_close: bool, pub(super) title_case_headers: bool, - pub(super) preserve_header_case: bool, - pub(super) max_buf_size: Option, - pub(super) pipeline_flush: bool, pub(super) writev: Option, } @@ -165,7 +165,7 @@ fn extract_tls_client_identity( stream: &tokio_rustls::server::TlsStream, ) -> Option { let certs = stream.get_ref().1.peer_certificates()?; - Some(parse_tls_client_identity(certs.iter().map(|cert| cert.as_ref()))) + Some(parse_tls_client_identity(certs.iter().map(std::convert::AsRef::as_ref))) } pub(super) fn parse_tls_client_identity<'a>( diff --git a/crates/rginx-http/src/server/http3/accept_loop.rs b/crates/rginx-http/src/server/http3/accept_loop.rs index 1f2315ec..7a699d28 100644 --- a/crates/rginx-http/src/server/http3/accept_loop.rs +++ b/crates/rginx-http/src/server/http3/accept_loop.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{JoinError, JoinSet, Result, watch}; pub async fn serve_http3( endpoint: quinn::Endpoint, diff --git a/crates/rginx-http/src/server/http3/body.rs b/crates/rginx-http/src/server/http3/body.rs index 79c70b76..dd6fc913 100644 --- a/crates/rginx-http/src/server/http3/body.rs +++ b/crates/rginx-http/src/server/http3/body.rs @@ -1,4 +1,6 @@ -use super::*; +use super::{ + Body, BoxError, Buf, Bytes, Context, Frame, Poll, RecvStream, RequestStream, SizeHint, +}; use pin_project_lite::pin_project; pin_project! { @@ -37,6 +39,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.data_finished && self.trailers_finished + } + fn poll_frame( self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>, @@ -91,10 +97,6 @@ where Poll::Ready(None) } - fn is_end_stream(&self) -> bool { - self.data_finished && self.trailers_finished - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } diff --git a/crates/rginx-http/src/server/http3/connection.rs b/crates/rginx-http/src/server/http3/connection.rs index 4bc76508..ea84ee07 100644 --- a/crates/rginx-http/src/server/http3/connection.rs +++ b/crates/rginx-http/src/server/http3/connection.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + Arc, AtomicBool, ConnectionPeerAddrs, Duration, Error, H3Connection, Incoming, JoinError, + JoinSet, Ordering, Result, SocketAddr, TlsClientIdentity, watch, +}; pub(super) async fn serve_http3_connection( incoming: Incoming, @@ -189,7 +192,7 @@ fn extract_http3_tls_client_identity(connection: &quinn::Connection) -> Option>>() .ok()?; Some(super::super::connection::parse_tls_client_identity( - certificates.iter().map(|certificate| certificate.as_ref()), + certificates.iter().map(std::convert::AsRef::as_ref), )) } diff --git a/crates/rginx-http/src/server/http3/endpoint.rs b/crates/rginx-http/src/server/http3/endpoint.rs index c7fe8ec7..5d92a8ea 100644 --- a/crates/rginx-http/src/server/http3/endpoint.rs +++ b/crates/rginx-http/src/server/http3/endpoint.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Arc, Error, Result, build_http3_server_config, hkdf, hmac}; const HTTP3_ACTIVE_CONNECTION_ID_LIMIT_NO_MIGRATION: u32 = 2; const HTTP3_ACTIVE_CONNECTION_ID_LIMIT_MIGRATION: u32 = 5; diff --git a/crates/rginx-http/src/server/http3/host_key.rs b/crates/rginx-http/src/server/http3/host_key.rs index cf66b317..eef20889 100644 --- a/crates/rginx-http/src/server/http3/host_key.rs +++ b/crates/rginx-http/src/server/http3/host_key.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Digest, Error, Path, Result, SecureRandom, Sha256, rand}; const HTTP3_HOST_KEY_BYTES: usize = 64; diff --git a/crates/rginx-http/src/server/http3/mod.rs b/crates/rginx-http/src/server/http3/mod.rs index 13c197cd..db5a9c70 100644 --- a/crates/rginx-http/src/server/http3/mod.rs +++ b/crates/rginx-http/src/server/http3/mod.rs @@ -1,5 +1,16 @@ //! HTTP/3 listener bootstrap, connection lifecycle, and request/response bridging. +mod accept_loop; +mod body; +mod close_reason; +mod connection; +mod endpoint; +mod host_key; +mod request; +mod response; +#[cfg(test)] +mod tests; + use std::net::SocketAddr; use std::path::Path; use std::sync::Arc; @@ -25,17 +36,6 @@ use crate::client_ip::{ConnectionPeerAddrs, TlsClientIdentity}; use crate::handler::{BoxError, HttpResponse}; use crate::tls::build_http3_server_config; -mod accept_loop; -mod body; -mod close_reason; -mod connection; -mod endpoint; -mod host_key; -mod request; -mod response; -#[cfg(test)] -mod tests; - pub use accept_loop::serve_http3; pub use endpoint::{bind_http3_endpoint, bind_http3_endpoint_with_socket}; diff --git a/crates/rginx-http/src/server/http3/request.rs b/crates/rginx-http/src/server/http3/request.rs index 2f45ce93..e7ad17df 100644 --- a/crates/rginx-http/src/server/http3/request.rs +++ b/crates/rginx-http/src/server/http3/request.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Arc, Bytes, ConnectionPeerAddrs, Error, RequestResolver, Result, Version}; pub(super) async fn serve_http3_request( resolver: RequestResolver, diff --git a/crates/rginx-http/src/server/http3/response.rs b/crates/rginx-http/src/server/http3/response.rs index e8b1eb95..bf30c77b 100644 --- a/crates/rginx-http/src/server/http3/response.rs +++ b/crates/rginx-http/src/server/http3/response.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{BodyExt, Bytes, Error, HttpResponse, RequestStream, Response, Result, SendStream}; pub(super) async fn send_http3_response( mut stream: RequestStream, diff --git a/crates/rginx-http/src/server/proxy_protocol.rs b/crates/rginx-http/src/server/proxy_protocol.rs index 033024f4..d1611f2a 100644 --- a/crates/rginx-http/src/server/proxy_protocol.rs +++ b/crates/rginx-http/src/server/proxy_protocol.rs @@ -61,7 +61,7 @@ pub(super) fn parse_proxy_protocol_v1( } let source = match protocol { - Some("TCP4") | Some("TCP6") => { + Some("TCP4" | "TCP6") => { let ip = source_addr .ok_or_else(|| { std::io::Error::new( diff --git a/crates/rginx-http/src/server/tests.rs b/crates/rginx-http/src/server/tests.rs index 70b55010..e608275b 100644 --- a/crates/rginx-http/src/server/tests.rs +++ b/crates/rginx-http/src/server/tests.rs @@ -8,39 +8,6 @@ use rustls::pki_types::{CertificateDer, pem::PemObject}; use super::connection::parse_tls_client_identity; use super::proxy_protocol::parse_proxy_protocol_v1; -fn remote_proxy_peer_addr() -> SocketAddr { - SocketAddr::from((Ipv4Addr::new(10, 0, 0, 1), 4000)) -} - -#[test] -fn proxy_protocol_v1_parses_tcp4_source_address() { - let source = parse_proxy_protocol_v1( - "PROXY TCP4 198.51.100.9 203.0.113.10 12345 443\r\n", - "10.0.0.1:4000".parse().unwrap(), - true, - ) - .expect("header should parse"); - - assert_eq!(source, Some("198.51.100.9:12345".parse().unwrap())); -} - -#[test] -fn proxy_protocol_v1_accepts_unknown_transport() { - let source = - parse_proxy_protocol_v1("PROXY UNKNOWN\r\n", "10.0.0.1:4000".parse().unwrap(), true) - .expect("unknown header should parse"); - - assert_eq!(source, None); -} - -#[test] -fn proxy_protocol_v1_rejects_invalid_headers() { - let error = parse_proxy_protocol_v1("BROKEN\r\n", "10.0.0.1:4000".parse().unwrap(), true) - .expect_err("invalid header should fail"); - - assert_eq!(error.kind(), std::io::ErrorKind::InvalidData); -} - proptest! { #![proptest_config(ProptestConfig::with_cases(128))] @@ -105,6 +72,39 @@ proptest! { } } +fn remote_proxy_peer_addr() -> SocketAddr { + SocketAddr::from((Ipv4Addr::new(10, 0, 0, 1), 4000)) +} + +#[test] +fn proxy_protocol_v1_parses_tcp4_source_address() { + let source = parse_proxy_protocol_v1( + "PROXY TCP4 198.51.100.9 203.0.113.10 12345 443\r\n", + "10.0.0.1:4000".parse().unwrap(), + true, + ) + .expect("header should parse"); + + assert_eq!(source, Some("198.51.100.9:12345".parse().unwrap())); +} + +#[test] +fn proxy_protocol_v1_accepts_unknown_transport() { + let source = + parse_proxy_protocol_v1("PROXY UNKNOWN\r\n", "10.0.0.1:4000".parse().unwrap(), true) + .expect("unknown header should parse"); + + assert_eq!(source, None); +} + +#[test] +fn proxy_protocol_v1_rejects_invalid_headers() { + let error = parse_proxy_protocol_v1("BROKEN\r\n", "10.0.0.1:4000".parse().unwrap(), true) + .expect_err("invalid header should fail"); + + assert_eq!(error.kind(), std::io::ErrorKind::InvalidData); +} + #[test] fn parse_tls_client_identity_extracts_subject_and_dns_san() { let mut params = CertificateParams::new(vec!["localhost".to_string()]) diff --git a/crates/rginx-http/src/state/agent.rs b/crates/rginx-http/src/state/agent.rs index 5764a35b..fac60962 100644 --- a/crates/rginx-http/src/state/agent.rs +++ b/crates/rginx-http/src/state/agent.rs @@ -1,31 +1,53 @@ -use super::*; +use super::{AgentRuntimeSnapshot, BTreeMap, Ordering, PathBuf, SharedState, watch}; #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct AgentRuntimeUpdate { - pub connection_state: String, pub command_cursor: Option, + pub connection_state: String, pub in_flight_command_id: Option, - pub last_register_success_unix_ms: Option, pub last_heartbeat_success_unix_ms: Option, + pub last_register_success_unix_ms: Option, } #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct AgentRuntimeState { + command_cursor: Option, configured: bool, - endpoint: Option, - node_id: Option, - state_path: Option, - region: Option, - pop: Option, - labels: BTreeMap, connection_state: String, - command_cursor: Option, + endpoint: Option, in_flight_command_id: Option, - last_register_success_unix_ms: Option, + labels: BTreeMap, last_heartbeat_success_unix_ms: Option, + last_register_success_unix_ms: Option, + node_id: Option, + pop: Option, + region: Option, + state_path: Option, } impl AgentRuntimeState { + pub(super) fn apply_update(&mut self, update: AgentRuntimeUpdate) -> bool { + let mut changed = false; + update_if_changed(&mut self.connection_state, update.connection_state, &mut changed); + update_if_changed(&mut self.command_cursor, update.command_cursor, &mut changed); + update_if_changed( + &mut self.in_flight_command_id, + update.in_flight_command_id, + &mut changed, + ); + update_if_changed( + &mut self.last_register_success_unix_ms, + update.last_register_success_unix_ms, + &mut changed, + ); + update_if_changed( + &mut self.last_heartbeat_success_unix_ms, + update.last_heartbeat_success_unix_ms, + &mut changed, + ); + changed + } + pub(super) fn from_config(agent: Option<&rginx_core::AgentSettings>) -> Self { let Some(agent) = agent else { return Self::default(); @@ -47,42 +69,12 @@ impl AgentRuntimeState { } } - pub(super) fn sync_config(&mut self, agent: Option<&rginx_core::AgentSettings>) -> bool { - let mut next = Self::from_config(agent); - if self.same_config_identity(&next) { - next.connection_state.clone_from(&self.connection_state); - next.command_cursor.clone_from(&self.command_cursor); - next.in_flight_command_id.clone_from(&self.in_flight_command_id); - next.last_register_success_unix_ms = self.last_register_success_unix_ms; - next.last_heartbeat_success_unix_ms = self.last_heartbeat_success_unix_ms; - } - if *self == next { - return false; - } - *self = next; - true - } - - pub(super) fn apply_update(&mut self, update: AgentRuntimeUpdate) -> bool { - let mut changed = false; - update_if_changed(&mut self.connection_state, update.connection_state, &mut changed); - update_if_changed(&mut self.command_cursor, update.command_cursor, &mut changed); - update_if_changed( - &mut self.in_flight_command_id, - update.in_flight_command_id, - &mut changed, - ); - update_if_changed( - &mut self.last_register_success_unix_ms, - update.last_register_success_unix_ms, - &mut changed, - ); - update_if_changed( - &mut self.last_heartbeat_success_unix_ms, - update.last_heartbeat_success_unix_ms, - &mut changed, - ); - changed + fn same_config_identity(&self, other: &Self) -> bool { + self.configured + && other.configured + && self.endpoint == other.endpoint + && self.node_id == other.node_id + && self.state_path == other.state_path } pub(super) fn snapshot(&self, locally_disabled: bool) -> AgentRuntimeSnapshot { @@ -108,12 +100,20 @@ impl AgentRuntimeState { } } - fn same_config_identity(&self, other: &Self) -> bool { - self.configured - && other.configured - && self.endpoint == other.endpoint - && self.node_id == other.node_id - && self.state_path == other.state_path + pub(super) fn sync_config(&mut self, agent: Option<&rginx_core::AgentSettings>) -> bool { + let mut next = Self::from_config(agent); + if self.same_config_identity(&next) { + next.connection_state.clone_from(&self.connection_state); + next.command_cursor.clone_from(&self.command_cursor); + next.in_flight_command_id.clone_from(&self.in_flight_command_id); + next.last_register_success_unix_ms = self.last_register_success_unix_ms; + next.last_heartbeat_success_unix_ms = self.last_heartbeat_success_unix_ms; + } + if *self == next { + return false; + } + *self = next; + true } } @@ -137,36 +137,31 @@ impl Default for AgentRuntimeState { } impl SharedState { - pub fn set_agent_configured(&self, agent: &rginx_core::AgentSettings) { - if self.sync_agent_runtime_config(Some(agent)) { - self.mark_status_snapshot_changed(); - } + #[must_use] + pub fn agent_disabled_receiver(&self) -> watch::Receiver { + self.lifecycle.agent_disabled.subscribe() } - pub fn update_agent_runtime(&self, update: AgentRuntimeUpdate) { - let changed = self - .lifecycle - .agent_runtime - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .apply_update(update); - if changed { - self.mark_status_snapshot_changed(); - } + fn agent_locally_disabled(&self) -> bool { + self.lifecycle.agent_disabled_value.load(Ordering::Acquire) } + #[must_use] pub fn agent_status_snapshot(&self) -> AgentRuntimeSnapshot { self.lifecycle .agent_runtime .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .snapshot(self.agent_locally_disabled()) } - pub fn agent_disabled_receiver(&self) -> watch::Receiver { - self.lifecycle.agent_disabled.subscribe() + pub fn set_agent_configured(&self, agent: &rginx_core::AgentSettings) { + if self.sync_agent_runtime_config(Some(agent)) { + self.mark_status_snapshot_changed(); + } } + #[must_use] pub fn set_agent_locally_disabled(&self, disabled: bool) -> AgentRuntimeSnapshot { let previous = self.lifecycle.agent_disabled_value.swap(disabled, Ordering::AcqRel); if previous != disabled { @@ -183,12 +178,20 @@ impl SharedState { self.lifecycle .agent_runtime .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .sync_config(agent) } - fn agent_locally_disabled(&self) -> bool { - self.lifecycle.agent_disabled_value.load(Ordering::Acquire) + pub fn update_agent_runtime(&self, update: AgentRuntimeUpdate) { + let changed = self + .lifecycle + .agent_runtime + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .apply_update(update); + if changed { + self.mark_status_snapshot_changed(); + } } } diff --git a/crates/rginx-http/src/state/cache.rs b/crates/rginx-http/src/state/cache.rs index 43cadec7..318a2480 100644 --- a/crates/rginx-http/src/state/cache.rs +++ b/crates/rginx-http/src/state/cache.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + CacheInvalidationResult, CachePurgeResult, CacheStatsSnapshot, ConfigSnapshot, HashMap, + SharedState, +}; impl SharedState { pub async fn cache_stats_snapshot(&self) -> CacheStatsSnapshot { @@ -17,97 +20,116 @@ impl SharedState { cache.cleanup_inactive_entries().await; } - pub async fn purge_cache_zone( + pub async fn clear_cache_invalidations( &self, zone_name: &str, - ) -> std::result::Result { + ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.purge_zone(zone_name).await + cache.clear_invalidations(zone_name).await } - pub async fn purge_cache_key( + pub async fn invalidate_cache_key( &self, zone_name: &str, key: &str, - ) -> std::result::Result { + ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.purge_key(zone_name, key).await + cache.invalidate_key(zone_name, key).await } - pub async fn purge_cache_prefix( + pub async fn invalidate_cache_prefix( &self, zone_name: &str, prefix: &str, - ) -> std::result::Result { + ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.purge_prefix(zone_name, prefix).await + cache.invalidate_prefix(zone_name, prefix).await } - pub async fn invalidate_cache_zone( + pub async fn invalidate_cache_tag( &self, zone_name: &str, + tag: &str, ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.invalidate_zone(zone_name).await + cache.invalidate_tag(zone_name, tag).await } - pub async fn invalidate_cache_key( + pub async fn invalidate_cache_zone( &self, zone_name: &str, - key: &str, ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.invalidate_key(zone_name, key).await + cache.invalidate_zone(zone_name).await } - pub async fn invalidate_cache_prefix( + pub(crate) fn mark_all_cache_zones_changed( + &self, + previous: &ConfigSnapshot, + next: &ConfigSnapshot, + version: u64, + ) { + let mut cache_versions = self + .observability + .cache_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + for zone_name in previous.cache_zones.keys() { + cache_versions.insert(zone_name.clone(), version); + } + for zone_name in next.cache_zones.keys() { + cache_versions.insert(zone_name.clone(), version); + } + } + + pub async fn purge_cache_key( &self, zone_name: &str, - prefix: &str, - ) -> std::result::Result { + key: &str, + ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.invalidate_prefix(zone_name, prefix).await + cache.purge_key(zone_name, key).await } - pub async fn invalidate_cache_tag( + pub async fn purge_cache_prefix( &self, zone_name: &str, - tag: &str, - ) -> std::result::Result { + prefix: &str, + ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.invalidate_tag(zone_name, tag).await + cache.purge_prefix(zone_name, prefix).await } - pub async fn clear_cache_invalidations( + pub async fn purge_cache_zone( &self, zone_name: &str, - ) -> std::result::Result { + ) -> std::result::Result { let cache = { let state = self.inner.read().await; state.cache.clone() }; - cache.clear_invalidations(zone_name).await + cache.purge_zone(zone_name).await } pub(crate) fn sync_cache_versions(&self, config: &ConfigSnapshot) { @@ -115,33 +137,14 @@ impl SharedState { .observability .cache_component_versions .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); let next = build_cache_zone_versions(config, Some(&*existing)); drop(existing); *self .observability .cache_component_versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) = next; - } - - pub(crate) fn mark_all_cache_zones_changed( - &self, - previous: &ConfigSnapshot, - next: &ConfigSnapshot, - version: u64, - ) { - let mut cache_versions = self - .observability - .cache_component_versions - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - for zone_name in previous.cache_zones.keys() { - cache_versions.insert(zone_name.clone(), version); - } - for zone_name in next.cache_zones.keys() { - cache_versions.insert(zone_name.clone(), version); - } + .unwrap_or_else(std::sync::PoisonError::into_inner) = next; } } diff --git a/crates/rginx-http/src/state/connections/guards.rs b/crates/rginx-http/src/state/connections/guards.rs index 27e299ba..5a2d8e87 100644 --- a/crates/rginx-http/src/state/connections/guards.rs +++ b/crates/rginx-http/src/state/connections/guards.rs @@ -1,4 +1,7 @@ -use super::super::*; +use super::super::{ + Arc, AtomicUsize, ListenerTrafficCounters, Ordering, SnapshotBusState, StdRwLock, + TrafficComponentVersions, +}; pub struct ActiveConnectionGuard { pub(super) active_connections: Arc, @@ -30,7 +33,7 @@ impl Drop for ActiveConnectionGuard { self.snapshot_bus.mark_changed_and_notify(true, false, true, false, false, false); self.traffic_component_versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .listeners .insert(self.listener_id.clone(), version); } @@ -43,7 +46,7 @@ impl Drop for ActiveHttp3ConnectionGuard { self.snapshot_bus.mark_changed_and_notify(true, false, true, false, false, false); self.traffic_component_versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .listeners .insert(self.listener_id.clone(), version); } @@ -56,7 +59,7 @@ impl Drop for ActiveHttp3RequestStreamGuard { self.snapshot_bus.mark_changed_and_notify(true, false, true, false, false, false); self.traffic_component_versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .listeners .insert(self.listener_id.clone(), version); } diff --git a/crates/rginx-http/src/state/connections/operations/acquire.rs b/crates/rginx-http/src/state/connections/operations/acquire.rs index 38a33326..eb2ba20f 100644 --- a/crates/rginx-http/src/state/connections/operations/acquire.rs +++ b/crates/rginx-http/src/state/connections/operations/acquire.rs @@ -1,56 +1,21 @@ -use super::super::super::*; +use super::super::super::{Ordering, Result, SharedState}; use super::super::guards::{ ActiveConnectionGuard, ActiveHttp3ConnectionGuard, ActiveHttp3RequestStreamGuard, }; impl SharedState { + #[must_use] pub fn active_connection_count(&self) -> usize { self.request_runtime.active_connections.load(Ordering::Acquire) } - pub fn try_acquire_connection( - &self, - listener_id: &str, - limit: Option, - ) -> Option { - let listener_active_connections = self - .listener_runtime - .active_connections - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .get(listener_id)? - .clone(); - loop { - let current = listener_active_connections.load(Ordering::Acquire); - if limit.is_some_and(|limit| current >= limit) { - return None; - } - - if listener_active_connections - .compare_exchange_weak(current, current + 1, Ordering::AcqRel, Ordering::Acquire) - .is_ok() - { - self.request_runtime.active_connections.fetch_add(1, Ordering::AcqRel); - return Some(ActiveConnectionGuard { - active_connections: self.request_runtime.active_connections.clone(), - listener_active_connections, - listener_id: listener_id.to_string(), - snapshot_bus: self.snapshot_bus.clone(), - traffic_component_versions: self - .observability - .traffic_component_versions - .clone(), - }); - } - } - } - + #[must_use] pub fn retain_connection_slot(&self, listener_id: &str) -> ActiveConnectionGuard { let listener_active_connections = self .listener_runtime .active_connections .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .get(listener_id) .expect("listener id should exist while retaining a connection slot") .clone(); @@ -103,4 +68,47 @@ impl SharedState { traffic_component_versions: self.observability.traffic_component_versions.clone(), }) } + + #[must_use] + pub fn try_acquire_connection( + &self, + listener_id: &str, + limit: Option, + ) -> Option { + let listener_active_connections = self + .listener_runtime + .active_connections + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .get(listener_id)? + .clone(); + loop { + let current = listener_active_connections.load(Ordering::Acquire); + if limit.is_some_and(|limit| current >= limit) { + return None; + } + + if listener_active_connections + .compare_exchange_weak( + current, + current.saturating_add(1), + Ordering::AcqRel, + Ordering::Acquire, + ) + .is_ok() + { + self.request_runtime.active_connections.fetch_add(1, Ordering::AcqRel); + return Some(ActiveConnectionGuard { + active_connections: self.request_runtime.active_connections.clone(), + listener_active_connections, + listener_id: listener_id.to_string(), + snapshot_bus: self.snapshot_bus.clone(), + traffic_component_versions: self + .observability + .traffic_component_versions + .clone(), + }); + } + } + } } diff --git a/crates/rginx-http/src/state/connections/operations/record.rs b/crates/rginx-http/src/state/connections/operations/record.rs index 7b14b910..78f8435d 100644 --- a/crates/rginx-http/src/state/connections/operations/record.rs +++ b/crates/rginx-http/src/state/connections/operations/record.rs @@ -1,4 +1,4 @@ -use super::super::super::*; +use super::super::super::{Ordering, SharedState, TlsHandshakeFailureReason}; impl SharedState { pub(crate) fn record_connection_accepted(&self, listener_id: &str) { @@ -9,110 +9,60 @@ impl SharedState { self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); } - pub(crate) fn record_mtls_handshake_success(&self, listener_id: &str, authenticated: bool) { - if !authenticated { - return; - } - - self.observability - .counters - .downstream_mtls_authenticated_connections - .fetch_add(1, Ordering::Relaxed); + pub(crate) fn record_connection_rejected(&self, listener_id: &str) { + self.observability.counters.downstream_connections_rejected.fetch_add(1, Ordering::Relaxed); if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.downstream_mtls_authenticated_connections.fetch_add(1, Ordering::Relaxed); + counters.downstream_connections_rejected.fetch_add(1, Ordering::Relaxed); } self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); } - pub(crate) fn record_tls_handshake_failure( + pub(crate) fn record_http3_connection_close( &self, listener_id: &str, - reason: TlsHandshakeFailureReason, + reason: quinn::ConnectionError, ) { - self.observability - .counters - .downstream_tls_handshake_failures - .fetch_add(1, Ordering::Relaxed); - match reason { - TlsHandshakeFailureReason::MissingClientCert => { - self.observability - .counters - .downstream_tls_handshake_failures_missing_client_cert - .fetch_add(1, Ordering::Relaxed); - if let Some(counters) = self.listener_traffic_counters(listener_id) { + if let Some(counters) = self.listener_traffic_counters(listener_id) { + match reason { + quinn::ConnectionError::VersionMismatch => { counters - .downstream_tls_handshake_failures_missing_client_cert + .http3_connection_close_version_mismatch_total .fetch_add(1, Ordering::Relaxed); } - } - TlsHandshakeFailureReason::UnknownCa => { - self.observability - .counters - .downstream_tls_handshake_failures_unknown_ca - .fetch_add(1, Ordering::Relaxed); - if let Some(counters) = self.listener_traffic_counters(listener_id) { + quinn::ConnectionError::TransportError(_) => { counters - .downstream_tls_handshake_failures_unknown_ca + .http3_connection_close_transport_error_total .fetch_add(1, Ordering::Relaxed); } - } - TlsHandshakeFailureReason::BadCertificate => { - self.observability - .counters - .downstream_tls_handshake_failures_bad_certificate - .fetch_add(1, Ordering::Relaxed); - if let Some(counters) = self.listener_traffic_counters(listener_id) { + quinn::ConnectionError::ConnectionClosed(_) => { counters - .downstream_tls_handshake_failures_bad_certificate + .http3_connection_close_connection_closed_total .fetch_add(1, Ordering::Relaxed); } - } - TlsHandshakeFailureReason::CertificateRevoked => { - self.observability - .counters - .downstream_tls_handshake_failures_certificate_revoked - .fetch_add(1, Ordering::Relaxed); - if let Some(counters) = self.listener_traffic_counters(listener_id) { + quinn::ConnectionError::ApplicationClosed(_) => { counters - .downstream_tls_handshake_failures_certificate_revoked + .http3_connection_close_application_closed_total .fetch_add(1, Ordering::Relaxed); } - } - TlsHandshakeFailureReason::VerifyDepthExceeded => { - self.observability - .counters - .downstream_tls_handshake_failures_verify_depth_exceeded - .fetch_add(1, Ordering::Relaxed); - if let Some(counters) = self.listener_traffic_counters(listener_id) { + quinn::ConnectionError::Reset => { + counters.http3_connection_close_reset_total.fetch_add(1, Ordering::Relaxed); + } + quinn::ConnectionError::TimedOut => { + counters.http3_connection_close_timed_out_total.fetch_add(1, Ordering::Relaxed); + } + quinn::ConnectionError::LocallyClosed => { counters - .downstream_tls_handshake_failures_verify_depth_exceeded + .http3_connection_close_locally_closed_total .fetch_add(1, Ordering::Relaxed); } - } - TlsHandshakeFailureReason::Other => { - self.observability - .counters - .downstream_tls_handshake_failures_other - .fetch_add(1, Ordering::Relaxed); - if let Some(counters) = self.listener_traffic_counters(listener_id) { + quinn::ConnectionError::CidsExhausted => { counters - .downstream_tls_handshake_failures_other + .http3_connection_close_cids_exhausted_total .fetch_add(1, Ordering::Relaxed); } } } - if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.downstream_tls_handshake_failures.fetch_add(1, Ordering::Relaxed); - } - self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); - } - - pub(crate) fn record_connection_rejected(&self, listener_id: &str) { - self.observability.counters.downstream_connections_rejected.fetch_add(1, Ordering::Relaxed); - if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.downstream_connections_rejected.fetch_add(1, Ordering::Relaxed); - } - self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); + self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); } pub(crate) fn record_http3_early_data_accepted_request(&self, listener_id: &str) { @@ -137,93 +87,143 @@ impl SharedState { self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); } - pub(crate) fn record_http3_retry_issued(&self, listener_id: &str) { + pub(crate) fn record_http3_request_accept_error(&self, listener_id: &str) { if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.http3_retry_issued_total.fetch_add(1, Ordering::Relaxed); + counters.http3_request_accept_errors_total.fetch_add(1, Ordering::Relaxed); } self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); } - pub(crate) fn record_http3_retry_failed(&self, listener_id: &str) { + pub(crate) fn record_http3_request_body_stream_error(&self, listener_id: &str) { if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.http3_retry_failed_total.fetch_add(1, Ordering::Relaxed); + counters.http3_request_body_stream_errors_total.fetch_add(1, Ordering::Relaxed); } self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); } - pub(crate) fn record_http3_request_accept_error(&self, listener_id: &str) { + pub(crate) fn record_http3_request_resolve_error(&self, listener_id: &str) { if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.http3_request_accept_errors_total.fetch_add(1, Ordering::Relaxed); + counters.http3_request_resolve_errors_total.fetch_add(1, Ordering::Relaxed); } self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); } - pub(crate) fn record_http3_request_resolve_error(&self, listener_id: &str) { + pub(crate) fn record_http3_response_stream_error(&self, listener_id: &str) { if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.http3_request_resolve_errors_total.fetch_add(1, Ordering::Relaxed); + counters.http3_response_stream_errors_total.fetch_add(1, Ordering::Relaxed); } self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); } - pub(crate) fn record_http3_request_body_stream_error(&self, listener_id: &str) { + pub(crate) fn record_http3_retry_failed(&self, listener_id: &str) { if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.http3_request_body_stream_errors_total.fetch_add(1, Ordering::Relaxed); + counters.http3_retry_failed_total.fetch_add(1, Ordering::Relaxed); } self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); } - pub(crate) fn record_http3_response_stream_error(&self, listener_id: &str) { + pub(crate) fn record_http3_retry_issued(&self, listener_id: &str) { if let Some(counters) = self.listener_traffic_counters(listener_id) { - counters.http3_response_stream_errors_total.fetch_add(1, Ordering::Relaxed); + counters.http3_retry_issued_total.fetch_add(1, Ordering::Relaxed); } self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); } - pub(crate) fn record_http3_connection_close( + pub(crate) fn record_mtls_handshake_success(&self, listener_id: &str, authenticated: bool) { + if !authenticated { + return; + } + + self.observability + .counters + .downstream_mtls_authenticated_connections + .fetch_add(1, Ordering::Relaxed); + if let Some(counters) = self.listener_traffic_counters(listener_id) { + counters.downstream_mtls_authenticated_connections.fetch_add(1, Ordering::Relaxed); + } + self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); + } + + pub(crate) fn record_tls_handshake_failure( &self, listener_id: &str, - reason: quinn::ConnectionError, + reason: TlsHandshakeFailureReason, ) { - if let Some(counters) = self.listener_traffic_counters(listener_id) { - match reason { - quinn::ConnectionError::VersionMismatch => { + self.observability + .counters + .downstream_tls_handshake_failures + .fetch_add(1, Ordering::Relaxed); + match reason { + TlsHandshakeFailureReason::MissingClientCert => { + self.observability + .counters + .downstream_tls_handshake_failures_missing_client_cert + .fetch_add(1, Ordering::Relaxed); + if let Some(counters) = self.listener_traffic_counters(listener_id) { counters - .http3_connection_close_version_mismatch_total + .downstream_tls_handshake_failures_missing_client_cert .fetch_add(1, Ordering::Relaxed); } - quinn::ConnectionError::TransportError(_) => { + } + TlsHandshakeFailureReason::UnknownCa => { + self.observability + .counters + .downstream_tls_handshake_failures_unknown_ca + .fetch_add(1, Ordering::Relaxed); + if let Some(counters) = self.listener_traffic_counters(listener_id) { counters - .http3_connection_close_transport_error_total + .downstream_tls_handshake_failures_unknown_ca .fetch_add(1, Ordering::Relaxed); } - quinn::ConnectionError::ConnectionClosed(_) => { + } + TlsHandshakeFailureReason::BadCertificate => { + self.observability + .counters + .downstream_tls_handshake_failures_bad_certificate + .fetch_add(1, Ordering::Relaxed); + if let Some(counters) = self.listener_traffic_counters(listener_id) { counters - .http3_connection_close_connection_closed_total + .downstream_tls_handshake_failures_bad_certificate .fetch_add(1, Ordering::Relaxed); } - quinn::ConnectionError::ApplicationClosed(_) => { + } + TlsHandshakeFailureReason::CertificateRevoked => { + self.observability + .counters + .downstream_tls_handshake_failures_certificate_revoked + .fetch_add(1, Ordering::Relaxed); + if let Some(counters) = self.listener_traffic_counters(listener_id) { counters - .http3_connection_close_application_closed_total + .downstream_tls_handshake_failures_certificate_revoked .fetch_add(1, Ordering::Relaxed); } - quinn::ConnectionError::Reset => { - counters.http3_connection_close_reset_total.fetch_add(1, Ordering::Relaxed); - } - quinn::ConnectionError::TimedOut => { - counters.http3_connection_close_timed_out_total.fetch_add(1, Ordering::Relaxed); - } - quinn::ConnectionError::LocallyClosed => { + } + TlsHandshakeFailureReason::VerifyDepthExceeded => { + self.observability + .counters + .downstream_tls_handshake_failures_verify_depth_exceeded + .fetch_add(1, Ordering::Relaxed); + if let Some(counters) = self.listener_traffic_counters(listener_id) { counters - .http3_connection_close_locally_closed_total + .downstream_tls_handshake_failures_verify_depth_exceeded .fetch_add(1, Ordering::Relaxed); } - quinn::ConnectionError::CidsExhausted => { + } + TlsHandshakeFailureReason::Other => { + self.observability + .counters + .downstream_tls_handshake_failures_other + .fetch_add(1, Ordering::Relaxed); + if let Some(counters) = self.listener_traffic_counters(listener_id) { counters - .http3_connection_close_cids_exhausted_total + .downstream_tls_handshake_failures_other .fetch_add(1, Ordering::Relaxed); } } } - self.mark_traffic_snapshot_changed(true, false, Some(listener_id), None, None); + if let Some(counters) = self.listener_traffic_counters(listener_id) { + counters.downstream_tls_handshake_failures.fetch_add(1, Ordering::Relaxed); + } + self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); } } diff --git a/crates/rginx-http/src/state/counters/grpc.rs b/crates/rginx-http/src/state/counters/grpc.rs index b4becc2c..dd0e28ea 100644 --- a/crates/rginx-http/src/state/counters/grpc.rs +++ b/crates/rginx-http/src/state/counters/grpc.rs @@ -1,19 +1,19 @@ -use super::*; +use super::{AtomicU64, GrpcTrafficSnapshot, Ordering}; #[derive(Debug, Default)] pub(in crate::state) struct GrpcTrafficCounters { - requests_total: AtomicU64, protocol_grpc_total: AtomicU64, - protocol_grpc_web_total: AtomicU64, protocol_grpc_web_text_total: AtomicU64, + protocol_grpc_web_total: AtomicU64, + requests_total: AtomicU64, status_0_total: AtomicU64, + status_12_total: AtomicU64, + status_14_total: AtomicU64, status_1_total: AtomicU64, status_3_total: AtomicU64, status_4_total: AtomicU64, status_7_total: AtomicU64, status_8_total: AtomicU64, - status_12_total: AtomicU64, - status_14_total: AtomicU64, status_other_total: AtomicU64, } diff --git a/crates/rginx-http/src/state/counters/http.rs b/crates/rginx-http/src/state/counters/http.rs index 66cd63d9..e93bb89a 100644 --- a/crates/rginx-http/src/state/counters/http.rs +++ b/crates/rginx-http/src/state/counters/http.rs @@ -1,9 +1,17 @@ -use super::*; +use super::{ + ApplyResultSnapshot, ApplyStatusSnapshot, AtomicU64, HttpCountersSnapshot, Ordering, + ReloadResultSnapshot, ReloadStatusSnapshot, +}; #[derive(Debug, Default)] pub(in crate::state) struct HttpCounters { pub(in crate::state) downstream_connections_accepted: AtomicU64, pub(in crate::state) downstream_connections_rejected: AtomicU64, + pub(in crate::state) downstream_http3_early_data_accepted_requests: AtomicU64, + pub(in crate::state) downstream_http3_early_data_rejected_requests: AtomicU64, + pub(in crate::state) downstream_mtls_anonymous_requests: AtomicU64, + pub(in crate::state) downstream_mtls_authenticated_connections: AtomicU64, + pub(in crate::state) downstream_mtls_authenticated_requests: AtomicU64, pub(in crate::state) downstream_requests: AtomicU64, pub(in crate::state) downstream_responses: AtomicU64, pub(in crate::state) downstream_responses_1xx: AtomicU64, @@ -11,44 +19,39 @@ pub(in crate::state) struct HttpCounters { pub(in crate::state) downstream_responses_3xx: AtomicU64, pub(in crate::state) downstream_responses_4xx: AtomicU64, pub(in crate::state) downstream_responses_5xx: AtomicU64, - pub(in crate::state) downstream_mtls_authenticated_connections: AtomicU64, - pub(in crate::state) downstream_mtls_authenticated_requests: AtomicU64, - pub(in crate::state) downstream_mtls_anonymous_requests: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures: AtomicU64, - pub(in crate::state) downstream_tls_handshake_failures_missing_client_cert: AtomicU64, - pub(in crate::state) downstream_tls_handshake_failures_unknown_ca: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures_bad_certificate: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures_certificate_revoked: AtomicU64, - pub(in crate::state) downstream_tls_handshake_failures_verify_depth_exceeded: AtomicU64, + pub(in crate::state) downstream_tls_handshake_failures_missing_client_cert: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures_other: AtomicU64, - pub(in crate::state) downstream_http3_early_data_accepted_requests: AtomicU64, - pub(in crate::state) downstream_http3_early_data_rejected_requests: AtomicU64, + pub(in crate::state) downstream_tls_handshake_failures_unknown_ca: AtomicU64, + pub(in crate::state) downstream_tls_handshake_failures_verify_depth_exceeded: AtomicU64, } #[derive(Debug, Default)] pub(in crate::state) struct ReloadHistory { pub(in crate::state) attempts_total: u64, - pub(in crate::state) successes_total: u64, pub(in crate::state) failures_total: u64, pub(in crate::state) last_result: Option, + pub(in crate::state) successes_total: u64, } #[derive(Debug, Default)] pub(in crate::state) struct ApplyHistory { pub(in crate::state) attempts_total: u64, - pub(in crate::state) successes_total: u64, pub(in crate::state) failures_total: u64, pub(in crate::state) last_result: Option, + pub(in crate::state) successes_total: u64, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub(crate) enum TlsHandshakeFailureReason { - MissingClientCert, - UnknownCa, BadCertificate, CertificateRevoked, - VerifyDepthExceeded, + MissingClientCert, Other, + UnknownCa, + VerifyDepthExceeded, } impl TlsHandshakeFailureReason { diff --git a/crates/rginx-http/src/state/counters/mod.rs b/crates/rginx-http/src/state/counters/mod.rs index 94f9403d..c5847c54 100644 --- a/crates/rginx-http/src/state/counters/mod.rs +++ b/crates/rginx-http/src/state/counters/mod.rs @@ -1,5 +1,3 @@ -use super::*; - mod grpc; pub(in crate::state) mod http; mod rolling; @@ -7,6 +5,13 @@ mod traffic; mod upstreams; mod versions; +use super::{ + ApplyResultSnapshot, ApplyStatusSnapshot, Arc, AtomicU64, AtomicUsize, ConfigSnapshot, + GrpcTrafficSnapshot, HashMap, HttpCountersSnapshot, MAX_RECENT_WINDOW_SECS, Mutex, Ordering, + RECENT_WINDOW_SECS, RecentTrafficStatsSnapshot, RecentUpstreamStatsSnapshot, + ReloadResultSnapshot, ReloadStatusSnapshot, StatusCode, SystemTime, UNIX_EPOCH, VecDeque, +}; + use self::grpc::GrpcTrafficCounters; pub(super) use self::http::{ApplyHistory, HttpCounters, ReloadHistory}; use self::rolling::{RecentTrafficStatsCounters, RecentUpstreamStatsCounters}; diff --git a/crates/rginx-http/src/state/counters/rolling.rs b/crates/rginx-http/src/state/counters/rolling.rs index d5b018ab..057a4111 100644 --- a/crates/rginx-http/src/state/counters/rolling.rs +++ b/crates/rginx-http/src/state/counters/rolling.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + MAX_RECENT_WINDOW_SECS, Mutex, RECENT_WINDOW_SECS, RecentTrafficStatsSnapshot, + RecentUpstreamStatsSnapshot, StatusCode, SystemTime, UNIX_EPOCH, VecDeque, +}; #[derive(Debug, Default)] struct RollingCounter { @@ -7,52 +10,48 @@ struct RollingCounter { #[derive(Debug, Default)] pub(in crate::state) struct RecentTrafficStatsCounters { - downstream_requests_total: RollingCounter, downstream_request_bytes_total: RollingCounter, - downstream_responses_total: RollingCounter, + downstream_requests_total: RollingCounter, downstream_response_bytes_total: RollingCounter, downstream_responses_2xx_total: RollingCounter, downstream_responses_4xx_total: RollingCounter, downstream_responses_5xx_total: RollingCounter, + downstream_responses_total: RollingCounter, grpc_requests_total: RollingCounter, } #[derive(Debug, Default)] pub(in crate::state) struct RecentUpstreamStatsCounters { - downstream_requests_total: RollingCounter, - downstream_request_bytes_total: RollingCounter, - peer_attempts_total: RollingCounter, + bad_gateway_responses_total: RollingCounter, completed_responses_total: RollingCounter, + downstream_request_bytes_total: RollingCounter, + downstream_requests_total: RollingCounter, downstream_response_bytes_total: RollingCounter, - bad_gateway_responses_total: RollingCounter, - gateway_timeout_responses_total: RollingCounter, failovers_total: RollingCounter, + gateway_timeout_responses_total: RollingCounter, + peer_attempts_total: RollingCounter, } impl RollingCounter { - fn add_now(&self, value: u64) { - self.add_at(window_now_secs(), value); - } - fn add_at(&self, second: u64, value: u64) { if value == 0 { return; } - let mut buckets = self.buckets.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut buckets = self.buckets.lock().unwrap_or_else(std::sync::PoisonError::into_inner); trim_old_buckets(&mut buckets, second, MAX_RECENT_WINDOW_SECS); match buckets.back_mut() { Some((bucket_second, count)) if *bucket_second == second => { - *count += value; + *count = count.saturating_add(value); } Some((bucket_second, count)) if *bucket_second > second => { - *count += value; + *count = count.saturating_add(value); } _ => buckets.push_back((second, value)), } } - fn increment_now(&self) { - self.add_now(1); + fn add_now(&self, value: u64) { + self.add_at(window_now_secs(), value); } #[cfg(test)] @@ -60,8 +59,12 @@ impl RollingCounter { self.add_at(second, 1); } + fn increment_now(&self) { + self.add_now(1); + } + fn sum_recent(&self, now_second: u64, window_secs: u64) -> u64 { - let mut buckets = self.buckets.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut buckets = self.buckets.lock().unwrap_or_else(std::sync::PoisonError::into_inner); trim_old_buckets(&mut buckets, now_second, MAX_RECENT_WINDOW_SECS); let cutoff = now_second.saturating_sub(window_secs.saturating_sub(1)); buckets.iter().filter_map(|(second, count)| (*second >= cutoff).then_some(*count)).sum() @@ -69,6 +72,29 @@ impl RollingCounter { } impl RecentTrafficStatsCounters { + pub(in crate::state) fn record_downstream_request(&self, request_bytes: u64) { + self.downstream_requests_total.increment_now(); + self.downstream_request_bytes_total.add_now(request_bytes); + } + + pub(in crate::state) fn record_downstream_response( + &self, + status: StatusCode, + response_bytes: u64, + ) { + self.downstream_responses_total.increment_now(); + self.downstream_response_bytes_total.add_now(response_bytes); + match status.as_u16() / 100 { + 2 => self.downstream_responses_2xx_total.increment_now(), + 4 => self.downstream_responses_4xx_total.increment_now(), + 5 => self.downstream_responses_5xx_total.increment_now(), + _ => {} + } + } + + pub(in crate::state) fn record_grpc_request(&self) { + self.grpc_requests_total.increment_now(); + } pub(in crate::state) fn snapshot(&self) -> RecentTrafficStatsSnapshot { self.snapshot_for_window(RECENT_WINDOW_SECS) } @@ -102,33 +128,38 @@ impl RecentTrafficStatsCounters { grpc_requests_total: self.grpc_requests_total.sum_recent(now, window_secs), } } +} + +impl RecentUpstreamStatsCounters { + pub(in crate::state) fn record_bad_gateway_response(&self) { + self.bad_gateway_responses_total.increment_now(); + } + + pub(in crate::state) fn record_completed_response(&self) { + self.completed_responses_total.increment_now(); + } pub(in crate::state) fn record_downstream_request(&self, request_bytes: u64) { self.downstream_requests_total.increment_now(); self.downstream_request_bytes_total.add_now(request_bytes); } - pub(in crate::state) fn record_downstream_response( - &self, - status: StatusCode, - response_bytes: u64, - ) { - self.downstream_responses_total.increment_now(); + pub(in crate::state) fn record_downstream_response_bytes(&self, response_bytes: u64) { self.downstream_response_bytes_total.add_now(response_bytes); - match status.as_u16() / 100 { - 2 => self.downstream_responses_2xx_total.increment_now(), - 4 => self.downstream_responses_4xx_total.increment_now(), - 5 => self.downstream_responses_5xx_total.increment_now(), - _ => {} - } } - pub(in crate::state) fn record_grpc_request(&self) { - self.grpc_requests_total.increment_now(); + pub(in crate::state) fn record_failover(&self) { + self.failovers_total.increment_now(); + } + + pub(in crate::state) fn record_gateway_timeout_response(&self) { + self.gateway_timeout_responses_total.increment_now(); + } + + pub(in crate::state) fn record_peer_attempt(&self) { + self.peer_attempts_total.increment_now(); } -} -impl RecentUpstreamStatsCounters { pub(in crate::state) fn snapshot(&self) -> RecentUpstreamStatsSnapshot { self.snapshot_for_window(RECENT_WINDOW_SECS) } @@ -158,39 +189,10 @@ impl RecentUpstreamStatsCounters { failovers_total: self.failovers_total.sum_recent(now, window_secs), } } - - pub(in crate::state) fn record_downstream_request(&self, request_bytes: u64) { - self.downstream_requests_total.increment_now(); - self.downstream_request_bytes_total.add_now(request_bytes); - } - - pub(in crate::state) fn record_peer_attempt(&self) { - self.peer_attempts_total.increment_now(); - } - - pub(in crate::state) fn record_completed_response(&self) { - self.completed_responses_total.increment_now(); - } - - pub(in crate::state) fn record_downstream_response_bytes(&self, response_bytes: u64) { - self.downstream_response_bytes_total.add_now(response_bytes); - } - - pub(in crate::state) fn record_bad_gateway_response(&self) { - self.bad_gateway_responses_total.increment_now(); - } - - pub(in crate::state) fn record_gateway_timeout_response(&self) { - self.gateway_timeout_responses_total.increment_now(); - } - - pub(in crate::state) fn record_failover(&self) { - self.failovers_total.increment_now(); - } } fn window_now_secs() -> u64 { - SystemTime::now().duration_since(UNIX_EPOCH).map(|duration| duration.as_secs()).unwrap_or(0) + SystemTime::now().duration_since(UNIX_EPOCH).map_or(0, |duration| duration.as_secs()) } fn trim_old_buckets(buckets: &mut VecDeque<(u64, u64)>, now_second: u64, window_secs: u64) { diff --git a/crates/rginx-http/src/state/counters/traffic.rs b/crates/rginx-http/src/state/counters/traffic.rs index 3dab8a60..583e1779 100644 --- a/crates/rginx-http/src/state/counters/traffic.rs +++ b/crates/rginx-http/src/state/counters/traffic.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + Arc, AtomicU64, AtomicUsize, ConfigSnapshot, GrpcTrafficCounters, HashMap, Ordering, + RecentTrafficStatsCounters, StatusCode, +}; #[derive(Debug, Default)] pub(in crate::state) struct ResponseCounters { @@ -12,103 +15,103 @@ pub(in crate::state) struct ResponseCounters { #[derive(Debug, Default)] pub(in crate::state) struct RequestTrafficCounters { - pub(in crate::state) downstream_requests: AtomicU64, pub(in crate::state) downstream_request_bytes_total: AtomicU64, - pub(in crate::state) unmatched_requests_total: AtomicU64, + pub(in crate::state) downstream_requests: AtomicU64, pub(in crate::state) downstream_response_bytes_total: AtomicU64, - pub(in crate::state) responses: ResponseCounters, - pub(in crate::state) recent_60s: RecentTrafficStatsCounters, pub(in crate::state) grpc: GrpcTrafficCounters, + pub(in crate::state) recent_60s: RecentTrafficStatsCounters, + pub(in crate::state) responses: ResponseCounters, + pub(in crate::state) unmatched_requests_total: AtomicU64, } #[derive(Debug, Default)] pub(in crate::state) struct ListenerTrafficCounters { - pub(in crate::state) downstream_connections_accepted: AtomicU64, - pub(in crate::state) downstream_connections_rejected: AtomicU64, pub(in crate::state) active_http3_connections: AtomicUsize, pub(in crate::state) active_http3_request_streams: AtomicUsize, - pub(in crate::state) http3_retry_issued_total: AtomicU64, - pub(in crate::state) http3_retry_failed_total: AtomicU64, - pub(in crate::state) http3_request_accept_errors_total: AtomicU64, - pub(in crate::state) http3_request_resolve_errors_total: AtomicU64, - pub(in crate::state) http3_request_body_stream_errors_total: AtomicU64, - pub(in crate::state) http3_response_stream_errors_total: AtomicU64, - pub(in crate::state) http3_connection_close_version_mismatch_total: AtomicU64, - pub(in crate::state) http3_connection_close_transport_error_total: AtomicU64, - pub(in crate::state) http3_connection_close_connection_closed_total: AtomicU64, - pub(in crate::state) http3_connection_close_application_closed_total: AtomicU64, - pub(in crate::state) http3_connection_close_reset_total: AtomicU64, - pub(in crate::state) http3_connection_close_timed_out_total: AtomicU64, - pub(in crate::state) http3_connection_close_locally_closed_total: AtomicU64, - pub(in crate::state) http3_connection_close_cids_exhausted_total: AtomicU64, + pub(in crate::state) downstream_connections_accepted: AtomicU64, + pub(in crate::state) downstream_connections_rejected: AtomicU64, + pub(in crate::state) downstream_http3_early_data_accepted_requests: AtomicU64, + pub(in crate::state) downstream_http3_early_data_rejected_requests: AtomicU64, + pub(in crate::state) downstream_mtls_anonymous_requests: AtomicU64, pub(in crate::state) downstream_mtls_authenticated_connections: AtomicU64, + pub(in crate::state) downstream_mtls_authenticated_requests: AtomicU64, + pub(in crate::state) downstream_request_bytes_total: AtomicU64, + pub(in crate::state) downstream_requests: AtomicU64, + pub(in crate::state) downstream_response_bytes_total: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures: AtomicU64, - pub(in crate::state) downstream_tls_handshake_failures_missing_client_cert: AtomicU64, - pub(in crate::state) downstream_tls_handshake_failures_unknown_ca: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures_bad_certificate: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures_certificate_revoked: AtomicU64, - pub(in crate::state) downstream_tls_handshake_failures_verify_depth_exceeded: AtomicU64, + pub(in crate::state) downstream_tls_handshake_failures_missing_client_cert: AtomicU64, pub(in crate::state) downstream_tls_handshake_failures_other: AtomicU64, - pub(in crate::state) downstream_http3_early_data_accepted_requests: AtomicU64, - pub(in crate::state) downstream_http3_early_data_rejected_requests: AtomicU64, - pub(in crate::state) downstream_requests: AtomicU64, - pub(in crate::state) downstream_request_bytes_total: AtomicU64, - pub(in crate::state) downstream_mtls_authenticated_requests: AtomicU64, - pub(in crate::state) downstream_mtls_anonymous_requests: AtomicU64, - pub(in crate::state) unmatched_requests_total: AtomicU64, - pub(in crate::state) downstream_response_bytes_total: AtomicU64, - pub(in crate::state) responses: ResponseCounters, - pub(in crate::state) recent_60s: RecentTrafficStatsCounters, + pub(in crate::state) downstream_tls_handshake_failures_unknown_ca: AtomicU64, + pub(in crate::state) downstream_tls_handshake_failures_verify_depth_exceeded: AtomicU64, pub(in crate::state) grpc: GrpcTrafficCounters, + pub(in crate::state) http3_connection_close_application_closed_total: AtomicU64, + pub(in crate::state) http3_connection_close_cids_exhausted_total: AtomicU64, + pub(in crate::state) http3_connection_close_connection_closed_total: AtomicU64, + pub(in crate::state) http3_connection_close_locally_closed_total: AtomicU64, + pub(in crate::state) http3_connection_close_reset_total: AtomicU64, + pub(in crate::state) http3_connection_close_timed_out_total: AtomicU64, + pub(in crate::state) http3_connection_close_transport_error_total: AtomicU64, + pub(in crate::state) http3_connection_close_version_mismatch_total: AtomicU64, + pub(in crate::state) http3_request_accept_errors_total: AtomicU64, + pub(in crate::state) http3_request_body_stream_errors_total: AtomicU64, + pub(in crate::state) http3_request_resolve_errors_total: AtomicU64, + pub(in crate::state) http3_response_stream_errors_total: AtomicU64, + pub(in crate::state) http3_retry_failed_total: AtomicU64, + pub(in crate::state) http3_retry_issued_total: AtomicU64, + pub(in crate::state) recent_60s: RecentTrafficStatsCounters, + pub(in crate::state) responses: ResponseCounters, + pub(in crate::state) unmatched_requests_total: AtomicU64, } #[derive(Debug, Default)] pub(in crate::state) struct RouteTrafficCounters { - pub(in crate::state) downstream_requests: AtomicU64, + pub(in crate::state) access_denied_total: AtomicU64, pub(in crate::state) downstream_request_bytes_total: AtomicU64, + pub(in crate::state) downstream_requests: AtomicU64, pub(in crate::state) downstream_response_bytes_total: AtomicU64, - pub(in crate::state) responses: ResponseCounters, - pub(in crate::state) access_denied_total: AtomicU64, + pub(in crate::state) grpc: GrpcTrafficCounters, pub(in crate::state) rate_limited_total: AtomicU64, pub(in crate::state) recent_60s: RecentTrafficStatsCounters, - pub(in crate::state) grpc: GrpcTrafficCounters, + pub(in crate::state) responses: ResponseCounters, } #[derive(Debug)] pub(in crate::state) struct ListenerTrafficEntry { - pub(in crate::state) listener_name: String, - pub(in crate::state) listen_addr: std::net::SocketAddr, - pub(in crate::state) http3_enabled: bool, pub(in crate::state) counters: Arc, + pub(in crate::state) http3_enabled: bool, + pub(in crate::state) listen_addr: std::net::SocketAddr, + pub(in crate::state) listener_name: String, } #[derive(Debug)] pub(in crate::state) struct VhostTrafficEntry { - pub(in crate::state) server_names: Vec, pub(in crate::state) counters: Arc, + pub(in crate::state) server_names: Vec, } #[derive(Debug)] pub(in crate::state) struct RouteTrafficEntry { - pub(in crate::state) vhost_id: String, pub(in crate::state) counters: Arc, + pub(in crate::state) vhost_id: String, } #[derive(Debug, Default)] pub(in crate::state) struct TrafficStatsIndex { - pub(in crate::state) listeners: HashMap, pub(in crate::state) listener_order: Vec, - pub(in crate::state) vhosts: HashMap, - pub(in crate::state) vhost_order: Vec, - pub(in crate::state) routes: HashMap, + pub(in crate::state) listeners: HashMap, pub(in crate::state) route_order: Vec, + pub(in crate::state) routes: HashMap, + pub(in crate::state) vhost_order: Vec, + pub(in crate::state) vhosts: HashMap, } #[derive(Debug, Default)] pub(in crate::state) struct TrafficComponentVersions { pub(in crate::state) listeners: HashMap, - pub(in crate::state) vhosts: HashMap, pub(in crate::state) routes: HashMap, + pub(in crate::state) vhosts: HashMap, } impl ResponseCounters { @@ -161,9 +164,10 @@ pub(in crate::state) fn build_traffic_stats_index( listener_name: listener.name.clone(), listen_addr: listener.server.listen_addr, http3_enabled: listener.http3.is_some(), - counters: current - .map(|entry| entry.counters.clone()) - .unwrap_or_else(|| Arc::new(ListenerTrafficCounters::default())), + counters: current.map_or_else( + || Arc::new(ListenerTrafficCounters::default()), + |entry| entry.counters.clone(), + ), }, ); } @@ -175,9 +179,10 @@ pub(in crate::state) fn build_traffic_stats_index( vhost.id.clone(), VhostTrafficEntry { server_names: vhost.server_names.clone(), - counters: current - .map(|entry| entry.counters.clone()) - .unwrap_or_else(|| Arc::new(RequestTrafficCounters::default())), + counters: current.map_or_else( + || Arc::new(RequestTrafficCounters::default()), + |entry| entry.counters.clone(), + ), }, ); @@ -188,9 +193,10 @@ pub(in crate::state) fn build_traffic_stats_index( route.id.clone(), RouteTrafficEntry { vhost_id: vhost.id.clone(), - counters: current - .map(|entry| entry.counters.clone()) - .unwrap_or_else(|| Arc::new(RouteTrafficCounters::default())), + counters: current.map_or_else( + || Arc::new(RouteTrafficCounters::default()), + |entry| entry.counters.clone(), + ), }, ); } diff --git a/crates/rginx-http/src/state/counters/upstreams.rs b/crates/rginx-http/src/state/counters/upstreams.rs index 6f0736d3..72b3f983 100644 --- a/crates/rginx-http/src/state/counters/upstreams.rs +++ b/crates/rginx-http/src/state/counters/upstreams.rs @@ -1,43 +1,43 @@ -use super::*; +use super::{Arc, AtomicU64, ConfigSnapshot, HashMap, RecentUpstreamStatsCounters}; #[derive(Debug, Default)] pub(in crate::state) struct UpstreamStats { - pub(in crate::state) downstream_requests_total: AtomicU64, - pub(in crate::state) downstream_request_bytes_total: AtomicU64, - pub(in crate::state) peer_attempts_total: AtomicU64, - pub(in crate::state) peer_successes_total: AtomicU64, - pub(in crate::state) peer_failures_total: AtomicU64, - pub(in crate::state) peer_timeouts_total: AtomicU64, - pub(in crate::state) failovers_total: AtomicU64, + pub(in crate::state) bad_gateway_responses_total: AtomicU64, + pub(in crate::state) bad_request_responses_total: AtomicU64, pub(in crate::state) completed_responses_total: AtomicU64, + pub(in crate::state) downstream_request_bytes_total: AtomicU64, + pub(in crate::state) downstream_requests_total: AtomicU64, pub(in crate::state) downstream_response_bytes_total: AtomicU64, - pub(in crate::state) bad_gateway_responses_total: AtomicU64, + pub(in crate::state) failovers_total: AtomicU64, pub(in crate::state) gateway_timeout_responses_total: AtomicU64, - pub(in crate::state) bad_request_responses_total: AtomicU64, - pub(in crate::state) payload_too_large_responses_total: AtomicU64, - pub(in crate::state) unsupported_media_type_responses_total: AtomicU64, pub(in crate::state) no_healthy_peers_total: AtomicU64, - pub(in crate::state) tls_failures_unknown_ca_total: AtomicU64, + pub(in crate::state) payload_too_large_responses_total: AtomicU64, + pub(in crate::state) peer_attempts_total: AtomicU64, + pub(in crate::state) peer_failures_total: AtomicU64, + pub(in crate::state) peer_successes_total: AtomicU64, + pub(in crate::state) peer_timeouts_total: AtomicU64, + pub(in crate::state) recent_60s: RecentUpstreamStatsCounters, pub(in crate::state) tls_failures_bad_certificate_total: AtomicU64, pub(in crate::state) tls_failures_certificate_revoked_total: AtomicU64, + pub(in crate::state) tls_failures_unknown_ca_total: AtomicU64, pub(in crate::state) tls_failures_verify_depth_exceeded_total: AtomicU64, - pub(in crate::state) recent_60s: RecentUpstreamStatsCounters, + pub(in crate::state) unsupported_media_type_responses_total: AtomicU64, } #[derive(Debug, Default)] pub(in crate::state) struct UpstreamPeerStats { pub(in crate::state) attempts_total: AtomicU64, - pub(in crate::state) successes_total: AtomicU64, pub(in crate::state) failures_total: AtomicU64, + pub(in crate::state) successes_total: AtomicU64, pub(in crate::state) timeouts_total: AtomicU64, } #[derive(Debug)] pub(in crate::state) struct UpstreamStatsEntry { - pub(in crate::state) upstream: Arc, pub(in crate::state) counters: Arc, - pub(in crate::state) peers: HashMap>, pub(in crate::state) peer_order: Vec, + pub(in crate::state) peers: HashMap>, + pub(in crate::state) upstream: Arc, } pub(in crate::state) fn build_upstream_stats_map( @@ -65,9 +65,10 @@ pub(in crate::state) fn build_upstream_stats_map( upstream.name.clone(), UpstreamStatsEntry { upstream: upstream.clone(), - counters: current - .map(|entry| entry.counters.clone()) - .unwrap_or_else(|| Arc::new(UpstreamStats::default())), + counters: current.map_or_else( + || Arc::new(UpstreamStats::default()), + |entry| entry.counters.clone(), + ), peers, peer_order: upstream.peers.iter().map(|peer| peer.url.clone()).collect(), }, diff --git a/crates/rginx-http/src/state/counters/versions.rs b/crates/rginx-http/src/state/counters/versions.rs index 790d85c3..8fba0391 100644 --- a/crates/rginx-http/src/state/counters/versions.rs +++ b/crates/rginx-http/src/state/counters/versions.rs @@ -1,11 +1,11 @@ -use super::*; +use super::AtomicU64; #[derive(Debug, Default)] pub(in crate::state) struct SnapshotComponentVersions { - pub(in crate::state) status: AtomicU64, + pub(in crate::state) cache: AtomicU64, pub(in crate::state) counters: AtomicU64, - pub(in crate::state) traffic: AtomicU64, pub(in crate::state) peer_health: AtomicU64, + pub(in crate::state) status: AtomicU64, + pub(in crate::state) traffic: AtomicU64, pub(in crate::state) upstreams: AtomicU64, - pub(in crate::state) cache: AtomicU64, } diff --git a/crates/rginx-http/src/state/helpers/mod.rs b/crates/rginx-http/src/state/helpers/mod.rs index c9488539..8a20b951 100644 --- a/crates/rginx-http/src/state/helpers/mod.rs +++ b/crates/rginx-http/src/state/helpers/mod.rs @@ -1,4 +1,9 @@ -use super::*; +use super::{ + Arc, CacheChangeNotifier, CacheManager, ConfigSnapshot, HashMap, HealthChangeNotifier, + Http3ListenerRuntimeSnapshot, JoinHandle, ListenerTrafficCounters, Mutex, Ordering, + PreparedState, ProxyClients, Result, SnapshotBusState, StdRwLock, SystemTime, TlsAcceptor, + UNIX_EPOCH, build_tls_acceptor, +}; pub(super) fn prepare_state( config: ConfigSnapshot, @@ -47,7 +52,7 @@ pub(super) fn build_peer_health_notifier( let version = snapshot_bus.mark_changed(false, false, false, true, false, false); peer_health_component_versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .insert(upstream_name.to_string(), version); snapshot_bus.notify_waiters(); }) @@ -61,7 +66,7 @@ pub(super) fn build_cache_notifier( let version = snapshot_bus.mark_changed(true, false, false, false, false, true); cache_component_versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .insert(zone_name.to_string(), version); snapshot_bus.notify_waiters(); }) @@ -69,8 +74,7 @@ pub(super) fn build_cache_notifier( pub(super) fn unix_time_ms(time: SystemTime) -> u64 { time.duration_since(UNIX_EPOCH) - .map(|duration| duration.as_millis().min(u128::from(u64::MAX)) as u64) - .unwrap_or(0) + .map_or(0, |duration| duration.as_millis().min(u128::from(u64::MAX)) as u64) } pub(super) fn http3_runtime_snapshot( @@ -128,6 +132,6 @@ pub(super) fn http3_runtime_snapshot( pub(super) fn take_background_tasks( background_tasks: &Arc>>>, ) -> Vec> { - let mut tasks = background_tasks.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); + let mut tasks = background_tasks.lock().unwrap_or_else(std::sync::PoisonError::into_inner); std::mem::take(&mut *tasks) } diff --git a/crates/rginx-http/src/state/lifecycle/acme.rs b/crates/rginx-http/src/state/lifecycle/acme.rs index db130875..ff24a370 100644 --- a/crates/rginx-http/src/state/lifecycle/acme.rs +++ b/crates/rginx-http/src/state/lifecycle/acme.rs @@ -1,47 +1,17 @@ -use super::super::*; +use super::super::{ConfigSnapshot, SharedState, unix_time_ms}; use std::time::{Duration, SystemTime}; impl SharedState { - pub fn register_acme_http01_challenge( - &self, - token: impl Into, - key_authorization: impl Into, - ) { - self.lifecycle - .acme_http01_challenges - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .insert(token.into(), key_authorization.into()); - } - - pub fn unregister_acme_http01_challenge(&self, token: &str) { - self.lifecycle - .acme_http01_challenges - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .remove(token); - } - + #[must_use] pub fn acme_http01_response(&self, token: &str) -> Option { self.lifecycle .acme_http01_challenges .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .get(token) .cloned() } - pub fn record_acme_refresh_success(&self, scope: &str) { - let mut statuses = - self.lifecycle.acme_statuses.write().unwrap_or_else(|poisoned| poisoned.into_inner()); - let entry = statuses.entry(scope.to_string()).or_default(); - entry.last_success_unix_ms = Some(unix_time_ms(SystemTime::now())); - entry.refreshes_total += 1; - entry.retry_after_unix_ms = None; - entry.last_error = None; - self.mark_status_snapshot_changed(); - } - pub fn record_acme_refresh_failure( &self, scope: &str, @@ -49,15 +19,38 @@ impl SharedState { retry_after: Option, ) { let mut statuses = - self.lifecycle.acme_statuses.write().unwrap_or_else(|poisoned| poisoned.into_inner()); + self.lifecycle.acme_statuses.write().unwrap_or_else(std::sync::PoisonError::into_inner); let entry = statuses.entry(scope.to_string()).or_default(); - entry.failures_total += 1; + entry.failures_total = entry.failures_total.saturating_add(1); entry.retry_after_unix_ms = retry_after.and_then(|delay| SystemTime::now().checked_add(delay).map(unix_time_ms)); entry.last_error = Some(error.into()); self.mark_status_snapshot_changed(); } + pub fn record_acme_refresh_success(&self, scope: &str) { + let mut statuses = + self.lifecycle.acme_statuses.write().unwrap_or_else(std::sync::PoisonError::into_inner); + let entry = statuses.entry(scope.to_string()).or_default(); + entry.last_success_unix_ms = Some(unix_time_ms(SystemTime::now())); + entry.refreshes_total = entry.refreshes_total.saturating_add(1); + entry.retry_after_unix_ms = None; + entry.last_error = None; + self.mark_status_snapshot_changed(); + } + + pub fn register_acme_http01_challenge( + &self, + token: impl Into, + key_authorization: impl Into, + ) { + self.lifecycle + .acme_http01_challenges + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .insert(token.into(), key_authorization.into()); + } + pub(super) fn sync_acme_statuses(&self, config: &ConfigSnapshot) { let managed_scopes = config .managed_certificates @@ -67,7 +60,15 @@ impl SharedState { self.lifecycle .acme_statuses .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .retain(|scope, _| managed_scopes.contains(scope.as_str())); } + + pub fn unregister_acme_http01_challenge(&self, token: &str) { + self.lifecycle + .acme_http01_challenges + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .remove(token); + } } diff --git a/crates/rginx-http/src/state/lifecycle/mtls.rs b/crates/rginx-http/src/state/lifecycle/mtls.rs index d73335ae..d71f702e 100644 --- a/crates/rginx-http/src/state/lifecycle/mtls.rs +++ b/crates/rginx-http/src/state/lifecycle/mtls.rs @@ -1,4 +1,4 @@ -use super::super::*; +use super::super::{ConfigSnapshot, MtlsStatusSnapshot, Ordering, SharedState}; impl SharedState { pub(super) fn mtls_status_snapshot(&self, config: &ConfigSnapshot) -> MtlsStatusSnapshot { @@ -22,54 +22,76 @@ impl SharedState { else { continue; }; - configured_listeners += 1; + configured_listeners = configured_listeners.saturating_add(1); match client_auth.mode { - rginx_core::ServerClientAuthMode::Optional => optional_listeners += 1, - rginx_core::ServerClientAuthMode::Required => required_listeners += 1, + rginx_core::ServerClientAuthMode::Optional => { + optional_listeners = optional_listeners.saturating_add(1); + } + rginx_core::ServerClientAuthMode::Required => { + required_listeners = required_listeners.saturating_add(1); + } } if let Some(counters) = self.listener_traffic_counters(&listener.id) { - authenticated_connections += - counters.downstream_mtls_authenticated_connections.load(Ordering::Relaxed); - authenticated_requests += - counters.downstream_mtls_authenticated_requests.load(Ordering::Relaxed); - anonymous_requests += - counters.downstream_mtls_anonymous_requests.load(Ordering::Relaxed); - handshake_failures_total += - counters.downstream_tls_handshake_failures.load(Ordering::Relaxed); - handshake_failures_missing_client_cert += counters - .downstream_tls_handshake_failures_missing_client_cert - .load(Ordering::Relaxed); - handshake_failures_unknown_ca += - counters.downstream_tls_handshake_failures_unknown_ca.load(Ordering::Relaxed); - handshake_failures_bad_certificate += counters - .downstream_tls_handshake_failures_bad_certificate - .load(Ordering::Relaxed); - handshake_failures_certificate_revoked += counters - .downstream_tls_handshake_failures_certificate_revoked - .load(Ordering::Relaxed); - handshake_failures_verify_depth_exceeded += counters - .downstream_tls_handshake_failures_verify_depth_exceeded - .load(Ordering::Relaxed); - handshake_failures_other += - counters.downstream_tls_handshake_failures_other.load(Ordering::Relaxed); + authenticated_connections = authenticated_connections.saturating_add( + counters.downstream_mtls_authenticated_connections.load(Ordering::Relaxed), + ); + authenticated_requests = authenticated_requests.saturating_add( + counters.downstream_mtls_authenticated_requests.load(Ordering::Relaxed), + ); + anonymous_requests = anonymous_requests.saturating_add( + counters.downstream_mtls_anonymous_requests.load(Ordering::Relaxed), + ); + handshake_failures_total = handshake_failures_total.saturating_add( + counters.downstream_tls_handshake_failures.load(Ordering::Relaxed), + ); + handshake_failures_missing_client_cert = handshake_failures_missing_client_cert + .saturating_add( + counters + .downstream_tls_handshake_failures_missing_client_cert + .load(Ordering::Relaxed), + ); + handshake_failures_unknown_ca = handshake_failures_unknown_ca.saturating_add( + counters.downstream_tls_handshake_failures_unknown_ca.load(Ordering::Relaxed), + ); + handshake_failures_bad_certificate = handshake_failures_bad_certificate + .saturating_add( + counters + .downstream_tls_handshake_failures_bad_certificate + .load(Ordering::Relaxed), + ); + handshake_failures_certificate_revoked = handshake_failures_certificate_revoked + .saturating_add( + counters + .downstream_tls_handshake_failures_certificate_revoked + .load(Ordering::Relaxed), + ); + handshake_failures_verify_depth_exceeded = handshake_failures_verify_depth_exceeded + .saturating_add( + counters + .downstream_tls_handshake_failures_verify_depth_exceeded + .load(Ordering::Relaxed), + ); + handshake_failures_other = handshake_failures_other.saturating_add( + counters.downstream_tls_handshake_failures_other.load(Ordering::Relaxed), + ); } } MtlsStatusSnapshot { - configured_listeners, - optional_listeners, - required_listeners, + anonymous_requests, authenticated_connections, authenticated_requests, - anonymous_requests, - handshake_failures_total, - handshake_failures_missing_client_cert, - handshake_failures_unknown_ca, + configured_listeners, handshake_failures_bad_certificate, handshake_failures_certificate_revoked, - handshake_failures_verify_depth_exceeded, + handshake_failures_missing_client_cert, handshake_failures_other, + handshake_failures_total, + handshake_failures_unknown_ca, + handshake_failures_verify_depth_exceeded, + optional_listeners, + required_listeners, } } } diff --git a/crates/rginx-http/src/state/lifecycle/reload.rs b/crates/rginx-http/src/state/lifecycle/reload.rs index 04220352..fd58073e 100644 --- a/crates/rginx-http/src/state/lifecycle/reload.rs +++ b/crates/rginx-http/src/state/lifecycle/reload.rs @@ -1,142 +1,22 @@ -use super::super::*; +use super::super::{ + ApplyOutcomeSnapshot, ApplyResultSnapshot, ApplyStatusSnapshot, Arc, + ConfigFailureStageSnapshot, ConfigSnapshot, PreparedState, ReloadOutcomeSnapshot, + ReloadResultSnapshot, Result, SharedState, SystemTime, build_cache_notifier, + build_peer_health_notifier, prepare_listener_tls_acceptors, prepare_state, unix_time_ms, +}; use super::topology::{traffic_topology_changed, upstream_topology_changed}; use crate::validate_config_transition; impl SharedState { - pub fn record_reload_success(&self, revision: u64, tls_certificate_changes: Vec) { - self.update_desired_revision(revision); - let mut history = - self.lifecycle.reload_history.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - history.attempts_total += 1; - history.successes_total += 1; - history.last_result = Some(ReloadResultSnapshot { - finished_at_unix_ms: unix_time_ms(SystemTime::now()), - outcome: ReloadOutcomeSnapshot::Success { revision }, - tls_certificate_changes, - active_revision: revision, - rollback_preserved_revision: None, - }); - self.mark_status_snapshot_changed(); - } - - pub fn record_reload_failure( - &self, - stage: ConfigFailureStageSnapshot, - error: impl Into, - active_revision: u64, - ) { - let mut history = - self.lifecycle.reload_history.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - history.attempts_total += 1; - history.failures_total += 1; - history.last_result = Some(ReloadResultSnapshot { - finished_at_unix_ms: unix_time_ms(SystemTime::now()), - outcome: ReloadOutcomeSnapshot::Failure { stage, error: error.into() }, - tls_certificate_changes: Vec::new(), - active_revision, - rollback_preserved_revision: Some(active_revision), - }); - self.mark_status_snapshot_changed(); - } - + #[must_use] pub fn apply_status_snapshot(&self) -> ApplyStatusSnapshot { self.lifecycle .apply_history .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .snapshot() } - pub fn record_apply_success(&self, accepted_revision: u64, active_revision: u64) { - self.update_desired_revision(accepted_revision); - let mut history = - self.lifecycle.apply_history.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - history.attempts_total += 1; - history.successes_total += 1; - history.last_result = Some(ApplyResultSnapshot { - finished_at_unix_ms: unix_time_ms(SystemTime::now()), - outcome: ApplyOutcomeSnapshot::Success { accepted_revision }, - active_revision, - rollback_preserved_revision: None, - }); - self.mark_status_snapshot_changed(); - } - - pub fn record_apply_failure( - &self, - stage: ConfigFailureStageSnapshot, - error: impl Into, - active_revision: u64, - rollback_preserved_revision: Option, - ) { - let mut history = - self.lifecycle.apply_history.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - history.attempts_total += 1; - history.failures_total += 1; - history.last_result = Some(ApplyResultSnapshot { - finished_at_unix_ms: unix_time_ms(SystemTime::now()), - outcome: ApplyOutcomeSnapshot::Failure { stage, error: error.into() }, - active_revision, - rollback_preserved_revision, - }); - self.mark_status_snapshot_changed(); - } - - pub fn record_ocsp_refresh_success(&self, scope: &str) { - let mut statuses = - self.lifecycle.ocsp_statuses.write().unwrap_or_else(|poisoned| poisoned.into_inner()); - let entry = statuses.entry(scope.to_string()).or_default(); - entry.last_refresh_unix_ms = Some(unix_time_ms(SystemTime::now())); - entry.refreshes_total += 1; - entry.last_error = None; - self.mark_status_snapshot_changed(); - } - - pub fn record_ocsp_refresh_failure(&self, scope: &str, error: impl Into) { - let mut statuses = - self.lifecycle.ocsp_statuses.write().unwrap_or_else(|poisoned| poisoned.into_inner()); - let entry = statuses.entry(scope.to_string()).or_default(); - entry.failures_total += 1; - entry.last_error = Some(error.into()); - self.mark_status_snapshot_changed(); - } - - pub async fn replace(&self, config: ConfigSnapshot) -> Result> { - let prepared = self.prepare_replacement(config).await?; - Ok(self.commit_prepared(prepared).await) - } - - pub async fn refresh_tls_acceptors_from_current_config(&self) -> Result<()> { - let config = self.current_config().await; - let listener_tls_acceptors = prepare_listener_tls_acceptors(config.as_ref())?; - *self.listener_runtime.tls_acceptors.write().await = listener_tls_acceptors; - self.mark_status_snapshot_changed(); - Ok(()) - } - - async fn prepare_replacement(&self, config: ConfigSnapshot) -> Result { - let current = self.current_config().await; - validate_config_transition(current.as_ref(), &config)?; - let mut prepared = prepare_state( - config, - Some(build_peer_health_notifier( - self.snapshot_bus.clone(), - self.observability.peer_health_component_versions.clone(), - )), - Some(build_cache_notifier( - self.snapshot_bus.clone(), - self.observability.cache_component_versions.clone(), - )), - )?; - prepared.retired_listeners = current - .listeners - .iter() - .filter(|listener| prepared.config.listener(&listener.id).is_none()) - .cloned() - .collect(); - Ok(prepared) - } - async fn commit_prepared(&self, prepared: PreparedState) -> Arc { let previous_config = self.current_config().await; if !prepared.retired_listeners.is_empty() { @@ -144,7 +24,7 @@ impl SharedState { .listener_runtime .retired .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); for listener in &prepared.retired_listeners { retired.insert(listener.id.clone(), listener.clone()); } @@ -167,7 +47,7 @@ impl SharedState { let next_revision = { let mut state = self.inner.write().await; - let next_revision = state.revision + 1; + let next_revision = state.revision.saturating_add(1); state.revision = next_revision; state.config = prepared.config.clone(); state.clients = prepared.clients; @@ -175,7 +55,7 @@ impl SharedState { next_revision }; - *self.lifecycle.node_identity.write().unwrap_or_else(|poisoned| poisoned.into_inner()) = + *self.lifecycle.node_identity.write().unwrap_or_else(std::sync::PoisonError::into_inner) = self.resolved_node_identity( prepared.config.agent.as_ref(), prepared.config.control_plane.as_ref(), @@ -209,4 +89,130 @@ impl SharedState { prepared.config } + + async fn prepare_replacement(&self, config: ConfigSnapshot) -> Result { + let current = self.current_config().await; + validate_config_transition(current.as_ref(), &config)?; + let mut prepared = prepare_state( + config, + Some(build_peer_health_notifier( + self.snapshot_bus.clone(), + self.observability.peer_health_component_versions.clone(), + )), + Some(build_cache_notifier( + self.snapshot_bus.clone(), + self.observability.cache_component_versions.clone(), + )), + )?; + prepared.retired_listeners = current + .listeners + .iter() + .filter(|listener| prepared.config.listener(&listener.id).is_none()) + .cloned() + .collect(); + Ok(prepared) + } + + pub fn record_apply_failure( + &self, + stage: ConfigFailureStageSnapshot, + error: impl Into, + active_revision: u64, + rollback_preserved_revision: Option, + ) { + let mut history = + self.lifecycle.apply_history.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + history.attempts_total = history.attempts_total.saturating_add(1); + history.failures_total = history.failures_total.saturating_add(1); + history.last_result = Some(ApplyResultSnapshot { + finished_at_unix_ms: unix_time_ms(SystemTime::now()), + outcome: ApplyOutcomeSnapshot::Failure { stage, error: error.into() }, + active_revision, + rollback_preserved_revision, + }); + self.mark_status_snapshot_changed(); + } + + pub fn record_apply_success(&self, accepted_revision: u64, active_revision: u64) { + self.update_desired_revision(accepted_revision); + let mut history = + self.lifecycle.apply_history.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + history.attempts_total = history.attempts_total.saturating_add(1); + history.successes_total = history.successes_total.saturating_add(1); + history.last_result = Some(ApplyResultSnapshot { + finished_at_unix_ms: unix_time_ms(SystemTime::now()), + outcome: ApplyOutcomeSnapshot::Success { accepted_revision }, + active_revision, + rollback_preserved_revision: None, + }); + self.mark_status_snapshot_changed(); + } + + pub fn record_ocsp_refresh_failure(&self, scope: &str, error: impl Into) { + let mut statuses = + self.lifecycle.ocsp_statuses.write().unwrap_or_else(std::sync::PoisonError::into_inner); + let entry = statuses.entry(scope.to_string()).or_default(); + entry.failures_total = entry.failures_total.saturating_add(1); + entry.last_error = Some(error.into()); + self.mark_status_snapshot_changed(); + } + + pub fn record_ocsp_refresh_success(&self, scope: &str) { + let mut statuses = + self.lifecycle.ocsp_statuses.write().unwrap_or_else(std::sync::PoisonError::into_inner); + let entry = statuses.entry(scope.to_string()).or_default(); + entry.last_refresh_unix_ms = Some(unix_time_ms(SystemTime::now())); + entry.refreshes_total = entry.refreshes_total.saturating_add(1); + entry.last_error = None; + self.mark_status_snapshot_changed(); + } + + pub fn record_reload_failure( + &self, + stage: ConfigFailureStageSnapshot, + error: impl Into, + active_revision: u64, + ) { + let mut history = + self.lifecycle.reload_history.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + history.attempts_total = history.attempts_total.saturating_add(1); + history.failures_total = history.failures_total.saturating_add(1); + history.last_result = Some(ReloadResultSnapshot { + finished_at_unix_ms: unix_time_ms(SystemTime::now()), + outcome: ReloadOutcomeSnapshot::Failure { stage, error: error.into() }, + tls_certificate_changes: Vec::new(), + active_revision, + rollback_preserved_revision: Some(active_revision), + }); + self.mark_status_snapshot_changed(); + } + + pub fn record_reload_success(&self, revision: u64, tls_certificate_changes: Vec) { + self.update_desired_revision(revision); + let mut history = + self.lifecycle.reload_history.lock().unwrap_or_else(std::sync::PoisonError::into_inner); + history.attempts_total = history.attempts_total.saturating_add(1); + history.successes_total = history.successes_total.saturating_add(1); + history.last_result = Some(ReloadResultSnapshot { + finished_at_unix_ms: unix_time_ms(SystemTime::now()), + outcome: ReloadOutcomeSnapshot::Success { revision }, + tls_certificate_changes, + active_revision: revision, + rollback_preserved_revision: None, + }); + self.mark_status_snapshot_changed(); + } + + pub async fn refresh_tls_acceptors_from_current_config(&self) -> Result<()> { + let config = self.current_config().await; + let listener_tls_acceptors = prepare_listener_tls_acceptors(config.as_ref())?; + *self.listener_runtime.tls_acceptors.write().await = listener_tls_acceptors; + self.mark_status_snapshot_changed(); + Ok(()) + } + + pub async fn replace(&self, config: ConfigSnapshot) -> Result> { + let prepared = self.prepare_replacement(config).await?; + Ok(self.commit_prepared(prepared).await) + } } diff --git a/crates/rginx-http/src/state/lifecycle/status.rs b/crates/rginx-http/src/state/lifecycle/status.rs index 12a7f222..54542afb 100644 --- a/crates/rginx-http/src/state/lifecycle/status.rs +++ b/crates/rginx-http/src/state/lifecycle/status.rs @@ -1,23 +1,62 @@ +mod status_helpers; use super::super::tls_runtime::{ tls_runtime_snapshot_for_config_with_ocsp_statuses, upstream_tls_status_snapshots, }; -use super::super::*; -mod status_helpers; +use super::super::{ + Arc, AtomicUsize, CacheStatsSnapshot, ConfigSnapshot, Listener, Ordering, RateLimiters, + ReloadStatusSnapshot, RuntimeListenerSnapshot, RuntimeStatusSnapshot, SharedState, TlsAcceptor, + http3_runtime_snapshot, +}; use status_helpers::{acme_runtime_snapshot, node_identity_snapshot}; impl SharedState { + #[must_use] + pub fn next_request_id(&self) -> String { + uuid::Uuid::now_v7().to_string() + } + + #[must_use] pub fn rate_limiters(&self) -> RateLimiters { self.request_runtime.rate_limiters.clone() } + #[must_use] pub fn reload_status_snapshot(&self) -> ReloadStatusSnapshot { self.lifecycle .reload_history .lock() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .snapshot() } + pub async fn remove_retired_listener_runtime(&self, listener_id: &str) { + self.listener_runtime + .retired + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .remove(listener_id); + self.listener_runtime.tls_acceptors.write().await.remove(listener_id); + self.listener_runtime + .active_connections + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .remove(listener_id); + } + + pub fn retire_listener_runtime(&self, listener: &Listener) { + self.listener_runtime + .retired + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .insert(listener.id.clone(), listener.clone()); + self.listener_runtime + .active_connections + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .entry(listener.id.clone()) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))); + } + pub async fn status_snapshot(&self) -> RuntimeStatusSnapshot { let (revision, config, cache_manager) = { let state = self.inner.read().await; @@ -25,20 +64,20 @@ impl SharedState { }; let desired_revision = self.desired_revision(); let node = node_identity_snapshot( - &self.lifecycle.node_identity.read().unwrap_or_else(|poisoned| poisoned.into_inner()), + &self.lifecycle.node_identity.read().unwrap_or_else(std::sync::PoisonError::into_inner), ); let cache = CacheStatsSnapshot { zones: cache_manager.snapshot_with_shared_sync().await }; let ocsp_statuses = self .lifecycle .ocsp_statuses .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .clone(); let acme_statuses = self .lifecycle .acme_statuses .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .clone(); let mtls = self.mtls_status_snapshot(config.as_ref()); let tls = tls_runtime_snapshot_for_config_with_ocsp_statuses( @@ -50,33 +89,40 @@ impl SharedState { .observability .traffic_stats .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let mut http3_active_connections = 0; - let mut http3_active_request_streams = 0; - let mut http3_retry_issued_total = 0; - let mut http3_retry_failed_total = 0; - let mut http3_request_accept_errors_total = 0; - let mut http3_request_resolve_errors_total = 0; - let mut http3_request_body_stream_errors_total = 0; - let mut http3_response_stream_errors_total = 0; + .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut http3_active_connections = 0usize; + let mut http3_active_request_streams = 0usize; + let mut http3_retry_issued_total = 0u64; + let mut http3_retry_failed_total = 0u64; + let mut http3_request_accept_errors_total = 0u64; + let mut http3_request_resolve_errors_total = 0u64; + let mut http3_request_body_stream_errors_total = 0u64; + let mut http3_response_stream_errors_total = 0u64; for listener in &config.listeners { let Some(entry) = listener_traffic.listeners.get(&listener.id) else { continue; }; let counters = &entry.counters; - http3_active_connections += counters.active_http3_connections.load(Ordering::Acquire); - http3_active_request_streams += - counters.active_http3_request_streams.load(Ordering::Acquire); - http3_retry_issued_total += counters.http3_retry_issued_total.load(Ordering::Relaxed); - http3_retry_failed_total += counters.http3_retry_failed_total.load(Ordering::Relaxed); - http3_request_accept_errors_total += - counters.http3_request_accept_errors_total.load(Ordering::Relaxed); - http3_request_resolve_errors_total += - counters.http3_request_resolve_errors_total.load(Ordering::Relaxed); - http3_request_body_stream_errors_total += - counters.http3_request_body_stream_errors_total.load(Ordering::Relaxed); - http3_response_stream_errors_total += - counters.http3_response_stream_errors_total.load(Ordering::Relaxed); + http3_active_connections = http3_active_connections + .saturating_add(counters.active_http3_connections.load(Ordering::Acquire)); + http3_active_request_streams = http3_active_request_streams + .saturating_add(counters.active_http3_request_streams.load(Ordering::Acquire)); + http3_retry_issued_total = http3_retry_issued_total + .saturating_add(counters.http3_retry_issued_total.load(Ordering::Relaxed)); + http3_retry_failed_total = http3_retry_failed_total + .saturating_add(counters.http3_retry_failed_total.load(Ordering::Relaxed)); + http3_request_accept_errors_total = http3_request_accept_errors_total + .saturating_add(counters.http3_request_accept_errors_total.load(Ordering::Relaxed)); + http3_request_resolve_errors_total = http3_request_resolve_errors_total.saturating_add( + counters.http3_request_resolve_errors_total.load(Ordering::Relaxed), + ); + http3_request_body_stream_errors_total = http3_request_body_stream_errors_total + .saturating_add( + counters.http3_request_body_stream_errors_total.load(Ordering::Relaxed), + ); + http3_response_stream_errors_total = http3_response_stream_errors_total.saturating_add( + counters.http3_response_stream_errors_total.load(Ordering::Relaxed), + ); } RuntimeStatusSnapshot { revision, @@ -191,50 +237,18 @@ impl SharedState { } } - pub async fn tls_acceptor(&self, listener_id: &str) -> Option { - self.listener_runtime.tls_acceptors.read().await.get(listener_id).cloned().flatten() - } - - pub fn next_request_id(&self) -> String { - uuid::Uuid::now_v7().to_string() - } - - pub fn retire_listener_runtime(&self, listener: &Listener) { - self.listener_runtime - .retired - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .insert(listener.id.clone(), listener.clone()); - self.listener_runtime - .active_connections - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .entry(listener.id.clone()) - .or_insert_with(|| Arc::new(AtomicUsize::new(0))); - } - - pub async fn remove_retired_listener_runtime(&self, listener_id: &str) { - self.listener_runtime - .retired - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .remove(listener_id); - self.listener_runtime.tls_acceptors.write().await.remove(listener_id); - self.listener_runtime - .active_connections - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .remove(listener_id); - } - pub(super) fn sync_listener_active_connections(&self, config: &ConfigSnapshot) { let mut active = self .listener_runtime .active_connections .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); for listener in &config.listeners { active.entry(listener.id.clone()).or_insert_with(|| Arc::new(AtomicUsize::new(0))); } } + + pub async fn tls_acceptor(&self, listener_id: &str) -> Option { + self.listener_runtime.tls_acceptors.read().await.get(listener_id).cloned().flatten() + } } diff --git a/crates/rginx-http/src/state/lifecycle/status/status_helpers.rs b/crates/rginx-http/src/state/lifecycle/status/status_helpers.rs index 3f0a48d1..68905220 100644 --- a/crates/rginx-http/src/state/lifecycle/status/status_helpers.rs +++ b/crates/rginx-http/src/state/lifecycle/status/status_helpers.rs @@ -1,4 +1,7 @@ -use super::super::super::*; +use super::super::super::{ + AcmeManagedCertificateSnapshot, AcmeRuntimeSnapshot, AcmeRuntimeStatusEntry, ConfigSnapshot, + HashMap, NodeIdentitySnapshot, NodeIdentityState, TlsCertificateStatusSnapshot, +}; pub(super) fn node_identity_snapshot(node_identity: &NodeIdentityState) -> NodeIdentitySnapshot { NodeIdentitySnapshot { @@ -43,8 +46,8 @@ pub(super) fn acme_runtime_snapshot( key_path: spec.key_path.clone(), last_success_unix_ms: runtime.and_then(|entry| entry.last_success_unix_ms), next_renewal_unix_ms: next_renewal_unix_ms(certificate, config.acme.as_ref()), - refreshes_total: runtime.map(|entry| entry.refreshes_total).unwrap_or(0), - failures_total: runtime.map(|entry| entry.failures_total).unwrap_or(0), + refreshes_total: runtime.map_or(0, |entry| entry.refreshes_total), + failures_total: runtime.map_or(0, |entry| entry.failures_total), retry_after_unix_ms: runtime.and_then(|entry| entry.retry_after_unix_ms), last_error: runtime.and_then(|entry| entry.last_error.clone()), } diff --git a/crates/rginx-http/src/state/lifecycle/tasks.rs b/crates/rginx-http/src/state/lifecycle/tasks.rs index 63b1888c..92993d34 100644 --- a/crates/rginx-http/src/state/lifecycle/tasks.rs +++ b/crates/rginx-http/src/state/lifecycle/tasks.rs @@ -1,15 +1,19 @@ -use super::super::*; +use super::super::{Future, SharedState, take_background_tasks}; impl SharedState { - pub fn spawn_background_task(&self, task: F) - where - F: Future + Send + 'static, - { - let handle = tokio::spawn(task); - let mut tasks = - self.lifecycle.background_tasks.lock().unwrap_or_else(|poisoned| poisoned.into_inner()); - tasks.retain(|task| !task.is_finished()); - tasks.push(handle); + pub async fn abort_background_tasks(&self) { + let tasks = take_background_tasks(&self.lifecycle.background_tasks); + for task in &tasks { + task.abort(); + } + + for task in tasks { + if let Err(error) = task.await + && !error.is_cancelled() + { + tracing::warn!(%error, "background task failed after abort"); + } + } } pub async fn drain_background_tasks(&self) { @@ -24,18 +28,17 @@ impl SharedState { } } - pub async fn abort_background_tasks(&self) { - let tasks = take_background_tasks(&self.lifecycle.background_tasks); - for task in &tasks { - task.abort(); - } - - for task in tasks { - if let Err(error) = task.await - && !error.is_cancelled() - { - tracing::warn!(%error, "background task failed after abort"); - } - } + pub fn spawn_background_task(&self, task: F) + where + F: Future + Send + 'static, + { + let handle = tokio::spawn(task); + let mut tasks = self + .lifecycle + .background_tasks + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + tasks.retain(|task| !task.is_finished()); + tasks.push(handle); } } diff --git a/crates/rginx-http/src/state/lifecycle/topology.rs b/crates/rginx-http/src/state/lifecycle/topology.rs index 7b77770b..25a8b48a 100644 --- a/crates/rginx-http/src/state/lifecycle/topology.rs +++ b/crates/rginx-http/src/state/lifecycle/topology.rs @@ -1,4 +1,4 @@ -use super::super::*; +use super::super::ConfigSnapshot; pub(super) fn traffic_topology_changed(previous: &ConfigSnapshot, next: &ConfigSnapshot) -> bool { let listener_ids = |config: &ConfigSnapshot| { diff --git a/crates/rginx-http/src/state/mod.rs b/crates/rginx-http/src/state/mod.rs index f2177be6..c26731c7 100644 --- a/crates/rginx-http/src/state/mod.rs +++ b/crates/rginx-http/src/state/mod.rs @@ -1,3 +1,18 @@ +mod agent; +mod cache; +mod connections; +mod counters; +mod helpers; +mod lifecycle; +mod snapshot_bus; +mod snapshots; +mod structure; +#[cfg(test)] +mod tests; +mod tls_runtime; +mod traffic; +mod upstreams; + use std::collections::{BTreeMap, HashMap, VecDeque}; use std::future::Future; use std::path::PathBuf; @@ -17,33 +32,6 @@ use crate::rate_limit::RateLimiters; use crate::tls::build_tls_acceptor; use crate::tls::ocsp::ocsp_responder_urls_for_certificate; -mod agent; -mod cache; -mod connections; -mod counters; -mod helpers; -mod lifecycle; -mod snapshot_bus; -mod snapshots; -mod structure; -#[cfg(test)] -mod tests; -mod tls_runtime; -mod traffic; -mod upstreams; - -const RECENT_WINDOW_SECS: u64 = 60; -const MAX_RECENT_WINDOW_SECS: u64 = 300; -const TLS_EXPIRY_WARNING_DAYS: i64 = 30; - -pub(super) struct PreparedState { - config: Arc, - clients: ProxyClients, - cache: CacheManager, - listener_tls_acceptors: HashMap>, - retired_listeners: Vec, -} - use self::agent::AgentRuntimeState; pub(crate) use self::counters::http::TlsHandshakeFailureReason; pub use crate::cache::{CacheInvalidationResult, CachePurgeResult, CacheZoneRuntimeSnapshot}; @@ -86,41 +74,53 @@ pub use tls_runtime::{ tls_runtime_snapshot_for_config, }; +const RECENT_WINDOW_SECS: u64 = 60; +const MAX_RECENT_WINDOW_SECS: u64 = 300; +const TLS_EXPIRY_WARNING_DAYS: i64 = 30; + +pub(super) struct PreparedState { + cache: CacheManager, + clients: ProxyClients, + config: Arc, + listener_tls_acceptors: HashMap>, + retired_listeners: Vec, +} + #[derive(Clone)] pub struct SharedState { + config_path: Option>, inner: Arc>, - revisions: watch::Sender, - snapshot_bus: SnapshotBusState, - request_runtime: RequestRuntimeState, + lifecycle: LifecycleState, listener_runtime: ListenerRuntimeState, observability: ObservabilityState, - lifecycle: LifecycleState, - config_path: Option>, + request_runtime: RequestRuntimeState, + revisions: watch::Sender, + snapshot_bus: SnapshotBusState, } #[derive(Debug, Clone, Default, PartialEq, Eq)] struct NodeIdentityState { + labels: BTreeMap, node_id: Option, - region: Option, pop: Option, - labels: BTreeMap, + region: Option, } #[derive(Debug, Clone, Default)] struct OcspRuntimeStatusEntry { - last_refresh_unix_ms: Option, - refreshes_total: u64, failures_total: u64, last_error: Option, + last_refresh_unix_ms: Option, + refreshes_total: u64, } #[derive(Debug, Clone, Default)] struct AcmeRuntimeStatusEntry { + failures_total: u64, + last_error: Option, last_success_unix_ms: Option, refreshes_total: u64, - failures_total: u64, retry_after_unix_ms: Option, - last_error: Option, } impl SharedState { @@ -198,6 +198,15 @@ impl SharedState { } impl NodeIdentityState { + fn from_agent(agent: &rginx_core::AgentSettings) -> Self { + Self { + node_id: Some(agent.node_id.clone()), + region: agent.region.clone(), + pop: agent.pop.clone(), + labels: agent.labels.clone(), + } + } + fn from_config( agent: Option<&rginx_core::AgentSettings>, control_plane: Option<&rginx_core::ControlPlaneSettings>, @@ -208,15 +217,6 @@ impl NodeIdentityState { Self::from_control_plane(control_plane) } - fn from_agent(agent: &rginx_core::AgentSettings) -> Self { - Self { - node_id: Some(agent.node_id.clone()), - region: agent.region.clone(), - pop: agent.pop.clone(), - labels: agent.labels.clone(), - } - } - fn from_control_plane(control_plane: Option<&rginx_core::ControlPlaneSettings>) -> Self { let Some(control_plane) = control_plane else { return Self::default(); diff --git a/crates/rginx-http/src/state/snapshot_bus.rs b/crates/rginx-http/src/state/snapshot_bus.rs index e9812b41..bdb5e372 100644 --- a/crates/rginx-http/src/state/snapshot_bus.rs +++ b/crates/rginx-http/src/state/snapshot_bus.rs @@ -1,45 +1,21 @@ -use super::*; - mod delta; mod identity; +use super::{ + ActiveState, Arc, ConfigSnapshot, HashMap, Listener, NodeIdentityState, PathBuf, + RevisionStatusSnapshot, SharedState, StdRwLock, watch, +}; + impl SharedState { - pub async fn snapshot(&self) -> ActiveState { - self.inner.read().await.clone() + #[must_use] + pub fn config_path(&self) -> Option<&PathBuf> { + self.config_path.as_deref() } pub async fn current_config(&self) -> Arc { self.inner.read().await.config.clone() } - pub async fn current_revision(&self) -> u64 { - self.inner.read().await.revision - } - - pub fn desired_revision(&self) -> u64 { - self.snapshot_bus.desired_revision() - } - - pub async fn revision_status_snapshot(&self) -> RevisionStatusSnapshot { - let current_revision = self.current_revision().await; - let desired_revision = self.desired_revision(); - RevisionStatusSnapshot { - desired_revision, - current_revision, - converged: desired_revision == current_revision, - } - } - - pub fn set_desired_revision(&self, desired_revision: u64) { - if self.update_desired_revision(desired_revision) { - self.mark_status_snapshot_changed(); - } - } - - pub fn current_snapshot_version(&self) -> u64 { - self.snapshot_bus.current_version() - } - pub async fn current_listener(&self, listener_id: &str) -> Option { if let Some(listener) = self.inner.read().await.config.listener(listener_id).cloned() { return Some(listener); @@ -48,39 +24,23 @@ impl SharedState { self.listener_runtime .retired .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .get(listener_id) .cloned() } - pub fn config_path(&self) -> Option<&PathBuf> { - self.config_path.as_deref() + pub async fn current_revision(&self) -> u64 { + self.inner.read().await.revision } - pub fn subscribe_updates(&self) -> watch::Receiver { - self.revisions.subscribe() + #[must_use] + pub fn current_snapshot_version(&self) -> u64 { + self.snapshot_bus.current_version() } - pub async fn wait_for_snapshot_change( - &self, - since_version: u64, - timeout: Option, - ) -> u64 { - loop { - let notified = self.snapshot_bus.notify.notified(); - let current = self.current_snapshot_version(); - if current > since_version { - return current; - } - - if let Some(timeout) = timeout { - if tokio::time::timeout(timeout, notified).await.is_err() { - return self.current_snapshot_version(); - } - } else { - notified.await; - } - } + #[must_use] + pub fn desired_revision(&self) -> u64 { + self.snapshot_bus.desired_revision() } pub(crate) fn mark_named_component_target_changed( @@ -91,11 +51,11 @@ impl SharedState { ) { versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .insert(name.to_string(), version); } - pub(crate) fn mark_snapshot_changed_and_notify( + pub(crate) fn mark_snapshot_changed( &self, status: bool, counters: bool, @@ -103,7 +63,7 @@ impl SharedState { peer_health: bool, upstreams: bool, ) -> u64 { - self.mark_snapshot_changed_with_cache_and_notify( + self.mark_snapshot_changed_with_cache( status, counters, traffic, @@ -113,44 +73,37 @@ impl SharedState { ) } - pub(crate) fn mark_snapshot_changed_with_cache_and_notify( + pub(crate) fn mark_snapshot_changed_and_notify( &self, status: bool, counters: bool, traffic: bool, peer_health: bool, upstreams: bool, - cache: bool, ) -> u64 { - self.snapshot_bus.mark_changed_and_notify( + self.mark_snapshot_changed_with_cache_and_notify( status, counters, traffic, peer_health, upstreams, - cache, + false, ) } - pub(crate) fn mark_snapshot_changed( + pub(crate) fn mark_snapshot_changed_with_cache( &self, status: bool, counters: bool, traffic: bool, peer_health: bool, upstreams: bool, + cache: bool, ) -> u64 { - self.mark_snapshot_changed_with_cache( - status, - counters, - traffic, - peer_health, - upstreams, - false, - ) + self.snapshot_bus.mark_changed(status, counters, traffic, peer_health, upstreams, cache) } - pub(crate) fn mark_snapshot_changed_with_cache( + pub(crate) fn mark_snapshot_changed_with_cache_and_notify( &self, status: bool, counters: bool, @@ -159,11 +112,14 @@ impl SharedState { upstreams: bool, cache: bool, ) -> u64 { - self.snapshot_bus.mark_changed(status, counters, traffic, peer_health, upstreams, cache) - } - - pub(crate) fn notify_snapshot_waiters(&self) { - self.snapshot_bus.notify_waiters(); + self.snapshot_bus.mark_changed_and_notify( + status, + counters, + traffic, + peer_health, + upstreams, + cache, + ) } pub(crate) fn mark_status_snapshot_changed(&self) -> u64 { @@ -195,7 +151,58 @@ impl SharedState { version } + pub(crate) fn notify_snapshot_waiters(&self) { + self.snapshot_bus.notify_waiters(); + } + + pub async fn revision_status_snapshot(&self) -> RevisionStatusSnapshot { + let current_revision = self.current_revision().await; + let desired_revision = self.desired_revision(); + RevisionStatusSnapshot { + desired_revision, + current_revision, + converged: desired_revision == current_revision, + } + } + + pub fn set_desired_revision(&self, desired_revision: u64) { + if self.update_desired_revision(desired_revision) { + self.mark_status_snapshot_changed(); + } + } + + pub async fn snapshot(&self) -> ActiveState { + self.inner.read().await.clone() + } + + #[must_use] + pub fn subscribe_updates(&self) -> watch::Receiver { + self.revisions.subscribe() + } + pub(crate) fn update_desired_revision(&self, desired_revision: u64) -> bool { self.snapshot_bus.update_desired_revision(desired_revision) } + + pub async fn wait_for_snapshot_change( + &self, + since_version: u64, + timeout: Option, + ) -> u64 { + loop { + let notified = self.snapshot_bus.notify.notified(); + let current = self.current_snapshot_version(); + if current > since_version { + return current; + } + + if let Some(timeout) = timeout { + if tokio::time::timeout(timeout, notified).await.is_err() { + return self.current_snapshot_version(); + } + } else { + notified.await; + } + } + } } diff --git a/crates/rginx-http/src/state/snapshot_bus/delta.rs b/crates/rginx-http/src/state/snapshot_bus/delta.rs index 3ed6bd67..3a429d94 100644 --- a/crates/rginx-http/src/state/snapshot_bus/delta.rs +++ b/crates/rginx-http/src/state/snapshot_bus/delta.rs @@ -1,6 +1,22 @@ -use super::super::*; +use super::super::{ + Arc, HashMap, Ordering, SharedState, SnapshotDeltaSnapshot, SnapshotModule, StdRwLock, +}; impl SharedState { + pub(crate) fn changed_named_component_targets_since( + &self, + versions: &Arc>>, + since_version: u64, + ) -> Vec { + let versions = versions.read().unwrap_or_else(std::sync::PoisonError::into_inner); + let mut changed = versions + .iter() + .filter_map(|(name, version)| (*version > since_version).then_some(name.clone())) + .collect::>(); + changed.sort(); + changed + } + #[must_use] pub fn snapshot_delta_since( &self, since_version: u64, @@ -157,20 +173,6 @@ impl SharedState { ), } } - - pub(crate) fn changed_named_component_targets_since( - &self, - versions: &Arc>>, - since_version: u64, - ) -> Vec { - let versions = versions.read().unwrap_or_else(|poisoned| poisoned.into_inner()); - let mut changed = versions - .iter() - .filter_map(|(name, version)| (*version > since_version).then_some(name.clone())) - .collect::>(); - changed.sort(); - changed - } } fn module_value( diff --git a/crates/rginx-http/src/state/snapshot_bus/identity.rs b/crates/rginx-http/src/state/snapshot_bus/identity.rs index 5d1b8c92..a7fd7597 100644 --- a/crates/rginx-http/src/state/snapshot_bus/identity.rs +++ b/crates/rginx-http/src/state/snapshot_bus/identity.rs @@ -1,6 +1,18 @@ -use super::*; +use super::{NodeIdentityState, SharedState}; impl SharedState { + pub(in crate::state) fn resolved_node_identity( + &self, + agent: Option<&rginx_core::AgentSettings>, + control_plane: Option<&rginx_core::ControlPlaneSettings>, + ) -> NodeIdentityState { + self.lifecycle + .node_identity_override + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone() + .unwrap_or_else(|| NodeIdentityState::from_config(agent, control_plane)) + } pub fn set_agent_identity(&self, agent: &rginx_core::AgentSettings) { self.set_node_identity_override(NodeIdentityState::from_agent(agent)); } @@ -17,7 +29,7 @@ impl SharedState { .lifecycle .node_identity_override .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); if override_state.as_ref() != Some(&next) { *override_state = Some(next.clone()); changed = true; @@ -29,7 +41,7 @@ impl SharedState { .lifecycle .node_identity .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); if *node_identity != next { *node_identity = next; changed = true; @@ -40,17 +52,4 @@ impl SharedState { self.mark_status_snapshot_changed(); } } - - pub(in crate::state) fn resolved_node_identity( - &self, - agent: Option<&rginx_core::AgentSettings>, - control_plane: Option<&rginx_core::ControlPlaneSettings>, - ) -> NodeIdentityState { - self.lifecycle - .node_identity_override - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) - .clone() - .unwrap_or_else(|| NodeIdentityState::from_config(agent, control_plane)) - } } diff --git a/crates/rginx-http/src/state/snapshots/acme.rs b/crates/rginx-http/src/state/snapshots/acme.rs index 7675af02..5f0204d6 100644 --- a/crates/rginx-http/src/state/snapshots/acme.rs +++ b/crates/rginx-http/src/state/snapshots/acme.rs @@ -4,28 +4,28 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct AcmeRuntimeSnapshot { - pub enabled: bool, #[serde(skip_serializing_if = "Option::is_none")] pub directory_url: Option, + pub enabled: bool, pub managed_certificates: Vec, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AcmeManagedCertificateSnapshot { - pub scope: String, - pub domains: Vec, - pub managed: bool, - pub challenge_type: String, pub cert_path: PathBuf, + pub challenge_type: String, + pub domains: Vec, + pub failures_total: u64, pub key_path: PathBuf, #[serde(skip_serializing_if = "Option::is_none")] + pub last_error: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub last_success_unix_ms: Option, + pub managed: bool, #[serde(skip_serializing_if = "Option::is_none")] pub next_renewal_unix_ms: Option, pub refreshes_total: u64, - pub failures_total: u64, #[serde(skip_serializing_if = "Option::is_none")] pub retry_after_unix_ms: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub last_error: Option, + pub scope: String, } diff --git a/crates/rginx-http/src/state/snapshots/active.rs b/crates/rginx-http/src/state/snapshots/active.rs index b7d6a1bc..afeb96de 100644 --- a/crates/rginx-http/src/state/snapshots/active.rs +++ b/crates/rginx-http/src/state/snapshots/active.rs @@ -7,8 +7,8 @@ use crate::proxy::ProxyClients; #[derive(Clone)] pub struct ActiveState { - pub revision: u64, - pub config: Arc, - pub clients: ProxyClients, pub(crate) cache: CacheManager, + pub clients: ProxyClients, + pub config: Arc, + pub revision: u64, } diff --git a/crates/rginx-http/src/state/snapshots/apply.rs b/crates/rginx-http/src/state/snapshots/apply.rs index 465dfe3d..f6c0beed 100644 --- a/crates/rginx-http/src/state/snapshots/apply.rs +++ b/crates/rginx-http/src/state/snapshots/apply.rs @@ -3,38 +3,38 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ConfigFailureStageSnapshot { - LoadConfig, - ValidateTransition, - PrepareListenerBindings, - LoadManagedDocuments, - ApplyManagedMutation, - ResolveManagedPath, - CompileConfig, ActivateManagedFiles, + ApplyManagedMutation, BeforeReplaceHook, + CompileConfig, + LoadConfig, + LoadManagedDocuments, + PrepareListenerBindings, ReplaceRuntime, + ResolveManagedPath, RollbackManagedFiles, + ValidateTransition, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ApplyOutcomeSnapshot { - Success { accepted_revision: u64 }, Failure { stage: ConfigFailureStageSnapshot, error: String }, + Success { accepted_revision: u64 }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ApplyResultSnapshot { + pub active_revision: u64, pub finished_at_unix_ms: u64, pub outcome: ApplyOutcomeSnapshot, - pub active_revision: u64, pub rollback_preserved_revision: Option, } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct ApplyStatusSnapshot { pub attempts_total: u64, - pub successes_total: u64, pub failures_total: u64, pub last_result: Option, + pub successes_total: u64, } diff --git a/crates/rginx-http/src/state/snapshots/delta.rs b/crates/rginx-http/src/state/snapshots/delta.rs index a3139d4b..fee1efa2 100644 --- a/crates/rginx-http/src/state/snapshots/delta.rs +++ b/crates/rginx-http/src/state/snapshots/delta.rs @@ -2,71 +2,71 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct SnapshotDeltaSnapshot { - pub schema_version: u32, - pub since_version: u64, - pub current_snapshot_version: u64, - pub included_modules: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub recent_window_secs: Option, + pub cache_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub status_version: Option, + pub cache_version: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub counters_version: Option, + pub changed_cache_zone_names: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub traffic_version: Option, + pub changed_listener_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub peer_health_version: Option, + pub changed_peer_health_upstream_names: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub upstreams_version: Option, + pub changed_recent_listener_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub cache_version: Option, + pub changed_recent_route_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub status_changed: Option, + pub changed_recent_upstream_names: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub counters_changed: Option, + pub changed_recent_vhost_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub traffic_changed: Option, + pub changed_route_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub traffic_recent_changed: Option, + pub changed_upstream_names: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub peer_health_changed: Option, + pub changed_vhost_ids: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub upstreams_changed: Option, + pub counters_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub cache_changed: Option, + pub counters_version: Option, + pub current_snapshot_version: u64, + pub included_modules: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub upstreams_recent_changed: Option, + pub peer_health_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_listener_ids: Option>, + pub peer_health_version: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_vhost_ids: Option>, + pub recent_window_secs: Option, + pub schema_version: u32, + pub since_version: u64, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_route_ids: Option>, + pub status_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_recent_listener_ids: Option>, + pub status_version: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_recent_vhost_ids: Option>, + pub traffic_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_recent_route_ids: Option>, + pub traffic_recent_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_peer_health_upstream_names: Option>, + pub traffic_version: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_upstream_names: Option>, + pub upstreams_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_recent_upstream_names: Option>, + pub upstreams_recent_changed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub changed_cache_zone_names: Option>, + pub upstreams_version: Option, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum SnapshotModule { - Status, + Cache, Counters, - Traffic, PeerHealth, + Status, + Traffic, Upstreams, - Cache, } impl SnapshotModule { @@ -79,10 +79,12 @@ impl SnapshotModule { Self::Cache, ]; + #[must_use] pub fn all() -> Vec { Self::ALL.to_vec() } + #[must_use] pub fn normalize(include: Option<&[Self]>) -> Vec { let requested = include.unwrap_or(&Self::ALL); Self::ALL.iter().copied().filter(|module| requested.contains(module)).collect() diff --git a/crates/rginx-http/src/state/snapshots/http.rs b/crates/rginx-http/src/state/snapshots/http.rs index 7f03eb67..68176e32 100644 --- a/crates/rginx-http/src/state/snapshots/http.rs +++ b/crates/rginx-http/src/state/snapshots/http.rs @@ -4,6 +4,11 @@ use serde::{Deserialize, Serialize}; pub struct HttpCountersSnapshot { pub downstream_connections_accepted: u64, pub downstream_connections_rejected: u64, + pub downstream_http3_early_data_accepted_requests: u64, + pub downstream_http3_early_data_rejected_requests: u64, + pub downstream_mtls_anonymous_requests: u64, + pub downstream_mtls_authenticated_connections: u64, + pub downstream_mtls_authenticated_requests: u64, pub downstream_requests: u64, pub downstream_responses: u64, pub downstream_responses_1xx: u64, @@ -11,33 +16,28 @@ pub struct HttpCountersSnapshot { pub downstream_responses_3xx: u64, pub downstream_responses_4xx: u64, pub downstream_responses_5xx: u64, - pub downstream_mtls_authenticated_connections: u64, - pub downstream_mtls_authenticated_requests: u64, - pub downstream_mtls_anonymous_requests: u64, pub downstream_tls_handshake_failures: u64, - pub downstream_tls_handshake_failures_missing_client_cert: u64, - pub downstream_tls_handshake_failures_unknown_ca: u64, pub downstream_tls_handshake_failures_bad_certificate: u64, pub downstream_tls_handshake_failures_certificate_revoked: u64, - pub downstream_tls_handshake_failures_verify_depth_exceeded: u64, + pub downstream_tls_handshake_failures_missing_client_cert: u64, pub downstream_tls_handshake_failures_other: u64, - pub downstream_http3_early_data_accepted_requests: u64, - pub downstream_http3_early_data_rejected_requests: u64, + pub downstream_tls_handshake_failures_unknown_ca: u64, + pub downstream_tls_handshake_failures_verify_depth_exceeded: u64, } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct MtlsStatusSnapshot { - pub configured_listeners: usize, - pub optional_listeners: usize, - pub required_listeners: usize, + pub anonymous_requests: u64, pub authenticated_connections: u64, pub authenticated_requests: u64, - pub anonymous_requests: u64, - pub handshake_failures_total: u64, - pub handshake_failures_missing_client_cert: u64, - pub handshake_failures_unknown_ca: u64, + pub configured_listeners: usize, pub handshake_failures_bad_certificate: u64, pub handshake_failures_certificate_revoked: u64, - pub handshake_failures_verify_depth_exceeded: u64, + pub handshake_failures_missing_client_cert: u64, pub handshake_failures_other: u64, + pub handshake_failures_total: u64, + pub handshake_failures_unknown_ca: u64, + pub handshake_failures_verify_depth_exceeded: u64, + pub optional_listeners: usize, + pub required_listeners: usize, } diff --git a/crates/rginx-http/src/state/snapshots/reload.rs b/crates/rginx-http/src/state/snapshots/reload.rs index aa698111..c0978c06 100644 --- a/crates/rginx-http/src/state/snapshots/reload.rs +++ b/crates/rginx-http/src/state/snapshots/reload.rs @@ -5,23 +5,23 @@ use super::ConfigFailureStageSnapshot; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ReloadOutcomeSnapshot { - Success { revision: u64 }, Failure { stage: ConfigFailureStageSnapshot, error: String }, + Success { revision: u64 }, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ReloadResultSnapshot { + pub active_revision: u64, pub finished_at_unix_ms: u64, pub outcome: ReloadOutcomeSnapshot, - pub tls_certificate_changes: Vec, - pub active_revision: u64, pub rollback_preserved_revision: Option, + pub tls_certificate_changes: Vec, } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct ReloadStatusSnapshot { pub attempts_total: u64, - pub successes_total: u64, pub failures_total: u64, pub last_result: Option, + pub successes_total: u64, } diff --git a/crates/rginx-http/src/state/snapshots/runtime.rs b/crates/rginx-http/src/state/snapshots/runtime.rs index 4847b37c..8ddb0c28 100644 --- a/crates/rginx-http/src/state/snapshots/runtime.rs +++ b/crates/rginx-http/src/state/snapshots/runtime.rs @@ -10,46 +10,46 @@ use super::{ #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct NodeIdentitySnapshot { + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + pub labels: BTreeMap, #[serde(default)] pub node_id: String, #[serde(default, skip_serializing_if = "Option::is_none")] - pub region: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] pub pop: Option, - #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] - pub labels: BTreeMap, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub region: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AgentRuntimeSnapshot { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub command_cursor: Option, #[serde(default)] pub configured: bool, #[serde(default)] - pub enabled: bool, + pub connection_state: String, #[serde(default)] - pub locally_disabled: bool, + pub enabled: bool, #[serde(default, skip_serializing_if = "Option::is_none")] pub endpoint: Option, #[serde(default, skip_serializing_if = "Option::is_none")] - pub node_id: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub state_path: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub region: Option, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub pop: Option, + pub in_flight_command_id: Option, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] pub labels: BTreeMap, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub last_heartbeat_success_unix_ms: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub last_register_success_unix_ms: Option, #[serde(default)] - pub connection_state: String, + pub locally_disabled: bool, #[serde(default, skip_serializing_if = "Option::is_none")] - pub command_cursor: Option, + pub node_id: Option, #[serde(default, skip_serializing_if = "Option::is_none")] - pub in_flight_command_id: Option, + pub pop: Option, #[serde(default, skip_serializing_if = "Option::is_none")] - pub last_register_success_unix_ms: Option, + pub region: Option, #[serde(default, skip_serializing_if = "Option::is_none")] - pub last_heartbeat_success_unix_ms: Option, + pub state_path: Option, } impl Default for AgentRuntimeSnapshot { @@ -76,120 +76,120 @@ impl Default for AgentRuntimeSnapshot { #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct RevisionStatusSnapshot { #[serde(default)] - pub desired_revision: u64, + pub converged: bool, #[serde(default, alias = "revision")] pub current_revision: u64, #[serde(default)] - pub converged: bool, + pub desired_revision: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RuntimeStatusSnapshot { - pub revision: u64, - #[serde(flatten)] - pub node: NodeIdentitySnapshot, + pub accept_workers: usize, #[serde(default)] - pub binary_version: String, + pub acme: AcmeRuntimeSnapshot, + pub active_connections: usize, #[serde(default)] - pub desired_revision: u64, + pub agent: AgentRuntimeSnapshot, #[serde(default)] - pub converged: bool, + pub apply: ApplyStatusSnapshot, #[serde(default)] - pub agent: AgentRuntimeSnapshot, + pub binary_version: String, + pub cache: CacheStatsSnapshot, pub config_path: Option, - pub listeners: Vec, - pub worker_threads: Option, - pub accept_workers: usize, - pub total_vhosts: usize, - pub total_routes: usize, - pub total_upstreams: usize, - pub tls_enabled: bool, + #[serde(default)] + pub converged: bool, + #[serde(default)] + pub desired_revision: u64, pub http3_active_connections: usize, pub http3_active_request_streams: usize, - pub http3_retry_issued_total: u64, - pub http3_retry_failed_total: u64, + pub http3_early_data_accepted_requests: u64, + pub http3_early_data_enabled_listeners: usize, + pub http3_early_data_rejected_requests: u64, pub http3_request_accept_errors_total: u64, - pub http3_request_resolve_errors_total: u64, pub http3_request_body_stream_errors_total: u64, + pub http3_request_resolve_errors_total: u64, pub http3_response_stream_errors_total: u64, - pub http3_early_data_enabled_listeners: usize, - pub http3_early_data_accepted_requests: u64, - pub http3_early_data_rejected_requests: u64, - #[serde(default)] - pub acme: AcmeRuntimeSnapshot, - pub tls: TlsRuntimeSnapshot, + pub http3_retry_failed_total: u64, + pub http3_retry_issued_total: u64, + pub listeners: Vec, pub mtls: MtlsStatusSnapshot, - pub upstream_tls: Vec, - pub cache: CacheStatsSnapshot, - pub active_connections: usize, - #[serde(default)] - pub apply: ApplyStatusSnapshot, + #[serde(flatten)] + pub node: NodeIdentitySnapshot, pub reload: ReloadStatusSnapshot, + pub revision: u64, + pub tls: TlsRuntimeSnapshot, + pub tls_enabled: bool, + pub total_routes: usize, + pub total_upstreams: usize, + pub total_vhosts: usize, + pub upstream_tls: Vec, + pub worker_threads: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RuntimeListenerSnapshot { - pub listener_id: String, - pub listener_name: String, - pub listen_addr: std::net::SocketAddr, + pub access_log_format_configured: bool, pub binding_count: usize, - pub http3_enabled: bool, - pub tls_enabled: bool, - pub proxy_protocol_enabled: bool, + pub bindings: Vec, pub default_certificate: Option, - pub keep_alive: bool, - pub max_connections: Option, - pub access_log_format_configured: bool, + pub http3_enabled: bool, #[serde(skip_serializing_if = "Option::is_none")] pub http3_runtime: Option, - pub bindings: Vec, + pub keep_alive: bool, + pub listen_addr: std::net::SocketAddr, + pub listener_id: String, + pub listener_name: String, + pub max_connections: Option, + pub proxy_protocol_enabled: bool, + pub tls_enabled: bool, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RuntimeListenerBindingSnapshot { - pub binding_name: String, - pub transport: String, - pub listen_addr: std::net::SocketAddr, - pub protocols: Vec, - pub worker_count: usize, - #[serde(skip_serializing_if = "Option::is_none")] - pub reuse_port_enabled: Option, #[serde(skip_serializing_if = "Option::is_none")] pub advertise_alt_svc: Option, #[serde(skip_serializing_if = "Option::is_none")] pub alt_svc_max_age_secs: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub http3_max_concurrent_streams: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub http3_stream_buffer_size: Option, + pub binding_name: String, #[serde(skip_serializing_if = "Option::is_none")] pub http3_active_connection_id_limit: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub http3_retry: Option, + pub http3_early_data_enabled: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub http3_gso: Option, #[serde(skip_serializing_if = "Option::is_none")] pub http3_host_key_path: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub http3_gso: Option, + pub http3_max_concurrent_streams: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub http3_early_data_enabled: Option, + pub http3_retry: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub http3_stream_buffer_size: Option, + pub listen_addr: std::net::SocketAddr, + pub protocols: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub reuse_port_enabled: Option, + pub transport: String, + pub worker_count: usize, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct Http3ListenerRuntimeSnapshot { pub active_connections: usize, pub active_request_streams: usize, - pub retry_issued_total: u64, - pub retry_failed_total: u64, - pub request_accept_errors_total: u64, - pub request_resolve_errors_total: u64, - pub request_body_stream_errors_total: u64, - pub response_stream_errors_total: u64, - pub connection_close_version_mismatch_total: u64, - pub connection_close_transport_error_total: u64, - pub connection_close_connection_closed_total: u64, pub connection_close_application_closed_total: u64, + pub connection_close_cids_exhausted_total: u64, + pub connection_close_connection_closed_total: u64, + pub connection_close_locally_closed_total: u64, pub connection_close_reset_total: u64, pub connection_close_timed_out_total: u64, - pub connection_close_locally_closed_total: u64, - pub connection_close_cids_exhausted_total: u64, + pub connection_close_transport_error_total: u64, + pub connection_close_version_mismatch_total: u64, + pub request_accept_errors_total: u64, + pub request_body_stream_errors_total: u64, + pub request_resolve_errors_total: u64, + pub response_stream_errors_total: u64, + pub retry_failed_total: u64, + pub retry_issued_total: u64, } diff --git a/crates/rginx-http/src/state/snapshots/tls.rs b/crates/rginx-http/src/state/snapshots/tls.rs index e7a71450..9cce93ed 100644 --- a/crates/rginx-http/src/state/snapshots/tls.rs +++ b/crates/rginx-http/src/state/snapshots/tls.rs @@ -10,133 +10,133 @@ pub struct TlsReloadBoundarySnapshot { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TlsListenerStatusSnapshot { - pub listener_id: String, - pub listener_name: String, - pub listen_addr: std::net::SocketAddr, - pub tls_enabled: bool, - pub http3_enabled: bool, - pub http3_listen_addr: Option, - pub default_certificate: Option, - pub versions: Option>, pub alpn_protocols: Vec, - pub http3_versions: Vec, - pub http3_alpn_protocols: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub http3_max_concurrent_streams: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub http3_stream_buffer_size: Option, + pub client_auth_crl_configured: bool, + pub client_auth_mode: Option, + pub client_auth_verify_depth: Option, + pub default_certificate: Option, #[serde(skip_serializing_if = "Option::is_none")] pub http3_active_connection_id_limit: Option, + pub http3_alpn_protocols: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub http3_retry: Option, + pub http3_early_data_enabled: Option, + pub http3_enabled: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub http3_gso: Option, #[serde(skip_serializing_if = "Option::is_none")] pub http3_host_key_path: Option, + pub http3_listen_addr: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub http3_gso: Option, + pub http3_max_concurrent_streams: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub http3_early_data_enabled: Option, - pub session_resumption_enabled: Option, - pub session_tickets_enabled: Option, + pub http3_retry: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub http3_stream_buffer_size: Option, + pub http3_versions: Vec, + pub listen_addr: std::net::SocketAddr, + pub listener_id: String, + pub listener_name: String, pub session_cache_size: Option, + pub session_resumption_enabled: Option, pub session_ticket_count: Option, - pub client_auth_mode: Option, - pub client_auth_verify_depth: Option, - pub client_auth_crl_configured: bool, + pub session_tickets_enabled: Option, pub sni_names: Vec, + pub tls_enabled: bool, + pub versions: Option>, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TlsCertificateStatusSnapshot { - pub scope: String, + pub additional_certificate_count: usize, + pub authority_key_identifier: Option, pub cert_path: PathBuf, - pub server_names: Vec, - pub subject: Option, - pub issuer: Option, - pub serial_number: Option, - pub san_dns_names: Vec, + pub chain_diagnostics: Vec, + pub chain_length: usize, + pub chain_subjects: Vec, + pub expires_in_days: Option, + pub extended_key_usage: Vec, pub fingerprint_sha256: Option, - pub subject_key_identifier: Option, - pub authority_key_identifier: Option, pub is_ca: Option, - pub path_len_constraint: Option, + pub issuer: Option, pub key_usage: Option, - pub extended_key_usage: Vec, - pub not_before_unix_ms: Option, pub not_after_unix_ms: Option, - pub expires_in_days: Option, - pub chain_length: usize, - pub chain_subjects: Vec, - pub chain_diagnostics: Vec, - pub selected_as_default_for_listeners: Vec, + pub not_before_unix_ms: Option, pub ocsp_staple_configured: bool, - pub additional_certificate_count: usize, + pub path_len_constraint: Option, + pub san_dns_names: Vec, + pub scope: String, + pub selected_as_default_for_listeners: Vec, + pub serial_number: Option, + pub server_names: Vec, + pub subject: Option, + pub subject_key_identifier: Option, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TlsOcspStatusSnapshot { - pub scope: String, - pub cert_path: PathBuf, - pub ocsp_staple_path: Option, - pub responder_urls: Vec, - pub nonce_mode: String, - pub responder_policy: String, + pub auto_refresh_enabled: bool, pub cache_loaded: bool, - pub cache_size_bytes: Option, pub cache_modified_unix_ms: Option, - pub auto_refresh_enabled: bool, - pub last_refresh_unix_ms: Option, - pub refreshes_total: u64, + pub cache_size_bytes: Option, + pub cert_path: PathBuf, pub failures_total: u64, pub last_error: Option, + pub last_refresh_unix_ms: Option, + pub nonce_mode: String, + pub ocsp_staple_path: Option, + pub refreshes_total: u64, + pub responder_policy: String, + pub responder_urls: Vec, + pub scope: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct TlsOcspRefreshSpec { - pub scope: String, - pub cert_path: PathBuf, - pub ocsp_staple_path: Option, - pub responder_urls: Vec, pub auto_refresh_enabled: bool, + pub cert_path: PathBuf, pub ocsp_nonce_mode: rginx_core::OcspNonceMode, pub ocsp_responder_policy: rginx_core::OcspResponderPolicy, + pub ocsp_staple_path: Option, + pub responder_urls: Vec, + pub scope: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TlsVhostBindingSnapshot { - pub listener_name: String, - pub vhost_id: String, - pub server_names: Vec, pub certificate_scopes: Vec, - pub fingerprints: Vec, pub default_selected: bool, + pub fingerprints: Vec, + pub listener_name: String, + pub server_names: Vec, + pub vhost_id: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TlsSniBindingSnapshot { - pub listener_name: String, - pub server_name: String, pub certificate_scopes: Vec, - pub fingerprints: Vec, pub default_selected: bool, + pub fingerprints: Vec, + pub listener_name: String, + pub server_name: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TlsDefaultCertificateBindingSnapshot { - pub listener_name: String, - pub server_name: String, pub certificate_scopes: Vec, pub fingerprints: Vec, + pub listener_name: String, + pub server_name: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TlsRuntimeSnapshot { - pub listeners: Vec, pub certificates: Vec, + pub default_certificate_bindings: Vec, + pub expiring_certificate_count: usize, + pub listeners: Vec, pub ocsp: Vec, - pub vhost_bindings: Vec, + pub reload_boundary: TlsReloadBoundarySnapshot, pub sni_bindings: Vec, pub sni_conflicts: Vec, - pub default_certificate_bindings: Vec, - pub reload_boundary: TlsReloadBoundarySnapshot, - pub expiring_certificate_count: usize, + pub vhost_bindings: Vec, } diff --git a/crates/rginx-http/src/state/snapshots/traffic.rs b/crates/rginx-http/src/state/snapshots/traffic.rs index 752db5d7..44597dd5 100644 --- a/crates/rginx-http/src/state/snapshots/traffic.rs +++ b/crates/rginx-http/src/state/snapshots/traffic.rs @@ -4,34 +4,28 @@ use super::Http3ListenerRuntimeSnapshot; #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct GrpcTrafficSnapshot { - pub requests_total: u64, pub protocol_grpc_total: u64, - pub protocol_grpc_web_total: u64, pub protocol_grpc_web_text_total: u64, + pub protocol_grpc_web_total: u64, + pub requests_total: u64, pub status_0_total: u64, + pub status_12_total: u64, + pub status_14_total: u64, pub status_1_total: u64, pub status_3_total: u64, pub status_4_total: u64, pub status_7_total: u64, pub status_8_total: u64, - pub status_12_total: u64, - pub status_14_total: u64, pub status_other_total: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ListenerStatsSnapshot { - pub listener_id: String, - pub listener_name: String, - pub listen_addr: std::net::SocketAddr, pub active_connections: usize, - #[serde(skip_serializing_if = "Option::is_none")] - pub http3_runtime: Option, pub downstream_connections_accepted: u64, pub downstream_connections_rejected: u64, - pub downstream_requests: u64, pub downstream_request_bytes_total: u64, - pub unmatched_requests_total: u64, + pub downstream_requests: u64, pub downstream_response_bytes_total: u64, pub downstream_responses: u64, pub downstream_responses_1xx: u64, @@ -39,19 +33,22 @@ pub struct ListenerStatsSnapshot { pub downstream_responses_3xx: u64, pub downstream_responses_4xx: u64, pub downstream_responses_5xx: u64, + pub grpc: GrpcTrafficSnapshot, + #[serde(skip_serializing_if = "Option::is_none")] + pub http3_runtime: Option, + pub listen_addr: std::net::SocketAddr, + pub listener_id: String, + pub listener_name: String, pub recent_60s: RecentTrafficStatsSnapshot, #[serde(skip_serializing_if = "Option::is_none")] pub recent_window: Option, - pub grpc: GrpcTrafficSnapshot, + pub unmatched_requests_total: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct VhostStatsSnapshot { - pub vhost_id: String, - pub server_names: Vec, - pub downstream_requests: u64, pub downstream_request_bytes_total: u64, - pub unmatched_requests_total: u64, + pub downstream_requests: u64, pub downstream_response_bytes_total: u64, pub downstream_responses: u64, pub downstream_responses_1xx: u64, @@ -59,18 +56,20 @@ pub struct VhostStatsSnapshot { pub downstream_responses_3xx: u64, pub downstream_responses_4xx: u64, pub downstream_responses_5xx: u64, + pub grpc: GrpcTrafficSnapshot, pub recent_60s: RecentTrafficStatsSnapshot, #[serde(skip_serializing_if = "Option::is_none")] pub recent_window: Option, - pub grpc: GrpcTrafficSnapshot, + pub server_names: Vec, + pub unmatched_requests_total: u64, + pub vhost_id: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct RouteStatsSnapshot { - pub route_id: String, - pub vhost_id: String, - pub downstream_requests: u64, + pub access_denied_total: u64, pub downstream_request_bytes_total: u64, + pub downstream_requests: u64, pub downstream_response_bytes_total: u64, pub downstream_responses: u64, pub downstream_responses_1xx: u64, @@ -78,30 +77,31 @@ pub struct RouteStatsSnapshot { pub downstream_responses_3xx: u64, pub downstream_responses_4xx: u64, pub downstream_responses_5xx: u64, - pub access_denied_total: u64, + pub grpc: GrpcTrafficSnapshot, pub rate_limited_total: u64, pub recent_60s: RecentTrafficStatsSnapshot, #[serde(skip_serializing_if = "Option::is_none")] pub recent_window: Option, - pub grpc: GrpcTrafficSnapshot, + pub route_id: String, + pub vhost_id: String, } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct RecentTrafficStatsSnapshot { - pub window_secs: u64, - pub downstream_requests_total: u64, pub downstream_request_bytes_total: u64, - pub downstream_responses_total: u64, + pub downstream_requests_total: u64, pub downstream_response_bytes_total: u64, pub downstream_responses_2xx_total: u64, pub downstream_responses_4xx_total: u64, pub downstream_responses_5xx_total: u64, + pub downstream_responses_total: u64, pub grpc_requests_total: u64, + pub window_secs: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct TrafficStatsSnapshot { pub listeners: Vec, - pub vhosts: Vec, pub routes: Vec, + pub vhosts: Vec, } diff --git a/crates/rginx-http/src/state/snapshots/upstreams.rs b/crates/rginx-http/src/state/snapshots/upstreams.rs index 3bed4a03..f9e49b1e 100644 --- a/crates/rginx-http/src/state/snapshots/upstreams.rs +++ b/crates/rginx-http/src/state/snapshots/upstreams.rs @@ -2,70 +2,70 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UpstreamTlsStatusSnapshot { - pub upstream_name: String, - pub protocol: String, - pub verify_mode: String, - pub tls_versions: Option>, - pub pool_idle_timeout_secs: Option, - pub pool_max_idle_per_host: usize, - pub tcp_keepalive_secs: Option, + pub client_identity_configured: bool, + pub crl_configured: bool, pub http2_keep_alive_interval_secs: Option, pub http2_keep_alive_timeout_secs: u64, pub http2_keep_alive_while_idle: bool, + pub pool_idle_timeout_secs: Option, + pub pool_max_idle_per_host: usize, + pub protocol: String, pub server_name_enabled: bool, pub server_name_override: Option, + pub tcp_keepalive_secs: Option, + pub tls_versions: Option>, + pub upstream_name: String, pub verify_depth: Option, - pub crl_configured: bool, - pub client_identity_configured: bool, + pub verify_mode: String, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UpstreamPeerStatsSnapshot { - pub peer_url: String, pub attempts_total: u64, - pub successes_total: u64, pub failures_total: u64, + pub peer_url: String, + pub successes_total: u64, pub timeouts_total: u64, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct UpstreamStatsSnapshot { - pub upstream_name: String, - pub tls: UpstreamTlsStatusSnapshot, - pub downstream_requests_total: u64, - pub downstream_request_bytes_total: u64, - pub peer_attempts_total: u64, - pub peer_successes_total: u64, - pub peer_failures_total: u64, - pub peer_timeouts_total: u64, - pub failovers_total: u64, + pub bad_gateway_responses_total: u64, + pub bad_request_responses_total: u64, pub completed_responses_total: u64, + pub downstream_request_bytes_total: u64, + pub downstream_requests_total: u64, pub downstream_response_bytes_total: u64, - pub bad_gateway_responses_total: u64, + pub failovers_total: u64, pub gateway_timeout_responses_total: u64, - pub bad_request_responses_total: u64, - pub payload_too_large_responses_total: u64, - pub unsupported_media_type_responses_total: u64, pub no_healthy_peers_total: u64, - pub tls_failures_unknown_ca_total: u64, - pub tls_failures_bad_certificate_total: u64, - pub tls_failures_certificate_revoked_total: u64, - pub tls_failures_verify_depth_exceeded_total: u64, + pub payload_too_large_responses_total: u64, + pub peer_attempts_total: u64, + pub peer_failures_total: u64, + pub peer_successes_total: u64, + pub peer_timeouts_total: u64, + pub peers: Vec, pub recent_60s: RecentUpstreamStatsSnapshot, #[serde(skip_serializing_if = "Option::is_none")] pub recent_window: Option, - pub peers: Vec, + pub tls: UpstreamTlsStatusSnapshot, + pub tls_failures_bad_certificate_total: u64, + pub tls_failures_certificate_revoked_total: u64, + pub tls_failures_unknown_ca_total: u64, + pub tls_failures_verify_depth_exceeded_total: u64, + pub unsupported_media_type_responses_total: u64, + pub upstream_name: String, } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct RecentUpstreamStatsSnapshot { - pub window_secs: u64, - pub downstream_requests_total: u64, - pub downstream_request_bytes_total: u64, - pub peer_attempts_total: u64, + pub bad_gateway_responses_total: u64, pub completed_responses_total: u64, + pub downstream_request_bytes_total: u64, + pub downstream_requests_total: u64, pub downstream_response_bytes_total: u64, - pub bad_gateway_responses_total: u64, - pub gateway_timeout_responses_total: u64, pub failovers_total: u64, + pub gateway_timeout_responses_total: u64, + pub peer_attempts_total: u64, + pub window_secs: u64, } diff --git a/crates/rginx-http/src/state/structure.rs b/crates/rginx-http/src/state/structure.rs index 2617206e..3988bde7 100644 --- a/crates/rginx-http/src/state/structure.rs +++ b/crates/rginx-http/src/state/structure.rs @@ -1,70 +1,69 @@ -use super::*; +use super::{ + AcmeRuntimeStatusEntry, AgentRuntimeState, ApplyHistory, Arc, AtomicBool, AtomicU64, + AtomicUsize, ConfigSnapshot, HashMap, HttpCounters, JoinHandle, Listener, Mutex, + NodeIdentityState, Notify, OcspRuntimeStatusEntry, Ordering, RateLimiters, ReloadHistory, + RwLock, SnapshotComponentVersions, StdRwLock, TlsAcceptor, TrafficComponentVersions, + TrafficStatsIndex, UpstreamStatsEntry, build_traffic_component_versions, + build_traffic_stats_index, build_upstream_name_versions, build_upstream_stats_map, cache, + watch, +}; #[derive(Clone)] pub(crate) struct SnapshotBusState { - pub(crate) version: Arc, - pub(crate) notify: Arc, pub(crate) components: Arc, desired_revision: Arc, + pub(crate) notify: Arc, + pub(crate) version: Arc, } #[derive(Clone)] pub(super) struct ListenerRuntimeState { - pub(super) tls_acceptors: Arc>>>, pub(super) active_connections: Arc>>>, pub(super) retired: Arc>>, + pub(super) tls_acceptors: Arc>>>, } #[derive(Clone)] pub(super) struct RequestRuntimeState { - pub(super) rate_limiters: RateLimiters, pub(super) active_connections: Arc, + pub(super) rate_limiters: RateLimiters, } #[derive(Clone)] pub(super) struct ObservabilityState { + pub(super) cache_component_versions: Arc>>, pub(super) counters: Arc, - pub(super) traffic_stats: Arc>, + pub(super) peer_health_component_versions: Arc>>, pub(super) traffic_component_versions: Arc>, - pub(super) upstream_stats: Arc>>, + pub(super) traffic_stats: Arc>, pub(super) upstream_component_versions: Arc>>, - pub(super) peer_health_component_versions: Arc>>, - pub(super) cache_component_versions: Arc>>, + pub(super) upstream_stats: Arc>>, } #[derive(Clone)] pub(super) struct LifecycleState { - pub(super) background_tasks: Arc>>>, - pub(super) node_identity: Arc>, - pub(super) node_identity_override: Arc>>, - pub(super) agent_runtime: Arc>, + pub(super) acme_http01_challenges: Arc>>, + pub(super) acme_statuses: Arc>>, pub(super) agent_disabled: watch::Sender, pub(super) agent_disabled_value: Arc, - pub(super) reload_history: Arc>, + pub(super) agent_runtime: Arc>, pub(super) apply_history: Arc>, + pub(super) background_tasks: Arc>>>, + pub(super) node_identity: Arc>, + pub(super) node_identity_override: Arc>>, pub(super) ocsp_statuses: Arc>>, - pub(super) acme_statuses: Arc>>, - pub(super) acme_http01_challenges: Arc>>, + pub(super) reload_history: Arc>, } impl SnapshotBusState { - pub(super) fn new(revision: u64) -> Self { - Self { - version: Arc::new(AtomicU64::new(0)), - notify: Arc::new(Notify::new()), - components: Arc::new(SnapshotComponentVersions::default()), - desired_revision: Arc::new(AtomicU64::new(revision)), - } + pub(super) fn current_version(&self) -> u64 { + self.version.load(Ordering::Acquire) } pub(super) fn desired_revision(&self) -> u64 { self.desired_revision.load(Ordering::Acquire) } - pub(super) fn current_version(&self) -> u64 { - self.version.load(Ordering::Acquire) - } - pub(super) fn mark_changed( &self, status: bool, @@ -74,7 +73,7 @@ impl SnapshotBusState { upstreams: bool, cache: bool, ) -> u64 { - let version = self.version.fetch_add(1, Ordering::AcqRel) + 1; + let version = self.version.fetch_add(1, Ordering::AcqRel).saturating_add(1); if status { self.components.status.store(version, Ordering::Release); } @@ -110,6 +109,15 @@ impl SnapshotBusState { version } + pub(super) fn new(revision: u64) -> Self { + Self { + version: Arc::new(AtomicU64::new(0)), + notify: Arc::new(Notify::new()), + components: Arc::new(SnapshotComponentVersions::default()), + desired_revision: Arc::new(AtomicU64::new(revision)), + } + } + pub(super) fn notify_waiters(&self) { self.notify.notify_waiters(); } @@ -147,9 +155,11 @@ impl ObservabilityState { peer_health_component_versions: Arc>>, cache_component_versions: Arc>>, ) -> Self { - *peer_health_component_versions.write().unwrap_or_else(|poisoned| poisoned.into_inner()) = + *peer_health_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = build_upstream_name_versions(config, None); - *cache_component_versions.write().unwrap_or_else(|poisoned| poisoned.into_inner()) = + *cache_component_versions.write().unwrap_or_else(std::sync::PoisonError::into_inner) = cache::build_cache_zone_versions(config, None); Self { diff --git a/crates/rginx-http/src/state/tests.rs b/crates/rginx-http/src/state/tests.rs index 41f3f2fc..67112687 100644 --- a/crates/rginx-http/src/state/tests.rs +++ b/crates/rginx-http/src/state/tests.rs @@ -1,3 +1,13 @@ +mod counters; +mod snapshots; +mod status; +mod status_cache; +mod support; +mod tls; +mod traffic; +mod transition; +mod upstreams; + use std::collections::HashMap; use std::fs; use std::path::PathBuf; @@ -19,14 +29,4 @@ use super::{ TlsHandshakeFailureReason, inspect_certificate, validate_config_transition, }; -mod counters; -mod snapshots; -mod status; -mod status_cache; -mod support; -mod tls; -mod traffic; -mod transition; -mod upstreams; - pub(crate) use support::*; diff --git a/crates/rginx-http/src/state/tests/status.rs b/crates/rginx-http/src/state/tests/status.rs index ad94ec42..fd9ee06b 100644 --- a/crates/rginx-http/src/state/tests/status.rs +++ b/crates/rginx-http/src/state/tests/status.rs @@ -1,8 +1,8 @@ +mod http3; + use super::*; use tempfile::tempdir; -mod http3; - fn managed_acme_status_config(cert_path: PathBuf, key_path: PathBuf) -> ConfigSnapshot { let mut config = snapshot("127.0.0.1:8080"); config.acme = Some(rginx_core::AcmeSettings { diff --git a/crates/rginx-http/src/state/tls_runtime/bindings.rs b/crates/rginx-http/src/state/tls_runtime/bindings.rs index 4266fae5..d3e2d6fe 100644 --- a/crates/rginx-http/src/state/tls_runtime/bindings.rs +++ b/crates/rginx-http/src/state/tls_runtime/bindings.rs @@ -1,4 +1,9 @@ -use super::*; +#[cfg(test)] +mod tests; +use super::{ + ConfigSnapshot, Listener, TlsCertificateStatusSnapshot, TlsDefaultCertificateBindingSnapshot, + TlsSniBindingSnapshot, TlsVhostBindingSnapshot, +}; pub(super) fn tls_binding_snapshots( config: &ConfigSnapshot, @@ -171,6 +176,3 @@ fn tls_certificate_scope_for_listener_vhost( None } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/state/tls_runtime/certificates.rs b/crates/rginx-http/src/state/tls_runtime/certificates.rs index 35cc6ed8..0bebb22e 100644 --- a/crates/rginx-http/src/state/tls_runtime/certificates.rs +++ b/crates/rginx-http/src/state/tls_runtime/certificates.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{ConfigSnapshot, Listener, TlsCertificateStatusSnapshot}; use crate::pki::inspect_certificate; @@ -60,7 +60,7 @@ fn build_listener_certificate_snapshot( .and_then(|certificate| certificate.not_before_unix_ms), not_after_unix_ms: inspected.as_ref().and_then(|certificate| certificate.not_after_unix_ms), expires_in_days: inspected.as_ref().and_then(|certificate| certificate.expires_in_days), - chain_length: inspected.as_ref().map(|certificate| certificate.chain_length).unwrap_or(0), + chain_length: inspected.as_ref().map_or(0, |certificate| certificate.chain_length), chain_subjects: inspected .as_ref() .map(|certificate| certificate.chain_subjects.clone()) @@ -119,7 +119,7 @@ fn build_vhost_certificate_snapshot( .and_then(|certificate| certificate.not_before_unix_ms), not_after_unix_ms: inspected.as_ref().and_then(|certificate| certificate.not_after_unix_ms), expires_in_days: inspected.as_ref().and_then(|certificate| certificate.expires_in_days), - chain_length: inspected.as_ref().map(|certificate| certificate.chain_length).unwrap_or(0), + chain_length: inspected.as_ref().map_or(0, |certificate| certificate.chain_length), chain_subjects: inspected .as_ref() .map(|certificate| certificate.chain_subjects.clone()) diff --git a/crates/rginx-http/src/state/tls_runtime/listeners.rs b/crates/rginx-http/src/state/tls_runtime/listeners.rs index 9963a9c9..0c9aa914 100644 --- a/crates/rginx-http/src/state/tls_runtime/listeners.rs +++ b/crates/rginx-http/src/state/tls_runtime/listeners.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{ConfigSnapshot, TlsListenerStatusSnapshot}; pub(super) fn tls_listener_status_snapshots( config: &ConfigSnapshot, diff --git a/crates/rginx-http/src/state/tls_runtime/mod.rs b/crates/rginx-http/src/state/tls_runtime/mod.rs index 60408dac..d5d1b793 100644 --- a/crates/rginx-http/src/state/tls_runtime/mod.rs +++ b/crates/rginx-http/src/state/tls_runtime/mod.rs @@ -1,5 +1,3 @@ -use super::*; - mod bindings; mod certificates; mod listeners; @@ -7,6 +5,14 @@ mod ocsp; mod reload_boundary; mod upstreams; +use super::{ + ConfigSnapshot, HashMap, Listener, OcspRuntimeStatusEntry, PathBuf, TLS_EXPIRY_WARNING_DAYS, + TlsCertificateStatusSnapshot, TlsDefaultCertificateBindingSnapshot, TlsListenerStatusSnapshot, + TlsOcspRefreshSpec, TlsOcspStatusSnapshot, TlsReloadBoundarySnapshot, TlsRuntimeSnapshot, + TlsSniBindingSnapshot, TlsVhostBindingSnapshot, UpstreamTlsStatusSnapshot, + ocsp_responder_urls_for_certificate, unix_time_ms, +}; + use bindings::tls_binding_snapshots; use certificates::tls_certificate_status_snapshots; use listeners::tls_listener_status_snapshots; diff --git a/crates/rginx-http/src/state/tls_runtime/ocsp.rs b/crates/rginx-http/src/state/tls_runtime/ocsp.rs index 813d5c76..707f73c8 100644 --- a/crates/rginx-http/src/state/tls_runtime/ocsp.rs +++ b/crates/rginx-http/src/state/tls_runtime/ocsp.rs @@ -1,11 +1,14 @@ -use super::*; +use super::{ + ConfigSnapshot, HashMap, OcspRuntimeStatusEntry, PathBuf, TlsOcspRefreshSpec, + TlsOcspStatusSnapshot, ocsp_responder_urls_for_certificate, unix_time_ms, +}; #[derive(Debug, Clone)] struct TlsOcspBundleSpec { - scope: String, cert_path: PathBuf, - ocsp_staple_path: Option, ocsp: rginx_core::OcspConfig, + ocsp_staple_path: Option, + scope: String, } pub fn tls_ocsp_refresh_specs_for_config(config: &ConfigSnapshot) -> Vec { @@ -106,11 +109,10 @@ fn build_tls_ocsp_status_snapshot( None }; - let (cache_loaded, cache_size_bytes, cache_modified_unix_ms, cache_error) = bundle - .ocsp_staple_path - .as_ref() - .map(|path| inspect_ocsp_cache_file(&bundle.cert_path, path, bundle.ocsp.responder_policy)) - .unwrap_or((false, None, None, None)); + let (cache_loaded, cache_size_bytes, cache_modified_unix_ms, cache_error) = + bundle.ocsp_staple_path.as_ref().map_or((false, None, None, None), |path| { + inspect_ocsp_cache_file(&bundle.cert_path, path, bundle.ocsp.responder_policy) + }); let runtime = runtime_statuses.and_then(|statuses| statuses.get(&bundle.scope)); let static_error = cache_error.or(responder_error).or_else(|| { if bundle.ocsp_staple_path.is_some() && refresh_spec.responder_urls.is_empty() { @@ -139,8 +141,8 @@ fn build_tls_ocsp_status_snapshot( cache_modified_unix_ms, auto_refresh_enabled: refresh_spec.auto_refresh_enabled, last_refresh_unix_ms: runtime.and_then(|entry| entry.last_refresh_unix_ms), - refreshes_total: runtime.map(|entry| entry.refreshes_total).unwrap_or(0), - failures_total: runtime.map(|entry| entry.failures_total).unwrap_or(0), + refreshes_total: runtime.map_or(0, |entry| entry.refreshes_total), + failures_total: runtime.map_or(0, |entry| entry.failures_total), last_error: runtime.and_then(|entry| entry.last_error.clone()).or(static_error), }) } @@ -222,7 +224,7 @@ fn inspect_ocsp_cache_file( let cache_error = match std::fs::File::open(path).and_then(|file| { use std::io::Read; - let mut reader = file.take(crate::MAX_OCSP_RESPONSE_BYTES as u64 + 1); + let mut reader = file.take((crate::MAX_OCSP_RESPONSE_BYTES as u64).saturating_add(1)); let mut bytes = Vec::new(); reader.read_to_end(&mut bytes)?; Ok(bytes) diff --git a/crates/rginx-http/src/state/tls_runtime/reload_boundary.rs b/crates/rginx-http/src/state/tls_runtime/reload_boundary.rs index 64bdc57b..92ac6a31 100644 --- a/crates/rginx-http/src/state/tls_runtime/reload_boundary.rs +++ b/crates/rginx-http/src/state/tls_runtime/reload_boundary.rs @@ -1,9 +1,11 @@ use crate::config_transition_boundary; +#[must_use] pub fn tls_reloadable_fields() -> Vec { config_transition_boundary().reloadable_fields } +#[must_use] pub fn tls_restart_required_fields() -> Vec { config_transition_boundary().restart_required_fields } diff --git a/crates/rginx-http/src/state/tls_runtime/upstreams.rs b/crates/rginx-http/src/state/tls_runtime/upstreams.rs index 6b9d4840..2834fa6a 100644 --- a/crates/rginx-http/src/state/tls_runtime/upstreams.rs +++ b/crates/rginx-http/src/state/tls_runtime/upstreams.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{ConfigSnapshot, UpstreamTlsStatusSnapshot}; pub(super) fn upstream_tls_status_snapshots( config: &ConfigSnapshot, diff --git a/crates/rginx-http/src/state/traffic/record.rs b/crates/rginx-http/src/state/traffic/record.rs index 6787c953..af917857 100644 --- a/crates/rginx-http/src/state/traffic/record.rs +++ b/crates/rginx-http/src/state/traffic/record.rs @@ -1,4 +1,4 @@ -use super::super::*; +use super::super::{Ordering, SharedState, StatusCode}; impl SharedState { pub(crate) fn record_downstream_request( @@ -40,30 +40,6 @@ impl SharedState { ); } - pub(crate) fn record_mtls_request(&self, listener_id: &str, authenticated: bool) { - if authenticated { - self.observability - .counters - .downstream_mtls_authenticated_requests - .fetch_add(1, Ordering::Relaxed); - } else { - self.observability - .counters - .downstream_mtls_anonymous_requests - .fetch_add(1, Ordering::Relaxed); - } - - if let Some(counters) = self.listener_traffic_counters(listener_id) { - if authenticated { - counters.downstream_mtls_authenticated_requests.fetch_add(1, Ordering::Relaxed); - } else { - counters.downstream_mtls_anonymous_requests.fetch_add(1, Ordering::Relaxed); - } - } - - self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); - } - pub(crate) fn record_downstream_response( &self, listener_id: &str, @@ -131,20 +107,6 @@ impl SharedState { ); } - pub(crate) fn record_route_access_denied(&self, route_id: &str) { - if let Some(counters) = self.route_traffic_counters(route_id) { - counters.access_denied_total.fetch_add(1, Ordering::Relaxed); - } - self.mark_traffic_snapshot_changed(false, false, None, None, Some(route_id)); - } - - pub(crate) fn record_route_rate_limited(&self, route_id: &str) { - if let Some(counters) = self.route_traffic_counters(route_id) { - counters.rate_limited_total.fetch_add(1, Ordering::Relaxed); - } - self.mark_traffic_snapshot_changed(false, false, None, None, Some(route_id)); - } - pub(crate) fn record_grpc_request( &self, listener_id: &str, @@ -199,4 +161,42 @@ impl SharedState { route_id, ); } + + pub(crate) fn record_mtls_request(&self, listener_id: &str, authenticated: bool) { + if authenticated { + self.observability + .counters + .downstream_mtls_authenticated_requests + .fetch_add(1, Ordering::Relaxed); + } else { + self.observability + .counters + .downstream_mtls_anonymous_requests + .fetch_add(1, Ordering::Relaxed); + } + + if let Some(counters) = self.listener_traffic_counters(listener_id) { + if authenticated { + counters.downstream_mtls_authenticated_requests.fetch_add(1, Ordering::Relaxed); + } else { + counters.downstream_mtls_anonymous_requests.fetch_add(1, Ordering::Relaxed); + } + } + + self.mark_traffic_snapshot_changed(true, true, Some(listener_id), None, None); + } + + pub(crate) fn record_route_access_denied(&self, route_id: &str) { + if let Some(counters) = self.route_traffic_counters(route_id) { + counters.access_denied_total.fetch_add(1, Ordering::Relaxed); + } + self.mark_traffic_snapshot_changed(false, false, None, None, Some(route_id)); + } + + pub(crate) fn record_route_rate_limited(&self, route_id: &str) { + if let Some(counters) = self.route_traffic_counters(route_id) { + counters.rate_limited_total.fetch_add(1, Ordering::Relaxed); + } + self.mark_traffic_snapshot_changed(false, false, None, None, Some(route_id)); + } } diff --git a/crates/rginx-http/src/state/traffic/refs.rs b/crates/rginx-http/src/state/traffic/refs.rs index 622cf7bb..78b09165 100644 --- a/crates/rginx-http/src/state/traffic/refs.rs +++ b/crates/rginx-http/src/state/traffic/refs.rs @@ -1,7 +1,7 @@ -use super::super::*; +use super::super::{Arc, ListenerTrafficCounters, RequestTrafficCounters, RouteTrafficCounters}; pub(super) struct TrafficCounterRefs { pub(super) listener: Option>, - pub(super) vhost: Option>, pub(super) route: Option>, + pub(super) vhost: Option>, } diff --git a/crates/rginx-http/src/state/traffic/snapshot.rs b/crates/rginx-http/src/state/traffic/snapshot.rs index 81c7c785..a7af0432 100644 --- a/crates/rginx-http/src/state/traffic/snapshot.rs +++ b/crates/rginx-http/src/state/traffic/snapshot.rs @@ -1,14 +1,65 @@ -use super::super::*; +use super::super::{ + Arc, HttpCountersSnapshot, ListenerStatsSnapshot, ListenerTrafficCounters, Ordering, + RouteStatsSnapshot, RouteTrafficCounters, SharedState, TrafficStatsSnapshot, + VhostStatsSnapshot, http3_runtime_snapshot, +}; impl SharedState { + #[must_use] pub fn counters_snapshot(&self) -> HttpCountersSnapshot { self.observability.counters.snapshot() } + pub(in crate::state) fn listener_traffic_counters( + &self, + listener_id: &str, + ) -> Option> { + let stats = self + .observability + .traffic_stats + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + stats.listeners.get(listener_id).map(|entry| entry.counters.clone()) + } + + pub(super) fn route_traffic_counters( + &self, + route_id: &str, + ) -> Option> { + let stats = self + .observability + .traffic_stats + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + stats.routes.get(route_id).map(|entry| entry.counters.clone()) + } + + pub(super) fn traffic_counter_refs( + &self, + listener_id: &str, + vhost_id: &str, + route_id: Option<&str>, + ) -> super::TrafficCounterRefs { + let stats = self + .observability + .traffic_stats + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + super::TrafficCounterRefs { + listener: stats.listeners.get(listener_id).map(|entry| entry.counters.clone()), + vhost: stats.vhosts.get(vhost_id).map(|entry| entry.counters.clone()), + route: route_id.and_then(|route_id| { + stats.routes.get(route_id).map(|entry| entry.counters.clone()) + }), + } + } + + #[must_use] pub fn traffic_stats_snapshot(&self) -> TrafficStatsSnapshot { self.traffic_stats_snapshot_with_window(None) } + #[must_use] pub fn traffic_stats_snapshot_with_window( &self, window_secs: Option, @@ -17,7 +68,7 @@ impl SharedState { .observability .traffic_stats .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); let listeners = stats .listener_order @@ -40,10 +91,9 @@ impl SharedState { .listener_runtime .active_connections .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()) + .unwrap_or_else(std::sync::PoisonError::into_inner) .get(listener_id) - .map(|connections| connections.load(Ordering::Acquire)) - .unwrap_or(0), + .map_or(0, |connections| connections.load(Ordering::Acquire)), http3_runtime: http3_runtime_snapshot( entry.http3_enabled, Some(entry.counters.as_ref()), @@ -170,50 +220,6 @@ impl SharedState { }) .collect(); - TrafficStatsSnapshot { listeners, vhosts, routes } - } - - pub(in crate::state) fn listener_traffic_counters( - &self, - listener_id: &str, - ) -> Option> { - let stats = self - .observability - .traffic_stats - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - stats.listeners.get(listener_id).map(|entry| entry.counters.clone()) - } - - pub(super) fn route_traffic_counters( - &self, - route_id: &str, - ) -> Option> { - let stats = self - .observability - .traffic_stats - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - stats.routes.get(route_id).map(|entry| entry.counters.clone()) - } - - pub(super) fn traffic_counter_refs( - &self, - listener_id: &str, - vhost_id: &str, - route_id: Option<&str>, - ) -> super::TrafficCounterRefs { - let stats = self - .observability - .traffic_stats - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - super::TrafficCounterRefs { - listener: stats.listeners.get(listener_id).map(|entry| entry.counters.clone()), - vhost: stats.vhosts.get(vhost_id).map(|entry| entry.counters.clone()), - route: route_id.and_then(|route_id| { - stats.routes.get(route_id).map(|entry| entry.counters.clone()) - }), - } + TrafficStatsSnapshot { listeners, routes, vhosts } } } diff --git a/crates/rginx-http/src/state/traffic/versions.rs b/crates/rginx-http/src/state/traffic/versions.rs index 3d918669..16e25952 100644 --- a/crates/rginx-http/src/state/traffic/versions.rs +++ b/crates/rginx-http/src/state/traffic/versions.rs @@ -1,34 +1,9 @@ -use super::super::*; +use super::super::{ + ConfigSnapshot, HashMap, SharedState, build_traffic_component_versions, + build_traffic_stats_index, +}; impl SharedState { - pub(crate) fn sync_traffic_stats(&self, config: &ConfigSnapshot) { - let existing = self - .observability - .traffic_stats - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let next = build_traffic_stats_index(config, Some(&*existing)); - drop(existing); - *self - .observability - .traffic_stats - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) = next; - - let existing = self - .observability - .traffic_component_versions - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let next = build_traffic_component_versions(config, Some(&*existing)); - drop(existing); - *self - .observability - .traffic_component_versions - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) = next; - } - pub(crate) fn changed_traffic_targets_since( &self, since_version: u64, @@ -37,7 +12,7 @@ impl SharedState { .observability .traffic_component_versions .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); let mut listeners = changed_ids_since(&versions.listeners, since_version); let mut vhosts = changed_ids_since(&versions.vhosts, since_version); let mut routes = changed_ids_since(&versions.routes, since_version); @@ -47,29 +22,6 @@ impl SharedState { (listeners, vhosts, routes) } - pub(crate) fn mark_traffic_targets_changed( - &self, - version: u64, - listener_id: Option<&str>, - vhost_id: Option<&str>, - route_id: Option<&str>, - ) { - let mut versions = self - .observability - .traffic_component_versions - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - if let Some(listener_id) = listener_id { - versions.listeners.insert(listener_id.to_string(), version); - } - if let Some(vhost_id) = vhost_id { - versions.vhosts.insert(vhost_id.to_string(), version); - } - if let Some(route_id) = route_id { - versions.routes.insert(route_id.to_string(), version); - } - } - pub(crate) fn mark_all_traffic_targets_changed( &self, previous: &ConfigSnapshot, @@ -80,7 +32,7 @@ impl SharedState { .observability .traffic_component_versions .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); for listener in &previous.listeners { versions.listeners.insert(listener.id.clone(), version); @@ -102,6 +54,57 @@ impl SharedState { } } } + + pub(crate) fn mark_traffic_targets_changed( + &self, + version: u64, + listener_id: Option<&str>, + vhost_id: Option<&str>, + route_id: Option<&str>, + ) { + let mut versions = self + .observability + .traffic_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if let Some(listener_id) = listener_id { + versions.listeners.insert(listener_id.to_string(), version); + } + if let Some(vhost_id) = vhost_id { + versions.vhosts.insert(vhost_id.to_string(), version); + } + if let Some(route_id) = route_id { + versions.routes.insert(route_id.to_string(), version); + } + } + + pub(crate) fn sync_traffic_stats(&self, config: &ConfigSnapshot) { + let existing = self + .observability + .traffic_stats + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let next = build_traffic_stats_index(config, Some(&*existing)); + drop(existing); + *self + .observability + .traffic_stats + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = next; + + let existing = self + .observability + .traffic_component_versions + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let next = build_traffic_component_versions(config, Some(&*existing)); + drop(existing); + *self + .observability + .traffic_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = next; + } } fn changed_ids_since(versions: &HashMap, since_version: u64) -> Vec { diff --git a/crates/rginx-http/src/state/upstreams/record.rs b/crates/rginx-http/src/state/upstreams/record.rs index 48459c9a..3c3be35c 100644 --- a/crates/rginx-http/src/state/upstreams/record.rs +++ b/crates/rginx-http/src/state/upstreams/record.rs @@ -1,63 +1,20 @@ -use super::super::*; +use super::super::{Ordering, SharedState}; impl SharedState { - pub(crate) fn record_upstream_request(&self, upstream_name: &str, request_bytes: u64) { + pub(crate) fn record_upstream_bad_gateway_response(&self, upstream_name: &str) { let Some(counters) = self.upstream_stats_counters(upstream_name) else { return; }; - counters.downstream_requests_total.fetch_add(1, Ordering::Relaxed); - counters.downstream_request_bytes_total.fetch_add(request_bytes, Ordering::Relaxed); - counters.recent_60s.record_downstream_request(request_bytes); - self.mark_upstream_snapshot_changed(upstream_name); - } - - pub(crate) fn record_upstream_peer_attempt(&self, upstream_name: &str, peer_url: &str) { - let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) - else { - return; - }; - counters.peer_attempts_total.fetch_add(1, Ordering::Relaxed); - peer.attempts_total.fetch_add(1, Ordering::Relaxed); - counters.recent_60s.record_peer_attempt(); - self.mark_upstream_snapshot_changed(upstream_name); - } - - pub(crate) fn record_upstream_peer_success(&self, upstream_name: &str, peer_url: &str) { - let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) - else { - return; - }; - counters.peer_successes_total.fetch_add(1, Ordering::Relaxed); - peer.successes_total.fetch_add(1, Ordering::Relaxed); - self.mark_upstream_snapshot_changed(upstream_name); - } - - pub(crate) fn record_upstream_peer_failure(&self, upstream_name: &str, peer_url: &str) { - let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) - else { - return; - }; - counters.peer_failures_total.fetch_add(1, Ordering::Relaxed); - peer.failures_total.fetch_add(1, Ordering::Relaxed); - self.mark_upstream_snapshot_changed(upstream_name); - } - - pub(crate) fn record_upstream_peer_timeout(&self, upstream_name: &str, peer_url: &str) { - let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) - else { - return; - }; - counters.peer_timeouts_total.fetch_add(1, Ordering::Relaxed); - peer.timeouts_total.fetch_add(1, Ordering::Relaxed); + counters.bad_gateway_responses_total.fetch_add(1, Ordering::Relaxed); + counters.recent_60s.record_bad_gateway_response(); self.mark_upstream_snapshot_changed(upstream_name); } - pub(crate) fn record_upstream_failover(&self, upstream_name: &str) { + pub(crate) fn record_upstream_bad_request_response(&self, upstream_name: &str) { let Some(counters) = self.upstream_stats_counters(upstream_name) else { return; }; - counters.failovers_total.fetch_add(1, Ordering::Relaxed); - counters.recent_60s.record_failover(); + counters.bad_request_responses_total.fetch_add(1, Ordering::Relaxed); self.mark_upstream_snapshot_changed(upstream_name); } @@ -70,21 +27,12 @@ impl SharedState { self.mark_upstream_snapshot_changed(upstream_name); } - pub(crate) fn record_upstream_response_bytes(&self, upstream_name: &str, response_bytes: u64) { - let Some(counters) = self.upstream_stats_counters(upstream_name) else { - return; - }; - counters.downstream_response_bytes_total.fetch_add(response_bytes, Ordering::Relaxed); - counters.recent_60s.record_downstream_response_bytes(response_bytes); - self.mark_upstream_snapshot_changed(upstream_name); - } - - pub(crate) fn record_upstream_bad_gateway_response(&self, upstream_name: &str) { + pub(crate) fn record_upstream_failover(&self, upstream_name: &str) { let Some(counters) = self.upstream_stats_counters(upstream_name) else { return; }; - counters.bad_gateway_responses_total.fetch_add(1, Ordering::Relaxed); - counters.recent_60s.record_bad_gateway_response(); + counters.failovers_total.fetch_add(1, Ordering::Relaxed); + counters.recent_60s.record_failover(); self.mark_upstream_snapshot_changed(upstream_name); } @@ -97,11 +45,11 @@ impl SharedState { self.mark_upstream_snapshot_changed(upstream_name); } - pub(crate) fn record_upstream_bad_request_response(&self, upstream_name: &str) { + pub(crate) fn record_upstream_no_healthy_peers(&self, upstream_name: &str) { let Some(counters) = self.upstream_stats_counters(upstream_name) else { return; }; - counters.bad_request_responses_total.fetch_add(1, Ordering::Relaxed); + counters.no_healthy_peers_total.fetch_add(1, Ordering::Relaxed); self.mark_upstream_snapshot_changed(upstream_name); } @@ -113,19 +61,24 @@ impl SharedState { self.mark_upstream_snapshot_changed(upstream_name); } - pub(crate) fn record_upstream_unsupported_media_type_response(&self, upstream_name: &str) { - let Some(counters) = self.upstream_stats_counters(upstream_name) else { + pub(crate) fn record_upstream_peer_attempt(&self, upstream_name: &str, peer_url: &str) { + let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) + else { return; }; - counters.unsupported_media_type_responses_total.fetch_add(1, Ordering::Relaxed); + counters.peer_attempts_total.fetch_add(1, Ordering::Relaxed); + peer.attempts_total.fetch_add(1, Ordering::Relaxed); + counters.recent_60s.record_peer_attempt(); self.mark_upstream_snapshot_changed(upstream_name); } - pub(crate) fn record_upstream_no_healthy_peers(&self, upstream_name: &str) { - let Some(counters) = self.upstream_stats_counters(upstream_name) else { + pub(crate) fn record_upstream_peer_failure(&self, upstream_name: &str, peer_url: &str) { + let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) + else { return; }; - counters.no_healthy_peers_total.fetch_add(1, Ordering::Relaxed); + counters.peer_failures_total.fetch_add(1, Ordering::Relaxed); + peer.failures_total.fetch_add(1, Ordering::Relaxed); self.mark_upstream_snapshot_changed(upstream_name); } @@ -154,4 +107,51 @@ impl SharedState { } self.mark_upstream_snapshot_changed(upstream_name); } + + pub(crate) fn record_upstream_peer_success(&self, upstream_name: &str, peer_url: &str) { + let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) + else { + return; + }; + counters.peer_successes_total.fetch_add(1, Ordering::Relaxed); + peer.successes_total.fetch_add(1, Ordering::Relaxed); + self.mark_upstream_snapshot_changed(upstream_name); + } + + pub(crate) fn record_upstream_peer_timeout(&self, upstream_name: &str, peer_url: &str) { + let Some((counters, peer)) = self.upstream_stats_peer_counters(upstream_name, peer_url) + else { + return; + }; + counters.peer_timeouts_total.fetch_add(1, Ordering::Relaxed); + peer.timeouts_total.fetch_add(1, Ordering::Relaxed); + self.mark_upstream_snapshot_changed(upstream_name); + } + + pub(crate) fn record_upstream_request(&self, upstream_name: &str, request_bytes: u64) { + let Some(counters) = self.upstream_stats_counters(upstream_name) else { + return; + }; + counters.downstream_requests_total.fetch_add(1, Ordering::Relaxed); + counters.downstream_request_bytes_total.fetch_add(request_bytes, Ordering::Relaxed); + counters.recent_60s.record_downstream_request(request_bytes); + self.mark_upstream_snapshot_changed(upstream_name); + } + + pub(crate) fn record_upstream_response_bytes(&self, upstream_name: &str, response_bytes: u64) { + let Some(counters) = self.upstream_stats_counters(upstream_name) else { + return; + }; + counters.downstream_response_bytes_total.fetch_add(response_bytes, Ordering::Relaxed); + counters.recent_60s.record_downstream_response_bytes(response_bytes); + self.mark_upstream_snapshot_changed(upstream_name); + } + + pub(crate) fn record_upstream_unsupported_media_type_response(&self, upstream_name: &str) { + let Some(counters) = self.upstream_stats_counters(upstream_name) else { + return; + }; + counters.unsupported_media_type_responses_total.fetch_add(1, Ordering::Relaxed); + self.mark_upstream_snapshot_changed(upstream_name); + } } diff --git a/crates/rginx-http/src/state/upstreams/snapshot.rs b/crates/rginx-http/src/state/upstreams/snapshot.rs index b8fb7d42..e434e4f2 100644 --- a/crates/rginx-http/src/state/upstreams/snapshot.rs +++ b/crates/rginx-http/src/state/upstreams/snapshot.rs @@ -1,15 +1,120 @@ use super::super::tls_runtime::upstream_tls_status_snapshot; -use super::super::*; +use super::super::{ + Arc, ConfigSnapshot, Ordering, SharedState, UpstreamHealthSnapshot, UpstreamPeerStats, + UpstreamPeerStatsSnapshot, UpstreamStats, UpstreamStatsSnapshot, build_upstream_name_versions, + build_upstream_stats_map, +}; impl SharedState { + pub(crate) fn mark_all_upstream_targets_changed( + &self, + previous: &ConfigSnapshot, + next: &ConfigSnapshot, + version: u64, + ) { + let mut upstream_versions = self + .observability + .upstream_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + for name in previous.upstreams.keys() { + upstream_versions.insert(name.clone(), version); + } + for name in next.upstreams.keys() { + upstream_versions.insert(name.clone(), version); + } + + let mut peer_health_versions = self + .observability + .peer_health_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner); + for name in previous.upstreams.keys() { + peer_health_versions.insert(name.clone(), version); + } + for name in next.upstreams.keys() { + peer_health_versions.insert(name.clone(), version); + } + } pub async fn peer_health_snapshot(&self) -> Vec { self.inner.read().await.clients.peer_health_snapshot().await } + pub(crate) fn sync_peer_health_versions(&self, config: &ConfigSnapshot) { + let existing = self + .observability + .peer_health_component_versions + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let next = build_upstream_name_versions(config, Some(&*existing)); + drop(existing); + *self + .observability + .peer_health_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = next; + } + + pub(crate) fn sync_upstream_stats(&self, config: &ConfigSnapshot) { + let existing = self + .observability + .upstream_stats + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let next = build_upstream_stats_map(config, Some(&*existing)); + drop(existing); + *self + .observability + .upstream_stats + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = next; + let existing = self + .observability + .upstream_component_versions + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let next = build_upstream_name_versions(config, Some(&*existing)); + drop(existing); + *self + .observability + .upstream_component_versions + .write() + .unwrap_or_else(std::sync::PoisonError::into_inner) = next; + } + + pub(super) fn upstream_stats_counters( + &self, + upstream_name: &str, + ) -> Option> { + let stats = self + .observability + .upstream_stats + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + stats.get(upstream_name).map(|entry| entry.counters.clone()) + } + + pub(super) fn upstream_stats_peer_counters( + &self, + upstream_name: &str, + peer_url: &str, + ) -> Option<(Arc, Arc)> { + let stats = self + .observability + .upstream_stats + .read() + .unwrap_or_else(std::sync::PoisonError::into_inner); + let entry = stats.get(upstream_name)?; + let peer = entry.peers.get(peer_url)?.clone(); + Some((entry.counters.clone(), peer)) + } + + #[must_use] pub fn upstream_stats_snapshot(&self) -> Vec { self.upstream_stats_snapshot_with_window(None) } + #[must_use] pub fn upstream_stats_snapshot_with_window( &self, window_secs: Option, @@ -18,7 +123,7 @@ impl SharedState { .observability .upstream_stats .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); + .unwrap_or_else(std::sync::PoisonError::into_inner); let mut upstream_names = stats.keys().cloned().collect::>(); upstream_names.sort(); @@ -117,104 +222,4 @@ impl SharedState { }) .collect() } - - pub(crate) fn sync_upstream_stats(&self, config: &ConfigSnapshot) { - let existing = self - .observability - .upstream_stats - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let next = build_upstream_stats_map(config, Some(&*existing)); - drop(existing); - *self - .observability - .upstream_stats - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) = next; - let existing = self - .observability - .upstream_component_versions - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let next = build_upstream_name_versions(config, Some(&*existing)); - drop(existing); - *self - .observability - .upstream_component_versions - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) = next; - } - - pub(crate) fn sync_peer_health_versions(&self, config: &ConfigSnapshot) { - let existing = self - .observability - .peer_health_component_versions - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let next = build_upstream_name_versions(config, Some(&*existing)); - drop(existing); - *self - .observability - .peer_health_component_versions - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()) = next; - } - - pub(super) fn upstream_stats_counters( - &self, - upstream_name: &str, - ) -> Option> { - let stats = self - .observability - .upstream_stats - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - stats.get(upstream_name).map(|entry| entry.counters.clone()) - } - - pub(super) fn upstream_stats_peer_counters( - &self, - upstream_name: &str, - peer_url: &str, - ) -> Option<(Arc, Arc)> { - let stats = self - .observability - .upstream_stats - .read() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - let entry = stats.get(upstream_name)?; - let peer = entry.peers.get(peer_url)?.clone(); - Some((entry.counters.clone(), peer)) - } - - pub(crate) fn mark_all_upstream_targets_changed( - &self, - previous: &ConfigSnapshot, - next: &ConfigSnapshot, - version: u64, - ) { - let mut upstream_versions = self - .observability - .upstream_component_versions - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - for name in previous.upstreams.keys() { - upstream_versions.insert(name.clone(), version); - } - for name in next.upstreams.keys() { - upstream_versions.insert(name.clone(), version); - } - - let mut peer_health_versions = self - .observability - .peer_health_component_versions - .write() - .unwrap_or_else(|poisoned| poisoned.into_inner()); - for name in previous.upstreams.keys() { - peer_health_versions.insert(name.clone(), version); - } - for name in next.upstreams.keys() { - peer_health_versions.insert(name.clone(), version); - } - } } diff --git a/crates/rginx-http/src/timeout/body.rs b/crates/rginx-http/src/timeout/body.rs index 59b12122..8e8b0066 100644 --- a/crates/rginx-http/src/timeout/body.rs +++ b/crates/rginx-http/src/timeout/body.rs @@ -1,3 +1,7 @@ +mod grpc_deadline; +mod idle; +mod max_bytes; + use std::future::Future; use std::pin::Pin; use std::time::Duration; @@ -13,10 +17,6 @@ use crate::handler::BoxError; use super::timers::{poll_idle_timer, reset_idle_timer}; -mod grpc_deadline; -mod idle; -mod max_bytes; - pub(crate) use grpc_deadline::GrpcDeadlineBody; pub(crate) use idle::IdleTimeoutBody; pub(crate) use max_bytes::{MaxBytesBody, RequestBodyLimitError}; diff --git a/crates/rginx-http/src/timeout/body/grpc_deadline.rs b/crates/rginx-http/src/timeout/body/grpc_deadline.rs index ba684ffa..fe59df0d 100644 --- a/crates/rginx-http/src/timeout/body/grpc_deadline.rs +++ b/crates/rginx-http/src/timeout/body/grpc_deadline.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + Body, BoxError, Bytes, Duration, Frame, Future, Instant, Pin, SizeHint, Sleep, + grpc_deadline_exceeded_trailers, pin_project, +}; pin_project! { #[derive(Debug)] @@ -40,6 +43,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -84,10 +91,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } diff --git a/crates/rginx-http/src/timeout/body/idle.rs b/crates/rginx-http/src/timeout/body/idle.rs index 44ea53c3..fa4ae206 100644 --- a/crates/rginx-http/src/timeout/body/idle.rs +++ b/crates/rginx-http/src/timeout/body/idle.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + Body, BoxError, Bytes, Duration, Frame, Pin, SizeHint, Sleep, pin_project, poll_idle_timer, + reset_idle_timer, +}; pin_project! { #[derive(Debug)] @@ -26,6 +29,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -64,10 +71,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.done - } - fn size_hint(&self) -> SizeHint { self.inner.size_hint() } diff --git a/crates/rginx-http/src/timeout/body/max_bytes.rs b/crates/rginx-http/src/timeout/body/max_bytes.rs index 0400c06b..871a9268 100644 --- a/crates/rginx-http/src/timeout/body/max_bytes.rs +++ b/crates/rginx-http/src/timeout/body/max_bytes.rs @@ -1,4 +1,15 @@ -use super::*; +use super::{Body, BoxError, Bytes, Frame, Pin, SizeHint, pin_project}; + +pin_project! { + #[derive(Debug)] + pub struct MaxBytesBody { + #[pin] + inner: B, + max_request_body_bytes: usize, + bytes_read: usize, + done: bool, + } +} #[derive(Debug)] pub struct RequestBodyLimitError { @@ -6,13 +17,12 @@ pub struct RequestBodyLimitError { } impl RequestBodyLimitError { - pub fn new(max_request_body_bytes: usize) -> Self { - Self { max_request_body_bytes } - } - pub fn max_request_body_bytes(&self) -> usize { self.max_request_body_bytes } + pub fn new(max_request_body_bytes: usize) -> Self { + Self { max_request_body_bytes } + } } impl std::fmt::Display for RequestBodyLimitError { @@ -27,17 +37,6 @@ impl std::fmt::Display for RequestBodyLimitError { impl std::error::Error for RequestBodyLimitError {} -pin_project! { - #[derive(Debug)] - pub struct MaxBytesBody { - #[pin] - inner: B, - max_request_body_bytes: usize, - bytes_read: usize, - done: bool, - } -} - impl MaxBytesBody { pub fn new(inner: B, max_request_body_bytes: usize) -> Self { Self { inner, max_request_body_bytes, bytes_read: 0, done: false } @@ -52,6 +51,10 @@ where type Data = Bytes; type Error = BoxError; + fn is_end_stream(&self) -> bool { + self.done || self.inner.is_end_stream() + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut std::task::Context<'_>, @@ -72,7 +75,7 @@ where )))); } - *this.bytes_read += data.len(); + *this.bytes_read = (*this.bytes_read).saturating_add(data.len()); std::task::Poll::Ready(Some(Ok(Frame::data(data)))) } Err(frame) => match frame.into_trailers() { @@ -95,10 +98,6 @@ where } } - fn is_end_stream(&self) -> bool { - self.done || self.inner.is_end_stream() - } - fn size_hint(&self) -> SizeHint { let remaining = self.max_request_body_bytes.saturating_sub(self.bytes_read) as u64; let inner = self.inner.size_hint(); diff --git a/crates/rginx-http/src/timeout/io.rs b/crates/rginx-http/src/timeout/io.rs index 2a7650f0..cdcaaefe 100644 --- a/crates/rginx-http/src/timeout/io.rs +++ b/crates/rginx-http/src/timeout/io.rs @@ -44,14 +44,8 @@ impl AsyncWrite for WriteTimeoutIo where T: AsyncWrite, { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let mut this = self.project(); - let result = this.inner.as_mut().poll_write(cx, buf); - poll_write_side(cx, result, *this.timeout, this.sleep, this.label.as_str(), "write") + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -66,6 +60,16 @@ where poll_write_side(cx, result, *this.timeout, this.sleep, this.label.as_str(), "shutdown") } + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let mut this = self.project(); + let result = this.inner.as_mut().poll_write(cx, buf); + poll_write_side(cx, result, *this.timeout, this.sleep, this.label.as_str(), "write") + } + fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -75,8 +79,4 @@ where let result = this.inner.as_mut().poll_write_vectored(cx, bufs); poll_write_side(cx, result, *this.timeout, this.sleep, this.label.as_str(), "write") } - - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } } diff --git a/crates/rginx-http/src/timeout/tests.rs b/crates/rginx-http/src/timeout/tests.rs index 2858b82b..44732bed 100644 --- a/crates/rginx-http/src/timeout/tests.rs +++ b/crates/rginx-http/src/timeout/tests.rs @@ -1,3 +1,8 @@ +mod grpc_deadline; +mod idle; +mod max_bytes; +mod write_timeout; + use std::future::{Future, poll_fn}; use std::io; use std::io::IoSlice; @@ -15,11 +20,6 @@ use tokio::time::{Instant, Sleep}; use super::{GrpcDeadlineBody, IdleTimeoutBody, MaxBytesBody, WriteTimeoutIo}; -mod grpc_deadline; -mod idle; -mod max_bytes; -mod write_timeout; - pin_project! { struct DelayedFrameBody { #[pin] @@ -28,6 +28,24 @@ pin_project! { } } +pin_project! { + struct TwoStageBody { + #[pin] + first_delay: Sleep, + #[pin] + second_delay: Sleep, + state: u8, + } +} + +pin_project! { + struct DelayedWriter { + #[pin] + delay: Sleep, + emitted: bool, + } +} + impl DelayedFrameBody { fn new(delay: Duration) -> Self { Self { delay: tokio::time::sleep(delay), emitted: false } @@ -38,6 +56,10 @@ impl Body for DelayedFrameBody { type Data = Bytes; type Error = io::Error; + fn is_end_stream(&self) -> bool { + self.emitted + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -57,10 +79,6 @@ impl Body for DelayedFrameBody { } } - fn is_end_stream(&self) -> bool { - self.emitted - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } @@ -80,6 +98,10 @@ impl Body for EarlyEndTrailersBody { type Data = Bytes; type Error = io::Error; + fn is_end_stream(&self) -> bool { + self.state >= 1 + } + fn poll_frame( self: Pin<&mut Self>, _cx: &mut Context<'_>, @@ -101,25 +123,11 @@ impl Body for EarlyEndTrailersBody { } } - fn is_end_stream(&self) -> bool { - self.state >= 1 - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } } -pin_project! { - struct TwoStageBody { - #[pin] - first_delay: Sleep, - #[pin] - second_delay: Sleep, - state: u8, - } -} - impl TwoStageBody { fn new(first_delay: Duration, second_delay: Duration) -> Self { Self { @@ -134,6 +142,10 @@ impl Body for TwoStageBody { type Data = Bytes; type Error = io::Error; + fn is_end_stream(&self) -> bool { + self.state >= 2 + } + fn poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -161,23 +173,11 @@ impl Body for TwoStageBody { } } - fn is_end_stream(&self) -> bool { - self.state >= 2 - } - fn size_hint(&self) -> SizeHint { SizeHint::default() } } -pin_project! { - struct DelayedWriter { - #[pin] - delay: Sleep, - emitted: bool, - } -} - impl DelayedWriter { fn new(delay: Duration) -> Self { Self { delay: tokio::time::sleep(delay), emitted: false } @@ -195,6 +195,18 @@ impl AsyncRead for DelayedWriter { } impl AsyncWrite for DelayedWriter { + fn is_write_vectored(&self) -> bool { + true + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -215,14 +227,6 @@ impl AsyncWrite for DelayedWriter { } } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -231,8 +235,4 @@ impl AsyncWrite for DelayedWriter { let total = bufs.iter().map(|buf| buf.len()).sum(); self.poll_write(cx, &vec![0u8; total]) } - - fn is_write_vectored(&self) -> bool { - true - } } diff --git a/crates/rginx-http/src/timeout/timers.rs b/crates/rginx-http/src/timeout/timers.rs index 008850e2..3a52457c 100644 --- a/crates/rginx-http/src/timeout/timers.rs +++ b/crates/rginx-http/src/timeout/timers.rs @@ -8,7 +8,9 @@ use tokio::time::Sleep; use crate::handler::BoxError; pub(super) fn reset_idle_timer(sleep: &mut Option>>, timeout: Duration) { - let deadline = tokio::time::Instant::now() + timeout; + let deadline = tokio::time::Instant::now() + .checked_add(timeout) + .expect("idle timeout deadline remains representable"); match sleep { Some(sleep) => sleep.as_mut().reset(deadline), None => *sleep = Some(Box::pin(tokio::time::sleep_until(deadline))), @@ -17,7 +19,11 @@ pub(super) fn reset_idle_timer(sleep: &mut Option>>, timeout: Dur fn arm_idle_timer(sleep: &mut Option>>, timeout: Duration) { if sleep.is_none() { - *sleep = Some(Box::pin(tokio::time::sleep_until(tokio::time::Instant::now() + timeout))); + *sleep = Some(Box::pin(tokio::time::sleep_until( + tokio::time::Instant::now() + .checked_add(timeout) + .expect("idle timeout deadline remains representable"), + ))); } } diff --git a/crates/rginx-http/src/tls/acceptor.rs b/crates/rginx-http/src/tls/acceptor.rs index 66551ef3..eeac01fe 100644 --- a/crates/rginx-http/src/tls/acceptor.rs +++ b/crates/rginx-http/src/tls/acceptor.rs @@ -17,8 +17,8 @@ use super::sni::{SniCertificateResolver, register_server_name_certificates}; struct TlsServerConfigOptions { alpn_protocols: Vec, - http3_only: bool, http3_early_data_enabled: bool, + http3_only: bool, } /// 构建支持 SNI 的 TLS acceptor diff --git a/crates/rginx-http/src/tls/certificates.rs b/crates/rginx-http/src/tls/certificates.rs index ebc99b07..986bdf88 100644 --- a/crates/rginx-http/src/tls/certificates.rs +++ b/crates/rginx-http/src/tls/certificates.rs @@ -1,8 +1,11 @@ +#[cfg(test)] +mod tests; use std::path::Path; use std::sync::Arc; use rginx_core::{Error, Result, ServerCertificateBundle, ServerTls, VirtualHostTls}; use rustls::RootCertStore; +use rustls::crypto::aws_lc_rs::sign::any_supported_type; use rustls::pki_types::{ CertificateDer, CertificateRevocationListDer, PrivateKeyDer, pem::{Error as PemError, PemObject}, @@ -39,7 +42,7 @@ fn load_certified_keys_from_material( ocsp_staple_path: Option<&std::path::PathBuf>, ocsp: &rginx_core::OcspConfig, ) -> Result>> { - let mut bundles = Vec::with_capacity(1 + additional_certificates.len()); + let mut bundles = Vec::with_capacity(additional_certificates.len().saturating_add(1)); bundles.push(ServerCertificateBundle { cert_path: cert_path.to_path_buf(), key_path: key_path.to_path_buf(), @@ -59,7 +62,7 @@ pub(crate) fn load_certified_key_bundle( let mut certified_key = rustls::sign::CertifiedKey::new( certs, - rustls::crypto::aws_lc_rs::sign::any_supported_type(&key).map_err(|_| { + any_supported_type(&key).map_err(|_| { Error::Server(format!( "server TLS private key file `{}` uses unsupported algorithm", bundle.key_path.display() @@ -189,6 +192,3 @@ fn map_pem_error(path: &Path, item: &str, error: PemError) -> Error { } } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-http/src/tls/certificates/tests.rs b/crates/rginx-http/src/tls/certificates/tests.rs index 460e314d..596de984 100644 --- a/crates/rginx-http/src/tls/certificates/tests.rs +++ b/crates/rginx-http/src/tls/certificates/tests.rs @@ -12,8 +12,8 @@ use rcgen::{ use super::load_certificate_revocation_lists; struct TestCertifiedKey { - signing_key: KeyPair, params: CertificateParams, + signing_key: KeyPair, } impl TestCertifiedKey { diff --git a/crates/rginx-http/src/tls/client_auth.rs b/crates/rginx-http/src/tls/client_auth.rs index c087cd15..5a0c2dd8 100644 --- a/crates/rginx-http/src/tls/client_auth.rs +++ b/crates/rginx-http/src/tls/client_auth.rs @@ -17,18 +17,22 @@ impl DepthLimitedClientVerifier { } impl ClientCertVerifier for DepthLimitedClientVerifier { - fn offer_client_auth(&self) -> bool { - self.inner.offer_client_auth() - } - fn client_auth_mandatory(&self) -> bool { self.inner.client_auth_mandatory() } + fn offer_client_auth(&self) -> bool { + self.inner.offer_client_auth() + } + fn root_hint_subjects(&self) -> &[rustls::DistinguishedName] { self.inner.root_hint_subjects() } + fn supported_verify_schemes(&self) -> Vec { + self.inner.supported_verify_schemes() + } + fn verify_client_cert( &self, end_entity: &CertificateDer<'_>, @@ -64,8 +68,4 @@ impl ClientCertVerifier for DepthLimitedClientVerifier { ) -> std::result::Result { self.inner.verify_tls13_signature(message, cert, dss) } - - fn supported_verify_schemes(&self) -> Vec { - self.inner.supported_verify_schemes() - } } diff --git a/crates/rginx-http/src/tls/mod.rs b/crates/rginx-http/src/tls/mod.rs index 50c255ab..7e990680 100644 --- a/crates/rginx-http/src/tls/mod.rs +++ b/crates/rginx-http/src/tls/mod.rs @@ -1,5 +1,3 @@ -use rginx_core::{OcspNonceMode, OcspResponderPolicy, Result}; - pub(crate) mod certificates; pub(crate) mod ocsp; @@ -11,6 +9,8 @@ mod sni; #[cfg(test)] mod tests; +use rginx_core::{OcspNonceMode, OcspResponderPolicy, Result}; + pub use acceptor::build_http3_server_config; pub use acceptor::build_tls_acceptor; #[cfg(test)] diff --git a/crates/rginx-http/src/tls/ocsp/der_helpers.rs b/crates/rginx-http/src/tls/ocsp/der_helpers.rs index 2834cc4e..74905a9a 100644 --- a/crates/rginx-http/src/tls/ocsp/der_helpers.rs +++ b/crates/rginx-http/src/tls/ocsp/der_helpers.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + AlgorithmIdentifier, CertificateDer, Digest, Error, ExtKeyUsageSyntax, ID_SHA1, + ObjectIdentifier, OctetString, Path, RasnCertId, RasnCertificate, RasnKeyUsage, Result, Sha1, +}; pub(super) fn basic_ocsp_response_type_oid() -> ObjectIdentifier { ObjectIdentifier::new(vec![1, 3, 6, 1, 5, 5, 7, 48, 1, 1]) @@ -98,7 +101,7 @@ pub(super) fn certificate_key_usage(cert: &RasnCertificate) -> Result bool { - value.get(index).map(|bit| *bit).unwrap_or(false) + value.get(index).is_some_and(|bit| *bit) } pub(super) fn hex_string(bytes: &[u8]) -> String { diff --git a/crates/rginx-http/src/tls/ocsp/discover.rs b/crates/rginx-http/src/tls/ocsp/discover.rs index 281ce3ec..5a8b38f6 100644 --- a/crates/rginx-http/src/tls/ocsp/discover.rs +++ b/crates/rginx-http/src/tls/ocsp/discover.rs @@ -1,4 +1,7 @@ -use super::*; +use super::{ + AuthorityInfoAccessSyntax, Error, Path, RasnCertificate, RasnGeneralName, Result, + load_certificate_chain_from_path, +}; pub(crate) fn ocsp_responder_urls_for_certificate(path: &Path) -> Result> { let certs = load_certificate_chain_from_path(path)?; diff --git a/crates/rginx-http/src/tls/ocsp/mod.rs b/crates/rginx-http/src/tls/ocsp/mod.rs index ae62ce49..099917ac 100644 --- a/crates/rginx-http/src/tls/ocsp/mod.rs +++ b/crates/rginx-http/src/tls/ocsp/mod.rs @@ -1,5 +1,15 @@ //! OCSP responder discovery, request construction, and response validation. +mod der_helpers; +mod discover; +mod nonce; +mod request; +mod signer; +#[cfg(test)] +mod tests; +mod time; +mod validate; + use std::path::Path; use std::time::SystemTime; @@ -26,16 +36,6 @@ use super::certificates::load_certificate_chain_from_path; #[cfg(test)] use super::certificates::load_certified_key_bundle; -mod der_helpers; -mod discover; -mod nonce; -mod request; -mod signer; -#[cfg(test)] -mod tests; -mod time; -mod validate; - pub(crate) use discover::ocsp_responder_urls_for_certificate; pub(crate) use request::{ build_ocsp_request_for_certificate, build_ocsp_request_for_certificate_with_options, diff --git a/crates/rginx-http/src/tls/ocsp/nonce.rs b/crates/rginx-http/src/tls/ocsp/nonce.rs index cc72fa4e..26ef4caa 100644 --- a/crates/rginx-http/src/tls/ocsp/nonce.rs +++ b/crates/rginx-http/src/tls/ocsp/nonce.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{Error, ObjectIdentifier, OcspNonceMode, OctetString, Path, RasnExtension, Result}; pub(super) fn ocsp_nonce_oid() -> ObjectIdentifier { ObjectIdentifier::new(vec![1, 3, 6, 1, 5, 5, 7, 48, 1, 2]) diff --git a/crates/rginx-http/src/tls/ocsp/request.rs b/crates/rginx-http/src/tls/ocsp/request.rs index 7d9856cb..9c6226e7 100644 --- a/crates/rginx-http/src/tls/ocsp/request.rs +++ b/crates/rginx-http/src/tls/ocsp/request.rs @@ -1,4 +1,8 @@ -use super::*; +use super::{ + CertificateDer, Error, Integer, OcspNonceMode, Path, RasnOcspRequest, Request, Result, + TbsRequest, build_ocsp_nonce_extension, build_rasn_ocsp_cert_id_from_chain, + build_request_nonce, load_certificate_chain_from_path, +}; pub(crate) fn build_ocsp_request_for_certificate(path: &Path) -> Result> { build_ocsp_request_for_certificate_with_options(path, OcspNonceMode::Disabled) diff --git a/crates/rginx-http/src/tls/ocsp/signer.rs b/crates/rginx-http/src/tls/ocsp/signer.rs index 107f2d0c..a045c776 100644 --- a/crates/rginx-http/src/tls/ocsp/signer.rs +++ b/crates/rginx-http/src/tls/ocsp/signer.rs @@ -1,4 +1,10 @@ -use super::*; +use super::{ + ALL_VERIFICATION_ALGS, CertificateDer, Digest, EndEntityCert, Error, OcspResponderPolicy, Path, + RasnBasicOcspResponse, RasnCertificate, RasnResponderId, Result, Sha1, SystemTime, + algorithm_identifier_value_bytes, bit_string_flag, certificate_extended_key_usage, + certificate_key_usage, certificate_valid_now, hex_string, signature_bytes, + subject_public_key_bytes, +}; pub(super) fn validate_basic_ocsp_response_signature( path: &Path, diff --git a/crates/rginx-http/src/tls/ocsp/tests.rs b/crates/rginx-http/src/tls/ocsp/tests.rs index 5df20da1..1d6438a9 100644 --- a/crates/rginx-http/src/tls/ocsp/tests.rs +++ b/crates/rginx-http/src/tls/ocsp/tests.rs @@ -1,3 +1,8 @@ +mod discovery; +mod nonce; +mod support; +mod validation; + use std::env; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; @@ -20,9 +25,4 @@ use rginx_core::ServerCertificateBundle; use super::*; -mod discovery; -mod nonce; -mod support; -mod validation; - pub(crate) use support::*; diff --git a/crates/rginx-http/src/tls/ocsp/tests/discovery.rs b/crates/rginx-http/src/tls/ocsp/tests/discovery.rs index 23f0391b..2219ef1a 100644 --- a/crates/rginx-http/src/tls/ocsp/tests/discovery.rs +++ b/crates/rginx-http/src/tls/ocsp/tests/discovery.rs @@ -1,20 +1,5 @@ use super::*; -#[test] -fn ocsp_responder_urls_for_certificate_extracts_aia_ocsp_uri() { - let temp_dir = temp_dir("rginx-ocsp-aia-responder-url"); - std::fs::create_dir_all(&temp_dir).expect("temp dir should exist"); - - let ca = generate_ca_cert("ocsp-test-ca"); - let leaf = generate_leaf_cert_with_ocsp_aia("localhost", &ca, "http://127.0.0.1:19090/ocsp"); - let cert_path = write_cert_chain(&temp_dir, "server", &leaf, &ca); - - let urls = ocsp_responder_urls_for_certificate(&cert_path) - .expect("AIA OCSP responder discovery should succeed"); - assert_eq!(urls, vec!["http://127.0.0.1:19090/ocsp".to_string()]); - - let _ = std::fs::remove_dir_all(temp_dir); -} proptest! { #![proptest_config(ProptestConfig::with_cases(48))] @@ -56,3 +41,19 @@ proptest! { let _ = std::fs::remove_dir_all(temp_dir); } } + +#[test] +fn ocsp_responder_urls_for_certificate_extracts_aia_ocsp_uri() { + let temp_dir = temp_dir("rginx-ocsp-aia-responder-url"); + std::fs::create_dir_all(&temp_dir).expect("temp dir should exist"); + + let ca = generate_ca_cert("ocsp-test-ca"); + let leaf = generate_leaf_cert_with_ocsp_aia("localhost", &ca, "http://127.0.0.1:19090/ocsp"); + let cert_path = write_cert_chain(&temp_dir, "server", &leaf, &ca); + + let urls = ocsp_responder_urls_for_certificate(&cert_path) + .expect("AIA OCSP responder discovery should succeed"); + assert_eq!(urls, vec!["http://127.0.0.1:19090/ocsp".to_string()]); + + let _ = std::fs::remove_dir_all(temp_dir); +} diff --git a/crates/rginx-http/src/tls/ocsp/tests/support.rs b/crates/rginx-http/src/tls/ocsp/tests/support.rs index 17631a7f..391dde81 100644 --- a/crates/rginx-http/src/tls/ocsp/tests/support.rs +++ b/crates/rginx-http/src/tls/ocsp/tests/support.rs @@ -1,7 +1,7 @@ -pub(crate) use super::*; - mod certs; mod response; +pub(crate) use super::*; + pub(crate) use certs::*; pub(crate) use response::*; diff --git a/crates/rginx-http/src/tls/ocsp/tests/support/certs.rs b/crates/rginx-http/src/tls/ocsp/tests/support/certs.rs index 2e178c2c..24de7adb 100644 --- a/crates/rginx-http/src/tls/ocsp/tests/support/certs.rs +++ b/crates/rginx-http/src/tls/ocsp/tests/support/certs.rs @@ -2,8 +2,8 @@ use super::*; pub(crate) struct TestCertifiedKey { pub(crate) cert: rcgen::Certificate, - pub(crate) signing_key: KeyPair, pub(crate) params: CertificateParams, + pub(crate) signing_key: KeyPair, } impl TestCertifiedKey { diff --git a/crates/rginx-http/src/tls/ocsp/tests/support/response.rs b/crates/rginx-http/src/tls/ocsp/tests/support/response.rs index 56cd2f1e..d3259fb5 100644 --- a/crates/rginx-http/src/tls/ocsp/tests/support/response.rs +++ b/crates/rginx-http/src/tls/ocsp/tests/support/response.rs @@ -1,42 +1,27 @@ use super::*; -pub(crate) fn build_ocsp_response_for_certificate( - cert_path: &Path, - issuer: &TestCertifiedKey, -) -> Vec { - build_ocsp_response_for_certificate_with_signer( - cert_path, - OcspResponseOptions::new(OcspResponseSigner::Issuer(issuer)), - ) -} - -pub(crate) fn build_ocsp_response_for_certificate_with_offsets( - cert_path: &Path, - issuer: &TestCertifiedKey, - this_update_offset: TimeOffset, - next_update_offset: TimeOffset, -) -> Vec { - build_ocsp_response_for_certificate_with_signer( - cert_path, - OcspResponseOptions::new(OcspResponseSigner::Issuer(issuer)) - .this_update_offset(this_update_offset) - .next_update_offset(Some(next_update_offset)) - .produced_at_offset(this_update_offset), - ) -} - pub(crate) struct OcspResponseOptions<'a> { - this_update_offset: TimeOffset, + cert_status: RasnCertStatus, + duplicate_matching_response: bool, next_update_offset: Option, produced_at_offset: TimeOffset, - cert_status: RasnCertStatus, - signer: OcspResponseSigner<'a>, response_nonce: Option<&'a [u8]>, - duplicate_matching_response: bool, + signer: OcspResponseSigner<'a>, tamper_signature: bool, + this_update_offset: TimeOffset, } impl<'a> OcspResponseOptions<'a> { + pub(crate) fn cert_status(mut self, status: RasnCertStatus) -> Self { + self.cert_status = status; + self + } + + pub(crate) fn duplicate_matching_response(mut self, duplicate: bool) -> Self { + self.duplicate_matching_response = duplicate; + self + } + pub(crate) fn new(signer: OcspResponseSigner<'a>) -> Self { Self { this_update_offset: TimeOffset::Before(Duration::from_secs(24 * 60 * 60)), @@ -50,11 +35,6 @@ impl<'a> OcspResponseOptions<'a> { } } - pub(crate) fn this_update_offset(mut self, offset: TimeOffset) -> Self { - self.this_update_offset = offset; - self - } - pub(crate) fn next_update_offset(mut self, offset: Option) -> Self { self.next_update_offset = offset; self @@ -65,27 +45,84 @@ impl<'a> OcspResponseOptions<'a> { self } - pub(crate) fn cert_status(mut self, status: RasnCertStatus) -> Self { - self.cert_status = status; - self - } - pub(crate) fn response_nonce(mut self, nonce: Option<&'a [u8]>) -> Self { self.response_nonce = nonce; self } - pub(crate) fn duplicate_matching_response(mut self, duplicate: bool) -> Self { - self.duplicate_matching_response = duplicate; + pub(crate) fn tamper_signature(mut self, tamper: bool) -> Self { + self.tamper_signature = tamper; self } - pub(crate) fn tamper_signature(mut self, tamper: bool) -> Self { - self.tamper_signature = tamper; + pub(crate) fn this_update_offset(mut self, offset: TimeOffset) -> Self { + self.this_update_offset = offset; self } } +#[derive(Clone, Copy)] +pub(crate) enum TimeOffset { + After(Duration), + Before(Duration), +} + +pub(crate) enum OcspResponseSigner<'a> { + Delegated(&'a TestCertifiedKey), + Issuer(&'a TestCertifiedKey), +} + +impl<'a> OcspResponseSigner<'a> { + pub(crate) fn embedded_certs(&self) -> Option> { + match self { + Self::Delegated(key) => Some(vec![ + rasn::der::decode(key.cert.der().as_ref()) + .expect("delegated responder certificate should decode"), + ]), + _ => None, + } + } + + pub(crate) fn responder_id(&self) -> RasnResponderId { + match self { + Self::Issuer(key) | Self::Delegated(key) => { + responder_id_for_certificate(key.cert.der().as_ref()) + } + } + } + + pub(crate) fn signing_key(&self) -> &KeyPair { + match self { + Self::Issuer(key) | Self::Delegated(key) => &key.signing_key, + } + } +} + +pub(crate) fn build_ocsp_response_for_certificate( + cert_path: &Path, + issuer: &TestCertifiedKey, +) -> Vec { + build_ocsp_response_for_certificate_with_signer( + cert_path, + OcspResponseOptions::new(OcspResponseSigner::Issuer(issuer)), + ) +} + +pub(crate) fn build_ocsp_response_for_certificate_with_offsets( + cert_path: &Path, + issuer: &TestCertifiedKey, + this_update_offset: TimeOffset, + next_update_offset: TimeOffset, +) -> Vec { + build_ocsp_response_for_certificate_with_signer( + cert_path, + OcspResponseOptions::new(OcspResponseSigner::Issuer(issuer)) + .this_update_offset(this_update_offset) + .next_update_offset(Some(next_update_offset)) + .produced_at_offset(this_update_offset), + ) +} + pub(crate) fn build_ocsp_response_for_certificate_with_signer( cert_path: &Path, options: OcspResponseOptions<'_>, @@ -156,7 +193,9 @@ pub(crate) fn ocsp_time_with_offset(base: SystemTime, offset: TimeOffset) -> Gen TimeOffset::Before(duration) => { base.checked_sub(duration).expect("time offset should stay after unix epoch") } - TimeOffset::After(duration) => base + duration, + TimeOffset::After(duration) => { + base.checked_add(duration).expect("time offset remains valid") + } }; generalized_time_from_system_time(time) } @@ -198,7 +237,7 @@ pub(crate) fn der_length(length: usize) -> Vec { } let bytes = length.to_be_bytes().into_iter().skip_while(|byte| *byte == 0).collect::>(); - let mut encoded = Vec::with_capacity(bytes.len() + 1); + let mut encoded = Vec::with_capacity(bytes.len().saturating_add(1)); encoded.push(0x80 | (bytes.len() as u8)); encoded.extend(bytes); encoded @@ -227,40 +266,3 @@ pub(crate) fn temp_dir(prefix: &str) -> PathBuf { let id = NEXT_ID.fetch_add(1, Ordering::Relaxed); env::temp_dir().join(format!("{prefix}-{unique}-{id}")) } - -#[derive(Clone, Copy)] -pub(crate) enum TimeOffset { - Before(Duration), - After(Duration), -} - -pub(crate) enum OcspResponseSigner<'a> { - Issuer(&'a TestCertifiedKey), - Delegated(&'a TestCertifiedKey), -} - -impl<'a> OcspResponseSigner<'a> { - pub(crate) fn signing_key(&self) -> &KeyPair { - match self { - Self::Issuer(key) | Self::Delegated(key) => &key.signing_key, - } - } - - pub(crate) fn responder_id(&self) -> RasnResponderId { - match self { - Self::Issuer(key) | Self::Delegated(key) => { - responder_id_for_certificate(key.cert.der().as_ref()) - } - } - } - - pub(crate) fn embedded_certs(&self) -> Option> { - match self { - Self::Delegated(key) => Some(vec![ - rasn::der::decode(key.cert.der().as_ref()) - .expect("delegated responder certificate should decode"), - ]), - _ => None, - } - } -} diff --git a/crates/rginx-http/src/tls/ocsp/time.rs b/crates/rginx-http/src/tls/ocsp/time.rs index 69f38fe6..512e122d 100644 --- a/crates/rginx-http/src/tls/ocsp/time.rs +++ b/crates/rginx-http/src/tls/ocsp/time.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{DateTime, GeneralizedTime, RasnCertificate, RasnTime, SystemTime, Utc}; pub(super) fn generalized_time_from_system_time(time: SystemTime) -> GeneralizedTime { let utc = DateTime::::from(time); diff --git a/crates/rginx-http/src/tls/ocsp/validate.rs b/crates/rginx-http/src/tls/ocsp/validate.rs index 1af1e7d9..981c2f88 100644 --- a/crates/rginx-http/src/tls/ocsp/validate.rs +++ b/crates/rginx-http/src/tls/ocsp/validate.rs @@ -1,4 +1,11 @@ -use super::*; +use super::{ + Error, OcspNonceMode, OcspResponderPolicy, Path, RasnBasicOcspResponse, RasnCertId, + RasnCertStatus, RasnCertificate, RasnOcspResponse, RasnOcspResponseStatus, RasnResponseData, + RasnSingleResponse, Result, SystemTime, basic_ocsp_response_type_oid, + build_rasn_ocsp_cert_id_from_chain, extract_ocsp_nonce, generalized_time_from_system_time, + load_certificate_chain_from_path, parse_leaf_and_issuer_certificates, + validate_basic_ocsp_response_signature, +}; pub(crate) fn validate_ocsp_response_for_certificate( path: &Path, @@ -191,7 +198,7 @@ pub(super) fn validate_ocsp_cert_status(path: &Path, response: &RasnSingleRespon "OCSP response for certificate `{}` reports the certificate as revoked", path.display() ))), - RasnCertStatus::Unknown(_) => Err(Error::Server(format!( + RasnCertStatus::Unknown(()) => Err(Error::Server(format!( "OCSP response for certificate `{}` reports an unknown certificate status", path.display() ))), diff --git a/crates/rginx-http/src/tls/provider.rs b/crates/rginx-http/src/tls/provider.rs index 0c11a51a..ffa0b268 100644 --- a/crates/rginx-http/src/tls/provider.rs +++ b/crates/rginx-http/src/tls/provider.rs @@ -45,7 +45,12 @@ pub(super) fn rustls_versions( } fn to_rustls_cipher_suite(suite: TlsCipherSuite) -> Result { - use rustls::crypto::aws_lc_rs::cipher_suite::*; + use rustls::crypto::aws_lc_rs::cipher_suite::{ + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + TLS13_AES_128_GCM_SHA256, TLS13_AES_256_GCM_SHA384, TLS13_CHACHA20_POLY1305_SHA256, + }; Ok(match suite { TlsCipherSuite::Tls13Aes256GcmSha384 => TLS13_AES_256_GCM_SHA384, @@ -65,7 +70,9 @@ fn to_rustls_cipher_suite(suite: TlsCipherSuite) -> Result } fn to_rustls_kx_group(group: TlsKeyExchangeGroup) -> Result<&'static dyn SupportedKxGroup> { - use rustls::crypto::aws_lc_rs::kx_group::*; + use rustls::crypto::aws_lc_rs::kx_group::{ + MLKEM768, MLKEM1024, SECP256R1, SECP256R1MLKEM768, SECP384R1, X25519, X25519MLKEM768, + }; Ok(match group { TlsKeyExchangeGroup::X25519 => X25519, diff --git a/crates/rginx-http/src/tls/session.rs b/crates/rginx-http/src/tls/session.rs index 2a9f79bc..01c1000a 100644 --- a/crates/rginx-http/src/tls/session.rs +++ b/crates/rginx-http/src/tls/session.rs @@ -4,6 +4,26 @@ use rginx_core::{Error, Result, ServerTls}; use rustls::ServerConfig; use rustls::server::{NoServerSessionStorage, ProducesTickets, ServerSessionMemoryCache}; +#[derive(Debug)] +struct DisabledTicketProducer {} + +impl ProducesTickets for DisabledTicketProducer { + fn decrypt(&self, _cipher: &[u8]) -> Option> { + None + } + fn enabled(&self) -> bool { + false + } + + fn encrypt(&self, _plain: &[u8]) -> Option> { + None + } + + fn lifetime(&self) -> u32 { + 0 + } +} + pub(super) fn apply_session_policy( config: &mut ServerConfig, tls: Option<&ServerTls>, @@ -39,24 +59,3 @@ pub(super) fn apply_session_policy( Ok(()) } - -#[derive(Debug)] -struct DisabledTicketProducer {} - -impl ProducesTickets for DisabledTicketProducer { - fn enabled(&self) -> bool { - false - } - - fn lifetime(&self) -> u32 { - 0 - } - - fn encrypt(&self, _plain: &[u8]) -> Option> { - None - } - - fn decrypt(&self, _cipher: &[u8]) -> Option> { - None - } -} diff --git a/crates/rginx-http/src/tls/sni.rs b/crates/rginx-http/src/tls/sni.rs index 36ddd095..e3842f22 100644 --- a/crates/rginx-http/src/tls/sni.rs +++ b/crates/rginx-http/src/tls/sni.rs @@ -8,8 +8,8 @@ use rustls::server::{ClientHello, ResolvesServerCert}; /// SNI 证书解析器,支持基于域名选择证书 #[derive(Debug)] pub(super) struct SniCertificateResolver { - default: Vec>, by_name: CompiledSniCertificates, + default: Vec>, } impl SniCertificateResolver { @@ -38,12 +38,22 @@ struct CompiledSniCertificates { } impl CompiledSniCertificates { + fn best_match(&self, server_name: &str) -> Option<&Vec>> { + self.best_match_candidate(server_name).map(|candidate| &candidate.certs) + } + + fn best_match_candidate(&self, server_name: &str) -> Option<&CompiledSniCertificateCandidate> { + let hostname = CompiledServerNamePattern::normalize_host(server_name); + self.candidates + .iter() + .find(|candidate| candidate.compiled.matches_normalized_host(&hostname).is_some()) + } fn compile(by_name: HashMap>>) -> Self { let mut candidates = by_name .into_iter() .filter_map(|(pattern, certs)| { let compiled = CompiledServerNamePattern::compile(&pattern)?; - Some(CompiledSniCertificateCandidate { compiled, certs }) + Some(CompiledSniCertificateCandidate { certs, compiled }) }) .collect::>(); @@ -56,23 +66,12 @@ impl CompiledSniCertificates { Self { candidates } } - - fn best_match(&self, server_name: &str) -> Option<&Vec>> { - self.best_match_candidate(server_name).map(|candidate| &candidate.certs) - } - - fn best_match_candidate(&self, server_name: &str) -> Option<&CompiledSniCertificateCandidate> { - let hostname = CompiledServerNamePattern::normalize_host(server_name); - self.candidates - .iter() - .find(|candidate| candidate.compiled.matches_normalized_host(&hostname).is_some()) - } } #[derive(Debug)] struct CompiledSniCertificateCandidate { - compiled: CompiledServerNamePattern, certs: Vec>, + compiled: CompiledServerNamePattern, } impl CompiledSniCertificateCandidate { diff --git a/crates/rginx-http/src/tls/tests.rs b/crates/rginx-http/src/tls/tests.rs index 2e5a5577..62603d89 100644 --- a/crates/rginx-http/src/tls/tests.rs +++ b/crates/rginx-http/src/tls/tests.rs @@ -1,3 +1,7 @@ +mod acceptor; +mod policy; +mod selection; + use std::collections::HashMap; use std::fs; use std::sync::Arc; @@ -11,10 +15,6 @@ use super::{ best_matching_sni_certificates, best_matching_wildcard_certificates, build_tls_acceptor, }; -mod acceptor; -mod policy; -mod selection; - static TEST_CERT_COUNTER: AtomicU64 = AtomicU64::new(0); const TEST_SERVER_CERT_PEM: &str = "-----BEGIN CERTIFICATE-----\nMIIDCTCCAfGgAwIBAgIUE+LKmhgfKie/YU/anMKv+Xgr5dYwDQYJKoZIhvcNAQEL\nBQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTI2MDMyMDE1MzIzMloXDTI2MDMy\nMTE1MzIzMlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF\nAAOCAQ8AMIIBCgKCAQEAvxn1IYqOORs2Ys/6Ou54G3alu+wZOeGkPy/ZLYUuO0pK\nh1WgvPvwGF3w3XZdEPhB0JXhqwqoz60SwGQJtEM9GGRHVnBV+BeE/4L1XO4H6Gz5\npMKFaCcJPwO4IrspjffpKQ217K9l9vbjK31tJKwOGaQ//icyzF13xuUvZms67PNc\nBqhZQchld9s90InnL3fCS+J58s9pjE0qlTr7bodvOXaYBxboDlBh4YV7PW/wjwBo\ngUwcbiJvtrRnY7ZlRi/C/bZUTGJ5kO7vSlAgMh2KL1DyY2Ws06n5KUNgpAuIjmew\nMtuYJ9H2xgRMrMjgWSD8N/RRFut4xnpm7jlRepzvwwIDAQABo1MwUTAdBgNVHQ4E\nFgQUIezWZPz8VZj6n2znyGWv76RsGMswHwYDVR0jBBgwFoAUIezWZPz8VZj6n2zn\nyGWv76RsGMswDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAbngq\np7KT2JaXL8BYQGThBZwRODtqv/jXwc34zE3DPPRb1F3i8/odH7+9ZLse35Hj0/gp\nqFQ0DNdOuNlrbrvny208P1OcBe2hYWOSsRGyhZpM5Ai+DkuHheZfhNKvWKdbFn8+\nyfeyN3orSsin9QG0Yx3eqtO/1/6D5TtLsnY2/yPV/j0pv2GCCuB0kcKfygOQTYW6\nJrmYzeFeR/bnQM/lOM49leURdgC/x7tveNG7KRvD0X85M9iuT9/0+VSu6yAkcEi5\nx23C/Chzu7FFVxwZRHD+RshbV4QTPewhi17EJwroMYFpjGUHJVUfzo6W6bsWqA59\nCiiHI87NdBZv4JUCOQ==\n-----END CERTIFICATE-----\n"; diff --git a/crates/rginx-http/src/transition.rs b/crates/rginx-http/src/transition.rs index cab3c2b9..64f28d18 100644 --- a/crates/rginx-http/src/transition.rs +++ b/crates/rginx-http/src/transition.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use rginx_core::{ConfigSnapshot, Error, Result}; const RELOADABLE_FIELDS: [&str; 13] = [ @@ -42,12 +44,13 @@ pub struct ConfigTransitionBoundary { #[derive(Debug, Clone, PartialEq, Eq)] pub struct ConfigTransitionPlan { - pub kind: ConfigTransitionKind, pub boundary: ConfigTransitionBoundary, pub changed_restart_required_fields: Vec, + pub kind: ConfigTransitionKind, } impl ConfigTransitionBoundary { + #[must_use] pub fn current() -> Self { Self { reloadable_fields: RELOADABLE_FIELDS.iter().map(|field| (*field).to_string()).collect(), @@ -60,10 +63,12 @@ impl ConfigTransitionBoundary { } impl ConfigTransitionPlan { + #[must_use] pub fn requires_restart(&self) -> bool { self.kind == ConfigTransitionKind::RestartRequired } + #[must_use] pub fn restart_required_message(&self) -> Option { self.requires_restart().then(|| { format!( @@ -75,6 +80,7 @@ impl ConfigTransitionPlan { } } +#[must_use] pub fn config_transition_boundary() -> ConfigTransitionBoundary { ConfigTransitionBoundary::current() } @@ -109,11 +115,9 @@ pub fn plan_config_transition( "{}.http3.listen {} -> {}", current_listener.id, current_http3 - .map(|listen_addr| listen_addr.to_string()) - .unwrap_or_else(|| "-".to_string()), + .map_or_else(|| "-".to_string(), |listen_addr| listen_addr.to_string()), next_http3 - .map(|listen_addr| listen_addr.to_string()) - .unwrap_or_else(|| "-".to_string()), + .map_or_else(|| "-".to_string(), |listen_addr| listen_addr.to_string()), )); } } @@ -202,6 +206,3 @@ pub fn validate_config_transition(current: &ConfigSnapshot, next: &ConfigSnapsho } Ok(()) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-observability/Cargo.toml b/crates/rginx-observability/Cargo.toml index 11390b35..5eed3015 100644 --- a/crates/rginx-observability/Cargo.toml +++ b/crates/rginx-observability/Cargo.toml @@ -11,5 +11,8 @@ readme.workspace = true rust-version.workspace = true description = "Tracing and logging initialization for rginx." +[lints] +workspace = true + [dependencies] tracing-subscriber.workspace = true diff --git a/crates/rginx-observability/src/lib.rs b/crates/rginx-observability/src/lib.rs index e68f2c78..e55843ec 100644 --- a/crates/rginx-observability/src/lib.rs +++ b/crates/rginx-observability/src/lib.rs @@ -1,3 +1,12 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] pub mod logging; pub use logging::init_logging; diff --git a/crates/rginx-runtime/Cargo.toml b/crates/rginx-runtime/Cargo.toml index 5c3c0deb..096456f0 100644 --- a/crates/rginx-runtime/Cargo.toml +++ b/crates/rginx-runtime/Cargo.toml @@ -11,6 +11,9 @@ readme.workspace = true rust-version.workspace = true description = "Runtime orchestration, reload, shutdown, and health checks for rginx." +[lints] +workspace = true + [dependencies] bytes.workspace = true http.workspace = true diff --git a/crates/rginx-runtime/src/acme/challenge.rs b/crates/rginx-runtime/src/acme/challenge.rs index 4041edf6..d97bce68 100644 --- a/crates/rginx-runtime/src/acme/challenge.rs +++ b/crates/rginx-runtime/src/acme/challenge.rs @@ -55,6 +55,10 @@ pub(crate) struct TemporaryChallengeServer { } impl TemporaryChallengeServer { + pub(crate) fn backend(&self) -> Arc { + self.backend.clone() + } + pub(crate) async fn bind_for_config(config: &ConfigSnapshot) -> Result { let listen_addrs = http01_listener_addrs(config); if listen_addrs.is_empty() { @@ -83,10 +87,6 @@ impl TemporaryChallengeServer { Ok(Self { backend, shutdown_tx, tasks }) } - pub(crate) fn backend(&self) -> Arc { - self.backend.clone() - } - pub(crate) async fn shutdown(self) { let _ = self.shutdown_tx.send(true); for task in self.tasks { diff --git a/crates/rginx-runtime/src/acme/mod.rs b/crates/rginx-runtime/src/acme/mod.rs index 5dcc9d4d..1598a0c4 100644 --- a/crates/rginx-runtime/src/acme/mod.rs +++ b/crates/rginx-runtime/src/acme/mod.rs @@ -1,7 +1,3 @@ -use rginx_core::{ConfigSnapshot, Error, Result}; -use rginx_http::SharedState; -use tokio::sync::watch; - mod account; mod challenge; mod lock; @@ -12,6 +8,10 @@ mod storage; mod tests; mod types; +use rginx_core::{ConfigSnapshot, Error, Result}; +use rginx_http::SharedState; +use tokio::sync::watch; + pub use types::{CertificateFailure, IssueSummary}; use account::load_or_create_account; @@ -39,7 +39,7 @@ pub async fn issue_once(config: &ConfigSnapshot) -> Result { match plan_reconcile(spec, certificate_statuses.get(&spec.scope), settings) { Some(plan) => pending.push((spec, plan)), None => { - summary.skipped += 1; + summary.skipped = summary.skipped.saturating_add(1); tracing::info!( scope = %spec.scope, "managed certificate already satisfies current ACME spec" @@ -68,7 +68,7 @@ pub async fn issue_once(config: &ConfigSnapshot) -> Result { .await { Ok(()) => { - summary.issued += 1; + summary.issued = summary.issued.saturating_add(1); tracing::info!(scope = %spec.scope, "managed ACME certificate issued"); } Err(error) => { diff --git a/crates/rginx-runtime/src/acme/scheduler.rs b/crates/rginx-runtime/src/acme/scheduler.rs index a5ffa6f6..a449d08a 100644 --- a/crates/rginx-runtime/src/acme/scheduler.rs +++ b/crates/rginx-runtime/src/acme/scheduler.rs @@ -27,31 +27,38 @@ struct RetryBackoff { } impl RetryBackoff { - fn retain_scopes(&mut self, scopes: &HashSet) { - self.entries.retain(|scope, _| scopes.contains(scope)); - } - - fn remaining_delay(&self, scope: &str) -> Option { - self.entries - .get(scope) - .and_then(|state| state.next_retry_at.checked_duration_since(Instant::now())) - } - - fn record_success(&mut self, scope: &str) { - self.entries.remove(scope); - } - fn record_failure(&mut self, scope: &str) -> Duration { let failures = self.entries.get(scope).map(|state| state.failures.saturating_add(1)).unwrap_or(1); let exponent = failures.saturating_sub(1).min(8); - let retry_delay = std::cmp::min(INITIAL_RETRY_DELAY * (1 << exponent), MAX_RETRY_DELAY); + let multiplier = 1u32.checked_shl(exponent).unwrap_or(u32::MAX); + let retry_delay = + std::cmp::min(INITIAL_RETRY_DELAY.saturating_mul(multiplier), MAX_RETRY_DELAY); self.entries.insert( scope.to_string(), - RetryState { failures, next_retry_at: Instant::now() + retry_delay }, + RetryState { + failures, + next_retry_at: Instant::now() + .checked_add(retry_delay) + .expect("ACME retry deadline remains representable"), + }, ); retry_delay } + + fn record_success(&mut self, scope: &str) { + self.entries.remove(scope); + } + + fn remaining_delay(&self, scope: &str) -> Option { + self.entries + .get(scope) + .and_then(|state| state.next_retry_at.checked_duration_since(Instant::now())) + } + + fn retain_scopes(&mut self, scopes: &HashSet) { + self.entries.retain(|scope, _| scopes.contains(scope)); + } } pub(super) async fn run(state: SharedState, mut shutdown: watch::Receiver) { diff --git a/crates/rginx-runtime/src/acme/storage.rs b/crates/rginx-runtime/src/acme/storage.rs index 4cba49e6..bf2d0aa7 100644 --- a/crates/rginx-runtime/src/acme/storage.rs +++ b/crates/rginx-runtime/src/acme/storage.rs @@ -15,8 +15,8 @@ static TEMP_FILE_SEQUENCE: AtomicU64 = AtomicU64::new(0); #[derive(Serialize, Deserialize)] pub(crate) struct PersistedAccountCredentials { - pub(crate) directory_url: String, pub(crate) credentials: AccountCredentials, + pub(crate) directory_url: String, } pub(crate) fn load_account_credentials( diff --git a/crates/rginx-runtime/src/acme/types.rs b/crates/rginx-runtime/src/acme/types.rs index 65066d23..eaaddd72 100644 --- a/crates/rginx-runtime/src/acme/types.rs +++ b/crates/rginx-runtime/src/acme/types.rs @@ -12,35 +12,34 @@ const ACME_ORDER_POLL_INITIAL_DELAY: Duration = Duration::from_millis(500); #[derive(Debug, Clone, PartialEq, Eq)] pub struct CertificateFailure { - pub scope: String, pub error: String, + pub scope: String, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct IssueSummary { - pub total: usize, + pub failures: Vec, pub issued: usize, pub skipped: usize, - pub failures: Vec, + pub total: usize, } impl IssueSummary { - pub(crate) fn new(total: usize) -> Self { - Self { total, issued: 0, skipped: 0, failures: Vec::new() } - } - pub fn is_success(&self) -> bool { self.failures.is_empty() } + pub(crate) fn new(total: usize) -> Self { + Self { total, issued: 0, skipped: 0, failures: Vec::new() } + } } #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) enum ReconcileReason { + Expiring { remaining_days: i64, renew_before_days: i64 }, MissingCertificate, - MissingPrivateKey, MissingCertificateMetadata, + MissingPrivateKey, SanMismatch { expected: Vec, actual: Vec }, - Expiring { remaining_days: i64, renew_before_days: i64 }, } #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/crates/rginx-runtime/src/admin/mod.rs b/crates/rginx-runtime/src/admin/mod.rs index e326c246..9c887210 100644 --- a/crates/rginx-runtime/src/admin/mod.rs +++ b/crates/rginx-runtime/src/admin/mod.rs @@ -1,12 +1,12 @@ -use rginx_http::SharedState; -use tokio::sync::watch; - mod model; mod service; mod socket; #[cfg(test)] mod tests; +use rginx_http::SharedState; +use tokio::sync::watch; + pub use model::{ AdminRequest, AdminResponse, AdminSnapshot, RevisionSnapshot, SnapshotVersionSnapshot, }; diff --git a/crates/rginx-runtime/src/admin/model.rs b/crates/rginx-runtime/src/admin/model.rs index 60ac38fc..d3166ee2 100644 --- a/crates/rginx-runtime/src/admin/model.rs +++ b/crates/rginx-runtime/src/admin/model.rs @@ -7,28 +7,28 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum AdminRequest { + ClearCacheInvalidations { zone_name: String }, + GetAgentStatus, + GetCacheStats, + GetCounters, + GetDelta { since_version: u64, include: Option>, window_secs: Option }, + GetPeerHealth, + GetRevision, GetSnapshot { include: Option>, window_secs: Option }, GetSnapshotVersion, - GetDelta { since_version: u64, include: Option>, window_secs: Option }, - WaitForSnapshotChange { since_version: u64, timeout_ms: Option }, GetStatus, - GetCacheStats, - GetCounters, GetTrafficStats { window_secs: Option }, - GetPeerHealth, GetUpstreamStats { window_secs: Option }, - GetAgentStatus, - SetAgentDisabled { disabled: bool }, - PurgeCacheZone { zone_name: String }, - PurgeCacheKey { zone_name: String, key: String }, - PurgeCachePrefix { zone_name: String, prefix: String }, - InvalidateCacheZone { zone_name: String }, InvalidateCacheKey { zone_name: String, key: String }, InvalidateCachePrefix { zone_name: String, prefix: String }, InvalidateCacheTag { zone_name: String, tag: String }, - ClearCacheInvalidations { zone_name: String }, - GetRevision, + InvalidateCacheZone { zone_name: String }, + PurgeCacheKey { zone_name: String, key: String }, + PurgeCachePrefix { zone_name: String, prefix: String }, + PurgeCacheZone { zone_name: String }, + SetAgentDisabled { disabled: bool }, SetDesiredRevision { desired_revision: u64 }, + WaitForSnapshotChange { since_version: u64, timeout_ms: Option }, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -39,11 +39,11 @@ pub struct RevisionSnapshot { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] struct RevisionSnapshotWire { #[serde(default)] - desired_revision: u64, + converged: bool, #[serde(default)] current_revision: Option, #[serde(default)] - converged: bool, + desired_revision: u64, #[serde(default)] revision: Option, } @@ -86,42 +86,42 @@ pub struct SnapshotVersionSnapshot { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct AdminSnapshot { - pub schema_version: u32, - pub snapshot_version: u64, - pub captured_at_unix_ms: u64, - pub pid: u32, pub binary_version: String, - pub included_modules: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub status: Option, + pub cache: Option, + pub captured_at_unix_ms: u64, #[serde(skip_serializing_if = "Option::is_none")] pub counters: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub traffic: Option, + pub included_modules: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub peer_health: Option>, + pub pid: u32, + pub schema_version: u32, + pub snapshot_version: u64, #[serde(skip_serializing_if = "Option::is_none")] - pub upstreams: Option>, + pub status: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub cache: Option, + pub traffic: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub upstreams: Option>, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "type", content = "data")] -#[allow(clippy::large_enum_variant)] +#[expect(clippy::large_enum_variant, reason = "admin API responses stay as direct typed variants")] pub enum AdminResponse { + AgentStatus(AgentRuntimeSnapshot), + CacheInvalidation(CacheInvalidationResult), + CachePurge(CachePurgeResult), + CacheStats(CacheStatsSnapshot), + Counters(HttpCountersSnapshot), + Delta(SnapshotDeltaSnapshot), + Error { message: String }, + PeerHealth(Vec), + Revision(RevisionSnapshot), Snapshot(AdminSnapshot), SnapshotVersion(SnapshotVersionSnapshot), - Delta(SnapshotDeltaSnapshot), Status(RuntimeStatusSnapshot), - CacheStats(CacheStatsSnapshot), - Counters(HttpCountersSnapshot), TrafficStats(TrafficStatsSnapshot), - PeerHealth(Vec), UpstreamStats(Vec), - AgentStatus(AgentRuntimeSnapshot), - CachePurge(CachePurgeResult), - CacheInvalidation(CacheInvalidationResult), - Revision(RevisionSnapshot), - Error { message: String }, } diff --git a/crates/rginx-runtime/src/admin/socket.rs b/crates/rginx-runtime/src/admin/socket.rs index ce482632..4bd5ca5a 100644 --- a/crates/rginx-runtime/src/admin/socket.rs +++ b/crates/rginx-runtime/src/admin/socket.rs @@ -8,23 +8,17 @@ use tokio::task::JoinError; use super::{INSTALLED_ADMIN_SOCKET_PATH, INSTALLED_CONFIG_PATH}; -pub fn admin_socket_path_for_config(config_path: &Path) -> PathBuf { - if config_path == Path::new(INSTALLED_CONFIG_PATH) { - return PathBuf::from(INSTALLED_ADMIN_SOCKET_PATH); - } - - let parent = config_path.parent().unwrap_or_else(|| Path::new(".")); - let stem = config_path.file_stem().and_then(|value| value.to_str()).unwrap_or("rginx"); - parent.join(format!("{stem}.admin.sock")) -} - pub(super) struct AdminSocketGuard { - path: PathBuf, device: u64, inode: u64, + path: PathBuf, } impl AdminSocketGuard { + pub(super) fn from_bound_path(path: &Path) -> io::Result { + let metadata = fs::metadata(path)?; + Ok(Self { path: path.to_path_buf(), device: metadata.dev(), inode: metadata.ino() }) + } pub(super) fn prepare_path(path: &Path) -> io::Result<()> { if let Some(parent) = path.parent() && !parent.as_os_str().is_empty() @@ -38,11 +32,6 @@ impl AdminSocketGuard { Ok(()) } - - pub(super) fn from_bound_path(path: &Path) -> io::Result { - let metadata = fs::metadata(path)?; - Ok(Self { path: path.to_path_buf(), device: metadata.dev(), inode: metadata.ino() }) - } } impl Drop for AdminSocketGuard { @@ -56,6 +45,16 @@ impl Drop for AdminSocketGuard { } } +pub fn admin_socket_path_for_config(config_path: &Path) -> PathBuf { + if config_path == Path::new(INSTALLED_CONFIG_PATH) { + return PathBuf::from(INSTALLED_ADMIN_SOCKET_PATH); + } + + let parent = config_path.parent().unwrap_or_else(|| Path::new(".")); + let stem = config_path.file_stem().and_then(|value| value.to_str()).unwrap_or("rginx"); + parent.join(format!("{stem}.admin.sock")) +} + pub(super) fn log_admin_connection_result(result: Result<(), JoinError>) { if let Err(error) = result { if error.is_panic() { diff --git a/crates/rginx-runtime/src/agent.rs b/crates/rginx-runtime/src/agent.rs index 92895870..a6752d1b 100644 --- a/crates/rginx-runtime/src/agent.rs +++ b/crates/rginx-runtime/src/agent.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::sync::Arc; use rginx_core::{AgentSettings, ControlPlaneSettings}; @@ -89,6 +91,3 @@ async fn run_outbound_with_client( tracing::info!(node_id = %agent.node_id, "outbound agent task stopped"); Ok(()) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-runtime/src/agent/tests.rs b/crates/rginx-runtime/src/agent/tests.rs index 86104267..b0bf7130 100644 --- a/crates/rginx-runtime/src/agent/tests.rs +++ b/crates/rginx-runtime/src/agent/tests.rs @@ -12,6 +12,48 @@ use tokio::sync::watch; use super::*; +struct NoopControlCenter; + +impl rginx_agent::OutboundControlPlaneClient for NoopControlCenter { + fn heartbeat( + &self, + _request: rginx_agent::AgentHeartbeatRequest, + ) -> Pin> + Send + 'static>> { + Box::pin(async { Ok(()) }) + } + + fn poll_commands( + &self, + _node_id: String, + _cursor: Option, + _timeout: Duration, + ) -> Pin< + Box< + dyn Future> + + Send + + 'static, + >, + > { + Box::pin(async { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(rginx_agent::AgentPollResponse::empty()) + }) + } + + fn post_result( + &self, + _result: rginx_agent::AgentCommandResult, + ) -> Pin> + Send + 'static>> { + Box::pin(async { Ok(()) }) + } + fn register( + &self, + _request: rginx_agent::AgentRegisterRequest, + ) -> Pin> + Send + 'static>> { + Box::pin(async { Ok(()) }) + } +} + #[tokio::test] async fn outbound_agent_task_starts_and_stops_on_shutdown() { let tempdir = tempfile::tempdir().expect("tempdir should be created"); @@ -55,49 +97,6 @@ async fn outbound_agent_task_starts_and_stops_on_shutdown() { .expect("outbound agent task should succeed"); } -struct NoopControlCenter; - -impl rginx_agent::OutboundControlPlaneClient for NoopControlCenter { - fn register( - &self, - _request: rginx_agent::AgentRegisterRequest, - ) -> Pin> + Send + 'static>> { - Box::pin(async { Ok(()) }) - } - - fn heartbeat( - &self, - _request: rginx_agent::AgentHeartbeatRequest, - ) -> Pin> + Send + 'static>> { - Box::pin(async { Ok(()) }) - } - - fn poll_commands( - &self, - _node_id: String, - _cursor: Option, - _timeout: Duration, - ) -> Pin< - Box< - dyn Future> - + Send - + 'static, - >, - > { - Box::pin(async { - tokio::time::sleep(Duration::from_millis(10)).await; - Ok(rginx_agent::AgentPollResponse::empty()) - }) - } - - fn post_result( - &self, - _result: rginx_agent::AgentCommandResult, - ) -> Pin> + Send + 'static>> { - Box::pin(async { Ok(()) }) - } -} - async fn wait_for_agent_identity(state: &RuntimeState) { for _ in 0..20 { let status = state.http.status_snapshot().await; diff --git a/crates/rginx-runtime/src/apply.rs b/crates/rginx-runtime/src/apply.rs index 541a5065..78b034cb 100644 --- a/crates/rginx-runtime/src/apply.rs +++ b/crates/rginx-runtime/src/apply.rs @@ -1,3 +1,7 @@ +mod staged_change; +#[cfg(test)] +mod tests; + use rginx_agent::model::ConfigApplyResultView; use rginx_agent::{ConfigApplyExecutor, ConfigApplyFuture, ConfigApplyOutcome}; use rginx_config::managed::{self, ManagedResourceMutation}; @@ -7,10 +11,6 @@ use rginx_http::ConfigFailureStageSnapshot; use crate::reload::describe_tls_certificate_changes; use crate::state::RuntimeState; -mod staged_change; -#[cfg(test)] -mod tests; - use staged_change::StagedManagedFileChange; #[derive(Clone)] diff --git a/crates/rginx-runtime/src/apply/staged_change.rs b/crates/rginx-runtime/src/apply/staged_change.rs index fbb5302b..352fef89 100644 --- a/crates/rginx-runtime/src/apply/staged_change.rs +++ b/crates/rginx-runtime/src/apply/staged_change.rs @@ -9,8 +9,8 @@ use rginx_config::managed::{AppliedManagedMutation, ManagedResourceDocument}; static ARTIFACT_SEQUENCE: AtomicU64 = AtomicU64::new(0); pub(super) struct StagedManagedFileChange { - final_path: PathBuf, backup_path: Option, + final_path: PathBuf, } impl StagedManagedFileChange { @@ -26,6 +26,12 @@ impl StagedManagedFileChange { } } + fn activate_delete(final_path: &Path) -> io::Result { + let backup_path = sibling_artifact_path(final_path, "rollback"); + fs::rename(final_path, &backup_path)?; + Ok(Self { final_path: final_path.to_path_buf(), backup_path: Some(backup_path) }) + } + fn activate_upsert(final_path: &Path, document: &ManagedResourceDocument) -> io::Result { if let Some(parent) = final_path.parent() { fs::create_dir_all(parent)?; @@ -59,12 +65,6 @@ impl StagedManagedFileChange { Ok(Self { final_path: final_path.to_path_buf(), backup_path }) } - fn activate_delete(final_path: &Path) -> io::Result { - let backup_path = sibling_artifact_path(final_path, "rollback"); - fs::rename(final_path, &backup_path)?; - Ok(Self { final_path: final_path.to_path_buf(), backup_path: Some(backup_path) }) - } - pub(super) fn commit(self) { if let Some(backup_path) = self.backup_path && let Err(error) = fs::remove_file(&backup_path) diff --git a/crates/rginx-runtime/src/bootstrap/listeners/activate.rs b/crates/rginx-runtime/src/bootstrap/listeners/activate.rs index 94b6aaae..e508f2ce 100644 --- a/crates/rginx-runtime/src/bootstrap/listeners/activate.rs +++ b/crates/rginx-runtime/src/bootstrap/listeners/activate.rs @@ -7,8 +7,8 @@ use super::PreparedListenerWorkerGroup; use super::group::ListenerWorkerGroup; struct WorkerDrainGuard { - remaining_workers: Arc, drain_completion_notify: Arc, + remaining_workers: Arc, } impl Drop for WorkerDrainGuard { @@ -27,7 +27,7 @@ pub(crate) fn activate_prepared_listener_worker_group( let (shutdown_tx, _shutdown_rx) = watch::channel(false); let mut tasks = Vec::new(); let remaining_workers = Arc::new(AtomicUsize::new( - prepared.worker_listeners.len() + prepared.http3_endpoints.len(), + prepared.worker_listeners.len().saturating_add(prepared.http3_endpoints.len()), )); for (worker_index, listener_socket) in prepared.worker_listeners.into_iter().enumerate() { diff --git a/crates/rginx-runtime/src/bootstrap/listeners/bind_udp.rs b/crates/rginx-runtime/src/bootstrap/listeners/bind_udp.rs index 2cec0662..568468d1 100644 --- a/crates/rginx-runtime/src/bootstrap/listeners/bind_udp.rs +++ b/crates/rginx-runtime/src/bootstrap/listeners/bind_udp.rs @@ -32,7 +32,7 @@ pub(crate) fn normalize_inherited_udp_sockets( } else if sockets.len() < desired_socket_count { sockets.extend(bind_std_udp_sockets_with_reuse_port( listen_addr, - desired_socket_count - sockets.len(), + desired_socket_count.saturating_sub(sockets.len()), desired_socket_count > 1, )?); } diff --git a/crates/rginx-runtime/src/bootstrap/listeners/group.rs b/crates/rginx-runtime/src/bootstrap/listeners/group.rs index 919fad12..a557ce77 100644 --- a/crates/rginx-runtime/src/bootstrap/listeners/group.rs +++ b/crates/rginx-runtime/src/bootstrap/listeners/group.rs @@ -8,20 +8,18 @@ use tokio::task::JoinHandle; use crate::restart::ListenerHandle; pub(crate) struct ListenerWorkerGroup { + pub(super) joined_tasks: usize, pub(super) listener: Listener, + pub(super) shutdown_tx: watch::Sender, pub(super) std_listener: Arc, pub(super) std_udp_sockets: Vec>, - pub(super) shutdown_tx: watch::Sender, pub(super) tasks: Vec>>, - pub(super) joined_tasks: usize, } impl ListenerWorkerGroup { - pub(crate) fn restart_handle(&self) -> ListenerHandle { - ListenerHandle { - listener: self.listener.clone(), - std_listener: self.std_listener.clone(), - std_udp_sockets: self.std_udp_sockets.clone(), + pub(crate) fn abort(&self) { + for task in &self.tasks { + task.abort(); } } @@ -29,13 +27,14 @@ impl ListenerWorkerGroup { let _ = self.shutdown_tx.send(true); } - pub(crate) fn abort(&self) { - for task in &self.tasks { - task.abort(); - } - } - pub(crate) fn is_finished(&self) -> bool { self.tasks.iter().all(JoinHandle::is_finished) } + pub(crate) fn restart_handle(&self) -> ListenerHandle { + ListenerHandle { + listener: self.listener.clone(), + std_listener: self.std_listener.clone(), + std_udp_sockets: self.std_udp_sockets.clone(), + } + } } diff --git a/crates/rginx-runtime/src/bootstrap/listeners/join.rs b/crates/rginx-runtime/src/bootstrap/listeners/join.rs index e526b699..3ed3652a 100644 --- a/crates/rginx-runtime/src/bootstrap/listeners/join.rs +++ b/crates/rginx-runtime/src/bootstrap/listeners/join.rs @@ -61,7 +61,7 @@ pub(crate) async fn join_listener_worker_group(group: &mut ListenerWorkerGroup) } } } - group.joined_tasks += 1; + group.joined_tasks = group.joined_tasks.saturating_add(1); } group.tasks.clear(); group.joined_tasks = 0; @@ -80,7 +80,7 @@ async fn join_aborted_listener_worker_group(group: &mut ListenerWorkerGroup) { { tracing::warn!(%error, listener = %group.listener.name, "http worker failed after abort"); } - group.joined_tasks += 1; + group.joined_tasks = group.joined_tasks.saturating_add(1); } group.tasks.clear(); group.joined_tasks = 0; diff --git a/crates/rginx-runtime/src/bootstrap/listeners/mod.rs b/crates/rginx-runtime/src/bootstrap/listeners/mod.rs index 995f37e8..68cb0c6f 100644 --- a/crates/rginx-runtime/src/bootstrap/listeners/mod.rs +++ b/crates/rginx-runtime/src/bootstrap/listeners/mod.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - mod activate; mod bind_tcp; mod bind_udp; @@ -11,6 +9,8 @@ mod reconcile; #[cfg(test)] mod tests; +use std::collections::HashMap; + #[cfg(test)] pub(super) use bind_tcp::bind_std_listener; #[cfg(test)] diff --git a/crates/rginx-runtime/src/bootstrap/listeners/prepare.rs b/crates/rginx-runtime/src/bootstrap/listeners/prepare.rs index 281b9545..2e995076 100644 --- a/crates/rginx-runtime/src/bootstrap/listeners/prepare.rs +++ b/crates/rginx-runtime/src/bootstrap/listeners/prepare.rs @@ -14,18 +14,67 @@ use super::{ListenerGroupMap, ListenerWorkerGroup}; type ListenerBindingKey = (rginx_core::ListenerTransportKind, std::net::SocketAddr); pub(crate) struct PreparedListenerWorkerGroup { + pub(super) http3_endpoints: Vec, pub(super) listener: Listener, pub(super) std_listener: Arc, pub(super) std_udp_sockets: Vec>, pub(super) worker_listeners: Vec, - pub(super) http3_endpoints: Vec, } struct ListenerReloadBoundary { - active_listener_ids: HashSet, - draining_listener_ids: HashSet, active_bindings: HashSet, + active_listener_ids: HashSet, draining_bindings: HashSet, + draining_listener_ids: HashSet, +} + +impl ListenerReloadBoundary { + fn contains_active_listener(&self, listener_id: &str) -> bool { + self.active_listener_ids.contains(listener_id) + } + + fn ensure_listener_can_be_added(&self, listener: &Listener) -> Result<()> { + if self.draining_listener_ids.contains(&listener.id) { + return Err(Error::Server(format!( + "listener `{}` cannot be re-added until the previous generation has drained", + listener.name + ))); + } + + for binding in listener_binding_keys(listener) { + if self.active_bindings.contains(&binding) || self.draining_bindings.contains(&binding) + { + return Err(Error::Server(format!( + "listener `{}` reuses {} listen address `{}` with a different listener identity during reload", + listener.name, + binding.0.as_str(), + binding.1 + ))); + } + } + + Ok(()) + } + fn from_groups( + active_listener_groups: &ListenerGroupMap, + draining_listener_groups: &[ListenerWorkerGroup], + ) -> Self { + Self { + active_listener_ids: active_listener_groups.keys().cloned().collect(), + draining_listener_ids: draining_listener_groups + .iter() + .map(|group| group.listener.id.clone()) + .collect(), + active_bindings: active_listener_groups + .values() + .flat_map(|group| listener_binding_keys(&group.listener)) + .collect(), + draining_bindings: draining_listener_groups + .iter() + .flat_map(|group| listener_binding_keys(&group.listener)) + .collect(), + } + } } pub(crate) async fn build_initial_listener_groups( @@ -107,56 +156,6 @@ pub(crate) fn prepare_added_listener_bindings( Ok(prepared) } -impl ListenerReloadBoundary { - fn from_groups( - active_listener_groups: &ListenerGroupMap, - draining_listener_groups: &[ListenerWorkerGroup], - ) -> Self { - Self { - active_listener_ids: active_listener_groups.keys().cloned().collect(), - draining_listener_ids: draining_listener_groups - .iter() - .map(|group| group.listener.id.clone()) - .collect(), - active_bindings: active_listener_groups - .values() - .flat_map(|group| listener_binding_keys(&group.listener)) - .collect(), - draining_bindings: draining_listener_groups - .iter() - .flat_map(|group| listener_binding_keys(&group.listener)) - .collect(), - } - } - - fn contains_active_listener(&self, listener_id: &str) -> bool { - self.active_listener_ids.contains(listener_id) - } - - fn ensure_listener_can_be_added(&self, listener: &Listener) -> Result<()> { - if self.draining_listener_ids.contains(&listener.id) { - return Err(Error::Server(format!( - "listener `{}` cannot be re-added until the previous generation has drained", - listener.name - ))); - } - - for binding in listener_binding_keys(listener) { - if self.active_bindings.contains(&binding) || self.draining_bindings.contains(&binding) - { - return Err(Error::Server(format!( - "listener `{}` reuses {} listen address `{}` with a different listener identity during reload", - listener.name, - binding.0.as_str(), - binding.1 - ))); - } - } - - Ok(()) - } -} - fn listener_binding_keys(listener: &Listener) -> impl Iterator + '_ { listener.transport_bindings().into_iter().map(|binding| (binding.kind, binding.listen_addr)) } diff --git a/crates/rginx-runtime/src/bootstrap/listeners/reconcile.rs b/crates/rginx-runtime/src/bootstrap/listeners/reconcile.rs index 332fbe70..5dc49b0c 100644 --- a/crates/rginx-runtime/src/bootstrap/listeners/reconcile.rs +++ b/crates/rginx-runtime/src/bootstrap/listeners/reconcile.rs @@ -68,7 +68,7 @@ pub(crate) async fn prune_draining_listener_groups( let mut index = 0usize; while index < draining_listener_groups.len() { if !draining_listener_groups[index].is_finished() { - index += 1; + index = index.saturating_add(1); continue; } diff --git a/crates/rginx-runtime/src/bootstrap/mod.rs b/crates/rginx-runtime/src/bootstrap/mod.rs index aafe7447..cafc8c27 100644 --- a/crates/rginx-runtime/src/bootstrap/mod.rs +++ b/crates/rginx-runtime/src/bootstrap/mod.rs @@ -1,3 +1,8 @@ +mod listeners; +mod reload; +mod restart; +mod shutdown; + use std::path::PathBuf; use std::sync::Arc; @@ -13,11 +18,6 @@ use crate::ocsp; use crate::shutdown as runtime_shutdown; use crate::state::RuntimeState; -mod listeners; -mod reload; -mod restart; -mod shutdown; - use listeners::{ ListenerWorkerGroup, build_initial_listener_groups, prune_draining_listener_groups, }; diff --git a/crates/rginx-runtime/src/bootstrap/shutdown.rs b/crates/rginx-runtime/src/bootstrap/shutdown.rs index 170b9d75..2aceb81b 100644 --- a/crates/rginx-runtime/src/bootstrap/shutdown.rs +++ b/crates/rginx-runtime/src/bootstrap/shutdown.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::time::Duration; use rginx_core::{Error, Result}; @@ -12,12 +14,12 @@ use super::listeners::{ }; pub(super) struct ShutdownTasks<'a> { + pub(super) acme_task: &'a mut Option>, pub(super) admin_task: &'a mut Option>>, pub(super) agent_task: &'a mut Option>>, - pub(super) control_plane_task: &'a mut Option>>, pub(super) cache_task: &'a mut Option>, + pub(super) control_plane_task: &'a mut Option>>, pub(super) health_task: &'a mut Option>, - pub(super) acme_task: &'a mut Option>, pub(super) ocsp_task: &'a mut Option>, } @@ -167,6 +169,3 @@ async fn join_unit_task_after_abort(task: &mut Option>, name: &st tracing::warn!(%error, "{name} task failed after abort"); } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-runtime/src/health.rs b/crates/rginx-runtime/src/health.rs index eed68a03..9b9f10f1 100644 --- a/crates/rginx-runtime/src/health.rs +++ b/crates/rginx-runtime/src/health.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::collections::{HashMap, HashSet}; use std::sync::Arc; @@ -11,17 +13,17 @@ use tokio::time::Instant; #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct ProbeKey { - upstream_name: String, peer_url: String, + upstream_name: String, } #[derive(Clone)] struct ProbeTarget { - key: ProbeKey, clients: ProxyClients, - upstream: Arc, - peer: UpstreamPeer, health_check: ActiveHealthCheck, + key: ProbeKey, + peer: UpstreamPeer, + upstream: Arc, } pub async fn run(state: SharedState, mut shutdown: watch::Receiver) { @@ -54,7 +56,11 @@ pub async fn run(state: SharedState, mut shutdown: watch::Receiver) { let mut probes = JoinSet::new(); for target in due_targets { - next_due.insert(target.key.clone(), now + target.health_check.interval); + next_due.insert( + target.key.clone(), + now.checked_add(target.health_check.interval) + .expect("health check deadline remains representable"), + ); probes.spawn(async move { probe_upstream_peer(target.clients, target.upstream, target.peer).await; }); @@ -126,7 +132,8 @@ fn collect_probe_targets(snapshot: ActiveState) -> Vec { } fn initial_probe_due_at(now: Instant, key: &ProbeKey, interval: std::time::Duration) -> Instant { - now + initial_probe_delay(key, interval) + now.checked_add(initial_probe_delay(key, interval)) + .expect("initial health probe deadline remains representable") } fn initial_probe_delay(key: &ProbeKey, interval: std::time::Duration) -> std::time::Duration { @@ -135,7 +142,7 @@ fn initial_probe_delay(key: &ProbeKey, interval: std::time::Duration) -> std::ti return std::time::Duration::ZERO; } - let jitter_nanos = stable_probe_hash(key) % interval_nanos; + let jitter_nanos = stable_probe_hash(key).rem_euclid(interval_nanos); let jitter_nanos = jitter_nanos.min(u128::from(u64::MAX)) as u64; std::time::Duration::from_nanos(jitter_nanos) } @@ -160,6 +167,3 @@ fn log_probe_task_result(result: std::result::Result<(), JoinError>) { } } } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-runtime/src/lib.rs b/crates/rginx-runtime/src/lib.rs index f9efa12b..0d3ef0ca 100644 --- a/crates/rginx-runtime/src/lib.rs +++ b/crates/rginx-runtime/src/lib.rs @@ -1,3 +1,12 @@ +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] pub mod acme; pub mod admin; mod agent; diff --git a/crates/rginx-runtime/src/ocsp/mod.rs b/crates/rginx-runtime/src/ocsp/mod.rs index 028a6ae9..9856ff1d 100644 --- a/crates/rginx-runtime/src/ocsp/mod.rs +++ b/crates/rginx-runtime/src/ocsp/mod.rs @@ -1,3 +1,12 @@ +mod client; +mod persist; +mod refresh; +mod scheduler; +mod spec; +mod state; +#[cfg(test)] +mod tests; + use std::time::Duration; use client::{OcspClient, build_ocsp_client}; @@ -9,15 +18,6 @@ use refresh::fetch_ocsp_response_from_url; use rginx_http::SharedState; use tokio::sync::watch; -mod client; -mod persist; -mod refresh; -mod scheduler; -mod spec; -mod state; -#[cfg(test)] -mod tests; - const OCSP_REFRESH_INTERVAL: Duration = Duration::from_secs(6 * 60 * 60); const OCSP_FETCH_TIMEOUT: Duration = Duration::from_secs(15); diff --git a/crates/rginx-runtime/src/restart.rs b/crates/rginx-runtime/src/restart.rs index 006724e2..287915d8 100644 --- a/crates/rginx-runtime/src/restart.rs +++ b/crates/rginx-runtime/src/restart.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use std::collections::HashMap; use std::env; use std::io::{self, Write}; @@ -35,9 +37,9 @@ enum InheritedSocketKind { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] struct InheritedListenerFd { + fd: RawFd, kind: InheritedSocketKind, listen_addr: SocketAddr, - fd: RawFd, } pub struct InheritedListeners { @@ -182,6 +184,3 @@ fn set_fd_inheritable(fd: RawFd) -> Result<()> { Ok(()) } - -#[cfg(test)] -mod tests; diff --git a/crates/rginx-runtime/src/state.rs b/crates/rginx-runtime/src/state.rs index 0d72e016..3a455880 100644 --- a/crates/rginx-runtime/src/state.rs +++ b/crates/rginx-runtime/src/state.rs @@ -6,12 +6,15 @@ use tokio::sync::Mutex; #[derive(Clone)] pub struct RuntimeState { + pub(crate) config_change_lock: Arc>, pub config_path: PathBuf, pub http: rginx_http::SharedState, - pub(crate) config_change_lock: Arc>, } impl RuntimeState { + pub async fn current_config(&self) -> Arc { + self.http.current_config().await + } pub fn new(config_path: PathBuf, config: ConfigSnapshot) -> Result { Ok(Self { config_path: config_path.clone(), @@ -19,8 +22,4 @@ impl RuntimeState { config_change_lock: Arc::new(Mutex::new(())), }) } - - pub async fn current_config(&self) -> Arc { - self.http.current_config().await - } } diff --git a/crates/rginx-sdk/Cargo.toml b/crates/rginx-sdk/Cargo.toml index b950cc9a..57f54b45 100644 --- a/crates/rginx-sdk/Cargo.toml +++ b/crates/rginx-sdk/Cargo.toml @@ -10,6 +10,9 @@ documentation.workspace = true readme.workspace = true rust-version.workspace = true +[lints] +workspace = true + [dependencies] # HTTP client reqwest = { version = "0.13", features = ["json", "rustls", "rustls-native-certs"], default-features = false } @@ -37,4 +40,3 @@ url = "2.5" [dev-dependencies] tokio-test = "0.4" mockito = "1.7" - diff --git a/crates/rginx-sdk/src/client.rs b/crates/rginx-sdk/src/client.rs index 3c2a674b..df46e850 100644 --- a/crates/rginx-sdk/src/client.rs +++ b/crates/rginx-sdk/src/client.rs @@ -1,6 +1,10 @@ use crate::config::{AuthConfig, ClientConfig}; use crate::error::{Error, Result}; -use crate::models::*; +use crate::models::{ + CircuitBreakerConfig, CircuitBreakerStats, ConfigApplyRequest, ConfigMetadata, ConfigRevision, + ConfigValidationRequest, ConfigValidationResult, HealthStatus, NodeInfo, NodeRegistration, + ReadinessStatus, RolloutPlan, RolloutState, +}; use reqwest::{Client, RequestBuilder, Response, StatusCode}; use serde::Serialize; use serde::de::DeserializeOwned; @@ -13,65 +17,6 @@ pub struct ControlPlaneClient { } impl ControlPlaneClient { - /// Create a new control plane client - pub fn new(config: ClientConfig) -> Result { - let builder = Client::builder() - .timeout(config.timeout) - .danger_accept_invalid_certs(config.tls.insecure_skip_verify); - - // TODO: Add mTLS support when needed - let http_client = builder.build()?; - - Ok(Self { config, http_client }) - } - - // ======================================================================== - // Node Management - // ======================================================================== - - /// Register a new node with the control plane - pub async fn register_node( - &self, - node_id: &str, - registration: Option, - ) -> Result { - let reg = registration.unwrap_or_else(|| NodeRegistration { - node_id: node_id.to_string(), - region: None, - zone: None, - labels: HashMap::new(), - capabilities: vec![], - }); - - let response: serde_json::Value = self.post("/v1/nodes/register", ®).await?; - - Ok(response["node_id"].as_str().unwrap_or(node_id).to_string()) - } - - /// Send a heartbeat for a registered node - pub async fn heartbeat(&self, node_id: &str) -> Result<()> { - let _: serde_json::Value = - self.post(&format!("/v1/nodes/{}/heartbeat", node_id), &serde_json::json!({})).await?; - Ok(()) - } - - /// Unregister a node - pub async fn unregister_node(&self, node_id: &str) -> Result<()> { - let _: serde_json::Value = - self.post(&format!("/v1/nodes/{}/unregister", node_id), &serde_json::json!({})).await?; - Ok(()) - } - - /// List all registered nodes - pub async fn list_nodes(&self) -> Result> { - self.get("/v1/nodes").await - } - - /// Get information about a specific node - pub async fn get_node(&self, node_id: &str) -> Result { - self.get(&format!("/v1/nodes/{}", node_id)).await - } - // ======================================================================== // Configuration Management // ======================================================================== @@ -88,34 +33,25 @@ impl ControlPlaneClient { Ok(response["revision"].as_u64().unwrap_or(0)) } - /// Validate a configuration without applying it (dry-run) - pub async fn validate_config( - &self, - config: serde_json::Value, - ) -> Result { - let request = ConfigValidationRequest { config }; - self.post("/v1/config/validate", &request).await - } - - /// Get configuration history - pub async fn get_config_history(&self, limit: Option) -> Result> { - let path = if let Some(limit) = limit { - format!("/v1/config/history?limit={}", limit) - } else { - "/v1/config/history".to_string() - }; - self.get(&path).await + fn build_request(&self, request: RequestBuilder) -> RequestBuilder { + match &self.config.auth { + AuthConfig::None => request, + AuthConfig::ApiKey(key) => request.header("X-API-Key", key), + AuthConfig::MutualTls { .. } => { + // mTLS is handled at the HTTP client level + request + } + } } - /// Rollback to a previous configuration revision - pub async fn rollback_config(&self, revision: u64, reason: Option) -> Result { - let request = serde_json::json!({ - "revision": revision, - "reason": reason, - }); - let response: serde_json::Value = self.post("/v1/config/rollback", &request).await?; + // ======================================================================== + // Circuit Breaker + // ======================================================================== - Ok(response["new_revision"].as_u64().unwrap_or(0)) + /// Create a new circuit breaker + pub async fn create_circuit_breaker(&self, config: CircuitBreakerConfig) -> Result<()> { + let _: serde_json::Value = self.post("/v1/breakers", &config).await?; + Ok(()) } // ======================================================================== @@ -129,78 +65,88 @@ impl ControlPlaneClient { Ok(response["rollout_id"].as_str().unwrap_or("").to_string()) } - /// Start a rollout - pub async fn start_rollout(&self, rollout_id: &str) -> Result<()> { - let _: serde_json::Value = self - .post(&format!("/v1/rollouts/{}/start", rollout_id), &serde_json::json!({})) - .await?; - Ok(()) - } - - /// Pause a rollout - pub async fn pause_rollout(&self, rollout_id: &str) -> Result<()> { - let _: serde_json::Value = self - .post(&format!("/v1/rollouts/{}/pause", rollout_id), &serde_json::json!({})) - .await?; - Ok(()) - } + async fn delete(&self, path: &str) -> Result<()> { + let url = self.config.base_url.join(path)?; + let response = self.build_request(self.http_client.delete(url)).send().await?; - /// Resume a paused rollout - pub async fn resume_rollout(&self, rollout_id: &str) -> Result<()> { - let _: serde_json::Value = self - .post(&format!("/v1/rollouts/{}/resume", rollout_id), &serde_json::json!({})) - .await?; - Ok(()) + if response.status().is_success() { + Ok(()) + } else { + Err(self.error_from_response(response).await) + } } - /// Rollback a rollout - pub async fn rollback_rollout(&self, rollout_id: &str, reason: Option) -> Result<()> { - let request = serde_json::json!({ "reason": reason }); - let _: serde_json::Value = - self.post(&format!("/v1/rollouts/{}/rollback", rollout_id), &request).await?; - Ok(()) + /// Delete a circuit breaker + pub async fn delete_circuit_breaker(&self, name: &str) -> Result<()> { + self.delete(&format!("/v1/breakers/{name}")).await } - /// Get rollout state - pub async fn get_rollout(&self, rollout_id: &str) -> Result { - self.get(&format!("/v1/rollouts/{}", rollout_id)).await - } + async fn error_from_response(&self, response: Response) -> Error { + let status = response.status(); + let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); - /// List all rollouts - pub async fn list_rollouts(&self) -> Result> { - self.get("/v1/rollouts").await + match status { + StatusCode::UNAUTHORIZED => Error::Authentication(message), + StatusCode::NOT_FOUND => Error::NotFound(message), + StatusCode::REQUEST_TIMEOUT => Error::Timeout(message), + _ => Error::Api { status: status.as_u16(), message }, + } } // ======================================================================== - // Circuit Breaker + // HTTP Helpers // ======================================================================== - /// Create a new circuit breaker - pub async fn create_circuit_breaker(&self, config: CircuitBreakerConfig) -> Result<()> { - let _: serde_json::Value = self.post("/v1/breakers", &config).await?; - Ok(()) + async fn get(&self, path: &str) -> Result { + let url = self.config.base_url.join(path)?; + let response = self.build_request(self.http_client.get(url)).send().await?; + + self.handle_response(response).await } /// Get circuit breaker statistics pub async fn get_circuit_breaker(&self, name: &str) -> Result { - self.get(&format!("/v1/breakers/{}", name)).await + self.get(&format!("/v1/breakers/{name}")).await } - /// List all circuit breakers - pub async fn list_circuit_breakers(&self) -> Result> { - self.get("/v1/breakers").await + /// Get configuration history + pub async fn get_config_history(&self, limit: Option) -> Result> { + let path = if let Some(limit) = limit { + format!("/v1/config/history?limit={limit}") + } else { + "/v1/config/history".to_string() + }; + self.get(&path).await } - /// Reset a circuit breaker - pub async fn reset_circuit_breaker(&self, name: &str) -> Result<()> { - let _: serde_json::Value = - self.post(&format!("/v1/breakers/{}/reset", name), &serde_json::json!({})).await?; - Ok(()) + /// Get information about a specific node + pub async fn get_node(&self, node_id: &str) -> Result { + self.get(&format!("/v1/nodes/{node_id}")).await } - /// Delete a circuit breaker - pub async fn delete_circuit_breaker(&self, name: &str) -> Result<()> { - self.delete(&format!("/v1/breakers/{}", name)).await + /// Get rollout state + pub async fn get_rollout(&self, rollout_id: &str) -> Result { + self.get(&format!("/v1/rollouts/{rollout_id}")).await + } + + async fn handle_response(&self, response: Response) -> Result { + let status = response.status(); + + if status.is_success() { + Ok(response.json().await?) + } else { + Err(self.error_from_response(response).await) + } + } + + async fn handle_response_text(&self, response: Response) -> Result { + let status = response.status(); + + if status.is_success() { + Ok(response.text().await?) + } else { + Err(self.error_from_response(response).await) + } } // ======================================================================== @@ -212,9 +158,26 @@ impl ControlPlaneClient { self.get("/v1/health").await } - /// Check control plane readiness - pub async fn readiness(&self) -> Result { - self.get("/v1/ready").await + /// Send a heartbeat for a registered node + pub async fn heartbeat(&self, node_id: &str) -> Result<()> { + let _: serde_json::Value = + self.post(&format!("/v1/nodes/{node_id}/heartbeat"), &serde_json::json!({})).await?; + Ok(()) + } + + /// List all circuit breakers + pub async fn list_circuit_breakers(&self) -> Result> { + self.get("/v1/breakers").await + } + + /// List all registered nodes + pub async fn list_nodes(&self) -> Result> { + self.get("/v1/nodes").await + } + + /// List all rollouts + pub async fn list_rollouts(&self) -> Result> { + self.get("/v1/rollouts").await } /// Get Prometheus metrics @@ -225,15 +188,23 @@ impl ControlPlaneClient { self.handle_response_text(response).await } - // ======================================================================== - // HTTP Helpers - // ======================================================================== + /// Create a new control plane client + pub fn new(config: ClientConfig) -> Result { + let builder = Client::builder() + .timeout(config.timeout) + .danger_accept_invalid_certs(config.tls.insecure_skip_verify); - async fn get(&self, path: &str) -> Result { - let url = self.config.base_url.join(path)?; - let response = self.build_request(self.http_client.get(url)).send().await?; + // TODO: Add mTLS support when needed + let http_client = builder.build()?; - self.handle_response(response).await + Ok(Self { config, http_client }) + } + + /// Pause a rollout + pub async fn pause_rollout(&self, rollout_id: &str) -> Result<()> { + let _: serde_json::Value = + self.post(&format!("/v1/rollouts/{rollout_id}/pause"), &serde_json::json!({})).await?; + Ok(()) } async fn post(&self, path: &str, body: &B) -> Result { @@ -243,57 +214,87 @@ impl ControlPlaneClient { self.handle_response(response).await } - async fn delete(&self, path: &str) -> Result<()> { - let url = self.config.base_url.join(path)?; - let response = self.build_request(self.http_client.delete(url)).send().await?; + /// Check control plane readiness + pub async fn readiness(&self) -> Result { + self.get("/v1/ready").await + } - if response.status().is_success() { - Ok(()) - } else { - Err(self.error_from_response(response).await) - } + // ======================================================================== + // Node Management + // ======================================================================== + + /// Register a new node with the control plane + pub async fn register_node( + &self, + node_id: &str, + registration: Option, + ) -> Result { + let reg = registration.unwrap_or_else(|| NodeRegistration { + node_id: node_id.to_string(), + region: None, + zone: None, + labels: HashMap::new(), + capabilities: vec![], + }); + + let response: serde_json::Value = self.post("/v1/nodes/register", ®).await?; + + Ok(response["node_id"].as_str().unwrap_or(node_id).to_string()) } - fn build_request(&self, request: RequestBuilder) -> RequestBuilder { - match &self.config.auth { - AuthConfig::None => request, - AuthConfig::ApiKey(key) => request.header("X-API-Key", key), - AuthConfig::MutualTls { .. } => { - // mTLS is handled at the HTTP client level - request - } - } + /// Reset a circuit breaker + pub async fn reset_circuit_breaker(&self, name: &str) -> Result<()> { + let _: serde_json::Value = + self.post(&format!("/v1/breakers/{name}/reset"), &serde_json::json!({})).await?; + Ok(()) } - async fn handle_response(&self, response: Response) -> Result { - let status = response.status(); + /// Resume a paused rollout + pub async fn resume_rollout(&self, rollout_id: &str) -> Result<()> { + let _: serde_json::Value = + self.post(&format!("/v1/rollouts/{rollout_id}/resume"), &serde_json::json!({})).await?; + Ok(()) + } - if status.is_success() { - Ok(response.json().await?) - } else { - Err(self.error_from_response(response).await) - } + /// Rollback to a previous configuration revision + pub async fn rollback_config(&self, revision: u64, reason: Option) -> Result { + let request = serde_json::json!({ + "revision": revision, + "reason": reason, + }); + let response: serde_json::Value = self.post("/v1/config/rollback", &request).await?; + + Ok(response["new_revision"].as_u64().unwrap_or(0)) } - async fn handle_response_text(&self, response: Response) -> Result { - let status = response.status(); + /// Rollback a rollout + pub async fn rollback_rollout(&self, rollout_id: &str, reason: Option) -> Result<()> { + let request = serde_json::json!({ "reason": reason }); + let _: serde_json::Value = + self.post(&format!("/v1/rollouts/{rollout_id}/rollback"), &request).await?; + Ok(()) + } - if status.is_success() { - Ok(response.text().await?) - } else { - Err(self.error_from_response(response).await) - } + /// Start a rollout + pub async fn start_rollout(&self, rollout_id: &str) -> Result<()> { + let _: serde_json::Value = + self.post(&format!("/v1/rollouts/{rollout_id}/start"), &serde_json::json!({})).await?; + Ok(()) } - async fn error_from_response(&self, response: Response) -> Error { - let status = response.status(); - let message = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); + /// Unregister a node + pub async fn unregister_node(&self, node_id: &str) -> Result<()> { + let _: serde_json::Value = + self.post(&format!("/v1/nodes/{node_id}/unregister"), &serde_json::json!({})).await?; + Ok(()) + } - match status { - StatusCode::UNAUTHORIZED => Error::Authentication(message), - StatusCode::NOT_FOUND => Error::NotFound(message), - StatusCode::REQUEST_TIMEOUT => Error::Timeout(message), - _ => Error::Api { status: status.as_u16(), message }, - } + /// Validate a configuration without applying it (dry-run) + pub async fn validate_config( + &self, + config: serde_json::Value, + ) -> Result { + let request = ConfigValidationRequest { config }; + self.post("/v1/config/validate", &request).await } } diff --git a/crates/rginx-sdk/src/config.rs b/crates/rginx-sdk/src/config.rs index f804819f..738b476a 100644 --- a/crates/rginx-sdk/src/config.rs +++ b/crates/rginx-sdk/src/config.rs @@ -5,32 +5,31 @@ use url::Url; /// Client configuration for connecting to the rginx Control Plane #[derive(Debug, Clone)] pub struct ClientConfig { - /// Base URL of the control plane API - pub base_url: Url, - /// Authentication method pub auth: AuthConfig, - /// Request timeout - pub timeout: Duration, + /// Base URL of the control plane API + pub base_url: Url, /// Maximum number of retries pub max_retries: u32, + /// Request timeout + pub timeout: Duration, + /// TLS configuration pub tls: TlsConfig, } #[derive(Debug, Clone)] pub enum AuthConfig { - /// No authentication - None, - /// API key authentication ApiKey(String), /// mTLS authentication MutualTls { client_cert_path: String, client_key_path: String }, + /// No authentication + None, } #[derive(Debug, Clone)] @@ -43,6 +42,12 @@ pub struct TlsConfig { } impl ClientConfig { + /// Skip TLS verification (insecure, for testing only) + #[must_use] + pub fn insecure_skip_verify(mut self) -> Self { + self.tls.insecure_skip_verify = true; + self + } /// Create a new client configuration with the given base URL pub fn new(base_url: &str) -> Result { let url = Url::parse(base_url).map_err(Error::InvalidUrl)?; @@ -62,36 +67,32 @@ impl ClientConfig { self } - /// Set mTLS authentication - pub fn with_mtls(mut self, cert_path: impl Into, key_path: impl Into) -> Self { - self.auth = AuthConfig::MutualTls { - client_cert_path: cert_path.into(), - client_key_path: key_path.into(), - }; - self - } - - /// Set request timeout - pub fn with_timeout(mut self, timeout: Duration) -> Self { - self.timeout = timeout; + /// Set CA certificate path for server verification + pub fn with_ca_cert(mut self, ca_cert_path: impl Into) -> Self { + self.tls.ca_cert_path = Some(ca_cert_path.into()); self } /// Set maximum number of retries + #[must_use] pub fn with_max_retries(mut self, max_retries: u32) -> Self { self.max_retries = max_retries; self } - /// Set CA certificate path for server verification - pub fn with_ca_cert(mut self, ca_cert_path: impl Into) -> Self { - self.tls.ca_cert_path = Some(ca_cert_path.into()); + /// Set mTLS authentication + pub fn with_mtls(mut self, cert_path: impl Into, key_path: impl Into) -> Self { + self.auth = AuthConfig::MutualTls { + client_cert_path: cert_path.into(), + client_key_path: key_path.into(), + }; self } - /// Skip TLS verification (insecure, for testing only) - pub fn insecure_skip_verify(mut self) -> Self { - self.tls.insecure_skip_verify = true; + /// Set request timeout + #[must_use] + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; self } } diff --git a/crates/rginx-sdk/src/error.rs b/crates/rginx-sdk/src/error.rs index e41a3a1b..26d4a115 100644 --- a/crates/rginx-sdk/src/error.rs +++ b/crates/rginx-sdk/src/error.rs @@ -2,35 +2,34 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum Error { + #[error("API error: {status} - {message}")] + Api { status: u16, message: String }, + + #[error("Authentication failed: {0}")] + Authentication(String), + + #[error("Connection error: {0}")] + Connection(String), #[error("HTTP request failed: {0}")] Http(#[from] reqwest::Error), - #[error("JSON serialization/deserialization failed: {0}")] - Json(#[from] serde_json::Error), - - #[error("WebSocket error: {0}")] - WebSocket(String), + #[error("Invalid configuration: {0}")] + InvalidConfig(String), #[error("Invalid URL: {0}")] InvalidUrl(#[from] url::ParseError), - #[error("API error: {status} - {message}")] - Api { status: u16, message: String }, - - #[error("Authentication failed: {0}")] - Authentication(String), + #[error("JSON serialization/deserialization failed: {0}")] + Json(#[from] serde_json::Error), #[error("Resource not found: {0}")] NotFound(String), - #[error("Invalid configuration: {0}")] - InvalidConfig(String), - #[error("Timeout: {0}")] Timeout(String), - #[error("Connection error: {0}")] - Connection(String), + #[error("WebSocket error: {0}")] + WebSocket(String), } pub type Result = std::result::Result; diff --git a/crates/rginx-sdk/src/lib.rs b/crates/rginx-sdk/src/lib.rs index 2ed0a1b5..00060885 100644 --- a/crates/rginx-sdk/src/lib.rs +++ b/crates/rginx-sdk/src/lib.rs @@ -30,6 +30,15 @@ //! Ok(()) //! } //! ``` +#![expect(clippy::pedantic, reason = "pedantic is tracked as an explicit audit baseline")] +#![expect( + clippy::blanket_clippy_restriction_lints, + reason = "restriction is intentionally tracked as a whole-command audit baseline" +)] +#![expect( + clippy::restriction, + reason = "restriction is tracked as an explicit audit baseline; individual lints are enabled separately" +)] pub mod client; pub mod config; diff --git a/crates/rginx-sdk/src/models.rs b/crates/rginx-sdk/src/models.rs index 5886ee48..d9658030 100644 --- a/crates/rginx-sdk/src/models.rs +++ b/crates/rginx-sdk/src/models.rs @@ -7,23 +7,23 @@ use std::collections::HashMap; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeRegistration { + pub capabilities: Vec, + pub labels: HashMap, pub node_id: String, pub region: Option, pub zone: Option, - pub labels: HashMap, - pub capabilities: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NodeInfo { - pub node_id: String, - pub region: Option, - pub zone: Option, - pub labels: HashMap, pub capabilities: Vec, - pub status: NodeStatus, + pub labels: HashMap, pub last_heartbeat: u64, + pub node_id: String, + pub region: Option, pub registered_at: u64, + pub status: NodeStatus, + pub zone: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -47,8 +47,8 @@ pub struct ConfigApplyRequest { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigMetadata { pub reason: Option, - pub tags: Vec, pub rollback_from: Option, + pub tags: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -58,33 +58,33 @@ pub struct ConfigValidationRequest { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigValidationResult { - pub valid: bool, pub errors: Vec, - pub warnings: Vec, pub impact: ConfigImpact, + pub valid: bool, + pub warnings: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigImpact { - pub requires_reload: bool, pub affects_traffic: bool, pub estimated_downtime_ms: u64, + pub requires_reload: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ConfigRevision { - pub revision: u64, pub applied_at: u64, pub applied_by: String, + pub revision: u64, pub status: ConfigApplyStatus, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ConfigApplyStatus { - Pending, Applied, Failed, + Pending, RolledBack, } @@ -94,40 +94,40 @@ pub enum ConfigApplyStatus { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RolloutPlan { - pub config_revision: u64, - pub strategy: RolloutStrategy, pub auto_advance: bool, + pub config_revision: u64, pub health_check_interval: u64, + pub strategy: RolloutStrategy, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type")] pub enum RolloutStrategy { - Percentage { target_percentage: u8 }, - NodeLabels { labels: HashMap }, - Canary { canary_nodes: Vec }, BlueGreen { active_group: String }, + Canary { canary_nodes: Vec }, + NodeLabels { labels: HashMap }, + Percentage { target_percentage: u8 }, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RolloutState { - pub rollout_id: String, - pub plan: RolloutPlan, - pub phase: RolloutPhase, - pub started_at: Option, + pub affected_nodes: Vec, pub completed_at: Option, pub current_percentage: u8, - pub affected_nodes: Vec, + pub phase: RolloutPhase, + pub plan: RolloutPlan, + pub rollout_id: String, + pub started_at: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum RolloutPhase { - Pending, - InProgress, - Paused, Completed, Failed, + InProgress, + Paused, + Pending, RolledBack, } @@ -137,29 +137,29 @@ pub enum RolloutPhase { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CircuitBreakerConfig { - pub name: String, pub failure_threshold: u32, + pub half_open_max_requests: u32, + pub name: String, pub success_threshold: u32, pub timeout_secs: u64, - pub half_open_max_requests: u32, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CircuitBreakerStats { + pub failure_count: u64, + pub last_state_change: u64, pub name: String, pub state: CircuitState, - pub total_requests: u64, pub success_count: u64, - pub failure_count: u64, - pub last_state_change: u64, + pub total_requests: u64, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum CircuitState { Closed, - Open, HalfOpen, + Open, } // ============================================================================ @@ -168,22 +168,22 @@ pub enum CircuitState { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Event { + pub data: serde_json::Value, pub event_type: EventType, - pub timestamp: u64, pub source: String, - pub data: serde_json::Value, + pub timestamp: u64, } #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum EventType { - NodeRegistered, - NodeUnregistered, - NodeHealthChanged, ConfigApplied, ConfigFailed, - RolloutStarted, + NodeHealthChanged, + NodeRegistered, + NodeUnregistered, RolloutCompleted, + RolloutStarted, } // ============================================================================ @@ -193,12 +193,12 @@ pub enum EventType { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HealthStatus { pub status: String, - pub version: String, pub uptime_secs: u64, + pub version: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ReadinessStatus { - pub ready: bool, pub checks: HashMap, + pub ready: bool, } diff --git a/crates/rginx-sdk/src/websocket.rs b/crates/rginx-sdk/src/websocket.rs index 55013d93..f794a057 100644 --- a/crates/rginx-sdk/src/websocket.rs +++ b/crates/rginx-sdk/src/websocket.rs @@ -1,3 +1,5 @@ +#[cfg(test)] +mod tests; use crate::config::ClientConfig; use crate::error::{Error, Result}; use crate::models::Event; @@ -11,39 +13,6 @@ pub struct EventSubscriber { } impl EventSubscriber { - /// Create a new event subscriber - pub fn new(config: ClientConfig) -> Self { - Self { config } - } - - /// Subscribe to control plane events - /// - /// Returns a channel receiver that will receive events as they arrive. - /// The connection will automatically reconnect on failure. - pub async fn subscribe(&self) -> Result> { - let (tx, rx) = mpsc::channel(100); - - let ws_url = self.build_websocket_url()?; - let config = self.config.clone(); - - tokio::spawn(async move { - loop { - match Self::connect_and_listen(&ws_url, &config, tx.clone()).await { - Ok(_) => { - tracing::info!("WebSocket connection closed normally"); - break; - } - Err(e) => { - tracing::warn!("WebSocket connection error: {}, reconnecting...", e); - tokio::time::sleep(std::time::Duration::from_secs(5)).await; - } - } - } - }); - - Ok(rx) - } - fn build_websocket_url(&self) -> Result { let mut url = self.config.base_url.clone(); @@ -55,7 +24,7 @@ impl EventSubscriber { }; url.set_scheme(scheme) - .map_err(|_| Error::InvalidConfig("Failed to set WebSocket scheme".to_string()))?; + .map_err(|()| Error::InvalidConfig("Failed to set WebSocket scheme".to_string()))?; url.set_path("/v1/events"); @@ -119,7 +88,37 @@ impl EventSubscriber { Ok(()) } -} + /// Create a new event subscriber + #[must_use] + pub fn new(config: ClientConfig) -> Self { + Self { config } + } -#[cfg(test)] -mod tests; + /// Subscribe to control plane events + /// + /// Returns a channel receiver that will receive events as they arrive. + /// The connection will automatically reconnect on failure. + pub async fn subscribe(&self) -> Result> { + let (tx, rx) = mpsc::channel(100); + + let ws_url = self.build_websocket_url()?; + let config = self.config.clone(); + + tokio::spawn(async move { + loop { + match Self::connect_and_listen(&ws_url, &config, tx.clone()).await { + Ok(()) => { + tracing::info!("WebSocket connection closed normally"); + break; + } + Err(e) => { + tracing::warn!("WebSocket connection error: {}, reconnecting...", e); + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + } + } + } + }); + + Ok(rx) + } +} From 0bb4a23d9fbfa7d01e44d1ac1d4317847f08fe70 Mon Sep 17 00:00:00 2001 From: vansour Date: Sun, 17 May 2026 14:23:47 +0800 Subject: [PATCH 2/4] fix: address PR review and modularization gate --- crates/rginx-agent/src/circuit_breaker.rs | 4 +++- crates/rginx-agent/src/outbound/runner.rs | 4 ---- crates/rginx-core/src/config/route.rs | 17 ++--------------- crates/rginx-core/src/config/route/prefix.rs | 11 +++++++++++ crates/rginx-http/src/cache/shared/memory.rs | 3 --- 5 files changed, 16 insertions(+), 23 deletions(-) create mode 100644 crates/rginx-core/src/config/route/prefix.rs diff --git a/crates/rginx-agent/src/circuit_breaker.rs b/crates/rginx-agent/src/circuit_breaker.rs index 15c01b22..86f702a5 100644 --- a/crates/rginx-agent/src/circuit_breaker.rs +++ b/crates/rginx-agent/src/circuit_breaker.rs @@ -286,5 +286,7 @@ impl Default for CircuitBreakerRegistry { } fn current_timestamp() -> u64 { - std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |elapsed| elapsed.as_secs()) } diff --git a/crates/rginx-agent/src/outbound/runner.rs b/crates/rginx-agent/src/outbound/runner.rs index 39b7db0d..2b5a6776 100644 --- a/crates/rginx-agent/src/outbound/runner.rs +++ b/crates/rginx-agent/src/outbound/runner.rs @@ -71,14 +71,12 @@ impl OutboundAgent { if let Some(result) = self.state.cached_result(&command.id) { return Ok(result); } - self.state.in_flight_command = Some(AgentInFlightCommand::received(command.clone())); self.persist_state()?; if let Some(in_flight) = self.state.in_flight_command.as_mut() { in_flight.mark(AgentCommandExecutionState::Executing); } self.persist_state()?; - let started_at_unix_ms = unix_ms(); let outcome = execute( &self.core, @@ -88,7 +86,6 @@ impl OutboundAgent { ) .await; let finished_at_unix_ms = unix_ms(); - let result = match outcome { Ok(result) => AgentCommandResult { command_id: command.id, @@ -183,7 +180,6 @@ impl OutboundAgent { self.state.command_cursor = Some(next_cursor); self.persist_state()?; } - Ok(OutboundAgentCycleOutcome { commands_received, results_posted }) } diff --git a/crates/rginx-core/src/config/route.rs b/crates/rginx-core/src/config/route.rs index 40ef1c3b..cc7be50b 100644 --- a/crates/rginx-core/src/config/route.rs +++ b/crates/rginx-core/src/config/route.rs @@ -1,3 +1,4 @@ +mod prefix; mod proxy_header; mod regex_matcher; @@ -149,7 +150,7 @@ impl RouteMatcher { pub fn matches(&self, path: &str) -> bool { match self { Self::Exact(expected) => path == expected, - Self::PreferredPrefix(prefix) | Self::Prefix(prefix) => prefix_matches(prefix, path), + Self::PreferredPrefix(prefix) | Self::Prefix(prefix) => prefix::matches(prefix, path), Self::Regex(regex) => regex.matches(path), Self::Named(_) => false, } @@ -291,17 +292,3 @@ pub struct ReturnAction { pub location: String, pub status: StatusCode, } - -fn prefix_matches(prefix: &str, path: &str) -> bool { - if prefix == "/" { - return true; - } - - if path == prefix { - return true; - } - - path.strip_prefix(prefix).is_some_and(|remainder| { - if prefix.ends_with('/') { true } else { remainder.starts_with('/') } - }) -} diff --git a/crates/rginx-core/src/config/route/prefix.rs b/crates/rginx-core/src/config/route/prefix.rs new file mode 100644 index 00000000..602b53ff --- /dev/null +++ b/crates/rginx-core/src/config/route/prefix.rs @@ -0,0 +1,11 @@ +pub(super) fn matches(prefix: &str, path: &str) -> bool { + if prefix == "/" { + return true; + } + if path == prefix { + return true; + } + path.strip_prefix(prefix).is_some_and(|remainder| { + if prefix.ends_with('/') { true } else { remainder.starts_with('/') } + }) +} diff --git a/crates/rginx-http/src/cache/shared/memory.rs b/crates/rginx-http/src/cache/shared/memory.rs index 2a5dfae7..87cf6b98 100644 --- a/crates/rginx-http/src/cache/shared/memory.rs +++ b/crates/rginx-http/src/cache/shared/memory.rs @@ -2,14 +2,12 @@ dead_code, reason = "shared memory primitives are reused across platform and test targets" )] - mod config; mod sys; #[cfg(test)] #[path = "memory/tests.rs"] mod tests; - use std::fs; use std::io; use std::mem::size_of; @@ -18,7 +16,6 @@ use std::ptr::NonNull; pub(crate) use self::config::{SharedMemorySegmentBacking, SharedMemorySegmentConfig}; use self::sys::{close_fd, fd_size, ftruncate, invalid_header, open_file, shm_open}; - const SHM_MAGIC: u64 = 0x4d48_5358_4e49_4752; const SHM_ABI_VERSION: u32 = 1; const DEFAULT_HASH_BUCKET_COUNT: u64 = 1_024; From 3c540bdd369398d474c4a9c7578b3e46ca2d6de4 Mon Sep 17 00:00:00 2001 From: vansour Date: Sun, 17 May 2026 14:28:55 +0800 Subject: [PATCH 3/4] test: relax streaming range timeout --- .../rginx-http/src/cache/tests/storage_p4/ranges.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/crates/rginx-http/src/cache/tests/storage_p4/ranges.rs b/crates/rginx-http/src/cache/tests/storage_p4/ranges.rs index 24342f9c..c25dee1a 100644 --- a/crates/rginx-http/src/cache/tests/storage_p4/ranges.rs +++ b/crates/rginx-http/src/cache/tests/storage_p4/ranges.rs @@ -45,7 +45,7 @@ async fn sliced_range_requests_stream_trim_while_filling_cache() { .header(CONTENT_LENGTH, "8") .body(boxed_body(StreamBody::new(stream))) .expect("response should build"); - let stored = timeout(Duration::from_millis(50), manager.store_response(context, response)) + let stored = timeout(Duration::from_millis(200), manager.store_response(context, response)) .await .expect("slice trim must not wait for the whole upstream slice"); assert_eq!(stored.headers().get(CACHE_STATUS_HEADER).unwrap(), "MISS"); @@ -57,14 +57,14 @@ async fn sliced_range_requests_stream_trim_while_filling_cache() { .await .expect("first slice chunk should send"); assert!( - timeout(Duration::from_millis(50), body.frame()).await.is_err(), + timeout(Duration::from_millis(200), body.frame()).await.is_err(), "trimmed downstream body should wait until requested bytes arrive" ); tx.send(Ok::, BoxError>(Frame::data(Bytes::from_static(b"cdef")))) .await .expect("second slice chunk should send"); - let frame = timeout(Duration::from_millis(50), body.frame()) + let frame = timeout(Duration::from_millis(200), body.frame()) .await .expect("trimmed data frame should arrive") .expect("body should yield a frame") @@ -120,7 +120,7 @@ async fn sliced_range_requests_share_inflight_fill_for_subranges() { .body(boxed_body(StreamBody::new(stream))) .expect("response should build"); let stored = - timeout(Duration::from_millis(50), manager.store_response(first_context, response)) + timeout(Duration::from_millis(200), manager.store_response(first_context, response)) .await .expect("streaming slice fill should start immediately"); let mut first_body = stored.into_body(); @@ -129,14 +129,14 @@ async fn sliced_range_requests_share_inflight_fill_for_subranges() { .await .expect("slice prefix should send"); assert!( - timeout(Duration::from_millis(50), first_body.frame()).await.is_err(), + timeout(Duration::from_millis(200), first_body.frame()).await.is_err(), "first trimmed subrange should wait until requested bytes arrive" ); tx.send(Ok::, BoxError>(Frame::data(Bytes::from_static(b"cdef")))) .await .expect("slice middle should send"); - let first_frame = timeout(Duration::from_millis(50), first_body.frame()) + let first_frame = timeout(Duration::from_millis(200), first_body.frame()) .await .expect("first trimmed frame should arrive") .expect("first trimmed body should yield a frame") From 36b53f6f29abc255c5accf78f7333b16ee8fc719 Mon Sep 17 00:00:00 2001 From: vansour Date: Sun, 17 May 2026 14:53:00 +0800 Subject: [PATCH 4/4] fix: address PR review feedback --- crates/rginx-agent/src/agent_core.rs | 53 ++++++-------- crates/rginx-agent/src/agent_core/command.rs | 31 +++++++++ crates/rginx-config/src/load/env_expand.rs | 73 ++++++++++++-------- 3 files changed, 96 insertions(+), 61 deletions(-) create mode 100644 crates/rginx-agent/src/agent_core/command.rs diff --git a/crates/rginx-agent/src/agent_core.rs b/crates/rginx-agent/src/agent_core.rs index ec5c069e..4252b8c5 100644 --- a/crates/rginx-agent/src/agent_core.rs +++ b/crates/rginx-agent/src/agent_core.rs @@ -16,6 +16,13 @@ use crate::model::{ use crate::server::control::{ConfigApplyExecutor, ReloadExecutor, UnsupportedConfigApplyExecutor}; use crate::system::collect_system_view; +mod command; + +pub use command::{ + CacheClearInvalidationsCommand, CacheInvalidateCommand, CacheInvalidateTarget, + CachePurgeCommand, CachePurgeTarget, +}; + const RELOAD_COMPLETION_TIMEOUT: Duration = Duration::from_secs(30); #[derive(Clone)] @@ -26,6 +33,7 @@ pub struct AgentCore { } impl AgentCore { + #[must_use] pub async fn action_status(&self, accepted_revision: u64) -> NodeActionStatusView { NodeActionStatusView { accepted_revision, @@ -46,6 +54,7 @@ impl AgentCore { }) } + #[must_use = "query result should be handled"] pub async fn cache(&self) -> Result { Ok(NodeCacheView::from(self.state.cache_stats_snapshot().await)) } @@ -62,6 +71,7 @@ impl AgentCore { Ok(self.wrap_result(result).await) } + #[must_use = "query result should be handled"] pub async fn delta_since( &self, since_version: u64, @@ -97,6 +107,7 @@ impl AgentCore { Ok(self.wrap_result(result).await) } + #[must_use] pub fn new(state: SharedState, reload_executor: Arc) -> Self { Self { state, @@ -130,6 +141,7 @@ impl AgentCore { Ok(self.reload_action_status(fallback_revision).await) } + #[must_use] async fn reload_action_status(&self, fallback_revision: u64) -> NodeActionStatusView { let current_revision = self.state.current_revision().await; let last_reload_result = last_reload_result(&self.state).await; @@ -145,6 +157,7 @@ impl AgentCore { } } + #[must_use = "query result should be handled"] pub async fn revision(&self) -> Result { Ok(NodeRevisionView::from(self.state.revision_status_snapshot().await)) } @@ -162,6 +175,7 @@ impl AgentCore { &self.state } + #[must_use = "query result should be handled"] pub async fn snapshot(&self, window_secs: Option) -> Result { Ok(NodeSnapshotView { snapshot_version: self.state.current_snapshot_version(), @@ -174,10 +188,12 @@ impl AgentCore { }) } + #[must_use = "query result should be handled"] pub async fn status(&self) -> Result { Ok(NodeStatusView::from(self.state.status_snapshot().await)) } + #[must_use = "query result should be handled"] pub async fn system(&self) -> Result { let config = self.state.current_config().await; let cache_zone_paths = @@ -187,10 +203,12 @@ impl AgentCore { .map_err(|error| Error::Server(error.to_string()))? } + #[must_use = "query result should be handled"] pub async fn traffic(&self, window_secs: Option) -> Result { Ok(NodeTrafficView::from(self.state.traffic_stats_snapshot_with_window(window_secs))) } + #[must_use = "query result should be handled"] pub async fn upstreams(&self, window_secs: Option) -> Result { Ok(NodeUpstreamsView { peer_health: self.state.peer_health_snapshot().await, @@ -224,6 +242,7 @@ impl AgentCore { } } + #[must_use = "query result should be handled"] pub async fn wait_for_snapshot_change( &self, since_version: u64, @@ -233,6 +252,7 @@ impl AgentCore { Ok(NodeWaitView { snapshot_version }) } + #[must_use] pub fn with_config_apply_executor( mut self, config_apply_executor: Arc, @@ -241,44 +261,13 @@ impl AgentCore { self } + #[must_use] pub async fn wrap_result(&self, result: T) -> NodeControlResultView { let current_revision = self.state.current_revision().await; NodeControlResultView { status: self.action_status(current_revision).await, result } } } -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CachePurgeCommand { - pub target: CachePurgeTarget, - pub zone_name: String, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum CachePurgeTarget { - Key(String), - Prefix(String), - Zone, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CacheInvalidateCommand { - pub target: CacheInvalidateTarget, - pub zone_name: String, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum CacheInvalidateTarget { - Key(String), - Prefix(String), - Tag(String), - Zone, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct CacheClearInvalidationsCommand { - pub zone_name: String, -} - async fn last_reload_result(state: &SharedState) -> Option { state.status_snapshot().await.reload.last_result } diff --git a/crates/rginx-agent/src/agent_core/command.rs b/crates/rginx-agent/src/agent_core/command.rs new file mode 100644 index 00000000..2e29ae77 --- /dev/null +++ b/crates/rginx-agent/src/agent_core/command.rs @@ -0,0 +1,31 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CachePurgeCommand { + pub target: CachePurgeTarget, + pub zone_name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CachePurgeTarget { + Key(String), + Prefix(String), + Zone, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CacheInvalidateCommand { + pub target: CacheInvalidateTarget, + pub zone_name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CacheInvalidateTarget { + Key(String), + Prefix(String), + Tag(String), + Zone, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CacheClearInvalidationsCommand { + pub zone_name: String, +} diff --git a/crates/rginx-config/src/load/env_expand.rs b/crates/rginx-config/src/load/env_expand.rs index 6b86b486..a241cc0a 100644 --- a/crates/rginx-config/src/load/env_expand.rs +++ b/crates/rginx-config/src/load/env_expand.rs @@ -19,14 +19,14 @@ pub(super) fn expand_env_placeholders_in_ron_strings( if ch == '"' { in_string = true; } - advance(&mut index, 1); + advance(&mut index, 1, source_path)?; continue; } if escaped { expanded.push(ch); escaped = false; - advance(&mut index, 1); + advance(&mut index, 1, source_path)?; continue; } @@ -34,37 +34,46 @@ pub(super) fn expand_env_placeholders_in_ron_strings( '\\' => { expanded.push(ch); escaped = true; - advance(&mut index, 1); + advance(&mut index, 1, source_path)?; } '"' => { expanded.push(ch); in_string = false; - advance(&mut index, 1); + advance(&mut index, 1, source_path)?; } - '$' if chars.get(checked_add(index, 1)) == Some(&'$') => { - expanded.push('$'); - advance(&mut index, 2); - } - '$' if chars.get(checked_add(index, 1)) == Some(&'{') => { - let token_start = checked_add(index, 2); - let end = chars[token_start..] - .iter() - .position(|candidate| *candidate == '}') - .map(|offset| checked_add(token_start, offset)) - .ok_or_else(|| { - Error::Config(format!( - "unterminated environment placeholder in `{}`", - source_path.display() - )) - })?; - let token = chars[token_start..end].iter().collect::(); - let replacement = resolve_env_placeholder(&token, source_path)?; - expanded.push_str(&escape_ron_string_fragment(&replacement)); - index = checked_add(end, 1); + '$' => { + let next_index = checked_add(index, 1, source_path)?; + match chars.get(next_index) { + Some('$') => { + expanded.push('$'); + advance(&mut index, 2, source_path)?; + } + Some('{') => { + let token_start = checked_add(index, 2, source_path)?; + let end_offset = chars[token_start..] + .iter() + .position(|candidate| *candidate == '}') + .ok_or_else(|| { + Error::Config(format!( + "unterminated environment placeholder in `{}`", + source_path.display() + )) + })?; + let end = checked_add(token_start, end_offset, source_path)?; + let token = chars[token_start..end].iter().collect::(); + let replacement = resolve_env_placeholder(&token, source_path)?; + expanded.push_str(&escape_ron_string_fragment(&replacement)); + index = checked_add(end, 1, source_path)?; + } + _ => { + expanded.push(ch); + advance(&mut index, 1, source_path)?; + } + } } _ => { expanded.push(ch); - advance(&mut index, 1); + advance(&mut index, 1, source_path)?; } } } @@ -114,10 +123,16 @@ fn escape_ron_string_fragment(value: &str) -> String { escaped } -fn advance(value: &mut usize, increment: usize) { - *value = checked_add(*value, increment); +fn advance(value: &mut usize, increment: usize, source_path: &Path) -> Result<()> { + *value = checked_add(*value, increment, source_path)?; + Ok(()) } -fn checked_add(value: usize, increment: usize) -> usize { - value.checked_add(increment).expect("RON environment expansion index remains representable") +fn checked_add(value: usize, increment: usize, source_path: &Path) -> Result { + value.checked_add(increment).ok_or_else(|| { + Error::Config(format!( + "environment placeholder index overflow while loading `{}`", + source_path.display() + )) + }) }