diff --git a/Cargo.lock b/Cargo.lock index 182c0a172..96b17a02d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4017,9 +4017,11 @@ dependencies = [ name = "serverbee-common" version = "1.0.0-alpha.6" dependencies = [ + "anyhow", "chrono", "serde", "serde_json", + "url", "utoipa", "uuid", ] diff --git a/apps/docs/content/docs/en/service-monitors.mdx b/apps/docs/content/docs/en/service-monitors.mdx index 7b2c58ca7..ecef073e1 100644 --- a/apps/docs/content/docs/en/service-monitors.mdx +++ b/apps/docs/content/docs/en/service-monitors.mdx @@ -30,6 +30,10 @@ Unlike Ping Monitoring, which asks agents to probe network targets, Service Moni The background checker wakes every 10 seconds, schedules due monitors based on their `interval`, and runs up to 20 checks concurrently. + +As an SSRF safeguard, targets that resolve to loopback (`127.0.0.0/8`, `::1`, `localhost`), link-local, the cloud metadata endpoint (`169.254.169.254`), or other non-routable addresses are blocked. SSL, TCP, and HTTP Keyword monitors reject such targets when you save the monitor, and the guard is re-applied on every HTTP redirect hop. Private LAN ranges (`10/8`, `172.16/12`, `192.168/16`, and IPv6 ULA) remain allowed so you can monitor internal hosts. + + ## Type-Specific Configuration ### SSL Certificate diff --git a/apps/docs/content/docs/zh/service-monitors.mdx b/apps/docs/content/docs/zh/service-monitors.mdx index 81378c3fd..5925339a8 100644 --- a/apps/docs/content/docs/zh/service-monitors.mdx +++ b/apps/docs/content/docs/zh/service-monitors.mdx @@ -30,6 +30,10 @@ icon: Radar 后台检查器每 10 秒唤醒一次,根据每个监控的 `interval` 判断是否到期,并最多并发执行 20 个检查。 + +作为 SSRF 防护,解析到回环地址(`127.0.0.0/8`、`::1`、`localhost`)、链路本地地址、云元数据端点(`169.254.169.254`)或其他不可路由地址的目标会被拦截。SSL、TCP 和 HTTP Keyword 监控在保存时即拒绝此类目标,并在每一次 HTTP 重定向跳转时重新校验。私有局域网网段(`10/8`、`172.16/12`、`192.168/16` 及 IPv6 ULA)仍然允许,以便监控内网主机。 + + ## 类型配置 ### SSL 证书 diff --git a/crates/agent/src/ip_quality/ssrf.rs b/crates/agent/src/ip_quality/ssrf.rs index 79e0eb0f3..ec232b8b8 100644 --- a/crates/agent/src/ip_quality/ssrf.rs +++ b/crates/agent/src/ip_quality/ssrf.rs @@ -1,361 +1,7 @@ -use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; - -use anyhow::{bail, Result}; -use url::Url; - -/// Allowed URL schemes for unlock check requests. -const ALLOWED_SCHEMES: &[&str] = &["http", "https"]; - -/// Allowed explicit ports (absent means scheme default — 80/443). -const ALLOWED_PORTS: &[u16] = &[80, 443]; - -/// Validate that a URL is safe to fetch: -/// - scheme must be `http` or `https` -/// - port must be 80, 443, or absent (scheme default) -/// - no embedded credentials (`user:pass@host`) -/// -/// Returns the parsed `Url` on success. -pub fn validate_url(raw: &str) -> Result { - let url = Url::parse(raw)?; - - if !ALLOWED_SCHEMES.contains(&url.scheme()) { - bail!("SSRF guard: scheme '{}' is not allowed (only http/https)", url.scheme()); - } - - if url.port().is_some_and(|port| !ALLOWED_PORTS.contains(&port)) { - bail!( - "SSRF guard: port {} is not allowed (only 80/443 or scheme default)", - url.port().unwrap() - ); - } - - // Reject embedded credentials: they are not a guard bypass (the host is - // still resolved and checked) but they would leak into the request and - // logs. - if !url.username().is_empty() || url.password().is_some() { - bail!("SSRF guard: URL must not contain embedded credentials"); - } - - Ok(url) -} - -/// Returns `true` if `addr` is globally routable (safe to connect to). -/// -/// Rejects: -/// IPv4: this-network (0.0.0.0/8), loopback (127.0.0.0/8), -/// private (10/8, 172.16/12, 192.168/16), -/// link-local (169.254.0.0/16), broadcast (255.255.255.255), -/// documentation (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24), -/// shared address space (100.64.0.0/10) -/// IPv6: IPv4-mapped/-compatible (`::ffff:a.b.c.d` / `::a.b.c.d` — unwrapped -/// and re-checked through the IPv4 rules), loopback (::1), -/// unspecified (::), link-local (fe80::/10), ULA (fc00::/7), -/// documentation (2001:db8::/32), NAT64 well-known prefix -/// (64:ff9b::/96, RFC 6052) -pub fn is_global_addr(addr: IpAddr) -> bool { - match addr { - IpAddr::V4(v4) => { - if v4.is_loopback() { - return false; - } - if v4.is_private() { - return false; - } - if v4.is_link_local() { - return false; - } - if v4.is_broadcast() { - return false; - } - let octets = v4.octets(); - // "This network" (RFC 791): 0.0.0.0/8 — covers 0.0.0.0 as well. - if octets[0] == 0 { - return false; - } - // Documentation ranges: 192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24 - if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 { - return false; - } - if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 { - return false; - } - if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 { - return false; - } - // Shared address space (RFC 6598): 100.64.0.0/10 - if octets[0] == 100 && (octets[1] & 0b1100_0000) == 0b0100_0000 { - return false; - } - true - } - IpAddr::V6(v6) => { - // Unwrap IPv4-mapped (`::ffff:a.b.c.d`) and IPv4-compatible - // (`::a.b.c.d`) addresses and re-check through the IPv4 rules. - // Without this, `[::ffff:127.0.0.1]` would slip past the v6 - // checks (its `.is_loopback()` is false) and defeat the guard. - if let Some(v4) = v6.to_ipv4() { - return is_global_addr(IpAddr::V4(v4)); - } - if v6.is_loopback() { - return false; - } - if v6.is_unspecified() { - return false; - } - let segs = v6.segments(); - // Link-local: fe80::/10 — first 10 bits are 1111111010 - if (segs[0] & 0xffc0) == 0xfe80 { - return false; - } - // ULA: fc00::/7 — first 7 bits are 1111110 - if (segs[0] & 0xfe00) == 0xfc00 { - return false; - } - // Documentation: 2001:db8::/32 — first two segments are 2001:0db8 - if segs[0] == 0x2001 && segs[1] == 0x0db8 { - return false; - } - // NAT64 well-known prefix: 64:ff9b::/96 (RFC 6052). The low 32 - // bits embed an IPv4 address (e.g. 64:ff9b::7f00:1 = 127.0.0.1), - // so the whole /96 is rejected. - if segs[0] == 0x0064 - && segs[1] == 0xff9b - && segs[2] == 0 - && segs[3] == 0 - && segs[4] == 0 - && segs[5] == 0 - { - return false; - } - true - } - } -} - -/// Resolve `host` to its socket addresses (on `port`) and reject if **any** -/// resolved address is non-global. This is the DNS-rebinding defense: if any -/// address is private the host is unsafe. -/// -/// Returns the list of resolved `SocketAddr` on success. -pub fn resolve_and_check(host: &str, port: u16) -> Result> { - let addrs: Vec = (host, port).to_socket_addrs()?.collect(); - - if addrs.is_empty() { - bail!("SSRF guard: could not resolve host '{}'", host); - } - - for addr in &addrs { - if !is_global_addr(addr.ip()) { - bail!( - "SSRF guard: host '{}' resolved to non-global address {} — request blocked", - host, - addr.ip() - ); - } - } - - Ok(addrs) -} - -#[cfg(test)] -mod tests { - use super::*; - - // ── validate_url ────────────────────────────────────────────────────────── - - #[test] - fn validate_url_accepts_http_default_port() { - assert!(validate_url("http://example.com/check").is_ok()); - } - - #[test] - fn validate_url_accepts_https_default_port() { - assert!(validate_url("https://example.com/check").is_ok()); - } - - #[test] - fn validate_url_accepts_explicit_port_80() { - assert!(validate_url("http://example.com:80/check").is_ok()); - } - - #[test] - fn validate_url_accepts_explicit_port_443() { - assert!(validate_url("https://example.com:443/check").is_ok()); - } - - #[test] - fn validate_url_rejects_non_http_scheme_ftp() { - let err = validate_url("ftp://example.com/file").unwrap_err(); - assert!(err.to_string().contains("scheme"), "expected scheme error, got: {err}"); - } - - #[test] - fn validate_url_rejects_non_http_scheme_file() { - assert!(validate_url("file:///etc/passwd").is_err()); - } - - #[test] - fn validate_url_rejects_non_http_scheme_gopher() { - assert!(validate_url("gopher://example.com/").is_err()); - } - - #[test] - fn validate_url_rejects_port_8080() { - let err = validate_url("http://example.com:8080/").unwrap_err(); - assert!(err.to_string().contains("port"), "expected port error, got: {err}"); - } - - #[test] - fn validate_url_rejects_port_3000() { - assert!(validate_url("http://example.com:3000/").is_err()); - } - - #[test] - fn validate_url_rejects_embedded_username() { - let err = validate_url("http://user@example.com/").unwrap_err(); - assert!(err.to_string().contains("credentials"), "expected credentials error, got: {err}"); - } - - #[test] - fn validate_url_rejects_embedded_user_and_password() { - let err = validate_url("http://user:pass@example.com/").unwrap_err(); - assert!(err.to_string().contains("credentials"), "expected credentials error, got: {err}"); - } - - // ── is_global_addr ──────────────────────────────────────────────────────── - - #[test] - fn is_global_addr_rejects_ipv4_loopback() { - assert!(!is_global_addr("127.0.0.1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_private_10() { - assert!(!is_global_addr("10.0.0.1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_private_192_168() { - assert!(!is_global_addr("192.168.1.1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_link_local() { - assert!(!is_global_addr("169.254.169.254".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_this_network() { - // 0.0.0.0/8 — the whole "this network" range, not just 0.0.0.0. - assert!(!is_global_addr("0.0.0.0".parse().unwrap())); - assert!(!is_global_addr("0.1.2.3".parse().unwrap())); - assert!(!is_global_addr("0.255.255.255".parse().unwrap())); - } - - #[test] - fn is_global_addr_accepts_ipv4_public() { - assert!(is_global_addr("8.8.8.8".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv6_loopback() { - assert!(!is_global_addr("::1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv6_ula() { - assert!(!is_global_addr("fc00::1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv6_link_local() { - assert!(!is_global_addr("fe80::1".parse().unwrap())); - } - - #[test] - fn is_global_addr_accepts_ipv6_public() { - assert!(is_global_addr("2606:4700:4700::1111".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_mapped_loopback() { - // ::ffff:127.0.0.1 — IPv4-mapped, must be unwrapped and rejected. - assert!(!is_global_addr("::ffff:127.0.0.1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_mapped_private() { - assert!(!is_global_addr("::ffff:10.0.0.1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_mapped_metadata_ip() { - // ::ffff:169.254.169.254 — cloud metadata via IPv4-mapped form. - assert!(!is_global_addr("::ffff:169.254.169.254".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv4_compatible_loopback() { - // ::127.0.0.1 — IPv4-compatible form, must also be unwrapped. - assert!(!is_global_addr("::127.0.0.1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv6_documentation() { - // 2001:db8::/32 — IPv6 documentation range. - assert!(!is_global_addr("2001:db8::1".parse().unwrap())); - } - - #[test] - fn is_global_addr_rejects_ipv6_nat64_well_known_prefix() { - // 64:ff9b::/96 embeds an IPv4 address; 64:ff9b::7f00:1 = 127.0.0.1. - assert!(!is_global_addr("64:ff9b::7f00:1".parse().unwrap())); - } - - // ── resolve_and_check ───────────────────────────────────────────────────── - - #[test] - fn resolve_and_check_rejects_localhost() { - let err = resolve_and_check("localhost", 80).unwrap_err(); - assert!( - err.to_string().contains("SSRF guard"), - "expected SSRF guard error, got: {err}" - ); - } - - #[test] - fn resolve_and_check_rejects_ipv4_loopback_literal() { - let err = resolve_and_check("127.0.0.1", 80).unwrap_err(); - assert!(err.to_string().contains("SSRF guard"), "got: {err}"); - } - - #[test] - fn resolve_and_check_rejects_ipv4_private() { - let err = resolve_and_check("10.0.0.1", 80).unwrap_err(); - assert!(err.to_string().contains("SSRF guard"), "got: {err}"); - } - - #[test] - fn resolve_and_check_rejects_link_local_metadata_ip() { - let err = resolve_and_check("169.254.169.254", 80).unwrap_err(); - assert!(err.to_string().contains("SSRF guard"), "got: {err}"); - } - - #[test] - fn resolve_and_check_rejects_ipv6_loopback() { - let err = resolve_and_check("::1", 80).unwrap_err(); - assert!(err.to_string().contains("SSRF guard"), "got: {err}"); - } - - #[test] - fn resolve_and_check_rejects_ipv6_ula() { - let err = resolve_and_check("fc00::1", 80).unwrap_err(); - assert!(err.to_string().contains("SSRF guard"), "got: {err}"); - } - - #[test] - fn resolve_and_check_accepts_public_ipv4() { - let result = resolve_and_check("8.8.8.8", 80); - assert!(result.is_ok(), "expected success, got: {:?}", result.err()); - } -} +//! SSRF guard. +//! +//! The implementation moved to `serverbee_common::ssrf` so the server-side +//! service-monitor checkers can reuse the exact same validation. It is +//! re-exported here so existing `super::ssrf::*` references keep working and the +//! agent retains a single, shared source of truth. +pub use serverbee_common::ssrf::*; diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index a0eb6bf1b..a9c98b979 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -8,4 +8,6 @@ serde.workspace = true serde_json.workspace = true chrono.workspace = true uuid.workspace = true +anyhow.workspace = true +url = "2" utoipa = { version = "5", optional = true } diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 780c2260b..60a3ca4cf 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -3,4 +3,5 @@ pub mod docker_types; pub mod firewall; pub mod protocol; pub mod security; +pub mod ssrf; pub mod types; diff --git a/crates/common/src/ssrf.rs b/crates/common/src/ssrf.rs new file mode 100644 index 000000000..c480d1bf5 --- /dev/null +++ b/crates/common/src/ssrf.rs @@ -0,0 +1,557 @@ +use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; + +use anyhow::{Result, bail}; +use url::Url; + +/// Allowed URL schemes for outbound check requests. +const ALLOWED_SCHEMES: &[&str] = &["http", "https"]; + +/// Allowed explicit ports (absent means scheme default — 80/443). +const ALLOWED_PORTS: &[u16] = &[80, 443]; + +/// Shared URL validation: scheme must be `http`/`https`, no embedded +/// credentials, and—when `restrict_ports` is set—the port must be 80/443 or the +/// scheme default. Returns the parsed `Url` on success. +fn validate_url_inner(raw: &str, restrict_ports: bool) -> Result { + let url = Url::parse(raw)?; + + if !ALLOWED_SCHEMES.contains(&url.scheme()) { + bail!( + "SSRF guard: scheme '{}' is not allowed (only http/https)", + url.scheme() + ); + } + + if restrict_ports && url.port().is_some_and(|port| !ALLOWED_PORTS.contains(&port)) { + bail!( + "SSRF guard: port {} is not allowed (only 80/443 or scheme default)", + url.port().unwrap() + ); + } + + // Reject embedded credentials: they are not a guard bypass (the host is + // still resolved and checked) but they would leak into the request and + // logs. + if !url.username().is_empty() || url.password().is_some() { + bail!("SSRF guard: URL must not contain embedded credentials"); + } + + Ok(url) +} + +/// Validate that a URL is safe to fetch on the strict (global-only) path: +/// - scheme must be `http` or `https` +/// - port must be 80, 443, or absent (scheme default) +/// - no embedded credentials (`user:pass@host`) +/// +/// Returns the parsed `Url` on success. +pub fn validate_url(raw: &str) -> Result { + validate_url_inner(raw, true) +} + +/// Like [`validate_url`] but allows any port, for the service-monitor checkers +/// where operators legitimately monitor HTTP services on non-standard ports +/// (e.g. `:8080`, `:3000`). The scheme and embedded-credentials checks still +/// apply, and the address-level guard ([`is_monitor_safe_addr`]) still blocks +/// loopback/link-local/metadata regardless of port. +pub fn validate_monitor_url(raw: &str) -> Result { + validate_url_inner(raw, false) +} + +/// Returns `true` if `addr` is globally routable (safe to connect to). +/// +/// Rejects: +/// IPv4: this-network (0.0.0.0/8), loopback (127.0.0.0/8), +/// private (10/8, 172.16/12, 192.168/16), +/// link-local (169.254.0.0/16), broadcast (255.255.255.255), +/// documentation (192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24), +/// shared address space (100.64.0.0/10) +/// IPv6: IPv4-mapped/-compatible (`::ffff:a.b.c.d` / `::a.b.c.d` — unwrapped +/// and re-checked through the IPv4 rules), loopback (::1), +/// unspecified (::), link-local (fe80::/10), ULA (fc00::/7), +/// documentation (2001:db8::/32), NAT64 well-known prefix +/// (64:ff9b::/96, RFC 6052) +pub fn is_global_addr(addr: IpAddr) -> bool { + match addr { + IpAddr::V4(v4) => { + if v4.is_loopback() { + return false; + } + if v4.is_private() { + return false; + } + if v4.is_link_local() { + return false; + } + if v4.is_broadcast() { + return false; + } + let octets = v4.octets(); + // "This network" (RFC 791): 0.0.0.0/8 — covers 0.0.0.0 as well. + if octets[0] == 0 { + return false; + } + // Documentation ranges: 192.0.2.0/24, 198.51.100.0/24, 203.0.113.0/24 + if octets[0] == 192 && octets[1] == 0 && octets[2] == 2 { + return false; + } + if octets[0] == 198 && octets[1] == 51 && octets[2] == 100 { + return false; + } + if octets[0] == 203 && octets[1] == 0 && octets[2] == 113 { + return false; + } + // Shared address space (RFC 6598): 100.64.0.0/10 + if octets[0] == 100 && (octets[1] & 0b1100_0000) == 0b0100_0000 { + return false; + } + true + } + IpAddr::V6(v6) => { + // Unwrap IPv4-mapped (`::ffff:a.b.c.d`) and IPv4-compatible + // (`::a.b.c.d`) addresses and re-check through the IPv4 rules. + // Without this, `[::ffff:127.0.0.1]` would slip past the v6 + // checks (its `.is_loopback()` is false) and defeat the guard. + if let Some(v4) = v6.to_ipv4() { + return is_global_addr(IpAddr::V4(v4)); + } + if v6.is_loopback() { + return false; + } + if v6.is_unspecified() { + return false; + } + let segs = v6.segments(); + // Link-local: fe80::/10 — first 10 bits are 1111111010 + if (segs[0] & 0xffc0) == 0xfe80 { + return false; + } + // ULA: fc00::/7 — first 7 bits are 1111110 + if (segs[0] & 0xfe00) == 0xfc00 { + return false; + } + // Documentation: 2001:db8::/32 — first two segments are 2001:0db8 + if segs[0] == 0x2001 && segs[1] == 0x0db8 { + return false; + } + // NAT64 well-known prefix: 64:ff9b::/96 (RFC 6052). The low 32 + // bits embed an IPv4 address (e.g. 64:ff9b::7f00:1 = 127.0.0.1), + // so the whole /96 is rejected. + if segs[0] == 0x0064 + && segs[1] == 0xff9b + && segs[2] == 0 + && segs[3] == 0 + && segs[4] == 0 + && segs[5] == 0 + { + return false; + } + true + } + } +} + +/// Resolve `host` to its socket addresses (on `port`) and reject if **any** +/// resolved address is non-global. This is the DNS-rebinding defense: if any +/// address is private the host is unsafe. +/// +/// Returns the list of resolved `SocketAddr` on success. Callers should connect +/// to the returned addresses directly (rather than re-resolving the host) to +/// keep the validated result and the connected address identical. +pub fn resolve_and_check(host: &str, port: u16) -> Result> { + let addrs: Vec = (host, port).to_socket_addrs()?.collect(); + + if addrs.is_empty() { + bail!("SSRF guard: could not resolve host '{}'", host); + } + + for addr in &addrs { + if !is_global_addr(addr.ip()) { + bail!( + "SSRF guard: host '{}' resolved to non-global address {} — request blocked", + host, + addr.ip() + ); + } + } + + Ok(addrs) +} + +/// Returns `true` if `addr` is safe for the **service-monitor checkers** to +/// connect to. +/// +/// Unlike [`is_global_addr`], this intentionally ALLOWS RFC1918 private ranges +/// (10/8, 172.16/12, 192.168/16) and IPv6 ULA (fc00::/7) so operators can +/// legitimately monitor internal/LAN hosts. It still blocks the addresses that +/// have no business being a monitoring target and are the real SSRF prizes: +/// loopback, link-local (incl. the cloud-metadata 169.254.169.254), unspecified, +/// broadcast, the NAT64 well-known prefix, and the IPv4-mapped/-compatible IPv6 +/// forms of any of those. +pub fn is_monitor_safe_addr(addr: IpAddr) -> bool { + match addr { + IpAddr::V4(v4) => { + if v4.is_loopback() { + return false; + } + // 169.254.0.0/16 — link-local, incl. the cloud metadata endpoint. + if v4.is_link_local() { + return false; + } + if v4.is_broadcast() { + return false; + } + // "This network" (RFC 791): 0.0.0.0/8. + if v4.octets()[0] == 0 { + return false; + } + // RFC1918 private ranges are intentionally ALLOWED for internal monitoring. + true + } + IpAddr::V6(v6) => { + // Unwrap IPv4-mapped/-compatible forms and re-check through the v4 + // rules so `[::ffff:169.254.169.254]` cannot slip past. + if let Some(v4) = v6.to_ipv4() { + return is_monitor_safe_addr(IpAddr::V4(v4)); + } + if v6.is_loopback() { + return false; + } + if v6.is_unspecified() { + return false; + } + let segs = v6.segments(); + // Link-local: fe80::/10. + if (segs[0] & 0xffc0) == 0xfe80 { + return false; + } + // NAT64 well-known prefix (64:ff9b::/96) embeds an IPv4 address and + // can reach metadata/loopback over IPv6 — block it wholesale. + if segs[0] == 0x0064 + && segs[1] == 0xff9b + && segs[2] == 0 + && segs[3] == 0 + && segs[4] == 0 + && segs[5] == 0 + { + return false; + } + // ULA (fc00::/7) is the IPv6 analogue of private space — ALLOWED. + true + } + } +} + +/// Like [`resolve_and_check`] but uses [`is_monitor_safe_addr`] (allows private +/// ranges; blocks loopback/link-local/metadata/NAT64). For the service-monitor +/// checkers, which legitimately need to reach internal hosts. +/// +/// Returns the validated `SocketAddr`s; callers should connect to them directly +/// (without re-resolving the host) to close the DNS-rebinding window. +pub fn resolve_and_check_monitor(host: &str, port: u16) -> Result> { + let addrs: Vec = (host, port).to_socket_addrs()?.collect(); + + if addrs.is_empty() { + bail!("SSRF guard: could not resolve host '{}'", host); + } + + for addr in &addrs { + if !is_monitor_safe_addr(addr.ip()) { + bail!( + "SSRF guard: host '{}' resolved to blocked address {} (loopback/link-local/metadata) — request blocked", + host, + addr.ip() + ); + } + } + + Ok(addrs) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ── validate_url ────────────────────────────────────────────────────────── + + #[test] + fn validate_url_accepts_http_default_port() { + assert!(validate_url("http://example.com/check").is_ok()); + } + + #[test] + fn validate_url_accepts_https_default_port() { + assert!(validate_url("https://example.com/check").is_ok()); + } + + #[test] + fn validate_url_accepts_explicit_port_80() { + assert!(validate_url("http://example.com:80/check").is_ok()); + } + + #[test] + fn validate_url_accepts_explicit_port_443() { + assert!(validate_url("https://example.com:443/check").is_ok()); + } + + #[test] + fn validate_url_rejects_non_http_scheme_ftp() { + let err = validate_url("ftp://example.com/file").unwrap_err(); + assert!( + err.to_string().contains("scheme"), + "expected scheme error, got: {err}" + ); + } + + #[test] + fn validate_url_rejects_non_http_scheme_file() { + assert!(validate_url("file:///etc/passwd").is_err()); + } + + #[test] + fn validate_url_rejects_non_http_scheme_gopher() { + assert!(validate_url("gopher://example.com/").is_err()); + } + + #[test] + fn validate_url_rejects_port_8080() { + let err = validate_url("http://example.com:8080/").unwrap_err(); + assert!( + err.to_string().contains("port"), + "expected port error, got: {err}" + ); + } + + #[test] + fn validate_url_rejects_port_3000() { + assert!(validate_url("http://example.com:3000/").is_err()); + } + + #[test] + fn validate_url_rejects_embedded_username() { + let err = validate_url("http://user@example.com/").unwrap_err(); + assert!( + err.to_string().contains("credentials"), + "expected credentials error, got: {err}" + ); + } + + #[test] + fn validate_url_rejects_embedded_user_and_password() { + let err = validate_url("http://user:pass@example.com/").unwrap_err(); + assert!( + err.to_string().contains("credentials"), + "expected credentials error, got: {err}" + ); + } + + // ── validate_monitor_url (any port; scheme/credentials still enforced) ──── + + #[test] + fn validate_monitor_url_allows_custom_port() { + // The whole point of the relaxed validator: non-standard ports work. + assert!(validate_monitor_url("http://example.com:8080/health").is_ok()); + assert!(validate_monitor_url("https://example.com:3000/").is_ok()); + assert!(validate_monitor_url("http://example.com:9000/").is_ok()); + } + + #[test] + fn validate_monitor_url_still_allows_standard_ports() { + assert!(validate_monitor_url("http://example.com/").is_ok()); + assert!(validate_monitor_url("https://example.com:443/").is_ok()); + } + + #[test] + fn validate_monitor_url_still_rejects_non_http_scheme() { + assert!(validate_monitor_url("file:///etc/passwd").is_err()); + assert!(validate_monitor_url("gopher://example.com:8080/").is_err()); + } + + #[test] + fn validate_monitor_url_still_rejects_embedded_credentials() { + let err = validate_monitor_url("http://user:pass@example.com:8080/").unwrap_err(); + assert!( + err.to_string().contains("credentials"), + "expected credentials error, got: {err}" + ); + } + + // ── is_global_addr ──────────────────────────────────────────────────────── + + #[test] + fn is_global_addr_rejects_ipv4_loopback() { + assert!(!is_global_addr("127.0.0.1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_private_10() { + assert!(!is_global_addr("10.0.0.1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_private_192_168() { + assert!(!is_global_addr("192.168.1.1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_link_local() { + assert!(!is_global_addr("169.254.169.254".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_this_network() { + // 0.0.0.0/8 — the whole "this network" range, not just 0.0.0.0. + assert!(!is_global_addr("0.0.0.0".parse().unwrap())); + assert!(!is_global_addr("0.1.2.3".parse().unwrap())); + assert!(!is_global_addr("0.255.255.255".parse().unwrap())); + } + + #[test] + fn is_global_addr_accepts_ipv4_public() { + assert!(is_global_addr("8.8.8.8".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv6_loopback() { + assert!(!is_global_addr("::1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv6_ula() { + assert!(!is_global_addr("fc00::1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv6_link_local() { + assert!(!is_global_addr("fe80::1".parse().unwrap())); + } + + #[test] + fn is_global_addr_accepts_ipv6_public() { + assert!(is_global_addr("2606:4700:4700::1111".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_mapped_loopback() { + // ::ffff:127.0.0.1 — IPv4-mapped, must be unwrapped and rejected. + assert!(!is_global_addr("::ffff:127.0.0.1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_mapped_private() { + assert!(!is_global_addr("::ffff:10.0.0.1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_mapped_metadata_ip() { + // ::ffff:169.254.169.254 — cloud metadata via IPv4-mapped form. + assert!(!is_global_addr("::ffff:169.254.169.254".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv4_compatible_loopback() { + // ::127.0.0.1 — IPv4-compatible form, must also be unwrapped. + assert!(!is_global_addr("::127.0.0.1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv6_documentation() { + // 2001:db8::/32 — IPv6 documentation range. + assert!(!is_global_addr("2001:db8::1".parse().unwrap())); + } + + #[test] + fn is_global_addr_rejects_ipv6_nat64_well_known_prefix() { + // 64:ff9b::/96 embeds an IPv4 address; 64:ff9b::7f00:1 = 127.0.0.1. + assert!(!is_global_addr("64:ff9b::7f00:1".parse().unwrap())); + } + + // ── resolve_and_check ───────────────────────────────────────────────────── + + #[test] + fn resolve_and_check_rejects_localhost() { + let err = resolve_and_check("localhost", 80).unwrap_err(); + assert!( + err.to_string().contains("SSRF guard"), + "expected SSRF guard error, got: {err}" + ); + } + + #[test] + fn resolve_and_check_rejects_ipv4_loopback_literal() { + let err = resolve_and_check("127.0.0.1", 80).unwrap_err(); + assert!(err.to_string().contains("SSRF guard"), "got: {err}"); + } + + #[test] + fn resolve_and_check_rejects_ipv4_private() { + let err = resolve_and_check("10.0.0.1", 80).unwrap_err(); + assert!(err.to_string().contains("SSRF guard"), "got: {err}"); + } + + #[test] + fn resolve_and_check_rejects_link_local_metadata_ip() { + let err = resolve_and_check("169.254.169.254", 80).unwrap_err(); + assert!(err.to_string().contains("SSRF guard"), "got: {err}"); + } + + #[test] + fn resolve_and_check_rejects_ipv6_loopback() { + let err = resolve_and_check("::1", 80).unwrap_err(); + assert!(err.to_string().contains("SSRF guard"), "got: {err}"); + } + + #[test] + fn resolve_and_check_rejects_ipv6_ula() { + let err = resolve_and_check("fc00::1", 80).unwrap_err(); + assert!(err.to_string().contains("SSRF guard"), "got: {err}"); + } + + #[test] + fn resolve_and_check_accepts_public_ipv4() { + let result = resolve_and_check("8.8.8.8", 80); + assert!(result.is_ok(), "expected success, got: {:?}", result.err()); + } + + // ── is_monitor_safe_addr (private allowed, metadata/loopback blocked) ────── + + #[test] + fn monitor_safe_blocks_loopback() { + assert!(!is_monitor_safe_addr("127.0.0.1".parse().unwrap())); + assert!(!is_monitor_safe_addr("::1".parse().unwrap())); + } + + #[test] + fn monitor_safe_blocks_cloud_metadata() { + assert!(!is_monitor_safe_addr("169.254.169.254".parse().unwrap())); + assert!(!is_monitor_safe_addr("::ffff:169.254.169.254".parse().unwrap())); + } + + #[test] + fn monitor_safe_blocks_ipv6_link_local_and_nat64() { + assert!(!is_monitor_safe_addr("fe80::1".parse().unwrap())); + // 64:ff9b::a9fe:a9fe = NAT64-wrapped 169.254.169.254. + assert!(!is_monitor_safe_addr("64:ff9b::a9fe:a9fe".parse().unwrap())); + } + + #[test] + fn monitor_safe_allows_rfc1918_private() { + // Internal/LAN monitoring is a legitimate use case — these are allowed. + assert!(is_monitor_safe_addr("10.0.0.5".parse().unwrap())); + assert!(is_monitor_safe_addr("172.16.0.1".parse().unwrap())); + assert!(is_monitor_safe_addr("192.168.1.10".parse().unwrap())); + assert!(is_monitor_safe_addr("fc00::1".parse().unwrap())); + } + + #[test] + fn monitor_safe_allows_public() { + assert!(is_monitor_safe_addr("8.8.8.8".parse().unwrap())); + } + + #[test] + fn resolve_and_check_monitor_blocks_metadata_allows_private() { + assert!(resolve_and_check_monitor("169.254.169.254", 80).is_err()); + assert!(resolve_and_check_monitor("127.0.0.1", 80).is_err()); + assert!(resolve_and_check_monitor("10.0.0.5", 80).is_ok()); + } +} diff --git a/crates/server/src/router/api/oauth.rs b/crates/server/src/router/api/oauth.rs index 971d0caa2..17ae3d72a 100644 --- a/crates/server/src/router/api/oauth.rs +++ b/crates/server/src/router/api/oauth.rs @@ -11,7 +11,7 @@ use axum::routing::get; use crate::router::utils::extract_client_ip; use chrono::Utc; use oauth2::reqwest::async_http_client; -use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; +use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse}; use sea_orm::EntityTrait; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -19,7 +19,8 @@ use uuid::Uuid; use crate::error::AppError; use crate::service::auth::AuthService; use crate::service::oauth::OAuthService; -use crate::state::AppState; +use crate::state::{AppState, OAuthFlowState}; +use dashmap::DashMap; #[derive(Debug, Deserialize)] pub struct CallbackQuery { @@ -65,6 +66,62 @@ pub async fn list_providers( }) } +/// Name of the short-lived HttpOnly pre-auth cookie that binds an OAuth login +/// flow to the browser that started it. +const OAUTH_NONCE_COOKIE: &str = "oauth_nonce"; + +/// Extract the `oauth_nonce` pre-auth cookie value from the request headers. +fn extract_oauth_nonce(headers: &HeaderMap) -> Option { + headers + .get("cookie")? + .to_str() + .ok()? + .split(';') + .find_map(|c| { + let c = c.trim(); + c.strip_prefix("oauth_nonce=").map(|v| v.to_string()) + }) +} + +/// Constant-time byte comparison (no early return) for the nonce check. +fn ct_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + let mut diff = 0u8; + for (x, y) in a.iter().zip(b.iter()) { + diff |= x ^ y; + } + diff == 0 +} + +/// Validate and atomically consume the stored OAuth flow state. +/// +/// `remove` makes the state single-use (replay-safe). Returns the flow state on +/// success, or a `BadRequest` for the first failed check. +fn validate_and_consume_state( + states: &DashMap, + state_param: &str, + provider: &str, + nonce_cookie: Option<&str>, +) -> Result { + let (_, flow) = states + .remove(state_param) + .ok_or_else(|| AppError::BadRequest("Invalid or expired OAuth state".to_string()))?; + + if flow.provider != provider { + return Err(AppError::BadRequest("OAuth state mismatch".to_string())); + } + if Utc::now() - flow.created_at > chrono::Duration::minutes(10) { + return Err(AppError::BadRequest("OAuth state expired".to_string())); + } + match nonce_cookie { + Some(cookie) if ct_eq(cookie.as_bytes(), flow.nonce.as_bytes()) => {} + _ => return Err(AppError::BadRequest("OAuth session mismatch".to_string())), + } + Ok(flow) +} + /// Redirect user to the OAuth provider's authorization page. #[utoipa::path( get, @@ -79,7 +136,7 @@ pub async fn list_providers( pub async fn oauth_authorize( State(state): State>, Path(provider): Path, -) -> Result { +) -> Result<(HeaderMap, Redirect), AppError> { if !OAuthService::is_configured(&provider, &state.config.oauth) { return Err(AppError::BadRequest(format!( "OAuth provider '{provider}' is not configured" @@ -88,7 +145,10 @@ pub async fn oauth_authorize( let client = OAuthService::build_client(&provider, &state.config.oauth)?; - let mut auth_request = client.authorize_url(CsrfToken::new_random); + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let mut auth_request = client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); // Add scopes based on provider let scopes = match provider.as_str() { @@ -96,25 +156,47 @@ pub async fn oauth_authorize( "google" => vec!["openid", "email", "profile"], _ => vec![], }; - for scope in scopes { auth_request = auth_request.add_scope(Scope::new(scope.to_string())); } let (auth_url, csrf_token) = auth_request.url(); - // Store CSRF state → provider mapping with 10-minute TTL - state - .oauth_states - .insert(csrf_token.secret().clone(), (provider, Utc::now())); + // Browser-binding nonce: mirrored into a short-lived pre-auth cookie and + // re-checked on callback to defend against login CSRF / session fixation. + let nonce = AuthService::generate_session_token(); + + state.oauth_states.insert( + csrf_token.secret().clone(), + OAuthFlowState { + provider, + created_at: Utc::now(), + nonce: nonce.clone(), + pkce_verifier: pkce_verifier.secret().clone(), + }, + ); // Evict expired states (older than 10 minutes) to prevent memory leak let cutoff = Utc::now() - chrono::Duration::minutes(10); - state - .oauth_states - .retain(|_, (_, created)| *created > cutoff); + state.oauth_states.retain(|_, flow| flow.created_at > cutoff); + + let secure_flag = if state.config.auth.secure_cookie { + "; Secure" + } else { + "" + }; + let cookie = format!( + "{OAUTH_NONCE_COOKIE}={nonce}; HttpOnly; SameSite=Lax; Path=/api/auth/oauth; Max-Age=600{secure_flag}" + ); + let mut response_headers = HeaderMap::new(); + response_headers.insert( + SET_COOKIE, + cookie + .parse() + .map_err(|_| AppError::Internal("Failed to set cookie".to_string()))?, + ); - Ok(Redirect::temporary(auth_url.as_str())) + Ok((response_headers, Redirect::temporary(auth_url.as_str()))) } /// Handle the OAuth callback from the provider. @@ -139,31 +221,23 @@ pub async fn oauth_callback( Query(query): Query, headers: HeaderMap, ) -> Result<(HeaderMap, Redirect), AppError> { - // Validate CSRF state token - let stored = state.oauth_states.remove(&query.state); - match stored { - Some((_, (stored_provider, created_at))) => { - // Check provider matches - if stored_provider != provider { - return Err(AppError::BadRequest("OAuth state mismatch".to_string())); - } - // Check not expired (10 minute window) - if Utc::now() - created_at > chrono::Duration::minutes(10) { - return Err(AppError::BadRequest("OAuth state expired".to_string())); - } - } - None => { - return Err(AppError::BadRequest( - "Invalid or expired OAuth state".to_string(), - )); - } - } + // Validate CSRF state + browser-binding nonce, atomically consuming the state. + let nonce_cookie = extract_oauth_nonce(&headers); + let flow = validate_and_consume_state( + &state.oauth_states, + &query.state, + &provider, + nonce_cookie.as_deref(), + )?; let client = OAuthService::build_client(&provider, &state.config.oauth)?; - // Exchange authorization code for access token + // Exchange authorization code for access token, supplying the PKCE verifier + // so the provider only honors a code bound to the verifier we generated at + // authorize time (defends against authorization-code interception/injection). let token_result = client .exchange_code(AuthorizationCode::new(query.code)) + .set_pkce_verifier(PkceCodeVerifier::new(flow.pkce_verifier)) .request_async(async_http_client) .await .map_err(|e| AppError::Internal(format!("OAuth token exchange failed: {e}")))?; @@ -229,6 +303,11 @@ pub async fn oauth_callback( token, state.config.auth.session_ttl, secure_flag ); + // Clear the pre-auth nonce cookie now that the flow is complete. + let clear_cookie = format!( + "{OAUTH_NONCE_COOKIE}=; HttpOnly; SameSite=Lax; Path=/api/auth/oauth; Max-Age=0{secure_flag}" + ); + let mut response_headers = HeaderMap::new(); response_headers.insert( SET_COOKIE, @@ -236,6 +315,91 @@ pub async fn oauth_callback( .parse() .map_err(|_| AppError::Internal("Failed to set cookie".to_string()))?, ); + response_headers.append( + SET_COOKIE, + clear_cookie + .parse() + .map_err(|_| AppError::Internal("Failed to set cookie".to_string()))?, + ); Ok((response_headers, Redirect::temporary("/"))) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::state::OAuthFlowState; + use dashmap::DashMap; + + fn make_states( + state: &str, + provider: &str, + nonce: &str, + age_min: i64, + ) -> DashMap { + let states = DashMap::new(); + states.insert( + state.to_string(), + OAuthFlowState { + provider: provider.to_string(), + created_at: Utc::now() - chrono::Duration::minutes(age_min), + nonce: nonce.to_string(), + pkce_verifier: "verifier1".to_string(), + }, + ); + states + } + + #[test] + fn rejects_unknown_state() { + let states: DashMap = DashMap::new(); + let err = + validate_and_consume_state(&states, "missing", "github", Some("n")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_provider_mismatch() { + let states = make_states("s1", "github", "nonce1", 0); + let err = + validate_and_consume_state(&states, "s1", "google", Some("nonce1")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_expired_state() { + let states = make_states("s1", "github", "nonce1", 11); + let err = + validate_and_consume_state(&states, "s1", "github", Some("nonce1")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_missing_nonce_cookie() { + let states = make_states("s1", "github", "nonce1", 0); + let err = validate_and_consume_state(&states, "s1", "github", None).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_mismatched_nonce_cookie() { + let states = make_states("s1", "github", "nonce1", 0); + let err = + validate_and_consume_state(&states, "s1", "github", Some("wrong")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn accepts_valid_state_and_is_single_use() { + let states = make_states("s1", "github", "nonce1", 0); + let flow = + validate_and_consume_state(&states, "s1", "github", Some("nonce1")).unwrap(); + assert_eq!(flow.provider, "github"); + // the PKCE verifier must round-trip back to the caller for token exchange + assert_eq!(flow.pkce_verifier, "verifier1"); + // second use must fail: state was consumed (replay protection) + let err = + validate_and_consume_state(&states, "s1", "github", Some("nonce1")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } +} diff --git a/crates/server/src/router/api/setting.rs b/crates/server/src/router/api/setting.rs index f5829c060..29f4ffe64 100644 --- a/crates/server/src/router/api/setting.rs +++ b/crates/server/src/router/api/setting.rs @@ -1,13 +1,19 @@ +use std::net::SocketAddr; use std::sync::Arc; use axum::body::Body; -use axum::extract::State; -use axum::http::header; +use axum::extract::{ConnectInfo, State}; +use axum::http::{HeaderMap, header}; use axum::routing::{get, post, put}; -use axum::{Json, Router}; +use axum::{Extension, Json, Router}; +use sea_orm::SqlxSqliteConnector; use serde::{Deserialize, Serialize}; +use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions}; use crate::error::{ApiResponse, AppError, ok}; +use crate::middleware::auth::CurrentUser; +use crate::router::utils::extract_client_ip; +use crate::service::audit::AuditService; use crate::service::config::ConfigService; use crate::state::AppState; @@ -77,6 +83,9 @@ async fn update_settings( )] pub async fn create_backup( State(state): State>, + Extension(actor): Extension, + ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, ) -> Result { let db_path = resolve_db_path(&state.config); @@ -100,6 +109,23 @@ pub async fn create_backup( // Clean up backup file let _ = tokio::fs::remove_file(&backup_path).await; + // Audit: exporting the full DB (password hashes, 2FA secrets, tokens) is a + // high-risk admin action and must leave a forensic trail. + let caller_ip = extract_client_ip( + &ConnectInfo(addr), + &headers, + &state.config.server.trusted_proxies, + ) + .to_string(); + let _ = AuditService::log( + &state.db, + &actor.user_id, + "settings.backup", + Some(&format!("bytes={}", bytes.len())), + &caller_ip, + ) + .await; + let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); let filename = format!("serverbee_backup_{timestamp}.db"); @@ -129,6 +155,9 @@ pub async fn create_backup( )] pub async fn restore_backup( State(state): State>, + Extension(actor): Extension, + ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, body: axum::body::Bytes, ) -> Result>, AppError> { if body.len() < 16 { @@ -170,6 +199,49 @@ pub async fn restore_backup( .await .map_err(|e| AppError::Internal(format!("Failed to restore DB: {e}")))?; + // Audit: replacing the live DB with an uploaded file can backdoor the whole + // instance — record it before the operator restarts. + // + // `state.db`'s pooled connections still hold the pre-restore inode (the file + // we just renamed to `.pre-restore`), so a row written through `state.db` + // would land in the now-discarded database and be lost after the mandatory + // restart. Open a short-lived connection to the freshly restored file so the + // forensic record persists into the DB the server will reopen. A tracing + // line is emitted unconditionally as a durable fallback. + let caller_ip = extract_client_ip( + &ConnectInfo(addr), + &headers, + &state.config.server.trusted_proxies, + ) + .to_string(); + tracing::warn!( + user_id = %actor.user_id, + ip = %caller_ip, + bytes = body.len(), + "audit: settings.restore — live database replaced via restore endpoint" + ); + match SqlitePoolOptions::new() + .max_connections(1) + .connect_with(SqliteConnectOptions::new().filename(&db_path)) + .await + { + Ok(pool) => { + let restored_db = SqlxSqliteConnector::from_sqlx_sqlite_pool(pool); + let _ = AuditService::log( + &restored_db, + &actor.user_id, + "settings.restore", + Some(&format!("bytes={}", body.len())), + &caller_ip, + ) + .await; + let _ = restored_db.close().await; + } + Err(e) => { + tracing::error!("failed to open restored DB to persist restore audit log: {e}"); + } + } + ok("Database restored. Please restart the server.") } diff --git a/crates/server/src/router/api/user.rs b/crates/server/src/router/api/user.rs index 58078b355..b09abd4c9 100644 --- a/crates/server/src/router/api/user.rs +++ b/crates/server/src/router/api/user.rs @@ -1,10 +1,15 @@ +use std::net::SocketAddr; use std::sync::Arc; -use axum::extract::{Path, State}; +use axum::extract::{ConnectInfo, Path, State}; +use axum::http::HeaderMap; use axum::routing::{delete, get, post, put}; -use axum::{Json, Router}; +use axum::{Extension, Json, Router}; use crate::error::{ApiResponse, AppError, ok}; +use crate::middleware::auth::CurrentUser; +use crate::router::utils::extract_client_ip; +use crate::service::audit::AuditService; use crate::service::user::{CreateUserInput, UpdateUserInput, UserResponse, UserService}; use crate::state::AppState; @@ -70,10 +75,33 @@ pub async fn get_user( )] pub async fn create_user( State(state): State>, + Extension(actor): Extension, + ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, Json(body): Json, ) -> Result>, AppError> { let user = UserService::create_user(&state.db, &body.username, &body.password, &body.role).await?; + + // Audit: account creation (a member or a stealth admin) is privilege-sensitive. + let caller_ip = extract_client_ip( + &ConnectInfo(addr), + &headers, + &state.config.server.trusted_proxies, + ) + .to_string(); + let _ = AuditService::log( + &state.db, + &actor.user_id, + "user.create", + Some(&format!( + "id={} username={} role={}", + user.id, user.username, user.role + )), + &caller_ip, + ) + .await; + ok(UserResponse::from(user)) } @@ -92,10 +120,40 @@ pub async fn create_user( )] pub async fn update_user( State(state): State>, + Extension(actor): Extension, + ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, Path(id): Path, Json(body): Json, ) -> Result>, AppError> { + let password_reset = body.password.is_some(); + let old_role = UserService::get_user(&state.db, &id).await?.role; let user = UserService::update_user(&state.db, &id, body).await?; + + // Audit: role promotion/demotion and password resets are privilege-sensitive. + let role_change = if old_role == user.role { + format!("role={}", user.role) + } else { + format!("role {old_role}->{}", user.role) + }; + let caller_ip = extract_client_ip( + &ConnectInfo(addr), + &headers, + &state.config.server.trusted_proxies, + ) + .to_string(); + let _ = AuditService::log( + &state.db, + &actor.user_id, + "user.update", + Some(&format!( + "id={} username={} {role_change} password_reset={password_reset}", + user.id, user.username + )), + &caller_ip, + ) + .await; + ok(UserResponse::from(user)) } @@ -114,8 +172,32 @@ pub async fn update_user( )] pub async fn delete_user( State(state): State>, + Extension(actor): Extension, + ConnectInfo(addr): ConnectInfo, + headers: HeaderMap, Path(id): Path, ) -> Result>, AppError> { + // Capture the target's identity before deletion so the audit entry is meaningful. + let target = UserService::get_user(&state.db, &id).await?; UserService::delete_user(&state.db, &id).await?; + + let caller_ip = extract_client_ip( + &ConnectInfo(addr), + &headers, + &state.config.server.trusted_proxies, + ) + .to_string(); + let _ = AuditService::log( + &state.db, + &actor.user_id, + "user.delete", + Some(&format!( + "id={} username={} role={}", + target.id, target.username, target.role + )), + &caller_ip, + ) + .await; + ok("ok") } diff --git a/crates/server/src/router/ws/agent.rs b/crates/server/src/router/ws/agent.rs index 7d548a831..da7e30cec 100644 --- a/crates/server/src/router/ws/agent.rs +++ b/crates/server/src/router/ws/agent.rs @@ -37,17 +37,20 @@ pub struct OptionalWsQuery { } fn extract_agent_token(headers: &HeaderMap, query: &OptionalWsQuery) -> Option { - // Prefer query param (reliable through reverse proxies / cloud load balancers) - if let Some(ref token) = query.token { - return Some(token.clone()); - } - // Fallback to Authorization header (direct connections) + // Prefer the Authorization header. Unlike the query string, it is not + // captured in reverse-proxy access logs, browser history, or Referer headers + // (CWE-598). The agent always sends this header alongside the query param. if let Some(auth) = headers.get("authorization") && let Ok(val) = auth.to_str() && let Some(token) = val.strip_prefix("Bearer ") { return Some(token.to_string()); } + // Fall back to the query param for proxies/load balancers that strip the + // Authorization header. + if let Some(ref token) = query.token { + return Some(token.clone()); + } None } @@ -1618,6 +1621,39 @@ mod tests { SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080) } + #[test] + fn extract_agent_token_prefers_authorization_header() { + let mut headers = HeaderMap::new(); + headers.insert("authorization", "Bearer header-token".parse().unwrap()); + let query = OptionalWsQuery { + token: Some("query-token".to_string()), + }; + // Header wins so the secret stays out of proxy access logs. + assert_eq!( + extract_agent_token(&headers, &query), + Some("header-token".to_string()) + ); + } + + #[test] + fn extract_agent_token_falls_back_to_query() { + let headers = HeaderMap::new(); + let query = OptionalWsQuery { + token: Some("query-token".to_string()), + }; + assert_eq!( + extract_agent_token(&headers, &query), + Some("query-token".to_string()) + ); + } + + #[test] + fn extract_agent_token_none_when_absent() { + let headers = HeaderMap::new(); + let query = OptionalWsQuery { token: None }; + assert_eq!(extract_agent_token(&headers, &query), None); + } + #[tokio::test] async fn current_connection_frame_handler_waits_for_server_lock() { let (db, _tmp) = setup_test_db().await; diff --git a/crates/server/src/service/checker/dns.rs b/crates/server/src/service/checker/dns.rs index c818da77c..b505e229e 100644 --- a/crates/server/src/service/checker/dns.rs +++ b/crates/server/src/service/checker/dns.rs @@ -7,6 +7,7 @@ use hickory_resolver::config::{NameServerConfig, ResolverConfig}; use hickory_resolver::proto::rr::RecordType; use hickory_resolver::proto::xfer::Protocol; use serde_json::{Value, json}; +use serverbee_common::ssrf; use super::CheckResult; @@ -113,6 +114,14 @@ pub async fn check(target: &str, config: &Value) -> CheckResult { fn build_resolver(nameserver: Option<&str>) -> Result { if let Some(ns) = nameserver { let ip = IpAddr::from_str(ns).map_err(|e| format!("Invalid nameserver IP '{ns}': {e}"))?; + // SSRF guard: an attacker-supplied nameserver must not point the server + // at loopback/link-local/metadata resolvers (private resolvers are + // allowed for internal monitoring). + if !ssrf::is_monitor_safe_addr(ip) { + return Err(format!( + "nameserver '{ns}' is a blocked address (loopback/link-local/metadata)" + )); + } let ns_config = NameServerConfig::new(SocketAddr::new(ip, 53), Protocol::Udp); let mut resolver_config = ResolverConfig::new(); resolver_config.add_name_server(ns_config); @@ -197,4 +206,17 @@ mod tests { let result = build_resolver(Some("not-an-ip")); assert!(result.is_err()); } + + #[test] + fn test_build_resolver_rejects_loopback_nameserver() { + // SSRF guard: cannot point the resolver at loopback or cloud metadata. + assert!(build_resolver(Some("127.0.0.1")).is_err()); + assert!(build_resolver(Some("169.254.169.254")).is_err()); + } + + #[test] + fn test_build_resolver_allows_private_nameserver() { + // Internal resolvers (RFC1918) are a legitimate monitoring setup. + assert!(build_resolver(Some("10.0.0.53")).is_ok()); + } } diff --git a/crates/server/src/service/checker/http_keyword.rs b/crates/server/src/service/checker/http_keyword.rs index 9731cb816..cb43a4dd5 100644 --- a/crates/server/src/service/checker/http_keyword.rs +++ b/crates/server/src/service/checker/http_keyword.rs @@ -2,6 +2,7 @@ use std::time::{Duration, Instant}; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use serde_json::{Value, json}; +use serverbee_common::ssrf; use super::CheckResult; @@ -58,83 +59,149 @@ pub async fn check(target: &str, config: &Value) -> CheckResult { } }; - // Build the HTTP client - let client = match reqwest::Client::builder() - .timeout(Duration::from_secs(timeout_secs)) - .danger_accept_invalid_certs(false) - .build() - { - Ok(c) => c, - Err(e) => { - let latency = start.elapsed().as_secs_f64() * 1000.0; - return CheckResult { - success: false, - latency: Some(latency), - detail: Value::Null, - error: Some(format!("Failed to build HTTP client: {e}")), - }; - } - }; + // SSRF guard with per-hop revalidation. `.resolve_to_addrs` only pins the + // original host, and reqwest's default policy would auto-follow up to 10 + // redirects — so a monitored endpoint could 3xx-redirect to + // 169.254.169.254 or an internal host and slip past the guard entirely. + // Disable auto-redirect and follow manually, validating the URL (scheme + + // credentials; any port allowed for non-standard service ports) and pinning + // the client to the freshly validated addresses on every hop. Mirrors the + // agent's ip_quality fetcher. + const MAX_REDIRECTS: usize = 10; - // Build the request - let mut request = match method.as_str() { - "GET" => client.get(target), - "POST" => { - let mut req = client.post(target); - if let Some(body) = body_str { - req = req.body(body.to_string()); - } - req - } - other => { - let latency = start.elapsed().as_secs_f64() * 1000.0; - return CheckResult { - success: false, - latency: Some(latency), - detail: Value::Null, - error: Some(format!("Unsupported HTTP method: {other}")), - }; - } + // Early-return helper for validation/build errors (Value::Null detail). + let fail = move |msg: String| CheckResult { + success: false, + latency: Some(start.elapsed().as_secs_f64() * 1000.0), + detail: Value::Null, + error: Some(msg), }; - request = request.headers(custom_headers); + let mut current_url = target.to_string(); + let mut hop = 0usize; - // Execute the request - let response = match request.send().await { - Ok(r) => r, - Err(e) => { - let latency = start.elapsed().as_secs_f64() * 1000.0; - return CheckResult { - success: false, - latency: Some(latency), - detail: json!({ - "status_code": null, - "keyword_found": null, - "response_time_ms": latency, - }), - error: Some(format!("HTTP request failed: {e}")), + let (status_code, response_body) = loop { + let url = match ssrf::validate_monitor_url(¤t_url) { + Ok(u) => u, + Err(e) => return fail(e.to_string()), + }; + let host = match url.host_str() { + Some(h) => h.to_string(), + None => return fail("URL has no host".to_string()), + }; + let port = url.port_or_known_default().unwrap_or(80); + let validated_addrs = match ssrf::resolve_and_check_monitor(&host, port) { + Ok(addrs) => addrs, + Err(e) => return fail(e.to_string()), + }; + + let client = match reqwest::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .danger_accept_invalid_certs(false) + .redirect(reqwest::redirect::Policy::none()) + .resolve_to_addrs(&host, &validated_addrs) + .build() + { + Ok(c) => c, + Err(e) => return fail(format!("Failed to build HTTP client: {e}")), + }; + + // The configured method/body/custom headers apply to the first hop + // only; redirects are followed as GET without the body or custom + // headers, so a POST body and any secrets in those headers are never + // replayed to a redirected (possibly attacker-controlled) host. + let request = if hop == 0 { + let base = match method.as_str() { + "GET" => client.get(url.clone()), + "POST" => { + let mut req = client.post(url.clone()); + if let Some(body) = body_str { + req = req.body(body.to_string()); + } + req + } + other => return fail(format!("Unsupported HTTP method: {other}")), }; - } - }; + base.headers(custom_headers.clone()) + } else { + client.get(url.clone()) + }; - let status_code = response.status().as_u16(); + // Execute the request + let response = match request.send().await { + Ok(r) => r, + Err(e) => { + let latency = start.elapsed().as_secs_f64() * 1000.0; + return CheckResult { + success: false, + latency: Some(latency), + detail: json!({ + "status_code": null, + "keyword_found": null, + "response_time_ms": latency, + }), + error: Some(format!("HTTP request failed: {e}")), + }; + } + }; - // Read response body - let response_body = match response.text().await { - Ok(text) => text, - Err(e) => { - let latency = start.elapsed().as_secs_f64() * 1000.0; - return CheckResult { - success: false, - latency: Some(latency), - detail: json!({ - "status_code": status_code, - "keyword_found": null, - "response_time_ms": latency, - }), - error: Some(format!("Failed to read response body: {e}")), + let status = response.status(); + + // Follow redirects manually so every hop is SSRF-validated. + if status.is_redirection() { + if hop >= MAX_REDIRECTS { + let latency = start.elapsed().as_secs_f64() * 1000.0; + return CheckResult { + success: false, + latency: Some(latency), + detail: json!({ + "status_code": status.as_u16(), + "keyword_found": null, + "response_time_ms": latency, + }), + error: Some(format!("Too many redirects (max {MAX_REDIRECTS})")), + }; + } + let location = match response + .headers() + .get(reqwest::header::LOCATION) + .and_then(|v| v.to_str().ok()) + { + Some(loc) => loc.to_string(), + None => return fail("Redirect response had no Location header".to_string()), }; + // Resolve the redirect target relative to the current URL; the next + // loop iteration re-validates and re-pins it. + let next = match url.join(&location) { + Ok(u) => u, + Err(e) => return fail(format!("Invalid redirect Location: {e}")), + }; + current_url = next.to_string(); + hop += 1; + continue; } + + let status_code = status.as_u16(); + + // Read response body + let response_body = match response.text().await { + Ok(text) => text, + Err(e) => { + let latency = start.elapsed().as_secs_f64() * 1000.0; + return CheckResult { + success: false, + latency: Some(latency), + detail: json!({ + "status_code": status_code, + "keyword_found": null, + "response_time_ms": latency, + }), + error: Some(format!("Failed to read response body: {e}")), + }; + } + }; + + break (status_code, response_body); }; let latency = start.elapsed().as_secs_f64() * 1000.0; diff --git a/crates/server/src/service/checker/ssl.rs b/crates/server/src/service/checker/ssl.rs index d2959f738..bfff01b84 100644 --- a/crates/server/src/service/checker/ssl.rs +++ b/crates/server/src/service/checker/ssl.rs @@ -4,6 +4,7 @@ use std::time::{Duration, Instant}; use rustls::ClientConfig; use rustls::pki_types::ServerName; use serde_json::{Value, json}; +use serverbee_common::ssrf; use sha2::{Digest, Sha256}; use tokio::net::TcpStream; use tokio_rustls::TlsConnector; @@ -79,8 +80,25 @@ pub async fn check(target: &str, config: &Value) -> CheckResult { let addr = format!("{host}:{port}"); + // SSRF guard: reject hosts resolving to blocked (loopback/link-local/ + // metadata) addresses, and connect only to the validated addresses (the TLS + // handshake still uses `server_name` for SNI / certificate validation). + let validated_addrs = match ssrf::resolve_and_check_monitor(&host, port) { + Ok(addrs) => addrs, + Err(e) => { + let latency = start.elapsed().as_secs_f64() * 1000.0; + return CheckResult { + success: false, + latency: Some(latency), + detail: Value::Null, + error: Some(e.to_string()), + }; + } + }; + // Connect TCP - let tcp_stream = match tokio::time::timeout(timeout, TcpStream::connect(&addr)).await { + let tcp_stream = match tokio::time::timeout(timeout, TcpStream::connect(&validated_addrs[..])).await + { Ok(Ok(s)) => s, Ok(Err(e)) => { let latency = start.elapsed().as_secs_f64() * 1000.0; diff --git a/crates/server/src/service/checker/tcp.rs b/crates/server/src/service/checker/tcp.rs index 79fa803f7..c53d57f6b 100644 --- a/crates/server/src/service/checker/tcp.rs +++ b/crates/server/src/service/checker/tcp.rs @@ -1,10 +1,23 @@ use std::time::{Duration, Instant}; use serde_json::{Value, json}; +use serverbee_common::ssrf; use tokio::net::TcpStream; use super::CheckResult; +/// Parse a "host:port" (or "[ipv6]:port") target into its components. +fn split_host_port(target: &str) -> Option<(String, u16)> { + if let Some(rest) = target.strip_prefix('[') { + // [ipv6]:port + let (host, after) = rest.split_once(']')?; + let port = after.strip_prefix(':')?.parse().ok()?; + return Some((host.to_string(), port)); + } + let (host, port) = target.rsplit_once(':')?; + Some((host.to_string(), port.parse().ok()?)) +} + /// Check TCP connectivity to `target` (expected format: "host:port"). /// /// Config options: @@ -15,7 +28,33 @@ pub async fn check(target: &str, config: &Value) -> CheckResult { let start = Instant::now(); - match tokio::time::timeout(timeout, TcpStream::connect(target)).await { + // SSRF guard: parse host:port, reject targets resolving to blocked + // (loopback/link-local/metadata) addresses, and connect only to the + // validated addresses so the host cannot rebind to a different IP. + let (host, port) = match split_host_port(target) { + Some(hp) => hp, + None => { + return CheckResult { + success: false, + latency: None, + detail: json!({ "connected": false }), + error: Some(format!("Invalid TCP target '{target}' (expected host:port)")), + }; + } + }; + let addrs = match ssrf::resolve_and_check_monitor(&host, port) { + Ok(addrs) => addrs, + Err(e) => { + return CheckResult { + success: false, + latency: None, + detail: json!({ "connected": false }), + error: Some(e.to_string()), + }; + } + }; + + match tokio::time::timeout(timeout, TcpStream::connect(&addrs[..])).await { Ok(Ok(_stream)) => { let latency = start.elapsed().as_secs_f64() * 1000.0; CheckResult { @@ -49,39 +88,63 @@ pub async fn check(target: &str, config: &Value) -> CheckResult { #[cfg(test)] mod tests { use super::*; - use tokio::net::TcpListener; - #[tokio::test] - async fn test_tcp_connect_success() { - // Bind a listener on a random port - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); + #[test] + fn test_split_host_port_ipv4() { + assert_eq!( + split_host_port("10.0.0.5:6379"), + Some(("10.0.0.5".to_string(), 6379)) + ); + } - let config = json!({ "timeout": 5 }); - let result = check(&addr.to_string(), &config).await; + #[test] + fn test_split_host_port_ipv6_bracketed() { + assert_eq!( + split_host_port("[fd00::1]:443"), + Some(("fd00::1".to_string(), 443)) + ); + } - assert!(result.success); - assert!(result.latency.is_some()); - assert!(result.error.is_none()); - assert_eq!(result.detail["connected"], true); + #[test] + fn test_split_host_port_invalid() { + assert_eq!(split_host_port("no-port"), None); + assert_eq!(split_host_port("host:notaport"), None); } #[tokio::test] - async fn test_tcp_connect_refused() { - // Use a port that is almost certainly not listening - let config = json!({ "timeout": 2 }); - let result = check("127.0.0.1:1", &config).await; + async fn test_tcp_blocks_loopback() { + // SSRF guard: loopback is never a valid monitoring target. + let result = check("127.0.0.1:80", &json!({ "timeout": 2 })).await; + assert!(!result.success); + assert!( + result.error.unwrap_or_default().contains("SSRF guard"), + "loopback should be rejected by the SSRF guard" + ); + } + #[tokio::test] + async fn test_tcp_blocks_cloud_metadata() { + let result = check("169.254.169.254:80", &json!({ "timeout": 2 })).await; assert!(!result.success); - assert!(result.error.is_some()); - assert_eq!(result.detail["connected"], false); + assert!(result.error.unwrap_or_default().contains("SSRF guard")); + } + + #[tokio::test] + async fn test_tcp_allows_private_target_through_guard() { + // Private RFC1918 is allowed past the guard (internal monitoring). + // Whether the connection itself succeeds or fails is environment- + // dependent; the invariant is that the guard does NOT reject it. + let result = check("10.255.255.1:1", &json!({ "timeout": 1 })).await; + assert!( + !result.error.unwrap_or_default().contains("SSRF guard"), + "private targets must be allowed past the SSRF guard" + ); } #[tokio::test] - async fn test_tcp_default_timeout() { - let config = json!({}); - // Just verify it doesn't panic with default config - let result = check("127.0.0.1:1", &config).await; + async fn test_tcp_invalid_target() { + let result = check("no-port-here", &json!({})).await; assert!(!result.success); + assert!(result.error.unwrap_or_default().contains("Invalid TCP target")); } } diff --git a/crates/server/src/service/service_monitor.rs b/crates/server/src/service/service_monitor.rs index 53b6dcf4c..43e456b82 100644 --- a/crates/server/src/service/service_monitor.rs +++ b/crates/server/src/service/service_monitor.rs @@ -23,6 +23,63 @@ fn default_true() -> bool { true } +/// Extract the host portion of a "host", "host:port", or "[ipv6]:port" target. +fn host_of(target: &str) -> &str { + if let Some(rest) = target.strip_prefix('[') + && let Some((host, _)) = rest.split_once(']') + { + return host; + } + if let Some((host, port)) = target.rsplit_once(':') + && port.parse::().is_ok() + { + return host; + } + target +} + +/// Reject monitor targets whose literal host is a loopback / link-local / +/// cloud-metadata address (or "localhost") at create/update time, so the SSRF +/// guard's rejection surfaces as a clear 4xx at configuration time instead of +/// as silent runtime check failures (and false alerts) on existing monitors. +/// +/// Only the connect-based types are SSRF-guarded at runtime (tcp/ssl/ +/// http_keyword); the dns checker guards only its nameserver and whois targets +/// a registry, so both are skipped here. This is a literal-host check with no +/// DNS resolution — hostnames that resolve to blocked addresses are still +/// caught by the runtime guard. +fn validate_target_addr(monitor_type: &str, target: &str) -> Result<(), AppError> { + let host = match monitor_type { + "http_keyword" => match serverbee_common::ssrf::validate_monitor_url(target) { + Ok(url) => url.host_str().map(str::to_string), + // A malformed URL is left for the runtime checker to report. + Err(_) => None, + }, + "tcp" | "ssl" => Some(host_of(target).to_string()), + _ => None, + }; + + let Some(host) = host else { return Ok(()) }; + let host = host.trim_start_matches('[').trim_end_matches(']'); + + if host.eq_ignore_ascii_case("localhost") { + return Err(AppError::Validation( + "target 'localhost' is not allowed: loopback/link-local/cloud-metadata \ + addresses are blocked for monitoring" + .to_string(), + )); + } + if let Ok(ip) = host.parse::() + && !serverbee_common::ssrf::is_monitor_safe_addr(ip) + { + return Err(AppError::Validation(format!( + "target address '{host}' is not allowed: loopback/link-local/cloud-metadata \ + addresses are blocked for monitoring" + ))); + } + Ok(()) +} + #[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)] pub struct CreateServiceMonitor { pub name: String, @@ -90,6 +147,7 @@ impl ServiceMonitorService { VALID_TYPES.join(", ") ))); } + validate_target_addr(&input.monitor_type, &input.target)?; let config_json = serde_json::to_string(&input.config_json) .map_err(|e| AppError::Validation(format!("Invalid config_json: {e}")))?; @@ -131,12 +189,14 @@ impl ServiceMonitorService { input: UpdateServiceMonitor, ) -> Result { let existing = Self::get(db, id).await?; + let monitor_type = existing.monitor_type.clone(); let mut model: service_monitor::ActiveModel = existing.into(); if let Some(name) = input.name { model.name = Set(name); } if let Some(target) = input.target { + validate_target_addr(&monitor_type, &target)?; model.target = Set(target); } if let Some(interval) = input.interval { @@ -303,6 +363,47 @@ mod tests { use super::*; use crate::test_utils::setup_test_db; + #[test] + fn host_of_extracts_host() { + assert_eq!(host_of("example.com:443"), "example.com"); + assert_eq!(host_of("example.com"), "example.com"); + assert_eq!(host_of("[fd00::1]:6379"), "fd00::1"); + assert_eq!(host_of("[::1]"), "::1"); + // a non-numeric ":suffix" is not a port, so the whole string is the host + assert_eq!(host_of("host:notaport"), "host:notaport"); + } + + #[test] + fn validate_target_addr_rejects_loopback_and_metadata() { + // localhost literal + assert!(validate_target_addr("tcp", "localhost:9527").is_err()); + assert!(validate_target_addr("http_keyword", "http://localhost/health").is_err()); + // loopback / metadata / unspecified literals across connect-based types + assert!(validate_target_addr("tcp", "127.0.0.1:80").is_err()); + assert!(validate_target_addr("ssl", "169.254.169.254").is_err()); + assert!(validate_target_addr("ssl", "[::1]:443").is_err()); + assert!(validate_target_addr("http_keyword", "http://127.0.0.1:8080/").is_err()); + assert!(validate_target_addr("http_keyword", "http://169.254.169.254/").is_err()); + } + + #[test] + fn validate_target_addr_allows_public_and_private() { + // public + assert!(validate_target_addr("tcp", "8.8.8.8:53").is_ok()); + assert!(validate_target_addr("http_keyword", "https://example.com:8443/").is_ok()); + // RFC1918 private is a legitimate internal-monitoring target + assert!(validate_target_addr("tcp", "10.0.0.5:6379").is_ok()); + assert!(validate_target_addr("ssl", "192.168.1.10:443").is_ok()); + } + + #[test] + fn validate_target_addr_skips_unguarded_types() { + // dns guards only the nameserver, whois targets a registry — the literal + // query target is not address-checked here. + assert!(validate_target_addr("dns", "localhost").is_ok()); + assert!(validate_target_addr("whois", "127.0.0.1").is_ok()); + } + fn sample_create() -> CreateServiceMonitor { CreateServiceMonitor { name: "Test SSL Monitor".to_string(), diff --git a/crates/server/src/state.rs b/crates/server/src/state.rs index d1d5ee5e1..68a084800 100644 --- a/crates/server/src/state.rs +++ b/crates/server/src/state.rs @@ -29,6 +29,32 @@ pub struct PendingTotp { pub created_at: chrono::DateTime, } +/// In-flight OAuth login flow state, keyed by the CSRF `state` token. +/// +/// `nonce` is mirrored into a short-lived HttpOnly pre-auth cookie set on the +/// authorize redirect and re-checked on the callback, binding the flow to the +/// browser that initiated it (defends against login CSRF / session fixation). +pub struct OAuthFlowState { + pub provider: String, + pub created_at: chrono::DateTime, + pub nonce: String, + pub pkce_verifier: String, +} + +// Manual `Debug` that redacts the browser-binding nonce and the PKCE verifier so +// these single-use secrets never land in logs or panic messages. (`Debug` is +// still required because `Result::unwrap_err` formats the Ok value in tests.) +impl std::fmt::Debug for OAuthFlowState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OAuthFlowState") + .field("provider", &self.provider) + .field("created_at", &self.created_at) + .field("nonce", &"") + .field("pkce_verifier", &"") + .finish() + } +} + /// Pending mobile pairing code, keyed by code string. pub struct PendingPair { pub user_id: String, @@ -52,8 +78,8 @@ pub struct AppState { pub geoip_downloading: AtomicBool, pub asn: Arc>>, pub asn_downloading: AtomicBool, - /// CSRF state tokens for OAuth flow, keyed by state string → provider. - pub oauth_states: DashMap)>, + /// In-flight OAuth login flows, keyed by the CSRF `state` token. + pub oauth_states: DashMap, /// Pending TOTP secrets for 2FA setup, keyed by user_id. pub pending_totp: DashMap, /// Rate limiter for login attempts, keyed by IP. diff --git a/crates/server/tests/integration.rs b/crates/server/tests/integration.rs index 036a37042..f65ec4f14 100644 --- a/crates/server/tests/integration.rs +++ b/crates/server/tests/integration.rs @@ -2624,12 +2624,12 @@ async fn test_service_monitor_crud_and_check() { let client = http_client(); login_admin(&client, &base_url).await; - // ── Step 1: Create a TCP monitor targeting the test server's own address ── + // ── Step 1a: loopback targets are rejected at create time (SSRF guard) ── let addr = base_url.trim_start_matches("http://"); - let create_resp = client + let loopback_resp = client .post(format!("{}/api/service-monitors", base_url)) .json(&json!({ - "name": "Localhost TCP Check", + "name": "Loopback TCP Check", "monitor_type": "tcp", "target": addr, "interval": 300, @@ -2639,6 +2639,30 @@ async fn test_service_monitor_crud_and_check() { .send() .await .expect("POST /api/service-monitors failed"); + assert_eq!( + loopback_resp.status(), + 422, + "creating a loopback-target monitor must be rejected at config time" + ); + + // ── Step 1b: create a monitor with a safe, non-routable target ── + // 192.0.2.1 is TEST-NET-1 (RFC 5737): it passes the monitor SSRF guard (not + // loopback/link-local/metadata) but is non-routable, so the live check below + // fails to connect rather than being blocked — exercising the full CRUD + + // check lifecycle without depending on an external host. + let create_resp = client + .post(format!("{}/api/service-monitors", base_url)) + .json(&json!({ + "name": "TCP Check", + "monitor_type": "tcp", + "target": "192.0.2.1:9", + "interval": 300, + "config_json": { "timeout": 2 }, + "enabled": true + })) + .send() + .await + .expect("POST /api/service-monitors failed"); assert_eq!( create_resp.status(), @@ -2649,7 +2673,7 @@ async fn test_service_monitor_crud_and_check() { let monitor_id = create_body["data"]["id"] .as_str() .expect("monitor id missing"); - assert_eq!(create_body["data"]["name"], "Localhost TCP Check"); + assert_eq!(create_body["data"]["name"], "TCP Check"); assert_eq!(create_body["data"]["monitor_type"], "tcp"); // ── Step 2: List monitors — verify it appears ── @@ -2669,7 +2693,7 @@ async fn test_service_monitor_crud_and_check() { "Created monitor should appear in list" ); - // ── Step 3: Trigger check — the test server is listening on the target port ── + // ── Step 3: Trigger check — non-routable target fails to connect ── let check_resp = client .post(format!( "{}/api/service-monitors/{}/check", @@ -2684,10 +2708,15 @@ async fn test_service_monitor_crud_and_check() { let record = &check_body["data"]; assert!(record["id"].is_number(), "record should have a numeric id"); assert_eq!(record["monitor_id"], monitor_id); - // TCP connection to our own test server should succeed - assert_eq!( - record["success"], true, - "TCP check to localhost test server should succeed" + // Whether the TCP connect to a non-routable address succeeds, is refused, + // or times out is environment-dependent; the invariant here is that the + // check endpoint runs the checker and persists a record with a real boolean + // success field. (Loopback/metadata rejection is covered by the checker unit + // tests and the Step 1a create-time guard above.) + assert!( + record["success"].is_boolean(), + "check should persist a boolean success, got: {:?}", + record["success"] ); // ── Step 4: Get records — verify the check created a record ── @@ -2706,7 +2735,7 @@ async fn test_service_monitor_crud_and_check() { .as_array() .expect("data should be array"); assert_eq!(records.len(), 1, "should have 1 record after one check"); - assert_eq!(records[0]["success"], true); + assert!(records[0]["success"].is_boolean()); // ── Step 5: Delete monitor ── let delete_resp = client diff --git a/docs/superpowers/plans/2026-06-08-oauth-login-csrf-fix.md b/docs/superpowers/plans/2026-06-08-oauth-login-csrf-fix.md new file mode 100644 index 000000000..801769de1 --- /dev/null +++ b/docs/superpowers/plans/2026-06-08-oauth-login-csrf-fix.md @@ -0,0 +1,541 @@ +# OAuth Login CSRF / Session Fixation Fix Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Bind the OAuth login flow to the initiating browser (pre-auth cookie nonce) and adopt PKCE (S256) so a victim cannot be silently logged into an attacker's account. + +**Architecture:** `oauth_authorize` mints a CSPRNG nonce + PKCE challenge, stores them with the CSRF state in the in-memory `oauth_states` map, and mirrors the nonce into a short-lived HttpOnly pre-auth cookie. `oauth_callback` consumes the state atomically and rejects the request unless the request's `oauth_nonce` cookie matches the stored nonce, then exchanges the code with the PKCE verifier. + +**Tech Stack:** Rust, Axum 0.8, `oauth2` v4.4.2, `dashmap`, sea-orm. + +Spec: `docs/superpowers/specs/2026-06-08-oauth-login-csrf-fix-design.md` + +--- + +## File Structure + +- `crates/server/src/state.rs` — add `OAuthFlowState` struct; change `oauth_states` field type. +- `crates/server/src/router/api/oauth.rs` — pure validation helper + cookie helpers + handler wiring + unit tests. + +No migration, no config change, no frontend change. The two endpoints stay HTTP 302 redirects, so the generated OpenAPI / `api-types.ts` is unaffected. + +--- + +## Task 1: Browser-bound OAuth state via pre-auth cookie nonce + +This task fully closes the login-CSRF / session-fixation hole. PKCE is added in Task 2. + +**Files:** +- Modify: `crates/server/src/state.rs` (struct near `PendingTotp` ~line 24; field ~line 56) +- Modify: `crates/server/src/router/api/oauth.rs` +- Test: `crates/server/src/router/api/oauth.rs` (`#[cfg(test)] mod tests`) + +- [ ] **Step 1: Add `OAuthFlowState` struct and migrate the field in `state.rs`** + +Add the struct just after the `PendingTotp` struct definition: + +```rust +/// In-flight OAuth login flow state, keyed by the CSRF `state` token. +/// +/// `nonce` is mirrored into a short-lived HttpOnly pre-auth cookie set on the +/// authorize redirect and re-checked on the callback, binding the flow to the +/// browser that initiated it (defends against login CSRF / session fixation). +pub struct OAuthFlowState { + pub provider: String, + pub created_at: chrono::DateTime, + pub nonce: String, +} +``` + +Change the field declaration (was `DashMap)>`): + +```rust + /// In-flight OAuth login flows, keyed by the CSRF `state` token. + pub oauth_states: DashMap, +``` + +The initializer `oauth_states: DashMap::new(),` stays as-is. + +- [ ] **Step 2: Write the failing unit tests in `oauth.rs`** + +Append this module at the end of `crates/server/src/router/api/oauth.rs`: + +```rust +#[cfg(test)] +mod tests { + use super::*; + use crate::state::OAuthFlowState; + use dashmap::DashMap; + + fn make_states( + state: &str, + provider: &str, + nonce: &str, + age_min: i64, + ) -> DashMap { + let states = DashMap::new(); + states.insert( + state.to_string(), + OAuthFlowState { + provider: provider.to_string(), + created_at: Utc::now() - chrono::Duration::minutes(age_min), + nonce: nonce.to_string(), + }, + ); + states + } + + #[test] + fn rejects_unknown_state() { + let states: DashMap = DashMap::new(); + let err = + validate_and_consume_state(&states, "missing", "github", Some("n")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_provider_mismatch() { + let states = make_states("s1", "github", "nonce1", 0); + let err = + validate_and_consume_state(&states, "s1", "google", Some("nonce1")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_expired_state() { + let states = make_states("s1", "github", "nonce1", 11); + let err = + validate_and_consume_state(&states, "s1", "github", Some("nonce1")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_missing_nonce_cookie() { + let states = make_states("s1", "github", "nonce1", 0); + let err = validate_and_consume_state(&states, "s1", "github", None).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn rejects_mismatched_nonce_cookie() { + let states = make_states("s1", "github", "nonce1", 0); + let err = + validate_and_consume_state(&states, "s1", "github", Some("wrong")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } + + #[test] + fn accepts_valid_state_and_is_single_use() { + let states = make_states("s1", "github", "nonce1", 0); + let flow = + validate_and_consume_state(&states, "s1", "github", Some("nonce1")).unwrap(); + assert_eq!(flow.provider, "github"); + // second use must fail: state was consumed (replay protection) + let err = + validate_and_consume_state(&states, "s1", "github", Some("nonce1")).unwrap_err(); + assert!(matches!(err, AppError::BadRequest(_))); + } +} +``` + +- [ ] **Step 3: Run the tests to confirm they fail (RED)** + +Run: `cargo test -p serverbee-server --lib router::api::oauth` +Expected: compile error — `validate_and_consume_state` not found (and possibly `OAuthFlowState` import). This is the RED state. + +- [ ] **Step 4: Add imports, cookie name, helpers, and the validator to `oauth.rs`** + +Add to the top-of-file imports (alongside the existing `use` lines): + +```rust +use crate::state::OAuthFlowState; +use dashmap::DashMap; +``` + +Add a module-level constant (near the top, after the imports): + +```rust +/// Name of the short-lived HttpOnly pre-auth cookie that binds an OAuth login +/// flow to the browser that started it. +const OAUTH_NONCE_COOKIE: &str = "oauth_nonce"; +``` + +Add these three functions (place them above `oauth_authorize`): + +```rust +/// Extract the `oauth_nonce` pre-auth cookie value from the request headers. +fn extract_oauth_nonce(headers: &HeaderMap) -> Option { + headers + .get("cookie")? + .to_str() + .ok()? + .split(';') + .find_map(|c| { + let c = c.trim(); + c.strip_prefix("oauth_nonce=").map(|v| v.to_string()) + }) +} + +/// Constant-time byte comparison (no early return) for the nonce check. +fn ct_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + let mut diff = 0u8; + for (x, y) in a.iter().zip(b.iter()) { + diff |= x ^ y; + } + diff == 0 +} + +/// Validate and atomically consume the stored OAuth flow state. +/// +/// `remove` makes the state single-use (replay-safe). Returns the flow state on +/// success, or a `BadRequest` for the first failed check. +fn validate_and_consume_state( + states: &DashMap, + state_param: &str, + provider: &str, + nonce_cookie: Option<&str>, +) -> Result { + let (_, flow) = states + .remove(state_param) + .ok_or_else(|| AppError::BadRequest("Invalid or expired OAuth state".to_string()))?; + + if flow.provider != provider { + return Err(AppError::BadRequest("OAuth state mismatch".to_string())); + } + if Utc::now() - flow.created_at > chrono::Duration::minutes(10) { + return Err(AppError::BadRequest("OAuth state expired".to_string())); + } + match nonce_cookie { + Some(cookie) if ct_eq(cookie.as_bytes(), flow.nonce.as_bytes()) => {} + _ => return Err(AppError::BadRequest("OAuth session mismatch".to_string())), + } + Ok(flow) +} +``` + +- [ ] **Step 5: Rewrite `oauth_authorize` to mint the nonce, store the struct, and set the pre-auth cookie** + +Replace the whole `oauth_authorize` function body (keep the `#[utoipa::path(...)]` attribute above it; change only the return type and body): + +```rust +pub async fn oauth_authorize( + State(state): State>, + Path(provider): Path, +) -> Result<(HeaderMap, Redirect), AppError> { + if !OAuthService::is_configured(&provider, &state.config.oauth) { + return Err(AppError::BadRequest(format!( + "OAuth provider '{provider}' is not configured" + ))); + } + + let client = OAuthService::build_client(&provider, &state.config.oauth)?; + + let mut auth_request = client.authorize_url(CsrfToken::new_random); + + // Add scopes based on provider + let scopes = match provider.as_str() { + "github" => vec!["read:user", "user:email"], + "google" => vec!["openid", "email", "profile"], + _ => vec![], + }; + for scope in scopes { + auth_request = auth_request.add_scope(Scope::new(scope.to_string())); + } + + let (auth_url, csrf_token) = auth_request.url(); + + // Browser-binding nonce: mirrored into a short-lived pre-auth cookie and + // re-checked on callback to defend against login CSRF / session fixation. + let nonce = AuthService::generate_session_token(); + + state.oauth_states.insert( + csrf_token.secret().clone(), + OAuthFlowState { + provider, + created_at: Utc::now(), + nonce: nonce.clone(), + }, + ); + + // Evict expired states (older than 10 minutes) to prevent memory leak + let cutoff = Utc::now() - chrono::Duration::minutes(10); + state.oauth_states.retain(|_, flow| flow.created_at > cutoff); + + let secure_flag = if state.config.auth.secure_cookie { + "; Secure" + } else { + "" + }; + let cookie = format!( + "{OAUTH_NONCE_COOKIE}={nonce}; HttpOnly; SameSite=Lax; Path=/api/auth/oauth; Max-Age=600{secure_flag}" + ); + let mut response_headers = HeaderMap::new(); + response_headers.insert( + SET_COOKIE, + cookie + .parse() + .map_err(|_| AppError::Internal("Failed to set cookie".to_string()))?, + ); + + Ok((response_headers, Redirect::temporary(auth_url.as_str()))) +} +``` + +- [ ] **Step 6: Rewrite the state-validation block in `oauth_callback` and clear the cookie on success** + +In `oauth_callback`, replace the existing CSRF-validation block (the `let stored = state.oauth_states.remove(&query.state); match stored { ... }`) with: + +```rust + // Validate CSRF state + browser-binding nonce, atomically consuming the state. + let nonce_cookie = extract_oauth_nonce(&headers); + validate_and_consume_state( + &state.oauth_states, + &query.state, + &provider, + nonce_cookie.as_deref(), + )?; +``` + +Then, at the end of `oauth_callback`, replace the cookie-setting block so it also clears the pre-auth cookie. Replace: + +```rust + let mut response_headers = HeaderMap::new(); + response_headers.insert( + SET_COOKIE, + cookie + .parse() + .map_err(|_| AppError::Internal("Failed to set cookie".to_string()))?, + ); + + Ok((response_headers, Redirect::temporary("/"))) +``` + +with: + +```rust + // Clear the pre-auth nonce cookie now that the flow is complete. + let clear_cookie = format!( + "{OAUTH_NONCE_COOKIE}=; HttpOnly; SameSite=Lax; Path=/api/auth/oauth; Max-Age=0{secure_flag}" + ); + + let mut response_headers = HeaderMap::new(); + response_headers.insert( + SET_COOKIE, + cookie + .parse() + .map_err(|_| AppError::Internal("Failed to set cookie".to_string()))?, + ); + response_headers.append( + SET_COOKIE, + clear_cookie + .parse() + .map_err(|_| AppError::Internal("Failed to set cookie".to_string()))?, + ); + + Ok((response_headers, Redirect::temporary("/"))) +``` + +(`secure_flag` is already in scope in `oauth_callback` from the existing session-cookie code.) + +- [ ] **Step 7: Run the tests to confirm they pass (GREEN)** + +Run: `cargo test -p serverbee-server --lib router::api::oauth` +Expected: 6 tests pass. + +- [ ] **Step 8: Build and lint** + +Run: `cargo clippy -p serverbee-server -- -D warnings` +Expected: no warnings, no errors. + +- [ ] **Step 9: Commit** + +```bash +git add crates/server/src/state.rs crates/server/src/router/api/oauth.rs +git commit -m "fix(server): bind OAuth login state to initiating browser via pre-auth cookie" +``` + +--- + +## Task 2: Add PKCE (S256) to the OAuth login flow + +Defense-in-depth against authorization-code interception/injection. The PKCE +verifier rides in the same `OAuthFlowState` and is bound to the same single-use +state, so no extra validation branch is needed. + +**Files:** +- Modify: `crates/server/src/state.rs` (`OAuthFlowState`) +- Modify: `crates/server/src/router/api/oauth.rs` +- Test: `crates/server/src/router/api/oauth.rs` (update `make_states`) + +- [ ] **Step 1: Add the `pkce_verifier` field to `OAuthFlowState`** + +In `crates/server/src/state.rs`, add the field to the struct: + +```rust +pub struct OAuthFlowState { + pub provider: String, + pub created_at: chrono::DateTime, + pub nonce: String, + pub pkce_verifier: String, +} +``` + +- [ ] **Step 2: Update the PKCE imports in `oauth.rs`** + +Extend the existing `oauth2` import line to add the PKCE types. Change: + +```rust +use oauth2::{AuthorizationCode, CsrfToken, Scope, TokenResponse}; +``` + +to: + +```rust +use oauth2::{AuthorizationCode, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, Scope, TokenResponse}; +``` + +- [ ] **Step 3: Generate the PKCE challenge in `oauth_authorize` and store the verifier** + +In `oauth_authorize`, change the `auth_request` creation to attach the PKCE +challenge, and generate the verifier just before it. Replace: + +```rust + let mut auth_request = client.authorize_url(CsrfToken::new_random); +``` + +with: + +```rust + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + let mut auth_request = client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); +``` + +Then add the verifier to the stored struct. Change the `OAuthFlowState { ... }` +literal in `oauth_authorize` to: + +```rust + OAuthFlowState { + provider, + created_at: Utc::now(), + nonce: nonce.clone(), + pkce_verifier: pkce_verifier.secret().clone(), + }, +``` + +- [ ] **Step 4: Send the PKCE verifier during token exchange in `oauth_callback`** + +Capture the validated flow state (it was discarded in Task 1). Change: + +```rust + validate_and_consume_state( + &state.oauth_states, + &query.state, + &provider, + nonce_cookie.as_deref(), + )?; +``` + +to: + +```rust + let flow = validate_and_consume_state( + &state.oauth_states, + &query.state, + &provider, + nonce_cookie.as_deref(), + )?; +``` + +Then change the code exchange. Replace: + +```rust + let token_result = client + .exchange_code(AuthorizationCode::new(query.code)) + .request_async(async_http_client) + .await + .map_err(|e| AppError::Internal(format!("OAuth token exchange failed: {e}")))?; +``` + +with: + +```rust + let token_result = client + .exchange_code(AuthorizationCode::new(query.code)) + .set_pkce_verifier(PkceCodeVerifier::new(flow.pkce_verifier)) + .request_async(async_http_client) + .await + .map_err(|e| AppError::Internal(format!("OAuth token exchange failed: {e}")))?; +``` + +- [ ] **Step 5: Update the test helper to populate `pkce_verifier`** + +In the `#[cfg(test)] mod tests` `make_states` helper, add the field to the +struct literal so it compiles: + +```rust + OAuthFlowState { + provider: provider.to_string(), + created_at: Utc::now() - chrono::Duration::minutes(age_min), + nonce: nonce.to_string(), + pkce_verifier: "verifier1".to_string(), + }, +``` + +- [ ] **Step 6: Run the tests (still GREEN)** + +Run: `cargo test -p serverbee-server --lib router::api::oauth` +Expected: 6 tests pass. + +- [ ] **Step 7: Build and lint** + +Run: `cargo clippy -p serverbee-server -- -D warnings` +Expected: no warnings, no errors. + +- [ ] **Step 8: Commit** + +```bash +git add crates/server/src/state.rs crates/server/src/router/api/oauth.rs +git commit -m "feat(server): adopt PKCE (S256) in OAuth login flow" +``` + +--- + +## Task 3: Final verification + +- [ ] **Step 1: Full workspace build + clippy** + +Run: `cargo clippy --workspace -- -D warnings` +Expected: clean. + +- [ ] **Step 2: Run the server test suite** + +Run: `cargo test -p serverbee-server` +Expected: all tests pass (including the 6 new oauth tests). + +- [ ] **Step 3: Manual verification note** + +The full network round-trip (provider token exchange + userinfo) is not unit +tested. If an OAuth provider is configured locally, manually confirm: +1. Clicking "Sign in with GitHub/Google" still completes login. +2. Replaying a captured `callback?code=...&state=...` URL in a *different* + browser (no `oauth_nonce` cookie) returns HTTP 400 "OAuth session mismatch". +3. The provider request now carries `code_challenge` / `code_challenge_method=S256`. + +--- + +## Self-Review + +- **Spec coverage:** browser-binding nonce (Task 1), PKCE (Task 2), single-use + state (Task 1 test `accepts_valid_state_and_is_single_use`), cookie clearing + (Task 1 Step 6), constant-time compare (`ct_eq`), no migration/config/frontend + change — all covered. +- **Placeholder scan:** none — every step shows full code/commands. +- **Type consistency:** `OAuthFlowState` fields (`provider`, `created_at`, + `nonce`, then `pkce_verifier` in Task 2) match across state.rs, the validator, + the handlers, and the test helper. `validate_and_consume_state` signature is + identical everywhere it appears. diff --git a/docs/superpowers/specs/2026-06-08-oauth-login-csrf-fix-design.md b/docs/superpowers/specs/2026-06-08-oauth-login-csrf-fix-design.md new file mode 100644 index 000000000..2cc463ffa --- /dev/null +++ b/docs/superpowers/specs/2026-06-08-oauth-login-csrf-fix-design.md @@ -0,0 +1,143 @@ +# OAuth Login CSRF / Session Fixation Fix — Design + +Date: 2026-06-08 +Status: Approved +Area: `crates/server` — OAuth login flow + +## Problem + +The OAuth login flow does not bind its CSRF `state` to the initiating browser +and uses no PKCE. `oauth_authorize` stores `state -> (provider, created_at)` in a +server-global `DashMap` keyed only by the CSRF token value; `oauth_callback` +validates only that the state exists, the provider matches, and the 10-minute TTL +has not elapsed, then unconditionally issues a `session_token` cookie. + +Because the state is never tied to the browser that started the flow, an attacker +can pre-initiate a flow with their own provider account, capture an unspent +`code` + `state`, and lure a victim to the callback URL within the TTL. The +victim's browser is then silently logged into the **attacker's** account +(login CSRF / session fixation). Any data, credentials, or 2FA the victim +subsequently configures lands in the attacker's account. This is the only +audit finding reachable with no credentials (only when an OAuth provider is +configured; providers are disabled by default). + +This deviates from the OAuth 2.0 Security BCP (RFC 9700), which requires either +PKCE or a state value bound to the user agent. + +## Goal + +Make the OAuth login flow resistant to login CSRF / session fixation and align +it with the OAuth 2.0 Security BCP, without changing the frontend, adding +migrations, or introducing new configuration. + +## Approach + +Add browser binding via a short-lived pre-auth cookie nonce, and adopt PKCE +(S256): + +1. On authorize, generate a random `nonce` (CSPRNG) and a PKCE + challenge/verifier. Send the challenge to the provider, store the nonce and + verifier alongside the state, and set the nonce in a short-lived HttpOnly + pre-auth cookie. +2. On callback, require the pre-auth cookie nonce to match the stored nonce for + the presented state before exchanging the code (with the PKCE verifier). + +An attacker cannot set the victim's pre-auth cookie, so a forged callback fails +the nonce check. PKCE additionally defeats authorization-code interception / +injection. + +The frontend already initiates the flow via a top-level anchor navigation +(``), so the `SameSite=Lax` pre-auth cookie +is sent both on the authorize redirect and on the provider's callback redirect. +No frontend change is required. + +## Design + +### Flow-state structure (`crates/server/src/state.rs`) + +Replace the tuple value with a named struct: + +```rust +pub struct OAuthFlowState { + pub provider: String, + pub created_at: DateTime, + pub nonce: String, // must match the pre-auth cookie on callback + pub pkce_verifier: String, // PKCE code_verifier secret +} +// oauth_states: DashMap +``` + +### `oauth_authorize` (`router/api/oauth.rs`) + +- Generate PKCE: `let (challenge, verifier) = PkceCodeChallenge::new_random_sha256();` + and `auth_request.set_pkce_challenge(challenge)`. +- Generate a CSPRNG nonce (reuse `AuthService::generate_session_token()`). +- Insert `OAuthFlowState { provider, created_at: now, nonce, pkce_verifier: + verifier.secret().clone() }` keyed by the CSRF state secret. +- Keep the existing expired-state eviction. +- Change the return type to `(HeaderMap, Redirect)` and set the pre-auth cookie: + `oauth_nonce=; HttpOnly; SameSite=Lax; Path=/api/auth/oauth; Max-Age=600` + (append `; Secure` when `auth.secure_cookie`). + +### Callback validation — extracted pure function + +```rust +fn validate_and_consume_state( + states: &DashMap, + state_param: &str, + provider: &str, + nonce_cookie: Option<&str>, +) -> Result +``` + +Order (note: `remove` consumes the state, making it single-use / replay-safe): + +1. `remove(state_param)` returns `None` -> `BadRequest("Invalid or expired OAuth state")`. +2. provider mismatch -> `BadRequest("OAuth state mismatch")`. +3. `now - created_at > 10min` -> `BadRequest("OAuth state expired")`. +4. `nonce_cookie` missing or not equal to stored nonce (constant-time compare) + -> `BadRequest("OAuth session mismatch")`. +5. Return the `OAuthFlowState` (carrying `pkce_verifier`). + +### `oauth_callback` (`router/api/oauth.rs`) + +- Parse `oauth_nonce` from the `Cookie` header via a small local helper that + mirrors `extract_session_cookie`. +- Call `validate_and_consume_state(...)` to obtain the `pkce_verifier`. +- Exchange the code with `.set_pkce_verifier(PkceCodeVerifier::new(flow.pkce_verifier))`. +- On success, set the `session_token` cookie **and** clear the pre-auth cookie + (`oauth_nonce=; Max-Age=0; Path=/api/auth/oauth`). + +## Testing + +Unit tests on `validate_and_consume_state` (no network required): + +- rejects an unknown state +- rejects a provider mismatch +- rejects an expired state +- rejects a missing nonce cookie +- rejects a mismatched nonce cookie +- accepts a valid state + matching nonce, returns the verifier, and a second + call fails (replay protection) + +Full network flow (token exchange, userinfo) stays manually verified, matching +the existing code, which has no OAuth integration tests. + +## Decisions / scope + +- Callback errors keep returning HTTP 400, consistent with the existing + state-error behavior. A legitimate user whose pre-auth cookie expired + (>10 min between click and callback) sees 400 — same effective window as the + existing state TTL. +- The pre-auth cookie is scoped to `Path=/api/auth/oauth` to limit its surface. +- The OIDC dead-code path (`userinfo` unimplemented, always 500s) is left + untouched — out of scope. +- No migration, no config change, no frontend change. +- Files touched: `crates/server/src/state.rs`, + `crates/server/src/router/api/oauth.rs`. + +## Out of scope (tracked separately) + +Other audit findings (server-side SSRF in service-monitor checkers, missing +audit logging on backup/restore and user management, agent token in WS query +string, default `trusted_proxies` breadth) are not addressed here.