Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
use crate::{db, AppState};
use axum::{
extract::{ConnectInfo, Request, State},
http::StatusCode,
http::{header, HeaderValue, StatusCode},
middleware::{self, Next},
response::IntoResponse,
routing::{get, post},
Extension, Json,
Json,
};
use serde_json::json;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::num::NonZeroU32;
use std::sync::{Arc, Mutex};
use tower_http::{
cors::CorsLayer,
limit::RequestBodyLimitLayer,
Expand All @@ -26,9 +28,24 @@ const MAX_BODY_BYTES: usize = 256 * 1024;
#[derive(Clone)]
pub struct AuthenticatedMerchant(pub String);

#[derive(Clone)]
struct RateLimitState {
requests_per_sec: u32,
limiters: Arc<Mutex<HashMap<String, governor::DefaultDirectRateLimiter>>>,
}

impl RateLimitState {
fn new(requests_per_sec: u32) -> Self {
Self {
requests_per_sec: requests_per_sec.max(1),
limiters: Arc::new(Mutex::new(HashMap::new())),
}
}
}

pub fn router(state: Arc<AppState>) -> axum::Router {
let cors = build_cors(&state.config);
let rate_limit_rps = state.config.rate_limit_requests_per_sec;
let rate_limit = RateLimitState::new(state.config.rate_limit_requests_per_sec);

axum::Router::new()
.route("/", get(|| async { "StellarGate API v0.1.0" }))
Expand All @@ -55,17 +72,16 @@ pub fn router(state: Arc<AppState>) -> axum::Router {
"/:id/webhooks/:delivery_id/redeliver",
post(payments::redeliver_webhook),
)
.layer(middleware::from_fn(
move |ConnectInfo(addr): ConnectInfo<SocketAddr>, req: Request, next: Next| {
rate_limit_middleware(addr, rate_limit_rps, req, next)
},
))
})
.fallback(not_found)
.layer(PropagateRequestIdLayer::x_request_id())
.layer(TraceLayer::new_for_http())
.layer(SetRequestIdLayer::x_request_id(MakeRequestUuid))
.layer(RequestBodyLimitLayer::new(MAX_BODY_BYTES))
.layer(middleware::from_fn_with_state(
rate_limit,
rate_limit_middleware,
))
.layer(cors)
.with_state(state)
}
Expand Down Expand Up @@ -139,49 +155,54 @@ async fn provision_merchant(
}

async fn rate_limit_middleware(
addr: SocketAddr,
rate_limit_rps: u32,
State(rate_limit): State<RateLimitState>,
req: Request,
next: Next,
) -> axum::response::Response {
static LIMITERS: std::sync::OnceLock<
std::sync::Mutex<std::collections::HashMap<String, governor::DefaultDirectRateLimiter>>,
> = std::sync::OnceLock::new();

let limiters = LIMITERS.get_or_init(|| std::sync::Mutex::new(std::collections::HashMap::new()));
let ip = addr.ip().to_string();

// Scoped so the `MutexGuard` is dropped before the `next.run().await`
// below, keeping the returned future `Send`.
let allowed = {
let mut map = limiters.lock().unwrap();
let limiter = map.entry(ip).or_insert_with(|| {
governor::RateLimiter::direct(governor::Quota::per_second(
std::num::NonZeroU32::new(rate_limit_rps).unwrap(),
))
});
limiter.check().is_ok()
};
if req.method() == axum::http::Method::POST && req.uri().path() == "/payments" {
let key = rate_limit_key(&req);
let limited = {
let mut map = rate_limit.limiters.lock().unwrap();
let limiter = map.entry(key).or_insert_with(|| {
governor::RateLimiter::direct(governor::Quota::per_second(
NonZeroU32::new(rate_limit.requests_per_sec).unwrap(),
))
});
limiter.check().is_err()
};

if !allowed {
let retry_after = (1000 / rate_limit_rps).max(1);
return (
StatusCode::TOO_MANY_REQUESTS,
[(
axum::http::header::RETRY_AFTER,
axum::http::HeaderValue::from_str(&retry_after.to_string()).unwrap(),
)],
axum::Json(json!({
"error": "rate limit exceeded",
"code": "rate_limit_exceeded"
})),
)
.into_response();
if limited {
return (
StatusCode::TOO_MANY_REQUESTS,
[(header::RETRY_AFTER, HeaderValue::from_static("1"))],
Json(json!({
"error": "rate limit exceeded",
"code": "rate_limit_exceeded"
})),
)
.into_response();
}
}

next.run(req).await
}

fn rate_limit_key(req: &Request) -> String {
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
return addr.ip().to_string();
}

for name in ["x-forwarded-for", "x-real-ip"] {
if let Some(value) = req.headers().get(name).and_then(|v| v.to_str().ok()) {
if let Some(first) = value.split(',').map(str::trim).find(|s| !s.is_empty()) {
return first.to_string();
}
}
}

"local".to_string()
}

fn build_cors(cfg: &crate::config::Config) -> CorsLayer {
use axum::http::HeaderName;
use tower_http::cors::AllowOrigin;
Expand Down
10 changes: 8 additions & 2 deletions src/api/payments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,14 @@ pub async fn list(
} else {
// Legacy offset pagination — kept for backward compatibility.
let offset = q.offset.unwrap_or(0).max(0);
let (payments, total) =
db::list_payments(&state.pool, &merchant_id, q.status.as_deref(), limit, offset).await?;
let (payments, total) = db::list_payments(
&state.pool,
&merchant_id,
q.status.as_deref(),
limit,
offset,
)
.await?;

// Provide next_cursor to ease migration to keyset pagination.
let next_cursor = payments.last().map(|p| encode_cursor(&p.created_at, &p.id));
Expand Down
33 changes: 16 additions & 17 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,13 @@ pub async fn list_payments(
.fetch_all(pool)
.await?;

let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM payments WHERE merchant_id = ? AND status = ?")
.bind(merchant_id)
.bind(s)
.fetch_one(pool)
.await?;
let total: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM payments WHERE merchant_id = ? AND status = ?",
)
.bind(merchant_id)
.bind(s)
.fetch_one(pool)
.await?;

(rows, total)
} else {
Expand Down Expand Up @@ -646,25 +648,22 @@ fn hash_api_key(raw: &str) -> String {
/// shown to the user by the caller and is not recoverable afterward.
pub async fn create_merchant(pool: &Db, id: &str, raw_key: &str) -> Result<()> {
let hash = hash_api_key(raw_key);
sqlx::query(
"INSERT INTO merchants (id, api_key_hash) VALUES (?, ?)",
)
.bind(id)
.bind(hash)
.execute(pool)
.await?;
sqlx::query("INSERT INTO merchants (id, api_key_hash) VALUES (?, ?)")
.bind(id)
.bind(hash)
.execute(pool)
.await?;
Ok(())
}

/// Look up a merchant by their raw API key. Returns `None` if the key does
/// not match any registered merchant.
pub async fn find_merchant_by_key(pool: &Db, raw_key: &str) -> Result<Option<String>> {
let hash = hash_api_key(raw_key);
let id: Option<String> =
sqlx::query_scalar("SELECT id FROM merchants WHERE api_key_hash = ?")
.bind(hash)
.fetch_optional(pool)
.await?;
let id: Option<String> = sqlx::query_scalar("SELECT id FROM merchants WHERE api_key_hash = ?")
.bind(hash)
.fetch_optional(pool)
.await?;
Ok(id)
}

Expand Down
11 changes: 10 additions & 1 deletion src/horizon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,16 @@ mod tests {
}

fn test_assets() -> Vec<crate::config::AcceptedAsset> {
crate::config::AcceptedAsset::parse_list("XLM,USDC:GUSDC")
vec![
crate::config::AcceptedAsset {
code: "XLM".into(),
issuer: None,
},
crate::config::AcceptedAsset {
code: "USDC".into(),
issuer: Some("GUSDC".into()),
},
]
}

#[test]
Expand Down
5 changes: 1 addition & 4 deletions tests/api_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ async fn test_server() -> TestServer {
async fn provision_merchant(server: &TestServer) -> String {
let res = server.post("/merchants").await;
res.assert_status(StatusCode::CREATED);
res.json::<Value>()["api_key"]
.as_str()
.unwrap()
.to_string()
res.json::<Value>()["api_key"].as_str().unwrap().to_string()
}

#[tokio::test]
Expand Down
17 changes: 12 additions & 5 deletions tests/rate_limit_tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
//! Rate-limit behaviour lives in its own integration binary on purpose.
//!
//! The limiter keeps a process-global table of per-IP limiters, created lazily
//! on the first request from each IP. Sharing that table with the broader API
//! tests (which run at a high limit) would let an earlier test create the limiter
//! for the test client's IP at the wrong rate. A dedicated test binary gives this
//! test a fresh, uncontaminated limiter table.
//! The broader API tests run at a high limit and exercise merchant auth heavily.
//! Keeping the low-quota assertion here makes the expected 429 path explicit.

use axum::http::StatusCode;
use axum_test::TestServer;
Expand Down Expand Up @@ -60,20 +57,30 @@ async fn server_with_config(cfg: Config) -> TestServer {
TestServer::new(router).unwrap()
}

async fn provision_merchant(server: &TestServer) -> String {
let res = server.post("/merchants").await;
res.assert_status(StatusCode::CREATED);
res.json::<Value>()["api_key"].as_str().unwrap().to_string()
}

#[tokio::test]
async fn test_rate_limit_exceeded_returns_429() {
let server = server_with_config(make_config(1)).await;
let key = provision_merchant(&server).await;
let auth = format!("Bearer {key}");

// The first request consumes the single per-second token.
let first = server
.post("/payments")
.add_header("Authorization", auth.clone())
.json(&json!({ "amount": "1", "asset": "XLM" }))
.await;
first.assert_status(StatusCode::CREATED);

// A second immediate request exceeds the quota and is rejected.
let second = server
.post("/payments")
.add_header("Authorization", auth)
.json(&json!({ "amount": "1", "asset": "XLM" }))
.await;
second.assert_status(StatusCode::TOO_MANY_REQUESTS);
Expand Down
Loading