Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ time = "0.3.46"
openssl = { version = "0.10", features = ["vendored"] }
openssl-sys = { version = "0.9", features = ["vendored"] }
url = "2.5.8"
dashmap = "6.1"
governor = "0.10"
2 changes: 2 additions & 0 deletions server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use sqlx::postgres::PgPoolOptions;
use std::sync::{Arc, Mutex};
use std::collections::HashMap;
use crate::state::AppState;
use dashmap::DashMap;

pub async fn setup_database_and_docker() -> Result<AppState, Box<dyn std::error::Error>> {
// 1. Docker setup
Expand All @@ -22,6 +23,7 @@ pub async fn setup_database_and_docker() -> Result<AppState, Box<dyn std::error:
github_id: std::env::var("GITHUB_CLIENT_ID").expect("Missing GITHUB_CLIENT_ID"),
github_secret: std::env::var("GITHUB_CLIENT_SECRET").expect("Missing GITHUB_CLIENT_SECRET"),
sessions: Arc::new(Mutex::new(HashMap::new())),
whitelist_rate_limiters: Arc::new(DashMap::new()),
};

Ok(state)
Expand Down
96 changes: 87 additions & 9 deletions server/src/handlers/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ use crate::models::{User, ProjectSummary, PublishRequest, WhitelistRequest};
use std::collections::HashMap;
use url::Url;

// Maximum number of whitelist entries allowed per project
const MAX_WHITELIST_ENTRIES: i64 = 100;

// Rate limit for whitelist operations (requests per minute per user)
pub const WHITELIST_RATE_LIMIT_PER_MINUTE: u32 = 20;

#[derive(Deserialize)]
pub struct SearchQuery {
q: String,
Expand Down Expand Up @@ -592,6 +598,15 @@ pub async fn get_whitelist(
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Session Error: {}", e)))?;
let user = user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized".to_string()))?;

// Apply rate limiting: 20 requests per minute per user
let rate_limiter = state.get_or_create_whitelist_rate_limiter(user.id);
if rate_limiter.check().is_err() {
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Rate limit exceeded. Please try again later.".to_string()
));
}

// Resolve project_id for this owner + slug
let project_row: Option<(i64,)> = sqlx::query_as(
"SELECT id FROM projects WHERE owner_id = $1 AND LOWER(slug) = LOWER($2)",
Expand Down Expand Up @@ -637,6 +652,15 @@ pub async fn add_to_whitelist(
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Session Error: {}", e)))?;
let user = user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized".to_string()))?;

// Apply rate limiting: 20 requests per minute per user
let rate_limiter = state.get_or_create_whitelist_rate_limiter(user.id);
if rate_limiter.check().is_err() {
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Rate limit exceeded. Please try again later.".to_string()
));
}

let trimmed_url = payload.allowed_url.trim();
if trimmed_url.is_empty() {
return Err((StatusCode::BAD_REQUEST, "allowed_url is required".to_string()));
Expand All @@ -662,21 +686,66 @@ pub async fn add_to_whitelist(
"Invalid URL format. URL must use http or https scheme and include a valid host.".to_string(),
))?;

// Unique(project_id, allowed_url) is enforced by the DB; ignore conflicts
let result = sqlx::query(
"INSERT INTO project_whitelists (project_id, allowed_url) \
VALUES ($1, $2) \
ON CONFLICT (project_id, allowed_url) DO NOTHING",
// Use a database transaction with table-level advisory lock to prevent race conditions
let mut tx = state.db.begin()
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Transaction Error: {}", e)))?;

// Get an advisory lock for this project's whitelist to prevent concurrent modifications
sqlx::query("SELECT pg_advisory_xact_lock($1)")
.bind(project_id)
.execute(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Lock Error: {}", e)))?;

// Check if entry already exists (using normalized URL)
let exists: (bool,) = sqlx::query_as(
"SELECT EXISTS(SELECT 1 FROM project_whitelists WHERE project_id = $1 AND allowed_url = $2)",
)
.bind(project_id)
.bind(&normalized_url)
.execute(&state.db)
.await;
.fetch_one(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("DB Error: {}", e)))?;

if let Err(e) = result {
return Err((StatusCode::INTERNAL_SERVER_ERROR, format!("DB Error: {}", e)));
if exists.0 {
// Entry already exists, commit transaction and return success
tx.commit().await.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Commit Error: {}", e)))?;
return Ok(StatusCode::OK); // 200 OK - idempotent operation, entry already exists
}

// Check current count
let count_row: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM project_whitelists WHERE project_id = $1",
)
.bind(project_id)
.fetch_one(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("DB Error: {}", e)))?;

if count_row.0 >= MAX_WHITELIST_ENTRIES {
// Transaction will automatically rollback when dropped
return Err((
StatusCode::FORBIDDEN,
format!("Maximum whitelist entries ({}) reached for this project", MAX_WHITELIST_ENTRIES)
));
}

// Insert the new entry (normalized)
sqlx::query(
"INSERT INTO project_whitelists (project_id, allowed_url) VALUES ($1, $2)",
)
.bind(project_id)
.bind(normalized_url)
.execute(&mut *tx)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Insert Error: {}", e)))?;

// Commit transaction
tx.commit()
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Commit Error: {}", e)))?;

Ok(StatusCode::CREATED)
}

Expand All @@ -697,6 +766,15 @@ pub async fn remove_from_whitelist(
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Session Error: {}", e)))?;
let user = user.ok_or((StatusCode::UNAUTHORIZED, "Unauthorized".to_string()))?;

// Apply rate limiting: 20 requests per minute per user
let rate_limiter = state.get_or_create_whitelist_rate_limiter(user.id);
if rate_limiter.check().is_err() {
return Err((
StatusCode::TOO_MANY_REQUESTS,
"Rate limit exceeded. Please try again later.".to_string()
));
}

let trimmed_url = payload.allowed_url.trim();
if trimmed_url.is_empty() {
return Err((StatusCode::BAD_REQUEST, "allowed_url is required".to_string()));
Expand Down
20 changes: 20 additions & 0 deletions server/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ use bollard::Docker;
use std::sync::{Arc, Mutex, MutexGuard};
use std::collections::HashMap;
use std::time::Instant;
use dashmap::DashMap;
use governor::{Quota, RateLimiter, clock::DefaultClock, state::{InMemoryState, NotKeyed}};
use std::num::NonZeroU32;
use crate::handlers::project::WHITELIST_RATE_LIMIT_PER_MINUTE;

// New Struct to track container details AND ownership
#[derive(Clone, Debug)]
Expand All @@ -25,6 +29,8 @@ pub struct AppState {
pub github_id: String,
pub github_secret: String,
pub sessions: SessionMap,
// Rate limiter for whitelist operations: per-user tracking
pub whitelist_rate_limiters: Arc<DashMap<i64, Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>>>,
}

impl AppState {
Expand All @@ -37,4 +43,18 @@ impl AppState {
}
}
}

/// Get or create a rate limiter for a user's whitelist operations
pub fn get_or_create_whitelist_rate_limiter(&self, user_id: i64) -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
self.whitelist_rate_limiters
.entry(user_id)
.or_insert_with(|| {
let quota = Quota::per_minute(
NonZeroU32::new(WHITELIST_RATE_LIMIT_PER_MINUTE)
.expect("WHITELIST_RATE_LIMIT_PER_MINUTE must be non-zero")
);
Arc::new(RateLimiter::direct(quota))
})
.clone()
}
}