diff --git a/packages/edge/infra/guard/core/Cargo.toml b/packages/edge/infra/guard/core/Cargo.toml index 5dc5123a1b..cf266606e0 100644 --- a/packages/edge/infra/guard/core/Cargo.toml +++ b/packages/edge/infra/guard/core/Cargo.toml @@ -27,7 +27,7 @@ prometheus = "0.13.3" rivet-config.workspace = true rand = "0.8.5" cluster.workspace = true -scc = "2.0.7" +moka = { version = "0.12", features = ["future"] } pegboard.workspace = true regex = "1.10.3" futures-util = "0.3.30" diff --git a/packages/edge/infra/guard/core/src/proxy_service.rs b/packages/edge/infra/guard/core/src/proxy_service.rs index d599ecc609..c503cc292c 100644 --- a/packages/edge/infra/guard/core/src/proxy_service.rs +++ b/packages/edge/infra/guard/core/src/proxy_service.rs @@ -8,6 +8,8 @@ use std::{ time::{Duration, Instant}, }; +use tokio::sync::Mutex; + use bytes::Bytes; use futures_util::{SinkExt, StreamExt}; use global_error::*; @@ -18,8 +20,8 @@ use hyper::{Request, Response, StatusCode}; use hyper_tungstenite; use hyper_util::client::legacy::Client; use hyper_util::rt::TokioExecutor; +use moka::future::Cache; use rand; -use scc::HashMap as SccHashMap; use serde_json; use tokio::time::timeout; use tracing::Instrument; @@ -29,6 +31,8 @@ use uuid::Uuid; use crate::metrics; const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for"); +const ROUTE_CACHE_TTL: Duration = Duration::from_secs(60 * 10); // 10 minutes +const PROXY_STATE_CACHE_TTL: Duration = Duration::from_secs(60 * 60); // 1 hour // Routing types #[derive(Clone, Debug)] @@ -154,38 +158,40 @@ pub type MiddlewareFn = Arc< // Cache for routing results struct RouteCache { - cache: SccHashMap<(String, String), RouteConfig>, + cache: Cache<(String, String), RouteConfig>, } impl RouteCache { fn new() -> Self { Self { - cache: SccHashMap::new(), + cache: Cache::builder() + .max_capacity(10_000) + .time_to_live(ROUTE_CACHE_TTL) + .build(), } } #[tracing::instrument(skip_all)] async fn get(&self, hostname: &str, path: &str) -> Option { self.cache - .get_async(&(hostname.to_owned(), path.to_owned())) + .get(&(hostname.to_owned(), path.to_owned())) .await - .map(|v| v.clone()) } #[tracing::instrument(skip_all)] async fn insert(&self, hostname: String, path: String, result: RouteConfig) { - self.cache.upsert_async((hostname, path), result).await; + self.cache.insert((hostname, path), result).await; - metrics::ROUTE_CACHE_SIZE.set(self.cache.len() as i64); + metrics::ROUTE_CACHE_SIZE.set(self.cache.entry_count() as i64); } #[tracing::instrument(skip_all)] async fn purge(&self, hostname: &str, path: &str) { self.cache - .remove_async(&(hostname.to_owned(), path.to_owned())) + .invalidate(&(hostname.to_owned(), path.to_owned())) .await; - metrics::ROUTE_CACHE_SIZE.set(self.cache.len() as i64); + metrics::ROUTE_CACHE_SIZE.set(self.cache.entry_count() as i64); } } @@ -257,8 +263,8 @@ pub struct ProxyState { routing_fn: RoutingFn, middleware_fn: MiddlewareFn, route_cache: RouteCache, - rate_limiters: SccHashMap<(Uuid, std::net::IpAddr), RateLimiter>, - in_flight_counters: SccHashMap<(Uuid, std::net::IpAddr), InFlightCounter>, + rate_limiters: Cache<(Uuid, std::net::IpAddr), Arc>>, + in_flight_counters: Cache<(Uuid, std::net::IpAddr), Arc>>, port_type: PortType, } @@ -274,8 +280,14 @@ impl ProxyState { routing_fn, middleware_fn, route_cache: RouteCache::new(), - rate_limiters: SccHashMap::new(), - in_flight_counters: SccHashMap::new(), + rate_limiters: Cache::builder() + .max_capacity(10_000) + .time_to_live(PROXY_STATE_CACHE_TTL) + .build(), + in_flight_counters: Cache::builder() + .max_capacity(10_000) + .time_to_live(PROXY_STATE_CACHE_TTL) + .build(), port_type, } } @@ -465,28 +477,29 @@ impl ProxyState { let middleware_config = self.get_middleware_config(&actor_id).await?; let cache_key = (actor_id, ip_addr); - let entry = self - .rate_limiters - .entry_async(cache_key) - .instrument(tracing::info_span!("entry_async")) - .await; - if let scc::hash_map::Entry::Occupied(mut entry) = entry { - // Key exists, get and mutate existing RateLimiter - let write_guard = entry.get_mut(); - Ok(write_guard.try_acquire()) + + // Get existing limiter or create a new one + let limiter_arc = if let Some(existing_limiter) = self.rate_limiters.get(&cache_key).await { + existing_limiter } else { - // Key doesn't exist, insert a new RateLimiter - let mut limiter = RateLimiter::new( + let new_limiter = Arc::new(Mutex::new(RateLimiter::new( middleware_config.rate_limit.requests, middleware_config.rate_limit.period, - ); - let result = limiter.try_acquire(); - entry.insert_entry(limiter); + ))); + self.rate_limiters + .insert(cache_key, new_limiter.clone()) + .await; + metrics::RATE_LIMITER_COUNT.set(self.rate_limiters.entry_count() as i64); + new_limiter + }; - metrics::RATE_LIMITER_COUNT.set(self.rate_limiters.len() as i64); + // Try to acquire from the limiter + let result = { + let mut limiter = limiter_arc.lock().await; + limiter.try_acquire() + }; - Ok(result) - } + Ok(result) } #[tracing::instrument(skip_all)] @@ -504,25 +517,29 @@ impl ProxyState { let middleware_config = self.get_middleware_config(&actor_id).await?; let cache_key = (actor_id, ip_addr); - let entry = self - .in_flight_counters - .entry_async(cache_key) - .instrument(tracing::info_span!("entry_async")) - .await; - if let scc::hash_map::Entry::Occupied(mut entry) = entry { - // Key exists, get and mutate existing InFlightCounter - let write_guard = entry.get_mut(); - Ok(write_guard.try_acquire()) - } else { - // Key doesn't exist, insert a new InFlightCounter - let mut counter = InFlightCounter::new(middleware_config.max_in_flight.amount); - let result = counter.try_acquire(); - entry.insert_entry(counter); - metrics::IN_FLIGHT_COUNTER_COUNT.set(self.in_flight_counters.len() as i64); + // Get existing counter or create a new one + let counter_arc = + if let Some(existing_counter) = self.in_flight_counters.get(&cache_key).await { + existing_counter + } else { + let new_counter = Arc::new(Mutex::new(InFlightCounter::new( + middleware_config.max_in_flight.amount, + ))); + self.in_flight_counters + .insert(cache_key, new_counter.clone()) + .await; + metrics::IN_FLIGHT_COUNTER_COUNT.set(self.in_flight_counters.entry_count() as i64); + new_counter + }; + + // Try to acquire from the counter + let result = { + let mut counter = counter_arc.lock().await; + counter.try_acquire() + }; - Ok(result) - } + Ok(result) } #[tracing::instrument(skip_all)] @@ -534,12 +551,8 @@ impl ProxyState { }; let cache_key = (actor_id, ip_addr); - if let Some(mut counter) = self - .in_flight_counters - .get_async(&cache_key) - .instrument(tracing::info_span!("get_async")) - .await - { + if let Some(counter_arc) = self.in_flight_counters.get(&cache_key).await { + let mut counter = counter_arc.lock().await; counter.release(); } } @@ -651,16 +664,16 @@ impl ProxyService { .status(StatusCode::TOO_MANY_REQUESTS) .body(Full::::new(Bytes::new())) .map_err(Into::into) - } else { + } else { // Increment metrics metrics::PROXY_REQUEST_PENDING .with_label_values(&[&actor_id_str, &server_id_str, method_str, &path]) .inc(); - + metrics::PROXY_REQUEST_TOTAL .with_label_values(&[&actor_id_str, &server_id_str, method_str, &path]) .inc(); - + // Prepare to release in-flight counter when done let state_clone = self.state.clone(); crate::defer! { @@ -668,7 +681,7 @@ impl ProxyService { state_clone.release_in_flight(client_ip, &actor_id).await; }.instrument(tracing::info_span!("release_in_flight_task"))); } - + // Branch for WebSocket vs HTTP handling // Both paths will handle their own metrics and error handling if hyper_tungstenite::is_upgrade_request(&req) { @@ -688,20 +701,11 @@ impl ProxyService { // Record metrics let duration = start_time.elapsed(); metrics::PROXY_REQUEST_DURATION - .with_label_values(&[ - &actor_id_str, - &server_id_str, - &status, - ]) + .with_label_values(&[&actor_id_str, &server_id_str, &status]) .observe(duration.as_secs_f64()); metrics::PROXY_REQUEST_PENDING - .with_label_values(&[ - &actor_id_str, - &server_id_str, - method_str, - &path, - ]) + .with_label_values(&[&actor_id_str, &server_id_str, method_str, &path]) .dec(); res @@ -1614,8 +1618,6 @@ impl ProxyService { "Request received" ); - let start_time = Instant::now(); - // Process the request let result = self.handle_request(req).await;