diff --git a/Cargo.lock b/Cargo.lock index 0dd0e763..285bbd94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1679,6 +1679,7 @@ dependencies = [ "blsful", "cb-common", "cb-metrics", + "client-ip", "eyre", "futures", "headers", @@ -1716,6 +1717,7 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", + "tracing-test", "tree_hash", "url", ] @@ -1836,6 +1838,16 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" +[[package]] +name = "client-ip" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31211fc26899744f5b22521fdc971e5f3875991d8880537537470685a0e9552d" +dependencies = [ + "forwarded-header-value", + "http", +] + [[package]] name = "cmake" version = "0.1.54" @@ -2828,6 +2840,16 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "forwarded-header-value" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" +dependencies = [ + "nonempty", + "thiserror 1.0.69", +] + [[package]] name = "fs-err" version = "3.1.0" @@ -3948,6 +3970,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonempty" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" + [[package]] name = "nu-ansi-term" version = "0.50.1" @@ -6193,6 +6221,27 @@ dependencies = [ "tracing-serde", ] +[[package]] +name = "tracing-test" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "557b891436fe0d5e0e363427fc7f217abf9ccd510d5136549847bdcbcd011d68" +dependencies = [ + "tracing-core", + "tracing-subscriber", + "tracing-test-macro", +] + +[[package]] +name = "tracing-test-macro" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568" +dependencies = [ + "quote", + "syn 2.0.106", +] + [[package]] name = "tree_hash" version = "0.9.1" diff --git a/Cargo.toml b/Cargo.toml index dc5ee88d..b0533144 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ cb-pbs = { path = "crates/pbs" } cb-signer = { path = "crates/signer" } cipher = "0.4" clap = { version = "4.5.4", features = ["derive", "env"] } +client-ip = { version = "0.1.1", features = [ "forwarded-header" ] } color-eyre = "0.6.3" const_format = "0.2.34" ctr = "0.9.2" @@ -74,6 +75,7 @@ tower-http = { version = "0.6", features = ["trace"] } tracing = "0.1.40" tracing-appender = "0.2.3" tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } +tracing-test = { version = "0.2.5", features = ["no-env-filter"] } tree_hash = "0.9" tree_hash_derive = "0.9" typenum = "1.17.0" diff --git a/api/signer-api.yml b/api/signer-api.yml index 9e11da34..95897ecd 100644 --- a/api/signer-api.yml +++ b/api/signer-api.yml @@ -15,6 +15,7 @@ paths: The token **must include** the following claims: - `exp` (integer): Expiration timestamp + - `route` (string): The route being requested (must be `/signer/v1/get_pubkeys` for this endpoint). - `module` (string): The ID of the module making the request, which must match a module ID in the Commit-Boost configuration file. tags: - Signer @@ -73,6 +74,7 @@ paths: The token **must include** the following claims: - `exp` (integer): Expiration timestamp - `module` (string): The ID of the module making the request, which must match a module ID in the Commit-Boost configuration file. + - `route` (string): The route being requested (must be `/signer/v1/request_signature/bls` for this endpoint). - `payload_hash` (string): The Keccak-256 hash of the JSON-encoded request body, with optional `0x` prefix. This is required to prevent JWT replay attacks. tags: - Signer @@ -220,6 +222,7 @@ paths: The token **must include** the following claims: - `exp` (integer): Expiration timestamp - `module` (string): The ID of the module making the request, which must match a module ID in the Commit-Boost configuration file. + - `route` (string): The route being requested (must be `/signer/v1/request_signature/proxy-bls` for this endpoint). - `payload_hash` (string): The Keccak-256 hash of the JSON-encoded request body, with optional `0x` prefix. This is required to prevent JWT replay attacks. tags: - Signer @@ -367,6 +370,7 @@ paths: The token **must include** the following claims: - `exp` (integer): Expiration timestamp - `module` (string): The ID of the module making the request, which must match a module ID in the Commit-Boost configuration file. + - `route` (string): The route being requested (must be `/signer/v1/request_signature/proxy-ecdsa` for this endpoint). - `payload_hash` (string): The Keccak-256 hash of the JSON-encoded request body, with optional `0x` prefix. This is required to prevent JWT replay attacks. tags: - Signer @@ -514,6 +518,7 @@ paths: The token **must include** the following claims: - `exp` (integer): Expiration timestamp - `module` (string): The ID of the module making the request, which must match a module ID in the Commit-Boost configuration file. + - `route` (string): The route being requested (must be `/signer/v1/generate_proxy_key` for this endpoint). - `payload_hash` (string): The Keccak-256 hash of the JSON-encoded request body, with optional `0x` prefix. This is required to prevent JWT replay attacks. tags: - Signer diff --git a/config.example.toml b/config.example.toml index 67085409..5b69f108 100644 --- a/config.example.toml +++ b/config.example.toml @@ -165,7 +165,8 @@ port = 20000 # Number of JWT authentication attempts a client can fail before blocking that client temporarily from Signer access # OPTIONAL, DEFAULT: 3 jwt_auth_fail_limit = 3 -# How long to block a client from Signer access, in seconds, if it failed JWT authentication too many times +# How long to block a client from Signer access, in seconds, if it failed JWT authentication too many times. +# This also defines the interval at which failed attempts are regularly checked and expired ones are cleaned up. # OPTIONAL, DEFAULT: 300 jwt_auth_fail_timeout_seconds = 300 diff --git a/crates/common/src/commit/client.rs b/crates/common/src/commit/client.rs index 1151eb6f..98d8c26d 100644 --- a/crates/common/src/commit/client.rs +++ b/crates/common/src/commit/client.rs @@ -2,10 +2,7 @@ use std::path::PathBuf; use alloy::primitives::Address; use eyre::WrapErr; -use reqwest::{ - Certificate, - header::{AUTHORIZATION, HeaderMap, HeaderValue}, -}; +use reqwest::Certificate; use serde::{Deserialize, Serialize}; use url::Url; @@ -60,30 +57,13 @@ impl SignerClient { Ok(Self { url: signer_server_url, client: builder.build()?, module_id, jwt_secret }) } - fn refresh_jwt(&mut self) -> Result<(), SignerClientError> { - let jwt = create_jwt(&self.module_id, &self.jwt_secret, None)?; - - let mut auth_value = - HeaderValue::from_str(&format!("Bearer {jwt}")).wrap_err("invalid jwt")?; - auth_value.set_sensitive(true); - - let mut headers = HeaderMap::new(); - headers.insert(AUTHORIZATION, auth_value); - - self.client = reqwest::Client::builder() - .timeout(DEFAULT_REQUEST_TIMEOUT) - .default_headers(headers) - .build()?; - - Ok(()) - } - fn create_jwt_for_payload( &mut self, + route: &str, payload: &T, ) -> Result { let payload_vec = serde_json::to_vec(payload)?; - create_jwt(&self.module_id, &self.jwt_secret, Some(&payload_vec)) + create_jwt(&self.module_id, &self.jwt_secret, route, Some(&payload_vec)) .wrap_err("failed to create JWT for payload") .map_err(SignerClientError::JWTError) } @@ -92,10 +72,12 @@ impl SignerClient { /// requested. // TODO: add more docs on how proxy keys work pub async fn get_pubkeys(&mut self) -> Result { - self.refresh_jwt()?; + let jwt = create_jwt(&self.module_id, &self.jwt_secret, GET_PUBKEYS_PATH, None) + .wrap_err("failed to create JWT for payload") + .map_err(SignerClientError::JWTError)?; let url = self.url.join(GET_PUBKEYS_PATH)?; - let res = self.client.get(url).send().await?; + let res = self.client.get(url).bearer_auth(jwt).send().await?; if !res.status().is_success() { return Err(SignerClientError::FailedRequest { @@ -117,7 +99,7 @@ impl SignerClient { Q: Serialize, T: for<'de> Deserialize<'de>, { - let jwt = self.create_jwt_for_payload(request)?; + let jwt = self.create_jwt_for_payload(route, request)?; let url = self.url.join(route)?; let res = self.client.post(url).json(&request).bearer_auth(jwt).send().await?; @@ -165,7 +147,7 @@ impl SignerClient { where T: ProxyId + for<'de> Deserialize<'de>, { - let jwt = self.create_jwt_for_payload(request)?; + let jwt = self.create_jwt_for_payload(GENERATE_PROXY_KEY_PATH, request)?; let url = self.url.join(GENERATE_PROXY_KEY_PATH)?; let res = self.client.post(url).json(&request).bearer_auth(jwt).send().await?; diff --git a/crates/common/src/config/signer.rs b/crates/common/src/config/signer.rs index 4e040701..b4c5db16 100644 --- a/crates/common/src/config/signer.rs +++ b/crates/common/src/config/signer.rs @@ -88,7 +88,8 @@ pub struct SignerConfig { pub jwt_auth_fail_limit: u32, /// Duration in seconds to rate limit an endpoint after the JWT auth failure - /// limit has been reached + /// limit has been reached. This also defines the interval at which failed + /// attempts are regularly checked and expired ones are cleaned up. #[serde(default = "default_u32::")] pub jwt_auth_fail_timeout_seconds: u32, diff --git a/crates/common/src/signer/store.rs b/crates/common/src/signer/store.rs index bd4aa103..d70ea8a0 100644 --- a/crates/common/src/signer/store.rs +++ b/crates/common/src/signer/store.rs @@ -244,14 +244,14 @@ impl ProxyStore { serde_json::from_str(&file_content)?; let signer = EcdsaSigner::new_from_bytes(&key_and_delegation.secret)?; - let pubkey = signer.address(); + let address = signer.address(); let proxy_signer = EcdsaProxySigner { signer, delegation: key_and_delegation.delegation, }; - proxy_signers.ecdsa_signers.insert(pubkey, proxy_signer); - ecdsa_map.entry(module_id.clone()).or_default().push(pubkey); + proxy_signers.ecdsa_signers.insert(address, proxy_signer); + ecdsa_map.entry(module_id.clone()).or_default().push(address); } } } diff --git a/crates/common/src/types.rs b/crates/common/src/types.rs index 13c6b501..9fa3b40b 100644 --- a/crates/common/src/types.rs +++ b/crates/common/src/types.rs @@ -26,6 +26,7 @@ pub struct Jwt(pub String); pub struct JwtClaims { pub exp: u64, pub module: ModuleId, + pub route: String, pub payload_hash: Option, } @@ -33,6 +34,7 @@ pub struct JwtClaims { pub struct JwtAdminClaims { pub exp: u64, pub admin: bool, + pub route: String, pub payload_hash: Option, } diff --git a/crates/common/src/utils.rs b/crates/common/src/utils.rs index 91c3b11a..bb26edb5 100644 --- a/crates/common/src/utils.rs +++ b/crates/common/src/utils.rs @@ -346,11 +346,17 @@ pub fn print_logo() { } /// Create a JWT for the given module id with expiration -pub fn create_jwt(module_id: &ModuleId, secret: &str, payload: Option<&[u8]>) -> eyre::Result { +pub fn create_jwt( + module_id: &ModuleId, + secret: &str, + route: &str, + payload: Option<&[u8]>, +) -> eyre::Result { jsonwebtoken::encode( &jsonwebtoken::Header::default(), &JwtClaims { module: module_id.clone(), + route: route.to_string(), exp: jsonwebtoken::get_current_timestamp() + SIGNER_JWT_EXPIRATION, payload_hash: payload.map(keccak256), }, @@ -361,11 +367,16 @@ pub fn create_jwt(module_id: &ModuleId, secret: &str, payload: Option<&[u8]>) -> } // Creates a JWT for module administration -pub fn create_admin_jwt(admin_secret: String, payload: Option<&[u8]>) -> eyre::Result { +pub fn create_admin_jwt( + admin_secret: String, + route: &str, + payload: Option<&[u8]>, +) -> eyre::Result { jsonwebtoken::encode( &jsonwebtoken::Header::default(), &JwtAdminClaims { admin: true, + route: route.to_string(), exp: jsonwebtoken::get_current_timestamp() + SIGNER_JWT_EXPIRATION, payload_hash: payload.map(keccak256), }, @@ -408,7 +419,12 @@ pub fn decode_admin_jwt(jwt: Jwt) -> eyre::Result { } /// Validate a JWT with the given secret -pub fn validate_jwt(jwt: Jwt, secret: &str, payload: Option<&[u8]>) -> eyre::Result<()> { +pub fn validate_jwt( + jwt: Jwt, + secret: &str, + route: &str, + payload: Option<&[u8]>, +) -> eyre::Result<()> { let mut validation = jsonwebtoken::Validation::default(); validation.leeway = 10; @@ -419,6 +435,11 @@ pub fn validate_jwt(jwt: Jwt, secret: &str, payload: Option<&[u8]>) -> eyre::Res )? .claims; + // Validate the route + if claims.route != route { + eyre::bail!("Token route does not match"); + } + // Validate the payload hash if provided if let Some(payload_bytes) = payload { if let Some(expected_hash) = claims.payload_hash { @@ -436,7 +457,12 @@ pub fn validate_jwt(jwt: Jwt, secret: &str, payload: Option<&[u8]>) -> eyre::Res } /// Validate an admin JWT with the given secret -pub fn validate_admin_jwt(jwt: Jwt, secret: &str, payload: Option<&[u8]>) -> eyre::Result<()> { +pub fn validate_admin_jwt( + jwt: Jwt, + secret: &str, + route: &str, + payload: Option<&[u8]>, +) -> eyre::Result<()> { let mut validation = jsonwebtoken::Validation::default(); validation.leeway = 10; @@ -451,6 +477,11 @@ pub fn validate_admin_jwt(jwt: Jwt, secret: &str, payload: Option<&[u8]>) -> eyr eyre::bail!("Token is not admin") } + // Validate the route + if claims.route != route { + eyre::bail!("Token route does not match"); + } + // Validate the payload hash if provided if let Some(payload_bytes) = payload { if let Some(expected_hash) = claims.payload_hash { @@ -546,24 +577,25 @@ mod test { #[test] fn test_jwt_validation_no_payload_hash() { // Check valid JWT - let jwt = create_jwt(&ModuleId("DA_COMMIT".to_string()), "secret", None).unwrap(); + let jwt = + create_jwt(&ModuleId("DA_COMMIT".to_string()), "secret", "/test/route", None).unwrap(); let claims = decode_jwt(jwt.clone()).unwrap(); let module_id = claims.module; let payload_hash = claims.payload_hash; assert_eq!(module_id, ModuleId("DA_COMMIT".to_string())); assert!(payload_hash.is_none()); - let response = validate_jwt(jwt, "secret", None); + let response = validate_jwt(jwt, "secret", "/test/route", None); assert!(response.is_ok()); // Check expired JWT - let expired_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NDI5OTU5NDYsIm1vZHVsZSI6IkRBX0NPTU1JVCJ9.iiq4Z2ed2hk3c3c-cn2QOQJWE5XUOc5BoaIPT-I8q-s".to_string()); - let response = validate_jwt(expired_jwt, "secret", None); + let expired_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NTgyOTkxNzIsIm1vZHVsZSI6IkRBX0NPTU1JVCIsInJvdXRlIjoiL3Rlc3Qvcm91dGUiLCJwYXlsb2FkX2hhc2giOm51bGx9._OBsNC67KLkk6f6ZQ2_CDbhYUJ2OtZ9egKAmi1L-ymA".to_string()); + let response = validate_jwt(expired_jwt, "secret", "/test/route", None); assert!(response.is_err()); assert_eq!(response.unwrap_err().to_string(), "ExpiredSignature"); // Check invalid signature JWT - let invalid_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NDI5OTU5NDYsIm1vZHVsZSI6IkRBX0NPTU1JVCJ9.w9WYdDNzgDjYTvjBkk4GGzywGNBYPxnzU2uJWzPUT1s".to_string()); - let response = validate_jwt(invalid_jwt, "secret", None); + let invalid_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NTgyOTkxMzQsIm1vZHVsZSI6IkRBX0NPTU1JVCIsInJvdXRlIjoiL3Rlc3Qvcm91dGUiLCJwYXlsb2FkX2hhc2giOm51bGx9.58QXayg2XeX5lXhIPw-a8kl04DWBEj5wBsqsedTeClo".to_string()); + let response = validate_jwt(invalid_jwt, "secret", "/test/route", None); assert!(response.is_err()); assert_eq!(response.unwrap_err().to_string(), "InvalidSignature"); } @@ -577,25 +609,30 @@ mod test { let payload_bytes = serde_json::to_vec(&payload).unwrap(); // Check valid JWT - let jwt = - create_jwt(&ModuleId("DA_COMMIT".to_string()), "secret", Some(&payload_bytes)).unwrap(); + let jwt = create_jwt( + &ModuleId("DA_COMMIT".to_string()), + "secret", + "/test/route", + Some(&payload_bytes), + ) + .unwrap(); let claims = decode_jwt(jwt.clone()).unwrap(); let module_id = claims.module; let payload_hash = claims.payload_hash; assert_eq!(module_id, ModuleId("DA_COMMIT".to_string())); assert_eq!(payload_hash, Some(keccak256(&payload_bytes))); - let response = validate_jwt(jwt, "secret", Some(&payload_bytes)); + let response = validate_jwt(jwt, "secret", "/test/route", Some(&payload_bytes)); assert!(response.is_ok()); // Check expired JWT - let expired_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NDI5OTU5NDYsIm1vZHVsZSI6IkRBX0NPTU1JVCJ9.iiq4Z2ed2hk3c3c-cn2QOQJWE5XUOc5BoaIPT-I8q-s".to_string()); - let response = validate_jwt(expired_jwt, "secret", Some(&payload_bytes)); + let expired_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NTgyOTgzNDQsIm1vZHVsZSI6IkRBX0NPTU1JVCIsInJvdXRlIjoiL3Rlc3Qvcm91dGUiLCJwYXlsb2FkX2hhc2giOiIweGFmODk2MjY0MzUzNTFmYzIwMDBkYmEwM2JiNTlhYjcyZWE0ODJiOWEwMDBmZWQzNmNkMjBlMDU0YjE2NjZmZjEifQ.PYrSxLXadKBgYZlmLam8RBSL32I1T_zAxlZpG6xnnII".to_string()); + let response = validate_jwt(expired_jwt, "secret", "/test/route", Some(&payload_bytes)); assert!(response.is_err()); assert_eq!(response.unwrap_err().to_string(), "ExpiredSignature"); // Check invalid signature JWT - let invalid_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NDI5OTU5NDYsIm1vZHVsZSI6IkRBX0NPTU1JVCJ9.w9WYdDNzgDjYTvjBkk4GGzywGNBYPxnzU2uJWzPUT1s".to_string()); - let response = validate_jwt(invalid_jwt, "secret", Some(&payload_bytes)); + let invalid_jwt = Jwt::from("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJleHAiOjE3NTgyOTkwMDAsIm1vZHVsZSI6IkRBX0NPTU1JVCIsInJvdXRlIjoiL3Rlc3Qvcm91dGUiLCJwYXlsb2FkX2hhc2giOiIweGFmODk2MjY0MzUzNTFmYzIwMDBkYmEwM2JiNTlhYjcyZWE0ODJiOWEwMDBmZWQzNmNkMjBlMDU0YjE2NjZmZjEifQ.mnC-AexkLlR9l98SJbln3DmV6r9XyHYdbjcUVcWdi_8".to_string()); + let response = validate_jwt(invalid_jwt, "secret", "/test/route", Some(&payload_bytes)); assert!(response.is_err()); assert_eq!(response.unwrap_err().to_string(), "InvalidSignature"); } diff --git a/crates/signer/Cargo.toml b/crates/signer/Cargo.toml index 7c6e63fa..1a688e1b 100644 --- a/crates/signer/Cargo.toml +++ b/crates/signer/Cargo.toml @@ -14,6 +14,7 @@ bimap.workspace = true blsful.workspace = true cb-common.workspace = true cb-metrics.workspace = true +client-ip.workspace = true eyre.workspace = true futures.workspace = true headers.workspace = true diff --git a/crates/signer/src/service.rs b/crates/signer/src/service.rs index eb284289..e9480db1 100644 --- a/crates/signer/src/service.rs +++ b/crates/signer/src/service.rs @@ -5,12 +5,12 @@ use std::{ time::{Duration, Instant}, }; -use alloy::primitives::{Address, B256, U256, keccak256}; +use alloy::primitives::{Address, B256, U256}; use axum::{ Extension, Json, body::{Body, to_bytes}, extract::{ConnectInfo, Request, State}, - http::StatusCode, + http::{HeaderMap, StatusCode}, middleware::{self, Next}, response::{IntoResponse, Response}, routing::{get, post}, @@ -36,6 +36,7 @@ use cb_common::{ utils::{decode_jwt, validate_admin_jwt, validate_jwt}, }; use cb_metrics::provider::MetricsProvider; +use client_ip::*; use eyre::Context; use headers::{Authorization, authorization::Bearer}; use parking_lot::RwLock as ParkingRwLock; @@ -144,13 +145,49 @@ impl SigningService { .route_layer(middleware::from_fn(log_request)) .route(STATUS_PATH, get(handle_status)); - if CryptoProvider::get_default().is_none() { - aws_lc_rs::default_provider() - .install_default() - .map_err(|_| eyre::eyre!("Failed to install TLS provider"))?; - } + // Run the JWT cleaning task + let jwt_cleaning_task = tokio::spawn(async move { + let mut interval = tokio::time::interval(state.jwt_auth_fail_timeout); + loop { + interval.tick().await; + let mut failures = state.jwt_auth_failures.write(); + let before = failures.len(); + failures + .retain(|_, info| info.last_failure.elapsed() < state.jwt_auth_fail_timeout); + let after = failures.len(); + if before != after { + debug!("Cleaned up {} old JWT auth failure entries", before - after); + } + } + }); let server_result = if let Some(tls_config) = config.tls_certificates { + if CryptoProvider::get_default().is_none() { + // Install the AWS-LC provider if no default is set, usually for CI + debug!("Installing AWS-LC as default TLS provider"); + let mut attempts = 0; + loop { + match aws_lc_rs::default_provider().install_default() { + Ok(_) => { + debug!("Successfully installed AWS-LC as default TLS provider"); + break; + } + Err(e) => { + error!( + "Failed to install AWS-LC as default TLS provider: {e:?}. Retrying..." + ); + if attempts >= 3 { + error!( + "Exceeded maximum attempts to install AWS-LC as default TLS provider" + ); + break; + } + attempts += 1; + } + } + } + } + let tls_config = RustlsConfig::from_pem(tls_config.0, tls_config.1).await?; axum_server::bind_rustls(config.endpoint, tls_config) .serve( @@ -165,6 +202,10 @@ impl SigningService { ) .await }; + + // Shutdown the JWT cleaning task + jwt_cleaning_task.abort(); + server_result.wrap_err("signer service exited") } @@ -173,39 +214,81 @@ impl SigningService { } } +/// Marks a JWT authentication failure for a given client IP +fn mark_jwt_failure(state: &SigningState, client_ip: IpAddr) { + let mut failures = state.jwt_auth_failures.write(); + let failure_info = failures + .entry(client_ip) + .or_insert(JwtAuthFailureInfo { failure_count: 0, last_failure: Instant::now() }); + failure_info.failure_count += 1; + failure_info.last_failure = Instant::now(); +} + +/// Get the true client IP from the request headers or fallback to the socket +/// address +fn get_true_ip(req_headers: &HeaderMap, addr: &SocketAddr) -> eyre::Result { + let ip_extractors = [ + cf_connecting_ip, + cloudfront_viewer_address, + fly_client_ip, + rightmost_forwarded, + rightmost_x_forwarded_for, + true_client_ip, + x_real_ip, + ]; + + // Run each extractor in order and return the first valid IP found + for extractor in ip_extractors { + match extractor(req_headers) { + Ok(true_ip) => { + return Ok(true_ip); + } + Err(e) => { + match e { + Error::AbsentHeader { .. } => continue, // Missing headers are fine + _ => return Err(eyre::eyre!(e.to_string())), // Report anything else + } + } + } + } + + // Fallback to the socket IP + Ok(addr.ip()) +} + /// Authentication middleware layer async fn jwt_auth( State(state): State, + req_headers: HeaderMap, TypedHeader(auth): TypedHeader>, addr: ConnectInfo, req: Request, next: Next, ) -> Result { // Check if the request needs to be rate limited - let client_ip = addr.ip(); + let client_ip = get_true_ip(&req_headers, &addr).map_err(|e| { + error!("Failed to get client IP: {e}"); + SignerModuleError::RequestError("failed to get client IP".to_string()) + })?; check_jwt_rate_limit(&state, &client_ip)?; // Clone the request so we can read the body let (parts, body) = req.into_parts(); + let path = parts.uri.path(); let bytes = to_bytes(body, REQUEST_MAX_BODY_LENGTH).await.map_err(|e| { error!("Failed to read request body: {e}"); SignerModuleError::RequestError(e.to_string()) })?; // Process JWT authorization - match check_jwt_auth(&auth, &state, &bytes) { + match check_jwt_auth(&auth, &state, path, &bytes) { Ok(module_id) => { let mut req = Request::from_parts(parts, Body::from(bytes)); req.extensions_mut().insert(module_id); Ok(next.run(req).await) } Err(SignerModuleError::Unauthorized) => { - let mut failures = state.jwt_auth_failures.write(); - let failure_info = failures - .entry(client_ip) - .or_insert(JwtAuthFailureInfo { failure_count: 0, last_failure: Instant::now() }); - failure_info.failure_count += 1; - failure_info.last_failure = Instant::now(); + mark_jwt_failure(&state, client_ip); Err(SignerModuleError::Unauthorized) } Err(err) => Err(err), @@ -253,6 +336,7 @@ fn check_jwt_rate_limit(state: &SigningState, client_ip: &IpAddr) -> Result<(), fn check_jwt_auth( auth: &Authorization, state: &SigningState, + path: &str, body: &[u8], ) -> Result { let jwt: Jwt = auth.token().to_string().into(); @@ -270,44 +354,33 @@ fn check_jwt_auth( SignerModuleError::Unauthorized })?; - if body.is_empty() { - // Skip payload hash comparison for requests without a body - validate_jwt(jwt, &jwt_config.jwt_secret, None).map_err(|e| { - error!("Unauthorized request. Invalid JWT: {e}"); - SignerModuleError::Unauthorized - })?; - } else { - validate_jwt(jwt, &jwt_config.jwt_secret, Some(body)).map_err(|e| { - error!("Unauthorized request. Invalid JWT: {e}"); - SignerModuleError::Unauthorized - })?; - - // Make sure the request contains a hash of the payload in its claims - if !body.is_empty() { - let payload_hash = keccak256(body); - if claims.payload_hash.is_none() || claims.payload_hash != Some(payload_hash) { - error!("Unauthorized request. Invalid payload hash in JWT claims"); - return Err(SignerModuleError::Unauthorized); - } - } - } + let body_bytes = if body.is_empty() { None } else { Some(body) }; + validate_jwt(jwt, &jwt_config.jwt_secret, path, body_bytes).map_err(|e| { + error!("Unauthorized request. Invalid JWT: {e}"); + SignerModuleError::Unauthorized + })?; Ok(claims.module) } async fn admin_auth( State(state): State, + req_headers: HeaderMap, TypedHeader(auth): TypedHeader>, addr: ConnectInfo, req: Request, next: Next, ) -> Result { // Check if the request needs to be rate limited - let client_ip = addr.ip(); + let client_ip = get_true_ip(&req_headers, &addr).map_err(|e| { + error!("Failed to get client IP: {e}"); + SignerModuleError::RequestError("failed to get client IP".to_string()) + })?; check_jwt_rate_limit(&state, &client_ip)?; // Clone the request so we can read the body let (parts, body) = req.into_parts(); + let path = parts.uri.path(); let bytes = to_bytes(body, REQUEST_MAX_BODY_LENGTH).await.map_err(|e| { error!("Failed to read request body: {e}"); SignerModuleError::RequestError(e.to_string()) @@ -316,18 +389,12 @@ async fn admin_auth( let jwt: Jwt = auth.token().to_string().into(); // Validate the admin JWT - if bytes.is_empty() { - // Skip payload hash comparison for requests without a body - validate_admin_jwt(jwt, &state.admin_secret.read(), None).map_err(|e| { - error!("Unauthorized request. Invalid JWT: {e}"); - SignerModuleError::Unauthorized - })?; - } else { - validate_admin_jwt(jwt, &state.admin_secret.read(), Some(&bytes)).map_err(|e| { - error!("Unauthorized request. Invalid payload hash in JWT claims: {e}"); - SignerModuleError::Unauthorized - })?; - } + let body_bytes: Option<&[u8]> = if bytes.is_empty() { None } else { Some(&bytes) }; + validate_admin_jwt(jwt, &state.admin_secret.read(), path, body_bytes).map_err(|e| { + error!("Unauthorized request. Invalid JWT: {e}"); + mark_jwt_failure(&state, client_ip); + SignerModuleError::Unauthorized + })?; let req = Request::from_parts(parts, Body::from(bytes)); Ok(next.run(req).await) @@ -598,6 +665,7 @@ async fn handle_reload( debug!(event = "reload", ?req_id, "New request"); + // Regenerate the config let config = match StartSignerConfig::load_from_env() { Ok(config) => config, Err(err) => { @@ -606,6 +674,16 @@ async fn handle_reload( } }; + // Start a new manager with the updated config + let new_manager = match start_manager(config).await { + Ok(manager) => manager, + Err(err) => { + error!(event = "reload", ?req_id, error = ?err, "Failed to reload manager"); + return Err(SignerModuleError::Internal("failed to reload config".to_string())); + } + }; + + // Update the JWT configs if provided in the request if let Some(jwt_secrets) = request.jwt_secrets { let mut jwt_configs = state.jwts.write(); let mut new_configs = HashMap::new(); @@ -627,23 +705,11 @@ async fn handle_reload( *jwt_configs = new_configs; } + // Update the rest of the state once everything has passed if let Some(admin_secret) = request.admin_secret { *state.admin_secret.write() = admin_secret; } - - let new_manager = match start_manager(config).await { - Ok(manager) => manager, - Err(err) => { - error!(event = "reload", ?req_id, error = ?err, "Failed to reload manager"); - return Err(SignerModuleError::Internal("failed to reload config".to_string())); - } - }; - - // Replace the contents of the manager RwLock - { - let mut manager_guard = state.manager.write().await; - *manager_guard = new_manager; - } + *state.manager.write().await = new_manager; Ok(StatusCode::OK) } diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 6cd2b829..5b373706 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -21,3 +21,6 @@ tracing.workspace = true tracing-subscriber.workspace = true tree_hash.workspace = true url.workspace = true + +[dev-dependencies] +tracing-test.workspace = true diff --git a/tests/tests/signer_jwt_auth.rs b/tests/tests/signer_jwt_auth.rs index 37561428..d1b65b3f 100644 --- a/tests/tests/signer_jwt_auth.rs +++ b/tests/tests/signer_jwt_auth.rs @@ -45,7 +45,7 @@ async fn test_signer_jwt_auth_success() -> Result<()> { let jwt_config = mod_cfgs.get(&module_id).expect("JWT config for test module not found"); // Run a pubkeys request - let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, None)?; + let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, GET_PUBKEYS_PATH, None)?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); let response = client.get(&url).bearer_auth(&jwt).send().await?; @@ -64,7 +64,7 @@ async fn test_signer_jwt_auth_fail() -> Result<()> { let start_config = start_server(20101, &mod_cfgs, ADMIN_SECRET.to_string(), false).await?; // Run a pubkeys request - this should fail due to invalid JWT - let jwt = create_jwt(&module_id, "incorrect secret", None)?; + let jwt = create_jwt(&module_id, "incorrect secret", GET_PUBKEYS_PATH, None)?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); let response = client.get(&url).bearer_auth(&jwt).send().await?; @@ -86,7 +86,7 @@ async fn test_signer_jwt_rate_limit() -> Result<()> { let mod_cfg = mod_cfgs.get(&module_id).expect("JWT config for test module not found"); // Run as many pubkeys requests as the fail limit - let jwt = create_jwt(&module_id, "incorrect secret", None)?; + let jwt = create_jwt(&module_id, "incorrect secret", GET_PUBKEYS_PATH, None)?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); for _ in 0..start_config.jwt_auth_fail_limit { @@ -95,7 +95,7 @@ async fn test_signer_jwt_rate_limit() -> Result<()> { } // Run another request - this should fail due to rate limiting now - let jwt = create_jwt(&module_id, &mod_cfg.jwt_secret, None)?; + let jwt = create_jwt(&module_id, &mod_cfg.jwt_secret, GET_PUBKEYS_PATH, None)?; let response = client.get(&url).bearer_auth(&jwt).send().await?; assert!(response.status() == StatusCode::TOO_MANY_REQUESTS); @@ -119,7 +119,7 @@ async fn test_signer_revoked_jwt_fail() -> Result<()> { let start_config = start_server(20400, &mod_cfgs, admin_secret.clone(), false).await?; // Run as many pubkeys requests as the fail limit - let jwt = create_jwt(&module_id, JWT_SECRET, None)?; + let jwt = create_jwt(&module_id, JWT_SECRET, GET_PUBKEYS_PATH, None)?; let client = reqwest::Client::new(); // At first, test module should be allowed to request pubkeys @@ -129,7 +129,7 @@ async fn test_signer_revoked_jwt_fail() -> Result<()> { let revoke_body = RevokeModuleRequest { module_id: ModuleId(JWT_MODULE.to_string()) }; let body_bytes = serde_json::to_vec(&revoke_body)?; - let admin_jwt = create_admin_jwt(admin_secret, Some(&body_bytes))?; + let admin_jwt = create_admin_jwt(admin_secret, REVOKE_MODULE_PATH, Some(&body_bytes))?; let revoke_url = format!("http://{}{}", start_config.endpoint, REVOKE_MODULE_PATH); let response = @@ -155,7 +155,7 @@ async fn test_signer_only_admin_can_revoke() -> Result<()> { let body_bytes = serde_json::to_vec(&revoke_body)?; // Run as many pubkeys requests as the fail limit - let jwt = create_jwt(&module_id, JWT_SECRET, Some(&body_bytes))?; + let jwt = create_jwt(&module_id, JWT_SECRET, REVOKE_MODULE_PATH, Some(&body_bytes))?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, REVOKE_MODULE_PATH); @@ -164,7 +164,45 @@ async fn test_signer_only_admin_can_revoke() -> Result<()> { assert!(response.status() == StatusCode::UNAUTHORIZED); // Admin should be able to revoke modules - let admin_jwt = create_admin_jwt(admin_secret, Some(&body_bytes))?; + let admin_jwt = create_admin_jwt(admin_secret, REVOKE_MODULE_PATH, Some(&body_bytes))?; + let response = client.post(&url).json(&revoke_body).bearer_auth(&admin_jwt).send().await?; + assert!(response.status() == StatusCode::OK); + + Ok(()) +} + +#[tokio::test] +async fn test_signer_admin_jwt_rate_limit() -> Result<()> { + setup_test_env(); + let admin_secret = ADMIN_SECRET.to_string(); + let module_id = ModuleId(JWT_MODULE.to_string()); + let mod_cfgs = create_mod_signing_configs().await; + let start_config = start_server(20510, &mod_cfgs, admin_secret.clone(), false).await?; + + let revoke_body = RevokeModuleRequest { module_id: ModuleId(JWT_MODULE.to_string()) }; + let body_bytes = serde_json::to_vec(&revoke_body)?; + + // Run as many pubkeys requests as the fail limit + let jwt = create_jwt(&module_id, JWT_SECRET, REVOKE_MODULE_PATH, Some(&body_bytes))?; + let client = reqwest::Client::new(); + let url = format!("http://{}{}", start_config.endpoint, REVOKE_MODULE_PATH); + + // Module JWT shouldn't be able to revoke modules + for _ in 0..start_config.jwt_auth_fail_limit { + let response = client.post(&url).json(&revoke_body).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::UNAUTHORIZED); + } + + // Run another request - this should fail due to rate limiting now + let admin_jwt = create_admin_jwt(admin_secret, REVOKE_MODULE_PATH, Some(&body_bytes))?; + let response = client.post(&url).json(&revoke_body).bearer_auth(&admin_jwt).send().await?; + assert!(response.status() == StatusCode::TOO_MANY_REQUESTS); + + // Wait for the rate limit timeout + tokio::time::sleep(Duration::from_secs(start_config.jwt_auth_fail_timeout_seconds as u64)) + .await; + + // Now the next request should succeed let response = client.post(&url).json(&revoke_body).bearer_auth(&admin_jwt).send().await?; assert!(response.status() == StatusCode::OK); diff --git a/tests/tests/signer_jwt_auth_cleanup.rs b/tests/tests/signer_jwt_auth_cleanup.rs new file mode 100644 index 00000000..d6fde2a4 --- /dev/null +++ b/tests/tests/signer_jwt_auth_cleanup.rs @@ -0,0 +1,70 @@ +use std::{collections::HashMap, time::Duration}; + +use alloy::primitives::b256; +use cb_common::{ + commit::constants::GET_PUBKEYS_PATH, + config::{ModuleSigningConfig, load_module_signing_configs}, + types::ModuleId, + utils::create_jwt, +}; +use cb_tests::{ + signer_service::start_server, + utils::{self}, +}; +use eyre::Result; +use reqwest::StatusCode; + +const JWT_MODULE: &str = "test-module"; +const JWT_SECRET: &str = "test-jwt-secret"; +const ADMIN_SECRET: &str = "test-admin-secret"; + +async fn create_mod_signing_configs() -> HashMap { + let mut cfg = + utils::get_commit_boost_config(utils::get_pbs_static_config(utils::get_pbs_config(0))); + + let module_id = ModuleId(JWT_MODULE.to_string()); + let signing_id = b256!("0101010101010101010101010101010101010101010101010101010101010101"); + + cfg.modules = Some(vec![utils::create_module_config(module_id.clone(), signing_id)]); + + let jwts = HashMap::from([(module_id.clone(), JWT_SECRET.to_string())]); + + load_module_signing_configs(&cfg, &jwts).unwrap() +} + +#[tokio::test] +#[tracing_test::traced_test] +async fn test_signer_jwt_fail_cleanup() -> Result<()> { + // setup_test_env() isn't used because we want to capture logs with tracing_test + let module_id = ModuleId(JWT_MODULE.to_string()); + let mod_cfgs = create_mod_signing_configs().await; + let start_config = start_server(20102, &mod_cfgs, ADMIN_SECRET.to_string(), false).await?; + let mod_cfg = mod_cfgs.get(&module_id).expect("JWT config for test module not found"); + + // Run as many pubkeys requests as the fail limit + let jwt = create_jwt(&module_id, "incorrect secret", GET_PUBKEYS_PATH, None)?; + let client = reqwest::Client::new(); + let url = format!("http://{}{}", start_config.endpoint, GET_PUBKEYS_PATH); + for _ in 0..start_config.jwt_auth_fail_limit { + let response = client.get(&url).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::UNAUTHORIZED); + } + + // Run another request - this should fail due to rate limiting now + let jwt = create_jwt(&module_id, &mod_cfg.jwt_secret, GET_PUBKEYS_PATH, None)?; + let response = client.get(&url).bearer_auth(&jwt).send().await?; + assert!(response.status() == StatusCode::TOO_MANY_REQUESTS); + + // Wait until the cleanup task should have run properly, takes a while for the + // timing to work out + tokio::time::sleep(Duration::from_secs( + (start_config.jwt_auth_fail_timeout_seconds * 3) as u64, + )) + .await; + + // Make sure the cleanup message was logged - it's all internal state so without + // refactoring or exposing it, this is the easiest way to check if it triggered + assert!(logs_contain("Cleaned up 1 old JWT auth failure entries")); + + Ok(()) +} diff --git a/tests/tests/signer_request_sig.rs b/tests/tests/signer_request_sig.rs index 15680587..78efbf9e 100644 --- a/tests/tests/signer_request_sig.rs +++ b/tests/tests/signer_request_sig.rs @@ -62,7 +62,12 @@ async fn test_signer_sign_request_good() -> Result<()> { let pubkey = BlsPublicKey::deserialize(&PUBKEY_1).unwrap(); let request = SignConsensusRequest { pubkey: pubkey.clone(), object_root, nonce }; let payload_bytes = serde_json::to_vec(&request)?; - let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, Some(&payload_bytes))?; + let jwt = create_jwt( + &module_id, + &jwt_config.jwt_secret, + REQUEST_SIGNATURE_BLS_PATH, + Some(&payload_bytes), + )?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, REQUEST_SIGNATURE_BLS_PATH); let response = client.post(&url).json(&request).bearer_auth(&jwt).send().await?; @@ -100,7 +105,12 @@ async fn test_signer_sign_request_different_module() -> Result<()> { let pubkey = BlsPublicKey::deserialize(&PUBKEY_1).unwrap(); let request = SignConsensusRequest { pubkey: pubkey.clone(), object_root, nonce }; let payload_bytes = serde_json::to_vec(&request)?; - let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, Some(&payload_bytes))?; + let jwt = create_jwt( + &module_id, + &jwt_config.jwt_secret, + REQUEST_SIGNATURE_BLS_PATH, + Some(&payload_bytes), + )?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, REQUEST_SIGNATURE_BLS_PATH); let response = client.post(&url).json(&request).bearer_auth(&jwt).send().await?; @@ -146,7 +156,12 @@ async fn test_signer_sign_request_incorrect_hash() -> Result<()> { let true_object_root = b256!("0x0123456789012345678901234567890123456789012345678901234567890123"); let true_request = SignConsensusRequest { pubkey, object_root: true_object_root, nonce }; - let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, Some(&fake_payload_bytes))?; + let jwt = create_jwt( + &module_id, + &jwt_config.jwt_secret, + REQUEST_SIGNATURE_BLS_PATH, + Some(&fake_payload_bytes), + )?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, REQUEST_SIGNATURE_BLS_PATH); let response = client.post(&url).json(&true_request).bearer_auth(&jwt).send().await?; @@ -171,7 +186,7 @@ async fn test_signer_sign_request_missing_hash() -> Result<()> { let pubkey = BlsPublicKey::deserialize(&PUBKEY_1).unwrap(); let object_root = b256!("0x0123456789012345678901234567890123456789012345678901234567890123"); let request = SignConsensusRequest { pubkey, object_root, nonce }; - let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, None)?; + let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, REQUEST_SIGNATURE_BLS_PATH, None)?; let client = reqwest::Client::new(); let url = format!("http://{}{}", start_config.endpoint, REQUEST_SIGNATURE_BLS_PATH); let response = client.post(&url).json(&request).bearer_auth(&jwt).send().await?; diff --git a/tests/tests/signer_tls.rs b/tests/tests/signer_tls.rs index 4f53bb92..2df98d73 100644 --- a/tests/tests/signer_tls.rs +++ b/tests/tests/signer_tls.rs @@ -41,7 +41,7 @@ async fn test_signer_tls() -> Result<()> { let jwt_config = mod_cfgs.get(&module_id).expect("JWT config for test module not found"); // Run a pubkeys request - let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, None)?; + let jwt = create_jwt(&module_id, &jwt_config.jwt_secret, GET_PUBKEYS_PATH, None)?; let cert = match start_config.tls_certificates { Some(ref certificates) => &certificates.0, None => bail!("TLS certificates not found in start config"),