diff --git a/Cargo.lock b/Cargo.lock index 0dd0e763..e06d7ff0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1716,6 +1716,7 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", + "tracing-test", "tree_hash", "url", ] @@ -6193,6 +6194,27 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tracing-test" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "557b891436fe0d5e0e363427fc7f217abf9ccd510d5136549847bdcbcd011d68" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568" +dependencies = [ + "quote", + "syn 2.0.106", +] + [[package]] name = "tree_hash" version = "0.9.1" diff --git a/Cargo.toml b/Cargo.toml index dc5ee88d..6ea8ba96 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ tower-http = { version = "0.6", features = ["trace"] } tracing = "0.1.40" tracing-appender = "0.2.3" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } +tracing-test = { version = "0.2.5", features = ["no-env-filter"] } tree_hash = "0.9" tree_hash_derive = "0.9" typenum = "1.17.0" diff --git a/config.example.toml b/config.example.toml index 67085409..5b69f108 100644 --- a/config.example.toml +++ b/config.example.toml @@ -165,7 +165,8 @@ port = 20000 # Number of JWT authentication attempts a client can fail before blocking that client temporarily from Signer access # OPTIONAL, DEFAULT: 3 jwt_auth_fail_limit = 3 -# How long to block a client from Signer access, in seconds, if it failed JWT authentication too many times +# How long to block a client from Signer access, in seconds, if it failed JWT authentication too many times. +# This also defines the interval at which failed attempts are regularly checked and expired ones are cleaned up. # OPTIONAL, DEFAULT: 300 jwt_auth_fail_timeout_seconds = 300 diff --git a/crates/common/src/config/signer.rs b/crates/common/src/config/signer.rs index 4e040701..b4c5db16 100644 --- a/crates/common/src/config/signer.rs +++ b/crates/common/src/config/signer.rs @@ -88,7 +88,8 @@ pub struct SignerConfig { pub jwt_auth_fail_limit: u32, /// Duration in seconds to rate limit an endpoint after the JWT auth failure - /// limit has been reached + /// limit has been reached. This also defines the interval at which failed + /// attempts are regularly checked and expired ones are cleaned up. #[serde(default = "default_u32::")] pub jwt_auth_fail_timeout_seconds: u32, diff --git a/crates/signer/src/service.rs b/crates/signer/src/service.rs index e59ca920..e0dd6de6 100644 --- a/crates/signer/src/service.rs +++ b/crates/signer/src/service.rs @@ -143,6 +143,22 @@ impl SigningService { .route_layer(middleware::from_fn(log_request)) .route(STATUS_PATH, get(handle_status)); + // Run the JWT cleaning task + let jwt_cleaning_task = tokio::spawn(async move { + let mut interval = tokio::time::interval(state.jwt_auth_fail_timeout); + loop { + interval.tick().await; + let mut failures = state.jwt_auth_failures.write().await; + let before = failures.len(); + failures + .retain(|_, info| info.last_failure.elapsed() < state.jwt_auth_fail_timeout); + let after = failures.len(); + if before != after { + debug!("Cleaned up {} old JWT auth failure entries", before - after); + } + } + }); + let server_result = if let Some(tls_config) = config.tls_certificates { if CryptoProvider::get_default().is_none() { // Install the AWS-LC provider if no default is set, usually for CI @@ -184,6 +200,10 @@ impl SigningService { ) .await }; + + // Shutdown the JWT cleaning task + jwt_cleaning_task.abort(); + server_result.wrap_err("signer service exited") } diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 6cd2b829..5b373706 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -21,3 +21,6 @@ tracing.workspace = true tracing-subscriber.workspace = true tree_hash.workspace = true url.workspace = true + +[dev-dependencies] +tracing-test.workspace = true diff --git a/tests/tests/signer_jwt_auth_cleanup.rs b/tests/tests/signer_jwt_auth_cleanup.rs new file mode 100644 index 00000000..d6fde2a4 --- /dev/null +++ b/tests/tests/signer_jwt_auth_cleanup.rs @@ -0,0 +1,70 @@ +use std::{collections::HashMap, time::Duration}; + +use alloy::primitives::b256; +use cb_common::{ + commit::constants::GET_PUBKEYS_PATH, + config::{ModuleSigningConfig, load_module_signing_configs}, + types::ModuleId, + utils::create_jwt, +}; +use cb_tests::{ + signer_service::start_server, + utils::{self}, +}; +use eyre::Result; +use reqwest::StatusCode; + +const JWT_MODULE: &str = "test-module"; +const JWT_SECRET: &str = "test-jwt-secret"; +const ADMIN_SECRET: &str = "test-admin-secret"; + +async fn create_mod_signing_configs() -> HashMap { + let mut cfg = + utils::get_commit_boost_config(utils::get_pbs_static_config(utils::get_pbs_config(0))); + + let module_id = ModuleId(JWT_MODULE.to_string()); + let signing_id = b256!("0101010101010101010101010101010101010101010101010101010101010101"); + + cfg.modules = Some(vec![utils::create_module_config(module_id.clone(), signing_id)]); + + let jwts = HashMap::from([(module_id.clone(), JWT_SECRET.to_string())]); + + load_module_signing_configs(&cfg, &jwts).unwrap() +} + +#[tokio::test] +#[tracing_test::traced_test] +async fn test_signer_jwt_fail_cleanup() -> Result<()> { + // setup_test_env() isn't used because we want to capture logs with tracing_test + let module_id = ModuleId(JWT_MODULE.to_string()); + let mod_cfgs = create_mod_signing_configs().await; + let start_config = start_server(20102, &mod_cfgs, ADMIN_SECRET.to_string(), false).await?; + let mod_cfg = mod_cfgs.get(&module_id).expect("JWT config for test module not found"); + + // Run as many pubkeys requests as the fail limit + let jwt = create_jwt(&module_id, "incorrect secret", GET_PUBKEYS_PATH, None)?; + let client = reqwest::Client::new(); + let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); + for _ in 0..start_config.jwt_auth_fail_limit { + let response = client.get(&url).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::UNAUTHORIZED); + } + + // Run another request - this should fail due to rate limiting now + let jwt = create_jwt(&module_id, &mod_cfg.jwt_secret, GET_PUBKEYS_PATH, None)?; + let response = client.get(&url).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::TOO_MANY_REQUESTS); + + // Wait until the cleanup task should have run properly, takes a while for the + // timing to work out + tokio::time::sleep(Duration::from_secs( + (start_config.jwt_auth_fail_timeout_seconds * 3) as u64, + )) + .await; + + // Make sure the cleanup message was logged - it's all internal state so without + // refactoring or exposing it, this is the easiest way to check if it triggered + assert!(logs_contain("Cleaned up 1 old JWT auth failure entries")); + + Ok(()) +}