diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 16c426a..d987484 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,17 @@ { "dockerFile": "Dockerfile", + "customizations": { + "vscode": { + "settings": { + "rust-analyzer.cargo.features": ["test-utils"], + "rust-analyzer.check.features": ["test-utils"], + "rust-analyzer.runnables.extraArgs": [ + "--features", + "test-utils" + ] + } + } + }, "extensions": [ "rust-lang.rust-analyzer" ], diff --git a/Cargo.toml b/Cargo.toml index 5519415..de9acc0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,9 @@ zeroize = { workspace = true, features = ["simd", "derive"] } zstd = { workspace = true } +[features] +test-utils = ["tokio/test-util"] + [workspace] resolver = "2" diff --git a/Makefile b/Makefile index a9221e5..ddbd542 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,10 @@ fmt: check: cargo check +.PHONY: test +test: + cargo test --features test-utils + .PHONY: clean clean: rm -rf $(DIST_DIR) captcha/dist diff --git a/docs/rules.md b/docs/rules.md index 0e521df..d4115e7 100644 --- a/docs/rules.md +++ b/docs/rules.md @@ -118,3 +118,41 @@ Valid lists types: - `String` - `Ip` +## Rate limiting + +Algorithm used is the [sliding window](https://blog.cloudflare.com/counting-things-a-lot-of-different-things/) +that uses request count from both current and previous period. + +`max` (u16) number of requests in given `period` (u16) denominated in seconds. +Rate limiters have finite `capacity` that should be a power of 2 (otherwise +next power of 2 will be allocated). E.g. `capacity: 1024` can store no more +than 1024 entries in a `period`. + +For a case when `max` threshold is crossed Pingoo responds with [HTTP 429 Too +Many +Requests](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Status/429). +For a case where `capacity` bucket is full Pingoo responds with [HTTP 503 +Service +Unavailable](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Status/503). + +**pingoo.yml** +```yml +rules: + rate_limit_api_routes: + expression: http_request.path.starts_with("/api/") + actions: + - action: limit + limit: + max: 10 + period: 60 + capacity: 1024 +``` + +In this example Pingoo: +* protects resources under `/api` route +* allows no more than 10 requests per ONE minute +* starts returning HTTP 429 to the specific client, when number of incoming + requests from IP address of that client crossed the threshold of 10 in + sampling period of ONE minute +* can count requests for 1024 (2^10) unique IP addresses on every minute +* starts returning HTTP 503 to new clients if bucket is full and their IP is not in the bucket diff --git a/pingoo/config/config.rs b/pingoo/config/config.rs index 8c1bdc2..fe6c184 100644 --- a/pingoo/config/config.rs +++ b/pingoo/config/config.rs @@ -9,13 +9,18 @@ use std::{ use http::StatusCode; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use tokio::fs; +use tokio::{ + fs, + sync::mpsc::{self, Sender}, + task::JoinHandle, +}; use tracing::{debug, warn}; use crate::{ Error, config::config_file::{ConfigFile, RuleConfigFile, parse_service}, lists::ListType, + rate_limiter::{Probe, get_rate_limit_handle}, rules::Rule, service_discovery::service_registry::Upstream, tls::acme::LETSENCRYPT_PRODUCTION_URL, @@ -46,6 +51,7 @@ pub struct Config { pub service_discovery: ServiceDiscoveryConfig, pub lists: HashMap, pub child_process: Option, + pub limiter_workers: Vec>, } #[derive(Clone, Copy, Debug, Deserialize, Serialize, Eq, PartialEq)] @@ -252,10 +258,19 @@ pub async fn load_and_validate() -> Result { .collect(); validate_listeners_config(&listeners, &services)?; + let mut limiter_workers: Vec> = vec![]; let rules: Vec = config_file .rules .into_iter() .map(|(rule_name, rule_config)| { + let mut limiter_tx: Option> = None; + if let Some(limiter_cfg) = rule_config.limit { + let buffer = 1024; // todo make configurable + let (tx, rx) = mpsc::channel(buffer); + limiter_workers.push(get_rate_limit_handle(rx, limiter_cfg)); + limiter_tx = Some(tx); + } + Ok(Rule { name: rule_name, expression: rule_config @@ -263,6 +278,7 @@ pub async fn load_and_validate() -> Result { .map(|expression| rules::compile_expression(&expression)) .map_or(Ok(None), |r| r.map(Some))?, actions: rule_config.actions, + limiter_tx, }) }) .collect::>() @@ -317,6 +333,7 @@ pub async fn load_and_validate() -> Result { service_discovery: config_file.service_discovery.unwrap_or_default(), lists, child_process: config_file.child_process, + limiter_workers, }; return Ok(config); diff --git a/pingoo/config/config_file.rs b/pingoo/config/config_file.rs index 20e733d..2ab785e 100644 --- a/pingoo/config/config_file.rs +++ b/pingoo/config/config_file.rs @@ -98,6 +98,7 @@ pub struct ServiceConfigFileStaticNotFound { pub struct RuleConfigFile { pub expression: Option, pub actions: Vec, + pub limit: Option, } impl Default for ServiceConfigFileStaticNotFound { diff --git a/pingoo/listeners/http_listener.rs b/pingoo/listeners/http_listener.rs index 703df50..5f13bd5 100644 --- a/pingoo/listeners/http_listener.rs +++ b/pingoo/listeners/http_listener.rs @@ -17,6 +17,7 @@ use crate::{ config::ListenerConfig, geoip::{self, GeoipDB, GeoipRecord}, listeners::{GRACEFUL_SHUTDOWN_TIMEOUT, Listener, accept_tcp_connection, bind_tcp_socket}, + rate_limiter::limit_http_request, rules, services::{ HttpService, @@ -117,7 +118,7 @@ impl Listener for HttpListener { } } -pub(super) async fn serve_http_requests( +pub async fn serve_http_requests( tcp_stream: IO, services: Arc>>, client_socket_addr: SocketAddr, @@ -258,6 +259,14 @@ pub(super) async fn serve_http_requests { + // todo: if "action: limit" then this must be defined - not Option + if let Some(tx) = rule.limiter_tx.clone() { + if let Some(res) = limit_http_request(client_data.ip, tx).await { + return Ok(res); + } + } + } } } } diff --git a/pingoo/main.rs b/pingoo/main.rs index b48be8c..5bbb9f6 100644 --- a/pingoo/main.rs +++ b/pingoo/main.rs @@ -13,6 +13,7 @@ mod error; mod geoip; mod listeners; mod lists; +mod rate_limiter; mod rules; mod serde_utils; mod service_discovery; diff --git a/pingoo/rate_limiter.rs b/pingoo/rate_limiter.rs new file mode 100644 index 0000000..1e70882 --- /dev/null +++ b/pingoo/rate_limiter.rs @@ -0,0 +1,359 @@ +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use rules::RateLimit; +use std::collections::HashMap; +use std::net::IpAddr; +use tokio::sync::mpsc; +use tokio::sync::mpsc::Sender; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use tokio::time::Duration; +use tokio::time::Instant; +use tracing::error; + +use crate::services::http_utils::new_internal_error_response_500; +use crate::services::http_utils::new_service_unavailable_error_503; +use crate::services::http_utils::new_too_many_requests_response_429; + +pub async fn limit_http_request(ip: IpAddr, tx: Sender) -> Option>> { + let (probe, rx) = get_probe(ip); + if let Err(err) = tx.send(probe).await { + error!("couldn't send request probe to rate limiter: {err}"); + return Some(new_internal_error_response_500()); + } + + let resp = rx.await; + if let Err(err) = resp { + error!("error on receiving rate limiter result: {err}"); + return Some(new_internal_error_response_500()); + } + + let result = resp.expect("error on receiving rate limiter result"); + if let Err(_) = result { + error!("rate limiter capacity reached for current timeframe"); + return Some(new_service_unavailable_error_503()); + } + + let can_resume = result.expect("rate limiter capacity reached for current timeframe"); + if !can_resume { + return Some(new_too_many_requests_response_429()); + } + None +} + +pub fn get_rate_limit_handle(mut rx: mpsc::Receiver, limiter_cfg: RateLimit) -> JoinHandle<()> { + tokio::spawn(async move { + let mut limiter = RateLimiter::new( + limiter_cfg.max, + Duration::from_secs(u64::from(limiter_cfg.window)), + limiter_cfg.capacity, + ); + while let Some(probe) = rx.recv().await { + let result = limiter.can_resume(probe.ip); + let _ = probe.resp.send(result); + } + }) +} + +fn get_probe(ip: IpAddr) -> (Probe, oneshot::Receiver) { + let (tx, rx) = oneshot::channel(); + (Probe { ip, resp: tx }, rx) +} + +type Response = Result; +type Responder = oneshot::Sender; +pub struct Probe { + ip: IpAddr, + resp: Responder, +} + +struct RateLimiterBucket { + starts_at: Instant, + inner: HashMap, +} +struct RateLimiter { + limit: u16, + sampling_period: Duration, + bucket_green: RateLimiterBucket, + bucket_blue: RateLimiterBucket, +} + +#[derive(Debug)] +struct Counter { + pub sum: u16, +} + +impl RateLimiterBucket { + pub fn new(starts_at: Instant, capacity: usize) -> Self { + Self { + starts_at, + inner: HashMap::with_capacity(capacity), + } + } +} + +impl RateLimiter { + pub fn new(limit: u16, sampling_period: Duration, capacity: usize) -> Self { + let mut sanitized_limit = limit; + if limit == u16::MAX { + sanitized_limit = limit - 1; + } + + let now = Instant::now(); + let before = create_prev_window(now, sampling_period); + + let bucket_green = RateLimiterBucket::new(now, capacity); + let bucket_blue = RateLimiterBucket::new(before, capacity); + + Self { + limit: sanitized_limit, + sampling_period, + bucket_green, + bucket_blue, + } + } + + pub fn can_resume(&mut self, ip: IpAddr) -> Result { + if self.limit == 0 { + return Ok(false); + } + + let is_green_bucket_current = self.bucket_green.starts_at >= self.bucket_blue.starts_at; + let starts_at = match is_green_bucket_current { + true => self.bucket_green.starts_at, + false => self.bucket_blue.starts_at, + }; + let now = Instant::now(); + if !self.is_within_curr_bucket_window(now, starts_at) && self.sampling_period > Duration::from_nanos(0) { + let capacity = self.bucket_green.inner.capacity(); + if self.is_outside_next_monothonic_window(now, self.bucket_green.starts_at) { + self.bucket_green = RateLimiterBucket::new(now, capacity); + self.bucket_blue = RateLimiterBucket::new(now, capacity); + } else { + // current bucket becomes previous + if is_green_bucket_current { + let starts_at = self.bucket_green.starts_at + self.sampling_period; + self.bucket_blue = RateLimiterBucket::new(starts_at, capacity); + } else { + let starts_at = self.bucket_blue.starts_at + self.sampling_period; + self.bucket_green = RateLimiterBucket::new(starts_at, capacity); + } + } + } + + let is_green_bucket_current = self.bucket_green.starts_at >= self.bucket_blue.starts_at; + let (curr_bucket, prev_bucket) = match is_green_bucket_current { + true => (&mut self.bucket_green, &mut self.bucket_blue), + false => (&mut self.bucket_blue, &mut self.bucket_green), + }; + + let curr_sum; + if let Some(counter) = curr_bucket.inner.get_mut(&ip) { + counter.increment(); + curr_sum = counter.sum; + } else if curr_bucket.inner.capacity() == curr_bucket.inner.len() { + return Err(()); + } else { + let mut counter = Counter::new(); + counter.increment(); + curr_sum = counter.sum; + curr_bucket.inner.insert(ip, counter); + } + + let prev_sum = match prev_bucket.inner.get(&ip) { + Some(c) => c.sum, + None => 0, + }; + + let approx = get_approx(prev_sum, curr_bucket.starts_at.elapsed(), self.sampling_period); + Ok(u64::from(self.limit) >= approx + u64::from(curr_sum)) + } + + fn is_within_curr_bucket_window(&self, now: Instant, curr_starts_at: Instant) -> bool { + let next_monothonic_window = curr_starts_at + self.sampling_period; + now < next_monothonic_window + } + + fn is_outside_next_monothonic_window(&self, now: Instant, curr_starts_at: Instant) -> bool { + let next_monothonic_window = curr_starts_at + 2 * self.sampling_period; + now >= next_monothonic_window + } +} + +impl Counter { + pub fn new() -> Self { + Self { sum: 0 } + } + + pub fn increment(&mut self) { + self.sum = self.sum.saturating_add(1); + } +} + +fn get_approx(prev_counter: u16, window_needle: Duration, sampling_period: Duration) -> u64 { + if window_needle >= sampling_period { + return 0; + } + + u64::from(prev_counter) * (sampling_period.as_secs() - window_needle.as_secs()) / sampling_period.as_secs() +} + +fn create_prev_window(instant: Instant, sampling_period: Duration) -> Instant { + if instant.elapsed() < sampling_period { + return instant; + } + instant - sampling_period +} + +#[cfg(feature = "test-utils")] +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use tokio::time::sleep; + + use super::*; + + #[tokio::test(start_paused = true)] + async fn test_rate_limiter_alg() { + // Tests example form https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ + // + // "Let's say I set a limit of 50 requests per minute on an API endpoint. + // In this situation, I did 18 requests during the current minute, which started 15 seconds ago + // and 42 requests during the entire previous minute." + // + // rate = 42 * ((60-15)/60) + 18 + // = 42 * 0.75 + 18 + // = 49.5 requests + let limit = 50; + let sampling_period = Duration::from_secs(60); + let mut r = RateLimiter::new(limit, sampling_period, 2); + let ip = Ipv4Addr::new(1, 1, 1, 1).into(); + + for _ in 0..42 { + assert!(r.can_resume(ip).unwrap(), "should resume until limit is not reached"); + } + + sleep(sampling_period + Duration::from_secs(15)).await; + for _ in 0..19 { + assert!(r.can_resume(ip).unwrap(), "should resume for 42 * ((60-15)/60) + 19 = 50"); + } + + assert!(!r.can_resume(ip).unwrap(), "should break for 42 * ((60-15)/60) + 20 = 51"); + + sleep(Duration::from_secs(3)).await; + assert!(r.can_resume(ip).unwrap(), "should resume for 42 * ((60-(15+3))/60) + 21 = 50"); + + sleep(2 * sampling_period).await; + for _ in 0..limit { + assert!(r.can_resume(ip).unwrap(), "should allow until limit is not reached"); + } + for _ in 0..u16::MAX { + assert!(!r.can_resume(ip).unwrap(), "should break when limit reached"); + } + } + + #[tokio::test(start_paused = true)] + async fn test_rate_limiter_backpressure() { + let sampling_period = Duration::new(1, 0); + let ips = [ + Ipv4Addr::new(0, 0, 0, 0).into(), + Ipv4Addr::new(1, 1, 1, 1).into(), + Ipv4Addr::new(2, 2, 2, 2).into(), + Ipv4Addr::new(3, 3, 3, 3).into(), + Ipv4Addr::new(4, 4, 4, 4).into(), + Ipv4Addr::new(5, 5, 5, 5).into(), + Ipv4Addr::new(6, 6, 6, 6).into(), + Ipv4Addr::new(7, 7, 7, 7).into(), + Ipv4Addr::new(8, 8, 8, 8).into(), + ]; + // quote from https://doc.rust-lang.org/std/collections/struct.HashMap.html#method.with_capacity + // "This method is allowed to allocate for more elements than capacity. If capacity is zero, the hash map will not allocate." + // + // 8 is the first number for which it is not allocating for more elements. + let capacity = 8; + let mut r = RateLimiter::new(10, sampling_period, capacity); + for i in 0..ips.len() { + if i > capacity { + assert!(r.can_resume(ips[i]).is_err(), "error - should backpressure on this IP"); + } + assert!(r.can_resume(ips[i]).unwrap()); + } + } + + #[tokio::test(start_paused = true)] + async fn test_rate_limiter_boundary_single_ip() { + let ip = Ipv4Addr::new(1, 1, 1, 1).into(); + let mut r = RateLimiter::new(1, Duration::new(1, 0), 2); + + assert!(r.can_resume(ip).unwrap(), "should allow once when limit is 1"); + assert!(!r.can_resume(ip).unwrap(), "should block n 2nd attempt when limit is 1"); + + let mut r = RateLimiter::new(0, Duration::new(1, 0), 2); + assert!(!r.can_resume(ip).unwrap(), "should treat zero limit as always limited"); + assert!(!r.can_resume(ip).unwrap(), "should treat zero limit as always limited"); + + let mut r = RateLimiter::new(0, Duration::new(0, 0), 2); + assert!( + !r.can_resume(ip).unwrap(), + "should treat zero limit as always limited, even when zero window" + ); + assert!( + !r.can_resume(ip).unwrap(), + "should treat zero limit as always limited, even when zero window" + ); + + let mut r = RateLimiter::new(1, Duration::new(0, 0), 2); + assert!( + r.can_resume(ip).unwrap(), + "allow - limit should take precedense over zero window" + ); + assert!( + !r.can_resume(ip).unwrap(), + "block - limit should take precedense over zero window" + ); + + let mut r = RateLimiter::new(u16::MAX, Duration::new(1, 0), 2); + for _ in 1..u16::MAX { + assert!(r.can_resume(ip).unwrap(), "allow - should handle limit overflow"); + } + for _ in 0..u16::MAX { + assert!(!r.can_resume(ip).unwrap(), "block - should handle limit overflow"); + } + } + + #[tokio::test(start_paused = true)] + async fn test_get_approx() { + let sampling_period = Duration::from_secs(60); + let mut counter: u16 = 0; + sleep(sampling_period).await; + for _ in 0..42 { + counter = counter.saturating_add(1); + } + + let start = Instant::now(); + assert_eq!(42, counter); + assert_eq!(42, get_approx(counter, start.elapsed(), sampling_period)); + sleep(Duration::from_secs(15)).await; + assert_eq!(31, get_approx(counter, start.elapsed(), sampling_period)); + } + + #[tokio::test(start_paused = true)] + async fn test_create_prev_window() { + let sampling_window = Duration::from_secs(60); + + for tc in vec![0, 13, 59] { + let now = Instant::now(); + let offset = Duration::from_secs(tc); + sleep(offset).await; + assert_eq!(now, create_prev_window(now, sampling_window)); + } + + for tc in vec![60, 73, 119] { + let now = Instant::now(); + let offset = Duration::from_secs(tc); + sleep(offset).await; + assert_eq!(now - sampling_window, create_prev_window(now, sampling_window)); + } + } +} diff --git a/pingoo/rules.rs b/pingoo/rules.rs index 952c3ee..11299bb 100644 --- a/pingoo/rules.rs +++ b/pingoo/rules.rs @@ -2,15 +2,17 @@ use std::net::IpAddr; use http::Uri; use serde::Serialize; +use tokio::sync::mpsc::Sender; use tracing::warn; -use crate::{geoip::CountryCode, serde_utils}; +use crate::{geoip::CountryCode, rate_limiter::Probe, serde_utils}; #[derive(Debug, Clone)] pub struct Rule { pub name: String, pub expression: Option, pub actions: Vec, + pub limiter_tx: Option>, } #[derive(Debug, Serialize)] diff --git a/pingoo/services/http_utils.rs b/pingoo/services/http_utils.rs index 733b6ec..4efec3a 100644 --- a/pingoo/services/http_utils.rs +++ b/pingoo/services/http_utils.rs @@ -75,6 +75,18 @@ pub fn new_bad_gateway_error() -> Response> { .expect("error building new_bad_gateway_error"); } +pub fn new_service_unavailable_error_503() -> Response> { + const ERROR_MESSAGE: &[u8] = b"503 Service Unavailable"; + let res_body = Full::new(Bytes::from_static(ERROR_MESSAGE)) + .map_err(|never| match never {}) + .boxed(); + return Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .header(header::CACHE_CONTROL, &CACHE_CONTROL_NO_CACHE) + .body(res_body) + .expect("error building new_service_unavailable_error_503"); +} + pub fn new_not_found_error() -> Response> { const NOT_FOUND_ERROR_MESSAGE: &[u8] = b"404 Not Found."; let res_body = Full::new(Bytes::from_static(NOT_FOUND_ERROR_MESSAGE)) @@ -111,6 +123,18 @@ pub fn new_method_not_allowed_error() -> Response> .expect("error building new_method_not_allowed_error"); } +pub fn new_too_many_requests_response_429() -> Response> { + const ERROR_MESSAGE: &[u8] = b"429 Too Many Requests"; + let res_body = Full::new(Bytes::from_static(ERROR_MESSAGE)) + .map_err(|never| match never {}) + .boxed(); + return Response::builder() + .status(StatusCode::TOO_MANY_REQUESTS) + .header(header::CACHE_CONTROL, &CACHE_CONTROL_NO_CACHE) + .body(res_body) + .expect("error building new_too_many_requests_response_429"); +} + pub fn get_path(req: &Request) -> &str { req.uri().path().trim_end_matches('/') } diff --git a/rules/rules.rs b/rules/rules.rs index 3577b1b..7604d6d 100644 --- a/rules/rules.rs +++ b/rules/rules.rs @@ -22,6 +22,14 @@ pub struct Rule { pub type CompiledExpression = bel::Program; pub type Context<'a> = bel::Context<'a>; +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "limit", rename_all = "snake_case")] +pub struct RateLimit { + pub max: u16, + pub window: u16, + pub capacity: usize, +} + // pub struct CompiledRule { // pub id: Uuid, // pub updated_at: DateTime, @@ -32,6 +40,7 @@ pub type Context<'a> = bel::Context<'a>; pub enum Action { Block {}, Captcha {}, + Limit {}, } #[derive(Debug, thiserror::Error)]