diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index f5576b6a..a4331ee5 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -14,7 +14,7 @@ use hyper::{Request, StatusCode}; use hyper_rustls::HttpsConnectorBuilder; use hyper_util::{client::legacy::Client, rt::TokioExecutor}; use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; -use miette::{IntoDiagnostic, Result, WrapErr}; +use miette::{IntoDiagnostic, Result, WrapErr, miette}; use openshell_bootstrap::{ DeployOptions, GatewayMetadata, RemoteOptions, clear_active_gateway, container_name, extract_host_from_ssh_destination, get_gateway_metadata, list_gateways, load_active_gateway, @@ -1653,6 +1653,7 @@ pub fn doctor_exec( ssh_key: Option<&str>, command: &[String], ) -> Result<()> { + validate_gateway_name(name)?; let container = container_name(name); let is_tty = std::io::stdin().is_terminal(); @@ -1676,7 +1677,15 @@ pub fn doctor_exec( }; let mut cmd = if let Some(ref host) = remote_host { + validate_ssh_host(host)?; + // Remote: ssh docker exec [-it] sh -lc '' + // + // SSH concatenates all arguments after the hostname into a single + // string for the remote shell, so inner_cmd must be escaped twice: + // once for `sh -lc` (already done above) and once for the SSH + // remote shell (done here). + let ssh_escaped_cmd = shell_escape(&inner_cmd); let mut c = Command::new("ssh"); if let Some(key) = ssh_key { c.args(["-i", key]); @@ -1693,7 +1702,7 @@ pub fn doctor_exec( } else { c.arg("-i"); } - c.args([&container, "sh", "-lc", &inner_cmd]); + c.args([&container, "sh", "-lc", &ssh_escaped_cmd]); c } else { // Local: docker exec [-it] sh -lc '' @@ -1790,6 +1799,42 @@ fn shell_escape(s: &str) -> String { format!("'{}'", s.replace('\'', "'\\''")) } +/// Validate that a gateway name is safe for use in container/volume/network +/// names and shell commands. Rejects names with characters outside the set +/// `[a-zA-Z0-9._-]`. +fn validate_gateway_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(miette!("gateway name is empty")); + } + if !name + .bytes() + .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'-' | b'_')) + { + return Err(miette!( + "gateway name contains invalid characters (allowed: alphanumeric, '.', '-', '_')" + )); + } + Ok(()) +} + +/// Validate that an SSH host string is a reasonable hostname or IP address. +/// Rejects values with shell metacharacters, spaces, or control characters +/// that could be used for injection via a poisoned metadata.json. +fn validate_ssh_host(host: &str) -> Result<()> { + if host.is_empty() { + return Err(miette!("SSH host is empty")); + } + // Allow: alphanumeric, dots, hyphens, colons (IPv6), square brackets ([::1]), + // and @ (user@host). + if !host + .bytes() + .all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'-' | b':' | b'[' | b']' | b'@')) + { + return Err(miette!("SSH host contains invalid characters: {host}")); + } + Ok(()) +} + /// Create a sandbox when no gateway is configured. /// /// Bootstraps a new gateway first, then delegates to [`sandbox_create`]. @@ -4942,7 +4987,7 @@ mod tests { git_sync_files, http_health_check, image_requests_gpu, inferred_provider_type, parse_cli_setting_value, parse_credential_pairs, provisioning_timeout_message, ready_false_condition_message, resolve_gateway_control_target_from, sandbox_should_persist, - source_requests_gpu, + shell_escape, source_requests_gpu, validate_gateway_name, validate_ssh_host, }; use crate::TEST_ENV_LOCK; use hyper::StatusCode; @@ -5457,4 +5502,58 @@ mod tests { server.join().expect("server thread"); assert_eq!(status, Some(StatusCode::OK)); } + + // ---- SEC-004: validate_gateway_name, validate_ssh_host, shell_escape ---- + + #[test] + fn validate_gateway_name_accepts_valid_names() { + assert!(validate_gateway_name("openshell").is_ok()); + assert!(validate_gateway_name("my-gateway").is_ok()); + assert!(validate_gateway_name("gateway_v2").is_ok()); + assert!(validate_gateway_name("gw.prod").is_ok()); + } + + #[test] + fn validate_gateway_name_rejects_invalid_names() { + assert!(validate_gateway_name("").is_err()); + assert!(validate_gateway_name("gw;rm -rf /").is_err()); + assert!(validate_gateway_name("gw name").is_err()); + assert!(validate_gateway_name("gw$(id)").is_err()); + assert!(validate_gateway_name("gw\nmalicious").is_err()); + } + + #[test] + fn validate_ssh_host_accepts_valid_hosts() { + assert!(validate_ssh_host("192.168.1.1").is_ok()); + assert!(validate_ssh_host("example.com").is_ok()); + assert!(validate_ssh_host("user@host.com").is_ok()); + assert!(validate_ssh_host("[::1]").is_ok()); + assert!(validate_ssh_host("2001:db8::1").is_ok()); + } + + #[test] + fn validate_ssh_host_rejects_invalid_hosts() { + assert!(validate_ssh_host("").is_err()); + assert!(validate_ssh_host("host;rm -rf /").is_err()); + assert!(validate_ssh_host("host$(id)").is_err()); + assert!(validate_ssh_host("host name").is_err()); + assert!(validate_ssh_host("host\nmalicious").is_err()); + } + + #[test] + fn shell_escape_double_escape_for_ssh() { + // Simulate the double-escape path for SSH: + // First escape for sh -lc, then escape again for SSH remote shell. + let inner_cmd = "KUBECONFIG=/etc/rancher/k3s/k3s.yaml echo 'hello world'"; + let ssh_escaped = shell_escape(inner_cmd); + // The result should be single-quoted (wrapping the entire inner_cmd) + assert!( + ssh_escaped.starts_with('\''), + "should be single-quoted: {ssh_escaped}" + ); + assert!( + ssh_escaped.ends_with('\''), + "should end with single-quote: {ssh_escaped}" + ); + } } diff --git a/crates/openshell-sandbox/src/l7/inference.rs b/crates/openshell-sandbox/src/l7/inference.rs index 022fe4fa..59dafdab 100644 --- a/crates/openshell-sandbox/src/l7/inference.rs +++ b/crates/openshell-sandbox/src/l7/inference.rs @@ -171,43 +171,69 @@ pub fn try_parse_http_request(buf: &[u8]) -> ParseResult { ) } +/// Maximum decoded body size from chunked transfer encoding (10 MiB). +/// Matches the caller's `MAX_INFERENCE_BUF` limit. +const MAX_CHUNKED_BODY: usize = 10 * 1024 * 1024; + +/// Maximum number of chunks to process. Normal HTTP clients send the body +/// in a handful of large chunks; thousands of tiny chunks indicate abuse. +const MAX_CHUNK_COUNT: usize = 4096; + /// Parse an HTTP chunked body from `buf[start..]`. /// /// Returns `(decoded_body, total_consumed_bytes_from_buf_start)` when complete, -/// or `None` if more bytes are needed. +/// or `None` if more bytes are needed or resource limits are exceeded. fn parse_chunked_body(buf: &[u8], start: usize) -> Option<(Vec, usize)> { let mut pos = start; let mut body = Vec::new(); + let mut chunk_count: usize = 0; loop { + chunk_count += 1; + if chunk_count > MAX_CHUNK_COUNT { + return None; + } + let size_line_end = find_crlf(buf, pos)?; let size_line = std::str::from_utf8(&buf[pos..size_line_end]).ok()?; let size_token = size_line.split(';').next()?.trim(); let chunk_size = usize::from_str_radix(size_token, 16).ok()?; - pos = size_line_end + 2; + pos = size_line_end.checked_add(2)?; if chunk_size == 0 { // Parse trailers (if any). Terminates on empty trailer line. loop { let trailer_end = find_crlf(buf, pos)?; let trailer_line = &buf[pos..trailer_end]; - pos = trailer_end + 2; + pos = trailer_end.checked_add(2)?; if trailer_line.is_empty() { return Some((body, pos)); } } } + // Early reject: chunk cannot possibly fit in remaining buffer. + let remaining = buf.len().saturating_sub(pos); + if chunk_size > remaining { + return None; + } + + // Reject if decoded body would exceed size limit. + if body.len().saturating_add(chunk_size) > MAX_CHUNKED_BODY { + return None; + } + let chunk_end = pos.checked_add(chunk_size)?; - if buf.len() < chunk_end + 2 { + let chunk_crlf_end = chunk_end.checked_add(2)?; + if buf.len() < chunk_crlf_end { return None; } - if &buf[chunk_end..chunk_end + 2] != b"\r\n" { + if &buf[chunk_end..chunk_crlf_end] != b"\r\n" { return None; } body.extend_from_slice(&buf[pos..chunk_end]); - pos = chunk_end + 2; + pos = chunk_crlf_end; } } @@ -484,4 +510,46 @@ mod tests { assert!(chunk.ends_with(b"\r\n")); assert_eq!(chunk.len(), 3 + 2 + 256 + 2); // "100" + \r\n + data + \r\n } + + // ---- SEC-010: parse_chunked_body resource limits ---- + + #[test] + fn parse_chunked_multi_chunk_body() { + // Two chunks: 5 bytes + 6 bytes + let request = b"POST /v1/chat HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n5\r\nhello\r\n6\r\n world\r\n0\r\n\r\n"; + let ParseResult::Complete(parsed, _) = try_parse_http_request(request) else { + panic!("expected Complete"); + }; + assert_eq!(parsed.body, b"hello world"); + } + + #[test] + fn parse_chunked_rejects_too_many_chunks() { + // Build a request with MAX_CHUNK_COUNT + 1 tiny chunks + let mut buf = Vec::new(); + buf.extend_from_slice(b"POST /v1/chat HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); + for _ in 0..=MAX_CHUNK_COUNT { + buf.extend_from_slice(b"1\r\nX\r\n"); + } + buf.extend_from_slice(b"0\r\n\r\n"); + assert!(matches!( + try_parse_http_request(&buf), + ParseResult::Incomplete + )); + } + + #[test] + fn parse_chunked_within_chunk_count_limit() { + // MAX_CHUNK_COUNT chunks should succeed + let mut buf = Vec::new(); + buf.extend_from_slice(b"POST /v1/chat HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n"); + for _ in 0..100 { + buf.extend_from_slice(b"1\r\nX\r\n"); + } + buf.extend_from_slice(b"0\r\n\r\n"); + let ParseResult::Complete(parsed, _) = try_parse_http_request(&buf) else { + panic!("expected Complete for 100 chunks"); + }; + assert_eq!(parsed.body.len(), 100); + } } diff --git a/crates/openshell-sandbox/src/l7/rest.rs b/crates/openshell-sandbox/src/l7/rest.rs index 61b026f5..26b8d4e8 100644 --- a/crates/openshell-sandbox/src/l7/rest.rs +++ b/crates/openshell-sandbox/src/l7/rest.rs @@ -88,7 +88,24 @@ async fn parse_http_request(client: &mut C) -> Result(client: &mut C) -> Result( } /// Parse Content-Length or Transfer-Encoding from HTTP headers. -fn parse_body_length(headers: &str) -> BodyLength { +/// +/// Per RFC 7230 Section 3.3.3, rejects requests containing both +/// `Content-Length` and `Transfer-Encoding` headers to prevent request +/// smuggling via CL/TE ambiguity. +fn parse_body_length(headers: &str) -> Result { + let mut has_te_chunked = false; + let mut cl_value: Option = None; + for line in headers.lines().skip(1) { let lower = line.to_ascii_lowercase(); if lower.starts_with("transfer-encoding:") { let val = lower.split_once(':').map_or("", |(_, v)| v.trim()); if val.contains("chunked") { - return BodyLength::Chunked; + has_te_chunked = true; } } if lower.starts_with("content-length:") && let Some(val) = lower.split_once(':').map(|(_, v)| v.trim()) && let Ok(len) = val.parse::() { - return BodyLength::ContentLength(len); + cl_value = Some(len); } } - BodyLength::None + + if has_te_chunked && cl_value.is_some() { + return Err(miette!( + "Request contains both Transfer-Encoding and Content-Length headers" + )); + } + + if has_te_chunked { + return Ok(BodyLength::Chunked); + } + if let Some(len) = cl_value { + return Ok(BodyLength::ContentLength(len)); + } + Ok(BodyLength::None) } /// Relay exactly `len` bytes from reader to writer. @@ -413,7 +456,7 @@ where let header_str = String::from_utf8_lossy(&buf[..header_end]); let status_code = parse_status_code(&header_str).unwrap_or(200); let server_wants_close = parse_connection_close(&header_str); - let body_length = parse_body_length(&header_str); + let body_length = parse_body_length(&header_str)?; debug!( status_code, @@ -596,7 +639,7 @@ mod tests { #[test] fn parse_content_length() { let headers = "POST /api HTTP/1.1\r\nHost: example.com\r\nContent-Length: 42\r\n\r\n"; - match parse_body_length(headers) { + match parse_body_length(headers).unwrap() { BodyLength::ContentLength(42) => {} other => panic!("Expected ContentLength(42), got {other:?}"), } @@ -606,7 +649,7 @@ mod tests { fn parse_chunked() { let headers = "POST /api HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\n\r\n"; - match parse_body_length(headers) { + match parse_body_length(headers).unwrap() { BodyLength::Chunked => {} other => panic!("Expected Chunked, got {other:?}"), } @@ -615,12 +658,77 @@ mod tests { #[test] fn parse_no_body() { let headers = "GET /api HTTP/1.1\r\nHost: example.com\r\n\r\n"; - match parse_body_length(headers) { + match parse_body_length(headers).unwrap() { BodyLength::None => {} other => panic!("Expected None, got {other:?}"), } } + /// SEC-009: Reject requests with both Content-Length and Transfer-Encoding + /// to prevent CL/TE request smuggling (RFC 7230 Section 3.3.3). + #[test] + fn reject_dual_content_length_and_transfer_encoding() { + let headers = "POST /api HTTP/1.1\r\nHost: x\r\nContent-Length: 5\r\nTransfer-Encoding: chunked\r\n\r\n"; + assert!( + parse_body_length(headers).is_err(), + "Must reject request with both CL and TE" + ); + } + + /// SEC-009: Same rejection regardless of header order. + #[test] + fn reject_dual_transfer_encoding_and_content_length() { + let headers = "POST /api HTTP/1.1\r\nHost: x\r\nTransfer-Encoding: chunked\r\nContent-Length: 5\r\n\r\n"; + assert!( + parse_body_length(headers).is_err(), + "Must reject request with both TE and CL" + ); + } + + /// SEC-009: Bare LF in headers enables header injection. + #[tokio::test] + async fn reject_bare_lf_in_headers() { + let (mut client, mut writer) = tokio::io::duplex(4096); + tokio::spawn(async move { + // Bare \n between two header values creates a parsing discrepancy + writer + .write_all( + b"GET /api HTTP/1.1\r\nX-Injected: value\nEvil: header\r\nHost: x\r\n\r\n", + ) + .await + .unwrap(); + }); + let result = parse_http_request(&mut client).await; + assert!(result.is_err(), "Must reject headers with bare LF"); + } + + /// SEC-009: Invalid UTF-8 in headers creates interpretation gap. + #[tokio::test] + async fn reject_invalid_utf8_in_headers() { + let (mut client, mut writer) = tokio::io::duplex(4096); + tokio::spawn(async move { + let mut raw = Vec::new(); + raw.extend_from_slice(b"GET /api HTTP/1.1\r\nHost: x\r\nX-Bad: \xc0\xaf\r\n\r\n"); + writer.write_all(&raw).await.unwrap(); + }); + let result = parse_http_request(&mut client).await; + assert!(result.is_err(), "Must reject headers with invalid UTF-8"); + } + + /// SEC-009: Reject unsupported HTTP versions. + #[tokio::test] + async fn reject_invalid_http_version() { + let (mut client, mut writer) = tokio::io::duplex(4096); + tokio::spawn(async move { + writer + .write_all(b"GET /api JUNK/9.9\r\nHost: x\r\n\r\n") + .await + .unwrap(); + }); + let result = parse_http_request(&mut client).await; + assert!(result.is_err(), "Must reject unsupported HTTP version"); + } + /// Regression test: two pipelined requests in a single write must be /// parsed independently. Before the fix, the 1024-byte `read()` buffer /// could capture bytes from the second request, which were forwarded diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index d662399b..7fc267e0 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -1034,19 +1034,23 @@ async fn route_inference_request( } } +/// Map router errors to HTTP status codes and sanitized messages. +/// +/// Returns generic error messages instead of verbatim internal details. +/// Full error context (upstream URLs, hostnames, TLS details) is logged +/// server-side by the caller at `warn` level for debugging. fn router_error_to_http(err: &openshell_router::RouterError) -> (u16, String) { use openshell_router::RouterError; match err { - RouterError::RouteNotFound(hint) => { - (400, format!("no route configured for route '{hint}'")) + RouterError::RouteNotFound(_) => (400, "no inference route configured".to_string()), + RouterError::NoCompatibleRoute(_) => { + (400, "no compatible inference route available".to_string()) + } + RouterError::Unauthorized(_) => (401, "unauthorized".to_string()), + RouterError::UpstreamUnavailable(_) => (503, "inference service unavailable".to_string()), + RouterError::UpstreamProtocol(_) | RouterError::Internal(_) => { + (502, "inference service error".to_string()) } - RouterError::NoCompatibleRoute(protocol) => ( - 400, - format!("no compatible route for source protocol '{protocol}'"), - ), - RouterError::Unauthorized(msg) => (401, msg.clone()), - RouterError::UpstreamUnavailable(msg) => (503, msg.clone()), - RouterError::UpstreamProtocol(msg) | RouterError::Internal(msg) => (502, msg.clone()), } } @@ -1249,6 +1253,13 @@ async fn resolve_and_check_allowed_ips( )); } + // Block control-plane ports regardless of IP match. + if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { + return Err(format!( + "port {port} is a blocked control-plane port, connection rejected" + )); + } + for addr in &addrs { // Always block loopback and link-local if is_always_blocked_ip(addr.ip()) { @@ -1271,11 +1282,27 @@ async fn resolve_and_check_allowed_ips( Ok(addrs) } +/// Minimum CIDR prefix length before logging a breadth warning. +/// CIDRs broader than /16 (65,536+ addresses) may unintentionally expose +/// control-plane services on the same network. +const MIN_SAFE_PREFIX_LEN: u8 = 16; + +/// Ports that are always blocked in `resolve_and_check_allowed_ips`, even +/// when the resolved IP matches an `allowed_ips` entry. These ports belong +/// to control-plane services that should never be reachable from a sandbox. +const BLOCKED_CONTROL_PLANE_PORTS: &[u16] = &[ + 2379, // etcd client + 2380, // etcd peer + 6443, // Kubernetes API server + 10250, // kubelet API + 10255, // kubelet read-only +]; + /// Parse CIDR/IP strings into `IpNet` values, rejecting invalid entries and /// entries that cover loopback or link-local ranges. /// /// Returns parsed networks on success, or an error describing which entries -/// are invalid. +/// are invalid. Logs a warning for overly broad CIDRs. fn parse_allowed_ips(raw: &[String]) -> std::result::Result, String> { let mut nets = Vec::with_capacity(raw.len()); let mut errors = Vec::new(); @@ -1293,7 +1320,17 @@ fn parse_allowed_ips(raw: &[String]) -> std::result::Result, S }); match parsed { - Ok(n) => nets.push(n), + Ok(n) => { + if n.prefix_len() < MIN_SAFE_PREFIX_LEN { + warn!( + cidr = %n, + prefix_len = n.prefix_len(), + "allowed_ips entry has a very broad CIDR (< /{MIN_SAFE_PREFIX_LEN}); \ + this may expose control-plane services on the same network" + ); + } + nets.push(n); + } Err(_) => errors.push(format!("invalid CIDR/IP in allowed_ips: {entry}")), } } @@ -2070,10 +2107,9 @@ mod tests { let err = openshell_router::RouterError::RouteNotFound("local".into()); let (status, msg) = router_error_to_http(&err); assert_eq!(status, 400); - assert!( - msg.contains("local"), - "message should contain the hint: {msg}" - ); + assert_eq!(msg, "no inference route configured"); + // SEC-008: must NOT leak the route hint to sandboxed code + assert!(!msg.contains("local")); } #[test] @@ -2081,42 +2117,56 @@ mod tests { let err = openshell_router::RouterError::NoCompatibleRoute("anthropic_messages".into()); let (status, msg) = router_error_to_http(&err); assert_eq!(status, 400); - assert!( - msg.contains("anthropic_messages"), - "message should contain the protocol: {msg}" - ); + assert_eq!(msg, "no compatible inference route available"); + // SEC-008: must NOT leak the protocol name to sandboxed code + assert!(!msg.contains("anthropic_messages")); } #[test] fn router_error_unauthorized_maps_to_401() { - let err = openshell_router::RouterError::Unauthorized("bad token".into()); + let err = + openshell_router::RouterError::Unauthorized("bad token from 10.0.0.5:8080".into()); let (status, msg) = router_error_to_http(&err); assert_eq!(status, 401); - assert_eq!(msg, "bad token"); + assert_eq!(msg, "unauthorized"); + // SEC-008: must NOT leak upstream details to sandboxed code + assert!(!msg.contains("10.0.0.5")); } #[test] fn router_error_upstream_unavailable_maps_to_503() { - let err = openshell_router::RouterError::UpstreamUnavailable("connection refused".into()); + let err = openshell_router::RouterError::UpstreamUnavailable( + "connection refused to 10.0.0.5:8080".into(), + ); let (status, msg) = router_error_to_http(&err); assert_eq!(status, 503); - assert_eq!(msg, "connection refused"); + assert_eq!(msg, "inference service unavailable"); + // SEC-008: must NOT leak upstream address to sandboxed code + assert!(!msg.contains("10.0.0.5")); } #[test] fn router_error_upstream_protocol_maps_to_502() { - let err = openshell_router::RouterError::UpstreamProtocol("bad gateway".into()); + let err = openshell_router::RouterError::UpstreamProtocol( + "TLS handshake failed for nim.internal.svc:443".into(), + ); let (status, msg) = router_error_to_http(&err); assert_eq!(status, 502); - assert_eq!(msg, "bad gateway"); + assert_eq!(msg, "inference service error"); + // SEC-008: must NOT leak internal hostnames to sandboxed code + assert!(!msg.contains("nim.internal")); } #[test] fn router_error_internal_maps_to_502() { - let err = openshell_router::RouterError::Internal("unexpected".into()); + let err = openshell_router::RouterError::Internal( + "failed to read /etc/openshell/routes.json".into(), + ); let (status, msg) = router_error_to_http(&err); assert_eq!(status, 502); - assert_eq!(msg, "unexpected"); + assert_eq!(msg, "inference service error"); + // SEC-008: must NOT leak file paths to sandboxed code + assert!(!msg.contains("/etc/openshell")); } #[test] @@ -2308,6 +2358,42 @@ mod tests { ); } + // --- SEC-005: CIDR breadth warning and control-plane port blocklist --- + + #[tokio::test] + async fn test_resolve_check_allowed_ips_blocks_control_plane_ports() { + let nets = parse_allowed_ips(&["0.0.0.0/0".to_string()]).unwrap(); + // K8s API server port + let result = resolve_and_check_allowed_ips("8.8.8.8", 6443, &nets).await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("blocked control-plane port")); + + // etcd client port + let result = resolve_and_check_allowed_ips("8.8.8.8", 2379, &nets).await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("blocked control-plane port")); + + // kubelet API port + let result = resolve_and_check_allowed_ips("8.8.8.8", 10250, &nets).await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("blocked control-plane port")); + } + + #[tokio::test] + async fn test_resolve_check_allowed_ips_allows_non_control_plane_ports() { + // Port 443 should not be blocked by the control-plane port list + let nets = parse_allowed_ips(&["8.8.8.0/24".to_string()]).unwrap(); + let result = resolve_and_check_allowed_ips("8.8.8.8", 443, &nets).await; + assert!(result.is_ok()); + } + + #[test] + fn test_parse_allowed_ips_broad_cidr_is_accepted() { + // Broad CIDRs are accepted (just warned about) -- design trade-off + let result = parse_allowed_ips(&["10.0.0.0/8".to_string()]); + assert!(result.is_ok()); + } + // --- extract_host_from_uri tests --- #[test] diff --git a/crates/openshell-sandbox/src/ssh.rs b/crates/openshell-sandbox/src/ssh.rs index 5c4ef862..10eab8c4 100644 --- a/crates/openshell-sandbox/src/ssh.rs +++ b/crates/openshell-sandbox/src/ssh.rs @@ -330,12 +330,21 @@ impl russh::server::Handler for SshHandler { _originator_port: u32, _session: &mut Session, ) -> Result { + // Validate port range before truncating u32 -> u16. The SSH protocol + // uses u32 for ports, but valid TCP ports are 0-65535. Without this + // check, port 65537 truncates to port 1 (privileged). + if port_to_connect > u32::from(u16::MAX) { + warn!( + host = host_to_connect, + port = port_to_connect, + "direct-tcpip rejected: port exceeds valid TCP range (0-65535)" + ); + return Ok(false); + } + // Only allow forwarding to loopback destinations to prevent the // sandbox SSH server from being used as a generic proxy. - let is_loopback = host_to_connect == "127.0.0.1" - || host_to_connect == "localhost" - || host_to_connect == "::1"; - if !is_loopback { + if !is_loopback_host(host_to_connect) { warn!( host = host_to_connect, port = port_to_connect, @@ -345,7 +354,6 @@ impl russh::server::Handler for SshHandler { } let host = host_to_connect.to_string(); - #[allow(clippy::cast_possible_truncation)] let port = port_to_connect as u16; let netns_fd = self.netns_fd; @@ -1074,6 +1082,41 @@ fn to_u16(value: u32) -> u16 { u16::try_from(value.min(u32::from(u16::MAX))).unwrap_or(u16::MAX) } +/// Check whether a host string refers to a loopback address. +/// +/// Covers all representations that resolve to loopback: +/// - `127.0.0.0/8` (the entire IPv4 loopback range, not just `127.0.0.1`) +/// - `localhost` +/// - `::1` and long-form IPv6 loopback (`0:0:0:0:0:0:0:1`) +/// - `::ffff:127.x.x.x` (IPv4-mapped IPv6 loopback) +/// - Bracketed forms like `[::1]` +fn is_loopback_host(host: &str) -> bool { + // Strip brackets for IPv6 addresses like [::1] + let host = host + .strip_prefix('[') + .and_then(|h| h.strip_suffix(']')) + .unwrap_or(host); + + if host.eq_ignore_ascii_case("localhost") { + return true; + } + + match host.parse::() { + Ok(std::net::IpAddr::V4(v4)) => v4.is_loopback(), // covers all 127.x.x.x + Ok(std::net::IpAddr::V6(v6)) => { + if v6.is_loopback() { + return true; // covers ::1 and long form + } + // Check IPv4-mapped IPv6 addresses like ::ffff:127.0.0.1 + if let Some(v4) = v6.to_ipv4_mapped() { + return v4.is_loopback(); + } + false + } + Err(_) => false, + } +} + #[cfg(test)] mod tests { use super::*; @@ -1311,4 +1354,54 @@ mod tests { assert!(!stdout.contains(SSH_HANDSHAKE_SECRET_ENV)); assert!(stdout.contains("ANTHROPIC_API_KEY=openshell:resolve:env:ANTHROPIC_API_KEY")); } + + // ----------------------------------------------------------------------- + // SEC-007: is_loopback_host tests + // ----------------------------------------------------------------------- + + #[test] + fn loopback_host_accepts_standard_ipv4() { + assert!(is_loopback_host("127.0.0.1")); + } + + #[test] + fn loopback_host_accepts_full_ipv4_range() { + assert!(is_loopback_host("127.0.0.2")); + assert!(is_loopback_host("127.255.255.255")); + } + + #[test] + fn loopback_host_accepts_localhost() { + assert!(is_loopback_host("localhost")); + assert!(is_loopback_host("LOCALHOST")); + assert!(is_loopback_host("Localhost")); + } + + #[test] + fn loopback_host_accepts_ipv6_loopback() { + assert!(is_loopback_host("::1")); + assert!(is_loopback_host("[::1]")); + assert!(is_loopback_host("0:0:0:0:0:0:0:1")); + } + + #[test] + fn loopback_host_accepts_ipv4_mapped_ipv6() { + assert!(is_loopback_host("::ffff:127.0.0.1")); + } + + #[test] + fn loopback_host_rejects_non_loopback() { + assert!(!is_loopback_host("10.0.0.1")); + assert!(!is_loopback_host("192.168.1.1")); + assert!(!is_loopback_host("8.8.8.8")); + assert!(!is_loopback_host("example.com")); + assert!(!is_loopback_host("::ffff:10.0.0.1")); + } + + #[test] + fn loopback_host_rejects_empty_and_garbage() { + assert!(!is_loopback_host("")); + assert!(!is_loopback_host("not-an-ip")); + assert!(!is_loopback_host("[]")); + } } diff --git a/crates/openshell-server/src/grpc.rs b/crates/openshell-server/src/grpc.rs index b968be27..de73da6f 100644 --- a/crates/openshell-server/src/grpc.rs +++ b/crates/openshell-server/src/grpc.rs @@ -1010,6 +1010,7 @@ impl OpenShell for OpenShellService { "environment keys must match ^[A-Za-z_][A-Za-z0-9_]*$", )); } + validate_exec_request_fields(&req)?; let sandbox = self .state @@ -1024,7 +1025,8 @@ impl OpenShell for OpenShellService { } let (target_host, target_port) = resolve_sandbox_exec_target(&self.state, &sandbox).await?; - let command_str = build_remote_exec_command(&req); + let command_str = build_remote_exec_command(&req) + .map_err(|e| Status::invalid_argument(format!("command construction failed: {e}")))?; let stdin_payload = req.stdin; let timeout_seconds = req.timeout_seconds; let sandbox_id = sandbox.id; @@ -3407,34 +3409,122 @@ async fn resolve_sandbox_exec_target( )) } -fn shell_escape(value: &str) -> String { +/// Maximum number of arguments in the command array. +const MAX_EXEC_COMMAND_ARGS: usize = 1024; +/// Maximum length of a single command argument or environment value (bytes). +const MAX_EXEC_ARG_LEN: usize = 32 * 1024; // 32 KiB +/// Maximum length of the workdir field (bytes). +const MAX_EXEC_WORKDIR_LEN: usize = 4096; + +/// Validate fields of an ExecSandboxRequest for control characters and size +/// limits before constructing a shell command string. +fn validate_exec_request_fields(req: &ExecSandboxRequest) -> Result<(), Status> { + if req.command.len() > MAX_EXEC_COMMAND_ARGS { + return Err(Status::invalid_argument(format!( + "command array exceeds {} argument limit", + MAX_EXEC_COMMAND_ARGS + ))); + } + for (i, arg) in req.command.iter().enumerate() { + if arg.len() > MAX_EXEC_ARG_LEN { + return Err(Status::invalid_argument(format!( + "command argument {i} exceeds {} byte limit", + MAX_EXEC_ARG_LEN + ))); + } + reject_control_chars(arg, &format!("command argument {i}"))?; + } + for (key, value) in &req.environment { + if value.len() > MAX_EXEC_ARG_LEN { + return Err(Status::invalid_argument(format!( + "environment value for '{key}' exceeds {} byte limit", + MAX_EXEC_ARG_LEN + ))); + } + reject_control_chars(value, &format!("environment value for '{key}'"))?; + } + if !req.workdir.is_empty() { + if req.workdir.len() > MAX_EXEC_WORKDIR_LEN { + return Err(Status::invalid_argument(format!( + "workdir exceeds {} byte limit", + MAX_EXEC_WORKDIR_LEN + ))); + } + reject_control_chars(&req.workdir, "workdir")?; + } + Ok(()) +} + +/// Reject null bytes and newlines in a user-supplied value. +fn reject_control_chars(value: &str, field_name: &str) -> Result<(), Status> { + if value.bytes().any(|b| b == 0) { + return Err(Status::invalid_argument(format!( + "{field_name} contains null bytes" + ))); + } + if value.bytes().any(|b| b == b'\n' || b == b'\r') { + return Err(Status::invalid_argument(format!( + "{field_name} contains newline or carriage return characters" + ))); + } + Ok(()) +} + +/// Shell-escape a value for embedding in a POSIX shell command. +/// +/// Wraps unsafe values in single quotes with the standard `'\''` idiom for +/// embedded single-quote characters. Rejects null bytes which can truncate +/// shell parsing at the C level. +fn shell_escape(value: &str) -> Result { + // Reject null bytes — can truncate shell parsing at the C level. + if value.bytes().any(|b| b == 0) { + return Err("value contains null bytes".to_string()); + } + // Reject newlines and carriage returns — safe within single quotes for + // one shell layer, but dangerous when the command string traverses + // multiple interpretation boundaries (SSH transport + bash -lc). + if value.bytes().any(|b| b == b'\n' || b == b'\r') { + return Err("value contains newline or carriage return".to_string()); + } if value.is_empty() { - return "''".to_string(); + return Ok("''".to_string()); } let safe = value .bytes() .all(|byte| byte.is_ascii_alphanumeric() || matches!(byte, b'.' | b'/' | b'-' | b'_')); if safe { - return value.to_string(); + return Ok(value.to_string()); } let escaped = value.replace('\'', "'\"'\"'"); - format!("'{escaped}'") + Ok(format!("'{escaped}'")) } -fn build_remote_exec_command(req: &ExecSandboxRequest) -> String { +/// Maximum total length of the assembled shell command string. +const MAX_COMMAND_STRING_LEN: usize = 256 * 1024; // 256 KiB + +fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result { let mut parts = Vec::new(); let mut env_entries = req.environment.iter().collect::>(); env_entries.sort_by(|(a, _), (b, _)| a.cmp(b)); for (key, value) in env_entries { - parts.push(format!("{key}={}", shell_escape(value))); + parts.push(format!("{key}={}", shell_escape(value)?)); + } + for arg in &req.command { + parts.push(shell_escape(arg)?); } - parts.extend(req.command.iter().map(|arg| shell_escape(arg))); let command = parts.join(" "); - if req.workdir.is_empty() { + let result = if req.workdir.is_empty() { command } else { - format!("cd {} && {command}", shell_escape(&req.workdir)) + format!("cd {} && {command}", shell_escape(&req.workdir)?) + }; + if result.len() > MAX_COMMAND_STRING_LEN { + return Err(format!( + "assembled command string exceeds {} byte limit", + MAX_COMMAND_STRING_LEN + )); } + Ok(result) } /// Resolve provider credentials into environment variables. @@ -3523,10 +3613,14 @@ async fn stream_exec_over_ssh( timeout_seconds: u32, handshake_secret: &str, ) -> Result<(), Status> { + let command_preview: String = command.chars().take(120).collect(); info!( sandbox_id = %sandbox_id, target_host = %target_host, target_port, + command_len = command.len(), + stdin_len = stdin_payload.len(), + command_preview = %command_preview, "ExecSandbox command started" ); @@ -3646,6 +3740,20 @@ async fn run_exec_with_russh( stdin_payload: Vec, tx: mpsc::Sender>, ) -> Result { + // Defense-in-depth: validate command at the transport boundary even though + // exec_sandbox and build_remote_exec_command already validate upstream. + if command.as_bytes().contains(&0) { + return Err(Status::invalid_argument( + "command contains null bytes at transport boundary", + )); + } + if command.len() > MAX_COMMAND_STRING_LEN { + return Err(Status::invalid_argument(format!( + "command exceeds {} byte limit at transport boundary", + MAX_COMMAND_STRING_LEN + ))); + } + let stream = TcpStream::connect(("127.0.0.1", local_proxy_port)) .await .map_err(|e| Status::internal(format!("failed to connect to ssh proxy: {e}")))?; @@ -3732,6 +3840,29 @@ async fn run_exec_with_russh( Ok(exit_code.unwrap_or(1)) } +/// Check whether an IP address is safe to use as an SSH proxy target. +/// +/// Blocks loopback (prevents connecting back to the gateway server itself) +/// and link-local addresses (prevents cloud metadata SSRF via 169.254.169.254). +fn is_safe_ssh_proxy_target(ip: std::net::IpAddr) -> bool { + match ip { + std::net::IpAddr::V4(v4) => { + !v4.is_loopback() // 127.0.0.0/8 + && !v4.is_link_local() // 169.254.0.0/16 (cloud metadata) + } + std::net::IpAddr::V6(v6) => { + if v6.is_loopback() { + return false; // ::1 + } + // Check IPv4-mapped IPv6 addresses like ::ffff:127.0.0.1 + if let Some(v4) = v6.to_ipv4_mapped() { + return !v4.is_loopback() && !v4.is_link_local(); + } + true + } + } +} + async fn start_single_use_ssh_proxy( target_host: &str, target_port: u16, @@ -3747,9 +3878,45 @@ async fn start_single_use_ssh_proxy( warn!("SSH proxy: failed to accept local connection"); return; }; - let Ok(mut sandbox_conn) = TcpStream::connect((target_host.as_str(), target_port)).await - else { - warn!(target_host = %target_host, target_port, "SSH proxy: failed to connect to sandbox"); + + // Resolve DNS and validate the target IP before connecting. + // This prevents SSRF if the sandbox status record were poisoned + // to point at loopback, cloud metadata, or other internal services. + let addr_str = format!("{}:{}", target_host, target_port); + let resolved = match tokio::net::lookup_host(&addr_str).await { + Ok(mut addrs) => match addrs.next() { + Some(addr) => addr, + None => { + warn!(target_host = %target_host, "SSH proxy: DNS resolution returned no addresses"); + return; + } + }, + Err(e) => { + warn!(target_host = %target_host, error = %e, "SSH proxy: DNS resolution failed"); + return; + } + }; + + if !is_safe_ssh_proxy_target(resolved.ip()) { + warn!( + target_host = %target_host, + resolved_ip = %resolved.ip(), + "SSH proxy: target resolved to blocked IP range (loopback or link-local)" + ); + return; + } + + debug!( + target_host = %target_host, + resolved_ip = %resolved.ip(), + target_port, + "SSH proxy: connecting to validated target" + ); + + // Connect to the resolved address directly (not the hostname) to + // prevent TOCTOU between validation and connection. + let Ok(mut sandbox_conn) = TcpStream::connect(resolved).await else { + warn!(target_host = %target_host, resolved_ip = %resolved.ip(), target_port, "SSH proxy: failed to connect to sandbox"); return; }; let Ok(preface) = build_preface(&uuid::Uuid::new_v4().to_string(), &handshake_secret) @@ -4016,9 +4183,10 @@ mod tests { MAX_ENVIRONMENT_ENTRIES, MAX_LOG_LEVEL_LEN, MAX_MAP_KEY_LEN, MAX_MAP_VALUE_LEN, MAX_NAME_LEN, MAX_PAGE_SIZE, MAX_POLICY_SIZE, MAX_PROVIDER_CONFIG_ENTRIES, MAX_PROVIDER_CREDENTIALS_ENTRIES, MAX_PROVIDER_TYPE_LEN, MAX_PROVIDERS, - MAX_TEMPLATE_MAP_ENTRIES, MAX_TEMPLATE_STRING_LEN, MAX_TEMPLATE_STRUCT_SIZE, clamp_limit, - create_provider_record, delete_provider_record, get_provider_record, is_valid_env_key, - list_provider_records, merge_chunk_into_policy, resolve_provider_environment, + MAX_TEMPLATE_MAP_ENTRIES, MAX_TEMPLATE_STRING_LEN, MAX_TEMPLATE_STRUCT_SIZE, + build_remote_exec_command, clamp_limit, create_provider_record, delete_provider_record, + get_provider_record, is_safe_ssh_proxy_target, is_valid_env_key, list_provider_records, + merge_chunk_into_policy, reject_control_chars, resolve_provider_environment, shell_escape, update_provider_record, validate_provider_fields, validate_sandbox_spec, }; use crate::persistence::{DraftChunkRecord, Store}; @@ -4044,6 +4212,176 @@ mod tests { assert!(!is_valid_env_key("X;rm -rf /")); } + // ---- SEC-002: shell_escape, reject_control_chars, build_remote_exec_command ---- + + #[test] + fn shell_escape_safe_chars_pass_through() { + assert_eq!(shell_escape("ls").unwrap(), "ls"); + assert_eq!(shell_escape("/usr/bin/python").unwrap(), "/usr/bin/python"); + assert_eq!(shell_escape("file.txt").unwrap(), "file.txt"); + assert_eq!(shell_escape("my-cmd_v2").unwrap(), "my-cmd_v2"); + } + + #[test] + fn shell_escape_empty_string() { + assert_eq!(shell_escape("").unwrap(), "''"); + } + + #[test] + fn shell_escape_wraps_unsafe_chars() { + assert_eq!(shell_escape("hello world").unwrap(), "'hello world'"); + assert_eq!(shell_escape("$(id)").unwrap(), "'$(id)'"); + assert_eq!(shell_escape("; rm -rf /").unwrap(), "'; rm -rf /'"); + } + + #[test] + fn shell_escape_handles_single_quotes() { + assert_eq!(shell_escape("it's").unwrap(), "'it'\"'\"'s'"); + } + + #[test] + fn shell_escape_rejects_null_bytes() { + assert!(shell_escape("hello\x00world").is_err()); + } + + #[test] + fn shell_escape_rejects_newlines() { + assert!(shell_escape("line1\nline2").is_err()); + assert!(shell_escape("line1\rline2").is_err()); + assert!(shell_escape("line1\r\nline2").is_err()); + } + + #[test] + fn reject_control_chars_allows_normal_values() { + assert!(reject_control_chars("hello world", "test").is_ok()); + assert!(reject_control_chars("$(cmd)", "test").is_ok()); + assert!(reject_control_chars("", "test").is_ok()); + } + + #[test] + fn reject_control_chars_rejects_null_bytes() { + assert!(reject_control_chars("hello\x00world", "test").is_err()); + } + + #[test] + fn reject_control_chars_rejects_newlines() { + assert!(reject_control_chars("line1\nline2", "test").is_err()); + assert!(reject_control_chars("line1\rline2", "test").is_err()); + } + + #[test] + fn build_remote_exec_command_basic() { + use openshell_core::proto::ExecSandboxRequest; + let req = ExecSandboxRequest { + sandbox_id: "test".to_string(), + command: vec!["ls".to_string(), "-la".to_string()], + ..Default::default() + }; + assert_eq!(build_remote_exec_command(&req).unwrap(), "ls -la"); + } + + #[test] + fn build_remote_exec_command_with_env_and_workdir() { + use openshell_core::proto::ExecSandboxRequest; + let req = ExecSandboxRequest { + sandbox_id: "test".to_string(), + command: vec![ + "python".to_string(), + "-c".to_string(), + "print('ok')".to_string(), + ], + environment: [("HOME".to_string(), "/home/user".to_string())] + .into_iter() + .collect(), + workdir: "/workspace".to_string(), + ..Default::default() + }; + let cmd = build_remote_exec_command(&req).unwrap(); + assert!(cmd.starts_with("cd /workspace && ")); + assert!(cmd.contains("HOME=/home/user")); + assert!(cmd.contains("'print('\"'\"'ok'\"'\"')'")); + } + + #[test] + fn build_remote_exec_command_rejects_null_bytes_in_args() { + use openshell_core::proto::ExecSandboxRequest; + let req = ExecSandboxRequest { + sandbox_id: "test".to_string(), + command: vec!["echo".to_string(), "hello\x00world".to_string()], + ..Default::default() + }; + assert!(build_remote_exec_command(&req).is_err()); + } + + #[test] + fn build_remote_exec_command_rejects_newlines_in_workdir() { + use openshell_core::proto::ExecSandboxRequest; + let req = ExecSandboxRequest { + sandbox_id: "test".to_string(), + command: vec!["ls".to_string()], + workdir: "/tmp\nmalicious".to_string(), + ..Default::default() + }; + assert!(build_remote_exec_command(&req).is_err()); + } + + // ---- SEC-006: is_safe_ssh_proxy_target ---- + + #[test] + fn ssh_proxy_target_allows_pod_network_ips() { + use std::net::{IpAddr, Ipv4Addr}; + // Typical pod network IPs should be allowed + assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( + 10, 0, 0, 5 + )))); + assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( + 172, 16, 0, 1 + )))); + assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( + 192, 168, 1, 100 + )))); + } + + #[test] + fn ssh_proxy_target_blocks_loopback() { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( + 127, 0, 0, 1 + )))); + assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( + 127, 0, 0, 2 + )))); + assert!(!is_safe_ssh_proxy_target(IpAddr::V6(Ipv6Addr::LOCALHOST))); + } + + #[test] + fn ssh_proxy_target_blocks_link_local() { + use std::net::{IpAddr, Ipv4Addr}; + // 169.254.169.254 is the cloud metadata endpoint + assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( + 169, 254, 169, 254 + )))); + assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( + 169, 254, 0, 1 + )))); + } + + #[test] + fn ssh_proxy_target_blocks_ipv4_mapped_ipv6_loopback() { + use std::net::IpAddr; + // ::ffff:127.0.0.1 + let ip: IpAddr = "::ffff:127.0.0.1".parse().unwrap(); + assert!(!is_safe_ssh_proxy_target(ip)); + } + + #[test] + fn ssh_proxy_target_blocks_ipv4_mapped_ipv6_link_local() { + use std::net::IpAddr; + // ::ffff:169.254.169.254 + let ip: IpAddr = "::ffff:169.254.169.254".parse().unwrap(); + assert!(!is_safe_ssh_proxy_target(ip)); + } + // ---- clamp_limit tests ---- #[test] diff --git a/e2e/python/test_inference_routing.py b/e2e/python/test_inference_routing.py index 5f402752..bda0d8cb 100644 --- a/e2e/python/test_inference_routing.py +++ b/e2e/python/test_inference_routing.py @@ -319,7 +319,7 @@ def call_anthropic_messages() -> str: assert result.exit_code == 0, f"stderr: {result.stderr}" output = result.stdout.strip() assert output.startswith("http_error_400"), output - assert "no compatible route" in output + assert "no compatible inference route" in output def test_non_inference_host_is_not_intercepted(