diff --git a/server/Cargo.toml b/server/Cargo.toml index 197c7c0..26ee5d5 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -24,3 +24,5 @@ time = "0.3.46" openssl = { version = "0.10", features = ["vendored"] } openssl-sys = { version = "0.9", features = ["vendored"] } url = "2.5.8" +dashmap = "6.1" +governor = "0.10" diff --git a/server/src/config.rs b/server/src/config.rs index d2bcdd3..f23f552 100644 --- a/server/src/config.rs +++ b/server/src/config.rs @@ -3,6 +3,7 @@ use sqlx::postgres::PgPoolOptions; use std::sync::{Arc, Mutex}; use std::collections::HashMap; use crate::state::AppState; +use dashmap::DashMap; pub async fn setup_database_and_docker() -> Result> { // 1. Docker setup @@ -22,6 +23,7 @@ pub async fn setup_database_and_docker() -> Result = sqlx::query_as( "SELECT id FROM projects WHERE owner_id = $1 AND LOWER(slug) = LOWER($2)", @@ -637,6 +652,15 @@ pub async fn add_to_whitelist( .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Session Error: {}", e)))?; let user = user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized".to_string()))?; + // Apply rate limiting: 20 requests per minute per user + let rate_limiter = state.get_or_create_whitelist_rate_limiter(user.id); + if rate_limiter.check().is_err() { + return Err(( + StatusCode::TOO_MANY_REQUESTS, + "Rate limit exceeded. Please try again later.".to_string() + )); + } + let trimmed_url = payload.allowed_url.trim(); if trimmed_url.is_empty() { return Err((StatusCode::BAD_REQUEST, "allowed_url is required".to_string())); @@ -662,21 +686,66 @@ pub async fn add_to_whitelist( "Invalid URL format. URL must use http or https scheme and include a valid host.".to_string(), ))?; - // Unique(project_id, allowed_url) is enforced by the DB; ignore conflicts - let result = sqlx::query( - "INSERT INTO project_whitelists (project_id, allowed_url) \ - VALUES ($1, $2) \ - ON CONFLICT (project_id, allowed_url) DO NOTHING", + // Use a database transaction with table-level advisory lock to prevent race conditions + let mut tx = state.db.begin() + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Transaction Error: {}", e)))?; + + // Get an advisory lock for this project's whitelist to prevent concurrent modifications + sqlx::query("SELECT pg_advisory_xact_lock($1)") + .bind(project_id) + .execute(&mut *tx) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Lock Error: {}", e)))?; + + // Check if entry already exists (using normalized URL) + let exists: (bool,) = sqlx::query_as( + "SELECT EXISTS(SELECT 1 FROM project_whitelists WHERE project_id = $1 AND allowed_url = $2)", ) .bind(project_id) .bind(&normalized_url) - .execute(&state.db) - .await; + .fetch_one(&mut *tx) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("DB Error: {}", e)))?; - if let Err(e) = result { - return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("DB Error: {}", e))); + if exists.0 { + // Entry already exists, commit transaction and return success + tx.commit().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Commit Error: {}", e)))?; + return Ok(StatusCode::OK); // 200 OK - idempotent operation, entry already exists } + // Check current count + let count_row: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM project_whitelists WHERE project_id = $1", + ) + .bind(project_id) + .fetch_one(&mut *tx) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("DB Error: {}", e)))?; + + if count_row.0 >= MAX_WHITELIST_ENTRIES { + // Transaction will automatically rollback when dropped + return Err(( + StatusCode::FORBIDDEN, + format!("Maximum whitelist entries ({}) reached for this project", MAX_WHITELIST_ENTRIES) + )); + } + + // Insert the new entry (normalized) + sqlx::query( + "INSERT INTO project_whitelists (project_id, allowed_url) VALUES ($1, $2)", + ) + .bind(project_id) + .bind(normalized_url) + .execute(&mut *tx) + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Insert Error: {}", e)))?; + + // Commit transaction + tx.commit() + .await + .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Commit Error: {}", e)))?; + Ok(StatusCode::CREATED) } @@ -697,6 +766,15 @@ pub async fn remove_from_whitelist( .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Session Error: {}", e)))?; let user = user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized".to_string()))?; + // Apply rate limiting: 20 requests per minute per user + let rate_limiter = state.get_or_create_whitelist_rate_limiter(user.id); + if rate_limiter.check().is_err() { + return Err(( + StatusCode::TOO_MANY_REQUESTS, + "Rate limit exceeded. Please try again later.".to_string() + )); + } + let trimmed_url = payload.allowed_url.trim(); if trimmed_url.is_empty() { return Err((StatusCode::BAD_REQUEST, "allowed_url is required".to_string())); diff --git a/server/src/state.rs b/server/src/state.rs index 7dcac10..e93230a 100644 --- a/server/src/state.rs +++ b/server/src/state.rs @@ -2,6 +2,10 @@ use bollard::Docker; use std::sync::{Arc, Mutex, MutexGuard}; use std::collections::HashMap; use std::time::Instant; +use dashmap::DashMap; +use governor::{Quota, RateLimiter, clock::DefaultClock, state::{InMemoryState, NotKeyed}}; +use std::num::NonZeroU32; +use crate::handlers::project::WHITELIST_RATE_LIMIT_PER_MINUTE; // New Struct to track container details AND ownership #[derive(Clone, Debug)] @@ -25,6 +29,8 @@ pub struct AppState { pub github_id: String, pub github_secret: String, pub sessions: SessionMap, + // Rate limiter for whitelist operations: per-user tracking + pub whitelist_rate_limiters: Arc>>>, } impl AppState { @@ -37,4 +43,18 @@ impl AppState { } } } + + /// Get or create a rate limiter for a user's whitelist operations + pub fn get_or_create_whitelist_rate_limiter(&self, user_id: i64) -> Arc> { + self.whitelist_rate_limiters + .entry(user_id) + .or_insert_with(|| { + let quota = Quota::per_minute( + NonZeroU32::new(WHITELIST_RATE_LIMIT_PER_MINUTE) + .expect("WHITELIST_RATE_LIMIT_PER_MINUTE must be non-zero") + ); + Arc::new(RateLimiter::direct(quota)) + }) + .clone() + } } \ No newline at end of file