From af46e8c3cf00a1af1fda70d4a731d3c4ad32508d Mon Sep 17 00:00:00 2001 From: jemilahabiodun-oss Date: Mon, 30 Mar 2026 07:51:54 +0000 Subject: [PATCH] fix: secure and configurable CORS (closes #380) --- services/api/src/config.rs | 25 ++++++++++++++++++++++ services/api/src/lib.rs | 43 +++++++++++++++++++++++++++++++++----- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/services/api/src/config.rs b/services/api/src/config.rs index ce5032b..d985155 100644 --- a/services/api/src/config.rs +++ b/services/api/src/config.rs @@ -63,6 +63,11 @@ pub struct Config { pub admin_whitelist_ips: Vec, pub request_signing_secret: Option, pub sendgrid_webhook_secret: Option, + // CORS config + pub cors_allowed_origins: Vec, + pub cors_allowed_methods: Vec, + pub cors_allowed_headers: Vec, + pub cors_allow_credentials: bool, } impl Config { @@ -127,6 +132,22 @@ impl Config { .filter(|&s| s > 0) .map(Duration::from_secs); + let cors_allowed_origins = env::var("CORS_ALLOWED_ORIGINS") + .map(|v| v.split(',').map(|s| s.trim().to_string()).collect()) + .unwrap_or_else(|_| vec!["https://yourdomain.com".to_string()]); + + let cors_allowed_methods = env::var("CORS_ALLOWED_METHODS") + .map(|v| v.split(',').map(|s| s.trim().to_string()).collect()) + .unwrap_or_else(|_| vec!["GET".to_string(), "POST".to_string(), "OPTIONS".to_string()]); + + let cors_allowed_headers = env::var("CORS_ALLOWED_HEADERS") + .map(|v| v.split(',').map(|s| s.trim().to_string()).collect()) + .unwrap_or_else(|_| vec!["Content-Type".to_string(), "Authorization".to_string()]); + + let cors_allow_credentials = env::var("CORS_ALLOW_CREDENTIALS") + .map(|v| v == "true" || v == "1") + .unwrap_or(false); + Self { bind_addr, redis_url: env::var("REDIS_URL") @@ -194,6 +215,10 @@ impl Config { .unwrap_or_default(), request_signing_secret: env::var("REQUEST_SIGNING_SECRET").ok(), sendgrid_webhook_secret: env::var("SENDGRID_WEBHOOK_SECRET").ok(), + cors_allowed_origins, + cors_allowed_methods, + cors_allowed_headers, + cors_allow_credentials, } } diff --git a/services/api/src/lib.rs b/services/api/src/lib.rs index 9537f9e..cfe83dd 100644 --- a/services/api/src/lib.rs +++ b/services/api/src/lib.rs @@ -33,7 +33,7 @@ use security::{ApiKeyAuth, IpWhitelist, RateLimiter}; use shutdown::ShutdownCoordinator; use tokio::net::TcpListener; use tower_http::{ - cors::{Any, CorsLayer}, + cors::{CorsLayer, Origin, Any}, trace::TraceLayer, }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -149,10 +149,43 @@ pub async fn run() -> anyhow::Result<()> { tracing::warn!("cache warming skipped: {err}"); } - let cors = CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any); + let cors = { + let origins = &state.config.cors_allowed_origins; + let methods = &state.config.cors_allowed_methods; + let headers = &state.config.cors_allowed_headers; + let allow_credentials = state.config.cors_allow_credentials; + + let mut cors = CorsLayer::new(); + + // Allowed origins + if origins.len() == 1 && origins[0] == "*" { + cors = cors.allow_origin(Any); + } else { + let origins_vec: Vec = origins.iter().filter_map(|o| Origin::try_from(o.as_str()).ok()).collect(); + cors = cors.allow_origin(origins_vec); + } + + // Allowed methods + if methods.len() == 1 && methods[0] == "*" { + cors = cors.allow_methods(Any); + } else { + cors = cors.allow_methods(methods.clone()); + } + + // Allowed headers + if headers.len() == 1 && headers[0] == "*" { + cors = cors.allow_headers(Any); + } else { + cors = cors.allow_headers(headers.clone()); + } + + // Allow credentials + if allow_credentials { + cors = cors.allow_credentials(true); + } + + cors + }; let public_routes = Router::new() .route("/health", get(handlers::health))