From c6f5d5764ba56cd9851ff25e26411259691ddbbf Mon Sep 17 00:00:00 2001 From: Tomoya0k Date: Tue, 23 Jun 2026 13:28:37 -0600 Subject: [PATCH] Fix config assets and payment rate limiting --- Cargo.lock | 121 +++++++++++++++++++++++++++ src/api/mod.rs | 121 +++++++++++++++++---------- src/config.rs | 58 ++++++++++--- src/horizon.rs | 105 +++++++++++++++++++---- src/main.rs | 15 +++- tests/api_tests.rs | 201 +++++++++++++++++++++++++++++++++++---------- 6 files changed, 506 insertions(+), 115 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9efeb5e..611c3c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -308,6 +308,20 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "6.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "der" version = "0.7.10" @@ -545,6 +559,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" +[[package]] +name = "futures-timer" +version = "3.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af43fadb8a98512d547e37b4e92e0ced13e205c061b87b4623eff01d918d6968" + [[package]] name = "futures-util" version = "0.3.32" @@ -595,6 +615,27 @@ dependencies = [ "wasip3", ] +[[package]] +name = "governor" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0746aa765db78b521451ef74221663b57ba595bf83f75d0ce23cc09447c8139f" +dependencies = [ + "cfg-if", + "dashmap", + "futures-sink", + "futures-timer", + "futures-util", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand", + "smallvec", + "spinning_top", +] + [[package]] name = "h2" version = "0.4.15" @@ -614,6 +655,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -1110,6 +1157,18 @@ dependencies = [ "tempfile", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1304,6 +1363,12 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.5" @@ -1357,6 +1422,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quote" version = "1.0.45" @@ -1402,6 +1482,15 @@ dependencies = [ "getrandom 0.2.17", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -1798,6 +1887,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -2012,6 +2110,7 @@ dependencies = [ "axum-test", "dotenvy", "futures-util", + "governor", "hex", "hmac", "reqwest", @@ -2662,6 +2761,28 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-link" version = "0.2.1" diff --git a/src/api/mod.rs b/src/api/mod.rs index 5c2745c..c0308c1 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,15 +1,17 @@ use crate::{db, AppState}; use axum::{ - extract::State, - http::StatusCode, + extract::{ConnectInfo, Request, State}, + http::{header, HeaderValue, StatusCode}, middleware::{self, Next}, response::IntoResponse, routing::{get, post}, - Json, Request, + Json, }; use serde_json::json; +use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; +use std::num::NonZeroU32; +use std::sync::{Arc, Mutex}; use tower_http::{ cors::CorsLayer, limit::RequestBodyLimitLayer, @@ -22,10 +24,25 @@ mod payments; /// Reject request bodies larger than this (256 KiB) before they hit a handler. const MAX_BODY_BYTES: usize = 256 * 1024; +#[derive(Clone)] +struct RateLimitState { + requests_per_sec: u32, + limiters: Arc>>, +} + +impl RateLimitState { + fn new(requests_per_sec: u32) -> Self { + Self { + requests_per_sec: requests_per_sec.max(1), + limiters: Arc::new(Mutex::new(HashMap::new())), + } + } +} + pub fn router(state: Arc) -> axum::Router { let cors = build_cors(&state.config); - let rate_limit_rps = state.config.rate_limit_requests_per_sec; - + let rate_limit = RateLimitState::new(state.config.rate_limit_requests_per_sec); + axum::Router::new() .route("/", get(|| async { "StellarGate API v0.1.0" })) .route("/health", get(health)) @@ -33,55 +50,73 @@ pub fn router(state: Arc) -> axum::Router { .route("/payments", post(payments::create).get(payments::list)) .route("/payments/:id", get(payments::get_by_id)) .route("/payments/:id/webhooks", get(payments::list_webhooks)) - .route("/payments/:id/webhooks/:delivery_id/redeliver", post(payments::redeliver_webhook)) + .route( + "/payments/:id/webhooks/:delivery_id/redeliver", + post(payments::redeliver_webhook), + ) .fallback(not_found) .layer(PropagateRequestIdLayer::x_request_id()) .layer(TraceLayer::new_for_http()) .layer(SetRequestIdLayer::x_request_id(MakeRequestUuid)) .layer(RequestBodyLimitLayer::new(MAX_BODY_BYTES)) + .layer(middleware::from_fn_with_state( + rate_limit, + rate_limit_middleware, + )) .layer(cors) .with_state(state) } async fn rate_limit_middleware( - ConnectInfo(addr): ConnectInfo, - rate_limit_rps: u32, + State(rate_limit): State, req: Request, next: Next, ) -> axum::response::Response { - static LIMITERS: std::sync::OnceLock>> = std::sync::OnceLock::new(); - - let limiters = LIMITERS.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new())); - let ip = addr.ip().to_string(); - - let mut map = limiters.lock().unwrap(); - let limiter = map.entry(ip).or_insert_with(|| { - governor::RateLimiter::direct( - governor::Quota::per_second( - std::num::NonZeroU32::new(rate_limit_rps).unwrap() + if req.method() == axum::http::Method::POST && req.uri().path() == "/payments" { + let key = rate_limit_key(&req); + let limited = { + let mut map = rate_limit.limiters.lock().unwrap(); + let limiter = map.entry(key).or_insert_with(|| { + governor::RateLimiter::direct(governor::Quota::per_second( + NonZeroU32::new(rate_limit.requests_per_sec).unwrap(), + )) + }); + limiter.check().is_err() + }; + + if limited { + let retry_after = HeaderValue::from_static("1"); + return ( + StatusCode::TOO_MANY_REQUESTS, + [(header::RETRY_AFTER, retry_after)], + Json(json!({ + "error": "rate limit exceeded", + "code": "rate_limit_exceeded" + })), ) - ) - }); - - if limiter.check().is_err() { - let retry_after = (1000 / rate_limit_rps).max(1); - return ( - StatusCode::TOO_MANY_REQUESTS, - [( - axum::http::header::RETRY_AFTER, - axum::http::HeaderValue::from_str(&retry_after.to_string()).unwrap(), - )], - axum::Json(json!({ - "error": "rate limit exceeded", - "code": "rate_limit_exceeded" - })), - ) - .into_response(); + .into_response(); + } } - + next.run(req).await } +fn rate_limit_key(req: &Request) -> String { + if let Some(ConnectInfo(addr)) = req.extensions().get::>() { + return addr.ip().to_string(); + } + + for name in ["x-forwarded-for", "x-real-ip"] { + if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) { + if let Some(first) = value.split(',').map(str::trim).find(|s| !s.is_empty()) { + return first.to_string(); + } + } + } + + "local".to_string() +} + fn build_cors(cfg: &crate::config::Config) -> CorsLayer { use axum::http::HeaderName; use tower_http::cors::AllowOrigin; @@ -98,10 +133,8 @@ fn build_cors(cfg: &crate::config::Config) -> CorsLayer { return CorsLayer::permissive(); } - let allow_origins: Vec = origins - .iter() - .filter_map(|o| o.parse().ok()) - .collect(); + let allow_origins: Vec = + origins.iter().filter_map(|o| o.parse().ok()).collect(); CorsLayer::new() .allow_origin(AllowOrigin::list(allow_origins)) @@ -123,7 +156,11 @@ async fn health() -> impl IntoResponse { async fn ready(State(state): State>) -> impl IntoResponse { match db::ping(&state.pool).await { Ok(()) => (StatusCode::OK, Json(json!({ "status": "ok" }))).into_response(), - Err(_) => (StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "status": "unavailable" }))).into_response(), + Err(_) => ( + StatusCode::SERVICE_UNAVAILABLE, + Json(json!({ "status": "unavailable" })), + ) + .into_response(), } } diff --git a/src/config.rs b/src/config.rs index 9614b1c..a8c2c9f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -58,18 +58,19 @@ impl AcceptedAsset { pub fn default_list() -> Vec { vec![ - AcceptedAsset { code: "XLM".into(), issuer: None }, + AcceptedAsset { + code: "XLM".into(), + issuer: None, + }, AcceptedAsset { code: "USDC".into(), - issuer: Some( - "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5".into(), - ), + issuer: Some("GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5".into()), }, ] } } -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct Config { pub port: u16, pub database_url: String, @@ -84,6 +85,8 @@ pub struct Config { pub webhook_retry_attempts: u32, pub webhook_retry_delay_ms: u64, pub poll_interval_secs: u64, + /// Rate limit for POST /payments, counted per client key. + pub rate_limit_requests_per_sec: u32, /// How long a payment intent stays `pending` before the expiry sweeper /// transitions it to `expired`. Counted from the intent's `created_at`. pub payment_ttl_secs: u64, @@ -114,6 +117,7 @@ impl Config { webhook_retry_attempts: parse_env("WEBHOOK_RETRY_ATTEMPTS", 3), webhook_retry_delay_ms: parse_env("WEBHOOK_RETRY_DELAY_MS", 5000), poll_interval_secs: parse_env("POLL_INTERVAL_SECS", 10), + rate_limit_requests_per_sec: parse_env("RATE_LIMIT_REQUESTS_PER_SEC", 10), payment_ttl_secs: parse_env("PAYMENT_TTL_SECS", 3600), cors_allowed_origins: std::env::var("CORS_ALLOWED_ORIGINS") .unwrap_or_default() @@ -149,6 +153,10 @@ impl std::fmt::Debug for Config { .field("webhook_retry_attempts", &self.webhook_retry_attempts) .field("webhook_retry_delay_ms", &self.webhook_retry_delay_ms) .field("poll_interval_secs", &self.poll_interval_secs) + .field( + "rate_limit_requests_per_sec", + &self.rate_limit_requests_per_sec, + ) .field("cors_allowed_origins", &self.cors_allowed_origins) .field("listener_mode", &self.listener_mode) .finish() @@ -173,23 +181,51 @@ mod tests { webhook_retry_attempts: 3, webhook_retry_delay_ms: 5000, poll_interval_secs: 10, + rate_limit_requests_per_sec: 10, payment_ttl_secs: 3600, cors_allowed_origins: vec![], listener_mode: ListenerMode::Stream, }; let output = format!("{cfg:?}"); - assert!(!output.contains("super-secret-key"), "gateway_secret must not appear in Debug output"); - assert!(!output.contains("webhook-hmac-secret"), "webhook_secret must not appear in Debug output"); - assert!(output.contains("***"), "redacted marker must appear in Debug output"); + assert!( + !output.contains("super-secret-key"), + "gateway_secret must not appear in Debug output" + ); + assert!( + !output.contains("webhook-hmac-secret"), + "webhook_secret must not appear in Debug output" + ); + assert!( + output.contains("***"), + "redacted marker must appear in Debug output" + ); } #[test] fn parse_accepted_assets_from_env_string() { let assets = AcceptedAsset::parse_list("XLM,USDC:GISSUER,EURC:GISSUER2"); assert_eq!(assets.len(), 3); - assert_eq!(assets[0], AcceptedAsset { code: "XLM".into(), issuer: None }); - assert_eq!(assets[1], AcceptedAsset { code: "USDC".into(), issuer: Some("GISSUER".into()) }); - assert_eq!(assets[2], AcceptedAsset { code: "EURC".into(), issuer: Some("GISSUER2".into()) }); + assert_eq!( + assets[0], + AcceptedAsset { + code: "XLM".into(), + issuer: None + } + ); + assert_eq!( + assets[1], + AcceptedAsset { + code: "USDC".into(), + issuer: Some("GISSUER".into()) + } + ); + assert_eq!( + assets[2], + AcceptedAsset { + code: "EURC".into(), + issuer: Some("GISSUER2".into()) + } + ); } } diff --git a/src/horizon.rs b/src/horizon.rs index 9c1bf0d..67eb0a0 100644 --- a/src/horizon.rs +++ b/src/horizon.rs @@ -91,13 +91,22 @@ struct Embedded { #[derive(Debug, PartialEq, Eq)] pub enum Verdict { /// Cumulative paid amount equals the requested amount exactly. - Completed { tx_hash: String, paid_amount: String }, + Completed { + tx_hash: String, + paid_amount: String, + }, /// Cumulative paid amount exceeds the requested amount. /// The intent is fulfilled; `delta` is the excess the merchant should refund. - Overpaid { tx_hash: String, paid_amount: String }, + Overpaid { + tx_hash: String, + paid_amount: String, + }, /// Cumulative paid amount is still below the requested amount. /// The intent remains open; `delta` is the shortfall still owed. - Underpaid { tx_hash: String, paid_amount: String }, + Underpaid { + tx_hash: String, + paid_amount: String, + }, } impl HorizonPayment { @@ -154,9 +163,18 @@ pub fn verify( use std::cmp::Ordering; match total_paid.cmp(&expected) { - Ordering::Equal => Some(Verdict::Completed { tx_hash, paid_amount }), - Ordering::Greater => Some(Verdict::Overpaid { tx_hash, paid_amount }), - Ordering::Less => Some(Verdict::Underpaid { tx_hash, paid_amount }), + Ordering::Equal => Some(Verdict::Completed { + tx_hash, + paid_amount, + }), + Ordering::Greater => Some(Verdict::Overpaid { + tx_hash, + paid_amount, + }), + Ordering::Less => Some(Verdict::Underpaid { + tx_hash, + paid_amount, + }), } } @@ -222,7 +240,12 @@ async fn starting_cursor(state: &Arc) -> anyhow::Result { .json() .await?; - match page.embedded.records.first().and_then(|p| p.paging_token.clone()) { + match page + .embedded + .records + .first() + .and_then(|p| p.paging_token.clone()) + { Some(token) => { // Persist immediately so a crash before the first page still leaves // us baselined rather than replaying history next time. @@ -307,22 +330,54 @@ async fn reconcile_payment(state: &Arc, hp: &HorizonPayment) -> anyhow .and_then(money::parse_stroops) .unwrap_or(0); - match verify(&payment, hp, &state.config.usdc_issuer, already_paid_stroops) { - Some(Verdict::Completed { tx_hash, paid_amount }) => { - settle(state, &payment, "completed", &tx_hash, &paid_amount, "payment.completed", None).await; + match verify( + &payment, + hp, + &state.config.accepted_assets, + already_paid_stroops, + ) { + Some(Verdict::Completed { + tx_hash, + paid_amount, + }) => { + settle( + state, + &payment, + "completed", + &tx_hash, + &paid_amount, + "payment.completed", + None, + ) + .await; Ok(true) } - Some(Verdict::Overpaid { tx_hash, paid_amount }) => { + Some(Verdict::Overpaid { + tx_hash, + paid_amount, + }) => { let delta = delta_str(&paid_amount, &payment.amount); info!( payment_id = %payment.id, excess = %delta.as_deref().unwrap_or("?"), "overpayment — intent completed, excess should be refunded" ); - settle(state, &payment, "completed", &tx_hash, &paid_amount, "payment.overpaid", delta.as_deref()).await; + settle( + state, + &payment, + "completed", + &tx_hash, + &paid_amount, + "payment.overpaid", + delta.as_deref(), + ) + .await; Ok(true) } - Some(Verdict::Underpaid { tx_hash, paid_amount }) => { + Some(Verdict::Underpaid { + tx_hash, + paid_amount, + }) => { let delta = delta_str(&payment.amount, &paid_amount); warn!( payment_id = %payment.id, @@ -331,7 +386,16 @@ async fn reconcile_payment(state: &Arc, hp: &HorizonPayment) -> anyhow remaining = %delta.as_deref().unwrap_or("?"), "underpayment — intent remains open for a top-up" ); - settle(state, &payment, "underpaid", &tx_hash, &paid_amount, "payment.underpaid", delta.as_deref()).await; + settle( + state, + &payment, + "underpaid", + &tx_hash, + &paid_amount, + "payment.underpaid", + delta.as_deref(), + ) + .await; Ok(true) } None => Ok(false), @@ -594,7 +658,16 @@ mod tests { } fn test_assets() -> Vec { - crate::config::AcceptedAsset::parse_list("XLM,USDC:GUSDC") + vec![ + crate::config::AcceptedAsset { + code: "XLM".into(), + issuer: None, + }, + crate::config::AcceptedAsset { + code: "USDC".into(), + issuer: Some("GUSDC".into()), + }, + ] } #[test] @@ -791,7 +864,7 @@ mod tests { let hp: HorizonPayment = serde_json::from_str(data).unwrap(); let p = pending("XLM", "10.00"); assert!(matches!( - verify(&p, &hp, USDC_ISSUER, 0), + verify(&p, &hp, &test_assets(), 0), Some(Verdict::Completed { .. }) )); } diff --git a/src/main.rs b/src/main.rs index 13792a1..4628f3a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,11 @@ use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use stellargate::{api, config::{Config, ListenerMode}, db, expiry, horizon, AppState}; +use stellargate::{ + api, + config::{Config, ListenerMode}, + db, expiry, horizon, AppState, +}; use tracing::info; use tracing_subscriber::EnvFilter; @@ -47,9 +51,12 @@ async fn main() -> Result<()> { let listener = tokio::net::TcpListener::bind(&addr).await?; info!("StellarGate API listening on {addr}"); - axum::serve(listener, api::router(state)) - .with_graceful_shutdown(shutdown_signal()) - .await?; + axum::serve( + listener, + api::router(state).into_make_service_with_connect_info::(), + ) + .with_graceful_shutdown(shutdown_signal()) + .await?; info!("shutdown complete"); Ok(()) diff --git a/tests/api_tests.rs b/tests/api_tests.rs index 1c48898..399e6bc 100644 --- a/tests/api_tests.rs +++ b/tests/api_tests.rs @@ -4,7 +4,11 @@ use serde_json::{json, Value}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use std::str::FromStr; use std::sync::Arc; -use stellargate::{api, config::{Config, ListenerMode}, db, AppState}; +use stellargate::{ + api, + config::{AcceptedAsset, Config, ListenerMode}, + db, AppState, +}; use time::format_description::well_known::Rfc3339; fn make_config() -> Config { @@ -15,11 +19,12 @@ fn make_config() -> Config { horizon_url: String::new(), gateway_public: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5".into(), gateway_secret: String::new(), - usdc_issuer: "GBBD47IF6LWK7P7MDEVSCWR7DPUWV3NY3DTQEVFL4NAT4AQH3ZLLFLA5".into(), + accepted_assets: AcceptedAsset::default_list(), webhook_secret: String::new(), webhook_retry_attempts: 1, webhook_retry_delay_ms: 0, poll_interval_secs: 10, + rate_limit_requests_per_sec: 10, payment_ttl_secs: 3600, cors_allowed_origins: vec![], listener_mode: ListenerMode::Poll, @@ -28,13 +33,26 @@ fn make_config() -> Config { async fn test_server_with_pool() -> (TestServer, db::Db) { let cfg = make_config(); + test_server_with_config_and_pool(cfg).await +} + +async fn test_server_with_config_and_pool(cfg: Config) -> (TestServer, db::Db) { let pool = SqlitePoolOptions::new() - .connect_with(SqliteConnectOptions::from_str(&cfg.database_url).unwrap().create_if_missing(true)) + .connect_with( + SqliteConnectOptions::from_str(&cfg.database_url) + .unwrap() + .create_if_missing(true), + ) .await .unwrap(); db::migrate(&pool).await.unwrap(); let http = reqwest::Client::new(); - let server = TestServer::new(api::router(Arc::new(AppState { pool: pool.clone(), config: cfg, http }))).unwrap(); + let server = TestServer::new(api::router(Arc::new(AppState { + pool: pool.clone(), + config: cfg, + http, + }))) + .unwrap(); (server, pool) } @@ -42,6 +60,10 @@ async fn test_server() -> TestServer { test_server_with_pool().await.0 } +async fn test_server_with_config(cfg: Config) -> TestServer { + test_server_with_config_and_pool(cfg).await.0 +} + #[tokio::test] async fn test_health() { let res = test_server().await.get("/health").await; @@ -58,7 +80,8 @@ async fn test_ready_ok_with_live_db() { #[tokio::test] async fn test_create_payment() { - let res = test_server().await + let res = test_server() + .await .post("/payments") .json(&json!({ "amount": "10", "asset": "XLM" })) .await; @@ -75,7 +98,8 @@ async fn test_create_payment() { /// "2026-04-29T15:00:00Z" succeeds. #[tokio::test] async fn test_timestamps_are_rfc3339_utc() { - let res = test_server().await + let res = test_server() + .await .post("/payments") .json(&json!({ "amount": "1", "asset": "XLM" })) .await; @@ -83,16 +107,22 @@ async fn test_timestamps_are_rfc3339_utc() { let body: Value = res.json(); for field in ["created_at", "updated_at"] { - let ts = body[field].as_str().unwrap_or_else(|| panic!("{field} missing")); + let ts = body[field] + .as_str() + .unwrap_or_else(|| panic!("{field} missing")); time::OffsetDateTime::parse(ts, &Rfc3339) .unwrap_or_else(|e| panic!("{field} = {ts:?} is not valid RFC 3339: {e}")); - assert!(ts.ends_with('Z'), "{field} = {ts:?} must have explicit Z suffix"); + assert!( + ts.ends_with('Z'), + "{field} = {ts:?} must have explicit Z suffix" + ); } } #[tokio::test] async fn test_create_invalid_asset() { - let res = test_server().await + let res = test_server() + .await .post("/payments") .json(&json!({ "amount": "10", "asset": "BTC" })) .await; @@ -103,7 +133,8 @@ async fn test_create_invalid_asset() { #[tokio::test] async fn test_create_invalid_amount() { - let res = test_server().await + let res = test_server() + .await .post("/payments") .json(&json!({ "amount": "-1", "asset": "XLM" })) .await; @@ -113,10 +144,14 @@ async fn test_create_invalid_amount() { #[tokio::test] async fn test_get_by_id() { let server = test_server().await; - let id = server.post("/payments") + let id = server + .post("/payments") .json(&json!({ "amount": "5", "asset": "USDC" })) .await - .json::()["id"].as_str().unwrap().to_string(); + .json::()["id"] + .as_str() + .unwrap() + .to_string(); let res = server.get(&format!("/payments/{id}")).await; res.assert_status_ok(); @@ -125,7 +160,9 @@ async fn test_get_by_id() { // Timestamps on the GET response must also be strict RFC 3339. for field in ["created_at", "updated_at"] { - let ts = body[field].as_str().unwrap_or_else(|| panic!("{field} missing")); + let ts = body[field] + .as_str() + .unwrap_or_else(|| panic!("{field} missing")); time::OffsetDateTime::parse(ts, &Rfc3339) .unwrap_or_else(|e| panic!("{field} = {ts:?} is not valid RFC 3339: {e}")); } @@ -139,7 +176,8 @@ async fn test_get_not_found() { #[tokio::test] async fn test_reject_too_many_decimals() { - let res = test_server().await + let res = test_server() + .await .post("/payments") .json(&json!({ "amount": "1.00000001", "asset": "XLM" })) .await; @@ -148,7 +186,8 @@ async fn test_reject_too_many_decimals() { #[tokio::test] async fn test_asset_is_case_insensitive() { - let res = test_server().await + let res = test_server() + .await .post("/payments") .json(&json!({ "amount": "1", "asset": "usdc" })) .await; @@ -158,18 +197,44 @@ async fn test_asset_is_case_insensitive() { #[tokio::test] async fn test_reject_bad_webhook_url() { - let res = test_server().await + let res = test_server() + .await .post("/payments") .json(&json!({ "amount": "1", "asset": "XLM", "webhook_url": "ftp://x" })) .await; res.assert_status(StatusCode::BAD_REQUEST); } +#[tokio::test] +async fn test_create_payment_rate_limit() { + let mut cfg = make_config(); + cfg.rate_limit_requests_per_sec = 1; + let server = test_server_with_config(cfg).await; + + let first = server + .post("/payments") + .add_header("x-forwarded-for", "203.0.113.10") + .json(&json!({ "amount": "1", "asset": "XLM" })) + .await; + first.assert_status(StatusCode::CREATED); + + let second = server + .post("/payments") + .add_header("x-forwarded-for", "203.0.113.10") + .json(&json!({ "amount": "2", "asset": "XLM" })) + .await; + second.assert_status(StatusCode::TOO_MANY_REQUESTS); + assert_eq!(second.json::()["code"], "rate_limit_exceeded"); +} + #[tokio::test] async fn test_list_payments() { let server = test_server().await; for amt in ["1", "2", "3"] { - server.post("/payments").json(&json!({ "amount": amt, "asset": "XLM" })).await; + server + .post("/payments") + .json(&json!({ "amount": amt, "asset": "XLM" })) + .await; } let res = server.get("/payments").await; @@ -182,7 +247,10 @@ async fn test_list_payments() { #[tokio::test] async fn test_list_filter_by_status() { let server = test_server().await; - server.post("/payments").json(&json!({ "amount": "1", "asset": "XLM" })).await; + server + .post("/payments") + .json(&json!({ "amount": "1", "asset": "XLM" })) + .await; // All created payments start pending, so completed should be empty. let res = server.get("/payments?status=completed").await; @@ -203,7 +271,10 @@ async fn test_list_invalid_status() { async fn test_list_cursor_pagination() { let server = test_server().await; for amt in ["1", "2", "3", "4", "5"] { - server.post("/payments").json(&json!({ "amount": amt, "asset": "XLM" })).await; + server + .post("/payments") + .json(&json!({ "amount": amt, "asset": "XLM" })) + .await; } // Page 1 via offset path — also returns next_cursor for migration. @@ -211,21 +282,32 @@ async fn test_list_cursor_pagination() { res.assert_status_ok(); let body: Value = res.json(); assert_eq!(body["payments"].as_array().unwrap().len(), 2); - let cursor = body["next_cursor"].as_str().expect("next_cursor must be present on a full page"); + let cursor = body["next_cursor"] + .as_str() + .expect("next_cursor must be present on a full page"); // Page 2 via keyset cursor. - let res2 = server.get(&format!("/payments?cursor={cursor}&limit=2")).await; + let res2 = server + .get(&format!("/payments?cursor={cursor}&limit=2")) + .await; res2.assert_status_ok(); let body2: Value = res2.json(); assert_eq!(body2["payments"].as_array().unwrap().len(), 2); - let cursor2 = body2["next_cursor"].as_str().expect("next_cursor must be present on a full page"); + let cursor2 = body2["next_cursor"] + .as_str() + .expect("next_cursor must be present on a full page"); // Page 3 — last page, fewer items than limit. - let res3 = server.get(&format!("/payments?cursor={cursor2}&limit=2")).await; + let res3 = server + .get(&format!("/payments?cursor={cursor2}&limit=2")) + .await; res3.assert_status_ok(); let body3: Value = res3.json(); assert_eq!(body3["payments"].as_array().unwrap().len(), 1); - assert!(body3["next_cursor"].is_null(), "last page must have null next_cursor"); + assert!( + body3["next_cursor"].is_null(), + "last page must have null next_cursor" + ); // All 5 IDs are unique across all pages. let ids: Vec = [&body, &body2, &body3] @@ -239,7 +321,10 @@ async fn test_list_cursor_pagination() { #[tokio::test] async fn test_list_cursor_invalid() { - let res = test_server().await.get("/payments?cursor=notvalidhex!!").await; + let res = test_server() + .await + .get("/payments?cursor=notvalidhex!!") + .await; res.assert_status(StatusCode::BAD_REQUEST); } @@ -255,7 +340,10 @@ async fn test_unknown_route_returns_json_404() { #[tokio::test] async fn test_list_webhooks_not_found() { - let res = test_server().await.get("/payments/nonexistent/webhooks").await; + let res = test_server() + .await + .get("/payments/nonexistent/webhooks") + .await; res.assert_status(StatusCode::NOT_FOUND); let body: Value = res.json(); assert_eq!(body["error"], "payment not found"); @@ -265,10 +353,14 @@ async fn test_list_webhooks_not_found() { #[tokio::test] async fn test_list_webhooks_empty() { let server = test_server().await; - let id = server.post("/payments") + let id = server + .post("/payments") .json(&json!({ "amount": "5", "asset": "XLM" })) .await - .json::()["id"].as_str().unwrap().to_string(); + .json::()["id"] + .as_str() + .unwrap() + .to_string(); let res = server.get(&format!("/payments/{id}/webhooks")).await; res.assert_status_ok(); @@ -279,19 +371,28 @@ async fn test_list_webhooks_empty() { #[tokio::test] async fn test_redeliver_webhook_not_found() { - let res = test_server().await.post("/payments/nonexistent/webhooks/xyz/redeliver").await; + let res = test_server() + .await + .post("/payments/nonexistent/webhooks/xyz/redeliver") + .await; res.assert_status(StatusCode::NOT_FOUND); } #[tokio::test] async fn test_redeliver_delivery_not_found() { let server = test_server().await; - let id = server.post("/payments") + let id = server + .post("/payments") .json(&json!({ "amount": "5", "asset": "XLM" })) .await - .json::()["id"].as_str().unwrap().to_string(); + .json::()["id"] + .as_str() + .unwrap() + .to_string(); - let res = server.post(&format!("/payments/{id}/webhooks/nonexistent/redeliver")).await; + let res = server + .post(&format!("/payments/{id}/webhooks/nonexistent/redeliver")) + .await; res.assert_status(StatusCode::NOT_FOUND); } @@ -300,15 +401,23 @@ async fn test_webhook_delivery_isolation() { let (server, pool) = test_server_with_pool().await; // Create two payments - let id1 = server.post("/payments") + let id1 = server + .post("/payments") .json(&json!({ "amount": "5", "asset": "XLM" })) .await - .json::()["id"].as_str().unwrap().to_string(); + .json::()["id"] + .as_str() + .unwrap() + .to_string(); - let id2 = server.post("/payments") + let id2 = server + .post("/payments") .json(&json!({ "amount": "10", "asset": "USDC" })) .await - .json::()["id"].as_str().unwrap().to_string(); + .json::()["id"] + .as_str() + .unwrap() + .to_string(); // Manually insert a delivery for payment 1 stellargate::db::save_webhook_delivery( @@ -320,18 +429,26 @@ async fn test_webhook_delivery_isolation() { ) .await .unwrap(); - + // List webhooks for payment 1 should find it let res1 = server.get(&format!("/payments/{id1}/webhooks")).await; res1.assert_status_ok(); - assert_eq!(res1.json::()["deliveries"].as_array().unwrap().len(), 1); - + assert_eq!( + res1.json::()["deliveries"].as_array().unwrap().len(), + 1 + ); + // List webhooks for payment 2 should be empty let res2 = server.get(&format!("/payments/{id2}/webhooks")).await; res2.assert_status_ok(); - assert_eq!(res2.json::()["deliveries"].as_array().unwrap().len(), 0); - + assert_eq!( + res2.json::()["deliveries"].as_array().unwrap().len(), + 0 + ); + // Try to redeliver delivery from payment 1 on payment 2 (should fail) - let res_cross = server.post(&format!("/payments/{id2}/webhooks/delivery-1/redeliver")).await; + let res_cross = server + .post(&format!("/payments/{id2}/webhooks/delivery-1/redeliver")) + .await; res_cross.assert_status(StatusCode::NOT_FOUND); }