From ce254d03f4cce47ad048f9b3ea01ae9dde3a94f4 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:02:55 +0100 Subject: [PATCH 01/11] rate limiting config prototype --- docs/configuration.md | 7 +++++++ docs/rules.md | 13 +++++++++++++ pingoo/config/config.rs | 1 + pingoo/config/config_file.rs | 1 + pingoo/rules.rs | 1 + rules/rules.rs | 9 +++++++++ 6 files changed, 32 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index 07602cc..1039081 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -68,6 +68,13 @@ rules: !http_request.user_agent.starts_with("Mozilla/") && !http_request.user_agent.contains("curl/") actions: - action: captcha + rate_limit_api_routes: + expression: http_request.path.starts_with("/api/") + actions: + - action: block + limit: + max: 10 + window: 60 # (optional) Lists can be used in rule expressions to match against a large number of values lists: diff --git a/docs/rules.md b/docs/rules.md index 0e521df..724abf5 100644 --- a/docs/rules.md +++ b/docs/rules.md @@ -118,3 +118,16 @@ Valid lists types: - `String` - `Ip` +## Rate limiting + +**pingoo.yml** +```yml +rules: + rate_limit_api_routes: + expression: http_request.path.starts_with("/api/") + actions: + - action: block + limit: + max: 10 + window: 60 +``` diff --git a/pingoo/config/config.rs b/pingoo/config/config.rs index 8c1bdc2..963fe3a 100644 --- a/pingoo/config/config.rs +++ b/pingoo/config/config.rs @@ -263,6 +263,7 @@ pub async fn load_and_validate() -> Result { .map(|expression| rules::compile_expression(&expression)) .map_or(Ok(None), |r| r.map(Some))?, actions: rule_config.actions, + limit: rule_config.limit, }) }) .collect::>() diff --git a/pingoo/config/config_file.rs b/pingoo/config/config_file.rs index 20e733d..2ab785e 100644 --- a/pingoo/config/config_file.rs +++ b/pingoo/config/config_file.rs @@ -98,6 +98,7 @@ pub struct ServiceConfigFileStaticNotFound { pub struct RuleConfigFile { pub expression: Option, pub actions: Vec, + pub limit: Option, } impl Default for ServiceConfigFileStaticNotFound { diff --git a/pingoo/rules.rs b/pingoo/rules.rs index 952c3ee..4a5ab77 100644 --- a/pingoo/rules.rs +++ b/pingoo/rules.rs @@ -11,6 +11,7 @@ pub struct Rule { pub name: String, pub expression: Option, pub actions: Vec, + pub limit: Option, } #[derive(Debug, Serialize)] diff --git a/rules/rules.rs b/rules/rules.rs index 3577b1b..c064a98 100644 --- a/rules/rules.rs +++ b/rules/rules.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -22,6 +24,13 @@ pub struct Rule { pub type CompiledExpression = bel::Program; pub type Context<'a> = bel::Context<'a>; +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "limit", rename_all = "snake_case")] +pub struct RateLimit { + pub max: u16, + pub window: Duration, +} + // pub struct CompiledRule { // pub id: Uuid, // pub updated_at: DateTime, From e1241267302247ad178ed046f2dc8a546108724c Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Mon, 5 Jan 2026 03:13:14 +0100 Subject: [PATCH 02/11] rate limiting sliding window --- Cargo.toml | 3 + Makefile | 4 + pingoo/main.rs | 1 + pingoo/rate_limiter.rs | 291 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+) create mode 100644 pingoo/rate_limiter.rs diff --git a/Cargo.toml b/Cargo.toml index 5519415..de9acc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,9 @@ zeroize = { workspace = true, features = ["simd", "derive"] } zstd = { workspace = true } +[features] +test-utils = ["tokio/test-util"] + [workspace] resolver = "2" diff --git a/Makefile b/Makefile index a9221e5..ddbd542 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,10 @@ fmt: check: cargo check +.PHONY: test +test: + cargo test --features test-utils + .PHONY: clean clean: rm -rf $(DIST_DIR) captcha/dist diff --git a/pingoo/main.rs b/pingoo/main.rs index b48be8c..5bbb9f6 100644 --- a/pingoo/main.rs +++ b/pingoo/main.rs @@ -13,6 +13,7 @@ mod error; mod geoip; mod listeners; mod lists; +mod rate_limiter; mod rules; mod serde_utils; mod service_discovery; diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs new file mode 100644 index 0000000..96d6e27 --- /dev/null +++ b/pingoo/rate_limiter.rs @@ -0,0 +1,291 @@ +use std::collections::HashMap; + +use tokio::time::Duration; +use tokio::time::Instant; + +pub struct RateLimiter { + limit: u16, + window: Duration, + state: HashMap, +} + +#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] +pub enum IpAddrBits { + V4([u8; 4]), // as in Ipv4Addr::octets() + V6([u8; 16]), // as in Ipv6Addr::octets() +} + +pub struct SlidingWindow { + limit: u16, + window: Duration, + previous_sampler: InMemorySampler, + current_sampler: InMemorySampler, +} + +trait Sampler { + fn new(window: Duration) -> Self; + fn increment(&mut self, limit: u16) -> Option<()>; + fn get_count(&self) -> u16; + fn get_created_at(&self) -> Instant; + fn get_approx(&self, next_window_duration: Duration) -> u64; +} + +#[derive(Debug, Copy, Clone)] +struct InMemorySampler { + window: Duration, + count: u16, + created_at: Instant, +} + +impl RateLimiter { + pub fn new(limit: u16, window: Duration) -> Self { + RateLimiter { + limit, + window, + state: HashMap::new(), + } + } + + pub fn can_resume(&mut self, ip: IpAddrBits) -> bool { + let mut new_ip_state = SlidingWindow::new(self.limit, self.window); + let mut result: bool = new_ip_state.can_resume(); + + self.state + .entry(ip) + .and_modify(|x| result = x.can_resume()) + .or_insert(new_ip_state); + result + } + + pub fn garbage_collect(&mut self) { + self.state + .retain(|_, v| v.get_last_sample_created_at().elapsed() < 2 * self.window); + } + + pub fn len(&self) -> usize { + self.state.len() + } +} + +impl SlidingWindow { + pub fn new(limit: u16, window: Duration) -> Self { + let mut sanitized_limit = limit; + if limit == u16::MAX { + sanitized_limit = limit - 1; + } + + SlidingWindow { + limit: sanitized_limit, + window, + previous_sampler: InMemorySampler::new(window), + current_sampler: InMemorySampler::new(window), + } + } + + pub fn can_resume(&mut self) -> bool { + if self.limit == 0 { + return false; + } + + if self.current_sampler.increment(self.limit).is_none() { + self.shuffle_samples(); + self.current_sampler.increment(self.limit); + } + + let elapsed = self.current_sampler.get_created_at() + self.current_sampler.get_created_at().elapsed() + - (self.previous_sampler.get_created_at() + self.window); + let approx = self.previous_sampler.get_approx(elapsed); + let current_count = self.current_sampler.get_count(); + u64::from(self.limit) >= approx + u64::from(current_count) + } + + pub fn get_last_sample_created_at(&self) -> Instant { + self.current_sampler.created_at + } + + fn shuffle_samples(&mut self) { + self.previous_sampler = self.current_sampler; + self.current_sampler = InMemorySampler::new(self.window); + } +} + +impl InMemorySampler { + fn is_expired(&self) -> bool { + self.created_at.elapsed().as_millis() > self.window.as_millis() + } +} + +impl Sampler for InMemorySampler { + fn new(window: Duration) -> Self { + InMemorySampler { + window, + count: 0, + created_at: Instant::now(), + } + } + + fn increment(&mut self, limit: u16) -> Option<()> { + if self.is_expired() { + return None; + } + + if limit >= self.count { + self.count += 1; + } + Some(()) + } + + fn get_count(&self) -> u16 { + self.count + } + + fn get_created_at(&self) -> Instant { + self.created_at + } + + fn get_approx(&self, next_window_duration: Duration) -> u64 { + if self.window > next_window_duration { + return u64::from(self.count) * (self.window.as_secs() - next_window_duration.as_secs()) + / self.window.as_secs(); + } + + 0 + } +} + +#[cfg(feature = "test-utils")] +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use tokio::time::sleep; + + use super::*; + + #[tokio::test(start_paused = true)] + async fn test_sliding_window_alg() { + // Tests example form https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ + // + // "Let's say I set a limit of 50 requests per minute on an API endpoint. + // In this situation, I did 18 requests during the current minute, which started 15 seconds ago + // and 42 requests during the entire previous minute." + // + // rate = 42 * ((60-15)/60) + 18 + // = 42 * 0.75 + 18 + // = 49.5 requests + let mut r = RateLimiter::new(50, Duration::new(60, 0)); + let ip = IpAddrBits::V4(Ipv4Addr::new(1, 1, 1, 1).octets()); + for _ in 0..42 { + assert!(r.can_resume(ip), "should resume until limit is not reached") + } + + sleep(Duration::from_secs(60 + 15)).await; + for _ in 0..19 { + assert!(r.can_resume(ip), "should resume for 42 * ((60-15)/60) + 19 = 50"); + } + + assert!(!r.can_resume(ip), "should break for 42 * ((60-15)/60) + 20 = 51"); + + sleep(Duration::from_secs(3)).await; + assert!(r.can_resume(ip), "should resume for 42 * ((60-(15+3))/60) + 21 = 50"); + } + + #[tokio::test(start_paused = true)] + async fn test_rate_limiter_gc() { + let mut limiter = RateLimiter::new(10, Duration::new(60, 0)); + let ips = [ + IpAddrBits::V4(Ipv4Addr::new(1, 1, 1, 1).octets()), + IpAddrBits::V4(Ipv4Addr::new(2, 2, 2, 2).octets()), + ]; + assert!(limiter.can_resume(ips[0])); + assert!(limiter.can_resume(ips[1])); + + sleep(Duration::from_secs(61)).await; + assert!(limiter.can_resume(ips[0])); + limiter.garbage_collect(); + assert_eq!(2, limiter.len()); + + sleep(Duration::from_secs(60)).await; + limiter.garbage_collect(); + assert_eq!( + 1, + limiter.len(), + "should garbage collect entries that were not updated for 2 * window" + ); + } + + #[tokio::test(start_paused = true)] + async fn test_sliding_window_boundary() { + let ip = IpAddrBits::V4(Ipv4Addr::new(1, 1, 1, 1).octets()); + let mut r = RateLimiter::new(1, Duration::new(1, 0)); + + assert!(r.can_resume(ip), "should allow once when limit is 1"); + assert!(!r.can_resume(ip), "should block on 2nd attempt when limit is 1"); + + let mut r = RateLimiter::new(0, Duration::new(1, 0)); + assert!(!r.can_resume(ip), "should treat zero limit as always limited"); + assert!(!r.can_resume(ip), "should treat zero limit as always limited"); + + let mut r = RateLimiter::new(0, Duration::new(0, 0)); + assert!( + !r.can_resume(ip), + "should treat zero limit as always limited, even when zero window" + ); + assert!( + !r.can_resume(ip), + "should treat zero limit as always limited, even when zero window" + ); + + let mut r = RateLimiter::new(1, Duration::new(1, 0)); + assert!(r.can_resume(ip), "allow - limit should take precedense over zero window"); + assert!(!r.can_resume(ip), "block - limit should take precedense over zero window"); + + let mut r = RateLimiter::new(u16::MAX, Duration::new(1, 0)); + for _ in 1..u16::MAX { + assert!(r.can_resume(ip), "allow - should handle limit overflow"); + } + assert!(!r.can_resume(ip), "block - should handle limit overflow"); + } + + #[tokio::test(start_paused = true)] + async fn test_inmemory_is_expired() { + let mut r = InMemorySampler::new(Duration::new(60, 0)); + let limit = 50; + assert!(r.increment(limit).is_some(), "should return Some when not expired"); + + sleep(Duration::from_secs(60)).await; + assert!(r.increment(limit).is_some(), "should return Some when still not expired"); + + sleep(Duration::from_secs(1)).await; + assert!(r.increment(limit).is_none(), "should return None when expired"); + } + + #[tokio::test] + async fn test_inmemory_get_count() { + let mut r = InMemorySampler::new(Duration::new(1, 0)); + let limit = 1; + assert_eq!(0, r.get_count()); + r.increment(limit).unwrap(); + assert_eq!(1, r.get_count()); + + for _ in 1..5 { + r.increment(limit).unwrap(); + assert_eq!(2, r.get_count(), "counter should not increase after crossing the limit"); + } + } + + #[tokio::test(start_paused = true)] + async fn test_inmemory_get_approx() { + let mut r = InMemorySampler::new(Duration::new(60, 0)); + sleep(Duration::from_secs(60)).await; + for _ in 0..42 { + r.increment(50); + } + + let start = Instant::now(); + assert_eq!(42, r.get_count()); + assert_eq!(42, r.get_approx(start.elapsed())); + sleep(Duration::from_secs(15)).await; + assert_eq!(31, r.get_approx(start.elapsed())); + } +} From c1afec76c622ed0f78e70de1bd98fa3396fbe7b2 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Mon, 5 Jan 2026 21:17:00 +0100 Subject: [PATCH 03/11] integrate rate limiter into http listener --- docs/configuration.md | 2 +- docs/rules.md | 2 +- pingoo/config/config.rs | 20 +++++++++- pingoo/listeners/http_listener.rs | 27 ++++++++++++- pingoo/rate_limiter.rs | 63 ++++++++++++++++++++----------- pingoo/rules.rs | 5 ++- pingoo/services/http_utils.rs | 12 ++++++ rules/rules.rs | 3 +- 8 files changed, 104 insertions(+), 30 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 1039081..564d69f 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -71,7 +71,7 @@ rules: rate_limit_api_routes: expression: http_request.path.starts_with("/api/") actions: - - action: block + - action: limit limit: max: 10 window: 60 diff --git a/docs/rules.md b/docs/rules.md index 724abf5..f24e1a3 100644 --- a/docs/rules.md +++ b/docs/rules.md @@ -126,7 +126,7 @@ rules: rate_limit_api_routes: expression: http_request.path.starts_with("/api/") actions: - - action: block + - action: limit limit: max: 10 window: 60 diff --git a/pingoo/config/config.rs b/pingoo/config/config.rs index 963fe3a..a6f7a48 100644 --- a/pingoo/config/config.rs +++ b/pingoo/config/config.rs @@ -9,13 +9,18 @@ use std::{ use http::StatusCode; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use tokio::fs; +use tokio::{ + fs, + sync::mpsc::{self, Sender}, + task::JoinHandle, +}; use tracing::{debug, warn}; use crate::{ Error, config::config_file::{ConfigFile, RuleConfigFile, parse_service}, lists::ListType, + rate_limiter::{Probe, get_rate_limit_manager}, rules::Rule, service_discovery::service_registry::Upstream, tls::acme::LETSENCRYPT_PRODUCTION_URL, @@ -46,6 +51,7 @@ pub struct Config { pub service_discovery: ServiceDiscoveryConfig, pub lists: HashMap, pub child_process: Option, + pub limiter_workers: Vec>, } #[derive(Clone, Copy, Debug, Deserialize, Serialize, Eq, PartialEq)] @@ -252,10 +258,19 @@ pub async fn load_and_validate() -> Result { .collect(); validate_listeners_config(&listeners, &services)?; + let mut limiter_workers: Vec> = vec![]; let rules: Vec = config_file .rules .into_iter() .map(|(rule_name, rule_config)| { + let mut limiter_tx: Option> = None; + if let Some(limiter_cfg) = rule_config.limit { + let buffer = 1024; // todo make configurable + let (tx, rx) = mpsc::channel(buffer); + limiter_workers.push(get_rate_limit_manager(rx, limiter_cfg)); + limiter_tx = Some(tx); + } + Ok(Rule { name: rule_name, expression: rule_config @@ -263,7 +278,7 @@ pub async fn load_and_validate() -> Result { .map(|expression| rules::compile_expression(&expression)) .map_or(Ok(None), |r| r.map(Some))?, actions: rule_config.actions, - limit: rule_config.limit, + limiter_tx, }) }) .collect::>() @@ -318,6 +333,7 @@ pub async fn load_and_validate() -> Result { service_discovery: config_file.service_discovery.unwrap_or_default(), lists, child_process: config_file.child_process, + limiter_workers, }; return Ok(config); diff --git a/pingoo/listeners/http_listener.rs b/pingoo/listeners/http_listener.rs index 703df50..96c98b9 100644 --- a/pingoo/listeners/http_listener.rs +++ b/pingoo/listeners/http_listener.rs @@ -17,12 +17,14 @@ use crate::{ config::ListenerConfig, geoip::{self, GeoipDB, GeoipRecord}, listeners::{GRACEFUL_SHUTDOWN_TIMEOUT, Listener, accept_tcp_connection, bind_tcp_socket}, + rate_limiter::get_probe, rules, services::{ HttpService, http_utils::{ HOSTNAME_MAX_LENGTH, RequestContext, RequestExtensionContext, USER_AGENT_MAX_LENGTH, get_path, - new_blocked_response, new_not_found_error, + new_blocked_response, new_internal_error_response_500, new_not_found_error, + new_too_many_requests_response_429, }, }, }; @@ -258,6 +260,29 @@ pub(super) async fn serve_http_requests { + // todo: if "action: limit" then this must be defined - not Option + if let Some(tx) = rule.limiter_tx.clone() { + let (probe, rx) = get_probe(client_data.ip); + if let Err(err) = tx.send(probe).await { + error!("couldn't send request probe to rate limiter: {err}"); + return Ok(new_internal_error_response_500()); + } + + let result = rx.await; + let can_resume; + if let Err(err) = result { + error!("couldn't receive rate limiter result: {err}"); + return Ok(new_internal_error_response_500()); + } else { + can_resume = result.expect("rate limiter result should be received"); + } + + if !can_resume { + return Ok(new_too_many_requests_response_429()); + } + } + } } } } diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index 96d6e27..31850af 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -1,21 +1,41 @@ use std::collections::HashMap; +use std::net::IpAddr; +use rules::RateLimit; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; use tokio::time::Duration; use tokio::time::Instant; -pub struct RateLimiter { - limit: u16, - window: Duration, - state: HashMap, +pub fn get_rate_limit_handle(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + tokio::spawn(async move { + let mut limiter = RateLimiter::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} + +pub fn get_probe(ip: IpAddr) -> (Probe, oneshot::Receiver) { + let (tx, rx) = oneshot::channel(); + (Probe { ip, resp: tx }, rx) } -#[derive(Debug, PartialEq, Eq, Hash, Copy, Clone)] -pub enum IpAddrBits { - V4([u8; 4]), // as in Ipv4Addr::octets() - V6([u8; 16]), // as in Ipv6Addr::octets() +type Responder = oneshot::Sender; +pub struct Probe { + ip: IpAddr, + resp: Responder, } -pub struct SlidingWindow { +struct RateLimiter { + limit: u16, + window: Duration, + state: HashMap, +} + +struct SlidingWindow { limit: u16, window: Duration, previous_sampler: InMemorySampler, @@ -46,14 +66,16 @@ impl RateLimiter { } } - pub fn can_resume(&mut self, ip: IpAddrBits) -> bool { - let mut new_ip_state = SlidingWindow::new(self.limit, self.window); - let mut result: bool = new_ip_state.can_resume(); - + pub fn can_resume(&mut self, ip: IpAddr) -> bool { + let mut result = false; self.state .entry(ip) .and_modify(|x| result = x.can_resume()) - .or_insert(new_ip_state); + .or_insert_with(|| { + let mut new_ip_state = SlidingWindow::new(self.limit, self.window); + result = new_ip_state.can_resume(); + new_ip_state + }); result } @@ -163,7 +185,7 @@ mod tests { use super::*; #[tokio::test(start_paused = true)] - async fn test_sliding_window_alg() { + async fn test_rate_limiter_alg() { // Tests example form https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ // // "Let's say I set a limit of 50 requests per minute on an API endpoint. @@ -174,7 +196,7 @@ mod tests { // = 42 * 0.75 + 18 // = 49.5 requests let mut r = RateLimiter::new(50, Duration::new(60, 0)); - let ip = IpAddrBits::V4(Ipv4Addr::new(1, 1, 1, 1).octets()); + let ip = Ipv4Addr::new(1, 1, 1, 1).into(); for _ in 0..42 { assert!(r.can_resume(ip), "should resume until limit is not reached") } @@ -193,10 +215,7 @@ mod tests { #[tokio::test(start_paused = true)] async fn test_rate_limiter_gc() { let mut limiter = RateLimiter::new(10, Duration::new(60, 0)); - let ips = [ - IpAddrBits::V4(Ipv4Addr::new(1, 1, 1, 1).octets()), - IpAddrBits::V4(Ipv4Addr::new(2, 2, 2, 2).octets()), - ]; + let ips = [Ipv4Addr::new(1, 1, 1, 1).into(), Ipv4Addr::new(2, 2, 2, 2).into()]; assert!(limiter.can_resume(ips[0])); assert!(limiter.can_resume(ips[1])); @@ -215,8 +234,8 @@ mod tests { } #[tokio::test(start_paused = true)] - async fn test_sliding_window_boundary() { - let ip = IpAddrBits::V4(Ipv4Addr::new(1, 1, 1, 1).octets()); + async fn test_rate_limiter_boundary() { + let ip = Ipv4Addr::new(1, 1, 1, 1).into(); let mut r = RateLimiter::new(1, Duration::new(1, 0)); assert!(r.can_resume(ip), "should allow once when limit is 1"); diff --git a/pingoo/rules.rs b/pingoo/rules.rs index 4a5ab77..11299bb 100644 --- a/pingoo/rules.rs +++ b/pingoo/rules.rs @@ -2,16 +2,17 @@ use std::net::IpAddr; use http::Uri; use serde::Serialize; +use tokio::sync::mpsc::Sender; use tracing::warn; -use crate::{geoip::CountryCode, serde_utils}; +use crate::{geoip::CountryCode, rate_limiter::Probe, serde_utils}; #[derive(Debug, Clone)] pub struct Rule { pub name: String, pub expression: Option, pub actions: Vec, - pub limit: Option, + pub limiter_tx: Option>, } #[derive(Debug, Serialize)] diff --git a/pingoo/services/http_utils.rs b/pingoo/services/http_utils.rs index 733b6ec..b785f57 100644 --- a/pingoo/services/http_utils.rs +++ b/pingoo/services/http_utils.rs @@ -111,6 +111,18 @@ pub fn new_method_not_allowed_error() -> Response> .expect("error building new_method_not_allowed_error"); } +pub fn new_too_many_requests_response_429() -> Response> { + const ERROR_MESSAGE: &[u8] = b"429 Too Many Requests"; + let res_body = Full::new(Bytes::from_static(ERROR_MESSAGE)) + .map_err(|never| match never {}) + .boxed(); + return Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .header(header::CACHE_CONTROL, &CACHE_CONTROL_NO_CACHE) + .body(res_body) + .expect("error building new_method_not_allowed_error"); +} + pub fn get_path(req: &Request) -> &str { req.uri().path().trim_end_matches('/') } diff --git a/rules/rules.rs b/rules/rules.rs index c064a98..bcf32a5 100644 --- a/rules/rules.rs +++ b/rules/rules.rs @@ -28,7 +28,7 @@ pub type Context<'a> = bel::Context<'a>; #[serde(tag = "limit", rename_all = "snake_case")] pub struct RateLimit { pub max: u16, - pub window: Duration, + pub window: u16, } // pub struct CompiledRule { @@ -41,6 +41,7 @@ pub struct RateLimit { pub enum Action { Block {}, Captcha {}, + Limit {}, } #[derive(Debug, thiserror::Error)] From 041cfb148c6949ed3ffcf10e7a5c23dbf3399325 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Tue, 6 Jan 2026 12:31:49 +0100 Subject: [PATCH 04/11] naive garbage collection --- pingoo/rate_limiter.rs | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index 31850af..1caaf3d 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -14,6 +14,8 @@ pub fn get_rate_limit_handle(mut rx: mpsc::Receiver, limiter_cfg: RateLim while let Some(probe) = rx.recv().await { let result = limiter.can_resume(probe.ip); let _ = probe.resp.send(result); + + limiter.garbage_collect(); } }) } @@ -80,8 +82,24 @@ impl RateLimiter { } pub fn garbage_collect(&mut self) { - self.state - .retain(|_, v| v.get_last_sample_created_at().elapsed() < 2 * self.window); + // inspired by https://blog.nginx.org/blog/rate-limiting-nginx + // + // "Additionally, to prevent memory from being exhausted, every time NGINX creates a new + // entry it removes up to two entries that have not been used in the previous 60 + // seconds." + const ITEMS: usize = 2; + + let garbage: heapless::Vec = self + .state + .iter() + .filter(|(_, v)| v.get_last_sample_created_at().elapsed() > 2 * self.window) + .take(ITEMS) + .map(|(k, _)| k.clone()) + .collect(); + + for ip in garbage { + let _ = self.state.remove(&ip); + } } pub fn len(&self) -> usize { @@ -215,21 +233,29 @@ mod tests { #[tokio::test(start_paused = true)] async fn test_rate_limiter_gc() { let mut limiter = RateLimiter::new(10, Duration::new(60, 0)); - let ips = [Ipv4Addr::new(1, 1, 1, 1).into(), Ipv4Addr::new(2, 2, 2, 2).into()]; - assert!(limiter.can_resume(ips[0])); - assert!(limiter.can_resume(ips[1])); + let ips = [ + Ipv4Addr::new(1, 1, 1, 1).into(), + Ipv4Addr::new(2, 2, 2, 2).into(), + Ipv4Addr::new(3, 3, 3, 3).into(), + Ipv4Addr::new(4, 4, 4, 4).into(), + Ipv4Addr::new(5, 5, 5, 5).into(), + ]; + for ip in ips { + assert!(limiter.can_resume(ip)); + } sleep(Duration::from_secs(61)).await; assert!(limiter.can_resume(ips[0])); limiter.garbage_collect(); - assert_eq!(2, limiter.len()); + assert_eq!(ips.len(), limiter.len()); sleep(Duration::from_secs(60)).await; + assert!(limiter.can_resume(ips[0])); limiter.garbage_collect(); assert_eq!( - 1, + ips.len() - 2, limiter.len(), - "should garbage collect entries that were not updated for 2 * window" + "should garbage collect entries that were not updated for 2 * window, but no more than 2" ); } From ceae32cc2b5325f8fda2b3fa64924c52fa8f7b82 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Wed, 7 Jan 2026 20:03:09 +0100 Subject: [PATCH 05/11] reduce memory footprint of SlidingWindow --- pingoo/config/config.rs | 4 +- pingoo/rate_limiter.rs | 212 +++++++++++++++++++--------------------- 2 files changed, 103 insertions(+), 113 deletions(-) diff --git a/pingoo/config/config.rs b/pingoo/config/config.rs index a6f7a48..fe6c184 100644 --- a/pingoo/config/config.rs +++ b/pingoo/config/config.rs @@ -20,7 +20,7 @@ use crate::{ Error, config::config_file::{ConfigFile, RuleConfigFile, parse_service}, lists::ListType, - rate_limiter::{Probe, get_rate_limit_manager}, + rate_limiter::{Probe, get_rate_limit_handle}, rules::Rule, service_discovery::service_registry::Upstream, tls::acme::LETSENCRYPT_PRODUCTION_URL, @@ -267,7 +267,7 @@ pub async fn load_and_validate() -> Result { if let Some(limiter_cfg) = rule_config.limit { let buffer = 1024; // todo make configurable let (tx, rx) = mpsc::channel(buffer); - limiter_workers.push(get_rate_limit_manager(rx, limiter_cfg)); + limiter_workers.push(get_rate_limit_handle(rx, limiter_cfg)); limiter_tx = Some(tx); } diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index 1caaf3d..d9bf277 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::net::IpAddr; +use std::sync::Arc; use rules::RateLimit; use tokio::sync::mpsc; @@ -33,66 +34,74 @@ pub struct Probe { struct RateLimiter { limit: u16, - window: Duration, - state: HashMap, + sampling_period: Duration, + current_window: Instant, + state: HashMap, // todo: from heapless crate } struct SlidingWindow { - limit: u16, - window: Duration, - previous_sampler: InMemorySampler, - current_sampler: InMemorySampler, + sampler_green: InMemorySampler, + sampler_blue: InMemorySampler, + curr_sampler: Arc, + prev_sampler: Arc, } trait Sampler { - fn new(window: Duration) -> Self; - fn increment(&mut self, limit: u16) -> Option<()>; + fn new(starts_at: Instant) -> Self; + fn increment(&mut self); + fn reset(&mut self, starts_at: Instant); fn get_count(&self) -> u16; - fn get_created_at(&self) -> Instant; - fn get_approx(&self, next_window_duration: Duration) -> u64; + fn get_starts_at(&self) -> Instant; + fn get_approx(&self, sampling_period: Duration, next_window_needle: Duration) -> u64; } #[derive(Debug, Copy, Clone)] struct InMemorySampler { - window: Duration, count: u16, - created_at: Instant, + starts_at: Instant, } impl RateLimiter { - pub fn new(limit: u16, window: Duration) -> Self { + pub fn new(limit: u16, sampling_period: Duration) -> Self { + let current_window = Instant::now(); + let mut sanitized_limit = limit; + if limit == u16::MAX { + sanitized_limit = limit - 1; + } + RateLimiter { - limit, - window, + limit: sanitized_limit, + sampling_period, + current_window, state: HashMap::new(), } } pub fn can_resume(&mut self, ip: IpAddr) -> bool { + let now = Instant::now(); + if now >= self.current_window + self.sampling_period { + self.current_window = now; + } + let mut result = false; self.state .entry(ip) - .and_modify(|x| result = x.can_resume()) + .and_modify(|x| result = x.can_resume(self.limit, self.current_window, self.sampling_period)) .or_insert_with(|| { - let mut new_ip_state = SlidingWindow::new(self.limit, self.window); - result = new_ip_state.can_resume(); + let mut new_ip_state = SlidingWindow::new(self.sampling_period, self.current_window); + result = new_ip_state.can_resume(self.limit, self.current_window, self.sampling_period); new_ip_state }); result } pub fn garbage_collect(&mut self) { - // inspired by https://blog.nginx.org/blog/rate-limiting-nginx - // - // "Additionally, to prevent memory from being exhausted, every time NGINX creates a new - // entry it removes up to two entries that have not been used in the previous 60 - // seconds." const ITEMS: usize = 2; let garbage: heapless::Vec = self .state .iter() - .filter(|(_, v)| v.get_last_sample_created_at().elapsed() > 2 * self.window) + .filter(|(_, v)| v.get_last_sample_created_at().elapsed() > 2 * self.sampling_period) .take(ITEMS) .map(|(k, _)| k.clone()) .collect(); @@ -108,88 +117,83 @@ impl RateLimiter { } impl SlidingWindow { - pub fn new(limit: u16, window: Duration) -> Self { - let mut sanitized_limit = limit; - if limit == u16::MAX { - sanitized_limit = limit - 1; - } + pub fn new(sampling_period: Duration, current_window: Instant) -> Self { + let prev_window = current_window - sampling_period; + let curr = InMemorySampler::new(current_window); + let prev = InMemorySampler::new(prev_window); SlidingWindow { - limit: sanitized_limit, - window, - previous_sampler: InMemorySampler::new(window), - current_sampler: InMemorySampler::new(window), + sampler_green: prev, + sampler_blue: curr, + curr_sampler: Arc::new(curr), + prev_sampler: Arc::new(prev), } } - pub fn can_resume(&mut self) -> bool { - if self.limit == 0 { + pub fn can_resume(&mut self, limit: u16, current_window: Instant, sampling_period: Duration) -> bool { + if limit == 0 { return false; } - if self.current_sampler.increment(self.limit).is_none() { - self.shuffle_samples(); - self.current_sampler.increment(self.limit); + if current_window != self.curr_sampler.get_starts_at() { + self.shuffle_samplers(current_window, sampling_period); } - let elapsed = self.current_sampler.get_created_at() + self.current_sampler.get_created_at().elapsed() - - (self.previous_sampler.get_created_at() + self.window); - let approx = self.previous_sampler.get_approx(elapsed); - let current_count = self.current_sampler.get_count(); - u64::from(self.limit) >= approx + u64::from(current_count) + Arc::make_mut(&mut self.curr_sampler).increment(); + + let approx = self + .prev_sampler + .get_approx(sampling_period, self.prev_sampler.get_starts_at().elapsed() - sampling_period); + let current_count = self.curr_sampler.get_count(); + + u64::from(limit) >= approx + u64::from(current_count) } pub fn get_last_sample_created_at(&self) -> Instant { - self.current_sampler.created_at + self.curr_sampler.starts_at } - fn shuffle_samples(&mut self) { - self.previous_sampler = self.current_sampler; - self.current_sampler = InMemorySampler::new(self.window); - } -} + fn shuffle_samplers(&mut self, current_window: Instant, sampling_period: Duration) { + let mut next_sampler = self.prev_sampler.clone(); + Arc::make_mut(&mut next_sampler).reset(current_window); -impl InMemorySampler { - fn is_expired(&self) -> bool { - self.created_at.elapsed().as_millis() > self.window.as_millis() + if current_window.elapsed() > sampling_period + self.curr_sampler.get_starts_at().elapsed() { + Arc::make_mut(&mut self.curr_sampler).reset(current_window - sampling_period); + } + + self.prev_sampler = self.curr_sampler.clone(); + self.curr_sampler = next_sampler; } } impl Sampler for InMemorySampler { - fn new(window: Duration) -> Self { - InMemorySampler { - window, - count: 0, - created_at: Instant::now(), - } + fn new(starts_at: Instant) -> Self { + InMemorySampler { count: 0, starts_at } } - fn increment(&mut self, limit: u16) -> Option<()> { - if self.is_expired() { - return None; - } + fn increment(&mut self) { + self.count = self.count.saturating_add(1); + } - if limit >= self.count { - self.count += 1; - } - Some(()) + fn reset(&mut self, starts_at: Instant) { + self.count = 0; + self.starts_at = starts_at; } fn get_count(&self) -> u16 { self.count } - fn get_created_at(&self) -> Instant { - self.created_at + fn get_starts_at(&self) -> Instant { + self.starts_at } - fn get_approx(&self, next_window_duration: Duration) -> u64 { - if self.window > next_window_duration { - return u64::from(self.count) * (self.window.as_secs() - next_window_duration.as_secs()) - / self.window.as_secs(); + fn get_approx(&self, sampling_period: Duration, next_window_needle: Duration) -> u64 { + if next_window_needle >= sampling_period { + return 0; } - 0 + u64::from(self.count) * (sampling_period.as_secs() - next_window_needle.as_secs()) / sampling_period.as_secs() } } @@ -213,13 +217,24 @@ mod tests { // rate = 42 * ((60-15)/60) + 18 // = 42 * 0.75 + 18 // = 49.5 requests - let mut r = RateLimiter::new(50, Duration::new(60, 0)); + let limit = 50; + let sampling_period = Duration::from_secs(60); + let mut r = RateLimiter::new(limit, sampling_period); let ip = Ipv4Addr::new(1, 1, 1, 1).into(); + + for _ in 0..limit { + assert!(r.can_resume(ip), "should allow until limit is not reached"); + } + for _ in 0..u16::MAX { + assert!(!r.can_resume(ip), "should break when limit reached"); + } + sleep(2 * sampling_period).await; + for _ in 0..42 { - assert!(r.can_resume(ip), "should resume until limit is not reached") + assert!(r.can_resume(ip), "should resume until limit is not reached"); } - sleep(Duration::from_secs(60 + 15)).await; + sleep(sampling_period + Duration::from_secs(15)).await; for _ in 0..19 { assert!(r.can_resume(ip), "should resume for 42 * ((60-15)/60) + 19 = 50"); } @@ -232,7 +247,8 @@ mod tests { #[tokio::test(start_paused = true)] async fn test_rate_limiter_gc() { - let mut limiter = RateLimiter::new(10, Duration::new(60, 0)); + let sampling_period = Duration::from_secs(60); + let mut limiter = RateLimiter::new(10, sampling_period); let ips = [ Ipv4Addr::new(1, 1, 1, 1).into(), Ipv4Addr::new(2, 2, 2, 2).into(), @@ -244,12 +260,12 @@ mod tests { assert!(limiter.can_resume(ip)); } - sleep(Duration::from_secs(61)).await; + sleep(sampling_period + Duration::from_secs(1)).await; assert!(limiter.can_resume(ips[0])); limiter.garbage_collect(); assert_eq!(ips.len(), limiter.len()); - sleep(Duration::from_secs(60)).await; + sleep(sampling_period).await; assert!(limiter.can_resume(ips[0])); limiter.garbage_collect(); assert_eq!( @@ -292,45 +308,19 @@ mod tests { assert!(!r.can_resume(ip), "block - should handle limit overflow"); } - #[tokio::test(start_paused = true)] - async fn test_inmemory_is_expired() { - let mut r = InMemorySampler::new(Duration::new(60, 0)); - let limit = 50; - assert!(r.increment(limit).is_some(), "should return Some when not expired"); - - sleep(Duration::from_secs(60)).await; - assert!(r.increment(limit).is_some(), "should return Some when still not expired"); - - sleep(Duration::from_secs(1)).await; - assert!(r.increment(limit).is_none(), "should return None when expired"); - } - - #[tokio::test] - async fn test_inmemory_get_count() { - let mut r = InMemorySampler::new(Duration::new(1, 0)); - let limit = 1; - assert_eq!(0, r.get_count()); - r.increment(limit).unwrap(); - assert_eq!(1, r.get_count()); - - for _ in 1..5 { - r.increment(limit).unwrap(); - assert_eq!(2, r.get_count(), "counter should not increase after crossing the limit"); - } - } - #[tokio::test(start_paused = true)] async fn test_inmemory_get_approx() { - let mut r = InMemorySampler::new(Duration::new(60, 0)); - sleep(Duration::from_secs(60)).await; + let sampling_period = Duration::from_secs(60); + let mut r = InMemorySampler::new(Instant::now()); + sleep(sampling_period).await; for _ in 0..42 { - r.increment(50); + r.increment(); } let start = Instant::now(); assert_eq!(42, r.get_count()); - assert_eq!(42, r.get_approx(start.elapsed())); + assert_eq!(42, r.get_approx(sampling_period, start.elapsed(),)); sleep(Duration::from_secs(15)).await; - assert_eq!(31, r.get_approx(start.elapsed())); + assert_eq!(31, r.get_approx(sampling_period, start.elapsed(),)); } } From 0c8ecb451c7ad9ee8baa79314d7285f49d088788 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Wed, 7 Jan 2026 23:21:06 +0100 Subject: [PATCH 06/11] reduce heap allocations --- docs/configuration.md | 7 - docs/rules.md | 30 ++++- pingoo/listeners/http_listener.rs | 19 ++- pingoo/rate_limiter.rs | 207 +++++++++++++++++++++--------- pingoo/services/http_utils.rs | 12 ++ rules/rules.rs | 9 ++ 6 files changed, 211 insertions(+), 73 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 564d69f..07602cc 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -68,13 +68,6 @@ rules: !http_request.user_agent.starts_with("Mozilla/") && !http_request.user_agent.contains("curl/") actions: - action: captcha - rate_limit_api_routes: - expression: http_request.path.starts_with("/api/") - actions: - - action: limit - limit: - max: 10 - window: 60 # (optional) Lists can be used in rule expressions to match against a large number of values lists: diff --git a/docs/rules.md b/docs/rules.md index f24e1a3..fe7a4a8 100644 --- a/docs/rules.md +++ b/docs/rules.md @@ -120,6 +120,23 @@ Valid lists types: ## Rate limiting +Algorithm used to evaluate a request is a [sliding +window](https://blog.cloudflare.com/counting-things-a-lot-of-different-things/) +that uses request count from both current and previous period. + +`max` (u16) number of requests in given `period` (u16) denominated in seconds. +Rate limiters have finite `capacity` measured in buckets. E.g. `bucket_8` can +store no more than 256 entries in a timeframe of 2x `period`. + +Available bucket sizes range from `bucket_8` up to `bucket_30`. + +For a case when `max` threshold is crossed Pingoo responds with [HTTP 429 Too +Many +Requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Status/429). +For a case where `capacity` bucket is full Pingoo responds with [HTTP 503 +Service +Unavailable](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Status/503). + **pingoo.yml** ```yml rules: @@ -129,5 +146,16 @@ rules: - action: limit limit: max: 10 - window: 60 + period: 60 + capacity: bucket_10 ``` + +In this example Pingoo: +* protects resources under `/api` route +* allows no more than 10 requests per ONE minute +* starts returning HTTP 429 to the specific client, when number of incoming + requests from IP address of that client crossed the threshold of 10 in + sampling period of ONE minute +* can count requests for 1024 (2^10) unique IP addresses on every TWO minutes +* starts returning HTTP 503 to any client when on every TWO minutes period at + least one request came from 1024 unique IP addresses diff --git a/pingoo/listeners/http_listener.rs b/pingoo/listeners/http_listener.rs index 96c98b9..89f16a8 100644 --- a/pingoo/listeners/http_listener.rs +++ b/pingoo/listeners/http_listener.rs @@ -24,7 +24,7 @@ use crate::{ http_utils::{ HOSTNAME_MAX_LENGTH, RequestContext, RequestExtensionContext, USER_AGENT_MAX_LENGTH, get_path, new_blocked_response, new_internal_error_response_500, new_not_found_error, - new_too_many_requests_response_429, + new_service_unavailable_error_503, new_too_many_requests_response_429, }, }, }; @@ -269,15 +269,20 @@ pub(super) async fn serve_http_requests, limiter_cfg: RateLimit) -> JoinHandle<()> { - tokio::spawn(async move { - let mut limiter = RateLimiter::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - - limiter.garbage_collect(); - } - }) + match limiter_cfg.capacity { + RateLimitBucketSize::Bucket8 => get_rate_limit_handle_b8(rx, limiter_cfg), + RateLimitBucketSize::Bucket9 => get_rate_limit_handle_b9(rx, limiter_cfg), + } } -pub fn get_probe(ip: IpAddr) -> (Probe, oneshot::Receiver) { +pub fn get_probe(ip: IpAddr) -> (Probe, oneshot::Receiver) { let (tx, rx) = oneshot::channel(); (Probe { ip, resp: tx }, rx) } -type Responder = oneshot::Sender; +type Response = Result; +type Responder = oneshot::Sender; pub struct Probe { ip: IpAddr, resp: Responder, } -struct RateLimiter { +struct RateLimiterBucket { + inner: FnvIndexMap, +} +struct RateLimiter { limit: u16, sampling_period: Duration, current_window: Instant, - state: HashMap, // todo: from heapless crate + bucket: RateLimiterBucket, } struct SlidingWindow { @@ -61,7 +63,31 @@ struct InMemorySampler { starts_at: Instant, } -impl RateLimiter { +impl RateLimiterBucket { + pub fn new() -> Self { + Self { + inner: FnvIndexMap::new(), + } + } + + pub fn entry(&mut self, key: IpAddr) -> Entry<'_, IpAddr, SlidingWindow, N> { + self.inner.entry(key) + } + + pub fn iter(&self) -> Iter<'_, IpAddr, SlidingWindow> { + self.inner.iter() + } + + pub fn remove(&mut self, key: &IpAddr) -> Option { + self.inner.remove(key) + } + + pub fn len(&self) -> usize { + self.inner.len() + } +} + +impl RateLimiter { pub fn new(limit: u16, sampling_period: Duration) -> Self { let current_window = Instant::now(); let mut sanitized_limit = limit; @@ -69,37 +95,42 @@ impl RateLimiter { sanitized_limit = limit - 1; } - RateLimiter { + Self { limit: sanitized_limit, sampling_period, current_window, - state: HashMap::new(), + bucket: RateLimiterBucket::new(), } } - pub fn can_resume(&mut self, ip: IpAddr) -> bool { + pub fn can_resume(&mut self, ip: IpAddr) -> Result { let now = Instant::now(); if now >= self.current_window + self.sampling_period { self.current_window = now; } - let mut result = false; - self.state + let mut can_resume = false; + if let Ok(_) = self + .bucket .entry(ip) - .and_modify(|x| result = x.can_resume(self.limit, self.current_window, self.sampling_period)) + .and_modify(|x| can_resume = x.can_resume(self.limit, self.current_window, self.sampling_period)) .or_insert_with(|| { - let mut new_ip_state = SlidingWindow::new(self.sampling_period, self.current_window); - result = new_ip_state.can_resume(self.limit, self.current_window, self.sampling_period); - new_ip_state - }); - result + let mut new_ip_bucket = SlidingWindow::new(self.sampling_period, self.current_window); + can_resume = new_ip_bucket.can_resume(self.limit, self.current_window, self.sampling_period); + new_ip_bucket + }) + { + return Ok(can_resume); + } + + Err(()) } pub fn garbage_collect(&mut self) { const ITEMS: usize = 2; let garbage: heapless::Vec = self - .state + .bucket .iter() .filter(|(_, v)| v.get_last_sample_created_at().elapsed() > 2 * self.sampling_period) .take(ITEMS) @@ -107,12 +138,12 @@ impl RateLimiter { .collect(); for ip in garbage { - let _ = self.state.remove(&ip); + let _ = self.bucket.remove(&ip); } } pub fn len(&self) -> usize { - self.state.len() + self.bucket.len() } } @@ -168,7 +199,7 @@ impl SlidingWindow { impl Sampler for InMemorySampler { fn new(starts_at: Instant) -> Self { - InMemorySampler { count: 0, starts_at } + Self { count: 0, starts_at } } fn increment(&mut self) { @@ -197,6 +228,33 @@ impl Sampler for InMemorySampler { } } +// todo: some makro/crate to avoid this ugly pattern, which is a consequence of using heapless::index_map::FnvIndexMap +// todo: alternatively decide which buckets we want to support -- for reference see test_memory_footprint() +fn get_rate_limit_handle_b8(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(8) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + + limiter.garbage_collect(); + } + }) +} +fn get_rate_limit_handle_b9(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(9) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + + limiter.garbage_collect(); + } + }) +} + #[cfg(feature = "test-utils")] #[cfg(test)] mod tests { @@ -206,6 +264,21 @@ mod tests { use super::*; + #[test] + fn test_memory_footprint() { + assert_eq!(17, std::mem::size_of::()); + assert_eq!(64, std::mem::size_of::()); + + assert_eq!(96_469_040, std::mem::size_of::>()); // ~one milion IPs -> 96 MB of mem footprint + + assert_eq!(23_600, std::mem::size_of::>()); // bucket_8 can store 256 IPs and consume 23.6 kB + assert_eq!(94_256, std::mem::size_of::>()); // bucket_10 -> 94 kB + assert_eq!(1_507_376, std::mem::size_of::>()); // bucket_14 -> 16 384 IPs -> 1.5 MB + assert_eq!(6_029_360, std::mem::size_of::>()); // 65 536 IPs -> 6 MB + assert_eq!(12_058_672, std::mem::size_of::>()); // 131 072 IPs -> 12 MB + assert_eq!(48_234_544, std::mem::size_of::>()); // 524 288 IPs -> 48 MB + } + #[tokio::test(start_paused = true)] async fn test_rate_limiter_alg() { // Tests example form https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ @@ -219,36 +292,38 @@ mod tests { // = 49.5 requests let limit = 50; let sampling_period = Duration::from_secs(60); - let mut r = RateLimiter::new(limit, sampling_period); + let mut r = RateLimiter::<2>::new(limit, sampling_period); let ip = Ipv4Addr::new(1, 1, 1, 1).into(); for _ in 0..limit { - assert!(r.can_resume(ip), "should allow until limit is not reached"); + assert!(r.can_resume(ip).unwrap(), "should allow until limit is not reached"); } for _ in 0..u16::MAX { - assert!(!r.can_resume(ip), "should break when limit reached"); + assert!(!r.can_resume(ip).unwrap(), "should break when limit reached"); } sleep(2 * sampling_period).await; for _ in 0..42 { - assert!(r.can_resume(ip), "should resume until limit is not reached"); + assert!(r.can_resume(ip).unwrap(), "should resume until limit is not reached"); } sleep(sampling_period + Duration::from_secs(15)).await; for _ in 0..19 { - assert!(r.can_resume(ip), "should resume for 42 * ((60-15)/60) + 19 = 50"); + assert!(r.can_resume(ip).unwrap(), "should resume for 42 * ((60-15)/60) + 19 = 50"); } - assert!(!r.can_resume(ip), "should break for 42 * ((60-15)/60) + 20 = 51"); + assert!(!r.can_resume(ip).unwrap(), "should break for 42 * ((60-15)/60) + 20 = 51"); sleep(Duration::from_secs(3)).await; - assert!(r.can_resume(ip), "should resume for 42 * ((60-(15+3))/60) + 21 = 50"); + assert!(r.can_resume(ip).unwrap(), "should resume for 42 * ((60-(15+3))/60) + 21 = 50"); } + // todo: test backpressure behavior + #[tokio::test(start_paused = true)] async fn test_rate_limiter_gc() { let sampling_period = Duration::from_secs(60); - let mut limiter = RateLimiter::new(10, sampling_period); + let mut limiter = RateLimiter::<8>::new(10, sampling_period); let ips = [ Ipv4Addr::new(1, 1, 1, 1).into(), Ipv4Addr::new(2, 2, 2, 2).into(), @@ -257,55 +332,71 @@ mod tests { Ipv4Addr::new(5, 5, 5, 5).into(), ]; for ip in ips { - assert!(limiter.can_resume(ip)); + assert!(limiter.can_resume(ip).unwrap()); } sleep(sampling_period + Duration::from_secs(1)).await; - assert!(limiter.can_resume(ips[0])); + assert!(limiter.can_resume(ips[0]).unwrap()); limiter.garbage_collect(); assert_eq!(ips.len(), limiter.len()); sleep(sampling_period).await; - assert!(limiter.can_resume(ips[0])); + assert!(limiter.can_resume(ips[0]).unwrap()); limiter.garbage_collect(); assert_eq!( ips.len() - 2, limiter.len(), - "should garbage collect entries that were not updated for 2 * window, but no more than 2" + "should garbage collect up to 2 entries that were not updated for 2 * window" ); } #[tokio::test(start_paused = true)] async fn test_rate_limiter_boundary() { - let ip = Ipv4Addr::new(1, 1, 1, 1).into(); - let mut r = RateLimiter::new(1, Duration::new(1, 0)); + let ips = [ + Ipv4Addr::new(1, 1, 1, 1).into(), + Ipv4Addr::new(2, 2, 2, 2).into(), + Ipv4Addr::new(3, 3, 3, 3).into(), + ]; + let ip = ips[0]; + let mut r = RateLimiter::<2>::new(1, Duration::new(1, 0)); - assert!(r.can_resume(ip), "should allow once when limit is 1"); - assert!(!r.can_resume(ip), "should block on 2nd attempt when limit is 1"); + assert!(r.can_resume(ip).unwrap(), "should allow once when limit is 1"); + assert!(!r.can_resume(ip).unwrap(), "should block on 2nd attempt when limit is 1"); - let mut r = RateLimiter::new(0, Duration::new(1, 0)); - assert!(!r.can_resume(ip), "should treat zero limit as always limited"); - assert!(!r.can_resume(ip), "should treat zero limit as always limited"); + let mut r = RateLimiter::<2>::new(0, Duration::new(1, 0)); + assert!(!r.can_resume(ip).unwrap(), "should treat zero limit as always limited"); + assert!(!r.can_resume(ip).unwrap(), "should treat zero limit as always limited"); - let mut r = RateLimiter::new(0, Duration::new(0, 0)); + let mut r = RateLimiter::<2>::new(0, Duration::new(0, 0)); assert!( - !r.can_resume(ip), + !r.can_resume(ip).unwrap(), "should treat zero limit as always limited, even when zero window" ); assert!( - !r.can_resume(ip), + !r.can_resume(ip).unwrap(), "should treat zero limit as always limited, even when zero window" ); - let mut r = RateLimiter::new(1, Duration::new(1, 0)); - assert!(r.can_resume(ip), "allow - limit should take precedense over zero window"); - assert!(!r.can_resume(ip), "block - limit should take precedense over zero window"); + let mut r = RateLimiter::<2>::new(1, Duration::new(0, 0)); + assert!( + r.can_resume(ip).unwrap(), + "allow - limit should take precedense over zero window" + ); + assert!( + !r.can_resume(ip).unwrap(), + "block - limit should take precedense over zero window" + ); - let mut r = RateLimiter::new(u16::MAX, Duration::new(1, 0)); + let mut r = RateLimiter::<2>::new(u16::MAX, Duration::new(1, 0)); for _ in 1..u16::MAX { - assert!(r.can_resume(ip), "allow - should handle limit overflow"); + assert!(r.can_resume(ip).unwrap(), "allow - should handle limit overflow"); } - assert!(!r.can_resume(ip), "block - should handle limit overflow"); + assert!(!r.can_resume(ip).unwrap(), "block - should handle limit overflow"); + + let mut r = RateLimiter::<2>::new(1, Duration::new(1, 0)); + assert!(r.can_resume(ips[0]).unwrap(), "allow - should handle this IP"); + assert!(r.can_resume(ips[1]).unwrap(), "allow - should handle that IP"); + assert!(r.can_resume(ips[2]).is_err(), "error - should backpressure on another IP"); } #[tokio::test(start_paused = true)] diff --git a/pingoo/services/http_utils.rs b/pingoo/services/http_utils.rs index b785f57..610338e 100644 --- a/pingoo/services/http_utils.rs +++ b/pingoo/services/http_utils.rs @@ -75,6 +75,18 @@ pub fn new_bad_gateway_error() -> Response> { .expect("error building new_bad_gateway_error"); } +pub fn new_service_unavailable_error_503() -> Response> { + const ERROR_MESSAGE: &[u8] = b"503 Service Unavailable"; + let res_body = Full::new(Bytes::from_static(ERROR_MESSAGE)) + .map_err(|never| match never {}) + .boxed(); + return Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .header(header::CACHE_CONTROL, &CACHE_CONTROL_NO_CACHE) + .body(res_body) + .expect("error building new_bad_gateway_error"); +} + pub fn new_not_found_error() -> Response> { const NOT_FOUND_ERROR_MESSAGE: &[u8] = b"404 Not Found."; let res_body = Full::new(Bytes::from_static(NOT_FOUND_ERROR_MESSAGE)) diff --git a/rules/rules.rs b/rules/rules.rs index bcf32a5..9e3d411 100644 --- a/rules/rules.rs +++ b/rules/rules.rs @@ -29,6 +29,15 @@ pub type Context<'a> = bel::Context<'a>; pub struct RateLimit { pub max: u16, pub window: u16, + pub capacity: RateLimitBucketSize, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "limit", rename_all = "snake_case")] +pub enum RateLimitBucketSize { + // todo: some makro/crate to avoid this ugly pattern, which is a consequence of using heapless::index_map::FnvIndexMap + Bucket8, + Bucket9, } // pub struct CompiledRule { From 92bdcbafe35bd900a5bd4ddbcb899d2d684d0be9 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Sat, 10 Jan 2026 11:40:22 +0100 Subject: [PATCH 07/11] reduce mem footprint --- pingoo/rate_limiter.rs | 345 ++++++++++++++++------------------------- 1 file changed, 137 insertions(+), 208 deletions(-) diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index eac0eb9..df46220 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -1,18 +1,14 @@ -use std::net::IpAddr; -use std::sync::Arc; - -use heapless::index_map::Entry; use heapless::index_map::FnvIndexMap; -use heapless::index_map::Iter; use rules::RateLimit; use rules::RateLimitBucketSize; +use std::net::IpAddr; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio::time::Duration; use tokio::time::Instant; -pub fn get_rate_limit_handle(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { +pub fn get_rate_limit_handle(rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { match limiter_cfg.capacity { RateLimitBucketSize::Bucket8 => get_rate_limit_handle_b8(rx, limiter_cfg), RateLimitBucketSize::Bucket9 => get_rate_limit_handle_b9(rx, limiter_cfg), @@ -32,200 +28,143 @@ pub struct Probe { } struct RateLimiterBucket { - inner: FnvIndexMap, + starts_at: Instant, + inner: FnvIndexMap, } struct RateLimiter { limit: u16, sampling_period: Duration, - current_window: Instant, - bucket: RateLimiterBucket, + bucket_green: RateLimiterBucket, + bucket_blue: RateLimiterBucket, } -struct SlidingWindow { - sampler_green: InMemorySampler, - sampler_blue: InMemorySampler, - curr_sampler: Arc, - prev_sampler: Arc, -} - -trait Sampler { - fn new(starts_at: Instant) -> Self; - fn increment(&mut self); - fn reset(&mut self, starts_at: Instant); - fn get_count(&self) -> u16; - fn get_starts_at(&self) -> Instant; - fn get_approx(&self, sampling_period: Duration, next_window_needle: Duration) -> u64; -} - -#[derive(Debug, Copy, Clone)] -struct InMemorySampler { - count: u16, - starts_at: Instant, +#[derive(Debug)] +struct Counter { + pub sum: u16, } impl RateLimiterBucket { - pub fn new() -> Self { + pub fn new(starts_at: Instant) -> Self { Self { + starts_at, inner: FnvIndexMap::new(), } } - - pub fn entry(&mut self, key: IpAddr) -> Entry<'_, IpAddr, SlidingWindow, N> { - self.inner.entry(key) - } - - pub fn iter(&self) -> Iter<'_, IpAddr, SlidingWindow> { - self.inner.iter() - } - - pub fn remove(&mut self, key: &IpAddr) -> Option { - self.inner.remove(key) - } - - pub fn len(&self) -> usize { - self.inner.len() - } } impl RateLimiter { pub fn new(limit: u16, sampling_period: Duration) -> Self { - let current_window = Instant::now(); let mut sanitized_limit = limit; if limit == u16::MAX { sanitized_limit = limit - 1; } + let now = Instant::now(); + let before = create_prev_window(now, sampling_period); + + let bucket_green = RateLimiterBucket::new(now); + let bucket_blue = RateLimiterBucket::new(before); + Self { limit: sanitized_limit, sampling_period, - current_window, - bucket: RateLimiterBucket::new(), + bucket_green, + bucket_blue, } } pub fn can_resume(&mut self, ip: IpAddr) -> Result { - let now = Instant::now(); - if now >= self.current_window + self.sampling_period { - self.current_window = now; - } - - let mut can_resume = false; - if let Ok(_) = self - .bucket - .entry(ip) - .and_modify(|x| can_resume = x.can_resume(self.limit, self.current_window, self.sampling_period)) - .or_insert_with(|| { - let mut new_ip_bucket = SlidingWindow::new(self.sampling_period, self.current_window); - can_resume = new_ip_bucket.can_resume(self.limit, self.current_window, self.sampling_period); - new_ip_bucket - }) - { - return Ok(can_resume); + if self.limit == 0 { + return Ok(false); } - Err(()) - } - - pub fn garbage_collect(&mut self) { - const ITEMS: usize = 2; - - let garbage: heapless::Vec = self - .bucket - .iter() - .filter(|(_, v)| v.get_last_sample_created_at().elapsed() > 2 * self.sampling_period) - .take(ITEMS) - .map(|(k, _)| k.clone()) - .collect(); - - for ip in garbage { - let _ = self.bucket.remove(&ip); + let is_green_bucket_current = self.bucket_green.starts_at >= self.bucket_blue.starts_at; + let starts_at = match is_green_bucket_current { + true => self.bucket_green.starts_at, + false => self.bucket_blue.starts_at, + }; + let now = Instant::now(); + if !self.is_within_curr_bucket_window(now, starts_at) && self.sampling_period > Duration::from_nanos(0) { + if self.is_outside_next_monothonic_window(now, self.bucket_green.starts_at) { + self.bucket_green = RateLimiterBucket::new(now); + self.bucket_blue = RateLimiterBucket::new(now); + } else { + // current bucket becomes previous + if is_green_bucket_current { + let starts_at = self.bucket_green.starts_at + self.sampling_period; + self.bucket_blue = RateLimiterBucket::new(starts_at); + } else { + let starts_at = self.bucket_blue.starts_at + self.sampling_period; + self.bucket_green = RateLimiterBucket::new(starts_at); + } + } } - } - - pub fn len(&self) -> usize { - self.bucket.len() - } -} -impl SlidingWindow { - pub fn new(sampling_period: Duration, current_window: Instant) -> Self { - let prev_window = current_window - sampling_period; - - let curr = InMemorySampler::new(current_window); - let prev = InMemorySampler::new(prev_window); - SlidingWindow { - sampler_green: prev, - sampler_blue: curr, - curr_sampler: Arc::new(curr), - prev_sampler: Arc::new(prev), - } - } + let is_green_bucket_current = self.bucket_green.starts_at >= self.bucket_blue.starts_at; + let (curr_bucket, prev_bucket) = match is_green_bucket_current { + true => (&mut self.bucket_green, &mut self.bucket_blue), + false => (&mut self.bucket_blue, &mut self.bucket_green), + }; - pub fn can_resume(&mut self, limit: u16, current_window: Instant, sampling_period: Duration) -> bool { - if limit == 0 { - return false; - } - - if current_window != self.curr_sampler.get_starts_at() { - self.shuffle_samplers(current_window, sampling_period); + let curr_counter = curr_bucket + .inner + .entry(ip) + .and_modify(|x| { + x.increment(); + }) + .or_insert_with(|| { + let mut x = Counter::new(); + x.increment(); + x + }); + if curr_counter.is_err() { + return Err(()); } + let curr_sum = curr_counter.expect("counter should exist").sum; - Arc::make_mut(&mut self.curr_sampler).increment(); - - let approx = self - .prev_sampler - .get_approx(sampling_period, self.prev_sampler.get_starts_at().elapsed() - sampling_period); - let current_count = self.curr_sampler.get_count(); + let prev_sum = match prev_bucket.inner.get(&ip) { + Some(c) => c.sum, + None => 0, + }; - u64::from(limit) >= approx + u64::from(current_count) + let approx = get_approx(prev_sum, curr_bucket.starts_at.elapsed(), self.sampling_period); + Ok(u64::from(self.limit) >= approx + u64::from(curr_sum)) } - pub fn get_last_sample_created_at(&self) -> Instant { - self.curr_sampler.starts_at + fn is_within_curr_bucket_window(&self, now: Instant, curr_starts_at: Instant) -> bool { + let next_monothonic_window = curr_starts_at + self.sampling_period; + now < next_monothonic_window } - fn shuffle_samplers(&mut self, current_window: Instant, sampling_period: Duration) { - let mut next_sampler = self.prev_sampler.clone(); - Arc::make_mut(&mut next_sampler).reset(current_window); - - if current_window.elapsed() > sampling_period + self.curr_sampler.get_starts_at().elapsed() { - Arc::make_mut(&mut self.curr_sampler).reset(current_window - sampling_period); - } - - self.prev_sampler = self.curr_sampler.clone(); - self.curr_sampler = next_sampler; + fn is_outside_next_monothonic_window(&self, now: Instant, curr_starts_at: Instant) -> bool { + let next_monothonic_window = curr_starts_at + 2 * self.sampling_period; + now >= next_monothonic_window } } -impl Sampler for InMemorySampler { - fn new(starts_at: Instant) -> Self { - Self { count: 0, starts_at } - } - - fn increment(&mut self) { - self.count = self.count.saturating_add(1); - } - - fn reset(&mut self, starts_at: Instant) { - self.count = 0; - self.starts_at = starts_at; +impl Counter { + pub fn new() -> Self { + Self { sum: 0 } } - fn get_count(&self) -> u16 { - self.count + pub fn increment(&mut self) { + self.sum = self.sum.saturating_add(1); } +} - fn get_starts_at(&self) -> Instant { - self.starts_at +fn get_approx(prev_counter: u16, window_needle: Duration, sampling_period: Duration) -> u64 { + if window_needle >= sampling_period { + return 0; } - fn get_approx(&self, sampling_period: Duration, next_window_needle: Duration) -> u64 { - if next_window_needle >= sampling_period { - return 0; - } + u64::from(prev_counter) * (sampling_period.as_secs() - window_needle.as_secs()) / sampling_period.as_secs() +} - u64::from(self.count) * (sampling_period.as_secs() - next_window_needle.as_secs()) / sampling_period.as_secs() +fn create_prev_window(instant: Instant, sampling_period: Duration) -> Instant { + if instant.elapsed() < sampling_period { + return instant; } + instant - sampling_period } // todo: some makro/crate to avoid this ugly pattern, which is a consequence of using heapless::index_map::FnvIndexMap @@ -237,8 +176,6 @@ fn get_rate_limit_handle_b8(mut rx: mpsc::Receiver, limiter_cfg: RateLimi while let Some(probe) = rx.recv().await { let result = limiter.can_resume(probe.ip); let _ = probe.resp.send(result); - - limiter.garbage_collect(); } }) } @@ -249,8 +186,6 @@ fn get_rate_limit_handle_b9(mut rx: mpsc::Receiver, limiter_cfg: RateLimi while let Some(probe) = rx.recv().await { let result = limiter.can_resume(probe.ip); let _ = probe.resp.send(result); - - limiter.garbage_collect(); } }) } @@ -265,18 +200,20 @@ mod tests { use super::*; #[test] + // this test case serves more for memory footprint documentation fn test_memory_footprint() { assert_eq!(17, std::mem::size_of::()); - assert_eq!(64, std::mem::size_of::()); + assert_eq!(2, std::mem::size_of::()); - assert_eq!(96_469_040, std::mem::size_of::>()); // ~one milion IPs -> 96 MB of mem footprint + assert_eq!(54_526_024, std::mem::size_of::>()); // ~million IPs -> 54.5 MB of mem footprint - assert_eq!(23_600, std::mem::size_of::>()); // bucket_8 can store 256 IPs and consume 23.6 kB - assert_eq!(94_256, std::mem::size_of::>()); // bucket_10 -> 94 kB - assert_eq!(1_507_376, std::mem::size_of::>()); // bucket_14 -> 16 384 IPs -> 1.5 MB - assert_eq!(6_029_360, std::mem::size_of::>()); // 65 536 IPs -> 6 MB - assert_eq!(12_058_672, std::mem::size_of::>()); // 131 072 IPs -> 12 MB - assert_eq!(48_234_544, std::mem::size_of::>()); // 524 288 IPs -> 48 MB + assert_eq!(53_320, std::mem::size_of::>()); // bucket_10 can store 1024 IPs and consume 53 kB + assert_eq!(852_040, std::mem::size_of::>()); // bucket_14 -> 16 384 IPs -> 852 kB + assert_eq!(3_407_944, std::mem::size_of::>()); // 65 536 IPs -> 3.4 MB + assert_eq!(6_815_816, std::mem::size_of::>()); // 131 072 IPs -> 6.8 MB + assert_eq!(27_263_048, std::mem::size_of::>()); // 524 288 IPs -> 27.2 MB + assert_eq!(436_207_688, std::mem::size_of::>()); // ~9 million IPs -> 436.2 MB + assert_eq!(872_415_304, std::mem::size_of::>()); // ~17 million IPs -> 872.4 MB } #[tokio::test(start_paused = true)] @@ -295,14 +232,6 @@ mod tests { let mut r = RateLimiter::<2>::new(limit, sampling_period); let ip = Ipv4Addr::new(1, 1, 1, 1).into(); - for _ in 0..limit { - assert!(r.can_resume(ip).unwrap(), "should allow until limit is not reached"); - } - for _ in 0..u16::MAX { - assert!(!r.can_resume(ip).unwrap(), "should break when limit reached"); - } - sleep(2 * sampling_period).await; - for _ in 0..42 { assert!(r.can_resume(ip).unwrap(), "should resume until limit is not reached"); } @@ -316,52 +245,38 @@ mod tests { sleep(Duration::from_secs(3)).await; assert!(r.can_resume(ip).unwrap(), "should resume for 42 * ((60-(15+3))/60) + 21 = 50"); + + sleep(2 * sampling_period).await; + for _ in 0..limit { + assert!(r.can_resume(ip).unwrap(), "should allow until limit is not reached"); + } + for _ in 0..u16::MAX { + assert!(!r.can_resume(ip).unwrap(), "should break when limit reached"); + } } // todo: test backpressure behavior #[tokio::test(start_paused = true)] - async fn test_rate_limiter_gc() { - let sampling_period = Duration::from_secs(60); - let mut limiter = RateLimiter::<8>::new(10, sampling_period); + async fn test_rate_limiter_backpressure() { let ips = [ Ipv4Addr::new(1, 1, 1, 1).into(), Ipv4Addr::new(2, 2, 2, 2).into(), Ipv4Addr::new(3, 3, 3, 3).into(), - Ipv4Addr::new(4, 4, 4, 4).into(), - Ipv4Addr::new(5, 5, 5, 5).into(), ]; - for ip in ips { - assert!(limiter.can_resume(ip).unwrap()); - } - - sleep(sampling_period + Duration::from_secs(1)).await; - assert!(limiter.can_resume(ips[0]).unwrap()); - limiter.garbage_collect(); - assert_eq!(ips.len(), limiter.len()); - - sleep(sampling_period).await; - assert!(limiter.can_resume(ips[0]).unwrap()); - limiter.garbage_collect(); - assert_eq!( - ips.len() - 2, - limiter.len(), - "should garbage collect up to 2 entries that were not updated for 2 * window" - ); + let mut r = RateLimiter::<2>::new(1, Duration::new(1, 0)); + assert!(r.can_resume(ips[0]).unwrap(), "allow - should handle this IP"); + assert!(r.can_resume(ips[1]).unwrap(), "allow - should handle that IP"); + assert!(r.can_resume(ips[2]).is_err(), "error - should backpressure on another IP"); } #[tokio::test(start_paused = true)] - async fn test_rate_limiter_boundary() { - let ips = [ - Ipv4Addr::new(1, 1, 1, 1).into(), - Ipv4Addr::new(2, 2, 2, 2).into(), - Ipv4Addr::new(3, 3, 3, 3).into(), - ]; - let ip = ips[0]; + async fn test_rate_limiter_boundary_single_ip() { + let ip = Ipv4Addr::new(1, 1, 1, 1).into(); let mut r = RateLimiter::<2>::new(1, Duration::new(1, 0)); assert!(r.can_resume(ip).unwrap(), "should allow once when limit is 1"); - assert!(!r.can_resume(ip).unwrap(), "should block on 2nd attempt when limit is 1"); + assert!(!r.can_resume(ip).unwrap(), "should block n 2nd attempt when limit is 1"); let mut r = RateLimiter::<2>::new(0, Duration::new(1, 0)); assert!(!r.can_resume(ip).unwrap(), "should treat zero limit as always limited"); @@ -392,26 +307,40 @@ mod tests { assert!(r.can_resume(ip).unwrap(), "allow - should handle limit overflow"); } assert!(!r.can_resume(ip).unwrap(), "block - should handle limit overflow"); - - let mut r = RateLimiter::<2>::new(1, Duration::new(1, 0)); - assert!(r.can_resume(ips[0]).unwrap(), "allow - should handle this IP"); - assert!(r.can_resume(ips[1]).unwrap(), "allow - should handle that IP"); - assert!(r.can_resume(ips[2]).is_err(), "error - should backpressure on another IP"); } #[tokio::test(start_paused = true)] async fn test_inmemory_get_approx() { let sampling_period = Duration::from_secs(60); - let mut r = InMemorySampler::new(Instant::now()); + let mut counter: u16 = 0; sleep(sampling_period).await; for _ in 0..42 { - r.increment(); + counter = counter.saturating_add(1); } let start = Instant::now(); - assert_eq!(42, r.get_count()); - assert_eq!(42, r.get_approx(sampling_period, start.elapsed(),)); + assert_eq!(42, counter); + assert_eq!(42, get_approx(counter, start.elapsed(), sampling_period)); sleep(Duration::from_secs(15)).await; - assert_eq!(31, r.get_approx(sampling_period, start.elapsed(),)); + assert_eq!(31, get_approx(counter, start.elapsed(), sampling_period)); + } + + #[tokio::test(start_paused = true)] + async fn test_create_prev_window() { + let sampling_window = Duration::from_secs(60); + + for tc in vec![0, 13, 59] { + let now = Instant::now(); + let offset = Duration::from_secs(tc); + sleep(offset).await; + assert_eq!(now, create_prev_window(now, sampling_window)); + } + + for tc in vec![60, 73, 119] { + let now = Instant::now(); + let offset = Duration::from_secs(tc); + sleep(offset).await; + assert_eq!(now - sampling_window, create_prev_window(now, sampling_window)); + } } } From 08e7e9ce1391de78b806fd98ce0c5c1a363147fb Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Sat, 10 Jan 2026 22:14:43 +0100 Subject: [PATCH 08/11] proposed set of buckets + some impr --- docs/rules.md | 21 ++++++--- pingoo/rate_limiter.rs | 98 +++++++++++++++++++++++++++++++++++++----- rules/rules.rs | 14 +++--- 3 files changed, 110 insertions(+), 23 deletions(-) diff --git a/docs/rules.md b/docs/rules.md index fe7a4a8..34df576 100644 --- a/docs/rules.md +++ b/docs/rules.md @@ -125,10 +125,18 @@ window](https://blog.cloudflare.com/counting-things-a-lot-of-different-things/) that uses request count from both current and previous period. `max` (u16) number of requests in given `period` (u16) denominated in seconds. -Rate limiters have finite `capacity` measured in buckets. E.g. `bucket_8` can -store no more than 256 entries in a timeframe of 2x `period`. - -Available bucket sizes range from `bucket_8` up to `bucket_30`. +Rate limiters have finite `capacity` measured in buckets. E.g. `bucket_10` can +store no more than 1024 entries in a timeframe of 2x `period`. + +Available bucket sizes are: +* `bucket_10` --> stores up to 1024 IPs (2^10), consumes additional 53.3 kB of memory +* `bucket_14` --> stores up to ~16k IPs (2^14), consumes additional 852 kB of memory +* `bucket_16` --> stores up to ~65k IPs (2^16), consumes additional 3.4 MB of memory +* `bucket_17` --> stores up to ~130k IPs (2^17), consumes additional 6.8 MB of memory +* `bucket_19` --> stores up to ~524k IPs (2^19), consumes additional 27.2 MB of memory +* `bucket_20` --> stores up to ~1 million IPs (2^20), consumes additional 54.5 MB of memory +* `bucket_23` --> stores up to ~9 million IPs (2^23), consumes additional 436.2 MB of memory +* `bucket_24` --> stores up to ~130k IPs (2^24), consumes additional 872.4 MB of memory For a case when `max` threshold is crossed Pingoo responds with [HTTP 429 Too Many @@ -156,6 +164,5 @@ In this example Pingoo: * starts returning HTTP 429 to the specific client, when number of incoming requests from IP address of that client crossed the threshold of 10 in sampling period of ONE minute -* can count requests for 1024 (2^10) unique IP addresses on every TWO minutes -* starts returning HTTP 503 to any client when on every TWO minutes period at - least one request came from 1024 unique IP addresses +* can count requests for 1024 (2^10) unique IP addresses on every minute +* starts returning HTTP 503 to new clients if their IP is not in the bucket diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index df46220..44948e4 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -10,8 +10,17 @@ use tokio::time::Instant; pub fn get_rate_limit_handle(rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { match limiter_cfg.capacity { - RateLimitBucketSize::Bucket8 => get_rate_limit_handle_b8(rx, limiter_cfg), - RateLimitBucketSize::Bucket9 => get_rate_limit_handle_b9(rx, limiter_cfg), + // duplicated logic in each function is a consequence of using heapless::index_map::FnvIndexMap + // the only difference between them is Map capacity + // todo: some cleaner solution -- macro maybe? + RateLimitBucketSize::Bucket10 => get_rate_limit_handle_b10(rx, limiter_cfg), + RateLimitBucketSize::Bucket14 => get_rate_limit_handle_b14(rx, limiter_cfg), + RateLimitBucketSize::Bucket16 => get_rate_limit_handle_b16(rx, limiter_cfg), + RateLimitBucketSize::Bucket17 => get_rate_limit_handle_b17(rx, limiter_cfg), + RateLimitBucketSize::Bucket19 => get_rate_limit_handle_b19(rx, limiter_cfg), + RateLimitBucketSize::Bucket20 => get_rate_limit_handle_b20(rx, limiter_cfg), + RateLimitBucketSize::Bucket23 => get_rate_limit_handle_b23(rx, limiter_cfg), + RateLimitBucketSize::Bucket24 => get_rate_limit_handle_b24(rx, limiter_cfg), } } @@ -167,11 +176,9 @@ fn create_prev_window(instant: Instant, sampling_period: Duration) -> Instant { instant - sampling_period } -// todo: some makro/crate to avoid this ugly pattern, which is a consequence of using heapless::index_map::FnvIndexMap -// todo: alternatively decide which buckets we want to support -- for reference see test_memory_footprint() -fn get_rate_limit_handle_b8(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { +fn get_rate_limit_handle_b10(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { let mut limiter = - RateLimiter::<{ 2usize.pow(8) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + RateLimiter::<{ 2usize.pow(10) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); tokio::spawn(async move { while let Some(probe) = rx.recv().await { let result = limiter.can_resume(probe.ip); @@ -179,9 +186,69 @@ fn get_rate_limit_handle_b8(mut rx: mpsc::Receiver, limiter_cfg: RateLimi } }) } -fn get_rate_limit_handle_b9(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { +fn get_rate_limit_handle_b14(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { let mut limiter = - RateLimiter::<{ 2usize.pow(9) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + RateLimiter::<{ 2usize.pow(14) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} +fn get_rate_limit_handle_b16(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(16) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} +fn get_rate_limit_handle_b17(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(17) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} +fn get_rate_limit_handle_b19(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(19) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} +fn get_rate_limit_handle_b20(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(20) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} +fn get_rate_limit_handle_b23(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(23) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); + tokio::spawn(async move { + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} +fn get_rate_limit_handle_b24(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + let mut limiter = + RateLimiter::<{ 2usize.pow(24) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); tokio::spawn(async move { while let Some(probe) = rx.recv().await { let result = limiter.can_resume(probe.ip); @@ -255,19 +322,28 @@ mod tests { } } - // todo: test backpressure behavior - #[tokio::test(start_paused = true)] async fn test_rate_limiter_backpressure() { + let sampling_period = Duration::new(1, 0); let ips = [ Ipv4Addr::new(1, 1, 1, 1).into(), Ipv4Addr::new(2, 2, 2, 2).into(), Ipv4Addr::new(3, 3, 3, 3).into(), + Ipv4Addr::new(4, 4, 4, 4).into(), ]; - let mut r = RateLimiter::<2>::new(1, Duration::new(1, 0)); + let mut r = RateLimiter::<2>::new(10, sampling_period); assert!(r.can_resume(ips[0]).unwrap(), "allow - should handle this IP"); assert!(r.can_resume(ips[1]).unwrap(), "allow - should handle that IP"); assert!(r.can_resume(ips[2]).is_err(), "error - should backpressure on another IP"); + assert!(r.can_resume(ips[0]).unwrap(), "allow - should still handle this IP"); + assert!(r.can_resume(ips[1]).unwrap(), "allow - should still handle that IP"); + assert!(r.can_resume(ips[2]).is_err(), "error - should again backpressure on another IP"); + + sleep(sampling_period).await; + assert!(r.can_resume(ips[2]).unwrap(), "allow - another IP after bucket rotation"); + assert!(r.can_resume(ips[3]).unwrap(), "allow - new IP after bucket rotation"); + assert!(r.can_resume(ips[0]).is_err(), "error - should backpressure on this IP"); + assert!(r.can_resume(ips[1]).is_err(), "error - should backpressure on that IP"); } #[tokio::test(start_paused = true)] diff --git a/rules/rules.rs b/rules/rules.rs index 9e3d411..f574283 100644 --- a/rules/rules.rs +++ b/rules/rules.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -34,10 +32,16 @@ pub struct RateLimit { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "limit", rename_all = "snake_case")] +/// Number is power (exponent) of 2 -- it defines number of unique IPs that can be tracked pub enum RateLimitBucketSize { - // todo: some makro/crate to avoid this ugly pattern, which is a consequence of using heapless::index_map::FnvIndexMap - Bucket8, - Bucket9, + Bucket10 = 2isize.pow(10), + Bucket14 = 2isize.pow(14), + Bucket16 = 2isize.pow(16), + Bucket17 = 2isize.pow(17), + Bucket19 = 2isize.pow(19), + Bucket20 = 2isize.pow(20), + Bucket23 = 2isize.pow(23), + Bucket24 = 2isize.pow(24), } // pub struct CompiledRule { From a4bed25b7019c32d67eb1705644b01a6f5b8b384 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Sun, 11 Jan 2026 17:20:52 +0100 Subject: [PATCH 09/11] minor fixes --- .devcontainer/devcontainer.json | 14 ++++++++++++++ docs/rules.md | 25 ++++++++++++------------- pingoo/rate_limiter.rs | 10 +++++++--- pingoo/services/http_utils.rs | 4 ++-- rules/rules.rs | 2 +- 5 files changed, 36 insertions(+), 19 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 16c426a..7a2d4ae 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,19 @@ { "dockerFile": "Dockerfile", + "customizations": { + // Configure properties specific to VS Code. + "vscode": { + // Set *default* container specific settings.json values on container create. + "settings": { + "rust-analyzer.cargo.features": ["test-utils"], + "rust-analyzer.check.features": ["test-utils"], + "rust-analyzer.runnables.extraArgs": [ + "--features", + "test-utils" + ], + }, + } + }, "extensions": [ "rust-lang.rust-analyzer" ], diff --git a/docs/rules.md b/docs/rules.md index 34df576..74e3c2c 100644 --- a/docs/rules.md +++ b/docs/rules.md @@ -120,23 +120,22 @@ Valid lists types: ## Rate limiting -Algorithm used to evaluate a request is a [sliding -window](https://blog.cloudflare.com/counting-things-a-lot-of-different-things/) +Algorithm used is the [sliding window](https://blog.cloudflare.com/counting-things-a-lot-of-different-things/) that uses request count from both current and previous period. `max` (u16) number of requests in given `period` (u16) denominated in seconds. -Rate limiters have finite `capacity` measured in buckets. E.g. `bucket_10` can +Rate limiters have finite `capacity` measured in buckets. E.g. `bucket10` can store no more than 1024 entries in a timeframe of 2x `period`. Available bucket sizes are: -* `bucket_10` --> stores up to 1024 IPs (2^10), consumes additional 53.3 kB of memory -* `bucket_14` --> stores up to ~16k IPs (2^14), consumes additional 852 kB of memory -* `bucket_16` --> stores up to ~65k IPs (2^16), consumes additional 3.4 MB of memory -* `bucket_17` --> stores up to ~130k IPs (2^17), consumes additional 6.8 MB of memory -* `bucket_19` --> stores up to ~524k IPs (2^19), consumes additional 27.2 MB of memory -* `bucket_20` --> stores up to ~1 million IPs (2^20), consumes additional 54.5 MB of memory -* `bucket_23` --> stores up to ~9 million IPs (2^23), consumes additional 436.2 MB of memory -* `bucket_24` --> stores up to ~130k IPs (2^24), consumes additional 872.4 MB of memory +* `bucket10` --> stores up to 1024 IPs (2^10), consumes additional 53.3 kB of memory +* `bucket14` --> stores up to ~16k IPs (2^14), consumes additional 852 kB of memory +* `bucket16` --> stores up to ~65k IPs (2^16), consumes additional 3.4 MB of memory +* `bucket17` --> stores up to ~130k IPs (2^17), consumes additional 6.8 MB of memory +* `bucket19` --> stores up to ~524k IPs (2^19), consumes additional 27.2 MB of memory +* `bucket20` --> stores up to ~1 million IPs (2^20), consumes additional 54.5 MB of memory +* `bucket23` --> stores up to ~9 million IPs (2^23), consumes additional 436.2 MB of memory +* `bucket24` --> stores up to ~17 million IPs (2^24), consumes additional 872.4 MB of memory For a case when `max` threshold is crossed Pingoo responds with [HTTP 429 Too Many @@ -155,7 +154,7 @@ rules: limit: max: 10 period: 60 - capacity: bucket_10 + capacity: bucket10 ``` In this example Pingoo: @@ -165,4 +164,4 @@ In this example Pingoo: requests from IP address of that client crossed the threshold of 10 in sampling period of ONE minute * can count requests for 1024 (2^10) unique IP addresses on every minute -* starts returning HTTP 503 to new clients if their IP is not in the bucket +* starts returning HTTP 503 to new clients if bucket is full and their IP is not in the bucket diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index 44948e4..b83bc28 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -176,6 +176,8 @@ fn create_prev_window(instant: Instant, sampling_period: Duration) -> Instant { instant - sampling_period } +// start of duplicated fuctions. Only power of 2 changes but it must be constant so ... +// todo: perhaps could be impr with macro fn get_rate_limit_handle_b10(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { let mut limiter = RateLimiter::<{ 2usize.pow(10) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); @@ -267,7 +269,7 @@ mod tests { use super::*; #[test] - // this test case serves more for memory footprint documentation + // this test case documents memory footprint fn test_memory_footprint() { assert_eq!(17, std::mem::size_of::()); assert_eq!(2, std::mem::size_of::()); @@ -382,11 +384,13 @@ mod tests { for _ in 1..u16::MAX { assert!(r.can_resume(ip).unwrap(), "allow - should handle limit overflow"); } - assert!(!r.can_resume(ip).unwrap(), "block - should handle limit overflow"); + for _ in 0..u16::MAX { + assert!(!r.can_resume(ip).unwrap(), "block - should handle limit overflow"); + } } #[tokio::test(start_paused = true)] - async fn test_inmemory_get_approx() { + async fn test_get_approx() { let sampling_period = Duration::from_secs(60); let mut counter: u16 = 0; sleep(sampling_period).await; diff --git a/pingoo/services/http_utils.rs b/pingoo/services/http_utils.rs index 610338e..4efec3a 100644 --- a/pingoo/services/http_utils.rs +++ b/pingoo/services/http_utils.rs @@ -84,7 +84,7 @@ pub fn new_service_unavailable_error_503() -> Response Response> { @@ -132,7 +132,7 @@ pub fn new_too_many_requests_response_429() -> Response) -> &str { diff --git a/rules/rules.rs b/rules/rules.rs index f574283..a35bbd7 100644 --- a/rules/rules.rs +++ b/rules/rules.rs @@ -31,7 +31,7 @@ pub struct RateLimit { } #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "limit", rename_all = "snake_case")] +#[serde(rename_all = "snake_case")] /// Number is power (exponent) of 2 -- it defines number of unique IPs that can be tracked pub enum RateLimitBucketSize { Bucket10 = 2isize.pow(10), From 561677aa64d07f478add284d3132b777e59f1f64 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Sun, 11 Jan 2026 20:45:09 +0100 Subject: [PATCH 10/11] switch to std::collections::HashMap due to stack overflow of release build --- .devcontainer/devcontainer.json | 6 +- docs/rules.md | 17 +-- pingoo/rate_limiter.rs | 227 +++++++++----------------------- rules/rules.rs | 16 +-- 4 files changed, 70 insertions(+), 196 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 7a2d4ae..d987484 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,17 +1,15 @@ { "dockerFile": "Dockerfile", "customizations": { - // Configure properties specific to VS Code. "vscode": { - // Set *default* container specific settings.json values on container create. "settings": { "rust-analyzer.cargo.features": ["test-utils"], "rust-analyzer.check.features": ["test-utils"], "rust-analyzer.runnables.extraArgs": [ "--features", "test-utils" - ], - }, + ] + } } }, "extensions": [ diff --git a/docs/rules.md b/docs/rules.md index 74e3c2c..d4115e7 100644 --- a/docs/rules.md +++ b/docs/rules.md @@ -124,18 +124,9 @@ Algorithm used is the [sliding window](https://blog.cloudflare.com/counting-thin that uses request count from both current and previous period. `max` (u16) number of requests in given `period` (u16) denominated in seconds. -Rate limiters have finite `capacity` measured in buckets. E.g. `bucket10` can -store no more than 1024 entries in a timeframe of 2x `period`. - -Available bucket sizes are: -* `bucket10` --> stores up to 1024 IPs (2^10), consumes additional 53.3 kB of memory -* `bucket14` --> stores up to ~16k IPs (2^14), consumes additional 852 kB of memory -* `bucket16` --> stores up to ~65k IPs (2^16), consumes additional 3.4 MB of memory -* `bucket17` --> stores up to ~130k IPs (2^17), consumes additional 6.8 MB of memory -* `bucket19` --> stores up to ~524k IPs (2^19), consumes additional 27.2 MB of memory -* `bucket20` --> stores up to ~1 million IPs (2^20), consumes additional 54.5 MB of memory -* `bucket23` --> stores up to ~9 million IPs (2^23), consumes additional 436.2 MB of memory -* `bucket24` --> stores up to ~17 million IPs (2^24), consumes additional 872.4 MB of memory +Rate limiters have finite `capacity` that should be a power of 2 (otherwise +next power of 2 will be allocated). E.g. `capacity: 1024` can store no more +than 1024 entries in a `period`. For a case when `max` threshold is crossed Pingoo responds with [HTTP 429 Too Many @@ -154,7 +145,7 @@ rules: limit: max: 10 period: 60 - capacity: bucket10 + capacity: 1024 ``` In this example Pingoo: diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index b83bc28..d7e0f09 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -1,6 +1,5 @@ -use heapless::index_map::FnvIndexMap; use rules::RateLimit; -use rules::RateLimitBucketSize; +use std::collections::HashMap; use std::net::IpAddr; use tokio::sync::mpsc; use tokio::sync::oneshot; @@ -8,20 +7,18 @@ use tokio::task::JoinHandle; use tokio::time::Duration; use tokio::time::Instant; -pub fn get_rate_limit_handle(rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - match limiter_cfg.capacity { - // duplicated logic in each function is a consequence of using heapless::index_map::FnvIndexMap - // the only difference between them is Map capacity - // todo: some cleaner solution -- macro maybe? - RateLimitBucketSize::Bucket10 => get_rate_limit_handle_b10(rx, limiter_cfg), - RateLimitBucketSize::Bucket14 => get_rate_limit_handle_b14(rx, limiter_cfg), - RateLimitBucketSize::Bucket16 => get_rate_limit_handle_b16(rx, limiter_cfg), - RateLimitBucketSize::Bucket17 => get_rate_limit_handle_b17(rx, limiter_cfg), - RateLimitBucketSize::Bucket19 => get_rate_limit_handle_b19(rx, limiter_cfg), - RateLimitBucketSize::Bucket20 => get_rate_limit_handle_b20(rx, limiter_cfg), - RateLimitBucketSize::Bucket23 => get_rate_limit_handle_b23(rx, limiter_cfg), - RateLimitBucketSize::Bucket24 => get_rate_limit_handle_b24(rx, limiter_cfg), - } +pub fn get_rate_limit_handle(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + tokio::spawn(async move { + let mut limiter = RateLimiter::new( + limiter_cfg.max, + Duration::from_secs(u64::from(limiter_cfg.window)), + limiter_cfg.capacity, + ); + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) } pub fn get_probe(ip: IpAddr) -> (Probe, oneshot::Receiver) { @@ -36,15 +33,15 @@ pub struct Probe { resp: Responder, } -struct RateLimiterBucket { +struct RateLimiterBucket { starts_at: Instant, - inner: FnvIndexMap, + inner: HashMap, } -struct RateLimiter { +struct RateLimiter { limit: u16, sampling_period: Duration, - bucket_green: RateLimiterBucket, - bucket_blue: RateLimiterBucket, + bucket_green: RateLimiterBucket, + bucket_blue: RateLimiterBucket, } #[derive(Debug)] @@ -52,17 +49,17 @@ struct Counter { pub sum: u16, } -impl RateLimiterBucket { - pub fn new(starts_at: Instant) -> Self { +impl RateLimiterBucket { + pub fn new(starts_at: Instant, capacity: usize) -> Self { Self { starts_at, - inner: FnvIndexMap::new(), + inner: HashMap::with_capacity(capacity), } } } -impl RateLimiter { - pub fn new(limit: u16, sampling_period: Duration) -> Self { +impl RateLimiter { + pub fn new(limit: u16, sampling_period: Duration, capacity: usize) -> Self { let mut sanitized_limit = limit; if limit == u16::MAX { sanitized_limit = limit - 1; @@ -71,8 +68,8 @@ impl RateLimiter { let now = Instant::now(); let before = create_prev_window(now, sampling_period); - let bucket_green = RateLimiterBucket::new(now); - let bucket_blue = RateLimiterBucket::new(before); + let bucket_green = RateLimiterBucket::new(now, capacity); + let bucket_blue = RateLimiterBucket::new(before, capacity); Self { limit: sanitized_limit, @@ -94,17 +91,18 @@ impl RateLimiter { }; let now = Instant::now(); if !self.is_within_curr_bucket_window(now, starts_at) && self.sampling_period > Duration::from_nanos(0) { + let capacity = self.bucket_green.inner.capacity(); if self.is_outside_next_monothonic_window(now, self.bucket_green.starts_at) { - self.bucket_green = RateLimiterBucket::new(now); - self.bucket_blue = RateLimiterBucket::new(now); + self.bucket_green = RateLimiterBucket::new(now, capacity); + self.bucket_blue = RateLimiterBucket::new(now, capacity); } else { // current bucket becomes previous if is_green_bucket_current { let starts_at = self.bucket_green.starts_at + self.sampling_period; - self.bucket_blue = RateLimiterBucket::new(starts_at); + self.bucket_blue = RateLimiterBucket::new(starts_at, capacity); } else { let starts_at = self.bucket_blue.starts_at + self.sampling_period; - self.bucket_green = RateLimiterBucket::new(starts_at); + self.bucket_green = RateLimiterBucket::new(starts_at, capacity); } } } @@ -115,21 +113,18 @@ impl RateLimiter { false => (&mut self.bucket_blue, &mut self.bucket_green), }; - let curr_counter = curr_bucket - .inner - .entry(ip) - .and_modify(|x| { - x.increment(); - }) - .or_insert_with(|| { - let mut x = Counter::new(); - x.increment(); - x - }); - if curr_counter.is_err() { + let curr_sum; + if let Some(counter) = curr_bucket.inner.get_mut(&ip) { + counter.increment(); + curr_sum = counter.sum; + } else if curr_bucket.inner.capacity() == curr_bucket.inner.len() { return Err(()); + } else { + let mut counter = Counter::new(); + counter.increment(); + curr_sum = counter.sum; + curr_bucket.inner.insert(ip, counter); } - let curr_sum = curr_counter.expect("counter should exist").sum; let prev_sum = match prev_bucket.inner.get(&ip) { Some(c) => c.sum, @@ -176,89 +171,6 @@ fn create_prev_window(instant: Instant, sampling_period: Duration) -> Instant { instant - sampling_period } -// start of duplicated fuctions. Only power of 2 changes but it must be constant so ... -// todo: perhaps could be impr with macro -fn get_rate_limit_handle_b10(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(10) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} -fn get_rate_limit_handle_b14(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(14) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} -fn get_rate_limit_handle_b16(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(16) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} -fn get_rate_limit_handle_b17(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(17) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} -fn get_rate_limit_handle_b19(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(19) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} -fn get_rate_limit_handle_b20(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(20) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} -fn get_rate_limit_handle_b23(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(23) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} -fn get_rate_limit_handle_b24(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { - let mut limiter = - RateLimiter::<{ 2usize.pow(24) }>::new(limiter_cfg.max, Duration::from_secs(u64::from(limiter_cfg.window))); - tokio::spawn(async move { - while let Some(probe) = rx.recv().await { - let result = limiter.can_resume(probe.ip); - let _ = probe.resp.send(result); - } - }) -} - #[cfg(feature = "test-utils")] #[cfg(test)] mod tests { @@ -268,23 +180,6 @@ mod tests { use super::*; - #[test] - // this test case documents memory footprint - fn test_memory_footprint() { - assert_eq!(17, std::mem::size_of::()); - assert_eq!(2, std::mem::size_of::()); - - assert_eq!(54_526_024, std::mem::size_of::>()); // ~million IPs -> 54.5 MB of mem footprint - - assert_eq!(53_320, std::mem::size_of::>()); // bucket_10 can store 1024 IPs and consume 53 kB - assert_eq!(852_040, std::mem::size_of::>()); // bucket_14 -> 16 384 IPs -> 852 kB - assert_eq!(3_407_944, std::mem::size_of::>()); // 65 536 IPs -> 3.4 MB - assert_eq!(6_815_816, std::mem::size_of::>()); // 131 072 IPs -> 6.8 MB - assert_eq!(27_263_048, std::mem::size_of::>()); // 524 288 IPs -> 27.2 MB - assert_eq!(436_207_688, std::mem::size_of::>()); // ~9 million IPs -> 436.2 MB - assert_eq!(872_415_304, std::mem::size_of::>()); // ~17 million IPs -> 872.4 MB - } - #[tokio::test(start_paused = true)] async fn test_rate_limiter_alg() { // Tests example form https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ @@ -298,7 +193,7 @@ mod tests { // = 49.5 requests let limit = 50; let sampling_period = Duration::from_secs(60); - let mut r = RateLimiter::<2>::new(limit, sampling_period); + let mut r = RateLimiter::new(limit, sampling_period, 2); let ip = Ipv4Addr::new(1, 1, 1, 1).into(); for _ in 0..42 { @@ -328,39 +223,43 @@ mod tests { async fn test_rate_limiter_backpressure() { let sampling_period = Duration::new(1, 0); let ips = [ + Ipv4Addr::new(0, 0, 0, 0).into(), Ipv4Addr::new(1, 1, 1, 1).into(), Ipv4Addr::new(2, 2, 2, 2).into(), Ipv4Addr::new(3, 3, 3, 3).into(), Ipv4Addr::new(4, 4, 4, 4).into(), + Ipv4Addr::new(5, 5, 5, 5).into(), + Ipv4Addr::new(6, 6, 6, 6).into(), + Ipv4Addr::new(7, 7, 7, 7).into(), + Ipv4Addr::new(8, 8, 8, 8).into(), ]; - let mut r = RateLimiter::<2>::new(10, sampling_period); - assert!(r.can_resume(ips[0]).unwrap(), "allow - should handle this IP"); - assert!(r.can_resume(ips[1]).unwrap(), "allow - should handle that IP"); - assert!(r.can_resume(ips[2]).is_err(), "error - should backpressure on another IP"); - assert!(r.can_resume(ips[0]).unwrap(), "allow - should still handle this IP"); - assert!(r.can_resume(ips[1]).unwrap(), "allow - should still handle that IP"); - assert!(r.can_resume(ips[2]).is_err(), "error - should again backpressure on another IP"); - - sleep(sampling_period).await; - assert!(r.can_resume(ips[2]).unwrap(), "allow - another IP after bucket rotation"); - assert!(r.can_resume(ips[3]).unwrap(), "allow - new IP after bucket rotation"); - assert!(r.can_resume(ips[0]).is_err(), "error - should backpressure on this IP"); - assert!(r.can_resume(ips[1]).is_err(), "error - should backpressure on that IP"); + // quote from https://doc.rust-lang.org/std/collections/struct.HashMap.html#method.with_capacity + // "This method is allowed to allocate for more elements than capacity. If capacity is zero, the hash map will not allocate." + // + // 8 is the first number for which it is not allocating for more elements. + let capacity = 8; + let mut r = RateLimiter::new(10, sampling_period, capacity); + for i in 0..ips.len() { + if i > capacity { + assert!(r.can_resume(ips[i]).is_err(), "error - should backpressure on this IP"); + } + assert!(r.can_resume(ips[i]).unwrap()); + } } #[tokio::test(start_paused = true)] async fn test_rate_limiter_boundary_single_ip() { let ip = Ipv4Addr::new(1, 1, 1, 1).into(); - let mut r = RateLimiter::<2>::new(1, Duration::new(1, 0)); + let mut r = RateLimiter::new(1, Duration::new(1, 0), 2); assert!(r.can_resume(ip).unwrap(), "should allow once when limit is 1"); assert!(!r.can_resume(ip).unwrap(), "should block n 2nd attempt when limit is 1"); - let mut r = RateLimiter::<2>::new(0, Duration::new(1, 0)); + let mut r = RateLimiter::new(0, Duration::new(1, 0), 2); assert!(!r.can_resume(ip).unwrap(), "should treat zero limit as always limited"); assert!(!r.can_resume(ip).unwrap(), "should treat zero limit as always limited"); - let mut r = RateLimiter::<2>::new(0, Duration::new(0, 0)); + let mut r = RateLimiter::new(0, Duration::new(0, 0), 2); assert!( !r.can_resume(ip).unwrap(), "should treat zero limit as always limited, even when zero window" @@ -370,7 +269,7 @@ mod tests { "should treat zero limit as always limited, even when zero window" ); - let mut r = RateLimiter::<2>::new(1, Duration::new(0, 0)); + let mut r = RateLimiter::new(1, Duration::new(0, 0), 2); assert!( r.can_resume(ip).unwrap(), "allow - limit should take precedense over zero window" @@ -380,7 +279,7 @@ mod tests { "block - limit should take precedense over zero window" ); - let mut r = RateLimiter::<2>::new(u16::MAX, Duration::new(1, 0)); + let mut r = RateLimiter::new(u16::MAX, Duration::new(1, 0), 2); for _ in 1..u16::MAX { assert!(r.can_resume(ip).unwrap(), "allow - should handle limit overflow"); } diff --git a/rules/rules.rs b/rules/rules.rs index a35bbd7..7604d6d 100644 --- a/rules/rules.rs +++ b/rules/rules.rs @@ -27,21 +27,7 @@ pub type Context<'a> = bel::Context<'a>; pub struct RateLimit { pub max: u16, pub window: u16, - pub capacity: RateLimitBucketSize, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -/// Number is power (exponent) of 2 -- it defines number of unique IPs that can be tracked -pub enum RateLimitBucketSize { - Bucket10 = 2isize.pow(10), - Bucket14 = 2isize.pow(14), - Bucket16 = 2isize.pow(16), - Bucket17 = 2isize.pow(17), - Bucket19 = 2isize.pow(19), - Bucket20 = 2isize.pow(20), - Bucket23 = 2isize.pow(23), - Bucket24 = 2isize.pow(24), + pub capacity: usize, } // pub struct CompiledRule { From f6cad491a90894c4f0b1b9ecc853f727a51c1f94 Mon Sep 17 00:00:00 2001 From: mrl5 <31549762+mrl5@users.noreply.github.com> Date: Fri, 30 Jan 2026 01:34:56 +0100 Subject: [PATCH 11/11] refactor --- pingoo/listeners/http_listener.rs | 31 +++++--------------------- pingoo/rate_limiter.rs | 36 ++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/pingoo/listeners/http_listener.rs b/pingoo/listeners/http_listener.rs index 89f16a8..5f13bd5 100644 --- a/pingoo/listeners/http_listener.rs +++ b/pingoo/listeners/http_listener.rs @@ -17,14 +17,13 @@ use crate::{ config::ListenerConfig, geoip::{self, GeoipDB, GeoipRecord}, listeners::{GRACEFUL_SHUTDOWN_TIMEOUT, Listener, accept_tcp_connection, bind_tcp_socket}, - rate_limiter::get_probe, + rate_limiter::limit_http_request, rules, services::{ HttpService, http_utils::{ HOSTNAME_MAX_LENGTH, RequestContext, RequestExtensionContext, USER_AGENT_MAX_LENGTH, get_path, - new_blocked_response, new_internal_error_response_500, new_not_found_error, - new_service_unavailable_error_503, new_too_many_requests_response_429, + new_blocked_response, new_not_found_error, }, }, }; @@ -119,7 +118,7 @@ impl Listener for HttpListener { } } -pub(super) async fn serve_http_requests( +pub async fn serve_http_requests( tcp_stream: IO, services: Arc>>, client_socket_addr: SocketAddr, @@ -263,28 +262,8 @@ pub(super) async fn serve_http_requests { // todo: if "action: limit" then this must be defined - not Option if let Some(tx) = rule.limiter_tx.clone() { - let (probe, rx) = get_probe(client_data.ip); - if let Err(err) = tx.send(probe).await { - error!("couldn't send request probe to rate limiter: {err}"); - return Ok(new_internal_error_response_500()); - } - - let resp = rx.await; - if let Err(err) = resp { - error!("error on receiving rate limiter result: {err}"); - return Ok(new_internal_error_response_500()); - } - - let result = resp.expect("error on receiving rate limiter result"); - if let Err(_) = result { - error!("rate limiter capacity reached for current timeframe"); - return Ok(new_service_unavailable_error_503()); - } - - let can_resume = - result.expect("rate limiter capacity reached for current timeframe"); - if !can_resume { - return Ok(new_too_many_requests_response_429()); + if let Some(res) = limit_http_request(client_data.ip, tx).await { + return Ok(res); } } } diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs index d7e0f09..1e70882 100644 --- a/pingoo/rate_limiter.rs +++ b/pingoo/rate_limiter.rs @@ -1,11 +1,45 @@ +use bytes::Bytes; +use http_body_util::combinators::BoxBody; use rules::RateLimit; use std::collections::HashMap; use std::net::IpAddr; use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio::time::Duration; use tokio::time::Instant; +use tracing::error; + +use crate::services::http_utils::new_internal_error_response_500; +use crate::services::http_utils::new_service_unavailable_error_503; +use crate::services::http_utils::new_too_many_requests_response_429; + +pub async fn limit_http_request(ip: IpAddr, tx: Sender) -> Option>> { + let (probe, rx) = get_probe(ip); + if let Err(err) = tx.send(probe).await { + error!("couldn't send request probe to rate limiter: {err}"); + return Some(new_internal_error_response_500()); + } + + let resp = rx.await; + if let Err(err) = resp { + error!("error on receiving rate limiter result: {err}"); + return Some(new_internal_error_response_500()); + } + + let result = resp.expect("error on receiving rate limiter result"); + if let Err(_) = result { + error!("rate limiter capacity reached for current timeframe"); + return Some(new_service_unavailable_error_503()); + } + + let can_resume = result.expect("rate limiter capacity reached for current timeframe"); + if !can_resume { + return Some(new_too_many_requests_response_429()); + } + None +} pub fn get_rate_limit_handle(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { tokio::spawn(async move { @@ -21,7 +55,7 @@ pub fn get_rate_limit_handle(mut rx: mpsc::Receiver, limiter_cfg: RateLim }) } -pub fn get_probe(ip: IpAddr) -> (Probe, oneshot::Receiver) { +fn get_probe(ip: IpAddr) -> (Probe, oneshot::Receiver) { let (tx, rx) = oneshot::channel(); (Probe { ip, resp: tx }, rx) }