diff --git a/src/api/mod.rs b/src/api/mod.rs index 54b5f87..1122e6d 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,15 +1,17 @@ use crate::{db, AppState}; use axum::{ extract::{ConnectInfo, Request, State}, - http::StatusCode, + http::{header, HeaderValue, StatusCode}, middleware::{self, Next}, response::IntoResponse, routing::{get, post}, - Extension, Json, + Json, }; use serde_json::json; +use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; +use std::num::NonZeroU32; +use std::sync::{Arc, Mutex}; use tower_http::{ cors::CorsLayer, limit::RequestBodyLimitLayer, @@ -26,9 +28,24 @@ const MAX_BODY_BYTES: usize = 256 * 1024; #[derive(Clone)] pub struct AuthenticatedMerchant(pub String); +#[derive(Clone)] +struct RateLimitState { + requests_per_sec: u32, + limiters: Arc>>, +} + +impl RateLimitState { + fn new(requests_per_sec: u32) -> Self { + Self { + requests_per_sec: requests_per_sec.max(1), + limiters: Arc::new(Mutex::new(HashMap::new())), + } + } +} + pub fn router(state: Arc) -> axum::Router { let cors = build_cors(&state.config); - let rate_limit_rps = state.config.rate_limit_requests_per_sec; + let rate_limit = RateLimitState::new(state.config.rate_limit_requests_per_sec); axum::Router::new() .route("/", get(|| async { "StellarGate API v0.1.0" })) @@ -55,17 +72,16 @@ pub fn router(state: Arc) -> axum::Router { "/:id/webhooks/:delivery_id/redeliver", post(payments::redeliver_webhook), ) - .layer(middleware::from_fn( - move |ConnectInfo(addr): ConnectInfo, req: Request, next: Next| { - rate_limit_middleware(addr, rate_limit_rps, req, next) - }, - )) }) .fallback(not_found) .layer(PropagateRequestIdLayer::x_request_id()) .layer(TraceLayer::new_for_http()) .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) .layer(RequestBodyLimitLayer::new(MAX_BODY_BYTES)) + .layer(middleware::from_fn_with_state( + rate_limit, + rate_limit_middleware, + )) .layer(cors) .with_state(state) } @@ -139,49 +155,54 @@ async fn provision_merchant( } async fn rate_limit_middleware( - addr: SocketAddr, - rate_limit_rps: u32, + State(rate_limit): State, req: Request, next: Next, ) -> axum::response::Response { - static LIMITERS: std::sync::OnceLock< - std::sync::Mutex>, - > = std::sync::OnceLock::new(); - - let limiters = LIMITERS.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new())); - let ip = addr.ip().to_string(); - - // Scoped so the `MutexGuard` is dropped before the `next.run().await` - // below, keeping the returned future `Send`. - let allowed = { - let mut map = limiters.lock().unwrap(); - let limiter = map.entry(ip).or_insert_with(|| { - governor::RateLimiter::direct(governor::Quota::per_second( - std::num::NonZeroU32::new(rate_limit_rps).unwrap(), - )) - }); - limiter.check().is_ok() - }; + if req.method() == axum::http::Method::POST && req.uri().path() == "/payments" { + let key = rate_limit_key(&req); + let limited = { + let mut map = rate_limit.limiters.lock().unwrap(); + let limiter = map.entry(key).or_insert_with(|| { + governor::RateLimiter::direct(governor::Quota::per_second( + NonZeroU32::new(rate_limit.requests_per_sec).unwrap(), + )) + }); + limiter.check().is_err() + }; - if !allowed { - let retry_after = (1000 / rate_limit_rps).max(1); - return ( - StatusCode::TOO_MANY_REQUESTS, - [( - axum::http::header::RETRY_AFTER, - axum::http::HeaderValue::from_str(&retry_after.to_string()).unwrap(), - )], - axum::Json(json!({ - "error": "rate limit exceeded", - "code": "rate_limit_exceeded" - })), - ) - .into_response(); + if limited { + return ( + StatusCode::TOO_MANY_REQUESTS, + [(header::RETRY_AFTER, HeaderValue::from_static("1"))], + Json(json!({ + "error": "rate limit exceeded", + "code": "rate_limit_exceeded" + })), + ) + .into_response(); + } } next.run(req).await } +fn rate_limit_key(req: &Request) -> String { + if let Some(ConnectInfo(addr)) = req.extensions().get::>() { + return addr.ip().to_string(); + } + + for name in ["x-forwarded-for", "x-real-ip"] { + if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) { + if let Some(first) = value.split(',').map(str::trim).find(|s| !s.is_empty()) { + return first.to_string(); + } + } + } + + "local".to_string() +} + fn build_cors(cfg: &crate::config::Config) -> CorsLayer { use axum::http::HeaderName; use tower_http::cors::AllowOrigin; diff --git a/src/api/payments.rs b/src/api/payments.rs index 0cbf433..6475d6b 100644 --- a/src/api/payments.rs +++ b/src/api/payments.rs @@ -263,8 +263,14 @@ pub async fn list( } else { // Legacy offset pagination — kept for backward compatibility. let offset = q.offset.unwrap_or(0).max(0); - let (payments, total) = - db::list_payments(&state.pool, &merchant_id, q.status.as_deref(), limit, offset).await?; + let (payments, total) = db::list_payments( + &state.pool, + &merchant_id, + q.status.as_deref(), + limit, + offset, + ) + .await?; // Provide next_cursor to ease migration to keyset pagination. let next_cursor = payments.last().map(|p| encode_cursor(&p.created_at, &p.id)); diff --git a/src/db.rs b/src/db.rs index 602f979..0c8da4b 100644 --- a/src/db.rs +++ b/src/db.rs @@ -307,11 +307,13 @@ pub async fn list_payments( .fetch_all(pool) .await?; - let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM payments WHERE merchant_id = ? AND status = ?") - .bind(merchant_id) - .bind(s) - .fetch_one(pool) - .await?; + let total: i64 = sqlx::query_scalar( + "SELECT COUNT(*) FROM payments WHERE merchant_id = ? AND status = ?", + ) + .bind(merchant_id) + .bind(s) + .fetch_one(pool) + .await?; (rows, total) } else { @@ -646,13 +648,11 @@ fn hash_api_key(raw: &str) -> String { /// shown to the user by the caller and is not recoverable afterward. pub async fn create_merchant(pool: &Db, id: &str, raw_key: &str) -> Result<()> { let hash = hash_api_key(raw_key); - sqlx::query( - "INSERT INTO merchants (id, api_key_hash) VALUES (?, ?)", - ) - .bind(id) - .bind(hash) - .execute(pool) - .await?; + sqlx::query("INSERT INTO merchants (id, api_key_hash) VALUES (?, ?)") + .bind(id) + .bind(hash) + .execute(pool) + .await?; Ok(()) } @@ -660,11 +660,10 @@ pub async fn create_merchant(pool: &Db, id: &str, raw_key: &str) -> Result<()> { /// not match any registered merchant. pub async fn find_merchant_by_key(pool: &Db, raw_key: &str) -> Result> { let hash = hash_api_key(raw_key); - let id: Option = - sqlx::query_scalar("SELECT id FROM merchants WHERE api_key_hash = ?") - .bind(hash) - .fetch_optional(pool) - .await?; + let id: Option = sqlx::query_scalar("SELECT id FROM merchants WHERE api_key_hash = ?") + .bind(hash) + .fetch_optional(pool) + .await?; Ok(id) } diff --git a/src/horizon.rs b/src/horizon.rs index 3676126..cac1680 100644 --- a/src/horizon.rs +++ b/src/horizon.rs @@ -694,7 +694,16 @@ mod tests { } fn test_assets() -> Vec { - crate::config::AcceptedAsset::parse_list("XLM,USDC:GUSDC") + vec![ + crate::config::AcceptedAsset { + code: "XLM".into(), + issuer: None, + }, + crate::config::AcceptedAsset { + code: "USDC".into(), + issuer: Some("GUSDC".into()), + }, + ] } #[test] diff --git a/tests/api_tests.rs b/tests/api_tests.rs index 8c53510..07a44fa 100644 --- a/tests/api_tests.rs +++ b/tests/api_tests.rs @@ -68,10 +68,7 @@ async fn test_server() -> TestServer { async fn provision_merchant(server: &TestServer) -> String { let res = server.post("/merchants").await; res.assert_status(StatusCode::CREATED); - res.json::()["api_key"] - .as_str() - .unwrap() - .to_string() + res.json::()["api_key"].as_str().unwrap().to_string() } #[tokio::test] diff --git a/tests/rate_limit_tests.rs b/tests/rate_limit_tests.rs index 983b632..c909278 100644 --- a/tests/rate_limit_tests.rs +++ b/tests/rate_limit_tests.rs @@ -1,10 +1,7 @@ //! Rate-limit behaviour lives in its own integration binary on purpose. //! -//! The limiter keeps a process-global table of per-IP limiters, created lazily -//! on the first request from each IP. Sharing that table with the broader API -//! tests (which run at a high limit) would let an earlier test create the limiter -//! for the test client's IP at the wrong rate. A dedicated test binary gives this -//! test a fresh, uncontaminated limiter table. +//! The broader API tests run at a high limit and exercise merchant auth heavily. +//! Keeping the low-quota assertion here makes the expected 429 path explicit. use axum::http::StatusCode; use axum_test::TestServer; @@ -60,13 +57,22 @@ async fn server_with_config(cfg: Config) -> TestServer { TestServer::new(router).unwrap() } +async fn provision_merchant(server: &TestServer) -> String { + let res = server.post("/merchants").await; + res.assert_status(StatusCode::CREATED); + res.json::()["api_key"].as_str().unwrap().to_string() +} + #[tokio::test] async fn test_rate_limit_exceeded_returns_429() { let server = server_with_config(make_config(1)).await; + let key = provision_merchant(&server).await; + let auth = format!("Bearer {key}"); // The first request consumes the single per-second token. let first = server .post("/payments") + .add_header("Authorization", auth.clone()) .json(&json!({ "amount": "1", "asset": "XLM" })) .await; first.assert_status(StatusCode::CREATED); @@ -74,6 +80,7 @@ async fn test_rate_limit_exceeded_returns_429() { // A second immediate request exceeds the quota and is rejected. let second = server .post("/payments") + .add_header("Authorization", auth) .json(&json!({ "amount": "1", "asset": "XLM" })) .await; second.assert_status(StatusCode::TOO_MANY_REQUESTS);