From 3069a2b9ce31c8cfab83dd9e3510f231ac358f19 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 08:56:17 -0700 Subject: [PATCH 01/12] test(e2e): add JSON-RPC L7 proxy coverage Add a Rust e2e test that drives MCP-style JSON-RPC requests through both the forward proxy and CONNECT tunnel paths. Cover method rules, params rules, batch handling, and invalid JSON denial expectations so the JSON-RPC implementation can be built against one failing scenario. Signed-off-by: Kris Hicks --- e2e/rust/Cargo.toml | 5 + e2e/rust/tests/forward_proxy_jsonrpc_l7.rs | 373 +++++++++++++++++++++ 2 files changed, 378 insertions(+) create mode 100644 e2e/rust/tests/forward_proxy_jsonrpc_l7.rs diff --git a/e2e/rust/Cargo.toml b/e2e/rust/Cargo.toml index 083c622df..2f61f2d86 100644 --- a/e2e/rust/Cargo.toml +++ b/e2e/rust/Cargo.toml @@ -97,6 +97,11 @@ name = "forward_proxy_graphql_l7" path = "tests/forward_proxy_graphql_l7.rs" required-features = ["e2e-host-gateway"] +[[test]] +name = "forward_proxy_jsonrpc_l7" +path = "tests/forward_proxy_jsonrpc_l7.rs" +required-features = ["e2e-host-gateway"] + [[test]] name = "gpu_device_selection" path = "tests/gpu_device_selection.rs" diff --git a/e2e/rust/tests/forward_proxy_jsonrpc_l7.rs b/e2e/rust/tests/forward_proxy_jsonrpc_l7.rs new file mode 100644 index 000000000..feba98ceb --- /dev/null +++ b/e2e/rust/tests/forward_proxy_jsonrpc_l7.rs @@ -0,0 +1,373 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! E2E tests for JSON-RPC L7 inspection across both proxy entry points. +//! +//! The upstream server deliberately does not implement JSON-RPC. `OpenShell` +//! parses and enforces JSON-RPC before forwarding, so any HTTP server that +//! accepts POST /mcp is enough to prove allowed requests reach upstream +//! and denied requests are stopped by the sandbox proxy. + +#![cfg(feature = "e2e")] + +use std::io::Write; + +use openshell_e2e::harness::container::ContainerHttpServer; +use openshell_e2e::harness::sandbox::SandboxGuard; +use tempfile::NamedTempFile; + +const TEST_SERVER_ALIAS: &str = "jsonrpc-l7.openshell.test"; + +async fn start_test_server() -> Result { + let script = r#"from http.server import BaseHTTPRequestHandler, HTTPServer + +class Handler(BaseHTTPRequestHandler): + def read_body(self): + if self.headers.get("Transfer-Encoding", "").lower() == "chunked": + data = b"" + while True: + size_line = self.rfile.readline() + if not size_line: + break + size = int(size_line.split(b";", 1)[0].strip(), 16) + if size == 0: + while self.rfile.readline().strip(): + pass + break + data += self.rfile.read(size) + self.rfile.read(2) + return data + return self.rfile.read(int(self.headers.get("Content-Length", "0"))) + + def do_GET(self): + self.send_response(200) + self.end_headers() + + def do_POST(self): + self.read_body() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(b'{"jsonrpc":"2.0","id":1,"result":{}}') + + def log_message(self, format, *args): + pass + +HTTPServer(("0.0.0.0", 8000), Handler).serve_forever() +"#; + + ContainerHttpServer::start_python(TEST_SERVER_ALIAS, script).await +} + +fn write_jsonrpc_policy(host: &str, port: u16) -> Result { + let mut file = NamedTempFile::new().map_err(|e| format!("create temp policy file: {e}"))?; + let policy = format!( + r#"version: 1 + +filesystem_policy: + include_workdir: true + read_only: + - /usr + - /lib + - /proc + - /dev/urandom + - /app + - /etc + - /var/log + read_write: + - /sandbox + - /tmp + - /dev/null + +landlock: + compatibility: best_effort + +process: + run_as_user: sandbox + run_as_group: sandbox + +network_policies: + test_jsonrpc_l7: + name: test_jsonrpc_l7 + endpoints: + - host: {host} + port: {port} + path: /mcp + protocol: json-rpc + enforcement: enforce + allowed_ips: + - "10.0.0.0/8" + - "172.0.0.0/8" + - "192.168.0.0/16" + - "fc00::/7" + json_rpc: + max_body_bytes: 65536 + on_parse_error: deny + batch_policy: deny_if_any_denied + rules: + - allow: + rpc_method: initialize + - allow: + rpc_method: tools/list + - allow: + rpc_method: tools/call + params: + name: read_status + - allow: + rpc_method: tools/call + params: + name: submit_report + arguments.scope: workspace/main + deny_rules: + - rpc_method: tools/call + params: + name: blocked_action + binaries: + - path: /usr/bin/python* + - path: /usr/local/bin/python* + - path: /sandbox/.uv/python/*/bin/python* +"# + ); + file.write_all(policy.as_bytes()) + .map_err(|e| format!("write temp policy file: {e}"))?; + file.flush() + .map_err(|e| format!("flush temp policy file: {e}"))?; + Ok(file) +} + +#[tokio::test] +#[allow(clippy::too_many_lines)] +async fn jsonrpc_l7_enforces_method_and_params_rules_on_forward_and_connect_paths() { + let server = start_test_server().await.expect("start test server"); + let policy = write_jsonrpc_policy(&server.host, server.port).expect("write custom policy"); + let policy_path = policy + .path() + .to_str() + .expect("temp policy path should be utf-8") + .to_string(); + + let script = format!( + r#" +import json +import os +import socket +import time +import urllib.error +import urllib.parse +import urllib.request + +HOST = {host:?} +PORT = {port} +DETAILS = {{}} + +def post_jsonrpc(method, params=None, req_id=1): + body = {{"jsonrpc": "2.0", "id": req_id, "method": method}} + if params is not None: + body["params"] = params + encoded = json.dumps(body).encode() + request = urllib.request.Request( + f"http://{{HOST}}:{{PORT}}/mcp", + data=encoded, + headers={{"Content-Type": "application/json"}}, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=15) as response: + response.read() + return response.status + except urllib.error.HTTPError as error: + error.read() + return error.code + +def post_jsonrpc_batch(requests): + encoded = json.dumps(requests).encode() + request = urllib.request.Request( + f"http://{{HOST}}:{{PORT}}/mcp", + data=encoded, + headers={{"Content-Type": "application/json"}}, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=15) as response: + response.read() + return response.status + except urllib.error.HTTPError as error: + error.read() + return error.code + +def post_invalid_json(): + encoded = b"not valid json {{" + request = urllib.request.Request( + f"http://{{HOST}}:{{PORT}}/mcp", + data=encoded, + headers={{"Content-Type": "application/json", "Content-Length": str(len(encoded))}}, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=15) as response: + response.read() + return response.status + except urllib.error.HTTPError as error: + error.read() + return error.code + +def proxy_parts(*names): + proxy_url = next((os.environ.get(name) for name in names if os.environ.get(name)), None) + parsed = urllib.parse.urlparse(proxy_url) + return parsed.hostname, parsed.port or 80 + +def read_until(sock, marker): + data = b"" + while marker not in data: + chunk = sock.recv(4096) + if not chunk: + break + data += chunk + return data + +def read_response(sock): + response = read_until(sock, b"\r\n\r\n") + headers, _, body = response.partition(b"\r\n\r\n") + content_length = 0 + for line in headers.split(b"\r\n")[1:]: + if line.lower().startswith(b"content-length:"): + content_length = int(line.split(b":", 1)[1].strip()) + break + while len(body) < content_length: + chunk = sock.recv(4096) + if not chunk: + break + body += chunk + return response, body + +def status_code(response, label): + parts = response.split() + if len(parts) < 2: + DETAILS[f"{{label}}_raw"] = response.decode(errors="replace") + raise RuntimeError(f"{{label}}: malformed HTTP response: {{response!r}}") + try: + return int(parts[1]) + except ValueError as error: + DETAILS[f"{{label}}_raw"] = response.decode(errors="replace") + raise RuntimeError(f"{{label}}: non-numeric HTTP status: {{response!r}}") from error + +def connect_http_status(label, request): + proxy_host, proxy_port = proxy_parts("HTTP_PROXY", "http_proxy", "HTTPS_PROXY", "https_proxy") + target = f"{{HOST}}:{{PORT}}" + + last_error = None + for attempt in range(5): + try: + with socket.create_connection((proxy_host, proxy_port), timeout=15) as sock: + sock.sendall( + f"CONNECT {{target}} HTTP/1.1\r\nHost: {{target}}\r\n\r\n".encode() + ) + connect_response = read_until(sock, b"\r\n\r\n") + connect_code = status_code(connect_response, f"{{label}}_connect") + if connect_code != 200: + return connect_code + sock.sendall(request) + sock.shutdown(socket.SHUT_WR) + response = read_until(sock, b"\r\n\r\n") + return status_code(response, f"{{label}}_response") + except (OSError, RuntimeError) as error: + last_error = error + DETAILS[f"{{label}}_attempt_{{attempt + 1}}_error"] = str(error) + time.sleep(0.2) + + raise RuntimeError(f"{{label}}: failed after 5 attempts: {{last_error}}") + +def connect_jsonrpc_status(method, params, label): + target = f"{{HOST}}:{{PORT}}" + body = {{"jsonrpc": "2.0", "id": 1, "method": method}} + if params is not None: + body["params"] = params + encoded = json.dumps(body).encode() + request = ( + f"POST /mcp HTTP/1.1\r\n" + f"Host: {{target}}\r\n" + f"Content-Type: application/json\r\n" + f"Content-Length: {{len(encoded)}}\r\n" + f"Connection: close\r\n" + f"\r\n" + ).encode() + encoded + return connect_http_status(label, request) + +results = {{ + # forward proxy — method-only allow rules + "forward_method_initialize_allowed": post_jsonrpc("initialize", {{"protocolVersion": "2025-11-25", "capabilities": {{}}}}), + "forward_method_tools_list_allowed": post_jsonrpc("tools/list"), + + # forward proxy — params allow rules + "forward_tools_call_params_name_no_args_allowed": post_jsonrpc("tools/call", {{"name": "read_status"}}), + "forward_tools_call_params_nested_args_allowed": post_jsonrpc("tools/call", {{"name": "submit_report", "arguments": {{"scope": "workspace/main", "title": "test"}}}}), + + # forward proxy — params denied + "forward_tools_call_params_name_no_args_denied": post_jsonrpc("tools/call", {{"name": "blocked_action"}}), + "forward_tools_call_params_name_with_args_denied": post_jsonrpc("tools/call", {{"name": "blocked_action", "arguments": {{"reason": "test"}}}}), + + # forward proxy — batch: all requests allowed + "forward_batch_all_allowed": post_jsonrpc_batch([ + {{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}}, + {{"jsonrpc": "2.0", "id": 2, "method": "tools/call", "params": {{"name": "read_status"}}}}, + ]), + + # forward proxy — batch: one denied request causes full batch denial + "forward_batch_one_denied": post_jsonrpc_batch([ + {{"jsonrpc": "2.0", "id": 1, "method": "tools/list"}}, + {{"jsonrpc": "2.0", "id": 2, "method": "tools/call", "params": {{"name": "blocked_action"}}}}, + ]), + + # forward proxy — invalid JSON body denied by on_parse_error: deny + "forward_invalid_json_denied": post_invalid_json(), + + # CONNECT path — representative allowed and denied cases + "connect_method_initialize_allowed": connect_jsonrpc_status("initialize", {{"protocolVersion": "2025-11-25", "capabilities": {{}}}}, "connect_method_initialize_allowed"), + "connect_method_tools_list_allowed": connect_jsonrpc_status("tools/list", None, "connect_method_tools_list_allowed"), + "connect_tools_call_params_name_no_args_allowed": connect_jsonrpc_status("tools/call", {{"name": "read_status"}}, "connect_tools_call_params_name_no_args_allowed"), + "connect_tools_call_params_nested_args_allowed": connect_jsonrpc_status("tools/call", {{"name": "submit_report", "arguments": {{"scope": "workspace/main"}}}}, "connect_tools_call_params_nested_args_allowed"), + "connect_tools_call_params_name_no_args_denied": connect_jsonrpc_status("tools/call", {{"name": "blocked_action"}}, "connect_tools_call_params_name_no_args_denied"), + "connect_tools_call_params_name_with_args_denied": connect_jsonrpc_status("tools/call", {{"name": "blocked_action", "arguments": {{"reason": "test"}}}}, "connect_tools_call_params_name_with_args_denied"), +}} +results.update(DETAILS) +print(json.dumps(results, sort_keys=True)) +"#, + host = server.host, + port = server.port, + ); + + let guard = SandboxGuard::create(&["--policy", &policy_path, "--", "python3", "-c", &script]) + .await + .expect("sandbox create"); + + for (key, expected) in [ + // forward proxy — allowed + ("forward_method_initialize_allowed", 200), + ("forward_method_tools_list_allowed", 200), + ("forward_tools_call_params_name_no_args_allowed", 200), + ("forward_tools_call_params_nested_args_allowed", 200), + // forward proxy — params denied + ("forward_tools_call_params_name_no_args_denied", 403), + ("forward_tools_call_params_name_with_args_denied", 403), + // forward proxy — batch + ("forward_batch_all_allowed", 200), + ("forward_batch_one_denied", 403), + // forward proxy — parse error + ("forward_invalid_json_denied", 403), + // CONNECT path — allowed + ("connect_method_initialize_allowed", 200), + ("connect_method_tools_list_allowed", 200), + ("connect_tools_call_params_name_no_args_allowed", 200), + ("connect_tools_call_params_nested_args_allowed", 200), + // CONNECT path — params denied + ("connect_tools_call_params_name_no_args_denied", 403), + ("connect_tools_call_params_name_with_args_denied", 403), + ] { + let expected_fragment = format!(r#""{key}": {expected}"#); + assert!( + guard.create_output.contains(&expected_fragment), + "expected {key}={expected}, got:\n{}", + guard.create_output + ); + } +} From f1e29aa2e2d6fe9db0f44e0f9e76c4c5e0192890 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 09:03:21 -0700 Subject: [PATCH 02/12] feat(policy): recognize JSON-RPC L7 endpoints Add json-rpc as a policy protocol and carry JSON-RPC rule fields through policy parsing and validation. Wire the protocol into the L7 dispatcher with a passthrough placeholder so later commits can add enforcement without changing endpoint recognition. Signed-off-by: Kris Hicks --- crates/openshell-policy/src/lib.rs | 26 + crates/openshell-sandbox/src/l7/mod.rs | 2420 +++++++++++++++++++++ crates/openshell-sandbox/src/l7/relay.rs | 2429 ++++++++++++++++++++++ 3 files changed, 4875 insertions(+) create mode 100644 crates/openshell-sandbox/src/l7/mod.rs create mode 100644 crates/openshell-sandbox/src/l7/relay.rs diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index aaabbf926..62c2868df 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -135,6 +135,8 @@ struct NetworkEndpointDef { graphql_persisted_queries: BTreeMap, #[serde(default, skip_serializing_if = "is_zero_u32")] graphql_max_body_bytes: u32, + #[serde(default, skip_serializing_if = "Option::is_none")] + json_rpc: Option, } // Signature dictated by serde's `skip_serializing_if`, which requires `&T`. @@ -149,6 +151,17 @@ fn is_zero_u32(v: &u32) -> bool { *v == 0 } +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct JsonRpcConfigDef { + #[serde(default, skip_serializing_if = "is_zero_u32")] + max_body_bytes: u32, + #[serde(default, skip_serializing_if = "String::is_empty")] + on_parse_error: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + batch_policy: String, +} + #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct GraphqlOperationDef { @@ -183,6 +196,10 @@ struct L7AllowDef { operation_name: String, #[serde(default, skip_serializing_if = "Vec::is_empty")] fields: Vec, + #[serde(default, skip_serializing_if = "String::is_empty")] + rpc_method: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + params: BTreeMap, } #[derive(Debug, Serialize, Deserialize)] @@ -216,6 +233,10 @@ struct L7DenyRuleDef { operation_name: String, #[serde(default, skip_serializing_if = "Vec::is_empty")] fields: Vec, + #[serde(default, skip_serializing_if = "String::is_empty")] + rpc_method: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + params: BTreeMap, } #[derive(Debug, Serialize, Deserialize)] @@ -462,6 +483,8 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { (key, yaml_matcher) }) .collect(), + rpc_method: String::new(), + params: BTreeMap::new(), }, } }) @@ -491,6 +514,8 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { (key.clone(), yaml_matcher) }) .collect(), + rpc_method: String::new(), + params: BTreeMap::new(), }) .collect(), allow_encoded_slash: e.allow_encoded_slash, @@ -512,6 +537,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(), graphql_max_body_bytes: e.graphql_max_body_bytes, + json_rpc: None, } }) .collect(), diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs new file mode 100644 index 000000000..ce6747c6d --- /dev/null +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -0,0 +1,2420 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! L7 protocol-aware inspection for the CONNECT proxy. +//! +//! When an endpoint is configured with a `protocol` field (e.g. `rest`, `sql`), +//! the proxy inspects application-layer traffic within the tunnel instead of +//! doing a raw `copy_bidirectional`. Each request within the tunnel is parsed, +//! evaluated against OPA policy, and either forwarded or denied. + +pub mod graphql; +pub mod inference; +pub mod path; +pub mod provider; +pub mod relay; +pub mod rest; +pub mod tls; +pub(crate) mod token_grant_injection; +pub(crate) mod websocket; + +/// Application-layer protocol for L7 inspection. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum L7Protocol { + Rest, + Websocket, + Graphql, + Sql, + JsonRpc, +} + +impl L7Protocol { + pub fn parse(s: &str) -> Option { + match s.to_ascii_lowercase().as_str() { + "rest" => Some(Self::Rest), + "websocket" => Some(Self::Websocket), + "graphql" => Some(Self::Graphql), + "sql" => Some(Self::Sql), + "json-rpc" => Some(Self::JsonRpc), + _ => None, + } + } +} + +/// TLS handling mode for proxy connections. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum TlsMode { + /// Auto-detect TLS by peeking the first bytes. If TLS is detected, + /// terminate it transparently. This is the default for all endpoints. + #[default] + Auto, + /// Explicit opt-out: raw tunnel with no TLS termination and no credential + /// injection. Use for client-cert mTLS to upstream or non-standard protocols. + Skip, +} + +/// Enforcement mode for L7 policy decisions. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum EnforcementMode { + /// Log violations but allow traffic through (safe migration path). + #[default] + Audit, + /// Deny violations — blocked requests never reach upstream. + Enforce, +} + +/// L7 configuration for an endpoint, extracted from policy data. +#[allow( + clippy::struct_excessive_bools, + reason = "Endpoint config mirrors independent policy schema toggles." +)] +#[derive(Debug, Clone)] +pub struct L7EndpointConfig { + pub protocol: L7Protocol, + /// Optional endpoint-level HTTP path glob used to select between L7 + /// protocols that share the same host:port. + pub path: String, + pub tls: TlsMode, + pub enforcement: EnforcementMode, + /// Maximum GraphQL request body bytes to buffer for inspection. + pub graphql_max_body_bytes: usize, + /// When true, percent-encoded `/` (`%2F`) is preserved in path segments + /// rather than rejected at the parser. Needed by upstreams like GitLab + /// that embed `%2F` in namespaced project paths. Defaults to false. + pub allow_encoded_slash: bool, + /// Opt-in rewrite of credential placeholders in client-to-server + /// WebSocket text messages after an allowed HTTP 101 upgrade. + pub websocket_credential_rewrite: bool, + /// Opt-in rewrite of credential placeholders in supported textual REST + /// request bodies before forwarding upstream. + pub request_body_credential_rewrite: bool, + /// When true, client-to-server GraphQL-over-WebSocket operation messages + /// are classified with the same operation policy used by GraphQL-over-HTTP. + pub websocket_graphql_policy: bool, +} + +/// Result of an L7 policy decision for a single request. +#[derive(Debug, Clone)] +pub struct L7Decision { + pub allowed: bool, + pub reason: String, + pub matched_rule: Option, +} + +/// Parsed L7 request metadata used for policy evaluation and logging. +#[derive(Debug, Clone)] +pub struct L7RequestInfo { + /// Protocol action: HTTP method (GET, POST, ...) or SQL command (SELECT, INSERT, ...). + pub action: String, + /// Target: URL path for REST, or empty for SQL. + pub target: String, + /// Decoded query parameter multimap for REST requests. + pub query_params: std::collections::HashMap>, + /// Parsed GraphQL operation metadata for GraphQL endpoints. + pub graphql: Option, +} + +/// Parse an L7 endpoint config from a regorus Value (returned by Rego query). +/// +/// The value is expected to be the raw endpoint object from the Rego data, +/// containing fields: `protocol`, optionally `tls`, `enforcement`. +pub fn parse_l7_config(val: ®orus::Value) -> Option { + let protocol_val = get_object_str(val, "protocol")?; + let protocol = L7Protocol::parse(&protocol_val)?; + + let tls = match get_object_str(val, "tls").as_deref() { + Some("skip") => TlsMode::Skip, + Some("terminate") => { + let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Other) + .severity(openshell_ocsf::SeverityId::Medium) + .message( + "'tls: terminate' is deprecated; TLS termination is now automatic. \ + Use 'tls: skip' to explicitly disable. This field will be removed in a future version.", + ) + .build(); + openshell_ocsf::ocsf_emit!(event); + TlsMode::Auto + } + Some("passthrough") => { + let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(openshell_ocsf::ActivityId::Other) + .severity(openshell_ocsf::SeverityId::Medium) + .message( + "'tls: passthrough' is deprecated; TLS termination is now automatic. \ + Use 'tls: skip' to explicitly disable. This field will be removed in a future version.", + ) + .build(); + openshell_ocsf::ocsf_emit!(event); + TlsMode::Auto + } + _ => TlsMode::Auto, + }; + + let enforcement = match get_object_str(val, "enforcement").as_deref() { + Some("enforce") => EnforcementMode::Enforce, + _ => EnforcementMode::Audit, + }; + + let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); + let websocket_credential_rewrite = + get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); + let request_body_credential_rewrite = + get_object_bool(val, "request_body_credential_rewrite").unwrap_or(false); + let websocket_graphql_policy = + protocol == L7Protocol::Websocket && endpoint_has_graphql_policy(val); + let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") + .and_then(|v| usize::try_from(v).ok()) + .filter(|v| *v > 0) + .unwrap_or(graphql::DEFAULT_MAX_BODY_BYTES); + + Some(L7EndpointConfig { + protocol, + path: get_object_str(val, "path").unwrap_or_default(), + tls, + enforcement, + graphql_max_body_bytes, + allow_encoded_slash, + websocket_credential_rewrite, + request_body_credential_rewrite, + websocket_graphql_policy, + }) +} + +impl L7EndpointConfig { + pub fn matches_path(&self, path: &str) -> bool { + endpoint_path_matches(&self.path, path) + } + + pub fn path_specificity(&self) -> usize { + if self.path.is_empty() { + 0 + } else { + self.path.chars().filter(|c| *c != '*').count() + } + } +} + +pub fn endpoint_path_matches(pattern: &str, path: &str) -> bool { + if pattern.is_empty() || pattern == "**" || pattern == "/**" { + return true; + } + if pattern == path { + return true; + } + if let Some(prefix) = pattern.strip_suffix("/**") { + return path == prefix || path.starts_with(&format!("{prefix}/")); + } + glob::Pattern::new(pattern).is_ok_and(|glob| glob.matches(path)) +} + +/// Parse the `tls` field from an endpoint config, independent of L7 protocol. +/// +/// Used to check for `tls: skip` even on L4-only endpoints (no `protocol` +/// field) that explicitly opt out of TLS auto-detection. +pub fn parse_tls_mode(val: ®orus::Value) -> TlsMode { + match get_object_str(val, "tls").as_deref() { + Some("skip") => TlsMode::Skip, + // "terminate" and "passthrough" are deprecated aliases (logged by parse_l7_config); fall through to Auto. + _ => TlsMode::Auto, + } +} + +/// Extract a bool value from a regorus object. Returns `None` when the key +/// is absent or not a boolean. +fn get_object_bool(val: ®orus::Value, key: &str) -> Option { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => match map.get(&key_val) { + Some(regorus::Value::Bool(b)) => Some(*b), + _ => None, + }, + _ => None, + } +} + +fn get_object_u64(val: ®orus::Value, key: &str) -> Option { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => match map.get(&key_val) { + Some(regorus::Value::Number(n)) => n.as_u64(), + _ => None, + }, + _ => None, + } +} + +/// Extract a string value from a regorus object. +fn get_object_str(val: ®orus::Value, key: &str) -> Option { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => match map.get(&key_val) { + Some(regorus::Value::String(s)) => { + let s = s.to_string(); + if s.is_empty() { None } else { Some(s) } + } + _ => None, + }, + _ => None, + } +} + +fn endpoint_has_graphql_policy(val: ®orus::Value) -> bool { + has_non_empty_object_field(val, "graphql_persisted_queries") + || has_graphql_persisted_query_mode(val) + || rules_have_graphql_policy(val, "rules", true) + || rules_have_graphql_policy(val, "deny_rules", false) +} + +fn rules_have_graphql_policy(val: ®orus::Value, key: &str, allow_wrapped: bool) -> bool { + let Some(regorus::Value::Array(rules)) = get_object_value(val, key) else { + return false; + }; + rules.iter().any(|rule| { + let rule = if allow_wrapped { + get_object_value(rule, "allow").unwrap_or(rule) + } else { + rule + }; + has_graphql_rule_fields(rule) + }) +} + +fn has_graphql_rule_fields(val: ®orus::Value) -> bool { + has_non_empty_string_field(val, "operation_type") + || has_non_empty_string_field(val, "operation_name") + || has_non_empty_array_field(val, "fields") +} + +fn has_non_empty_string_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::String(s)) if !s.is_empty()) +} + +fn has_non_empty_array_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Array(values)) if !values.is_empty()) +} + +fn has_non_empty_object_field(val: ®orus::Value, key: &str) -> bool { + matches!(get_object_value(val, key), Some(regorus::Value::Object(values)) if !values.is_empty()) +} + +fn has_graphql_persisted_query_mode(val: ®orus::Value) -> bool { + matches!( + get_object_value(val, "persisted_queries"), + Some(regorus::Value::String(mode)) if !mode.is_empty() && mode.as_ref() != "deny" + ) +} + +fn get_object_value<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + +/// Check a glob pattern for obvious syntax issues. +/// +/// Returns `Some(warning_message)` if the pattern looks malformed. +/// OPA's `glob.match` is forgiving, so these are warnings (not errors) +/// to surface likely typos without blocking policy loading. +fn check_glob_syntax(pattern: &str) -> Option { + let mut bracket_depth: i32 = 0; + for c in pattern.chars() { + match c { + '[' => bracket_depth += 1, + ']' => { + if bracket_depth == 0 { + return Some(format!("glob pattern '{pattern}' has unmatched ']'")); + } + bracket_depth -= 1; + } + _ => {} + } + } + if bracket_depth > 0 { + return Some(format!("glob pattern '{pattern}' has unclosed '['")); + } + + let mut brace_depth: i32 = 0; + for c in pattern.chars() { + match c { + '{' => brace_depth += 1, + '}' => { + if brace_depth == 0 { + return Some(format!("glob pattern '{pattern}' has unmatched '}}'")); + } + brace_depth -= 1; + } + _ => {} + } + } + if brace_depth > 0 { + return Some(format!("glob pattern '{pattern}' has unclosed '{{'")); + } + + None +} + +fn validate_host_wildcard(errors: &mut Vec, loc: &str, host: &str) { + if !host.contains('*') { + return; + } + + if host == "*" || host == "**" { + errors.push(format!( + "{loc}: host wildcard '{host}' matches all hosts; use specific patterns like '*.example.com'" + )); + return; + } + + let labels: Vec<&str> = host.split('.').collect(); + let first_label = labels.first().copied().unwrap_or_default(); + if labels.iter().skip(1).any(|label| label.contains('*')) { + errors.push(format!( + "{loc}: host wildcard may only appear in the first DNS label, got '{host}'" + )); + return; + } + if first_label.contains("**") && first_label != "**" { + errors.push(format!( + "{loc}: recursive host wildcard '**' is only allowed as the entire first DNS label, got '{host}'" + )); + return; + } + + // Reject TLD or single-label wildcards. They are accepted by the policy + // engine but silently fail at the proxy layer (see #787). + if labels.len() <= 2 { + errors.push(format!( + "{loc}: TLD wildcard '{host}' is not allowed; \ + use subdomain wildcards like '*.example.com' instead" + )); + } +} + +fn validate_graphql_operation_type( + errors: &mut Vec, + loc: &str, + value: Option<&str>, + required: bool, +) { + let Some(value) = value.filter(|v| !v.is_empty()) else { + if required { + errors.push(format!( + "{loc}.operation_type: required for GraphQL L7 rules" + )); + } + return; + }; + + let valid = ["query", "mutation", "subscription", "*"]; + if !valid.contains(&value.to_ascii_lowercase().as_str()) { + errors.push(format!( + "{loc}.operation_type: expected query, mutation, subscription, or *, got '{value}'" + )); + } +} + +fn validate_graphql_fields( + errors: &mut Vec, + warnings: &mut Vec, + loc: &str, + fields: Option<&serde_json::Value>, +) { + let Some(fields) = fields else { + return; + }; + let Some(items) = fields.as_array() else { + errors.push(format!( + "{loc}.fields: expected array of GraphQL root field globs" + )); + return; + }; + if items.is_empty() { + errors.push(format!( + "{loc}.fields: list must not be empty; omit fields to match all root fields" + )); + return; + } + for item in items { + let Some(field) = item.as_str() else { + errors.push(format!("{loc}.fields: all values must be strings")); + continue; + }; + if field.is_empty() { + errors.push(format!("{loc}.fields: field glob must not be empty")); + } else if let Some(warning) = check_glob_syntax(field) { + warnings.push(format!("{loc}.fields: {warning}")); + } + } +} + +fn validate_graphql_rule( + errors: &mut Vec, + warnings: &mut Vec, + loc: &str, + rule: &serde_json::Value, + required: bool, +) { + validate_graphql_operation_type( + errors, + loc, + rule.get("operation_type").and_then(|v| v.as_str()), + required, + ); + if let Some(name) = rule.get("operation_name").and_then(|v| v.as_str()) + && !name.is_empty() + && let Some(warning) = check_glob_syntax(name) + { + warnings.push(format!("{loc}.operation_name: {warning}")); + } + validate_graphql_fields(errors, warnings, loc, rule.get("fields")); +} + +fn json_rule_has_graphql_fields(rule: &serde_json::Value) -> bool { + rule.get("operation_type") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule + .get("operation_name") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty()) + || rule.get("fields").is_some() +} + +fn json_rule_has_transport_fields(rule: &serde_json::Value) -> bool { + rule.get("method").is_some() || rule.get("path").is_some() || rule.get("query").is_some() +} + +fn json_endpoint_has_graphql_policy(ep: &serde_json::Value) -> bool { + ep.get("graphql_persisted_queries") + .and_then(|v| v.as_object()) + .is_some_and(|v| !v.is_empty()) + || ep + .get("persisted_queries") + .and_then(|v| v.as_str()) + .is_some_and(|v| !v.is_empty() && v != "deny") + || ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| { + rules.iter().any(|rule| { + rule.get("allow") + .or(Some(rule)) + .is_some_and(json_rule_has_graphql_fields) + }) + }) + || ep + .get("deny_rules") + .and_then(|v| v.as_array()) + .is_some_and(|rules| rules.iter().any(json_rule_has_graphql_fields)) +} + +/// Validate L7 policy configuration in the loaded OPA data. +/// +/// Returns a list of errors and warnings. Errors should prevent sandbox startup; +/// warnings are logged but don't block. +pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec) { + let mut errors = Vec::new(); + let mut warnings = Vec::new(); + + let Some(policies) = data_json + .get("network_policies") + .and_then(|v| v.as_object()) + else { + return (errors, warnings); + }; + + for (name, policy) in policies { + let Some(endpoints) = policy.get("endpoints").and_then(|v| v.as_array()) else { + continue; + }; + + for (i, ep) in endpoints.iter().enumerate() { + let protocol = ep.get("protocol").and_then(|v| v.as_str()).unwrap_or(""); + let tls = ep.get("tls").and_then(|v| v.as_str()).unwrap_or(""); + let enforcement = ep.get("enforcement").and_then(|v| v.as_str()).unwrap_or(""); + let access = ep.get("access").and_then(|v| v.as_str()).unwrap_or(""); + let has_rules = ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(|a| !a.is_empty()); + let websocket_has_graphql_policy = + protocol == "websocket" && json_endpoint_has_graphql_policy(ep); + let host = ep.get("host").and_then(|v| v.as_str()).unwrap_or(""); + let endpoint_path = ep.get("path").and_then(|v| v.as_str()).unwrap_or(""); + + // Read ports from either "ports" array or scalar "port". + let ports: Vec = ep.get("ports").and_then(|v| v.as_array()).map_or_else( + || { + ep.get("port") + .and_then(serde_json::Value::as_u64) + .filter(|p| *p > 0) + .into_iter() + .collect() + }, + |arr| arr.iter().filter_map(serde_json::Value::as_u64).collect(), + ); + let loc = format!("{name}.endpoints[{i}]"); + + if !endpoint_path.is_empty() { + if !endpoint_path.starts_with('/') && endpoint_path != "**" { + errors.push(format!( + "{loc}: endpoint path must start with '/' or be '**', got '{endpoint_path}'" + )); + } + if let Some(warning) = check_glob_syntax(endpoint_path) { + warnings.push(format!("{loc}.path: {warning}")); + } + } + + validate_host_wildcard(&mut errors, &loc, host); + + // port + ports mutual exclusion + let has_scalar_port = ep + .get("port") + .and_then(serde_json::Value::as_u64) + .is_some_and(|p| p > 0); + let has_ports_array = ep + .get("ports") + .and_then(|v| v.as_array()) + .is_some_and(|a| !a.is_empty()); + if has_scalar_port && has_ports_array { + errors.push(format!( + "{loc}: port and ports are mutually exclusive; use ports for multiple ports" + )); + } + + // rules + access mutual exclusion + if has_rules && !access.is_empty() { + errors.push(format!("{loc}: rules and access are mutually exclusive")); + } + + // protocol requires rules or access + if !protocol.is_empty() && !has_rules && access.is_empty() { + errors.push(format!( + "{loc}: protocol requires rules or access to define allowed traffic" + )); + } + + if !protocol.is_empty() && L7Protocol::parse(protocol).is_none() { + errors.push(format!( + "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, sql, or json-rpc)" + )); + } + + if let Some(mode) = ep.get("persisted_queries").and_then(|v| v.as_str()) + && !mode.is_empty() + && mode != "deny" + && mode != "allow_registered" + { + errors.push(format!( + "{loc}: persisted_queries must be 'deny' or 'allow_registered', got '{mode}'" + )); + } + + if ep.get("graphql_max_body_bytes").is_some() { + let valid_max = ep + .get("graphql_max_body_bytes") + .and_then(serde_json::Value::as_u64) + .is_some_and(|v| v > 0); + if !valid_max { + errors.push(format!( + "{loc}: graphql_max_body_bytes must be a positive integer" + )); + } + } + + if protocol != "graphql" + && protocol != "websocket" + && (ep.get("persisted_queries").is_some() + || ep.get("graphql_persisted_queries").is_some() + || ep.get("graphql_max_body_bytes").is_some()) + { + warnings.push(format!( + "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql or websocket" + )); + } + + if ep + .get("websocket_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + && protocol != "websocket" + { + warnings.push(format!( + "{loc}: websocket_credential_rewrite is ignored unless protocol is rest or websocket" + )); + } + + if ep + .get("request_body_credential_rewrite") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false) + && protocol != "rest" + { + warnings.push(format!( + "{loc}: request_body_credential_rewrite is ignored unless protocol is rest" + )); + } + + if let Some(registry_value) = ep.get("graphql_persisted_queries") { + let Some(registry) = registry_value.as_object() else { + errors.push(format!( + "{loc}: graphql_persisted_queries must be a map keyed by hash or saved-query id" + )); + continue; + }; + for (key, op) in registry { + let registry_loc = format!("{loc}.graphql_persisted_queries[{key}]"); + validate_graphql_rule(&mut errors, &mut warnings, ®istry_loc, op, true); + } + } + + // Deprecated tls values: warn but don't error + if tls == "terminate" || tls == "passthrough" { + warnings.push(format!( + "{loc}: 'tls: {tls}' is deprecated; TLS termination is now automatic. Use 'tls: skip' to disable." + )); + } + + // tls: skip with L7 on port 443 won't work + if tls == "skip" && !protocol.is_empty() && ports.contains(&443) { + warnings.push(format!( + "{loc}: 'tls: skip' with L7 rules on port 443 — L7 inspection cannot work on encrypted traffic" + )); + } + + // sql + enforce blocked in v1 + if protocol == "sql" && enforcement == "enforce" { + errors.push(format!( + "{loc}: SQL enforcement requires full SQL parsing (not available in v1). Use `enforcement: audit`." + )); + } + + // rules with empty list + if ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(Vec::is_empty) + { + errors.push(format!( + "{loc}: rules list cannot be empty (would deny all traffic). Use `access: full` or remove rules." + )); + } + + // port 443 + rest + tls: skip — L7 won't work (already handled above) + // The old warning about missing `tls: terminate` is no longer needed + // because TLS termination is now automatic. + + // Validate deny_rules + let has_deny_rules = ep + .get("deny_rules") + .and_then(|v| v.as_array()) + .is_some_and(|a| !a.is_empty()); + if has_deny_rules { + // deny_rules require L7 inspection + if protocol.is_empty() { + errors.push(format!( + "{loc}: deny_rules require protocol (L7 inspection must be enabled)" + )); + } + + // deny_rules require some allow base (access or rules) + if !has_rules && access.is_empty() { + errors.push(format!( + "{loc}: deny_rules require rules or access to define the base allow set" + )); + } + + if let Some(deny_rules) = ep.get("deny_rules").and_then(|v| v.as_array()) { + for (deny_idx, deny_rule) in deny_rules.iter().enumerate() { + let deny_loc = format!("{loc}.deny_rules[{deny_idx}]"); + + // Validate method + if let Some(method) = deny_rule.get("method").and_then(|m| m.as_str()) + && !method.is_empty() + && (protocol == "rest" || protocol == "websocket") + { + let valid_methods = valid_methods_for_protocol(protocol); + if !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { + warnings.push(format!( + "{deny_loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") + )); + } + } + + // Validate path glob syntax + if let Some(path) = deny_rule.get("path").and_then(|p| p.as_str()) + && let Some(warning) = check_glob_syntax(path) + { + warnings.push(format!("{deny_loc}.path: {warning}")); + } + + // Validate query matchers — mirrors allow-side validation exactly + if let Some(query) = deny_rule.get("query").filter(|v| !v.is_null()) { + let Some(query_obj) = query.as_object() else { + errors.push(format!( + "{deny_loc}.query: expected map of query matchers" + )); + continue; + }; + + for (param, matcher) in query_obj { + if let Some(glob_str) = matcher.as_str() { + if let Some(warning) = check_glob_syntax(glob_str) { + warnings + .push(format!("{deny_loc}.query.{param}: {warning}")); + } + continue; + } + + let Some(matcher_obj) = matcher.as_object() else { + errors.push(format!( + "{deny_loc}.query.{param}: expected string glob or object with `any`" + )); + continue; + }; + + let has_any = matcher_obj.get("any").is_some(); + let has_glob = matcher_obj.get("glob").is_some(); + let has_unknown = + matcher_obj.keys().any(|k| k != "any" && k != "glob"); + if has_unknown { + errors.push(format!( + "{deny_loc}.query.{param}: unknown matcher keys; only `glob` or `any` are supported" + )); + continue; + } + + if has_glob && has_any { + errors.push(format!( + "{deny_loc}.query.{param}: matcher cannot specify both `glob` and `any`" + )); + continue; + } + + if !has_glob && !has_any { + errors.push(format!( + "{deny_loc}.query.{param}: object matcher requires `glob` string or non-empty `any` list" + )); + continue; + } + + if has_glob { + match matcher_obj.get("glob").and_then(|v| v.as_str()) { + None => { + errors.push(format!( + "{deny_loc}.query.{param}.glob: expected glob string" + )); + } + Some(g) => { + if let Some(warning) = check_glob_syntax(g) { + warnings.push(format!( + "{deny_loc}.query.{param}.glob: {warning}" + )); + } + } + } + continue; + } + + let any = matcher_obj.get("any").and_then(|v| v.as_array()); + let Some(any) = any else { + errors.push(format!( + "{deny_loc}.query.{param}.any: expected array of glob strings" + )); + continue; + }; + + if any.is_empty() { + errors.push(format!( + "{deny_loc}.query.{param}.any: list must not be empty" + )); + continue; + } + + if any.iter().any(|v| v.as_str().is_none()) { + errors.push(format!( + "{deny_loc}.query.{param}.any: all values must be strings" + )); + } + + for item in any.iter().filter_map(|v| v.as_str()) { + if let Some(warning) = check_glob_syntax(item) { + warnings.push(format!( + "{deny_loc}.query.{param}.any: {warning}" + )); + } + } + } + } + + // SQL command validation + if let Some(command) = deny_rule.get("command").and_then(|c| c.as_str()) + && !command.is_empty() + && protocol == "rest" + { + warnings + .push(format!("{deny_loc}: command is for SQL protocol, not REST")); + } + + let deny_has_graphql = json_rule_has_graphql_fields(deny_rule); + if protocol == "websocket" + && deny_has_graphql + && json_rule_has_transport_fields(deny_rule) + { + errors.push(format!( + "{deny_loc}: WebSocket GraphQL deny rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + + if protocol == "graphql" || (protocol == "websocket" && deny_has_graphql) { + validate_graphql_rule( + &mut errors, + &mut warnings, + &deny_loc, + deny_rule, + true, + ); + } else if deny_has_graphql { + warnings.push(format!( + "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" + )); + } + } + } + } + + // Empty deny_rules list (explicitly set but empty) + if ep + .get("deny_rules") + .and_then(|v| v.as_array()) + .is_some_and(Vec::is_empty) + { + errors.push(format!( + "{loc}: deny_rules list cannot be empty (would have no effect). Remove it if no denials are needed." + )); + } + + // Validate HTTP methods in rules + if has_rules && (protocol == "rest" || protocol == "websocket") { + let valid_methods = valid_methods_for_protocol(protocol); + if let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { + for (rule_idx, rule) in rules.iter().enumerate() { + if let Some(method) = rule + .get("allow") + .and_then(|a| a.get("method")) + .and_then(|m| m.as_str()) + && !method.is_empty() + && !valid_methods.contains(&method.to_ascii_uppercase().as_str()) + { + warnings.push(format!( + "{loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." + , valid_methods.join(", ") + )); + } + + let Some(query) = rule + .get("allow") + .and_then(|a| a.get("query")) + .filter(|v| !v.is_null()) + else { + continue; + }; + + let Some(query_obj) = query.as_object() else { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query: expected map of query matchers" + )); + continue; + }; + + for (param, matcher) in query_obj { + if let Some(glob_str) = matcher.as_str() { + if let Some(warning) = check_glob_syntax(glob_str) { + warnings.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: {warning}" + )); + } + continue; + } + + let Some(matcher_obj) = matcher.as_object() else { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: expected string glob or object with `any`" + )); + continue; + }; + + let has_any = matcher_obj.get("any").is_some(); + let has_glob = matcher_obj.get("glob").is_some(); + let has_unknown = matcher_obj.keys().any(|k| k != "any" && k != "glob"); + if has_unknown { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: unknown matcher keys; only `glob` or `any` are supported" + )); + continue; + } + + if has_glob && has_any { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: matcher cannot specify both `glob` and `any`" + )); + continue; + } + + if !has_glob && !has_any { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}: object matcher requires `glob` string or non-empty `any` list" + )); + continue; + } + + if has_glob { + match matcher_obj.get("glob").and_then(|v| v.as_str()) { + None => { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.glob: expected glob string" + )); + } + Some(g) => { + if let Some(warning) = check_glob_syntax(g) { + warnings.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.glob: {warning}" + )); + } + } + } + continue; + } + + let any = matcher_obj.get("any").and_then(|v| v.as_array()); + let Some(any) = any else { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: expected array of glob strings" + )); + continue; + }; + + if any.is_empty() { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: list must not be empty" + )); + continue; + } + + if any.iter().any(|v| v.as_str().is_none()) { + errors.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: all values must be strings" + )); + } + + for item in any.iter().filter_map(|v| v.as_str()) { + if let Some(warning) = check_glob_syntax(item) { + warnings.push(format!( + "{loc}.rules[{rule_idx}].allow.query.{param}.any: {warning}" + )); + } + } + } + } + } + } + + if has_rules && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { + for (rule_idx, rule) in rules.iter().enumerate() { + let allow = rule.get("allow").unwrap_or(rule); + let rule_loc = format!("{loc}.rules[{rule_idx}].allow"); + let allow_has_graphql = json_rule_has_graphql_fields(allow); + if websocket_has_graphql_policy + && allow + .get("method") + .and_then(|m| m.as_str()) + .is_some_and(|method| method.eq_ignore_ascii_case("WEBSOCKET_TEXT")) + { + errors.push(format!( + "{rule_loc}: WebSocket endpoints with GraphQL operation policy must use operation_type/operation_name/fields rules for client messages instead of WEBSOCKET_TEXT" + )); + } + if protocol == "websocket" + && allow_has_graphql + && json_rule_has_transport_fields(allow) + { + errors.push(format!( + "{rule_loc}: WebSocket GraphQL allow rules must not combine method/path/query with operation_type/operation_name/fields" + )); + } + if protocol == "graphql" || (protocol == "websocket" && allow_has_graphql) { + validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); + } else if allow_has_graphql { + warnings.push(format!( + "{rule_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" + )); + } + } + } + } + } + + (errors, warnings) +} + +/// Expand `access` presets into explicit `rules` in the policy data. +/// +/// This preprocesses the JSON data so Rego only needs to handle explicit rules. +pub fn expand_access_presets(data: &mut serde_json::Value) { + let Some(policies) = data + .get_mut("network_policies") + .and_then(|v| v.as_object_mut()) + else { + return; + }; + + for (_name, policy) in policies.iter_mut() { + let Some(endpoints) = policy.get_mut("endpoints").and_then(|v| v.as_array_mut()) else { + continue; + }; + + for ep in endpoints.iter_mut() { + let access = ep + .get("access") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + if access.is_empty() { + continue; + } + + // Don't expand if rules already exist (validation will catch this) + if ep + .get("rules") + .and_then(|v| v.as_array()) + .is_some_and(|a| !a.is_empty()) + { + continue; + } + + let protocol = ep + .get("protocol") + .and_then(|v| v.as_str()) + .unwrap_or("rest"); + let rules = if protocol == "graphql" { + match access.as_str() { + "read-only" => vec![graphql_rule_json("query")], + "read-write" => vec![graphql_rule_json("query"), graphql_rule_json("mutation")], + "full" => vec![graphql_rule_json("*")], + _ => continue, + } + } else if protocol == "websocket" { + match access.as_str() { + "read-only" => vec![rule_json("GET", "**")], + "read-write" => vec![rule_json("GET", "**"), rule_json("WEBSOCKET_TEXT", "**")], + "full" => vec![rule_json("*", "**")], + _ => continue, + } + } else { + match access.as_str() { + "read-only" => vec![ + rule_json("GET", "**"), + rule_json("HEAD", "**"), + rule_json("OPTIONS", "**"), + ], + "read-write" => vec![ + rule_json("GET", "**"), + rule_json("HEAD", "**"), + rule_json("OPTIONS", "**"), + rule_json("POST", "**"), + rule_json("PUT", "**"), + rule_json("PATCH", "**"), + ], + "full" => vec![rule_json("*", "**")], + _ => continue, + } + }; + + ep.as_object_mut() + .unwrap() + .insert("rules".to_string(), serde_json::Value::Array(rules)); + } + } +} + +fn rule_json(method: &str, path: &str) -> serde_json::Value { + serde_json::json!({ + "allow": { + "method": method, + "path": path + } + }) +} + +fn valid_methods_for_protocol(protocol: &str) -> &'static [&'static str] { + match protocol { + "websocket" => &["GET", "WEBSOCKET_TEXT", "*"], + _ => &[ + "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", + ], + } +} + +fn graphql_rule_json(operation_type: &str) -> serde_json::Value { + serde_json::json!({ + "allow": { + "operation_type": operation_type + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_l7_config_rest_enforce() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "tls": "terminate", "enforcement": "enforce", "host": "api.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::Rest); + // "terminate" is deprecated and treated as Auto. + assert_eq!(config.tls, TlsMode::Auto); + assert_eq!(config.enforcement, EnforcementMode::Enforce); + } + + #[test] + fn parse_l7_config_defaults() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "api.example.com", "port": 80}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::Rest); + assert_eq!(config.tls, TlsMode::Auto); + assert_eq!(config.enforcement, EnforcementMode::Audit); + } + + #[test] + fn parse_l7_config_websocket_protocol() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::Websocket); + } + + #[test] + fn parse_l7_config_skip() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "tls": "skip", "host": "api.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.tls, TlsMode::Skip); + } + + #[test] + fn parse_l7_config_no_protocol() { + let val = + regorus::Value::from_json_str(r#"{"host": "api.example.com", "port": 443}"#).unwrap(); + assert!(parse_l7_config(&val).is_none()); + } + + #[test] + fn parse_l7_config_allow_encoded_slash_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "api.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.allow_encoded_slash); + } + + #[test] + fn parse_l7_config_allow_encoded_slash_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gitlab.example.com", "port": 443, "allow_encoded_slash": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.allow_encoded_slash); + } + + #[test] + fn parse_l7_config_websocket_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443, "websocket_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_request_body_credential_rewrite_opt_in() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "rest", "host": "slack.com", "port": 443, "request_body_credential_rewrite": true}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.request_body_credential_rewrite); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_defaults_false() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(!config.websocket_graphql_policy); + } + + #[test] + fn parse_l7_config_websocket_graphql_policy_detects_operation_rules() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}}]}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert!(config.websocket_graphql_policy); + } + + #[test] + fn validate_websocket_credential_rewrite_warns_unless_rest_or_websocket() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "websocket_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("websocket_credential_rewrite is ignored")), + "expected websocket_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn validate_request_body_credential_rewrite_warns_unless_rest() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "request_body_credential_rewrite": true + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings + .iter() + .any(|w| w.contains("request_body_credential_rewrite is ignored")), + "expected request_body_credential_rewrite warning: {warnings:?}" + ); + } + + #[test] + fn expand_websocket_read_write_access_includes_text_messages() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "access": "read-write" + }], + "binaries": [] + } + } + }); + + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + let methods: Vec<&str> = rules + .iter() + .map(|r| r["allow"]["method"].as_str().unwrap()) + .collect(); + assert!(methods.contains(&"GET")); + assert!(methods.contains(&"WEBSOCKET_TEXT")); + } + + #[test] + fn validate_websocket_accepts_graphql_operation_rules() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!(errors.is_empty(), "expected no errors: {errors:?}"); + assert!(warnings.is_empty(), "expected no warnings: {warnings:?}"); + } + + #[test] + fn validate_websocket_graphql_rule_requires_operation_type() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"fields": ["messageAdded"]}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("operation_type")), + "expected missing operation_type error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_rule_rejects_mixed_transport_fields() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql", "operation_type": "subscription"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("must not combine")), + "expected mixed-field error: {errors:?}" + ); + } + + #[test] + fn validate_websocket_graphql_policy_rejects_raw_text_message_rule() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "gateway.example.com", + "port": 443, + "protocol": "websocket", + "rules": [ + {"allow": {"method": "GET", "path": "/graphql"}}, + {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}, + {"allow": {"operation_type": "query"}} + ] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("instead of WEBSOCKET_TEXT")), + "expected raw WEBSOCKET_TEXT rejection: {errors:?}" + ); + } + + #[test] + fn validate_rules_and_access_mutual_exclusion() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "access": "read-only", + "rules": [{"allow": {"method": "GET", "path": "**"}}] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!(errors.iter().any(|e| e.contains("mutually exclusive"))); + } + + #[test] + fn validate_protocol_requires_rules_or_access() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest" + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("requires rules or access")) + ); + } + + #[test] + fn validate_sql_enforce_blocked() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "db.internal", + "port": 5432, + "protocol": "sql", + "enforcement": "enforce", + "rules": [{"allow": {"command": "SELECT"}}] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!(errors.iter().any(|e| e.contains("SQL enforcement"))); + } + + #[test] + fn validate_tls_terminate_deprecated_warning() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "tls": "terminate", + "protocol": "rest", + "access": "full" + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "deprecated tls should not error: {errors:?}" + ); + assert!( + warnings.iter().any(|w| w.contains("deprecated")), + "should warn about deprecated tls: {warnings:?}" + ); + } + + #[test] + fn validate_tls_skip_with_l7_on_443_warns() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "tls": "skip", + "protocol": "rest", + "access": "read-only" + }], + "binaries": [] + } + } + }); + let (_errors, warnings) = validate_l7_policies(&data); + assert!( + warnings.iter().any(|w| w.contains("tls: skip")), + "should warn about skip + L7 on 443: {warnings:?}" + ); + } + + #[test] + fn validate_port_443_rest_no_tls_no_warning() { + // With auto-TLS, no warning is needed for port 443 + rest without + // explicit tls field — TLS will be auto-detected. + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "access": "read-only" + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!(errors.is_empty(), "should have no errors: {errors:?}"); + assert!( + !warnings.iter().any(|w| w.contains("tls")), + "should have no tls warnings with auto-detect: {warnings:?}" + ); + } + + #[test] + fn expand_read_only_preset() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 80, + "protocol": "rest", + "access": "read-only" + }], + "binaries": [] + } + } + }); + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + assert_eq!(rules.len(), 3); + let methods: Vec<&str> = rules + .iter() + .map(|r| r["allow"]["method"].as_str().unwrap()) + .collect(); + assert!(methods.contains(&"GET")); + assert!(methods.contains(&"HEAD")); + assert!(methods.contains(&"OPTIONS")); + } + + #[test] + fn expand_full_preset() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 80, + "protocol": "rest", + "access": "full" + }], + "binaries": [] + } + } + }); + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + assert_eq!(rules.len(), 1); + assert_eq!(rules[0]["allow"]["method"].as_str().unwrap(), "*"); + assert_eq!(rules[0]["allow"]["path"].as_str().unwrap(), "**"); + } + + #[test] + fn expand_graphql_readonly_preset() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "graphql", + "access": "read-only" + }], + "binaries": [] + } + } + }); + expand_access_presets(&mut data); + let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] + .as_array() + .unwrap(); + assert_eq!(rules.len(), 1); + assert_eq!( + rules[0]["allow"]["operation_type"].as_str().unwrap(), + "query" + ); + } + + #[test] + fn validate_graphql_rule_requires_operation_type() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "graphql", + "rules": [{ + "allow": { + "fields": ["viewer"] + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("operation_type")), + "GraphQL rules should require operation_type: {errors:?}" + ); + } + + #[test] + fn validate_graphql_persisted_query_mode() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "graphql", + "access": "full", + "persisted_queries": "allow_all" + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("persisted_queries")), + "invalid persisted query mode should be rejected: {errors:?}" + ); + } + + #[test] + fn l4_only_endpoint_untouched() { + let mut data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + expand_access_presets(&mut data); + assert!( + data["network_policies"]["test"]["endpoints"][0] + .get("rules") + .is_none() + ); + } + + // ---- Host wildcard validation tests ---- + + #[test] + fn validate_wildcard_host_star_only_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "*", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("matches all hosts")), + "Bare * host should be rejected, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_double_star_only_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "**", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("matches all hosts")), + "Bare ** host should be rejected, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_mid_label_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "foo.*.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("first DNS label")), + "Mid-label wildcard should be rejected, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_single_label_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "*com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("TLD wildcard")), + "Single-label wildcard should be rejected, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_recursive_intra_label_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "foo**.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("recursive host wildcard")), + "Recursive intra-label wildcard should be rejected, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_tld_rejected() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "*.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("TLD wildcard")), + "*.com should be rejected as TLD wildcard, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_double_star_tld_rejected() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "**.org", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("TLD wildcard")), + "**.org should be rejected as TLD wildcard, got errors: {errors:?}" + ); + } + + #[test] + fn validate_wildcard_host_valid_no_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "*.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "*.example.com should be valid, got errors: {errors:?}" + ); + assert!( + warnings.is_empty(), + "*.example.com should not warn, got warnings: {warnings:?}" + ); + } + + #[test] + fn validate_wildcard_host_double_star_valid_no_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "**.example.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "**.example.com should be valid, got errors: {errors:?}" + ); + assert!( + warnings.is_empty(), + "**.example.com should not warn, got warnings: {warnings:?}" + ); + } + + #[test] + fn validate_wildcard_host_intra_label_valid_no_error() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "*-aiplatform.googleapis.com", + "port": 443 + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "*-aiplatform.googleapis.com should be valid, got errors: {errors:?}" + ); + assert!( + warnings.is_empty(), + "*-aiplatform.googleapis.com should not warn, got warnings: {warnings:?}" + ); + } + + #[test] + fn validate_port_and_ports_mutually_exclusive() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "ports": [443, 8443] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("port and ports are mutually exclusive")), + "Should reject both port and ports, got errors: {errors:?}" + ); + } + + #[test] + fn validate_ports_array_rest_443_no_warning() { + // With auto-TLS, no warning needed for ports array containing 443. + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "ports": [443, 8080], + "protocol": "rest", + "access": "read-only" + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!(errors.is_empty(), "should have no errors: {errors:?}"); + assert!( + !warnings.iter().any(|w| w.contains("tls")), + "should have no tls warnings with auto-detect: {warnings:?}" + ); + } + + #[test] + fn validate_query_any_requires_non_empty_array() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": { "any": [] } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("allow.query.tag.any")), + "expected query any validation error, got: {errors:?}" + ); + } + + #[test] + fn validate_query_object_rejects_unknown_keys() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": { "mode": "foo-*" } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.iter().any(|e| e.contains("unknown matcher keys")), + "expected unknown query matcher key error, got: {errors:?}" + ); + } + + #[test] + fn validate_query_glob_warns_on_unclosed_bracket() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": "[unclosed" + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "malformed glob should warn, not error: {errors:?}" + ); + assert!( + warnings + .iter() + .any(|w| w.contains("unclosed '['") && w.contains("allow.query.tag")), + "expected glob syntax warning, got: {warnings:?}" + ); + } + + #[test] + fn validate_query_glob_warns_on_unclosed_brace() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "format": { "glob": "{json,xml" } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "malformed glob should warn, not error: {errors:?}" + ); + assert!( + warnings + .iter() + .any(|w| w.contains("unclosed '{'") && w.contains("allow.query.format.glob")), + "expected glob syntax warning, got: {warnings:?}" + ); + } + + #[test] + fn validate_query_any_warns_on_malformed_glob_item() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "tag": { "any": ["valid-*", "[bad"] } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "malformed glob in any should warn, not error: {errors:?}" + ); + assert!( + warnings + .iter() + .any(|w| w.contains("unclosed '['") && w.contains("allow.query.tag.any")), + "expected glob syntax warning for any item, got: {warnings:?}" + ); + } + + #[test] + fn validate_query_string_and_any_matchers_are_accepted() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 8080, + "protocol": "rest", + "rules": [{ + "allow": { + "method": "GET", + "path": "/download", + "query": { + "slug": "my-*", + "tag": { "any": ["foo-*", "bar-*"] }, + "owner": { "glob": "org-*" } + } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "valid query matcher shapes should not error: {errors:?}" + ); + } + + // --- Deny rules validation tests --- + + #[test] + fn validate_deny_rules_require_protocol() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "deny_rules": [{ "method": "POST", "path": "/admin" }] + }], + "binaries": [] + } + } + }); + let (errors, _) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("deny_rules require protocol")), + "should require protocol for deny_rules: {errors:?}" + ); + } + + #[test] + fn validate_deny_rules_require_allow_base() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "deny_rules": [{ "method": "POST", "path": "/admin" }] + }], + "binaries": [] + } + } + }); + let (errors, _) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("deny_rules require rules or access")), + "should require rules or access for deny_rules: {errors:?}" + ); + } + + #[test] + fn validate_deny_rules_empty_list_rejected() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "access": "full", + "deny_rules": [] + }], + "binaries": [] + } + } + }); + let (errors, _) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("deny_rules list cannot be empty")), + "should reject empty deny_rules: {errors:?}" + ); + } + + #[test] + fn validate_deny_rules_valid_config_accepted() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "access": "read-write", + "deny_rules": [ + { "method": "POST", "path": "/repos/*/pulls/*/reviews" }, + { "method": "PUT", "path": "/repos/*/branches/*/protection" } + ] + }], + "binaries": [] + } + } + }); + let (errors, _) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "valid deny_rules should not error: {errors:?}" + ); + } + + #[test] + fn validate_deny_rules_query_empty_any_rejected() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "access": "full", + "deny_rules": [{ + "method": "POST", + "path": "/admin", + "query": { "type": { "any": [] } } + }] + }], + "binaries": [] + } + } + }); + let (errors, _) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("any: list must not be empty")), + "should reject empty any list in deny query: {errors:?}" + ); + } + + #[test] + fn validate_deny_rules_query_non_string_rejected() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "access": "full", + "deny_rules": [{ + "method": "POST", + "path": "/admin", + "query": { "force": 123 } + }] + }], + "binaries": [] + } + } + }); + let (errors, _) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("expected string glob or object")), + "should reject non-string/non-object matcher in deny query: {errors:?}" + ); + } + + #[test] + fn validate_deny_rules_query_valid_matchers_accepted() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "api.example.com", + "port": 443, + "protocol": "rest", + "access": "full", + "deny_rules": [{ + "method": "POST", + "path": "/admin/**", + "query": { + "force": "true", + "type": { "any": ["admin-*", "root-*"] }, + "scope": { "glob": "org-*" } + } + }] + }], + "binaries": [] + } + } + }); + let (errors, _) = validate_l7_policies(&data); + assert!( + errors.is_empty(), + "valid deny query matchers should not error: {errors:?}" + ); + } +} diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs new file mode 100644 index 000000000..6baa2ab03 --- /dev/null +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -0,0 +1,2429 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Protocol-aware bidirectional relay with L7 inspection. +//! +//! Replaces `copy_bidirectional` for endpoints with L7 configuration. +//! Parses each request within the tunnel, evaluates it against OPA policy, +//! and either forwards or denies the request. + +use crate::activity_aggregator::{ActivitySender, try_record_activity}; +use crate::l7::provider::{L7Provider, RelayOutcome}; +use crate::l7::rest::WebSocketExtensionMode; +use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; +use crate::opa::{PolicyGenerationGuard, TunnelPolicyEngine}; +use crate::secrets::{self, SecretResolver}; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, + NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, +}; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +use tracing::{debug, warn}; + +/// Context for L7 request policy evaluation. +pub struct L7EvalContext { + /// Host from the CONNECT request. + pub host: String, + /// Port from the CONNECT request. + pub port: u16, + /// Matched policy name from L4 evaluation. + pub policy_name: String, + /// Binary path (for cross-layer Rego evaluation). + pub binary_path: String, + /// Ancestor paths. + pub ancestors: Vec, + /// Cmdline paths. + pub cmdline_paths: Vec, + /// Supervisor-only placeholder resolver for outbound headers. + pub(crate) secret_resolver: Option>, + /// Anonymous activity counter channel. + pub(crate) activity_tx: Option, + /// Dynamic credentials (token grants) keyed by endpoint-bound provider metadata. + pub(crate) dynamic_credentials: Option< + Arc< + std::sync::RwLock< + std::collections::HashMap, + >, + >, + >, + /// Dynamic token grant resolver for endpoint-bound credentials. + pub(crate) token_grant_resolver: + Option>, +} + +#[derive(Default)] +pub(crate) struct UpgradeRelayOptions<'a> { + pub(crate) websocket_request: bool, + pub(crate) websocket: WebSocketUpgradeBehavior, + pub(crate) secret_resolver: Option>, + pub(crate) engine: Option<&'a TunnelPolicyEngine>, + pub(crate) ctx: Option<&'a L7EvalContext>, + pub(crate) enforcement: EnforcementMode, + pub(crate) target: String, + pub(crate) query_params: std::collections::HashMap>, + pub(crate) policy_name: String, +} + +#[derive(Default)] +pub(crate) struct WebSocketUpgradeBehavior { + pub(crate) credential_rewrite: bool, + pub(crate) message_policy: WebSocketMessagePolicy, + pub(crate) permessage_deflate: bool, +} + +#[derive(Clone, Copy, Default, PartialEq, Eq)] +pub(crate) enum WebSocketMessagePolicy { + #[default] + None, + Transport, + Graphql, +} + +impl WebSocketMessagePolicy { + fn inspects_messages(self) -> bool { + self != Self::None + } + + fn is_graphql(self) -> bool { + self == Self::Graphql + } +} + +#[derive(Debug, Clone, Copy)] +enum ParseRejectionMode { + L7Endpoint, + Passthrough, +} + +fn parse_rejection_detail(error: &str, mode: ParseRejectionMode) -> String { + if error.contains("encoded '/' (%2F)") { + match mode { + ParseRejectionMode::L7Endpoint => format!( + "{error}; set allow_encoded_slash: true on this endpoint if the upstream requires encoded slashes" + ), + ParseRejectionMode::Passthrough => format!( + "{error}; passthrough credential relay uses strict path parsing, so configure this endpoint with protocol: rest and allow_encoded_slash: true for encoded-slash APIs, or use tls: skip if HTTP parsing is not needed" + ), + } + } else { + error.to_string() + } +} + +fn emit_parse_rejection(ctx: &L7EvalContext, detail: &str, engine_type: &str) { + let policy_name = if ctx.policy_name.is_empty() { + "-" + } else { + &ctx.policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(policy_name, engine_type) + .message(format!( + "HTTP request rejected before policy evaluation for {}:{}", + ctx.host, ctx.port + )) + .status_detail(detail) + .build(); + ocsf_emit!(event); + emit_activity(ctx, true, "l7_parse_rejection"); +} + +/// Run protocol-aware L7 inspection on a tunnel. +/// +/// This replaces `copy_bidirectional` for L7-enabled endpoints. +/// Protocol detection (peek) is the caller's responsibility — this function +/// assumes the streams are already proven to carry the expected protocol. +/// For TLS-terminated connections, ALPN proves HTTP; for plaintext, the +/// caller peeks on the raw `TcpStream` before calling this. +pub async fn relay_with_inspection( + config: &L7EndpointConfig, + engine: TunnelPolicyEngine, + client: &mut C, + upstream: &mut U, + ctx: &L7EvalContext, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + match config.protocol { + L7Protocol::Rest | L7Protocol::Websocket => { + relay_rest(config, &engine, client, upstream, ctx).await + } + L7Protocol::Graphql => relay_graphql(config, &engine, client, upstream, ctx).await, + L7Protocol::Sql => { + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + // SQL provider is Phase 3 — fall through to passthrough with warning + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Low) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .message("SQL L7 provider not yet implemented, falling back to passthrough") + .build(); + ocsf_emit!(event); + } + tokio::io::copy_bidirectional(client, upstream) + .await + .into_diagnostic()?; + Ok(()) + } + L7Protocol::JsonRpc => { + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + // JSON-RPC provider not yet implemented — fall through to passthrough + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Low) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .message("JSON-RPC L7 provider not yet implemented, falling back to passthrough") + .build(); + ocsf_emit!(event); + } + tokio::io::copy_bidirectional(client, upstream) + .await + .into_diagnostic()?; + Ok(()) + } + } +} + +/// Run HTTP L7 inspection with per-request protocol selection. +/// +/// This is used when multiple L7 endpoints share a host:port, for example a +/// REST API under `/repos/**` and a GraphQL API under `/graphql`. +pub async fn relay_with_route_selection( + configs: &[L7EndpointConfig], + engine: TunnelPolicyEngine, + client: &mut C, + upstream: &mut U, + ctx: &L7EvalContext, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + let provider = + crate::l7::rest::RestProvider::with_options(crate::l7::path::CanonicalizeOptions { + allow_encoded_slash: configs.iter().any(|config| config.allow_encoded_slash), + ..Default::default() + }); + + loop { + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let mut req = match provider.parse_request(client).await { + Ok(Some(req)) => req, + Ok(None) => return Ok(()), + Err(e) => { + if is_benign_connection_error(&e) { + debug!( + host = %ctx.host, + port = ctx.port, + error = %e, + "L7 route-selected connection closed" + ); + } else { + let detail = + parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); + emit_parse_rejection(ctx, &detail, "l7"); + } + return Ok(()); + } + }; + + let Some(config) = select_l7_config_for_path(configs, &req.target) else { + crate::l7::rest::RestProvider::default() + .deny( + &req, + &ctx.policy_name, + "no L7 endpoint path matched request", + client, + ) + .await?; + return Ok(()); + }; + + let graphql_info = if config.protocol == L7Protocol::Graphql { + match crate::l7::graphql::inspect_graphql_request( + client, + &mut req, + config.graphql_max_body_bytes, + ) + .await + { + Ok(info) => Some(info), + Err(e) => { + if is_benign_connection_error(&e) { + debug!( + host = %ctx.host, + port = ctx.port, + error = %e, + "GraphQL L7 connection closed" + ); + } else { + let detail = + parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); + emit_parse_rejection(ctx, &detail, "l7-graphql"); + } + return Ok(()); + } + } + } else { + None + }; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let (eval_target, redacted_target) = if let Some(ref resolver) = ctx.secret_resolver { + match secrets::rewrite_target_for_eval(&req.target, resolver) { + Ok(result) => (result.resolved, result.redacted), + Err(e) => { + warn!( + host = %ctx.host, + port = ctx.port, + error = %e, + "credential resolution failed in request target, rejecting" + ); + let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client.write_all(response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(()); + } + } + } else { + (req.target.clone(), req.target.clone()) + }; + + let request_info = L7RequestInfo { + action: req.action.clone(), + target: redacted_target.clone(), + query_params: req.query_params.clone(), + graphql: graphql_info.clone(), + }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + + let parse_error_reason = graphql_info + .as_ref() + .and_then(|info| info.error.as_deref()) + .map(|error| format!("GraphQL request rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(&engine, ctx, &request_info)? + }; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let decision_str = match (allowed, config.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + let engine_type = match config.protocol { + L7Protocol::Graphql => "l7-graphql", + L7Protocol::Websocket => "l7-websocket", + L7Protocol::Rest | L7Protocol::Sql | L7Protocol::JsonRpc => "l7", + }; + emit_l7_request_log( + ctx, + &request_info, + &redacted_target, + decision_str, + engine_type, + &reason, + graphql_info.as_ref(), + ); + + let _ = &eval_target; + + if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( + &req, + client, + upstream, + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, + ) + .await?; + match outcome { + RelayOutcome::Reusable => {} + RelayOutcome::Consumed => return Ok(()), + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req.query_params, + Some(&engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; + } + } + } else { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } +} + +fn select_l7_config_for_path<'a>( + configs: &'a [L7EndpointConfig], + path: &str, +) -> Option<&'a L7EndpointConfig> { + configs + .iter() + .filter(|config| config.matches_path(path)) + .max_by_key(|config| config.path_specificity()) +} + +fn emit_l7_request_log( + ctx: &L7EvalContext, + request_info: &L7RequestInfo, + redacted_target: &str, + decision_str: &str, + engine_type: &str, + reason: &str, + graphql_info: Option<&crate::l7::graphql::GraphqlRequestInfo>, +) { + let (action_id, disposition_id, severity) = match decision_str { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let summary = graphql_info + .map(|info| format!(" {}", graphql_log_summary(info))) + .unwrap_or_default(); + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + &request_info.action, + OcsfUrl::new("http", &ctx.host, redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, engine_type) + .message(format!( + "L7_REQUEST {decision_str} {} {}:{}{}{} reason={}", + request_info.action, ctx.host, ctx.port, redacted_target, summary, reason, + )) + .build(); + ocsf_emit!(event); + emit_activity(ctx, decision_str == "deny", "l7_policy"); +} + +fn emit_activity(ctx: &L7EvalContext, denied: bool, deny_group: &'static str) { + if let Some(tx) = &ctx.activity_tx { + let _ = try_record_activity(tx, denied, deny_group); + } +} + +/// Handle an upgraded connection (101 Switching Protocols). +/// +/// Forwards any overflow bytes from the upgrade response to the client, then +/// either switches to a parsed WebSocket relay for opted-in message policy / +/// credential rewriting or to raw bidirectional TCP copy for other upgrades. +pub(crate) async fn handle_upgrade( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, + options: UpgradeRelayOptions<'_>, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + let use_websocket_relay = options.websocket_request + && (options.websocket.message_policy.inspects_messages() + || options.websocket.permessage_deflate + || (options.websocket.credential_rewrite && options.secret_resolver.is_some())); + let relay_mode = if use_websocket_relay { + "websocket parsed relay" + } else { + "raw bidirectional relay (L7 enforcement no longer active)" + }; + ocsf_emit!( + NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .activity_name("Upgrade") + .severity(SeverityId::Informational) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!( + "101 Switching Protocols — {relay_mode} [host:{host} port:{port} overflow_bytes:{}]", + overflow.len() + )) + .build() + ); + if use_websocket_relay { + let resolver = if options.websocket.credential_rewrite { + options.secret_resolver.as_deref() + } else { + None + }; + let inspector = if options.websocket.message_policy.inspects_messages() { + match (options.engine, options.ctx) { + (Some(engine), Some(ctx)) => Some(crate::l7::websocket::InspectionOptions { + engine, + ctx, + enforcement: options.enforcement, + target: options.target.clone(), + query_params: options.query_params.clone(), + graphql_policy: options.websocket.message_policy.is_graphql(), + }), + _ => { + return Err(miette!( + "websocket message inspection missing policy context" + )); + } + } + } else { + None + }; + let compression = if options.websocket.permessage_deflate { + crate::l7::websocket::WebSocketCompression::PermessageDeflate + } else { + crate::l7::websocket::WebSocketCompression::None + }; + return crate::l7::websocket::relay_with_options( + client, + upstream, + overflow, + host, + port, + crate::l7::websocket::RelayOptions { + policy_name: &options.policy_name, + resolver, + inspector, + compression, + }, + ) + .await; + } + if !overflow.is_empty() { + client.write_all(&overflow).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + } + tokio::io::copy_bidirectional(client, upstream) + .await + .into_diagnostic()?; + Ok(()) +} + +pub(crate) fn upgrade_options<'a>( + config: &L7EndpointConfig, + ctx: &'a L7EvalContext, + websocket_request: bool, + target: &str, + query_params: &std::collections::HashMap>, + engine: Option<&'a TunnelPolicyEngine>, +) -> UpgradeRelayOptions<'a> { + let websocket_credential_rewrite = + matches!(config.protocol, L7Protocol::Rest | L7Protocol::Websocket) + && config.websocket_credential_rewrite; + let websocket_message_policy = if config.protocol == L7Protocol::Websocket { + if config.websocket_graphql_policy { + WebSocketMessagePolicy::Graphql + } else { + WebSocketMessagePolicy::Transport + } + } else { + WebSocketMessagePolicy::None + }; + UpgradeRelayOptions { + websocket_request, + websocket: WebSocketUpgradeBehavior { + credential_rewrite: websocket_credential_rewrite, + message_policy: websocket_message_policy, + permessage_deflate: false, + }, + secret_resolver: if websocket_credential_rewrite { + ctx.secret_resolver.clone() + } else { + None + }, + engine, + ctx: engine.map(|_| ctx), + enforcement: config.enforcement, + target: target.to_string(), + query_params: query_params.clone(), + policy_name: ctx.policy_name.clone(), + } +} + +pub(crate) fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { + if config.protocol == L7Protocol::Websocket + || (config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite) + { + WebSocketExtensionMode::PermessageDeflate + } else { + WebSocketExtensionMode::Preserve + } +} + +/// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. +async fn relay_rest( + config: &L7EndpointConfig, + engine: &TunnelPolicyEngine, + client: &mut C, + upstream: &mut U, + ctx: &L7EvalContext, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + // Build a provider carrying the per-endpoint canonicalization options so + // request parsing honors the endpoint's `allow_encoded_slash` setting + // (e.g. APIs like GitLab that embed `%2F` in path segments). + let provider = + crate::l7::rest::RestProvider::with_options(crate::l7::path::CanonicalizeOptions { + allow_encoded_slash: config.allow_encoded_slash, + ..Default::default() + }); + loop { + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + // Parse one HTTP request from client + let req = match provider.parse_request(client).await { + Ok(Some(req)) => req, + Ok(None) => return Ok(()), // Client closed connection + Err(e) => { + if is_benign_connection_error(&e) { + debug!( + host = %ctx.host, + port = ctx.port, + error = %e, + "L7 connection closed" + ); + } else { + let detail = + parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); + emit_parse_rejection(ctx, &detail, "l7"); + } + return Ok(()); // Close connection on parse error + } + }; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + // Rewrite credential placeholders in the request target BEFORE OPA + // evaluation. OPA sees the redacted path; the resolved path goes only + // to the upstream write. + let (eval_target, redacted_target) = if let Some(ref resolver) = ctx.secret_resolver { + match secrets::rewrite_target_for_eval(&req.target, resolver) { + Ok(result) => (result.resolved, result.redacted), + Err(e) => { + warn!( + host = %ctx.host, + port = ctx.port, + error = %e, + "credential resolution failed in request target, rejecting" + ); + let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client.write_all(response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(()); + } + } + } else { + (req.target.clone(), req.target.clone()) + }; + + let request_info = L7RequestInfo { + action: req.action.clone(), + target: redacted_target.clone(), + query_params: req.query_params.clone(), + graphql: None, + }; + let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); + if config.protocol == L7Protocol::Websocket && !websocket_request { + provider + .deny_with_redacted_target( + &req, + &ctx.policy_name, + "websocket endpoint requires a valid WebSocket upgrade request", + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + + // Evaluate L7 policy via Rego (using redacted target) + let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + // Check if this is an upgrade request for logging purposes. + let header_end = req + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(req.raw_header.len(), |p| p + 4); + let is_upgrade_request = { + let h = String::from_utf8_lossy(&req.raw_header[..header_end]); + h.lines() + .skip(1) + .any(|l| l.to_ascii_lowercase().starts_with("upgrade:")) + }; + + let decision_str = match (allowed, config.enforcement, is_upgrade_request) { + (true, _, true) => "allow_upgrade", + (true, _, false) => "allow", + (false, EnforcementMode::Audit, _) => "audit", + (false, EnforcementMode::Enforce, _) => "deny", + }; + + // Log every L7 decision as an OCSF HTTP Activity event. + // Uses redacted_target (path only, no query params) to avoid logging secrets. + { + let (action_id, disposition_id, severity) = match decision_str { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + &request_info.action, + OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "l7") + .message(format!( + "L7_REQUEST {decision_str} {} {}:{}{} reason={}", + request_info.action, ctx.host, ctx.port, redacted_target, reason, + )) + .build(); + ocsf_emit!(event); + } + + // Store the resolved target for the deny response redaction + let _ = &eval_target; + + if allowed || config.enforcement == EnforcementMode::Audit { + let req_with_auth = + match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { + Ok(req) => req, + Err(e) => { + warn!( + host = %ctx.host, + port = ctx.port, + error = %e, + "Token grant failed in L7 relay" + ); + write_bad_gateway_response(client).await?; + return Ok(()); + } + }; + + // Forward request to upstream and relay response + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( + &req_with_auth, + client, + upstream, + crate::l7::rest::RelayRequestOptions { + resolver: ctx.secret_resolver.as_deref(), + generation_guard: Some(engine.generation_guard()), + websocket_extensions: websocket_extension_mode(config), + request_body_credential_rewrite: config.protocol == L7Protocol::Rest + && config.request_body_credential_rewrite, + }, + ) + .await?; + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => { + debug!( + host = %ctx.host, + port = ctx.port, + "Upstream connection not reusable, closing L7 relay" + ); + return Ok(()); + } + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let mut options = upgrade_options( + config, + ctx, + websocket_request, + &redacted_target, + &req_with_auth.query_params, + Some(engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; + } + } + } else { + // Enforce mode: deny with 403 and close connection (use redacted target) + provider + .deny_with_redacted_target( + &req, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } +} + +fn close_if_stale(guard: &PolicyGenerationGuard, ctx: &L7EvalContext) -> bool { + if !guard.is_stale() { + return false; + } + + ocsf_emit!( + NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "l7") + .message(format!( + "L7 tunnel closed after policy reload [host:{} port:{} captured_generation:{} current_generation:{}]", + ctx.host, + ctx.port, + guard.captured_generation(), + guard.current_generation(), + )) + .build() + ); + true +} + +async fn relay_graphql( + config: &L7EndpointConfig, + engine: &TunnelPolicyEngine, + client: &mut C, + upstream: &mut U, + ctx: &L7EvalContext, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + loop { + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let parsed = match crate::l7::graphql::parse_graphql_http_request( + client, + config.graphql_max_body_bytes, + crate::l7::path::CanonicalizeOptions { + allow_encoded_slash: config.allow_encoded_slash, + ..Default::default() + }, + ) + .await + { + Ok(Some(parsed)) => parsed, + Ok(None) => return Ok(()), + Err(e) => { + if is_benign_connection_error(&e) { + debug!( + host = %ctx.host, + port = ctx.port, + error = %e, + "GraphQL L7 connection closed" + ); + } else { + let detail = + parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); + emit_parse_rejection(ctx, &detail, "l7-graphql"); + } + return Ok(()); + } + }; + + let req = parsed.request; + let graphql_info = parsed.info; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let (eval_target, redacted_target) = if let Some(ref resolver) = ctx.secret_resolver { + match secrets::rewrite_target_for_eval(&req.target, resolver) { + Ok(result) => (result.resolved, result.redacted), + Err(e) => { + warn!( + host = %ctx.host, + port = ctx.port, + error = %e, + "credential resolution failed in GraphQL request target, rejecting" + ); + let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client.write_all(response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(()); + } + } + } else { + (req.target.clone(), req.target.clone()) + }; + + let request_info = L7RequestInfo { + action: req.action.clone(), + target: redacted_target.clone(), + query_params: req.query_params.clone(), + graphql: Some(graphql_info.clone()), + }; + + // Malformed or ambiguous GraphQL requests, such as duplicated GET + // control parameters, are rejected before policy evaluation. This + // keeps parser-differential cases fail-closed even if the endpoint is + // otherwise in audit mode. + let parse_error_reason = graphql_info + .error + .as_deref() + .map(|error| format!("GraphQL request rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(engine, ctx, &request_info)? + }; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let decision_str = match (allowed, config.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + + { + let (action_id, disposition_id, severity) = match decision_str { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let gql_summary = graphql_log_summary(&graphql_info); + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + &request_info.action, + OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "l7-graphql") + .message(format!( + "GRAPHQL_L7_REQUEST {decision_str} {} {}:{}{} {gql_summary} reason={}", + request_info.action, ctx.host, ctx.port, redacted_target, reason, + )) + .build(); + ocsf_emit!(event); + } + + let _ = &eval_target; + + if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + &req, + client, + upstream, + ctx.secret_resolver.as_deref(), + Some(engine.generation_guard()), + ) + .await?; + match outcome { + RelayOutcome::Reusable => {} + RelayOutcome::Consumed => { + debug!( + host = %ctx.host, + port = ctx.port, + "Upstream connection not reusable, closing GraphQL L7 relay" + ); + return Ok(()); + } + RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } => { + let options = UpgradeRelayOptions { + websocket: WebSocketUpgradeBehavior { + permessage_deflate: websocket_permessage_deflate, + ..Default::default() + }, + ..Default::default() + }; + return handle_upgrade( + client, upstream, overflow, &ctx.host, ctx.port, options, + ) + .await; + } + } + } else { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } +} + +fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { + if let Some(error) = &info.error { + return format!("graphql_error={error:?}"); + } + let ops: Vec = info + .operations + .iter() + .map(|op| { + let name = op.operation_name.as_deref().unwrap_or("-"); + let fields = if op.fields.is_empty() { + "-".to_string() + } else { + op.fields.join(",") + }; + let persisted = op + .persisted_query_hash + .as_deref() + .or(op.persisted_query_id.as_deref()) + .unwrap_or("-"); + format!( + "type={} name={} fields={} persisted={}", + op.operation_type, name, fields, persisted + ) + }) + .collect(); + format!("graphql_ops={}", ops.join(";")) +} + +/// Check if a miette error represents a benign connection close. +/// +/// TLS handshake EOF, missing `close_notify`, connection resets, and broken +/// pipes are all normal lifecycle events for proxied connections — not worth +/// a WARN that interrupts the user's terminal. +fn is_benign_connection_error(err: &miette::Report) -> bool { + const BENIGN: &[&str] = &[ + "close_notify", + "tls handshake eof", + "connection reset", + "broken pipe", + "unexpected eof", + "client disconnected mid-request", + ]; + let msg = err.to_string().to_ascii_lowercase(); + BENIGN.iter().any(|pat| msg.contains(pat)) +} + +/// Evaluate an L7 request against the OPA engine. +/// +/// Returns `(allowed, deny_reason)`. +pub fn evaluate_l7_request( + engine: &TunnelPolicyEngine, + ctx: &L7EvalContext, + request: &L7RequestInfo, +) -> Result<(bool, String)> { + if engine.is_stale() { + return Err(miette!( + "L7 tunnel policy generation is stale [captured_generation:{} current_generation:{}]", + engine.captured_generation(), + engine.current_generation(), + )); + } + + let input_json = serde_json::json!({ + "network": { + "host": ctx.host, + "port": ctx.port, + }, + "exec": { + "path": ctx.binary_path, + "ancestors": ctx.ancestors, + "cmdline_paths": ctx.cmdline_paths, + }, + "request": { + "method": request.action, + "path": request.target, + "query_params": request.query_params.clone(), + "graphql": request.graphql.clone(), + } + }); + + let mut engine = engine + .engine() + .lock() + .map_err(|_| miette!("OPA engine lock poisoned"))?; + + engine + .set_input_json(&input_json.to_string()) + .map_err(|e| miette!("{e}"))?; + + let allowed = engine + .eval_rule("data.openshell.sandbox.allow_request".into()) + .map_err(|e| miette!("{e}"))?; + let allowed = allowed == regorus::Value::from(true); + + let reason = if allowed { + String::new() + } else { + let val = engine + .eval_rule("data.openshell.sandbox.request_deny_reason".into()) + .map_err(|e| miette!("{e}"))?; + match val { + regorus::Value::String(s) => s.to_string(), + regorus::Value::Undefined => "request denied by policy".to_string(), + other => other.to_string(), + } + }; + + Ok((allowed, reason)) +} + +/// Relay HTTP traffic with credential injection only (no L7 OPA evaluation). +/// +/// Used when TLS is auto-terminated but no L7 policy (`protocol` + `access`/`rules`) +/// is configured. Parses HTTP requests minimally to rewrite credential +/// placeholders and log requests for observability, then forwards everything. +pub async fn relay_passthrough_with_credentials( + client: &mut C, + upstream: &mut U, + ctx: &L7EvalContext, + generation_guard: &PolicyGenerationGuard, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + // Passthrough path: no L7 policy is enforced here, so use default + // (strict) canonicalization options. Calls to GitLab-style APIs that + // need `%2F` must be configured as L7 endpoints so the per-endpoint + // `allow_encoded_slash` opt-in applies. + let provider = crate::l7::rest::RestProvider::default(); + let mut request_count: u64 = 0; + let resolver = ctx.secret_resolver.as_deref(); + + loop { + if close_if_stale(generation_guard, ctx) { + return Ok(()); + } + + // Read next request from client. + let req = match provider.parse_request(client).await { + Ok(Some(req)) => req, + Ok(None) => break, // Client closed connection. + Err(e) => { + if is_benign_connection_error(&e) { + break; + } + let detail = + parse_rejection_detail(&e.to_string(), ParseRejectionMode::Passthrough); + emit_parse_rejection(ctx, &detail, "http-parser"); + return Ok(()); + } + }; + + if close_if_stale(generation_guard, ctx) { + return Ok(()); + } + + request_count += 1; + + // Resolve and redact the target for logging. + let redacted_target = if let Some(ref res) = ctx.secret_resolver { + match secrets::rewrite_target_for_eval(&req.target, res) { + Ok(result) => result.redacted, + Err(e) => { + warn!( + host = %ctx.host, + port = ctx.port, + error = %e, + "credential resolution failed in request target, rejecting" + ); + let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client.write_all(response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + return Ok(()); + } + } + } else { + req.target.clone() + }; + + // Log for observability via OCSF HTTP Activity event. + // Uses redacted_target (path only, no query params) to avoid logging secrets. + let has_creds = resolver.is_some(); + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .http_request(HttpRequest::new( + &req.action, + OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .message(format!( + "HTTP_REQUEST {} {}:{}{} credentials_injected={has_creds} request_num={request_count}", + req.action, ctx.host, ctx.port, redacted_target, + )) + .build(); + ocsf_emit!(event); + } + + let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await + { + Ok(req) => req, + Err(e) => { + warn!( + host = %ctx.host, + port = ctx.port, + error = %e, + "Token grant failed in passthrough relay" + ); + write_bad_gateway_response(client).await?; + return Ok(()); + } + }; + + // Forward request with credential rewriting and relay the response. + // relay_http_request_with_resolver handles both directions: it sends + // the request upstream and reads the response back to the client. + let outcome = crate::l7::rest::relay_http_request_with_options_guarded( + &req_with_auth, + client, + upstream, + crate::l7::rest::RelayRequestOptions { + resolver, + generation_guard: Some(generation_guard), + ..Default::default() + }, + ) + .await?; + + match outcome { + RelayOutcome::Reusable => {} // continue loop + RelayOutcome::Consumed => break, + RelayOutcome::Upgraded { overflow, .. } => { + return handle_upgrade( + client, + upstream, + overflow, + &ctx.host, + ctx.port, + UpgradeRelayOptions::default(), + ) + .await; + } + } + } + + debug!( + host = %ctx.host, + port = ctx.port, + total_requests = request_count, + "Credential injection relay completed" + ); + + Ok(()) +} + +async fn write_bad_gateway_response(client: &mut W) -> Result<()> +where + W: AsyncWrite + Unpin, +{ + let response = b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + client.write_all(response).await.into_diagnostic()?; + client.flush().await.into_diagnostic()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::opa::{NetworkInput, OpaEngine}; + use std::path::PathBuf; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + + fn rest_token_grant_relay_context( + resolver_response: std::result::Result<&str, &str>, + ) -> ( + L7EndpointConfig, + TunnelPolicyEngine, + L7EvalContext, + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture, + ) { + let data = r#" +network_policies: + rest_api: + name: rest_api + endpoints: + - host: api.example.test + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/v1/**" + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.test".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let provider_key = "api.example.test\t8080\t/v1/**\tprovider:access_token"; + let fixture = match resolver_response { + Ok(token) => { + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::success( + provider_key, + token, + ) + } + Err(error) => { + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::failure( + provider_key, + error, + ) + } + }; + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: Some(fixture.dynamic_credentials()), + token_grant_resolver: Some(fixture.resolver()), + }; + + (config, tunnel_engine, ctx, fixture) + } + + fn passthrough_token_grant_relay_context( + resolver_response: std::result::Result<&str, &str>, + ) -> ( + PolicyGenerationGuard, + L7EvalContext, + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture, + ) { + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(TEST_POLICY, policy_data).unwrap(); + let generation_guard = engine + .generation_guard(engine.current_generation()) + .unwrap(); + let provider_key = "api.example.test\t8080\t/v1/**\tprovider:access_token"; + let fixture = match resolver_response { + Ok(token) => { + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::success( + provider_key, + token, + ) + } + Err(error) => { + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::failure( + provider_key, + error, + ) + } + }; + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: Some(fixture.dynamic_credentials()), + token_grant_resolver: Some(fixture.resolver()), + }; + + (generation_guard, ctx, fixture) + } + + fn authorization_header_count(headers: &str) -> usize { + headers + .lines() + .filter(|line| { + line.split_once(':') + .is_some_and(|(name, _)| name.eq_ignore_ascii_case("authorization")) + }) + .count() + } + + #[test] + fn parse_rejection_detail_adds_l7_hint_for_encoded_slash() { + let detail = parse_rejection_detail( + "HTTP request-target rejected: request-target contains an encoded '/' (%2F) which is not allowed on this endpoint", + ParseRejectionMode::L7Endpoint, + ); + + assert!(detail.contains("allow_encoded_slash: true")); + assert!(detail.contains("upstream requires encoded slashes")); + } + + #[test] + fn parse_rejection_detail_adds_passthrough_hint_for_encoded_slash() { + let detail = parse_rejection_detail( + "HTTP request-target rejected: request-target contains an encoded '/' (%2F) which is not allowed on this endpoint", + ParseRejectionMode::Passthrough, + ); + + assert!(detail.contains("protocol: rest")); + assert!(detail.contains("allow_encoded_slash: true")); + assert!(detail.contains("tls: skip")); + } + + #[test] + fn parse_rejection_detail_preserves_other_errors() { + let error = "HTTP headers contain invalid UTF-8"; + + assert_eq!( + parse_rejection_detail(error, ParseRejectionMode::L7Endpoint), + error + ); + } + + #[tokio::test] + async fn l7_rest_relay_injects_token_grant_authorization_header() { + let (config, tunnel_engine, ctx, fixture) = + rest_token_grant_relay_context(Ok("grant-token")); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n", + ) + .await + .unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + + assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); + assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); + assert!(!upstream_request.contains("stale-token")); + assert_eq!(authorization_header_count(&upstream_request), 1); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[tokio::test] + async fn l7_rest_relay_token_grant_failure_does_not_forward_request() { + let (config, tunnel_engine, ctx, fixture) = + rest_token_grant_relay_context(Err("oauth unavailable")); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nConnection: close\r\n\r\n", + ) + .await + .unwrap(); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("bad gateway response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("502 Bad Gateway")); + + let mut upstream_request = [0u8; 128]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("upstream should close without forwarded data") + .unwrap(); + assert_eq!(n, 0, "unauthenticated request must not reach upstream"); + + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[tokio::test] + async fn passthrough_relay_injects_token_grant_authorization_header() { + let (generation_guard, ctx, fixture) = + passthrough_token_grant_relay_context(Ok("grant-token")); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + ) + .await + }); + + app.write_all( + b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n", + ) + .await + .unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + + assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); + assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); + assert!(!upstream_request.contains("stale-token")); + assert_eq!(authorization_header_count(&upstream_request), 1); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[tokio::test] + async fn passthrough_relay_token_grant_failure_returns_bad_gateway_without_forwarding() { + let (generation_guard, ctx, fixture) = + passthrough_token_grant_relay_context(Err("oauth unavailable")); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + ) + .await + }); + + app.write_all( + b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nConnection: close\r\n\r\n", + ) + .await + .unwrap(); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("bad gateway response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("502 Bad Gateway")); + + let mut upstream_request = [0u8; 128]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("upstream should close without forwarded data") + .unwrap(); + assert_eq!(n, 0, "unauthenticated request must not reach upstream"); + + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[test] + fn websocket_text_policy_requires_explicit_message_rule() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&input) + .unwrap() + .1; + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let request = L7RequestInfo { + action: "WEBSOCKET_TEXT".into(), + target: "/ws".into(), + query_params: std::collections::HashMap::new(), + graphql: None, + }; + + let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + + assert!(!allowed); + assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); + } + + #[tokio::test] + async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Rest, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + assert!(forwarded.contains("Connection: Upgrade\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", + ) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should fail closed on invalid accept") + .unwrap() + .expect_err("invalid accept must fail the route-selected relay"); + assert!(err.to_string().contains("Sec-WebSocket-Accept")); + + let mut response = [0u8; 1]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client side should close without 101") + .unwrap(); + assert_eq!(n, 0, "invalid response must not forward 101 headers"); + } + + #[tokio::test] + async fn route_selected_websocket_rewrites_text_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/ws".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten websocket text should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!(rewritten, r#"{"op":2,"d":{"token":"real-token"}}"#); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + #[tokio::test] + async fn route_selected_graphql_websocket_rewrites_connection_init_credentials_after_upgrade() { + let data = r#" +network_policies: + route_api: + name: route_api + endpoints: + - host: gateway.example.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + websocket_credential_rewrite: true + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let configs = vec![L7EndpointConfig { + protocol: L7Protocol::Websocket, + path: "/graphql".into(), + tls: crate::l7::TlsMode::Auto, + enforcement: EnforcementMode::Enforce, + graphql_max_body_bytes: 0, + allow_encoded_slash: false, + websocket_credential_rewrite: true, + request_body_credential_rewrite: false, + websocket_graphql_policy: true, + }]; + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let placeholder = child_env.get("T").expect("placeholder env"); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "route_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver.map(Arc::new), + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_route_selection( + &configs, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /graphql HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut forwarded = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut forwarded), + ) + .await + .expect("upgrade request should reach upstream") + .unwrap(); + let forwarded = String::from_utf8_lossy(&forwarded[..n]); + assert!(forwarded.contains("GET /graphql HTTP/1.1")); + assert!(forwarded.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 1024]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("client should receive upgrade response") + .unwrap(); + assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); + + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let (masked, rewritten) = tokio::time::timeout( + std::time::Duration::from_secs(1), + read_text_frame(&mut upstream), + ) + .await + .expect("rewritten GraphQL WebSocket control message should reach upstream") + .unwrap(); + assert!(masked, "client-to-server frame must remain masked"); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + + drop(app); + drop(upstream); + let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn read_text_frame( + reader: &mut R, + ) -> std::io::Result<(bool, String)> { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await?; + assert_eq!(header[0] & 0x0f, 0x1, "expected text frame"); + let masked = header[1] & 0x80 != 0; + let payload_len = usize::from(header[1] & 0x7f); + assert!(payload_len <= 125, "test helper only supports small frames"); + let mut mask = [0u8; 4]; + if masked { + reader.read_exact(&mut mask).await?; + } + let mut payload = vec![0u8; payload_len]; + reader.read_exact(&mut payload).await?; + if masked { + for (idx, byte) in payload.iter_mut().enumerate() { + *byte ^= mask[idx % 4]; + } + } + Ok((masked, String::from_utf8(payload).expect("text payload"))) + } + + #[tokio::test] + async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { + let initial_data = r#" +network_policies: + rest_api: + name: rest_api + endpoints: + - host: api.example.test + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: POST + path: "/write" + binaries: + - { path: /usr/bin/curl } +"#; + let reloaded_data = r#" +network_policies: + rest_api: + name: rest_api + endpoints: + - host: api.example.test + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/write" + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, initial_data).unwrap(); + let input = NetworkInput { + host: "api.example.test".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"POST /write HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n", + ) + .await + .unwrap(); + + let mut first_upstream = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut first_upstream), + ) + .await + .expect("first request should reach upstream") + .unwrap(); + let first_upstream = String::from_utf8_lossy(&first_upstream[..n]); + assert!(first_upstream.starts_with("POST /write HTTP/1.1")); + + upstream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nOK") + .await + .unwrap(); + + let mut first_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut first_response), + ) + .await + .expect("first response should reach client") + .unwrap(); + let first_response = String::from_utf8_lossy(&first_response[..n]); + assert!(first_response.contains("200 OK")); + + engine.reload(TEST_POLICY, reloaded_data).unwrap(); + app.write_all( + b"POST /write HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n", + ) + .await + .unwrap(); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should close stale tunnel") + .unwrap() + .unwrap(); + + let mut second_upstream = [0u8; 128]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut second_upstream), + ) + .await + .expect("upstream side should close") + .unwrap(); + assert_eq!(n, 0, "stale request must not be forwarded upstream"); + } + + #[tokio::test] + async fn passthrough_relay_closes_keep_alive_tunnel_after_policy_generation_change() { + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(TEST_POLICY, policy_data).unwrap(); + let generation_guard = engine + .generation_guard(engine.current_generation()) + .unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + ) + .await + }); + + app.write_all( + b"GET /first HTTP/1.1\r\nHost: api.example.test\r\nConnection: keep-alive\r\n\r\n", + ) + .await + .unwrap(); + + let mut first_upstream = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut first_upstream), + ) + .await + .expect("first passthrough request should reach upstream") + .unwrap(); + let first_upstream = String::from_utf8_lossy(&first_upstream[..n]); + assert!(first_upstream.starts_with("GET /first HTTP/1.1")); + + upstream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nOK") + .await + .unwrap(); + + let mut first_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut first_response), + ) + .await + .expect("first passthrough response should reach client") + .unwrap(); + let first_response = String::from_utf8_lossy(&first_response[..n]); + assert!(first_response.contains("200 OK")); + + engine.reload(TEST_POLICY, policy_data).unwrap(); + app.write_all( + b"GET /second HTTP/1.1\r\nHost: api.example.test\r\nConnection: keep-alive\r\n\r\n", + ) + .await + .unwrap(); + + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("passthrough relay should close stale tunnel") + .unwrap() + .unwrap(); + + let mut second_upstream = [0u8; 128]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut second_upstream), + ) + .await + .expect("upstream side should close") + .unwrap(); + assert_eq!( + n, 0, + "stale passthrough request must not be forwarded upstream" + ); + } +} From 96dbfc198e69e239cbfe64dedd533fb07b872fa9 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 16:15:06 -0700 Subject: [PATCH 03/12] refactor(l7): share HTTP body inspection helper Move HTTP request body buffering and chunked-body normalization out of the GraphQL module so other HTTP-carried L7 protocols can inspect request bodies without depending on GraphQL internals. Signed-off-by: Kris Hicks --- crates/openshell-sandbox/src/l7/graphql.rs | 636 +++++++++++++++++++++ crates/openshell-sandbox/src/l7/http.rs | 199 +++++++ crates/openshell-sandbox/src/l7/mod.rs | 1 + 3 files changed, 836 insertions(+) create mode 100644 crates/openshell-sandbox/src/l7/graphql.rs create mode 100644 crates/openshell-sandbox/src/l7/http.rs diff --git a/crates/openshell-sandbox/src/l7/graphql.rs b/crates/openshell-sandbox/src/l7/graphql.rs new file mode 100644 index 000000000..f548f2a10 --- /dev/null +++ b/crates/openshell-sandbox/src/l7/graphql.rs @@ -0,0 +1,636 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! GraphQL-over-HTTP L7 inspection. + +use crate::l7::provider::{L7Provider, L7Request}; +use apollo_parser::Parser; +use apollo_parser::cst; +use miette::{Result, miette}; +use serde::Serialize; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub const DEFAULT_MAX_BODY_BYTES: usize = 64 * 1024; + +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct GraphqlRequestInfo { + pub operations: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, PartialEq, Eq)] +pub struct GraphqlOperationInfo { + pub operation_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub operation_name: Option, + pub fields: Vec, + pub persisted_query: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub persisted_query_hash: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub persisted_query_id: Option, +} + +pub struct GraphqlHttpRequest { + pub request: L7Request, + pub info: GraphqlRequestInfo, +} + +pub async fn parse_graphql_http_request( + client: &mut C, + max_body_bytes: usize, + canonicalize_options: crate::l7::path::CanonicalizeOptions, +) -> Result> { + let provider = crate::l7::rest::RestProvider::with_options(canonicalize_options); + let Some(mut request) = provider.parse_request(client).await? else { + return Ok(None); + }; + + let info = inspect_graphql_request(client, &mut request, max_body_bytes).await?; + + Ok(Some(GraphqlHttpRequest { request, info })) +} + +pub(crate) async fn inspect_graphql_request( + client: &mut C, + request: &mut L7Request, + max_body_bytes: usize, +) -> Result { + let header_str = header_str(request)?; + reject_unsupported_headers(header_str)?; + let body = crate::l7::http::read_body_for_inspection(client, request, max_body_bytes).await?; + Ok(classify_request(request, &body)) +} + +pub fn classify_request(request: &L7Request, body: &[u8]) -> GraphqlRequestInfo { + match classify_request_inner(request, body) { + Ok(operations) => GraphqlRequestInfo { + operations, + error: None, + }, + Err(err) => GraphqlRequestInfo { + operations: Vec::new(), + error: Some(err), + }, + } +} + +pub fn classify_json_envelope_value(value: &Value) -> GraphqlRequestInfo { + match classify_json_envelope(value) { + Ok(operations) => GraphqlRequestInfo { + operations, + error: None, + }, + Err(err) => GraphqlRequestInfo { + operations: Vec::new(), + error: Some(err), + }, + } +} + +fn classify_request_inner( + request: &L7Request, + body: &[u8], +) -> std::result::Result, String> { + match request.action.to_ascii_uppercase().as_str() { + "GET" => classify_get(request), + "POST" => classify_post(body), + method => Err(format!("unsupported GraphQL HTTP method {method}")), + } +} + +fn classify_get(request: &L7Request) -> std::result::Result, String> { + let query = unique_query_value(&request.query_params, "query")?; + let operation_name = unique_query_value(&request.query_params, "operationName")?; + let extensions = unique_query_value(&request.query_params, "extensions")? + .and_then(|raw| serde_json::from_str::(&raw).ok()); + let id = unique_persisted_query_id(&request.query_params)?; + + classify_envelope( + query.as_deref(), + operation_name.as_deref(), + extensions.as_ref(), + id, + ) +} + +fn classify_post(body: &[u8]) -> std::result::Result, String> { + if body.is_empty() { + return Err("GraphQL POST body is empty".to_string()); + } + let value: Value = serde_json::from_slice(body) + .map_err(|err| format!("GraphQL request body is not valid JSON: {err}"))?; + + match value { + Value::Array(items) => { + if items.is_empty() { + return Err("GraphQL batch request is empty".to_string()); + } + let mut operations = Vec::new(); + for item in items { + operations.extend(classify_json_envelope(&item)?); + } + Ok(operations) + } + Value::Object(_) => classify_json_envelope(&value), + _ => Err("GraphQL JSON envelope must be an object or array".to_string()), + } +} + +fn classify_json_envelope(value: &Value) -> std::result::Result, String> { + let obj = value + .as_object() + .ok_or_else(|| "GraphQL batch item must be an object".to_string())?; + let query = obj.get("query").and_then(Value::as_str); + let operation_name = obj.get("operationName").and_then(Value::as_str); + let extensions = obj.get("extensions"); + let id = obj + .get("id") + .or_else(|| obj.get("documentId")) + .or_else(|| obj.get("queryId")) + .and_then(Value::as_str) + .map(ToString::to_string); + + classify_envelope(query, operation_name, extensions, id) +} + +fn classify_envelope( + query: Option<&str>, + operation_name: Option<&str>, + extensions: Option<&Value>, + persisted_id: Option, +) -> std::result::Result, String> { + let persisted_hash = persisted_query_hash(extensions); + let query = query.filter(|q| !q.trim().is_empty()); + + if let Some(query) = query { + let mut operation = classify_document(query, operation_name)?; + if let Some(hash) = persisted_hash { + operation.persisted_query = true; + operation.persisted_query_hash = Some(hash); + } + if let Some(id) = persisted_id { + operation.persisted_query = true; + operation.persisted_query_id = Some(id); + } + return Ok(vec![operation]); + } + + if persisted_hash.is_some() || persisted_id.is_some() { + return Ok(vec![GraphqlOperationInfo { + operation_type: String::new(), + operation_name: operation_name.map(ToString::to_string), + fields: Vec::new(), + persisted_query: true, + persisted_query_hash: persisted_hash, + persisted_query_id: persisted_id, + }]); + } + + Err("GraphQL request has no query document or persisted query identifier".to_string()) +} + +fn classify_document( + query: &str, + operation_name: Option<&str>, +) -> std::result::Result { + let parser = Parser::new(query).recursion_limit(128).token_limit(20_000); + let cst = parser.parse(); + let mut parse_errors = cst.errors(); + if let Some(err) = parse_errors.next() { + return Err(format!("GraphQL document parse error: {err}")); + } + + let document = cst.document(); + let mut operations = Vec::new(); + let mut fragments = HashMap::new(); + + for definition in document.definitions() { + match definition { + cst::Definition::OperationDefinition(operation) => operations.push(operation), + cst::Definition::FragmentDefinition(fragment) => { + if let Some(name) = fragment.fragment_name().and_then(|n| n.name()) { + fragments.insert(name.text().to_string(), fragment); + } + } + _ => {} + } + } + + if operations.is_empty() { + return Err("GraphQL document contains no executable operation".to_string()); + } + + let selected = if let Some(expected_name) = operation_name.filter(|name| !name.is_empty()) { + operations + .into_iter() + .find(|op| { + op.name() + .is_some_and(|name| name.text().as_ref() == expected_name) + }) + .ok_or_else(|| format!("GraphQL operationName {expected_name:?} was not found"))? + } else if operations.len() == 1 { + operations.remove(0) + } else { + return Err("GraphQL document has multiple operations but no operationName".to_string()); + }; + + let operation_type = operation_type(&selected); + let operation_name = selected.name().map(|name| name.text().to_string()); + let selection_set = selected + .selection_set() + .ok_or_else(|| "GraphQL operation has no selection set".to_string())?; + let mut fields = HashSet::new(); + let mut visited_fragments = HashSet::new(); + collect_root_fields( + selection_set, + &fragments, + &mut visited_fragments, + &mut fields, + ); + let mut fields: Vec<_> = fields.into_iter().collect(); + fields.sort(); + + Ok(GraphqlOperationInfo { + operation_type, + operation_name, + fields, + persisted_query: false, + persisted_query_hash: None, + persisted_query_id: None, + }) +} + +fn operation_type(operation: &cst::OperationDefinition) -> String { + let Some(operation_type) = operation.operation_type() else { + return "query".to_string(); + }; + if operation_type.mutation_token().is_some() { + "mutation".to_string() + } else if operation_type.subscription_token().is_some() { + "subscription".to_string() + } else { + "query".to_string() + } +} + +fn collect_root_fields( + selection_set: cst::SelectionSet, + fragments: &HashMap, + visited_fragments: &mut HashSet, + fields: &mut HashSet, +) { + for selection in selection_set.selections() { + match selection { + cst::Selection::Field(field) => { + if let Some(name) = field.name() { + fields.insert(name.text().to_string()); + } + } + cst::Selection::InlineFragment(fragment) => { + if let Some(selection_set) = fragment.selection_set() { + collect_root_fields(selection_set, fragments, visited_fragments, fields); + } + } + cst::Selection::FragmentSpread(spread) => { + let Some(name) = spread.fragment_name().and_then(|n| n.name()) else { + continue; + }; + let name = name.text().to_string(); + if !visited_fragments.insert(name.clone()) { + continue; + } + if let Some(fragment) = fragments.get(&name) + && let Some(selection_set) = fragment.selection_set() + { + collect_root_fields(selection_set, fragments, visited_fragments, fields); + } + } + } + } +} + +fn persisted_query_hash(extensions: Option<&Value>) -> Option { + extensions? + .get("persistedQuery")? + .get("sha256Hash")? + .as_str() + .filter(|hash| !hash.is_empty()) + .map(ToString::to_string) +} + +fn unique_query_value( + params: &HashMap>, + key: &str, +) -> std::result::Result, String> { + let Some(values) = params.get(key) else { + return Ok(None); + }; + if values.len() > 1 { + return Err(format!( + "GraphQL GET parameter {key:?} must not appear more than once" + )); + } + Ok(values.first().filter(|value| !value.is_empty()).cloned()) +} + +fn unique_persisted_query_id( + params: &HashMap>, +) -> std::result::Result, String> { + let mut selected: Option<(String, String)> = None; + for key in ["id", "documentId", "queryId"] { + let Some(value) = unique_query_value(params, key)? else { + continue; + }; + if let Some((existing_key, _)) = selected { + return Err(format!( + "GraphQL GET persisted-query id parameters {existing_key:?} and {key:?} must not be combined" + )); + } + selected = Some((key.to_string(), value)); + } + Ok(selected.map(|(_, value)| value)) +} + +fn header_str(request: &L7Request) -> Result<&str> { + let header_end = request + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(request.raw_header.len(), |p| p + 4); + std::str::from_utf8(&request.raw_header[..header_end]) + .map_err(|_| miette!("GraphQL HTTP headers contain invalid UTF-8")) +} + +fn reject_unsupported_headers(headers: &str) -> Result<()> { + for line in headers.lines().skip(1) { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("content-encoding:") { + let encoding = lower.split_once(':').map_or("", |(_, v)| v.trim()); + if !encoding.is_empty() && encoding != "identity" { + return Err(miette!( + "GraphQL request content-encoding {encoding:?} is not supported" + )); + } + } + if lower.starts_with("content-type:") { + let content_type = lower.split_once(':').map_or("", |(_, v)| v.trim()); + if content_type.starts_with("multipart/") { + return Err(miette!("GraphQL multipart requests are not supported")); + } + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::l7::provider::BodyLength; + + fn request(method: &str, target: &str) -> L7Request { + L7Request { + action: method.to_string(), + target: target.to_string(), + query_params: crate::l7::rest::parse_target_query(target).unwrap().1, + raw_header: format!("{method} {target} HTTP/1.1\r\nHost: example.com\r\n\r\n") + .into_bytes(), + body_length: BodyLength::None, + } + } + + #[test] + fn classifies_simple_query() { + let req = request("POST", "/graphql"); + let info = classify_request(&req, br#"{"query":"query Viewer { viewer { login } }"}"#); + assert_eq!(info.error, None); + assert_eq!(info.operations[0].operation_type, "query"); + assert_eq!(info.operations[0].fields, vec!["viewer"]); + } + + #[test] + fn classifies_mutation_field_not_alias() { + let req = request("POST", "/graphql"); + let info = classify_request( + &req, + br#"{"query":"mutation M { safeAlias: volumeDelete(volumeId:\"x\") { id } }","operationName":"M"}"#, + ); + assert_eq!(info.error, None); + assert_eq!(info.operations[0].operation_type, "mutation"); + assert_eq!(info.operations[0].operation_name.as_deref(), Some("M")); + assert_eq!(info.operations[0].fields, vec!["volumeDelete"]); + } + + #[test] + fn expands_root_fragments() { + let req = request("POST", "/graphql"); + let info = classify_request( + &req, + br#"{"query":"query Q { ...RootFields } fragment RootFields on Query { viewer repository(owner:\"o\", name:\"r\") { id } }"}"#, + ); + assert_eq!(info.error, None); + assert_eq!(info.operations[0].fields, vec!["repository", "viewer"]); + } + + #[test] + fn multiple_operations_without_name_errors() { + let req = request("POST", "/graphql"); + let info = classify_request( + &req, + br#"{"query":"query A { viewer { login } } query B { rateLimit { limit } }"}"#, + ); + assert!(info.error.unwrap().contains("multiple operations")); + } + + #[test] + fn detects_hash_only_apollo_persisted_query() { + let req = request("POST", "/graphql"); + let info = classify_request( + &req, + br#"{"operationName":"Viewer","extensions":{"persistedQuery":{"version":1,"sha256Hash":"abc123"}}}"#, + ); + assert_eq!(info.error, None); + let op = &info.operations[0]; + assert!(op.persisted_query); + assert_eq!(op.operation_name.as_deref(), Some("Viewer")); + assert_eq!(op.persisted_query_hash.as_deref(), Some("abc123")); + } + + #[test] + fn graphql_get_rejects_duplicate_query_parameter() { + let req = request( + "GET", + "/graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D&query=mutation+Delete+%7B+volumeDelete(volumeId%3A%22x%22)+%7B+id+%7D+%7D", + ); + let info = classify_request(&req, b""); + assert!( + info.error + .as_deref() + .is_some_and(|err| err.contains("must not appear more than once")), + "expected duplicate control parameter error, got {info:?}" + ); + } + + #[test] + fn graphql_get_rejects_ambiguous_persisted_query_ids() { + let req = request("GET", "/graphql?id=one&queryId=two"); + let info = classify_request(&req, b""); + assert!( + info.error + .as_deref() + .is_some_and(|err| err.contains("must not be combined")), + "expected ambiguous persisted-query id error, got {info:?}" + ); + } + + #[tokio::test] + async fn chunked_graphql_post_is_normalized_after_inspection() { + let body = br#"{"query":"query Viewer { viewer { login } }"}"#; + let mut raw_header = + b"POST /graphql HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\nTrailer: X-Sig\r\nX-Test: yes\r\n\r\n" + .to_vec(); + raw_header.extend_from_slice(format!("{:x}\r\n", body.len()).as_bytes()); + raw_header.extend_from_slice(body); + raw_header.extend_from_slice(b"\r\n0\r\nX-Sig: ignored\r\n\r\n"); + + let mut req = L7Request { + action: "POST".to_string(), + target: "/graphql".to_string(), + query_params: HashMap::new(), + raw_header, + body_length: BodyLength::Chunked, + }; + let mut client = tokio::io::empty(); + + let info = inspect_graphql_request(&mut client, &mut req, DEFAULT_MAX_BODY_BYTES) + .await + .expect("chunked body should inspect"); + + assert_eq!(info.error, None); + assert!(matches!( + req.body_length, + BodyLength::ContentLength(len) if len == body.len() as u64 + )); + let forwarded = String::from_utf8_lossy(&req.raw_header); + assert!(forwarded.contains(&format!("Content-Length: {}", body.len()))); + assert!(forwarded.contains("X-Test: yes\r\n")); + assert!(!forwarded.to_ascii_lowercase().contains("transfer-encoding")); + assert!(!forwarded.to_ascii_lowercase().contains("trailer:")); + assert!(req.raw_header.ends_with(body)); + } + + #[tokio::test] + async fn absolute_form_chunked_graphql_post_classifies_after_inspection() { + let body = br#"{"query":"query Viewer { viewer { login } }"}"#; + let mut raw_header = + b"POST http://example.com/graphql HTTP/1.1\r\nHost: example.com\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\nConnection: close\r\n\r\n" + .to_vec(); + raw_header.extend_from_slice(format!("{:x}\r\n", body.len()).as_bytes()); + raw_header.extend_from_slice(body); + raw_header.extend_from_slice(b"\r\n0\r\n\r\n"); + + let mut req = L7Request { + action: "POST".to_string(), + target: "/graphql".to_string(), + query_params: HashMap::new(), + raw_header, + body_length: BodyLength::Chunked, + }; + let mut client = tokio::io::empty(); + + let info = inspect_graphql_request(&mut client, &mut req, DEFAULT_MAX_BODY_BYTES) + .await + .expect("absolute-form chunked body should inspect"); + + assert_eq!(info.error, None); + assert_eq!(info.operations[0].operation_type, "query"); + assert_eq!(info.operations[0].fields, vec!["viewer"]); + } + + #[tokio::test] + async fn absolute_form_chunked_graphql_post_is_allowed_by_field_policy() { + let body = br#"{"query":"query Viewer { viewer { login } }"}"#; + let mut raw_header = + b"POST http://host.openshell.internal:8080/graphql HTTP/1.1\r\nHost: host.openshell.internal:8080\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\nConnection: close\r\n\r\n" + .to_vec(); + raw_header.extend_from_slice(format!("{:x}\r\n", body.len()).as_bytes()); + raw_header.extend_from_slice(body); + raw_header.extend_from_slice(b"\r\n0\r\n\r\n"); + + let mut req = L7Request { + action: "POST".to_string(), + target: "/graphql".to_string(), + query_params: HashMap::new(), + raw_header, + body_length: BodyLength::Chunked, + }; + let mut client = tokio::io::empty(); + let info = inspect_graphql_request(&mut client, &mut req, DEFAULT_MAX_BODY_BYTES) + .await + .expect("chunked body should inspect"); + + let data = r" +network_policies: + test_graphql_l7: + name: test_graphql_l7 + endpoints: + - host: host.openshell.internal + port: 8080 + protocol: graphql + enforcement: enforce + persisted_queries: allow_registered + graphql_persisted_queries: + abc123: + operation_type: query + operation_name: Viewer + fields: [viewer] + rules: + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: mutation + fields: [createIssue] + deny_rules: + - operation_type: mutation + fields: [deleteRepository] + binaries: + - { path: /usr/bin/python3 } +"; + let engine = crate::opa::OpaEngine::from_strings( + include_str!("../../data/sandbox-policy.rego"), + data, + ) + .expect("policy should load"); + let ctx = crate::l7::relay::L7EvalContext { + host: "host.openshell.internal".to_string(), + port: 8080, + policy_name: "test_graphql_l7".to_string(), + binary_path: "/usr/bin/python3".to_string(), + ancestors: Vec::new(), + cmdline_paths: Vec::new(), + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let request_info = crate::l7::L7RequestInfo { + action: req.action, + target: req.target, + query_params: req.query_params, + graphql: Some(info), + }; + + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .expect("tunnel engine should clone"); + let (allowed, reason) = + crate::l7::relay::evaluate_l7_request(&tunnel_engine, &ctx, &request_info) + .expect("evaluation should complete"); + + assert!(allowed, "expected query to be allowed, got {reason}"); + } +} diff --git a/crates/openshell-sandbox/src/l7/http.rs b/crates/openshell-sandbox/src/l7/http.rs new file mode 100644 index 000000000..66269f6ba --- /dev/null +++ b/crates/openshell-sandbox/src/l7/http.rs @@ -0,0 +1,199 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared HTTP/1.1 request helpers for L7 protocols carried over HTTP. + +use crate::l7::provider::{BodyLength, L7Request}; +use miette::{IntoDiagnostic, Result, miette}; +use tokio::io::{AsyncRead, AsyncReadExt}; + +const READ_BUF_SIZE: usize = 8192; + +pub async fn read_body_for_inspection( + client: &mut C, + request: &mut L7Request, + max_body_bytes: usize, +) -> Result> { + let header_end = request + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(request.raw_header.len(), |p| p + 4); + let overflow = request.raw_header[header_end..].to_vec(); + + match request.body_length { + BodyLength::None => Ok(Vec::new()), + BodyLength::ContentLength(len) => { + let len = usize::try_from(len) + .map_err(|_| miette!("HTTP request body length exceeds platform limit"))?; + if len > max_body_bytes { + return Err(miette!( + "HTTP request body exceeds {max_body_bytes} byte inspection limit" + )); + } + if overflow.len() > len { + return Err(miette!( + "HTTP request contains more body bytes than Content-Length" + )); + } + let remaining = len - overflow.len(); + let mut body = overflow; + if remaining > 0 { + let start = body.len(); + body.resize(len, 0); + client + .read_exact(&mut body[start..]) + .await + .into_diagnostic()?; + } + request.raw_header.truncate(header_end); + request.raw_header.extend_from_slice(&body); + Ok(body) + } + BodyLength::Chunked => { + let body = read_chunked_body_for_inspection( + client, + request, + header_end, + overflow, + max_body_bytes, + ) + .await?; + normalize_chunked_request_to_content_length(request, header_end, &body)?; + Ok(body) + } + } +} + +fn normalize_chunked_request_to_content_length( + request: &mut L7Request, + header_end: usize, + body: &[u8], +) -> Result<()> { + let header_str = std::str::from_utf8(&request.raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let header_str = header_str + .strip_suffix("\r\n\r\n") + .ok_or_else(|| miette!("HTTP headers missing terminator"))?; + + let mut normalized = Vec::with_capacity(header_str.len() + body.len() + 32); + for (idx, line) in header_str.split("\r\n").enumerate() { + if idx > 0 { + let name = line + .split_once(':') + .map(|(name, _)| name.trim().to_ascii_lowercase()); + if matches!( + name.as_deref(), + Some("transfer-encoding" | "content-length" | "trailer") + ) { + continue; + } + } + normalized.extend_from_slice(line.as_bytes()); + normalized.extend_from_slice(b"\r\n"); + } + normalized.extend_from_slice(format!("Content-Length: {}\r\n\r\n", body.len()).as_bytes()); + normalized.extend_from_slice(body); + + request.raw_header = normalized; + request.body_length = BodyLength::ContentLength(body.len() as u64); + Ok(()) +} + +async fn read_chunked_body_for_inspection( + client: &mut C, + request: &mut L7Request, + header_end: usize, + overflow: Vec, + max_body_bytes: usize, +) -> Result> { + let mut raw = overflow; + let mut decoded = Vec::new(); + let mut pos = 0usize; + + loop { + let size_line_end = loop { + if let Some(end) = find_crlf(&raw, pos) { + break end; + } + read_more(client, &mut raw, max_body_bytes).await?; + }; + let size_line = std::str::from_utf8(&raw[pos..size_line_end]) + .into_diagnostic() + .map_err(|_| miette!("Invalid UTF-8 in HTTP chunk-size line"))?; + let size_token = size_line + .split(';') + .next() + .map(str::trim) + .unwrap_or_default(); + let chunk_size = usize::from_str_radix(size_token, 16) + .into_diagnostic() + .map_err(|_| miette!("Invalid HTTP chunk size token: {size_token:?}"))?; + pos = size_line_end + 2; + + if decoded.len().saturating_add(chunk_size) > max_body_bytes { + return Err(miette!( + "HTTP request body exceeds {max_body_bytes} byte inspection limit" + )); + } + + if chunk_size == 0 { + loop { + let trailer_end = loop { + if let Some(end) = find_crlf(&raw, pos) { + break end; + } + read_more(client, &mut raw, max_body_bytes).await?; + }; + let trailer_line = &raw[pos..trailer_end]; + pos = trailer_end + 2; + if trailer_line.is_empty() { + request.raw_header.truncate(header_end); + request.raw_header.extend_from_slice(&raw[..pos]); + return Ok(decoded); + } + } + } + + let chunk_end = pos + .checked_add(chunk_size) + .ok_or_else(|| miette!("HTTP chunk size overflow"))?; + let chunk_with_crlf_end = chunk_end + .checked_add(2) + .ok_or_else(|| miette!("HTTP chunk size overflow"))?; + while raw.len() < chunk_with_crlf_end { + read_more(client, &mut raw, max_body_bytes).await?; + } + decoded.extend_from_slice(&raw[pos..chunk_end]); + if raw.get(chunk_end..chunk_with_crlf_end) != Some(&b"\r\n"[..]) { + return Err(miette!("HTTP chunk payload missing terminating CRLF")); + } + pos = chunk_with_crlf_end; + } +} + +async fn read_more( + client: &mut C, + raw: &mut Vec, + max_body_bytes: usize, +) -> Result<()> { + if raw.len() > max_body_bytes.saturating_mul(2).max(max_body_bytes) { + return Err(miette!( + "HTTP chunked request body exceeds inspection framing limit" + )); + } + let mut buf = [0u8; READ_BUF_SIZE]; + let n = client.read(&mut buf).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("HTTP chunked body ended before terminator")); + } + raw.extend_from_slice(&buf[..n]); + Ok(()) +} + +fn find_crlf(buf: &[u8], start: usize) -> Option { + buf.get(start..)? + .windows(2) + .position(|w| w == b"\r\n") + .map(|p| start + p) +} diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index ce6747c6d..17d2e4c5e 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -9,6 +9,7 @@ //! evaluated against OPA policy, and either forwarded or denied. pub mod graphql; +pub(crate) mod http; pub mod inference; pub mod path; pub mod provider; From 1cf6f026e62d42710d2455de93308cec7d737fbf Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 09:46:04 -0700 Subject: [PATCH 04/12] feat(l7): enforce JSON-RPC method rules Add the JSON-RPC HTTP parser and relay path, extract request methods, and pass JSON-RPC metadata into L7 policy evaluation. Wire rpc_method through proto and policy conversion, add Rego matching for JSON-RPC methods, and inspect forward-proxy JSON-RPC bodies before relaying upstream. Signed-off-by: Kris Hicks --- crates/openshell-cli/src/policy_update.rs | 2 + crates/openshell-policy/src/lib.rs | 6 +- crates/openshell-policy/src/merge.rs | 2 + crates/openshell-providers/src/profiles.rs | 2 + .../data/sandbox-policy.rego | 752 ++ crates/openshell-sandbox/src/l7/graphql.rs | 1 + crates/openshell-sandbox/src/l7/jsonrpc.rs | 99 + crates/openshell-sandbox/src/l7/mod.rs | 3 + crates/openshell-sandbox/src/l7/relay.rs | 280 +- crates/openshell-sandbox/src/l7/websocket.rs | 1943 +++++ .../src/mechanistic_mapper.rs | 1 + crates/openshell-sandbox/src/opa.rs | 5248 +++++++++++ crates/openshell-sandbox/src/policy_local.rs | 2027 +++++ crates/openshell-sandbox/src/proxy.rs | 7661 +++++++++++++++++ proto/sandbox.proto | 4 + 15 files changed, 18010 insertions(+), 21 deletions(-) create mode 100644 crates/openshell-sandbox/data/sandbox-policy.rego create mode 100644 crates/openshell-sandbox/src/l7/jsonrpc.rs create mode 100644 crates/openshell-sandbox/src/l7/websocket.rs create mode 100644 crates/openshell-sandbox/src/opa.rs create mode 100644 crates/openshell-sandbox/src/policy_local.rs create mode 100644 crates/openshell-sandbox/src/proxy.rs diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 57656b878..03695d48e 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -205,6 +205,7 @@ fn group_allow_rules(specs: &[String]) -> Result Result SandboxPolicy { operation_type: r.allow.operation_type, operation_name: r.allow.operation_name, fields: r.allow.fields, + rpc_method: r.allow.rpc_method, query: r .allow .query @@ -328,6 +329,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { operation_type: d.operation_type, operation_name: d.operation_name, fields: d.fields, + rpc_method: d.rpc_method, query: d .query .into_iter() @@ -469,6 +471,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { operation_type: a.operation_type, operation_name: a.operation_name, fields: a.fields, + rpc_method: a.rpc_method, query: a .query .into_iter() @@ -483,7 +486,6 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { (key, yaml_matcher) }) .collect(), - rpc_method: String::new(), params: BTreeMap::new(), }, } @@ -500,6 +502,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { operation_type: d.operation_type.clone(), operation_name: d.operation_name.clone(), fields: d.fields.clone(), + rpc_method: d.rpc_method.clone(), query: d .query .iter() @@ -514,7 +517,6 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { (key.clone(), yaml_matcher) }) .collect(), - rpc_method: String::new(), params: BTreeMap::new(), }) .collect(), diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index f191cd272..73c40316e 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -747,6 +747,7 @@ fn expand_access_preset(protocol: &str, access: &str) -> Option> { operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + rpc_method: String::new(), }), }) .collect(), @@ -961,6 +962,7 @@ mod tests { operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + rpc_method: String::new(), }), } } diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index d2a35ca80..a6d282256 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -816,6 +816,7 @@ fn allow_to_proto(allow: &L7AllowProfile) -> L7Allow { operation_type: allow.operation_type.clone(), operation_name: allow.operation_name.clone(), fields: allow.fields.clone(), + rpc_method: String::new(), } } @@ -848,6 +849,7 @@ fn deny_rule_to_proto(rule: &L7DenyRuleProfile) -> L7DenyRule { operation_type: rule.operation_type.clone(), operation_name: rule.operation_name.clone(), fields: rule.fields.clone(), + rpc_method: String::new(), } } diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego new file mode 100644 index 000000000..c25b42af0 --- /dev/null +++ b/crates/openshell-sandbox/data/sandbox-policy.rego @@ -0,0 +1,752 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +package openshell.sandbox + +default allow_network = false + +# --- Static policy data passthrough (queried at sandbox startup) --- + +filesystem_policy := data.filesystem_policy + +landlock_policy := data.landlock + +process_policy := data.process + +# --- Network access decision (queried per-CONNECT request) --- + +allow_network if { + network_policy_for_request +} + +# --- Deny reasons (specific diagnostics for debugging policy denials) --- + +deny_reason := "missing input.network" if { + not input.network +} + +deny_reason := "missing input.exec" if { + input.network + not input.exec +} + +deny_reason := reason if { + input.network + input.exec + not network_policy_for_request + not endpoint_policy_for_request + count(data.network_policies) > 0 + reason := sprintf("endpoint %s:%d is not allowed by any policy", [input.network.host, input.network.port]) +} + +deny_reason := reason if { + input.network + input.exec + not network_policy_for_request + endpoint_policy_for_request + ancestors_str := concat(" -> ", input.exec.ancestors) + cmdline_str := concat(", ", input.exec.cmdline_paths) + binary_misses := [r | + some name + policy := data.network_policies[name] + endpoint_allowed(policy, input.network) + not binary_allowed(policy, input.exec) + r := sprintf("binary '%s' not allowed in policy '%s' (ancestors: [%s], cmdline: [%s]). SYMLINK HINT: the binary path is the kernel-resolved target from /proc//exe, not the symlink. If your policy specifies a symlink (e.g., /usr/bin/python3) but the actual binary is /usr/bin/python3.11, either: (1) use the canonical path in your policy (run 'readlink -f /usr/bin/python3' inside the sandbox), or (2) ensure symlink resolution is working (check sandbox logs for 'Cannot access container filesystem')", [input.exec.path, name, ancestors_str, cmdline_str]) + ] + count(binary_misses) > 0 + reason := concat("; ", binary_misses) +} + +deny_reason := "network connections not allowed by policy" if { + input.network + input.exec + not network_policy_for_request + count(data.network_policies) == 0 +} + +# --- Matched policy name (for audit logging) --- +# +# Collects all matching policy names into a set, then deterministically picks +# the lexicographically smallest. This avoids a "complete rule conflict" when +# multiple policies cover the same endpoint (e.g. after draft approval adds an +# overlapping rule). + +_matching_policy_names contains name if { + some name + policy := data.network_policies[name] + endpoint_allowed(policy, input.network) + binary_allowed(policy, input.exec) +} + +matched_network_policy := min(_matching_policy_names) if { + count(_matching_policy_names) > 0 +} + +# --- Core matching logic --- + +# True when at least one network policy matches the request (endpoint + binary). +# Expressed as a boolean so that multiple matching policies don't cause a +# "complete rule conflict". +network_policy_for_request if { + some name + data.network_policies[name] + endpoint_allowed(data.network_policies[name], input.network) + binary_allowed(data.network_policies[name], input.exec) +} + +endpoint_policy_for_request if { + some name + data.network_policies[name] + endpoint_allowed(data.network_policies[name], input.network) +} + +# Endpoint matching: exact host (case-insensitive) + port in ports list. +endpoint_allowed(policy, network) if { + some endpoint + endpoint := policy.endpoints[_] + not contains(endpoint.host, "*") + lower(endpoint.host) == lower(network.host) + endpoint.ports[_] == network.port +} + +# Endpoint matching: glob host pattern + port in ports list. +# Uses "." as delimiter so "*" matches a single DNS label and "**" matches +# across label boundaries — consistent with TLS certificate wildcard semantics. +endpoint_allowed(policy, network) if { + some endpoint + endpoint := policy.endpoints[_] + contains(endpoint.host, "*") + glob.match(lower(endpoint.host), ["."], lower(network.host)) + endpoint.ports[_] == network.port +} + +# Endpoint matching: hostless with allowed_ips — match any host on port. +# When an endpoint has allowed_ips but no host, it matches any hostname on the +# given port. The actual IP validation happens in Rust post-DNS-resolution. +endpoint_allowed(policy, network) if { + some endpoint + endpoint := policy.endpoints[_] + object.get(endpoint, "host", "") == "" + count(object.get(endpoint, "allowed_ips", [])) > 0 + endpoint.ports[_] == network.port +} + +# Binary matching: exact path. +# SHA256 integrity is enforced in Rust via trust-on-first-use (TOFU) cache, +# not in Rego. The proxy computes and caches binary hashes at runtime. +binary_allowed(policy, exec) if { + some b + b := policy.binaries[_] + not contains(b.path, "*") + b.path == exec.path +} + +# Binary matching: ancestor exact path (e.g., claude spawns node). +binary_allowed(policy, exec) if { + some b + b := policy.binaries[_] + not contains(b.path, "*") + ancestor := exec.ancestors[_] + b.path == ancestor +} + +# Binary matching: glob pattern against exe path or any ancestor. +# NOTE: cmdline_paths are intentionally excluded — argv[0] is trivially +# spoofable via execve and must not be used as a grant-access signal. +binary_allowed(policy, exec) if { + some b in policy.binaries + contains(b.path, "*") + all_paths := array.concat([exec.path], exec.ancestors) + some p in all_paths + glob.match(b.path, ["/"], p) +} + +user_declared_binary_allowed(policy, exec) if { + some b + b := policy.binaries[_] + not object.get(b, "advisor_proposed", false) + not contains(b.path, "*") + b.path == exec.path +} + +user_declared_binary_allowed(policy, exec) if { + some b + b := policy.binaries[_] + not object.get(b, "advisor_proposed", false) + not contains(b.path, "*") + ancestor := exec.ancestors[_] + b.path == ancestor +} + +user_declared_binary_allowed(policy, exec) if { + some b in policy.binaries + not object.get(b, "advisor_proposed", false) + contains(b.path, "*") + all_paths := array.concat([exec.path], exec.ancestors) + some p in all_paths + glob.match(b.path, ["/"], p) +} + +# --- Network action (allow / deny) --- +# +# These rules are mutually exclusive by construction: +# - "allow" requires `network_policy_for_request` (binary+endpoint matched) +# - default is "deny" when no policy matches. + +default network_action := "deny" + +# Explicitly allowed: endpoint + binary match in a network policy → allow. +network_action := "allow" if { + network_policy_for_request +} + +# =========================================================================== +# L7 request evaluation (queried per-request within a tunnel) +# =========================================================================== + +default allow_request = false + +# Per-policy helper: true when this single policy has at least one endpoint +# matching the L4 request whose L7 rules also permit the specific request. +# Isolating the endpoint iteration inside a function avoids the regorus +# "duplicated definition of local variable" error that occurs when the +# outer `some name` iterates over multiple policies that share a host:port. +_policy_allows_l7(policy) if { + some ep + ep := policy.endpoints[_] + endpoint_matches_l7_request(ep, input.network, input.request) + request_allowed_for_endpoint(input.request, ep) +} + +# L7 request allowed if any matching L4 policy also allows the L7 request +# AND no deny rule blocks it. Deny rules take precedence over allow rules. +allow_request if { + some name + policy := data.network_policies[name] + endpoint_allowed(policy, input.network) + binary_allowed(policy, input.exec) + _policy_allows_l7(policy) + not deny_request +} + +# --- L7 deny rules --- +# +# Deny rules are evaluated after allow rules and take precedence. +# If a request matches any deny rule on any matching endpoint, it is blocked +# even if it would otherwise be allowed. + +default deny_request = false + +# Per-policy helper: true when this policy has at least one endpoint matching +# the L4 request whose deny_rules also match the specific L7 request. +_policy_denies_l7(policy) if { + some ep + ep := policy.endpoints[_] + endpoint_matches_l7_request(ep, input.network, input.request) + request_denied_for_endpoint(input.request, ep) +} + +deny_request if { + some name + policy := data.network_policies[name] + endpoint_allowed(policy, input.network) + binary_allowed(policy, input.exec) + _policy_denies_l7(policy) +} + +# --- L7 deny rule matching: REST method + path + query --- + +request_denied_for_endpoint(request, endpoint) if { + some deny_rule + deny_rule := endpoint.deny_rules[_] + deny_rule.method + method_matches(request.method, deny_rule.method) + path_matches(request.path, deny_rule.path) + deny_query_params_match(request, deny_rule) +} + +# --- L7 deny rule matching: SQL command --- + +request_denied_for_endpoint(request, endpoint) if { + some deny_rule + deny_rule := endpoint.deny_rules[_] + deny_rule.command + command_matches(request.command, deny_rule.command) +} + +# --- L7 deny rule matching: GraphQL operation --- + +request_denied_for_endpoint(request, endpoint) if { + graphql_request_has_operations(request) + some deny_rule + deny_rule := endpoint.deny_rules[_] + deny_rule.operation_type + op := request.graphql.operations[_] + graphql_deny_rule_matches_operation(op, deny_rule, endpoint) +} + +# A GraphQL endpoint path is authoritative once it matches. If the parsed +# GraphQL request is malformed, hash-only without a trusted registry entry, or +# contains an operation outside the GraphQL allow rules, a broader REST rule on +# the same host:port must not allow it through. +request_denied_for_endpoint(request, endpoint) if { + endpoint.protocol == "graphql" + is_object(request.graphql) + not graphql_request_allowed(request, endpoint) +} + +# The same authority applies when a WebSocket endpoint opts into GraphQL +# operation policy. Once the relay classifies a client text message as a +# GraphQL-over-WebSocket operation, generic WEBSOCKET_TEXT rules must not bypass +# operation_type / operation_name / fields policy. +request_denied_for_endpoint(request, endpoint) if { + endpoint.protocol == "websocket" + is_object(request.graphql) + not graphql_request_allowed(request, endpoint) +} + +# Deny query matching: fail-closed semantics. +# If no query rules on the deny rule, match unconditionally (any query params). +# If query rules present, trigger the deny if ANY value for a configured key +# matches the matcher. This is the inverse of allow-side semantics where ALL +# values must match. For deny logic, a single matching value is enough to block. +deny_query_params_match(request, deny_rule) if { + deny_query_rules := object.get(deny_rule, "query", {}) + count(deny_query_rules) == 0 +} + +deny_query_params_match(request, deny_rule) if { + deny_query_rules := object.get(deny_rule, "query", {}) + count(deny_query_rules) > 0 + not deny_query_key_missing(request, deny_query_rules) + not deny_query_value_mismatch_all(request, deny_query_rules) +} + +# A configured deny query key is missing from the request entirely. +# Missing key means the deny rule doesn't apply (fail-open on absence). +deny_query_key_missing(request, query_rules) if { + some key + query_rules[key] + request_query := object.get(request, "query_params", {}) + values := object.get(request_query, key, null) + values == null +} + +# ALL values for a configured key fail to match the matcher. +# If even one value matches, deny fires. This rule checks the opposite: +# true when NO value matches (i.e., every value is a mismatch). +deny_query_value_mismatch_all(request, query_rules) if { + some key + matcher := query_rules[key] + request_query := object.get(request, "query_params", {}) + values := object.get(request_query, key, []) + count(values) > 0 + not deny_any_value_matches(values, matcher) +} + +# True if at least one value in the list matches the matcher. +deny_any_value_matches(values, matcher) if { + some i + query_value_matches(values[i], matcher) +} + +# --- L7 deny reason --- + +request_deny_reason := reason if { + input.request + graphql_request_error(input.request) + reason := sprintf("GraphQL request rejected: %s", [input.request.graphql.error]) +} + +request_deny_reason := reason if { + input.request + not graphql_request_error(input.request) + graphql_request_has_unregistered_persisted_query(input.request, matched_endpoint_config) + reason := "GraphQL persisted query is not registered" +} + +request_deny_reason := reason if { + input.request + deny_request + graphql_request_has_operations(input.request) + not graphql_request_has_unregistered_persisted_query(input.request, matched_endpoint_config) + reason := "GraphQL operation blocked by endpoint policy" +} + +request_deny_reason := reason if { + input.request + not deny_request + not allow_request + graphql_request_has_operations(input.request) + not graphql_request_has_unregistered_persisted_query(input.request, matched_endpoint_config) + reason := "GraphQL operation not permitted by policy" +} + +request_deny_reason := reason if { + input.request + deny_request + not graphql_request_has_operations(input.request) + reason := sprintf("%s %s blocked by deny rule", [input.request.method, input.request.path]) +} + +request_deny_reason := reason if { + input.request + not deny_request + not allow_request + not graphql_request_has_operations(input.request) + reason := sprintf("%s %s not permitted by policy", [input.request.method, input.request.path]) +} + +# --- L7 rule matching: REST method + path --- + +request_allowed_for_endpoint(request, endpoint) if { + some rule + rule := endpoint.rules[_] + rule.allow.method + method_matches(request.method, rule.allow.method) + path_matches(request.path, rule.allow.path) + query_params_match(request, rule) +} + +# --- L7 rule matching: SQL command --- + +request_allowed_for_endpoint(request, endpoint) if { + some rule + rule := endpoint.rules[_] + rule.allow.command + command_matches(request.command, rule.allow.command) +} + +# --- L7 rule matching: JSON-RPC method --- + +request_allowed_for_endpoint(request, endpoint) if { + some rule + rule := endpoint.rules[_] + rule.allow.rpc_method + jsonrpc := object.get(request, "jsonrpc", {}) + method := object.get(jsonrpc, "method", null) + method != null + glob.match(rule.allow.rpc_method, [], method) +} + +# --- L7 rule matching: GraphQL operation --- + +request_allowed_for_endpoint(request, endpoint) if { + graphql_request_allowed(request, endpoint) +} + +graphql_request_allowed(request, endpoint) if { + graphql_request_has_operations(request) + not graphql_request_error(request) + not graphql_request_has_unregistered_persisted_query(request, endpoint) + not graphql_request_has_unallowed_operation(request, endpoint) +} + +graphql_request_has_operations(request) if { + is_object(request.graphql) + operations := object.get(request.graphql, "operations", []) + count(operations) > 0 +} + +graphql_request_error(request) if { + is_object(request.graphql) + error := object.get(request.graphql, "error", "") + error != "" +} + +graphql_request_has_unallowed_operation(request, endpoint) if { + op := request.graphql.operations[_] + not graphql_operation_allowed(op, endpoint) +} + +graphql_operation_allowed(op, endpoint) if { + rule := endpoint.rules[_] + rule.allow.operation_type + graphql_allow_rule_matches_operation(op, rule.allow, endpoint) +} + +graphql_request_has_unregistered_persisted_query(request, endpoint) if { + op := request.graphql.operations[_] + graphql_operation_needs_registry(op) + not graphql_registered_operation(op, endpoint) +} + +graphql_operation_needs_registry(op) if { + object.get(op, "persisted_query", false) == true + object.get(op, "operation_type", "") == "" +} + +graphql_registered_operation(op, endpoint) if { + object.get(endpoint, "persisted_queries", "deny") == "allow_registered" + id := graphql_operation_registry_key(op) + endpoint.graphql_persisted_queries[id] +} + +graphql_operation_registry_key(op) := key if { + key := object.get(op, "persisted_query_hash", "") + key != "" +} + +graphql_operation_registry_key(op) := key if { + object.get(op, "persisted_query_hash", "") == "" + key := object.get(op, "persisted_query_id", "") + key != "" +} + +graphql_effective_operation(op, endpoint) := registered if { + graphql_operation_needs_registry(op) + key := graphql_operation_registry_key(op) + registered := endpoint.graphql_persisted_queries[key] +} + +graphql_effective_operation(op, _) := op if { + not graphql_operation_needs_registry(op) +} + +graphql_allow_rule_matches_operation(op, rule, endpoint) if { + effective := graphql_effective_operation(op, endpoint) + graphql_operation_type_matches(effective, rule) + graphql_operation_name_matches(effective, rule) + graphql_allow_fields_match(effective, rule) +} + +graphql_deny_rule_matches_operation(op, rule, endpoint) if { + effective := graphql_effective_operation(op, endpoint) + graphql_operation_type_matches(effective, rule) + graphql_operation_name_matches(effective, rule) + graphql_deny_fields_match(effective, rule) +} + +graphql_operation_type_matches(_, rule) if { + object.get(rule, "operation_type", "") == "*" +} + +graphql_operation_type_matches(op, rule) if { + expected := object.get(rule, "operation_type", "") + expected != "" + expected != "*" + lower(object.get(op, "operation_type", "")) == lower(expected) +} + +graphql_operation_name_matches(_, rule) if { + object.get(rule, "operation_name", "") == "" +} + +graphql_operation_name_matches(op, rule) if { + pattern := object.get(rule, "operation_name", "") + pattern != "" + name := object.get(op, "operation_name", "") + glob.match(pattern, [], name) +} + +# Allow-side field constraints are intentionally all-selected-fields semantics: +# if a rule declares fields, every root field selected by the operation must +# match one of the rule patterns. This prevents mixed-operation requests from +# allowing an unlisted field because one safe field also appeared. +graphql_allow_fields_match(_, rule) if { + count(object.get(rule, "fields", [])) == 0 +} + +graphql_allow_fields_match(op, rule) if { + count(object.get(rule, "fields", [])) > 0 + count(object.get(op, "fields", [])) > 0 + not graphql_operation_has_unmatched_field(op, rule) +} + +graphql_operation_has_unmatched_field(op, rule) if { + field := object.get(op, "fields", [])[_] + not graphql_field_matches_any(field, object.get(rule, "fields", [])) +} + +graphql_deny_fields_match(_, rule) if { + count(object.get(rule, "fields", [])) == 0 +} + +graphql_deny_fields_match(op, rule) if { + field := object.get(op, "fields", [])[_] + graphql_field_matches_any(field, object.get(rule, "fields", [])) +} + +graphql_field_matches_any(field, patterns) if { + pattern := patterns[_] + glob.match(pattern, [], field) +} + +# Wildcard "*" matches any method; otherwise case-insensitive exact match. +# RFC 9110 §9.3.2: HEAD is semantically identical to GET except no response body. +method_matches(_, "*") if true + +method_matches(actual, expected) if { + expected != "*" + upper(actual) == upper(expected) +} + +method_matches(actual, expected) if { + upper(actual) == "HEAD" + upper(expected) == "GET" +} + +# Path matching: "**" matches everything; otherwise glob.match with "/" delimiter. +# +# INVARIANT: `input.request.path` is canonicalized by the sandbox before +# policy evaluation — percent-decoded, dot-segments resolved, doubled +# slashes collapsed, `;params` stripped, `%2F` rejected (unless an +# endpoint opts in). Patterns here must therefore match canonical paths; +# do not attempt defensive matching against `..` or `%2e%2e` — those +# inputs are rejected at the L7 parser boundary before this rule runs. +path_matches(_, "**") if true + +path_matches(actual, pattern) if { + pattern != "**" + glob.match(pattern, ["/"], actual) +} + +# Query matching: +# - If no query rules are configured, allow any query params. +# - For configured keys, all request values for that key must match. +# - Matcher shape supports either `glob` or `any`. +query_params_match(request, rule) if { + query_rules := object.get(rule.allow, "query", {}) + not query_mismatch(request, query_rules) +} + +query_mismatch(request, query_rules) if { + some key + matcher := query_rules[key] + not query_key_matches(request, key, matcher) +} + +query_key_matches(request, key, matcher) if { + request_query := object.get(request, "query_params", {}) + values := object.get(request_query, key, null) + values != null + count(values) > 0 + not query_value_mismatch(values, matcher) +} + +query_value_mismatch(values, matcher) if { + some i + value := values[i] + not query_value_matches(value, matcher) +} + +query_value_matches(value, matcher) if { + is_string(matcher) + glob.match(matcher, [], value) +} + +query_value_matches(value, matcher) if { + is_object(matcher) + glob_pattern := object.get(matcher, "glob", "") + glob_pattern != "" + glob.match(glob_pattern, [], value) +} + +query_value_matches(value, matcher) if { + is_object(matcher) + any_patterns := object.get(matcher, "any", []) + count(any_patterns) > 0 + some i + glob.match(any_patterns[i], [], value) +} + +# SQL command matching: "*" matches any; otherwise case-insensitive. +command_matches(_, "*") if true + +command_matches(actual, expected) if { + expected != "*" + upper(actual) == upper(expected) +} + +# --- Matched endpoint config (for L7 and allowed_ips extraction) --- +# Returns the raw endpoint object for the matched policy + host:port. +# Used by Rust to extract L7 config (protocol, tls, enforcement, +# allow_encoded_slash) and/or allowed_ips for SSRF allowlist validation. + +# Per-policy helper: returns matching endpoint configs for a single policy. +_policy_endpoint_configs(policy) := [ep | + some ep + ep := policy.endpoints[_] + endpoint_matches_request(ep, input.network) + endpoint_has_extended_config(ep) +] + +# Collect matching endpoint configs across all policies. Iterates over +# _matching_policy_names (a set, safe from regorus variable collisions) +# then collects per-policy configs via the helper function. +_matching_endpoint_configs := [cfg | + some pname + _matching_policy_names[pname] + cfgs := _policy_endpoint_configs(data.network_policies[pname]) + cfg := cfgs[_] +] + +matched_endpoint_config := _matching_endpoint_configs[0] if { + count(_matching_endpoint_configs) > 0 +} + +_policy_has_exact_declared_endpoint(policy) if { + some ep + ep := policy.endpoints[_] + not object.get(ep, "advisor_proposed", false) + not contains(ep.host, "*") + lower(ep.host) == lower(input.network.host) + ep.ports[_] == input.network.port +} + +exact_declared_endpoint_host if { + some pname + policy := data.network_policies[pname] + user_declared_binary_allowed(policy, input.exec) + _policy_has_exact_declared_endpoint(policy) +} + +# Hosted endpoint: exact host match + port in ports list. +endpoint_matches_request(ep, network) if { + not contains(ep.host, "*") + lower(ep.host) == lower(network.host) + ep.ports[_] == network.port +} + +# Hosted endpoint: glob host match + port in ports list. +endpoint_matches_request(ep, network) if { + contains(ep.host, "*") + glob.match(lower(ep.host), ["."], lower(network.host)) + ep.ports[_] == network.port +} + +# Hostless endpoint with allowed_ips: match on port only. +endpoint_matches_request(ep, network) if { + object.get(ep, "host", "") == "" + count(object.get(ep, "allowed_ips", [])) > 0 + ep.ports[_] == network.port +} + +endpoint_matches_l7_request(ep, network, request) if { + endpoint_matches_request(ep, network) + endpoint_path_matches_request(ep, request) +} + +endpoint_path_matches_request(ep, request) if { + object.get(ep, "path", "") == "" +} + +endpoint_path_matches_request(ep, request) if { + path := object.get(ep, "path", "") + path != "" + path_matches(request.path, path) +} + +# An endpoint has extended config if it specifies L7 protocol, allowed_ips, +# or an explicit tls mode (e.g. tls: skip). +endpoint_has_extended_config(ep) if { + ep.protocol +} + +endpoint_has_extended_config(ep) if { + count(object.get(ep, "allowed_ips", [])) > 0 +} + +endpoint_has_extended_config(ep) if { + ep.tls +} diff --git a/crates/openshell-sandbox/src/l7/graphql.rs b/crates/openshell-sandbox/src/l7/graphql.rs index f548f2a10..77ec3b6fd 100644 --- a/crates/openshell-sandbox/src/l7/graphql.rs +++ b/crates/openshell-sandbox/src/l7/graphql.rs @@ -622,6 +622,7 @@ network_policies: target: req.target, query_params: req.query_params, graphql: Some(info), + jsonrpc: None, }; let tunnel_engine = engine diff --git a/crates/openshell-sandbox/src/l7/jsonrpc.rs b/crates/openshell-sandbox/src/l7/jsonrpc.rs new file mode 100644 index 000000000..977c8046f --- /dev/null +++ b/crates/openshell-sandbox/src/l7/jsonrpc.rs @@ -0,0 +1,99 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! JSON-RPC 2.0 over HTTP L7 inspection. + +use miette::Result; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::l7::provider::{L7Provider, L7Request}; + +pub struct JsonRpcHttpRequest { + pub request: L7Request, + pub info: JsonRpcRequestInfo, +} + +pub(crate) async fn parse_jsonrpc_http_request( + client: &mut C, + max_body_bytes: usize, + canonicalize_options: crate::l7::path::CanonicalizeOptions, +) -> Result> { + let provider = crate::l7::rest::RestProvider::with_options(canonicalize_options); + let Some(mut request) = provider.parse_request(client).await? else { + return Ok(None); + }; + let body = + crate::l7::http::read_body_for_inspection(client, &mut request, max_body_bytes).await?; + let info = parse_jsonrpc_body(&body); + Ok(Some(JsonRpcHttpRequest { request, info })) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct JsonRpcRequestInfo { + pub method: Option, + pub error: Option, +} + +/// Returns true if the parsed request's method matches the given `rpc_method` rule pattern. +/// +/// An empty `rpc_method` pattern matches any method. +pub fn rpc_method_rule_matches(info: &JsonRpcRequestInfo, rpc_method: &str) -> bool { + if rpc_method.is_empty() { + return true; + } + info.method.as_deref() == Some(rpc_method) +} + +/// Parse a JSON-RPC 2.0 request body and extract the `method` field. +/// +/// Returns an info struct with `method` set on success, or `error` set if the +/// body is not valid JSON-RPC 2.0. +pub fn parse_jsonrpc_body(body: &[u8]) -> JsonRpcRequestInfo { + let Ok(value) = serde_json::from_slice::(body) else { + return JsonRpcRequestInfo { + method: None, + error: Some("invalid JSON".to_string()), + }; + }; + let Some(method) = value.get("method").and_then(|m| m.as_str()) else { + return JsonRpcRequestInfo { + method: None, + error: Some("missing or non-string 'method' field".to_string()), + }; + }; + JsonRpcRequestInfo { + method: Some(method.to_string()), + error: None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_method_from_request_body() { + let body = br#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#; + let info = parse_jsonrpc_body(body); + assert_eq!(info.method.as_deref(), Some("initialize")); + assert!(info.error.is_none()); + } + + #[test] + fn rpc_method_rule_empty_matches_any() { + let info = parse_jsonrpc_body(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#); + assert!(rpc_method_rule_matches(&info, "")); + } + + #[test] + fn rpc_method_rule_matches_exact_method() { + let info = parse_jsonrpc_body(br#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#); + assert!(rpc_method_rule_matches(&info, "initialize")); + } + + #[test] + fn rpc_method_rule_does_not_match_different_method() { + let info = parse_jsonrpc_body(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#); + assert!(!rpc_method_rule_matches(&info, "initialize")); + } +} diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index 17d2e4c5e..d83a4dfa5 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -11,6 +11,7 @@ pub mod graphql; pub(crate) mod http; pub mod inference; +pub mod jsonrpc; pub mod path; pub mod provider; pub mod relay; @@ -113,6 +114,8 @@ pub struct L7RequestInfo { pub query_params: std::collections::HashMap>, /// Parsed GraphQL operation metadata for GraphQL endpoints. pub graphql: Option, + /// Parsed JSON-RPC request metadata for JSON-RPC endpoints. + pub jsonrpc: Option, } /// Parse an L7 endpoint config from a regorus Value (returned by Rego query). diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 6baa2ab03..ec723df50 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -178,25 +178,7 @@ where .into_diagnostic()?; Ok(()) } - L7Protocol::JsonRpc => { - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - // JSON-RPC provider not yet implemented — fall through to passthrough - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .severity(SeverityId::Low) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .message("JSON-RPC L7 provider not yet implemented, falling back to passthrough") - .build(); - ocsf_emit!(event); - } - tokio::io::copy_bidirectional(client, upstream) - .await - .into_diagnostic()?; - Ok(()) - } + L7Protocol::JsonRpc => relay_jsonrpc(config, &engine, client, upstream, ctx).await, } } @@ -316,6 +298,7 @@ where target: redacted_target.clone(), query_params: req.query_params.clone(), graphql: graphql_info.clone(), + jsonrpc: None, }; let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); if config.protocol == L7Protocol::Websocket && !websocket_request { @@ -713,6 +696,7 @@ where target: redacted_target.clone(), query_params: req.query_params.clone(), graphql: None, + jsonrpc: None, }; let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); if config.protocol == L7Protocol::Websocket && !websocket_request { @@ -904,6 +888,162 @@ fn close_if_stale(guard: &PolicyGenerationGuard, ctx: &L7EvalContext) -> bool { true } +async fn relay_jsonrpc( + config: &L7EndpointConfig, + engine: &TunnelPolicyEngine, + client: &mut C, + upstream: &mut U, + ctx: &L7EvalContext, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + loop { + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let parsed = match crate::l7::jsonrpc::parse_jsonrpc_http_request( + client, + 64 * 1024, + crate::l7::path::CanonicalizeOptions { + allow_encoded_slash: config.allow_encoded_slash, + ..Default::default() + }, + ) + .await + { + Ok(Some(parsed)) => parsed, + Ok(None) => return Ok(()), + Err(e) => { + if is_benign_connection_error(&e) { + debug!( + host = %ctx.host, + port = ctx.port, + error = %e, + "JSON-RPC L7 connection closed" + ); + } else { + let detail = + parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); + emit_parse_rejection(ctx, &detail, "l7-jsonrpc"); + } + return Ok(()); + } + }; + + let req = parsed.request; + let jsonrpc_info = parsed.info; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let redacted_target = req.target.clone(); + + let request_info = L7RequestInfo { + action: req.action.clone(), + target: redacted_target.clone(), + query_params: req.query_params.clone(), + graphql: None, + jsonrpc: Some(jsonrpc_info.clone()), + }; + + let parse_error_reason = jsonrpc_info + .error + .as_deref() + .map(|e| format!("JSON-RPC request rejected: {e}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(engine, ctx, &request_info)? + }; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let decision_str = match (allowed, config.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + + { + let (action_id, disposition_id, severity) = match decision_str { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + _ => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + }; + let rpc_method = jsonrpc_info.method.as_deref().unwrap_or("-"); + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + &request_info.action, + OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "l7-jsonrpc") + .message(format!( + "JSONRPC_L7_REQUEST {decision_str} {} {}:{}{} rpc_method={rpc_method} reason={}", + request_info.action, ctx.host, ctx.port, redacted_target, reason, + )) + .build(); + ocsf_emit!(event); + } + + if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + &req, + client, + upstream, + ctx.secret_resolver.as_deref(), + Some(engine.generation_guard()), + ) + .await?; + match outcome { + RelayOutcome::Reusable => {} + RelayOutcome::Consumed => { + debug!( + host = %ctx.host, + port = ctx.port, + "Upstream connection not reusable, closing JSON-RPC L7 relay" + ); + return Ok(()); + } + RelayOutcome::Upgraded { .. } => { + return Ok(()); + } + } + } else { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } +} + async fn relay_graphql( config: &L7EndpointConfig, engine: &TunnelPolicyEngine, @@ -981,6 +1121,7 @@ where target: redacted_target.clone(), query_params: req.query_params.clone(), graphql: Some(graphql_info.clone()), + jsonrpc: None, }; // Malformed or ambiguous GraphQL requests, such as duplicated GET @@ -1178,6 +1319,10 @@ pub fn evaluate_l7_request( "path": request.target, "query_params": request.query_params.clone(), "graphql": request.graphql.clone(), + "jsonrpc": request.jsonrpc.as_ref().map(|j| serde_json::json!({ + "method": j.method, + "error": j.error, + })), } }); @@ -1811,6 +1956,7 @@ network_policies: target: "/ws".into(), query_params: std::collections::HashMap::new(), graphql: None, + jsonrpc: None, }; let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); @@ -2426,4 +2572,100 @@ network_policies: "stale passthrough request must not be forwarded upstream" ); } + + #[tokio::test] + async fn jsonrpc_relay_denies_method_not_in_allow_list() { + let data = r" +network_policies: + mcp_api: + name: mcp_api + endpoints: + - host: mcp.example.test + port: 8000 + path: /mcp + protocol: json-rpc + enforcement: enforce + rules: + - allow: + rpc_method: initialize + binaries: + - { path: /usr/bin/python3 } +"; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "mcp.example.test".into(), + port: 8000, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "mcp.example.test".into(), + port: 8000, + policy_name: "mcp_api".into(), + binary_path: "/usr/bin/python3".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = + br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"list_repos"}}"#; + let request = format!( + "POST /mcp HTTP/1.1\r\nHost: mcp.example.test:8000\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body.len() + ); + app.write_all(request.as_bytes()).await.unwrap(); + app.write_all(body).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(2), app.read(&mut response)) + .await + .expect("relay should respond without reaching upstream") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!( + response.contains("403"), + "tools/call not in allow list must be denied with 403, got: {response:?}" + ); + + let mut upstream_buf = [0u8; 128]; + let n = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_buf), + ) + .await + .unwrap_or(Ok(0)) + .unwrap_or(0); + assert_eq!(n, 0, "denied request must not be forwarded to upstream"); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should complete") + .unwrap() + .unwrap(); + } } diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs new file mode 100644 index 000000000..70c11f330 --- /dev/null +++ b/crates/openshell-sandbox/src/l7/websocket.rs @@ -0,0 +1,1943 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! WebSocket relay for opt-in credential placeholder rewriting and message policy. +//! +//! The relay parses only client-to-server frames. Server-to-client bytes stay +//! raw passthrough so inspection and rewriting cannot expose response payloads. + +use crate::l7::relay::{L7EvalContext, evaluate_l7_request}; +use crate::l7::{EnforcementMode, L7RequestInfo}; +use crate::opa::TunnelPolicyEngine; +use crate::secrets::SecretResolver; +use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; +use miette::{IntoDiagnostic, Result, miette}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, NetworkActivityBuilder, SeverityId, StatusId, + ocsf_emit, +}; +use std::collections::HashMap; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; +const MAX_RAW_FRAME_PAYLOAD_BYTES: u64 = 16 * 1024 * 1024; +const COPY_BUF_SIZE: usize = 8192; +const OPCODE_CONTINUATION: u8 = 0x0; +const OPCODE_TEXT: u8 = 0x1; +const OPCODE_BINARY: u8 = 0x2; +const OPCODE_CLOSE: u8 = 0x8; +const OPCODE_PING: u8 = 0x9; +const OPCODE_PONG: u8 = 0xA; + +#[derive(Debug)] +struct FrameHeader { + fin: bool, + rsv: u8, + opcode: u8, + masked: bool, + payload_len: u64, + mask_key: Option<[u8; 4]>, + raw_header: Vec, +} + +#[derive(Debug)] +enum FragmentState { + None, + Text { payload: Vec, compressed: bool }, + Binary, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(super) enum WebSocketCompression { + None, + PermessageDeflate, +} + +pub(super) struct InspectionOptions<'a> { + pub(super) engine: &'a TunnelPolicyEngine, + pub(super) ctx: &'a L7EvalContext, + pub(super) enforcement: EnforcementMode, + pub(super) target: String, + pub(super) query_params: HashMap>, + pub(super) graphql_policy: bool, +} + +pub(super) struct RelayOptions<'a> { + pub(super) policy_name: &'a str, + pub(super) resolver: Option<&'a SecretResolver>, + pub(super) inspector: Option>, + pub(super) compression: WebSocketCompression, +} + +/// Relay an upgraded WebSocket connection with optional client text inspection, +/// credential rewriting, and strict permessage-deflate handling. +pub(super) async fn relay_with_options( + client: &mut C, + upstream: &mut U, + overflow: Vec, + host: &str, + port: u16, + options: RelayOptions<'_>, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + let (mut client_read, mut client_write) = tokio::io::split(client); + let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream); + + if !overflow.is_empty() { + client_write.write_all(&overflow).await.into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + } + + let client_to_server = + relay_client_to_server(&mut client_read, &mut upstream_write, host, port, &options); + let server_to_client = async { + tokio::io::copy(&mut upstream_read, &mut client_write) + .await + .into_diagnostic()?; + client_write.flush().await.into_diagnostic()?; + Ok::<(), miette::Report>(()) + }; + + let result = tokio::select! { + result = client_to_server => result, + result = server_to_client => result, + }; + let _ = upstream_write.shutdown().await; + let _ = client_write.shutdown().await; + result +} + +async fn relay_client_to_server( + reader: &mut R, + writer: &mut W, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut fragments = FragmentState::None; + let mut close_seen = false; + + loop { + let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(e)); + })? + else { + writer.shutdown().await.into_diagnostic()?; + return Ok(()); + }; + + if close_seen { + let e = miette!("websocket frame received after close frame"); + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + if let Err(e) = validate_frame_header(&frame, &fragments, options.compression) { + emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); + return Err(e); + } + + match frame.opcode { + OPCODE_TEXT => { + let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + let compressed = frame.rsv == 0x40; + if frame.fin { + relay_text_payload( + writer, &frame, payload, false, compressed, host, port, options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } else { + fragments = FragmentState::Text { + payload, + compressed, + }; + } + } + OPCODE_CONTINUATION => match &mut fragments { + FragmentState::Text { + payload, + compressed, + } => { + let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if let Err(e) = append_text_fragment(payload, next) { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + if frame.fin { + let complete = std::mem::take(payload); + let was_compressed = *compressed; + fragments = FragmentState::None; + relay_text_payload( + writer, + &frame, + complete, + true, + was_compressed, + host, + port, + options, + ) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + } + FragmentState::Binary => { + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.fin { + fragments = FragmentState::None; + } + } + FragmentState::None => { + let e = + miette!("websocket continuation frame without active fragmented message"); + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(&e), + ); + return Err(e); + } + }, + OPCODE_BINARY => { + if !frame.fin { + fragments = FragmentState::Binary; + } + copy_raw_frame_payload(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + } + OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { + relay_control_frame(reader, writer, &frame) + .await + .inspect_err(|e| { + emit_protocol_failure( + host, + port, + options.policy_name, + protocol_failure_class(e), + ); + })?; + if frame.opcode == OPCODE_CLOSE { + close_seen = true; + } + } + _ => unreachable!("validated opcode"), + } + } +} + +async fn read_frame_header(reader: &mut R) -> Result> { + let first = match reader.read_u8().await { + Ok(byte) => byte, + Err(e) + if matches!( + e.kind(), + std::io::ErrorKind::UnexpectedEof + | std::io::ErrorKind::ConnectionReset + | std::io::ErrorKind::BrokenPipe + ) => + { + return Ok(None); + } + Err(e) => return Err(miette!("{e}")), + }; + let second = reader + .read_u8() + .await + .map_err(|e| miette!("malformed websocket frame header: {e}"))?; + + let mut raw_header = vec![first, second]; + let len_code = second & 0x7F; + let payload_len = match len_code { + 0..=125 => u64::from(len_code), + 126 => { + let mut bytes = [0u8; 2]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + raw_header.extend_from_slice(&bytes); + let len = u64::from(u16::from_be_bytes(bytes)); + if len < 126 { + return Err(miette!( + "websocket frame uses non-minimal 16-bit extended length" + )); + } + len + } + 127 => { + let mut bytes = [0u8; 8]; + reader + .read_exact(&mut bytes) + .await + .map_err(|e| miette!("malformed websocket extended length: {e}"))?; + if bytes[0] & 0x80 != 0 { + return Err(miette!("websocket frame uses non-canonical 64-bit length")); + } + raw_header.extend_from_slice(&bytes); + let len = u64::from_be_bytes(bytes); + if u16::try_from(len).is_ok() { + return Err(miette!( + "websocket frame uses non-minimal 64-bit extended length" + )); + } + len + } + _ => unreachable!("7-bit length code"), + }; + + let masked = second & 0x80 != 0; + let mask_key = if masked { + let mut key = [0u8; 4]; + reader + .read_exact(&mut key) + .await + .map_err(|e| miette!("malformed websocket mask key: {e}"))?; + raw_header.extend_from_slice(&key); + Some(key) + } else { + None + }; + + Ok(Some(FrameHeader { + fin: first & 0x80 != 0, + rsv: first & 0x70, + opcode: first & 0x0F, + masked, + payload_len, + mask_key, + raw_header, + })) +} + +fn validate_frame_header( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> Result<()> { + if !valid_rsv_bits(frame, fragments, compression) { + return Err(miette!( + "websocket frame has unsupported RSV bits or extension state" + )); + } + if !frame.masked { + return Err(miette!("websocket client frame is not masked")); + } + if !matches!( + frame.opcode, + OPCODE_CONTINUATION + | OPCODE_TEXT + | OPCODE_BINARY + | OPCODE_CLOSE + | OPCODE_PING + | OPCODE_PONG + ) { + return Err(miette!("websocket frame uses reserved opcode")); + } + if matches!(frame.opcode, OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG) { + if !frame.fin { + return Err(miette!("websocket control frame is fragmented")); + } + if frame.payload_len > 125 { + return Err(miette!("websocket control frame exceeds 125 bytes")); + } + } + if matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) + && !matches!(fragments, FragmentState::None) + { + return Err(miette!( + "websocket data frame started before previous fragmented message completed" + )); + } + if matches!(frame.opcode, OPCODE_CONTINUATION) && matches!(fragments, FragmentState::None) { + return Err(miette!( + "websocket continuation frame without active fragmented message" + )); + } + if (frame.opcode == OPCODE_BINARY + || (frame.opcode == OPCODE_CONTINUATION && matches!(fragments, FragmentState::Binary))) + && frame.payload_len > MAX_RAW_FRAME_PAYLOAD_BYTES + { + return Err(miette!( + "websocket binary frame exceeds {MAX_RAW_FRAME_PAYLOAD_BYTES} byte relay limit" + )); + } + Ok(()) +} + +fn valid_rsv_bits( + frame: &FrameHeader, + fragments: &FragmentState, + compression: WebSocketCompression, +) -> bool { + if frame.rsv == 0 { + return true; + } + if compression != WebSocketCompression::PermessageDeflate || frame.rsv != 0x40 { + return false; + } + matches!(fragments, FragmentState::None) && matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) +} + +async fn read_masked_payload( + reader: &mut R, + frame: &FrameHeader, +) -> Result> { + let payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket text frame is too large to buffer"))?; + if payload_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + let mut payload = vec![0u8; payload_len]; + reader + .read_exact(&mut payload) + .await + .map_err(|e| miette!("malformed websocket payload: {e}"))?; + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + Ok(payload) +} + +fn append_text_fragment(buffer: &mut Vec, next: Vec) -> Result<()> { + let new_len = buffer + .len() + .checked_add(next.len()) + .ok_or_else(|| miette!("websocket text message length overflow"))?; + if new_len > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + buffer.extend_from_slice(&next); + Ok(()) +} + +#[allow(clippy::too_many_arguments)] +async fn relay_text_payload( + writer: &mut W, + frame: &FrameHeader, + payload: Vec, + force_reframe: bool, + compressed: bool, + host: &str, + port: u16, + options: &RelayOptions<'_>, +) -> Result<()> { + let message_payload = if compressed { + decompress_permessage_deflate(&payload)? + } else { + payload + }; + let mut text = String::from_utf8(message_payload) + .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; + let replacements = if let Some(resolver) = options.resolver { + resolver + .rewrite_websocket_text_placeholders(&mut text) + .map_err(|_| miette!("websocket credential placeholder resolution failed"))? + } else { + 0 + }; + + if let Some(inspector) = options.inspector.as_ref() { + inspect_websocket_text_message(host, port, options.policy_name, inspector, &text)?; + } + + if replacements == 0 && !force_reframe && !compressed { + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut payload = text.into_bytes(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + writer.write_all(&payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + return Ok(()); + } + + if replacements > 0 { + emit_rewrite_event(host, port, options.policy_name, replacements); + } + if compressed { + let compressed_payload = compress_permessage_deflate(text.as_bytes())?; + return write_masked_frame_with_rsv(writer, OPCODE_TEXT, 0x40, &compressed_payload).await; + } + write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await +} + +fn inspect_websocket_text_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + if inspector.graphql_policy { + return inspect_graphql_websocket_message(host, port, policy_name, inspector, text); + } + + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + jsonrpc: None, + }; + let (allowed, reason) = evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)?; + let decision = match (allowed, inspector.enforcement) { + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + None, + ); + if !allowed && inspector.enforcement == EnforcementMode::Enforce { + return Err(miette!("websocket text message denied by policy")); + } + Ok(()) +} + +fn inspect_graphql_websocket_message( + host: &str, + port: u16, + policy_name: &str, + inspector: &InspectionOptions<'_>, + text: &str, +) -> Result<()> { + match classify_graphql_websocket_message(text) { + GraphqlWebSocketMessage::Control { message_type } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_CONTROL".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: None, + jsonrpc: None, + }; + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + "allow", + &format!("GraphQL WebSocket control message {message_type}"), + None, + ); + Ok(()) + } + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + let request_info = L7RequestInfo { + action: "WEBSOCKET_TEXT".to_string(), + target: inspector.target.clone(), + query_params: inspector.query_params.clone(), + graphql: Some(graphql.clone()), + jsonrpc: None, + }; + let parse_error_reason = graphql + .error + .as_deref() + .map(|error| format!("GraphQL WebSocket message rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = if let Some(reason) = parse_error_reason { + (false, reason) + } else { + evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)? + }; + let decision = match (allowed, inspector.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + let reason = format!("graphql_ws_type={message_type} {reason}"); + emit_websocket_l7_event( + host, + port, + policy_name, + &request_info, + decision, + &reason, + Some(&graphql), + ); + if (!allowed && inspector.enforcement == EnforcementMode::Enforce) || force_deny { + return Err(miette!("websocket GraphQL message denied by policy")); + } + Ok(()) + } + } +} + +#[derive(Debug)] +enum GraphqlWebSocketMessage { + Control { + message_type: String, + }, + Operation { + message_type: String, + graphql: crate::l7::graphql::GraphqlRequestInfo, + }, +} + +fn classify_graphql_websocket_message(text: &str) -> GraphqlWebSocketMessage { + let value = match serde_json::from_str::(text) { + Ok(value) => value, + Err(err) => { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error(format!( + "GraphQL WebSocket message is not valid JSON: {err}" + )), + }; + } + }; + let Some(obj) = value.as_object() else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message must be a JSON object"), + }; + }; + let Some(message_type) = obj.get("type").and_then(serde_json::Value::as_str) else { + return GraphqlWebSocketMessage::Operation { + message_type: "unknown".to_string(), + graphql: graphql_error("GraphQL WebSocket message missing string type"), + }; + }; + + match message_type { + "subscribe" | "start" => { + if obj + .get("id") + .and_then(serde_json::Value::as_str) + .is_none_or(str::is_empty) + { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing non-empty id", + ), + }; + } + let Some(payload) = obj.get("payload").filter(|value| value.is_object()) else { + return GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error( + "GraphQL WebSocket operation message missing object payload", + ), + }; + }; + GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: crate::l7::graphql::classify_json_envelope_value(payload), + } + } + "connection_init" | "connection_terminate" | "ping" | "pong" | "complete" | "stop" => { + GraphqlWebSocketMessage::Control { + message_type: message_type.to_string(), + } + } + _ => GraphqlWebSocketMessage::Operation { + message_type: message_type.to_string(), + graphql: graphql_error(format!( + "unsupported GraphQL WebSocket client message type {message_type:?}" + )), + }, + } +} + +fn graphql_error(message: impl Into) -> crate::l7::graphql::GraphqlRequestInfo { + crate::l7::graphql::GraphqlRequestInfo { + operations: Vec::new(), + error: Some(message.into()), + } +} + +async fn relay_control_frame( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let raw_payload_len = usize::try_from(frame.payload_len) + .map_err(|_| miette!("websocket control frame payload length overflow"))?; + let mut raw_payload = vec![0u8; raw_payload_len]; + reader + .read_exact(&mut raw_payload) + .await + .map_err(|e| miette!("malformed websocket control payload: {e}"))?; + + if frame.opcode == OPCODE_CLOSE { + let mut payload = raw_payload.clone(); + let mask_key = frame + .mask_key + .ok_or_else(|| miette!("websocket client frame is not masked"))?; + apply_mask(&mut payload, mask_key); + validate_close_payload(&payload)?; + } + + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + writer.write_all(&raw_payload).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn validate_close_payload(payload: &[u8]) -> Result<()> { + if payload.len() == 1 { + return Err(miette!( + "websocket close frame payload cannot be exactly one byte" + )); + } + if payload.len() < 2 { + return Ok(()); + } + + let code = u16::from_be_bytes([payload[0], payload[1]]); + if !valid_close_code(code) { + return Err(miette!("websocket close frame uses invalid close code")); + } + if std::str::from_utf8(&payload[2..]).is_err() { + return Err(miette!("websocket close frame reason is not valid UTF-8")); + } + Ok(()) +} + +fn valid_close_code(code: u16) -> bool { + (matches!(code, 1000..=1014) && !matches!(code, 1004..=1006)) || (3000..=4999).contains(&code) +} + +async fn copy_raw_frame_payload( + reader: &mut R, + writer: &mut W, + frame: &FrameHeader, +) -> Result<()> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + writer + .write_all(&frame.raw_header) + .await + .into_diagnostic()?; + let mut remaining = frame.payload_len; + let mut buf = [0u8; COPY_BUF_SIZE]; + while remaining > 0 { + let to_read = usize::try_from(remaining) + .unwrap_or(buf.len()) + .min(buf.len()); + let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!("websocket payload ended before declared length")); + } + writer.write_all(&buf[..n]).await.into_diagnostic()?; + remaining -= n as u64; + } + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +async fn write_masked_frame( + writer: &mut W, + opcode: u8, + payload: &[u8], +) -> Result<()> { + write_masked_frame_with_rsv(writer, opcode, 0, payload).await +} + +async fn write_masked_frame_with_rsv( + writer: &mut W, + opcode: u8, + rsv: u8, + payload: &[u8], +) -> Result<()> { + let mut header = Vec::with_capacity(14); + header.push(0x80 | rsv | opcode); + match payload.len() { + 0..=125 => header.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + header.push(0x80 | 0x7e); + header.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + header.push(0x80 | 127); + header.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + let mask_key = new_mask_key(); + header.extend_from_slice(&mask_key); + + let mut masked = payload.to_vec(); + apply_mask(&mut masked, mask_key); + writer.write_all(&header).await.into_diagnostic()?; + writer.write_all(&masked).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +fn decompress_permessage_deflate(payload: &[u8]) -> Result> { + let mut decoder = Decompress::new(false); + let mut input = Vec::with_capacity(payload.len() + 4); + input.extend_from_slice(payload); + input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); + let mut out = Vec::with_capacity(payload.len().saturating_mul(2).min(MAX_TEXT_MESSAGE_BYTES)); + let mut input_pos = 0usize; + let mut scratch = [0u8; COPY_BUF_SIZE]; + loop { + let before_in = decoder.total_in(); + let before_out = decoder.total_out(); + let status = decoder + .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate decompression failed: {e}"))?; + let read = usize::try_from(decoder.total_in() - before_in) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + let written = usize::try_from(decoder.total_out() - before_out) + .map_err(|_| miette!("websocket permessage-deflate output length overflow"))?; + input_pos = input_pos + .checked_add(read) + .ok_or_else(|| miette!("websocket permessage-deflate input length overflow"))?; + if out.len().saturating_add(written) > MAX_TEXT_MESSAGE_BYTES { + return Err(miette!( + "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" + )); + } + out.extend_from_slice(&scratch[..written]); + if matches!(status, Status::StreamEnd) { + break; + } + if input_pos >= input.len() && written < scratch.len() { + break; + } + if read == 0 && written == 0 { + return Err(miette!( + "websocket permessage-deflate decompression did not make progress" + )); + } + } + Ok(out) +} + +fn compress_permessage_deflate(payload: &[u8]) -> Result> { + let mut compressor = Compress::new(Compression::fast(), false); + let expansion = payload.len() / 16; + let mut out = Vec::with_capacity(payload.len().saturating_add(expansion).saturating_add(128)); + loop { + let consumed = usize::try_from(compressor.total_in()) + .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; + if consumed >= payload.len() { + break; + } + let before_in = compressor.total_in(); + let before_out = compressor.total_out(); + let status = compressor + .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if matches!(status, Status::BufError) + || (compressor.total_in() == before_in && compressor.total_out() == before_out) + { + out.reserve(out.capacity().max(1024)); + } + } + loop { + out.reserve(64); + let before_out = compressor.total_out(); + compressor + .compress_vec(&[], &mut out, FlushCompress::Sync) + .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; + if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + break; + } + if compressor.total_out() == before_out { + out.reserve(out.capacity().max(1024)); + } + } + if !out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { + return Err(miette!( + "websocket permessage-deflate compression missing sync marker" + )); + } + out.truncate(out.len() - 4); + Ok(out) +} + +fn new_mask_key() -> [u8; 4] { + let bytes = uuid::Uuid::new_v4().into_bytes(); + [bytes[0], bytes[1], bytes[2], bytes[3]] +} + +fn apply_mask(payload: &mut [u8], mask_key: [u8; 4]) { + for (i, byte) in payload.iter_mut().enumerate() { + *byte ^= mask_key[i % 4]; + } +} + +fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: usize) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(rewrite_event_message(host, port, replacements)) + .build(); + ocsf_emit!(event); +} + +fn rewrite_event_message(host: &str, port: u16, replacements: usize) -> String { + format!( + "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" + ) +} + +fn emit_websocket_l7_event( + host: &str, + port: u16, + policy_name: &str, + request_info: &L7RequestInfo, + decision: &str, + reason: &str, + graphql: Option<&crate::l7::graphql::GraphqlRequestInfo>, +) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let (action_id, disposition_id, severity) = match decision { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let summary = graphql.map(graphql_log_summary).unwrap_or_default(); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(format!( + "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{}{} reason={reason}", + request_info.action, request_info.target, summary + )) + .build(); + ocsf_emit!(event); +} + +fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { + if let Some(error) = info.error.as_deref() { + return format!(" graphql_error={error:?}"); + } + let ops: Vec = info + .operations + .iter() + .map(|op| { + let name = op.operation_name.as_deref().unwrap_or("-"); + let fields = if op.fields.is_empty() { + "-".to_string() + } else { + op.fields.join(",") + }; + let persisted = op + .persisted_query_hash + .as_deref() + .or(op.persisted_query_id.as_deref()) + .unwrap_or("-"); + format!( + "type={} name={} fields={} persisted={}", + op.operation_type, name, fields, persisted + ) + }) + .collect(); + format!(" graphql_ops={}", ops.join(";")) +} + +fn protocol_failure_class(error: &miette::Report) -> &'static str { + let msg = error.to_string().to_ascii_lowercase(); + if msg.contains("credential") { + "credential_resolution_failed" + } else if msg.contains("utf-8") { + "invalid_utf8" + } else if msg.contains("close frame") || msg.contains("after close") { + "invalid_close_frame" + } else if msg.contains("control frame") { + "invalid_control_frame" + } else if msg.contains("length") + || msg.contains("too large") + || msg.contains("exceeds") + || msg.contains("overflow") + { + "invalid_length" + } else if msg.contains("continuation") || msg.contains("fragmented") { + "invalid_fragmentation" + } else if msg.contains("reserved opcode") { + "reserved_opcode" + } else if msg.contains("not masked") { + "unmasked_client_frame" + } else if msg.contains("rsv") { + "rsv_bits" + } else if msg.contains("malformed") { + "malformed_frame" + } else { + "protocol_error" + } +} + +fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, failure_class: &str) { + let policy_name = if policy_name.is_empty() { + "-" + } else { + policy_name + }; + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .firewall_rule(policy_name, "l7-websocket") + .message(protocol_failure_message(host, port)) + .status_detail(failure_class) + .build(); + ocsf_emit!(event); +} + +fn protocol_failure_message(host: &str, port: u16) -> String { + format!("WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::l7::relay::L7EvalContext; + use crate::opa::{NetworkInput, OpaEngine}; + use crate::secrets::SecretResolver; + use std::path::PathBuf; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); + const GRAPHQL_WS_POLICY: &str = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.test + port: 443 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + binaries: + - { path: /usr/bin/node } +"#; + + fn resolver() -> (HashMap, SecretResolver) { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), + ); + (child_env, resolver.expect("resolver")) + } + + fn masked_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec { + masked_frame_with_rsv(fin, opcode, 0, payload) + } + + fn masked_frame_with_rsv(fin: bool, opcode: u8, rsv: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push((if fin { 0x80 } else { 0 }) | rsv | opcode); + match payload.len() { + 0..=125 => frame.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), + 126..=65_535 => { + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("payload <= 65535") + .to_be_bytes(), + ); + } + _ => { + frame.push(0x80 | 127); + frame.extend_from_slice(&(payload.len() as u64).to_be_bytes()); + } + } + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(u8::try_from(payload.len()).expect("test payload fits in one byte")); + frame.extend_from_slice(payload); + frame + } + + fn masked_frame_with_declared_len(opcode: u8, declared_len: u64) -> Vec { + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 127); + frame.extend_from_slice(&declared_len.to_be_bytes()); + frame.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); + frame + } + + fn masked_frame_with_non_minimal_16_bit_len(opcode: u8, payload: &[u8]) -> Vec { + let mask_key = [0x37, 0xfa, 0x21, 0x3d]; + let mut frame = Vec::new(); + frame.push(0x80 | opcode); + frame.push(0x80 | 0x7e); + frame.extend_from_slice( + &u16::try_from(payload.len()) + .expect("test payload fits u16") + .to_be_bytes(), + ); + frame.extend_from_slice(&mask_key); + for (i, byte) in payload.iter().enumerate() { + frame.push(byte ^ mask_key[i % 4]); + } + frame + } + + fn close_payload(code: u16, reason: &[u8]) -> Vec { + let mut payload = Vec::with_capacity(2 + reason.len()); + payload.extend_from_slice(&code.to_be_bytes()); + payload.extend_from_slice(reason); + payload + } + + async fn run_client_to_server(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_with_graphql_policy( + input: Vec, + resolver: Option<&SecretResolver>, + ) -> Result> { + let engine = OpaEngine::from_strings(TEST_POLICY, GRAPHQL_WS_POLICY) + .expect("GraphQL WebSocket policy should load"); + let network_input = NetworkInput { + host: "realtime.graphql.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let generation = engine + .evaluate_network_action_with_generation(&network_input) + .expect("network action should evaluate") + .1; + let tunnel_engine = engine + .clone_engine_for_tunnel(generation) + .expect("tunnel engine"); + let ctx = L7EvalContext { + host: "realtime.graphql.test".into(), + port: 443, + policy_name: "graphql_ws".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "graphql_ws", + resolver, + inspector: Some(InspectionOptions { + engine: &tunnel_engine, + ctx: &ctx, + enforcement: EnforcementMode::Enforce, + target: "/graphql".to_string(), + query_params: HashMap::new(), + graphql_policy: true, + }), + compression: WebSocketCompression::None, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "realtime.graphql.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + async fn run_client_to_server_compressed(input: Vec) -> Result> { + let (_, resolver) = resolver(); + let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); + + client_write.write_all(&input).await.unwrap(); + drop(client_write); + + let options = RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::PermessageDeflate, + }; + let result = relay_client_to_server( + &mut relay_read, + &mut relay_write, + "gateway.example.test", + 443, + &options, + ) + .await; + drop(relay_write); + + let mut output = Vec::new(); + upstream_read.read_to_end(&mut output).await.unwrap(); + result.map(|()| output) + } + + fn decode_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_ne!(frame[1] & 0x80, 0); + String::from_utf8(decode_masked_payload(frame)).unwrap() + } + + fn decode_masked_payload(frame: &[u8]) -> Vec { + assert_ne!(frame[1] & 0x80, 0); + let len_code = frame[1] & 0x7F; + let (payload_len, mask_offset) = match len_code { + 0..=125 => (usize::from(len_code), 2), + 126 => (usize::from(u16::from_be_bytes([frame[2], frame[3]])), 4), + 127 => { + let len = u64::from_be_bytes(frame[2..10].try_into().unwrap()); + (usize::try_from(len).unwrap(), 10) + } + _ => unreachable!(), + }; + let mask_key: [u8; 4] = frame[mask_offset..mask_offset + 4].try_into().unwrap(); + let mut payload = frame[mask_offset + 4..mask_offset + 4 + payload_len].to_vec(); + apply_mask(&mut payload, mask_key); + payload + } + + fn decode_compressed_masked_text_frame(frame: &[u8]) -> String { + assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); + assert_eq!(frame[0] & 0x40, 0x40); + let payload = decode_masked_payload(frame); + String::from_utf8(decompress_permessage_deflate(&payload).unwrap()).unwrap() + } + + async fn read_one_frame(reader: &mut R) -> Vec { + let mut header = [0u8; 2]; + reader.read_exact(&mut header).await.unwrap(); + let len_code = header[1] & 0x7F; + let extended_len = match len_code { + 0..=125 => Vec::new(), + 126 => { + let mut bytes = vec![0u8; 2]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + 127 => { + let mut bytes = vec![0u8; 8]; + reader.read_exact(&mut bytes).await.unwrap(); + bytes + } + _ => unreachable!(), + }; + let payload_len = match len_code { + 0..=125 => usize::from(len_code), + 126 => usize::from(u16::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )), + 127 => usize::try_from(u64::from_be_bytes( + extended_len.as_slice().try_into().unwrap(), + )) + .unwrap(), + _ => unreachable!(), + }; + let mask_len = if header[1] & 0x80 != 0 { 4 } else { 0 }; + let mut rest = vec![0u8; extended_len.len() + mask_len + payload_len]; + rest[..extended_len.len()].copy_from_slice(&extended_len); + reader + .read_exact(&mut rest[extended_len.len()..]) + .await + .unwrap(); + + let mut frame = header.to_vec(); + frame.extend_from_slice(&rest); + frame + } + + #[test] + fn classifies_graphql_transport_ws_subscribe_operation() { + let message = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "subscribe"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations.len(), 1); + assert_eq!(graphql.operations[0].operation_type, "subscription"); + assert_eq!( + graphql.operations[0].operation_name.as_deref(), + Some("NewMessages") + ); + assert_eq!(graphql.operations[0].fields, vec!["messageAdded"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_legacy_graphql_ws_start_operation() { + let message = r#"{"type":"start","id":"1","payload":{"query":"query Viewer { viewer }"}}"#; + + match classify_graphql_websocket_message(message) { + GraphqlWebSocketMessage::Operation { + message_type, + graphql, + } => { + assert_eq!(message_type, "start"); + assert!( + graphql.error.is_none(), + "unexpected error: {:?}", + graphql.error + ); + assert_eq!(graphql.operations[0].operation_type, "query"); + assert_eq!(graphql.operations[0].fields, vec!["viewer"]); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + } + } + + #[test] + fn classifies_graphql_websocket_control_message_without_payload_logging() { + match classify_graphql_websocket_message( + r#"{"type":"connection_init","payload":{"authorization":"secret"}}"#, + ) { + GraphqlWebSocketMessage::Control { message_type } => { + assert_eq!(message_type, "connection_init"); + } + other @ GraphqlWebSocketMessage::Operation { .. } => { + panic!("expected control message, got {other:?}") + } + } + } + + #[test] + fn unsupported_graphql_websocket_message_type_fails_closed() { + match classify_graphql_websocket_message(r#"{"type":"next","id":"1"}"#) { + GraphqlWebSocketMessage::Operation { graphql, .. } => { + assert!( + graphql + .error + .as_deref() + .is_some_and(|error| error.contains("unsupported")) + ); + } + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation error, got {other:?}") + } + } + } + + #[test] + fn graphql_websocket_log_summary_excludes_payload_variables_and_secrets() { + let placeholder = "openshell:resolve:env:T"; + let message = format!( + r#"{{"type":"subscribe","id":"1","payload":{{"query":"query Viewer {{ viewer }}","variables":{{"token":"{placeholder}"}}}}}}"# + ); + let graphql = match classify_graphql_websocket_message(&message) { + GraphqlWebSocketMessage::Operation { graphql, .. } => graphql, + other @ GraphqlWebSocketMessage::Control { .. } => { + panic!("expected operation, got {other:?}") + } + }; + let summary = graphql_log_summary(&graphql); + + assert!(summary.contains("type=query")); + assert!(summary.contains("fields=viewer")); + assert!(!summary.contains(placeholder)); + assert!(!summary.contains("real-token")); + assert!(!summary.contains("variables")); + assert!(!summary.contains("token")); + assert!(!summary.contains("secret_len")); + } + + #[tokio::test] + async fn rewrites_discord_like_identify_text_payload() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let output = run_client_to_server(masked_frame(true, OPCODE_TEXT, payload.as_bytes())) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + } + + #[tokio::test] + async fn upgraded_relay_rewrites_client_text_before_upstream_receives_it() { + let (child_env, resolver) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + let client_frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + assert!( + !String::from_utf8_lossy(&client_frame).contains("real-token"), + "client-side fixture must not contain the real token" + ); + + let (mut client_app, mut relay_client) = tokio::io::duplex(4096); + let (mut relay_upstream, mut upstream_app) = tokio::io::duplex(4096); + let relay = tokio::spawn(async move { + relay_with_options( + &mut relay_client, + &mut relay_upstream, + Vec::new(), + "gateway.example.test", + 443, + RelayOptions { + policy_name: "test-policy", + resolver: Some(&resolver), + inspector: None, + compression: WebSocketCompression::None, + }, + ) + .await + }); + + client_app.write_all(&client_frame).await.unwrap(); + client_app.flush().await.unwrap(); + + let upstream_frame = tokio::time::timeout( + std::time::Duration::from_secs(2), + read_one_frame(&mut upstream_app), + ) + .await + .expect("upstream should receive rewritten frame"); + assert_eq!( + decode_masked_text_frame(&upstream_frame), + r#"{"op":2,"d":{"token":"real-token"}}"# + ); + + drop(client_app); + drop(upstream_app); + let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay).await; + } + + #[tokio::test] + async fn graphql_websocket_policy_allows_subscription_operation() { + let payload = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame.clone(), None) + .await + .expect("allowed subscription should relay"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), payload); + } + + #[tokio::test] + async fn graphql_websocket_policy_denies_unlisted_operation_field() { + let payload = + r#"{"type":"subscribe","id":"1","payload":{"query":"query Admin { adminAuditLog }"}}"#; + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let err = run_client_to_server_with_graphql_policy(frame, None) + .await + .expect_err("unlisted field should be denied"); + + assert!(err.to_string().contains("websocket GraphQL message denied")); + } + + #[tokio::test] + async fn graphql_websocket_control_message_rewrites_credentials_before_relay() { + let (child_env, resolver) = SecretResolver::from_provider_env( + std::iter::once(("T".to_string(), "real-token".to_string())).collect(), + ); + let resolver = resolver.expect("resolver"); + let placeholder = child_env.get("T").expect("placeholder env"); + let payload = format!( + r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# + ); + let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); + + let output = run_client_to_server_with_graphql_policy(frame, Some(&resolver)) + .await + .expect("control message should relay after credential rewrite"); + + let rewritten = decode_masked_text_frame(&output); + assert_eq!( + rewritten, + r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# + ); + assert!(!rewritten.contains(placeholder)); + } + + #[tokio::test] + async fn text_without_placeholder_passes_semantically_unchanged() { + let frame = masked_frame(true, OPCODE_TEXT, br#"{"op":1,"d":42}"#); + let output = run_client_to_server(frame.clone()) + .await + .expect("relay should succeed"); + + assert_eq!(output, frame); + assert_eq!(decode_masked_text_frame(&output), r#"{"op":1,"d":42}"#); + } + + #[tokio::test] + async fn unknown_placeholder_fails_closed() { + let frame = masked_frame( + true, + OPCODE_TEXT, + br#"{"token":"openshell:resolve:env:UNKNOWN"}"#, + ); + + let err = run_client_to_server(frame) + .await + .expect_err("unknown placeholder should fail"); + + assert!( + err.to_string() + .contains("credential placeholder resolution") + ); + } + + #[tokio::test] + async fn fragmented_text_rewrites_after_final_continuation() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let second = r#""}"#; + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend(masked_frame(true, OPCODE_CONTINUATION, second.as_bytes())); + + let output = run_client_to_server(input) + .await + .expect("relay should succeed"); + + assert_eq!( + decode_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn rejects_rsv_bits() { + let mut frame = masked_frame(true, OPCODE_TEXT, b"hello"); + frame[0] |= 0x40; + + let err = run_client_to_server(frame) + .await + .expect_err("RSV frame should fail"); + + assert!(err.to_string().contains("RSV bits")); + } + + #[tokio::test] + async fn rejects_unmasked_client_frame() { + let err = run_client_to_server(unmasked_frame(OPCODE_TEXT, b"hello")) + .await + .expect_err("unmasked frame should fail"); + + assert!(err.to_string().contains("not masked")); + } + + #[tokio::test] + async fn rejects_invalid_utf8_text() { + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &[0xff])) + .await + .expect_err("invalid UTF-8 should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_oversize_text_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &payload)) + .await + .expect_err("oversize text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn fragmented_text_allows_interleaved_ping_pong_and_rewrites_at_completion() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let first = format!(r#"{{"token":"{placeholder}"#); + let first_control_frame = masked_frame(true, OPCODE_PING, b"p"); + let second_control_frame = masked_frame(true, OPCODE_PONG, b"q"); + let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); + input.extend_from_slice(&first_control_frame); + input.extend_from_slice(&second_control_frame); + input.extend(masked_frame(true, OPCODE_CONTINUATION, br#""}"#)); + + let output = run_client_to_server(input) + .await + .expect("relay should allow interleaved control frames"); + + assert!(output.starts_with(&first_control_frame)); + assert_eq!( + &output + [first_control_frame.len()..first_control_frame.len() + second_control_frame.len()], + second_control_frame.as_slice() + ); + assert_eq!( + decode_masked_text_frame( + &output[first_control_frame.len() + second_control_frame.len()..] + ), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rewrites_with_permessage_deflate() { + let (child_env, _) = resolver(); + let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); + let payload = format!(r#"{{"token":"{placeholder}"}}"#); + let compressed = compress_permessage_deflate(payload.as_bytes()).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let output = run_client_to_server_compressed(input) + .await + .expect("compressed text should relay"); + + assert_eq!( + decode_compressed_masked_text_frame(&output), + r#"{"token":"real-token"}"# + ); + } + + #[tokio::test] + async fn compressed_text_rejects_decompressed_oversize_message() { + let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; + let compressed = compress_permessage_deflate(&payload).unwrap(); + let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); + + let err = run_client_to_server_compressed(input) + .await + .expect_err("oversize decompressed text should fail"); + + assert!(err.to_string().contains("exceeds")); + } + + #[tokio::test] + async fn binary_frame_passes_through_unchanged() { + let frame = masked_frame(true, OPCODE_BINARY, &[0, 1, 2, 3, 255]); + + let output = run_client_to_server(frame.clone()) + .await + .expect("binary frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_reserved_opcode() { + let err = run_client_to_server(masked_frame(true, 0x3, b"reserved")) + .await + .expect_err("reserved opcode should fail"); + + assert!(err.to_string().contains("reserved opcode")); + } + + #[tokio::test] + async fn rejects_continuation_without_active_message() { + let err = run_client_to_server(masked_frame(true, OPCODE_CONTINUATION, b"orphan")) + .await + .expect_err("orphan continuation should fail"); + + assert!(err.to_string().contains("continuation")); + } + + #[tokio::test] + async fn rejects_new_data_frame_before_fragment_completion() { + let mut input = masked_frame(false, OPCODE_TEXT, b"partial"); + input.extend(masked_frame(true, OPCODE_TEXT, b"second")); + + let err = run_client_to_server(input) + .await + .expect_err("new data frame during fragmentation should fail"); + + assert!(err.to_string().contains("previous fragmented message")); + } + + #[tokio::test] + async fn rejects_fragmented_control_frame() { + let err = run_client_to_server(masked_frame(false, OPCODE_PING, b"ping")) + .await + .expect_err("fragmented control frame should fail"); + + assert!(err.to_string().contains("control frame is fragmented")); + } + + #[tokio::test] + async fn rejects_control_frame_over_125_bytes() { + let payload = vec![b'a'; 126]; + let err = run_client_to_server(masked_frame(true, OPCODE_PING, &payload)) + .await + .expect_err("oversize control frame should fail"); + + assert!(err.to_string().contains("control frame exceeds")); + } + + #[tokio::test] + async fn rejects_non_minimal_extended_length() { + let err = run_client_to_server(masked_frame_with_non_minimal_16_bit_len( + OPCODE_TEXT, + b"hello", + )) + .await + .expect_err("non-minimal length should fail"); + + assert!(err.to_string().contains("non-minimal")); + } + + #[tokio::test] + async fn rejects_oversize_binary_frame_before_payload_buffering() { + let err = run_client_to_server(masked_frame_with_declared_len( + OPCODE_BINARY, + MAX_RAW_FRAME_PAYLOAD_BYTES + 1, + )) + .await + .expect_err("oversize binary frame should fail"); + + assert!(err.to_string().contains("binary frame exceeds")); + } + + #[tokio::test] + async fn validates_close_frame_payloads() { + let frame = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + + let output = run_client_to_server(frame.clone()) + .await + .expect("valid close frame should pass through"); + + assert_eq!(output, frame); + } + + #[tokio::test] + async fn rejects_close_frame_with_one_byte_payload() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &[0x03])) + .await + .expect_err("one-byte close frame should fail"); + + assert!(err.to_string().contains("exactly one byte")); + } + + #[tokio::test] + async fn rejects_reserved_close_code() { + let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &close_payload(1005, b""))) + .await + .expect_err("reserved close code should fail"); + + assert!(err.to_string().contains("invalid close code")); + } + + #[tokio::test] + async fn rejects_close_reason_with_invalid_utf8() { + let err = run_client_to_server(masked_frame( + true, + OPCODE_CLOSE, + &close_payload(1000, &[0xff]), + )) + .await + .expect_err("invalid close reason should fail"); + + assert!(err.to_string().contains("valid UTF-8")); + } + + #[tokio::test] + async fn rejects_frames_after_client_close_frame() { + let mut input = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); + input.extend(masked_frame(true, OPCODE_TEXT, b"late")); + + let err = run_client_to_server(input) + .await + .expect_err("frames after close should fail"); + + assert!(err.to_string().contains("after close")); + } + + #[test] + fn websocket_ocsf_messages_do_not_include_payload_or_secret_material() { + let placeholder = "openshell:resolve:env:DISCORD_BOT_TOKEN"; + let secret = "real-token"; + let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); + + let rewrite = rewrite_event_message("gateway.example.test", 443, 1); + let failure = protocol_failure_message("gateway.example.test", 443); + let messages = [rewrite, failure]; + + for message in messages { + assert!(!message.contains(placeholder)); + assert!(!message.contains(secret)); + assert!(!message.contains(&payload)); + assert!(!message.contains("secret_len")); + assert!(!message.contains("payload_len")); + } + } +} diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index ba7c51de9..273bdee3b 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -355,6 +355,7 @@ fn build_l7_rules(samples: &HashMap<(String, String), u32>) -> Vec { operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + rpc_method: String::new(), }), }); } diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs new file mode 100644 index 000000000..aa49509e5 --- /dev/null +++ b/crates/openshell-sandbox/src/opa.rs @@ -0,0 +1,5248 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Embedded OPA policy engine using regorus. +//! +//! Wraps [`regorus::Engine`] to evaluate Rego policies for sandbox network +//! access decisions. The engine is loaded once at sandbox startup and queried +//! on every proxy CONNECT request. + +use crate::policy::{FilesystemPolicy, LandlockCompatibility, LandlockPolicy, ProcessPolicy}; +use miette::Result; +use openshell_core::proto::SandboxPolicy as ProtoSandboxPolicy; +use std::path::{Path, PathBuf}; +use std::sync::{ + Arc, Mutex, + atomic::{AtomicU64, Ordering}, +}; + +/// Baked-in rego rules for OPA policy evaluation. +/// These rules define the network access decision logic and static config +/// passthroughs. They reference `data.sandbox.*` for policy data. +const BAKED_POLICY_RULES: &str = include_str!("../data/sandbox-policy.rego"); + +/// Result of evaluating a network access request against OPA policy. +pub struct PolicyDecision { + pub allowed: bool, + pub reason: String, + pub matched_policy: Option, +} + +/// Network action returned by OPA `network_action` rule. +/// +/// - `Allow`: endpoint + binary explicitly matched in a network policy +/// - `Deny`: no matching policy +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NetworkAction { + Allow { matched_policy: Option }, + Deny { reason: String }, +} + +/// Input for a network access policy evaluation. +pub struct NetworkInput { + pub host: String, + pub port: u16, + pub binary_path: PathBuf, + pub binary_sha256: String, + /// Ancestor binary paths from process tree walk (parent, grandparent, ...). + pub ancestors: Vec, + /// Absolute paths extracted from `/proc//cmdline` of the socket-owning + /// process and its ancestors. Captures script paths (e.g. `/usr/local/bin/claude`) + /// that don't appear in `/proc//exe` because the interpreter (node) is the exe. + pub cmdline_paths: Vec, +} + +/// Sandbox configuration extracted from OPA data at startup. +pub struct SandboxConfig { + pub filesystem: FilesystemPolicy, + pub landlock: LandlockPolicy, + pub process: ProcessPolicy, +} + +/// Embedded OPA policy engine. +/// +/// Thread-safe: the inner `regorus::Engine` requires `&mut self` for +/// evaluation, so access is serialized via a `Mutex`. This is acceptable +/// because policy evaluation is fast (microseconds) and contention is low +/// (one eval per CONNECT request). +pub struct OpaEngine { + engine: Mutex, + generation: Arc, +} + +/// Generation guard captured when an HTTP tunnel or request path starts. +#[derive(Clone)] +pub struct PolicyGenerationGuard { + captured_generation: u64, + current_generation: Arc, +} + +impl PolicyGenerationGuard { + pub fn captured_generation(&self) -> u64 { + self.captured_generation + } + + pub fn current_generation(&self) -> u64 { + self.current_generation.load(Ordering::Acquire) + } + + pub fn is_stale(&self) -> bool { + self.current_generation() != self.captured_generation + } + + pub fn ensure_current(&self) -> Result<()> { + if self.is_stale() { + return Err(miette::miette!( + "policy generation is stale [captured_generation:{} current_generation:{}]", + self.captured_generation(), + self.current_generation(), + )); + } + Ok(()) + } +} + +/// Per-tunnel L7 policy evaluator bound to the engine generation captured when +/// the tunnel was established. +pub struct TunnelPolicyEngine { + engine: Mutex, + generation_guard: PolicyGenerationGuard, +} + +impl TunnelPolicyEngine { + pub fn captured_generation(&self) -> u64 { + self.generation_guard.captured_generation() + } + + pub fn current_generation(&self) -> u64 { + self.generation_guard.current_generation() + } + + pub fn is_stale(&self) -> bool { + self.generation_guard.is_stale() + } + + pub fn generation_guard(&self) -> &PolicyGenerationGuard { + &self.generation_guard + } + + pub(crate) fn engine(&self) -> &Mutex { + &self.engine + } +} + +impl OpaEngine { + /// Load policy from a `.rego` rules file and data from a YAML file. + /// + /// Preprocesses the YAML data to expand access presets and validate L7 config. + pub fn from_files(policy_path: &Path, data_path: &Path) -> Result { + let yaml_str = std::fs::read_to_string(data_path).map_err(|e| { + miette::miette!("failed to read YAML data from {}: {e}", data_path.display()) + })?; + let mut engine = regorus::Engine::new(); + engine + .add_policy_from_file(policy_path) + .map_err(|e| miette::miette!("{e}"))?; + let data_json = preprocess_yaml_data(&yaml_str)?; + engine + .add_data_json(&data_json) + .map_err(|e| miette::miette!("{e}"))?; + Ok(Self { + engine: Mutex::new(engine), + generation: Arc::new(AtomicU64::new(0)), + }) + } + + /// Load policy rules and data from strings (data is YAML). + /// + /// Preprocesses the YAML data to expand access presets and validate L7 config. + pub fn from_strings(policy: &str, data_yaml: &str) -> Result { + let mut engine = regorus::Engine::new(); + engine + .add_policy("policy.rego".into(), policy.into()) + .map_err(|e| miette::miette!("{e}"))?; + let data_json = preprocess_yaml_data(data_yaml)?; + engine + .add_data_json(&data_json) + .map_err(|e| miette::miette!("{e}"))?; + Ok(Self { + engine: Mutex::new(engine), + generation: Arc::new(AtomicU64::new(0)), + }) + } + + /// Create OPA engine from a typed proto policy. + /// + /// Uses baked-in rego rules and converts the proto's typed fields to JSON + /// data under the `sandbox` key (matching `data.sandbox.*` references in + /// the rego rules). + /// + /// Expands access presets and validates L7 config. + pub fn from_proto(proto: &ProtoSandboxPolicy) -> Result { + Self::from_proto_with_pid(proto, 0) + } + + /// Create OPA engine from a typed proto policy with symlink resolution. + /// + /// When `entrypoint_pid` is non-zero, binary paths in the policy that are + /// symlinks inside the container filesystem are resolved via + /// `/proc//root/` and added as additional entries. This bridges the + /// gap between user-specified symlink paths (e.g., `/usr/bin/python3`) and + /// kernel-resolved canonical paths (e.g., `/usr/bin/python3.11`). + pub fn from_proto_with_pid(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> Result { + let data_json_str = proto_to_opa_data_json(proto, entrypoint_pid); + + // Parse back to Value for preprocessing, then re-serialize + let mut data: serde_json::Value = serde_json::from_str(&data_json_str) + .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; + + // Validate BEFORE expanding presets + let (errors, warnings) = crate::l7::validate_l7_policies(&data); + for w in &warnings { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "validated") + .unmapped("warning", serde_json::json!(w.clone())) + .message(format!("L7 policy validation warning: {w}")) + .build() + ); + } + if !errors.is_empty() { + return Err(miette::miette!( + "L7 policy validation failed:\n{}", + errors.join("\n") + )); + } + + // Expand access presets to explicit rules after validation + crate::l7::expand_access_presets(&mut data); + + let data_json = data.to_string(); + let mut engine = regorus::Engine::new(); + engine + .add_policy("policy.rego".into(), BAKED_POLICY_RULES.into()) + .map_err(|e| miette::miette!("{e}"))?; + engine + .add_data_json(&data_json) + .map_err(|e| miette::miette!("{e}"))?; + Ok(Self { + engine: Mutex::new(engine), + generation: Arc::new(AtomicU64::new(0)), + }) + } + + /// Evaluate a network access request against the loaded policy. + /// + /// Builds an OPA input document from the `NetworkInput`, evaluates the + /// `allow_network` rule, and returns a `PolicyDecision` with the result, + /// deny reason, and matched policy name. + pub fn evaluate_network(&self, input: &NetworkInput) -> Result { + let ancestor_strs: Vec = input + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let cmdline_strs: Vec = input + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let input_json = serde_json::json!({ + "exec": { + "path": input.binary_path.to_string_lossy(), + "ancestors": ancestor_strs, + "cmdline_paths": cmdline_strs, + }, + "network": { + "host": input.host, + "port": input.port, + } + }); + + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + + engine + .set_input_json(&input_json.to_string()) + .map_err(|e| miette::miette!("{e}"))?; + + let allowed = engine + .eval_rule("data.openshell.sandbox.allow_network".into()) + .map_err(|e| miette::miette!("{e}"))?; + let allowed = allowed == regorus::Value::from(true); + + let reason = engine + .eval_rule("data.openshell.sandbox.deny_reason".into()) + .map_err(|e| miette::miette!("{e}"))?; + let reason = value_to_string(&reason); + + let matched = engine + .eval_rule("data.openshell.sandbox.matched_network_policy".into()) + .map_err(|e| miette::miette!("{e}"))?; + let matched_policy = if matched == regorus::Value::Undefined { + None + } else { + Some(value_to_string(&matched)) + }; + + Ok(PolicyDecision { + allowed, + reason, + matched_policy, + }) + } + + /// Evaluate a network access request and return a routing action. + /// + /// Uses the OPA `network_action` rule which returns one of: + /// `"allow"` or `"deny"`. + pub fn evaluate_network_action(&self, input: &NetworkInput) -> Result { + Ok(self.evaluate_network_action_with_generation(input)?.0) + } + + /// Evaluate network action and return the policy generation used for the evaluation. + pub fn evaluate_network_action_with_generation( + &self, + input: &NetworkInput, + ) -> Result<(NetworkAction, u64)> { + let ancestor_strs: Vec = input + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let cmdline_strs: Vec = input + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let input_json = serde_json::json!({ + "exec": { + "path": input.binary_path.to_string_lossy(), + "ancestors": ancestor_strs, + "cmdline_paths": cmdline_strs, + }, + "network": { + "host": input.host, + "port": input.port, + } + }); + + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + let generation = self.current_generation(); + + engine + .set_input_json(&input_json.to_string()) + .map_err(|e| miette::miette!("{e}"))?; + + let action_val = engine + .eval_rule("data.openshell.sandbox.network_action".into()) + .map_err(|e| miette::miette!("{e}"))?; + let action_str = value_to_string(&action_val); + + let matched = engine + .eval_rule("data.openshell.sandbox.matched_network_policy".into()) + .map_err(|e| miette::miette!("{e}"))?; + let matched_policy = if matched == regorus::Value::Undefined { + None + } else { + Some(value_to_string(&matched)) + }; + + if action_str == "allow" { + Ok((NetworkAction::Allow { matched_policy }, generation)) + } else { + let reason_val = engine + .eval_rule("data.openshell.sandbox.deny_reason".into()) + .map_err(|e| miette::miette!("{e}"))?; + let reason = value_to_string(&reason_val); + Ok((NetworkAction::Deny { reason }, generation)) + } + } + + /// Reload policy and data from strings (data is YAML). + /// + /// Designed for future gRPC hot-reload from the openshell gateway. + /// Replaces the entire engine atomically. Routes through the full + /// preprocessing pipeline (port normalization, L7 validation, preset + /// expansion) to maintain consistency with `from_strings()`. + pub fn reload(&self, policy: &str, data_yaml: &str) -> Result<()> { + let new = Self::from_strings(policy, data_yaml)?; + let new_engine = new + .engine + .into_inner() + .map_err(|_| miette::miette!("lock poisoned on new engine"))?; + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + *engine = new_engine; + self.generation.fetch_add(1, Ordering::AcqRel); + Ok(()) + } + + /// Reload policy from a proto `SandboxPolicy` message. + /// + /// Reuses the full `from_proto()` pipeline (proto-to-JSON conversion, L7 + /// validation, access preset expansion) so the reload has identical + /// validation guarantees as initial load. Atomically replaces the inner + /// engine on success; on failure the previous engine is untouched (LKG). + pub fn reload_from_proto(&self, proto: &ProtoSandboxPolicy) -> Result<()> { + self.reload_from_proto_with_pid(proto, 0) + } + + /// Reload policy from a proto with symlink resolution. + /// + /// When `entrypoint_pid` is non-zero, binary paths that are symlinks + /// inside the container filesystem are resolved and added as additional + /// match entries. See [`from_proto_with_pid`] for details. + pub fn reload_from_proto_with_pid( + &self, + proto: &ProtoSandboxPolicy, + entrypoint_pid: u32, + ) -> Result<()> { + // Build a complete new engine through the same validated pipeline. + let new = Self::from_proto_with_pid(proto, entrypoint_pid)?; + let new_engine = new + .engine + .into_inner() + .map_err(|_| miette::miette!("lock poisoned on new engine"))?; + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + *engine = new_engine; + self.generation.fetch_add(1, Ordering::AcqRel); + Ok(()) + } + + /// Current policy generation. Successful reloads increment this value. + pub fn current_generation(&self) -> u64 { + self.generation.load(Ordering::Acquire) + } + + /// Return a guard for a previously captured policy generation. + pub fn generation_guard(&self, expected_generation: u64) -> Result { + let generation = self.current_generation(); + if generation != expected_generation { + return Err(miette::miette!( + "policy changed before HTTP relay started [expected_generation:{expected_generation} current_generation:{generation}]" + )); + } + Ok(PolicyGenerationGuard { + captured_generation: generation, + current_generation: Arc::clone(&self.generation), + }) + } + + /// Query static sandbox configuration from the OPA data module. + /// + /// Extracts `filesystem_policy`, `landlock`, and `process` from the Rego + /// data and converts them into the Rust policy structs used by the sandbox + /// runtime for filesystem preparation, Landlock setup, and privilege dropping. + pub fn query_sandbox_config(&self) -> Result { + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + + // Query filesystem policy + let fs_val = engine + .eval_rule("data.openshell.sandbox.filesystem_policy".into()) + .map_err(|e| miette::miette!("{e}"))?; + let filesystem = parse_filesystem_policy(&fs_val); + + // Query landlock policy + let ll_val = engine + .eval_rule("data.openshell.sandbox.landlock_policy".into()) + .map_err(|e| miette::miette!("{e}"))?; + let landlock = parse_landlock_policy(&ll_val); + + // Query process policy + let proc_val = engine + .eval_rule("data.openshell.sandbox.process_policy".into()) + .map_err(|e| miette::miette!("{e}"))?; + let process = parse_process_policy(&proc_val); + + Ok(SandboxConfig { + filesystem, + landlock, + process, + }) + } + + /// Query the L7 endpoint config for a matched policy and host:port. + /// + /// After L4 evaluation allows a CONNECT, this method queries the Rego data + /// to get the full endpoint object for the matched policy. Returns the raw + /// `regorus::Value` which can be parsed by `l7::parse_l7_config()`. + pub fn query_endpoint_config(&self, input: &NetworkInput) -> Result> { + Ok(self.query_endpoint_config_with_generation(input)?.0) + } + + /// Query L7 endpoint config and return the policy generation used for the query. + pub fn query_endpoint_config_with_generation( + &self, + input: &NetworkInput, + ) -> Result<(Option, u64)> { + let (configs, generation) = self.query_endpoint_configs_with_generation(input)?; + Ok((configs.into_iter().next(), generation)) + } + + /// Query all matching endpoint configs and return the policy generation used for the query. + pub fn query_endpoint_configs_with_generation( + &self, + input: &NetworkInput, + ) -> Result<(Vec, u64)> { + let ancestor_strs: Vec = input + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let cmdline_strs: Vec = input + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let input_json = serde_json::json!({ + "exec": { + "path": input.binary_path.to_string_lossy(), + "ancestors": ancestor_strs, + "cmdline_paths": cmdline_strs, + }, + "network": { + "host": input.host, + "port": input.port, + } + }); + + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + let generation = self.current_generation(); + + engine + .set_input_json(&input_json.to_string()) + .map_err(|e| miette::miette!("{e}"))?; + + let val = engine + .eval_rule("data.openshell.sandbox._matching_endpoint_configs".into()) + .map_err(|e| miette::miette!("{e}"))?; + + match val { + regorus::Value::Undefined => Ok((Vec::new(), generation)), + regorus::Value::Array(values) => Ok((values.to_vec(), generation)), + other => Ok((vec![other], generation)), + } + } + + /// Query `allowed_ips` from the matched endpoint config for a given request. + /// + /// Returns the list of CIDR/IP strings from the endpoint's `allowed_ips` + /// field, or an empty vec if the field is absent or the endpoint has no + /// match. This is used by the proxy to decide between full SSRF blocking + /// and allowlist-based IP validation. + pub fn query_allowed_ips(&self, input: &NetworkInput) -> Result> { + Ok(self + .query_endpoint_config(input)? + .map(|val| get_str_array(&val, "allowed_ips")) + .unwrap_or_default()) + } + + /// Return true when the matched endpoint is an exact declared hostname. + /// + /// This intentionally excludes wildcard and hostless endpoints. The proxy + /// uses this as a narrow signal that the operator explicitly declared the + /// destination hostname, which can safely skip the default private-IP SSRF + /// denial while preserving separate handling for `allowed_ips` and advisor + /// proposals. + pub fn query_exact_declared_endpoint_host(&self, input: &NetworkInput) -> Result { + let ancestor_strs: Vec = input + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let cmdline_strs: Vec = input + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let input_json = serde_json::json!({ + "exec": { + "path": input.binary_path.to_string_lossy(), + "ancestors": ancestor_strs, + "cmdline_paths": cmdline_strs, + }, + "network": { + "host": input.host, + "port": input.port, + } + }); + + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + + engine + .set_input_json(&input_json.to_string()) + .map_err(|e| miette::miette!("{e}"))?; + + let val = engine + .eval_rule("data.openshell.sandbox.exact_declared_endpoint_host".into()) + .map_err(|e| miette::miette!("{e}"))?; + + Ok(val == regorus::Value::from(true)) + } + + /// Clone the inner regorus engine for per-tunnel L7 evaluation. + /// + /// With the `arc` feature enabled, this shares compiled policy via Arc + /// and only duplicates interpreter state (~microseconds). The cloned + /// engine can be used without Mutex contention. + pub fn clone_engine_for_tunnel(&self, expected_generation: u64) -> Result { + let engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + let generation = self.current_generation(); + if generation != expected_generation { + return Err(miette::miette!( + "policy changed before L7 tunnel started [expected_generation:{expected_generation} current_generation:{generation}]" + )); + } + Ok(TunnelPolicyEngine { + engine: Mutex::new(engine.clone()), + generation_guard: PolicyGenerationGuard { + captured_generation: generation, + current_generation: Arc::clone(&self.generation), + }, + }) + } +} + +/// Convert a `regorus::Value` to a string, handling various types. +fn value_to_string(val: ®orus::Value) -> String { + match val { + regorus::Value::String(s) => s.to_string(), + regorus::Value::Undefined => String::new(), + other => other.to_string(), + } +} + +/// Extract a string from a `regorus::Value` object field. +fn get_str(val: ®orus::Value, key: &str) -> Option { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => match map.get(&key_val) { + Some(regorus::Value::String(s)) => Some(s.to_string()), + _ => None, + }, + _ => None, + } +} + +/// Extract a bool from a `regorus::Value` object field. +fn get_bool(val: ®orus::Value, key: &str) -> Option { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => match map.get(&key_val) { + Some(regorus::Value::Bool(b)) => Some(*b), + _ => None, + }, + _ => None, + } +} + +/// Extract a string array from a `regorus::Value` object field. +fn get_str_array(val: ®orus::Value, key: &str) -> Vec { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => match map.get(&key_val) { + Some(regorus::Value::Array(arr)) => arr + .iter() + .filter_map(|v| { + if let regorus::Value::String(s) = v { + Some(s.to_string()) + } else { + None + } + }) + .collect(), + _ => vec![], + }, + _ => vec![], + } +} + +fn parse_filesystem_policy(val: ®orus::Value) -> FilesystemPolicy { + FilesystemPolicy { + read_only: get_str_array(val, "read_only") + .into_iter() + .map(PathBuf::from) + .collect(), + read_write: get_str_array(val, "read_write") + .into_iter() + .map(PathBuf::from) + .collect(), + include_workdir: get_bool(val, "include_workdir").unwrap_or(true), + } +} + +fn parse_landlock_policy(val: ®orus::Value) -> LandlockPolicy { + let compat = get_str(val, "compatibility").unwrap_or_default(); + LandlockPolicy { + compatibility: if compat == "hard_requirement" { + LandlockCompatibility::HardRequirement + } else { + LandlockCompatibility::BestEffort + }, + } +} + +fn parse_process_policy(val: ®orus::Value) -> ProcessPolicy { + ProcessPolicy { + run_as_user: get_str(val, "run_as_user"), + run_as_group: get_str(val, "run_as_group"), + } +} + +/// Preprocess YAML policy data: parse, normalize, validate, expand access presets, return JSON. +fn preprocess_yaml_data(yaml_str: &str) -> Result { + let mut data: serde_json::Value = serde_yml::from_str(yaml_str) + .map_err(|e| miette::miette!("failed to parse YAML data: {e}"))?; + + // Normalize port → ports for all endpoints so Rego always sees "ports" array. + normalize_endpoint_ports(&mut data); + + // Validate BEFORE expanding presets (catches user errors like rules+access) + let (errors, warnings) = crate::l7::validate_l7_policies(&data); + for w in &warnings { + openshell_ocsf::ocsf_emit!( + openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(openshell_ocsf::SeverityId::Medium) + .status(openshell_ocsf::StatusId::Success) + .state(openshell_ocsf::StateId::Enabled, "validated") + .unmapped("warning", serde_json::json!(w.clone())) + .message(format!("L7 policy validation warning: {w}")) + .build() + ); + } + if !errors.is_empty() { + return Err(miette::miette!( + "L7 policy validation failed:\n{}", + errors.join("\n") + )); + } + + // Expand access presets to explicit rules after validation + crate::l7::expand_access_presets(&mut data); + + serde_json::to_string(&data).map_err(|e| miette::miette!("failed to serialize data: {e}")) +} + +/// Normalize endpoint port/ports in JSON data. +/// +/// YAML policies may use `port: N` (single) or `ports: [N, M]` (multi). +/// This normalizes all endpoints to have a `ports` array so Rego rules +/// only need to reference `endpoint.ports[_]`. +fn normalize_endpoint_ports(data: &mut serde_json::Value) { + let Some(policies) = data + .get_mut("network_policies") + .and_then(|v| v.as_object_mut()) + else { + return; + }; + + for (_name, policy) in policies.iter_mut() { + let Some(endpoints) = policy.get_mut("endpoints").and_then(|v| v.as_array_mut()) else { + continue; + }; + + for ep in endpoints.iter_mut() { + let Some(ep_obj) = ep.as_object_mut() else { + continue; + }; + + // If "ports" already exists and is non-empty, keep it. + let has_ports = ep_obj + .get("ports") + .and_then(|v| v.as_array()) + .is_some_and(|a| !a.is_empty()); + + if !has_ports { + // Promote scalar "port" to "ports" array. + let port = ep_obj + .get("port") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + if port > 0 { + ep_obj.insert( + "ports".to_string(), + serde_json::Value::Array(vec![serde_json::json!(port)]), + ); + } + } + + // Remove scalar "port" — Rego only uses "ports". + ep_obj.remove("port"); + } + } +} + +/// Resolve a policy binary path through the container's root filesystem. +/// +/// On Linux, `/proc//root/` provides access to the container's mount +/// namespace. If the policy path is a symlink inside the container +/// (e.g., `/usr/bin/python3` → `/usr/bin/python3.11`), returns the +/// canonical target path. Returns `None` if: +/// - Not on Linux +/// - `entrypoint_pid` is 0 (container not yet started) +/// - Path contains glob characters +/// - Path is not a symlink +/// - Resolution fails (binary doesn't exist in container) +/// - Resolved path equals the original +/// +/// Normalize a path by resolving `.` and `..` components without touching +/// the filesystem. Only works correctly for absolute paths. +#[cfg(any(target_os = "linux", test))] +fn normalize_path(path: &Path) -> PathBuf { + let mut result = PathBuf::new(); + for component in path.components() { + match component { + std::path::Component::ParentDir => { + result.pop(); + } + std::path::Component::CurDir => {} + other => result.push(other), + } + } + result +} + +#[cfg(target_os = "linux")] +fn resolve_binary_in_container(policy_path: &str, entrypoint_pid: u32) -> Option { + if policy_path.contains('*') || entrypoint_pid == 0 { + return None; + } + + // Walk the symlink chain inside the container filesystem using + // read_link rather than canonicalize. canonicalize resolves + // /proc//root itself (a kernel pseudo-symlink to /) which + // strips the prefix we need. read_link only reads the target of + // the specified symlink, keeping us in the container's namespace. + let mut resolved = PathBuf::from(policy_path); + + // Linux SYMLOOP_MAX is 40; stop before infinite loops + for _ in 0..40 { + let container_path = format!("/proc/{entrypoint_pid}/root{}", resolved.display()); + + tracing::debug!( + "Symlink resolution: probing container_path={container_path} for policy_path={policy_path} pid={entrypoint_pid}" + ); + + let meta = match std::fs::symlink_metadata(&container_path) { + Ok(m) => m, + Err(e) => { + // Only warn on the first iteration (the original policy path). + // On subsequent iterations, the intermediate target may + // legitimately not exist (broken symlink chain). + if resolved.as_os_str() == policy_path { + tracing::warn!( + "Cannot access container filesystem for symlink resolution: \ + path={policy_path} container_path={container_path} pid={entrypoint_pid} \ + error={e}. Binary paths in policy will be matched literally. \ + If this binary is a symlink (e.g., /usr/bin/python3 -> python3.11), \ + use the canonical path instead, or run with CAP_SYS_PTRACE." + ); + } else { + tracing::warn!( + "Symlink chain broken during resolution: \ + original={policy_path} current={} pid={entrypoint_pid} error={e}. \ + Binary will be matched by original path only.", + resolved.display() + ); + } + return None; + } + }; + + if !meta.file_type().is_symlink() { + // Reached a non-symlink — this is the final resolved target + break; + } + + let target = match std::fs::read_link(&container_path) { + Ok(t) => t, + Err(e) => { + tracing::warn!( + "Symlink detected but read_link failed: \ + path={policy_path} current={} pid={entrypoint_pid} error={e}. \ + Binary will be matched by original path only.", + resolved.display() + ); + return None; + } + }; + + if target.is_absolute() { + resolved = target; + } else { + // Relative symlink: resolve against the containing directory + // e.g., /usr/bin/python3 -> python3.11 becomes /usr/bin/python3.11 + if let Some(parent) = resolved.parent() { + resolved = normalize_path(&parent.join(&target)); + } else { + break; + } + } + } + + let resolved_str = resolved.to_string_lossy().into_owned(); + + if resolved_str == policy_path { + None + } else { + tracing::info!( + "Resolved policy binary symlink via container filesystem: \ + original={policy_path} resolved={resolved_str} pid={entrypoint_pid}" + ); + Some(resolved_str) + } +} + +#[cfg(not(target_os = "linux"))] +fn resolve_binary_in_container(_policy_path: &str, _entrypoint_pid: u32) -> Option { + None +} + +/// Convert typed proto policy fields to JSON suitable for `engine.add_data_json()`. +/// +/// The rego rules reference `data.*` directly, so the JSON structure has +/// top-level keys matching the data expectations: +/// - `data.filesystem_policy` +/// - `data.landlock` +/// - `data.process` +/// - `data.network_policies` +/// +/// When `entrypoint_pid` is non-zero, binary paths that are symlinks inside +/// the container filesystem are resolved via `/proc//root/` and added +/// as additional entries alongside the original path. This ensures that +/// user-specified symlink paths (e.g., `/usr/bin/python3`) match the +/// kernel-resolved canonical paths reported by `/proc//exe` (e.g., +/// `/usr/bin/python3.11`). +fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> String { + let filesystem_policy = proto.filesystem.as_ref().map_or_else( + || { + serde_json::json!({ + "include_workdir": true, + "read_only": [], + "read_write": [], + }) + }, + |fs| { + serde_json::json!({ + "include_workdir": fs.include_workdir, + "read_only": fs.read_only, + "read_write": fs.read_write, + }) + }, + ); + + let landlock = proto.landlock.as_ref().map_or_else( + || serde_json::json!({"compatibility": "best_effort"}), + |ll| serde_json::json!({"compatibility": ll.compatibility}), + ); + + let process = proto.process.as_ref().map_or_else( + || { + serde_json::json!({ + "run_as_user": "sandbox", + "run_as_group": "sandbox", + }) + }, + |p| { + serde_json::json!({ + "run_as_user": p.run_as_user, + "run_as_group": p.run_as_group, + }) + }, + ); + + let network_policies: serde_json::Map = proto + .network_policies + .iter() + .map(|(key, rule)| { + let endpoints: Vec = rule + .endpoints + .iter() + .map(|e| { + // Normalize port/ports: ports takes precedence, then + // single port promoted to array. Rego always sees "ports". + let ports: Vec = if !e.ports.is_empty() { + e.ports.clone() + } else if e.port > 0 { + vec![e.port] + } else { + vec![] + }; + let mut ep = serde_json::json!({"host": e.host, "ports": ports}); + if !e.path.is_empty() { + ep["path"] = e.path.clone().into(); + } + if !e.protocol.is_empty() { + ep["protocol"] = e.protocol.clone().into(); + } + if !e.tls.is_empty() { + ep["tls"] = e.tls.clone().into(); + } + if !e.enforcement.is_empty() { + ep["enforcement"] = e.enforcement.clone().into(); + } + if !e.access.is_empty() { + ep["access"] = e.access.clone().into(); + } + if !e.rules.is_empty() { + let rules: Vec = e + .rules + .iter() + .map(|r| { + let a = r.allow.as_ref(); + let mut allow = serde_json::json!({ + "method": a.map_or("", |a| &a.method), + "path": a.map_or("", |a| &a.path), + "command": a.map_or("", |a| &a.command), + "operation_type": a.map_or("", |a| &a.operation_type), + "operation_name": a.map_or("", |a| &a.operation_name), + "rpc_method": a.map_or("", |a| &a.rpc_method), + }); + if let Some(a) = a + && !a.fields.is_empty() + { + allow["fields"] = a.fields.clone().into(); + } + let query: serde_json::Map = a + .map(|allow| { + allow + .query + .iter() + .map(|(key, matcher)| { + let mut matcher_json = serde_json::json!({}); + if !matcher.glob.is_empty() { + matcher_json["glob"] = + matcher.glob.clone().into(); + } + if !matcher.any.is_empty() { + matcher_json["any"] = + matcher.any.clone().into(); + } + (key.clone(), matcher_json) + }) + .collect() + }) + .unwrap_or_default(); + if !query.is_empty() { + allow["query"] = query.into(); + } + serde_json::json!({ "allow": allow }) + }) + .collect(); + ep["rules"] = rules.into(); + } + if !e.allowed_ips.is_empty() { + ep["allowed_ips"] = e.allowed_ips.clone().into(); + } + if e.advisor_proposed { + ep["advisor_proposed"] = true.into(); + } + if !e.deny_rules.is_empty() { + let deny_rules: Vec = e + .deny_rules + .iter() + .map(|d| { + let mut deny = serde_json::json!({}); + if !d.method.is_empty() { + deny["method"] = d.method.clone().into(); + } + if !d.path.is_empty() { + deny["path"] = d.path.clone().into(); + } + if !d.command.is_empty() { + deny["command"] = d.command.clone().into(); + } + if !d.operation_type.is_empty() { + deny["operation_type"] = d.operation_type.clone().into(); + } + if !d.operation_name.is_empty() { + deny["operation_name"] = d.operation_name.clone().into(); + } + if !d.fields.is_empty() { + deny["fields"] = d.fields.clone().into(); + } + let query: serde_json::Map = d + .query + .iter() + .map(|(key, matcher)| { + let mut matcher_json = serde_json::json!({}); + if !matcher.glob.is_empty() { + matcher_json["glob"] = matcher.glob.clone().into(); + } + if !matcher.any.is_empty() { + matcher_json["any"] = matcher.any.clone().into(); + } + (key.clone(), matcher_json) + }) + .collect(); + if !query.is_empty() { + deny["query"] = query.into(); + } + deny + }) + .collect(); + ep["deny_rules"] = deny_rules.into(); + } + if e.allow_encoded_slash { + ep["allow_encoded_slash"] = true.into(); + } + if e.websocket_credential_rewrite { + ep["websocket_credential_rewrite"] = true.into(); + } + if e.request_body_credential_rewrite { + ep["request_body_credential_rewrite"] = true.into(); + } + if !e.persisted_queries.is_empty() { + ep["persisted_queries"] = e.persisted_queries.clone().into(); + } + if !e.graphql_persisted_queries.is_empty() { + let persisted: serde_json::Map = e + .graphql_persisted_queries + .iter() + .map(|(key, op)| { + ( + key.clone(), + serde_json::json!({ + "operation_type": op.operation_type, + "operation_name": op.operation_name, + "fields": op.fields, + }), + ) + }) + .collect(); + ep["graphql_persisted_queries"] = persisted.into(); + } + if e.graphql_max_body_bytes > 0 { + ep["graphql_max_body_bytes"] = e.graphql_max_body_bytes.into(); + } + ep + }) + .collect(); + let binaries: Vec = rule + .binaries + .iter() + .flat_map(|b| { + // The deprecated harness bit is ignored by policy YAML, but + // advisor-generated proposals use it as internal provenance. + #[allow(deprecated)] + let advisor_proposed = b.harness; + let binary_entry = |path: &str| { + let mut entry = serde_json::json!({"path": path}); + if advisor_proposed { + entry["advisor_proposed"] = true.into(); + } + entry + }; + let mut entries = vec![binary_entry(&b.path)]; + if let Some(resolved) = resolve_binary_in_container(&b.path, entrypoint_pid) { + entries.push(binary_entry(&resolved)); + } + entries + }) + .collect(); + ( + key.clone(), + serde_json::json!({ + "name": rule.name, + "endpoints": endpoints, + "binaries": binaries, + }), + ) + }) + .collect(); + + serde_json::json!({ + "filesystem_policy": filesystem_policy, + "landlock": landlock, + "process": process, + "network_policies": network_policies, + }) + .to_string() +} + +#[cfg(test)] +#[allow( + clippy::needless_raw_string_hashes, + clippy::similar_names, + clippy::doc_markdown, + clippy::match_wildcard_for_single_variants, + reason = "Test code: test fixtures and panic-on-unexpected matches are idiomatic in tests." +)] +mod tests { + use super::*; + + use openshell_core::proto::{ + FilesystemPolicy as ProtoFs, L7Allow, L7QueryMatcher, L7Rule, NetworkBinary, + NetworkEndpoint, NetworkPolicyRule, ProcessPolicy as ProtoProc, + SandboxPolicy as ProtoSandboxPolicy, + }; + + const TEST_POLICY: &str = include_str!("../data/sandbox-policy.rego"); + const TEST_DATA_YAML: &str = include_str!("../testdata/sandbox-policy.yaml"); + + fn test_engine() -> OpaEngine { + OpaEngine::from_strings(TEST_POLICY, TEST_DATA_YAML).expect("Failed to load test policy") + } + + fn test_proto() -> ProtoSandboxPolicy { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "claude_code".to_string(), + NetworkPolicyRule { + name: "claude_code".to_string(), + endpoints: vec![ + NetworkEndpoint { + host: "api.anthropic.com".to_string(), + port: 443, + ..Default::default() + }, + NetworkEndpoint { + host: "statsig.anthropic.com".to_string(), + port: 443, + ..Default::default() + }, + ], + binaries: vec![NetworkBinary { + path: "/usr/local/bin/claude".to_string(), + ..Default::default() + }], + }, + ); + network_policies.insert( + "gitlab".to_string(), + NetworkPolicyRule { + name: "gitlab".to_string(), + endpoints: vec![NetworkEndpoint { + host: "gitlab.com".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/glab".to_string(), + ..Default::default() + }], + }, + ); + ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec!["/usr".to_string(), "/lib".to_string()], + read_write: vec!["/sandbox".to_string(), "/tmp".to_string()], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + } + } + + #[test] + fn allowed_binary_and_endpoint() { + let engine = test_engine(); + // Simulates Claude Code: exe is /usr/bin/node, script is /usr/local/bin/claude + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected allow, got deny: {}", + decision.reason + ); + assert_eq!(decision.matched_policy.as_deref(), Some("claude_code")); + } + + #[test] + fn wrong_binary_denied() { + let engine = test_engine(); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + assert!( + decision.reason.contains("not allowed"), + "Expected specific deny reason, got: {}", + decision.reason + ); + } + + #[test] + fn wrong_endpoint_denied() { + let engine = test_engine(); + let input = NetworkInput { + host: "evil.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + assert!( + decision.reason.contains("endpoint"), + "Expected endpoint deny reason, got: {}", + decision.reason + ); + } + + #[test] + fn unknown_binary_default_deny() { + let engine = test_engine(); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/tmp/malicious"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + } + + #[test] + fn github_policy_allows_git() { + let engine = test_engine(); + let input = NetworkInput { + host: "github.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/git"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected allow, got deny: {}", + decision.reason + ); + assert_eq!( + decision.matched_policy.as_deref(), + Some("github_ssh_over_https") + ); + } + + #[test] + fn case_insensitive_host_matching() { + let engine = test_engine(); + let input = NetworkInput { + host: "API.ANTHROPIC.COM".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected case-insensitive match, got deny: {}", + decision.reason + ); + } + + #[test] + fn wrong_port_denied() { + let engine = test_engine(); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 80, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + } + + #[test] + fn query_sandbox_config_extracts_filesystem() { + let engine = test_engine(); + let config = engine.query_sandbox_config().unwrap(); + assert!(config.filesystem.include_workdir); + assert!(config.filesystem.read_only.contains(&PathBuf::from("/usr"))); + assert!( + config + .filesystem + .read_write + .contains(&PathBuf::from("/tmp")) + ); + } + + #[test] + fn query_sandbox_config_extracts_process() { + let engine = test_engine(); + let config = engine.query_sandbox_config().unwrap(); + assert_eq!(config.process.run_as_user.as_deref(), Some("sandbox")); + assert_eq!(config.process.run_as_group.as_deref(), Some("sandbox")); + } + + #[test] + fn from_strings_and_from_files_produce_same_results() { + let engine = test_engine(); + + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(decision.allowed); + } + + #[test] + fn reload_replaces_policy() { + let engine = test_engine(); + + // Verify initial policy works + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(decision.allowed); + + // Reload with a policy that has no network policies (deny all) + let empty_data = r" +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +network_policies: {} +"; + engine.reload(TEST_POLICY, empty_data).unwrap(); + + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + !decision.allowed, + "Expected deny after reload with empty policies" + ); + } + + #[test] + fn ancestor_binary_allowed() { + // Use github policy: binary /usr/bin/git is the policy binary. + // If the socket process is /usr/bin/python3 but its ancestor is /usr/bin/git, allow. + let engine = test_engine(); + let input = NetworkInput { + host: "github.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![PathBuf::from("/usr/bin/git")], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected allow via ancestor match, got deny: {}", + decision.reason + ); + assert_eq!( + decision.matched_policy.as_deref(), + Some("github_ssh_over_https") + ); + } + + #[test] + fn no_ancestor_match_denied() { + let engine = test_engine(); + let input = NetworkInput { + host: "github.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![PathBuf::from("/usr/bin/bash")], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + assert!( + decision.reason.contains("not allowed"), + "Expected 'not allowed' in deny reason, got: {}", + decision.reason + ); + } + + #[test] + fn deep_ancestor_chain_matches() { + let engine = test_engine(); + let input = NetworkInput { + host: "github.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![PathBuf::from("/usr/bin/sh"), PathBuf::from("/usr/bin/git")], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected allow via deep ancestor match, got deny: {}", + decision.reason + ); + } + + #[test] + fn empty_ancestors_falls_back_to_direct() { + let engine = test_engine(); + // Direct binary path match still works with empty ancestors and cmdline + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Direct path match should still work with empty ancestors" + ); + } + + #[test] + fn glob_pattern_matches_binary() { + // Test with a policy that uses glob patterns + let glob_data = r#" +network_policies: + glob_test: + name: glob_test + endpoints: + - { host: example.com, port: 443 } + binaries: + - { path: "/usr/bin/*" } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected glob pattern to match binary, got deny: {}", + decision.reason + ); + } + + #[test] + fn glob_pattern_matches_ancestor() { + let glob_data = r#" +network_policies: + glob_test: + name: glob_test + endpoints: + - { host: example.com, port: 443 } + binaries: + - { path: "/usr/local/bin/*" } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![PathBuf::from("/usr/local/bin/claude")], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected glob pattern to match ancestor, got deny: {}", + decision.reason + ); + } + + #[test] + fn glob_pattern_no_cross_segment() { + // * should NOT match across / boundaries + let glob_data = r#" +network_policies: + glob_test: + name: glob_test + endpoints: + - { host: example.com, port: 443 } + binaries: + - { path: "/usr/bin/*" } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/subdir/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed, "Glob * should not cross / boundaries"); + } + + #[test] + fn cmdline_path_does_not_grant_access() { + // Simulates: node runs /usr/local/bin/my-tool (a script with shebang). + // exe = /usr/bin/node, cmdline contains /usr/local/bin/my-tool. + // cmdline_paths are attacker-controlled (argv[0] spoofing) and must + // NOT be used as a grant-access signal. + let cmdline_data = r" +network_policies: + script_test: + name: script_test + endpoints: + - { host: example.com, port: 443 } + binaries: + - { path: /usr/local/bin/my-tool } +"; + let engine = OpaEngine::from_strings(TEST_POLICY, cmdline_data).unwrap(); + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![PathBuf::from("/usr/bin/bash")], + cmdline_paths: vec![PathBuf::from("/usr/local/bin/my-tool")], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + !decision.allowed, + "cmdline_paths must not grant network access (argv[0] is spoofable)" + ); + } + + #[test] + fn cmdline_path_no_match_denied() { + let cmdline_data = r" +network_policies: + script_test: + name: script_test + endpoints: + - { host: example.com, port: 443 } + binaries: + - { path: /usr/local/bin/my-tool } +"; + let engine = OpaEngine::from_strings(TEST_POLICY, cmdline_data).unwrap(); + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![PathBuf::from("/usr/bin/bash")], + cmdline_paths: vec![ + PathBuf::from("/usr/bin/node"), + PathBuf::from("/tmp/script.js"), + ], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + } + + #[test] + fn cmdline_glob_pattern_does_not_grant_access() { + let glob_data = r#" +network_policies: + glob_test: + name: glob_test + endpoints: + - { host: example.com, port: 443 } + binaries: + - { path: "/usr/local/bin/*" } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + !decision.allowed, + "cmdline_paths must not match globs for granting access (argv[0] is spoofable)" + ); + } + + #[test] + fn from_proto_allows_matching_request() { + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Expected allow from proto-based engine, got deny: {}", + decision.reason + ); + assert_eq!(decision.matched_policy.as_deref(), Some("claude_code")); + } + + #[test] + fn from_proto_denies_unmatched_request() { + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); + let input = NetworkInput { + host: "evil.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + } + + #[test] + fn from_proto_extracts_sandbox_config() { + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); + let config = engine.query_sandbox_config().unwrap(); + assert!(config.filesystem.include_workdir); + assert!(config.filesystem.read_only.contains(&PathBuf::from("/usr"))); + assert!( + config + .filesystem + .read_write + .contains(&PathBuf::from("/tmp")) + ); + assert_eq!(config.process.run_as_user.as_deref(), Some("sandbox")); + assert_eq!(config.process.run_as_group.as_deref(), Some("sandbox")); + } + + // ======================================================================== + // L7 request evaluation tests + // ======================================================================== + + const L7_TEST_DATA: &str = r#" +network_policies: + rest_api: + name: rest_api + endpoints: + - host: api.example.com + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/repos/**" + - allow: + method: POST + path: "/repos/*/issues" + binaries: + - { path: /usr/bin/curl } + readonly_api: + name: readonly_api + endpoints: + - host: api.readonly.com + port: 8080 + protocol: rest + enforcement: enforce + access: read-only + binaries: + - { path: /usr/bin/curl } + full_api: + name: full_api + endpoints: + - host: api.full.com + port: 8080 + protocol: rest + enforcement: audit + access: full + binaries: + - { path: /usr/bin/curl } + query_api: + name: query_api + endpoints: + - host: api.query.com + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "/download" + query: + tag: "foo-*" + - allow: + method: GET + path: "/search" + query: + tag: + any: ["foo-*", "bar-*"] + binaries: + - { path: /usr/bin/curl } + graphql_api: + name: graphql_api + endpoints: + - host: api.graphql.com + port: 443 + protocol: graphql + enforcement: enforce + persisted_queries: allow_registered + graphql_persisted_queries: + abc123: + operation_type: query + operation_name: Viewer + fields: [viewer] + rules: + - allow: + operation_type: query + fields: [viewer, repository] + - allow: + operation_type: mutation + operation_name: Issue* + fields: [createIssue, deleteRepository] + deny_rules: + - operation_type: mutation + fields: [deleteRepository] + binaries: + - { path: /usr/bin/curl } + graphql_readonly: + name: graphql_readonly + endpoints: + - host: gql.readonly.com + port: 443 + protocol: graphql + enforcement: enforce + access: read-only + binaries: + - { path: /usr/bin/curl } + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + - allow: + operation_type: subscription + fields: [messageAdded] + deny_rules: + - operation_type: mutation + binaries: + - { path: /usr/bin/curl } + l4_only: + name: l4_only + endpoints: + - { host: l4only.example.com, port: 443 } + binaries: + - { path: /usr/bin/curl } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + + fn l7_engine() -> OpaEngine { + OpaEngine::from_strings(TEST_POLICY, L7_TEST_DATA).expect("Failed to load L7 test data") + } + + fn l7_input(host: &str, port: u16, method: &str, path: &str) -> serde_json::Value { + l7_input_with_query(host, port, method, path, serde_json::json!({})) + } + + fn l7_input_with_query( + host: &str, + port: u16, + method: &str, + path: &str, + query_params: serde_json::Value, + ) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": port }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": method, + "path": path, + "query_params": query_params + } + }) + } + + fn l7_jsonrpc_input(host: &str, port: u16, path: &str, rpc_method: &str) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": port }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "POST", + "path": path, + "query_params": {}, + "jsonrpc": { + "method": rpc_method + } + } + }) + } + + fn l7_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": 443 }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "POST", + "path": "/graphql", + "query_params": {}, + "graphql": { + "operations": operations + } + } + }) + } + + fn l7_graphql_error_input(host: &str, error: &str) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": 443 }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "POST", + "path": "/graphql", + "query_params": {}, + "graphql": { + "operations": [], + "error": error + } + } + }) + } + + fn l7_websocket_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": 443 }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "WEBSOCKET_TEXT", + "path": "/graphql", + "query_params": {}, + "graphql": { + "operations": operations + } + } + }) + } + + fn eval_l7(engine: &OpaEngine, input: &serde_json::Value) -> bool { + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + val == regorus::Value::from(true) + } + + #[test] + fn l7_get_allowed_by_rules() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "GET", "/repos/myorg/foo"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_post_allowed_by_rules() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "POST", "/repos/myorg/issues"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_delete_denied_by_rules() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "DELETE", "/repos/myorg/foo"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_get_wrong_path_denied() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "GET", "/admin/settings"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_readonly_preset_allows_get() { + let engine = l7_engine(); + let input = l7_input("api.readonly.com", 8080, "GET", "/anything"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_readonly_preset_allows_head() { + let engine = l7_engine(); + let input = l7_input("api.readonly.com", 8080, "HEAD", "/anything"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_readonly_preset_allows_options() { + let engine = l7_engine(); + let input = l7_input("api.readonly.com", 8080, "OPTIONS", "/anything"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_readonly_preset_denies_post() { + let engine = l7_engine(); + let input = l7_input("api.readonly.com", 8080, "POST", "/anything"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_readonly_preset_denies_delete() { + let engine = l7_engine(); + let input = l7_input("api.readonly.com", 8080, "DELETE", "/anything"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_full_preset_allows_everything() { + let engine = l7_engine(); + for method in &["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"] { + let input = l7_input("api.full.com", 8080, method, "/any/path"); + assert!( + eval_l7(&engine, &input), + "{method} should be allowed with full preset" + ); + } + } + + #[test] + fn l7_graphql_query_allowed_by_field_rule() { + let engine = l7_engine(); + let input = l7_graphql_input( + "api.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "operation_name": "RepoLookup", + "fields": ["repository"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_graphql_unlisted_field_denied() { + let engine = l7_engine(); + let input = l7_graphql_input( + "api.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["viewer", "adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_graphql_batch_denied_if_any_operation_unallowed() { + let engine = l7_engine(); + let input = l7_graphql_input( + "api.graphql.com", + serde_json::json!([ + { + "operation_type": "query", + "fields": ["viewer"], + "persisted_query": false + }, + { + "operation_type": "mutation", + "operation_name": "DeleteRepo", + "fields": ["deleteRepository"], + "persisted_query": false + } + ]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_graphql_deny_rule_takes_precedence() { + let engine = l7_engine(); + let input = l7_graphql_input( + "api.graphql.com", + serde_json::json!([{ + "operation_type": "mutation", + "operation_name": "IssueDelete", + "fields": ["deleteRepository"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_graphql_registered_hash_only_query_allowed() { + let engine = l7_engine(); + let input = l7_graphql_input( + "api.graphql.com", + serde_json::json!([{ + "operation_type": "", + "operation_name": "Viewer", + "fields": [], + "persisted_query": true, + "persisted_query_hash": "abc123" + }]), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_graphql_unregistered_hash_only_query_denied() { + let engine = l7_engine(); + let input = l7_graphql_input( + "api.graphql.com", + serde_json::json!([{ + "operation_type": "", + "operation_name": "Viewer", + "fields": [], + "persisted_query": true, + "persisted_query_hash": "missing" + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_graphql_unregistered_hash_only_query_has_deny_reason() { + let engine = l7_engine(); + let input = l7_graphql_input( + "api.graphql.com", + serde_json::json!([{ + "operation_type": "", + "operation_name": "Viewer", + "fields": [], + "persisted_query": true, + "persisted_query_hash": "missing" + }]), + ); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.request_deny_reason".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::String("GraphQL persisted query is not registered".into()) + ); + } + + #[test] + fn l7_graphql_parse_error_denied() { + let engine = l7_engine(); + let input = l7_graphql_error_input("api.graphql.com", "GraphQL document parse error"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_graphql_readonly_access_allows_query_and_denies_mutation() { + let engine = l7_engine(); + let query = l7_graphql_input( + "gql.readonly.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["viewer"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &query)); + + let mutation = l7_graphql_input( + "gql.readonly.com", + serde_json::json!([{ + "operation_type": "mutation", + "fields": ["createIssue"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &mutation)); + } + + #[test] + fn l7_websocket_graphql_subscription_allowed_by_field_rule() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "subscription", + "operation_name": "NewMessages", + "fields": ["messageAdded"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_unlisted_field_denied() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_deny_rule_takes_precedence() { + let engine = l7_engine(); + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "mutation", + "operation_name": "DeleteRepo", + "fields": ["deleteRepository"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_websocket_graphql_not_bypassed_by_generic_text_rule() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: realtime.graphql.com + ports: [443] + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + method: WEBSOCKET_TEXT + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + binaries: + - { path: /usr/bin/curl } +"#; + let data_json: serde_json::Value = + serde_yml::from_str(data).expect("fixture should parse as YAML"); + let mut rego = regorus::Engine::new(); + rego.add_policy("policy.rego".into(), TEST_POLICY.into()) + .expect("policy should load"); + rego.add_data_json(&data_json.to_string()) + .expect("data should load"); + let engine = OpaEngine { + engine: Mutex::new(rego), + generation: Arc::new(AtomicU64::new(0)), + }; + let input = l7_websocket_graphql_input( + "realtime.graphql.com", + serde_json::json!([{ + "operation_type": "query", + "fields": ["adminAuditLog"], + "persisted_query": false + }]), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_endpoint_path_scopes_rest_and_graphql_on_same_host() { + let data = r#" +network_policies: + mixed_api: + name: mixed_api + endpoints: + - host: api.github.test + port: 443 + path: "/repos/**" + protocol: rest + enforcement: enforce + rules: + - allow: + method: "*" + path: "/**" + - host: api.github.test + port: 443 + path: "/graphql" + protocol: graphql + enforcement: enforce + rules: + - allow: + operation_type: query + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + + let rest_write = l7_input("api.github.test", 443, "POST", "/repos/org/repo/issues"); + assert!(eval_l7(&engine, &rest_write)); + + let graphql_query = l7_graphql_input( + "api.github.test", + serde_json::json!([{ + "operation_type": "query", + "fields": ["viewer"], + "persisted_query": false + }]), + ); + assert!(eval_l7(&engine, &graphql_query)); + + let graphql_mutation = l7_graphql_input( + "api.github.test", + serde_json::json!([{ + "operation_type": "mutation", + "fields": ["deleteRepository"], + "persisted_query": false + }]), + ); + assert!( + !eval_l7(&engine, &graphql_mutation), + "REST rules on the same host must not allow a GraphQL mutation" + ); + } + + #[test] + fn l7_method_matching_case_insensitive() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "get", "/repos/myorg/foo"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_path_glob_matching() { + let engine = l7_engine(); + // /repos/** should match /repos/org/repo + let input = l7_input("api.example.com", 8080, "GET", "/repos/org/repo"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_glob_allows_matching_duplicate_values() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/download", + serde_json::json!({ + "tag": ["foo-a", "foo-b"], + "extra": ["ignored"], + }), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_glob_denies_on_mismatched_duplicate_value() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/download", + serde_json::json!({ + "tag": ["foo-a", "evil"], + }), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_any_allows_if_every_value_matches_any_pattern() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/search", + serde_json::json!({ + "tag": ["foo-a", "bar-b"], + }), + ); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_missing_required_key_denied() { + let engine = l7_engine(); + let input = l7_input_with_query( + "api.query.com", + 8080, + "GET", + "/download", + serde_json::json!({}), + ); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_query_rules_from_proto_are_enforced() { + let mut query = std::collections::HashMap::new(); + query.insert( + "tag".to_string(), + L7QueryMatcher { + glob: "foo-*".to_string(), + any: vec![], + }, + ); + + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "query_proto".to_string(), + NetworkPolicyRule { + name: "query_proto".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.proto.com".to_string(), + port: 8080, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "GET".to_string(), + path: "/download".to_string(), + command: String::new(), + query, + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + rpc_method: String::new(), + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let allow_input = l7_input_with_query( + "api.proto.com", + 8080, + "GET", + "/download", + serde_json::json!({ "tag": ["foo-a"] }), + ); + assert!(eval_l7(&engine, &allow_input)); + + let deny_input = l7_input_with_query( + "api.proto.com", + 8080, + "GET", + "/download", + serde_json::json!({ "tag": ["evil"] }), + ); + assert!(!eval_l7(&engine, &deny_input)); + } + + #[test] + fn l7_jsonrpc_rpc_method_from_proto_is_enforced() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "jsonrpc_proto".to_string(), + NetworkPolicyRule { + name: "jsonrpc_proto".to_string(), + endpoints: vec![NetworkEndpoint { + host: "mcp.proto.com".to_string(), + port: 8000, + path: "/mcp".to_string(), + protocol: "json-rpc".to_string(), + enforcement: "enforce".to_string(), + rules: vec![L7Rule { + allow: Some(L7Allow { + method: String::new(), + path: String::new(), + command: String::new(), + query: std::collections::HashMap::new(), + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + rpc_method: "initialize".to_string(), + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let allow_input = l7_jsonrpc_input("mcp.proto.com", 8000, "/mcp", "initialize"); + assert!(eval_l7(&engine, &allow_input)); + + let deny_input = l7_jsonrpc_input("mcp.proto.com", 8000, "/mcp", "tools/list"); + assert!(!eval_l7(&engine, &deny_input)); + } + + #[test] + fn l7_no_request_on_l4_only_endpoint() { + // L4-only endpoint should not match L7 allow_request + let engine = l7_engine(); + let input = l7_input("l4only.example.com", 443, "GET", "/anything"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_wrong_binary_denied_even_with_matching_rules() { + let engine = l7_engine(); + let input = serde_json::json!({ + "network": { "host": "api.example.com", "port": 8080 }, + "exec": { + "path": "/usr/bin/python3", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "GET", + "path": "/repos/myorg/foo" + } + }); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_deny_reason_populated() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "DELETE", "/repos/myorg/foo"); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.request_deny_reason".into()) + .unwrap(); + let reason = match val { + regorus::Value::String(s) => s.to_string(), + _ => String::new(), + }; + assert!( + reason.contains("not permitted"), + "Expected deny reason, got: {reason}" + ); + } + + #[test] + fn l7_endpoint_config_returned_for_l7_endpoint() { + let engine = l7_engine(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let config = engine.query_endpoint_config(&input).unwrap(); + assert!(config.is_some(), "Expected L7 config for rest endpoint"); + let config = config.unwrap(); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert_eq!(l7.protocol, crate::l7::L7Protocol::Rest); + assert_eq!(l7.enforcement, crate::l7::EnforcementMode::Enforce); + } + + #[test] + fn l7_endpoint_config_preserves_proto_allow_encoded_slash() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "npm".to_string(), + NetworkPolicyRule { + name: "npm".to_string(), + endpoints: vec![NetworkEndpoint { + host: "registry.npmjs.org".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "read-only".to_string(), + allow_encoded_slash: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "registry.npmjs.org".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.allow_encoded_slash); + } + + #[test] + fn l7_endpoint_config_preserves_proto_websocket_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "gateway".to_string(), + NetworkPolicyRule { + name: "gateway".to_string(), + endpoints: vec![NetworkEndpoint { + host: "gateway.example.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "full".to_string(), + websocket_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "gateway.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.websocket_credential_rewrite); + } + + #[test] + fn l7_endpoint_config_preserves_proto_request_body_credential_rewrite() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "slack".to_string(), + NetworkPolicyRule { + name: "slack".to_string(), + endpoints: vec![NetworkEndpoint { + host: "slack.com".to_string(), + port: 443, + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + access: "read-write".to_string(), + request_body_credential_rewrite: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/node".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "slack.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let config = engine + .query_endpoint_config(&input) + .unwrap() + .expect("endpoint config"); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert!(l7.request_body_credential_rewrite); + } + + #[test] + fn l7_endpoint_config_none_for_l4_only() { + let engine = l7_engine(); + let input = NetworkInput { + host: "l4only.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let config = engine.query_endpoint_config(&input).unwrap(); + assert!( + config.is_none(), + "Expected no L7 config for L4-only endpoint" + ); + } + + #[test] + fn l7_clone_engine_for_tunnel() { + let engine = l7_engine(); + let cloned = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + // Verify the cloned engine can evaluate + let input_json = l7_input("api.example.com", 8080, "GET", "/repos/myorg/foo"); + let mut eng = cloned.engine().lock().unwrap(); + eng.set_input_json(&input_json.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!(val, regorus::Value::from(true)); + } + + #[test] + fn policy_generation_starts_at_zero_and_increments_on_successful_reload() { + let engine = l7_engine(); + assert_eq!(engine.current_generation(), 0); + + engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); + + assert_eq!(engine.current_generation(), 1); + } + + #[test] + fn policy_generation_does_not_increment_on_failed_reload() { + let engine = l7_engine(); + engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); + assert_eq!(engine.current_generation(), 1); + + let invalid_l7_data = r#" +network_policies: + bad_api: + name: bad_api + endpoints: + - host: api.example.com + port: 8080 + protocol: invalid-protocol + binaries: + - { path: /usr/bin/curl } +"#; + assert!(engine.reload(TEST_POLICY, invalid_l7_data).is_err()); + assert_eq!(engine.current_generation(), 1); + + let input_json = l7_input("api.example.com", 8080, "GET", "/repos/myorg/foo"); + let cloned = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let mut eng = cloned.engine().lock().unwrap(); + eng.set_input_json(&input_json.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!(val, regorus::Value::from(true)); + } + + #[test] + fn endpoint_config_generation_matches_query_generation() { + let engine = l7_engine(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let (config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + assert!(config.is_some()); + assert_eq!(generation, engine.current_generation()); + + engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); + + let (config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + assert!(config.is_some()); + assert_eq!(generation, engine.current_generation()); + assert_eq!(generation, 1); + } + + #[test] + fn tunnel_clone_rejects_stale_generation() { + let engine = l7_engine(); + let captured_generation = engine.current_generation(); + engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); + + assert!(engine.clone_engine_for_tunnel(captured_generation).is_err()); + } + + // ======================================================================== + // Deny rules tests + // ======================================================================== + + const L7_DENY_TEST_DATA: &str = r#" +network_policies: + github_api: + name: github_api + endpoints: + - host: api.github.com + port: 443 + protocol: rest + enforcement: enforce + access: read-write + deny_rules: + - method: POST + path: "/repos/*/pulls/*/reviews" + - method: PUT + path: "/repos/*/branches/*/protection" + - method: "*" + path: "/repos/*/rulesets" + binaries: + - { path: /usr/bin/curl } + deny_with_query: + name: deny_with_query + endpoints: + - host: api.restricted.com + port: 443 + protocol: rest + enforcement: enforce + access: full + deny_rules: + - method: POST + path: "/admin/**" + query: + force: "true" + binaries: + - { path: /usr/bin/curl } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + + fn l7_deny_engine() -> OpaEngine { + OpaEngine::from_strings(TEST_POLICY, L7_DENY_TEST_DATA) + .expect("Failed to load deny test data") + } + + #[test] + fn l7_deny_rule_blocks_allowed_method_path() { + let engine = l7_deny_engine(); + // POST to reviews is allowed by read-write preset but denied by deny rule + let input = l7_input( + "api.github.com", + 443, + "POST", + "/repos/myorg/pulls/123/reviews", + ); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(false), + "deny rule should block POST to reviews" + ); + } + + #[test] + fn l7_deny_rule_allows_non_matching_requests() { + let engine = l7_deny_engine(); + // GET repos/issues is allowed and not denied + let input = l7_input("api.github.com", 443, "GET", "/repos/myorg/issues"); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(true), + "non-denied GET should be allowed" + ); + } + + #[test] + fn l7_deny_rule_allows_same_method_different_path() { + let engine = l7_deny_engine(); + // POST to issues is allowed (deny only targets reviews) + let input = l7_input("api.github.com", 443, "POST", "/repos/myorg/issues"); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(true), + "POST to issues should be allowed" + ); + } + + #[test] + fn l7_deny_rule_blocks_wildcard_method() { + let engine = l7_deny_engine(); + // GET /repos/myorg/rulesets should be denied (method: "*") + let input = l7_input("api.github.com", 443, "GET", "/repos/myorg/rulesets"); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(false), + "wildcard method deny should block GET" + ); + } + + #[test] + fn l7_deny_rule_blocks_put_protection() { + let engine = l7_deny_engine(); + let input = l7_input( + "api.github.com", + 443, + "PUT", + "/repos/myorg/branches/main/protection", + ); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(false), + "PUT to branch protection should be denied" + ); + } + + #[test] + fn l7_deny_reason_populated_when_deny_rule_matches() { + let engine = l7_deny_engine(); + let input = l7_input( + "api.github.com", + 443, + "POST", + "/repos/myorg/pulls/123/reviews", + ); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.request_deny_reason".into()) + .unwrap(); + let reason = match val { + regorus::Value::String(s) => s.to_string(), + _ => String::new(), + }; + assert!( + reason.contains("deny rule"), + "Expected deny rule reason, got: {reason}" + ); + } + + #[test] + fn l7_deny_rule_with_query_blocks_matching_params() { + let engine = l7_deny_engine(); + // POST /admin/settings with force=true should be denied + let input = l7_input_with_query( + "api.restricted.com", + 443, + "POST", + "/admin/settings", + serde_json::json!({"force": ["true"]}), + ); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(false), + "deny with matching query should block" + ); + } + + #[test] + fn l7_deny_rule_with_query_allows_non_matching_params() { + let engine = l7_deny_engine(); + // POST /admin/settings with force=false should be allowed (query doesn't match deny) + let input = l7_input_with_query( + "api.restricted.com", + 443, + "POST", + "/admin/settings", + serde_json::json!({"force": ["false"]}), + ); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(true), + "deny with non-matching query should allow" + ); + } + + #[test] + fn l7_deny_rule_with_query_blocks_when_any_value_matches() { + let engine = l7_deny_engine(); + // POST /admin/settings with force=true&force=false should STILL be denied + // because at least one value ("true") matches the deny rule. + // This is fail-closed: any matching value triggers the deny. + let input = l7_input_with_query( + "api.restricted.com", + 443, + "POST", + "/admin/settings", + serde_json::json!({"force": ["true", "false"]}), + ); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(false), + "deny should fire when ANY value matches, even with mixed values" + ); + } + + #[test] + fn l7_deny_rule_without_matching_query_key_allows() { + let engine = l7_deny_engine(); + // POST /admin/settings with no query params -- deny rule has query.force=true, + // so no match (key not present) and request should be allowed + let input = l7_input("api.restricted.com", 443, "POST", "/admin/settings"); + let mut eng = engine.engine.lock().unwrap(); + eng.set_input_json(&input.to_string()).unwrap(); + let val = eng + .eval_rule("data.openshell.sandbox.allow_request".into()) + .unwrap(); + assert_eq!( + val, + regorus::Value::from(true), + "deny without matching query key should allow" + ); + } + + // ======================================================================== + // Overlapping policies (duplicate host:port) — regression tests + // ======================================================================== + + /// Two network_policies entries covering the same host:port with L7 rules. + /// Before the fix, this caused regorus to fail with + /// "duplicated definition of local variable ep" in allow_request. + const OVERLAPPING_L7_TEST_DATA: &str = r#" +network_policies: + test_server: + name: test_server + endpoints: + - host: 192.168.1.100 + port: 8567 + protocol: rest + enforcement: enforce + rules: + - allow: + method: GET + path: "**" + binaries: + - { path: /usr/bin/curl } + allow_192_168_1_100_8567: + name: allow_192_168_1_100_8567 + endpoints: + - host: 192.168.1.100 + port: 8567 + protocol: rest + enforcement: enforce + allowed_ips: + - 192.168.1.100 + rules: + - allow: + method: GET + path: "**" + binaries: + - { path: /usr/bin/curl } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + + #[test] + fn l7_overlapping_policies_allow_request_does_not_crash() { + let engine = OpaEngine::from_strings(TEST_POLICY, OVERLAPPING_L7_TEST_DATA) + .expect("engine should load overlapping data"); + let input = l7_input("192.168.1.100", 8567, "GET", "/test"); + // Should not panic or error — must evaluate to true. + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_overlapping_policies_deny_request_does_not_crash() { + let engine = OpaEngine::from_strings(TEST_POLICY, OVERLAPPING_L7_TEST_DATA) + .expect("engine should load overlapping data"); + let input = l7_input("192.168.1.100", 8567, "DELETE", "/test"); + // DELETE is not in the rules, so should deny — but must not crash. + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn overlapping_policies_endpoint_config_returns_result() { + let engine = OpaEngine::from_strings(TEST_POLICY, OVERLAPPING_L7_TEST_DATA) + .expect("engine should load overlapping data"); + let input = NetworkInput { + host: "192.168.1.100".into(), + port: 8567, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: String::new(), + ancestors: vec![], + cmdline_paths: vec![], + }; + // Should return config from one of the entries without error. + let config = engine.query_endpoint_config(&input).unwrap(); + assert!( + config.is_some(), + "Expected endpoint config for overlapping policies" + ); + } + + // ======================================================================== + // network_action tests + // ======================================================================== + + const INFERENCE_TEST_DATA: &str = r#" +network_policies: + claude_code: + name: claude_code + endpoints: + - { host: api.anthropic.com, port: 443 } + binaries: + - { path: /usr/local/bin/claude } + gitlab: + name: gitlab + endpoints: + - { host: gitlab.com, port: 443 } + binaries: + - { path: /usr/bin/glab } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + + const NO_INFERENCE_TEST_DATA: &str = r#" +network_policies: + gitlab: + name: gitlab + endpoints: + - { host: gitlab.com, port: 443 } + binaries: + - { path: /usr/bin/glab } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + + fn inference_engine() -> OpaEngine { + OpaEngine::from_strings(TEST_POLICY, INFERENCE_TEST_DATA) + .expect("Failed to load inference test data") + } + + fn no_inference_engine() -> OpaEngine { + OpaEngine::from_strings(TEST_POLICY, NO_INFERENCE_TEST_DATA) + .expect("Failed to load no-inference test data") + } + + #[test] + fn explicitly_allowed_endpoint_binary_returns_allow() { + let engine = inference_engine(); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + assert_eq!( + action, + NetworkAction::Allow { + matched_policy: Some("claude_code".to_string()) + }, + ); + } + + #[test] + fn unknown_endpoint_returns_deny() { + let engine = inference_engine(); + let input = NetworkInput { + host: "api.openai.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + match &action { + NetworkAction::Deny { .. } => {} + other => panic!("Expected Deny, got: {other:?}"), + } + } + + #[test] + fn unknown_endpoint_without_inference_returns_deny() { + let engine = no_inference_engine(); + let input = NetworkInput { + host: "api.openai.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + match &action { + NetworkAction::Deny { .. } => {} + other => panic!("Expected Deny, got: {other:?}"), + } + } + + #[test] + fn endpoint_in_policy_binary_not_allowed_returns_deny() { + // api.anthropic.com is declared but python3 is not in the binary list. + // With binary allow/deny, this is denied. + let engine = inference_engine(); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + match &action { + NetworkAction::Deny { .. } => {} + other => panic!("Expected Deny, got: {other:?}"), + } + } + + #[test] + fn endpoint_in_policy_binary_not_allowed_without_inference_returns_deny() { + let engine = no_inference_engine(); + let input = NetworkInput { + host: "gitlab.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + match &action { + NetworkAction::Deny { .. } => {} + other => panic!("Expected Deny, got: {other:?}"), + } + } + + #[test] + fn from_proto_explicitly_allowed_returns_allow() { + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + assert_eq!( + action, + NetworkAction::Allow { + matched_policy: Some("claude_code".to_string()) + }, + ); + } + + #[test] + fn from_proto_unknown_endpoint_returns_deny() { + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "api.openai.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + match &action { + NetworkAction::Deny { .. } => {} + other => panic!("Expected Deny, got: {other:?}"), + } + } + + #[test] + fn network_action_with_dev_policy() { + let engine = test_engine(); + // claude direct to api.anthropic.com → allow (explicit match) + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + assert_eq!( + action, + NetworkAction::Allow { + matched_policy: Some("claude_code".to_string()) + }, + ); + + // git to github.com → allow + let input = NetworkInput { + host: "github.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/git"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let action = engine.evaluate_network_action(&input).unwrap(); + assert_eq!( + action, + NetworkAction::Allow { + matched_policy: Some("github_ssh_over_https".to_string()) + }, + ); + } + + // ======================================================================== + // allowed_ips tests + // ======================================================================== + + const ALLOWED_IPS_TEST_DATA: &str = r#" +network_policies: + # Mode 2: host + allowed_ips + internal_api: + name: internal_api + endpoints: + - host: my-service.corp.net + port: 8080 + allowed_ips: ["10.0.5.0/24"] + binaries: + - { path: /usr/bin/curl } + # Mode 3: allowed_ips only (no host) — uses port 9443 to avoid overlap + private_network: + name: private_network + endpoints: + - port: 9443 + allowed_ips: ["172.16.0.0/12", "192.168.1.1"] + binaries: + - { path: /usr/bin/curl } + # Mode 1: host only (no allowed_ips) — standard behavior + public_api: + name: public_api + endpoints: + - { host: api.github.com, port: 443 } + binaries: + - { path: /usr/bin/curl } + # Wildcard host endpoint should not count as an exact declared hostname. + wildcard_api: + name: wildcard_api + endpoints: + - { host: "*.corp.net", port: 443 } + binaries: + - { path: /usr/bin/curl } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + + fn allowed_ips_engine() -> OpaEngine { + OpaEngine::from_strings(TEST_POLICY, ALLOWED_IPS_TEST_DATA) + .expect("Failed to load allowed_ips test data") + } + + #[test] + fn allowed_ips_mode2_host_plus_ips_allows() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "my-service.corp.net".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Mode 2 (host+IPs) should allow: {}", + decision.reason + ); + assert_eq!(decision.matched_policy.as_deref(), Some("internal_api")); + } + + #[test] + fn allowed_ips_mode2_returns_allowed_ips() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "my-service.corp.net".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let ips = engine.query_allowed_ips(&input).unwrap(); + assert_eq!(ips, vec!["10.0.5.0/24"]); + } + + #[test] + fn allowed_ips_mode3_hostless_allows_any_domain() { + let engine = allowed_ips_engine(); + // Any hostname on port 9443 should match the private_network policy + let input = NetworkInput { + host: "anything.example.com".into(), + port: 9443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Mode 3 (IPs only) should allow any domain on matching port: {}", + decision.reason + ); + } + + #[test] + fn allowed_ips_mode3_returns_allowed_ips() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "anything.example.com".into(), + port: 9443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let ips = engine.query_allowed_ips(&input).unwrap(); + assert_eq!(ips, vec!["172.16.0.0/12", "192.168.1.1"]); + } + + #[test] + fn allowed_ips_mode1_no_ips_returns_empty() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "api.github.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let ips = engine.query_allowed_ips(&input).unwrap(); + assert!(ips.is_empty(), "Mode 1 should return no allowed_ips"); + } + + #[test] + fn exact_declared_endpoint_host_true_for_l4_host_only() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "api.github.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + assert!(engine.query_endpoint_config(&input).unwrap().is_none()); + assert!(engine.query_exact_declared_endpoint_host(&input).unwrap()); + } + + #[test] + fn exact_declared_endpoint_host_true_for_host_with_allowed_ips() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "my-service.corp.net".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + assert!(engine.query_exact_declared_endpoint_host(&input).unwrap()); + } + + #[test] + fn exact_declared_endpoint_host_false_for_hostless_allowed_ips() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "anything.example.com".into(), + port: 9443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + assert!(!engine.query_exact_declared_endpoint_host(&input).unwrap()); + } + + #[test] + fn exact_declared_endpoint_host_false_for_wildcard_host() { + let engine = allowed_ips_engine(); + let input = NetworkInput { + host: "api.corp.net".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let decision = engine.evaluate_network(&input).unwrap(); + assert!(decision.allowed, "wildcard endpoint should still allow"); + assert!(!engine.query_exact_declared_endpoint_host(&input).unwrap()); + } + + #[test] + fn exact_declared_endpoint_host_false_for_advisor_proposed_binary() { + let mut network_policies = std::collections::HashMap::new(); + let mut proposal_binary = NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }; + #[allow(deprecated)] + { + proposal_binary.harness = true; + } + network_policies.insert( + "allow_mcp_internal_corp_example_com_8443".to_string(), + NetworkPolicyRule { + name: "allow_mcp_internal_corp_example_com_8443".to_string(), + endpoints: vec![NetworkEndpoint { + host: "mcp-internal.corp.example.com".to_string(), + port: 8443, + ..Default::default() + }], + binaries: vec![proposal_binary], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "mcp-internal.corp.example.com".into(), + port: 8443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "advisor proposal should still allow at OPA L4" + ); + assert!(!engine.query_exact_declared_endpoint_host(&input).unwrap()); + } + + #[test] + fn exact_declared_endpoint_host_false_for_advisor_proposed_endpoint() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "app-api".to_string(), + NetworkPolicyRule { + name: "app-api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "internal-admin.local".to_string(), + port: 443, + ports: vec![443], + advisor_proposed: true, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/python".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let input = NetworkInput { + host: "internal-admin.local".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let decision = engine.evaluate_network(&input).unwrap(); + assert!(decision.allowed, "policy should still allow at OPA L4"); + assert!( + !engine.query_exact_declared_endpoint_host(&input).unwrap(), + "advisor endpoint provenance should block exact-host SSRF trust" + ); + } + + #[test] + fn allowed_ips_mode3_wrong_port_denied() { + let engine = allowed_ips_engine(); + // Port 12345 doesn't match any policy + let input = NetworkInput { + host: "anything.example.com".into(), + port: 12345, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed, "Mode 3: wrong port should deny"); + } + + #[test] + fn allowed_ips_proto_round_trip() { + // Test that allowed_ips survives proto → OPA data → query + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "internal".to_string(), + NetworkPolicyRule { + name: "internal".to_string(), + endpoints: vec![NetworkEndpoint { + host: "internal.corp.net".to_string(), + port: 8080, + allowed_ips: vec!["10.0.5.0/24".to_string(), "10.0.6.0/24".to_string()], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); + + let input = NetworkInput { + host: "internal.corp.net".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let ips = engine.query_allowed_ips(&input).unwrap(); + assert_eq!(ips, vec!["10.0.5.0/24", "10.0.6.0/24"]); + } + + // ======================================================================== + // Multi-port endpoint tests + // ======================================================================== + + #[test] + fn multi_port_endpoint_matches_first_port() { + let data = r#" +network_policies: + multi: + name: multi + endpoints: + - { host: api.example.com, ports: [443, 8443] } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "First port in multi-port should match: {}", + decision.reason + ); + } + + #[test] + fn multi_port_endpoint_matches_second_port() { + let data = r#" +network_policies: + multi: + name: multi + endpoints: + - { host: api.example.com, ports: [443, 8443] } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 8443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Second port in multi-port should match: {}", + decision.reason + ); + } + + #[test] + fn multi_port_endpoint_rejects_unlisted_port() { + let data = r#" +network_policies: + multi: + name: multi + endpoints: + - { host: api.example.com, ports: [443, 8443] } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 80, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed, "Unlisted port should be denied"); + } + + #[test] + fn single_port_backwards_compat() { + // Old-style YAML with just `port: 443` should still work + let data = r#" +network_policies: + compat: + name: compat + endpoints: + - { host: api.example.com, port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Single port backwards compat: {}", + decision.reason + ); + + // Wrong port should still deny + let input_bad = NetworkInput { + host: "api.example.com".into(), + port: 80, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input_bad).unwrap(); + assert!(!decision.allowed); + } + + #[test] + fn hostless_endpoint_multi_port() { + let data = r#" +network_policies: + private: + name: private + endpoints: + - ports: [80, 443] + allowed_ips: ["10.0.0.0/8"] + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + // Port 80 + let input80 = NetworkInput { + host: "anything.internal".into(), + port: 80, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input80).unwrap(); + assert!( + decision.allowed, + "Hostless multi-port should match port 80: {}", + decision.reason + ); + // Port 443 + let input443 = NetworkInput { + host: "anything.internal".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input443).unwrap(); + assert!( + decision.allowed, + "Hostless multi-port should match port 443: {}", + decision.reason + ); + // Port 8080 should deny + let input_bad = NetworkInput { + host: "anything.internal".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input_bad).unwrap(); + assert!(!decision.allowed); + } + + #[test] + fn from_proto_multi_port_allows_matching() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "multi".to_string(), + NetworkPolicyRule { + name: "multi".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ports: vec![443, 8443], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + let engine = OpaEngine::from_proto(&proto).unwrap(); + // Port 443 + let input443 = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + assert!(engine.evaluate_network(&input443).unwrap().allowed); + // Port 8443 + let input8443 = NetworkInput { + host: "api.example.com".into(), + port: 8443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + assert!(engine.evaluate_network(&input8443).unwrap().allowed); + // Port 80 denied + let input80 = NetworkInput { + host: "api.example.com".into(), + port: 80, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + assert!(!engine.evaluate_network(&input80).unwrap().allowed); + } + + // ======================================================================== + // Host wildcard tests + // ======================================================================== + + #[test] + fn wildcard_host_matches_subdomain() { + let data = r#" +network_policies: + wildcard: + name: wildcard + endpoints: + - { host: "*.example.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "*.example.com should match api.example.com: {}", + decision.reason + ); + } + + #[test] + fn wildcard_host_rejects_deep_subdomain() { + // * should match single DNS label only (does not cross .) + let data = r#" +network_policies: + wildcard: + name: wildcard + endpoints: + - { host: "*.example.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "deep.sub.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + !decision.allowed, + "*.example.com should NOT match deep.sub.example.com" + ); + } + + #[test] + fn wildcard_host_rejects_exact_domain() { + let data = r#" +network_policies: + wildcard: + name: wildcard + endpoints: + - { host: "*.example.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + !decision.allowed, + "*.example.com should NOT match example.com (requires at least one label)" + ); + } + + #[test] + fn wildcard_host_case_insensitive() { + let data = r#" +network_policies: + wildcard: + name: wildcard + endpoints: + - { host: "*.EXAMPLE.COM", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Host wildcards should be case-insensitive: {}", + decision.reason + ); + } + + #[test] + fn wildcard_host_plus_port() { + let data = r#" +network_policies: + wildcard: + name: wildcard + endpoints: + - { host: "*.example.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + // Right host, wrong port + let input = NetworkInput { + host: "api.example.com".into(), + port: 80, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed, "Wildcard host on wrong port should deny"); + } + + #[test] + fn wildcard_host_intra_label_matches() { + // First-label intra-label wildcard: `*` matches the variable prefix + // within a single DNS label. Locks validator/runtime alignment for + // the pattern accepted by `validate_host_wildcard`. + let data = r#" +network_policies: + intra_label: + name: intra_label + endpoints: + - { host: "*-aiplatform.googleapis.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "us-central1-aiplatform.googleapis.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "*-aiplatform.googleapis.com should match us-central1-aiplatform.googleapis.com: {}", + decision.reason + ); + } + + #[test] + fn wildcard_host_intra_label_does_not_cross_dot() { + // `glob.match(..., ["."])` treats `.` as a label boundary that `*` + // cannot cross. `*-aiplatform.googleapis.com` must not match a host + // whose first label is `us-central1` and where `aiplatform` is a + // separate label. + let data = r#" +network_policies: + intra_label: + name: intra_label + endpoints: + - { host: "*-aiplatform.googleapis.com", port: 443 } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "us-central1.aiplatform.googleapis.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + !decision.allowed, + "*-aiplatform.googleapis.com must NOT match us-central1.aiplatform.googleapis.com \ + (would cross a `.` boundary)" + ); + } + + #[test] + fn wildcard_host_multi_port() { + let data = r#" +network_policies: + wildcard: + name: wildcard + endpoints: + - { host: "*.example.com", ports: [443, 8443] } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 8443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Wildcard host + multi-port should match: {}", + decision.reason + ); + } + + #[test] + fn wildcard_host_l7_rules_apply() { + let data = r#" +network_policies: + wildcard_l7: + name: wildcard_l7 + endpoints: + - host: "*.example.com" + port: 8080 + protocol: rest + enforcement: enforce + tls: terminate + rules: + - allow: + method: GET + path: "/api/**" + binaries: + - { path: /usr/bin/curl } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + // L7 GET to /api/foo — should be allowed + let input = l7_input("api.example.com", 8080, "GET", "/api/foo"); + assert!( + eval_l7(&engine, &input), + "L7 rule should apply to wildcard-matched host" + ); + // L7 DELETE to /api/foo — should be denied by L7 rule + let input_bad = l7_input("api.example.com", 8080, "DELETE", "/api/foo"); + assert!( + !eval_l7(&engine, &input_bad), + "L7 DELETE should be denied even on wildcard host" + ); + } + + #[test] + fn wildcard_host_l7_endpoint_config_returned() { + let data = r#" +network_policies: + wildcard_l7: + name: wildcard_l7 + endpoints: + - host: "*.example.com" + port: 8080 + protocol: rest + enforcement: enforce + tls: terminate + rules: + - allow: + method: GET + path: "**" + binaries: + - { path: /usr/bin/curl } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let config = engine.query_endpoint_config(&input).unwrap(); + assert!( + config.is_some(), + "Should return endpoint config for wildcard-matched host" + ); + let config = config.unwrap(); + let l7 = crate::l7::parse_l7_config(&config).unwrap(); + assert_eq!(l7.protocol, crate::l7::L7Protocol::Rest); + assert_eq!(l7.enforcement, crate::l7::EnforcementMode::Enforce); + } + + #[test] + fn l7_multi_port_request_evaluation() { + let data = r#" +network_policies: + multi_l7: + name: multi_l7 + endpoints: + - host: api.example.com + ports: [8080, 9090] + protocol: rest + enforcement: enforce + tls: terminate + rules: + - allow: + method: GET + path: "**" + binaries: + - { path: /usr/bin/curl } +filesystem_policy: + include_workdir: true + read_only: [] + read_write: [] +landlock: + compatibility: best_effort +process: + run_as_user: sandbox + run_as_group: sandbox +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + // GET on port 8080 — allowed + let input1 = l7_input("api.example.com", 8080, "GET", "/anything"); + assert!( + eval_l7(&engine, &input1), + "L7 on first port of multi-port should work" + ); + // GET on port 9090 — allowed + let input2 = l7_input("api.example.com", 9090, "GET", "/anything"); + assert!( + eval_l7(&engine, &input2), + "L7 on second port of multi-port should work" + ); + } + + // ======================================================================== + // Symlink resolution tests (issue #770) + // ======================================================================== + + #[test] + fn normalize_path_resolves_parent_and_current() { + use std::path::{Path, PathBuf}; + assert_eq!( + normalize_path(Path::new("/usr/bin/../lib/python3")), + PathBuf::from("/usr/lib/python3") + ); + assert_eq!( + normalize_path(Path::new("/usr/bin/./python3")), + PathBuf::from("/usr/bin/python3") + ); + assert_eq!( + normalize_path(Path::new("/a/b/c/../../d")), + PathBuf::from("/a/d") + ); + assert_eq!( + normalize_path(Path::new("/usr/bin/python3")), + PathBuf::from("/usr/bin/python3") + ); + } + + #[test] + fn resolve_binary_skips_glob_paths() { + // Glob patterns should never be resolved — they're matched differently + assert!(resolve_binary_in_container("/usr/bin/*", 1).is_none()); + assert!(resolve_binary_in_container("/usr/local/bin/**", 1).is_none()); + } + + #[test] + fn resolve_binary_skips_pid_zero() { + // pid=0 means the container hasn't started yet + assert!(resolve_binary_in_container("/usr/bin/python3", 0).is_none()); + } + + #[test] + fn resolve_binary_returns_none_for_nonexistent_path() { + // A path that doesn't exist in any container should gracefully return None + assert!( + resolve_binary_in_container("/nonexistent/binary/path/that/will/never/exist", 1) + .is_none() + ); + } + + #[test] + fn proto_to_opa_data_json_pid_zero_no_expansion() { + // With pid=0, proto_to_opa_data_json should produce the same output + // as the original (no symlink expansion) + let proto = test_proto(); + let data_no_pid = proto_to_opa_data_json(&proto, 0); + let parsed: serde_json::Value = serde_json::from_str(&data_no_pid).unwrap(); + + // Verify the claude_code policy has exactly 1 binary entry (no expansion) + let binaries = parsed["network_policies"]["claude_code"]["binaries"] + .as_array() + .unwrap(); + assert_eq!( + binaries.len(), + 1, + "With pid=0, should have no expanded binaries" + ); + assert_eq!(binaries[0]["path"], "/usr/local/bin/claude"); + } + + #[test] + fn symlink_expanded_binary_allows_resolved_path() { + // Simulate what happens after symlink resolution: the OPA data + // contains both the original symlink path and the resolved path. + // A request using the resolved path should be allowed. + let data = r#" +network_policies: + python_policy: + name: python_policy + endpoints: + - { host: pypi.org, port: 443 } + binaries: + - { path: /usr/bin/python3 } + - { path: /usr/bin/python3.11 } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + + // Request with the resolved path (what the kernel reports) + let input = NetworkInput { + host: "pypi.org".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3.11"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Resolved symlink path should be allowed: {}", + decision.reason + ); + assert_eq!(decision.matched_policy.as_deref(), Some("python_policy")); + } + + #[test] + fn symlink_expanded_binary_still_allows_original_path() { + // Even with expansion, the original path must still work + let data = r#" +network_policies: + python_policy: + name: python_policy + endpoints: + - { host: pypi.org, port: 443 } + binaries: + - { path: /usr/bin/python3 } + - { path: /usr/bin/python3.11 } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + + // Request with the original symlink path (unlikely at runtime, but must not break) + let input = NetworkInput { + host: "pypi.org".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Original symlink path should still be allowed: {}", + decision.reason + ); + } + + #[test] + fn symlink_expanded_binary_does_not_weaken_security() { + // A binary NOT in the policy should still be denied, even if + // the expanded entries exist for other binaries. + let data = r#" +network_policies: + python_policy: + name: python_policy + endpoints: + - { host: pypi.org, port: 443 } + binaries: + - { path: /usr/bin/python3 } + - { path: /usr/bin/python3.11 } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + + let input = NetworkInput { + host: "pypi.org".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed, "Unrelated binary should still be denied"); + } + + #[test] + fn symlink_expansion_works_with_ancestors() { + // Ancestor binary matching should also work with expanded paths + let data = r#" +network_policies: + python_policy: + name: python_policy + endpoints: + - { host: pypi.org, port: 443 } + binaries: + - { path: /usr/bin/python3 } + - { path: /usr/bin/python3.11 } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + + // The exe is curl, but an ancestor is the resolved python3.11 + let input = NetworkInput { + host: "pypi.org".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![PathBuf::from("/usr/bin/python3.11")], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Resolved symlink path should match as ancestor: {}", + decision.reason + ); + } + + #[test] + fn symlink_expansion_via_proto_with_pid_zero() { + // from_proto_with_pid(proto, 0) should produce same results as from_proto(proto) + let proto = test_proto(); + let engine_default = OpaEngine::from_proto(&proto).expect("from_proto should succeed"); + let engine_pid0 = OpaEngine::from_proto_with_pid(&proto, 0) + .expect("from_proto_with_pid(0) should succeed"); + + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + + let decision_default = engine_default.evaluate_network(&input).unwrap(); + let decision_pid0 = engine_pid0.evaluate_network(&input).unwrap(); + + assert_eq!( + decision_default.allowed, decision_pid0.allowed, + "from_proto and from_proto_with_pid(0) should produce identical results" + ); + } + + #[test] + fn reload_from_proto_with_pid_zero_works() { + // reload_from_proto_with_pid(proto, 0) should function identically to reload_from_proto + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("from_proto should succeed"); + + // Verify initial policy works + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(decision.allowed); + + // Reload with same proto at pid=0 + engine + .reload_from_proto_with_pid(&proto, 0) + .expect("reload_from_proto_with_pid should succeed"); + + // Should still work + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "reload_from_proto_with_pid(0) should preserve behavior" + ); + } + + #[test] + fn hot_reload_preserves_symlink_expansion_behavior() { + // Simulates the hot-reload path: initial load at pid=0, then reload + // with a new proto that would have expanded binaries at a real PID. + // Since we can't mock /proc//root/ in unit tests, we test + // that reload_from_proto_with_pid at pid=0 still works correctly + // and that the engine is properly replaced. + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("initial load should succeed"); + + // Verify initial policy allows claude + let claude_input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + assert!(engine.evaluate_network(&claude_input).unwrap().allowed); + + // Create a new proto with an additional policy + let mut new_proto = test_proto(); + new_proto.network_policies.insert( + "python_api".to_string(), + NetworkPolicyRule { + name: "python_api".to_string(), + endpoints: vec![NetworkEndpoint { + host: "pypi.org".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/python3".to_string(), + ..Default::default() + }], + }, + ); + + // Hot-reload with pid=0 + engine + .reload_from_proto_with_pid(&new_proto, 0) + .expect("hot-reload should succeed"); + + // Old policy should still work + assert!( + engine.evaluate_network(&claude_input).unwrap().allowed, + "Old policies should survive hot-reload" + ); + + // New policy should also work + let python_input = NetworkInput { + host: "pypi.org".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + assert!( + engine.evaluate_network(&python_input).unwrap().allowed, + "New policy should be active after hot-reload" + ); + } + + #[test] + fn hot_reload_replaces_engine_atomically() { + // Test that a failed reload preserves the last-known-good engine + let proto = test_proto(); + let engine = OpaEngine::from_proto(&proto).expect("initial load should succeed"); + + let claude_input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + assert!(engine.evaluate_network(&claude_input).unwrap().allowed); + + // Reload with same proto — should succeed and preserve behavior + engine + .reload_from_proto_with_pid(&proto, 0) + .expect("reload should succeed"); + + assert!( + engine.evaluate_network(&claude_input).unwrap().allowed, + "Engine should work after successful reload" + ); + } + + #[test] + fn deny_reason_includes_symlink_hint() { + // Verify the deny reason includes an actionable symlink hint + let engine = test_engine(); + let input = NetworkInput { + host: "api.anthropic.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/python3.11"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + assert!( + decision.reason.contains("SYMLINK HINT"), + "Deny reason should include prominent symlink hint, got: {}", + decision.reason + ); + assert!( + decision.reason.contains("readlink -f"), + "Deny reason should include actionable fix command, got: {}", + decision.reason + ); + } + + #[test] + fn deny_reason_collapses_endpoint_misses() { + let engine = test_engine(); + let input = NetworkInput { + host: "not-configured.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/local/bin/claude"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!(!decision.allowed); + assert_eq!( + decision.reason, + "endpoint not-configured.example.com:443 is not allowed by any policy" + ); + } + + /// Check if symlink resolution through `/proc//root/` actually works. + /// Creates a real symlink in a tempdir and attempts to resolve it via + /// the procfs root path. This catches environments where the probe path + /// is readable but canonicalization/read_link fails (e.g., containers + /// with restricted ptrace scope, rootless containers). + #[cfg(target_os = "linux")] + fn procfs_root_accessible() -> bool { + use std::os::unix::fs::symlink; + let Ok(dir) = tempfile::tempdir() else { + return false; + }; + let target = dir.path().join("probe_target"); + let link = dir.path().join("probe_link"); + if std::fs::write(&target, b"probe").is_err() { + return false; + } + if symlink(&target, &link).is_err() { + return false; + } + let pid = std::process::id(); + let link_path = link.to_string_lossy().to_string(); + // Actually attempt the same resolution our production code uses + resolve_binary_in_container(&link_path, pid).is_some() + } + + #[cfg(target_os = "linux")] + #[test] + fn resolve_binary_with_real_symlink() { + use std::os::unix::fs::symlink; + + if !procfs_root_accessible() { + eprintln!("Skipping: /proc//root/ not accessible in this environment"); + return; + } + + // Create a real symlink in a temp directory and verify resolution + // works through /proc/self/root (which maps to / on the host) + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("python3.11"); + let link = dir.path().join("python3"); + + // Create the target file + std::fs::write(&target, b"#!/usr/bin/env python3\n").unwrap(); + // Create symlink + symlink(&target, &link).unwrap(); + + // Use our own PID — /proc//root/ points to / + let our_pid = std::process::id(); + let link_path = link.to_string_lossy().to_string(); + let result = resolve_binary_in_container(&link_path, our_pid); + + assert!( + result.is_some(), + "Should resolve symlink via /proc//root/" + ); + let resolved = result.unwrap(); + assert!( + resolved.ends_with("python3.11"), + "Resolved path should point to target: {resolved}" + ); + } + + #[cfg(target_os = "linux")] + #[test] + fn resolve_binary_non_symlink_returns_none() { + use std::io::Write; + + if !procfs_root_accessible() { + eprintln!("Skipping: /proc//root/ not accessible in this environment"); + return; + } + + // A regular file should return None (no expansion needed) + let mut tmp = tempfile::NamedTempFile::new().unwrap(); + tmp.write_all(b"regular file").unwrap(); + tmp.flush().unwrap(); + + let our_pid = std::process::id(); + let path = tmp.path().to_string_lossy().to_string(); + let result = resolve_binary_in_container(&path, our_pid); + + assert!( + result.is_none(), + "Non-symlink file should return None, got: {result:?}" + ); + } + + #[cfg(target_os = "linux")] + #[test] + fn resolve_binary_multi_level_symlink() { + use std::os::unix::fs::symlink; + + if !procfs_root_accessible() { + eprintln!("Skipping: /proc//root/ not accessible in this environment"); + return; + } + + // Test multi-level symlink resolution: python3 -> python3.11 -> cpython3.11 + let dir = tempfile::tempdir().unwrap(); + let final_target = dir.path().join("cpython3.11"); + let mid_link = dir.path().join("python3.11"); + let top_link = dir.path().join("python3"); + + std::fs::write(&final_target, b"final binary").unwrap(); + symlink(&final_target, &mid_link).unwrap(); + symlink(&mid_link, &top_link).unwrap(); + + let our_pid = std::process::id(); + let link_path = top_link.to_string_lossy().to_string(); + let result = resolve_binary_in_container(&link_path, our_pid); + + assert!(result.is_some(), "Should resolve multi-level symlink chain"); + let resolved = result.unwrap(); + assert!( + resolved.ends_with("cpython3.11"), + "Should resolve to final target: {resolved}" + ); + } + + #[cfg(target_os = "linux")] + #[test] + fn from_proto_with_pid_expands_symlinks_in_container() { + use std::os::unix::fs::symlink; + + if !procfs_root_accessible() { + eprintln!("Skipping: /proc//root/ not accessible in this environment"); + return; + } + + // End-to-end test: create a symlink, build engine with our PID, + // verify the resolved path is allowed + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("node22"); + let link = dir.path().join("node"); + + std::fs::write(&target, b"node binary").unwrap(); + symlink(&target, &link).unwrap(); + + let link_path = link.to_string_lossy().to_string(); + let target_path = target.to_string_lossy().to_string(); + + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "test".to_string(), + NetworkPolicyRule { + name: "test".to_string(), + endpoints: vec![NetworkEndpoint { + host: "example.com".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: link_path, + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + // Build engine with our PID (symlink resolution will work via /proc/self/root/) + let our_pid = std::process::id(); + let engine = OpaEngine::from_proto_with_pid(&proto, our_pid) + .expect("from_proto_with_pid should succeed"); + + // Request using the resolved target path should be allowed + let input = NetworkInput { + host: "example.com".into(), + port: 443, + binary_path: PathBuf::from(&target_path), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input).unwrap(); + assert!( + decision.allowed, + "Resolved symlink target should be allowed after expansion: {}", + decision.reason + ); + } + + #[cfg(target_os = "linux")] + #[test] + fn reload_from_proto_with_pid_resolves_symlinks() { + use std::os::unix::fs::symlink; + + if !procfs_root_accessible() { + eprintln!("Skipping: /proc//root/ not accessible in this environment"); + return; + } + + // Test hot-reload path: initial engine at pid=0, then reload with + // real PID to trigger symlink resolution + let dir = tempfile::tempdir().unwrap(); + let target = dir.path().join("python3.11"); + let link = dir.path().join("python3"); + + std::fs::write(&target, b"python binary").unwrap(); + symlink(&target, &link).unwrap(); + + let link_path = link.to_string_lossy().to_string(); + let target_path = target.to_string_lossy().to_string(); + + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "python".to_string(), + NetworkPolicyRule { + name: "python".to_string(), + endpoints: vec![NetworkEndpoint { + host: "pypi.org".to_string(), + port: 443, + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: link_path, + ..Default::default() + }], + }, + ); + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + // Initial load at pid=0 — no symlink expansion + let engine = OpaEngine::from_proto(&proto).expect("initial load"); + + // Request with resolved path should be DENIED (no expansion yet) + let input_resolved = NetworkInput { + host: "pypi.org".into(), + port: 443, + binary_path: PathBuf::from(&target_path), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let decision = engine.evaluate_network(&input_resolved).unwrap(); + assert!( + !decision.allowed, + "Before reload with PID, resolved path should be denied" + ); + + // Hot-reload with real PID — symlinks resolved + let our_pid = std::process::id(); + engine + .reload_from_proto_with_pid(&proto, our_pid) + .expect("reload with PID"); + + // Now the resolved path should be ALLOWED + let decision = engine.evaluate_network(&input_resolved).unwrap(); + assert!( + decision.allowed, + "After reload with PID, resolved path should be allowed: {}", + decision.reason + ); + } + + #[test] + fn l7_head_allowed_where_get_is_allowed() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "HEAD", "/repos/myorg/foo"); + assert!(eval_l7(&engine, &input)); + } + + #[test] + fn l7_head_denied_when_only_post_allowed() { + let engine = OpaEngine::from_strings( + TEST_POLICY, + "network_policies:\n p:\n name: p\n endpoints:\n - host: h.test\n port: 80\n protocol: rest\n enforcement: enforce\n rules:\n - allow: {method: POST, path: \"/\"}\n binaries:\n - {path: /usr/bin/curl}\n", + ) + .unwrap(); + let input = l7_input("h.test", 80, "HEAD", "/"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_options_not_implicitly_allowed_by_get() { + let engine = l7_engine(); + let input = l7_input("api.example.com", 8080, "OPTIONS", "/repos/myorg/foo"); + assert!(!eval_l7(&engine, &input)); + } + + #[test] + fn l7_head_blocked_by_deny_rule_targeting_get() { + // deny_rules use method_matches() too; a deny on GET must also block HEAD. + let engine = OpaEngine::from_strings( + TEST_POLICY, + "network_policies:\n p:\n name: p\n endpoints:\n - host: h.test\n port: 80\n protocol: rest\n enforcement: enforce\n access: full\n deny_rules:\n - method: GET\n path: \"/protected\"\n binaries:\n - {path: /usr/bin/curl}\n", + ) + .unwrap(); + let input = l7_input("h.test", 80, "HEAD", "/protected"); + assert!(!eval_l7(&engine, &input)); + } +} diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs new file mode 100644 index 000000000..34eeaada5 --- /dev/null +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -0,0 +1,2027 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Sandbox-local policy advisor HTTP API. + +use miette::{IntoDiagnostic, Result}; +use openshell_core::proto::{ + L7Allow, L7DenyRule, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, PolicyChunk, + SandboxPolicy as ProtoSandboxPolicy, +}; +use openshell_ocsf::{ConfigStateChangeBuilder, SeverityId, StateId, StatusId, ocsf_emit}; +use serde::Deserialize; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::RwLock; + +pub const POLICY_LOCAL_HOST: &str = "policy.local"; + +/// Filesystem path of the static agent guidance bundle inside the sandbox. +/// Single source of truth: the skill installer writes here, the L7 deny body +/// references this path in `next_steps`, and the skill's own documentation +/// renders the same path. Changing the location is a one-line update here. +pub const SKILL_PATH: &str = "/etc/openshell/skills/policy_advisor.md"; + +/// Human-readable guidance for agents that are more likely to follow plain +/// instructions than structured next-step JSON alone. +pub const AGENT_GUIDANCE: &str = "OpenShell blocked this request with sandbox policy. If the user's task still needs this network action, do not stop here: read /etc/openshell/skills/policy_advisor.md, submit the narrowest policy proposal to http://policy.local/v1/proposals, wait for approval and `policy_reloaded: true`, then retry the original request."; + +/// Routes served by the in-sandbox policy advisor API. Held in one place so +/// the L7 deny `next_steps` array, the route dispatcher, the skill content, +/// and tests all stay in sync — change the wire path here and every caller +/// follows. See `agent_next_steps()` for the consumer that surfaces these +/// to the agent on a 403. +pub const ROUTE_POLICY_CURRENT: &str = "/v1/policy/current"; +pub const ROUTE_DENIALS: &str = "/v1/denials"; +pub const ROUTE_PROPOSALS: &str = "/v1/proposals"; +/// Per-proposal status and long-poll routes live below this prefix: +/// `GET /v1/proposals/{chunk_id}` — immediate status +/// `GET /v1/proposals/{chunk_id}/wait?timeout` — long-poll until terminal +/// Trailing slash differentiates from the bare `POST /v1/proposals` submit. +const ROUTE_PROPOSALS_PREFIX: &str = "/v1/proposals/"; + +/// Long-poll bounds for `GET /v1/proposals/{id}/wait?timeout=`. The agent +/// re-issues on timeout, so the cap is a hold ceiling, not a hard limit on +/// how long the agent can wait overall. +const PROPOSAL_WAIT_DEFAULT_SECS: u64 = 60; +const PROPOSAL_WAIT_MIN_SECS: u64 = 1; +const PROPOSAL_WAIT_MAX_SECS: u64 = 300; +const PROPOSAL_WAIT_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1); +/// Minimum window the reload-readiness phase gets after a chunk +/// terminalizes, even if the caller's deadline is shorter. Without this, +/// approvals that arrive at T-50ms always return `policy_reloaded=false` +/// and force a re-issue. 500ms is well below typical supervisor poll +/// latency but enough to cover the in-memory coverage check. +const RELOAD_WAIT_MIN_FLOOR: std::time::Duration = std::time::Duration::from_millis(500); + +const MAX_POLICY_LOCAL_BODY_BYTES: usize = 64 * 1024; +/// Hard ceiling on how long a single request body read can stall. Bounds a +/// slowloris-style upload from an in-sandbox process; the proxy listener only +/// accepts loopback connections, so practical impact is limited, but this is +/// cheap defense-in-depth. +const POLICY_LOCAL_BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15); +const DEFAULT_DENIALS_LIMIT: usize = 10; +const MAX_DENIALS_LIMIT: usize = 100; +/// The shorthand rolling appender keeps three files (daily rotation); read the +/// most recent two so a request just past midnight still has yesterday's +/// denials. +const DENIAL_LOG_FILES_TO_SCAN: usize = 2; +const LOG_DIR: &str = "/var/log"; +/// Shorthand log filenames are `openshell.YYYY-MM-DD.log`. The trailing dot in +/// the prefix is intentional: it disambiguates from the OCSF JSONL appender's +/// `openshell-ocsf.YYYY-MM-DD.log`, which we never want to surface here (the +/// JSONL is opt-in via `ocsf_json_enabled` and not the source of truth for +/// `/v1/denials`). +const SHORTHAND_LOG_PREFIX: &str = "openshell."; +/// Defensive cap on per-line length returned to the agent so a pathological +/// log entry (very long URL path, etc.) cannot blow up the response. +const MAX_DENIAL_LINE_BYTES: usize = 4096; + +#[derive(Debug)] +pub struct PolicyLocalContext { + current_policy: Arc>>, + gateway_endpoint: Option, + sandbox_name: Option, + shorthand_log_dir: PathBuf, +} + +impl PolicyLocalContext { + pub fn new( + current_policy: Option, + gateway_endpoint: Option, + sandbox_name: Option, + ) -> Self { + Self::with_log_dir( + current_policy, + gateway_endpoint, + sandbox_name, + PathBuf::from(LOG_DIR), + ) + } + + fn with_log_dir( + current_policy: Option, + gateway_endpoint: Option, + sandbox_name: Option, + shorthand_log_dir: PathBuf, + ) -> Self { + Self { + current_policy: Arc::new(RwLock::new(current_policy)), + gateway_endpoint, + sandbox_name, + shorthand_log_dir, + } + } + + pub async fn set_current_policy(&self, policy: ProtoSandboxPolicy) { + *self.current_policy.write().await = Some(policy); + } +} + +pub async fn handle_forward_request( + ctx: &PolicyLocalContext, + method: &str, + path: &str, + initial_request: &[u8], + client: &mut S, +) -> Result<()> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let body = read_request_body(initial_request, client).await?; + let (status, payload) = route_request(ctx, method, path, &body).await; + write_json_response(client, status, payload).await +} + +async fn route_request( + ctx: &PolicyLocalContext, + method: &str, + path: &str, + body: &[u8], +) -> (u16, serde_json::Value) { + let (route, query) = path.split_once('?').map_or((path, ""), |(r, q)| (r, q)); + // Gate every route on the feature flag so the agent surface is fully off + // when the flag is off — including the diagnostic `current_policy` and + // `denials` routes. The skill is also not installed in that mode, so a + // disabled sandbox has no entry point into this API at all. + if !crate::agent_proposals_enabled() { + return ( + 404, + serde_json::json!({ + "error": "feature_disabled", + "detail": "agent-driven policy proposals are not enabled in this sandbox; set the `agent_policy_proposals_enabled` setting to true to enable" + }), + ); + } + match (method, route) { + ("GET", ROUTE_POLICY_CURRENT) => current_policy_response(ctx).await, + ("GET", ROUTE_DENIALS) => recent_denials_response(ctx, query).await, + ("POST", ROUTE_PROPOSALS) => submit_proposal(ctx, body).await, + ("GET", path) if path.starts_with(ROUTE_PROPOSALS_PREFIX) => { + proposal_state_route(ctx, path, query).await + } + _ => ( + 404, + serde_json::json!({ + "error": "not_found", + "detail": format!("policy.local route not found: {method} {route}") + }), + ), + } +} + +/// Parse `{chunk_id}` (status) or `{chunk_id}/wait` (long-poll) from the path +/// suffix and dispatch. Empty `chunk_id` or extra segments return 404 so a +/// malformed path cannot trigger a gateway call. +async fn proposal_state_route( + ctx: &PolicyLocalContext, + path: &str, + query: &str, +) -> (u16, serde_json::Value) { + let suffix = path + .strip_prefix(ROUTE_PROPOSALS_PREFIX) + .unwrap_or_default(); + let (chunk_id, wait) = match suffix.split_once('/') { + Some((id, "wait")) => (id, true), + Some(_) => return not_found_payload(path), + None => (suffix, false), + }; + if chunk_id.is_empty() { + return not_found_payload(path); + } + if wait { + proposal_wait_response(ctx, chunk_id, query).await + } else { + proposal_status_response(ctx, chunk_id).await + } +} + +fn not_found_payload(path: &str) -> (u16, serde_json::Value) { + ( + 404, + serde_json::json!({ + "error": "not_found", + "detail": format!("policy.local proposal sub-route not found: {path}") + }), + ) +} + +/// Build the `next_steps` array embedded in the L7 deny body so the agent has +/// machine-readable pointers to this API. Centralizes the shape here to keep +/// the deny body and the actual route table from drifting — adding or +/// renaming a route only requires touching the route constants above. +/// +/// Returns an empty array when `agent_proposals_enabled()` is false so a +/// disabled sandbox doesn't advertise a surface that 404s. The deny body +/// caller still emits the field (with `[]`) so the wire shape is stable. +#[must_use] +pub fn agent_next_steps() -> serde_json::Value { + if !crate::agent_proposals_enabled() { + return serde_json::json!([]); + } + let host = POLICY_LOCAL_HOST; + serde_json::json!([ + { + "action": "read_skill", + "path": SKILL_PATH, + }, + { + "action": "inspect_policy", + "method": "GET", + "url": format!("http://{host}{ROUTE_POLICY_CURRENT}"), + }, + { + "action": "inspect_recent_denials", + "method": "GET", + "url": format!("http://{host}{ROUTE_DENIALS}?last=5"), + }, + { + "action": "submit_proposal", + "method": "POST", + "url": format!("http://{host}{ROUTE_PROPOSALS}"), + "body_type": "PolicyMergeOperation", + }, + ]) +} + +/// Build the optional natural-language guidance embedded in L7 deny bodies. +#[must_use] +pub fn agent_guidance() -> Option<&'static str> { + crate::agent_proposals_enabled().then_some(AGENT_GUIDANCE) +} + +async fn current_policy_response(ctx: &PolicyLocalContext) -> (u16, serde_json::Value) { + let Some(policy) = ctx.current_policy.read().await.clone() else { + return ( + 404, + serde_json::json!({ + "error": "policy_unavailable", + "detail": "no current sandbox policy is loaded" + }), + ); + }; + + match openshell_policy::serialize_sandbox_policy(&policy) { + Ok(policy_yaml) => ( + 200, + serde_json::json!({ + "format": "yaml", + "policy_yaml": policy_yaml + }), + ), + Err(error) => ( + 500, + serde_json::json!({ + "error": "policy_serialize_failed", + "detail": error.to_string() + }), + ), + } +} + +async fn recent_denials_response( + ctx: &PolicyLocalContext, + query: &str, +) -> (u16, serde_json::Value) { + let limit = parse_last_query(query).unwrap_or(DEFAULT_DENIALS_LIMIT); + let log_dir = ctx.shorthand_log_dir.clone(); + + // Distinguish "shorthand log exists and no denials happened" from "no log + // file yet, so we have nothing to read." Without this flag the agent sees + // `[]` in both cases and cannot tell the difference. The shorthand log is + // always-on (no setting gates it), so the only way `log_available=false` + // happens in practice is if the supervisor has not flushed any events to + // disk yet, or `/var/log` is not writable in this image. + let log_available = matches!( + collect_shorthand_log_files(&log_dir, 1), + Ok(files) if !files.is_empty() + ); + + let denials = tokio::task::spawn_blocking(move || read_recent_denial_lines(&log_dir, limit)) + .await + .unwrap_or_default(); + + let mut payload = serde_json::json!({ + "denials": denials, + "log_available": log_available, + }); + if !log_available { + payload["note"] = serde_json::json!( + "no shorthand log file is present yet at /var/log/openshell.YYYY-MM-DD.log; the supervisor may not have emitted any events to disk yet" + ); + } + + (200, payload) +} + +fn parse_last_query(query: &str) -> Option { + if query.is_empty() { + return None; + } + for pair in query.split('&') { + let Some((key, value)) = pair.split_once('=') else { + continue; + }; + if key == "last" { + return value + .parse::() + .ok() + .map(|n| n.clamp(1, MAX_DENIALS_LIMIT)); + } + } + None +} + +/// Walk the shorthand log files (most-recent first) and return up to `limit` +/// raw denial lines in newest-first order. The agent receives the same +/// human-readable text that `openshell logs` displays — no parsing back into +/// structured form. Updating the shorthand format adds fields automatically; +/// no schema rev required. +/// +/// Reads files synchronously and is intended to run inside `spawn_blocking`. +fn read_recent_denial_lines(log_dir: &Path, limit: usize) -> Vec { + let Ok(files) = collect_shorthand_log_files(log_dir, DENIAL_LOG_FILES_TO_SCAN) else { + return Vec::new(); + }; + + let mut lines: Vec = Vec::with_capacity(limit); + for path in files { + let Ok(contents) = std::fs::read_to_string(&path) else { + continue; + }; + // Walk lines newest-first. Within a single file, the last line written + // is the freshest event. + for line in contents.lines().rev() { + if !is_ocsf_denial_line(line) { + continue; + } + // Defense-in-depth: redact query strings before truncation. The + // FORWARD deny path in `proxy.rs` populates the OCSF `message` + // and URL with the raw request path including `?query=...`, which + // the shorthand layer then renders verbatim. Stripping queries + // here means the agent never sees the secret even if an upstream + // emit site forgets to redact (TODO: harden the emit sites in + // proxy.rs FORWARD path so the on-disk shorthand log itself is + // clean — tracked separately). Redact first so truncation cannot + // slice mid-secret. + let redacted = redact_query_strings(line); + let surfaced = truncate_at_char_boundary(&redacted, MAX_DENIAL_LINE_BYTES); + lines.push(surfaced); + if lines.len() >= limit { + return lines; + } + } + } + lines +} + +/// Replace any `?` substring with `?[redacted]` to keep query-string +/// secrets out of the agent's view. Walks per Unicode scalar value so multi-byte +/// content is safe. A query is everything from `?` until the next whitespace or +/// `]` (the shorthand format uses `[...]` for context tags). +fn redact_query_strings(line: &str) -> String { + let mut out = String::with_capacity(line.len()); + let mut chars = line.chars(); + while let Some(c) = chars.next() { + if c == '?' { + out.push('?'); + out.push_str("[redacted]"); + // Consume until whitespace or `]` (preserved as the next token's + // boundary by writing it back out). + for next in chars.by_ref() { + if next.is_whitespace() || next == ']' { + out.push(next); + break; + } + } + } else { + out.push(c); + } + } + out +} + +/// Truncate `s` at the largest UTF-8 char boundary <= `max_bytes`, appending a +/// `...[truncated]` suffix. Returning a `String` (not `&str`) avoids surprising +/// callers about lifetime relationships with `s`. +fn truncate_at_char_boundary(s: &str, max_bytes: usize) -> String { + if s.len() <= max_bytes { + return s.to_string(); + } + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; + } + let mut out = String::with_capacity(end + "...[truncated]".len()); + out.push_str(&s[..end]); + out.push_str("...[truncated]"); + out +} + +/// True for OCSF denial events as rendered by the shorthand layer. The format +/// is ` OCSF <[SEV]> ...`. The literal +/// ` OCSF ` substring identifies an OCSF event (vs. a non-OCSF tracing line); +/// ` DENIED ` is the OCSF action label uppercased and surrounded by spaces, so +/// matching it is safe against substring collisions in URLs or hostnames. +fn is_ocsf_denial_line(line: &str) -> bool { + line.contains(" OCSF ") && line.contains(" DENIED ") +} + +fn collect_shorthand_log_files(log_dir: &Path, max_files: usize) -> std::io::Result> { + let mut entries: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(log_dir)? + .filter_map(std::result::Result::ok) + .filter_map(|entry| { + let path = entry.path(); + let name = entry.file_name(); + let name = name.to_string_lossy(); + // `openshell.YYYY-MM-DD.log` only — the trailing dot in the prefix + // disambiguates from `openshell-ocsf.YYYY-MM-DD.log`. + if !name.starts_with(SHORTHAND_LOG_PREFIX) || !name.ends_with(".log") { + return None; + } + let modified = entry.metadata().and_then(|m| m.modified()).ok()?; + Some((modified, path)) + }) + .collect(); + + entries.sort_by_key(|entry| std::cmp::Reverse(entry.0)); + Ok(entries + .into_iter() + .take(max_files) + .map(|(_, p)| p) + .collect()) +} + +async fn submit_proposal(ctx: &PolicyLocalContext, body: &[u8]) -> (u16, serde_json::Value) { + let Some(endpoint) = ctx.gateway_endpoint.as_deref() else { + return ( + 503, + serde_json::json!({ + "error": "gateway_unavailable", + "detail": "policy proposal submission requires a gateway-connected sandbox" + }), + ); + }; + let Some(sandbox_name) = ctx + .sandbox_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + else { + return ( + 503, + serde_json::json!({ + "error": "sandbox_name_unavailable", + "detail": "policy proposal submission requires a sandbox name" + }), + ); + }; + + let chunks = match proposal_chunks_from_body(body) { + Ok(chunks) => chunks, + Err(error) => return (400, error_payload("invalid_proposal", error)), + }; + + let client = match crate::grpc_client::CachedOpenShellClient::connect(endpoint).await { + Ok(client) => client, + Err(error) => { + return ( + 502, + serde_json::json!({ + "error": "gateway_connect_failed", + "detail": error.to_string() + }), + ); + } + }; + + // Pre-compute the audit summaries before handing `chunks` to the + // gateway client (which consumes the vec). The summaries pair up with + // the gateway's `accepted_chunk_ids` by index for the propose events + // emitted after submit returns. + let audit_summaries: Vec = chunks.iter().map(summarize_chunk_for_audit).collect(); + + let response = match client + .submit_policy_analysis(sandbox_name, vec![], chunks, vec![], "agent_authored") + .await + { + Ok(response) => response, + Err(error) => { + return ( + 502, + serde_json::json!({ + "error": "proposal_submit_failed", + "detail": error.to_string() + }), + ); + } + }; + + // One OCSF event per accepted chunk so the audit trace in + // `openshell logs ` carries the propose beat alongside the + // proxy deny and policy reload that bracket it. + // + // The gateway compresses its `accepted_chunk_ids` by skipping rejected + // chunks (`grpc/policy.rs:1357-1436`); the proto does not promise 1:1 + // ordering against the request. Today client-side validation catches + // both rejection causes (missing rule_name, missing proposed_rule) + // before submit, so the lengths match in practice. If they don't, we + // can't safely pair audit_summaries by index — fall back to a generic + // event per accepted chunk_id rather than mis-attribute a summary. + let pairing_is_safe = response.accepted_chunk_ids.len() == audit_summaries.len(); + for (idx, chunk_id) in response.accepted_chunk_ids.iter().enumerate() { + let summary = if pairing_is_safe { + audit_summaries[idx].as_str() + } else { + "(summary unavailable: gateway partially accepted)" + }; + emit_policy_propose_event(chunk_id, summary); + } + + ( + 202, + serde_json::json!({ + "status": "submitted", + "accepted_chunks": response.accepted_chunks, + "rejected_chunks": response.rejected_chunks, + "rejection_reasons": response.rejection_reasons, + "accepted_chunk_ids": response.accepted_chunk_ids, + }), + ) +} + +/// Emit one CONFIG:PROPOSED audit event for an agent-authored proposal that +/// the gateway just accepted. The message names the `chunk_id`, the binary, +/// and the endpoint the agent is asking to reach — what a developer needs +/// to see in the audit trace to correlate against the inbox card. +fn emit_policy_propose_event(chunk_id: &str, summary: &str) { + ocsf_emit!( + ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Other, "PROPOSED") + .unmapped("chunk_id", serde_json::json!(chunk_id)) + .message(format!( + "agent_authored proposal chunk:{chunk_id} {summary}" + )) + .build() + ); +} + +/// Emit one CONFIG:APPROVED or CONFIG:REJECTED audit event observed by the +/// `/wait` poll loop. The reviewer's free-form `rejection_reason` (if any) +/// is included verbatim so the audit trace shows what guidance the agent +/// received. +fn emit_policy_decision_event(chunk: &PolicyChunk) { + let summary = summarize_chunk_for_audit(chunk); + match chunk.status.as_str() { + "approved" => ocsf_emit!( + ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "APPROVED") + .unmapped("chunk_id", serde_json::json!(chunk.id)) + .message(format!("chunk:{} approved {summary}", chunk.id)) + .build() + ), + "rejected" => { + // The reviewer's free-form rejection_reason is opaque user + // input. The agent reads the raw text via `GET /v1/proposals/ + // {id}` to redraft; the OCSF surface (which can be shipped to + // external SIEMs per AGENTS.md) gets a sanitized copy — caps + // length and strips control characters so a stray credential + // or escape sequence cannot leak into the audit log. + let sanitized = sanitize_reason_for_audit(&chunk.rejection_reason); + let reason_display = if sanitized.is_empty() { + "(no guidance)".to_string() + } else { + format!("\"{sanitized}\"") + }; + ocsf_emit!( + ConfigStateChangeBuilder::new(crate::ocsf_ctx()) + .severity(SeverityId::Low) + .status(StatusId::Success) + .state(StateId::Disabled, "REJECTED") + .unmapped("chunk_id", serde_json::json!(chunk.id)) + .unmapped("rejection_reason", serde_json::json!(sanitized)) + .message(format!( + "chunk:{} rejected {summary} reason:{reason_display}", + chunk.id + )) + .build() + ); + } + // Caller is gated on `is_terminal_status`, so a non-terminal status + // here is a code change that broke the invariant. Warn loudly so + // the audit gap doesn't go silent. + other => tracing::warn!( + chunk_id = %chunk.id, + status = %other, + "emit_policy_decision_event called on non-terminal status; no audit event emitted" + ), + } +} + +/// Sanitize a free-form reviewer-typed string before it lands in the OCSF +/// audit surface. The agent still reads the raw text via the API — this is +/// audit-side defense only. +fn sanitize_reason_for_audit(raw: &str) -> String { + const MAX_CHARS: usize = 200; + let cleaned: String = raw + .chars() + .filter(|c| !c.is_control() || *c == ' ') + .take(MAX_CHARS) + .collect(); + if raw.chars().count() > MAX_CHARS { + format!("{cleaned}…") + } else { + cleaned + } +} + +/// One-line audit description of a chunk's target: binary, host, port, and +/// L7 method/path if present. Used by both the propose and approve/reject +/// audit events so the trace can be grepped by endpoint without parsing +/// JSON. +fn summarize_chunk_for_audit(chunk: &PolicyChunk) -> String { + let Some(rule) = chunk.proposed_rule.as_ref() else { + return format!("rule_name:{}", chunk.rule_name); + }; + let endpoint = rule.endpoints.first().map_or_else( + || "unknown".to_string(), + |ep| format!("{}:{}", ep.host, ep.port), + ); + let l7 = rule + .endpoints + .first() + .and_then(|ep| ep.rules.first()) + .and_then(|r| r.allow.as_ref()) + .map(|a| format!(" {} {}", a.method, a.path)) + .unwrap_or_default(); + let binary = if chunk.binary.is_empty() { + String::new() + } else { + format!(" by {}", chunk.binary) + }; + format!("on {endpoint}{l7}{binary}") +} + +/// `GET /v1/proposals/{chunk_id}` — immediate state. One gateway call, no loop. +async fn proposal_status_response( + ctx: &PolicyLocalContext, + chunk_id: &str, +) -> (u16, serde_json::Value) { + let session = match open_lookup_session(ctx).await { + Ok(session) => session, + Err(err) => return err, + }; + fetch_chunk_or_404(&session, chunk_id, false).await +} + +/// `GET /v1/proposals/{chunk_id}/wait?timeout=` — block until terminal or +/// timeout. Returns the chunk's current state on a status transition; on +/// timeout, returns the still-pending state with `timed_out: true` so the +/// agent can re-issue without ambiguity. The agent's wait costs zero LLM +/// tokens — the tool call sits in a socket recv until we return. +async fn proposal_wait_response( + ctx: &PolicyLocalContext, + chunk_id: &str, + query: &str, +) -> (u16, serde_json::Value) { + let session = match open_lookup_session(ctx).await { + Ok(session) => session, + Err(err) => return err, + }; + let timeout_secs = parse_timeout_query(query); + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(timeout_secs); + loop { + match fetch_chunk(&session, chunk_id).await { + Ok(Some(chunk)) if is_terminal_status(&chunk.status) => { + // Audit beat: emit at the moment this sandbox observes the + // decision so the trace correlates with the proxy events + // bracketing the loop. Multiple waiters on the same chunk + // each fire one event — acceptable for a wakeup audit. + emit_policy_decision_event(&chunk); + let policy_reloaded = if chunk.status == "approved" { + // Hold the wait until the local supervisor has loaded a + // policy that semantically contains this chunk's + // proposed rule. Reloads triggered by *other* chunks or + // settings changes do not wake us; a missing + // proposed_rule (defensive) skips the check and + // returns reloaded=false so the agent can decide. + // + // Floor the reload-wait window to RELOAD_WAIT_MIN_FLOOR + // so an approval that arrives at T-50ms still gets a + // realistic shot at seeing the reload. Worst case we + // overshoot the caller's deadline by this floor — + // preferable to returning reloaded=false on every + // short-budget call and forcing the agent to re-issue. + let reload_deadline = std::cmp::max( + deadline, + tokio::time::Instant::now() + RELOAD_WAIT_MIN_FLOOR, + ); + match chunk.proposed_rule.as_ref() { + Some(rule) => { + wait_for_local_policy_to_cover(ctx, rule, reload_deadline).await + } + None => false, + } + } else { + // Rejected: no reload semantics — the agent reads + // rejection_reason and redrafts. + false + }; + return (200, chunk_state_payload(&chunk, false, policy_reloaded)); + } + Ok(Some(chunk)) => { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return (200, chunk_state_payload(&chunk, true, false)); + } + let sleep_for = std::cmp::min(remaining, PROPOSAL_WAIT_POLL_INTERVAL); + tokio::time::sleep(sleep_for).await; + } + Ok(None) => return chunk_not_found_payload(chunk_id), + Err(err) => return err, + } + } +} + +fn chunk_not_found_payload(chunk_id: &str) -> (u16, serde_json::Value) { + ( + 404, + error_payload( + "chunk_not_found", + format!("chunk '{chunk_id}' is not present in this sandbox's draft policy"), + ), + ) +} + +async fn fetch_chunk_or_404( + session: &LookupSession<'_>, + chunk_id: &str, + timed_out: bool, +) -> (u16, serde_json::Value) { + match fetch_chunk(session, chunk_id).await { + Ok(Some(chunk)) => (200, chunk_state_payload(&chunk, timed_out, false)), + Ok(None) => chunk_not_found_payload(chunk_id), + Err(err) => err, + } +} + +/// Build the agent-facing response for a chunk. +/// +/// Selection rule: include the fields the agent needs to decide what to do +/// next on the redraft loop — identity (`chunk_id`, `status`), the proposal +/// it submitted (`rule_name`, `binary`), the two feedback signals +/// (`rejection_reason` from the reviewer, `validation_result` from the +/// gateway prover), and (on /wait) `policy_reloaded` so the agent can tell +/// "approved AND the new rule is loaded — safe to retry" from "approved +/// but the supervisor hasn't reloaded yet — re-issue /wait or surface to +/// user". Display-only proto fields (`hit_count`, `confidence`, `stage`, +/// timing) are left off until a concrete agent need surfaces them. +fn chunk_state_payload( + chunk: &PolicyChunk, + timed_out: bool, + policy_reloaded: bool, +) -> serde_json::Value { + let mut payload = serde_json::json!({ + "chunk_id": chunk.id, + "status": chunk.status, + "rule_name": chunk.rule_name, + "binary": chunk.binary, + "rejection_reason": chunk.rejection_reason, + "validation_result": chunk.validation_result, + }); + if timed_out { + payload["timed_out"] = serde_json::json!(true); + } + if chunk.status == "approved" { + payload["policy_reloaded"] = serde_json::json!(policy_reloaded); + } + payload +} + +fn is_terminal_status(status: &str) -> bool { + matches!(status, "approved" | "rejected") +} + +/// After a chunk is approved upstream, wait until the local supervisor has +/// loaded a policy that semantically contains the chunk's proposed rule. +/// Returns `true` if coverage was observed before the deadline, `false` +/// otherwise — the caller reports that bool back to the agent as +/// `policy_reloaded` so it can decide whether to retry immediately or +/// re-issue `/wait`. +/// +/// Why rule-coverage instead of whole-policy diff (as we used to do): +/// +/// 1. **False sleep.** If the agent re-issues `/wait` after a `timed_out` +/// response, the chunk may have approved AND the supervisor may have +/// reloaded between the two `/wait` calls. A diff-based check snapshots +/// the already-updated policy as baseline and then waits forever for +/// another change. The skill tells the agent to re-issue on +/// `timed_out`, so the diff approach is broken on the happy path. +/// 2. **False wakeup.** Any unrelated reload (another agent's approval, +/// settings change) flips a whole-policy diff, but the chunk's actual +/// rule may not be loaded yet. The agent retries, hits another +/// `policy_denied`, and the revise-loop fires with no real signal to +/// revise on. +/// +/// The polling cadence here is faster than `PROPOSAL_WAIT_POLL_INTERVAL` +/// (which paces upstream gateway calls). This loop only reads in-memory +/// state, so 200ms gives a responsive handoff to the agent's retry once +/// the supervisor's own policy poll catches up. +async fn wait_for_local_policy_to_cover( + ctx: &PolicyLocalContext, + proposed_rule: &NetworkPolicyRule, + deadline: tokio::time::Instant, +) -> bool { + const TICK: std::time::Duration = std::time::Duration::from_millis(200); + loop { + // Clone the snapshot out of the RwLock before running coverage — + // otherwise the read guard is held across `policy_covers_rule`'s + // iteration of `network_policies`, serializing a writer (supervisor + // reload) on the very thing we're waiting for. Clone-per-tick on + // a few-KB struct is cheap for the bounded wait window here. + let snapshot = ctx.current_policy.read().await.clone(); + if let Some(policy) = snapshot.as_ref() + && openshell_policy::policy_covers_rule(policy, proposed_rule) + { + return true; + } + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return false; + } + tokio::time::sleep(std::cmp::min(remaining, TICK)).await; + } +} + +/// Parse `?timeout=` from the query string. Default applies for missing +/// or unparseable values; bounds clamp to keep the agent's hold ceiling +/// sane. Re-issue is the right pattern for longer waits. +fn parse_timeout_query(query: &str) -> u64 { + let raw = query + .split('&') + .filter_map(|kv| kv.split_once('=')) + .find(|(k, _)| *k == "timeout") + .map_or("", |(_, v)| v); + raw.parse::() + .unwrap_or(PROPOSAL_WAIT_DEFAULT_SECS) + .clamp(PROPOSAL_WAIT_MIN_SECS, PROPOSAL_WAIT_MAX_SECS) +} + +/// One connected gateway client + the validated sandbox name. Built once +/// per request and reused for every `fetch_chunk` call in a wait loop so a +/// 60-second wait does one TLS handshake, not sixty. +struct LookupSession<'a> { + client: crate::grpc_client::CachedOpenShellClient, + sandbox_name: &'a str, +} + +/// Validate ctx and open one gateway channel. Failures map to the canonical +/// error payload shape used by both `/proposals/{id}` and `/wait`. +async fn open_lookup_session( + ctx: &PolicyLocalContext, +) -> std::result::Result, (u16, serde_json::Value)> { + let endpoint = ctx.gateway_endpoint.as_deref().ok_or_else(|| { + ( + 503, + error_payload( + "gateway_unavailable", + "proposal state lookup requires a gateway-connected sandbox".to_string(), + ), + ) + })?; + let sandbox_name = ctx + .sandbox_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + .ok_or_else(|| { + ( + 503, + error_payload( + "sandbox_name_unavailable", + "proposal state lookup requires a sandbox name".to_string(), + ), + ) + })?; + let client = crate::grpc_client::CachedOpenShellClient::connect(endpoint) + .await + .map_err(|e| (502, error_payload("gateway_connect_failed", e.to_string())))?; + Ok(LookupSession { + client, + sandbox_name, + }) +} + +/// One gateway call: list the sandbox's draft chunks and find the matching +/// id. Returns `Ok(None)` only when the gateway responded successfully but +/// no chunk in this sandbox matches. +async fn fetch_chunk( + session: &LookupSession<'_>, + chunk_id: &str, +) -> std::result::Result, (u16, serde_json::Value)> { + let chunks = session + .client + .get_draft_policy(session.sandbox_name, "") + .await + .map_err(|e| (502, error_payload("gateway_lookup_failed", e.to_string())))?; + Ok(chunks.into_iter().find(|c| c.id == chunk_id)) +} + +fn proposal_chunks_from_body(body: &[u8]) -> std::result::Result, String> { + let request: ProposalRequest = serde_json::from_slice(body).map_err(|e| e.to_string())?; + if request.operations.is_empty() { + return Err("proposal requires at least one operation".to_string()); + } + + let mut chunks = Vec::new(); + for operation in request.operations { + let Some(add_rule) = operation.get("addRule").cloned() else { + return Err( + "this MVP accepts `addRule` operations; submit a full narrow NetworkPolicyRule" + .to_string(), + ); + }; + let add_rule: AddNetworkRuleJson = + serde_json::from_value(add_rule).map_err(|e| e.to_string())?; + chunks.push(policy_chunk_from_add_rule( + add_rule, + request.intent_summary.as_deref().unwrap_or_default(), + )?); + } + + Ok(chunks) +} + +fn policy_chunk_from_add_rule( + add_rule: AddNetworkRuleJson, + intent_summary: &str, +) -> std::result::Result { + let mut rule = network_rule_from_json(add_rule.rule)?; + let rule_name = add_rule + .rule_name + .as_deref() + .map(str::trim) + .filter(|name| !name.is_empty()) + .map_or_else(|| rule.name.clone(), ToString::to_string); + if rule_name.trim().is_empty() { + return Err("addRule.ruleName or rule.name is required".to_string()); + } + if rule.name.trim().is_empty() { + rule.name.clone_from(&rule_name); + } + + let binary = rule + .binaries + .first() + .map(|binary| binary.path.clone()) + .unwrap_or_default(); + + Ok(PolicyChunk { + id: String::new(), + status: "pending".to_string(), + rule_name, + proposed_rule: Some(rule), + rationale: intent_summary.to_string(), + security_notes: String::new(), + confidence: 0.75, + denial_summary_ids: vec![], + created_at_ms: 0, + decided_at_ms: 0, + stage: "agent".to_string(), + supersedes_chunk_id: String::new(), + hit_count: 1, + first_seen_ms: 0, + last_seen_ms: 0, + binary, + validation_result: String::new(), + rejection_reason: String::new(), + }) +} + +fn network_rule_from_json( + rule: NetworkPolicyRuleJson, +) -> std::result::Result { + if rule.endpoints.is_empty() { + return Err("rule.endpoints must contain at least one endpoint".to_string()); + } + + let endpoints = rule + .endpoints + .into_iter() + .map(|endpoint| { + let mut endpoint = network_endpoint_from_json(endpoint)?; + endpoint.advisor_proposed = true; + Ok::(endpoint) + }) + .collect::, _>>()?; + let binaries = rule + .binaries + .into_iter() + .map(|binary| { + let mut proposal_binary = NetworkBinary { + path: binary.path, + ..Default::default() + }; + // The deprecated harness bit is ignored by policy YAML, but OPA + // maps it to advisor_proposed to preserve the SSRF two-step flow. + #[allow(deprecated)] + { + proposal_binary.harness = true; + } + proposal_binary + }) + .collect(); + + Ok(NetworkPolicyRule { + name: rule.name.unwrap_or_default(), + endpoints, + binaries, + }) +} + +fn network_endpoint_from_json( + endpoint: NetworkEndpointJson, +) -> std::result::Result { + if endpoint.host.trim().is_empty() { + return Err("endpoint.host is required".to_string()); + } + + let mut ports = endpoint.ports; + if ports.is_empty() && endpoint.port > 0 { + ports.push(endpoint.port); + } + if ports.is_empty() { + return Err("endpoint.port or endpoint.ports is required".to_string()); + } + if endpoint + .rules + .iter() + .any(|rule| rule.allow.path.contains('?')) + { + return Err("L7 allow paths must not include query strings".to_string()); + } + + let port = ports.first().copied().unwrap_or_default(); + let rules = endpoint + .rules + .into_iter() + .map(|rule| L7Rule { + allow: Some(L7Allow { + method: rule.allow.method, + path: rule.allow.path, + command: rule.allow.command, + query: HashMap::new(), + // GraphQL fields default empty — agent-authored proposals from + // policy.local target REST/SQL/L4 endpoints; GraphQL operation + // matching is set on the policy server side or via direct YAML. + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + rpc_method: String::new(), + }), + }) + .collect(); + let deny_rules = endpoint + .deny_rules + .into_iter() + .map(|rule| L7DenyRule { + method: rule.method, + path: rule.path, + command: rule.command, + query: HashMap::new(), + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + rpc_method: String::new(), + }) + .collect(); + + Ok(NetworkEndpoint { + host: endpoint.host, + port, + protocol: endpoint.protocol, + tls: endpoint.tls, + enforcement: endpoint.enforcement, + access: endpoint.access, + rules, + allowed_ips: endpoint.allowed_ips, + ports, + deny_rules, + allow_encoded_slash: endpoint.allow_encoded_slash, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + advisor_proposed: false, + // GraphQL persisted-query knobs and path scoping default empty — + // agent proposals don't author them today. + persisted_queries: String::new(), + graphql_persisted_queries: HashMap::new(), + graphql_max_body_bytes: 0, + path: String::new(), + }) +} + +async fn read_request_body(initial_request: &[u8], client: &mut S) -> Result> +where + S: AsyncRead + Unpin, +{ + let Some(header_end) = find_header_end(initial_request) else { + return Ok(Vec::new()); + }; + let content_length = parse_content_length(&initial_request[..header_end])?; + if content_length > MAX_POLICY_LOCAL_BODY_BYTES { + return Err(miette::miette!( + "policy.local request body exceeds {MAX_POLICY_LOCAL_BODY_BYTES} bytes" + )); + } + + let mut body = initial_request[header_end..].to_vec(); + if body.len() > content_length { + body.truncate(content_length); + } + let read_loop = async { + while body.len() < content_length { + let remaining = content_length - body.len(); + let mut chunk = vec![0u8; remaining.min(8192)]; + let n = client.read(&mut chunk).await.into_diagnostic()?; + if n == 0 { + return Err(miette::miette!("policy.local request body ended early")); + } + body.extend_from_slice(&chunk[..n]); + } + Ok::<(), miette::Report>(()) + }; + tokio::time::timeout(POLICY_LOCAL_BODY_READ_TIMEOUT, read_loop) + .await + .map_err(|_| miette::miette!("policy.local request body read timed out"))??; + + Ok(body) +} + +fn parse_content_length(headers: &[u8]) -> Result { + let headers = String::from_utf8_lossy(headers); + for line in headers.lines().skip(1) { + if let Some((name, value)) = line.split_once(':') + && name.eq_ignore_ascii_case("content-length") + { + return value + .trim() + .parse::() + .into_diagnostic() + .map_err(|_| miette::miette!("invalid policy.local Content-Length")); + } + } + Ok(0) +} + +fn find_header_end(buf: &[u8]) -> Option { + buf.windows(4) + .position(|window| window == b"\r\n\r\n") + .map(|idx| idx + 4) +} + +async fn write_json_response( + client: &mut S, + status: u16, + payload: serde_json::Value, +) -> Result<()> +where + S: AsyncWrite + Unpin, +{ + let body = payload.to_string(); + let response = format!( + "HTTP/1.1 {status} {}\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + Connection: close\r\n\ + \r\n\ + {}", + status_text(status), + body.len(), + body + ); + client + .write_all(response.as_bytes()) + .await + .into_diagnostic()?; + client.flush().await.into_diagnostic()?; + Ok(()) +} + +fn status_text(status: u16) -> &'static str { + match status { + 202 => "Accepted", + 400 => "Bad Request", + 404 => "Not Found", + 500 => "Internal Server Error", + 502 => "Bad Gateway", + 503 => "Service Unavailable", + _ => "OK", + } +} + +fn error_payload(error: &str, detail: String) -> serde_json::Value { + serde_json::json!({ + "error": error, + "detail": detail + }) +} + +#[derive(Debug, Deserialize)] +struct ProposalRequest { + #[serde(default)] + intent_summary: Option, + #[serde(default)] + operations: Vec, +} + +#[derive(Debug, Deserialize)] +struct AddNetworkRuleJson { + #[serde(default, rename = "ruleName")] + rule_name: Option, + rule: NetworkPolicyRuleJson, +} + +#[derive(Debug, Deserialize)] +struct NetworkPolicyRuleJson { + #[serde(default)] + name: Option, + #[serde(default)] + endpoints: Vec, + #[serde(default)] + binaries: Vec, +} + +#[derive(Debug, Deserialize)] +struct NetworkEndpointJson { + host: String, + #[serde(default)] + port: u32, + #[serde(default)] + ports: Vec, + #[serde(default)] + protocol: String, + #[serde(default)] + tls: String, + #[serde(default)] + enforcement: String, + #[serde(default)] + access: String, + #[serde(default)] + rules: Vec, + #[serde(default)] + allowed_ips: Vec, + #[serde(default)] + deny_rules: Vec, + #[serde(default)] + allow_encoded_slash: bool, +} + +#[derive(Debug, Deserialize)] +struct NetworkBinaryJson { + path: String, +} + +#[derive(Debug, Deserialize)] +struct L7RuleJson { + allow: L7AllowJson, +} + +#[derive(Debug, Deserialize)] +struct L7AllowJson { + #[serde(default)] + method: String, + #[serde(default)] + path: String, + #[serde(default)] + command: String, +} + +#[derive(Debug, Deserialize)] +struct L7DenyRuleJson { + #[serde(default)] + method: String, + #[serde(default)] + path: String, + #[serde(default)] + command: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn proposal_chunks_from_body_accepts_add_rule_operation() { + let body = br#"{ + "intent_summary": "Allow gh to create one repo.", + "operations": [ + { + "addRule": { + "ruleName": "github_api_repo_create", + "rule": { + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "protocol": "rest", + "tls": "terminate", + "enforcement": "enforce", + "rules": [ + { + "allow": { + "method": "POST", + "path": "/user/repos" + } + } + ] + } + ], + "binaries": [ + { + "path": "/usr/bin/gh" + } + ] + } + } + } + ] + }"#; + + let chunks = proposal_chunks_from_body(body).unwrap(); + + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].rule_name, "github_api_repo_create"); + assert_eq!(chunks[0].rationale, "Allow gh to create one repo."); + assert_eq!(chunks[0].binary, "/usr/bin/gh"); + let rule = chunks[0].proposed_rule.as_ref().unwrap(); + assert_eq!(rule.name, "github_api_repo_create"); + assert_eq!(rule.endpoints[0].host, "api.github.com"); + assert_eq!(rule.endpoints[0].port, 443); + assert_eq!(rule.endpoints[0].ports, vec![443]); + assert_eq!(rule.endpoints[0].protocol, "rest"); + #[allow(deprecated)] + { + assert!(rule.binaries[0].harness); + } + assert_eq!( + rule.endpoints[0].rules[0].allow.as_ref().unwrap().path, + "/user/repos" + ); + } + + #[test] + fn proposal_chunks_from_body_rejects_query_in_l7_path() { + let body = br#"{ + "operations": [ + { + "addRule": { + "ruleName": "bad", + "rule": { + "endpoints": [ + { + "host": "api.github.com", + "port": 443, + "rules": [ + { + "allow": { + "method": "GET", + "path": "/repos?token=secret" + } + } + ] + } + ] + } + } + } + ] + }"#; + + let error = proposal_chunks_from_body(body).unwrap_err(); + assert!(error.contains("query strings")); + assert!(!error.contains("secret")); + } + + #[test] + fn parse_last_query_clamps_to_max() { + assert_eq!(parse_last_query("last=5"), Some(5)); + assert_eq!(parse_last_query("foo=bar&last=20"), Some(20)); + assert_eq!(parse_last_query("last=999"), Some(MAX_DENIALS_LIMIT)); + assert_eq!(parse_last_query("last=0"), Some(1)); + assert_eq!(parse_last_query(""), None); + assert_eq!(parse_last_query("other=1"), None); + } + + #[test] + fn is_ocsf_denial_line_filters_correctly() { + // OCSF denial — match. + assert!(is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/x [policy:p engine:l7]" + )); + assert!(is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF NET:OPEN [MED] DENIED curl(42) -> blocked.com:443 [policy:- engine:opa]" + )); + + // OCSF allowed — must not match. + assert!(!is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(42) -> api.example.com:443" + )); + + // Non-OCSF tracing line — must not match even if it contains the word DENIED. + assert!(!is_ocsf_denial_line( + "2026-05-06T17:02:00.000Z INFO some::module: request DENIED in upstream" + )); + + // Empty line — must not match. + assert!(!is_ocsf_denial_line("")); + } + + #[tokio::test] + async fn recent_denials_returns_newest_first_from_shorthand_lines() { + let dir = tempfile::tempdir().unwrap(); + let log_path = dir.path().join("openshell.2026-05-06.log"); + // Mixed file: allowed events, non-OCSF info lines, two denials. + // Lines are written in chronological order; reader walks newest-first. + let body = "\ +2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(10) -> api.example.com:443 [policy:default engine:opa] +2026-05-06T17:02:01.000Z INFO some::module: routine status check +2026-05-06T17:02:02.000Z OCSF HTTP:GET [MED] DENIED GET http://blocked.example/v1/data [policy:default-deny engine:l7] +2026-05-06T17:02:03.000Z OCSF NET:OPEN [INFO] ALLOWED curl(11) -> api.example.com:443 +2026-05-06T17:02:04.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/repos/x/y/contents/z [policy:gh_readonly engine:l7] +"; + std::fs::write(&log_path, body).unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "last=10").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], true); + let denials = payload["denials"].as_array().unwrap(); + assert_eq!(denials.len(), 2); + // Newest first. + assert!(denials[0].as_str().unwrap().contains("HTTP:PUT")); + assert!( + denials[0] + .as_str() + .unwrap() + .contains("/repos/x/y/contents/z") + ); + assert!(denials[1].as_str().unwrap().contains("HTTP:GET")); + assert!(denials[1].as_str().unwrap().contains("blocked.example")); + } + + #[tokio::test] + async fn recent_denials_skips_jsonl_log_files() { + // The shorthand reader must not surface `openshell-ocsf.*.log` content + // even if a deny-looking line is present, so the response stays + // independent of the JSONL appender's enabled state. + let dir = tempfile::tempdir().unwrap(); + let jsonl = dir.path().join("openshell-ocsf.2026-05-06.log"); + std::fs::write( + &jsonl, + r#"{"class_uid":4002,"action_id":2,"message":"DENIED","time":1}"#, + ) + .unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], false); + assert_eq!(payload["denials"].as_array().unwrap().len(), 0); + } + + #[tokio::test] + async fn recent_denials_signals_when_log_is_missing() { + let dir = tempfile::tempdir().unwrap(); + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (status, payload) = recent_denials_response(&ctx, "").await; + assert_eq!(status, 200); + assert_eq!(payload["log_available"], false); + assert_eq!(payload["denials"].as_array().unwrap().len(), 0); + assert!( + payload["note"] + .as_str() + .unwrap() + .contains("/var/log/openshell.") + ); + } + + #[test] + fn redact_query_strings_removes_query_from_url_token() { + let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x?access_token=secret-token-1234 [policy:p engine:l7]"; + let redacted = redact_query_strings(line); + assert!(!redacted.contains("secret-token-1234")); + assert!(!redacted.contains("access_token")); + assert!(redacted.contains("?[redacted]")); + // Bracketed tag after the URL preserved. + assert!(redacted.contains("[policy:p engine:l7]")); + } + + #[test] + fn redact_query_strings_removes_query_in_reason_tag() { + // The FORWARD deny path's `message` becomes `[reason:...]` and may + // include a path with query string lacking a `://` prefix. + let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x [policy:p engine:opa] [reason:FORWARD denied PUT api.github.com:443/x?token=secret-456]"; + let redacted = redact_query_strings(line); + assert!(!redacted.contains("secret-456")); + assert!(!redacted.contains("token=secret")); + assert!(redacted.contains("?[redacted]]")); + } + + #[test] + fn redact_query_strings_handles_multibyte_chars() { + let line = "ÜLÅUTF8 ? secret-x [policy:p]"; + // No `?` here, so no redaction — but must not panic. + let _ = redact_query_strings(line); + } + + #[test] + fn truncate_at_char_boundary_does_not_panic_on_multibyte() { + // 4-byte emoji sequence so byte-naive slicing would panic. + let s = "🚀".repeat(2000); // 8000 bytes + let truncated = truncate_at_char_boundary(&s, 4096); + assert!(truncated.len() <= 4096 + "...[truncated]".len()); + assert!(truncated.ends_with("...[truncated]")); + // Result must be valid UTF-8 — implicit if we return without panic. + } + + #[tokio::test] + async fn recent_denials_truncates_pathological_lines() { + let dir = tempfile::tempdir().unwrap(); + let log_path = dir.path().join("openshell.2026-05-06.log"); + // A single OCSF denial line exceeding MAX_DENIAL_LINE_BYTES. + let huge_path = "/".to_string() + &"a".repeat(MAX_DENIAL_LINE_BYTES + 100); + let line = format!( + "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://x{huge_path} [policy:p engine:l7]\n" + ); + std::fs::write(&log_path, line).unwrap(); + + let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); + let (_, payload) = recent_denials_response(&ctx, "last=1").await; + let denials = payload["denials"].as_array().unwrap(); + assert_eq!(denials.len(), 1); + let surfaced = denials[0].as_str().unwrap(); + assert!(surfaced.len() <= MAX_DENIAL_LINE_BYTES + "...[truncated]".len()); + assert!(surfaced.ends_with("...[truncated]")); + } + + use crate::test_helpers::ProposalsFlagGuard; + + #[test] + fn agent_next_steps_returns_empty_when_flag_off() { + let _guard = ProposalsFlagGuard::set_blocking(false); + let steps = agent_next_steps(); + let arr = steps.as_array().expect("agent_next_steps is an array"); + assert!( + arr.is_empty(), + "expected empty next_steps when feature is off, got {steps}" + ); + } + + #[test] + fn agent_next_steps_returns_full_array_when_flag_on() { + let _guard = ProposalsFlagGuard::set_blocking(true); + let steps = agent_next_steps(); + let arr = steps.as_array().expect("agent_next_steps is an array"); + assert_eq!(arr.len(), 4, "expected 4 next_steps when feature is on"); + let actions: Vec<&str> = arr + .iter() + .filter_map(|v| v.get("action").and_then(serde_json::Value::as_str)) + .collect(); + assert!(actions.contains(&"read_skill")); + assert!(actions.contains(&"submit_proposal")); + } + + #[test] + fn agent_guidance_is_absent_when_flag_off() { + let _guard = ProposalsFlagGuard::set_blocking(false); + assert!(agent_guidance().is_none()); + } + + #[test] + fn agent_guidance_points_to_policy_advisor_when_flag_on() { + let _guard = ProposalsFlagGuard::set_blocking(true); + let guidance = agent_guidance().expect("guidance when proposals are enabled"); + assert!(guidance.contains("do not stop")); + assert!(guidance.contains("/etc/openshell/skills/policy_advisor.md")); + assert!(guidance.contains("http://policy.local/v1/proposals")); + assert!(guidance.contains("policy_reloaded: true")); + } + + #[tokio::test] + async fn route_request_returns_feature_disabled_when_flag_off() { + let _guard = ProposalsFlagGuard::set(false).await; + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + // Even the otherwise-public `current_policy` route returns 404 with + // a feature_disabled error: when the surface is off it's off + // entirely, not selectively. + let (status, payload) = route_request(&ctx, "GET", ROUTE_POLICY_CURRENT, &[]).await; + assert_eq!(status, 404); + assert_eq!(payload["error"], "feature_disabled"); + assert!( + payload["detail"] + .as_str() + .unwrap() + .contains("agent_policy_proposals_enabled"), + "feature_disabled detail must name the setting key for actionability" + ); + } + + #[tokio::test] + async fn current_policy_route_returns_yaml_envelope() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + let (mut client, mut server) = tokio::io::duplex(4096); + let request = + b"GET http://policy.local/v1/policy/current HTTP/1.1\r\nHost: policy.local\r\n\r\n"; + let task = tokio::spawn(async move { + handle_forward_request(&ctx, "GET", "/v1/policy/current", request, &mut server) + .await + .unwrap(); + }); + + let mut received = Vec::new(); + client.read_to_end(&mut received).await.unwrap(); + task.await.unwrap(); + + let response = String::from_utf8(received).unwrap(); + assert!(response.starts_with("HTTP/1.1 200 OK")); + let (_, body) = response.split_once("\r\n\r\n").unwrap(); + let body: serde_json::Value = serde_json::from_str(body).unwrap(); + assert_eq!(body["format"], "yaml"); + assert!(body["policy_yaml"].as_str().unwrap().contains("version: 1")); + } + + #[test] + fn parse_timeout_query_defaults_and_clamps() { + assert_eq!(parse_timeout_query(""), PROPOSAL_WAIT_DEFAULT_SECS); + assert_eq!(parse_timeout_query("timeout="), PROPOSAL_WAIT_DEFAULT_SECS); + assert_eq!( + parse_timeout_query("timeout=abc"), + PROPOSAL_WAIT_DEFAULT_SECS + ); + assert_eq!(parse_timeout_query("timeout=30"), 30); + assert_eq!(parse_timeout_query("foo=1&timeout=45"), 45); + // Below floor clamps up; above ceiling clamps down. + assert_eq!(parse_timeout_query("timeout=0"), PROPOSAL_WAIT_MIN_SECS); + assert_eq!(parse_timeout_query("timeout=9999"), PROPOSAL_WAIT_MAX_SECS); + } + + #[test] + fn is_terminal_status_matches_only_approved_and_rejected() { + assert!(!is_terminal_status("pending")); + assert!(is_terminal_status("approved")); + assert!(is_terminal_status("rejected")); + assert!(!is_terminal_status("")); + } + + #[test] + fn chunk_state_payload_surfaces_loop_fields() { + let chunk = PolicyChunk { + id: "chunk-x".to_string(), + status: "rejected".to_string(), + rule_name: "allow_example".to_string(), + binary: "/usr/bin/curl".to_string(), + rejection_reason: "scope too broad".to_string(), + validation_result: "no exfil paths".to_string(), + ..Default::default() + }; + let pending = chunk_state_payload(&chunk, false, false); + assert_eq!(pending["chunk_id"], "chunk-x"); + assert_eq!(pending["status"], "rejected"); + assert_eq!(pending["rejection_reason"], "scope too broad"); + assert_eq!(pending["validation_result"], "no exfil paths"); + // timed_out and policy_reloaded only appear when relevant. + assert!(pending.get("timed_out").is_none()); + assert!( + pending.get("policy_reloaded").is_none(), + "policy_reloaded is only meaningful for approved chunks" + ); + + let timed = chunk_state_payload(&chunk, true, false); + assert_eq!(timed["timed_out"], true); + } + + #[test] + fn chunk_state_payload_includes_policy_reloaded_when_approved() { + let chunk = PolicyChunk { + id: "chunk-y".to_string(), + status: "approved".to_string(), + rule_name: "allow_github".to_string(), + binary: "/usr/bin/curl".to_string(), + ..Default::default() + }; + let reloaded = chunk_state_payload(&chunk, false, true); + assert_eq!(reloaded["status"], "approved"); + assert_eq!(reloaded["policy_reloaded"], true); + + let not_reloaded = chunk_state_payload(&chunk, false, false); + assert_eq!(not_reloaded["policy_reloaded"], false); + } + + #[tokio::test] + async fn proposal_routes_reject_malformed_paths() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new(None, None, None); + + // Empty chunk_id after the prefix is 404, not a wildcard list. + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/", &[]).await; + assert_eq!(status, 404); + + // More than one segment after the id (not "/wait") is 404, not a + // partial match. Prevents `/v1/proposals/abc/extra` from silently + // dispatching as a status lookup for "abc/extra". + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/extra", &[]).await; + assert_eq!(status, 404); + + // Trailing path after `/wait` also 404 — must not match the wait + // arm as a wildcard. + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/wait/extra", &[]).await; + assert_eq!(status, 404); + } + + #[tokio::test] + async fn proposal_status_route_returns_503_when_no_gateway() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); + + let (status, body) = route_request(&ctx, "GET", "/v1/proposals/chunk-id", &[]).await; + assert_eq!(status, 503); + assert_eq!(body["error"], "gateway_unavailable"); + } + + #[tokio::test] + async fn proposal_wait_route_returns_503_when_no_gateway() { + let _guard = ProposalsFlagGuard::set(true).await; + let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); + + let (status, body) = + route_request(&ctx, "GET", "/v1/proposals/chunk-id/wait?timeout=1", &[]).await; + assert_eq!(status, 503); + assert_eq!(body["error"], "gateway_unavailable"); + } + + #[tokio::test] + async fn proposal_routes_return_feature_disabled_when_flag_off() { + let _guard = ProposalsFlagGuard::set(false).await; + let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); + + let (status, body) = route_request(&ctx, "GET", "/v1/proposals/abc", &[]).await; + assert_eq!(status, 404); + assert_eq!(body["error"], "feature_disabled"); + + let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/wait", &[]).await; + assert_eq!(status, 404); + } + + #[test] + fn summarize_chunk_for_audit_includes_endpoint_l7_path_and_binary() { + let chunk = PolicyChunk { + id: "ignored".to_string(), + rule_name: "github_write".to_string(), + binary: "/usr/bin/curl".to_string(), + proposed_rule: Some(NetworkPolicyRule { + name: "github_write".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "PUT".to_string(), + path: "/repos/foo/bar/contents/x.md".to_string(), + ..Default::default() + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }), + ..Default::default() + }; + let summary = summarize_chunk_for_audit(&chunk); + assert!(summary.contains("api.github.com:443")); + assert!(summary.contains("PUT /repos/foo/bar/contents/x.md")); + assert!(summary.contains("/usr/bin/curl")); + } + + // Helpers — synthetic proposed rule + policy with that rule already + // merged. Both reused across reload-readiness tests. + fn proposed_curl_rule_for_github() -> NetworkPolicyRule { + NetworkPolicyRule { + name: "agent_proposed".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + } + } + + fn policy_with_rule(rule: NetworkPolicyRule) -> ProtoSandboxPolicy { + ProtoSandboxPolicy { + version: 1, + network_policies: HashMap::from([(rule.name.clone(), rule)]), + ..Default::default() + } + } + + #[tokio::test] + async fn wait_returns_reloaded_true_when_rule_already_loaded() { + // John's false-sleep case: the supervisor has already reloaded a + // policy containing the proposed rule before /wait starts. A + // whole-policy diff would never see another change and burn the + // full timeout. Rule-coverage must return immediately. + let proposed = proposed_curl_rule_for_github(); + let ctx = PolicyLocalContext::new(Some(policy_with_rule(proposed.clone())), None, None); + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2); + + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + let elapsed = start.elapsed(); + + assert!(reloaded, "should report reloaded=true on coverage"); + assert!( + elapsed < std::time::Duration::from_millis(200), + "should return immediately, not poll-and-wait; took {elapsed:?}" + ); + } + + #[tokio::test] + async fn wait_does_not_wake_on_unrelated_policy_change() { + // John's false-wakeup case: a *different* rule gets added to the + // local policy (other agent's approval, settings change, etc.). + // The agent's specific rule is still not loaded. A diff-based + // check would wake here; coverage must not. + let proposed = proposed_curl_rule_for_github(); + // Start with a policy that does NOT contain the proposed rule. + let initial = ProtoSandboxPolicy { + version: 1, + ..Default::default() + }; + let ctx = PolicyLocalContext::new(Some(initial), None, None); + + // Concurrently, an unrelated rule lands. We must not return. + let unrelated_load = { + let policy = ctx.current_policy.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + *policy.write().await = Some(policy_with_rule(NetworkPolicyRule { + name: "unrelated".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + ports: vec![443], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + })); + }) + }; + + let deadline = tokio::time::Instant::now() + std::time::Duration::from_millis(400); + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + unrelated_load.await.unwrap(); + let elapsed = start.elapsed(); + + assert!( + !reloaded, + "must not wake on an unrelated reload; coverage was never satisfied" + ); + assert!( + elapsed >= std::time::Duration::from_millis(350), + "should have held until the deadline; only waited {elapsed:?}" + ); + } + + #[tokio::test] + async fn wait_wakes_when_matching_rule_arrives_mid_flight() { + // Sandbox starts without the rule, then a reload lands containing + // it. /wait should observe coverage and return reloaded=true. + let proposed = proposed_curl_rule_for_github(); + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + + let matching_load = { + let policy = ctx.current_policy.clone(); + let target = proposed.clone(); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + *policy.write().await = Some(policy_with_rule(target)); + }) + }; + + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2); + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + matching_load.await.unwrap(); + let elapsed = start.elapsed(); + + assert!(reloaded, "should report reloaded=true after coverage lands"); + assert!( + elapsed < std::time::Duration::from_millis(800), + "should return shortly after coverage; took {elapsed:?}" + ); + } + + #[tokio::test] + async fn wait_returns_reloaded_false_at_deadline_when_no_coverage() { + // Deadline budget exhausted, the proposed rule never showed up. + // Coverage check returns false — the agent gets policy_reloaded= + // false and decides whether to retry blind or re-issue /wait. + let proposed = proposed_curl_rule_for_github(); + let ctx = PolicyLocalContext::new( + Some(ProtoSandboxPolicy { + version: 1, + ..Default::default() + }), + None, + None, + ); + let deadline = tokio::time::Instant::now() + std::time::Duration::from_millis(300); + let start = tokio::time::Instant::now(); + let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; + let elapsed = start.elapsed(); + + assert!(!reloaded); + assert!( + elapsed >= std::time::Duration::from_millis(250), + "should wait until ~deadline; only waited {elapsed:?}" + ); + assert!( + elapsed < std::time::Duration::from_millis(800), + "should not extend past deadline by much; took {elapsed:?}" + ); + } + + #[test] + fn sanitize_reason_for_audit_strips_control_chars_and_caps_length() { + // Tabs and newlines are stripped; ordinary printable chars survive; + // multi-byte characters count as one char in the cap. + let raw = "line one\nline\ttwo\u{0001}\u{0007}"; + let cleaned = sanitize_reason_for_audit(raw); + assert!(!cleaned.contains('\n')); + assert!(!cleaned.contains('\t')); + assert!(!cleaned.contains('\u{0001}')); + assert!(cleaned.contains("line one")); + assert!(cleaned.contains("linetwo")); + + // Length cap with ellipsis marker so a downstream reader can tell + // the audit string is truncated. + let long: String = "x".repeat(500); + let capped = sanitize_reason_for_audit(&long); + assert!(capped.chars().count() <= 201); + assert!(capped.ends_with('…')); + + // Empty input maps to empty output (caller renders "(no guidance)"). + assert_eq!(sanitize_reason_for_audit(""), ""); + } + + #[test] + fn summarize_chunk_for_audit_falls_back_to_rule_name_without_rule() { + let chunk = PolicyChunk { + rule_name: "fallback".to_string(), + proposed_rule: None, + ..Default::default() + }; + assert_eq!(summarize_chunk_for_audit(&chunk), "rule_name:fallback"); + } +} diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs new file mode 100644 index 000000000..7782f69d7 --- /dev/null +++ b/crates/openshell-sandbox/src/proxy.rs @@ -0,0 +1,7661 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! HTTP CONNECT proxy with OPA policy evaluation and process-identity binding. + +use crate::activity_aggregator::{ActivitySender, try_record_activity}; +use crate::denial_aggregator::DenialEvent; +use crate::identity::BinaryIdentityCache; +use crate::l7::tls::ProxyTlsState; +use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; +use crate::policy::ProxyPolicy; +use crate::policy_local::{POLICY_LOCAL_HOST, PolicyLocalContext}; +use crate::provider_credentials::ProviderCredentialState; +use crate::secrets::{SecretResolver, rewrite_header_line_checked}; +use miette::{IntoDiagnostic, Result}; +use openshell_core::net::{is_always_blocked_ip, is_internal_ip, is_link_local_ip}; +use openshell_ocsf::{ + ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, + NetworkActivityBuilder, Process, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, +}; +use std::net::{IpAddr, SocketAddr}; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; +use tokio::io::{ + AsyncRead as TokioAsyncRead, AsyncReadExt, AsyncWrite as TokioAsyncWrite, AsyncWriteExt, +}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; + +const MAX_HEADER_BYTES: usize = 8192; +const INFERENCE_LOCAL_HOST: &str = "inference.local"; +const INFERENCE_LOCAL_PORT: u16 = 443; + +/// Hostnames injected by compute drivers as `/etc/hosts` aliases for the host +/// machine. Traffic to these names is eligible for the trusted-gateway SSRF +/// exemption when the resolved IP matches the driver-injected value read from +/// `/etc/hosts` at proxy startup. +const HOST_GATEWAY_ALIASES: &[&str] = &[ + "host.openshell.internal", + "host.containers.internal", + "host.docker.internal", +]; + +/// Cloud instance metadata IPs that are NEVER exempted from SSRF blocking, +/// even when they coincidentally match a host-gateway alias resolution. +/// This list covers the well-known IMDS endpoints across major cloud providers. +const CLOUD_METADATA_IPS: &[IpAddr] = &[ + // AWS / GCP / Azure instance metadata service + IpAddr::V4(std::net::Ipv4Addr::new(169, 254, 169, 254)), +]; + +/// Maximum total bytes for a streaming inference response body (32 MiB). +#[cfg(not(test))] +const MAX_STREAMING_BODY: usize = 32 * 1024 * 1024; +// Keep unit tests deterministic without pushing tens of MiB through loopback. +#[cfg(test)] +const MAX_STREAMING_BODY: usize = 1024; + +/// Idle timeout per chunk when relaying streaming inference responses. +/// +/// Reasoning models (e.g. nemotron-3-super, o1, o3) can pause for 60+ seconds +/// between "thinking" and output phases. 120s provides headroom while still +/// catching genuinely stuck streams. +#[cfg(not(test))] +const CHUNK_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); +// Exercise idle-timeout truncation without slowing the full package test suite. +#[cfg(test)] +const CHUNK_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(100); + +/// Result of a proxy CONNECT policy decision. +struct ConnectDecision { + action: NetworkAction, + /// Policy generation used for the L4 network decision. + generation: u64, + /// Resolved binary path. + binary: Option, + /// PID owning the socket. + binary_pid: Option, + /// Ancestor binary paths from process tree walk. + ancestors: Vec, + /// Cmdline-derived absolute paths (for script detection). + cmdline_paths: Vec, +} + +/// Outcome of an inference interception attempt. +/// +/// Returned by [`handle_inference_interception`] so the call site can emit +/// a structured CONNECT deny log when the connection is not successfully routed. +#[derive(Debug)] +enum InferenceOutcome { + /// At least one request was successfully routed to a local inference backend. + Routed, + /// The connection was denied (TLS failure, non-inference request, etc.). + Denied { reason: String }, +} + +/// Inference routing context for sandbox-local execution. +/// +/// Holds a `Router` (HTTP client) and cached sets of resolved routes. +/// User routes serve `inference.local` traffic; system routes are consumed +/// in-process by the supervisor for platform functions (e.g. agent harness). +pub struct InferenceContext { + pub patterns: Vec, + router: openshell_router::Router, + /// Routes for the user-facing `inference.local` endpoint. + routes: Arc>>, + /// Routes for supervisor-only system inference (`sandbox-system`). + system_routes: Arc>>, +} + +impl InferenceContext { + // `router`/`routes` are intentionally distinct nouns (the router and the + // route list it consumes); both names are clearer than alternatives. + #[allow(clippy::similar_names)] + pub fn new( + patterns: Vec, + router: openshell_router::Router, + routes: Vec, + system_routes: Vec, + ) -> Self { + Self { + patterns, + router, + routes: Arc::new(tokio::sync::RwLock::new(routes)), + system_routes: Arc::new(tokio::sync::RwLock::new(system_routes)), + } + } + + /// Get a handle to the user route cache for background refresh. + pub fn route_cache( + &self, + ) -> Arc>> { + self.routes.clone() + } + + /// Get a handle to the system route cache for background refresh. + pub fn system_route_cache( + &self, + ) -> Arc>> { + self.system_routes.clone() + } + + /// Make an inference call using system routes (supervisor-only). + /// + /// This is the in-process API for platform functions. It bypasses the + /// CONNECT proxy entirely — the supervisor calls the router directly + /// from the host network namespace. + pub async fn system_inference( + &self, + protocol: &str, + method: &str, + path: &str, + headers: Vec<(String, String)>, + body: bytes::Bytes, + ) -> Result { + let routes = self.system_routes.read().await; + self.router + .proxy_with_candidates(protocol, method, path, headers, body, &routes) + .await + } +} + +#[derive(Debug)] +pub struct ProxyHandle { + #[allow(dead_code)] + http_addr: Option, + join: JoinHandle<()>, +} + +impl ProxyHandle { + /// Start the proxy with OPA engine for policy evaluation. + /// + /// The proxy uses OPA for network decisions with process-identity binding + /// via `/proc/net/tcp`. All connections are evaluated through OPA policy. + #[allow(clippy::too_many_arguments)] + pub(crate) async fn start_with_bind_addr( + policy: &ProxyPolicy, + bind_addr: Option, + opa_engine: Arc, + identity_cache: Arc, + entrypoint_pid: Arc, + tls_state: Option>, + inference_ctx: Option>, + provider_credentials: Option, + policy_local_ctx: Option>, + denial_tx: Option>, + activity_tx: Option, + ) -> Result { + // Use override bind_addr, fall back to policy http_addr, then default + // to loopback:3128. The default allows the proxy to function when no + // network namespace is available (e.g. missing CAP_NET_ADMIN) and the + // policy doesn't specify an explicit address. + let default_addr: SocketAddr = ([127, 0, 0, 1], 3128).into(); + let http_addr = bind_addr.or(policy.http_addr).unwrap_or(default_addr); + + // Only enforce loopback restriction when not using network namespace override + if bind_addr.is_none() && !http_addr.ip().is_loopback() { + return Err(miette::miette!( + "Proxy http_addr must be loopback-only: {http_addr}" + )); + } + + let listener = TcpListener::bind(http_addr).await.into_diagnostic()?; + let local_addr = listener.local_addr().into_diagnostic()?; + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Listen) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_ip(local_addr.ip(), local_addr.port())) + .message(format!("Proxy listening on {local_addr}")) + .build(); + ocsf_emit!(event); + } + + // Detect the trusted host gateway IP from /etc/hosts before user code + // runs. This is read once at startup so later /etc/hosts modifications + // by sandbox workloads cannot influence the stored value. + let trusted_host_gateway: Arc> = Arc::new(detect_trusted_host_gateway()); + if let Some(ref ip) = *trusted_host_gateway { + tracing::info!( + %ip, + "Trusted host gateway detected from /etc/hosts; \ + host-gateway aliases exempt from SSRF always-blocked check" + ); + } + + let join = tokio::spawn(async move { + loop { + match listener.accept().await { + Ok((stream, _addr)) => { + let opa = opa_engine.clone(); + let cache = identity_cache.clone(); + let spid = entrypoint_pid.clone(); + let tls = tls_state.clone(); + let inf = inference_ctx.clone(); + let policy_local = policy_local_ctx.clone(); + let gw = trusted_host_gateway.clone(); + let resolver = provider_credentials + .as_ref() + .and_then(ProviderCredentialState::resolver); + let dynamic_credentials = provider_credentials.as_ref().map(|state| { + Arc::new(std::sync::RwLock::new( + state.snapshot().dynamic_credentials.clone(), + )) + }); + let dtx = denial_tx.clone(); + let atx = activity_tx.clone(); + tokio::spawn(async move { + if let Err(err) = handle_tcp_connection( + stream, + opa, + cache, + spid, + tls, + inf, + policy_local, + gw, + resolver, + dynamic_credentials, + dtx, + atx, + ) + .await + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("Proxy connection error: {err}")) + .build(); + ocsf_emit!(event); + } + }); + } + Err(err) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("Proxy accept error: {err}")) + .build(); + ocsf_emit!(event); + break; + } + } + } + }); + + Ok(Self { + http_addr: Some(local_addr), + join, + }) + } + + #[allow(dead_code)] + pub const fn http_addr(&self) -> Option { + self.http_addr + } +} + +impl Drop for ProxyHandle { + fn drop(&mut self) { + self.join.abort(); + } +} + +fn emit_activity(tx: &Option, denied: bool, deny_group: &'static str) { + if let Some(tx) = tx { + let _ = try_record_activity(tx, denied, deny_group); + } +} + +fn l7_inspection_active(l7_route: Option<&L7RouteSnapshot>) -> bool { + l7_route.is_some_and(|route| !route.configs.is_empty()) +} + +fn emit_connect_activity_if_l4_only( + tx: &Option, + l7_route: Option<&L7RouteSnapshot>, +) { + if !l7_inspection_active(l7_route) { + emit_activity(tx, false, "unknown"); + } +} + +fn emit_activity_simple(tx: Option<&ActivitySender>, denied: bool, deny_group: &'static str) { + if let Some(tx) = tx { + let _ = try_record_activity(tx, denied, deny_group); + } +} + +fn emit_forward_success_activity(tx: Option<&ActivitySender>, l7_activity_pending: bool) { + emit_activity_simple( + tx, + false, + if l7_activity_pending { + "l7_policy" + } else { + "unknown" + }, + ); +} + +/// Emit a denial event to the aggregator channel (if configured). +/// Used by `handle_tcp_connection` which owns `Option`. +fn emit_denial( + tx: &Option>, + host: &str, + port: u16, + binary: &str, + decision: &ConnectDecision, + reason: &str, + stage: &str, +) { + if let Some(tx) = tx { + let _ = tx.send(DenialEvent { + host: host.to_string(), + port, + binary: binary.to_string(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.display().to_string()) + .collect(), + deny_reason: reason.to_string(), + denial_stage: stage.to_string(), + l7_method: None, + l7_path: None, + }); + } +} + +/// Emit a denial event from a borrowed sender reference. +/// Used by `handle_forward_proxy` which borrows `Option<&Sender>`. +fn emit_denial_simple( + tx: Option<&mpsc::UnboundedSender>, + host: &str, + port: u16, + binary: &str, + decision: &ConnectDecision, + reason: &str, + stage: &str, +) { + if let Some(tx) = tx { + let _ = tx.send(DenialEvent { + host: host.to_string(), + port, + binary: binary.to_string(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.display().to_string()) + .collect(), + deny_reason: reason.to_string(), + denial_stage: stage.to_string(), + l7_method: None, + l7_path: None, + }); + } +} + +// Many distinct, non-related context parameters are required for a CONNECT +// dispatch; bundling them into a struct would just shift the noise into call +// sites. +#[allow(clippy::too_many_arguments)] +async fn handle_tcp_connection( + mut client: TcpStream, + opa_engine: Arc, + identity_cache: Arc, + entrypoint_pid: Arc, + tls_state: Option>, + inference_ctx: Option>, + policy_local_ctx: Option>, + trusted_host_gateway: Arc>, + secret_resolver: Option>, + dynamic_credentials: Option< + Arc< + std::sync::RwLock< + std::collections::HashMap, + >, + >, + >, + denial_tx: Option>, + activity_tx: Option, +) -> Result<()> { + let mut buf = vec![0u8; MAX_HEADER_BYTES]; + let mut used = 0usize; + + loop { + if used == buf.len() { + respond( + &mut client, + b"HTTP/1.1 431 Request Header Fields Too Large\r\n\r\n", + ) + .await?; + return Ok(()); + } + + let n = client.read(&mut buf[used..]).await.into_diagnostic()?; + if n == 0 { + return Ok(()); + } + used += n; + + if buf[..used].windows(4).any(|win| win == b"\r\n\r\n") { + break; + } + } + + let request = String::from_utf8_lossy(&buf[..used]); + let mut lines = request.split("\r\n"); + let request_line = lines.next().unwrap_or(""); + let mut parts = request_line.split_whitespace(); + let method = parts.next().unwrap_or(""); + let target = parts.next().unwrap_or(""); + + if method != "CONNECT" { + return handle_forward_proxy( + method, + target, + &buf[..], + used, + &mut client, + opa_engine, + identity_cache, + entrypoint_pid, + policy_local_ctx, + trusted_host_gateway, + secret_resolver, + dynamic_credentials, + denial_tx.as_ref(), + activity_tx.as_ref(), + ) + .await; + } + + let (host, port) = parse_target(target)?; + let host_lc = host.to_ascii_lowercase(); + + if host_lc == INFERENCE_LOCAL_HOST && port == INFERENCE_LOCAL_PORT { + respond(&mut client, b"HTTP/1.1 200 Connection Established\r\n\r\n").await?; + let outcome = handle_inference_interception( + client, + INFERENCE_LOCAL_HOST, + port, + tls_state.as_ref(), + inference_ctx.as_ref(), + ) + .await?; + if let InferenceOutcome::Denied { reason } = outcome { + emit_activity(&activity_tx, true, "forward_policy"); + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, port)) + .message(format!("Inference interception denied: {reason}")) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + return Ok(()); + } + + let peer_addr = client.peer_addr().into_diagnostic()?; + let _local_addr = client.local_addr().into_diagnostic()?; + + // Evaluate OPA policy with process-identity binding. + // Wrapped in spawn_blocking because identity resolution does heavy sync I/O: + // /proc scanning + SHA256 hashing of binaries (e.g. node at 124MB). + let opa_clone = opa_engine.clone(); + let cache_clone = identity_cache.clone(); + let pid_clone = entrypoint_pid.clone(); + let host_clone = host_lc.clone(); + let decision = tokio::task::spawn_blocking(move || { + evaluate_opa_tcp( + peer_addr, + &opa_clone, + &cache_clone, + &pid_clone, + &host_clone, + port, + ) + }) + .await + .map_err(|e| miette::miette!("identity resolution task panicked: {e}"))?; + + // Extract action string and matched policy for logging + let (matched_policy, deny_reason) = match &decision.action { + NetworkAction::Allow { matched_policy } => (matched_policy.clone(), String::new()), + NetworkAction::Deny { reason } => (None, reason.clone()), + }; + + // Build log context fields (shared by deny log below and deferred allow log after L7 check) + let binary_str = decision + .binary + .as_ref() + .map_or_else(|| "-".to_string(), |p| p.display().to_string()); + let pid_str = decision + .binary_pid + .map_or_else(|| "-".to_string(), |p| p.to_string()); + let ancestors_str = if decision.ancestors.is_empty() { + "-".to_string() + } else { + decision + .ancestors + .iter() + .map(|p| p.display().to_string()) + .collect::>() + .join(" -> ") + }; + let cmdline_str = if decision.cmdline_paths.is_empty() { + "-".to_string() + } else { + decision + .cmdline_paths + .iter() + .map(|p| p.display().to_string()) + .collect::>() + .join(", ") + }; + let policy_str = matched_policy.as_deref().unwrap_or("-"); + + // Log denied connections immediately — they never reach L7. + // Allowed connections are logged after the L7 config check (below) + // so we can distinguish CONNECT (L4-only) from CONNECT_L7 (L7 follows). + if matches!(decision.action, NetworkAction::Deny { .. }) { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "opa") + .message(format!("CONNECT denied {host_lc}:{port}")) + .status_detail(&deny_reason) + .build(); + ocsf_emit!(event); + emit_denial( + &denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &deny_reason, + "connect", + ); + emit_activity(&activity_tx, true, "connect_policy"); + respond( + &mut client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("CONNECT {host_lc}:{port} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + + let sandbox_entrypoint_pid = entrypoint_pid.load(Ordering::Acquire); + + // Query allowed_ips from the matched endpoint config (if any). + // When present, the SSRF check validates resolved IPs against this + // allowlist instead of blanket-blocking all private IPs. + // When the policy host is already a literal IP address, treat it as + // implicitly allowed — the user explicitly declared the destination. + // Exact declared hostnames also skip the private-IP blanket block below, + // while keeping loopback/link-local/unspecified addresses denied. + let mut raw_allowed_ips = query_allowed_ips(&opa_engine, &decision, &host_lc, port); + if raw_allowed_ips.is_empty() { + raw_allowed_ips = implicit_allowed_ips_for_ip_host(&host); + } + let exact_declared_endpoint_host = + query_exact_declared_endpoint_host(&opa_engine, &decision, &host_lc, port); + + // Defense-in-depth: resolve DNS and reject connections to internal IPs. + let dns_connect_start = std::time::Instant::now(); + // The "non-empty" branch is the explicit-allowlist path; reading it first + // matches the policy decision narrative. + #[allow(clippy::if_not_else)] + let mut upstream = if is_host_gateway_alias(&host_lc) + && let Some(gw) = *trusted_host_gateway + { + // Trusted host-gateway path. The compute driver injected this hostname + // into /etc/hosts pointing at a known IP (read at proxy startup before + // user code runs). Bypass the normal SSRF tiers so link-local gateway + // addresses (used by rootless Podman with pasta) are not hard-blocked. + // Cloud metadata IPs and control-plane ports are still rejected. + match resolve_and_check_trusted_gateway(&host, port, gw, sandbox_entrypoint_pid).await { + Ok(addrs) => TcpStream::connect(addrs.as_slice()) + .await + .into_diagnostic()?, + Err(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: trusted-gateway check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial( + &denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity(&activity_tx, true, "ssrf"); + respond( + &mut client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("CONNECT {host_lc}:{port} blocked: trusted-gateway check failed"), + ), + ) + .await?; + return Ok(()); + } + } + } else if !raw_allowed_ips.is_empty() { + // allowed_ips mode: validate resolved IPs against CIDR allowlist. + // Loopback and link-local are still always blocked. + match parse_allowed_ips(&raw_allowed_ips) { + Ok(nets) => { + match resolve_and_check_allowed_ips(&host, port, &nets, sandbox_entrypoint_pid) + .await + { + Ok(addrs) => TcpStream::connect(addrs.as_slice()) + .await + .into_diagnostic()?, + Err(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: allowed_ips check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial( + &denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity(&activity_tx, true, "ssrf"); + respond( + &mut client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!( + "CONNECT {host_lc}:{port} blocked: allowed_ips check failed" + ), + ), + ) + .await?; + return Ok(()); + } + } + } + Err(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: invalid allowed_ips in policy for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial( + &denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity(&activity_tx, true, "ssrf"); + respond( + &mut client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("CONNECT {host_lc}:{port} blocked: invalid allowed_ips in policy"), + ), + ) + .await?; + return Ok(()); + } + } + } else if exact_declared_endpoint_host { + // Exact declared hostname mode: the operator explicitly allowed this + // host:port, so private IP resolution is permitted without duplicating + // the resolved IP in allowed_ips. Always-blocked addresses and + // control-plane ports remain denied. + match resolve_and_check_declared_endpoint(&host, port, sandbox_entrypoint_pid).await { + Ok(addrs) => TcpStream::connect(addrs.as_slice()) + .await + .into_diagnostic()?, + Err(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: declared endpoint check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial( + &denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + respond( + &mut client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!( + "CONNECT {host_lc}:{port} blocked: declared endpoint check failed" + ), + ), + ) + .await?; + return Ok(()); + } + } + } else { + // Default: reject all internal IPs (loopback, RFC 1918, link-local). + match resolve_and_reject_internal(&host, port, sandbox_entrypoint_pid).await { + Ok(addrs) => TcpStream::connect(addrs.as_slice()) + .await + .into_diagnostic()?, + Err(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "ssrf") + .message(format!( + "CONNECT blocked: internal address {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial( + &denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity(&activity_tx, true, "ssrf"); + respond( + &mut client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("CONNECT {host_lc}:{port} blocked: internal address"), + ), + ) + .await?; + return Ok(()); + } + } + }; + + debug!( + "handle_tcp_connection dns_resolve_and_tcp_connect: {}ms host={host_lc}", + dns_connect_start.elapsed().as_millis() + ); + + respond(&mut client, b"HTTP/1.1 200 Connection Established\r\n\r\n").await?; + + // Check if endpoint has L7 config for protocol-aware inspection, and + // retain the generation for HTTP passthrough keep-alive tunnels. + let l7_route = query_l7_route_snapshot(&opa_engine, &decision, &host_lc, port); + let should_inspect_l7 = l7_inspection_active(l7_route.as_ref()); + + // Log the allowed CONNECT — use CONNECT_L7 when L7 inspection follows, + // so log consumers can distinguish L4-only decisions from tunnel lifecycle events. + let connect_msg = if should_inspect_l7 { + "CONNECT_L7" + } else { + "CONNECT" + }; + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "opa") + .message(format!("{connect_msg} allowed {host_lc}:{port}")) + .build(); + ocsf_emit!(event); + } + emit_connect_activity_if_l4_only(&activity_tx, l7_route.as_ref()); + + // Determine effective TLS mode. Check the raw endpoint config for + // `tls: skip` independently of L7 config (which requires `protocol`). + let effective_tls_skip = + query_tls_mode(&opa_engine, &decision, &host_lc, port) == crate::l7::TlsMode::Skip; + + // Build L7 eval context (shared by TLS-terminated and plaintext paths). + let ctx = crate::l7::relay::L7EvalContext { + host: host_lc.clone(), + port, + policy_name: matched_policy.clone().unwrap_or_default(), + binary_path: decision + .binary + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + cmdline_paths: decision + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + secret_resolver: secret_resolver.clone(), + activity_tx: activity_tx.clone(), + dynamic_credentials: dynamic_credentials.clone(), + token_grant_resolver: dynamic_credentials + .as_ref() + .map(|_| crate::l7::token_grant_injection::default_resolver()), + }; + + if effective_tls_skip { + // tls: skip — raw tunnel, no termination, no credential injection. + debug!( + host = %host_lc, + port = port, + "tls: skip — bypassing TLS auto-detection, raw tunnel" + ); + let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream) + .await + .into_diagnostic()?; + return Ok(()); + } + + // Auto-detect TLS by peeking the first bytes. + let mut peek_buf = [0u8; 8]; + let n = client.peek(&mut peek_buf).await.into_diagnostic()?; + if n == 0 { + return Ok(()); + } + + let is_tls = crate::l7::tls::looks_like_tls(&peek_buf[..n]); + let is_http = crate::l7::rest::looks_like_http(&peek_buf[..n]); + + if is_tls { + // TLS detected — terminate unconditionally. + if let Some(ref tls) = tls_state { + let tls_result = async { + let mut tls_client = + crate::l7::tls::tls_terminate_client(client, tls, &host_lc).await?; + let mut tls_upstream = + crate::l7::tls::tls_connect_upstream(upstream, &host_lc, tls.upstream_config()) + .await?; + + if let Some(route) = l7_route.as_ref().filter(|route| !route.configs.is_empty()) { + // L7 inspection on terminated TLS traffic. + let tunnel_engine = match opa_engine.clone_engine_for_tunnel(route.generation) { + Ok(engine) => engine, + Err(e) => { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + return Ok(()); + } + }; + if route.configs.len() == 1 { + crate::l7::relay::relay_with_inspection( + &route.configs[0].config, + tunnel_engine, + &mut tls_client, + &mut tls_upstream, + &ctx, + ) + .await + } else { + let configs: Vec = route + .configs + .iter() + .map(|snapshot| snapshot.config.clone()) + .collect(); + crate::l7::relay::relay_with_route_selection( + &configs, + tunnel_engine, + &mut tls_client, + &mut tls_upstream, + &ctx, + ) + .await + } + } else { + // No L7 config — relay with credential injection only. + let generation = l7_route + .as_ref() + .map_or(decision.generation, |route| route.generation); + let generation_guard = match opa_engine.generation_guard(generation) { + Ok(guard) => guard, + Err(e) => { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + return Ok(()); + } + }; + crate::l7::relay::relay_passthrough_with_credentials( + &mut tls_client, + &mut tls_upstream, + &ctx, + &generation_guard, + ) + .await + } + }; + if let Err(e) = tls_result.await { + if is_benign_relay_error(&e) { + debug!( + host = %host_lc, + port = port, + error = %e, + "TLS connection closed" + ); + } else { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("TLS relay error: {e}")) + .build(); + ocsf_emit!(event); + } + } + } else { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!( + "TLS detected but TLS state not configured for {host_lc}:{port}, falling back to raw tunnel" + )) + .build(); + ocsf_emit!(event); + } + let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream) + .await + .into_diagnostic()?; + } + } else if is_http { + // Plaintext HTTP detected. + if let Some(route) = l7_route.as_ref().filter(|route| !route.configs.is_empty()) { + let tunnel_engine = match opa_engine.clone_engine_for_tunnel(route.generation) { + Ok(engine) => engine, + Err(e) => { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + return Ok(()); + } + }; + let relay_result = if route.configs.len() == 1 { + crate::l7::relay::relay_with_inspection( + &route.configs[0].config, + tunnel_engine, + &mut client, + &mut upstream, + &ctx, + ) + .await + } else { + let configs: Vec = route + .configs + .iter() + .map(|snapshot| snapshot.config.clone()) + .collect(); + crate::l7::relay::relay_with_route_selection( + &configs, + tunnel_engine, + &mut client, + &mut upstream, + &ctx, + ) + .await + }; + if let Err(e) = relay_result { + if is_benign_relay_error(&e) { + debug!(host = %host_lc, port = port, error = %e, "L7 connection closed"); + } else { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("L7 relay error: {e}")) + .build(); + ocsf_emit!(event); + } + } + } else { + // Plaintext HTTP, no L7 config — relay with credential injection. + let generation = l7_route + .as_ref() + .map_or(decision.generation, |route| route.generation); + let generation_guard = match opa_engine.generation_guard(generation) { + Ok(guard) => guard, + Err(e) => { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + return Ok(()); + } + }; + if let Err(e) = crate::l7::relay::relay_passthrough_with_credentials( + &mut client, + &mut upstream, + &ctx, + &generation_guard, + ) + .await + { + if is_benign_relay_error(&e) { + debug!(host = %host_lc, port = port, error = %e, "HTTP relay closed"); + } else { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("HTTP relay error: {e}")) + .build(); + ocsf_emit!(event); + } + } + } + } else { + // Neither TLS nor HTTP — raw binary relay. + debug!( + host = %host_lc, + port = port, + "Non-TLS non-HTTP traffic detected, raw tunnel" + ); + let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream) + .await + .into_diagnostic()?; + } + + Ok(()) +} + +/// Resolved process identity for a TCP peer: binary path, PID, ancestor chain, +/// cmdline paths, and the TOFU-verified binary hash. +/// +/// Produced by [`resolve_process_identity`]; consumed by [`evaluate_opa_tcp`] +/// and by the identity-chain regression tests. +#[cfg(target_os = "linux")] +struct ResolvedIdentity { + bin_path: PathBuf, + binary_pid: u32, + ancestors: Vec, + cmdline_paths: Vec, + bin_hash: String, +} + +#[cfg(target_os = "linux")] +#[derive(Debug, Eq, PartialEq)] +struct PolicyIdentityKey { + bin_path: PathBuf, + ancestors: Vec, + cmdline_paths: Vec, + bin_hash: String, +} + +#[cfg(target_os = "linux")] +impl ResolvedIdentity { + fn policy_key(&self) -> PolicyIdentityKey { + PolicyIdentityKey { + bin_path: self.bin_path.clone(), + ancestors: self.ancestors.clone(), + cmdline_paths: self.cmdline_paths.clone(), + bin_hash: self.bin_hash.clone(), + } + } +} + +/// Error from [`resolve_process_identity`]. Carries the deny reason and +/// whatever partial identity data was resolved before the failure so the +/// caller can include it in the [`ConnectDecision`] and OCSF event. +#[cfg(target_os = "linux")] +struct IdentityError { + reason: String, + binary: Option, + binary_pid: Option, + ancestors: Vec, +} + +#[cfg(target_os = "linux")] +fn resolve_owner_identity( + owner_pid: u32, + entrypoint_pid: u32, + identity_cache: &BinaryIdentityCache, +) -> std::result::Result { + let bin_path = + crate::procfs::binary_path(owner_pid.cast_signed()).map_err(|e| IdentityError { + reason: format!("failed to resolve peer binary for PID {owner_pid}: {e}"), + binary: None, + binary_pid: Some(owner_pid), + ancestors: vec![], + })?; + + let bin_hash = identity_cache + .verify_or_cache(&bin_path) + .map_err(|e| IdentityError { + reason: format!("binary integrity check failed: {e}"), + binary: Some(bin_path.clone()), + binary_pid: Some(owner_pid), + ancestors: vec![], + })?; + + let ancestors = crate::procfs::collect_ancestor_binaries(owner_pid, entrypoint_pid); + + for ancestor in &ancestors { + identity_cache + .verify_or_cache(ancestor) + .map_err(|e| IdentityError { + reason: format!( + "ancestor integrity check failed for {}: {e}", + ancestor.display() + ), + binary: Some(bin_path.clone()), + binary_pid: Some(owner_pid), + ancestors: ancestors.clone(), + })?; + } + + let mut exclude = ancestors.clone(); + exclude.push(bin_path.clone()); + let cmdline_paths = crate::procfs::collect_cmdline_paths(owner_pid, entrypoint_pid, &exclude); + + Ok(ResolvedIdentity { + bin_path, + binary_pid: owner_pid, + ancestors, + cmdline_paths, + bin_hash, + }) +} + +/// Resolve the identity of the process owning a TCP peer connection. +/// +/// Walks `/proc//net/tcp` to find the socket inode, locates +/// every owning PID, reads `/proc//exe`, TOFU-verifies each binary hash, +/// walks each ancestor chain verifying every ancestor, and collects +/// cmdline-derived absolute paths for script detection. +/// +/// This is the identity-resolution block of [`evaluate_opa_tcp`] extracted +/// into a standalone helper so it can be exercised by Linux-only regression +/// tests without a full OPA engine. The key invariant under test is that on +/// a hot-swap of the peer binary, the failure mode is +/// `"Binary integrity violation"` (from the identity cache) rather than +/// `"Failed to stat ... (deleted)"` (from the kernel-tainted path). +#[cfg(target_os = "linux")] +fn resolve_process_identity( + entrypoint_pid: u32, + peer_port: u16, + identity_cache: &BinaryIdentityCache, +) -> std::result::Result { + let socket_owners = crate::procfs::resolve_tcp_peer_socket_owners(entrypoint_pid, peer_port) + .map_err(|e| IdentityError { + reason: format!("failed to resolve peer binary: {e}"), + binary: None, + binary_pid: None, + ancestors: vec![], + })?; + + let mut identities = Vec::with_capacity(socket_owners.owners.len()); + for owner in &socket_owners.owners { + identities.push(resolve_owner_identity( + owner.pid, + entrypoint_pid, + identity_cache, + )?); + } + + let Some(first_identity) = identities.first() else { + return Err(IdentityError { + reason: format!( + "failed to resolve peer binary: no process found owning socket inode {}", + socket_owners.inode + ), + binary: None, + binary_pid: None, + ancestors: vec![], + }); + }; + + let first_key = first_identity.policy_key(); + if identities + .iter() + .skip(1) + .any(|identity| identity.policy_key() != first_key) + { + let mut pids: Vec = identities + .iter() + .map(|identity| identity.binary_pid) + .collect(); + pids.sort_unstable(); + return Err(IdentityError { + reason: format!( + "ambiguous shared socket ownership: inode {} is held by PIDs [{}] with different policy identities", + socket_owners.inode, + pids.iter() + .map(u32::to_string) + .collect::>() + .join(", ") + ), + binary: None, + binary_pid: None, + ancestors: vec![], + }); + } + + let mut identity = identities.swap_remove(0); + if let Some(lowest_pid) = socket_owners.owners.iter().map(|owner| owner.pid).min() { + identity.binary_pid = lowest_pid; + } + Ok(identity) +} + +/// Evaluate OPA policy for a TCP connection with identity binding via /proc/net/tcp. +#[cfg(target_os = "linux")] +fn evaluate_opa_tcp( + peer_addr: SocketAddr, + engine: &OpaEngine, + identity_cache: &BinaryIdentityCache, + entrypoint_pid: &AtomicU32, + host: &str, + port: u16, +) -> ConnectDecision { + use crate::opa::NetworkInput; + use std::sync::atomic::Ordering; + + let deny = |reason: String, + binary: Option, + binary_pid: Option, + ancestors: Vec, + cmdline_paths: Vec| + -> ConnectDecision { + ConnectDecision { + action: NetworkAction::Deny { reason }, + generation: engine.current_generation(), + binary, + binary_pid, + ancestors, + cmdline_paths, + } + }; + + let pid = entrypoint_pid.load(Ordering::Acquire); + if pid == 0 { + return deny( + "entrypoint process not yet spawned".into(), + None, + None, + vec![], + vec![], + ); + } + + let total_start = std::time::Instant::now(); + let peer_port = peer_addr.port(); + + let identity = match resolve_process_identity(pid, peer_port, identity_cache) { + Ok(id) => id, + Err(err) => { + return deny( + err.reason, + err.binary, + err.binary_pid, + err.ancestors, + vec![], + ); + } + }; + + let ResolvedIdentity { + bin_path, + binary_pid, + ancestors, + cmdline_paths, + bin_hash, + } = identity; + + let input = NetworkInput { + host: host.to_string(), + port, + binary_path: bin_path.clone(), + binary_sha256: bin_hash, + ancestors: ancestors.clone(), + cmdline_paths: cmdline_paths.clone(), + }; + + let result = match engine.evaluate_network_action_with_generation(&input) { + Ok((action, generation)) => ConnectDecision { + action, + generation, + binary: Some(bin_path), + binary_pid: Some(binary_pid), + ancestors, + cmdline_paths, + }, + Err(e) => deny( + format!("policy evaluation error: {e}"), + Some(bin_path), + Some(binary_pid), + ancestors, + cmdline_paths, + ), + }; + debug!( + "evaluate_opa_tcp TOTAL: {}ms host={host} port={port}", + total_start.elapsed().as_millis() + ); + result +} + +/// Non-Linux stub: OPA identity binding requires /proc. +#[cfg(not(target_os = "linux"))] +fn evaluate_opa_tcp( + _peer_addr: SocketAddr, + engine: &OpaEngine, + _identity_cache: &BinaryIdentityCache, + _entrypoint_pid: &AtomicU32, + _host: &str, + _port: u16, +) -> ConnectDecision { + ConnectDecision { + action: NetworkAction::Deny { + reason: "identity binding unavailable on this platform".into(), + }, + generation: engine.current_generation(), + binary: None, + binary_pid: None, + ancestors: vec![], + cmdline_paths: vec![], + } +} + +/// Maximum buffer size for inference request parsing (10 MiB). +const MAX_INFERENCE_BUF: usize = 10 * 1024 * 1024; + +/// Initial buffer size for inference request parsing (64 KiB). +const INITIAL_INFERENCE_BUF: usize = 65536; + +/// Handle an intercepted connection for inference routing. +/// +/// TLS-terminates the client connection, parses HTTP requests, and executes +/// inference API calls locally via `openshell-router`. +/// Non-inference requests are denied with 403. +/// +/// Returns [`InferenceOutcome::Routed`] if at least one request was successfully +/// routed, or [`InferenceOutcome::Denied`] with a reason for all denial cases. +async fn handle_inference_interception( + client: TcpStream, + host: &str, + port: u16, + tls_state: Option<&Arc>, + inference_ctx: Option<&Arc>, +) -> Result { + let Some(ctx) = inference_ctx else { + return Ok(InferenceOutcome::Denied { + reason: "cluster inference context not configured".to_string(), + }); + }; + + let Some(tls) = tls_state else { + return Ok(InferenceOutcome::Denied { + reason: "missing TLS state".to_string(), + }); + }; + + // TLS-terminate the client side (present a cert for the target host) + let mut tls_client = match crate::l7::tls::tls_terminate_client(client, tls, host).await { + Ok(c) => c, + Err(e) => { + return Ok(InferenceOutcome::Denied { + reason: format!("TLS handshake failed: {e}"), + }); + } + }; + + process_inference_keepalive(&mut tls_client, ctx, port).await +} + +/// Read and process HTTP requests from a TLS-terminated inference connection. +/// +/// Each request is matched against inference patterns and routed locally. +/// Any non-inference request is immediately denied and the connection is closed, +/// even if previous requests on the same keep-alive connection were routed +/// successfully. +async fn process_inference_keepalive( + stream: &mut S, + ctx: &InferenceContext, + port: u16, +) -> Result { + use crate::l7::inference::{ParseResult, format_http_response, try_parse_http_request}; + + let mut buf = vec![0u8; INITIAL_INFERENCE_BUF]; + let mut used = 0usize; + let mut routed_any = false; + + loop { + let n = match stream.read(&mut buf[used..]).await { + Ok(n) => n, + Err(e) => { + if routed_any { + break; + } + return Ok(InferenceOutcome::Denied { + reason: format!("I/O error: {e}"), + }); + } + }; + if n == 0 { + if routed_any { + break; + } + return Ok(InferenceOutcome::Denied { + reason: "client closed connection".to_string(), + }); + } + used += n; + + // Try to parse a complete HTTP request + match try_parse_http_request(&buf[..used]) { + ParseResult::Complete(request, consumed) => { + let was_routed = route_inference_request(&request, ctx, stream).await?; + if was_routed { + routed_any = true; + } else { + // Deny and close: a non-inference request must not be silently + // ignored on a keep-alive connection that previously routed + // inference traffic. + return Ok(InferenceOutcome::Denied { + reason: "connection not allowed by policy".to_string(), + }); + } + + // Shift buffer for next request + buf.copy_within(consumed..used, 0); + used -= consumed; + } + ParseResult::Incomplete => { + // Need more data — grow buffer if full + if used == buf.len() { + if buf.len() >= MAX_INFERENCE_BUF { + let response = format_http_response(413, &[], b"Payload Too Large"); + write_all(stream, &response).await?; + if routed_any { + break; + } + return Ok(InferenceOutcome::Denied { + reason: "payload too large".to_string(), + }); + } + buf.resize((buf.len() * 2).min(MAX_INFERENCE_BUF), 0); + } + } + ParseResult::Invalid(reason) => { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Rejected) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, port)) + .message(format!("Rejecting malformed inference request: {reason}")) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + let response = format_http_response(400, &[], b"Bad Request"); + write_all(stream, &response).await?; + return Ok(InferenceOutcome::Denied { reason }); + } + } + } + + Ok(InferenceOutcome::Routed) +} + +/// Route a parsed inference request locally via the sandbox router, or deny it. +/// +/// Returns `Ok(true)` if the request was routed to an inference backend, +/// `Ok(false)` if it was denied as a non-inference request. +async fn route_inference_request( + request: &crate::l7::inference::ParsedHttpRequest, + ctx: &InferenceContext, + tls_client: &mut (impl tokio::io::AsyncWrite + Unpin), +) -> Result { + use crate::l7::inference::{detect_inference_pattern, format_http_response}; + + let normalized_path = normalize_inference_path(&request.path); + + if let Some(pattern) = + detect_inference_pattern(&request.method, &normalized_path, &ctx.patterns) + { + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Allowed) + .disposition(DispositionId::Detected) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "Intercepted inference request, routing locally: {} {} (protocol={}, kind={})", + request.method, normalized_path, pattern.protocol, pattern.kind + )) + .build(); + ocsf_emit!(event); + } + + let routes = ctx.routes.read().await; + + if routes.is_empty() { + let body = serde_json::json!({ + "error": "cluster inference is not configured", + "hint": "run: openshell cluster inference set --help" + }); + let body_bytes = body.to_string(); + let response = format_http_response( + 503, + &[("content-type".to_string(), "application/json".to_string())], + body_bytes.as_bytes(), + ); + write_all(tls_client, &response).await?; + return Ok(true); + } + + // Buffered protocols (embeddings, model discovery) return a single JSON + // object, not an SSE token stream. Serve them buffered with an accurate + // Content-Length: the streaming path would append an SSE error frame to + // the body on a size-cap or idle-timeout truncation, corrupting a + // payload the client parses as one JSON object. Framing is declared per + // protocol on the matched pattern. + if pattern.is_buffered() { + match ctx + .router + .proxy_with_candidates( + &pattern.protocol, + &request.method, + &normalized_path, + request.headers.clone(), + bytes::Bytes::from(request.body.clone()), + &routes, + ) + .await + { + Ok(resp) => { + let resp_headers = sanitize_inference_response_headers(resp.headers); + let response = format_http_response(resp.status, &resp_headers, &resp.body); + write_all(tls_client, &response).await?; + } + Err(e) => write_inference_router_error(tls_client, &e).await?, + } + return Ok(true); + } + + match ctx + .router + .proxy_with_candidates_streaming( + &pattern.protocol, + &request.method, + &normalized_path, + request.headers.clone(), + bytes::Bytes::from(request.body.clone()), + &routes, + ) + .await + { + Ok(mut resp) => { + use crate::l7::inference::{ + format_chunk, format_chunk_terminator, format_http_response_header, + format_sse_error, + }; + + let resp_headers = sanitize_inference_response_headers( + std::mem::take(&mut resp.headers).into_iter().collect(), + ); + + // Write response headers immediately (chunked TE). + let header_bytes = format_http_response_header(resp.status, &resp_headers); + write_all(tls_client, &header_bytes).await?; + + // Stream body chunks with byte cap and idle timeout. + // + // Each upstream chunk is wrapped in HTTP chunked framing and + // flushed immediately so SSE events reach the client without + // delay. Unlike the previous per-byte write_all+flush, we + // coalesce the framing header + data + trailer into a single + // write_all call, reducing the number of TLS records per chunk + // from 3 to 1 while preserving incremental delivery. + let mut total_bytes: usize = 0; + loop { + match tokio::time::timeout(CHUNK_IDLE_TIMEOUT, resp.next_chunk()).await { + Ok(Ok(Some(chunk))) => { + total_bytes += chunk.len(); + if total_bytes > MAX_STREAMING_BODY { + warn!( + total_bytes = total_bytes, + limit = MAX_STREAMING_BODY, + "streaming response exceeded byte limit, truncating" + ); + let err = format_sse_error( + "response truncated: exceeded maximum streaming body size", + ); + let _ = write_all(tls_client, &format_chunk(&err)).await; + break; + } + let encoded = format_chunk(&chunk); + write_all(tls_client, &encoded).await?; + } + Ok(Ok(None)) => break, + Ok(Err(e)) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "error reading upstream response chunk after \ + {total_bytes} bytes: {e}" + )) + .build(); + ocsf_emit!(event); + let err = format_sse_error("response truncated: upstream read error"); + let _ = write_all(tls_client, &format_chunk(&err)).await; + break; + } + Err(_) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "streaming response chunk idle timeout after \ + {total_bytes} bytes, closing" + )) + .build(); + ocsf_emit!(event); + let err = + format_sse_error("response truncated: chunk idle timeout exceeded"); + let _ = write_all(tls_client, &format_chunk(&err)).await; + break; + } + } + } + + // Terminate the chunked stream. + write_all(tls_client, format_chunk_terminator()).await?; + } + Err(e) => write_inference_router_error(tls_client, &e).await?, + } + Ok(true) + } else { + // Not an inference request — deny + { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "connection not allowed by policy: {} {}", + request.method, normalized_path + )) + .build(); + ocsf_emit!(event); + } + let body = serde_json::json!({"error": "connection not allowed by policy"}); + let body_bytes = body.to_string(); + let response = format_http_response( + 403, + &[("content-type".to_string(), "application/json".to_string())], + body_bytes.as_bytes(), + ); + write_all(tls_client, &response).await?; + Ok(false) + } +} + +/// Emit an OCSF failure event and write a buffered JSON error response for a +/// router error hit while proxying an inference request. +/// +/// Shared by the streaming and buffered routing paths so both surface upstream +/// failures with the same status mapping and the same audit record. +async fn write_inference_router_error( + tls_client: &mut (impl tokio::io::AsyncWrite + Unpin), + err: &openshell_router::RouterError, +) -> Result<()> { + use crate::l7::inference::format_http_response; + + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) + .message(format!( + "inference endpoint detected but upstream service failed: {err}" + )) + .build(); + ocsf_emit!(event); + + let (status, msg) = router_error_to_http(err); + let body = serde_json::json!({ "error": msg }).to_string(); + let response = format_http_response( + status, + &[("content-type".to_string(), "application/json".to_string())], + body.as_bytes(), + ); + write_all(tls_client, &response).await +} + +/// Map router errors to HTTP status codes and sanitized messages. +/// +/// Returns generic, client-safe messages instead of verbatim internal details; +/// the full error is recorded in the OCSF failure event by the caller. +fn router_error_to_http(err: &openshell_router::RouterError) -> (u16, String) { + use openshell_router::RouterError; + match err { + RouterError::RouteNotFound(_) => (400, "no inference route configured".to_string()), + RouterError::NoCompatibleRoute(_) => { + (400, "no compatible inference route available".to_string()) + } + RouterError::Unauthorized(_) => (401, "unauthorized".to_string()), + RouterError::UpstreamUnavailable(_) => (503, "inference service unavailable".to_string()), + RouterError::UpstreamProtocol(_) | RouterError::Internal(_) => { + (502, "inference service error".to_string()) + } + } +} + +fn sanitize_inference_response_headers(headers: Vec<(String, String)>) -> Vec<(String, String)> { + headers + .into_iter() + .filter(|(name, _)| !should_strip_response_header(name)) + .collect() +} + +fn should_strip_response_header(name: &str) -> bool { + let name_lc = name.to_ascii_lowercase(); + matches!(name_lc.as_str(), "content-length") || is_hop_by_hop_header(&name_lc) +} + +fn is_hop_by_hop_header(name: &str) -> bool { + matches!( + name, + "connection" + | "keep-alive" + | "proxy-authenticate" + | "proxy-authorization" + | "proxy-connection" + | "te" + | "trailer" + | "transfer-encoding" + | "upgrade" + ) +} + +/// Write all bytes to an async writer. +async fn write_all(writer: &mut (impl tokio::io::AsyncWrite + Unpin), data: &[u8]) -> Result<()> { + use tokio::io::AsyncWriteExt; + writer.write_all(data).await.into_diagnostic()?; + writer.flush().await.into_diagnostic()?; + Ok(()) +} + +#[derive(Debug, Clone)] +struct L7ConfigSnapshot { + config: crate::l7::L7EndpointConfig, +} + +#[derive(Debug, Clone)] +struct L7RouteSnapshot { + configs: Vec, + generation: u64, +} + +fn emit_l7_tunnel_close_after_policy_change(host: &str, port: u16, error: miette::Report) { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Open) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!( + "L7 tunnel closed before inspection because policy changed: {error}" + )) + .build(); + ocsf_emit!(event); +} + +/// Query L7 endpoint config from the OPA engine for a matched CONNECT decision. +/// +/// Returns `Some(L7EndpointConfig)` if the matched endpoint has L7 config (protocol field), +/// `None` for L4-only endpoints. +fn query_l7_route_snapshot( + engine: &OpaEngine, + decision: &ConnectDecision, + host: &str, + port: u16, +) -> Option { + // Only query if action is Allow (not Deny) + let has_policy = match &decision.action { + NetworkAction::Allow { matched_policy } => matched_policy.is_some(), + NetworkAction::Deny { .. } => false, + }; + if !has_policy { + return None; + } + + let input = crate::opa::NetworkInput { + host: host.to_string(), + port, + binary_path: decision.binary.clone().unwrap_or_default(), + binary_sha256: String::new(), + ancestors: decision.ancestors.clone(), + cmdline_paths: decision.cmdline_paths.clone(), + }; + + match engine.query_endpoint_configs_with_generation(&input) { + Ok((vals, generation)) => Some(L7RouteSnapshot { + configs: vals + .into_iter() + .filter_map(|val| crate::l7::parse_l7_config(&val)) + .map(|config| L7ConfigSnapshot { config }) + .collect(), + generation, + }), + Err(e) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!("Failed to query L7 endpoint config: {e}")) + .build(); + ocsf_emit!(event); + None + } + } +} + +fn select_l7_config_for_path<'a>( + configs: &'a [L7ConfigSnapshot], + path: &str, +) -> Option<&'a L7ConfigSnapshot> { + configs + .iter() + .filter(|snapshot| snapshot.config.matches_path(path)) + .max_by_key(|snapshot| snapshot.config.path_specificity()) +} + +/// Query the TLS mode for an endpoint, independent of L7 config. +/// +/// This extracts `tls: skip` from the endpoint even when no `protocol` is set. +fn query_tls_mode( + engine: &OpaEngine, + decision: &ConnectDecision, + host: &str, + port: u16, +) -> crate::l7::TlsMode { + let has_policy = match &decision.action { + NetworkAction::Allow { matched_policy } => matched_policy.is_some(), + NetworkAction::Deny { .. } => false, + }; + if !has_policy { + return crate::l7::TlsMode::Auto; + } + + let input = crate::opa::NetworkInput { + host: host.to_string(), + port, + binary_path: decision.binary.clone().unwrap_or_default(), + binary_sha256: String::new(), + ancestors: decision.ancestors.clone(), + cmdline_paths: decision.cmdline_paths.clone(), + }; + + match engine.query_endpoint_config(&input) { + Ok(Some(val)) => crate::l7::parse_tls_mode(&val), + _ => crate::l7::TlsMode::Auto, + } +} + +/// When the policy endpoint host is a literal IP address, the user has +/// explicitly declared intent to allow that destination. Synthesize an +/// `allowed_ips` entry so the existing allowlist-validation path is used +/// instead of the blanket internal-IP rejection. +/// +/// Always-blocked addresses (loopback, link-local, unspecified) are skipped +/// — synthesizing an `allowed_ips` entry for them would be silently +/// un-enforceable at runtime. +fn implicit_allowed_ips_for_ip_host(host: &str) -> Vec { + let lookup_host = normalize_host_lookup_key(host); + if let Ok(ip) = lookup_host.parse::() { + if is_always_blocked_ip(ip) { + warn!( + host, + "Policy host is an always-blocked address; \ + implicit allowed_ips skipped — SSRF hardening prevents \ + traffic to this destination regardless of policy" + ); + return vec![]; + } + vec![lookup_host.to_string()] + } else { + vec![] + } +} + +fn normalize_host_lookup_key(host: &str) -> &str { + host.strip_prefix('[') + .and_then(|trimmed| trimmed.strip_suffix(']')) + .unwrap_or(host) +} + +/// Returns `true` if `host` is one of the well-known driver-injected aliases +/// for the host machine (e.g. `host.openshell.internal`). +fn is_host_gateway_alias(host: &str) -> bool { + let h = normalize_host_lookup_key(host); + HOST_GATEWAY_ALIASES + .iter() + .any(|alias| alias.eq_ignore_ascii_case(h)) +} + +/// Returns `true` if `ip` is a known cloud instance metadata endpoint that +/// must never be exempted from SSRF blocking. +/// +/// IPv4-mapped IPv6 addresses (e.g. `::ffff:169.254.169.254`) are normalized +/// to their embedded IPv4 representation before comparison, so the invariant +/// holds regardless of how the address is represented. +fn is_cloud_metadata_ip(ip: IpAddr) -> bool { + match ip { + IpAddr::V4(_) => CLOUD_METADATA_IPS.contains(&ip), + IpAddr::V6(v6) => v6 + .to_ipv4_mapped() + .is_some_and(|v4| CLOUD_METADATA_IPS.contains(&IpAddr::V4(v4))), + } +} + +/// Read the proxy's own `/etc/hosts` at startup and return the IP mapped to +/// `host.openshell.internal`, if present and safe. +/// +/// This is called once before user code runs, so the returned value is immune +/// to later `/etc/hosts` tampering by sandbox workloads. Returns `None` if no +/// entry exists, the entry cannot be parsed, or the mapped IP is a cloud +/// metadata address. +#[cfg(any(target_os = "linux", test))] +fn detect_trusted_host_gateway() -> Option { + let contents = std::fs::read_to_string("/etc/hosts").ok()?; + let ips = parse_hosts_file_for_host(&contents, "host.openshell.internal"); + + // Multiple distinct IPs for the alias is unexpected — compute drivers + // always inject exactly one. Warn loudly so operators can diagnose the + // inconsistency; we still proceed with the first entry rather than + // disabling the exemption entirely, because the mismatch guard in + // resolve_and_check_trusted_gateway() will reject any runtime resolution + // that returns a different IP. + if ips.len() > 1 { + warn!( + ips = ?ips, + "host.openshell.internal has {} distinct IPs in /etc/hosts; \ + expected exactly one. Using first entry. \ + Connections resolving to any other IP will be rejected.", + ips.len() + ); + } + + let ip = ips.into_iter().next()?; + + if is_cloud_metadata_ip(ip) { + warn!( + %ip, + "host.openshell.internal resolves to a cloud metadata IP; \ + trusted-gateway SSRF exemption disabled" + ); + return None; + } + // The exemption exists solely for link-local IPs used by rootless Podman + // with pasta. Private RFC 1918 addresses (e.g. Docker bridge 172.17.0.1, + // Kubernetes node 192.168.x.x), loopback, unspecified, and all other + // non-link-local addresses are never legitimate candidates for the + // link-local SSRF exemption — they must fall through to the normal + // allowed_ips / resolve_and_reject_internal() enforcement path. + if !is_link_local_ip(ip) { + warn!( + %ip, + "host.openshell.internal maps to a non-link-local IP; \ + trusted-gateway SSRF exemption disabled" + ); + return None; + } + Some(ip) +} + +#[cfg(not(any(target_os = "linux", test)))] +fn detect_trusted_host_gateway() -> Option { + None +} + +/// Resolve `host:port` and validate that every resolved address matches the +/// trusted host gateway IP. +/// +/// This bypasses the normal SSRF tiers (always-blocked and internal-IP) for +/// driver-injected host-gateway aliases, allowing link-local addresses used +/// by rootless Podman with pasta without opening up arbitrary link-local or +/// cloud metadata access. +/// +/// Rejects: +/// - Any resolved IP that is a cloud metadata address (defense-in-depth) +/// - Any resolved IP that does not match `trusted_gw` (prevents /etc/hosts tampering) +/// - Control-plane ports (etcd, K8s API, kubelet) regardless of IP +async fn resolve_and_check_trusted_gateway( + host: &str, + port: u16, + trusted_gw: IpAddr, + entrypoint_pid: u32, +) -> std::result::Result, String> { + if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { + return Err(format!( + "port {port} is a blocked control-plane port, connection rejected" + )); + } + let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; + if addrs.is_empty() { + return Err(format!( + "DNS resolution returned no addresses for {}", + normalize_host_lookup_key(host) + )); + } + for addr in &addrs { + if is_cloud_metadata_ip(addr.ip()) { + return Err(format!( + "{host} resolves to cloud metadata address {}, connection rejected", + addr.ip() + )); + } + if addr.ip() != trusted_gw { + return Err(format!( + "{host} resolves to {} which does not match trusted host gateway \ + {trusted_gw}, connection rejected", + addr.ip() + )); + } + // Defense-in-depth: even if the resolved IP matches trusted_gw, reject + // any non-link-local address. detect_trusted_host_gateway() already + // enforces this at startup, but we re-check here to guard against any + // unanticipated code path that might admit a private or loopback IP. + if !is_link_local_ip(addr.ip()) { + return Err(format!( + "{host} resolves to non-link-local address {}, \ + connection rejected", + addr.ip() + )); + } + } + Ok(addrs) +} + +fn resolve_ip_literal(host: &str, port: u16) -> Option> { + normalize_host_lookup_key(host) + .parse::() + .ok() + .map(|ip| vec![SocketAddr::new(ip, port)]) +} + +#[cfg(any(target_os = "linux", test))] +fn parse_hosts_file_for_host(contents: &str, host: &str) -> Vec { + let lookup_host = normalize_host_lookup_key(host); + let mut addrs = Vec::new(); + + for raw_line in contents.lines() { + let line = raw_line.split('#').next().unwrap_or("").trim(); + if line.is_empty() { + continue; + } + + let mut fields = line.split_whitespace(); + let Some(ip_str) = fields.next() else { + continue; + }; + let Ok(ip) = ip_str.parse::() else { + continue; + }; + + if fields.any(|alias| alias.eq_ignore_ascii_case(lookup_host)) && !addrs.contains(&ip) { + addrs.push(ip); + } + } + + addrs +} + +#[cfg(any(target_os = "linux", test))] +fn resolve_from_hosts_file_contents(contents: &str, host: &str, port: u16) -> Vec { + parse_hosts_file_for_host(contents, host) + .into_iter() + .map(|ip| SocketAddr::new(ip, port)) + .collect() +} + +#[cfg(target_os = "linux")] +async fn resolve_from_sandbox_hosts( + host: &str, + port: u16, + entrypoint_pid: u32, +) -> Option> { + if entrypoint_pid == 0 { + return None; + } + + let hosts_path = format!("/proc/{entrypoint_pid}/root/etc/hosts"); + let contents = match tokio::fs::read_to_string(&hosts_path).await { + Ok(contents) => contents, + Err(error) => { + debug!( + pid = entrypoint_pid, + path = %hosts_path, + host, + "Falling back to DNS; failed to read sandbox hosts file: {error}" + ); + return None; + } + }; + + let addrs = resolve_from_hosts_file_contents(&contents, host, port); + if addrs.is_empty() { None } else { Some(addrs) } +} + +// Mirrors the Linux signature so call sites can `.await` uniformly across +// platforms; the non-Linux path has nothing to await. +#[cfg(not(target_os = "linux"))] +#[allow(clippy::unused_async)] +async fn resolve_from_sandbox_hosts( + _host: &str, + _port: u16, + _entrypoint_pid: u32, +) -> Option> { + None +} + +async fn resolve_socket_addrs( + host: &str, + port: u16, + entrypoint_pid: u32, +) -> std::result::Result, String> { + if let Some(addrs) = resolve_ip_literal(host, port) { + return Ok(addrs); + } + + if let Some(addrs) = resolve_from_sandbox_hosts(host, port, entrypoint_pid).await { + return Ok(addrs); + } + + let lookup_host = normalize_host_lookup_key(host); + let addrs: Vec = tokio::net::lookup_host((lookup_host, port)) + .await + .map_err(|e| format!("DNS resolution failed for {lookup_host}:{port}: {e}"))? + .collect(); + + if addrs.is_empty() { + return Err(format!( + "DNS resolution returned no addresses for {lookup_host}:{port}" + )); + } + + Ok(addrs) +} + +fn reject_internal_resolved_addrs( + host: &str, + addrs: &[SocketAddr], +) -> std::result::Result<(), String> { + if addrs.is_empty() { + return Err(format!( + "DNS resolution returned no addresses for {}", + normalize_host_lookup_key(host) + )); + } + + for addr in addrs { + if is_internal_ip(addr.ip()) { + return Err(format!( + "{host} resolves to internal address {}, connection rejected", + addr.ip() + )); + } + } + + Ok(()) +} + +fn validate_allowed_ips_for_resolved_addrs( + host: &str, + port: u16, + addrs: &[SocketAddr], + allowed_ips: &[ipnet::IpNet], +) -> std::result::Result<(), String> { + if addrs.is_empty() { + return Err(format!( + "DNS resolution returned no addresses for {}", + normalize_host_lookup_key(host) + )); + } + + // Block control-plane ports regardless of IP match. + if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { + return Err(format!( + "port {port} is a blocked control-plane port, connection rejected" + )); + } + + for addr in addrs { + // Always block loopback and link-local + if is_always_blocked_ip(addr.ip()) { + return Err(format!( + "{host} resolves to always-blocked address {}, connection rejected", + addr.ip() + )); + } + + // Check resolved IP against the allowlist + let ip_allowed = allowed_ips.iter().any(|net| net.contains(&addr.ip())); + if !ip_allowed { + return Err(format!( + "{host} resolves to {} which is not in allowed_ips, connection rejected", + addr.ip() + )); + } + } + + Ok(()) +} + +fn validate_declared_endpoint_resolved_addrs( + host: &str, + port: u16, + addrs: &[SocketAddr], +) -> std::result::Result<(), String> { + if addrs.is_empty() { + return Err(format!( + "DNS resolution returned no addresses for {}", + normalize_host_lookup_key(host) + )); + } + + if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { + return Err(format!( + "port {port} is a blocked control-plane port, connection rejected" + )); + } + + for addr in addrs { + if is_always_blocked_ip(addr.ip()) { + return Err(format!( + "{host} resolves to always-blocked address {}, connection rejected", + addr.ip() + )); + } + } + + Ok(()) +} + +/// Resolve a host:port using sandbox `/etc/hosts` first (when available), then +/// reject if any resolved address is internal. +/// +/// Returns the resolved `SocketAddr` list on success. Returns an error string +/// if any resolved IP is in an internal range or if DNS resolution fails. +async fn resolve_and_reject_internal( + host: &str, + port: u16, + entrypoint_pid: u32, +) -> std::result::Result, String> { + let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; + reject_internal_resolved_addrs(host, &addrs)?; + Ok(addrs) +} + +/// Resolve a host:port using sandbox `/etc/hosts` first (when available), then +/// validate resolved addresses against a CIDR/IP allowlist. +/// +/// Rejects loopback and link-local unconditionally. For all other resolved +/// addresses, checks that each one matches at least one entry in `allowed_ips`. +/// Entries can be CIDR notation ("10.0.5.0/24") or exact IPs ("10.0.5.20"). +/// +/// Returns the resolved `SocketAddr` list on success. +async fn resolve_and_check_allowed_ips( + host: &str, + port: u16, + allowed_ips: &[ipnet::IpNet], + entrypoint_pid: u32, +) -> std::result::Result, String> { + let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; + validate_allowed_ips_for_resolved_addrs(host, port, &addrs, allowed_ips)?; + Ok(addrs) +} + +/// Resolve a host:port that was explicitly declared by hostname in policy. +/// +/// Exact declared hostnames are the operator's trust signal, so RFC1918 and +/// other private ranges are allowed without a duplicated `allowed_ips` entry. +/// Loopback, link-local, unspecified, and control-plane ports remain blocked. +async fn resolve_and_check_declared_endpoint( + host: &str, + port: u16, + entrypoint_pid: u32, +) -> std::result::Result, String> { + let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; + validate_declared_endpoint_resolved_addrs(host, port, &addrs)?; + Ok(addrs) +} + +/// Minimum CIDR prefix length before logging a breadth warning. +/// CIDRs broader than /16 (65,536+ addresses) may unintentionally expose +/// control-plane services on the same network. +const MIN_SAFE_PREFIX_LEN: u8 = 16; + +/// Ports that are always blocked in `resolve_and_check_allowed_ips`, even +/// when the resolved IP matches an `allowed_ips` entry. These ports belong +/// to control-plane services that should never be reachable from a sandbox. +const BLOCKED_CONTROL_PLANE_PORTS: &[u16] = &[ + 2379, // etcd client + 2380, // etcd peer + 6443, // Kubernetes API server + 10250, // kubelet API + 10255, // kubelet read-only +]; + +/// Parse CIDR/IP strings into `IpNet` values, rejecting invalid entries and +/// entries that overlap always-blocked ranges (loopback, link-local, +/// unspecified). +/// +/// Returns parsed networks on success, or an error describing which entries +/// are invalid or always-blocked. Logs a warning for overly broad CIDRs +/// that are not outright blocked. +fn parse_allowed_ips(raw: &[String]) -> std::result::Result, String> { + use openshell_core::net::is_always_blocked_net; + + let mut nets = Vec::with_capacity(raw.len()); + let mut errors = Vec::new(); + + for entry in raw { + // Try as CIDR first, then as bare IP (convert to /32 or /128) + let parsed = entry.parse::().or_else(|_| { + entry + .parse::() + .map(|ip| match ip { + IpAddr::V4(v4) => ipnet::IpNet::V4(ipnet::Ipv4Net::from(v4)), + IpAddr::V6(v6) => ipnet::IpNet::V6(ipnet::Ipv6Net::from(v6)), + }) + .map_err(|_| ()) + }); + + match parsed { + Ok(n) => { + // Reject entries that overlap always-blocked ranges — these + // would be silently denied at runtime by is_always_blocked_ip + // and cause confusing UX (accepted in policy, never works). + if is_always_blocked_net(n) { + errors.push(format!( + "allowed_ips entry {entry} falls within always-blocked range \ + (loopback/link-local/unspecified); remove this entry — \ + SSRF hardening prevents traffic to these destinations \ + regardless of policy" + )); + continue; + } + + if n.prefix_len() < MIN_SAFE_PREFIX_LEN { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .severity(SeverityId::Medium) + .message(format!( + "allowed_ips entry has a very broad CIDR {n} (/{}) < /{MIN_SAFE_PREFIX_LEN}; \ + this may expose control-plane services on the same network", + n.prefix_len() + )) + .build(); + ocsf_emit!(event); + } + nets.push(n); + } + Err(()) => errors.push(format!("invalid CIDR/IP in allowed_ips: {entry}")), + } + } + + if errors.is_empty() { + Ok(nets) + } else { + Err(errors.join("; ")) + } +} + +/// Query `allowed_ips` from the matched endpoint config for a CONNECT decision. +fn query_allowed_ips( + engine: &OpaEngine, + decision: &ConnectDecision, + host: &str, + port: u16, +) -> Vec { + // Only query if action is Allow with a matched policy + let has_policy = match &decision.action { + NetworkAction::Allow { matched_policy } => matched_policy.is_some(), + NetworkAction::Deny { .. } => false, + }; + if !has_policy { + return vec![]; + } + + let input = crate::opa::NetworkInput { + host: host.to_string(), + port, + binary_path: decision.binary.clone().unwrap_or_default(), + binary_sha256: String::new(), + ancestors: decision.ancestors.clone(), + cmdline_paths: decision.cmdline_paths.clone(), + }; + + match engine.query_allowed_ips(&input) { + Ok(ips) => ips, + Err(e) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!( + "Failed to query allowed_ips from endpoint config: {e}" + )) + .build(); + ocsf_emit!(event); + vec![] + } + } +} + +/// Query whether the matched endpoint was declared as this exact hostname. +fn query_exact_declared_endpoint_host( + engine: &OpaEngine, + decision: &ConnectDecision, + host: &str, + port: u16, +) -> bool { + let has_policy = match &decision.action { + NetworkAction::Allow { matched_policy } => matched_policy.is_some(), + NetworkAction::Deny { .. } => false, + }; + if !has_policy { + return false; + } + + let input = crate::opa::NetworkInput { + host: host.to_string(), + port, + binary_path: decision.binary.clone().unwrap_or_default(), + binary_sha256: String::new(), + ancestors: decision.ancestors.clone(), + cmdline_paths: decision.cmdline_paths.clone(), + }; + + match engine.query_exact_declared_endpoint_host(&input) { + Ok(is_exact_declared) => is_exact_declared, + Err(e) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(host, port)) + .message(format!("Failed to query exact declared endpoint host: {e}")) + .build(); + ocsf_emit!(event); + false + } + } +} + +/// Canonicalize the request-target for inference pattern detection. +/// +/// Falls back to the raw path on canonicalization error: the request is then +/// routed through the normal forward path, where `rest.rs::parse_http_request` +/// will reject it properly. Returning the raw path here prevents a crafted +/// target from bypassing inference routing without our detection logic having +/// to implement a second, duplicate error-response surface. +fn normalize_inference_path(path: &str) -> String { + match crate::l7::path::canonicalize_request_target( + path, + &crate::l7::path::CanonicalizeOptions::default(), + ) { + Ok((canon, _)) => canon.path, + Err(_) => path.to_string(), + } +} + +/// Extract the hostname from an absolute-form URI used in plain HTTP proxy requests. +/// +/// For example, `"http://example.com/path"` yields `"example.com"` and +/// `"http://example.com:8080/path"` yields `"example.com"`. Returns `"unknown"` +/// if the URI cannot be parsed. +#[cfg(test)] +fn extract_host_from_uri(uri: &str) -> String { + // Absolute-form URIs look like "http://host[:port]/path" + // Strip the scheme prefix, then extract the authority (host[:port]) before the first '/'. + let after_scheme = uri.find("://").map_or(uri, |i| &uri[i + 3..]); + let authority = after_scheme.split('/').next().unwrap_or(after_scheme); + // Strip port if present (handle IPv6 bracket notation) + let host = if authority.starts_with('[') { + // IPv6: [::1]:port + authority.find(']').map_or(authority, |i| &authority[..=i]) + } else { + authority.split(':').next().unwrap_or(authority) + }; + if host.is_empty() { + "unknown".to_string() + } else { + host.to_string() + } +} + +/// Parse an absolute-form proxy request URI into its components. +/// +/// For example, `"http://10.86.8.223:8000/screenshot/"` yields +/// `("http", "10.86.8.223", 8000, "/screenshot/")`. +/// +/// Handles: +/// - Default port 80 for `http`, 443 for `https` +/// - IPv6 bracket notation (`[::1]`) +/// - Missing path (defaults to `/`) +/// - Query strings (preserved in path) +fn parse_proxy_uri(uri: &str) -> Result<(String, String, u16, String)> { + // Extract scheme + let (scheme, rest) = uri + .split_once("://") + .ok_or_else(|| miette::miette!("Missing scheme in proxy URI: {uri}"))?; + let scheme = scheme.to_ascii_lowercase(); + + // Split authority from path + let (authority, path) = if rest.starts_with('[') { + // IPv6: [::1]:port/path + let bracket_end = rest + .find(']') + .ok_or_else(|| miette::miette!("Unclosed IPv6 bracket in URI: {uri}"))?; + let after_bracket = &rest[bracket_end + 1..]; + after_bracket.find('/').map_or((rest, "/"), |slash_pos| { + ( + &rest[..=bracket_end + slash_pos], + &after_bracket[slash_pos..], + ) + }) + } else if let Some(slash_pos) = rest.find('/') { + (&rest[..slash_pos], &rest[slash_pos..]) + } else { + (rest, "/") + }; + + // Parse host and port from authority + let (host, port) = if authority.starts_with('[') { + // IPv6: [::1]:port or [::1] + let bracket_end = authority + .find(']') + .ok_or_else(|| miette::miette!("Unclosed IPv6 bracket: {uri}"))?; + let host = &authority[1..bracket_end]; // strip brackets + let port_str = &authority[bracket_end + 1..]; + let port = if let Some(port_str) = port_str.strip_prefix(':') { + port_str + .parse::() + .map_err(|_| miette::miette!("Invalid port in URI: {uri}"))? + } else { + match scheme.as_str() { + "https" => 443, + _ => 80, + } + }; + (host.to_string(), port) + } else if let Some((h, p)) = authority.rsplit_once(':') { + let port = p + .parse::() + .map_err(|_| miette::miette!("Invalid port in URI: {uri}"))?; + (h.to_string(), port) + } else { + let port = match scheme.as_str() { + "https" => 443, + _ => 80, + }; + (authority.to_string(), port) + }; + + if host.is_empty() { + return Err(miette::miette!("Empty host in URI: {uri}")); + } + + let path = if path.is_empty() { "/" } else { path }; + + Ok((scheme, host, port, path.to_string())) +} + +/// Rewrite an absolute-form HTTP proxy request to origin-form for upstream. +/// +/// Transforms `GET http://host:port/path HTTP/1.1` into `GET /path HTTP/1.1`, +/// strips proxy hop-by-hop headers, injects `Connection: close` and `Via`. +/// +/// Returns the rewritten request bytes (headers + any overflow body bytes). +fn rewrite_forward_request( + raw: &[u8], + used: usize, + path: &str, + secret_resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, +) -> Result, crate::secrets::UnresolvedPlaceholderError> { + let header_end = raw[..used] + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(used, |p| p + 4); + let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); + let upstream_path = match secret_resolver { + Some(resolver) => crate::secrets::rewrite_target_for_eval(path, resolver)?.resolved, + None => path.to_string(), + }; + + let header_str = String::from_utf8_lossy(&raw[..header_end]); + let lines = header_str.split("\r\n").collect::>(); + + // Rebuild headers, stripping hop-by-hop and adding proxy headers + let mut output = Vec::with_capacity(header_end + 128); + let mut has_connection = false; + let mut has_via = false; + + for (i, line) in lines.iter().enumerate() { + if i == 0 { + // Rewrite request line: METHOD absolute-uri HTTP/1.1 → METHOD path HTTP/1.1 + let parts: Vec<&str> = line.splitn(3, ' ').collect(); + if parts.len() == 3 { + output.extend_from_slice(parts[0].as_bytes()); + output.push(b' '); + output.extend_from_slice(upstream_path.as_bytes()); + output.push(b' '); + output.extend_from_slice(parts[2].as_bytes()); + } else { + output.extend_from_slice(line.as_bytes()); + } + output.extend_from_slice(b"\r\n"); + continue; + } + if line.is_empty() { + // End of headers + break; + } + + let lower = line.to_ascii_lowercase(); + + // Strip proxy hop-by-hop headers + if lower.starts_with("proxy-connection:") + || lower.starts_with("proxy-authorization:") + || lower.starts_with("proxy-authenticate:") + { + continue; + } + + // Replace Connection header + if lower.starts_with("connection:") { + has_connection = true; + if websocket_upgrade { + output.extend_from_slice(line.as_bytes()); + output.extend_from_slice(b"\r\n"); + continue; + } + output.extend_from_slice(b"Connection: close\r\n"); + continue; + } + + let rewritten_line = match secret_resolver { + Some(resolver) => rewrite_header_line_checked(line, resolver)?, + None => line.to_string(), + }; + + output.extend_from_slice(rewritten_line.as_bytes()); + output.extend_from_slice(b"\r\n"); + + if lower.starts_with("via:") { + has_via = true; + } + } + + // Inject missing headers + if !has_connection && !websocket_upgrade { + output.extend_from_slice(b"Connection: close\r\n"); + } + if !has_via { + output.extend_from_slice(b"Via: 1.1 openshell-sandbox\r\n"); + } + + // End of headers + output.extend_from_slice(b"\r\n"); + let rewritten_header_end = output.len(); + + // Append any overflow body bytes from the original buffer + if header_end < used { + output.extend_from_slice(&raw[header_end..used]); + } + + // Fail-closed: scan for any remaining unresolved placeholders + if secret_resolver.is_some() { + let scan_end = if request_body_credential_rewrite { + rewritten_header_end + } else { + output.len() + }; + let output_str = String::from_utf8_lossy(&output[..scan_end]); + if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) + || output_str.contains(crate::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) + { + return Err(crate::secrets::UnresolvedPlaceholderError { location: "header" }); + } + } + + Ok(output) +} + +struct ForwardRelayOptions<'a> { + generation_guard: &'a PolicyGenerationGuard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode, + secret_resolver: Option<&'a SecretResolver>, + request_body_credential_rewrite: bool, +} + +async fn relay_rewritten_forward_request( + method: &str, + path: &str, + rewritten: Vec, + client: &mut C, + upstream: &mut U, + options: ForwardRelayOptions<'_>, +) -> Result +where + C: TokioAsyncRead + TokioAsyncWrite + Unpin, + U: TokioAsyncRead + TokioAsyncWrite + Unpin, +{ + let header_end = rewritten + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(rewritten.len(), |p| p + 4); + let header_str = String::from_utf8_lossy(&rewritten[..header_end]); + let body_length = crate::l7::rest::parse_body_length(&header_str)?; + let (_, query_params) = crate::l7::rest::parse_target_query(path)?; + let req = crate::l7::provider::L7Request { + action: method.to_string(), + target: path.to_string(), + query_params, + raw_header: rewritten, + body_length, + }; + + crate::l7::rest::relay_http_request_with_options_guarded( + &req, + client, + upstream, + crate::l7::rest::RelayRequestOptions { + resolver: options.secret_resolver, + generation_guard: Some(options.generation_guard), + websocket_extensions: options.websocket_extensions, + request_body_credential_rewrite: options.request_body_credential_rewrite, + }, + ) + .await +} + +async fn inject_token_grant_for_forward_request( + method: &str, + upstream_target: &str, + forward_request_bytes: Vec, + l7_ctx: &crate::l7::relay::L7EvalContext, +) -> Result> { + let header_end = forward_request_bytes + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(forward_request_bytes.len(), |p| p + 4); + let header_str = std::str::from_utf8(&forward_request_bytes[..header_end]) + .into_diagnostic() + .map_err(|_| miette::miette!("Forward HTTP headers contain invalid UTF-8"))?; + let body_length = crate::l7::rest::parse_body_length(header_str)?; + let forward_request_for_token_grant = crate::l7::provider::L7Request { + action: method.to_string(), + target: upstream_target.to_string(), + query_params: std::collections::HashMap::new(), + raw_header: forward_request_bytes, + body_length, + }; + + crate::l7::token_grant_injection::inject_if_needed(forward_request_for_token_grant, l7_ctx) + .await + .map(|req| req.raw_header) +} + +/// Handle a plain HTTP forward proxy request (non-CONNECT). +/// +/// Public IPs are allowed through when the endpoint passes OPA evaluation. +/// Private IPs require explicit `allowed_ips` on the endpoint config (SSRF +/// override). Rewrites the absolute-form request to origin-form, connects +/// upstream, and relays the request/response using the guarded HTTP relay. +// Many distinct, non-related context parameters are required for forward proxy +// dispatch; bundling them into a struct would just shift the noise into call sites. +#[allow(clippy::too_many_arguments)] +async fn handle_forward_proxy( + method: &str, + target_uri: &str, + buf: &[u8], + used: usize, + client: &mut TcpStream, + opa_engine: Arc, + identity_cache: Arc, + entrypoint_pid: Arc, + policy_local_ctx: Option>, + trusted_host_gateway: Arc>, + secret_resolver: Option>, + dynamic_credentials: Option< + Arc< + std::sync::RwLock< + std::collections::HashMap, + >, + >, + >, + denial_tx: Option<&mpsc::UnboundedSender>, + activity_tx: Option<&ActivitySender>, +) -> Result<()> { + // 1. Parse the absolute-form URI. `path` is marked `mut` so that, when an + // L7 config applies, the canonicalized form produced below replaces it + // in-place — keeping OPA evaluation and the bytes written onto the wire + // in sync. See the L7 block below. + let (scheme, host, port, mut path) = match parse_proxy_uri(target_uri) { + Ok(parsed) => parsed, + Err(e) => { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .message(format!("FORWARD parse error for {target_uri}: {e}")) + .build(); + ocsf_emit!(event); + respond(client, b"HTTP/1.1 400 Bad Request\r\n\r\n").await?; + return Ok(()); + } + }; + let host_lc = host.to_ascii_lowercase(); + + if host_lc == POLICY_LOCAL_HOST { + if scheme != "http" || port != 80 { + respond( + client, + &build_json_error_response( + 400, + "Bad Request", + "invalid_policy_local_scheme", + "Use http://policy.local only", + ), + ) + .await?; + return Ok(()); + } + if let Some(ctx) = policy_local_ctx { + return crate::policy_local::handle_forward_request( + &ctx, + method, + &path, + &buf[..used], + client, + ) + .await; + } + respond( + client, + b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 31\r\n\r\npolicy.local is not configured", + ) + .await?; + return Ok(()); + } + + // 2. Reject HTTPS — must use CONNECT for TLS + if scheme == "https" { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Refuse) + .action(ActionId::Denied) + .disposition(DispositionId::Rejected) + .severity(SeverityId::Informational) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!( + "FORWARD rejected: HTTPS requires CONNECT for {host_lc}:{port}" + )) + .build(); + ocsf_emit!(event); + } + respond( + client, + b"HTTP/1.1 400 Bad Request\r\nContent-Length: 27\r\n\r\nUse CONNECT for HTTPS URLs", + ) + .await?; + return Ok(()); + } + + // 3. Evaluate OPA policy (same identity binding as CONNECT) + let peer_addr = client.peer_addr().into_diagnostic()?; + let _local_addr = client.local_addr().into_diagnostic()?; + + let opa_clone = opa_engine.clone(); + let cache_clone = identity_cache.clone(); + let pid_clone = entrypoint_pid.clone(); + let host_clone = host_lc.clone(); + let decision = tokio::task::spawn_blocking(move || { + evaluate_opa_tcp( + peer_addr, + &opa_clone, + &cache_clone, + &pid_clone, + &host_clone, + port, + ) + }) + .await + .map_err(|e| miette::miette!("identity resolution task panicked: {e}"))?; + + // Build log context + let binary_str = decision + .binary + .as_ref() + .map_or_else(|| "-".to_string(), |p| p.display().to_string()); + let pid_str = decision + .binary_pid + .map_or_else(|| "-".to_string(), |p| p.to_string()); + let ancestors_str = if decision.ancestors.is_empty() { + "-".to_string() + } else { + decision + .ancestors + .iter() + .map(|p| p.display().to_string()) + .collect::>() + .join(" -> ") + }; + let cmdline_str = if decision.cmdline_paths.is_empty() { + "-".to_string() + } else { + decision + .cmdline_paths + .iter() + .map(|p| p.display().to_string()) + .collect::>() + .join(", ") + }; + + // 4. Only proceed on explicit Allow — reject Deny + let matched_policy = match &decision.action { + NetworkAction::Allow { matched_policy } => matched_policy.clone(), + NetworkAction::Deny { reason } => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule("-", "opa") + .message(format!("FORWARD denied {method} {host_lc}:{port}{path}")) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + reason, + "forward", + ); + emit_activity_simple(activity_tx, true, "forward_policy"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + }; + let policy_str = matched_policy.as_deref().unwrap_or("-"); + let sandbox_entrypoint_pid = entrypoint_pid.load(Ordering::Acquire); + let forward_generation_guard = match opa_engine.generation_guard(decision.generation) { + Ok(guard) => guard, + Err(e) => { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + emit_activity_simple(activity_tx, true, "policy_stale"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + }; + let mut forward_request_bytes = buf[..used].to_vec(); + let mut upstream_target = path.clone(); + let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; + let mut forward_tunnel_engine: Option = None; + let mut forward_upgrade_config: Option = None; + let mut forward_upgrade_target = String::new(); + let mut forward_upgrade_query_params = std::collections::HashMap::new(); + let mut forward_websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + let mut request_body_credential_rewrite = false; + let l7_ctx = crate::l7::relay::L7EvalContext { + host: host_lc.clone(), + port, + policy_name: matched_policy.clone().unwrap_or_default(), + binary_path: decision + .binary + .as_ref() + .map(|p| p.to_string_lossy().into_owned()) + .unwrap_or_default(), + ancestors: decision + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + cmdline_paths: decision + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(), + secret_resolver: secret_resolver.clone(), + activity_tx: activity_tx.cloned(), + dynamic_credentials: dynamic_credentials.clone(), + token_grant_resolver: dynamic_credentials + .as_ref() + .map(|_| crate::l7::token_grant_injection::default_resolver()), + }; + let mut l7_activity_pending = false; + + // 4b. If the endpoint has L7 config, evaluate the request against + // L7 policy. The forward proxy handles exactly one request per + // connection (Connection: close), so a single evaluation suffices. + if let Some(route) = query_l7_route_snapshot(&opa_engine, &decision, &host_lc, port) + && !route.configs.is_empty() + { + if route.generation != forward_generation_guard.captured_generation() { + emit_l7_tunnel_close_after_policy_change( + &host_lc, + port, + miette::miette!( + "policy changed before forward L7 evaluation [expected_generation:{} current_generation:{}]", + forward_generation_guard.captured_generation(), + route.generation, + ), + ); + emit_activity_simple(activity_tx, true, "policy_stale"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + let tunnel_engine = match opa_engine.clone_engine_for_tunnel(route.generation) { + Ok(engine) => engine, + Err(e) => { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + emit_activity_simple(activity_tx, true, "policy_stale"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + }; + + // Canonicalize the request-target. The canonical form is fed to OPA + // AND reassigned to the outer `path` variable so the later call to + // `rewrite_forward_request` writes canonical bytes to the upstream. + // This closes the policy/upstream parser-differential at this site; + // without this reassignment, OPA would evaluate the canonical form + // while the upstream re-normalizes the raw input and dispatches on a + // potentially different path. + let canonicalize_options = crate::l7::path::CanonicalizeOptions { + allow_encoded_slash: route + .configs + .iter() + .any(|snapshot| snapshot.config.allow_encoded_slash), + ..Default::default() + }; + let query_params = + match crate::l7::path::canonicalize_request_target(&path, &canonicalize_options) { + Ok((canon, query)) => { + upstream_target = match query.as_deref() { + Some(raw_query) if !raw_query.is_empty() => { + format!("{}?{raw_query}", canon.path) + } + _ => canon.path.clone(), + }; + let params = query + .as_deref() + .map_or_else(std::collections::HashMap::new, |q| { + crate::l7::rest::parse_query_params(q).unwrap_or_default() + }); + path = canon.path; + params + } + Err(e) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!( + "FORWARD_L7 rejecting non-canonical request-target: {e}" + )) + .build(); + ocsf_emit!(event); + emit_activity_simple(activity_tx, true, "l7_parse_rejection"); + respond( + client, + &build_json_error_response( + 400, + "Bad Request", + "invalid_request_target", + "request-target must be canonical", + ), + ) + .await?; + return Ok(()); + } + }; + let Some(l7_config) = select_l7_config_for_path(&route.configs, &path) else { + emit_activity_simple(activity_tx, true, "l7_policy"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} did not match an L7 endpoint path"), + ), + ) + .await?; + return Ok(()); + }; + forward_websocket_request = + crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); + websocket_extensions = crate::l7::relay::websocket_extension_mode(&l7_config.config); + request_body_credential_rewrite = l7_config.config.protocol == crate::l7::L7Protocol::Rest + && l7_config.config.request_body_credential_rewrite; + forward_upgrade_config = Some(l7_config.config.clone()); + forward_upgrade_target = path.clone(); + forward_upgrade_query_params = query_params.clone(); + let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { + let header_end = forward_request_bytes + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(forward_request_bytes.len(), |p| p + 4); + let header_str = std::str::from_utf8(&forward_request_bytes[..header_end]) + .map_err(|_| miette::miette!("Forward GraphQL headers contain invalid UTF-8"))?; + let body_length = crate::l7::rest::parse_body_length(header_str)?; + let mut graphql_request = crate::l7::provider::L7Request { + action: method.to_string(), + target: path.clone(), + query_params: query_params.clone(), + raw_header: forward_request_bytes, + body_length, + }; + let info = match crate::l7::graphql::inspect_graphql_request( + client, + &mut graphql_request, + l7_config.config.graphql_max_body_bytes, + ) + .await + { + Ok(info) => info, + Err(e) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("FORWARD_GRAPHQL_L7 request rejected: {e}")) + .build(); + ocsf_emit!(event); + emit_activity_simple(activity_tx, true, "l7_parse_rejection"); + respond( + client, + &build_json_error_response( + 400, + "Bad Request", + "invalid_graphql_request", + &format!("GraphQL request rejected before policy evaluation: {e}"), + ), + ) + .await?; + return Ok(()); + } + }; + forward_request_bytes = graphql_request.raw_header; + Some(info) + } else { + None + }; + let jsonrpc = if l7_config.config.protocol == crate::l7::L7Protocol::JsonRpc { + let header_end = forward_request_bytes + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(forward_request_bytes.len(), |p| p + 4); + let header_str = std::str::from_utf8(&forward_request_bytes[..header_end]) + .map_err(|_| miette::miette!("Forward JSON-RPC headers contain invalid UTF-8"))?; + let body_length = crate::l7::rest::parse_body_length(header_str)?; + let mut jsonrpc_request = crate::l7::provider::L7Request { + action: method.to_string(), + target: path.clone(), + query_params: query_params.clone(), + raw_header: forward_request_bytes, + body_length, + }; + let body = match crate::l7::http::read_body_for_inspection( + client, + &mut jsonrpc_request, + 64 * 1024, + ) + .await + { + Ok(body) => body, + Err(e) => { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("FORWARD_JSONRPC_L7 request rejected: {e}")) + .build(); + ocsf_emit!(event); + emit_activity_simple(activity_tx, true, "l7_parse_rejection"); + respond( + client, + &build_json_error_response( + 400, + "Bad Request", + "invalid_jsonrpc_request", + &format!("JSON-RPC request rejected before policy evaluation: {e}"), + ), + ) + .await?; + return Ok(()); + } + }; + forward_request_bytes = jsonrpc_request.raw_header; + Some(crate::l7::jsonrpc::parse_jsonrpc_body(&body)) + } else { + None + }; + let request_info = crate::l7::L7RequestInfo { + action: method.to_string(), + target: path.clone(), + query_params, + graphql, + jsonrpc, + }; + + let parse_error_reason = request_info + .graphql + .as_ref() + .and_then(|info| info.error.as_deref()) + .map(|error| format!("GraphQL request rejected: {error}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason) = parse_error_reason.map_or_else( + || { + crate::l7::relay::evaluate_l7_request(&tunnel_engine, &l7_ctx, &request_info) + .unwrap_or_else(|e| { + let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("L7 eval failed, denying request: {e}")) + .build(); + ocsf_emit!(event); + (false, format!("L7 evaluation error: {e}")) + }) + }, + |reason| (false, reason), + ); + + let decision_str = match (allowed, l7_config.config.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, crate::l7::EnforcementMode::Audit) => "audit", + (false, crate::l7::EnforcementMode::Enforce) => "deny", + }; + + { + let (action_id, disposition_id, severity) = match decision_str { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + "allow" | "audit" => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + _ => ( + ActionId::Other, + DispositionId::Other, + SeverityId::Informational, + ), + }; + let engine_type = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { + "l7-graphql" + } else { + "l7" + }; + let message_prefix = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { + "FORWARD_GRAPHQL_L7" + } else { + "FORWARD_L7" + }; + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, engine_type) + .message(format!( + "{message_prefix} {decision_str} {method} {host_lc}:{port}{path} reason={reason}" + )) + .build(); + ocsf_emit!(event); + } + + let effectively_denied = force_deny + || (!allowed && l7_config.config.enforcement == crate::l7::EnforcementMode::Enforce); + + if effectively_denied { + emit_activity_simple(activity_tx, true, "l7_policy"); + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "forward-l7-deny", + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} denied by L7 policy: {reason}"), + ), + ) + .await?; + return Ok(()); + } + l7_activity_pending = true; + forward_tunnel_engine = Some(tunnel_engine); + } + + // 5. DNS resolution + SSRF defence (mirrors the CONNECT path logic). + // - If the host is a driver-injected host-gateway alias: bypass SSRF + // tiers and validate only against the trusted gateway IP. + // - If allowed_ips is set: validate resolved IPs against the allowlist + // (this is the SSRF override for private IP destinations). + // - If the endpoint is an exact declared hostname: allow private IPs, + // but still reject always-blocked addresses and control-plane ports. + // - Otherwise: reject internal IPs, allow public IPs through. + // When the policy host is already a literal IP address, treat it as + // implicitly allowed — the user explicitly declared the destination. + let mut raw_allowed_ips = query_allowed_ips(&opa_engine, &decision, &host_lc, port); + if raw_allowed_ips.is_empty() { + raw_allowed_ips = implicit_allowed_ips_for_ip_host(&host); + } + let exact_declared_endpoint_host = + query_exact_declared_endpoint_host(&opa_engine, &decision, &host_lc, port); + + // The trusted-gateway branch is the first path; reading it before the + // allowed_ips and default branches matches the policy decision narrative. + #[allow(clippy::if_not_else)] + let addrs = if is_host_gateway_alias(&host_lc) + && let Some(gw) = *trusted_host_gateway + { + // Trusted host-gateway path. Mirrors the CONNECT path logic. + match resolve_and_check_trusted_gateway(&host, port, gw, sandbox_entrypoint_pid).await { + Ok(addrs) => addrs, + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: trusted-gateway check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity_simple(activity_tx, true, "ssrf"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("{method} {host_lc}:{port} blocked: trusted-gateway check failed"), + ), + ) + .await?; + return Ok(()); + } + } + } else if !raw_allowed_ips.is_empty() { + // allowed_ips mode: validate resolved IPs against CIDR allowlist. + match parse_allowed_ips(&raw_allowed_ips) { + Ok(nets) => { + match resolve_and_check_allowed_ips(&host, port, &nets, sandbox_entrypoint_pid) + .await + { + Ok(addrs) => addrs, + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: allowed_ips check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity_simple(activity_tx, true, "ssrf"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!( + "{method} {host_lc}:{port} blocked: allowed_ips check failed" + ), + ), + ) + .await?; + return Ok(()); + } + } + } + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: invalid allowed_ips in policy for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity_simple(activity_tx, true, "ssrf"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!( + "{method} {host_lc}:{port} blocked: invalid allowed_ips in policy" + ), + ), + ) + .await?; + return Ok(()); + } + } + } else if exact_declared_endpoint_host { + // Exact declared hostname mode mirrors CONNECT: private resolved + // addresses are allowed for this operator-declared host:port, while + // always-blocked addresses and control-plane ports remain denied. + match resolve_and_check_declared_endpoint(&host, port, sandbox_entrypoint_pid).await { + Ok(addrs) => addrs, + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: declared endpoint check failed for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!( + "{method} {host_lc}:{port} blocked: declared endpoint check failed" + ), + ), + ) + .await?; + return Ok(()); + } + } + } else { + // No allowed_ips: reject internal IPs, allow public IPs through. + match resolve_and_reject_internal(&host, port, sandbox_entrypoint_pid).await { + Ok(addrs) => addrs, + Err(reason) => { + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Denied) + .disposition(DispositionId::Blocked) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "ssrf") + .message(format!( + "FORWARD blocked: internal IP without allowed_ips for {host_lc}:{port}" + )) + .status_detail(&reason) + .build(); + ocsf_emit!(event); + } + emit_denial_simple( + denial_tx, + &host_lc, + port, + &binary_str, + &decision, + &reason, + "ssrf", + ); + emit_activity_simple(activity_tx, true, "ssrf"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "ssrf_denied", + &format!("{method} {host_lc}:{port} blocked: internal address"), + ), + ) + .await?; + return Ok(()); + } + } + }; + + if let Err(e) = forward_generation_guard.ensure_current() { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + emit_activity_simple(activity_tx, true, "policy_stale"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + + // 6. Connect upstream + let mut upstream = match TcpStream::connect(addrs.as_slice()).await { + Ok(s) => s, + Err(e) => { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Low) + .status(StatusId::Failure) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .message(format!( + "FORWARD upstream connect failed for {host_lc}:{port}: {e}" + )) + .build(); + ocsf_emit!(event); + respond( + client, + &build_json_error_response( + 502, + "Bad Gateway", + "upstream_unreachable", + &format!("connection to {host_lc}:{port} failed"), + ), + ) + .await?; + return Ok(()); + } + }; + + // Log success + { + let event = HttpActivityBuilder::new(crate::ocsf_ctx()) + .activity(ActivityId::Other) + .action(ActionId::Allowed) + .disposition(DispositionId::Allowed) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .http_request(HttpRequest::new( + method, + OcsfUrl::new("http", &host_lc, &path, port), + )) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) + .actor_process( + Process::from_bypass(&binary_str, &pid_str, &ancestors_str) + .with_cmd_line(&cmdline_str), + ) + .firewall_rule(policy_str, "opa") + .message(format!("FORWARD allowed {method} {host_lc}:{port}{path}")) + .build(); + ocsf_emit!(event); + } + emit_forward_success_activity(activity_tx, l7_activity_pending); + + forward_request_bytes = match inject_token_grant_for_forward_request( + method, + &upstream_target, + forward_request_bytes, + &l7_ctx, + ) + .await + { + Ok(bytes) => bytes, + Err(e) => { + warn!( + dst_host = %host_lc, + dst_port = port, + error = %e, + "token grant failed in forward proxy" + ); + respond( + client, + &build_json_error_response( + 502, + "Bad Gateway", + "token_grant_failed", + "dynamic token grant failed", + ), + ) + .await?; + return Ok(()); + } + }; + + // 9. Rewrite request and forward to upstream + let rewritten = match rewrite_forward_request( + &forward_request_bytes, + forward_request_bytes.len(), + &upstream_target, + secret_resolver.as_deref(), + request_body_credential_rewrite, + ) { + Ok(bytes) => bytes, + Err(e) => { + warn!( + dst_host = %host_lc, + dst_port = port, + error = %e, + "credential injection failed in forward proxy" + ); + respond( + client, + &build_json_error_response( + 500, + "Internal Server Error", + "credential_injection_failed", + "unresolved credential placeholder in request", + ), + ) + .await?; + return Ok(()); + } + }; + if let Err(e) = forward_generation_guard.ensure_current() { + emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + let outcome = relay_rewritten_forward_request( + method, + &path, + rewritten, + client, + &mut upstream, + ForwardRelayOptions { + generation_guard: &forward_generation_guard, + websocket_extensions, + secret_resolver: secret_resolver.as_deref(), + request_body_credential_rewrite, + }, + ) + .await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + let mut upgrade_options = if let (Some(config), Some(engine)) = ( + forward_upgrade_config.as_ref(), + forward_tunnel_engine.as_ref(), + ) { + crate::l7::relay::upgrade_options( + config, + &l7_ctx, + forward_websocket_request, + &forward_upgrade_target, + &forward_upgrade_query_params, + Some(engine), + ) + } else { + crate::l7::relay::UpgradeRelayOptions { + websocket_request: forward_websocket_request, + ..Default::default() + } + }; + upgrade_options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + client, + &mut upstream, + overflow, + &host_lc, + port, + upgrade_options, + ) + .await?; + } + + Ok(()) +} + +fn parse_target(target: &str) -> Result<(String, u16)> { + let (host, port_str) = target + .split_once(':') + .ok_or_else(|| miette::miette!("CONNECT target missing port: {target}"))?; + let port: u16 = port_str + .parse() + .map_err(|_| miette::miette!("Invalid port in CONNECT target: {target}"))?; + Ok((host.to_string(), port)) +} + +async fn respond(client: &mut TcpStream, bytes: &[u8]) -> Result<()> { + client.write_all(bytes).await.into_diagnostic()?; + Ok(()) +} + +/// Build an HTTP error response with a JSON body. +/// +/// Returns bytes ready to write to the client socket. The body is a JSON +/// object with `error` and `detail` fields, matching the format used by the +/// L7 deny path in `l7/rest.rs`. +fn build_json_error_response(status: u16, status_text: &str, error: &str, detail: &str) -> Vec { + let body = serde_json::json!({ + "error": error, + "detail": detail, + }); + let body_str = body.to_string(); + format!( + "HTTP/1.1 {status} {status_text}\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\ + Connection: close\r\n\ + \r\n\ + {}", + body_str.len(), + body_str, + ) + .into_bytes() +} + +/// Check if a miette error represents a benign connection close. +/// +/// TLS handshake EOF, missing `close_notify`, connection resets, and broken +/// pipes are all normal lifecycle events for proxied connections — not worth +/// a WARN that interrupts the user's terminal. +fn is_benign_relay_error(err: &miette::Report) -> bool { + const BENIGN: &[&str] = &[ + "close_notify", + "tls handshake eof", + "connection reset", + "broken pipe", + "unexpected eof", + ]; + let msg = err.to_string().to_ascii_lowercase(); + BENIGN.iter().any(|pat| msg.contains(pat)) +} + +#[cfg(test)] +#[allow( + clippy::needless_raw_string_hashes, + clippy::iter_on_single_items, + clippy::needless_continue, + reason = "Test code: test fixtures and explicit control-flow markers are idiomatic in tests." +)] +mod tests { + use super::*; + use std::future::Future; + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use std::sync::Arc; + use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; + use tokio::net::{TcpListener, TcpStream}; + + fn websocket_l7_config( + protocol: crate::l7::L7Protocol, + websocket_credential_rewrite: bool, + ) -> crate::l7::L7EndpointConfig { + crate::l7::L7EndpointConfig { + protocol, + path: "/**".to_string(), + tls: crate::l7::TlsMode::Auto, + enforcement: crate::l7::EnforcementMode::Enforce, + graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + allow_encoded_slash: false, + websocket_credential_rewrite, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + } + } + + #[test] + fn connect_activity_is_skipped_when_l7_will_count_the_request() { + let (tx, mut rx) = mpsc::channel(4); + let activity_tx = Some(tx); + let l7_route = L7RouteSnapshot { + configs: vec![L7ConfigSnapshot { + config: websocket_l7_config(crate::l7::L7Protocol::Rest, false), + }], + generation: 1, + }; + let l4_route = L7RouteSnapshot { + configs: Vec::new(), + generation: 1, + }; + + emit_connect_activity_if_l4_only(&activity_tx, Some(&l7_route)); + assert!( + rx.try_recv().is_err(), + "L7-inspected CONNECT should not emit an extra L4 activity event" + ); + + emit_connect_activity_if_l4_only(&activity_tx, Some(&l4_route)); + let event = rx.try_recv().expect("L4-only CONNECT should emit activity"); + assert!(!event.denied); + assert_eq!(event.deny_group, "unknown"); + + emit_connect_activity_if_l4_only(&activity_tx, None); + let event = rx + .try_recv() + .expect("CONNECT without an L7 route should emit activity"); + assert!(!event.denied); + assert_eq!(event.deny_group, "unknown"); + } + + #[test] + fn forward_l7_allowed_activity_is_deferred_until_after_ssrf() { + let (tx, mut rx) = mpsc::channel(4); + let activity_tx = Some(tx); + + let l7_activity_pending = true; + assert!( + rx.try_recv().is_err(), + "allowed L7 evaluation must not emit activity before SSRF succeeds" + ); + + emit_activity_simple(activity_tx.as_ref(), true, "ssrf"); + let event = rx + .try_recv() + .expect("SSRF denial should emit the request activity"); + assert!(event.denied); + assert_eq!(event.deny_group, "ssrf"); + assert!( + rx.try_recv().is_err(), + "SSRF-denied forward request must not also emit allowed L7 activity" + ); + + emit_forward_success_activity(activity_tx.as_ref(), l7_activity_pending); + let event = rx + .try_recv() + .expect("L7 activity should emit after SSRF succeeds"); + assert!(!event.denied); + assert_eq!(event.deny_group, "l7_policy"); + } + + #[test] + fn forward_success_activity_uses_unknown_without_l7() { + let (tx, mut rx) = mpsc::channel(4); + let activity_tx = Some(tx); + + emit_forward_success_activity(activity_tx.as_ref(), false); + let event = rx + .try_recv() + .expect("non-L7 forward success should emit activity"); + assert!(!event.denied); + assert_eq!(event.deny_group, "unknown"); + } + + fn forward_test_guard() -> PolicyGenerationGuard { + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + engine + .generation_guard(engine.current_generation()) + .unwrap() + } + + async fn relay_forward_request_and_capture( + method: &str, + path: &str, + raw: &[u8], + resolver: Option<&SecretResolver>, + request_body_credential_rewrite: bool, + ) -> Result { + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw, + raw.len(), + path, + resolver, + request_body_credential_rewrite, + ) + .map_err(|e| miette::miette!("{e}"))?; + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let upstream_task = tokio::spawn(async move { + let mut buf = vec![0u8; 8192]; + let mut total = 0usize; + let mut expected_total = None; + loop { + let n = upstream_side.read(&mut buf[total..]).await.unwrap(); + if n == 0 { + break; + } + total += n; + if expected_total.is_none() + && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") + { + let header_end = end + 4; + let headers = String::from_utf8_lossy(&buf[..header_end]); + let len = headers + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().ok()) + .flatten() + }) + .unwrap_or(0); + expected_total = Some(header_end + len); + } + if expected_total.is_some_and(|expected| total >= expected) { + break; + } + } + upstream_side + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + upstream_side.flush().await.unwrap(); + String::from_utf8_lossy(&buf[..total]).to_string() + }); + + relay_rewritten_forward_request( + method, + path, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: resolver, + request_body_credential_rewrite, + }, + ) + .await?; + + upstream_task + .await + .map_err(|e| miette::miette!("upstream task failed: {e}")) + } + + fn forward_token_grant_context( + resolver_response: std::result::Result<&str, &str>, + ) -> ( + crate::l7::relay::L7EvalContext, + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture, + ) { + let provider_key = "api.example.test\t8080\t/v1/**\tprovider:access_token"; + let fixture = match resolver_response { + Ok(token) => { + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::success( + provider_key, + token, + ) + } + Err(error) => { + crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::failure( + provider_key, + error, + ) + } + }; + let ctx = crate::l7::relay::L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: Some(fixture.dynamic_credentials()), + token_grant_resolver: Some(fixture.resolver()), + }; + + (ctx, fixture) + } + + fn authorization_header_count(headers: &str) -> usize { + headers + .lines() + .filter(|line| { + line.split_once(':') + .is_some_and(|(name, _)| name.eq_ignore_ascii_case("authorization")) + }) + .count() + } + + fn forward_websocket_policy_parts( + data: &str, + host: &str, + port: u16, + path: &str, + policy_name: &str, + ) -> ( + crate::l7::L7EndpointConfig, + crate::opa::TunnelPolicyEngine, + crate::l7::relay::L7EvalContext, + ) { + let policy = include_str!("../data/sandbox-policy.rego"); + let engine = OpaEngine::from_strings(policy, data).unwrap(); + let decision = ConnectDecision { + action: NetworkAction::Allow { + matched_policy: Some(policy_name.to_string()), + }, + generation: engine.current_generation(), + binary: Some(PathBuf::from("/usr/bin/node")), + binary_pid: None, + ancestors: vec![], + cmdline_paths: vec![], + }; + let route = + query_l7_route_snapshot(&engine, &decision, host, port).expect("L7 route should match"); + let config = select_l7_config_for_path(&route.configs, path) + .expect("path-specific L7 config should match") + .config + .clone(); + let tunnel_engine = engine + .clone_engine_for_tunnel(route.generation) + .expect("tunnel engine"); + let ctx = crate::l7::relay::L7EvalContext { + host: host.to_string(), + port, + policy_name: policy_name.to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + (config, tunnel_engine, ctx) + } + + async fn read_http_headers(reader: &mut R) -> Vec { + let mut bytes = Vec::new(); + let mut chunk = [0u8; 256]; + loop { + let n = + tokio::time::timeout(std::time::Duration::from_secs(1), reader.read(&mut chunk)) + .await + .expect("HTTP headers should arrive") + .expect("header read should succeed"); + assert!(n > 0, "stream closed before HTTP headers"); + bytes.extend_from_slice(&chunk[..n]); + if bytes.windows(4).any(|w| w == b"\r\n\r\n") { + return bytes; + } + } + } + + fn masked_text_frame(payload: &[u8]) -> Vec { + let mask = [0x11, 0x22, 0x33, 0x44]; + assert!( + payload.len() <= 125, + "test helper only supports small frames" + ); + let payload_len = u8::try_from(payload.len()).expect("small frame length"); + let mut frame = vec![0x81, 0x80 | payload_len]; + frame.extend_from_slice(&mask); + frame.extend( + payload + .iter() + .enumerate() + .map(|(idx, byte)| byte ^ mask[idx % 4]), + ); + frame + } + + async fn forward_websocket_denied_after_upgrade( + config: crate::l7::L7EndpointConfig, + tunnel_engine: crate::opa::TunnelPolicyEngine, + ctx: crate::l7::relay::L7EvalContext, + path: &str, + payload: &str, + ) -> (miette::Report, Vec) { + let host = ctx.host.clone(); + let port = ctx.port; + let raw = format!( + "GET http://{host}{path} HTTP/1.1\r\n\ + Host: {host}\r\n\ + Upgrade: websocket\r\n\ + Connection: Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n" + ); + let rewritten = rewrite_forward_request(raw.as_bytes(), raw.len(), path, None, false) + .expect("forward websocket request should rewrite to origin form"); + let websocket_extensions = crate::l7::relay::websocket_extension_mode(&config); + let target = path.to_string(); + let query_params = std::collections::HashMap::new(); + let (mut proxy_to_upstream, mut upstream) = tokio::io::duplex(8192); + let (mut app, mut proxy_to_client) = tokio::io::duplex(8192); + + let relay = tokio::spawn(async move { + let guard = tunnel_engine.generation_guard(); + let outcome = relay_rewritten_forward_request( + "GET", + &target, + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: guard, + websocket_extensions, + secret_resolver: None, + request_body_credential_rewrite: false, + }, + ) + .await?; + if let crate::l7::provider::RelayOutcome::Upgraded { + overflow, + websocket_permessage_deflate, + } = outcome + { + let mut options = crate::l7::relay::upgrade_options( + &config, + &ctx, + true, + &target, + &query_params, + Some(&tunnel_engine), + ); + options.websocket.permessage_deflate = websocket_permessage_deflate; + crate::l7::relay::handle_upgrade( + &mut proxy_to_client, + &mut proxy_to_upstream, + overflow, + &host, + port, + options, + ) + .await?; + } + Ok::<(), miette::Report>(()) + }); + + let forwarded_headers = read_http_headers(&mut upstream).await; + let forwarded_headers = String::from_utf8_lossy(&forwarded_headers); + assert!(forwarded_headers.starts_with(&format!("GET {path} HTTP/1.1\r\n"))); + assert!(forwarded_headers.contains("Upgrade: websocket\r\n")); + + upstream + .write_all( + b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", + ) + .await + .unwrap(); + + let response = read_http_headers(&mut app).await; + assert!(String::from_utf8_lossy(&response).contains("101 Switching Protocols")); + + app.write_all(&masked_text_frame(payload.as_bytes())) + .await + .unwrap(); + + let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("websocket relay should fail closed after denied frame") + .expect("relay task should not panic") + .expect_err("denied websocket frame should fail the forward relay"); + + let mut leaked = Vec::new(); + tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read_to_end(&mut leaked), + ) + .await + .expect("upstream side should close") + .expect("upstream read should succeed"); + (err, leaked) + } + + #[test] + fn forward_websocket_upgrade_options_enable_native_policy_context() { + let (_, resolver) = SecretResolver::from_provider_env( + [("DISCORD_BOT_TOKEN".to_string(), "discord-real".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.map(Arc::new); + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "ws_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: resolver, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let query_params = std::collections::HashMap::new(); + + let extensions = crate::l7::relay::websocket_extension_mode(&websocket_l7_config( + crate::l7::L7Protocol::Websocket, + true, + )); + let options = crate::l7::relay::upgrade_options( + &websocket_l7_config(crate::l7::L7Protocol::Websocket, true), + &ctx, + true, + "/ws", + &query_params, + Some(&tunnel_engine), + ); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::PermessageDeflate + ); + assert!(options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_some()); + assert!(options.engine.is_some()); + assert!(options.ctx.is_some()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::Transport + )); + } + + #[test] + fn forward_websocket_upgrade_options_preserve_rest_without_rewrite() { + let ctx = crate::l7::relay::L7EvalContext { + host: "gateway.example.test".to_string(), + port: 80, + policy_name: "rest_api".to_string(), + binary_path: "/usr/bin/node".to_string(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let query_params = std::collections::HashMap::new(); + let config = websocket_l7_config(crate::l7::L7Protocol::Rest, false); + let extensions = crate::l7::relay::websocket_extension_mode(&config); + let options = + crate::l7::relay::upgrade_options(&config, &ctx, true, "/ws", &query_params, None); + + assert_eq!( + extensions, + crate::l7::rest::WebSocketExtensionMode::Preserve + ); + assert!(!options.websocket.credential_rewrite); + assert!(options.secret_resolver.is_none()); + assert!(options.engine.is_none()); + assert!(options.ctx.is_none()); + assert!(matches!( + options.websocket.message_policy, + crate::l7::relay::WebSocketMessagePolicy::None + )); + } + + #[tokio::test] + async fn forward_websocket_upgrade_blocks_text_frame_by_policy() { + let data = r#" +network_policies: + ws_api: + name: ws_api + endpoints: + - host: gateway.example.test + port: 80 + path: "/ws" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + - allow: + method: WEBSOCKET_TEXT + path: "/ws" + deny_rules: + - method: WEBSOCKET_TEXT + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = + forward_websocket_policy_parts(data, "gateway.example.test", 80, "/ws", "ws_api"); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/ws", + r#"{"type":"unsafe"}"#, + ) + .await; + + assert!(err.to_string().contains("websocket text message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy WebSocket text frames must not reach upstream" + ); + } + + #[tokio::test] + async fn forward_graphql_websocket_upgrade_blocks_unallowed_operation() { + let data = r#" +network_policies: + graphql_ws: + name: graphql_ws + endpoints: + - host: gateway.example.test + port: 80 + path: "/graphql" + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/graphql" + - allow: + operation_type: query + fields: [viewer] + deny_rules: + - operation_type: query + fields: [admin] + binaries: + - { path: /usr/bin/node } +"#; + let (config, tunnel_engine, ctx) = forward_websocket_policy_parts( + data, + "gateway.example.test", + 80, + "/graphql", + "graphql_ws", + ); + assert!( + config.websocket_graphql_policy, + "operation rules should enable GraphQL-over-WebSocket inspection" + ); + + let (err, leaked) = forward_websocket_denied_after_upgrade( + config, + tunnel_engine, + ctx, + "/graphql", + r#"{"id":"1","type":"subscribe","payload":{"query":"query { admin }"}}"#, + ) + .await; + + assert!(err.to_string().contains("websocket GraphQL message denied")); + assert!( + leaked.is_empty(), + "denied forward-proxy GraphQL WebSocket operations must not reach upstream" + ); + } + + #[test] + fn l7_route_selection_prefers_path_specific_graphql_endpoint() { + let configs = vec![ + L7ConfigSnapshot { + config: crate::l7::L7EndpointConfig { + protocol: crate::l7::L7Protocol::Rest, + path: "/**".to_string(), + tls: crate::l7::TlsMode::Auto, + enforcement: crate::l7::EnforcementMode::Enforce, + graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }, + }, + L7ConfigSnapshot { + config: crate::l7::L7EndpointConfig { + protocol: crate::l7::L7Protocol::Graphql, + path: "/graphql".to_string(), + tls: crate::l7::TlsMode::Auto, + enforcement: crate::l7::EnforcementMode::Enforce, + graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + allow_encoded_slash: false, + websocket_credential_rewrite: false, + request_body_credential_rewrite: false, + websocket_graphql_policy: false, + }, + }, + ]; + + let selected = + select_l7_config_for_path(&configs, "/graphql").expect("expected path-specific route"); + assert_eq!(selected.config.protocol, crate::l7::L7Protocol::Graphql); + + let selected = + select_l7_config_for_path(&configs, "/repos/org/repo").expect("expected REST route"); + assert_eq!(selected.config.protocol, crate::l7::L7Protocol::Rest); + } + + // -- is_internal_ip: IPv4 -- + + #[test] + fn test_rejects_ipv4_loopback() { + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)))); + } + + #[test] + fn test_rejects_ipv4_private_10() { + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)))); + } + + #[test] + fn test_rejects_ipv4_private_172_16() { + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255)))); + } + + #[test] + fn test_rejects_ipv4_private_192_168() { + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( + 192, 168, 255, 255 + )))); + } + + #[test] + fn test_rejects_ipv4_link_local_metadata() { + // Cloud metadata endpoint + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 169, 254 + )))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1)))); + } + + #[test] + fn test_rejects_ipv4_unspecified() { + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))); + } + + #[test] + fn test_rejects_ipv4_cgnat() { + // 100.64.0.0/10 — CGNAT / shared address space (RFC 6598) + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1)))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 100, 50, 3)))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( + 100, 127, 255, 255 + )))); + // Just outside the /10 boundary + assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 128, 0, 1)))); + assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new( + 100, 63, 255, 255 + )))); + } + + #[test] + fn test_rejects_ipv4_special_use_ranges() { + // 192.0.0.0/24 — IETF protocol assignments + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(192, 0, 0, 1)))); + // 198.18.0.0/15 — benchmarking + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 18, 0, 1)))); + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 19, 255, 255)))); + // 198.51.100.0/24 — TEST-NET-2 + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 1)))); + // 203.0.113.0/24 — TEST-NET-3 + assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1)))); + } + + #[test] + fn test_rejects_ipv6_mapped_cgnat() { + // ::ffff:100.64.0.1 should be caught via IPv4-mapped unwrapping + let v6 = Ipv4Addr::new(100, 64, 0, 1).to_ipv6_mapped(); + assert!(is_internal_ip(IpAddr::V6(v6))); + } + + #[test] + fn test_allows_ipv4_public() { + assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); + assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)))); + assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)))); + } + + #[test] + fn test_allows_ipv4_non_private_172() { + // 172.32.0.0 is outside the 172.16/12 private range + assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 32, 0, 1)))); + } + + // -- is_internal_ip: IPv6 -- + + #[test] + fn test_rejects_ipv6_loopback() { + assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))); + } + + #[test] + fn test_rejects_ipv6_unspecified() { + assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED))); + } + + #[test] + fn test_rejects_ipv6_link_local() { + // fe80::1 + assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::new( + 0xfe80, 0, 0, 0, 0, 0, 0, 1 + )))); + } + + #[test] + fn test_rejects_ipv6_unique_local_address() { + // fdc4:f303:9324::254 + assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::new( + 0xfdc4, 0xf303, 0x9324, 0, 0, 0, 0, 0x0254 + )))); + } + + #[test] + fn test_rejects_ipv4_mapped_ipv6_private() { + // ::ffff:10.0.0.1 + let v6 = Ipv4Addr::new(10, 0, 0, 1).to_ipv6_mapped(); + assert!(is_internal_ip(IpAddr::V6(v6))); + } + + #[test] + fn test_rejects_ipv4_mapped_ipv6_loopback() { + // ::ffff:127.0.0.1 + let v6 = Ipv4Addr::LOCALHOST.to_ipv6_mapped(); + assert!(is_internal_ip(IpAddr::V6(v6))); + } + + #[test] + fn test_rejects_ipv4_mapped_ipv6_link_local() { + // ::ffff:169.254.169.254 + let v6 = Ipv4Addr::new(169, 254, 169, 254).to_ipv6_mapped(); + assert!(is_internal_ip(IpAddr::V6(v6))); + } + + #[test] + fn test_allows_ipv6_public() { + // 2001:4860:4860::8888 (Google DNS) + assert!(!is_internal_ip(IpAddr::V6(Ipv6Addr::new( + 0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888 + )))); + } + + #[test] + fn test_allows_ipv4_mapped_ipv6_public() { + // ::ffff:8.8.8.8 + let v6 = Ipv4Addr::new(8, 8, 8, 8).to_ipv6_mapped(); + assert!(!is_internal_ip(IpAddr::V6(v6))); + } + + // -- resolve_and_reject_internal -- + + #[test] + fn test_parse_hosts_file_for_host_handles_comments_invalid_rows_and_case() { + let contents = r#" + # comment + 192.168.1.105 searxng.local searxng + bad-ip ignored.local + 93.184.216.34 Example.Local # trailing comment + ::1 loopback.local + 192.168.1.105 searxng.local + "#; + + let result = parse_hosts_file_for_host(contents, "SEARXNG.LOCAL"); + assert_eq!(result, vec![IpAddr::V4(Ipv4Addr::new(192, 168, 1, 105))]); + + let public = parse_hosts_file_for_host(contents, "example.local"); + assert_eq!(public, vec![IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34))]); + } + + #[test] + fn test_resolve_from_hosts_file_contents_requires_exact_alias_match() { + let contents = "192.168.1.105 searxng.local\n"; + + assert!( + resolve_from_hosts_file_contents(contents, "searxng", 8080).is_empty(), + "partial alias match should not resolve" + ); + + let result = resolve_from_hosts_file_contents(contents, "searxng.local", 8080); + assert_eq!( + result, + vec![SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 105)), + 8080 + )] + ); + } + + #[test] + fn test_resolve_from_hosts_file_contents_public_ip_passes_default_ssrf_check() { + let addrs = + resolve_from_hosts_file_contents("93.184.216.34 example.local\n", "example.local", 80); + assert!(reject_internal_resolved_addrs("example.local", &addrs).is_ok()); + } + + #[test] + fn test_resolve_from_hosts_file_contents_private_ip_requires_allowed_ips() { + let addrs = resolve_from_hosts_file_contents( + "192.168.1.105 searxng.local\n", + "searxng.local", + 8080, + ); + + let err = reject_internal_resolved_addrs("searxng.local", &addrs).unwrap_err(); + assert!( + err.contains("internal address"), + "expected private hosts-file resolution to remain blocked: {err}" + ); + + let nets = parse_allowed_ips(&["192.168.1.105/32".to_string()]).unwrap(); + assert!( + validate_allowed_ips_for_resolved_addrs("searxng.local", 8080, &addrs, &nets).is_ok() + ); + } + + #[test] + fn test_declared_endpoint_private_hosts_file_resolution_allowed() { + let addrs = resolve_from_hosts_file_contents( + "192.168.1.105 searxng.local\n", + "searxng.local", + 8080, + ); + + assert!(validate_declared_endpoint_resolved_addrs("searxng.local", 8080, &addrs).is_ok()); + } + + #[test] + fn test_declared_endpoint_loopback_stays_blocked() { + let addrs = + resolve_from_hosts_file_contents("127.0.0.1 loopback.local\n", "loopback.local", 80); + + let err = + validate_declared_endpoint_resolved_addrs("loopback.local", 80, &addrs).unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected loopback to stay blocked: {err}" + ); + } + + #[test] + fn test_declared_endpoint_link_local_stays_blocked() { + let addrs = resolve_from_hosts_file_contents( + "169.254.169.254 metadata.local\n", + "metadata.local", + 80, + ); + + let err = + validate_declared_endpoint_resolved_addrs("metadata.local", 80, &addrs).unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected link-local to stay blocked: {err}" + ); + } + + #[test] + fn test_declared_endpoint_blocks_control_plane_ports() { + let addrs = + resolve_from_hosts_file_contents("10.0.0.5 kube-api.local\n", "kube-api.local", 6443); + + let err = + validate_declared_endpoint_resolved_addrs("kube-api.local", 6443, &addrs).unwrap_err(); + assert!( + err.contains("blocked control-plane port"), + "expected control-plane port to stay blocked: {err}" + ); + } + + #[test] + fn test_resolve_from_hosts_file_contents_always_blocked_ip_stays_blocked() { + let addrs = + resolve_from_hosts_file_contents("127.0.0.1 loopback.local\n", "loopback.local", 80); + let nets = vec!["127.0.0.0/8".parse::().unwrap()]; + let err = validate_allowed_ips_for_resolved_addrs("loopback.local", 80, &addrs, &nets) + .unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected always-blocked hosts-file resolution to stay blocked: {err}" + ); + } + + #[test] + fn test_resolve_from_hosts_file_contents_returns_empty_without_match() { + let result = + resolve_from_hosts_file_contents("192.168.1.105 searxng.local\n", "missing.local", 80); + assert!(result.is_empty()); + } + + // -- is_host_gateway_alias -- + + #[test] + fn test_is_host_gateway_alias_recognises_known_aliases() { + assert!(is_host_gateway_alias("host.openshell.internal")); + assert!(is_host_gateway_alias("host.containers.internal")); + assert!(is_host_gateway_alias("host.docker.internal")); + } + + #[test] + fn test_is_host_gateway_alias_is_case_insensitive() { + assert!(is_host_gateway_alias("HOST.OPENSHELL.INTERNAL")); + assert!(is_host_gateway_alias("Host.Containers.Internal")); + assert!(is_host_gateway_alias("HOST.DOCKER.INTERNAL")); + } + + #[test] + fn test_is_host_gateway_alias_rejects_unknown_hosts() { + assert!(!is_host_gateway_alias("api.example.com")); + assert!(!is_host_gateway_alias("host.openshell.internal.evil.com")); + assert!(!is_host_gateway_alias("evil.host.openshell.internal")); + assert!(!is_host_gateway_alias("openshell.internal")); + assert!(!is_host_gateway_alias("")); + } + + // -- is_cloud_metadata_ip -- + + #[test] + fn test_is_cloud_metadata_ip_blocks_known_metadata_ip() { + assert!(is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 169, 254 + )))); + } + + #[test] + fn test_is_cloud_metadata_ip_allows_other_link_local() { + // The pasta gateway address on this test host — not a metadata IP. + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 1, 2 + )))); + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 0, 1 + )))); + } + + #[test] + fn test_is_cloud_metadata_ip_allows_private_and_public() { + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 10, 0, 0, 1 + )))); + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( + 192, 168, 1, 1 + )))); + assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); + } + + #[test] + fn test_is_cloud_metadata_ip_blocks_ipv4_mapped_metadata() { + // ::ffff:169.254.169.254 is the IPv4-mapped IPv6 representation of the + // AWS/GCP/Azure IMDS endpoint. is_link_local_ip() recognizes it as + // link-local, so is_cloud_metadata_ip() must also catch it — otherwise + // the trusted-gateway exemption would be granted to the metadata service. + let mapped = Ipv4Addr::new(169, 254, 169, 254).to_ipv6_mapped(); + assert!( + is_cloud_metadata_ip(IpAddr::V6(mapped)), + "::ffff:169.254.169.254 must be recognized as cloud metadata" + ); + } + + #[test] + fn test_is_cloud_metadata_ip_allows_other_ipv4_mapped_link_local() { + // Other IPv4-mapped link-local addresses are NOT metadata. + let mapped = Ipv4Addr::new(169, 254, 1, 2).to_ipv6_mapped(); + assert!( + !is_cloud_metadata_ip(IpAddr::V6(mapped)), + "::ffff:169.254.1.2 should not be flagged as cloud metadata" + ); + } + + // -- detect_trusted_host_gateway -- + + #[test] + fn test_detect_trusted_host_gateway_returns_ip_from_hosts_content() { + // We test the underlying parser directly since detect_trusted_host_gateway + // reads the real /etc/hosts. The production code composes these same primitives. + let contents = "169.254.1.2\thost.openshell.internal host.containers.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_detect_trusted_host_gateway_ignores_cloud_metadata_ip() { + // Simulate a /etc/hosts where the driver injected the cloud metadata IP — + // this should be caught and suppressed. + let contents = "169.254.169.254\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))]); + // is_cloud_metadata_ip should flag it, preventing the exemption. + assert!(is_cloud_metadata_ip(ips[0])); + } + + #[test] + fn test_detect_trusted_host_gateway_no_entry_returns_empty() { + let contents = "127.0.0.1 localhost\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert!(ips.is_empty()); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_loopback() { + // Loopback is not link-local — must not receive the SSRF exemption. + let ip = IpAddr::V4(Ipv4Addr::LOCALHOST); + assert!(!is_cloud_metadata_ip(ip)); + assert!(!is_link_local_ip(ip)); + // The guard: !link-local → reject. + assert!(!is_link_local_ip(ip)); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_unspecified() { + // Unspecified (0.0.0.0) is not link-local — must not be trusted. + let ip = IpAddr::V4(Ipv4Addr::UNSPECIFIED); + assert!(!is_cloud_metadata_ip(ip)); + assert!(!is_link_local_ip(ip)); + assert!(!is_link_local_ip(ip)); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_loopback_v6() { + let ip = IpAddr::V6(Ipv6Addr::LOCALHOST); + assert!(!is_cloud_metadata_ip(ip)); + assert!(!is_link_local_ip(ip)); + } + + #[test] + fn test_detect_trusted_host_gateway_rejects_private_ip() { + // Docker bridge (172.17.0.1) and K8s host gateway (192.168.x.x) are + // RFC 1918 private addresses — not link-local. Before this fix they + // slipped through the old always-blocked guard and received the SSRF + // exemption. The new guard (!is_link_local_ip) rejects them, so + // connections to these hosts fall through to resolve_and_reject_internal(). + for ip in [ + IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1)), + IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), + IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), + ] { + assert!(!is_cloud_metadata_ip(ip), "{ip} should not be metadata"); + assert!(!is_link_local_ip(ip), "{ip} should not be link-local"); + // Guard fires — exemption disabled. + assert!(!is_link_local_ip(ip), "{ip}: guard must reject"); + } + } + + #[test] + fn test_detect_trusted_host_gateway_allows_link_local_non_metadata() { + // 169.254.1.2 (rootless Podman pasta gateway) IS link-local and is + // not a cloud metadata IP — it is the only address class the exemption + // is designed for. + let ip = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + assert!(!is_cloud_metadata_ip(ip)); + assert!(is_link_local_ip(ip)); + // Guard does NOT fire — this IP is eligible for the exemption. + assert!(is_link_local_ip(ip)); + } + + // -- parse_hosts_file_for_host: multi-entry / duplicate scenarios -- + + #[test] + fn test_parse_hosts_file_single_entry() { + // Normal driver-injected case: exactly one IP for the alias. + let contents = "169.254.1.2\thost.openshell.internal host.containers.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_parse_hosts_file_duplicate_same_ip_deduplicated() { + // Same IP on two separate lines for the same alias — deduplicated to one. + let contents = "169.254.1.2\thost.openshell.internal\n\ + 169.254.1.2\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!( + ips, + vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))], + "identical IPs across lines must be deduplicated" + ); + } + + #[test] + fn test_parse_hosts_file_multiple_distinct_ips() { + // Two distinct IPs for the same alias — both returned, first entry wins + // in detect_trusted_host_gateway(), second would cause mismatch rejection + // in resolve_and_check_trusted_gateway(). + let contents = "169.254.1.2\thost.openshell.internal\n\ + 169.254.1.3\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips.len(), 2, "two distinct IPs must both be returned"); + assert_eq!(ips[0], IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))); + assert_eq!(ips[1], IpAddr::V4(Ipv4Addr::new(169, 254, 1, 3))); + } + + #[test] + fn test_parse_hosts_file_first_entry_wins_on_ambiguity() { + // detect_trusted_host_gateway() pins to the first entry via .next(). + // Verify the ordering guarantee: first line wins. + let contents = "169.254.1.3\thost.openshell.internal\n\ + 169.254.1.2\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!( + ips[0], + IpAddr::V4(Ipv4Addr::new(169, 254, 1, 3)), + "first line must be first in the returned vec" + ); + } + + #[test] + fn test_parse_hosts_file_ignores_other_aliases_on_same_line() { + // An entry with multiple aliases — only the matching alias counts. + let contents = + "169.254.1.2\thost.containers.internal host.openshell.internal host.docker.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + // Non-matching aliases on the same line do not produce extra entries. + let ips2 = parse_hosts_file_for_host(contents, "host.docker.internal"); + assert_eq!(ips2, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_parse_hosts_file_alias_not_present() { + let contents = "127.0.0.1\tlocalhost\n\ + ::1\t\tlocalhost\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert!(ips.is_empty()); + } + + #[test] + fn test_parse_hosts_file_comment_lines_skipped() { + let contents = "# 169.254.1.2 host.openshell.internal\n\ + 169.254.1.2\thost.openshell.internal\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + // Commented-out line must not produce an entry. + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + #[test] + fn test_parse_hosts_file_inline_comment_stripped() { + // Anything after '#' on a data line is treated as a comment. + let contents = "169.254.1.2\thost.openshell.internal # injected by driver\n"; + let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); + assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); + } + + // -- resolve_and_check_trusted_gateway -- + + #[tokio::test] + async fn test_trusted_gateway_allows_link_local_gateway_ip() { + // Simulate the rootless Podman pasta case: host.openshell.internal + // points to a link-local address which is the only path to the host. + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + + // We resolve via /etc/hosts (pid=0 falls back to system), so we + // exercise the trusted_gw mismatch / cloud-metadata guards directly + // against a known resolved address. + let addrs = [SocketAddr::new(trusted_gw, 8080)]; + + // Validate the guard logic inline (mirrors resolve_and_check_trusted_gateway). + assert!(!is_cloud_metadata_ip(trusted_gw)); + assert_eq!(addrs[0].ip(), trusted_gw); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_cloud_metadata_ip() { + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + let metadata_ip = IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)); + + // Simulate resolution returning the metadata IP. + let addrs = [SocketAddr::new(metadata_ip, 80)]; + + // Cloud metadata check must fire before the trusted_gw equality check. + let err: Result<(), String> = if is_cloud_metadata_ip(addrs[0].ip()) { + Err(format!( + "host resolves to cloud metadata address {}, connection rejected", + addrs[0].ip() + )) + } else if addrs[0].ip() != trusted_gw { + Err(format!( + "host resolves to {} which does not match trusted host gateway \ + {trusted_gw}, connection rejected", + addrs[0].ip() + )) + } else { + Ok(()) + }; + + assert!(err.is_err()); + assert!( + err.unwrap_err().contains("cloud metadata"), + "expected cloud-metadata rejection" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_mismatched_ip() { + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + let other_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); + + let addrs = [SocketAddr::new(other_ip, 8080)]; + + let err: Result<(), String> = if is_cloud_metadata_ip(addrs[0].ip()) { + Err("cloud metadata".to_string()) + } else if addrs[0].ip() != trusted_gw { + Err(format!( + "{} does not match trusted host gateway {trusted_gw}", + addrs[0].ip() + )) + } else { + Ok(()) + }; + + assert!(err.is_err()); + assert!( + err.unwrap_err() + .contains("does not match trusted host gateway"), + "expected mismatch rejection" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_control_plane_port() { + // Control-plane port check runs before resolution. + let result = resolve_and_check_trusted_gateway( + "host.openshell.internal", + 6443, + IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)), + 0, + ) + .await; + assert!(result.is_err()); + assert!( + result.unwrap_err().contains("blocked control-plane port"), + "expected control-plane port rejection" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_all_control_plane_ports() { + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + for &port in BLOCKED_CONTROL_PLANE_PORTS { + let result = + resolve_and_check_trusted_gateway("host.openshell.internal", port, trusted_gw, 0) + .await; + assert!( + result.is_err(), + "port {port} should be blocked by control-plane guard" + ); + assert!( + result.unwrap_err().contains("blocked control-plane port"), + "expected control-plane rejection for port {port}" + ); + } + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_loopback_as_trusted_gw() { + // Defense-in-depth: even if detect_trusted_host_gateway somehow admitted + // a loopback IP, resolve_and_check_trusted_gateway must reject it. + // Using an IP literal as the host bypasses DNS and gives a deterministic + // resolved address, allowing us to exercise the actual function. + let loopback = IpAddr::V4(Ipv4Addr::LOCALHOST); + let result = resolve_and_check_trusted_gateway("127.0.0.1", 8080, loopback, 0).await; + assert!(result.is_err(), "loopback must be rejected"); + let err = result.unwrap_err(); + assert!( + err.contains("non-link-local"), + "expected non-link-local rejection, got: {err}" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_unspecified_as_trusted_gw() { + // Defense-in-depth: 0.0.0.0 as trusted_gw must be rejected. + // IP literal resolves to 0.0.0.0 directly, bypassing DNS. + let unspecified = IpAddr::V4(Ipv4Addr::UNSPECIFIED); + let result = resolve_and_check_trusted_gateway("0.0.0.0", 8080, unspecified, 0).await; + assert!(result.is_err(), "unspecified must be rejected"); + let err = result.unwrap_err(); + assert!( + err.contains("non-link-local"), + "expected non-link-local rejection, got: {err}" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_ip_literal_mismatch() { + // If the requested IP literal doesn't match trusted_gw, the mismatch + // guard fires. This exercises the full resolution→validation path. + let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); + let other_ip = "10.0.0.1"; // RFC1918, resolves as a literal + let result = resolve_and_check_trusted_gateway(other_ip, 8080, trusted_gw, 0).await; + assert!(result.is_err(), "IP mismatch must be rejected"); + let err = result.unwrap_err(); + assert!( + err.contains("does not match trusted host gateway"), + "expected mismatch rejection, got: {err}" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_cloud_metadata_literal() { + // Cloud metadata IP as a literal address — must be rejected even when + // it matches trusted_gw (which detect_trusted_host_gateway prevents, + // but this is the defense-in-depth layer). + let metadata = IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)); + let result = resolve_and_check_trusted_gateway("169.254.169.254", 80, metadata, 0).await; + assert!(result.is_err(), "cloud metadata IP must be rejected"); + let err = result.unwrap_err(); + assert!( + err.contains("cloud metadata"), + "expected cloud-metadata rejection, got: {err}" + ); + } + + #[tokio::test] + async fn test_trusted_gateway_rejects_private_ip_as_trusted_gw() { + // Defense-in-depth: a private RFC 1918 IP (e.g. Docker bridge 172.17.0.1) + // must be rejected even if it somehow matched trusted_gw. + // detect_trusted_host_gateway() already blocks these via !is_link_local_ip(), + // but resolve_and_check_trusted_gateway() must enforce the same invariant. + let docker_bridge = IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1)); + let result = resolve_and_check_trusted_gateway("172.17.0.1", 8080, docker_bridge, 0).await; + assert!(result.is_err(), "private RFC 1918 IP must be rejected"); + let err = result.unwrap_err(); + assert!( + err.contains("non-link-local"), + "expected non-link-local rejection for private IP, got: {err}" + ); + } + + #[tokio::test] + async fn test_rejects_localhost_resolution() { + let result = resolve_and_reject_internal("localhost", 80, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("internal address"), + "expected 'internal address' in error: {err}" + ); + } + + #[tokio::test] + async fn test_rejects_loopback_ip_literal() { + let result = resolve_and_reject_internal("127.0.0.1", 443, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("internal address"), + "expected 'internal address' in error: {err}" + ); + } + + #[tokio::test] + async fn test_rejects_metadata_ip() { + let result = resolve_and_reject_internal("169.254.169.254", 80, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("internal address"), + "expected 'internal address' in error: {err}" + ); + } + + #[tokio::test] + async fn test_dns_failure_returns_error() { + let result = resolve_and_reject_internal("this-host-does-not-exist.invalid", 80, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("DNS resolution failed"), + "expected 'DNS resolution failed' in error: {err}" + ); + } + + #[tokio::test] + async fn inference_interception_applies_router_header_allowlist() { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_task = tokio::spawn(async move { + use crate::l7::inference::{ParseResult, try_parse_http_request}; + + let (mut upstream, _) = listener.accept().await.unwrap(); + let mut buf = Vec::new(); + let mut chunk = [0u8; 4096]; + + loop { + let n = upstream.read(&mut chunk).await.unwrap(); + assert!(n > 0, "upstream request closed before request completed"); + buf.extend_from_slice(&chunk[..n]); + + match try_parse_http_request(&buf) { + ParseResult::Complete(_, consumed) => { + upstream + .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") + .await + .unwrap(); + return String::from_utf8_lossy(&buf[..consumed]).to_string(); + } + ParseResult::Incomplete => continue, + ParseResult::Invalid(reason) => { + panic!("forwarded request should parse cleanly: {reason}"); + } + } + } + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint: format!("http://{upstream_addr}"), + model: "meta/llama-3.1-8b-instruct".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["openai_chat_completions".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![ + "openai-organization".to_string(), + "x-model-id".to_string(), + ], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + }], + vec![], + ); + + let body = r#"{"model":"ignored","messages":[{"role":"user","content":"hi"}]}"#; + let request = format!( + "POST /v1/chat/completions HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + OpenAI-Organization: org_123\r\n\ + Authorization: Bearer client-key\r\n\ + Cookie: session=abc\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, + ); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response_text = String::from_utf8_lossy(&response); + assert!(response_text.starts_with("HTTP/1.1 200")); + + let outcome = server_task.await.unwrap().unwrap(); + assert!( + matches!(outcome, InferenceOutcome::Routed), + "expected Routed outcome, got: {outcome:?}" + ); + + let forwarded = upstream_task.await.unwrap(); + let forwarded_lc = forwarded.to_ascii_lowercase(); + assert!(forwarded_lc.contains("openai-organization: org_123")); + assert!(forwarded_lc.contains("authorization: bearer test-api-key")); + assert!(!forwarded_lc.contains("authorization: bearer client-key")); + assert!(!forwarded_lc.contains("cookie:")); + } + + fn streaming_inference_route(endpoint: String) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "meta/llama-3.1-8b-instruct".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["openai_chat_completions".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + } + } + + fn embeddings_inference_route(endpoint: String) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "text-embedding-3-small".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["openai_embeddings".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + } + } + + /// Embeddings responses are a single buffered JSON object, not an SSE + /// stream. They must be framed with `Content-Length` and must never be sent + /// through the chunked streaming path, whose truncation handlers would + /// append an SSE `proxy_stream_error` frame into the JSON body. + #[tokio::test] + async fn inference_embeddings_served_buffered_with_content_length() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = r#"{"object":"list","data":[{"object":"embedding","index":0,"embedding":[0.1,0.2]}],"model":"text-embedding-3-small"}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + // Buffered upstream response with Content-Length (no chunked TE). + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![embeddings_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + let body = r#"{"model":"text-embedding-3-small","input":"hello"}"#; + let request = format!( + "POST /v1/embeddings HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, + ); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + + server_task.await.unwrap().unwrap(); + upstream_task.await.unwrap(); + + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected buffered 200 response, got: {response}" + ); + let lower = response.to_ascii_lowercase(); + assert!( + lower.contains("content-length:"), + "embeddings response must be Content-Length framed, got: {response}" + ); + assert!( + !lower.contains("transfer-encoding: chunked"), + "embeddings response must NOT be chunked, got: {response}" + ); + assert!( + !response.contains("proxy_stream_error"), + "embeddings response must not carry an SSE error frame, got: {response}" + ); + assert!( + response.contains(r#""object":"list""#), + "embeddings JSON body must be forwarded intact, got: {response}" + ); + } + + fn model_discovery_inference_route( + endpoint: String, + ) -> openshell_router::config::ResolvedRoute { + openshell_router::config::ResolvedRoute { + name: "inference.local".to_string(), + endpoint, + model: "text-embedding-3-small".to_string(), + api_key: "test-api-key".to_string(), + protocols: vec!["model_discovery".to_string()], + auth: openshell_router::config::AuthHeader::Bearer, + default_headers: vec![], + passthrough_headers: vec![], + timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, + model_in_path: false, + request_path_override: None, + } + } + + /// `GET /v1/models` (model discovery) returns one JSON object — a model + /// list — exactly like embeddings. It must be served buffered with + /// `Content-Length`, never through the chunked streaming path whose + /// truncation handlers would append an SSE `proxy_stream_error` frame into + /// the JSON body. This guards the framing classification for the protocol. + #[tokio::test] + async fn inference_model_discovery_served_buffered_with_content_length() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = + r#"{"object":"list","data":[{"id":"text-embedding-3-small","object":"model"}]}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + // GET model discovery carries no request body. + let request = "GET /v1/models HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + + server_task.await.unwrap().unwrap(); + upstream_task.await.unwrap(); + + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected buffered 200 response, got: {response}" + ); + let lower = response.to_ascii_lowercase(); + assert!( + lower.contains("content-length:"), + "model discovery response must be Content-Length framed, got: {response}" + ); + assert!( + !lower.contains("transfer-encoding: chunked"), + "model discovery response must NOT be chunked, got: {response}" + ); + assert!( + !response.contains("proxy_stream_error"), + "model discovery response must not carry an SSE error frame, got: {response}" + ); + assert!( + response.contains(r#""object":"list""#), + "model discovery JSON body must be forwarded intact, got: {response}" + ); + } + + /// `GET /v1/models/{id}` (model discovery glob) must forward the model id in + /// the path through the buffered path with the id intact, never streamed. + #[tokio::test] + async fn inference_model_discovery_glob_path_served_buffered() { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_body = r#"{"id":"gpt-4.1","object":"model"}"#; + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + let forwarded = read_forwarded_request_line(&mut upstream).await; + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + upstream_body.len(), + upstream_body, + ); + upstream.write_all(resp.as_bytes()).await.unwrap(); + forwarded + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{upstream_addr}" + ))], + vec![], + ); + + let request = "GET /v1/models/gpt-4.1 HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + server_task.await.unwrap().unwrap(); + let (method, forwarded_path) = upstream_task.await.unwrap(); + + assert_eq!(method, "GET"); + assert_eq!( + forwarded_path, "/v1/models/gpt-4.1", + "the model id in the glob path must be forwarded intact" + ); + let lower = response.to_ascii_lowercase(); + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n") + && lower.contains("content-length:") + && !lower.contains("transfer-encoding: chunked") + && !response.contains("proxy_stream_error"), + "glob model discovery must be buffered and Content-Length framed, got: {response}" + ); + } + + /// A failed model-discovery upstream must produce a buffered, Content-Length + /// framed JSON error, never a chunked SSE `proxy_stream_error` frame. + #[tokio::test] + async fn inference_model_discovery_error_served_buffered() { + // A port with no listener so the upstream connection is refused. + let dead_addr = { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + drop(listener); + addr + }; + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![model_discovery_inference_route(format!( + "http://{dead_addr}" + ))], + vec![], + ); + + let request = "GET /v1/models HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Length: 0\r\n\r\n" + .to_string(); + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + let response = String::from_utf8(response).unwrap(); + server_task.await.unwrap().unwrap(); + + let lower = response.to_ascii_lowercase(); + assert!( + response.starts_with("HTTP/1.1 5"), + "a refused upstream should yield a 5xx, got: {response}" + ); + assert!( + lower.contains("content-length:") + && !lower.contains("transfer-encoding: chunked") + && !response.contains("proxy_stream_error"), + "buffered model-discovery error must be Content-Length framed JSON, got: {response}" + ); + assert!( + response.contains("error"), + "error response should carry a JSON error body, got: {response}" + ); + } + + async fn read_forwarded_inference_request(stream: &mut S) { + use crate::l7::inference::{ParseResult, try_parse_http_request}; + + let mut buf = Vec::new(); + let mut chunk = [0u8; 4096]; + loop { + let n = stream.read(&mut chunk).await.unwrap(); + assert!(n > 0, "upstream request closed before completion"); + buf.extend_from_slice(&chunk[..n]); + + match try_parse_http_request(&buf) { + ParseResult::Complete(_, _) => return, + ParseResult::Incomplete => continue, + ParseResult::Invalid(reason) => { + panic!("forwarded request should parse cleanly: {reason}"); + } + } + } + } + + /// Like [`read_forwarded_inference_request`] but returns the forwarded + /// request line (method, path) so a test can assert the upstream URL path. + async fn read_forwarded_request_line(stream: &mut S) -> (String, String) { + use crate::l7::inference::{ParseResult, try_parse_http_request}; + + let mut buf = Vec::new(); + let mut chunk = [0u8; 4096]; + loop { + let n = stream.read(&mut chunk).await.unwrap(); + assert!(n > 0, "upstream request closed before completion"); + buf.extend_from_slice(&chunk[..n]); + + match try_parse_http_request(&buf) { + ParseResult::Complete(req, _) => return (req.method, req.path), + ParseResult::Incomplete => continue, + ParseResult::Invalid(reason) => { + panic!("forwarded request should parse cleanly: {reason}"); + } + } + } + } + + async fn run_live_streaming_inference(serve_upstream: F) -> String + where + F: FnOnce(TcpStream) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let upstream_addr = listener.local_addr().unwrap(); + let upstream_task = tokio::spawn(async move { + let (mut upstream, _) = listener.accept().await.unwrap(); + read_forwarded_inference_request(&mut upstream).await; + serve_upstream(upstream).await; + }); + + let router = openshell_router::Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + let ctx = InferenceContext::new( + patterns, + router, + vec![streaming_inference_route(format!("http://{upstream_addr}"))], + vec![], + ); + + let body = r#"{"model":"ignored","messages":[{"role":"user","content":"hi"}]}"#; + let request = format!( + "POST /v1/chat/completions HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + Accept: text/event-stream\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, + ); + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + client_write.write_all(request.as_bytes()).await.unwrap(); + client_write.shutdown().await.unwrap(); + + let mut response = Vec::new(); + client_read.read_to_end(&mut response).await.unwrap(); + + let outcome = server_task.await.unwrap().unwrap(); + assert!( + matches!(outcome, InferenceOutcome::Routed), + "expected Routed outcome, got: {outcome:?}" + ); + upstream_task.await.unwrap(); + + String::from_utf8(response).unwrap() + } + + fn assert_streaming_sse_error(response: &str, message: &str) { + assert!( + response.starts_with("HTTP/1.1 200 OK\r\n"), + "expected successful streaming response, got: {response}" + ); + assert!( + response + .to_ascii_lowercase() + .contains("transfer-encoding: chunked"), + "expected chunked streaming response, got: {response}" + ); + assert!( + response.contains("\"type\":\"proxy_stream_error\""), + "expected proxy_stream_error SSE event, got: {response}" + ); + assert!( + response.contains(&format!("\"message\":\"{message}\"")), + "expected SSE message {message:?}, got: {response}" + ); + assert!( + response.ends_with("0\r\n\r\n"), + "streaming response must end with chunked terminator, got: {response}" + ); + } + + #[tokio::test] + async fn inference_stream_byte_limit_injects_sse_error() { + let response = run_live_streaming_inference(|mut upstream| async move { + use crate::l7::inference::{format_chunk, format_chunk_terminator}; + + upstream + .write_all( + b"HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Transfer-Encoding: chunked\r\n\r\n", + ) + .await + .unwrap(); + let body = vec![b'a'; MAX_STREAMING_BODY + 1]; + let _ = upstream.write_all(&format_chunk(&body)).await; + let _ = upstream.write_all(format_chunk_terminator()).await; + }) + .await; + + assert_streaming_sse_error( + &response, + "response truncated: exceeded maximum streaming body size", + ); + } + + #[tokio::test] + async fn inference_stream_upstream_read_error_injects_sse_error() { + let response = run_live_streaming_inference(|mut upstream| async move { + upstream + .write_all( + b"HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Content-Length: 64\r\n\r\n\ + partial", + ) + .await + .unwrap(); + }) + .await; + + assert!( + response.contains("partial"), + "expected initial upstream bytes before truncation, got: {response}" + ); + assert_streaming_sse_error(&response, "response truncated: upstream read error"); + } + + #[tokio::test] + async fn inference_stream_idle_timeout_injects_sse_error() { + let response = run_live_streaming_inference(|mut upstream| async move { + upstream + .write_all( + b"HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Transfer-Encoding: chunked\r\n\r\n", + ) + .await + .unwrap(); + tokio::time::sleep(CHUNK_IDLE_TIMEOUT + std::time::Duration::from_millis(50)).await; + }) + .await; + + assert_streaming_sse_error(&response, "response truncated: chunk idle timeout exceeded"); + } + + // -- router_error_to_http -- + + #[test] + fn router_error_route_not_found_maps_to_400() { + let err = openshell_router::RouterError::RouteNotFound("local".into()); + let (status, msg) = router_error_to_http(&err); + assert_eq!(status, 400); + assert_eq!(msg, "no inference route configured"); + // SEC-008: must NOT leak the route hint to sandboxed code + assert!(!msg.contains("local")); + } + + #[test] + fn router_error_no_compatible_route_maps_to_400() { + let err = openshell_router::RouterError::NoCompatibleRoute("anthropic_messages".into()); + let (status, msg) = router_error_to_http(&err); + assert_eq!(status, 400); + assert_eq!(msg, "no compatible inference route available"); + // SEC-008: must NOT leak the protocol name to sandboxed code + assert!(!msg.contains("anthropic_messages")); + } + + #[test] + fn router_error_unauthorized_maps_to_401() { + let err = + openshell_router::RouterError::Unauthorized("bad token from 10.0.0.5:8080".into()); + let (status, msg) = router_error_to_http(&err); + assert_eq!(status, 401); + assert_eq!(msg, "unauthorized"); + // SEC-008: must NOT leak upstream details to sandboxed code + assert!(!msg.contains("10.0.0.5")); + } + + #[test] + fn router_error_upstream_unavailable_maps_to_503() { + let err = openshell_router::RouterError::UpstreamUnavailable( + "connection refused to 10.0.0.5:8080".into(), + ); + let (status, msg) = router_error_to_http(&err); + assert_eq!(status, 503); + assert_eq!(msg, "inference service unavailable"); + // SEC-008: must NOT leak upstream address to sandboxed code + assert!(!msg.contains("10.0.0.5")); + } + + #[test] + fn router_error_upstream_protocol_maps_to_502() { + let err = openshell_router::RouterError::UpstreamProtocol( + "TLS handshake failed for nim.internal.svc:443".into(), + ); + let (status, msg) = router_error_to_http(&err); + assert_eq!(status, 502); + assert_eq!(msg, "inference service error"); + // SEC-008: must NOT leak internal hostnames to sandboxed code + assert!(!msg.contains("nim.internal")); + } + + #[test] + fn router_error_internal_maps_to_502() { + let err = openshell_router::RouterError::Internal( + "failed to read /etc/openshell/routes.json".into(), + ); + let (status, msg) = router_error_to_http(&err); + assert_eq!(status, 502); + assert_eq!(msg, "inference service error"); + // SEC-008: must NOT leak file paths to sandboxed code + assert!(!msg.contains("/etc/openshell")); + } + + #[test] + fn sanitize_response_headers_strips_hop_by_hop() { + let headers = vec![ + ("transfer-encoding".to_string(), "chunked".to_string()), + ("content-length".to_string(), "128".to_string()), + ("connection".to_string(), "keep-alive".to_string()), + ("content-type".to_string(), "text/event-stream".to_string()), + ("cache-control".to_string(), "no-cache".to_string()), + ]; + + let kept = sanitize_inference_response_headers(headers); + + assert!( + kept.iter() + .all(|(k, _)| !k.eq_ignore_ascii_case("transfer-encoding")), + "transfer-encoding should be stripped" + ); + assert!( + kept.iter() + .all(|(k, _)| !k.eq_ignore_ascii_case("content-length")), + "content-length should be stripped" + ); + assert!( + kept.iter() + .all(|(k, _)| !k.eq_ignore_ascii_case("connection")), + "connection should be stripped" + ); + assert!( + kept.iter() + .any(|(k, _)| k.eq_ignore_ascii_case("content-type")), + "content-type should be preserved" + ); + assert!( + kept.iter() + .any(|(k, _)| k.eq_ignore_ascii_case("cache-control")), + "cache-control should be preserved" + ); + } + + // -- is_always_blocked_ip -- + + #[test] + fn test_always_blocked_loopback_v4() { + assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))); + assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( + 127, 0, 0, 2 + )))); + } + + #[test] + fn test_always_blocked_link_local_v4() { + assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 169, 254 + )))); + assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( + 169, 254, 0, 1 + )))); + } + + #[test] + fn test_always_blocked_loopback_v6() { + assert!(is_always_blocked_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))); + } + + #[test] + fn test_always_blocked_link_local_v6() { + assert!(is_always_blocked_ip(IpAddr::V6(Ipv6Addr::new( + 0xfe80, 0, 0, 0, 0, 0, 0, 1 + )))); + } + + #[test] + fn test_always_blocked_ipv4_unspecified() { + assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))); + } + + #[test] + fn test_always_blocked_ipv6_unspecified() { + assert!(is_always_blocked_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED))); + } + + #[test] + fn test_always_blocked_ipv4_mapped_v6_loopback() { + let v6 = Ipv4Addr::LOCALHOST.to_ipv6_mapped(); + assert!(is_always_blocked_ip(IpAddr::V6(v6))); + } + + #[test] + fn test_always_blocked_ipv4_mapped_v6_link_local() { + let v6 = Ipv4Addr::new(169, 254, 169, 254).to_ipv6_mapped(); + assert!(is_always_blocked_ip(IpAddr::V6(v6))); + } + + #[test] + fn test_always_blocked_allows_rfc1918() { + // RFC 1918 addresses should NOT be always-blocked (they're allowed + // when allowed_ips is configured) + assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( + 10, 0, 0, 1 + )))); + assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( + 172, 16, 0, 1 + )))); + assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( + 192, 168, 0, 1 + )))); + } + + #[test] + fn test_always_blocked_allows_public() { + assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); + assert!(!is_always_blocked_ip(IpAddr::V6(Ipv6Addr::new( + 0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888 + )))); + } + + // -- parse_allowed_ips -- + + #[test] + fn test_parse_cidr_notation() { + let raw = vec!["10.0.5.0/24".to_string()]; + let nets = parse_allowed_ips(&raw).unwrap(); + assert_eq!(nets.len(), 1); + assert!(nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 5, 1)))); + assert!(!nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 6, 1)))); + } + + #[test] + fn test_parse_exact_ip() { + let raw = vec!["10.0.5.20".to_string()]; + let nets = parse_allowed_ips(&raw).unwrap(); + assert_eq!(nets.len(), 1); + assert!(nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 5, 20)))); + assert!(!nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 5, 21)))); + } + + #[test] + fn test_parse_multiple_entries() { + let raw = vec![ + "10.0.0.0/8".to_string(), + "172.16.0.0/12".to_string(), + "192.168.1.1".to_string(), + ]; + let nets = parse_allowed_ips(&raw).unwrap(); + assert_eq!(nets.len(), 3); + } + + #[test] + fn test_parse_invalid_entry_errors() { + let raw = vec!["not-an-ip".to_string()]; + let result = parse_allowed_ips(&raw); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid CIDR/IP")); + } + + #[test] + fn test_parse_mixed_valid_invalid_errors() { + let raw = vec!["10.0.5.0/24".to_string(), "garbage".to_string()]; + let result = parse_allowed_ips(&raw); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_resolve_check_allowed_ips_blocks_loopback() { + // Construct nets directly (parse_allowed_ips now rejects always-blocked). + let nets = vec!["127.0.0.0/8".parse::().unwrap()]; + let result = resolve_and_check_allowed_ips("127.0.0.1", 80, &nets, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected 'always-blocked' in error: {err}" + ); + } + + #[tokio::test] + async fn test_resolve_check_allowed_ips_blocks_metadata() { + // Construct nets directly (parse_allowed_ips now rejects always-blocked). + let nets = vec!["169.254.0.0/16".parse::().unwrap()]; + let result = resolve_and_check_allowed_ips("169.254.169.254", 80, &nets, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected 'always-blocked' in error: {err}" + ); + } + + #[tokio::test] + async fn test_resolve_check_allowed_ips_blocks_unspecified() { + // Construct nets directly (parse_allowed_ips now rejects always-blocked). + let nets = vec!["0.0.0.0/0".parse::().unwrap()]; + let result = resolve_and_check_allowed_ips("0.0.0.0", 80, &nets, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected 'always-blocked' in error: {err}" + ); + } + + #[tokio::test] + async fn test_resolve_check_allowed_ips_rejects_outside_allowlist() { + // 8.8.8.8 resolves to a public IP which is NOT in 10.0.0.0/8 + let nets = parse_allowed_ips(&["10.0.0.0/8".to_string()]).unwrap(); + let result = resolve_and_check_allowed_ips("dns.google", 443, &nets, 0).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("not in allowed_ips"), + "expected 'not in allowed_ips' in error: {err}" + ); + } + + // --- SEC-005: CIDR breadth warning and control-plane port blocklist --- + + #[tokio::test] + async fn test_resolve_check_allowed_ips_blocks_control_plane_ports() { + // Use a public CIDR (parse_allowed_ips now rejects 0.0.0.0/0). + let nets = parse_allowed_ips(&["8.8.8.0/24".to_string()]).unwrap(); + // K8s API server port + let result = resolve_and_check_allowed_ips("8.8.8.8", 6443, &nets, 0).await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("blocked control-plane port")); + + // etcd client port + let result = resolve_and_check_allowed_ips("8.8.8.8", 2379, &nets, 0).await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("blocked control-plane port")); + + // kubelet API port + let result = resolve_and_check_allowed_ips("8.8.8.8", 10250, &nets, 0).await; + assert!(result.is_err()); + assert!(result.unwrap_err().contains("blocked control-plane port")); + } + + #[tokio::test] + async fn test_resolve_check_allowed_ips_allows_non_control_plane_ports() { + // Port 443 should not be blocked by the control-plane port list + let nets = parse_allowed_ips(&["8.8.8.0/24".to_string()]).unwrap(); + let result = resolve_and_check_allowed_ips("8.8.8.8", 443, &nets, 0).await; + assert!(result.is_ok()); + } + + #[test] + fn test_parse_allowed_ips_broad_cidr_is_accepted() { + // Broad CIDRs are accepted (just warned about) -- design trade-off + let result = parse_allowed_ips(&["10.0.0.0/8".to_string()]); + assert!(result.is_ok()); + } + + // --- parse_allowed_ips: always-blocked rejection tests --- + + #[test] + fn test_parse_allowed_ips_rejects_loopback_cidr() { + let result = parse_allowed_ips(&["127.0.0.0/8".to_string()]); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("always-blocked")); + } + + #[test] + fn test_parse_allowed_ips_rejects_link_local_cidr() { + let result = parse_allowed_ips(&["169.254.0.0/16".to_string()]); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("always-blocked")); + } + + #[test] + fn test_parse_allowed_ips_rejects_unspecified() { + let result = parse_allowed_ips(&["0.0.0.0".to_string()]); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("always-blocked")); + } + + #[test] + fn test_parse_allowed_ips_rejects_single_loopback_ip() { + let result = parse_allowed_ips(&["127.0.0.1".to_string()]); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("always-blocked")); + } + + #[test] + fn test_parse_allowed_ips_rejects_single_metadata_ip() { + let result = parse_allowed_ips(&["169.254.169.254".to_string()]); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("always-blocked")); + } + + #[test] + fn test_parse_allowed_ips_rejects_wildcard_cidr() { + let result = parse_allowed_ips(&["0.0.0.0/0".to_string()]); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("always-blocked")); + } + + #[test] + fn test_parse_allowed_ips_mixed_valid_and_blocked() { + // A blocked entry taints the whole batch. + let result = parse_allowed_ips(&["10.0.5.0/24".to_string(), "127.0.0.1".to_string()]); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("always-blocked")); + } + + #[test] + fn test_parse_allowed_ips_accepts_rfc1918() { + let result = parse_allowed_ips(&["10.0.5.0/24".to_string(), "192.168.1.0/24".to_string()]); + assert!(result.is_ok()); + } + + // --- implicit_allowed_ips_for_ip_host: always-blocked skip tests --- + + #[test] + fn test_implicit_allowed_ips_skips_loopback() { + let result = implicit_allowed_ips_for_ip_host("127.0.0.1"); + assert!(result.is_empty()); + } + + #[test] + fn test_implicit_allowed_ips_skips_link_local() { + let result = implicit_allowed_ips_for_ip_host("169.254.169.254"); + assert!(result.is_empty()); + } + + #[test] + fn test_implicit_allowed_ips_skips_unspecified() { + let result = implicit_allowed_ips_for_ip_host("0.0.0.0"); + assert!(result.is_empty()); + } + + #[test] + fn test_implicit_allowed_ips_allows_rfc1918() { + let result = implicit_allowed_ips_for_ip_host("10.0.5.20"); + assert_eq!(result, vec!["10.0.5.20"]); + } + + // --- extract_host_from_uri tests --- + + #[test] + fn test_extract_host_from_http_uri() { + assert_eq!( + extract_host_from_uri("http://example.com/path"), + "example.com" + ); + } + + #[test] + fn test_extract_host_from_https_uri() { + assert_eq!( + extract_host_from_uri("https://api.openai.com/v1/chat/completions"), + "api.openai.com" + ); + } + + #[test] + fn test_extract_host_from_uri_with_port() { + assert_eq!( + extract_host_from_uri("http://example.com:8080/path"), + "example.com" + ); + } + + #[test] + fn test_extract_host_from_uri_ipv6() { + assert_eq!(extract_host_from_uri("http://[::1]:8080/path"), "[::1]"); + } + + #[test] + fn test_extract_host_from_uri_no_path() { + assert_eq!(extract_host_from_uri("http://example.com"), "example.com"); + } + + #[test] + fn test_extract_host_from_uri_empty() { + assert_eq!(extract_host_from_uri(""), "unknown"); + } + + #[test] + fn test_extract_host_from_uri_malformed() { + // Gracefully handles garbage input + let result = extract_host_from_uri("not-a-uri"); + assert!(!result.is_empty()); + } + + // --- parse_proxy_uri tests --- + + #[test] + fn test_parse_proxy_uri_standard() { + let (scheme, host, port, path) = + parse_proxy_uri("http://10.86.8.223:8000/screenshot/").unwrap(); + assert_eq!(scheme, "http"); + assert_eq!(host, "10.86.8.223"); + assert_eq!(port, 8000); + assert_eq!(path, "/screenshot/"); + } + + #[test] + fn test_parse_proxy_uri_default_port() { + let (scheme, host, port, path) = parse_proxy_uri("http://example.com/path").unwrap(); + assert_eq!(scheme, "http"); + assert_eq!(host, "example.com"); + assert_eq!(port, 80); + assert_eq!(path, "/path"); + } + + #[test] + fn test_parse_proxy_uri_https_default_port() { + let (scheme, host, port, path) = + parse_proxy_uri("https://api.example.com/v1/chat").unwrap(); + assert_eq!(scheme, "https"); + assert_eq!(host, "api.example.com"); + assert_eq!(port, 443); + assert_eq!(path, "/v1/chat"); + } + + #[test] + fn test_parse_proxy_uri_missing_path() { + let (_, host, port, path) = parse_proxy_uri("http://10.0.0.1:9090").unwrap(); + assert_eq!(host, "10.0.0.1"); + assert_eq!(port, 9090); + assert_eq!(path, "/"); + } + + #[test] + fn test_parse_proxy_uri_with_query() { + let (_, _, _, path) = parse_proxy_uri("http://host:80/api?key=val&foo=bar").unwrap(); + assert_eq!(path, "/api?key=val&foo=bar"); + } + + #[test] + fn test_parse_proxy_uri_ipv6() { + let (_, host, port, path) = parse_proxy_uri("http://[::1]:8080/test").unwrap(); + assert_eq!(host, "::1"); + assert_eq!(port, 8080); + assert_eq!(path, "/test"); + } + + #[test] + fn test_parse_proxy_uri_ipv6_default_port() { + let (_, host, port, path) = parse_proxy_uri("http://[fe80::1]/path").unwrap(); + assert_eq!(host, "fe80::1"); + assert_eq!(port, 80); + assert_eq!(path, "/path"); + } + + #[test] + fn test_parse_proxy_uri_missing_scheme() { + let result = parse_proxy_uri("example.com/path"); + assert!(result.is_err()); + } + + #[test] + fn test_parse_proxy_uri_empty_host() { + let result = parse_proxy_uri("http:///path"); + assert!(result.is_err()); + } + + // --- rewrite_forward_request tests --- + + #[tokio::test] + async fn forward_proxy_injects_token_grant_before_rewriting_request() { + let (ctx, fixture) = forward_token_grant_context(Ok("grant-token")); + let raw = b"GET http://api.example.test:8080/v1/projects HTTP/1.1\r\nHost: api.example.test:8080\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n".to_vec(); + + let with_token = inject_token_grant_for_forward_request("GET", "/v1/projects", raw, &ctx) + .await + .expect("forward token grant should inject"); + let rewritten = + rewrite_forward_request(&with_token, with_token.len(), "/v1/projects", None, false) + .expect("forward request should rewrite"); + let rewritten = String::from_utf8_lossy(&rewritten); + + assert!(rewritten.starts_with("GET /v1/projects HTTP/1.1\r\n")); + assert!(rewritten.contains("Authorization: Bearer grant-token\r\n")); + assert!(!rewritten.contains("stale-token")); + assert_eq!(authorization_header_count(&rewritten), 1); + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[tokio::test] + async fn forward_proxy_token_grant_failure_returns_error_before_rewrite() { + let (ctx, fixture) = forward_token_grant_context(Err("oauth unavailable")); + let raw = b"GET http://api.example.test:8080/v1/projects HTTP/1.1\r\nHost: api.example.test:8080\r\nConnection: close\r\n\r\n".to_vec(); + + let err = inject_token_grant_for_forward_request("GET", "/v1/projects", raw, &ctx) + .await + .expect_err("forward token grant failure should stop request rewriting"); + + assert!(err.to_string().contains("Token grant failed")); + assert!(err.to_string().contains("oauth unavailable")); + fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); + } + + #[test] + fn test_rewrite_get_request() { + let raw = + b"GET http://10.0.0.1:8000/api HTTP/1.1\r\nHost: 10.0.0.1:8000\r\nAccept: */*\r\n\r\n"; + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); + let result_str = String::from_utf8_lossy(&result); + assert!(result_str.starts_with("GET /api HTTP/1.1\r\n")); + assert!(result_str.contains("Host: 10.0.0.1:8000")); + assert!(result_str.contains("Connection: close")); + assert!(result_str.contains("Via: 1.1 openshell-sandbox")); + } + + #[test] + fn test_rewrite_strips_proxy_headers() { + let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nProxy-Authorization: Basic abc\r\nProxy-Connection: keep-alive\r\nAccept: */*\r\n\r\n"; + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); + let result_str = String::from_utf8_lossy(&result); + assert!( + !result_str + .to_ascii_lowercase() + .contains("proxy-authorization") + ); + assert!(!result_str.to_ascii_lowercase().contains("proxy-connection")); + assert!(result_str.contains("Accept: */*")); + } + + #[test] + fn test_rewrite_replaces_connection_header() { + let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nConnection: keep-alive\r\n\r\n"; + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); + let result_str = String::from_utf8_lossy(&result); + assert!(result_str.contains("Connection: close")); + assert!(!result_str.contains("keep-alive")); + } + + #[test] + fn test_rewrite_preserves_body_overflow() { + let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 13\r\n\r\n{\"key\":\"val\"}"; + let result = + rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); + let result_str = String::from_utf8_lossy(&result); + assert!(result_str.contains("{\"key\":\"val\"}")); + assert!(result_str.contains("POST /api HTTP/1.1")); + } + + #[test] + fn test_rewrite_preserves_existing_via() { + let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nVia: 1.0 upstream\r\n\r\n"; + let result = + rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); + let result_str = String::from_utf8_lossy(&result); + assert!(result_str.contains("Via: 1.0 upstream")); + // Should not add a second Via header + assert!(!result_str.contains("Via: 1.1 openshell-sandbox")); + } + + #[test] + fn test_rewrite_forward_request_uses_canonical_path_on_the_wire() { + // Regression: the forward-proxy caller must canonicalize first and + // then pass the canonical form to rewrite_forward_request so that + // OPA's policy evaluation and the bytes dispatched to the upstream + // agree. Prior to this guarantee, OPA saw the canonical form while + // the upstream re-normalized the raw path independently, re-opening + // the parser-differential this PR closes. + let raw = b"GET http://host/public/../secret HTTP/1.1\r\nHost: host\r\n\r\n"; + let (canon, _) = crate::l7::path::canonicalize_request_target( + "/public/../secret", + &crate::l7::path::CanonicalizeOptions::default(), + ) + .expect("canonicalization should succeed for the attack payload"); + assert_eq!(canon.path, "/secret"); + + let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None, false) + .expect("rewrite_forward_request should succeed"); + let rewritten_str = String::from_utf8_lossy(&rewritten); + assert!( + rewritten_str.starts_with("GET /secret HTTP/1.1\r\n"), + "outbound request line must use canonical path, got: {rewritten_str:?}" + ); + assert!( + !rewritten_str.contains(".."), + "outbound bytes must not leak the pre-canonical form, got: {rewritten_str:?}" + ); + } + + #[test] + fn test_rewrite_forward_request_preserves_canonical_query_on_the_wire() { + let raw = b"GET http://host/public/../graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D HTTP/1.1\r\nHost: host\r\n\r\n"; + let (canon, raw_query) = crate::l7::path::canonicalize_request_target( + "/public/../graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D", + &crate::l7::path::CanonicalizeOptions::default(), + ) + .expect("canonicalization should preserve query separately"); + let upstream_target = match raw_query.as_deref() { + Some(raw_query) if !raw_query.is_empty() => format!("{}?{raw_query}", canon.path), + _ => canon.path, + }; + + let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None, false) + .expect("rewrite_forward_request should succeed"); + let rewritten_str = String::from_utf8_lossy(&rewritten); + assert!( + rewritten_str.starts_with( + "GET /graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D HTTP/1.1\r\n" + ), + "outbound request line must preserve canonical query, got: {rewritten_str:?}" + ); + } + + #[test] + fn test_rewrite_resolves_placeholder_auth_headers() { + let (_, resolver) = SecretResolver::from_provider_env( + [("ANTHROPIC_API_KEY".to_string(), "sk-test".to_string())] + .into_iter() + .collect(), + ); + let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nAuthorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY\r\n\r\n"; + let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref(), false) + .expect("should succeed"); + let result_str = String::from_utf8_lossy(&result); + assert!(result_str.contains("Authorization: Bearer sk-test")); + assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); + } + + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_body_alias_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = format!("token={alias}&channel=C123"); + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.starts_with("POST /api/messages HTTP/1.1\r\n")); + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); + } + + #[tokio::test] + async fn forward_relay_rewrites_urlencoded_canonical_body_from_initial_read() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN&channel=C123"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + + let forwarded = relay_forward_request_and_capture( + "POST", + "/api/messages", + raw.as_bytes(), + Some(&resolver), + true, + ) + .await + .expect("forward relay should rewrite credentials"); + + let expected_body = "token=provider-real-token&channel=C123"; + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); + assert!(forwarded.ends_with(expected_body)); + assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); + assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); + } + + #[tokio::test] + async fn forward_relay_unresolved_body_placeholder_fails_before_upstream_write() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let body = "token=provider-OPENSHELL-RESOLVE-ENV-MISSING_TOKEN"; + let raw = format!( + "POST http://api.example.com/api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Content-Type: application/x-www-form-urlencoded\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let guard = forward_test_guard(); + let rewritten = rewrite_forward_request( + raw.as_bytes(), + raw.len(), + "/api/messages", + Some(&resolver), + true, + ) + .expect("header rewrite should defer body overflow to body rewriter"); + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let err = relay_rewritten_forward_request( + "POST", + "/api/messages", + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: Some(&resolver), + request_body_credential_rewrite: true, + }, + ) + .await + .expect_err("unresolved body placeholder should fail closed"); + + assert!(!err.to_string().contains("provider-real-token")); + assert!(!err.to_string().contains("MISSING_TOKEN")); + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "failed forward body rewrite must not reach upstream" + ); + } + + #[test] + fn test_forward_rewrite_preserves_websocket_upgrade_connection_header() { + let raw = "GET http://gateway.example.test/ws HTTP/1.1\r\n\ + Host: gateway.example.test\r\n\ + Upgrade: websocket\r\n\ + Connection: keep-alive, Upgrade\r\n\ + Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ + Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\ + Sec-WebSocket-Version: 13\r\n\r\n"; + + let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None, false) + .expect("websocket forward rewrite should succeed"); + let result_str = String::from_utf8_lossy(&result); + + assert!(result_str.starts_with("GET /ws HTTP/1.1\r\n")); + assert!(result_str.contains("Connection: keep-alive, Upgrade\r\n")); + assert!( + !result_str.contains("Connection: close\r\n"), + "websocket forward proxy must not strip the upgrade token" + ); + } + + #[tokio::test] + async fn test_forward_relay_guard_blocks_stale_generation_before_upstream_write() { + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + let guard = engine + .generation_guard(engine.current_generation()) + .unwrap(); + engine.reload(policy, policy_data).unwrap(); + + let raw = b"GET http://host/api HTTP/1.1\r\nHost: host\r\n\r\n"; + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let result = relay_rewritten_forward_request( + "GET", + "/api", + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, + ) + .await; + assert!( + result.is_err(), + "stale generation must stop forward relay before upstream write" + ); + + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "stale forward request bytes must not reach upstream" + ); + } + + #[tokio::test] + async fn test_forward_relay_rejects_cl_te_smuggling_before_upstream_write() { + let policy = include_str!("../data/sandbox-policy.rego"); + let policy_data = "network_policies: {}\n"; + let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); + let guard = engine + .generation_guard(engine.current_generation()) + .unwrap(); + + let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 4\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"; + let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) + .expect("rewrite should succeed"); + let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); + let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); + + let result = relay_rewritten_forward_request( + "POST", + "/api", + rewritten, + &mut proxy_to_client, + &mut proxy_to_upstream, + ForwardRelayOptions { + generation_guard: &guard, + websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, + secret_resolver: None, + request_body_credential_rewrite: false, + }, + ) + .await; + assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); + + drop(proxy_to_upstream); + let mut forwarded = Vec::new(); + upstream_side.read_to_end(&mut forwarded).await.unwrap(); + assert!( + forwarded.is_empty(), + "smuggled forward request bytes must not reach upstream" + ); + } + + // --- Forward proxy SSRF defence tests --- + // + // The forward proxy handler uses the same SSRF logic as the CONNECT path: + // - No allowed_ips: resolve_and_reject_internal blocks private IPs, allows public. + // - With allowed_ips: resolve_and_check_allowed_ips validates against allowlist. + // + // These tests document that contract for the forward proxy path specifically. + + #[tokio::test] + async fn test_forward_public_ip_allowed_without_allowed_ips() { + // Public IPs (e.g. dns.google -> 8.8.8.8) should pass through + // resolve_and_reject_internal without needing allowed_ips. + let result = resolve_and_reject_internal("dns.google", 80, 0).await; + assert!( + result.is_ok(), + "Public IP should be allowed without allowed_ips: {result:?}" + ); + let addrs = result.unwrap(); + assert!(!addrs.is_empty(), "Should resolve to at least one address"); + // All resolved addresses should be public. + for addr in &addrs { + assert!( + !is_internal_ip(addr.ip()), + "dns.google should resolve to public IPs, got {}", + addr.ip() + ); + } + } + + #[tokio::test] + async fn test_forward_private_ip_rejected_without_allowed_ips() { + // Private IP literals should be rejected by resolve_and_reject_internal. + let result = resolve_and_reject_internal("10.0.0.1", 80, 0).await; + assert!( + result.is_err(), + "Private IP should be rejected without allowed_ips" + ); + let err = result.unwrap_err(); + assert!( + err.contains("internal address"), + "expected 'internal address' in error: {err}" + ); + } + + #[tokio::test] + async fn test_forward_private_ip_accepted_with_allowed_ips() { + // Private IP with matching allowed_ips should pass through. + let nets = parse_allowed_ips(&["10.0.0.0/8".to_string()]).unwrap(); + let result = resolve_and_check_allowed_ips("10.0.0.1", 80, &nets, 0).await; + assert!( + result.is_ok(), + "Private IP with matching allowed_ips should be accepted: {result:?}" + ); + } + + #[tokio::test] + async fn test_forward_private_ip_rejected_with_wrong_allowed_ips() { + // Private IP not in allowed_ips should be rejected. + let nets = parse_allowed_ips(&["192.168.0.0/16".to_string()]).unwrap(); + let result = resolve_and_check_allowed_ips("10.0.0.1", 80, &nets, 0).await; + assert!( + result.is_err(), + "Private IP not in allowed_ips should be rejected" + ); + let err = result.unwrap_err(); + assert!( + err.contains("not in allowed_ips"), + "expected 'not in allowed_ips' in error: {err}" + ); + } + + #[tokio::test] + async fn test_forward_loopback_always_blocked_even_with_allowed_ips() { + // Loopback addresses are always blocked, even if in allowed_ips. + // Construct nets directly (parse_allowed_ips now rejects always-blocked). + let nets = vec!["127.0.0.0/8".parse::().unwrap()]; + let result = resolve_and_check_allowed_ips("127.0.0.1", 80, &nets, 0).await; + assert!(result.is_err(), "Loopback should be always blocked"); + let err = result.unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected 'always-blocked' in error: {err}" + ); + } + + #[tokio::test] + async fn test_forward_link_local_always_blocked_even_with_allowed_ips() { + // Link-local / cloud metadata addresses are always blocked. + // Construct nets directly (parse_allowed_ips now rejects always-blocked). + let nets = vec!["169.254.0.0/16".parse::().unwrap()]; + let result = resolve_and_check_allowed_ips("169.254.169.254", 80, &nets, 0).await; + assert!(result.is_err(), "Link-local should be always blocked"); + let err = result.unwrap_err(); + assert!( + err.contains("always-blocked"), + "expected 'always-blocked' in error: {err}" + ); + } + + // -- implicit_allowed_ips_for_ip_host -- + + #[test] + fn test_implicit_allowed_ips_returns_ip_for_ipv4_literal() { + let result = implicit_allowed_ips_for_ip_host("192.168.1.100"); + assert_eq!(result, vec!["192.168.1.100"]); + } + + #[test] + fn test_implicit_allowed_ips_skips_ipv6_loopback() { + // ::1 is always-blocked, so implicit allowed_ips should be empty. + let result = implicit_allowed_ips_for_ip_host("::1"); + assert!(result.is_empty()); + } + + #[test] + fn test_implicit_allowed_ips_returns_empty_for_hostname() { + let result = implicit_allowed_ips_for_ip_host("api.github.com"); + assert!(result.is_empty()); + } + + #[test] + fn test_implicit_allowed_ips_returns_empty_for_wildcard() { + let result = implicit_allowed_ips_for_ip_host("*.example.com"); + assert!(result.is_empty()); + } + + /// Regression test: exercises the actual keep-alive interception loop to + /// verify that a non-inference request is denied even after a previous + /// inference request was successfully routed on the same connection. + /// + /// Before the fix, `handle_inference_interception` used + /// `else if !routed_any` which silently dropped denials once `routed_any` + /// was true, allowing non-inference HTTP requests to piggyback on a + /// keep-alive connection that had previously handled inference traffic. + /// Regression test: exercises the actual keep-alive interception loop to + /// verify that a non-inference request is denied even after a previous + /// inference request was successfully routed on the same connection. + /// + /// The server runs in a spawned task with empty routes (the inference + /// request gets a 503 "not configured" but is still recognized as + /// inference and returns Ok(true)). The client sends the inference + /// request, reads the 503 response, then sends a non-inference request + /// on the same connection. The server must return Denied. + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_keepalive_denies_non_inference_after_routed() { + use openshell_router::Router; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let router = Router::new().unwrap(); + let patterns = crate::l7::inference::default_patterns(); + // Empty routes: inference request gets 503 but returns Ok(true). + let ctx = InferenceContext::new(patterns, router, vec![], vec![]); + + let body = r#"{"model":"test","messages":[{"role":"user","content":"hi"}]}"#; + let inference_req = format!( + "POST /v1/chat/completions HTTP/1.1\r\n\ + Host: inference.local\r\n\ + Content-Type: application/json\r\n\ + Content-Length: {}\r\n\r\n{}", + body.len(), + body, + ); + let non_inference_req = "GET /admin/config HTTP/1.1\r\nHost: inference.local\r\n\r\n"; + + let (client, mut server) = tokio::io::duplex(65536); + let (mut client_read, mut client_write) = tokio::io::split(client); + + // Spawn the server task so it runs concurrently. + let server_task = + tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); + + // Client: send inference request, read response, send non-inference. + client_write + .write_all(inference_req.as_bytes()) + .await + .unwrap(); + + // Read the 503 response so the server loops back to read. + let mut buf = vec![0u8; 4096]; + let _ = client_read.read(&mut buf).await.unwrap(); + + // Send non-inference request on the same keep-alive connection. + client_write + .write_all(non_inference_req.as_bytes()) + .await + .unwrap(); + drop(client_write); + + // Drain remaining response bytes. + tokio::spawn(async move { + let mut buf = vec![0u8; 4096]; + loop { + match client_read.read(&mut buf).await { + Ok(0) | Err(_) => break, + Ok(_) => continue, + } + } + }); + + let outcome = server_task.await.unwrap().unwrap(); + + assert!( + matches!(outcome, InferenceOutcome::Denied { .. }), + "expected Denied after non-inference request on keep-alive, got: {outcome:?}" + ); + } + + // -- build_json_error_response -- + + #[test] + fn test_json_error_response_403() { + let resp = build_json_error_response( + 403, + "Forbidden", + "policy_denied", + "CONNECT api.example.com:443 not permitted by policy", + ); + let resp_str = String::from_utf8(resp).unwrap(); + + assert!(resp_str.starts_with("HTTP/1.1 403 Forbidden\r\n")); + assert!(resp_str.contains("Content-Type: application/json\r\n")); + assert!(resp_str.contains("Connection: close\r\n")); + + // Extract body after \r\n\r\n + let body_start = resp_str.find("\r\n\r\n").unwrap() + 4; + let body: serde_json::Value = serde_json::from_str(&resp_str[body_start..]).unwrap(); + assert_eq!(body["error"], "policy_denied"); + assert_eq!( + body["detail"], + "CONNECT api.example.com:443 not permitted by policy" + ); + } + + #[test] + fn test_json_error_response_502() { + let resp = build_json_error_response( + 502, + "Bad Gateway", + "upstream_unreachable", + "connection to api.example.com:443 failed", + ); + let resp_str = String::from_utf8(resp).unwrap(); + + assert!(resp_str.starts_with("HTTP/1.1 502 Bad Gateway\r\n")); + + let body_start = resp_str.find("\r\n\r\n").unwrap() + 4; + let body: serde_json::Value = serde_json::from_str(&resp_str[body_start..]).unwrap(); + assert_eq!(body["error"], "upstream_unreachable"); + assert_eq!(body["detail"], "connection to api.example.com:443 failed"); + } + + #[test] + fn test_json_error_response_content_length_matches() { + let resp = build_json_error_response(403, "Forbidden", "test", "detail"); + let resp_str = String::from_utf8(resp).unwrap(); + + // Extract Content-Length value + let cl_line = resp_str + .lines() + .find(|l| l.starts_with("Content-Length:")) + .unwrap(); + let cl: usize = cl_line.split(": ").nth(1).unwrap().trim().parse().unwrap(); + + // Verify body length matches + let body_start = resp_str.find("\r\n\r\n").unwrap() + 4; + assert_eq!(resp_str[body_start..].len(), cl); + } + + /// End-to-end regression for the `docker cp` hot-swap hazard that + /// motivated `binary_path()` stripping the kernel's `" (deleted)"` + /// suffix (PR #844). + /// + /// Before the strip, the identity-resolution chain inside + /// `evaluate_opa_tcp` failed with `"Failed to stat + /// /opt/openshell/bin/openshell-sandbox (deleted)"` because + /// `BinaryIdentityCache::verify_or_cache()` tried to `metadata()` the + /// tainted path. That masked the real security signal: a live process + /// was now bound to a *different* binary on disk than the one that was + /// TOFU-cached. After the strip, `binary_path()` returns a path that + /// stats fine, the cache rehashes the new bytes, and the hash mismatch + /// surfaces as a `Binary integrity violation` error — the contract this + /// PR is trying to establish. + /// + /// Test shape (from the review comment on the initial PR): + /// 1. Start a `TcpListener` in the test process. + /// 2. Copy `/bin/bash` to a temp path we control. + /// 3. Prime `BinaryIdentityCache` with that temp binary's hash. + /// 4. Spawn the temp bash as a child with a `/dev/tcp` one-liner that + /// opens a real TCP connection to the listener and holds it open. + /// 5. Accept the connection on the listener side and capture the peer's + /// ephemeral port — that's what `resolve_process_identity` uses to + /// walk `/proc/net/tcp` back to the child PID. + /// 6. Overwrite the temp bash on disk with different bytes to simulate + /// a `docker cp` hot-swap. The running child is unaffected (it still + /// executes from its in-memory image), but `/proc//exe` will + /// now readlink to `" (deleted)"` OR the overwritten file, depending + /// on whether the filesystem reused the inode. + /// 7. Call `resolve_process_identity` and assert: + /// - the error reason contains `"Binary integrity violation"` (the + /// cache detected the tampered on-disk bytes), and + /// - the error reason does NOT contain `"Failed to stat"` or + /// `"(deleted)"` (the old pre-strip failure mode). + #[cfg(target_os = "linux")] + #[test] + fn resolve_process_identity_surfaces_binary_integrity_violation_on_hot_swap() { + use crate::identity::BinaryIdentityCache; + use std::io::Read; + use std::net::TcpListener; + use std::os::unix::fs::PermissionsExt; + use std::process::{Command, Stdio}; + use std::time::Duration; + + // Skip if /bin/bash is not present (e.g. minimal containers). + if !std::path::Path::new("/bin/bash").exists() { + eprintln!("skipping: /bin/bash not available"); + return; + } + + // 1. Start a listener on loopback. + let listener = TcpListener::bind("127.0.0.1:0").expect("bind"); + let listener_port = listener.local_addr().unwrap().port(); + + // 2. Copy /bin/bash to a temp path. + let tmp = tempfile::TempDir::new().unwrap(); + let bash_v1 = tmp.path().join("hotswap-bash"); + std::fs::copy("/bin/bash", &bash_v1).expect("copy bash"); + std::fs::set_permissions(&bash_v1, std::fs::Permissions::from_mode(0o755)).unwrap(); + + // 3. Prime the cache with the v1 hash of the temp bash. + let cache = BinaryIdentityCache::new(); + let v1_hash = cache + .verify_or_cache(&bash_v1) + .expect("prime cache with v1 bash hash"); + assert!(!v1_hash.is_empty()); + + // 4. Spawn the temp bash with a /dev/tcp one-liner that opens a real + // connection to the listener and sleeps to keep it open. The + // `read -t` blocks on stdin so the shell stays resident. + let script = format!("exec 3<>/dev/tcp/127.0.0.1/{listener_port}; sleep 30 <&3"); + let mut child = Command::new(&bash_v1) + .arg("-c") + .arg(&script) + .stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .spawn() + .expect("spawn hotswap-bash child"); + + // 5. Accept on the listener side, capture the peer port. + listener.set_nonblocking(false).expect("blocking listener"); + let (mut stream, peer_addr) = match listener.accept() { + Ok(pair) => pair, + Err(e) => { + let _ = child.kill(); + let _ = child.wait(); + panic!("failed to accept child connection: {e}"); + } + }; + let peer_port = peer_addr.port(); + // Drain any spurious data; we just need the socket open. + stream + .set_read_timeout(Some(Duration::from_millis(50))) + .ok(); + let mut buf = [0u8; 16]; + let _ = stream.read(&mut buf); + + // Give the kernel a moment so /proc//net/tcp and + // /proc//fd/ both reflect the ESTABLISHED socket. + std::thread::sleep(Duration::from_millis(50)); + + // 6. Simulate `docker cp`: unlink the running binary and create a + // fresh file with different bytes at the same path. Writing + // in place via O_TRUNC is rejected by the kernel with ETXTBSY + // because the inode is still being executed. Unlink is cheap: + // the inode persists in memory via the child's exec mapping, + // so the child keeps running, but a new inode now lives at + // `bash_v1` with a different SHA-256. + std::fs::remove_file(&bash_v1).expect("unlink running bash_v1"); + let tampered_bytes = b"#!/bin/sh\n# tampered bash v2 from hotswap test\nexit 0\n"; + std::fs::write(&bash_v1, tampered_bytes).expect("write replacement bytes"); + + // 7. Resolve identity through the real helper and assert the + // contract: we want "Binary integrity violation", not + // "Failed to stat ... (deleted)". + let test_pid = std::process::id(); + let result = resolve_process_identity(test_pid, peer_port, &cache); + + // Always clean up the child before asserting so a failure doesn't + // leak a sleeping process across test runs. + let _ = child.kill(); + let _ = child.wait(); + + match result { + Ok(_) => panic!( + "resolve_process_identity unexpectedly succeeded after hot-swap; \ + the cache should have detected the tampered on-disk bytes" + ), + Err(err) => { + assert!( + err.reason.contains("Binary integrity violation"), + "expected 'Binary integrity violation' error, got: {}", + err.reason + ); + assert!( + !err.reason.contains("Failed to stat"), + "pre-PR-#844 failure mode leaked: {}", + err.reason + ); + assert!( + !err.reason.contains("(deleted)"), + "resolved path still contains '(deleted)' suffix: {}", + err.reason + ); + // The binary field should be populated — we did resolve a + // path before failing. + assert!( + err.binary.is_some(), + "expected resolved binary path on integrity failure" + ); + if let Some(path) = &err.binary { + assert!( + !path.to_string_lossy().contains("(deleted)"), + "resolved binary path still tainted: {}", + path.display() + ); + } + } + } + } + + #[cfg(target_os = "linux")] + #[test] + // TODO: exec'ing /bin/sleep (SELinux label bin_t) from a user_home_t test + // binary causes /proc//exe readlink to return ENOENT on + // SELinux-enforcing hosts. Fix by building a test-sleep-helper binary in + // the same crate so it inherits the user_home_t label. + fn resolve_process_identity_denies_fork_exec_shared_socket_ambiguity() { + use crate::identity::BinaryIdentityCache; + use std::ffi::CString; + use std::net::{TcpListener, TcpStream}; + use std::os::fd::AsRawFd; + use std::time::{Duration, Instant}; + + struct ChildGuard(libc::pid_t); + impl Drop for ChildGuard { + fn drop(&mut self) { + #[allow(unsafe_code)] + unsafe { + libc::kill(self.0, libc::SIGKILL); + libc::waitpid(self.0, std::ptr::null_mut(), 0); + } + } + } + + if !std::path::Path::new("/bin/sleep").exists() { + eprintln!("skipping: /bin/sleep not available"); + return; + } + + if std::process::Command::new("getenforce") + .output() + .is_ok_and(|o| String::from_utf8_lossy(&o.stdout).trim() == "Enforcing") + { + eprintln!( + "skipping: SELinux is enforcing — cross-label /proc//exe readlink fails" + ); + return; + } + + let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let listener_port = listener.local_addr().unwrap().port(); + let stream = TcpStream::connect(("127.0.0.1", listener_port)).expect("connect"); + let peer_port = stream.local_addr().unwrap().port(); + let (_accepted, _) = listener.accept().expect("accept"); + + let fd = stream.as_raw_fd(); + // libc/syscall FFI requires unsafe + #[allow(unsafe_code)] + unsafe { + let flags = libc::fcntl(fd, libc::F_GETFD); + assert!(flags >= 0, "F_GETFD failed"); + assert_eq!( + libc::fcntl(fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC), + 0, + "F_SETFD failed" + ); + } + + let sleep_path = CString::new("/bin/sleep").unwrap(); + let arg0 = CString::new("sleep").unwrap(); + let arg1 = CString::new("30").unwrap(); + // libc/syscall FFI requires unsafe + #[allow(unsafe_code)] + let child_pid = unsafe { libc::fork() }; + assert!(child_pid >= 0, "fork failed"); + if child_pid == 0 { + // libc/syscall FFI requires unsafe + #[allow(unsafe_code)] + unsafe { + libc::execl( + sleep_path.as_ptr(), + arg0.as_ptr(), + arg1.as_ptr(), + std::ptr::null::(), + ); + libc::_exit(127); + } + } + + let _guard = ChildGuard(child_pid); + let entrypoint_pid = std::process::id(); + + let deadline = Instant::now() + Duration::from_secs(5); + loop { + if let Ok(link) = std::fs::read_link(format!("/proc/{child_pid}/exe")) + && link.to_string_lossy().contains("sleep") + { + break; + } + assert!( + Instant::now() < deadline, + "child pid {child_pid} did not exec into sleep within 5s" + ); + std::thread::sleep(Duration::from_millis(20)); + } + + let cache = BinaryIdentityCache::new(); + + let mut result = resolve_process_identity(entrypoint_pid, peer_port, &cache); + for _ in 0..10 { + match &result { + Err(err) + if err.reason.contains("No such file or directory") + || err.reason.contains("os error 2") => + { + // /proc//fd scan transiently failed; give procfs time to settle. + std::thread::sleep(Duration::from_millis(50)); + result = resolve_process_identity(entrypoint_pid, peer_port, &cache); + } + Ok(_) => { + // On arm64 under heavy CI load the /proc fd scan can transiently + // miss the parent process's socket fd, making the scan return only + // the child as owner and yielding a spurious Ok. Retry to give + // both owners time to appear consistently in /proc//fd. + std::thread::sleep(Duration::from_millis(50)); + result = resolve_process_identity(entrypoint_pid, peer_port, &cache); + } + _ => break, + } + } + + match result { + Ok(identity) => panic!( + "resolve_process_identity unexpectedly succeeded for shared socket owned by PID {}", + identity.binary_pid + ), + Err(err) => { + assert!( + err.reason.contains("ambiguous shared socket ownership"), + "expected ambiguous socket ownership error, got: {}", + err.reason + ); + assert!( + err.reason.contains(&entrypoint_pid.to_string()), + "error should include parent PID; got: {}", + err.reason + ); + assert!( + err.reason.contains(&child_pid.to_string()), + "error should include child PID; got: {}", + err.reason + ); + } + } + } +} diff --git a/proto/sandbox.proto b/proto/sandbox.proto index ef0b0540f..cf0fc902b 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -160,6 +160,8 @@ message L7DenyRule { // GraphQL root field globs. Deny rules match when any selected root field // matches any configured glob. repeated string fields = 7; + // JSON-RPC method name (JSON-RPC): exact name or glob, e.g. "tools/call". + string rpc_method = 8; } // An L7 policy rule (allow-only). @@ -186,6 +188,8 @@ message L7Allow { // GraphQL root field globs. Allow rules match only when every selected root // field matches one of the configured globs. Omit to match all fields. repeated string fields = 7; + // JSON-RPC method name (JSON-RPC): exact name or glob, e.g. "tools/call". + string rpc_method = 8; } // Query value matcher for one query parameter key. From e482e6195c4a0919c5c8b38eadc8c5ed5dd57875 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 17:05:38 -0700 Subject: [PATCH 05/12] fix(l7): honor JSON-RPC body size config Carry JSON-RPC max body bytes from policy into runtime endpoint config and use it on both CONNECT and forward JSON-RPC inspection paths instead of hardcoding 64 KiB. Signed-off-by: Kris Hicks --- crates/openshell-policy/src/lib.rs | 43 ++++++++++++- crates/openshell-providers/src/profiles.rs | 4 ++ crates/openshell-sandbox/src/l7/jsonrpc.rs | 2 + crates/openshell-sandbox/src/l7/mod.rs | 65 ++++++++++++++++++++ crates/openshell-sandbox/src/l7/relay.rs | 5 +- crates/openshell-sandbox/src/opa.rs | 3 + crates/openshell-sandbox/src/policy_local.rs | 1 + crates/openshell-sandbox/src/proxy.rs | 5 +- proto/sandbox.proto | 3 + 9 files changed, 128 insertions(+), 3 deletions(-) diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 1a692d8d8..d839782a7 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -162,6 +162,14 @@ struct JsonRpcConfigDef { batch_policy: String, } +fn json_rpc_config_from_proto(max_body_bytes: u32) -> Option { + (max_body_bytes > 0).then_some(JsonRpcConfigDef { + max_body_bytes, + on_parse_error: String::new(), + batch_policy: String::new(), + }) +} + #[derive(Debug, Serialize, Deserialize)] #[serde(deny_unknown_fields)] struct GraphqlOperationDef { @@ -370,6 +378,10 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { }) .collect(), graphql_max_body_bytes: e.graphql_max_body_bytes, + json_rpc_max_body_bytes: e + .json_rpc + .as_ref() + .map_or(0, |config| config.max_body_bytes), } }) .collect(), @@ -539,7 +551,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { }) .collect(), graphql_max_body_bytes: e.graphql_max_body_bytes, - json_rpc: None, + json_rpc: json_rpc_config_from_proto(e.json_rpc_max_body_bytes), } }) .collect(), @@ -1727,6 +1739,35 @@ network_policies: assert_eq!(ep.deny_rules[0].fields, vec!["deleteRepository"]); } + #[test] + fn round_trip_preserves_json_rpc_max_body_bytes() { + let yaml = r" +version: 1 +network_policies: + mcp: + name: mcp + endpoints: + - host: mcp.example.com + port: 443 + protocol: json-rpc + enforcement: enforce + json_rpc: + max_body_bytes: 131072 + rules: + - allow: + rpc_method: initialize + binaries: + - path: /usr/bin/curl +"; + let proto1 = parse_sandbox_policy(yaml).expect("parse failed"); + let yaml_out = serialize_sandbox_policy(&proto1).expect("serialize failed"); + let proto2 = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + + let ep = &proto2.network_policies["mcp"].endpoints[0]; + assert_eq!(ep.protocol, "json-rpc"); + assert_eq!(ep.json_rpc_max_body_bytes, 131_072); + } + #[test] fn round_trip_preserves_websocket_credential_rewrite() { let yaml = r" diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index a6d282256..86e3928f0 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -200,6 +200,8 @@ pub struct EndpointProfile { pub graphql_persisted_queries: HashMap, #[serde(default, skip_serializing_if = "is_zero")] pub graphql_max_body_bytes: u32, + #[serde(default, skip_serializing_if = "is_zero")] + pub json_rpc_max_body_bytes: u32, #[serde(default, skip_serializing_if = "String::is_empty")] pub path: String, } @@ -743,6 +745,7 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { .map(|(name, operation)| (name.clone(), graphql_operation_to_proto(operation))) .collect(), graphql_max_body_bytes: endpoint.graphql_max_body_bytes, + json_rpc_max_body_bytes: endpoint.json_rpc_max_body_bytes, path: endpoint.path.clone(), } } @@ -773,6 +776,7 @@ fn endpoint_from_proto(endpoint: &NetworkEndpoint) -> EndpointProfile { .map(|(name, operation)| (name.clone(), graphql_operation_from_proto(operation))) .collect(), graphql_max_body_bytes: endpoint.graphql_max_body_bytes, + json_rpc_max_body_bytes: endpoint.json_rpc_max_body_bytes, path: endpoint.path.clone(), } } diff --git a/crates/openshell-sandbox/src/l7/jsonrpc.rs b/crates/openshell-sandbox/src/l7/jsonrpc.rs index 977c8046f..2dc83c12d 100644 --- a/crates/openshell-sandbox/src/l7/jsonrpc.rs +++ b/crates/openshell-sandbox/src/l7/jsonrpc.rs @@ -8,6 +8,8 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::l7::provider::{L7Provider, L7Request}; +pub const DEFAULT_MAX_BODY_BYTES: usize = 64 * 1024; + pub struct JsonRpcHttpRequest { pub request: L7Request, pub info: JsonRpcRequestInfo, diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs index d83a4dfa5..14fee09da 100644 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ b/crates/openshell-sandbox/src/l7/mod.rs @@ -80,6 +80,8 @@ pub struct L7EndpointConfig { pub enforcement: EnforcementMode, /// Maximum GraphQL request body bytes to buffer for inspection. pub graphql_max_body_bytes: usize, + /// Maximum JSON-RPC request body bytes to buffer for inspection. + pub json_rpc_max_body_bytes: usize, /// When true, percent-encoded `/` (`%2F`) is preserved in path segments /// rather than rejected at the parser. Needed by upstreams like GitLab /// that embed `%2F` in namespaced project paths. Defaults to false. @@ -171,6 +173,10 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { .and_then(|v| usize::try_from(v).ok()) .filter(|v| *v > 0) .unwrap_or(graphql::DEFAULT_MAX_BODY_BYTES); + let json_rpc_max_body_bytes = get_object_u64(val, "json_rpc_max_body_bytes") + .and_then(|v| usize::try_from(v).ok()) + .filter(|v| *v > 0) + .unwrap_or(jsonrpc::DEFAULT_MAX_BODY_BYTES); Some(L7EndpointConfig { protocol, @@ -178,6 +184,7 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { tls, enforcement, graphql_max_body_bytes, + json_rpc_max_body_bytes, allow_encoded_slash, websocket_credential_rewrite, request_body_credential_rewrite, @@ -630,6 +637,18 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } } + if ep.get("json_rpc_max_body_bytes").is_some() { + let valid_max = ep + .get("json_rpc_max_body_bytes") + .and_then(serde_json::Value::as_u64) + .is_some_and(|v| v > 0); + if !valid_max { + errors.push(format!( + "{loc}: json_rpc_max_body_bytes must be a positive integer" + )); + } + } + if protocol != "graphql" && protocol != "websocket" && (ep.get("persisted_queries").is_some() @@ -641,6 +660,12 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< )); } + if protocol != "json-rpc" && ep.get("json_rpc_max_body_bytes").is_some() { + warnings.push(format!( + "{loc}: JSON-RPC-specific endpoint fields are ignored unless protocol is json-rpc" + )); + } + if ep .get("websocket_credential_rewrite") .and_then(serde_json::Value::as_bool) @@ -1201,6 +1226,46 @@ mod tests { assert_eq!(config.protocol, L7Protocol::Rest); assert_eq!(config.tls, TlsMode::Auto); assert_eq!(config.enforcement, EnforcementMode::Audit); + assert_eq!( + config.json_rpc_max_body_bytes, + jsonrpc::DEFAULT_MAX_BODY_BYTES + ); + } + + #[test] + fn parse_l7_config_jsonrpc_max_body_bytes() { + let val = regorus::Value::from_json_str( + r#"{"protocol": "json-rpc", "host": "mcp.example.com", "port": 443, "json_rpc_max_body_bytes": 131072}"#, + ) + .unwrap(); + let config = parse_l7_config(&val).unwrap(); + assert_eq!(config.protocol, L7Protocol::JsonRpc); + assert_eq!(config.json_rpc_max_body_bytes, 131_072); + } + + #[test] + fn validate_jsonrpc_max_body_bytes_must_be_positive() { + let data = serde_json::json!({ + "network_policies": { + "test": { + "endpoints": [{ + "host": "mcp.example.com", + "port": 443, + "protocol": "json-rpc", + "access": "full", + "json_rpc_max_body_bytes": 0 + }], + "binaries": [] + } + } + }); + let (errors, _warnings) = validate_l7_policies(&data); + assert!( + errors + .iter() + .any(|e| e.contains("json_rpc_max_body_bytes must be a positive integer")), + "should reject non-positive JSON-RPC max body size, got errors: {errors:?}" + ); } #[test] diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index ec723df50..0df2fe9b8 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -906,7 +906,7 @@ where let parsed = match crate::l7::jsonrpc::parse_jsonrpc_http_request( client, - 64 * 1024, + config.json_rpc_max_body_bytes, crate::l7::path::CanonicalizeOptions { allow_encoded_slash: config.allow_encoded_slash, ..Default::default() @@ -1993,6 +1993,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: EnforcementMode::Enforce, graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: true, request_body_credential_rewrite: false, @@ -2096,6 +2097,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: EnforcementMode::Enforce, graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: true, request_body_credential_rewrite: false, @@ -2216,6 +2218,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: EnforcementMode::Enforce, graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: true, request_body_credential_rewrite: false, diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index aa49509e5..de6ff5257 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -1140,6 +1140,9 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if e.graphql_max_body_bytes > 0 { ep["graphql_max_body_bytes"] = e.graphql_max_body_bytes.into(); } + if e.json_rpc_max_body_bytes > 0 { + ep["json_rpc_max_body_bytes"] = e.json_rpc_max_body_bytes.into(); + } ep }) .collect(); diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs index 34eeaada5..20536ac2b 100644 --- a/crates/openshell-sandbox/src/policy_local.rs +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -1122,6 +1122,7 @@ fn network_endpoint_from_json( persisted_queries: String::new(), graphql_persisted_queries: HashMap::new(), graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: 0, path: String::new(), }) } diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index 7782f69d7..a45f1bc60 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -3412,7 +3412,7 @@ async fn handle_forward_proxy( let body = match crate::l7::http::read_body_for_inspection( client, &mut jsonrpc_request, - 64 * 1024, + l7_config.config.json_rpc_max_body_bytes, ) .await { @@ -4142,6 +4142,7 @@ mod tests { tls: crate::l7::TlsMode::Auto, enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite, request_body_credential_rewrite: false, @@ -4741,6 +4742,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, request_body_credential_rewrite: false, @@ -4754,6 +4756,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, request_body_credential_rewrite: false, diff --git a/proto/sandbox.proto b/proto/sandbox.proto index cf0fc902b..9d6ec2824 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -128,6 +128,9 @@ message NetworkEndpoint { // Advisor-proposed endpoints must not satisfy exact-host SSRF trust unless // they are converted through an explicit user-authored policy path. bool advisor_proposed = 18; + // Maximum JSON-RPC-over-HTTP request body bytes to buffer for inspection. + // Defaults to 65536 when unset. + uint32 json_rpc_max_body_bytes = 19; } // Trusted GraphQL operation classification. From a2e05209b181c8bd84b7fd5ee65b35e74793d5b2 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 11:42:53 -0700 Subject: [PATCH 06/12] feat(l7): match JSON-RPC params in rules Add JSON-RPC params matcher maps to proto and YAML policy conversion, including shared matcher conversion helpers. Flatten object params into dot-separated keys for policy input and extend Rego allow and deny matching to filter JSON-RPC calls by params. Signed-off-by: Kris Hicks --- crates/openshell-cli/src/policy_update.rs | 2 + crates/openshell-policy/src/lib.rs | 89 +++++----- crates/openshell-policy/src/merge.rs | 2 + crates/openshell-providers/src/profiles.rs | 2 + .../data/sandbox-policy.rego | 43 ++++- crates/openshell-sandbox/src/l7/jsonrpc.rs | 52 ++++++ crates/openshell-sandbox/src/l7/relay.rs | 1 + .../src/mechanistic_mapper.rs | 1 + crates/openshell-sandbox/src/opa.rs | 158 ++++++++++++++---- crates/openshell-sandbox/src/policy_local.rs | 2 + crates/openshell-server/src/grpc/policy.rs | 6 + proto/sandbox.proto | 6 + 12 files changed, 285 insertions(+), 79 deletions(-) diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 03695d48e..22363c920 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -206,6 +206,7 @@ fn group_allow_rules(specs: &[String]) -> Result Result L7QueryMatcher { + match matcher { + QueryMatcherDef::Glob(glob) => L7QueryMatcher { glob, any: vec![] }, + QueryMatcherDef::Any(any) => L7QueryMatcher { + glob: String::new(), + any: any.any, + }, + } +} + +fn matcher_proto_to_def(matcher: L7QueryMatcher) -> QueryMatcherDef { + if matcher.any.is_empty() { + QueryMatcherDef::Glob(matcher.glob) + } else { + QueryMatcherDef::Any(QueryAnyDef { any: matcher.any }) + } +} + fn to_proto(raw: PolicyFile) -> SandboxPolicy { let network_policies = raw .network_policies @@ -311,16 +329,15 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { .query .into_iter() .map(|(key, matcher)| { - let proto = match matcher { - QueryMatcherDef::Glob(glob) => { - L7QueryMatcher { glob, any: vec![] } - } - QueryMatcherDef::Any(any) => L7QueryMatcher { - glob: String::new(), - any: any.any, - }, - }; - (key, proto) + (key, matcher_def_to_proto(matcher)) + }) + .collect(), + params: r + .allow + .params + .into_iter() + .map(|(key, matcher)| { + (key, matcher_def_to_proto(matcher)) }) .collect(), }), @@ -341,18 +358,12 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { query: d .query .into_iter() - .map(|(key, matcher)| { - let proto = match matcher { - QueryMatcherDef::Glob(glob) => { - L7QueryMatcher { glob, any: vec![] } - } - QueryMatcherDef::Any(any) => L7QueryMatcher { - glob: String::new(), - any: any.any, - }, - }; - (key, proto) - }) + .map(|(key, matcher)| (key, matcher_def_to_proto(matcher))) + .collect(), + params: d + .params + .into_iter() + .map(|(key, matcher)| (key, matcher_def_to_proto(matcher))) .collect(), }) .collect(), @@ -488,17 +499,16 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { .query .into_iter() .map(|(key, matcher)| { - let yaml_matcher = if matcher.any.is_empty() { - QueryMatcherDef::Glob(matcher.glob) - } else { - QueryMatcherDef::Any(QueryAnyDef { - any: matcher.any, - }) - }; - (key, yaml_matcher) + (key, matcher_proto_to_def(matcher)) + }) + .collect(), + params: a + .params + .into_iter() + .map(|(key, matcher)| { + (key, matcher_proto_to_def(matcher)) }) .collect(), - params: BTreeMap::new(), }, } }) @@ -519,17 +529,16 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { .query .iter() .map(|(key, matcher)| { - let yaml_matcher = if matcher.any.is_empty() { - QueryMatcherDef::Glob(matcher.glob.clone()) - } else { - QueryMatcherDef::Any(QueryAnyDef { - any: matcher.any.clone(), - }) - }; - (key.clone(), yaml_matcher) + (key.clone(), matcher_proto_to_def(matcher.clone())) + }) + .collect(), + params: d + .params + .iter() + .map(|(key, matcher)| { + (key.clone(), matcher_proto_to_def(matcher.clone())) }) .collect(), - params: BTreeMap::new(), }) .collect(), allow_encoded_slash: e.allow_encoded_slash, diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 73c40316e..6be185ca0 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -748,6 +748,7 @@ fn expand_access_preset(protocol: &str, access: &str) -> Option> { operation_name: String::new(), fields: Vec::new(), rpc_method: String::new(), + params: HashMap::default(), }), }) .collect(), @@ -963,6 +964,7 @@ mod tests { operation_name: String::new(), fields: Vec::new(), rpc_method: String::new(), + params: HashMap::default(), }), } } diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 86e3928f0..f94c09ef9 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -821,6 +821,7 @@ fn allow_to_proto(allow: &L7AllowProfile) -> L7Allow { operation_name: allow.operation_name.clone(), fields: allow.fields.clone(), rpc_method: String::new(), + params: HashMap::new(), } } @@ -854,6 +855,7 @@ fn deny_rule_to_proto(rule: &L7DenyRuleProfile) -> L7DenyRule { operation_name: rule.operation_name.clone(), fields: rule.fields.clone(), rpc_method: String::new(), + params: HashMap::new(), } } diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego index c25b42af0..38070911c 100644 --- a/crates/openshell-sandbox/data/sandbox-policy.rego +++ b/crates/openshell-sandbox/data/sandbox-policy.rego @@ -274,6 +274,15 @@ request_denied_for_endpoint(request, endpoint) if { command_matches(request.command, deny_rule.command) } +# --- L7 deny rule matching: JSON-RPC method + params --- + +request_denied_for_endpoint(request, endpoint) if { + some deny_rule + deny_rule := endpoint.deny_rules[_] + deny_rule.rpc_method + jsonrpc_rule_matches(request, deny_rule) +} + # --- L7 deny rule matching: GraphQL operation --- request_denied_for_endpoint(request, endpoint) if { @@ -423,10 +432,7 @@ request_allowed_for_endpoint(request, endpoint) if { some rule rule := endpoint.rules[_] rule.allow.rpc_method - jsonrpc := object.get(request, "jsonrpc", {}) - method := object.get(jsonrpc, "method", null) - method != null - glob.match(rule.allow.rpc_method, [], method) + jsonrpc_rule_matches(request, rule.allow) } # --- L7 rule matching: GraphQL operation --- @@ -650,6 +656,35 @@ query_value_matches(value, matcher) if { glob.match(any_patterns[i], [], value) } +# JSON-RPC method and params matching. The sandbox flattens object params into +# dot-separated keys before policy evaluation, e.g. arguments.scope. +jsonrpc_rule_matches(request, rule) if { + jsonrpc := object.get(request, "jsonrpc", {}) + method := object.get(jsonrpc, "method", null) + method != null + glob.match(rule.rpc_method, [], method) + jsonrpc_params_match(jsonrpc, rule) +} + +jsonrpc_params_match(jsonrpc, rule) if { + param_rules := object.get(rule, "params", {}) + not jsonrpc_param_mismatch(jsonrpc, param_rules) +} + +jsonrpc_param_mismatch(jsonrpc, param_rules) if { + some key + matcher := param_rules[key] + not jsonrpc_param_key_matches(jsonrpc, key, matcher) +} + +jsonrpc_param_key_matches(jsonrpc, key, matcher) if { + params := object.get(jsonrpc, "params", {}) + value := object.get(params, key, null) + value != null + is_string(value) + query_value_matches(value, matcher) +} + # SQL command matching: "*" matches any; otherwise case-insensitive. command_matches(_, "*") if true diff --git a/crates/openshell-sandbox/src/l7/jsonrpc.rs b/crates/openshell-sandbox/src/l7/jsonrpc.rs index 2dc83c12d..e75ac761f 100644 --- a/crates/openshell-sandbox/src/l7/jsonrpc.rs +++ b/crates/openshell-sandbox/src/l7/jsonrpc.rs @@ -4,6 +4,7 @@ //! JSON-RPC 2.0 over HTTP L7 inspection. use miette::Result; +use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncWrite}; use crate::l7::provider::{L7Provider, L7Request}; @@ -33,6 +34,7 @@ pub(crate) async fn parse_jsonrpc_http_request, + pub params: HashMap, pub error: Option, } @@ -54,21 +56,57 @@ pub fn parse_jsonrpc_body(body: &[u8]) -> JsonRpcRequestInfo { let Ok(value) = serde_json::from_slice::(body) else { return JsonRpcRequestInfo { method: None, + params: HashMap::new(), error: Some("invalid JSON".to_string()), }; }; let Some(method) = value.get("method").and_then(|m| m.as_str()) else { return JsonRpcRequestInfo { method: None, + params: HashMap::new(), error: Some("missing or non-string 'method' field".to_string()), }; }; JsonRpcRequestInfo { method: Some(method.to_string()), + params: value + .get("params") + .map_or_else(HashMap::new, flatten_jsonrpc_params), error: None, } } +fn flatten_jsonrpc_params(value: &serde_json::Value) -> HashMap { + let mut params = HashMap::new(); + flatten_json_value("", value, &mut params); + params +} + +fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap) { + match value { + serde_json::Value::Object(map) => { + for (key, child) in map { + let next = if prefix.is_empty() { + key.clone() + } else { + format!("{prefix}.{key}") + }; + flatten_json_value(&next, child, out); + } + } + serde_json::Value::String(s) if !prefix.is_empty() => { + out.insert(prefix.to_string(), s.clone()); + } + serde_json::Value::Number(n) if !prefix.is_empty() => { + out.insert(prefix.to_string(), n.to_string()); + } + serde_json::Value::Bool(b) if !prefix.is_empty() => { + out.insert(prefix.to_string(), b.to_string()); + } + _ => {} + } +} + #[cfg(test)] mod tests { use super::*; @@ -81,6 +119,20 @@ mod tests { assert!(info.error.is_none()); } + #[test] + fn flattens_object_params_for_policy_matching() { + let body = br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"submit_report","arguments":{"scope":"workspace/main"}}}"#; + let info = parse_jsonrpc_body(body); + assert_eq!( + info.params.get("name").map(String::as_str), + Some("submit_report") + ); + assert_eq!( + info.params.get("arguments.scope").map(String::as_str), + Some("workspace/main") + ); + } + #[test] fn rpc_method_rule_empty_matches_any() { let info = parse_jsonrpc_body(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#); diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 0df2fe9b8..7d998f1e1 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -1321,6 +1321,7 @@ pub fn evaluate_l7_request( "graphql": request.graphql.clone(), "jsonrpc": request.jsonrpc.as_ref().map(|j| serde_json::json!({ "method": j.method, + "params": j.params, "error": j.error, })), } diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index 273bdee3b..bbe7b93b8 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -356,6 +356,7 @@ fn build_l7_rules(samples: &HashMap<(String, String), u32>) -> Vec { operation_name: String::new(), fields: Vec::new(), rpc_method: String::new(), + params: HashMap::new(), }), }); } diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs index de6ff5257..ac50340d1 100644 --- a/crates/openshell-sandbox/src/opa.rs +++ b/crates/openshell-sandbox/src/opa.rs @@ -923,6 +923,24 @@ fn resolve_binary_in_container(_policy_path: &str, _entrypoint_pid: u32) -> Opti None } +fn l7_matchers_to_json( + matchers: &std::collections::HashMap, +) -> serde_json::Map { + matchers + .iter() + .map(|(key, matcher)| { + let mut matcher_json = serde_json::json!({}); + if !matcher.glob.is_empty() { + matcher_json["glob"] = matcher.glob.clone().into(); + } + if !matcher.any.is_empty() { + matcher_json["any"] = matcher.any.clone().into(); + } + (key.clone(), matcher_json) + }) + .collect() +} + /// Convert typed proto policy fields to JSON suitable for `engine.add_data_json()`. /// /// The rego rules reference `data.*` directly, so the JSON structure has @@ -1028,29 +1046,18 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St { allow["fields"] = a.fields.clone().into(); } - let query: serde_json::Map = a - .map(|allow| { - allow - .query - .iter() - .map(|(key, matcher)| { - let mut matcher_json = serde_json::json!({}); - if !matcher.glob.is_empty() { - matcher_json["glob"] = - matcher.glob.clone().into(); - } - if !matcher.any.is_empty() { - matcher_json["any"] = - matcher.any.clone().into(); - } - (key.clone(), matcher_json) - }) - .collect() - }) - .unwrap_or_default(); + let query = a.map_or_else(serde_json::Map::new, |allow| { + l7_matchers_to_json(&allow.query) + }); if !query.is_empty() { allow["query"] = query.into(); } + let params = a.map_or_else(serde_json::Map::new, |allow| { + l7_matchers_to_json(&allow.params) + }); + if !params.is_empty() { + allow["params"] = params.into(); + } serde_json::json!({ "allow": allow }) }) .collect(); @@ -1086,23 +1093,17 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if !d.fields.is_empty() { deny["fields"] = d.fields.clone().into(); } - let query: serde_json::Map = d - .query - .iter() - .map(|(key, matcher)| { - let mut matcher_json = serde_json::json!({}); - if !matcher.glob.is_empty() { - matcher_json["glob"] = matcher.glob.clone().into(); - } - if !matcher.any.is_empty() { - matcher_json["any"] = matcher.any.clone().into(); - } - (key.clone(), matcher_json) - }) - .collect(); + if !d.rpc_method.is_empty() { + deny["rpc_method"] = d.rpc_method.clone().into(); + } + let query = l7_matchers_to_json(&d.query); if !query.is_empty() { deny["query"] = query.into(); } + let params = l7_matchers_to_json(&d.params); + if !params.is_empty() { + deny["params"] = params.into(); + } deny }) .collect(); @@ -1951,6 +1952,16 @@ process: } fn l7_jsonrpc_input(host: &str, port: u16, path: &str, rpc_method: &str) -> serde_json::Value { + l7_jsonrpc_input_with_params(host, port, path, rpc_method, serde_json::json!({})) + } + + fn l7_jsonrpc_input_with_params( + host: &str, + port: u16, + path: &str, + rpc_method: &str, + params: serde_json::Value, + ) -> serde_json::Value { serde_json::json!({ "network": { "host": host, "port": port }, "exec": { @@ -1963,7 +1974,8 @@ process: "path": path, "query_params": {}, "jsonrpc": { - "method": rpc_method + "method": rpc_method, + "params": params } } }) @@ -2516,6 +2528,7 @@ network_policies: operation_name: String::new(), fields: Vec::new(), rpc_method: String::new(), + params: std::collections::HashMap::new(), }), }], ..Default::default() @@ -2587,6 +2600,7 @@ network_policies: operation_name: String::new(), fields: Vec::new(), rpc_method: "initialize".to_string(), + params: std::collections::HashMap::new(), }), }], ..Default::default() @@ -2623,6 +2637,80 @@ network_policies: assert!(!eval_l7(&engine, &deny_input)); } + #[test] + fn l7_jsonrpc_params_rules_filter_tools_call() { + let data = r#" +network_policies: + jsonrpc_params: + name: jsonrpc_params + endpoints: + - host: mcp.params.test + port: 8000 + path: /mcp + protocol: json-rpc + enforcement: enforce + rules: + - allow: + rpc_method: tools/call + params: + name: read_status + - allow: + rpc_method: tools/call + params: + name: submit_report + arguments.scope: workspace/main + deny_rules: + - rpc_method: tools/call + params: + name: blocked_action + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).expect("engine from yaml"); + + let read_status = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({"name": "read_status"}), + ); + assert!(eval_l7(&engine, &read_status)); + + let submit_report = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({ + "name": "submit_report", + "arguments.scope": "workspace/main" + }), + ); + assert!(eval_l7(&engine, &submit_report)); + + let blocked_without_args = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({"name": "blocked_action"}), + ); + assert!(!eval_l7(&engine, &blocked_without_args)); + + let blocked_with_args = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({ + "name": "blocked_action", + "arguments.reason": "test" + }), + ); + assert!(!eval_l7(&engine, &blocked_with_args)); + } + #[test] fn l7_no_request_on_l4_only_endpoint() { // L4-only endpoint should not match L7 allow_request diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs index 20536ac2b..aa270d017 100644 --- a/crates/openshell-sandbox/src/policy_local.rs +++ b/crates/openshell-sandbox/src/policy_local.rs @@ -1084,6 +1084,7 @@ fn network_endpoint_from_json( operation_name: String::new(), fields: Vec::new(), rpc_method: String::new(), + params: HashMap::new(), }), }) .collect(); @@ -1099,6 +1100,7 @@ fn network_endpoint_from_json( operation_name: String::new(), fields: Vec::new(), rpc_method: String::new(), + params: HashMap::new(), }) .collect(); diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 2e2210f44..ebae4809a 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -8049,6 +8049,8 @@ mod tests { operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + params: HashMap::default(), + rpc_method: String::new(), }), }], }; @@ -8444,6 +8446,8 @@ mod tests { operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + params: HashMap::default(), + rpc_method: String::new(), }), }], }]; @@ -8590,6 +8594,8 @@ mod tests { operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + params: HashMap::default(), + rpc_method: String::new(), }), }], }; diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 9d6ec2824..afe1d3301 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -165,6 +165,9 @@ message L7DenyRule { repeated string fields = 7; // JSON-RPC method name (JSON-RPC): exact name or glob, e.g. "tools/call". string rpc_method = 8; + // JSON-RPC params matcher map. Dot-separated keys select nested params + // fields, e.g. "arguments.scope". + map params = 9; } // An L7 policy rule (allow-only). @@ -193,6 +196,9 @@ message L7Allow { repeated string fields = 7; // JSON-RPC method name (JSON-RPC): exact name or glob, e.g. "tools/call". string rpc_method = 8; + // JSON-RPC params matcher map. Dot-separated keys select nested params + // fields, e.g. "arguments.scope". + map params = 9; } // Query value matcher for one query parameter key. From 18aa4050e3901f76a2e677daa32565470a4b86ee Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 11:59:38 -0700 Subject: [PATCH 07/12] feat(l7): support JSON-RPC batch calls Parse JSON-RPC batch arrays into per-call metadata and evaluate each batch item with the existing method and params policy rules. Deny the whole batch when any call is denied. Signed-off-by: Kris Hicks --- crates/openshell-sandbox/src/l7/jsonrpc.rs | 109 ++++++++++++------- crates/openshell-sandbox/src/l7/relay.rs | 117 +++++++++++++++++++-- 2 files changed, 184 insertions(+), 42 deletions(-) diff --git a/crates/openshell-sandbox/src/l7/jsonrpc.rs b/crates/openshell-sandbox/src/l7/jsonrpc.rs index e75ac761f..81c7c62fa 100644 --- a/crates/openshell-sandbox/src/l7/jsonrpc.rs +++ b/crates/openshell-sandbox/src/l7/jsonrpc.rs @@ -33,19 +33,15 @@ pub(crate) async fn parse_jsonrpc_http_request, - pub params: HashMap, + pub calls: Vec, + pub is_batch: bool, pub error: Option, } -/// Returns true if the parsed request's method matches the given `rpc_method` rule pattern. -/// -/// An empty `rpc_method` pattern matches any method. -pub fn rpc_method_rule_matches(info: &JsonRpcRequestInfo, rpc_method: &str) -> bool { - if rpc_method.is_empty() { - return true; - } - info.method.as_deref() == Some(rpc_method) +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct JsonRpcCallInfo { + pub method: String, + pub params: HashMap, } /// Parse a JSON-RPC 2.0 request body and extract the `method` field. @@ -55,25 +51,60 @@ pub fn rpc_method_rule_matches(info: &JsonRpcRequestInfo, rpc_method: &str) -> b pub fn parse_jsonrpc_body(body: &[u8]) -> JsonRpcRequestInfo { let Ok(value) = serde_json::from_slice::(body) else { return JsonRpcRequestInfo { - method: None, - params: HashMap::new(), + calls: Vec::new(), + is_batch: false, error: Some("invalid JSON".to_string()), }; }; - let Some(method) = value.get("method").and_then(|m| m.as_str()) else { + + if let serde_json::Value::Array(items) = value { + if items.is_empty() { + return JsonRpcRequestInfo { + calls: Vec::new(), + is_batch: true, + error: Some("empty batch".to_string()), + }; + } + let mut calls = Vec::new(); + for item in &items { + let Some(call) = parse_jsonrpc_call(item) else { + return JsonRpcRequestInfo { + calls: Vec::new(), + is_batch: true, + error: Some("batch item missing or non-string 'method' field".to_string()), + }; + }; + calls.push(call); + } + return JsonRpcRequestInfo { + calls, + is_batch: true, + error: None, + }; + } + + let Some(call) = parse_jsonrpc_call(&value) else { return JsonRpcRequestInfo { - method: None, - params: HashMap::new(), + calls: Vec::new(), + is_batch: false, error: Some("missing or non-string 'method' field".to_string()), }; }; JsonRpcRequestInfo { - method: Some(method.to_string()), + calls: vec![call], + is_batch: false, + error: None, + } +} + +fn parse_jsonrpc_call(value: &serde_json::Value) -> Option { + let method = value.get("method").and_then(|m| m.as_str())?; + Some(JsonRpcCallInfo { + method: method.to_string(), params: value .get("params") .map_or_else(HashMap::new, flatten_jsonrpc_params), - error: None, - } + }) } fn flatten_jsonrpc_params(value: &serde_json::Value) -> HashMap { @@ -115,7 +146,12 @@ mod tests { fn parses_method_from_request_body() { let body = br#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#; let info = parse_jsonrpc_body(body); - assert_eq!(info.method.as_deref(), Some("initialize")); + assert_eq!( + info.calls.first().map(|call| call.method.as_str()), + Some("initialize") + ); + assert_eq!(info.calls.len(), 1); + assert!(!info.is_batch); assert!(info.error.is_none()); } @@ -123,31 +159,32 @@ mod tests { fn flattens_object_params_for_policy_matching() { let body = br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"submit_report","arguments":{"scope":"workspace/main"}}}"#; let info = parse_jsonrpc_body(body); + let params = &info.calls.first().expect("single request call").params; assert_eq!( - info.params.get("name").map(String::as_str), + params.get("name").map(String::as_str), Some("submit_report") ); assert_eq!( - info.params.get("arguments.scope").map(String::as_str), + params.get("arguments.scope").map(String::as_str), Some("workspace/main") ); } #[test] - fn rpc_method_rule_empty_matches_any() { - let info = parse_jsonrpc_body(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#); - assert!(rpc_method_rule_matches(&info, "")); - } - - #[test] - fn rpc_method_rule_matches_exact_method() { - let info = parse_jsonrpc_body(br#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#); - assert!(rpc_method_rule_matches(&info, "initialize")); - } - - #[test] - fn rpc_method_rule_does_not_match_different_method() { - let info = parse_jsonrpc_body(br#"{"jsonrpc":"2.0","id":1,"method":"tools/call"}"#); - assert!(!rpc_method_rule_matches(&info, "initialize")); + fn parses_valid_batch_without_error() { + let body = br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_status"}} + ]"#; + let info = parse_jsonrpc_body(body); + assert!(info.error.is_none()); + assert!(info.is_batch); + assert_eq!(info.calls.len(), 2); + assert_eq!(info.calls[0].method, "tools/list"); + assert_eq!(info.calls[1].method, "tools/call"); + assert_eq!( + info.calls[1].params.get("name").map(String::as_str), + Some("read_status") + ); } } diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index 7d998f1e1..f79dc7851 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -981,7 +981,11 @@ where SeverityId::Informational, ), }; - let rpc_method = jsonrpc_info.method.as_deref().unwrap_or("-"); + let rpc_method = jsonrpc_info + .calls + .first() + .map(|call| call.method.as_str()) + .unwrap_or("-"); let event = HttpActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) .action(action_id) @@ -1295,6 +1299,33 @@ pub fn evaluate_l7_request( engine: &TunnelPolicyEngine, ctx: &L7EvalContext, request: &L7RequestInfo, +) -> Result<(bool, String)> { + if let Some(jsonrpc) = &request.jsonrpc + && jsonrpc.is_batch + && !jsonrpc.calls.is_empty() + { + for call in &jsonrpc.calls { + let mut item_request = request.clone(); + item_request.jsonrpc = Some(crate::l7::jsonrpc::JsonRpcRequestInfo { + calls: vec![call.clone()], + is_batch: false, + error: None, + }); + let (allowed, reason) = evaluate_l7_request_once(engine, ctx, &item_request)?; + if !allowed { + return Ok((false, reason)); + } + } + return Ok((true, String::new())); + } + + evaluate_l7_request_once(engine, ctx, request) +} + +fn evaluate_l7_request_once( + engine: &TunnelPolicyEngine, + ctx: &L7EvalContext, + request: &L7RequestInfo, ) -> Result<(bool, String)> { if engine.is_stale() { return Err(miette!( @@ -1319,11 +1350,14 @@ pub fn evaluate_l7_request( "path": request.target, "query_params": request.query_params.clone(), "graphql": request.graphql.clone(), - "jsonrpc": request.jsonrpc.as_ref().map(|j| serde_json::json!({ - "method": j.method, - "params": j.params, - "error": j.error, - })), + "jsonrpc": request.jsonrpc.as_ref().map(|j| { + let call = if j.is_batch { None } else { j.calls.first() }; + serde_json::json!({ + "method": call.map(|call| call.method.as_str()), + "params": call.map(|call| call.params.clone()).unwrap_or_default(), + "error": j.error, + }) + }), } }); @@ -1966,6 +2000,77 @@ network_policies: assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); } + #[test] + fn jsonrpc_batch_evaluates_each_call() { + let data = r#" +network_policies: + jsonrpc_api: + name: jsonrpc_api + endpoints: + - host: api.example.test + port: 443 + protocol: json-rpc + enforcement: enforce + rules: + - allow: + method: POST + path: "/mcp" + rpc_method: "tools/list" + - allow: + method: POST + path: "/mcp" + rpc_method: "tools/call" + params: + name: read_status + deny_rules: + - rpc_method: "tools/call" + params: + name: blocked_action + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "jsonrpc_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let mut request = L7RequestInfo { + action: "POST".into(), + target: "/mcp".into(), + query_params: std::collections::HashMap::new(), + graphql: None, + jsonrpc: Some(crate::l7::jsonrpc::parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_status"}} + ]"#, + )), + }; + + let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + assert!(allowed, "{reason}"); + + request.jsonrpc = Some(crate::l7::jsonrpc::parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"blocked_action"}} + ]"#, + )); + let (allowed, _) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + assert!(!allowed); + } + #[tokio::test] async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { let data = r#" From 2df897c627f50ca761c51df3784735553b73c98e Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 14:02:40 -0700 Subject: [PATCH 08/12] fix(l7): redact JSON-RPC params in logs Log JSON-RPC endpoint, RPC methods, params SHA-256 digest, and policy version without recording raw params. Use when no params are present. Signed-off-by: Kris Hicks --- crates/openshell-sandbox/src/l7/jsonrpc.rs | 78 +++++++ crates/openshell-sandbox/src/l7/relay.rs | 233 +++++++++++++++++++-- 2 files changed, 293 insertions(+), 18 deletions(-) diff --git a/crates/openshell-sandbox/src/l7/jsonrpc.rs b/crates/openshell-sandbox/src/l7/jsonrpc.rs index 81c7c62fa..563767200 100644 --- a/crates/openshell-sandbox/src/l7/jsonrpc.rs +++ b/crates/openshell-sandbox/src/l7/jsonrpc.rs @@ -4,6 +4,8 @@ //! JSON-RPC 2.0 over HTTP L7 inspection. use miette::Result; +use sha2::{Digest, Sha256}; +use std::collections::BTreeMap; use std::collections::HashMap; use tokio::io::{AsyncRead, AsyncWrite}; @@ -44,6 +46,27 @@ pub struct JsonRpcCallInfo { pub params: HashMap, } +impl JsonRpcRequestInfo { + pub(crate) fn params_sha256(&self) -> Option { + if self.is_batch { + if self.calls.is_empty() || self.calls.iter().all(|call| call.params.is_empty()) { + return None; + } + let canonical_params = self + .calls + .iter() + .map(|call| canonical_params_map(&call.params)) + .collect::>(); + return Some(sha256_json(&canonical_params)); + } + + let call = self.calls.first()?; + if call.params.is_empty() { + return None; + } + Some(sha256_json(&canonical_params_map(&call.params))) + } +} /// Parse a JSON-RPC 2.0 request body and extract the `method` field. /// /// Returns an info struct with `method` set on success, or `error` set if the @@ -113,6 +136,18 @@ fn flatten_jsonrpc_params(value: &serde_json::Value) -> HashMap params } +fn canonical_params_map(params: &HashMap) -> BTreeMap { + params + .iter() + .map(|(key, value)| (key.clone(), value.clone())) + .collect() +} + +fn sha256_json(value: &impl serde::Serialize) -> String { + let encoded = serde_json::to_vec(value).expect("canonical JSON-RPC params should serialize"); + hex::encode(Sha256::digest(&encoded)) +} + fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap) { match value { serde_json::Value::Object(map) => { @@ -187,4 +222,47 @@ mod tests { Some("read_status") ); } + + #[test] + fn params_digest_is_canonical_and_redacted() { + let first = parse_jsonrpc_body( + br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"submit_report","arguments":{"scope":"workspace/main"}}}"#, + ); + let reordered = parse_jsonrpc_body( + br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"arguments":{"scope":"workspace/main"},"name":"submit_report"}}"#, + ); + let changed = parse_jsonrpc_body( + br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"submit_report","arguments":{"scope":"workspace/other"}}}"#, + ); + + let digest = first.params_sha256().expect("params digest"); + assert_eq!(Some(digest.as_str()), reordered.params_sha256().as_deref()); + assert_ne!(Some(digest.as_str()), changed.params_sha256().as_deref()); + assert_eq!(digest.len(), 64); + assert!(digest.chars().all(|c| c.is_ascii_hexdigit())); + assert!(!digest.contains("workspace/main")); + assert!(!digest.contains("submit_report")); + } + + #[test] + fn batch_params_digest_covers_call_params_without_raw_values() { + let batch = parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"blocked_action"}} + ]"#, + ); + let empty_batch = parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"initialize"} + ]"#, + ); + + let digest = batch.params_sha256().expect("batch params digest"); + assert_eq!(digest.len(), 64); + assert!(digest.chars().all(|c| c.is_ascii_hexdigit())); + assert!(!digest.contains("blocked_action")); + assert!(empty_batch.params_sha256().is_none()); + } } diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index f79dc7851..dfc2ee389 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -955,10 +955,12 @@ where .as_deref() .map(|e| format!("JSON-RPC request rejected: {e}")); let force_deny = parse_error_reason.is_some(); - let (allowed, reason) = if let Some(reason) = parse_error_reason { - (false, reason) + let (allowed, reason, jsonrpc_log_info) = if let Some(reason) = parse_error_reason { + (false, reason, jsonrpc_info.clone()) } else { - evaluate_l7_request(engine, ctx, &request_info)? + let evaluation = + evaluate_jsonrpc_l7_request_for_log(engine, ctx, &request_info, &jsonrpc_info)?; + (evaluation.allowed, evaluation.reason, evaluation.log_info) }; if close_if_stale(engine.generation_guard(), ctx) { @@ -981,11 +983,11 @@ where SeverityId::Informational, ), }; - let rpc_method = jsonrpc_info - .calls - .first() - .map(|call| call.method.as_str()) - .unwrap_or("-"); + let endpoint = format!("{}:{}{}", ctx.host, ctx.port, redacted_target); + let params_sha256 = jsonrpc_log_info + .params_sha256() + .unwrap_or_else(|| "".to_string()); + let policy_version = engine.captured_generation(); let event = HttpActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) .action(action_id) @@ -997,9 +999,14 @@ where )) .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) .firewall_rule(&ctx.policy_name, "l7-jsonrpc") - .message(format!( - "JSONRPC_L7_REQUEST {decision_str} {} {}:{}{} rpc_method={rpc_method} reason={}", - request_info.action, ctx.host, ctx.port, redacted_target, reason, + .message(jsonrpc_log_message( + decision_str, + &request_info.action, + &endpoint, + &jsonrpc_log_info, + ¶ms_sha256, + policy_version, + &reason, )) .build(); ocsf_emit!(event); @@ -1274,6 +1281,38 @@ fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String format!("graphql_ops={}", ops.join(";")) } +fn jsonrpc_log_message( + decision: &str, + http_method: &str, + endpoint: &str, + info: &crate::l7::jsonrpc::JsonRpcRequestInfo, + params_sha256: &str, + policy_version: u64, + reason: &str, +) -> String { + let rpc_methods = jsonrpc_methods_for_log(info); + format!( + "JSONRPC_L7_REQUEST decision={decision} http_method={http_method} endpoint={endpoint} rpc_methods={rpc_methods} params_sha256={params_sha256} policy_version={policy_version} reason={reason}" + ) +} + +fn jsonrpc_methods_for_log(info: &crate::l7::jsonrpc::JsonRpcRequestInfo) -> String { + if info.calls.is_empty() { + return "-".to_string(); + } + info.calls + .iter() + .map(|call| call.method.as_str()) + .collect::>() + .join(",") +} + +struct JsonRpcEvaluation { + allowed: bool, + reason: String, + log_info: crate::l7::jsonrpc::JsonRpcRequestInfo, +} + /// Check if a miette error represents a benign connection close. /// /// TLS handshake EOF, missing `close_notify`, connection resets, and broken @@ -1305,12 +1344,7 @@ pub fn evaluate_l7_request( && !jsonrpc.calls.is_empty() { for call in &jsonrpc.calls { - let mut item_request = request.clone(); - item_request.jsonrpc = Some(crate::l7::jsonrpc::JsonRpcRequestInfo { - calls: vec![call.clone()], - is_batch: false, - error: None, - }); + let item_request = jsonrpc_request_for_call(request, call); let (allowed, reason) = evaluate_l7_request_once(engine, ctx, &item_request)?; if !allowed { return Ok((false, reason)); @@ -1322,6 +1356,66 @@ pub fn evaluate_l7_request( evaluate_l7_request_once(engine, ctx, request) } +fn evaluate_jsonrpc_l7_request_for_log( + engine: &TunnelPolicyEngine, + ctx: &L7EvalContext, + request: &L7RequestInfo, + jsonrpc: &crate::l7::jsonrpc::JsonRpcRequestInfo, +) -> Result { + if jsonrpc.is_batch && !jsonrpc.calls.is_empty() { + let mut denied_calls = Vec::new(); + let mut first_denied_reason = None; + for call in &jsonrpc.calls { + let item_request = jsonrpc_request_for_call(request, call); + let (allowed, reason) = evaluate_l7_request_once(engine, ctx, &item_request)?; + if !allowed { + if first_denied_reason.is_none() { + first_denied_reason = Some(reason); + } + denied_calls.push(call.clone()); + } + } + + if denied_calls.is_empty() { + return Ok(JsonRpcEvaluation { + allowed: true, + reason: String::new(), + log_info: jsonrpc.clone(), + }); + } + + return Ok(JsonRpcEvaluation { + allowed: false, + reason: first_denied_reason.unwrap_or_else(|| "request denied by policy".to_string()), + log_info: crate::l7::jsonrpc::JsonRpcRequestInfo { + calls: denied_calls, + is_batch: true, + error: None, + }, + }); + } + + let (allowed, reason) = evaluate_l7_request_once(engine, ctx, request)?; + Ok(JsonRpcEvaluation { + allowed, + reason, + log_info: jsonrpc.clone(), + }) +} + +fn jsonrpc_request_for_call( + request: &L7RequestInfo, + call: &crate::l7::jsonrpc::JsonRpcCallInfo, +) -> L7RequestInfo { + let mut item_request = request.clone(); + item_request.jsonrpc = Some(crate::l7::jsonrpc::JsonRpcRequestInfo { + calls: vec![call.clone()], + is_batch: false, + error: None, + }); + item_request +} + fn evaluate_l7_request_once( engine: &TunnelPolicyEngine, ctx: &L7EvalContext, @@ -2026,6 +2120,7 @@ network_policies: - rpc_method: "tools/call" params: name: blocked_action + - rpc_method: "tools/delete" binaries: - { path: /usr/bin/node } "#; @@ -2064,11 +2159,113 @@ network_policies: request.jsonrpc = Some(crate::l7::jsonrpc::parse_jsonrpc_body( br#"[ {"jsonrpc":"2.0","id":1,"method":"tools/list"}, - {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"blocked_action"}} + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"blocked_action"}}, + {"jsonrpc":"2.0","id":3,"method":"tools/delete","params":{"name":"purge_cache"}} ]"#, )); let (allowed, _) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); assert!(!allowed); + + let jsonrpc = request.jsonrpc.as_ref().expect("jsonrpc request"); + let evaluation = + evaluate_jsonrpc_l7_request_for_log(&tunnel_engine, &ctx, &request, jsonrpc).unwrap(); + assert!(!evaluation.allowed); + assert!(evaluation.log_info.is_batch); + assert_eq!( + jsonrpc_methods_for_log(&evaluation.log_info), + "tools/call,tools/delete" + ); + + let full_params_sha256 = jsonrpc.params_sha256().expect("full batch params digest"); + let log_params_sha256 = evaluation + .log_info + .params_sha256() + .expect("logged batch params digest"); + assert_ne!(full_params_sha256, log_params_sha256); + let message = jsonrpc_log_message( + "deny", + "POST", + "api.example.test:443/mcp", + &evaluation.log_info, + &log_params_sha256, + 42, + &evaluation.reason, + ); + assert!(message.contains("rpc_methods=tools/call,tools/delete")); + assert!(message.contains("params_sha256=")); + assert!(!message.contains("params_sha256=sha256:")); + assert!(message.contains("policy_version=42")); + assert!(!message.contains("tools/list")); + assert!(!message.contains("blocked_action")); + assert!(!message.contains("purge_cache")); + } + + #[test] + fn jsonrpc_log_records_digest_not_args() { + let info = crate::l7::jsonrpc::parse_jsonrpc_body( + br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"delete_resource","arguments":{"scope":"secret-scope"}}}"#, + ); + let params_sha256 = info.params_sha256().expect("params digest"); + let message = jsonrpc_log_message( + "deny", + "POST", + "mcp.example.com:443/mcp", + &info, + ¶ms_sha256, + 42, + "request denied by policy", + ); + + assert!(message.contains("endpoint=mcp.example.com:443/mcp")); + assert!(message.contains("rpc_methods=tools/call")); + assert!(message.contains("params_sha256=")); + assert!(!message.contains("params_sha256=sha256:")); + assert!(message.contains("policy_version=42")); + assert!(!message.contains("delete_resource")); + assert!(!message.contains("secret-scope")); + + let batch = crate::l7::jsonrpc::parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"delete_resource"}} + ]"#, + ); + let batch_params_sha256 = batch.params_sha256().expect("batch params digest"); + let batch_message = jsonrpc_log_message( + "allow", + "POST", + "mcp.example.com:443/mcp", + &batch, + &batch_params_sha256, + 43, + "", + ); + + assert!(batch_message.starts_with("JSONRPC_L7_REQUEST ")); + assert!(batch_message.contains("rpc_methods=tools/list,tools/call")); + assert!(batch_message.contains("params_sha256=")); + assert!(!batch_message.contains("params_sha256=sha256:")); + assert!(batch_message.contains("policy_version=43")); + assert!(!batch_message.contains("rpc_method=")); + assert!(!batch_message.contains("delete_resource")); + + let no_params = crate::l7::jsonrpc::parse_jsonrpc_body( + br#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#, + ); + let no_params_sha256 = no_params + .params_sha256() + .unwrap_or_else(|| "".to_string()); + let no_params_message = jsonrpc_log_message( + "allow", + "POST", + "mcp.example.com:443/mcp", + &no_params, + &no_params_sha256, + 44, + "", + ); + assert!(no_params_message.contains("rpc_methods=initialize")); + assert!(no_params_message.contains("params_sha256=")); } #[tokio::test] From f1d3c1525302dac857221a183b7efad985ea1b78 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Wed, 10 Jun 2026 13:06:04 -0700 Subject: [PATCH 09/12] docs(policy): document JSON-RPC L7 rules Document JSON-RPC endpoint configuration, rpc_method and params matchers, batch denial behavior, current directionality limits, matcher scope, and the current policy update CLI limitation. Signed-off-by: Kris Hicks --- architecture/sandbox.md | 6 +++ docs/reference/policy-schema.mdx | 78 +++++++++++++++++++++++++++++--- docs/sandboxes/policies.mdx | 54 ++++++++++++++++++++-- 3 files changed, 128 insertions(+), 10 deletions(-) diff --git a/architecture/sandbox.md b/architecture/sandbox.md index e60b727a5..0e2d1c559 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -49,6 +49,12 @@ paths, such as proxy support files or GPU device paths when a GPU is present. All ordinary agent egress is routed through the sandbox proxy. The proxy identifies the calling binary, checks trust-on-first-use binary identity, rejects unsafe internal destinations, and evaluates the active policy. +For inspected HTTP traffic, the proxy can enforce REST method/path rules, +WebSocket upgrade and text-message rules, GraphQL operation rules, and +JSON-RPC method and params rules on sandbox-to-server request bodies. JSON-RPC +request inspection buffers up to the endpoint `json_rpc.max_body_bytes` limit. +JSON-RPC responses and server-to-client MCP messages on response or SSE streams +are relayed but are not currently parsed for policy enforcement. `https://inference.local` is special. It bypasses OPA network policy and is handled by the inference interception path: diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 59f72c9f7..fa540e4dd 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -155,7 +155,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `host` | string | Yes | Hostname or IP address. Supports a `*` wildcard inside the first DNS label only: `*.example.com`, `**.example.com`, and intra-label patterns like `*-aiplatform.googleapis.com` are accepted; bare `*`/`**`, TLD wildcards (`*.com`), and wildcards outside the first label are rejected at load time. | | `port` | integer | Yes | TCP port number. | | `path` | string | No | Optional HTTP path glob used to select between L7 endpoints that share the same host and port. Empty means all paths. Use this when REST and GraphQL live under the same host, such as `/repos/**` and `/graphql`. | -| `protocol` | string | No | Set to `rest` for HTTP method/path inspection, `websocket` for RFC 6455 upgrade and client text-message inspection, or `graphql` for GraphQL-over-HTTP operation inspection. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket traffic. Omit for TCP passthrough. | +| `protocol` | string | No | Set to `rest` for HTTP method/path inspection, `websocket` for RFC 6455 upgrade and client text-message inspection, `graphql` for GraphQL-over-HTTP operation inspection, or `json-rpc` for sandbox-to-server JSON-RPC-over-HTTP method and params inspection. WebSocket endpoints can also use GraphQL operation rules for GraphQL-over-WebSocket traffic. Omit for TCP passthrough. | | `tls` | string | No | TLS handling mode. The proxy auto-detects TLS by peeking the first bytes of each connection and terminates it for inspected HTTPS traffic, so this field is optional in most cases. Set to `skip` to disable auto-detection for edge cases such as client-certificate mTLS or non-standard protocols. The values `terminate` and `passthrough` are deprecated and log a warning; they are still accepted for backward compatibility but have no effect on behavior. | | `enforcement` | string | No | `enforce` actively blocks disallowed requests. `audit` logs violations but allows traffic through. | | `access` | string | No | Access preset. One of `read-only`, `read-write`, or `full`. Mutually exclusive with `rules`. | @@ -168,6 +168,7 @@ Each endpoint defines a reachable destination and optional inspection rules. | `persisted_queries` | string | No | GraphQL hash-only behavior for `protocol: graphql` and GraphQL-over-WebSocket operation policy. Default is `deny`; use `allow_registered` only with `graphql_persisted_queries`. | | `graphql_persisted_queries` | map | No | Trusted GraphQL persisted-query registry keyed by hash or saved-query ID. Values contain `operation_type`, optional `operation_name`, and optional root `fields`. | | `graphql_max_body_bytes` | integer | No | Maximum GraphQL-over-HTTP request body bytes buffered for inspection. Defaults to `65536`. | +| `json_rpc` | object | No | JSON-RPC endpoint options. For `protocol: json-rpc`, `json_rpc.max_body_bytes` sets the maximum JSON-RPC-over-HTTP request body bytes buffered for inspection. Defaults to `65536`. | Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placeholder form and whole-token provider-shaped aliases such as `provider-OPENSHELL-RESOLVE-ENV-API_TOKEN` when the referenced environment key exists in the configured provider credentials. @@ -175,11 +176,13 @@ Credential rewrite recognizes the canonical `openshell:resolve:env:KEY` placehol The `access` field accepts one of the following values: -| Value | REST expansion | WebSocket expansion | GraphQL expansion | -|---|---|---|---| -| `full` | All methods and paths. | WebSocket upgrade and all inspected client text-message paths. | All operation types. | -| `read-only` | `GET`, `HEAD`, `OPTIONS`. | WebSocket upgrade handshake only. | `query` operations. | -| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | WebSocket upgrade handshake and client text messages. | `query` and `mutation` operations. | +| Value | REST expansion | WebSocket expansion | GraphQL expansion | JSON-RPC behavior | +|---|---|---|---|---| +| `full` | All methods and paths. | WebSocket upgrade and all inspected client text-message paths. | All operation types. | Allows matching HTTP requests without constraining JSON-RPC methods. | +| `read-only` | `GET`, `HEAD`, `OPTIONS`. | WebSocket upgrade handshake only. | `query` operations. | Expands to HTTP read methods and does not allow typical JSON-RPC `POST` calls. | +| `read-write` | `GET`, `HEAD`, `OPTIONS`, `POST`, `PUT`, `PATCH`. | WebSocket upgrade handshake and client text messages. | `query` and `mutation` operations. | Allows matching HTTP write requests without constraining JSON-RPC methods. | + +For JSON-RPC endpoints, prefer explicit `rules` with `rpc_method` and optional `params` when you need method-level control. #### Allow Rule Objects @@ -274,6 +277,42 @@ rules: Do not combine `method`, `path`, or `query` with `operation_type`, `operation_name`, or `fields` inside the same WebSocket rule. When a WebSocket endpoint has GraphQL operation policy, use GraphQL rules for client messages instead of a raw `WEBSOCKET_TEXT` allow rule. +##### JSON-RPC Allow Rule (`protocol: json-rpc`) + +JSON-RPC allow rules match sandbox-to-server JSON-RPC-over-HTTP request objects by RPC method and optional params. They apply to single JSON-RPC requests and batch requests. For a batch, OpenShell evaluates each call independently. JSON-RPC responses and server-to-client messages on response bodies or MCP SSE streams are relayed but are not currently parsed for policy enforcement. + +| Field | Type | Required | Description | +|---|---|---|---| +| `rpc_method` | string | Yes | JSON-RPC method name or glob, such as `initialize`, `tools/list`, or `tools/*`. | +| `params` | map | No | Params matchers keyed by flattened object-param path. Use dot-separated keys for nested object params, such as `arguments.scope`. Matcher value can be a glob string or an object with `any`. Strings, numbers, and booleans are converted to strings; arrays, `null`, and non-object top-level params do not produce matcher keys. | + +Example JSON-RPC allow rules: + +```yaml showLineNumbers={false} +endpoints: + - host: mcp.example.com + port: 443 + path: /mcp + protocol: json-rpc + enforcement: enforce + json_rpc: + max_body_bytes: 131072 + rules: + - allow: + rpc_method: initialize + - allow: + rpc_method: tools/list + - allow: + rpc_method: tools/call + params: + name: read_status + - allow: + rpc_method: tools/call + params: + name: submit_report + arguments.scope: workspace/main +``` + #### Deny Rule Objects Blocks specific operations on endpoints that otherwise have broad access. Deny rules are evaluated after allow rules and take precedence: if a request matches any deny rule, it is blocked regardless of what the allow rules or access preset permit. @@ -356,6 +395,33 @@ endpoints: operation_name: Admin* ``` +##### JSON-RPC Deny Rule (`protocol: json-rpc`) + +JSON-RPC deny rules use the same field names as JSON-RPC allow rules, but they appear directly under each `deny_rules` entry instead of under an `allow` wrapper. Deny rules take precedence over allow rules. In a batch request, one denied call denies the full batch. + +| Field | Type | Required | Description | +|---|---|---|---| +| `rpc_method` | string | Yes | JSON-RPC method name or glob to deny. | +| `params` | map | No | Params matchers keyed by flattened object-param path. Omit to deny every call matching `rpc_method`. Strings, numbers, and booleans are converted to strings; arrays, `null`, and non-object top-level params do not produce matcher keys. | + +Example JSON-RPC deny rules: + +```yaml showLineNumbers={false} +endpoints: + - host: mcp.example.com + port: 443 + path: /mcp + protocol: json-rpc + enforcement: enforce + rules: + - allow: + rpc_method: tools/* + deny_rules: + - rpc_method: tools/call + params: + name: delete_resource +``` + ### Binary Object Identifies an executable that is permitted to use the associated endpoints. diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 406ed12b8..bae6fa279 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -148,7 +148,7 @@ The following steps outline the hot-reload policy update workflow. To inspect a stored sandbox-authored revision instead of the current effective policy, pass `--rev `. -5. Edit the YAML: add or adjust `network_policies` entries, binaries, `access`, or `rules`. +5. Edit the YAML: add or adjust `network_policies` entries, binaries, `access`, `rules`, or protocol-specific matchers such as GraphQL operation fields and JSON-RPC `rpc_method` / `params` rules. 6. Push the updated policy when you need a full replacement. Exit codes: 0 = loaded, 1 = validation failed, 124 = timeout. @@ -173,7 +173,7 @@ Use `openshell policy update` when you want to merge network policy changes into - remove one endpoint or one named rule without rewriting the rest of the file. - preview a merged result locally with `--dry-run` before you send it to the gateway. -Use `openshell policy set` instead when you want to replace the full policy, update static sections, or make broader edits that are easier to express in YAML. +Use `openshell policy set` instead when you want to replace the full policy, update static sections, or make broader edits that are easier to express in YAML. Use full YAML for GraphQL and JSON-RPC rule shapes. ### Update Commands @@ -210,6 +210,7 @@ This is the practical difference: Current constraints: - `--add-allow` and `--add-deny` work on `protocol: rest` and `protocol: websocket` endpoints. +- GraphQL and JSON-RPC fine-grained rules require full policy YAML applied with `openshell policy set`. - `--add-deny` requires the endpoint to already have an allow base, either an `access` preset or explicit allow `rules`. - `protocol: sql` is not a practical incremental workflow today. OpenShell does not do full SQL parsing, and SQL enforcement is not meaningfully supported yet. @@ -228,7 +229,7 @@ Each segment has a fixed meaning: | `host` | Yes | Destination hostname. | | `port` | Yes | Destination port, `1` through `65535`. | | `access` | No | Access preset for L7 endpoints: `read-only`, `read-write`, or `full`. Incremental updates expand presets into protocol-specific method/path rules for REST and WebSocket endpoints. | -| `protocol` | No | L7 inspection mode: `rest`, `websocket`, or `sql`. `sql` is audit-only and not a recommended workflow today. | +| `protocol` | No | L7 inspection mode accepted by `openshell policy update`: `rest`, `websocket`, or `sql`. `sql` is audit-only and not a recommended workflow today. Full policy YAML also supports `graphql` and `json-rpc`. | | `enforcement` | No | Enforcement mode for inspected traffic: `enforce` or `audit`. | | `options` | No | Comma-separated endpoint options. Use `websocket-credential-rewrite` with `protocol: websocket` or REST compatibility endpoints that perform a WebSocket upgrade. Use `request-body-credential-rewrite` only with `protocol: rest`. | @@ -548,7 +549,7 @@ For an end-to-end walkthrough that combines this policy with a GitHub credential - { path: /usr/bin/gh } ``` -Endpoints with `protocol: rest` enable HTTP request inspection and can opt in to supported text request body credential rewrite. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. WebSocket endpoints can also classify GraphQL-over-WebSocket operation messages with the same operation rules used by GraphQL-over-HTTP. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. +Endpoints with `protocol: rest` enable HTTP request inspection and can opt in to supported text request body credential rewrite. Endpoints with `protocol: websocket` validate WebSocket upgrades and inspect client text messages on the upgraded request path. WebSocket endpoints can also classify GraphQL-over-WebSocket operation messages with the same operation rules used by GraphQL-over-HTTP. Endpoints with `protocol: graphql` parse GraphQL-over-HTTP payloads before evaluating rules. Endpoints with `protocol: json-rpc` parse JSON-RPC-over-HTTP request bodies and evaluate `rpc_method` and optional params rules. The endpoint-level `path` field lets these protocols share `api.github.com:443` without treating GraphQL payloads as plain REST `POST /graphql` requests. @@ -579,6 +580,51 @@ REST rules can also constrain query parameter values: `query` matchers are case-sensitive and run on decoded values. If a request has duplicate keys (for example, `tag=a&tag=b`), every value for that key must match the configured glob(s). +### JSON-RPC matching + +JSON-RPC endpoints use `protocol: json-rpc`. The proxy parses sandbox-to-server JSON-RPC-over-HTTP request bodies, evaluates the `method` field against `rpc_method`, and can match object params through dot-separated `params` keys. + +JSON-RPC policy enforcement is directional. It applies to HTTP request bodies sent by the sandboxed process to the configured endpoint. JSON-RPC responses and server-to-client messages carried on response bodies or MCP SSE streams are relayed but are not currently parsed for policy enforcement. + +JSON-RPC endpoint policies currently require full policy YAML applied with `openshell policy set`; the incremental `openshell policy update --add-endpoint` parser does not accept `json-rpc` as a protocol. + +```yaml showLineNumbers={false} + mcp_server: + name: mcp_server + endpoints: + - host: mcp.example.com + port: 443 + path: /mcp + protocol: json-rpc + enforcement: enforce + json_rpc: + max_body_bytes: 131072 + rules: + - allow: + rpc_method: initialize + - allow: + rpc_method: tools/list + - allow: + rpc_method: tools/call + params: + name: read_status + - allow: + rpc_method: tools/call + params: + name: submit_report + arguments.scope: workspace/main + deny_rules: + - rpc_method: tools/call + params: + name: delete_resource + binaries: + - { path: /usr/bin/python3 } +``` + +`json_rpc.max_body_bytes` controls how many JSON-RPC-over-HTTP request body bytes OpenShell buffers for inspection. It defaults to `65536`. + +`params` matchers are case-sensitive and use the same string glob or `{ any: [...] }` matcher syntax as REST query parameters. They match scalar leaf values from object params: strings, numbers, and booleans are converted to strings, and nested JSON object params are flattened with dot-separated keys before matching. Arrays, `null`, and non-object top-level params do not produce matcher keys. This is useful for controls such as matching MCP `tools/call` by `params.name`, but it is not a complete MCP payload policy for rich nested content. For batch requests, OpenShell evaluates each JSON-RPC call independently and denies the whole batch if any call is denied. + ### GraphQL matching GraphQL endpoints use `protocol: graphql`. The proxy parses GraphQL-over-HTTP `GET` and `POST` requests, classifies each operation, and evaluates rules against the operation type, optional operation name, and selected root fields. From b25c1a5e172fe963d6af8fd3e44fd98ed75bd913 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Mon, 15 Jun 2026 10:36:00 -0700 Subject: [PATCH 10/12] fix(sandbox): fail closed on ambiguous JSON-RPC requests Reject ambiguous JSON-RPC params before policy evaluation by refusing literal dotted object keys and flattened selector collisions. Force-deny JSON-RPC parse errors in the forward-proxy path so broad REST-style access presets cannot bypass malformed JSON-RPC bodies. Require JSON-RPC 2.0 request objects, use JSON-RPC-specific forward audit logs with RPC methods, params digest, and policy generation, and reject unsupported json_rpc YAML knobs instead of accepting unused fields. Signed-off-by: Kris Hicks --- architecture/sandbox.md | 2 + crates/openshell-policy/src/lib.rs | 35 +++-- crates/openshell-sandbox/src/l7/jsonrpc.rs | 156 +++++++++++++++++---- crates/openshell-sandbox/src/l7/relay.rs | 4 +- crates/openshell-sandbox/src/proxy.rs | 88 +++++++++--- docs/reference/policy-schema.mdx | 2 +- docs/sandboxes/policies.mdx | 2 +- 7 files changed, 233 insertions(+), 56 deletions(-) diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 0e2d1c559..eb78eb6ad 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -53,6 +53,8 @@ For inspected HTTP traffic, the proxy can enforce REST method/path rules, WebSocket upgrade and text-message rules, GraphQL operation rules, and JSON-RPC method and params rules on sandbox-to-server request bodies. JSON-RPC request inspection buffers up to the endpoint `json_rpc.max_body_bytes` limit. +Literal dotted keys in JSON-RPC params are rejected before policy evaluation so +they cannot be confused with flattened nested selector paths. JSON-RPC responses and server-to-client MCP messages on response or SSE streams are relayed but are not currently parsed for policy enforcement. diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 41e3e38ed..7a380eb9d 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -156,18 +156,10 @@ fn is_zero_u32(v: &u32) -> bool { struct JsonRpcConfigDef { #[serde(default, skip_serializing_if = "is_zero_u32")] max_body_bytes: u32, - #[serde(default, skip_serializing_if = "String::is_empty")] - on_parse_error: String, - #[serde(default, skip_serializing_if = "String::is_empty")] - batch_policy: String, } fn json_rpc_config_from_proto(max_body_bytes: u32) -> Option { - (max_body_bytes > 0).then_some(JsonRpcConfigDef { - max_body_bytes, - on_parse_error: String::new(), - batch_policy: String::new(), - }) + (max_body_bytes > 0).then_some(JsonRpcConfigDef { max_body_bytes }) } #[derive(Debug, Serialize, Deserialize)] @@ -1777,6 +1769,31 @@ network_policies: assert_eq!(ep.json_rpc_max_body_bytes, 131_072); } + #[test] + fn parse_rejects_unsupported_json_rpc_config_fields() { + let yaml = r" +version: 1 +network_policies: + mcp: + endpoints: + - host: mcp.example.com + port: 443 + protocol: json-rpc + json_rpc: + max_body_bytes: 131072 + on_parse_error: deny + batch_policy: all + access: full + binaries: + - path: /usr/bin/curl +"; + + assert!( + parse_sandbox_policy(yaml).is_err(), + "unsupported json_rpc fields must not be silently accepted" + ); + } + #[test] fn round_trip_preserves_websocket_credential_rewrite() { let yaml = r" diff --git a/crates/openshell-sandbox/src/l7/jsonrpc.rs b/crates/openshell-sandbox/src/l7/jsonrpc.rs index 563767200..a27297811 100644 --- a/crates/openshell-sandbox/src/l7/jsonrpc.rs +++ b/crates/openshell-sandbox/src/l7/jsonrpc.rs @@ -90,12 +90,15 @@ pub fn parse_jsonrpc_body(body: &[u8]) -> JsonRpcRequestInfo { } let mut calls = Vec::new(); for item in &items { - let Some(call) = parse_jsonrpc_call(item) else { - return JsonRpcRequestInfo { - calls: Vec::new(), - is_batch: true, - error: Some("batch item missing or non-string 'method' field".to_string()), - }; + let call = match parse_jsonrpc_call(item) { + Ok(call) => call, + Err(error) => { + return JsonRpcRequestInfo { + calls: Vec::new(), + is_batch: true, + error: Some(format!("batch item invalid: {error}")), + }; + } }; calls.push(call); } @@ -106,12 +109,15 @@ pub fn parse_jsonrpc_body(body: &[u8]) -> JsonRpcRequestInfo { }; } - let Some(call) = parse_jsonrpc_call(&value) else { - return JsonRpcRequestInfo { - calls: Vec::new(), - is_batch: false, - error: Some("missing or non-string 'method' field".to_string()), - }; + let call = match parse_jsonrpc_call(&value) { + Ok(call) => call, + Err(error) => { + return JsonRpcRequestInfo { + calls: Vec::new(), + is_batch: false, + error: Some(error), + }; + } }; JsonRpcRequestInfo { calls: vec![call], @@ -120,20 +126,33 @@ pub fn parse_jsonrpc_body(body: &[u8]) -> JsonRpcRequestInfo { } } -fn parse_jsonrpc_call(value: &serde_json::Value) -> Option { - let method = value.get("method").and_then(|m| m.as_str())?; - Some(JsonRpcCallInfo { +fn parse_jsonrpc_call(value: &serde_json::Value) -> std::result::Result { + let version = value + .get("jsonrpc") + .and_then(|v| v.as_str()) + .ok_or_else(|| "missing or non-string 'jsonrpc' field".to_string())?; + if version != "2.0" { + return Err(format!("unsupported JSON-RPC version '{version}'")); + } + let method = value + .get("method") + .and_then(|m| m.as_str()) + .ok_or_else(|| "missing or non-string 'method' field".to_string())?; + let params = value + .get("params") + .map_or_else(|| Ok(HashMap::new()), flatten_jsonrpc_params)?; + Ok(JsonRpcCallInfo { method: method.to_string(), - params: value - .get("params") - .map_or_else(HashMap::new, flatten_jsonrpc_params), + params, }) } -fn flatten_jsonrpc_params(value: &serde_json::Value) -> HashMap { +fn flatten_jsonrpc_params( + value: &serde_json::Value, +) -> std::result::Result, String> { let mut params = HashMap::new(); - flatten_json_value("", value, &mut params); - params + flatten_json_value("", value, &mut params)?; + Ok(params) } fn canonical_params_map(params: &HashMap) -> BTreeMap { @@ -148,29 +167,50 @@ fn sha256_json(value: &impl serde::Serialize) -> String { hex::encode(Sha256::digest(&encoded)) } -fn flatten_json_value(prefix: &str, value: &serde_json::Value, out: &mut HashMap) { +fn flatten_json_value( + prefix: &str, + value: &serde_json::Value, + out: &mut HashMap, +) -> std::result::Result<(), String> { match value { serde_json::Value::Object(map) => { for (key, child) in map { + if key.contains('.') { + return Err(format!( + "ambiguous dotted params key '{key}' is not allowed" + )); + } let next = if prefix.is_empty() { key.clone() } else { format!("{prefix}.{key}") }; - flatten_json_value(&next, child, out); + flatten_json_value(&next, child, out)?; } } serde_json::Value::String(s) if !prefix.is_empty() => { - out.insert(prefix.to_string(), s.clone()); + insert_flattened_param(out, prefix, s.clone())?; } serde_json::Value::Number(n) if !prefix.is_empty() => { - out.insert(prefix.to_string(), n.to_string()); + insert_flattened_param(out, prefix, n.to_string())?; } serde_json::Value::Bool(b) if !prefix.is_empty() => { - out.insert(prefix.to_string(), b.to_string()); + insert_flattened_param(out, prefix, b.to_string())?; } _ => {} } + Ok(()) +} + +fn insert_flattened_param( + out: &mut HashMap, + key: &str, + value: String, +) -> std::result::Result<(), String> { + if out.insert(key.to_string(), value).is_some() { + return Err(format!("ambiguous params key collision at '{key}'")); + } + Ok(()) } #[cfg(test)] @@ -205,6 +245,70 @@ mod tests { ); } + #[test] + fn rejects_literal_dotted_param_keys() { + let body = br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"arguments.scope":"workspace/other","arguments":{"scope":"workspace/main"}}}"#; + let info = parse_jsonrpc_body(body); + + assert!(info.calls.is_empty()); + assert!( + info.error + .as_deref() + .is_some_and(|error| error.contains("ambiguous dotted params key")), + "expected dotted params key error, got {info:?}" + ); + } + + #[test] + fn rejects_requests_missing_jsonrpc_version() { + let body = br#"{"id":1,"method":"tools/list"}"#; + let info = parse_jsonrpc_body(body); + + assert!(info.calls.is_empty()); + assert_eq!( + info.error.as_deref(), + Some("missing or non-string 'jsonrpc' field") + ); + } + + #[test] + fn rejects_batch_items_missing_jsonrpc_version() { + let body = br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"id":2,"method":"tools/call","params":{"name":"read_status"}} + ]"#; + let info = parse_jsonrpc_body(body); + + assert!(info.calls.is_empty()); + assert!(info.is_batch); + assert_eq!( + info.error.as_deref(), + Some("batch item invalid: missing or non-string 'jsonrpc' field") + ); + } + + #[test] + fn rejects_unsupported_jsonrpc_version() { + let body = br#"{"jsonrpc":"1.0","id":1,"method":"tools/list"}"#; + let info = parse_jsonrpc_body(body); + + assert!(info.calls.is_empty()); + assert_eq!( + info.error.as_deref(), + Some("unsupported JSON-RPC version '1.0'") + ); + } + + #[test] + fn detects_flattened_param_collisions() { + let mut params = HashMap::from([("arguments.scope".to_string(), "first".to_string())]); + + let error = insert_flattened_param(&mut params, "arguments.scope", "second".to_string()) + .expect_err("duplicate flattened key should be ambiguous"); + + assert!(error.contains("ambiguous params key collision")); + } + #[test] fn parses_valid_batch_without_error() { let body = br#"[ diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs index dfc2ee389..08bce6ff5 100644 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ b/crates/openshell-sandbox/src/l7/relay.rs @@ -1281,7 +1281,7 @@ fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String format!("graphql_ops={}", ops.join(";")) } -fn jsonrpc_log_message( +pub(crate) fn jsonrpc_log_message( decision: &str, http_method: &str, endpoint: &str, @@ -1296,7 +1296,7 @@ fn jsonrpc_log_message( ) } -fn jsonrpc_methods_for_log(info: &crate::l7::jsonrpc::JsonRpcRequestInfo) -> String { +pub(crate) fn jsonrpc_methods_for_log(info: &crate::l7::jsonrpc::JsonRpcRequestInfo) -> String { if info.calls.is_empty() { return "-".to_string(); } diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs index a45f1bc60..d6a1807c7 100644 --- a/crates/openshell-sandbox/src/proxy.rs +++ b/crates/openshell-sandbox/src/proxy.rs @@ -345,6 +345,21 @@ fn emit_forward_success_activity(tx: Option<&ActivitySender>, l7_activity_pendin ); } +fn l7_parse_error_reason(request_info: &crate::l7::L7RequestInfo) -> Option { + request_info + .graphql + .as_ref() + .and_then(|info| info.error.as_deref()) + .map(|error| format!("GraphQL request rejected: {error}")) + .or_else(|| { + request_info + .jsonrpc + .as_ref() + .and_then(|info| info.error.as_deref()) + .map(|error| format!("JSON-RPC request rejected: {error}")) + }) +} + /// Emit a denial event to the aggregator channel (if configured). /// Used by `handle_tcp_connection` which owns `Option`. fn emit_denial( @@ -3453,11 +3468,7 @@ async fn handle_forward_proxy( jsonrpc, }; - let parse_error_reason = request_info - .graphql - .as_ref() - .and_then(|info| info.error.as_deref()) - .map(|error| format!("GraphQL request rejected: {error}")); + let parse_error_reason = l7_parse_error_reason(&request_info); let force_deny = parse_error_reason.is_some(); let (allowed, reason) = parse_error_reason.map_or_else( || { @@ -3498,16 +3509,39 @@ async fn handle_forward_proxy( SeverityId::Informational, ), }; - let engine_type = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { - "l7-graphql" - } else { - "l7" - }; - let message_prefix = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { - "FORWARD_GRAPHQL_L7" - } else { - "FORWARD_L7" + let engine_type = match l7_config.config.protocol { + crate::l7::L7Protocol::Graphql => "l7-graphql", + crate::l7::L7Protocol::JsonRpc => "l7-jsonrpc", + _ => "l7", }; + let log_message = request_info.jsonrpc.as_ref().map_or_else( + || { + let message_prefix = + if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { + "FORWARD_GRAPHQL_L7" + } else { + "FORWARD_L7" + }; + format!( + "{message_prefix} {decision_str} {method} {host_lc}:{port}{path} reason={reason}" + ) + }, + |jsonrpc_info| { + let endpoint = format!("{host_lc}:{port}{path}"); + let params_sha256 = jsonrpc_info + .params_sha256() + .unwrap_or_else(|| "".to_string()); + crate::l7::relay::jsonrpc_log_message( + decision_str, + method, + &endpoint, + jsonrpc_info, + ¶ms_sha256, + tunnel_engine.captured_generation(), + &reason, + ) + }, + ); let event = HttpActivityBuilder::new(crate::ocsf_ctx()) .activity(ActivityId::Other) .action(action_id) @@ -3524,9 +3558,7 @@ async fn handle_forward_proxy( .with_cmd_line(&cmdline_str), ) .firewall_rule(policy_str, engine_type) - .message(format!( - "{message_prefix} {decision_str} {method} {host_lc}:{port}{path} reason={reason}" - )) + .message(log_message) .build(); ocsf_emit!(event); } @@ -4184,6 +4216,28 @@ mod tests { assert_eq!(event.deny_group, "unknown"); } + #[test] + fn l7_parse_error_reason_includes_jsonrpc_errors() { + let request_info = crate::l7::L7RequestInfo { + action: "POST".to_string(), + target: "/mcp".to_string(), + query_params: std::collections::HashMap::new(), + graphql: None, + jsonrpc: Some(crate::l7::jsonrpc::JsonRpcRequestInfo { + calls: Vec::new(), + is_batch: false, + error: Some("ambiguous dotted params key 'arguments.scope'".to_string()), + }), + }; + + let reason = l7_parse_error_reason(&request_info).expect("JSON-RPC parse error"); + + assert_eq!( + reason, + "JSON-RPC request rejected: ambiguous dotted params key 'arguments.scope'" + ); + } + #[test] fn forward_l7_allowed_activity_is_deferred_until_after_ssrf() { let (tx, mut rx) = mpsc::channel(4); diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index fa540e4dd..1a0705cda 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -284,7 +284,7 @@ JSON-RPC allow rules match sandbox-to-server JSON-RPC-over-HTTP request objects | Field | Type | Required | Description | |---|---|---|---| | `rpc_method` | string | Yes | JSON-RPC method name or glob, such as `initialize`, `tools/list`, or `tools/*`. | -| `params` | map | No | Params matchers keyed by flattened object-param path. Use dot-separated keys for nested object params, such as `arguments.scope`. Matcher value can be a glob string or an object with `any`. Strings, numbers, and booleans are converted to strings; arrays, `null`, and non-object top-level params do not produce matcher keys. | +| `params` | map | No | Params matchers keyed by flattened object-param path. Use dot-separated keys for nested object params, such as `arguments.scope`. Matcher value can be a glob string or an object with `any`. Strings, numbers, and booleans are converted to strings; arrays, `null`, and non-object top-level params do not produce matcher keys. Requests with literal `.` characters in params object keys are rejected before policy evaluation because they are ambiguous with flattened nested paths. | Example JSON-RPC allow rules: diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index bae6fa279..f0a3464b3 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -623,7 +623,7 @@ JSON-RPC endpoint policies currently require full policy YAML applied with `open `json_rpc.max_body_bytes` controls how many JSON-RPC-over-HTTP request body bytes OpenShell buffers for inspection. It defaults to `65536`. -`params` matchers are case-sensitive and use the same string glob or `{ any: [...] }` matcher syntax as REST query parameters. They match scalar leaf values from object params: strings, numbers, and booleans are converted to strings, and nested JSON object params are flattened with dot-separated keys before matching. Arrays, `null`, and non-object top-level params do not produce matcher keys. This is useful for controls such as matching MCP `tools/call` by `params.name`, but it is not a complete MCP payload policy for rich nested content. For batch requests, OpenShell evaluates each JSON-RPC call independently and denies the whole batch if any call is denied. +`params` matchers are case-sensitive and use the same string glob or `{ any: [...] }` matcher syntax as REST query parameters. They match scalar leaf values from object params: strings, numbers, and booleans are converted to strings, and nested JSON object params are flattened with dot-separated keys before matching. Arrays, `null`, and non-object top-level params do not produce matcher keys. Requests with literal `.` characters in params object keys are rejected before policy evaluation because they are ambiguous with flattened nested paths. This is useful for controls such as matching MCP `tools/call` by `params.name`, but it is not a complete MCP payload policy for rich nested content. For batch requests, OpenShell evaluates each JSON-RPC call independently and denies the whole batch if any call is denied. ### GraphQL matching From 645ff478f6103b7eeab616914cf38f34ba12afe3 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Mon, 15 Jun 2026 11:04:43 -0700 Subject: [PATCH 11/12] ci(e2e): add MCP conformance coverage Add a reusable MCP conformance workflow that runs upstream client scenarios through an OpenShell sandbox. Add a client image, wrapper, policy template, and expected-failures baseline for expanding MCP conformance coverage. Remove stale JSON-RPC e2e policy fields that are no longer accepted. Signed-off-by: Kris Hicks --- .github/workflows/branch-e2e.yml | 16 ++- .github/workflows/mcp-conformance.yml | 99 ++++++++++++++++ e2e/mcp-conformance/Dockerfile.client | 25 ++++ e2e/mcp-conformance/README.md | 27 +++++ .../client-through-openshell.sh | 108 ++++++++++++++++++ e2e/mcp-conformance/expected-failures.yml | 7 ++ e2e/mcp-conformance/policy-template.yaml | 56 +++++++++ e2e/rust/tests/forward_proxy_jsonrpc_l7.rs | 4 +- 8 files changed, 337 insertions(+), 5 deletions(-) create mode 100644 .github/workflows/mcp-conformance.yml create mode 100644 e2e/mcp-conformance/Dockerfile.client create mode 100644 e2e/mcp-conformance/README.md create mode 100755 e2e/mcp-conformance/client-through-openshell.sh create mode 100644 e2e/mcp-conformance/expected-failures.yml create mode 100644 e2e/mcp-conformance/policy-template.yaml diff --git a/.github/workflows/branch-e2e.yml b/.github/workflows/branch-e2e.yml index 8a9e7fe29..46d9528b0 100644 --- a/.github/workflows/branch-e2e.yml +++ b/.github/workflows/branch-e2e.yml @@ -111,6 +111,16 @@ jobs: with: image-tag: ${{ github.sha }} + mcp-conformance: + needs: [pr_metadata, build-gateway, build-supervisor] + if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_core_e2e == 'true' + permissions: + contents: read + packages: read + uses: ./.github/workflows/mcp-conformance.yml + with: + image-tag: ${{ github.sha }} + kubernetes-ha-e2e: needs: [pr_metadata, build-gateway, build-supervisor] if: needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_kubernetes_ha_e2e == 'true' @@ -126,7 +136,7 @@ jobs: core-e2e-result: name: Core E2E result - needs: [pr_metadata, build-gateway, build-supervisor, e2e, kubernetes-e2e] + needs: [pr_metadata, build-gateway, build-supervisor, e2e, kubernetes-e2e, mcp-conformance] if: always() && needs.pr_metadata.outputs.should_run == 'true' && needs.pr_metadata.outputs.run_core_e2e == 'true' runs-on: ubuntu-latest steps: @@ -136,6 +146,7 @@ jobs: BUILD_SUPERVISOR_RESULT: ${{ needs.build-supervisor.result }} E2E_RESULT: ${{ needs.e2e.result }} KUBERNETES_E2E_RESULT: ${{ needs.kubernetes-e2e.result }} + MCP_CONFORMANCE_RESULT: ${{ needs.mcp-conformance.result }} run: | set -euo pipefail failed=0 @@ -143,7 +154,8 @@ jobs: "build-gateway:$BUILD_GATEWAY_RESULT" \ "build-supervisor:$BUILD_SUPERVISOR_RESULT" \ "e2e:$E2E_RESULT" \ - "kubernetes-e2e:$KUBERNETES_E2E_RESULT"; do + "kubernetes-e2e:$KUBERNETES_E2E_RESULT" \ + "mcp-conformance:$MCP_CONFORMANCE_RESULT"; do name="${item%%:*}" result="${item#*:}" if [ "$result" != "success" ]; then diff --git a/.github/workflows/mcp-conformance.yml b/.github/workflows/mcp-conformance.yml new file mode 100644 index 000000000..853c3dd8e --- /dev/null +++ b/.github/workflows/mcp-conformance.yml @@ -0,0 +1,99 @@ +name: MCP Conformance Test + +on: + workflow_call: + inputs: + image-tag: + description: "Image tag to test (typically the commit SHA)" + required: true + type: string + runner: + description: "GitHub Actions runner label" + required: false + type: string + default: "linux-amd64-cpu8" + checkout-ref: + description: "Git ref to check out for test inputs (defaults to the workflow SHA)" + required: false + type: string + default: "" + +permissions: + contents: read + packages: read + +jobs: + mcp-conformance: + name: MCP Conformance + runs-on: ${{ inputs.runner }} + timeout-minutes: 40 + defaults: + run: + shell: bash + container: + image: ghcr.io/nvidia/openshell/ci:latest + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --privileged + volumes: + - /var/run/docker.sock:/var/run/docker.sock + - /home/runner/_work:/home/runner/_work + env: + MISE_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + IMAGE_TAG: ${{ inputs.image-tag }} + OPENSHELL_REGISTRY: ghcr.io/nvidia/openshell + OPENSHELL_REGISTRY_HOST: ghcr.io + OPENSHELL_REGISTRY_NAMESPACE: nvidia/openshell + OPENSHELL_REGISTRY_USERNAME: ${{ github.actor }} + OPENSHELL_REGISTRY_PASSWORD: ${{ secrets.GITHUB_TOKEN }} + OPENSHELL_SUPERVISOR_IMAGE: ${{ format('ghcr.io/nvidia/openshell/supervisor:{0}', inputs.image-tag) }} + OPENSHELL_MCP_CONFORMANCE_CLIENT_IMAGE: openshell-mcp-conformance-client:${{ github.sha }} + steps: + - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + with: + ref: ${{ inputs['checkout-ref'] || github.sha }} + + - name: Check out MCP conformance tests + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + with: + repository: modelcontextprotocol/conformance + ref: v0.1.16 + path: .cache/mcp-conformance + + - name: Set up Node.js + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6 + with: + node-version: "22" + cache: npm + cache-dependency-path: .cache/mcp-conformance/package-lock.json + + - name: Build MCP conformance runner + working-directory: .cache/mcp-conformance + run: | + npm ci + npm run build + + - name: Log in to GHCR with Docker + run: echo "${OPENSHELL_REGISTRY_PASSWORD}" | docker login ghcr.io -u "${{ github.actor }}" --password-stdin + + - name: Build OpenShell e2e binaries + run: | + cargo build -p openshell-server --bin openshell-gateway --features openshell-core/dev-settings + cargo build -p openshell-cli --bin openshell --features openshell-core/dev-settings + + - name: Build MCP conformance client image + run: docker build --pull -f e2e/mcp-conformance/Dockerfile.client -t "${OPENSHELL_MCP_CONFORMANCE_CLIENT_IMAGE}" .cache/mcp-conformance + + - name: Run MCP conformance through OpenShell + run: | + set -euo pipefail + for scenario in initialize tools_call; do + echo "::group::MCP conformance: ${scenario}" + node .cache/mcp-conformance/dist/index.js client \ + --command "bash e2e/mcp-conformance/client-through-openshell.sh" \ + --scenario "${scenario}" \ + --expected-failures e2e/mcp-conformance/expected-failures.yml \ + --timeout 900000 + echo "::endgroup::" + done diff --git a/e2e/mcp-conformance/Dockerfile.client b/e2e/mcp-conformance/Dockerfile.client new file mode 100644 index 000000000..79810bbe9 --- /dev/null +++ b/e2e/mcp-conformance/Dockerfile.client @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +FROM public.ecr.aws/docker/library/node:22-bookworm-slim + +RUN apt-get update \ + && apt-get install -y --no-install-recommends ca-certificates iproute2 \ + && rm -rf /var/lib/apt/lists/* + +ARG SANDBOX_UID=1000660000 +ARG SANDBOX_GID=1000660000 + +# Match the sandbox user expected by OpenShell policies and supervisor setup. +# The UID/GID are intentionally outside Debian's default login.defs range. +RUN groupadd -K "GID_MAX=${SANDBOX_GID}" -g "${SANDBOX_GID}" sandbox \ + && useradd -K "UID_MAX=${SANDBOX_UID}" --no-log-init -m -u "${SANDBOX_UID}" -g sandbox sandbox + +WORKDIR /opt/mcp-conformance + +COPY . . +RUN if [ -f package-lock.json ]; then npm ci; else npm install; fi +RUN chown -R sandbox:sandbox /opt/mcp-conformance /home/sandbox + +USER sandbox +CMD ["sleep", "infinity"] diff --git a/e2e/mcp-conformance/README.md b/e2e/mcp-conformance/README.md new file mode 100644 index 000000000..1dd47b87e --- /dev/null +++ b/e2e/mcp-conformance/README.md @@ -0,0 +1,27 @@ +# MCP Conformance E2E + +This directory contains the OpenShell wrapper for the upstream +`modelcontextprotocol/conformance` runner. + +The workflow checks out and builds the upstream conformance repository, then +runs its CLI in client mode. The upstream runner starts a real MCP test server, +then invokes `client-through-openshell.sh` with that server URL. The wrapper +starts the Docker-backed OpenShell e2e gateway and runs the upstream TypeScript +`everything-client` inside an OpenShell sandbox, so the MCP traffic crosses the +sandbox proxy. + +The conformance server URL uses `localhost` from the GitHub Actions job +container's perspective. Sandboxes run in separate Docker containers, so the +wrapper rewrites local URLs to `host.openshell.internal`, the alias that +`e2e/with-docker-gateway.sh` attaches to the job container on the e2e Docker +network. + +The generated policy allows valid JSON-RPC requests to the conformance server +with `rpc_method: "*"`. That keeps OpenShell deny-by-default at the network +boundary while allowing the upstream scenarios to exercise MCP behavior. The +policy body lives in `policy-template.yaml`; the wrapper renders its host, port, +and path placeholders from the upstream server URL. + +When enabling broader upstream suites, add scenarios that OpenShell does not yet +support through the JSON-RPC proxy to `expected-failures.yml`. The upstream +runner treats listed failures as allowed and treats stale entries as failures. diff --git a/e2e/mcp-conformance/client-through-openshell.sh b/e2e/mcp-conformance/client-through-openshell.sh new file mode 100755 index 000000000..5b3e0c6fd --- /dev/null +++ b/e2e/mcp-conformance/client-through-openshell.sh @@ -0,0 +1,108 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Runs the upstream MCP conformance client through an OpenShell sandbox. +# +# The modelcontextprotocol/conformance runner starts a real MCP test server in +# the GitHub Actions job container and invokes this script with that server URL. +# This script starts the normal Docker-backed OpenShell e2e gateway, creates a +# sandbox from the prebuilt conformance client image, and runs the upstream +# TypeScript everything-client inside that sandbox. That keeps the MCP +# client/server traffic in the OpenShell proxy data path. +# +# Conformance server URLs usually point at localhost in the job container. +# Sandboxes are separate Docker containers, so localhost would point back at the +# sandbox itself. The wrapper rewrites local URLs to host.openshell.internal, +# which e2e/with-docker-gateway.sh attaches to the job container on the e2e +# Docker network. + +set -euo pipefail + +usage() { + echo "usage: $0 " >&2 +} + +if [ "$#" -ne 1 ]; then + usage + exit 2 +fi + +# Parse the conformance runner's server URL and render the OpenShell policy. +prepare_conformance_target() { + local server_url=$1 + local policy_file=$2 + local policy_template=$3 + + python3 - "${server_url}" "${policy_file}" "${policy_template}" <<'PY' +import json +import string +import sys +from pathlib import Path +from urllib.parse import urlparse, urlunparse + +raw_url, policy_file, policy_template = sys.argv[1:4] +parsed = urlparse(raw_url) + +if parsed.scheme not in ("http", "https"): + raise SystemExit(f"unsupported conformance server URL scheme: {parsed.scheme!r}") + +host = parsed.hostname +if not host: + raise SystemExit(f"conformance server URL is missing a host: {raw_url}") + +target_host = "host.openshell.internal" if host in {"localhost", "127.0.0.1", "::1"} else host +port = parsed.port or (443 if parsed.scheme == "https" else 80) +path = parsed.path or "/" +netloc_host = f"[{target_host}]" if ":" in target_host and not target_host.startswith("[") else target_host +netloc = f"{netloc_host}:{port}" +rewritten = urlunparse((parsed.scheme, netloc, path, parsed.params, parsed.query, parsed.fragment)) + +template = string.Template(Path(policy_template).read_text(encoding="utf-8")) +policy = template.substitute( + host=json.dumps(target_host), + port=str(port), + path=json.dumps(path), +) +Path(policy_file).write_text(policy, encoding="utf-8") + +print(rewritten) +PY +} + +SERVER_URL="$1" +CLIENT_IMAGE="${OPENSHELL_MCP_CONFORMANCE_CLIENT_IMAGE:?set OPENSHELL_MCP_CONFORMANCE_CLIENT_IMAGE to the prebuilt conformance client image}" +ROOT="$(git rev-parse --show-toplevel 2>/dev/null || pwd)" +POLICY_TEMPLATE="${ROOT}/e2e/mcp-conformance/policy-template.yaml" + +POLICY_FILE="$(mktemp "${TMPDIR:-/tmp}/openshell-mcp-conformance-policy.XXXXXX.yaml")" +trap 'rm -f "${POLICY_FILE}"' EXIT + +CLIENT_SERVER_URL="$(prepare_conformance_target "${SERVER_URL}" "${POLICY_FILE}" "${POLICY_TEMPLATE}")" + +ENV_ARGS=() +# These environment variables are set by the upstream conformance test runner +# before it invokes the configured client command. Forward them into the +# sandbox because the sandboxed TypeScript client depends on them to select the +# scenario and read scenario-specific context. +for NAME in MCP_CONFORMANCE_SCENARIO MCP_CONFORMANCE_CONTEXT MCP_CONFORMANCE_PROTOCOL_VERSION; do + if [ -n "${!NAME+x}" ]; then + ENV_ARGS+=(--env "${NAME}=${!NAME}") + fi +done + +# shellcheck source=e2e/support/gateway-common.sh disable=SC1091 +source "${ROOT}/e2e/support/gateway-common.sh" +TARGET_DIR="$(e2e_cargo_target_dir "${ROOT}")" +OPENSHELL_BIN="${OPENSHELL_BIN:-${TARGET_DIR}/debug/openshell}" +export OPENSHELL_E2E_DOCKER_SANDBOX_IMAGE="${OPENSHELL_E2E_DOCKER_SANDBOX_IMAGE:-${CLIENT_IMAGE}}" + +# shellcheck disable=SC2016 +"${ROOT}/e2e/with-docker-gateway.sh" \ + "${OPENSHELL_BIN}" sandbox create \ + --from "${CLIENT_IMAGE}" \ + --policy "${POLICY_FILE}" \ + "${ENV_ARGS[@]}" \ + -- \ + sh -c 'cd /opt/mcp-conformance && exec ./node_modules/.bin/tsx examples/clients/typescript/everything-client.ts "$1"' \ + sh "${CLIENT_SERVER_URL}" diff --git a/e2e/mcp-conformance/expected-failures.yml b/e2e/mcp-conformance/expected-failures.yml new file mode 100644 index 000000000..5b226631f --- /dev/null +++ b/e2e/mcp-conformance/expected-failures.yml @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# Add client scenarios here when enabling broader MCP conformance suites that +# exercise features OpenShell does not yet support through the JSON-RPC proxy. +client: [] +server: [] diff --git a/e2e/mcp-conformance/policy-template.yaml b/e2e/mcp-conformance/policy-template.yaml new file mode 100644 index 000000000..2a02f6374 --- /dev/null +++ b/e2e/mcp-conformance/policy-template.yaml @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +version: 1 + +filesystem_policy: + include_workdir: true + read_only: + - /bin + - /usr + - /lib + - /lib64 + - /proc + - /sys + - /dev/urandom + - /etc + - /opt + - /var/log + read_write: + - /sandbox + - /tmp + - /dev/null + - /home/sandbox + +landlock: + compatibility: best_effort + +process: + run_as_user: sandbox + run_as_group: sandbox + +network_policies: + mcp_conformance: + name: mcp_conformance + endpoints: + - host: ${host} + port: ${port} + path: ${path} + protocol: json-rpc + enforcement: enforce + allowed_ips: + - "10.0.0.0/8" + - "172.0.0.0/8" + - "192.168.0.0/16" + - "fc00::/7" + json_rpc: + max_body_bytes: 131072 + rules: + - allow: + rpc_method: "*" + binaries: + - path: /bin/sh + - path: /usr/bin/env + - path: /usr/local/bin/node + - path: /usr/bin/node + - path: /opt/mcp-conformance/node_modules/.bin/* diff --git a/e2e/rust/tests/forward_proxy_jsonrpc_l7.rs b/e2e/rust/tests/forward_proxy_jsonrpc_l7.rs index feba98ceb..ada51ba2e 100644 --- a/e2e/rust/tests/forward_proxy_jsonrpc_l7.rs +++ b/e2e/rust/tests/forward_proxy_jsonrpc_l7.rs @@ -102,8 +102,6 @@ network_policies: - "fc00::/7" json_rpc: max_body_bytes: 65536 - on_parse_error: deny - batch_policy: deny_if_any_denied rules: - allow: rpc_method: initialize @@ -318,7 +316,7 @@ results = {{ {{"jsonrpc": "2.0", "id": 2, "method": "tools/call", "params": {{"name": "blocked_action"}}}}, ]), - # forward proxy — invalid JSON body denied by on_parse_error: deny + # forward proxy — invalid JSON body fails closed before generic rules apply "forward_invalid_json_denied": post_invalid_json(), # CONNECT path — representative allowed and denied cases From 7101a37b3dd1fcfa36cb357cf3f5133b6c2fc295 Mon Sep 17 00:00:00 2001 From: Kris Hicks Date: Mon, 15 Jun 2026 14:06:34 -0700 Subject: [PATCH 12/12] fix(l7): port JSON-RPC L7 to supervisor network Signed-off-by: Kris Hicks --- .../data/sandbox-policy.rego | 787 -- crates/openshell-sandbox/src/l7/graphql.rs | 637 -- crates/openshell-sandbox/src/l7/mod.rs | 2489 ------ crates/openshell-sandbox/src/l7/relay.rs | 2977 ------- crates/openshell-sandbox/src/l7/websocket.rs | 1943 ----- crates/openshell-sandbox/src/opa.rs | 5339 ------------ crates/openshell-sandbox/src/policy_local.rs | 2030 ----- crates/openshell-sandbox/src/proxy.rs | 7718 ----------------- .../data/sandbox-policy.rego | 47 + .../src/l7/graphql.rs | 1 + .../src/l7/http.rs | 0 .../src/l7/jsonrpc.rs | 0 .../src/l7/mod.rs | 33 +- .../src/l7/relay.rs | 569 +- .../src/l7/websocket.rs | 3 + .../openshell-supervisor-network/src/opa.rs | 239 +- .../src/policy_local.rs | 5 + .../openshell-supervisor-network/src/proxy.rs | 160 +- 18 files changed, 995 insertions(+), 23982 deletions(-) delete mode 100644 crates/openshell-sandbox/data/sandbox-policy.rego delete mode 100644 crates/openshell-sandbox/src/l7/graphql.rs delete mode 100644 crates/openshell-sandbox/src/l7/mod.rs delete mode 100644 crates/openshell-sandbox/src/l7/relay.rs delete mode 100644 crates/openshell-sandbox/src/l7/websocket.rs delete mode 100644 crates/openshell-sandbox/src/opa.rs delete mode 100644 crates/openshell-sandbox/src/policy_local.rs delete mode 100644 crates/openshell-sandbox/src/proxy.rs rename crates/{openshell-sandbox => openshell-supervisor-network}/src/l7/http.rs (100%) rename crates/{openshell-sandbox => openshell-supervisor-network}/src/l7/jsonrpc.rs (100%) diff --git a/crates/openshell-sandbox/data/sandbox-policy.rego b/crates/openshell-sandbox/data/sandbox-policy.rego deleted file mode 100644 index 38070911c..000000000 --- a/crates/openshell-sandbox/data/sandbox-policy.rego +++ /dev/null @@ -1,787 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -package openshell.sandbox - -default allow_network = false - -# --- Static policy data passthrough (queried at sandbox startup) --- - -filesystem_policy := data.filesystem_policy - -landlock_policy := data.landlock - -process_policy := data.process - -# --- Network access decision (queried per-CONNECT request) --- - -allow_network if { - network_policy_for_request -} - -# --- Deny reasons (specific diagnostics for debugging policy denials) --- - -deny_reason := "missing input.network" if { - not input.network -} - -deny_reason := "missing input.exec" if { - input.network - not input.exec -} - -deny_reason := reason if { - input.network - input.exec - not network_policy_for_request - not endpoint_policy_for_request - count(data.network_policies) > 0 - reason := sprintf("endpoint %s:%d is not allowed by any policy", [input.network.host, input.network.port]) -} - -deny_reason := reason if { - input.network - input.exec - not network_policy_for_request - endpoint_policy_for_request - ancestors_str := concat(" -> ", input.exec.ancestors) - cmdline_str := concat(", ", input.exec.cmdline_paths) - binary_misses := [r | - some name - policy := data.network_policies[name] - endpoint_allowed(policy, input.network) - not binary_allowed(policy, input.exec) - r := sprintf("binary '%s' not allowed in policy '%s' (ancestors: [%s], cmdline: [%s]). SYMLINK HINT: the binary path is the kernel-resolved target from /proc//exe, not the symlink. If your policy specifies a symlink (e.g., /usr/bin/python3) but the actual binary is /usr/bin/python3.11, either: (1) use the canonical path in your policy (run 'readlink -f /usr/bin/python3' inside the sandbox), or (2) ensure symlink resolution is working (check sandbox logs for 'Cannot access container filesystem')", [input.exec.path, name, ancestors_str, cmdline_str]) - ] - count(binary_misses) > 0 - reason := concat("; ", binary_misses) -} - -deny_reason := "network connections not allowed by policy" if { - input.network - input.exec - not network_policy_for_request - count(data.network_policies) == 0 -} - -# --- Matched policy name (for audit logging) --- -# -# Collects all matching policy names into a set, then deterministically picks -# the lexicographically smallest. This avoids a "complete rule conflict" when -# multiple policies cover the same endpoint (e.g. after draft approval adds an -# overlapping rule). - -_matching_policy_names contains name if { - some name - policy := data.network_policies[name] - endpoint_allowed(policy, input.network) - binary_allowed(policy, input.exec) -} - -matched_network_policy := min(_matching_policy_names) if { - count(_matching_policy_names) > 0 -} - -# --- Core matching logic --- - -# True when at least one network policy matches the request (endpoint + binary). -# Expressed as a boolean so that multiple matching policies don't cause a -# "complete rule conflict". -network_policy_for_request if { - some name - data.network_policies[name] - endpoint_allowed(data.network_policies[name], input.network) - binary_allowed(data.network_policies[name], input.exec) -} - -endpoint_policy_for_request if { - some name - data.network_policies[name] - endpoint_allowed(data.network_policies[name], input.network) -} - -# Endpoint matching: exact host (case-insensitive) + port in ports list. -endpoint_allowed(policy, network) if { - some endpoint - endpoint := policy.endpoints[_] - not contains(endpoint.host, "*") - lower(endpoint.host) == lower(network.host) - endpoint.ports[_] == network.port -} - -# Endpoint matching: glob host pattern + port in ports list. -# Uses "." as delimiter so "*" matches a single DNS label and "**" matches -# across label boundaries — consistent with TLS certificate wildcard semantics. -endpoint_allowed(policy, network) if { - some endpoint - endpoint := policy.endpoints[_] - contains(endpoint.host, "*") - glob.match(lower(endpoint.host), ["."], lower(network.host)) - endpoint.ports[_] == network.port -} - -# Endpoint matching: hostless with allowed_ips — match any host on port. -# When an endpoint has allowed_ips but no host, it matches any hostname on the -# given port. The actual IP validation happens in Rust post-DNS-resolution. -endpoint_allowed(policy, network) if { - some endpoint - endpoint := policy.endpoints[_] - object.get(endpoint, "host", "") == "" - count(object.get(endpoint, "allowed_ips", [])) > 0 - endpoint.ports[_] == network.port -} - -# Binary matching: exact path. -# SHA256 integrity is enforced in Rust via trust-on-first-use (TOFU) cache, -# not in Rego. The proxy computes and caches binary hashes at runtime. -binary_allowed(policy, exec) if { - some b - b := policy.binaries[_] - not contains(b.path, "*") - b.path == exec.path -} - -# Binary matching: ancestor exact path (e.g., claude spawns node). -binary_allowed(policy, exec) if { - some b - b := policy.binaries[_] - not contains(b.path, "*") - ancestor := exec.ancestors[_] - b.path == ancestor -} - -# Binary matching: glob pattern against exe path or any ancestor. -# NOTE: cmdline_paths are intentionally excluded — argv[0] is trivially -# spoofable via execve and must not be used as a grant-access signal. -binary_allowed(policy, exec) if { - some b in policy.binaries - contains(b.path, "*") - all_paths := array.concat([exec.path], exec.ancestors) - some p in all_paths - glob.match(b.path, ["/"], p) -} - -user_declared_binary_allowed(policy, exec) if { - some b - b := policy.binaries[_] - not object.get(b, "advisor_proposed", false) - not contains(b.path, "*") - b.path == exec.path -} - -user_declared_binary_allowed(policy, exec) if { - some b - b := policy.binaries[_] - not object.get(b, "advisor_proposed", false) - not contains(b.path, "*") - ancestor := exec.ancestors[_] - b.path == ancestor -} - -user_declared_binary_allowed(policy, exec) if { - some b in policy.binaries - not object.get(b, "advisor_proposed", false) - contains(b.path, "*") - all_paths := array.concat([exec.path], exec.ancestors) - some p in all_paths - glob.match(b.path, ["/"], p) -} - -# --- Network action (allow / deny) --- -# -# These rules are mutually exclusive by construction: -# - "allow" requires `network_policy_for_request` (binary+endpoint matched) -# - default is "deny" when no policy matches. - -default network_action := "deny" - -# Explicitly allowed: endpoint + binary match in a network policy → allow. -network_action := "allow" if { - network_policy_for_request -} - -# =========================================================================== -# L7 request evaluation (queried per-request within a tunnel) -# =========================================================================== - -default allow_request = false - -# Per-policy helper: true when this single policy has at least one endpoint -# matching the L4 request whose L7 rules also permit the specific request. -# Isolating the endpoint iteration inside a function avoids the regorus -# "duplicated definition of local variable" error that occurs when the -# outer `some name` iterates over multiple policies that share a host:port. -_policy_allows_l7(policy) if { - some ep - ep := policy.endpoints[_] - endpoint_matches_l7_request(ep, input.network, input.request) - request_allowed_for_endpoint(input.request, ep) -} - -# L7 request allowed if any matching L4 policy also allows the L7 request -# AND no deny rule blocks it. Deny rules take precedence over allow rules. -allow_request if { - some name - policy := data.network_policies[name] - endpoint_allowed(policy, input.network) - binary_allowed(policy, input.exec) - _policy_allows_l7(policy) - not deny_request -} - -# --- L7 deny rules --- -# -# Deny rules are evaluated after allow rules and take precedence. -# If a request matches any deny rule on any matching endpoint, it is blocked -# even if it would otherwise be allowed. - -default deny_request = false - -# Per-policy helper: true when this policy has at least one endpoint matching -# the L4 request whose deny_rules also match the specific L7 request. -_policy_denies_l7(policy) if { - some ep - ep := policy.endpoints[_] - endpoint_matches_l7_request(ep, input.network, input.request) - request_denied_for_endpoint(input.request, ep) -} - -deny_request if { - some name - policy := data.network_policies[name] - endpoint_allowed(policy, input.network) - binary_allowed(policy, input.exec) - _policy_denies_l7(policy) -} - -# --- L7 deny rule matching: REST method + path + query --- - -request_denied_for_endpoint(request, endpoint) if { - some deny_rule - deny_rule := endpoint.deny_rules[_] - deny_rule.method - method_matches(request.method, deny_rule.method) - path_matches(request.path, deny_rule.path) - deny_query_params_match(request, deny_rule) -} - -# --- L7 deny rule matching: SQL command --- - -request_denied_for_endpoint(request, endpoint) if { - some deny_rule - deny_rule := endpoint.deny_rules[_] - deny_rule.command - command_matches(request.command, deny_rule.command) -} - -# --- L7 deny rule matching: JSON-RPC method + params --- - -request_denied_for_endpoint(request, endpoint) if { - some deny_rule - deny_rule := endpoint.deny_rules[_] - deny_rule.rpc_method - jsonrpc_rule_matches(request, deny_rule) -} - -# --- L7 deny rule matching: GraphQL operation --- - -request_denied_for_endpoint(request, endpoint) if { - graphql_request_has_operations(request) - some deny_rule - deny_rule := endpoint.deny_rules[_] - deny_rule.operation_type - op := request.graphql.operations[_] - graphql_deny_rule_matches_operation(op, deny_rule, endpoint) -} - -# A GraphQL endpoint path is authoritative once it matches. If the parsed -# GraphQL request is malformed, hash-only without a trusted registry entry, or -# contains an operation outside the GraphQL allow rules, a broader REST rule on -# the same host:port must not allow it through. -request_denied_for_endpoint(request, endpoint) if { - endpoint.protocol == "graphql" - is_object(request.graphql) - not graphql_request_allowed(request, endpoint) -} - -# The same authority applies when a WebSocket endpoint opts into GraphQL -# operation policy. Once the relay classifies a client text message as a -# GraphQL-over-WebSocket operation, generic WEBSOCKET_TEXT rules must not bypass -# operation_type / operation_name / fields policy. -request_denied_for_endpoint(request, endpoint) if { - endpoint.protocol == "websocket" - is_object(request.graphql) - not graphql_request_allowed(request, endpoint) -} - -# Deny query matching: fail-closed semantics. -# If no query rules on the deny rule, match unconditionally (any query params). -# If query rules present, trigger the deny if ANY value for a configured key -# matches the matcher. This is the inverse of allow-side semantics where ALL -# values must match. For deny logic, a single matching value is enough to block. -deny_query_params_match(request, deny_rule) if { - deny_query_rules := object.get(deny_rule, "query", {}) - count(deny_query_rules) == 0 -} - -deny_query_params_match(request, deny_rule) if { - deny_query_rules := object.get(deny_rule, "query", {}) - count(deny_query_rules) > 0 - not deny_query_key_missing(request, deny_query_rules) - not deny_query_value_mismatch_all(request, deny_query_rules) -} - -# A configured deny query key is missing from the request entirely. -# Missing key means the deny rule doesn't apply (fail-open on absence). -deny_query_key_missing(request, query_rules) if { - some key - query_rules[key] - request_query := object.get(request, "query_params", {}) - values := object.get(request_query, key, null) - values == null -} - -# ALL values for a configured key fail to match the matcher. -# If even one value matches, deny fires. This rule checks the opposite: -# true when NO value matches (i.e., every value is a mismatch). -deny_query_value_mismatch_all(request, query_rules) if { - some key - matcher := query_rules[key] - request_query := object.get(request, "query_params", {}) - values := object.get(request_query, key, []) - count(values) > 0 - not deny_any_value_matches(values, matcher) -} - -# True if at least one value in the list matches the matcher. -deny_any_value_matches(values, matcher) if { - some i - query_value_matches(values[i], matcher) -} - -# --- L7 deny reason --- - -request_deny_reason := reason if { - input.request - graphql_request_error(input.request) - reason := sprintf("GraphQL request rejected: %s", [input.request.graphql.error]) -} - -request_deny_reason := reason if { - input.request - not graphql_request_error(input.request) - graphql_request_has_unregistered_persisted_query(input.request, matched_endpoint_config) - reason := "GraphQL persisted query is not registered" -} - -request_deny_reason := reason if { - input.request - deny_request - graphql_request_has_operations(input.request) - not graphql_request_has_unregistered_persisted_query(input.request, matched_endpoint_config) - reason := "GraphQL operation blocked by endpoint policy" -} - -request_deny_reason := reason if { - input.request - not deny_request - not allow_request - graphql_request_has_operations(input.request) - not graphql_request_has_unregistered_persisted_query(input.request, matched_endpoint_config) - reason := "GraphQL operation not permitted by policy" -} - -request_deny_reason := reason if { - input.request - deny_request - not graphql_request_has_operations(input.request) - reason := sprintf("%s %s blocked by deny rule", [input.request.method, input.request.path]) -} - -request_deny_reason := reason if { - input.request - not deny_request - not allow_request - not graphql_request_has_operations(input.request) - reason := sprintf("%s %s not permitted by policy", [input.request.method, input.request.path]) -} - -# --- L7 rule matching: REST method + path --- - -request_allowed_for_endpoint(request, endpoint) if { - some rule - rule := endpoint.rules[_] - rule.allow.method - method_matches(request.method, rule.allow.method) - path_matches(request.path, rule.allow.path) - query_params_match(request, rule) -} - -# --- L7 rule matching: SQL command --- - -request_allowed_for_endpoint(request, endpoint) if { - some rule - rule := endpoint.rules[_] - rule.allow.command - command_matches(request.command, rule.allow.command) -} - -# --- L7 rule matching: JSON-RPC method --- - -request_allowed_for_endpoint(request, endpoint) if { - some rule - rule := endpoint.rules[_] - rule.allow.rpc_method - jsonrpc_rule_matches(request, rule.allow) -} - -# --- L7 rule matching: GraphQL operation --- - -request_allowed_for_endpoint(request, endpoint) if { - graphql_request_allowed(request, endpoint) -} - -graphql_request_allowed(request, endpoint) if { - graphql_request_has_operations(request) - not graphql_request_error(request) - not graphql_request_has_unregistered_persisted_query(request, endpoint) - not graphql_request_has_unallowed_operation(request, endpoint) -} - -graphql_request_has_operations(request) if { - is_object(request.graphql) - operations := object.get(request.graphql, "operations", []) - count(operations) > 0 -} - -graphql_request_error(request) if { - is_object(request.graphql) - error := object.get(request.graphql, "error", "") - error != "" -} - -graphql_request_has_unallowed_operation(request, endpoint) if { - op := request.graphql.operations[_] - not graphql_operation_allowed(op, endpoint) -} - -graphql_operation_allowed(op, endpoint) if { - rule := endpoint.rules[_] - rule.allow.operation_type - graphql_allow_rule_matches_operation(op, rule.allow, endpoint) -} - -graphql_request_has_unregistered_persisted_query(request, endpoint) if { - op := request.graphql.operations[_] - graphql_operation_needs_registry(op) - not graphql_registered_operation(op, endpoint) -} - -graphql_operation_needs_registry(op) if { - object.get(op, "persisted_query", false) == true - object.get(op, "operation_type", "") == "" -} - -graphql_registered_operation(op, endpoint) if { - object.get(endpoint, "persisted_queries", "deny") == "allow_registered" - id := graphql_operation_registry_key(op) - endpoint.graphql_persisted_queries[id] -} - -graphql_operation_registry_key(op) := key if { - key := object.get(op, "persisted_query_hash", "") - key != "" -} - -graphql_operation_registry_key(op) := key if { - object.get(op, "persisted_query_hash", "") == "" - key := object.get(op, "persisted_query_id", "") - key != "" -} - -graphql_effective_operation(op, endpoint) := registered if { - graphql_operation_needs_registry(op) - key := graphql_operation_registry_key(op) - registered := endpoint.graphql_persisted_queries[key] -} - -graphql_effective_operation(op, _) := op if { - not graphql_operation_needs_registry(op) -} - -graphql_allow_rule_matches_operation(op, rule, endpoint) if { - effective := graphql_effective_operation(op, endpoint) - graphql_operation_type_matches(effective, rule) - graphql_operation_name_matches(effective, rule) - graphql_allow_fields_match(effective, rule) -} - -graphql_deny_rule_matches_operation(op, rule, endpoint) if { - effective := graphql_effective_operation(op, endpoint) - graphql_operation_type_matches(effective, rule) - graphql_operation_name_matches(effective, rule) - graphql_deny_fields_match(effective, rule) -} - -graphql_operation_type_matches(_, rule) if { - object.get(rule, "operation_type", "") == "*" -} - -graphql_operation_type_matches(op, rule) if { - expected := object.get(rule, "operation_type", "") - expected != "" - expected != "*" - lower(object.get(op, "operation_type", "")) == lower(expected) -} - -graphql_operation_name_matches(_, rule) if { - object.get(rule, "operation_name", "") == "" -} - -graphql_operation_name_matches(op, rule) if { - pattern := object.get(rule, "operation_name", "") - pattern != "" - name := object.get(op, "operation_name", "") - glob.match(pattern, [], name) -} - -# Allow-side field constraints are intentionally all-selected-fields semantics: -# if a rule declares fields, every root field selected by the operation must -# match one of the rule patterns. This prevents mixed-operation requests from -# allowing an unlisted field because one safe field also appeared. -graphql_allow_fields_match(_, rule) if { - count(object.get(rule, "fields", [])) == 0 -} - -graphql_allow_fields_match(op, rule) if { - count(object.get(rule, "fields", [])) > 0 - count(object.get(op, "fields", [])) > 0 - not graphql_operation_has_unmatched_field(op, rule) -} - -graphql_operation_has_unmatched_field(op, rule) if { - field := object.get(op, "fields", [])[_] - not graphql_field_matches_any(field, object.get(rule, "fields", [])) -} - -graphql_deny_fields_match(_, rule) if { - count(object.get(rule, "fields", [])) == 0 -} - -graphql_deny_fields_match(op, rule) if { - field := object.get(op, "fields", [])[_] - graphql_field_matches_any(field, object.get(rule, "fields", [])) -} - -graphql_field_matches_any(field, patterns) if { - pattern := patterns[_] - glob.match(pattern, [], field) -} - -# Wildcard "*" matches any method; otherwise case-insensitive exact match. -# RFC 9110 §9.3.2: HEAD is semantically identical to GET except no response body. -method_matches(_, "*") if true - -method_matches(actual, expected) if { - expected != "*" - upper(actual) == upper(expected) -} - -method_matches(actual, expected) if { - upper(actual) == "HEAD" - upper(expected) == "GET" -} - -# Path matching: "**" matches everything; otherwise glob.match with "/" delimiter. -# -# INVARIANT: `input.request.path` is canonicalized by the sandbox before -# policy evaluation — percent-decoded, dot-segments resolved, doubled -# slashes collapsed, `;params` stripped, `%2F` rejected (unless an -# endpoint opts in). Patterns here must therefore match canonical paths; -# do not attempt defensive matching against `..` or `%2e%2e` — those -# inputs are rejected at the L7 parser boundary before this rule runs. -path_matches(_, "**") if true - -path_matches(actual, pattern) if { - pattern != "**" - glob.match(pattern, ["/"], actual) -} - -# Query matching: -# - If no query rules are configured, allow any query params. -# - For configured keys, all request values for that key must match. -# - Matcher shape supports either `glob` or `any`. -query_params_match(request, rule) if { - query_rules := object.get(rule.allow, "query", {}) - not query_mismatch(request, query_rules) -} - -query_mismatch(request, query_rules) if { - some key - matcher := query_rules[key] - not query_key_matches(request, key, matcher) -} - -query_key_matches(request, key, matcher) if { - request_query := object.get(request, "query_params", {}) - values := object.get(request_query, key, null) - values != null - count(values) > 0 - not query_value_mismatch(values, matcher) -} - -query_value_mismatch(values, matcher) if { - some i - value := values[i] - not query_value_matches(value, matcher) -} - -query_value_matches(value, matcher) if { - is_string(matcher) - glob.match(matcher, [], value) -} - -query_value_matches(value, matcher) if { - is_object(matcher) - glob_pattern := object.get(matcher, "glob", "") - glob_pattern != "" - glob.match(glob_pattern, [], value) -} - -query_value_matches(value, matcher) if { - is_object(matcher) - any_patterns := object.get(matcher, "any", []) - count(any_patterns) > 0 - some i - glob.match(any_patterns[i], [], value) -} - -# JSON-RPC method and params matching. The sandbox flattens object params into -# dot-separated keys before policy evaluation, e.g. arguments.scope. -jsonrpc_rule_matches(request, rule) if { - jsonrpc := object.get(request, "jsonrpc", {}) - method := object.get(jsonrpc, "method", null) - method != null - glob.match(rule.rpc_method, [], method) - jsonrpc_params_match(jsonrpc, rule) -} - -jsonrpc_params_match(jsonrpc, rule) if { - param_rules := object.get(rule, "params", {}) - not jsonrpc_param_mismatch(jsonrpc, param_rules) -} - -jsonrpc_param_mismatch(jsonrpc, param_rules) if { - some key - matcher := param_rules[key] - not jsonrpc_param_key_matches(jsonrpc, key, matcher) -} - -jsonrpc_param_key_matches(jsonrpc, key, matcher) if { - params := object.get(jsonrpc, "params", {}) - value := object.get(params, key, null) - value != null - is_string(value) - query_value_matches(value, matcher) -} - -# SQL command matching: "*" matches any; otherwise case-insensitive. -command_matches(_, "*") if true - -command_matches(actual, expected) if { - expected != "*" - upper(actual) == upper(expected) -} - -# --- Matched endpoint config (for L7 and allowed_ips extraction) --- -# Returns the raw endpoint object for the matched policy + host:port. -# Used by Rust to extract L7 config (protocol, tls, enforcement, -# allow_encoded_slash) and/or allowed_ips for SSRF allowlist validation. - -# Per-policy helper: returns matching endpoint configs for a single policy. -_policy_endpoint_configs(policy) := [ep | - some ep - ep := policy.endpoints[_] - endpoint_matches_request(ep, input.network) - endpoint_has_extended_config(ep) -] - -# Collect matching endpoint configs across all policies. Iterates over -# _matching_policy_names (a set, safe from regorus variable collisions) -# then collects per-policy configs via the helper function. -_matching_endpoint_configs := [cfg | - some pname - _matching_policy_names[pname] - cfgs := _policy_endpoint_configs(data.network_policies[pname]) - cfg := cfgs[_] -] - -matched_endpoint_config := _matching_endpoint_configs[0] if { - count(_matching_endpoint_configs) > 0 -} - -_policy_has_exact_declared_endpoint(policy) if { - some ep - ep := policy.endpoints[_] - not object.get(ep, "advisor_proposed", false) - not contains(ep.host, "*") - lower(ep.host) == lower(input.network.host) - ep.ports[_] == input.network.port -} - -exact_declared_endpoint_host if { - some pname - policy := data.network_policies[pname] - user_declared_binary_allowed(policy, input.exec) - _policy_has_exact_declared_endpoint(policy) -} - -# Hosted endpoint: exact host match + port in ports list. -endpoint_matches_request(ep, network) if { - not contains(ep.host, "*") - lower(ep.host) == lower(network.host) - ep.ports[_] == network.port -} - -# Hosted endpoint: glob host match + port in ports list. -endpoint_matches_request(ep, network) if { - contains(ep.host, "*") - glob.match(lower(ep.host), ["."], lower(network.host)) - ep.ports[_] == network.port -} - -# Hostless endpoint with allowed_ips: match on port only. -endpoint_matches_request(ep, network) if { - object.get(ep, "host", "") == "" - count(object.get(ep, "allowed_ips", [])) > 0 - ep.ports[_] == network.port -} - -endpoint_matches_l7_request(ep, network, request) if { - endpoint_matches_request(ep, network) - endpoint_path_matches_request(ep, request) -} - -endpoint_path_matches_request(ep, request) if { - object.get(ep, "path", "") == "" -} - -endpoint_path_matches_request(ep, request) if { - path := object.get(ep, "path", "") - path != "" - path_matches(request.path, path) -} - -# An endpoint has extended config if it specifies L7 protocol, allowed_ips, -# or an explicit tls mode (e.g. tls: skip). -endpoint_has_extended_config(ep) if { - ep.protocol -} - -endpoint_has_extended_config(ep) if { - count(object.get(ep, "allowed_ips", [])) > 0 -} - -endpoint_has_extended_config(ep) if { - ep.tls -} diff --git a/crates/openshell-sandbox/src/l7/graphql.rs b/crates/openshell-sandbox/src/l7/graphql.rs deleted file mode 100644 index 77ec3b6fd..000000000 --- a/crates/openshell-sandbox/src/l7/graphql.rs +++ /dev/null @@ -1,637 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! GraphQL-over-HTTP L7 inspection. - -use crate::l7::provider::{L7Provider, L7Request}; -use apollo_parser::Parser; -use apollo_parser::cst; -use miette::{Result, miette}; -use serde::Serialize; -use serde_json::Value; -use std::collections::{HashMap, HashSet}; -use tokio::io::{AsyncRead, AsyncWrite}; - -pub const DEFAULT_MAX_BODY_BYTES: usize = 64 * 1024; - -#[derive(Debug, Clone, Serialize, PartialEq, Eq)] -pub struct GraphqlRequestInfo { - pub operations: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Clone, Serialize, PartialEq, Eq)] -pub struct GraphqlOperationInfo { - pub operation_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub operation_name: Option, - pub fields: Vec, - pub persisted_query: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub persisted_query_hash: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub persisted_query_id: Option, -} - -pub struct GraphqlHttpRequest { - pub request: L7Request, - pub info: GraphqlRequestInfo, -} - -pub async fn parse_graphql_http_request( - client: &mut C, - max_body_bytes: usize, - canonicalize_options: crate::l7::path::CanonicalizeOptions, -) -> Result> { - let provider = crate::l7::rest::RestProvider::with_options(canonicalize_options); - let Some(mut request) = provider.parse_request(client).await? else { - return Ok(None); - }; - - let info = inspect_graphql_request(client, &mut request, max_body_bytes).await?; - - Ok(Some(GraphqlHttpRequest { request, info })) -} - -pub(crate) async fn inspect_graphql_request( - client: &mut C, - request: &mut L7Request, - max_body_bytes: usize, -) -> Result { - let header_str = header_str(request)?; - reject_unsupported_headers(header_str)?; - let body = crate::l7::http::read_body_for_inspection(client, request, max_body_bytes).await?; - Ok(classify_request(request, &body)) -} - -pub fn classify_request(request: &L7Request, body: &[u8]) -> GraphqlRequestInfo { - match classify_request_inner(request, body) { - Ok(operations) => GraphqlRequestInfo { - operations, - error: None, - }, - Err(err) => GraphqlRequestInfo { - operations: Vec::new(), - error: Some(err), - }, - } -} - -pub fn classify_json_envelope_value(value: &Value) -> GraphqlRequestInfo { - match classify_json_envelope(value) { - Ok(operations) => GraphqlRequestInfo { - operations, - error: None, - }, - Err(err) => GraphqlRequestInfo { - operations: Vec::new(), - error: Some(err), - }, - } -} - -fn classify_request_inner( - request: &L7Request, - body: &[u8], -) -> std::result::Result, String> { - match request.action.to_ascii_uppercase().as_str() { - "GET" => classify_get(request), - "POST" => classify_post(body), - method => Err(format!("unsupported GraphQL HTTP method {method}")), - } -} - -fn classify_get(request: &L7Request) -> std::result::Result, String> { - let query = unique_query_value(&request.query_params, "query")?; - let operation_name = unique_query_value(&request.query_params, "operationName")?; - let extensions = unique_query_value(&request.query_params, "extensions")? - .and_then(|raw| serde_json::from_str::(&raw).ok()); - let id = unique_persisted_query_id(&request.query_params)?; - - classify_envelope( - query.as_deref(), - operation_name.as_deref(), - extensions.as_ref(), - id, - ) -} - -fn classify_post(body: &[u8]) -> std::result::Result, String> { - if body.is_empty() { - return Err("GraphQL POST body is empty".to_string()); - } - let value: Value = serde_json::from_slice(body) - .map_err(|err| format!("GraphQL request body is not valid JSON: {err}"))?; - - match value { - Value::Array(items) => { - if items.is_empty() { - return Err("GraphQL batch request is empty".to_string()); - } - let mut operations = Vec::new(); - for item in items { - operations.extend(classify_json_envelope(&item)?); - } - Ok(operations) - } - Value::Object(_) => classify_json_envelope(&value), - _ => Err("GraphQL JSON envelope must be an object or array".to_string()), - } -} - -fn classify_json_envelope(value: &Value) -> std::result::Result, String> { - let obj = value - .as_object() - .ok_or_else(|| "GraphQL batch item must be an object".to_string())?; - let query = obj.get("query").and_then(Value::as_str); - let operation_name = obj.get("operationName").and_then(Value::as_str); - let extensions = obj.get("extensions"); - let id = obj - .get("id") - .or_else(|| obj.get("documentId")) - .or_else(|| obj.get("queryId")) - .and_then(Value::as_str) - .map(ToString::to_string); - - classify_envelope(query, operation_name, extensions, id) -} - -fn classify_envelope( - query: Option<&str>, - operation_name: Option<&str>, - extensions: Option<&Value>, - persisted_id: Option, -) -> std::result::Result, String> { - let persisted_hash = persisted_query_hash(extensions); - let query = query.filter(|q| !q.trim().is_empty()); - - if let Some(query) = query { - let mut operation = classify_document(query, operation_name)?; - if let Some(hash) = persisted_hash { - operation.persisted_query = true; - operation.persisted_query_hash = Some(hash); - } - if let Some(id) = persisted_id { - operation.persisted_query = true; - operation.persisted_query_id = Some(id); - } - return Ok(vec![operation]); - } - - if persisted_hash.is_some() || persisted_id.is_some() { - return Ok(vec![GraphqlOperationInfo { - operation_type: String::new(), - operation_name: operation_name.map(ToString::to_string), - fields: Vec::new(), - persisted_query: true, - persisted_query_hash: persisted_hash, - persisted_query_id: persisted_id, - }]); - } - - Err("GraphQL request has no query document or persisted query identifier".to_string()) -} - -fn classify_document( - query: &str, - operation_name: Option<&str>, -) -> std::result::Result { - let parser = Parser::new(query).recursion_limit(128).token_limit(20_000); - let cst = parser.parse(); - let mut parse_errors = cst.errors(); - if let Some(err) = parse_errors.next() { - return Err(format!("GraphQL document parse error: {err}")); - } - - let document = cst.document(); - let mut operations = Vec::new(); - let mut fragments = HashMap::new(); - - for definition in document.definitions() { - match definition { - cst::Definition::OperationDefinition(operation) => operations.push(operation), - cst::Definition::FragmentDefinition(fragment) => { - if let Some(name) = fragment.fragment_name().and_then(|n| n.name()) { - fragments.insert(name.text().to_string(), fragment); - } - } - _ => {} - } - } - - if operations.is_empty() { - return Err("GraphQL document contains no executable operation".to_string()); - } - - let selected = if let Some(expected_name) = operation_name.filter(|name| !name.is_empty()) { - operations - .into_iter() - .find(|op| { - op.name() - .is_some_and(|name| name.text().as_ref() == expected_name) - }) - .ok_or_else(|| format!("GraphQL operationName {expected_name:?} was not found"))? - } else if operations.len() == 1 { - operations.remove(0) - } else { - return Err("GraphQL document has multiple operations but no operationName".to_string()); - }; - - let operation_type = operation_type(&selected); - let operation_name = selected.name().map(|name| name.text().to_string()); - let selection_set = selected - .selection_set() - .ok_or_else(|| "GraphQL operation has no selection set".to_string())?; - let mut fields = HashSet::new(); - let mut visited_fragments = HashSet::new(); - collect_root_fields( - selection_set, - &fragments, - &mut visited_fragments, - &mut fields, - ); - let mut fields: Vec<_> = fields.into_iter().collect(); - fields.sort(); - - Ok(GraphqlOperationInfo { - operation_type, - operation_name, - fields, - persisted_query: false, - persisted_query_hash: None, - persisted_query_id: None, - }) -} - -fn operation_type(operation: &cst::OperationDefinition) -> String { - let Some(operation_type) = operation.operation_type() else { - return "query".to_string(); - }; - if operation_type.mutation_token().is_some() { - "mutation".to_string() - } else if operation_type.subscription_token().is_some() { - "subscription".to_string() - } else { - "query".to_string() - } -} - -fn collect_root_fields( - selection_set: cst::SelectionSet, - fragments: &HashMap, - visited_fragments: &mut HashSet, - fields: &mut HashSet, -) { - for selection in selection_set.selections() { - match selection { - cst::Selection::Field(field) => { - if let Some(name) = field.name() { - fields.insert(name.text().to_string()); - } - } - cst::Selection::InlineFragment(fragment) => { - if let Some(selection_set) = fragment.selection_set() { - collect_root_fields(selection_set, fragments, visited_fragments, fields); - } - } - cst::Selection::FragmentSpread(spread) => { - let Some(name) = spread.fragment_name().and_then(|n| n.name()) else { - continue; - }; - let name = name.text().to_string(); - if !visited_fragments.insert(name.clone()) { - continue; - } - if let Some(fragment) = fragments.get(&name) - && let Some(selection_set) = fragment.selection_set() - { - collect_root_fields(selection_set, fragments, visited_fragments, fields); - } - } - } - } -} - -fn persisted_query_hash(extensions: Option<&Value>) -> Option { - extensions? - .get("persistedQuery")? - .get("sha256Hash")? - .as_str() - .filter(|hash| !hash.is_empty()) - .map(ToString::to_string) -} - -fn unique_query_value( - params: &HashMap>, - key: &str, -) -> std::result::Result, String> { - let Some(values) = params.get(key) else { - return Ok(None); - }; - if values.len() > 1 { - return Err(format!( - "GraphQL GET parameter {key:?} must not appear more than once" - )); - } - Ok(values.first().filter(|value| !value.is_empty()).cloned()) -} - -fn unique_persisted_query_id( - params: &HashMap>, -) -> std::result::Result, String> { - let mut selected: Option<(String, String)> = None; - for key in ["id", "documentId", "queryId"] { - let Some(value) = unique_query_value(params, key)? else { - continue; - }; - if let Some((existing_key, _)) = selected { - return Err(format!( - "GraphQL GET persisted-query id parameters {existing_key:?} and {key:?} must not be combined" - )); - } - selected = Some((key.to_string(), value)); - } - Ok(selected.map(|(_, value)| value)) -} - -fn header_str(request: &L7Request) -> Result<&str> { - let header_end = request - .raw_header - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map_or(request.raw_header.len(), |p| p + 4); - std::str::from_utf8(&request.raw_header[..header_end]) - .map_err(|_| miette!("GraphQL HTTP headers contain invalid UTF-8")) -} - -fn reject_unsupported_headers(headers: &str) -> Result<()> { - for line in headers.lines().skip(1) { - let lower = line.to_ascii_lowercase(); - if lower.starts_with("content-encoding:") { - let encoding = lower.split_once(':').map_or("", |(_, v)| v.trim()); - if !encoding.is_empty() && encoding != "identity" { - return Err(miette!( - "GraphQL request content-encoding {encoding:?} is not supported" - )); - } - } - if lower.starts_with("content-type:") { - let content_type = lower.split_once(':').map_or("", |(_, v)| v.trim()); - if content_type.starts_with("multipart/") { - return Err(miette!("GraphQL multipart requests are not supported")); - } - } - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::l7::provider::BodyLength; - - fn request(method: &str, target: &str) -> L7Request { - L7Request { - action: method.to_string(), - target: target.to_string(), - query_params: crate::l7::rest::parse_target_query(target).unwrap().1, - raw_header: format!("{method} {target} HTTP/1.1\r\nHost: example.com\r\n\r\n") - .into_bytes(), - body_length: BodyLength::None, - } - } - - #[test] - fn classifies_simple_query() { - let req = request("POST", "/graphql"); - let info = classify_request(&req, br#"{"query":"query Viewer { viewer { login } }"}"#); - assert_eq!(info.error, None); - assert_eq!(info.operations[0].operation_type, "query"); - assert_eq!(info.operations[0].fields, vec!["viewer"]); - } - - #[test] - fn classifies_mutation_field_not_alias() { - let req = request("POST", "/graphql"); - let info = classify_request( - &req, - br#"{"query":"mutation M { safeAlias: volumeDelete(volumeId:\"x\") { id } }","operationName":"M"}"#, - ); - assert_eq!(info.error, None); - assert_eq!(info.operations[0].operation_type, "mutation"); - assert_eq!(info.operations[0].operation_name.as_deref(), Some("M")); - assert_eq!(info.operations[0].fields, vec!["volumeDelete"]); - } - - #[test] - fn expands_root_fragments() { - let req = request("POST", "/graphql"); - let info = classify_request( - &req, - br#"{"query":"query Q { ...RootFields } fragment RootFields on Query { viewer repository(owner:\"o\", name:\"r\") { id } }"}"#, - ); - assert_eq!(info.error, None); - assert_eq!(info.operations[0].fields, vec!["repository", "viewer"]); - } - - #[test] - fn multiple_operations_without_name_errors() { - let req = request("POST", "/graphql"); - let info = classify_request( - &req, - br#"{"query":"query A { viewer { login } } query B { rateLimit { limit } }"}"#, - ); - assert!(info.error.unwrap().contains("multiple operations")); - } - - #[test] - fn detects_hash_only_apollo_persisted_query() { - let req = request("POST", "/graphql"); - let info = classify_request( - &req, - br#"{"operationName":"Viewer","extensions":{"persistedQuery":{"version":1,"sha256Hash":"abc123"}}}"#, - ); - assert_eq!(info.error, None); - let op = &info.operations[0]; - assert!(op.persisted_query); - assert_eq!(op.operation_name.as_deref(), Some("Viewer")); - assert_eq!(op.persisted_query_hash.as_deref(), Some("abc123")); - } - - #[test] - fn graphql_get_rejects_duplicate_query_parameter() { - let req = request( - "GET", - "/graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D&query=mutation+Delete+%7B+volumeDelete(volumeId%3A%22x%22)+%7B+id+%7D+%7D", - ); - let info = classify_request(&req, b""); - assert!( - info.error - .as_deref() - .is_some_and(|err| err.contains("must not appear more than once")), - "expected duplicate control parameter error, got {info:?}" - ); - } - - #[test] - fn graphql_get_rejects_ambiguous_persisted_query_ids() { - let req = request("GET", "/graphql?id=one&queryId=two"); - let info = classify_request(&req, b""); - assert!( - info.error - .as_deref() - .is_some_and(|err| err.contains("must not be combined")), - "expected ambiguous persisted-query id error, got {info:?}" - ); - } - - #[tokio::test] - async fn chunked_graphql_post_is_normalized_after_inspection() { - let body = br#"{"query":"query Viewer { viewer { login } }"}"#; - let mut raw_header = - b"POST /graphql HTTP/1.1\r\nHost: example.com\r\nTransfer-Encoding: chunked\r\nTrailer: X-Sig\r\nX-Test: yes\r\n\r\n" - .to_vec(); - raw_header.extend_from_slice(format!("{:x}\r\n", body.len()).as_bytes()); - raw_header.extend_from_slice(body); - raw_header.extend_from_slice(b"\r\n0\r\nX-Sig: ignored\r\n\r\n"); - - let mut req = L7Request { - action: "POST".to_string(), - target: "/graphql".to_string(), - query_params: HashMap::new(), - raw_header, - body_length: BodyLength::Chunked, - }; - let mut client = tokio::io::empty(); - - let info = inspect_graphql_request(&mut client, &mut req, DEFAULT_MAX_BODY_BYTES) - .await - .expect("chunked body should inspect"); - - assert_eq!(info.error, None); - assert!(matches!( - req.body_length, - BodyLength::ContentLength(len) if len == body.len() as u64 - )); - let forwarded = String::from_utf8_lossy(&req.raw_header); - assert!(forwarded.contains(&format!("Content-Length: {}", body.len()))); - assert!(forwarded.contains("X-Test: yes\r\n")); - assert!(!forwarded.to_ascii_lowercase().contains("transfer-encoding")); - assert!(!forwarded.to_ascii_lowercase().contains("trailer:")); - assert!(req.raw_header.ends_with(body)); - } - - #[tokio::test] - async fn absolute_form_chunked_graphql_post_classifies_after_inspection() { - let body = br#"{"query":"query Viewer { viewer { login } }"}"#; - let mut raw_header = - b"POST http://example.com/graphql HTTP/1.1\r\nHost: example.com\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\nConnection: close\r\n\r\n" - .to_vec(); - raw_header.extend_from_slice(format!("{:x}\r\n", body.len()).as_bytes()); - raw_header.extend_from_slice(body); - raw_header.extend_from_slice(b"\r\n0\r\n\r\n"); - - let mut req = L7Request { - action: "POST".to_string(), - target: "/graphql".to_string(), - query_params: HashMap::new(), - raw_header, - body_length: BodyLength::Chunked, - }; - let mut client = tokio::io::empty(); - - let info = inspect_graphql_request(&mut client, &mut req, DEFAULT_MAX_BODY_BYTES) - .await - .expect("absolute-form chunked body should inspect"); - - assert_eq!(info.error, None); - assert_eq!(info.operations[0].operation_type, "query"); - assert_eq!(info.operations[0].fields, vec!["viewer"]); - } - - #[tokio::test] - async fn absolute_form_chunked_graphql_post_is_allowed_by_field_policy() { - let body = br#"{"query":"query Viewer { viewer { login } }"}"#; - let mut raw_header = - b"POST http://host.openshell.internal:8080/graphql HTTP/1.1\r\nHost: host.openshell.internal:8080\r\nContent-Type: application/json\r\nTransfer-Encoding: chunked\r\nConnection: close\r\n\r\n" - .to_vec(); - raw_header.extend_from_slice(format!("{:x}\r\n", body.len()).as_bytes()); - raw_header.extend_from_slice(body); - raw_header.extend_from_slice(b"\r\n0\r\n\r\n"); - - let mut req = L7Request { - action: "POST".to_string(), - target: "/graphql".to_string(), - query_params: HashMap::new(), - raw_header, - body_length: BodyLength::Chunked, - }; - let mut client = tokio::io::empty(); - let info = inspect_graphql_request(&mut client, &mut req, DEFAULT_MAX_BODY_BYTES) - .await - .expect("chunked body should inspect"); - - let data = r" -network_policies: - test_graphql_l7: - name: test_graphql_l7 - endpoints: - - host: host.openshell.internal - port: 8080 - protocol: graphql - enforcement: enforce - persisted_queries: allow_registered - graphql_persisted_queries: - abc123: - operation_type: query - operation_name: Viewer - fields: [viewer] - rules: - - allow: - operation_type: query - fields: [viewer] - - allow: - operation_type: mutation - fields: [createIssue] - deny_rules: - - operation_type: mutation - fields: [deleteRepository] - binaries: - - { path: /usr/bin/python3 } -"; - let engine = crate::opa::OpaEngine::from_strings( - include_str!("../../data/sandbox-policy.rego"), - data, - ) - .expect("policy should load"); - let ctx = crate::l7::relay::L7EvalContext { - host: "host.openshell.internal".to_string(), - port: 8080, - policy_name: "test_graphql_l7".to_string(), - binary_path: "/usr/bin/python3".to_string(), - ancestors: Vec::new(), - cmdline_paths: Vec::new(), - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - let request_info = crate::l7::L7RequestInfo { - action: req.action, - target: req.target, - query_params: req.query_params, - graphql: Some(info), - jsonrpc: None, - }; - - let tunnel_engine = engine - .clone_engine_for_tunnel(engine.current_generation()) - .expect("tunnel engine should clone"); - let (allowed, reason) = - crate::l7::relay::evaluate_l7_request(&tunnel_engine, &ctx, &request_info) - .expect("evaluation should complete"); - - assert!(allowed, "expected query to be allowed, got {reason}"); - } -} diff --git a/crates/openshell-sandbox/src/l7/mod.rs b/crates/openshell-sandbox/src/l7/mod.rs deleted file mode 100644 index 14fee09da..000000000 --- a/crates/openshell-sandbox/src/l7/mod.rs +++ /dev/null @@ -1,2489 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! L7 protocol-aware inspection for the CONNECT proxy. -//! -//! When an endpoint is configured with a `protocol` field (e.g. `rest`, `sql`), -//! the proxy inspects application-layer traffic within the tunnel instead of -//! doing a raw `copy_bidirectional`. Each request within the tunnel is parsed, -//! evaluated against OPA policy, and either forwarded or denied. - -pub mod graphql; -pub(crate) mod http; -pub mod inference; -pub mod jsonrpc; -pub mod path; -pub mod provider; -pub mod relay; -pub mod rest; -pub mod tls; -pub(crate) mod token_grant_injection; -pub(crate) mod websocket; - -/// Application-layer protocol for L7 inspection. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum L7Protocol { - Rest, - Websocket, - Graphql, - Sql, - JsonRpc, -} - -impl L7Protocol { - pub fn parse(s: &str) -> Option { - match s.to_ascii_lowercase().as_str() { - "rest" => Some(Self::Rest), - "websocket" => Some(Self::Websocket), - "graphql" => Some(Self::Graphql), - "sql" => Some(Self::Sql), - "json-rpc" => Some(Self::JsonRpc), - _ => None, - } - } -} - -/// TLS handling mode for proxy connections. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum TlsMode { - /// Auto-detect TLS by peeking the first bytes. If TLS is detected, - /// terminate it transparently. This is the default for all endpoints. - #[default] - Auto, - /// Explicit opt-out: raw tunnel with no TLS termination and no credential - /// injection. Use for client-cert mTLS to upstream or non-standard protocols. - Skip, -} - -/// Enforcement mode for L7 policy decisions. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -pub enum EnforcementMode { - /// Log violations but allow traffic through (safe migration path). - #[default] - Audit, - /// Deny violations — blocked requests never reach upstream. - Enforce, -} - -/// L7 configuration for an endpoint, extracted from policy data. -#[allow( - clippy::struct_excessive_bools, - reason = "Endpoint config mirrors independent policy schema toggles." -)] -#[derive(Debug, Clone)] -pub struct L7EndpointConfig { - pub protocol: L7Protocol, - /// Optional endpoint-level HTTP path glob used to select between L7 - /// protocols that share the same host:port. - pub path: String, - pub tls: TlsMode, - pub enforcement: EnforcementMode, - /// Maximum GraphQL request body bytes to buffer for inspection. - pub graphql_max_body_bytes: usize, - /// Maximum JSON-RPC request body bytes to buffer for inspection. - pub json_rpc_max_body_bytes: usize, - /// When true, percent-encoded `/` (`%2F`) is preserved in path segments - /// rather than rejected at the parser. Needed by upstreams like GitLab - /// that embed `%2F` in namespaced project paths. Defaults to false. - pub allow_encoded_slash: bool, - /// Opt-in rewrite of credential placeholders in client-to-server - /// WebSocket text messages after an allowed HTTP 101 upgrade. - pub websocket_credential_rewrite: bool, - /// Opt-in rewrite of credential placeholders in supported textual REST - /// request bodies before forwarding upstream. - pub request_body_credential_rewrite: bool, - /// When true, client-to-server GraphQL-over-WebSocket operation messages - /// are classified with the same operation policy used by GraphQL-over-HTTP. - pub websocket_graphql_policy: bool, -} - -/// Result of an L7 policy decision for a single request. -#[derive(Debug, Clone)] -pub struct L7Decision { - pub allowed: bool, - pub reason: String, - pub matched_rule: Option, -} - -/// Parsed L7 request metadata used for policy evaluation and logging. -#[derive(Debug, Clone)] -pub struct L7RequestInfo { - /// Protocol action: HTTP method (GET, POST, ...) or SQL command (SELECT, INSERT, ...). - pub action: String, - /// Target: URL path for REST, or empty for SQL. - pub target: String, - /// Decoded query parameter multimap for REST requests. - pub query_params: std::collections::HashMap>, - /// Parsed GraphQL operation metadata for GraphQL endpoints. - pub graphql: Option, - /// Parsed JSON-RPC request metadata for JSON-RPC endpoints. - pub jsonrpc: Option, -} - -/// Parse an L7 endpoint config from a regorus Value (returned by Rego query). -/// -/// The value is expected to be the raw endpoint object from the Rego data, -/// containing fields: `protocol`, optionally `tls`, `enforcement`. -pub fn parse_l7_config(val: ®orus::Value) -> Option { - let protocol_val = get_object_str(val, "protocol")?; - let protocol = L7Protocol::parse(&protocol_val)?; - - let tls = match get_object_str(val, "tls").as_deref() { - Some("skip") => TlsMode::Skip, - Some("terminate") => { - let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Other) - .severity(openshell_ocsf::SeverityId::Medium) - .message( - "'tls: terminate' is deprecated; TLS termination is now automatic. \ - Use 'tls: skip' to explicitly disable. This field will be removed in a future version.", - ) - .build(); - openshell_ocsf::ocsf_emit!(event); - TlsMode::Auto - } - Some("passthrough") => { - let event = openshell_ocsf::NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(openshell_ocsf::ActivityId::Other) - .severity(openshell_ocsf::SeverityId::Medium) - .message( - "'tls: passthrough' is deprecated; TLS termination is now automatic. \ - Use 'tls: skip' to explicitly disable. This field will be removed in a future version.", - ) - .build(); - openshell_ocsf::ocsf_emit!(event); - TlsMode::Auto - } - _ => TlsMode::Auto, - }; - - let enforcement = match get_object_str(val, "enforcement").as_deref() { - Some("enforce") => EnforcementMode::Enforce, - _ => EnforcementMode::Audit, - }; - - let allow_encoded_slash = get_object_bool(val, "allow_encoded_slash").unwrap_or(false); - let websocket_credential_rewrite = - get_object_bool(val, "websocket_credential_rewrite").unwrap_or(false); - let request_body_credential_rewrite = - get_object_bool(val, "request_body_credential_rewrite").unwrap_or(false); - let websocket_graphql_policy = - protocol == L7Protocol::Websocket && endpoint_has_graphql_policy(val); - let graphql_max_body_bytes = get_object_u64(val, "graphql_max_body_bytes") - .and_then(|v| usize::try_from(v).ok()) - .filter(|v| *v > 0) - .unwrap_or(graphql::DEFAULT_MAX_BODY_BYTES); - let json_rpc_max_body_bytes = get_object_u64(val, "json_rpc_max_body_bytes") - .and_then(|v| usize::try_from(v).ok()) - .filter(|v| *v > 0) - .unwrap_or(jsonrpc::DEFAULT_MAX_BODY_BYTES); - - Some(L7EndpointConfig { - protocol, - path: get_object_str(val, "path").unwrap_or_default(), - tls, - enforcement, - graphql_max_body_bytes, - json_rpc_max_body_bytes, - allow_encoded_slash, - websocket_credential_rewrite, - request_body_credential_rewrite, - websocket_graphql_policy, - }) -} - -impl L7EndpointConfig { - pub fn matches_path(&self, path: &str) -> bool { - endpoint_path_matches(&self.path, path) - } - - pub fn path_specificity(&self) -> usize { - if self.path.is_empty() { - 0 - } else { - self.path.chars().filter(|c| *c != '*').count() - } - } -} - -pub fn endpoint_path_matches(pattern: &str, path: &str) -> bool { - if pattern.is_empty() || pattern == "**" || pattern == "/**" { - return true; - } - if pattern == path { - return true; - } - if let Some(prefix) = pattern.strip_suffix("/**") { - return path == prefix || path.starts_with(&format!("{prefix}/")); - } - glob::Pattern::new(pattern).is_ok_and(|glob| glob.matches(path)) -} - -/// Parse the `tls` field from an endpoint config, independent of L7 protocol. -/// -/// Used to check for `tls: skip` even on L4-only endpoints (no `protocol` -/// field) that explicitly opt out of TLS auto-detection. -pub fn parse_tls_mode(val: ®orus::Value) -> TlsMode { - match get_object_str(val, "tls").as_deref() { - Some("skip") => TlsMode::Skip, - // "terminate" and "passthrough" are deprecated aliases (logged by parse_l7_config); fall through to Auto. - _ => TlsMode::Auto, - } -} - -/// Extract a bool value from a regorus object. Returns `None` when the key -/// is absent or not a boolean. -fn get_object_bool(val: ®orus::Value, key: &str) -> Option { - let key_val = regorus::Value::String(key.into()); - match val { - regorus::Value::Object(map) => match map.get(&key_val) { - Some(regorus::Value::Bool(b)) => Some(*b), - _ => None, - }, - _ => None, - } -} - -fn get_object_u64(val: ®orus::Value, key: &str) -> Option { - let key_val = regorus::Value::String(key.into()); - match val { - regorus::Value::Object(map) => match map.get(&key_val) { - Some(regorus::Value::Number(n)) => n.as_u64(), - _ => None, - }, - _ => None, - } -} - -/// Extract a string value from a regorus object. -fn get_object_str(val: ®orus::Value, key: &str) -> Option { - let key_val = regorus::Value::String(key.into()); - match val { - regorus::Value::Object(map) => match map.get(&key_val) { - Some(regorus::Value::String(s)) => { - let s = s.to_string(); - if s.is_empty() { None } else { Some(s) } - } - _ => None, - }, - _ => None, - } -} - -fn endpoint_has_graphql_policy(val: ®orus::Value) -> bool { - has_non_empty_object_field(val, "graphql_persisted_queries") - || has_graphql_persisted_query_mode(val) - || rules_have_graphql_policy(val, "rules", true) - || rules_have_graphql_policy(val, "deny_rules", false) -} - -fn rules_have_graphql_policy(val: ®orus::Value, key: &str, allow_wrapped: bool) -> bool { - let Some(regorus::Value::Array(rules)) = get_object_value(val, key) else { - return false; - }; - rules.iter().any(|rule| { - let rule = if allow_wrapped { - get_object_value(rule, "allow").unwrap_or(rule) - } else { - rule - }; - has_graphql_rule_fields(rule) - }) -} - -fn has_graphql_rule_fields(val: ®orus::Value) -> bool { - has_non_empty_string_field(val, "operation_type") - || has_non_empty_string_field(val, "operation_name") - || has_non_empty_array_field(val, "fields") -} - -fn has_non_empty_string_field(val: ®orus::Value, key: &str) -> bool { - matches!(get_object_value(val, key), Some(regorus::Value::String(s)) if !s.is_empty()) -} - -fn has_non_empty_array_field(val: ®orus::Value, key: &str) -> bool { - matches!(get_object_value(val, key), Some(regorus::Value::Array(values)) if !values.is_empty()) -} - -fn has_non_empty_object_field(val: ®orus::Value, key: &str) -> bool { - matches!(get_object_value(val, key), Some(regorus::Value::Object(values)) if !values.is_empty()) -} - -fn has_graphql_persisted_query_mode(val: ®orus::Value) -> bool { - matches!( - get_object_value(val, "persisted_queries"), - Some(regorus::Value::String(mode)) if !mode.is_empty() && mode.as_ref() != "deny" - ) -} - -fn get_object_value<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { - let key_val = regorus::Value::String(key.into()); - match val { - regorus::Value::Object(map) => map.get(&key_val), - _ => None, - } -} - -/// Check a glob pattern for obvious syntax issues. -/// -/// Returns `Some(warning_message)` if the pattern looks malformed. -/// OPA's `glob.match` is forgiving, so these are warnings (not errors) -/// to surface likely typos without blocking policy loading. -fn check_glob_syntax(pattern: &str) -> Option { - let mut bracket_depth: i32 = 0; - for c in pattern.chars() { - match c { - '[' => bracket_depth += 1, - ']' => { - if bracket_depth == 0 { - return Some(format!("glob pattern '{pattern}' has unmatched ']'")); - } - bracket_depth -= 1; - } - _ => {} - } - } - if bracket_depth > 0 { - return Some(format!("glob pattern '{pattern}' has unclosed '['")); - } - - let mut brace_depth: i32 = 0; - for c in pattern.chars() { - match c { - '{' => brace_depth += 1, - '}' => { - if brace_depth == 0 { - return Some(format!("glob pattern '{pattern}' has unmatched '}}'")); - } - brace_depth -= 1; - } - _ => {} - } - } - if brace_depth > 0 { - return Some(format!("glob pattern '{pattern}' has unclosed '{{'")); - } - - None -} - -fn validate_host_wildcard(errors: &mut Vec, loc: &str, host: &str) { - if !host.contains('*') { - return; - } - - if host == "*" || host == "**" { - errors.push(format!( - "{loc}: host wildcard '{host}' matches all hosts; use specific patterns like '*.example.com'" - )); - return; - } - - let labels: Vec<&str> = host.split('.').collect(); - let first_label = labels.first().copied().unwrap_or_default(); - if labels.iter().skip(1).any(|label| label.contains('*')) { - errors.push(format!( - "{loc}: host wildcard may only appear in the first DNS label, got '{host}'" - )); - return; - } - if first_label.contains("**") && first_label != "**" { - errors.push(format!( - "{loc}: recursive host wildcard '**' is only allowed as the entire first DNS label, got '{host}'" - )); - return; - } - - // Reject TLD or single-label wildcards. They are accepted by the policy - // engine but silently fail at the proxy layer (see #787). - if labels.len() <= 2 { - errors.push(format!( - "{loc}: TLD wildcard '{host}' is not allowed; \ - use subdomain wildcards like '*.example.com' instead" - )); - } -} - -fn validate_graphql_operation_type( - errors: &mut Vec, - loc: &str, - value: Option<&str>, - required: bool, -) { - let Some(value) = value.filter(|v| !v.is_empty()) else { - if required { - errors.push(format!( - "{loc}.operation_type: required for GraphQL L7 rules" - )); - } - return; - }; - - let valid = ["query", "mutation", "subscription", "*"]; - if !valid.contains(&value.to_ascii_lowercase().as_str()) { - errors.push(format!( - "{loc}.operation_type: expected query, mutation, subscription, or *, got '{value}'" - )); - } -} - -fn validate_graphql_fields( - errors: &mut Vec, - warnings: &mut Vec, - loc: &str, - fields: Option<&serde_json::Value>, -) { - let Some(fields) = fields else { - return; - }; - let Some(items) = fields.as_array() else { - errors.push(format!( - "{loc}.fields: expected array of GraphQL root field globs" - )); - return; - }; - if items.is_empty() { - errors.push(format!( - "{loc}.fields: list must not be empty; omit fields to match all root fields" - )); - return; - } - for item in items { - let Some(field) = item.as_str() else { - errors.push(format!("{loc}.fields: all values must be strings")); - continue; - }; - if field.is_empty() { - errors.push(format!("{loc}.fields: field glob must not be empty")); - } else if let Some(warning) = check_glob_syntax(field) { - warnings.push(format!("{loc}.fields: {warning}")); - } - } -} - -fn validate_graphql_rule( - errors: &mut Vec, - warnings: &mut Vec, - loc: &str, - rule: &serde_json::Value, - required: bool, -) { - validate_graphql_operation_type( - errors, - loc, - rule.get("operation_type").and_then(|v| v.as_str()), - required, - ); - if let Some(name) = rule.get("operation_name").and_then(|v| v.as_str()) - && !name.is_empty() - && let Some(warning) = check_glob_syntax(name) - { - warnings.push(format!("{loc}.operation_name: {warning}")); - } - validate_graphql_fields(errors, warnings, loc, rule.get("fields")); -} - -fn json_rule_has_graphql_fields(rule: &serde_json::Value) -> bool { - rule.get("operation_type") - .and_then(|v| v.as_str()) - .is_some_and(|v| !v.is_empty()) - || rule - .get("operation_name") - .and_then(|v| v.as_str()) - .is_some_and(|v| !v.is_empty()) - || rule.get("fields").is_some() -} - -fn json_rule_has_transport_fields(rule: &serde_json::Value) -> bool { - rule.get("method").is_some() || rule.get("path").is_some() || rule.get("query").is_some() -} - -fn json_endpoint_has_graphql_policy(ep: &serde_json::Value) -> bool { - ep.get("graphql_persisted_queries") - .and_then(|v| v.as_object()) - .is_some_and(|v| !v.is_empty()) - || ep - .get("persisted_queries") - .and_then(|v| v.as_str()) - .is_some_and(|v| !v.is_empty() && v != "deny") - || ep - .get("rules") - .and_then(|v| v.as_array()) - .is_some_and(|rules| { - rules.iter().any(|rule| { - rule.get("allow") - .or(Some(rule)) - .is_some_and(json_rule_has_graphql_fields) - }) - }) - || ep - .get("deny_rules") - .and_then(|v| v.as_array()) - .is_some_and(|rules| rules.iter().any(json_rule_has_graphql_fields)) -} - -/// Validate L7 policy configuration in the loaded OPA data. -/// -/// Returns a list of errors and warnings. Errors should prevent sandbox startup; -/// warnings are logged but don't block. -pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec) { - let mut errors = Vec::new(); - let mut warnings = Vec::new(); - - let Some(policies) = data_json - .get("network_policies") - .and_then(|v| v.as_object()) - else { - return (errors, warnings); - }; - - for (name, policy) in policies { - let Some(endpoints) = policy.get("endpoints").and_then(|v| v.as_array()) else { - continue; - }; - - for (i, ep) in endpoints.iter().enumerate() { - let protocol = ep.get("protocol").and_then(|v| v.as_str()).unwrap_or(""); - let tls = ep.get("tls").and_then(|v| v.as_str()).unwrap_or(""); - let enforcement = ep.get("enforcement").and_then(|v| v.as_str()).unwrap_or(""); - let access = ep.get("access").and_then(|v| v.as_str()).unwrap_or(""); - let has_rules = ep - .get("rules") - .and_then(|v| v.as_array()) - .is_some_and(|a| !a.is_empty()); - let websocket_has_graphql_policy = - protocol == "websocket" && json_endpoint_has_graphql_policy(ep); - let host = ep.get("host").and_then(|v| v.as_str()).unwrap_or(""); - let endpoint_path = ep.get("path").and_then(|v| v.as_str()).unwrap_or(""); - - // Read ports from either "ports" array or scalar "port". - let ports: Vec = ep.get("ports").and_then(|v| v.as_array()).map_or_else( - || { - ep.get("port") - .and_then(serde_json::Value::as_u64) - .filter(|p| *p > 0) - .into_iter() - .collect() - }, - |arr| arr.iter().filter_map(serde_json::Value::as_u64).collect(), - ); - let loc = format!("{name}.endpoints[{i}]"); - - if !endpoint_path.is_empty() { - if !endpoint_path.starts_with('/') && endpoint_path != "**" { - errors.push(format!( - "{loc}: endpoint path must start with '/' or be '**', got '{endpoint_path}'" - )); - } - if let Some(warning) = check_glob_syntax(endpoint_path) { - warnings.push(format!("{loc}.path: {warning}")); - } - } - - validate_host_wildcard(&mut errors, &loc, host); - - // port + ports mutual exclusion - let has_scalar_port = ep - .get("port") - .and_then(serde_json::Value::as_u64) - .is_some_and(|p| p > 0); - let has_ports_array = ep - .get("ports") - .and_then(|v| v.as_array()) - .is_some_and(|a| !a.is_empty()); - if has_scalar_port && has_ports_array { - errors.push(format!( - "{loc}: port and ports are mutually exclusive; use ports for multiple ports" - )); - } - - // rules + access mutual exclusion - if has_rules && !access.is_empty() { - errors.push(format!("{loc}: rules and access are mutually exclusive")); - } - - // protocol requires rules or access - if !protocol.is_empty() && !has_rules && access.is_empty() { - errors.push(format!( - "{loc}: protocol requires rules or access to define allowed traffic" - )); - } - - if !protocol.is_empty() && L7Protocol::parse(protocol).is_none() { - errors.push(format!( - "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, sql, or json-rpc)" - )); - } - - if let Some(mode) = ep.get("persisted_queries").and_then(|v| v.as_str()) - && !mode.is_empty() - && mode != "deny" - && mode != "allow_registered" - { - errors.push(format!( - "{loc}: persisted_queries must be 'deny' or 'allow_registered', got '{mode}'" - )); - } - - if ep.get("graphql_max_body_bytes").is_some() { - let valid_max = ep - .get("graphql_max_body_bytes") - .and_then(serde_json::Value::as_u64) - .is_some_and(|v| v > 0); - if !valid_max { - errors.push(format!( - "{loc}: graphql_max_body_bytes must be a positive integer" - )); - } - } - - if ep.get("json_rpc_max_body_bytes").is_some() { - let valid_max = ep - .get("json_rpc_max_body_bytes") - .and_then(serde_json::Value::as_u64) - .is_some_and(|v| v > 0); - if !valid_max { - errors.push(format!( - "{loc}: json_rpc_max_body_bytes must be a positive integer" - )); - } - } - - if protocol != "graphql" - && protocol != "websocket" - && (ep.get("persisted_queries").is_some() - || ep.get("graphql_persisted_queries").is_some() - || ep.get("graphql_max_body_bytes").is_some()) - { - warnings.push(format!( - "{loc}: GraphQL-specific endpoint fields are ignored unless protocol is graphql or websocket" - )); - } - - if protocol != "json-rpc" && ep.get("json_rpc_max_body_bytes").is_some() { - warnings.push(format!( - "{loc}: JSON-RPC-specific endpoint fields are ignored unless protocol is json-rpc" - )); - } - - if ep - .get("websocket_credential_rewrite") - .and_then(serde_json::Value::as_bool) - .unwrap_or(false) - && protocol != "rest" - && protocol != "websocket" - { - warnings.push(format!( - "{loc}: websocket_credential_rewrite is ignored unless protocol is rest or websocket" - )); - } - - if ep - .get("request_body_credential_rewrite") - .and_then(serde_json::Value::as_bool) - .unwrap_or(false) - && protocol != "rest" - { - warnings.push(format!( - "{loc}: request_body_credential_rewrite is ignored unless protocol is rest" - )); - } - - if let Some(registry_value) = ep.get("graphql_persisted_queries") { - let Some(registry) = registry_value.as_object() else { - errors.push(format!( - "{loc}: graphql_persisted_queries must be a map keyed by hash or saved-query id" - )); - continue; - }; - for (key, op) in registry { - let registry_loc = format!("{loc}.graphql_persisted_queries[{key}]"); - validate_graphql_rule(&mut errors, &mut warnings, ®istry_loc, op, true); - } - } - - // Deprecated tls values: warn but don't error - if tls == "terminate" || tls == "passthrough" { - warnings.push(format!( - "{loc}: 'tls: {tls}' is deprecated; TLS termination is now automatic. Use 'tls: skip' to disable." - )); - } - - // tls: skip with L7 on port 443 won't work - if tls == "skip" && !protocol.is_empty() && ports.contains(&443) { - warnings.push(format!( - "{loc}: 'tls: skip' with L7 rules on port 443 — L7 inspection cannot work on encrypted traffic" - )); - } - - // sql + enforce blocked in v1 - if protocol == "sql" && enforcement == "enforce" { - errors.push(format!( - "{loc}: SQL enforcement requires full SQL parsing (not available in v1). Use `enforcement: audit`." - )); - } - - // rules with empty list - if ep - .get("rules") - .and_then(|v| v.as_array()) - .is_some_and(Vec::is_empty) - { - errors.push(format!( - "{loc}: rules list cannot be empty (would deny all traffic). Use `access: full` or remove rules." - )); - } - - // port 443 + rest + tls: skip — L7 won't work (already handled above) - // The old warning about missing `tls: terminate` is no longer needed - // because TLS termination is now automatic. - - // Validate deny_rules - let has_deny_rules = ep - .get("deny_rules") - .and_then(|v| v.as_array()) - .is_some_and(|a| !a.is_empty()); - if has_deny_rules { - // deny_rules require L7 inspection - if protocol.is_empty() { - errors.push(format!( - "{loc}: deny_rules require protocol (L7 inspection must be enabled)" - )); - } - - // deny_rules require some allow base (access or rules) - if !has_rules && access.is_empty() { - errors.push(format!( - "{loc}: deny_rules require rules or access to define the base allow set" - )); - } - - if let Some(deny_rules) = ep.get("deny_rules").and_then(|v| v.as_array()) { - for (deny_idx, deny_rule) in deny_rules.iter().enumerate() { - let deny_loc = format!("{loc}.deny_rules[{deny_idx}]"); - - // Validate method - if let Some(method) = deny_rule.get("method").and_then(|m| m.as_str()) - && !method.is_empty() - && (protocol == "rest" || protocol == "websocket") - { - let valid_methods = valid_methods_for_protocol(protocol); - if !valid_methods.contains(&method.to_ascii_uppercase().as_str()) { - warnings.push(format!( - "{deny_loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." - , valid_methods.join(", ") - )); - } - } - - // Validate path glob syntax - if let Some(path) = deny_rule.get("path").and_then(|p| p.as_str()) - && let Some(warning) = check_glob_syntax(path) - { - warnings.push(format!("{deny_loc}.path: {warning}")); - } - - // Validate query matchers — mirrors allow-side validation exactly - if let Some(query) = deny_rule.get("query").filter(|v| !v.is_null()) { - let Some(query_obj) = query.as_object() else { - errors.push(format!( - "{deny_loc}.query: expected map of query matchers" - )); - continue; - }; - - for (param, matcher) in query_obj { - if let Some(glob_str) = matcher.as_str() { - if let Some(warning) = check_glob_syntax(glob_str) { - warnings - .push(format!("{deny_loc}.query.{param}: {warning}")); - } - continue; - } - - let Some(matcher_obj) = matcher.as_object() else { - errors.push(format!( - "{deny_loc}.query.{param}: expected string glob or object with `any`" - )); - continue; - }; - - let has_any = matcher_obj.get("any").is_some(); - let has_glob = matcher_obj.get("glob").is_some(); - let has_unknown = - matcher_obj.keys().any(|k| k != "any" && k != "glob"); - if has_unknown { - errors.push(format!( - "{deny_loc}.query.{param}: unknown matcher keys; only `glob` or `any` are supported" - )); - continue; - } - - if has_glob && has_any { - errors.push(format!( - "{deny_loc}.query.{param}: matcher cannot specify both `glob` and `any`" - )); - continue; - } - - if !has_glob && !has_any { - errors.push(format!( - "{deny_loc}.query.{param}: object matcher requires `glob` string or non-empty `any` list" - )); - continue; - } - - if has_glob { - match matcher_obj.get("glob").and_then(|v| v.as_str()) { - None => { - errors.push(format!( - "{deny_loc}.query.{param}.glob: expected glob string" - )); - } - Some(g) => { - if let Some(warning) = check_glob_syntax(g) { - warnings.push(format!( - "{deny_loc}.query.{param}.glob: {warning}" - )); - } - } - } - continue; - } - - let any = matcher_obj.get("any").and_then(|v| v.as_array()); - let Some(any) = any else { - errors.push(format!( - "{deny_loc}.query.{param}.any: expected array of glob strings" - )); - continue; - }; - - if any.is_empty() { - errors.push(format!( - "{deny_loc}.query.{param}.any: list must not be empty" - )); - continue; - } - - if any.iter().any(|v| v.as_str().is_none()) { - errors.push(format!( - "{deny_loc}.query.{param}.any: all values must be strings" - )); - } - - for item in any.iter().filter_map(|v| v.as_str()) { - if let Some(warning) = check_glob_syntax(item) { - warnings.push(format!( - "{deny_loc}.query.{param}.any: {warning}" - )); - } - } - } - } - - // SQL command validation - if let Some(command) = deny_rule.get("command").and_then(|c| c.as_str()) - && !command.is_empty() - && protocol == "rest" - { - warnings - .push(format!("{deny_loc}: command is for SQL protocol, not REST")); - } - - let deny_has_graphql = json_rule_has_graphql_fields(deny_rule); - if protocol == "websocket" - && deny_has_graphql - && json_rule_has_transport_fields(deny_rule) - { - errors.push(format!( - "{deny_loc}: WebSocket GraphQL deny rules must not combine method/path/query with operation_type/operation_name/fields" - )); - } - - if protocol == "graphql" || (protocol == "websocket" && deny_has_graphql) { - validate_graphql_rule( - &mut errors, - &mut warnings, - &deny_loc, - deny_rule, - true, - ); - } else if deny_has_graphql { - warnings.push(format!( - "{deny_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" - )); - } - } - } - } - - // Empty deny_rules list (explicitly set but empty) - if ep - .get("deny_rules") - .and_then(|v| v.as_array()) - .is_some_and(Vec::is_empty) - { - errors.push(format!( - "{loc}: deny_rules list cannot be empty (would have no effect). Remove it if no denials are needed." - )); - } - - // Validate HTTP methods in rules - if has_rules && (protocol == "rest" || protocol == "websocket") { - let valid_methods = valid_methods_for_protocol(protocol); - if let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { - for (rule_idx, rule) in rules.iter().enumerate() { - if let Some(method) = rule - .get("allow") - .and_then(|a| a.get("method")) - .and_then(|m| m.as_str()) - && !method.is_empty() - && !valid_methods.contains(&method.to_ascii_uppercase().as_str()) - { - warnings.push(format!( - "{loc}: Unknown HTTP/WebSocket method '{method}'. Standard methods: {}." - , valid_methods.join(", ") - )); - } - - let Some(query) = rule - .get("allow") - .and_then(|a| a.get("query")) - .filter(|v| !v.is_null()) - else { - continue; - }; - - let Some(query_obj) = query.as_object() else { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query: expected map of query matchers" - )); - continue; - }; - - for (param, matcher) in query_obj { - if let Some(glob_str) = matcher.as_str() { - if let Some(warning) = check_glob_syntax(glob_str) { - warnings.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}: {warning}" - )); - } - continue; - } - - let Some(matcher_obj) = matcher.as_object() else { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}: expected string glob or object with `any`" - )); - continue; - }; - - let has_any = matcher_obj.get("any").is_some(); - let has_glob = matcher_obj.get("glob").is_some(); - let has_unknown = matcher_obj.keys().any(|k| k != "any" && k != "glob"); - if has_unknown { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}: unknown matcher keys; only `glob` or `any` are supported" - )); - continue; - } - - if has_glob && has_any { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}: matcher cannot specify both `glob` and `any`" - )); - continue; - } - - if !has_glob && !has_any { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}: object matcher requires `glob` string or non-empty `any` list" - )); - continue; - } - - if has_glob { - match matcher_obj.get("glob").and_then(|v| v.as_str()) { - None => { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}.glob: expected glob string" - )); - } - Some(g) => { - if let Some(warning) = check_glob_syntax(g) { - warnings.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}.glob: {warning}" - )); - } - } - } - continue; - } - - let any = matcher_obj.get("any").and_then(|v| v.as_array()); - let Some(any) = any else { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}.any: expected array of glob strings" - )); - continue; - }; - - if any.is_empty() { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}.any: list must not be empty" - )); - continue; - } - - if any.iter().any(|v| v.as_str().is_none()) { - errors.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}.any: all values must be strings" - )); - } - - for item in any.iter().filter_map(|v| v.as_str()) { - if let Some(warning) = check_glob_syntax(item) { - warnings.push(format!( - "{loc}.rules[{rule_idx}].allow.query.{param}.any: {warning}" - )); - } - } - } - } - } - } - - if has_rules && let Some(rules) = ep.get("rules").and_then(|v| v.as_array()) { - for (rule_idx, rule) in rules.iter().enumerate() { - let allow = rule.get("allow").unwrap_or(rule); - let rule_loc = format!("{loc}.rules[{rule_idx}].allow"); - let allow_has_graphql = json_rule_has_graphql_fields(allow); - if websocket_has_graphql_policy - && allow - .get("method") - .and_then(|m| m.as_str()) - .is_some_and(|method| method.eq_ignore_ascii_case("WEBSOCKET_TEXT")) - { - errors.push(format!( - "{rule_loc}: WebSocket endpoints with GraphQL operation policy must use operation_type/operation_name/fields rules for client messages instead of WEBSOCKET_TEXT" - )); - } - if protocol == "websocket" - && allow_has_graphql - && json_rule_has_transport_fields(allow) - { - errors.push(format!( - "{rule_loc}: WebSocket GraphQL allow rules must not combine method/path/query with operation_type/operation_name/fields" - )); - } - if protocol == "graphql" || (protocol == "websocket" && allow_has_graphql) { - validate_graphql_rule(&mut errors, &mut warnings, &rule_loc, allow, true); - } else if allow_has_graphql { - warnings.push(format!( - "{rule_loc}: GraphQL rule fields are ignored unless protocol is graphql or websocket" - )); - } - } - } - } - } - - (errors, warnings) -} - -/// Expand `access` presets into explicit `rules` in the policy data. -/// -/// This preprocesses the JSON data so Rego only needs to handle explicit rules. -pub fn expand_access_presets(data: &mut serde_json::Value) { - let Some(policies) = data - .get_mut("network_policies") - .and_then(|v| v.as_object_mut()) - else { - return; - }; - - for (_name, policy) in policies.iter_mut() { - let Some(endpoints) = policy.get_mut("endpoints").and_then(|v| v.as_array_mut()) else { - continue; - }; - - for ep in endpoints.iter_mut() { - let access = ep - .get("access") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - if access.is_empty() { - continue; - } - - // Don't expand if rules already exist (validation will catch this) - if ep - .get("rules") - .and_then(|v| v.as_array()) - .is_some_and(|a| !a.is_empty()) - { - continue; - } - - let protocol = ep - .get("protocol") - .and_then(|v| v.as_str()) - .unwrap_or("rest"); - let rules = if protocol == "graphql" { - match access.as_str() { - "read-only" => vec![graphql_rule_json("query")], - "read-write" => vec![graphql_rule_json("query"), graphql_rule_json("mutation")], - "full" => vec![graphql_rule_json("*")], - _ => continue, - } - } else if protocol == "websocket" { - match access.as_str() { - "read-only" => vec![rule_json("GET", "**")], - "read-write" => vec![rule_json("GET", "**"), rule_json("WEBSOCKET_TEXT", "**")], - "full" => vec![rule_json("*", "**")], - _ => continue, - } - } else { - match access.as_str() { - "read-only" => vec![ - rule_json("GET", "**"), - rule_json("HEAD", "**"), - rule_json("OPTIONS", "**"), - ], - "read-write" => vec![ - rule_json("GET", "**"), - rule_json("HEAD", "**"), - rule_json("OPTIONS", "**"), - rule_json("POST", "**"), - rule_json("PUT", "**"), - rule_json("PATCH", "**"), - ], - "full" => vec![rule_json("*", "**")], - _ => continue, - } - }; - - ep.as_object_mut() - .unwrap() - .insert("rules".to_string(), serde_json::Value::Array(rules)); - } - } -} - -fn rule_json(method: &str, path: &str) -> serde_json::Value { - serde_json::json!({ - "allow": { - "method": method, - "path": path - } - }) -} - -fn valid_methods_for_protocol(protocol: &str) -> &'static [&'static str] { - match protocol { - "websocket" => &["GET", "WEBSOCKET_TEXT", "*"], - _ => &[ - "GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "*", - ], - } -} - -fn graphql_rule_json(operation_type: &str) -> serde_json::Value { - serde_json::json!({ - "allow": { - "operation_type": operation_type - } - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn parse_l7_config_rest_enforce() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "tls": "terminate", "enforcement": "enforce", "host": "api.example.com", "port": 443}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert_eq!(config.protocol, L7Protocol::Rest); - // "terminate" is deprecated and treated as Auto. - assert_eq!(config.tls, TlsMode::Auto); - assert_eq!(config.enforcement, EnforcementMode::Enforce); - } - - #[test] - fn parse_l7_config_defaults() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "host": "api.example.com", "port": 80}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert_eq!(config.protocol, L7Protocol::Rest); - assert_eq!(config.tls, TlsMode::Auto); - assert_eq!(config.enforcement, EnforcementMode::Audit); - assert_eq!( - config.json_rpc_max_body_bytes, - jsonrpc::DEFAULT_MAX_BODY_BYTES - ); - } - - #[test] - fn parse_l7_config_jsonrpc_max_body_bytes() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "json-rpc", "host": "mcp.example.com", "port": 443, "json_rpc_max_body_bytes": 131072}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert_eq!(config.protocol, L7Protocol::JsonRpc); - assert_eq!(config.json_rpc_max_body_bytes, 131_072); - } - - #[test] - fn validate_jsonrpc_max_body_bytes_must_be_positive() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "mcp.example.com", - "port": 443, - "protocol": "json-rpc", - "access": "full", - "json_rpc_max_body_bytes": 0 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("json_rpc_max_body_bytes must be a positive integer")), - "should reject non-positive JSON-RPC max body size, got errors: {errors:?}" - ); - } - - #[test] - fn parse_l7_config_websocket_protocol() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert_eq!(config.protocol, L7Protocol::Websocket); - } - - #[test] - fn parse_l7_config_skip() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "tls": "skip", "host": "api.example.com", "port": 443}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert_eq!(config.tls, TlsMode::Skip); - } - - #[test] - fn parse_l7_config_no_protocol() { - let val = - regorus::Value::from_json_str(r#"{"host": "api.example.com", "port": 443}"#).unwrap(); - assert!(parse_l7_config(&val).is_none()); - } - - #[test] - fn parse_l7_config_allow_encoded_slash_defaults_false() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "host": "api.example.com", "port": 443}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(!config.allow_encoded_slash); - } - - #[test] - fn parse_l7_config_allow_encoded_slash_opt_in() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "host": "gitlab.example.com", "port": 443, "allow_encoded_slash": true}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(config.allow_encoded_slash); - } - - #[test] - fn parse_l7_config_websocket_credential_rewrite_defaults_false() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(!config.websocket_credential_rewrite); - } - - #[test] - fn parse_l7_config_websocket_credential_rewrite_opt_in() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "host": "gateway.example.com", "port": 443, "websocket_credential_rewrite": true}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(config.websocket_credential_rewrite); - } - - #[test] - fn parse_l7_config_request_body_credential_rewrite_defaults_false() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "host": "slack.com", "port": 443}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(!config.request_body_credential_rewrite); - } - - #[test] - fn parse_l7_config_request_body_credential_rewrite_opt_in() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "rest", "host": "slack.com", "port": 443, "request_body_credential_rewrite": true}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(config.request_body_credential_rewrite); - } - - #[test] - fn parse_l7_config_websocket_graphql_policy_defaults_false() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}]}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(!config.websocket_graphql_policy); - } - - #[test] - fn parse_l7_config_websocket_graphql_policy_detects_operation_rules() { - let val = regorus::Value::from_json_str( - r#"{"protocol": "websocket", "host": "gateway.example.com", "port": 443, "rules": [{"allow": {"method": "GET", "path": "/graphql"}}, {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}}]}"#, - ) - .unwrap(); - let config = parse_l7_config(&val).unwrap(); - assert!(config.websocket_graphql_policy); - } - - #[test] - fn validate_websocket_credential_rewrite_warns_unless_rest_or_websocket() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "gateway.example.com", - "port": 443, - "websocket_credential_rewrite": true - }], - "binaries": [] - } - } - }); - let (_errors, warnings) = validate_l7_policies(&data); - assert!( - warnings - .iter() - .any(|w| w.contains("websocket_credential_rewrite is ignored")), - "expected websocket_credential_rewrite warning: {warnings:?}" - ); - } - - #[test] - fn validate_request_body_credential_rewrite_warns_unless_rest() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "gateway.example.com", - "port": 443, - "protocol": "websocket", - "request_body_credential_rewrite": true - }], - "binaries": [] - } - } - }); - let (_errors, warnings) = validate_l7_policies(&data); - assert!( - warnings - .iter() - .any(|w| w.contains("request_body_credential_rewrite is ignored")), - "expected request_body_credential_rewrite warning: {warnings:?}" - ); - } - - #[test] - fn expand_websocket_read_write_access_includes_text_messages() { - let mut data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "gateway.example.com", - "port": 443, - "protocol": "websocket", - "access": "read-write" - }], - "binaries": [] - } - } - }); - - expand_access_presets(&mut data); - let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] - .as_array() - .unwrap(); - let methods: Vec<&str> = rules - .iter() - .map(|r| r["allow"]["method"].as_str().unwrap()) - .collect(); - assert!(methods.contains(&"GET")); - assert!(methods.contains(&"WEBSOCKET_TEXT")); - } - - #[test] - fn validate_websocket_accepts_graphql_operation_rules() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "gateway.example.com", - "port": 443, - "protocol": "websocket", - "rules": [ - {"allow": {"method": "GET", "path": "/graphql"}}, - {"allow": {"operation_type": "subscription", "fields": ["messageAdded"]}} - ] - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!(errors.is_empty(), "expected no errors: {errors:?}"); - assert!(warnings.is_empty(), "expected no warnings: {warnings:?}"); - } - - #[test] - fn validate_websocket_graphql_rule_requires_operation_type() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "gateway.example.com", - "port": 443, - "protocol": "websocket", - "rules": [ - {"allow": {"method": "GET", "path": "/graphql"}}, - {"allow": {"fields": ["messageAdded"]}} - ] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("operation_type")), - "expected missing operation_type error: {errors:?}" - ); - } - - #[test] - fn validate_websocket_graphql_rule_rejects_mixed_transport_fields() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "gateway.example.com", - "port": 443, - "protocol": "websocket", - "rules": [ - {"allow": {"method": "GET", "path": "/graphql"}}, - {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql", "operation_type": "subscription"}} - ] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("must not combine")), - "expected mixed-field error: {errors:?}" - ); - } - - #[test] - fn validate_websocket_graphql_policy_rejects_raw_text_message_rule() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "gateway.example.com", - "port": 443, - "protocol": "websocket", - "rules": [ - {"allow": {"method": "GET", "path": "/graphql"}}, - {"allow": {"method": "WEBSOCKET_TEXT", "path": "/graphql"}}, - {"allow": {"operation_type": "query"}} - ] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("instead of WEBSOCKET_TEXT")), - "expected raw WEBSOCKET_TEXT rejection: {errors:?}" - ); - } - - #[test] - fn validate_rules_and_access_mutual_exclusion() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "access": "read-only", - "rules": [{"allow": {"method": "GET", "path": "**"}}] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!(errors.iter().any(|e| e.contains("mutually exclusive"))); - } - - #[test] - fn validate_protocol_requires_rules_or_access() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest" - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("requires rules or access")) - ); - } - - #[test] - fn validate_sql_enforce_blocked() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "db.internal", - "port": 5432, - "protocol": "sql", - "enforcement": "enforce", - "rules": [{"allow": {"command": "SELECT"}}] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!(errors.iter().any(|e| e.contains("SQL enforcement"))); - } - - #[test] - fn validate_tls_terminate_deprecated_warning() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "tls": "terminate", - "protocol": "rest", - "access": "full" - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "deprecated tls should not error: {errors:?}" - ); - assert!( - warnings.iter().any(|w| w.contains("deprecated")), - "should warn about deprecated tls: {warnings:?}" - ); - } - - #[test] - fn validate_tls_skip_with_l7_on_443_warns() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "tls": "skip", - "protocol": "rest", - "access": "read-only" - }], - "binaries": [] - } - } - }); - let (_errors, warnings) = validate_l7_policies(&data); - assert!( - warnings.iter().any(|w| w.contains("tls: skip")), - "should warn about skip + L7 on 443: {warnings:?}" - ); - } - - #[test] - fn validate_port_443_rest_no_tls_no_warning() { - // With auto-TLS, no warning is needed for port 443 + rest without - // explicit tls field — TLS will be auto-detected. - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "access": "read-only" - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!(errors.is_empty(), "should have no errors: {errors:?}"); - assert!( - !warnings.iter().any(|w| w.contains("tls")), - "should have no tls warnings with auto-detect: {warnings:?}" - ); - } - - #[test] - fn expand_read_only_preset() { - let mut data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 80, - "protocol": "rest", - "access": "read-only" - }], - "binaries": [] - } - } - }); - expand_access_presets(&mut data); - let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] - .as_array() - .unwrap(); - assert_eq!(rules.len(), 3); - let methods: Vec<&str> = rules - .iter() - .map(|r| r["allow"]["method"].as_str().unwrap()) - .collect(); - assert!(methods.contains(&"GET")); - assert!(methods.contains(&"HEAD")); - assert!(methods.contains(&"OPTIONS")); - } - - #[test] - fn expand_full_preset() { - let mut data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 80, - "protocol": "rest", - "access": "full" - }], - "binaries": [] - } - } - }); - expand_access_presets(&mut data); - let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] - .as_array() - .unwrap(); - assert_eq!(rules.len(), 1); - assert_eq!(rules[0]["allow"]["method"].as_str().unwrap(), "*"); - assert_eq!(rules[0]["allow"]["path"].as_str().unwrap(), "**"); - } - - #[test] - fn expand_graphql_readonly_preset() { - let mut data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "graphql", - "access": "read-only" - }], - "binaries": [] - } - } - }); - expand_access_presets(&mut data); - let rules = data["network_policies"]["test"]["endpoints"][0]["rules"] - .as_array() - .unwrap(); - assert_eq!(rules.len(), 1); - assert_eq!( - rules[0]["allow"]["operation_type"].as_str().unwrap(), - "query" - ); - } - - #[test] - fn validate_graphql_rule_requires_operation_type() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "graphql", - "rules": [{ - "allow": { - "fields": ["viewer"] - } - }] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("operation_type")), - "GraphQL rules should require operation_type: {errors:?}" - ); - } - - #[test] - fn validate_graphql_persisted_query_mode() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "graphql", - "access": "full", - "persisted_queries": "allow_all" - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("persisted_queries")), - "invalid persisted query mode should be rejected: {errors:?}" - ); - } - - #[test] - fn l4_only_endpoint_untouched() { - let mut data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443 - }], - "binaries": [] - } - } - }); - expand_access_presets(&mut data); - assert!( - data["network_policies"]["test"]["endpoints"][0] - .get("rules") - .is_none() - ); - } - - // ---- Host wildcard validation tests ---- - - #[test] - fn validate_wildcard_host_star_only_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "*", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("matches all hosts")), - "Bare * host should be rejected, got errors: {errors:?}" - ); - } - - #[test] - fn validate_wildcard_host_double_star_only_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "**", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("matches all hosts")), - "Bare ** host should be rejected, got errors: {errors:?}" - ); - } - - #[test] - fn validate_wildcard_host_mid_label_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "foo.*.example.com", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("first DNS label")), - "Mid-label wildcard should be rejected, got errors: {errors:?}" - ); - } - - #[test] - fn validate_wildcard_host_single_label_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "*com", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("TLD wildcard")), - "Single-label wildcard should be rejected, got errors: {errors:?}" - ); - } - - #[test] - fn validate_wildcard_host_recursive_intra_label_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "foo**.example.com", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("recursive host wildcard")), - "Recursive intra-label wildcard should be rejected, got errors: {errors:?}" - ); - } - - #[test] - fn validate_wildcard_host_tld_rejected() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "*.com", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("TLD wildcard")), - "*.com should be rejected as TLD wildcard, got errors: {errors:?}" - ); - } - - #[test] - fn validate_wildcard_host_double_star_tld_rejected() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "**.org", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("TLD wildcard")), - "**.org should be rejected as TLD wildcard, got errors: {errors:?}" - ); - } - - #[test] - fn validate_wildcard_host_valid_no_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "*.example.com", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "*.example.com should be valid, got errors: {errors:?}" - ); - assert!( - warnings.is_empty(), - "*.example.com should not warn, got warnings: {warnings:?}" - ); - } - - #[test] - fn validate_wildcard_host_double_star_valid_no_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "**.example.com", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "**.example.com should be valid, got errors: {errors:?}" - ); - assert!( - warnings.is_empty(), - "**.example.com should not warn, got warnings: {warnings:?}" - ); - } - - #[test] - fn validate_wildcard_host_intra_label_valid_no_error() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "*-aiplatform.googleapis.com", - "port": 443 - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "*-aiplatform.googleapis.com should be valid, got errors: {errors:?}" - ); - assert!( - warnings.is_empty(), - "*-aiplatform.googleapis.com should not warn, got warnings: {warnings:?}" - ); - } - - #[test] - fn validate_port_and_ports_mutually_exclusive() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "ports": [443, 8443] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("port and ports are mutually exclusive")), - "Should reject both port and ports, got errors: {errors:?}" - ); - } - - #[test] - fn validate_ports_array_rest_443_no_warning() { - // With auto-TLS, no warning needed for ports array containing 443. - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "ports": [443, 8080], - "protocol": "rest", - "access": "read-only" - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!(errors.is_empty(), "should have no errors: {errors:?}"); - assert!( - !warnings.iter().any(|w| w.contains("tls")), - "should have no tls warnings with auto-detect: {warnings:?}" - ); - } - - #[test] - fn validate_query_any_requires_non_empty_array() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 8080, - "protocol": "rest", - "rules": [{ - "allow": { - "method": "GET", - "path": "/download", - "query": { - "tag": { "any": [] } - } - } - }] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("allow.query.tag.any")), - "expected query any validation error, got: {errors:?}" - ); - } - - #[test] - fn validate_query_object_rejects_unknown_keys() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 8080, - "protocol": "rest", - "rules": [{ - "allow": { - "method": "GET", - "path": "/download", - "query": { - "tag": { "mode": "foo-*" } - } - } - }] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.iter().any(|e| e.contains("unknown matcher keys")), - "expected unknown query matcher key error, got: {errors:?}" - ); - } - - #[test] - fn validate_query_glob_warns_on_unclosed_bracket() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 8080, - "protocol": "rest", - "rules": [{ - "allow": { - "method": "GET", - "path": "/download", - "query": { - "tag": "[unclosed" - } - } - }] - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "malformed glob should warn, not error: {errors:?}" - ); - assert!( - warnings - .iter() - .any(|w| w.contains("unclosed '['") && w.contains("allow.query.tag")), - "expected glob syntax warning, got: {warnings:?}" - ); - } - - #[test] - fn validate_query_glob_warns_on_unclosed_brace() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 8080, - "protocol": "rest", - "rules": [{ - "allow": { - "method": "GET", - "path": "/download", - "query": { - "format": { "glob": "{json,xml" } - } - } - }] - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "malformed glob should warn, not error: {errors:?}" - ); - assert!( - warnings - .iter() - .any(|w| w.contains("unclosed '{'") && w.contains("allow.query.format.glob")), - "expected glob syntax warning, got: {warnings:?}" - ); - } - - #[test] - fn validate_query_any_warns_on_malformed_glob_item() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 8080, - "protocol": "rest", - "rules": [{ - "allow": { - "method": "GET", - "path": "/download", - "query": { - "tag": { "any": ["valid-*", "[bad"] } - } - } - }] - }], - "binaries": [] - } - } - }); - let (errors, warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "malformed glob in any should warn, not error: {errors:?}" - ); - assert!( - warnings - .iter() - .any(|w| w.contains("unclosed '['") && w.contains("allow.query.tag.any")), - "expected glob syntax warning for any item, got: {warnings:?}" - ); - } - - #[test] - fn validate_query_string_and_any_matchers_are_accepted() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 8080, - "protocol": "rest", - "rules": [{ - "allow": { - "method": "GET", - "path": "/download", - "query": { - "slug": "my-*", - "tag": { "any": ["foo-*", "bar-*"] }, - "owner": { "glob": "org-*" } - } - } - }] - }], - "binaries": [] - } - } - }); - let (errors, _warnings) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "valid query matcher shapes should not error: {errors:?}" - ); - } - - // --- Deny rules validation tests --- - - #[test] - fn validate_deny_rules_require_protocol() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "deny_rules": [{ "method": "POST", "path": "/admin" }] - }], - "binaries": [] - } - } - }); - let (errors, _) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("deny_rules require protocol")), - "should require protocol for deny_rules: {errors:?}" - ); - } - - #[test] - fn validate_deny_rules_require_allow_base() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "deny_rules": [{ "method": "POST", "path": "/admin" }] - }], - "binaries": [] - } - } - }); - let (errors, _) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("deny_rules require rules or access")), - "should require rules or access for deny_rules: {errors:?}" - ); - } - - #[test] - fn validate_deny_rules_empty_list_rejected() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "access": "full", - "deny_rules": [] - }], - "binaries": [] - } - } - }); - let (errors, _) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("deny_rules list cannot be empty")), - "should reject empty deny_rules: {errors:?}" - ); - } - - #[test] - fn validate_deny_rules_valid_config_accepted() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "access": "read-write", - "deny_rules": [ - { "method": "POST", "path": "/repos/*/pulls/*/reviews" }, - { "method": "PUT", "path": "/repos/*/branches/*/protection" } - ] - }], - "binaries": [] - } - } - }); - let (errors, _) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "valid deny_rules should not error: {errors:?}" - ); - } - - #[test] - fn validate_deny_rules_query_empty_any_rejected() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "access": "full", - "deny_rules": [{ - "method": "POST", - "path": "/admin", - "query": { "type": { "any": [] } } - }] - }], - "binaries": [] - } - } - }); - let (errors, _) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("any: list must not be empty")), - "should reject empty any list in deny query: {errors:?}" - ); - } - - #[test] - fn validate_deny_rules_query_non_string_rejected() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "access": "full", - "deny_rules": [{ - "method": "POST", - "path": "/admin", - "query": { "force": 123 } - }] - }], - "binaries": [] - } - } - }); - let (errors, _) = validate_l7_policies(&data); - assert!( - errors - .iter() - .any(|e| e.contains("expected string glob or object")), - "should reject non-string/non-object matcher in deny query: {errors:?}" - ); - } - - #[test] - fn validate_deny_rules_query_valid_matchers_accepted() { - let data = serde_json::json!({ - "network_policies": { - "test": { - "endpoints": [{ - "host": "api.example.com", - "port": 443, - "protocol": "rest", - "access": "full", - "deny_rules": [{ - "method": "POST", - "path": "/admin/**", - "query": { - "force": "true", - "type": { "any": ["admin-*", "root-*"] }, - "scope": { "glob": "org-*" } - } - }] - }], - "binaries": [] - } - } - }); - let (errors, _) = validate_l7_policies(&data); - assert!( - errors.is_empty(), - "valid deny query matchers should not error: {errors:?}" - ); - } -} diff --git a/crates/openshell-sandbox/src/l7/relay.rs b/crates/openshell-sandbox/src/l7/relay.rs deleted file mode 100644 index 08bce6ff5..000000000 --- a/crates/openshell-sandbox/src/l7/relay.rs +++ /dev/null @@ -1,2977 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! Protocol-aware bidirectional relay with L7 inspection. -//! -//! Replaces `copy_bidirectional` for endpoints with L7 configuration. -//! Parses each request within the tunnel, evaluates it against OPA policy, -//! and either forwards or denies the request. - -use crate::activity_aggregator::{ActivitySender, try_record_activity}; -use crate::l7::provider::{L7Provider, RelayOutcome}; -use crate::l7::rest::WebSocketExtensionMode; -use crate::l7::{EnforcementMode, L7EndpointConfig, L7Protocol, L7RequestInfo}; -use crate::opa::{PolicyGenerationGuard, TunnelPolicyEngine}; -use crate::secrets::{self, SecretResolver}; -use miette::{IntoDiagnostic, Result, miette}; -use openshell_ocsf::{ - ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, - NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, -}; -use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tracing::{debug, warn}; - -/// Context for L7 request policy evaluation. -pub struct L7EvalContext { - /// Host from the CONNECT request. - pub host: String, - /// Port from the CONNECT request. - pub port: u16, - /// Matched policy name from L4 evaluation. - pub policy_name: String, - /// Binary path (for cross-layer Rego evaluation). - pub binary_path: String, - /// Ancestor paths. - pub ancestors: Vec, - /// Cmdline paths. - pub cmdline_paths: Vec, - /// Supervisor-only placeholder resolver for outbound headers. - pub(crate) secret_resolver: Option>, - /// Anonymous activity counter channel. - pub(crate) activity_tx: Option, - /// Dynamic credentials (token grants) keyed by endpoint-bound provider metadata. - pub(crate) dynamic_credentials: Option< - Arc< - std::sync::RwLock< - std::collections::HashMap, - >, - >, - >, - /// Dynamic token grant resolver for endpoint-bound credentials. - pub(crate) token_grant_resolver: - Option>, -} - -#[derive(Default)] -pub(crate) struct UpgradeRelayOptions<'a> { - pub(crate) websocket_request: bool, - pub(crate) websocket: WebSocketUpgradeBehavior, - pub(crate) secret_resolver: Option>, - pub(crate) engine: Option<&'a TunnelPolicyEngine>, - pub(crate) ctx: Option<&'a L7EvalContext>, - pub(crate) enforcement: EnforcementMode, - pub(crate) target: String, - pub(crate) query_params: std::collections::HashMap>, - pub(crate) policy_name: String, -} - -#[derive(Default)] -pub(crate) struct WebSocketUpgradeBehavior { - pub(crate) credential_rewrite: bool, - pub(crate) message_policy: WebSocketMessagePolicy, - pub(crate) permessage_deflate: bool, -} - -#[derive(Clone, Copy, Default, PartialEq, Eq)] -pub(crate) enum WebSocketMessagePolicy { - #[default] - None, - Transport, - Graphql, -} - -impl WebSocketMessagePolicy { - fn inspects_messages(self) -> bool { - self != Self::None - } - - fn is_graphql(self) -> bool { - self == Self::Graphql - } -} - -#[derive(Debug, Clone, Copy)] -enum ParseRejectionMode { - L7Endpoint, - Passthrough, -} - -fn parse_rejection_detail(error: &str, mode: ParseRejectionMode) -> String { - if error.contains("encoded '/' (%2F)") { - match mode { - ParseRejectionMode::L7Endpoint => format!( - "{error}; set allow_encoded_slash: true on this endpoint if the upstream requires encoded slashes" - ), - ParseRejectionMode::Passthrough => format!( - "{error}; passthrough credential relay uses strict path parsing, so configure this endpoint with protocol: rest and allow_encoded_slash: true for encoded-slash APIs, or use tls: skip if HTTP parsing is not needed" - ), - } - } else { - error.to_string() - } -} - -fn emit_parse_rejection(ctx: &L7EvalContext, detail: &str, engine_type: &str) { - let policy_name = if ctx.policy_name.is_empty() { - "-" - } else { - &ctx.policy_name - }; - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .firewall_rule(policy_name, engine_type) - .message(format!( - "HTTP request rejected before policy evaluation for {}:{}", - ctx.host, ctx.port - )) - .status_detail(detail) - .build(); - ocsf_emit!(event); - emit_activity(ctx, true, "l7_parse_rejection"); -} - -/// Run protocol-aware L7 inspection on a tunnel. -/// -/// This replaces `copy_bidirectional` for L7-enabled endpoints. -/// Protocol detection (peek) is the caller's responsibility — this function -/// assumes the streams are already proven to carry the expected protocol. -/// For TLS-terminated connections, ALPN proves HTTP; for plaintext, the -/// caller peeks on the raw `TcpStream` before calling this. -pub async fn relay_with_inspection( - config: &L7EndpointConfig, - engine: TunnelPolicyEngine, - client: &mut C, - upstream: &mut U, - ctx: &L7EvalContext, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - match config.protocol { - L7Protocol::Rest | L7Protocol::Websocket => { - relay_rest(config, &engine, client, upstream, ctx).await - } - L7Protocol::Graphql => relay_graphql(config, &engine, client, upstream, ctx).await, - L7Protocol::Sql => { - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - // SQL provider is Phase 3 — fall through to passthrough with warning - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .severity(SeverityId::Low) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .message("SQL L7 provider not yet implemented, falling back to passthrough") - .build(); - ocsf_emit!(event); - } - tokio::io::copy_bidirectional(client, upstream) - .await - .into_diagnostic()?; - Ok(()) - } - L7Protocol::JsonRpc => relay_jsonrpc(config, &engine, client, upstream, ctx).await, - } -} - -/// Run HTTP L7 inspection with per-request protocol selection. -/// -/// This is used when multiple L7 endpoints share a host:port, for example a -/// REST API under `/repos/**` and a GraphQL API under `/graphql`. -pub async fn relay_with_route_selection( - configs: &[L7EndpointConfig], - engine: TunnelPolicyEngine, - client: &mut C, - upstream: &mut U, - ctx: &L7EvalContext, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - let provider = - crate::l7::rest::RestProvider::with_options(crate::l7::path::CanonicalizeOptions { - allow_encoded_slash: configs.iter().any(|config| config.allow_encoded_slash), - ..Default::default() - }); - - loop { - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let mut req = match provider.parse_request(client).await { - Ok(Some(req)) => req, - Ok(None) => return Ok(()), - Err(e) => { - if is_benign_connection_error(&e) { - debug!( - host = %ctx.host, - port = ctx.port, - error = %e, - "L7 route-selected connection closed" - ); - } else { - let detail = - parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); - emit_parse_rejection(ctx, &detail, "l7"); - } - return Ok(()); - } - }; - - let Some(config) = select_l7_config_for_path(configs, &req.target) else { - crate::l7::rest::RestProvider::default() - .deny( - &req, - &ctx.policy_name, - "no L7 endpoint path matched request", - client, - ) - .await?; - return Ok(()); - }; - - let graphql_info = if config.protocol == L7Protocol::Graphql { - match crate::l7::graphql::inspect_graphql_request( - client, - &mut req, - config.graphql_max_body_bytes, - ) - .await - { - Ok(info) => Some(info), - Err(e) => { - if is_benign_connection_error(&e) { - debug!( - host = %ctx.host, - port = ctx.port, - error = %e, - "GraphQL L7 connection closed" - ); - } else { - let detail = - parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); - emit_parse_rejection(ctx, &detail, "l7-graphql"); - } - return Ok(()); - } - } - } else { - None - }; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let (eval_target, redacted_target) = if let Some(ref resolver) = ctx.secret_resolver { - match secrets::rewrite_target_for_eval(&req.target, resolver) { - Ok(result) => (result.resolved, result.redacted), - Err(e) => { - warn!( - host = %ctx.host, - port = ctx.port, - error = %e, - "credential resolution failed in request target, rejecting" - ); - let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; - client.write_all(response).await.into_diagnostic()?; - client.flush().await.into_diagnostic()?; - return Ok(()); - } - } - } else { - (req.target.clone(), req.target.clone()) - }; - - let request_info = L7RequestInfo { - action: req.action.clone(), - target: redacted_target.clone(), - query_params: req.query_params.clone(), - graphql: graphql_info.clone(), - jsonrpc: None, - }; - let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); - if config.protocol == L7Protocol::Websocket && !websocket_request { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &req, - &ctx.policy_name, - "websocket endpoint requires a valid WebSocket upgrade request", - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - - let parse_error_reason = graphql_info - .as_ref() - .and_then(|info| info.error.as_deref()) - .map(|error| format!("GraphQL request rejected: {error}")); - let force_deny = parse_error_reason.is_some(); - let (allowed, reason) = if let Some(reason) = parse_error_reason { - (false, reason) - } else { - evaluate_l7_request(&engine, ctx, &request_info)? - }; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let decision_str = match (allowed, config.enforcement) { - (_, _) if force_deny => "deny", - (true, _) => "allow", - (false, EnforcementMode::Audit) => "audit", - (false, EnforcementMode::Enforce) => "deny", - }; - let engine_type = match config.protocol { - L7Protocol::Graphql => "l7-graphql", - L7Protocol::Websocket => "l7-websocket", - L7Protocol::Rest | L7Protocol::Sql | L7Protocol::JsonRpc => "l7", - }; - emit_l7_request_log( - ctx, - &request_info, - &redacted_target, - decision_str, - engine_type, - &reason, - graphql_info.as_ref(), - ); - - let _ = &eval_target; - - if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let outcome = crate::l7::rest::relay_http_request_with_options_guarded( - &req, - client, - upstream, - crate::l7::rest::RelayRequestOptions { - resolver: ctx.secret_resolver.as_deref(), - generation_guard: Some(engine.generation_guard()), - websocket_extensions: websocket_extension_mode(config), - request_body_credential_rewrite: config.protocol == L7Protocol::Rest - && config.request_body_credential_rewrite, - }, - ) - .await?; - match outcome { - RelayOutcome::Reusable => {} - RelayOutcome::Consumed => return Ok(()), - RelayOutcome::Upgraded { - overflow, - websocket_permessage_deflate, - } => { - let mut options = upgrade_options( - config, - ctx, - websocket_request, - &redacted_target, - &req.query_params, - Some(&engine), - ); - options.websocket.permessage_deflate = websocket_permessage_deflate; - return handle_upgrade( - client, upstream, overflow, &ctx.host, ctx.port, options, - ) - .await; - } - } - } else { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &req, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - } -} - -fn select_l7_config_for_path<'a>( - configs: &'a [L7EndpointConfig], - path: &str, -) -> Option<&'a L7EndpointConfig> { - configs - .iter() - .filter(|config| config.matches_path(path)) - .max_by_key(|config| config.path_specificity()) -} - -fn emit_l7_request_log( - ctx: &L7EvalContext, - request_info: &L7RequestInfo, - redacted_target: &str, - decision_str: &str, - engine_type: &str, - reason: &str, - graphql_info: Option<&crate::l7::graphql::GraphqlRequestInfo>, -) { - let (action_id, disposition_id, severity) = match decision_str { - "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), - "allow" | "audit" => ( - ActionId::Allowed, - DispositionId::Allowed, - SeverityId::Informational, - ), - _ => ( - ActionId::Other, - DispositionId::Other, - SeverityId::Informational, - ), - }; - let summary = graphql_info - .map(|info| format!(" {}", graphql_log_summary(info))) - .unwrap_or_default(); - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(action_id) - .disposition(disposition_id) - .severity(severity) - .http_request(HttpRequest::new( - &request_info.action, - OcsfUrl::new("http", &ctx.host, redacted_target, ctx.port), - )) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .firewall_rule(&ctx.policy_name, engine_type) - .message(format!( - "L7_REQUEST {decision_str} {} {}:{}{}{} reason={}", - request_info.action, ctx.host, ctx.port, redacted_target, summary, reason, - )) - .build(); - ocsf_emit!(event); - emit_activity(ctx, decision_str == "deny", "l7_policy"); -} - -fn emit_activity(ctx: &L7EvalContext, denied: bool, deny_group: &'static str) { - if let Some(tx) = &ctx.activity_tx { - let _ = try_record_activity(tx, denied, deny_group); - } -} - -/// Handle an upgraded connection (101 Switching Protocols). -/// -/// Forwards any overflow bytes from the upgrade response to the client, then -/// either switches to a parsed WebSocket relay for opted-in message policy / -/// credential rewriting or to raw bidirectional TCP copy for other upgrades. -pub(crate) async fn handle_upgrade( - client: &mut C, - upstream: &mut U, - overflow: Vec, - host: &str, - port: u16, - options: UpgradeRelayOptions<'_>, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - let use_websocket_relay = options.websocket_request - && (options.websocket.message_policy.inspects_messages() - || options.websocket.permessage_deflate - || (options.websocket.credential_rewrite && options.secret_resolver.is_some())); - let relay_mode = if use_websocket_relay { - "websocket parsed relay" - } else { - "raw bidirectional relay (L7 enforcement no longer active)" - }; - ocsf_emit!( - NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .activity_name("Upgrade") - .severity(SeverityId::Informational) - .dst_endpoint(Endpoint::from_domain(host, port)) - .message(format!( - "101 Switching Protocols — {relay_mode} [host:{host} port:{port} overflow_bytes:{}]", - overflow.len() - )) - .build() - ); - if use_websocket_relay { - let resolver = if options.websocket.credential_rewrite { - options.secret_resolver.as_deref() - } else { - None - }; - let inspector = if options.websocket.message_policy.inspects_messages() { - match (options.engine, options.ctx) { - (Some(engine), Some(ctx)) => Some(crate::l7::websocket::InspectionOptions { - engine, - ctx, - enforcement: options.enforcement, - target: options.target.clone(), - query_params: options.query_params.clone(), - graphql_policy: options.websocket.message_policy.is_graphql(), - }), - _ => { - return Err(miette!( - "websocket message inspection missing policy context" - )); - } - } - } else { - None - }; - let compression = if options.websocket.permessage_deflate { - crate::l7::websocket::WebSocketCompression::PermessageDeflate - } else { - crate::l7::websocket::WebSocketCompression::None - }; - return crate::l7::websocket::relay_with_options( - client, - upstream, - overflow, - host, - port, - crate::l7::websocket::RelayOptions { - policy_name: &options.policy_name, - resolver, - inspector, - compression, - }, - ) - .await; - } - if !overflow.is_empty() { - client.write_all(&overflow).await.into_diagnostic()?; - client.flush().await.into_diagnostic()?; - } - tokio::io::copy_bidirectional(client, upstream) - .await - .into_diagnostic()?; - Ok(()) -} - -pub(crate) fn upgrade_options<'a>( - config: &L7EndpointConfig, - ctx: &'a L7EvalContext, - websocket_request: bool, - target: &str, - query_params: &std::collections::HashMap>, - engine: Option<&'a TunnelPolicyEngine>, -) -> UpgradeRelayOptions<'a> { - let websocket_credential_rewrite = - matches!(config.protocol, L7Protocol::Rest | L7Protocol::Websocket) - && config.websocket_credential_rewrite; - let websocket_message_policy = if config.protocol == L7Protocol::Websocket { - if config.websocket_graphql_policy { - WebSocketMessagePolicy::Graphql - } else { - WebSocketMessagePolicy::Transport - } - } else { - WebSocketMessagePolicy::None - }; - UpgradeRelayOptions { - websocket_request, - websocket: WebSocketUpgradeBehavior { - credential_rewrite: websocket_credential_rewrite, - message_policy: websocket_message_policy, - permessage_deflate: false, - }, - secret_resolver: if websocket_credential_rewrite { - ctx.secret_resolver.clone() - } else { - None - }, - engine, - ctx: engine.map(|_| ctx), - enforcement: config.enforcement, - target: target.to_string(), - query_params: query_params.clone(), - policy_name: ctx.policy_name.clone(), - } -} - -pub(crate) fn websocket_extension_mode(config: &L7EndpointConfig) -> WebSocketExtensionMode { - if config.protocol == L7Protocol::Websocket - || (config.protocol == L7Protocol::Rest && config.websocket_credential_rewrite) - { - WebSocketExtensionMode::PermessageDeflate - } else { - WebSocketExtensionMode::Preserve - } -} - -/// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. -async fn relay_rest( - config: &L7EndpointConfig, - engine: &TunnelPolicyEngine, - client: &mut C, - upstream: &mut U, - ctx: &L7EvalContext, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - // Build a provider carrying the per-endpoint canonicalization options so - // request parsing honors the endpoint's `allow_encoded_slash` setting - // (e.g. APIs like GitLab that embed `%2F` in path segments). - let provider = - crate::l7::rest::RestProvider::with_options(crate::l7::path::CanonicalizeOptions { - allow_encoded_slash: config.allow_encoded_slash, - ..Default::default() - }); - loop { - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - // Parse one HTTP request from client - let req = match provider.parse_request(client).await { - Ok(Some(req)) => req, - Ok(None) => return Ok(()), // Client closed connection - Err(e) => { - if is_benign_connection_error(&e) { - debug!( - host = %ctx.host, - port = ctx.port, - error = %e, - "L7 connection closed" - ); - } else { - let detail = - parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); - emit_parse_rejection(ctx, &detail, "l7"); - } - return Ok(()); // Close connection on parse error - } - }; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - // Rewrite credential placeholders in the request target BEFORE OPA - // evaluation. OPA sees the redacted path; the resolved path goes only - // to the upstream write. - let (eval_target, redacted_target) = if let Some(ref resolver) = ctx.secret_resolver { - match secrets::rewrite_target_for_eval(&req.target, resolver) { - Ok(result) => (result.resolved, result.redacted), - Err(e) => { - warn!( - host = %ctx.host, - port = ctx.port, - error = %e, - "credential resolution failed in request target, rejecting" - ); - let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; - client.write_all(response).await.into_diagnostic()?; - client.flush().await.into_diagnostic()?; - return Ok(()); - } - } - } else { - (req.target.clone(), req.target.clone()) - }; - - let request_info = L7RequestInfo { - action: req.action.clone(), - target: redacted_target.clone(), - query_params: req.query_params.clone(), - graphql: None, - jsonrpc: None, - }; - let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); - if config.protocol == L7Protocol::Websocket && !websocket_request { - provider - .deny_with_redacted_target( - &req, - &ctx.policy_name, - "websocket endpoint requires a valid WebSocket upgrade request", - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - - // Evaluate L7 policy via Rego (using redacted target) - let (allowed, reason) = evaluate_l7_request(engine, ctx, &request_info)?; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - // Check if this is an upgrade request for logging purposes. - let header_end = req - .raw_header - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map_or(req.raw_header.len(), |p| p + 4); - let is_upgrade_request = { - let h = String::from_utf8_lossy(&req.raw_header[..header_end]); - h.lines() - .skip(1) - .any(|l| l.to_ascii_lowercase().starts_with("upgrade:")) - }; - - let decision_str = match (allowed, config.enforcement, is_upgrade_request) { - (true, _, true) => "allow_upgrade", - (true, _, false) => "allow", - (false, EnforcementMode::Audit, _) => "audit", - (false, EnforcementMode::Enforce, _) => "deny", - }; - - // Log every L7 decision as an OCSF HTTP Activity event. - // Uses redacted_target (path only, no query params) to avoid logging secrets. - { - let (action_id, disposition_id, severity) = match decision_str { - "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), - "allow" | "audit" => ( - ActionId::Allowed, - DispositionId::Allowed, - SeverityId::Informational, - ), - _ => ( - ActionId::Other, - DispositionId::Other, - SeverityId::Informational, - ), - }; - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(action_id) - .disposition(disposition_id) - .severity(severity) - .http_request(HttpRequest::new( - &request_info.action, - OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), - )) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .firewall_rule(&ctx.policy_name, "l7") - .message(format!( - "L7_REQUEST {decision_str} {} {}:{}{} reason={}", - request_info.action, ctx.host, ctx.port, redacted_target, reason, - )) - .build(); - ocsf_emit!(event); - } - - // Store the resolved target for the deny response redaction - let _ = &eval_target; - - if allowed || config.enforcement == EnforcementMode::Audit { - let req_with_auth = - match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { - Ok(req) => req, - Err(e) => { - warn!( - host = %ctx.host, - port = ctx.port, - error = %e, - "Token grant failed in L7 relay" - ); - write_bad_gateway_response(client).await?; - return Ok(()); - } - }; - - // Forward request to upstream and relay response - let outcome = crate::l7::rest::relay_http_request_with_options_guarded( - &req_with_auth, - client, - upstream, - crate::l7::rest::RelayRequestOptions { - resolver: ctx.secret_resolver.as_deref(), - generation_guard: Some(engine.generation_guard()), - websocket_extensions: websocket_extension_mode(config), - request_body_credential_rewrite: config.protocol == L7Protocol::Rest - && config.request_body_credential_rewrite, - }, - ) - .await?; - match outcome { - RelayOutcome::Reusable => {} // continue loop - RelayOutcome::Consumed => { - debug!( - host = %ctx.host, - port = ctx.port, - "Upstream connection not reusable, closing L7 relay" - ); - return Ok(()); - } - RelayOutcome::Upgraded { - overflow, - websocket_permessage_deflate, - } => { - let mut options = upgrade_options( - config, - ctx, - websocket_request, - &redacted_target, - &req_with_auth.query_params, - Some(engine), - ); - options.websocket.permessage_deflate = websocket_permessage_deflate; - return handle_upgrade( - client, upstream, overflow, &ctx.host, ctx.port, options, - ) - .await; - } - } - } else { - // Enforce mode: deny with 403 and close connection (use redacted target) - provider - .deny_with_redacted_target( - &req, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - } -} - -fn close_if_stale(guard: &PolicyGenerationGuard, ctx: &L7EvalContext) -> bool { - if !guard.is_stale() { - return false; - } - - ocsf_emit!( - NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .firewall_rule(&ctx.policy_name, "l7") - .message(format!( - "L7 tunnel closed after policy reload [host:{} port:{} captured_generation:{} current_generation:{}]", - ctx.host, - ctx.port, - guard.captured_generation(), - guard.current_generation(), - )) - .build() - ); - true -} - -async fn relay_jsonrpc( - config: &L7EndpointConfig, - engine: &TunnelPolicyEngine, - client: &mut C, - upstream: &mut U, - ctx: &L7EvalContext, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - loop { - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let parsed = match crate::l7::jsonrpc::parse_jsonrpc_http_request( - client, - config.json_rpc_max_body_bytes, - crate::l7::path::CanonicalizeOptions { - allow_encoded_slash: config.allow_encoded_slash, - ..Default::default() - }, - ) - .await - { - Ok(Some(parsed)) => parsed, - Ok(None) => return Ok(()), - Err(e) => { - if is_benign_connection_error(&e) { - debug!( - host = %ctx.host, - port = ctx.port, - error = %e, - "JSON-RPC L7 connection closed" - ); - } else { - let detail = - parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); - emit_parse_rejection(ctx, &detail, "l7-jsonrpc"); - } - return Ok(()); - } - }; - - let req = parsed.request; - let jsonrpc_info = parsed.info; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let redacted_target = req.target.clone(); - - let request_info = L7RequestInfo { - action: req.action.clone(), - target: redacted_target.clone(), - query_params: req.query_params.clone(), - graphql: None, - jsonrpc: Some(jsonrpc_info.clone()), - }; - - let parse_error_reason = jsonrpc_info - .error - .as_deref() - .map(|e| format!("JSON-RPC request rejected: {e}")); - let force_deny = parse_error_reason.is_some(); - let (allowed, reason, jsonrpc_log_info) = if let Some(reason) = parse_error_reason { - (false, reason, jsonrpc_info.clone()) - } else { - let evaluation = - evaluate_jsonrpc_l7_request_for_log(engine, ctx, &request_info, &jsonrpc_info)?; - (evaluation.allowed, evaluation.reason, evaluation.log_info) - }; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let decision_str = match (allowed, config.enforcement) { - (_, _) if force_deny => "deny", - (true, _) => "allow", - (false, EnforcementMode::Audit) => "audit", - (false, EnforcementMode::Enforce) => "deny", - }; - - { - let (action_id, disposition_id, severity) = match decision_str { - "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), - _ => ( - ActionId::Allowed, - DispositionId::Allowed, - SeverityId::Informational, - ), - }; - let endpoint = format!("{}:{}{}", ctx.host, ctx.port, redacted_target); - let params_sha256 = jsonrpc_log_info - .params_sha256() - .unwrap_or_else(|| "".to_string()); - let policy_version = engine.captured_generation(); - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(action_id) - .disposition(disposition_id) - .severity(severity) - .http_request(HttpRequest::new( - &request_info.action, - OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), - )) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .firewall_rule(&ctx.policy_name, "l7-jsonrpc") - .message(jsonrpc_log_message( - decision_str, - &request_info.action, - &endpoint, - &jsonrpc_log_info, - ¶ms_sha256, - policy_version, - &reason, - )) - .build(); - ocsf_emit!(event); - } - - if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( - &req, - client, - upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), - ) - .await?; - match outcome { - RelayOutcome::Reusable => {} - RelayOutcome::Consumed => { - debug!( - host = %ctx.host, - port = ctx.port, - "Upstream connection not reusable, closing JSON-RPC L7 relay" - ); - return Ok(()); - } - RelayOutcome::Upgraded { .. } => { - return Ok(()); - } - } - } else { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &req, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - } -} - -async fn relay_graphql( - config: &L7EndpointConfig, - engine: &TunnelPolicyEngine, - client: &mut C, - upstream: &mut U, - ctx: &L7EvalContext, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - loop { - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let parsed = match crate::l7::graphql::parse_graphql_http_request( - client, - config.graphql_max_body_bytes, - crate::l7::path::CanonicalizeOptions { - allow_encoded_slash: config.allow_encoded_slash, - ..Default::default() - }, - ) - .await - { - Ok(Some(parsed)) => parsed, - Ok(None) => return Ok(()), - Err(e) => { - if is_benign_connection_error(&e) { - debug!( - host = %ctx.host, - port = ctx.port, - error = %e, - "GraphQL L7 connection closed" - ); - } else { - let detail = - parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); - emit_parse_rejection(ctx, &detail, "l7-graphql"); - } - return Ok(()); - } - }; - - let req = parsed.request; - let graphql_info = parsed.info; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let (eval_target, redacted_target) = if let Some(ref resolver) = ctx.secret_resolver { - match secrets::rewrite_target_for_eval(&req.target, resolver) { - Ok(result) => (result.resolved, result.redacted), - Err(e) => { - warn!( - host = %ctx.host, - port = ctx.port, - error = %e, - "credential resolution failed in GraphQL request target, rejecting" - ); - let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; - client.write_all(response).await.into_diagnostic()?; - client.flush().await.into_diagnostic()?; - return Ok(()); - } - } - } else { - (req.target.clone(), req.target.clone()) - }; - - let request_info = L7RequestInfo { - action: req.action.clone(), - target: redacted_target.clone(), - query_params: req.query_params.clone(), - graphql: Some(graphql_info.clone()), - jsonrpc: None, - }; - - // Malformed or ambiguous GraphQL requests, such as duplicated GET - // control parameters, are rejected before policy evaluation. This - // keeps parser-differential cases fail-closed even if the endpoint is - // otherwise in audit mode. - let parse_error_reason = graphql_info - .error - .as_deref() - .map(|error| format!("GraphQL request rejected: {error}")); - let force_deny = parse_error_reason.is_some(); - let (allowed, reason) = if let Some(reason) = parse_error_reason { - (false, reason) - } else { - evaluate_l7_request(engine, ctx, &request_info)? - }; - - if close_if_stale(engine.generation_guard(), ctx) { - return Ok(()); - } - - let decision_str = match (allowed, config.enforcement) { - (_, _) if force_deny => "deny", - (true, _) => "allow", - (false, EnforcementMode::Audit) => "audit", - (false, EnforcementMode::Enforce) => "deny", - }; - - { - let (action_id, disposition_id, severity) = match decision_str { - "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), - "allow" | "audit" => ( - ActionId::Allowed, - DispositionId::Allowed, - SeverityId::Informational, - ), - _ => ( - ActionId::Other, - DispositionId::Other, - SeverityId::Informational, - ), - }; - let gql_summary = graphql_log_summary(&graphql_info); - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(action_id) - .disposition(disposition_id) - .severity(severity) - .http_request(HttpRequest::new( - &request_info.action, - OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), - )) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .firewall_rule(&ctx.policy_name, "l7-graphql") - .message(format!( - "GRAPHQL_L7_REQUEST {decision_str} {} {}:{}{} {gql_summary} reason={}", - request_info.action, ctx.host, ctx.port, redacted_target, reason, - )) - .build(); - ocsf_emit!(event); - } - - let _ = &eval_target; - - if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( - &req, - client, - upstream, - ctx.secret_resolver.as_deref(), - Some(engine.generation_guard()), - ) - .await?; - match outcome { - RelayOutcome::Reusable => {} - RelayOutcome::Consumed => { - debug!( - host = %ctx.host, - port = ctx.port, - "Upstream connection not reusable, closing GraphQL L7 relay" - ); - return Ok(()); - } - RelayOutcome::Upgraded { - overflow, - websocket_permessage_deflate, - } => { - let options = UpgradeRelayOptions { - websocket: WebSocketUpgradeBehavior { - permessage_deflate: websocket_permessage_deflate, - ..Default::default() - }, - ..Default::default() - }; - return handle_upgrade( - client, upstream, overflow, &ctx.host, ctx.port, options, - ) - .await; - } - } - } else { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &req, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - } -} - -fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { - if let Some(error) = &info.error { - return format!("graphql_error={error:?}"); - } - let ops: Vec = info - .operations - .iter() - .map(|op| { - let name = op.operation_name.as_deref().unwrap_or("-"); - let fields = if op.fields.is_empty() { - "-".to_string() - } else { - op.fields.join(",") - }; - let persisted = op - .persisted_query_hash - .as_deref() - .or(op.persisted_query_id.as_deref()) - .unwrap_or("-"); - format!( - "type={} name={} fields={} persisted={}", - op.operation_type, name, fields, persisted - ) - }) - .collect(); - format!("graphql_ops={}", ops.join(";")) -} - -pub(crate) fn jsonrpc_log_message( - decision: &str, - http_method: &str, - endpoint: &str, - info: &crate::l7::jsonrpc::JsonRpcRequestInfo, - params_sha256: &str, - policy_version: u64, - reason: &str, -) -> String { - let rpc_methods = jsonrpc_methods_for_log(info); - format!( - "JSONRPC_L7_REQUEST decision={decision} http_method={http_method} endpoint={endpoint} rpc_methods={rpc_methods} params_sha256={params_sha256} policy_version={policy_version} reason={reason}" - ) -} - -pub(crate) fn jsonrpc_methods_for_log(info: &crate::l7::jsonrpc::JsonRpcRequestInfo) -> String { - if info.calls.is_empty() { - return "-".to_string(); - } - info.calls - .iter() - .map(|call| call.method.as_str()) - .collect::>() - .join(",") -} - -struct JsonRpcEvaluation { - allowed: bool, - reason: String, - log_info: crate::l7::jsonrpc::JsonRpcRequestInfo, -} - -/// Check if a miette error represents a benign connection close. -/// -/// TLS handshake EOF, missing `close_notify`, connection resets, and broken -/// pipes are all normal lifecycle events for proxied connections — not worth -/// a WARN that interrupts the user's terminal. -fn is_benign_connection_error(err: &miette::Report) -> bool { - const BENIGN: &[&str] = &[ - "close_notify", - "tls handshake eof", - "connection reset", - "broken pipe", - "unexpected eof", - "client disconnected mid-request", - ]; - let msg = err.to_string().to_ascii_lowercase(); - BENIGN.iter().any(|pat| msg.contains(pat)) -} - -/// Evaluate an L7 request against the OPA engine. -/// -/// Returns `(allowed, deny_reason)`. -pub fn evaluate_l7_request( - engine: &TunnelPolicyEngine, - ctx: &L7EvalContext, - request: &L7RequestInfo, -) -> Result<(bool, String)> { - if let Some(jsonrpc) = &request.jsonrpc - && jsonrpc.is_batch - && !jsonrpc.calls.is_empty() - { - for call in &jsonrpc.calls { - let item_request = jsonrpc_request_for_call(request, call); - let (allowed, reason) = evaluate_l7_request_once(engine, ctx, &item_request)?; - if !allowed { - return Ok((false, reason)); - } - } - return Ok((true, String::new())); - } - - evaluate_l7_request_once(engine, ctx, request) -} - -fn evaluate_jsonrpc_l7_request_for_log( - engine: &TunnelPolicyEngine, - ctx: &L7EvalContext, - request: &L7RequestInfo, - jsonrpc: &crate::l7::jsonrpc::JsonRpcRequestInfo, -) -> Result { - if jsonrpc.is_batch && !jsonrpc.calls.is_empty() { - let mut denied_calls = Vec::new(); - let mut first_denied_reason = None; - for call in &jsonrpc.calls { - let item_request = jsonrpc_request_for_call(request, call); - let (allowed, reason) = evaluate_l7_request_once(engine, ctx, &item_request)?; - if !allowed { - if first_denied_reason.is_none() { - first_denied_reason = Some(reason); - } - denied_calls.push(call.clone()); - } - } - - if denied_calls.is_empty() { - return Ok(JsonRpcEvaluation { - allowed: true, - reason: String::new(), - log_info: jsonrpc.clone(), - }); - } - - return Ok(JsonRpcEvaluation { - allowed: false, - reason: first_denied_reason.unwrap_or_else(|| "request denied by policy".to_string()), - log_info: crate::l7::jsonrpc::JsonRpcRequestInfo { - calls: denied_calls, - is_batch: true, - error: None, - }, - }); - } - - let (allowed, reason) = evaluate_l7_request_once(engine, ctx, request)?; - Ok(JsonRpcEvaluation { - allowed, - reason, - log_info: jsonrpc.clone(), - }) -} - -fn jsonrpc_request_for_call( - request: &L7RequestInfo, - call: &crate::l7::jsonrpc::JsonRpcCallInfo, -) -> L7RequestInfo { - let mut item_request = request.clone(); - item_request.jsonrpc = Some(crate::l7::jsonrpc::JsonRpcRequestInfo { - calls: vec![call.clone()], - is_batch: false, - error: None, - }); - item_request -} - -fn evaluate_l7_request_once( - engine: &TunnelPolicyEngine, - ctx: &L7EvalContext, - request: &L7RequestInfo, -) -> Result<(bool, String)> { - if engine.is_stale() { - return Err(miette!( - "L7 tunnel policy generation is stale [captured_generation:{} current_generation:{}]", - engine.captured_generation(), - engine.current_generation(), - )); - } - - let input_json = serde_json::json!({ - "network": { - "host": ctx.host, - "port": ctx.port, - }, - "exec": { - "path": ctx.binary_path, - "ancestors": ctx.ancestors, - "cmdline_paths": ctx.cmdline_paths, - }, - "request": { - "method": request.action, - "path": request.target, - "query_params": request.query_params.clone(), - "graphql": request.graphql.clone(), - "jsonrpc": request.jsonrpc.as_ref().map(|j| { - let call = if j.is_batch { None } else { j.calls.first() }; - serde_json::json!({ - "method": call.map(|call| call.method.as_str()), - "params": call.map(|call| call.params.clone()).unwrap_or_default(), - "error": j.error, - }) - }), - } - }); - - let mut engine = engine - .engine() - .lock() - .map_err(|_| miette!("OPA engine lock poisoned"))?; - - engine - .set_input_json(&input_json.to_string()) - .map_err(|e| miette!("{e}"))?; - - let allowed = engine - .eval_rule("data.openshell.sandbox.allow_request".into()) - .map_err(|e| miette!("{e}"))?; - let allowed = allowed == regorus::Value::from(true); - - let reason = if allowed { - String::new() - } else { - let val = engine - .eval_rule("data.openshell.sandbox.request_deny_reason".into()) - .map_err(|e| miette!("{e}"))?; - match val { - regorus::Value::String(s) => s.to_string(), - regorus::Value::Undefined => "request denied by policy".to_string(), - other => other.to_string(), - } - }; - - Ok((allowed, reason)) -} - -/// Relay HTTP traffic with credential injection only (no L7 OPA evaluation). -/// -/// Used when TLS is auto-terminated but no L7 policy (`protocol` + `access`/`rules`) -/// is configured. Parses HTTP requests minimally to rewrite credential -/// placeholders and log requests for observability, then forwards everything. -pub async fn relay_passthrough_with_credentials( - client: &mut C, - upstream: &mut U, - ctx: &L7EvalContext, - generation_guard: &PolicyGenerationGuard, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - // Passthrough path: no L7 policy is enforced here, so use default - // (strict) canonicalization options. Calls to GitLab-style APIs that - // need `%2F` must be configured as L7 endpoints so the per-endpoint - // `allow_encoded_slash` opt-in applies. - let provider = crate::l7::rest::RestProvider::default(); - let mut request_count: u64 = 0; - let resolver = ctx.secret_resolver.as_deref(); - - loop { - if close_if_stale(generation_guard, ctx) { - return Ok(()); - } - - // Read next request from client. - let req = match provider.parse_request(client).await { - Ok(Some(req)) => req, - Ok(None) => break, // Client closed connection. - Err(e) => { - if is_benign_connection_error(&e) { - break; - } - let detail = - parse_rejection_detail(&e.to_string(), ParseRejectionMode::Passthrough); - emit_parse_rejection(ctx, &detail, "http-parser"); - return Ok(()); - } - }; - - if close_if_stale(generation_guard, ctx) { - return Ok(()); - } - - request_count += 1; - - // Resolve and redact the target for logging. - let redacted_target = if let Some(ref res) = ctx.secret_resolver { - match secrets::rewrite_target_for_eval(&req.target, res) { - Ok(result) => result.redacted, - Err(e) => { - warn!( - host = %ctx.host, - port = ctx.port, - error = %e, - "credential resolution failed in request target, rejecting" - ); - let response = b"HTTP/1.1 500 Internal Server Error\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; - client.write_all(response).await.into_diagnostic()?; - client.flush().await.into_diagnostic()?; - return Ok(()); - } - } - } else { - req.target.clone() - }; - - // Log for observability via OCSF HTTP Activity event. - // Uses redacted_target (path only, no query params) to avoid logging secrets. - let has_creds = resolver.is_some(); - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Allowed) - .disposition(DispositionId::Allowed) - .severity(SeverityId::Informational) - .http_request(HttpRequest::new( - &req.action, - OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), - )) - .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) - .message(format!( - "HTTP_REQUEST {} {}:{}{} credentials_injected={has_creds} request_num={request_count}", - req.action, ctx.host, ctx.port, redacted_target, - )) - .build(); - ocsf_emit!(event); - } - - let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await - { - Ok(req) => req, - Err(e) => { - warn!( - host = %ctx.host, - port = ctx.port, - error = %e, - "Token grant failed in passthrough relay" - ); - write_bad_gateway_response(client).await?; - return Ok(()); - } - }; - - // Forward request with credential rewriting and relay the response. - // relay_http_request_with_resolver handles both directions: it sends - // the request upstream and reads the response back to the client. - let outcome = crate::l7::rest::relay_http_request_with_options_guarded( - &req_with_auth, - client, - upstream, - crate::l7::rest::RelayRequestOptions { - resolver, - generation_guard: Some(generation_guard), - ..Default::default() - }, - ) - .await?; - - match outcome { - RelayOutcome::Reusable => {} // continue loop - RelayOutcome::Consumed => break, - RelayOutcome::Upgraded { overflow, .. } => { - return handle_upgrade( - client, - upstream, - overflow, - &ctx.host, - ctx.port, - UpgradeRelayOptions::default(), - ) - .await; - } - } - } - - debug!( - host = %ctx.host, - port = ctx.port, - total_requests = request_count, - "Credential injection relay completed" - ); - - Ok(()) -} - -async fn write_bad_gateway_response(client: &mut W) -> Result<()> -where - W: AsyncWrite + Unpin, -{ - let response = b"HTTP/1.1 502 Bad Gateway\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; - client.write_all(response).await.into_diagnostic()?; - client.flush().await.into_diagnostic()?; - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::opa::{NetworkInput, OpaEngine}; - use std::path::PathBuf; - use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; - - const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); - - fn rest_token_grant_relay_context( - resolver_response: std::result::Result<&str, &str>, - ) -> ( - L7EndpointConfig, - TunnelPolicyEngine, - L7EvalContext, - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture, - ) { - let data = r#" -network_policies: - rest_api: - name: rest_api - endpoints: - - host: api.example.test - port: 8080 - protocol: rest - enforcement: enforce - rules: - - allow: - method: GET - path: "/v1/**" - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.test".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let (endpoint_config, generation) = engine - .query_endpoint_config_with_generation(&input) - .unwrap(); - let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); - let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); - let provider_key = "api.example.test\t8080\t/v1/**\tprovider:access_token"; - let fixture = match resolver_response { - Ok(token) => { - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::success( - provider_key, - token, - ) - } - Err(error) => { - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::failure( - provider_key, - error, - ) - } - }; - let ctx = L7EvalContext { - host: "api.example.test".into(), - port: 8080, - policy_name: "rest_api".into(), - binary_path: "/usr/bin/curl".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: Some(fixture.dynamic_credentials()), - token_grant_resolver: Some(fixture.resolver()), - }; - - (config, tunnel_engine, ctx, fixture) - } - - fn passthrough_token_grant_relay_context( - resolver_response: std::result::Result<&str, &str>, - ) -> ( - PolicyGenerationGuard, - L7EvalContext, - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture, - ) { - let policy_data = "network_policies: {}\n"; - let engine = OpaEngine::from_strings(TEST_POLICY, policy_data).unwrap(); - let generation_guard = engine - .generation_guard(engine.current_generation()) - .unwrap(); - let provider_key = "api.example.test\t8080\t/v1/**\tprovider:access_token"; - let fixture = match resolver_response { - Ok(token) => { - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::success( - provider_key, - token, - ) - } - Err(error) => { - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::failure( - provider_key, - error, - ) - } - }; - let ctx = L7EvalContext { - host: "api.example.test".into(), - port: 8080, - policy_name: "rest_api".into(), - binary_path: "/usr/bin/curl".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: Some(fixture.dynamic_credentials()), - token_grant_resolver: Some(fixture.resolver()), - }; - - (generation_guard, ctx, fixture) - } - - fn authorization_header_count(headers: &str) -> usize { - headers - .lines() - .filter(|line| { - line.split_once(':') - .is_some_and(|(name, _)| name.eq_ignore_ascii_case("authorization")) - }) - .count() - } - - #[test] - fn parse_rejection_detail_adds_l7_hint_for_encoded_slash() { - let detail = parse_rejection_detail( - "HTTP request-target rejected: request-target contains an encoded '/' (%2F) which is not allowed on this endpoint", - ParseRejectionMode::L7Endpoint, - ); - - assert!(detail.contains("allow_encoded_slash: true")); - assert!(detail.contains("upstream requires encoded slashes")); - } - - #[test] - fn parse_rejection_detail_adds_passthrough_hint_for_encoded_slash() { - let detail = parse_rejection_detail( - "HTTP request-target rejected: request-target contains an encoded '/' (%2F) which is not allowed on this endpoint", - ParseRejectionMode::Passthrough, - ); - - assert!(detail.contains("protocol: rest")); - assert!(detail.contains("allow_encoded_slash: true")); - assert!(detail.contains("tls: skip")); - } - - #[test] - fn parse_rejection_detail_preserves_other_errors() { - let error = "HTTP headers contain invalid UTF-8"; - - assert_eq!( - parse_rejection_detail(error, ParseRejectionMode::L7Endpoint), - error - ); - } - - #[tokio::test] - async fn l7_rest_relay_injects_token_grant_authorization_header() { - let (config, tunnel_engine, ctx, fixture) = - rest_token_grant_relay_context(Ok("grant-token")); - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_with_inspection( - &config, - tunnel_engine, - &mut relay_client, - &mut relay_upstream, - &ctx, - ) - .await - }); - - app.write_all( - b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n", - ) - .await - .unwrap(); - - let mut upstream_request = [0u8; 1024]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut upstream_request), - ) - .await - .expect("request should reach upstream") - .unwrap(); - let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); - - assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); - assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); - assert!(!upstream_request.contains("stale-token")); - assert_eq!(authorization_header_count(&upstream_request), 1); - - upstream - .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") - .await - .unwrap(); - - let mut client_response = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - app.read(&mut client_response), - ) - .await - .expect("response should reach client") - .unwrap(); - assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); - drop(app); - - tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("relay should finish") - .unwrap() - .unwrap(); - - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); - } - - #[tokio::test] - async fn l7_rest_relay_token_grant_failure_does_not_forward_request() { - let (config, tunnel_engine, ctx, fixture) = - rest_token_grant_relay_context(Err("oauth unavailable")); - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_with_inspection( - &config, - tunnel_engine, - &mut relay_client, - &mut relay_upstream, - &ctx, - ) - .await - }); - - app.write_all( - b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nConnection: close\r\n\r\n", - ) - .await - .unwrap(); - - tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("relay should finish") - .unwrap() - .unwrap(); - - let mut client_response = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - app.read(&mut client_response), - ) - .await - .expect("bad gateway response should reach client") - .unwrap(); - assert!(String::from_utf8_lossy(&client_response[..n]).contains("502 Bad Gateway")); - - let mut upstream_request = [0u8; 128]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut upstream_request), - ) - .await - .expect("upstream should close without forwarded data") - .unwrap(); - assert_eq!(n, 0, "unauthenticated request must not reach upstream"); - - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); - } - - #[tokio::test] - async fn passthrough_relay_injects_token_grant_authorization_header() { - let (generation_guard, ctx, fixture) = - passthrough_token_grant_relay_context(Ok("grant-token")); - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_passthrough_with_credentials( - &mut relay_client, - &mut relay_upstream, - &ctx, - &generation_guard, - ) - .await - }); - - app.write_all( - b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n", - ) - .await - .unwrap(); - - let mut upstream_request = [0u8; 1024]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut upstream_request), - ) - .await - .expect("request should reach upstream") - .unwrap(); - let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); - - assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); - assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); - assert!(!upstream_request.contains("stale-token")); - assert_eq!(authorization_header_count(&upstream_request), 1); - - upstream - .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") - .await - .unwrap(); - - let mut client_response = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - app.read(&mut client_response), - ) - .await - .expect("response should reach client") - .unwrap(); - assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); - drop(app); - - tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("relay should finish") - .unwrap() - .unwrap(); - - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); - } - - #[tokio::test] - async fn passthrough_relay_token_grant_failure_returns_bad_gateway_without_forwarding() { - let (generation_guard, ctx, fixture) = - passthrough_token_grant_relay_context(Err("oauth unavailable")); - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_passthrough_with_credentials( - &mut relay_client, - &mut relay_upstream, - &ctx, - &generation_guard, - ) - .await - }); - - app.write_all( - b"GET /v1/projects HTTP/1.1\r\nHost: api.example.test\r\nConnection: close\r\n\r\n", - ) - .await - .unwrap(); - - tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("relay should finish") - .unwrap() - .unwrap(); - - let mut client_response = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - app.read(&mut client_response), - ) - .await - .expect("bad gateway response should reach client") - .unwrap(); - assert!(String::from_utf8_lossy(&client_response[..n]).contains("502 Bad Gateway")); - - let mut upstream_request = [0u8; 128]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut upstream_request), - ) - .await - .expect("upstream should close without forwarded data") - .unwrap(); - assert_eq!(n, 0, "unauthenticated request must not reach upstream"); - - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); - } - - #[test] - fn websocket_text_policy_requires_explicit_message_rule() { - let data = r#" -network_policies: - ws_api: - name: ws_api - endpoints: - - host: gateway.example.test - port: 443 - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/ws" - binaries: - - { path: /usr/bin/node } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "gateway.example.test".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let generation = engine - .evaluate_network_action_with_generation(&input) - .unwrap() - .1; - let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); - let ctx = L7EvalContext { - host: "gateway.example.test".into(), - port: 443, - policy_name: "ws_api".into(), - binary_path: "/usr/bin/node".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - let request = L7RequestInfo { - action: "WEBSOCKET_TEXT".into(), - target: "/ws".into(), - query_params: std::collections::HashMap::new(), - graphql: None, - jsonrpc: None, - }; - - let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); - - assert!(!allowed); - assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); - } - - #[test] - fn jsonrpc_batch_evaluates_each_call() { - let data = r#" -network_policies: - jsonrpc_api: - name: jsonrpc_api - endpoints: - - host: api.example.test - port: 443 - protocol: json-rpc - enforcement: enforce - rules: - - allow: - method: POST - path: "/mcp" - rpc_method: "tools/list" - - allow: - method: POST - path: "/mcp" - rpc_method: "tools/call" - params: - name: read_status - deny_rules: - - rpc_method: "tools/call" - params: - name: blocked_action - - rpc_method: "tools/delete" - binaries: - - { path: /usr/bin/node } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let tunnel_engine = engine - .clone_engine_for_tunnel(engine.current_generation()) - .unwrap(); - let ctx = L7EvalContext { - host: "api.example.test".into(), - port: 443, - policy_name: "jsonrpc_api".into(), - binary_path: "/usr/bin/node".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - let mut request = L7RequestInfo { - action: "POST".into(), - target: "/mcp".into(), - query_params: std::collections::HashMap::new(), - graphql: None, - jsonrpc: Some(crate::l7::jsonrpc::parse_jsonrpc_body( - br#"[ - {"jsonrpc":"2.0","id":1,"method":"tools/list"}, - {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_status"}} - ]"#, - )), - }; - - let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); - assert!(allowed, "{reason}"); - - request.jsonrpc = Some(crate::l7::jsonrpc::parse_jsonrpc_body( - br#"[ - {"jsonrpc":"2.0","id":1,"method":"tools/list"}, - {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"blocked_action"}}, - {"jsonrpc":"2.0","id":3,"method":"tools/delete","params":{"name":"purge_cache"}} - ]"#, - )); - let (allowed, _) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); - assert!(!allowed); - - let jsonrpc = request.jsonrpc.as_ref().expect("jsonrpc request"); - let evaluation = - evaluate_jsonrpc_l7_request_for_log(&tunnel_engine, &ctx, &request, jsonrpc).unwrap(); - assert!(!evaluation.allowed); - assert!(evaluation.log_info.is_batch); - assert_eq!( - jsonrpc_methods_for_log(&evaluation.log_info), - "tools/call,tools/delete" - ); - - let full_params_sha256 = jsonrpc.params_sha256().expect("full batch params digest"); - let log_params_sha256 = evaluation - .log_info - .params_sha256() - .expect("logged batch params digest"); - assert_ne!(full_params_sha256, log_params_sha256); - let message = jsonrpc_log_message( - "deny", - "POST", - "api.example.test:443/mcp", - &evaluation.log_info, - &log_params_sha256, - 42, - &evaluation.reason, - ); - assert!(message.contains("rpc_methods=tools/call,tools/delete")); - assert!(message.contains("params_sha256=")); - assert!(!message.contains("params_sha256=sha256:")); - assert!(message.contains("policy_version=42")); - assert!(!message.contains("tools/list")); - assert!(!message.contains("blocked_action")); - assert!(!message.contains("purge_cache")); - } - - #[test] - fn jsonrpc_log_records_digest_not_args() { - let info = crate::l7::jsonrpc::parse_jsonrpc_body( - br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"delete_resource","arguments":{"scope":"secret-scope"}}}"#, - ); - let params_sha256 = info.params_sha256().expect("params digest"); - let message = jsonrpc_log_message( - "deny", - "POST", - "mcp.example.com:443/mcp", - &info, - ¶ms_sha256, - 42, - "request denied by policy", - ); - - assert!(message.contains("endpoint=mcp.example.com:443/mcp")); - assert!(message.contains("rpc_methods=tools/call")); - assert!(message.contains("params_sha256=")); - assert!(!message.contains("params_sha256=sha256:")); - assert!(message.contains("policy_version=42")); - assert!(!message.contains("delete_resource")); - assert!(!message.contains("secret-scope")); - - let batch = crate::l7::jsonrpc::parse_jsonrpc_body( - br#"[ - {"jsonrpc":"2.0","id":1,"method":"tools/list"}, - {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"delete_resource"}} - ]"#, - ); - let batch_params_sha256 = batch.params_sha256().expect("batch params digest"); - let batch_message = jsonrpc_log_message( - "allow", - "POST", - "mcp.example.com:443/mcp", - &batch, - &batch_params_sha256, - 43, - "", - ); - - assert!(batch_message.starts_with("JSONRPC_L7_REQUEST ")); - assert!(batch_message.contains("rpc_methods=tools/list,tools/call")); - assert!(batch_message.contains("params_sha256=")); - assert!(!batch_message.contains("params_sha256=sha256:")); - assert!(batch_message.contains("policy_version=43")); - assert!(!batch_message.contains("rpc_method=")); - assert!(!batch_message.contains("delete_resource")); - - let no_params = crate::l7::jsonrpc::parse_jsonrpc_body( - br#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#, - ); - let no_params_sha256 = no_params - .params_sha256() - .unwrap_or_else(|| "".to_string()); - let no_params_message = jsonrpc_log_message( - "allow", - "POST", - "mcp.example.com:443/mcp", - &no_params, - &no_params_sha256, - 44, - "", - ); - assert!(no_params_message.contains("rpc_methods=initialize")); - assert!(no_params_message.contains("params_sha256=")); - } - - #[tokio::test] - async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { - let data = r#" -network_policies: - route_api: - name: route_api - endpoints: - - host: gateway.example.test - port: 443 - protocol: rest - enforcement: enforce - rules: - - allow: - method: GET - path: "/ws" - binaries: - - { path: /usr/bin/node } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let tunnel_engine = engine - .clone_engine_for_tunnel(engine.current_generation()) - .unwrap(); - let configs = vec![L7EndpointConfig { - protocol: L7Protocol::Rest, - path: "/ws".into(), - tls: crate::l7::TlsMode::Auto, - enforcement: EnforcementMode::Enforce, - graphql_max_body_bytes: 0, - json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, - allow_encoded_slash: false, - websocket_credential_rewrite: true, - request_body_credential_rewrite: false, - websocket_graphql_policy: false, - }]; - let ctx = L7EvalContext { - host: "gateway.example.test".into(), - port: 443, - policy_name: "route_api".into(), - binary_path: "/usr/bin/node".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_with_route_selection( - &configs, - tunnel_engine, - &mut relay_client, - &mut relay_upstream, - &ctx, - ) - .await - }); - - app.write_all( - b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", - ) - .await - .unwrap(); - - let mut forwarded = [0u8; 1024]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut forwarded), - ) - .await - .expect("upgrade request should reach upstream") - .unwrap(); - let forwarded = String::from_utf8_lossy(&forwarded[..n]); - assert!(forwarded.contains("Upgrade: websocket\r\n")); - assert!(forwarded.contains("Connection: Upgrade\r\n")); - - upstream - .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: invalid\r\n\r\n", - ) - .await - .unwrap(); - - let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("relay should fail closed on invalid accept") - .unwrap() - .expect_err("invalid accept must fail the route-selected relay"); - assert!(err.to_string().contains("Sec-WebSocket-Accept")); - - let mut response = [0u8; 1]; - let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) - .await - .expect("client side should close without 101") - .unwrap(); - assert_eq!(n, 0, "invalid response must not forward 101 headers"); - } - - #[tokio::test] - async fn route_selected_websocket_rewrites_text_credentials_after_upgrade() { - let data = r#" -network_policies: - route_api: - name: route_api - endpoints: - - host: gateway.example.test - port: 443 - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/ws" - - allow: - method: WEBSOCKET_TEXT - path: "/ws" - websocket_credential_rewrite: true - binaries: - - { path: /usr/bin/node } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let tunnel_engine = engine - .clone_engine_for_tunnel(engine.current_generation()) - .unwrap(); - let configs = vec![L7EndpointConfig { - protocol: L7Protocol::Websocket, - path: "/ws".into(), - tls: crate::l7::TlsMode::Auto, - enforcement: EnforcementMode::Enforce, - graphql_max_body_bytes: 0, - json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, - allow_encoded_slash: false, - websocket_credential_rewrite: true, - request_body_credential_rewrite: false, - websocket_graphql_policy: false, - }]; - let (child_env, resolver) = SecretResolver::from_provider_env( - std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), - ); - let placeholder = child_env.get("DISCORD_BOT_TOKEN").expect("placeholder env"); - let ctx = L7EvalContext { - host: "gateway.example.test".into(), - port: 443, - policy_name: "route_api".into(), - binary_path: "/usr/bin/node".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: resolver.map(Arc::new), - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_with_route_selection( - &configs, - tunnel_engine, - &mut relay_client, - &mut relay_upstream, - &ctx, - ) - .await - }); - - app.write_all( - b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", - ) - .await - .unwrap(); - - let mut forwarded = [0u8; 1024]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut forwarded), - ) - .await - .expect("upgrade request should reach upstream") - .unwrap(); - let forwarded = String::from_utf8_lossy(&forwarded[..n]); - assert!(forwarded.contains("Upgrade: websocket\r\n")); - - upstream - .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", - ) - .await - .unwrap(); - - let mut response = [0u8; 1024]; - let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) - .await - .expect("client should receive upgrade response") - .unwrap(); - assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); - - let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); - app.write_all(&masked_text_frame(payload.as_bytes())) - .await - .unwrap(); - - let (masked, rewritten) = tokio::time::timeout( - std::time::Duration::from_secs(1), - read_text_frame(&mut upstream), - ) - .await - .expect("rewritten websocket text should reach upstream") - .unwrap(); - assert!(masked, "client-to-server frame must remain masked"); - assert_eq!(rewritten, r#"{"op":2,"d":{"token":"real-token"}}"#); - assert!(!rewritten.contains(placeholder)); - - drop(app); - drop(upstream); - let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; - } - - #[tokio::test] - async fn route_selected_graphql_websocket_rewrites_connection_init_credentials_after_upgrade() { - let data = r#" -network_policies: - route_api: - name: route_api - endpoints: - - host: gateway.example.test - port: 443 - path: "/graphql" - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/graphql" - - allow: - operation_type: query - fields: [viewer] - websocket_credential_rewrite: true - binaries: - - { path: /usr/bin/node } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let tunnel_engine = engine - .clone_engine_for_tunnel(engine.current_generation()) - .unwrap(); - let configs = vec![L7EndpointConfig { - protocol: L7Protocol::Websocket, - path: "/graphql".into(), - tls: crate::l7::TlsMode::Auto, - enforcement: EnforcementMode::Enforce, - graphql_max_body_bytes: 0, - json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, - allow_encoded_slash: false, - websocket_credential_rewrite: true, - request_body_credential_rewrite: false, - websocket_graphql_policy: true, - }]; - let (child_env, resolver) = SecretResolver::from_provider_env( - std::iter::once(("T".to_string(), "real-token".to_string())).collect(), - ); - let placeholder = child_env.get("T").expect("placeholder env"); - let ctx = L7EvalContext { - host: "gateway.example.test".into(), - port: 443, - policy_name: "route_api".into(), - binary_path: "/usr/bin/node".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: resolver.map(Arc::new), - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_with_route_selection( - &configs, - tunnel_engine, - &mut relay_client, - &mut relay_upstream, - &ctx, - ) - .await - }); - - app.write_all( - b"GET /graphql HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", - ) - .await - .unwrap(); - - let mut forwarded = [0u8; 1024]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut forwarded), - ) - .await - .expect("upgrade request should reach upstream") - .unwrap(); - let forwarded = String::from_utf8_lossy(&forwarded[..n]); - assert!(forwarded.contains("GET /graphql HTTP/1.1")); - assert!(forwarded.contains("Upgrade: websocket\r\n")); - - upstream - .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", - ) - .await - .unwrap(); - - let mut response = [0u8; 1024]; - let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) - .await - .expect("client should receive upgrade response") - .unwrap(); - assert!(String::from_utf8_lossy(&response[..n]).contains("101 Switching Protocols")); - - let payload = format!( - r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# - ); - app.write_all(&masked_text_frame(payload.as_bytes())) - .await - .unwrap(); - - let (masked, rewritten) = tokio::time::timeout( - std::time::Duration::from_secs(1), - read_text_frame(&mut upstream), - ) - .await - .expect("rewritten GraphQL WebSocket control message should reach upstream") - .unwrap(); - assert!(masked, "client-to-server frame must remain masked"); - assert_eq!( - rewritten, - r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# - ); - assert!(!rewritten.contains(placeholder)); - - drop(app); - drop(upstream); - let _ = tokio::time::timeout(std::time::Duration::from_secs(1), relay).await; - } - - fn masked_text_frame(payload: &[u8]) -> Vec { - let mask = [0x11, 0x22, 0x33, 0x44]; - assert!( - payload.len() <= 125, - "test helper only supports small frames" - ); - let payload_len = u8::try_from(payload.len()).expect("small frame length"); - let mut frame = vec![0x81, 0x80 | payload_len]; - frame.extend_from_slice(&mask); - frame.extend( - payload - .iter() - .enumerate() - .map(|(idx, byte)| byte ^ mask[idx % 4]), - ); - frame - } - - async fn read_text_frame( - reader: &mut R, - ) -> std::io::Result<(bool, String)> { - let mut header = [0u8; 2]; - reader.read_exact(&mut header).await?; - assert_eq!(header[0] & 0x0f, 0x1, "expected text frame"); - let masked = header[1] & 0x80 != 0; - let payload_len = usize::from(header[1] & 0x7f); - assert!(payload_len <= 125, "test helper only supports small frames"); - let mut mask = [0u8; 4]; - if masked { - reader.read_exact(&mut mask).await?; - } - let mut payload = vec![0u8; payload_len]; - reader.read_exact(&mut payload).await?; - if masked { - for (idx, byte) in payload.iter_mut().enumerate() { - *byte ^= mask[idx % 4]; - } - } - Ok((masked, String::from_utf8(payload).expect("text payload"))) - } - - #[tokio::test] - async fn l7_relay_closes_keep_alive_tunnel_after_policy_generation_change() { - let initial_data = r#" -network_policies: - rest_api: - name: rest_api - endpoints: - - host: api.example.test - port: 8080 - protocol: rest - enforcement: enforce - rules: - - allow: - method: POST - path: "/write" - binaries: - - { path: /usr/bin/curl } -"#; - let reloaded_data = r#" -network_policies: - rest_api: - name: rest_api - endpoints: - - host: api.example.test - port: 8080 - protocol: rest - enforcement: enforce - rules: - - allow: - method: GET - path: "/write" - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, initial_data).unwrap(); - let input = NetworkInput { - host: "api.example.test".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let (endpoint_config, generation) = engine - .query_endpoint_config_with_generation(&input) - .unwrap(); - let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); - let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); - let ctx = L7EvalContext { - host: "api.example.test".into(), - port: 8080, - policy_name: "rest_api".into(), - binary_path: "/usr/bin/curl".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_with_inspection( - &config, - tunnel_engine, - &mut relay_client, - &mut relay_upstream, - &ctx, - ) - .await - }); - - app.write_all( - b"POST /write HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n", - ) - .await - .unwrap(); - - let mut first_upstream = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut first_upstream), - ) - .await - .expect("first request should reach upstream") - .unwrap(); - let first_upstream = String::from_utf8_lossy(&first_upstream[..n]); - assert!(first_upstream.starts_with("POST /write HTTP/1.1")); - - upstream - .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nOK") - .await - .unwrap(); - - let mut first_response = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - app.read(&mut first_response), - ) - .await - .expect("first response should reach client") - .unwrap(); - let first_response = String::from_utf8_lossy(&first_response[..n]); - assert!(first_response.contains("200 OK")); - - engine.reload(TEST_POLICY, reloaded_data).unwrap(); - app.write_all( - b"POST /write HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n", - ) - .await - .unwrap(); - - tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("relay should close stale tunnel") - .unwrap() - .unwrap(); - - let mut second_upstream = [0u8; 128]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut second_upstream), - ) - .await - .expect("upstream side should close") - .unwrap(); - assert_eq!(n, 0, "stale request must not be forwarded upstream"); - } - - #[tokio::test] - async fn passthrough_relay_closes_keep_alive_tunnel_after_policy_generation_change() { - let policy_data = "network_policies: {}\n"; - let engine = OpaEngine::from_strings(TEST_POLICY, policy_data).unwrap(); - let generation_guard = engine - .generation_guard(engine.current_generation()) - .unwrap(); - let ctx = L7EvalContext { - host: "api.example.test".into(), - port: 8080, - policy_name: "rest_api".into(), - binary_path: "/usr/bin/curl".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_passthrough_with_credentials( - &mut relay_client, - &mut relay_upstream, - &ctx, - &generation_guard, - ) - .await - }); - - app.write_all( - b"GET /first HTTP/1.1\r\nHost: api.example.test\r\nConnection: keep-alive\r\n\r\n", - ) - .await - .unwrap(); - - let mut first_upstream = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut first_upstream), - ) - .await - .expect("first passthrough request should reach upstream") - .unwrap(); - let first_upstream = String::from_utf8_lossy(&first_upstream[..n]); - assert!(first_upstream.starts_with("GET /first HTTP/1.1")); - - upstream - .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nOK") - .await - .unwrap(); - - let mut first_response = [0u8; 512]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - app.read(&mut first_response), - ) - .await - .expect("first passthrough response should reach client") - .unwrap(); - let first_response = String::from_utf8_lossy(&first_response[..n]); - assert!(first_response.contains("200 OK")); - - engine.reload(TEST_POLICY, policy_data).unwrap(); - app.write_all( - b"GET /second HTTP/1.1\r\nHost: api.example.test\r\nConnection: keep-alive\r\n\r\n", - ) - .await - .unwrap(); - - tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("passthrough relay should close stale tunnel") - .unwrap() - .unwrap(); - - let mut second_upstream = [0u8; 128]; - let n = tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read(&mut second_upstream), - ) - .await - .expect("upstream side should close") - .unwrap(); - assert_eq!( - n, 0, - "stale passthrough request must not be forwarded upstream" - ); - } - - #[tokio::test] - async fn jsonrpc_relay_denies_method_not_in_allow_list() { - let data = r" -network_policies: - mcp_api: - name: mcp_api - endpoints: - - host: mcp.example.test - port: 8000 - path: /mcp - protocol: json-rpc - enforcement: enforce - rules: - - allow: - rpc_method: initialize - binaries: - - { path: /usr/bin/python3 } -"; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "mcp.example.test".into(), - port: 8000, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let (endpoint_config, generation) = engine - .query_endpoint_config_with_generation(&input) - .unwrap(); - let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); - let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); - let ctx = L7EvalContext { - host: "mcp.example.test".into(), - port: 8000, - policy_name: "mcp_api".into(), - binary_path: "/usr/bin/python3".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - - let (mut app, mut relay_client) = tokio::io::duplex(8192); - let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); - let relay = tokio::spawn(async move { - relay_with_inspection( - &config, - tunnel_engine, - &mut relay_client, - &mut relay_upstream, - &ctx, - ) - .await - }); - - let body = - br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"list_repos"}}"#; - let request = format!( - "POST /mcp HTTP/1.1\r\nHost: mcp.example.test:8000\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", - body.len() - ); - app.write_all(request.as_bytes()).await.unwrap(); - app.write_all(body).await.unwrap(); - - let mut response = [0u8; 512]; - let n = tokio::time::timeout(std::time::Duration::from_secs(2), app.read(&mut response)) - .await - .expect("relay should respond without reaching upstream") - .unwrap(); - let response = String::from_utf8_lossy(&response[..n]); - assert!( - response.contains("403"), - "tools/call not in allow list must be denied with 403, got: {response:?}" - ); - - let mut upstream_buf = [0u8; 128]; - let n = tokio::time::timeout( - std::time::Duration::from_millis(100), - upstream.read(&mut upstream_buf), - ) - .await - .unwrap_or(Ok(0)) - .unwrap_or(0); - assert_eq!(n, 0, "denied request must not be forwarded to upstream"); - - drop(app); - tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("relay should complete") - .unwrap() - .unwrap(); - } -} diff --git a/crates/openshell-sandbox/src/l7/websocket.rs b/crates/openshell-sandbox/src/l7/websocket.rs deleted file mode 100644 index 70c11f330..000000000 --- a/crates/openshell-sandbox/src/l7/websocket.rs +++ /dev/null @@ -1,1943 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! WebSocket relay for opt-in credential placeholder rewriting and message policy. -//! -//! The relay parses only client-to-server frames. Server-to-client bytes stay -//! raw passthrough so inspection and rewriting cannot expose response payloads. - -use crate::l7::relay::{L7EvalContext, evaluate_l7_request}; -use crate::l7::{EnforcementMode, L7RequestInfo}; -use crate::opa::TunnelPolicyEngine; -use crate::secrets::SecretResolver; -use flate2::{Compress, Compression, Decompress, FlushCompress, FlushDecompress, Status}; -use miette::{IntoDiagnostic, Result, miette}; -use openshell_ocsf::{ - ActionId, ActivityId, DispositionId, Endpoint, NetworkActivityBuilder, SeverityId, StatusId, - ocsf_emit, -}; -use std::collections::HashMap; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -const MAX_TEXT_MESSAGE_BYTES: usize = 1024 * 1024; -const MAX_RAW_FRAME_PAYLOAD_BYTES: u64 = 16 * 1024 * 1024; -const COPY_BUF_SIZE: usize = 8192; -const OPCODE_CONTINUATION: u8 = 0x0; -const OPCODE_TEXT: u8 = 0x1; -const OPCODE_BINARY: u8 = 0x2; -const OPCODE_CLOSE: u8 = 0x8; -const OPCODE_PING: u8 = 0x9; -const OPCODE_PONG: u8 = 0xA; - -#[derive(Debug)] -struct FrameHeader { - fin: bool, - rsv: u8, - opcode: u8, - masked: bool, - payload_len: u64, - mask_key: Option<[u8; 4]>, - raw_header: Vec, -} - -#[derive(Debug)] -enum FragmentState { - None, - Text { payload: Vec, compressed: bool }, - Binary, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(super) enum WebSocketCompression { - None, - PermessageDeflate, -} - -pub(super) struct InspectionOptions<'a> { - pub(super) engine: &'a TunnelPolicyEngine, - pub(super) ctx: &'a L7EvalContext, - pub(super) enforcement: EnforcementMode, - pub(super) target: String, - pub(super) query_params: HashMap>, - pub(super) graphql_policy: bool, -} - -pub(super) struct RelayOptions<'a> { - pub(super) policy_name: &'a str, - pub(super) resolver: Option<&'a SecretResolver>, - pub(super) inspector: Option>, - pub(super) compression: WebSocketCompression, -} - -/// Relay an upgraded WebSocket connection with optional client text inspection, -/// credential rewriting, and strict permessage-deflate handling. -pub(super) async fn relay_with_options( - client: &mut C, - upstream: &mut U, - overflow: Vec, - host: &str, - port: u16, - options: RelayOptions<'_>, -) -> Result<()> -where - C: AsyncRead + AsyncWrite + Unpin + Send, - U: AsyncRead + AsyncWrite + Unpin + Send, -{ - let (mut client_read, mut client_write) = tokio::io::split(client); - let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream); - - if !overflow.is_empty() { - client_write.write_all(&overflow).await.into_diagnostic()?; - client_write.flush().await.into_diagnostic()?; - } - - let client_to_server = - relay_client_to_server(&mut client_read, &mut upstream_write, host, port, &options); - let server_to_client = async { - tokio::io::copy(&mut upstream_read, &mut client_write) - .await - .into_diagnostic()?; - client_write.flush().await.into_diagnostic()?; - Ok::<(), miette::Report>(()) - }; - - let result = tokio::select! { - result = client_to_server => result, - result = server_to_client => result, - }; - let _ = upstream_write.shutdown().await; - let _ = client_write.shutdown().await; - result -} - -async fn relay_client_to_server( - reader: &mut R, - writer: &mut W, - host: &str, - port: u16, - options: &RelayOptions<'_>, -) -> Result<()> -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let mut fragments = FragmentState::None; - let mut close_seen = false; - - loop { - let Some(frame) = read_frame_header(reader).await.inspect_err(|e| { - emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(e)); - })? - else { - writer.shutdown().await.into_diagnostic()?; - return Ok(()); - }; - - if close_seen { - let e = miette!("websocket frame received after close frame"); - emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); - return Err(e); - } - - if let Err(e) = validate_frame_header(&frame, &fragments, options.compression) { - emit_protocol_failure(host, port, options.policy_name, protocol_failure_class(&e)); - return Err(e); - } - - match frame.opcode { - OPCODE_TEXT => { - let payload = read_masked_payload(reader, &frame).await.inspect_err(|e| { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(e), - ); - })?; - let compressed = frame.rsv == 0x40; - if frame.fin { - relay_text_payload( - writer, &frame, payload, false, compressed, host, port, options, - ) - .await - .inspect_err(|e| { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(e), - ); - })?; - } else { - fragments = FragmentState::Text { - payload, - compressed, - }; - } - } - OPCODE_CONTINUATION => match &mut fragments { - FragmentState::Text { - payload, - compressed, - } => { - let next = read_masked_payload(reader, &frame).await.inspect_err(|e| { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(e), - ); - })?; - if let Err(e) = append_text_fragment(payload, next) { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(&e), - ); - return Err(e); - } - if frame.fin { - let complete = std::mem::take(payload); - let was_compressed = *compressed; - fragments = FragmentState::None; - relay_text_payload( - writer, - &frame, - complete, - true, - was_compressed, - host, - port, - options, - ) - .await - .inspect_err(|e| { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(e), - ); - })?; - } - } - FragmentState::Binary => { - copy_raw_frame_payload(reader, writer, &frame) - .await - .inspect_err(|e| { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(e), - ); - })?; - if frame.fin { - fragments = FragmentState::None; - } - } - FragmentState::None => { - let e = - miette!("websocket continuation frame without active fragmented message"); - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(&e), - ); - return Err(e); - } - }, - OPCODE_BINARY => { - if !frame.fin { - fragments = FragmentState::Binary; - } - copy_raw_frame_payload(reader, writer, &frame) - .await - .inspect_err(|e| { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(e), - ); - })?; - } - OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG => { - relay_control_frame(reader, writer, &frame) - .await - .inspect_err(|e| { - emit_protocol_failure( - host, - port, - options.policy_name, - protocol_failure_class(e), - ); - })?; - if frame.opcode == OPCODE_CLOSE { - close_seen = true; - } - } - _ => unreachable!("validated opcode"), - } - } -} - -async fn read_frame_header(reader: &mut R) -> Result> { - let first = match reader.read_u8().await { - Ok(byte) => byte, - Err(e) - if matches!( - e.kind(), - std::io::ErrorKind::UnexpectedEof - | std::io::ErrorKind::ConnectionReset - | std::io::ErrorKind::BrokenPipe - ) => - { - return Ok(None); - } - Err(e) => return Err(miette!("{e}")), - }; - let second = reader - .read_u8() - .await - .map_err(|e| miette!("malformed websocket frame header: {e}"))?; - - let mut raw_header = vec![first, second]; - let len_code = second & 0x7F; - let payload_len = match len_code { - 0..=125 => u64::from(len_code), - 126 => { - let mut bytes = [0u8; 2]; - reader - .read_exact(&mut bytes) - .await - .map_err(|e| miette!("malformed websocket extended length: {e}"))?; - raw_header.extend_from_slice(&bytes); - let len = u64::from(u16::from_be_bytes(bytes)); - if len < 126 { - return Err(miette!( - "websocket frame uses non-minimal 16-bit extended length" - )); - } - len - } - 127 => { - let mut bytes = [0u8; 8]; - reader - .read_exact(&mut bytes) - .await - .map_err(|e| miette!("malformed websocket extended length: {e}"))?; - if bytes[0] & 0x80 != 0 { - return Err(miette!("websocket frame uses non-canonical 64-bit length")); - } - raw_header.extend_from_slice(&bytes); - let len = u64::from_be_bytes(bytes); - if u16::try_from(len).is_ok() { - return Err(miette!( - "websocket frame uses non-minimal 64-bit extended length" - )); - } - len - } - _ => unreachable!("7-bit length code"), - }; - - let masked = second & 0x80 != 0; - let mask_key = if masked { - let mut key = [0u8; 4]; - reader - .read_exact(&mut key) - .await - .map_err(|e| miette!("malformed websocket mask key: {e}"))?; - raw_header.extend_from_slice(&key); - Some(key) - } else { - None - }; - - Ok(Some(FrameHeader { - fin: first & 0x80 != 0, - rsv: first & 0x70, - opcode: first & 0x0F, - masked, - payload_len, - mask_key, - raw_header, - })) -} - -fn validate_frame_header( - frame: &FrameHeader, - fragments: &FragmentState, - compression: WebSocketCompression, -) -> Result<()> { - if !valid_rsv_bits(frame, fragments, compression) { - return Err(miette!( - "websocket frame has unsupported RSV bits or extension state" - )); - } - if !frame.masked { - return Err(miette!("websocket client frame is not masked")); - } - if !matches!( - frame.opcode, - OPCODE_CONTINUATION - | OPCODE_TEXT - | OPCODE_BINARY - | OPCODE_CLOSE - | OPCODE_PING - | OPCODE_PONG - ) { - return Err(miette!("websocket frame uses reserved opcode")); - } - if matches!(frame.opcode, OPCODE_CLOSE | OPCODE_PING | OPCODE_PONG) { - if !frame.fin { - return Err(miette!("websocket control frame is fragmented")); - } - if frame.payload_len > 125 { - return Err(miette!("websocket control frame exceeds 125 bytes")); - } - } - if matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) - && !matches!(fragments, FragmentState::None) - { - return Err(miette!( - "websocket data frame started before previous fragmented message completed" - )); - } - if matches!(frame.opcode, OPCODE_CONTINUATION) && matches!(fragments, FragmentState::None) { - return Err(miette!( - "websocket continuation frame without active fragmented message" - )); - } - if (frame.opcode == OPCODE_BINARY - || (frame.opcode == OPCODE_CONTINUATION && matches!(fragments, FragmentState::Binary))) - && frame.payload_len > MAX_RAW_FRAME_PAYLOAD_BYTES - { - return Err(miette!( - "websocket binary frame exceeds {MAX_RAW_FRAME_PAYLOAD_BYTES} byte relay limit" - )); - } - Ok(()) -} - -fn valid_rsv_bits( - frame: &FrameHeader, - fragments: &FragmentState, - compression: WebSocketCompression, -) -> bool { - if frame.rsv == 0 { - return true; - } - if compression != WebSocketCompression::PermessageDeflate || frame.rsv != 0x40 { - return false; - } - matches!(fragments, FragmentState::None) && matches!(frame.opcode, OPCODE_TEXT | OPCODE_BINARY) -} - -async fn read_masked_payload( - reader: &mut R, - frame: &FrameHeader, -) -> Result> { - let payload_len = usize::try_from(frame.payload_len) - .map_err(|_| miette!("websocket text frame is too large to buffer"))?; - if payload_len > MAX_TEXT_MESSAGE_BYTES { - return Err(miette!( - "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" - )); - } - let mut payload = vec![0u8; payload_len]; - reader - .read_exact(&mut payload) - .await - .map_err(|e| miette!("malformed websocket payload: {e}"))?; - let mask_key = frame - .mask_key - .ok_or_else(|| miette!("websocket client frame is not masked"))?; - apply_mask(&mut payload, mask_key); - Ok(payload) -} - -fn append_text_fragment(buffer: &mut Vec, next: Vec) -> Result<()> { - let new_len = buffer - .len() - .checked_add(next.len()) - .ok_or_else(|| miette!("websocket text message length overflow"))?; - if new_len > MAX_TEXT_MESSAGE_BYTES { - return Err(miette!( - "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" - )); - } - buffer.extend_from_slice(&next); - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -async fn relay_text_payload( - writer: &mut W, - frame: &FrameHeader, - payload: Vec, - force_reframe: bool, - compressed: bool, - host: &str, - port: u16, - options: &RelayOptions<'_>, -) -> Result<()> { - let message_payload = if compressed { - decompress_permessage_deflate(&payload)? - } else { - payload - }; - let mut text = String::from_utf8(message_payload) - .map_err(|_| miette!("websocket text message is not valid UTF-8"))?; - let replacements = if let Some(resolver) = options.resolver { - resolver - .rewrite_websocket_text_placeholders(&mut text) - .map_err(|_| miette!("websocket credential placeholder resolution failed"))? - } else { - 0 - }; - - if let Some(inspector) = options.inspector.as_ref() { - inspect_websocket_text_message(host, port, options.policy_name, inspector, &text)?; - } - - if replacements == 0 && !force_reframe && !compressed { - writer - .write_all(&frame.raw_header) - .await - .into_diagnostic()?; - let mut payload = text.into_bytes(); - let mask_key = frame - .mask_key - .ok_or_else(|| miette!("websocket client frame is not masked"))?; - apply_mask(&mut payload, mask_key); - writer.write_all(&payload).await.into_diagnostic()?; - writer.flush().await.into_diagnostic()?; - return Ok(()); - } - - if replacements > 0 { - emit_rewrite_event(host, port, options.policy_name, replacements); - } - if compressed { - let compressed_payload = compress_permessage_deflate(text.as_bytes())?; - return write_masked_frame_with_rsv(writer, OPCODE_TEXT, 0x40, &compressed_payload).await; - } - write_masked_frame(writer, OPCODE_TEXT, text.as_bytes()).await -} - -fn inspect_websocket_text_message( - host: &str, - port: u16, - policy_name: &str, - inspector: &InspectionOptions<'_>, - text: &str, -) -> Result<()> { - if inspector.graphql_policy { - return inspect_graphql_websocket_message(host, port, policy_name, inspector, text); - } - - let request_info = L7RequestInfo { - action: "WEBSOCKET_TEXT".to_string(), - target: inspector.target.clone(), - query_params: inspector.query_params.clone(), - graphql: None, - jsonrpc: None, - }; - let (allowed, reason) = evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)?; - let decision = match (allowed, inspector.enforcement) { - (true, _) => "allow", - (false, EnforcementMode::Audit) => "audit", - (false, EnforcementMode::Enforce) => "deny", - }; - emit_websocket_l7_event( - host, - port, - policy_name, - &request_info, - decision, - &reason, - None, - ); - if !allowed && inspector.enforcement == EnforcementMode::Enforce { - return Err(miette!("websocket text message denied by policy")); - } - Ok(()) -} - -fn inspect_graphql_websocket_message( - host: &str, - port: u16, - policy_name: &str, - inspector: &InspectionOptions<'_>, - text: &str, -) -> Result<()> { - match classify_graphql_websocket_message(text) { - GraphqlWebSocketMessage::Control { message_type } => { - let request_info = L7RequestInfo { - action: "WEBSOCKET_CONTROL".to_string(), - target: inspector.target.clone(), - query_params: inspector.query_params.clone(), - graphql: None, - jsonrpc: None, - }; - emit_websocket_l7_event( - host, - port, - policy_name, - &request_info, - "allow", - &format!("GraphQL WebSocket control message {message_type}"), - None, - ); - Ok(()) - } - GraphqlWebSocketMessage::Operation { - message_type, - graphql, - } => { - let request_info = L7RequestInfo { - action: "WEBSOCKET_TEXT".to_string(), - target: inspector.target.clone(), - query_params: inspector.query_params.clone(), - graphql: Some(graphql.clone()), - jsonrpc: None, - }; - let parse_error_reason = graphql - .error - .as_deref() - .map(|error| format!("GraphQL WebSocket message rejected: {error}")); - let force_deny = parse_error_reason.is_some(); - let (allowed, reason) = if let Some(reason) = parse_error_reason { - (false, reason) - } else { - evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)? - }; - let decision = match (allowed, inspector.enforcement) { - (_, _) if force_deny => "deny", - (true, _) => "allow", - (false, EnforcementMode::Audit) => "audit", - (false, EnforcementMode::Enforce) => "deny", - }; - let reason = format!("graphql_ws_type={message_type} {reason}"); - emit_websocket_l7_event( - host, - port, - policy_name, - &request_info, - decision, - &reason, - Some(&graphql), - ); - if (!allowed && inspector.enforcement == EnforcementMode::Enforce) || force_deny { - return Err(miette!("websocket GraphQL message denied by policy")); - } - Ok(()) - } - } -} - -#[derive(Debug)] -enum GraphqlWebSocketMessage { - Control { - message_type: String, - }, - Operation { - message_type: String, - graphql: crate::l7::graphql::GraphqlRequestInfo, - }, -} - -fn classify_graphql_websocket_message(text: &str) -> GraphqlWebSocketMessage { - let value = match serde_json::from_str::(text) { - Ok(value) => value, - Err(err) => { - return GraphqlWebSocketMessage::Operation { - message_type: "unknown".to_string(), - graphql: graphql_error(format!( - "GraphQL WebSocket message is not valid JSON: {err}" - )), - }; - } - }; - let Some(obj) = value.as_object() else { - return GraphqlWebSocketMessage::Operation { - message_type: "unknown".to_string(), - graphql: graphql_error("GraphQL WebSocket message must be a JSON object"), - }; - }; - let Some(message_type) = obj.get("type").and_then(serde_json::Value::as_str) else { - return GraphqlWebSocketMessage::Operation { - message_type: "unknown".to_string(), - graphql: graphql_error("GraphQL WebSocket message missing string type"), - }; - }; - - match message_type { - "subscribe" | "start" => { - if obj - .get("id") - .and_then(serde_json::Value::as_str) - .is_none_or(str::is_empty) - { - return GraphqlWebSocketMessage::Operation { - message_type: message_type.to_string(), - graphql: graphql_error( - "GraphQL WebSocket operation message missing non-empty id", - ), - }; - } - let Some(payload) = obj.get("payload").filter(|value| value.is_object()) else { - return GraphqlWebSocketMessage::Operation { - message_type: message_type.to_string(), - graphql: graphql_error( - "GraphQL WebSocket operation message missing object payload", - ), - }; - }; - GraphqlWebSocketMessage::Operation { - message_type: message_type.to_string(), - graphql: crate::l7::graphql::classify_json_envelope_value(payload), - } - } - "connection_init" | "connection_terminate" | "ping" | "pong" | "complete" | "stop" => { - GraphqlWebSocketMessage::Control { - message_type: message_type.to_string(), - } - } - _ => GraphqlWebSocketMessage::Operation { - message_type: message_type.to_string(), - graphql: graphql_error(format!( - "unsupported GraphQL WebSocket client message type {message_type:?}" - )), - }, - } -} - -fn graphql_error(message: impl Into) -> crate::l7::graphql::GraphqlRequestInfo { - crate::l7::graphql::GraphqlRequestInfo { - operations: Vec::new(), - error: Some(message.into()), - } -} - -async fn relay_control_frame( - reader: &mut R, - writer: &mut W, - frame: &FrameHeader, -) -> Result<()> -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let raw_payload_len = usize::try_from(frame.payload_len) - .map_err(|_| miette!("websocket control frame payload length overflow"))?; - let mut raw_payload = vec![0u8; raw_payload_len]; - reader - .read_exact(&mut raw_payload) - .await - .map_err(|e| miette!("malformed websocket control payload: {e}"))?; - - if frame.opcode == OPCODE_CLOSE { - let mut payload = raw_payload.clone(); - let mask_key = frame - .mask_key - .ok_or_else(|| miette!("websocket client frame is not masked"))?; - apply_mask(&mut payload, mask_key); - validate_close_payload(&payload)?; - } - - writer - .write_all(&frame.raw_header) - .await - .into_diagnostic()?; - writer.write_all(&raw_payload).await.into_diagnostic()?; - writer.flush().await.into_diagnostic()?; - Ok(()) -} - -fn validate_close_payload(payload: &[u8]) -> Result<()> { - if payload.len() == 1 { - return Err(miette!( - "websocket close frame payload cannot be exactly one byte" - )); - } - if payload.len() < 2 { - return Ok(()); - } - - let code = u16::from_be_bytes([payload[0], payload[1]]); - if !valid_close_code(code) { - return Err(miette!("websocket close frame uses invalid close code")); - } - if std::str::from_utf8(&payload[2..]).is_err() { - return Err(miette!("websocket close frame reason is not valid UTF-8")); - } - Ok(()) -} - -fn valid_close_code(code: u16) -> bool { - (matches!(code, 1000..=1014) && !matches!(code, 1004..=1006)) || (3000..=4999).contains(&code) -} - -async fn copy_raw_frame_payload( - reader: &mut R, - writer: &mut W, - frame: &FrameHeader, -) -> Result<()> -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - writer - .write_all(&frame.raw_header) - .await - .into_diagnostic()?; - let mut remaining = frame.payload_len; - let mut buf = [0u8; COPY_BUF_SIZE]; - while remaining > 0 { - let to_read = usize::try_from(remaining) - .unwrap_or(buf.len()) - .min(buf.len()); - let n = reader.read(&mut buf[..to_read]).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("websocket payload ended before declared length")); - } - writer.write_all(&buf[..n]).await.into_diagnostic()?; - remaining -= n as u64; - } - writer.flush().await.into_diagnostic()?; - Ok(()) -} - -async fn write_masked_frame( - writer: &mut W, - opcode: u8, - payload: &[u8], -) -> Result<()> { - write_masked_frame_with_rsv(writer, opcode, 0, payload).await -} - -async fn write_masked_frame_with_rsv( - writer: &mut W, - opcode: u8, - rsv: u8, - payload: &[u8], -) -> Result<()> { - let mut header = Vec::with_capacity(14); - header.push(0x80 | rsv | opcode); - match payload.len() { - 0..=125 => header.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), - 126..=65_535 => { - header.push(0x80 | 0x7e); - header.extend_from_slice( - &u16::try_from(payload.len()) - .expect("payload <= 65535") - .to_be_bytes(), - ); - } - _ => { - header.push(0x80 | 127); - header.extend_from_slice(&(payload.len() as u64).to_be_bytes()); - } - } - let mask_key = new_mask_key(); - header.extend_from_slice(&mask_key); - - let mut masked = payload.to_vec(); - apply_mask(&mut masked, mask_key); - writer.write_all(&header).await.into_diagnostic()?; - writer.write_all(&masked).await.into_diagnostic()?; - writer.flush().await.into_diagnostic()?; - Ok(()) -} - -fn decompress_permessage_deflate(payload: &[u8]) -> Result> { - let mut decoder = Decompress::new(false); - let mut input = Vec::with_capacity(payload.len() + 4); - input.extend_from_slice(payload); - input.extend_from_slice(&[0x00, 0x00, 0xff, 0xff]); - let mut out = Vec::with_capacity(payload.len().saturating_mul(2).min(MAX_TEXT_MESSAGE_BYTES)); - let mut input_pos = 0usize; - let mut scratch = [0u8; COPY_BUF_SIZE]; - loop { - let before_in = decoder.total_in(); - let before_out = decoder.total_out(); - let status = decoder - .decompress(&input[input_pos..], &mut scratch, FlushDecompress::Sync) - .map_err(|e| miette!("websocket permessage-deflate decompression failed: {e}"))?; - let read = usize::try_from(decoder.total_in() - before_in) - .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; - let written = usize::try_from(decoder.total_out() - before_out) - .map_err(|_| miette!("websocket permessage-deflate output length overflow"))?; - input_pos = input_pos - .checked_add(read) - .ok_or_else(|| miette!("websocket permessage-deflate input length overflow"))?; - if out.len().saturating_add(written) > MAX_TEXT_MESSAGE_BYTES { - return Err(miette!( - "websocket text message exceeds {MAX_TEXT_MESSAGE_BYTES} byte limit" - )); - } - out.extend_from_slice(&scratch[..written]); - if matches!(status, Status::StreamEnd) { - break; - } - if input_pos >= input.len() && written < scratch.len() { - break; - } - if read == 0 && written == 0 { - return Err(miette!( - "websocket permessage-deflate decompression did not make progress" - )); - } - } - Ok(out) -} - -fn compress_permessage_deflate(payload: &[u8]) -> Result> { - let mut compressor = Compress::new(Compression::fast(), false); - let expansion = payload.len() / 16; - let mut out = Vec::with_capacity(payload.len().saturating_add(expansion).saturating_add(128)); - loop { - let consumed = usize::try_from(compressor.total_in()) - .map_err(|_| miette!("websocket permessage-deflate input length overflow"))?; - if consumed >= payload.len() { - break; - } - let before_in = compressor.total_in(); - let before_out = compressor.total_out(); - let status = compressor - .compress_vec(&payload[consumed..], &mut out, FlushCompress::None) - .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; - if matches!(status, Status::BufError) - || (compressor.total_in() == before_in && compressor.total_out() == before_out) - { - out.reserve(out.capacity().max(1024)); - } - } - loop { - out.reserve(64); - let before_out = compressor.total_out(); - compressor - .compress_vec(&[], &mut out, FlushCompress::Sync) - .map_err(|e| miette!("websocket permessage-deflate compression failed: {e}"))?; - if out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { - break; - } - if compressor.total_out() == before_out { - out.reserve(out.capacity().max(1024)); - } - } - if !out.ends_with(&[0x00, 0x00, 0xff, 0xff]) { - return Err(miette!( - "websocket permessage-deflate compression missing sync marker" - )); - } - out.truncate(out.len() - 4); - Ok(out) -} - -fn new_mask_key() -> [u8; 4] { - let bytes = uuid::Uuid::new_v4().into_bytes(); - [bytes[0], bytes[1], bytes[2], bytes[3]] -} - -fn apply_mask(payload: &mut [u8], mask_key: [u8; 4]) { - for (i, byte) in payload.iter_mut().enumerate() { - *byte ^= mask_key[i % 4]; - } -} - -fn emit_rewrite_event(host: &str, port: u16, policy_name: &str, replacements: usize) { - let policy_name = if policy_name.is_empty() { - "-" - } else { - policy_name - }; - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Allowed) - .disposition(DispositionId::Allowed) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .dst_endpoint(Endpoint::from_domain(host, port)) - .firewall_rule(policy_name, "l7-websocket") - .message(rewrite_event_message(host, port, replacements)) - .build(); - ocsf_emit!(event); -} - -fn rewrite_event_message(host: &str, port: u16, replacements: usize) -> String { - format!( - "WEBSOCKET_CREDENTIAL_REWRITE rewrote client text message [host:{host} port:{port} replacements:{replacements}]" - ) -} - -fn emit_websocket_l7_event( - host: &str, - port: u16, - policy_name: &str, - request_info: &L7RequestInfo, - decision: &str, - reason: &str, - graphql: Option<&crate::l7::graphql::GraphqlRequestInfo>, -) { - let policy_name = if policy_name.is_empty() { - "-" - } else { - policy_name - }; - let (action_id, disposition_id, severity) = match decision { - "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), - "allow" | "audit" => ( - ActionId::Allowed, - DispositionId::Allowed, - SeverityId::Informational, - ), - _ => ( - ActionId::Other, - DispositionId::Other, - SeverityId::Informational, - ), - }; - let summary = graphql.map(graphql_log_summary).unwrap_or_default(); - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(action_id) - .disposition(disposition_id) - .severity(severity) - .status(StatusId::Success) - .dst_endpoint(Endpoint::from_domain(host, port)) - .firewall_rule(policy_name, "l7-websocket") - .message(format!( - "WEBSOCKET_L7_REQUEST {decision} {} {host}:{port}{}{} reason={reason}", - request_info.action, request_info.target, summary - )) - .build(); - ocsf_emit!(event); -} - -fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String { - if let Some(error) = info.error.as_deref() { - return format!(" graphql_error={error:?}"); - } - let ops: Vec = info - .operations - .iter() - .map(|op| { - let name = op.operation_name.as_deref().unwrap_or("-"); - let fields = if op.fields.is_empty() { - "-".to_string() - } else { - op.fields.join(",") - }; - let persisted = op - .persisted_query_hash - .as_deref() - .or(op.persisted_query_id.as_deref()) - .unwrap_or("-"); - format!( - "type={} name={} fields={} persisted={}", - op.operation_type, name, fields, persisted - ) - }) - .collect(); - format!(" graphql_ops={}", ops.join(";")) -} - -fn protocol_failure_class(error: &miette::Report) -> &'static str { - let msg = error.to_string().to_ascii_lowercase(); - if msg.contains("credential") { - "credential_resolution_failed" - } else if msg.contains("utf-8") { - "invalid_utf8" - } else if msg.contains("close frame") || msg.contains("after close") { - "invalid_close_frame" - } else if msg.contains("control frame") { - "invalid_control_frame" - } else if msg.contains("length") - || msg.contains("too large") - || msg.contains("exceeds") - || msg.contains("overflow") - { - "invalid_length" - } else if msg.contains("continuation") || msg.contains("fragmented") { - "invalid_fragmentation" - } else if msg.contains("reserved opcode") { - "reserved_opcode" - } else if msg.contains("not masked") { - "unmasked_client_frame" - } else if msg.contains("rsv") { - "rsv_bits" - } else if msg.contains("malformed") { - "malformed_frame" - } else { - "protocol_error" - } -} - -fn emit_protocol_failure(host: &str, port: u16, policy_name: &str, failure_class: &str) { - let policy_name = if policy_name.is_empty() { - "-" - } else { - policy_name - }; - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(host, port)) - .firewall_rule(policy_name, "l7-websocket") - .message(protocol_failure_message(host, port)) - .status_detail(failure_class) - .build(); - ocsf_emit!(event); -} - -fn protocol_failure_message(host: &str, port: u16) -> String { - format!("WEBSOCKET_CREDENTIAL_REWRITE closed ambiguous client frame [host:{host} port:{port}]") -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::l7::relay::L7EvalContext; - use crate::opa::{NetworkInput, OpaEngine}; - use crate::secrets::SecretResolver; - use std::path::PathBuf; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - const TEST_POLICY: &str = include_str!("../../data/sandbox-policy.rego"); - const GRAPHQL_WS_POLICY: &str = r#" -network_policies: - graphql_ws: - name: graphql_ws - endpoints: - - host: realtime.graphql.test - port: 443 - path: "/graphql" - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/graphql" - - allow: - operation_type: query - fields: [viewer] - - allow: - operation_type: subscription - fields: [messageAdded] - binaries: - - { path: /usr/bin/node } -"#; - - fn resolver() -> (HashMap, SecretResolver) { - let (child_env, resolver) = SecretResolver::from_provider_env( - std::iter::once(("DISCORD_BOT_TOKEN".to_string(), "real-token".to_string())).collect(), - ); - (child_env, resolver.expect("resolver")) - } - - fn masked_frame(fin: bool, opcode: u8, payload: &[u8]) -> Vec { - masked_frame_with_rsv(fin, opcode, 0, payload) - } - - fn masked_frame_with_rsv(fin: bool, opcode: u8, rsv: u8, payload: &[u8]) -> Vec { - let mask_key = [0x37, 0xfa, 0x21, 0x3d]; - let mut frame = Vec::new(); - frame.push((if fin { 0x80 } else { 0 }) | rsv | opcode); - match payload.len() { - 0..=125 => frame.push(0x80 | u8::try_from(payload.len()).expect("payload <= 125")), - 126..=65_535 => { - frame.push(0x80 | 0x7e); - frame.extend_from_slice( - &u16::try_from(payload.len()) - .expect("payload <= 65535") - .to_be_bytes(), - ); - } - _ => { - frame.push(0x80 | 127); - frame.extend_from_slice(&(payload.len() as u64).to_be_bytes()); - } - } - frame.extend_from_slice(&mask_key); - for (i, byte) in payload.iter().enumerate() { - frame.push(byte ^ mask_key[i % 4]); - } - frame - } - - fn unmasked_frame(opcode: u8, payload: &[u8]) -> Vec { - let mut frame = Vec::new(); - frame.push(0x80 | opcode); - frame.push(u8::try_from(payload.len()).expect("test payload fits in one byte")); - frame.extend_from_slice(payload); - frame - } - - fn masked_frame_with_declared_len(opcode: u8, declared_len: u64) -> Vec { - let mut frame = Vec::new(); - frame.push(0x80 | opcode); - frame.push(0x80 | 127); - frame.extend_from_slice(&declared_len.to_be_bytes()); - frame.extend_from_slice(&[0x37, 0xfa, 0x21, 0x3d]); - frame - } - - fn masked_frame_with_non_minimal_16_bit_len(opcode: u8, payload: &[u8]) -> Vec { - let mask_key = [0x37, 0xfa, 0x21, 0x3d]; - let mut frame = Vec::new(); - frame.push(0x80 | opcode); - frame.push(0x80 | 0x7e); - frame.extend_from_slice( - &u16::try_from(payload.len()) - .expect("test payload fits u16") - .to_be_bytes(), - ); - frame.extend_from_slice(&mask_key); - for (i, byte) in payload.iter().enumerate() { - frame.push(byte ^ mask_key[i % 4]); - } - frame - } - - fn close_payload(code: u16, reason: &[u8]) -> Vec { - let mut payload = Vec::with_capacity(2 + reason.len()); - payload.extend_from_slice(&code.to_be_bytes()); - payload.extend_from_slice(reason); - payload - } - - async fn run_client_to_server(input: Vec) -> Result> { - let (_, resolver) = resolver(); - let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); - let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); - - client_write.write_all(&input).await.unwrap(); - drop(client_write); - - let options = RelayOptions { - policy_name: "test-policy", - resolver: Some(&resolver), - inspector: None, - compression: WebSocketCompression::None, - }; - let result = relay_client_to_server( - &mut relay_read, - &mut relay_write, - "gateway.example.test", - 443, - &options, - ) - .await; - drop(relay_write); - - let mut output = Vec::new(); - upstream_read.read_to_end(&mut output).await.unwrap(); - result.map(|()| output) - } - - async fn run_client_to_server_with_graphql_policy( - input: Vec, - resolver: Option<&SecretResolver>, - ) -> Result> { - let engine = OpaEngine::from_strings(TEST_POLICY, GRAPHQL_WS_POLICY) - .expect("GraphQL WebSocket policy should load"); - let network_input = NetworkInput { - host: "realtime.graphql.test".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let generation = engine - .evaluate_network_action_with_generation(&network_input) - .expect("network action should evaluate") - .1; - let tunnel_engine = engine - .clone_engine_for_tunnel(generation) - .expect("tunnel engine"); - let ctx = L7EvalContext { - host: "realtime.graphql.test".into(), - port: 443, - policy_name: "graphql_ws".into(), - binary_path: "/usr/bin/node".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); - let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); - - client_write.write_all(&input).await.unwrap(); - drop(client_write); - - let options = RelayOptions { - policy_name: "graphql_ws", - resolver, - inspector: Some(InspectionOptions { - engine: &tunnel_engine, - ctx: &ctx, - enforcement: EnforcementMode::Enforce, - target: "/graphql".to_string(), - query_params: HashMap::new(), - graphql_policy: true, - }), - compression: WebSocketCompression::None, - }; - let result = relay_client_to_server( - &mut relay_read, - &mut relay_write, - "realtime.graphql.test", - 443, - &options, - ) - .await; - drop(relay_write); - - let mut output = Vec::new(); - upstream_read.read_to_end(&mut output).await.unwrap(); - result.map(|()| output) - } - - async fn run_client_to_server_compressed(input: Vec) -> Result> { - let (_, resolver) = resolver(); - let (mut client_write, mut relay_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); - let (mut relay_write, mut upstream_read) = tokio::io::duplex(MAX_TEXT_MESSAGE_BYTES + 1024); - - client_write.write_all(&input).await.unwrap(); - drop(client_write); - - let options = RelayOptions { - policy_name: "test-policy", - resolver: Some(&resolver), - inspector: None, - compression: WebSocketCompression::PermessageDeflate, - }; - let result = relay_client_to_server( - &mut relay_read, - &mut relay_write, - "gateway.example.test", - 443, - &options, - ) - .await; - drop(relay_write); - - let mut output = Vec::new(); - upstream_read.read_to_end(&mut output).await.unwrap(); - result.map(|()| output) - } - - fn decode_masked_text_frame(frame: &[u8]) -> String { - assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); - assert_ne!(frame[1] & 0x80, 0); - String::from_utf8(decode_masked_payload(frame)).unwrap() - } - - fn decode_masked_payload(frame: &[u8]) -> Vec { - assert_ne!(frame[1] & 0x80, 0); - let len_code = frame[1] & 0x7F; - let (payload_len, mask_offset) = match len_code { - 0..=125 => (usize::from(len_code), 2), - 126 => (usize::from(u16::from_be_bytes([frame[2], frame[3]])), 4), - 127 => { - let len = u64::from_be_bytes(frame[2..10].try_into().unwrap()); - (usize::try_from(len).unwrap(), 10) - } - _ => unreachable!(), - }; - let mask_key: [u8; 4] = frame[mask_offset..mask_offset + 4].try_into().unwrap(); - let mut payload = frame[mask_offset + 4..mask_offset + 4 + payload_len].to_vec(); - apply_mask(&mut payload, mask_key); - payload - } - - fn decode_compressed_masked_text_frame(frame: &[u8]) -> String { - assert_eq!(frame[0] & 0x0F, OPCODE_TEXT); - assert_eq!(frame[0] & 0x40, 0x40); - let payload = decode_masked_payload(frame); - String::from_utf8(decompress_permessage_deflate(&payload).unwrap()).unwrap() - } - - async fn read_one_frame(reader: &mut R) -> Vec { - let mut header = [0u8; 2]; - reader.read_exact(&mut header).await.unwrap(); - let len_code = header[1] & 0x7F; - let extended_len = match len_code { - 0..=125 => Vec::new(), - 126 => { - let mut bytes = vec![0u8; 2]; - reader.read_exact(&mut bytes).await.unwrap(); - bytes - } - 127 => { - let mut bytes = vec![0u8; 8]; - reader.read_exact(&mut bytes).await.unwrap(); - bytes - } - _ => unreachable!(), - }; - let payload_len = match len_code { - 0..=125 => usize::from(len_code), - 126 => usize::from(u16::from_be_bytes( - extended_len.as_slice().try_into().unwrap(), - )), - 127 => usize::try_from(u64::from_be_bytes( - extended_len.as_slice().try_into().unwrap(), - )) - .unwrap(), - _ => unreachable!(), - }; - let mask_len = if header[1] & 0x80 != 0 { 4 } else { 0 }; - let mut rest = vec![0u8; extended_len.len() + mask_len + payload_len]; - rest[..extended_len.len()].copy_from_slice(&extended_len); - reader - .read_exact(&mut rest[extended_len.len()..]) - .await - .unwrap(); - - let mut frame = header.to_vec(); - frame.extend_from_slice(&rest); - frame - } - - #[test] - fn classifies_graphql_transport_ws_subscribe_operation() { - let message = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; - - match classify_graphql_websocket_message(message) { - GraphqlWebSocketMessage::Operation { - message_type, - graphql, - } => { - assert_eq!(message_type, "subscribe"); - assert!( - graphql.error.is_none(), - "unexpected error: {:?}", - graphql.error - ); - assert_eq!(graphql.operations.len(), 1); - assert_eq!(graphql.operations[0].operation_type, "subscription"); - assert_eq!( - graphql.operations[0].operation_name.as_deref(), - Some("NewMessages") - ); - assert_eq!(graphql.operations[0].fields, vec!["messageAdded"]); - } - other @ GraphqlWebSocketMessage::Control { .. } => { - panic!("expected operation, got {other:?}") - } - } - } - - #[test] - fn classifies_legacy_graphql_ws_start_operation() { - let message = r#"{"type":"start","id":"1","payload":{"query":"query Viewer { viewer }"}}"#; - - match classify_graphql_websocket_message(message) { - GraphqlWebSocketMessage::Operation { - message_type, - graphql, - } => { - assert_eq!(message_type, "start"); - assert!( - graphql.error.is_none(), - "unexpected error: {:?}", - graphql.error - ); - assert_eq!(graphql.operations[0].operation_type, "query"); - assert_eq!(graphql.operations[0].fields, vec!["viewer"]); - } - other @ GraphqlWebSocketMessage::Control { .. } => { - panic!("expected operation, got {other:?}") - } - } - } - - #[test] - fn classifies_graphql_websocket_control_message_without_payload_logging() { - match classify_graphql_websocket_message( - r#"{"type":"connection_init","payload":{"authorization":"secret"}}"#, - ) { - GraphqlWebSocketMessage::Control { message_type } => { - assert_eq!(message_type, "connection_init"); - } - other @ GraphqlWebSocketMessage::Operation { .. } => { - panic!("expected control message, got {other:?}") - } - } - } - - #[test] - fn unsupported_graphql_websocket_message_type_fails_closed() { - match classify_graphql_websocket_message(r#"{"type":"next","id":"1"}"#) { - GraphqlWebSocketMessage::Operation { graphql, .. } => { - assert!( - graphql - .error - .as_deref() - .is_some_and(|error| error.contains("unsupported")) - ); - } - other @ GraphqlWebSocketMessage::Control { .. } => { - panic!("expected operation error, got {other:?}") - } - } - } - - #[test] - fn graphql_websocket_log_summary_excludes_payload_variables_and_secrets() { - let placeholder = "openshell:resolve:env:T"; - let message = format!( - r#"{{"type":"subscribe","id":"1","payload":{{"query":"query Viewer {{ viewer }}","variables":{{"token":"{placeholder}"}}}}}}"# - ); - let graphql = match classify_graphql_websocket_message(&message) { - GraphqlWebSocketMessage::Operation { graphql, .. } => graphql, - other @ GraphqlWebSocketMessage::Control { .. } => { - panic!("expected operation, got {other:?}") - } - }; - let summary = graphql_log_summary(&graphql); - - assert!(summary.contains("type=query")); - assert!(summary.contains("fields=viewer")); - assert!(!summary.contains(placeholder)); - assert!(!summary.contains("real-token")); - assert!(!summary.contains("variables")); - assert!(!summary.contains("token")); - assert!(!summary.contains("secret_len")); - } - - #[tokio::test] - async fn rewrites_discord_like_identify_text_payload() { - let (child_env, _) = resolver(); - let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); - let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); - - let output = run_client_to_server(masked_frame(true, OPCODE_TEXT, payload.as_bytes())) - .await - .expect("relay should succeed"); - - assert_eq!( - decode_masked_text_frame(&output), - r#"{"op":2,"d":{"token":"real-token"}}"# - ); - } - - #[tokio::test] - async fn upgraded_relay_rewrites_client_text_before_upstream_receives_it() { - let (child_env, resolver) = resolver(); - let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); - let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); - let client_frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); - assert!( - !String::from_utf8_lossy(&client_frame).contains("real-token"), - "client-side fixture must not contain the real token" - ); - - let (mut client_app, mut relay_client) = tokio::io::duplex(4096); - let (mut relay_upstream, mut upstream_app) = tokio::io::duplex(4096); - let relay = tokio::spawn(async move { - relay_with_options( - &mut relay_client, - &mut relay_upstream, - Vec::new(), - "gateway.example.test", - 443, - RelayOptions { - policy_name: "test-policy", - resolver: Some(&resolver), - inspector: None, - compression: WebSocketCompression::None, - }, - ) - .await - }); - - client_app.write_all(&client_frame).await.unwrap(); - client_app.flush().await.unwrap(); - - let upstream_frame = tokio::time::timeout( - std::time::Duration::from_secs(2), - read_one_frame(&mut upstream_app), - ) - .await - .expect("upstream should receive rewritten frame"); - assert_eq!( - decode_masked_text_frame(&upstream_frame), - r#"{"op":2,"d":{"token":"real-token"}}"# - ); - - drop(client_app); - drop(upstream_app); - let _ = tokio::time::timeout(std::time::Duration::from_secs(2), relay).await; - } - - #[tokio::test] - async fn graphql_websocket_policy_allows_subscription_operation() { - let payload = r#"{"type":"subscribe","id":"1","payload":{"query":"subscription NewMessages { messageAdded }"}}"#; - let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); - - let output = run_client_to_server_with_graphql_policy(frame.clone(), None) - .await - .expect("allowed subscription should relay"); - - assert_eq!(output, frame); - assert_eq!(decode_masked_text_frame(&output), payload); - } - - #[tokio::test] - async fn graphql_websocket_policy_denies_unlisted_operation_field() { - let payload = - r#"{"type":"subscribe","id":"1","payload":{"query":"query Admin { adminAuditLog }"}}"#; - let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); - - let err = run_client_to_server_with_graphql_policy(frame, None) - .await - .expect_err("unlisted field should be denied"); - - assert!(err.to_string().contains("websocket GraphQL message denied")); - } - - #[tokio::test] - async fn graphql_websocket_control_message_rewrites_credentials_before_relay() { - let (child_env, resolver) = SecretResolver::from_provider_env( - std::iter::once(("T".to_string(), "real-token".to_string())).collect(), - ); - let resolver = resolver.expect("resolver"); - let placeholder = child_env.get("T").expect("placeholder env"); - let payload = format!( - r#"{{"type":"connection_init","payload":{{"authorization":"{placeholder}"}}}}"# - ); - let frame = masked_frame(true, OPCODE_TEXT, payload.as_bytes()); - - let output = run_client_to_server_with_graphql_policy(frame, Some(&resolver)) - .await - .expect("control message should relay after credential rewrite"); - - let rewritten = decode_masked_text_frame(&output); - assert_eq!( - rewritten, - r#"{"type":"connection_init","payload":{"authorization":"real-token"}}"# - ); - assert!(!rewritten.contains(placeholder)); - } - - #[tokio::test] - async fn text_without_placeholder_passes_semantically_unchanged() { - let frame = masked_frame(true, OPCODE_TEXT, br#"{"op":1,"d":42}"#); - let output = run_client_to_server(frame.clone()) - .await - .expect("relay should succeed"); - - assert_eq!(output, frame); - assert_eq!(decode_masked_text_frame(&output), r#"{"op":1,"d":42}"#); - } - - #[tokio::test] - async fn unknown_placeholder_fails_closed() { - let frame = masked_frame( - true, - OPCODE_TEXT, - br#"{"token":"openshell:resolve:env:UNKNOWN"}"#, - ); - - let err = run_client_to_server(frame) - .await - .expect_err("unknown placeholder should fail"); - - assert!( - err.to_string() - .contains("credential placeholder resolution") - ); - } - - #[tokio::test] - async fn fragmented_text_rewrites_after_final_continuation() { - let (child_env, _) = resolver(); - let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); - let first = format!(r#"{{"token":"{placeholder}"#); - let second = r#""}"#; - let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); - input.extend(masked_frame(true, OPCODE_CONTINUATION, second.as_bytes())); - - let output = run_client_to_server(input) - .await - .expect("relay should succeed"); - - assert_eq!( - decode_masked_text_frame(&output), - r#"{"token":"real-token"}"# - ); - } - - #[tokio::test] - async fn rejects_rsv_bits() { - let mut frame = masked_frame(true, OPCODE_TEXT, b"hello"); - frame[0] |= 0x40; - - let err = run_client_to_server(frame) - .await - .expect_err("RSV frame should fail"); - - assert!(err.to_string().contains("RSV bits")); - } - - #[tokio::test] - async fn rejects_unmasked_client_frame() { - let err = run_client_to_server(unmasked_frame(OPCODE_TEXT, b"hello")) - .await - .expect_err("unmasked frame should fail"); - - assert!(err.to_string().contains("not masked")); - } - - #[tokio::test] - async fn rejects_invalid_utf8_text() { - let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &[0xff])) - .await - .expect_err("invalid UTF-8 should fail"); - - assert!(err.to_string().contains("valid UTF-8")); - } - - #[tokio::test] - async fn rejects_oversize_text_message() { - let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; - let err = run_client_to_server(masked_frame(true, OPCODE_TEXT, &payload)) - .await - .expect_err("oversize text should fail"); - - assert!(err.to_string().contains("exceeds")); - } - - #[tokio::test] - async fn fragmented_text_allows_interleaved_ping_pong_and_rewrites_at_completion() { - let (child_env, _) = resolver(); - let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); - let first = format!(r#"{{"token":"{placeholder}"#); - let first_control_frame = masked_frame(true, OPCODE_PING, b"p"); - let second_control_frame = masked_frame(true, OPCODE_PONG, b"q"); - let mut input = masked_frame(false, OPCODE_TEXT, first.as_bytes()); - input.extend_from_slice(&first_control_frame); - input.extend_from_slice(&second_control_frame); - input.extend(masked_frame(true, OPCODE_CONTINUATION, br#""}"#)); - - let output = run_client_to_server(input) - .await - .expect("relay should allow interleaved control frames"); - - assert!(output.starts_with(&first_control_frame)); - assert_eq!( - &output - [first_control_frame.len()..first_control_frame.len() + second_control_frame.len()], - second_control_frame.as_slice() - ); - assert_eq!( - decode_masked_text_frame( - &output[first_control_frame.len() + second_control_frame.len()..] - ), - r#"{"token":"real-token"}"# - ); - } - - #[tokio::test] - async fn compressed_text_rewrites_with_permessage_deflate() { - let (child_env, _) = resolver(); - let placeholder = child_env.get("DISCORD_BOT_TOKEN").unwrap(); - let payload = format!(r#"{{"token":"{placeholder}"}}"#); - let compressed = compress_permessage_deflate(payload.as_bytes()).unwrap(); - let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); - - let output = run_client_to_server_compressed(input) - .await - .expect("compressed text should relay"); - - assert_eq!( - decode_compressed_masked_text_frame(&output), - r#"{"token":"real-token"}"# - ); - } - - #[tokio::test] - async fn compressed_text_rejects_decompressed_oversize_message() { - let payload = vec![b'a'; MAX_TEXT_MESSAGE_BYTES + 1]; - let compressed = compress_permessage_deflate(&payload).unwrap(); - let input = masked_frame_with_rsv(true, OPCODE_TEXT, 0x40, &compressed); - - let err = run_client_to_server_compressed(input) - .await - .expect_err("oversize decompressed text should fail"); - - assert!(err.to_string().contains("exceeds")); - } - - #[tokio::test] - async fn binary_frame_passes_through_unchanged() { - let frame = masked_frame(true, OPCODE_BINARY, &[0, 1, 2, 3, 255]); - - let output = run_client_to_server(frame.clone()) - .await - .expect("binary frame should pass through"); - - assert_eq!(output, frame); - } - - #[tokio::test] - async fn rejects_reserved_opcode() { - let err = run_client_to_server(masked_frame(true, 0x3, b"reserved")) - .await - .expect_err("reserved opcode should fail"); - - assert!(err.to_string().contains("reserved opcode")); - } - - #[tokio::test] - async fn rejects_continuation_without_active_message() { - let err = run_client_to_server(masked_frame(true, OPCODE_CONTINUATION, b"orphan")) - .await - .expect_err("orphan continuation should fail"); - - assert!(err.to_string().contains("continuation")); - } - - #[tokio::test] - async fn rejects_new_data_frame_before_fragment_completion() { - let mut input = masked_frame(false, OPCODE_TEXT, b"partial"); - input.extend(masked_frame(true, OPCODE_TEXT, b"second")); - - let err = run_client_to_server(input) - .await - .expect_err("new data frame during fragmentation should fail"); - - assert!(err.to_string().contains("previous fragmented message")); - } - - #[tokio::test] - async fn rejects_fragmented_control_frame() { - let err = run_client_to_server(masked_frame(false, OPCODE_PING, b"ping")) - .await - .expect_err("fragmented control frame should fail"); - - assert!(err.to_string().contains("control frame is fragmented")); - } - - #[tokio::test] - async fn rejects_control_frame_over_125_bytes() { - let payload = vec![b'a'; 126]; - let err = run_client_to_server(masked_frame(true, OPCODE_PING, &payload)) - .await - .expect_err("oversize control frame should fail"); - - assert!(err.to_string().contains("control frame exceeds")); - } - - #[tokio::test] - async fn rejects_non_minimal_extended_length() { - let err = run_client_to_server(masked_frame_with_non_minimal_16_bit_len( - OPCODE_TEXT, - b"hello", - )) - .await - .expect_err("non-minimal length should fail"); - - assert!(err.to_string().contains("non-minimal")); - } - - #[tokio::test] - async fn rejects_oversize_binary_frame_before_payload_buffering() { - let err = run_client_to_server(masked_frame_with_declared_len( - OPCODE_BINARY, - MAX_RAW_FRAME_PAYLOAD_BYTES + 1, - )) - .await - .expect_err("oversize binary frame should fail"); - - assert!(err.to_string().contains("binary frame exceeds")); - } - - #[tokio::test] - async fn validates_close_frame_payloads() { - let frame = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); - - let output = run_client_to_server(frame.clone()) - .await - .expect("valid close frame should pass through"); - - assert_eq!(output, frame); - } - - #[tokio::test] - async fn rejects_close_frame_with_one_byte_payload() { - let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &[0x03])) - .await - .expect_err("one-byte close frame should fail"); - - assert!(err.to_string().contains("exactly one byte")); - } - - #[tokio::test] - async fn rejects_reserved_close_code() { - let err = run_client_to_server(masked_frame(true, OPCODE_CLOSE, &close_payload(1005, b""))) - .await - .expect_err("reserved close code should fail"); - - assert!(err.to_string().contains("invalid close code")); - } - - #[tokio::test] - async fn rejects_close_reason_with_invalid_utf8() { - let err = run_client_to_server(masked_frame( - true, - OPCODE_CLOSE, - &close_payload(1000, &[0xff]), - )) - .await - .expect_err("invalid close reason should fail"); - - assert!(err.to_string().contains("valid UTF-8")); - } - - #[tokio::test] - async fn rejects_frames_after_client_close_frame() { - let mut input = masked_frame(true, OPCODE_CLOSE, &close_payload(1000, b"done")); - input.extend(masked_frame(true, OPCODE_TEXT, b"late")); - - let err = run_client_to_server(input) - .await - .expect_err("frames after close should fail"); - - assert!(err.to_string().contains("after close")); - } - - #[test] - fn websocket_ocsf_messages_do_not_include_payload_or_secret_material() { - let placeholder = "openshell:resolve:env:DISCORD_BOT_TOKEN"; - let secret = "real-token"; - let payload = format!(r#"{{"op":2,"d":{{"token":"{placeholder}"}}}}"#); - - let rewrite = rewrite_event_message("gateway.example.test", 443, 1); - let failure = protocol_failure_message("gateway.example.test", 443); - let messages = [rewrite, failure]; - - for message in messages { - assert!(!message.contains(placeholder)); - assert!(!message.contains(secret)); - assert!(!message.contains(&payload)); - assert!(!message.contains("secret_len")); - assert!(!message.contains("payload_len")); - } - } -} diff --git a/crates/openshell-sandbox/src/opa.rs b/crates/openshell-sandbox/src/opa.rs deleted file mode 100644 index ac50340d1..000000000 --- a/crates/openshell-sandbox/src/opa.rs +++ /dev/null @@ -1,5339 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! Embedded OPA policy engine using regorus. -//! -//! Wraps [`regorus::Engine`] to evaluate Rego policies for sandbox network -//! access decisions. The engine is loaded once at sandbox startup and queried -//! on every proxy CONNECT request. - -use crate::policy::{FilesystemPolicy, LandlockCompatibility, LandlockPolicy, ProcessPolicy}; -use miette::Result; -use openshell_core::proto::SandboxPolicy as ProtoSandboxPolicy; -use std::path::{Path, PathBuf}; -use std::sync::{ - Arc, Mutex, - atomic::{AtomicU64, Ordering}, -}; - -/// Baked-in rego rules for OPA policy evaluation. -/// These rules define the network access decision logic and static config -/// passthroughs. They reference `data.sandbox.*` for policy data. -const BAKED_POLICY_RULES: &str = include_str!("../data/sandbox-policy.rego"); - -/// Result of evaluating a network access request against OPA policy. -pub struct PolicyDecision { - pub allowed: bool, - pub reason: String, - pub matched_policy: Option, -} - -/// Network action returned by OPA `network_action` rule. -/// -/// - `Allow`: endpoint + binary explicitly matched in a network policy -/// - `Deny`: no matching policy -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum NetworkAction { - Allow { matched_policy: Option }, - Deny { reason: String }, -} - -/// Input for a network access policy evaluation. -pub struct NetworkInput { - pub host: String, - pub port: u16, - pub binary_path: PathBuf, - pub binary_sha256: String, - /// Ancestor binary paths from process tree walk (parent, grandparent, ...). - pub ancestors: Vec, - /// Absolute paths extracted from `/proc//cmdline` of the socket-owning - /// process and its ancestors. Captures script paths (e.g. `/usr/local/bin/claude`) - /// that don't appear in `/proc//exe` because the interpreter (node) is the exe. - pub cmdline_paths: Vec, -} - -/// Sandbox configuration extracted from OPA data at startup. -pub struct SandboxConfig { - pub filesystem: FilesystemPolicy, - pub landlock: LandlockPolicy, - pub process: ProcessPolicy, -} - -/// Embedded OPA policy engine. -/// -/// Thread-safe: the inner `regorus::Engine` requires `&mut self` for -/// evaluation, so access is serialized via a `Mutex`. This is acceptable -/// because policy evaluation is fast (microseconds) and contention is low -/// (one eval per CONNECT request). -pub struct OpaEngine { - engine: Mutex, - generation: Arc, -} - -/// Generation guard captured when an HTTP tunnel or request path starts. -#[derive(Clone)] -pub struct PolicyGenerationGuard { - captured_generation: u64, - current_generation: Arc, -} - -impl PolicyGenerationGuard { - pub fn captured_generation(&self) -> u64 { - self.captured_generation - } - - pub fn current_generation(&self) -> u64 { - self.current_generation.load(Ordering::Acquire) - } - - pub fn is_stale(&self) -> bool { - self.current_generation() != self.captured_generation - } - - pub fn ensure_current(&self) -> Result<()> { - if self.is_stale() { - return Err(miette::miette!( - "policy generation is stale [captured_generation:{} current_generation:{}]", - self.captured_generation(), - self.current_generation(), - )); - } - Ok(()) - } -} - -/// Per-tunnel L7 policy evaluator bound to the engine generation captured when -/// the tunnel was established. -pub struct TunnelPolicyEngine { - engine: Mutex, - generation_guard: PolicyGenerationGuard, -} - -impl TunnelPolicyEngine { - pub fn captured_generation(&self) -> u64 { - self.generation_guard.captured_generation() - } - - pub fn current_generation(&self) -> u64 { - self.generation_guard.current_generation() - } - - pub fn is_stale(&self) -> bool { - self.generation_guard.is_stale() - } - - pub fn generation_guard(&self) -> &PolicyGenerationGuard { - &self.generation_guard - } - - pub(crate) fn engine(&self) -> &Mutex { - &self.engine - } -} - -impl OpaEngine { - /// Load policy from a `.rego` rules file and data from a YAML file. - /// - /// Preprocesses the YAML data to expand access presets and validate L7 config. - pub fn from_files(policy_path: &Path, data_path: &Path) -> Result { - let yaml_str = std::fs::read_to_string(data_path).map_err(|e| { - miette::miette!("failed to read YAML data from {}: {e}", data_path.display()) - })?; - let mut engine = regorus::Engine::new(); - engine - .add_policy_from_file(policy_path) - .map_err(|e| miette::miette!("{e}"))?; - let data_json = preprocess_yaml_data(&yaml_str)?; - engine - .add_data_json(&data_json) - .map_err(|e| miette::miette!("{e}"))?; - Ok(Self { - engine: Mutex::new(engine), - generation: Arc::new(AtomicU64::new(0)), - }) - } - - /// Load policy rules and data from strings (data is YAML). - /// - /// Preprocesses the YAML data to expand access presets and validate L7 config. - pub fn from_strings(policy: &str, data_yaml: &str) -> Result { - let mut engine = regorus::Engine::new(); - engine - .add_policy("policy.rego".into(), policy.into()) - .map_err(|e| miette::miette!("{e}"))?; - let data_json = preprocess_yaml_data(data_yaml)?; - engine - .add_data_json(&data_json) - .map_err(|e| miette::miette!("{e}"))?; - Ok(Self { - engine: Mutex::new(engine), - generation: Arc::new(AtomicU64::new(0)), - }) - } - - /// Create OPA engine from a typed proto policy. - /// - /// Uses baked-in rego rules and converts the proto's typed fields to JSON - /// data under the `sandbox` key (matching `data.sandbox.*` references in - /// the rego rules). - /// - /// Expands access presets and validates L7 config. - pub fn from_proto(proto: &ProtoSandboxPolicy) -> Result { - Self::from_proto_with_pid(proto, 0) - } - - /// Create OPA engine from a typed proto policy with symlink resolution. - /// - /// When `entrypoint_pid` is non-zero, binary paths in the policy that are - /// symlinks inside the container filesystem are resolved via - /// `/proc//root/` and added as additional entries. This bridges the - /// gap between user-specified symlink paths (e.g., `/usr/bin/python3`) and - /// kernel-resolved canonical paths (e.g., `/usr/bin/python3.11`). - pub fn from_proto_with_pid(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> Result { - let data_json_str = proto_to_opa_data_json(proto, entrypoint_pid); - - // Parse back to Value for preprocessing, then re-serialize - let mut data: serde_json::Value = serde_json::from_str(&data_json_str) - .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; - - // Validate BEFORE expanding presets - let (errors, warnings) = crate::l7::validate_l7_policies(&data); - for w in &warnings { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(openshell_ocsf::SeverityId::Medium) - .status(openshell_ocsf::StatusId::Success) - .state(openshell_ocsf::StateId::Enabled, "validated") - .unmapped("warning", serde_json::json!(w.clone())) - .message(format!("L7 policy validation warning: {w}")) - .build() - ); - } - if !errors.is_empty() { - return Err(miette::miette!( - "L7 policy validation failed:\n{}", - errors.join("\n") - )); - } - - // Expand access presets to explicit rules after validation - crate::l7::expand_access_presets(&mut data); - - let data_json = data.to_string(); - let mut engine = regorus::Engine::new(); - engine - .add_policy("policy.rego".into(), BAKED_POLICY_RULES.into()) - .map_err(|e| miette::miette!("{e}"))?; - engine - .add_data_json(&data_json) - .map_err(|e| miette::miette!("{e}"))?; - Ok(Self { - engine: Mutex::new(engine), - generation: Arc::new(AtomicU64::new(0)), - }) - } - - /// Evaluate a network access request against the loaded policy. - /// - /// Builds an OPA input document from the `NetworkInput`, evaluates the - /// `allow_network` rule, and returns a `PolicyDecision` with the result, - /// deny reason, and matched policy name. - pub fn evaluate_network(&self, input: &NetworkInput) -> Result { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); - - let mut engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - - engine - .set_input_json(&input_json.to_string()) - .map_err(|e| miette::miette!("{e}"))?; - - let allowed = engine - .eval_rule("data.openshell.sandbox.allow_network".into()) - .map_err(|e| miette::miette!("{e}"))?; - let allowed = allowed == regorus::Value::from(true); - - let reason = engine - .eval_rule("data.openshell.sandbox.deny_reason".into()) - .map_err(|e| miette::miette!("{e}"))?; - let reason = value_to_string(&reason); - - let matched = engine - .eval_rule("data.openshell.sandbox.matched_network_policy".into()) - .map_err(|e| miette::miette!("{e}"))?; - let matched_policy = if matched == regorus::Value::Undefined { - None - } else { - Some(value_to_string(&matched)) - }; - - Ok(PolicyDecision { - allowed, - reason, - matched_policy, - }) - } - - /// Evaluate a network access request and return a routing action. - /// - /// Uses the OPA `network_action` rule which returns one of: - /// `"allow"` or `"deny"`. - pub fn evaluate_network_action(&self, input: &NetworkInput) -> Result { - Ok(self.evaluate_network_action_with_generation(input)?.0) - } - - /// Evaluate network action and return the policy generation used for the evaluation. - pub fn evaluate_network_action_with_generation( - &self, - input: &NetworkInput, - ) -> Result<(NetworkAction, u64)> { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); - - let mut engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - let generation = self.current_generation(); - - engine - .set_input_json(&input_json.to_string()) - .map_err(|e| miette::miette!("{e}"))?; - - let action_val = engine - .eval_rule("data.openshell.sandbox.network_action".into()) - .map_err(|e| miette::miette!("{e}"))?; - let action_str = value_to_string(&action_val); - - let matched = engine - .eval_rule("data.openshell.sandbox.matched_network_policy".into()) - .map_err(|e| miette::miette!("{e}"))?; - let matched_policy = if matched == regorus::Value::Undefined { - None - } else { - Some(value_to_string(&matched)) - }; - - if action_str == "allow" { - Ok((NetworkAction::Allow { matched_policy }, generation)) - } else { - let reason_val = engine - .eval_rule("data.openshell.sandbox.deny_reason".into()) - .map_err(|e| miette::miette!("{e}"))?; - let reason = value_to_string(&reason_val); - Ok((NetworkAction::Deny { reason }, generation)) - } - } - - /// Reload policy and data from strings (data is YAML). - /// - /// Designed for future gRPC hot-reload from the openshell gateway. - /// Replaces the entire engine atomically. Routes through the full - /// preprocessing pipeline (port normalization, L7 validation, preset - /// expansion) to maintain consistency with `from_strings()`. - pub fn reload(&self, policy: &str, data_yaml: &str) -> Result<()> { - let new = Self::from_strings(policy, data_yaml)?; - let new_engine = new - .engine - .into_inner() - .map_err(|_| miette::miette!("lock poisoned on new engine"))?; - let mut engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - *engine = new_engine; - self.generation.fetch_add(1, Ordering::AcqRel); - Ok(()) - } - - /// Reload policy from a proto `SandboxPolicy` message. - /// - /// Reuses the full `from_proto()` pipeline (proto-to-JSON conversion, L7 - /// validation, access preset expansion) so the reload has identical - /// validation guarantees as initial load. Atomically replaces the inner - /// engine on success; on failure the previous engine is untouched (LKG). - pub fn reload_from_proto(&self, proto: &ProtoSandboxPolicy) -> Result<()> { - self.reload_from_proto_with_pid(proto, 0) - } - - /// Reload policy from a proto with symlink resolution. - /// - /// When `entrypoint_pid` is non-zero, binary paths that are symlinks - /// inside the container filesystem are resolved and added as additional - /// match entries. See [`from_proto_with_pid`] for details. - pub fn reload_from_proto_with_pid( - &self, - proto: &ProtoSandboxPolicy, - entrypoint_pid: u32, - ) -> Result<()> { - // Build a complete new engine through the same validated pipeline. - let new = Self::from_proto_with_pid(proto, entrypoint_pid)?; - let new_engine = new - .engine - .into_inner() - .map_err(|_| miette::miette!("lock poisoned on new engine"))?; - let mut engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - *engine = new_engine; - self.generation.fetch_add(1, Ordering::AcqRel); - Ok(()) - } - - /// Current policy generation. Successful reloads increment this value. - pub fn current_generation(&self) -> u64 { - self.generation.load(Ordering::Acquire) - } - - /// Return a guard for a previously captured policy generation. - pub fn generation_guard(&self, expected_generation: u64) -> Result { - let generation = self.current_generation(); - if generation != expected_generation { - return Err(miette::miette!( - "policy changed before HTTP relay started [expected_generation:{expected_generation} current_generation:{generation}]" - )); - } - Ok(PolicyGenerationGuard { - captured_generation: generation, - current_generation: Arc::clone(&self.generation), - }) - } - - /// Query static sandbox configuration from the OPA data module. - /// - /// Extracts `filesystem_policy`, `landlock`, and `process` from the Rego - /// data and converts them into the Rust policy structs used by the sandbox - /// runtime for filesystem preparation, Landlock setup, and privilege dropping. - pub fn query_sandbox_config(&self) -> Result { - let mut engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - - // Query filesystem policy - let fs_val = engine - .eval_rule("data.openshell.sandbox.filesystem_policy".into()) - .map_err(|e| miette::miette!("{e}"))?; - let filesystem = parse_filesystem_policy(&fs_val); - - // Query landlock policy - let ll_val = engine - .eval_rule("data.openshell.sandbox.landlock_policy".into()) - .map_err(|e| miette::miette!("{e}"))?; - let landlock = parse_landlock_policy(&ll_val); - - // Query process policy - let proc_val = engine - .eval_rule("data.openshell.sandbox.process_policy".into()) - .map_err(|e| miette::miette!("{e}"))?; - let process = parse_process_policy(&proc_val); - - Ok(SandboxConfig { - filesystem, - landlock, - process, - }) - } - - /// Query the L7 endpoint config for a matched policy and host:port. - /// - /// After L4 evaluation allows a CONNECT, this method queries the Rego data - /// to get the full endpoint object for the matched policy. Returns the raw - /// `regorus::Value` which can be parsed by `l7::parse_l7_config()`. - pub fn query_endpoint_config(&self, input: &NetworkInput) -> Result> { - Ok(self.query_endpoint_config_with_generation(input)?.0) - } - - /// Query L7 endpoint config and return the policy generation used for the query. - pub fn query_endpoint_config_with_generation( - &self, - input: &NetworkInput, - ) -> Result<(Option, u64)> { - let (configs, generation) = self.query_endpoint_configs_with_generation(input)?; - Ok((configs.into_iter().next(), generation)) - } - - /// Query all matching endpoint configs and return the policy generation used for the query. - pub fn query_endpoint_configs_with_generation( - &self, - input: &NetworkInput, - ) -> Result<(Vec, u64)> { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); - - let mut engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - let generation = self.current_generation(); - - engine - .set_input_json(&input_json.to_string()) - .map_err(|e| miette::miette!("{e}"))?; - - let val = engine - .eval_rule("data.openshell.sandbox._matching_endpoint_configs".into()) - .map_err(|e| miette::miette!("{e}"))?; - - match val { - regorus::Value::Undefined => Ok((Vec::new(), generation)), - regorus::Value::Array(values) => Ok((values.to_vec(), generation)), - other => Ok((vec![other], generation)), - } - } - - /// Query `allowed_ips` from the matched endpoint config for a given request. - /// - /// Returns the list of CIDR/IP strings from the endpoint's `allowed_ips` - /// field, or an empty vec if the field is absent or the endpoint has no - /// match. This is used by the proxy to decide between full SSRF blocking - /// and allowlist-based IP validation. - pub fn query_allowed_ips(&self, input: &NetworkInput) -> Result> { - Ok(self - .query_endpoint_config(input)? - .map(|val| get_str_array(&val, "allowed_ips")) - .unwrap_or_default()) - } - - /// Return true when the matched endpoint is an exact declared hostname. - /// - /// This intentionally excludes wildcard and hostless endpoints. The proxy - /// uses this as a narrow signal that the operator explicitly declared the - /// destination hostname, which can safely skip the default private-IP SSRF - /// denial while preserving separate handling for `allowed_ips` and advisor - /// proposals. - pub fn query_exact_declared_endpoint_host(&self, input: &NetworkInput) -> Result { - let ancestor_strs: Vec = input - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let cmdline_strs: Vec = input - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(); - let input_json = serde_json::json!({ - "exec": { - "path": input.binary_path.to_string_lossy(), - "ancestors": ancestor_strs, - "cmdline_paths": cmdline_strs, - }, - "network": { - "host": input.host, - "port": input.port, - } - }); - - let mut engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - - engine - .set_input_json(&input_json.to_string()) - .map_err(|e| miette::miette!("{e}"))?; - - let val = engine - .eval_rule("data.openshell.sandbox.exact_declared_endpoint_host".into()) - .map_err(|e| miette::miette!("{e}"))?; - - Ok(val == regorus::Value::from(true)) - } - - /// Clone the inner regorus engine for per-tunnel L7 evaluation. - /// - /// With the `arc` feature enabled, this shares compiled policy via Arc - /// and only duplicates interpreter state (~microseconds). The cloned - /// engine can be used without Mutex contention. - pub fn clone_engine_for_tunnel(&self, expected_generation: u64) -> Result { - let engine = self - .engine - .lock() - .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - let generation = self.current_generation(); - if generation != expected_generation { - return Err(miette::miette!( - "policy changed before L7 tunnel started [expected_generation:{expected_generation} current_generation:{generation}]" - )); - } - Ok(TunnelPolicyEngine { - engine: Mutex::new(engine.clone()), - generation_guard: PolicyGenerationGuard { - captured_generation: generation, - current_generation: Arc::clone(&self.generation), - }, - }) - } -} - -/// Convert a `regorus::Value` to a string, handling various types. -fn value_to_string(val: ®orus::Value) -> String { - match val { - regorus::Value::String(s) => s.to_string(), - regorus::Value::Undefined => String::new(), - other => other.to_string(), - } -} - -/// Extract a string from a `regorus::Value` object field. -fn get_str(val: ®orus::Value, key: &str) -> Option { - let key_val = regorus::Value::String(key.into()); - match val { - regorus::Value::Object(map) => match map.get(&key_val) { - Some(regorus::Value::String(s)) => Some(s.to_string()), - _ => None, - }, - _ => None, - } -} - -/// Extract a bool from a `regorus::Value` object field. -fn get_bool(val: ®orus::Value, key: &str) -> Option { - let key_val = regorus::Value::String(key.into()); - match val { - regorus::Value::Object(map) => match map.get(&key_val) { - Some(regorus::Value::Bool(b)) => Some(*b), - _ => None, - }, - _ => None, - } -} - -/// Extract a string array from a `regorus::Value` object field. -fn get_str_array(val: ®orus::Value, key: &str) -> Vec { - let key_val = regorus::Value::String(key.into()); - match val { - regorus::Value::Object(map) => match map.get(&key_val) { - Some(regorus::Value::Array(arr)) => arr - .iter() - .filter_map(|v| { - if let regorus::Value::String(s) = v { - Some(s.to_string()) - } else { - None - } - }) - .collect(), - _ => vec![], - }, - _ => vec![], - } -} - -fn parse_filesystem_policy(val: ®orus::Value) -> FilesystemPolicy { - FilesystemPolicy { - read_only: get_str_array(val, "read_only") - .into_iter() - .map(PathBuf::from) - .collect(), - read_write: get_str_array(val, "read_write") - .into_iter() - .map(PathBuf::from) - .collect(), - include_workdir: get_bool(val, "include_workdir").unwrap_or(true), - } -} - -fn parse_landlock_policy(val: ®orus::Value) -> LandlockPolicy { - let compat = get_str(val, "compatibility").unwrap_or_default(); - LandlockPolicy { - compatibility: if compat == "hard_requirement" { - LandlockCompatibility::HardRequirement - } else { - LandlockCompatibility::BestEffort - }, - } -} - -fn parse_process_policy(val: ®orus::Value) -> ProcessPolicy { - ProcessPolicy { - run_as_user: get_str(val, "run_as_user"), - run_as_group: get_str(val, "run_as_group"), - } -} - -/// Preprocess YAML policy data: parse, normalize, validate, expand access presets, return JSON. -fn preprocess_yaml_data(yaml_str: &str) -> Result { - let mut data: serde_json::Value = serde_yml::from_str(yaml_str) - .map_err(|e| miette::miette!("failed to parse YAML data: {e}"))?; - - // Normalize port → ports for all endpoints so Rego always sees "ports" array. - normalize_endpoint_ports(&mut data); - - // Validate BEFORE expanding presets (catches user errors like rules+access) - let (errors, warnings) = crate::l7::validate_l7_policies(&data); - for w in &warnings { - openshell_ocsf::ocsf_emit!( - openshell_ocsf::ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(openshell_ocsf::SeverityId::Medium) - .status(openshell_ocsf::StatusId::Success) - .state(openshell_ocsf::StateId::Enabled, "validated") - .unmapped("warning", serde_json::json!(w.clone())) - .message(format!("L7 policy validation warning: {w}")) - .build() - ); - } - if !errors.is_empty() { - return Err(miette::miette!( - "L7 policy validation failed:\n{}", - errors.join("\n") - )); - } - - // Expand access presets to explicit rules after validation - crate::l7::expand_access_presets(&mut data); - - serde_json::to_string(&data).map_err(|e| miette::miette!("failed to serialize data: {e}")) -} - -/// Normalize endpoint port/ports in JSON data. -/// -/// YAML policies may use `port: N` (single) or `ports: [N, M]` (multi). -/// This normalizes all endpoints to have a `ports` array so Rego rules -/// only need to reference `endpoint.ports[_]`. -fn normalize_endpoint_ports(data: &mut serde_json::Value) { - let Some(policies) = data - .get_mut("network_policies") - .and_then(|v| v.as_object_mut()) - else { - return; - }; - - for (_name, policy) in policies.iter_mut() { - let Some(endpoints) = policy.get_mut("endpoints").and_then(|v| v.as_array_mut()) else { - continue; - }; - - for ep in endpoints.iter_mut() { - let Some(ep_obj) = ep.as_object_mut() else { - continue; - }; - - // If "ports" already exists and is non-empty, keep it. - let has_ports = ep_obj - .get("ports") - .and_then(|v| v.as_array()) - .is_some_and(|a| !a.is_empty()); - - if !has_ports { - // Promote scalar "port" to "ports" array. - let port = ep_obj - .get("port") - .and_then(serde_json::Value::as_u64) - .unwrap_or(0); - if port > 0 { - ep_obj.insert( - "ports".to_string(), - serde_json::Value::Array(vec![serde_json::json!(port)]), - ); - } - } - - // Remove scalar "port" — Rego only uses "ports". - ep_obj.remove("port"); - } - } -} - -/// Resolve a policy binary path through the container's root filesystem. -/// -/// On Linux, `/proc//root/` provides access to the container's mount -/// namespace. If the policy path is a symlink inside the container -/// (e.g., `/usr/bin/python3` → `/usr/bin/python3.11`), returns the -/// canonical target path. Returns `None` if: -/// - Not on Linux -/// - `entrypoint_pid` is 0 (container not yet started) -/// - Path contains glob characters -/// - Path is not a symlink -/// - Resolution fails (binary doesn't exist in container) -/// - Resolved path equals the original -/// -/// Normalize a path by resolving `.` and `..` components without touching -/// the filesystem. Only works correctly for absolute paths. -#[cfg(any(target_os = "linux", test))] -fn normalize_path(path: &Path) -> PathBuf { - let mut result = PathBuf::new(); - for component in path.components() { - match component { - std::path::Component::ParentDir => { - result.pop(); - } - std::path::Component::CurDir => {} - other => result.push(other), - } - } - result -} - -#[cfg(target_os = "linux")] -fn resolve_binary_in_container(policy_path: &str, entrypoint_pid: u32) -> Option { - if policy_path.contains('*') || entrypoint_pid == 0 { - return None; - } - - // Walk the symlink chain inside the container filesystem using - // read_link rather than canonicalize. canonicalize resolves - // /proc//root itself (a kernel pseudo-symlink to /) which - // strips the prefix we need. read_link only reads the target of - // the specified symlink, keeping us in the container's namespace. - let mut resolved = PathBuf::from(policy_path); - - // Linux SYMLOOP_MAX is 40; stop before infinite loops - for _ in 0..40 { - let container_path = format!("/proc/{entrypoint_pid}/root{}", resolved.display()); - - tracing::debug!( - "Symlink resolution: probing container_path={container_path} for policy_path={policy_path} pid={entrypoint_pid}" - ); - - let meta = match std::fs::symlink_metadata(&container_path) { - Ok(m) => m, - Err(e) => { - // Only warn on the first iteration (the original policy path). - // On subsequent iterations, the intermediate target may - // legitimately not exist (broken symlink chain). - if resolved.as_os_str() == policy_path { - tracing::warn!( - "Cannot access container filesystem for symlink resolution: \ - path={policy_path} container_path={container_path} pid={entrypoint_pid} \ - error={e}. Binary paths in policy will be matched literally. \ - If this binary is a symlink (e.g., /usr/bin/python3 -> python3.11), \ - use the canonical path instead, or run with CAP_SYS_PTRACE." - ); - } else { - tracing::warn!( - "Symlink chain broken during resolution: \ - original={policy_path} current={} pid={entrypoint_pid} error={e}. \ - Binary will be matched by original path only.", - resolved.display() - ); - } - return None; - } - }; - - if !meta.file_type().is_symlink() { - // Reached a non-symlink — this is the final resolved target - break; - } - - let target = match std::fs::read_link(&container_path) { - Ok(t) => t, - Err(e) => { - tracing::warn!( - "Symlink detected but read_link failed: \ - path={policy_path} current={} pid={entrypoint_pid} error={e}. \ - Binary will be matched by original path only.", - resolved.display() - ); - return None; - } - }; - - if target.is_absolute() { - resolved = target; - } else { - // Relative symlink: resolve against the containing directory - // e.g., /usr/bin/python3 -> python3.11 becomes /usr/bin/python3.11 - if let Some(parent) = resolved.parent() { - resolved = normalize_path(&parent.join(&target)); - } else { - break; - } - } - } - - let resolved_str = resolved.to_string_lossy().into_owned(); - - if resolved_str == policy_path { - None - } else { - tracing::info!( - "Resolved policy binary symlink via container filesystem: \ - original={policy_path} resolved={resolved_str} pid={entrypoint_pid}" - ); - Some(resolved_str) - } -} - -#[cfg(not(target_os = "linux"))] -fn resolve_binary_in_container(_policy_path: &str, _entrypoint_pid: u32) -> Option { - None -} - -fn l7_matchers_to_json( - matchers: &std::collections::HashMap, -) -> serde_json::Map { - matchers - .iter() - .map(|(key, matcher)| { - let mut matcher_json = serde_json::json!({}); - if !matcher.glob.is_empty() { - matcher_json["glob"] = matcher.glob.clone().into(); - } - if !matcher.any.is_empty() { - matcher_json["any"] = matcher.any.clone().into(); - } - (key.clone(), matcher_json) - }) - .collect() -} - -/// Convert typed proto policy fields to JSON suitable for `engine.add_data_json()`. -/// -/// The rego rules reference `data.*` directly, so the JSON structure has -/// top-level keys matching the data expectations: -/// - `data.filesystem_policy` -/// - `data.landlock` -/// - `data.process` -/// - `data.network_policies` -/// -/// When `entrypoint_pid` is non-zero, binary paths that are symlinks inside -/// the container filesystem are resolved via `/proc//root/` and added -/// as additional entries alongside the original path. This ensures that -/// user-specified symlink paths (e.g., `/usr/bin/python3`) match the -/// kernel-resolved canonical paths reported by `/proc//exe` (e.g., -/// `/usr/bin/python3.11`). -fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> String { - let filesystem_policy = proto.filesystem.as_ref().map_or_else( - || { - serde_json::json!({ - "include_workdir": true, - "read_only": [], - "read_write": [], - }) - }, - |fs| { - serde_json::json!({ - "include_workdir": fs.include_workdir, - "read_only": fs.read_only, - "read_write": fs.read_write, - }) - }, - ); - - let landlock = proto.landlock.as_ref().map_or_else( - || serde_json::json!({"compatibility": "best_effort"}), - |ll| serde_json::json!({"compatibility": ll.compatibility}), - ); - - let process = proto.process.as_ref().map_or_else( - || { - serde_json::json!({ - "run_as_user": "sandbox", - "run_as_group": "sandbox", - }) - }, - |p| { - serde_json::json!({ - "run_as_user": p.run_as_user, - "run_as_group": p.run_as_group, - }) - }, - ); - - let network_policies: serde_json::Map = proto - .network_policies - .iter() - .map(|(key, rule)| { - let endpoints: Vec = rule - .endpoints - .iter() - .map(|e| { - // Normalize port/ports: ports takes precedence, then - // single port promoted to array. Rego always sees "ports". - let ports: Vec = if !e.ports.is_empty() { - e.ports.clone() - } else if e.port > 0 { - vec![e.port] - } else { - vec![] - }; - let mut ep = serde_json::json!({"host": e.host, "ports": ports}); - if !e.path.is_empty() { - ep["path"] = e.path.clone().into(); - } - if !e.protocol.is_empty() { - ep["protocol"] = e.protocol.clone().into(); - } - if !e.tls.is_empty() { - ep["tls"] = e.tls.clone().into(); - } - if !e.enforcement.is_empty() { - ep["enforcement"] = e.enforcement.clone().into(); - } - if !e.access.is_empty() { - ep["access"] = e.access.clone().into(); - } - if !e.rules.is_empty() { - let rules: Vec = e - .rules - .iter() - .map(|r| { - let a = r.allow.as_ref(); - let mut allow = serde_json::json!({ - "method": a.map_or("", |a| &a.method), - "path": a.map_or("", |a| &a.path), - "command": a.map_or("", |a| &a.command), - "operation_type": a.map_or("", |a| &a.operation_type), - "operation_name": a.map_or("", |a| &a.operation_name), - "rpc_method": a.map_or("", |a| &a.rpc_method), - }); - if let Some(a) = a - && !a.fields.is_empty() - { - allow["fields"] = a.fields.clone().into(); - } - let query = a.map_or_else(serde_json::Map::new, |allow| { - l7_matchers_to_json(&allow.query) - }); - if !query.is_empty() { - allow["query"] = query.into(); - } - let params = a.map_or_else(serde_json::Map::new, |allow| { - l7_matchers_to_json(&allow.params) - }); - if !params.is_empty() { - allow["params"] = params.into(); - } - serde_json::json!({ "allow": allow }) - }) - .collect(); - ep["rules"] = rules.into(); - } - if !e.allowed_ips.is_empty() { - ep["allowed_ips"] = e.allowed_ips.clone().into(); - } - if e.advisor_proposed { - ep["advisor_proposed"] = true.into(); - } - if !e.deny_rules.is_empty() { - let deny_rules: Vec = e - .deny_rules - .iter() - .map(|d| { - let mut deny = serde_json::json!({}); - if !d.method.is_empty() { - deny["method"] = d.method.clone().into(); - } - if !d.path.is_empty() { - deny["path"] = d.path.clone().into(); - } - if !d.command.is_empty() { - deny["command"] = d.command.clone().into(); - } - if !d.operation_type.is_empty() { - deny["operation_type"] = d.operation_type.clone().into(); - } - if !d.operation_name.is_empty() { - deny["operation_name"] = d.operation_name.clone().into(); - } - if !d.fields.is_empty() { - deny["fields"] = d.fields.clone().into(); - } - if !d.rpc_method.is_empty() { - deny["rpc_method"] = d.rpc_method.clone().into(); - } - let query = l7_matchers_to_json(&d.query); - if !query.is_empty() { - deny["query"] = query.into(); - } - let params = l7_matchers_to_json(&d.params); - if !params.is_empty() { - deny["params"] = params.into(); - } - deny - }) - .collect(); - ep["deny_rules"] = deny_rules.into(); - } - if e.allow_encoded_slash { - ep["allow_encoded_slash"] = true.into(); - } - if e.websocket_credential_rewrite { - ep["websocket_credential_rewrite"] = true.into(); - } - if e.request_body_credential_rewrite { - ep["request_body_credential_rewrite"] = true.into(); - } - if !e.persisted_queries.is_empty() { - ep["persisted_queries"] = e.persisted_queries.clone().into(); - } - if !e.graphql_persisted_queries.is_empty() { - let persisted: serde_json::Map = e - .graphql_persisted_queries - .iter() - .map(|(key, op)| { - ( - key.clone(), - serde_json::json!({ - "operation_type": op.operation_type, - "operation_name": op.operation_name, - "fields": op.fields, - }), - ) - }) - .collect(); - ep["graphql_persisted_queries"] = persisted.into(); - } - if e.graphql_max_body_bytes > 0 { - ep["graphql_max_body_bytes"] = e.graphql_max_body_bytes.into(); - } - if e.json_rpc_max_body_bytes > 0 { - ep["json_rpc_max_body_bytes"] = e.json_rpc_max_body_bytes.into(); - } - ep - }) - .collect(); - let binaries: Vec = rule - .binaries - .iter() - .flat_map(|b| { - // The deprecated harness bit is ignored by policy YAML, but - // advisor-generated proposals use it as internal provenance. - #[allow(deprecated)] - let advisor_proposed = b.harness; - let binary_entry = |path: &str| { - let mut entry = serde_json::json!({"path": path}); - if advisor_proposed { - entry["advisor_proposed"] = true.into(); - } - entry - }; - let mut entries = vec![binary_entry(&b.path)]; - if let Some(resolved) = resolve_binary_in_container(&b.path, entrypoint_pid) { - entries.push(binary_entry(&resolved)); - } - entries - }) - .collect(); - ( - key.clone(), - serde_json::json!({ - "name": rule.name, - "endpoints": endpoints, - "binaries": binaries, - }), - ) - }) - .collect(); - - serde_json::json!({ - "filesystem_policy": filesystem_policy, - "landlock": landlock, - "process": process, - "network_policies": network_policies, - }) - .to_string() -} - -#[cfg(test)] -#[allow( - clippy::needless_raw_string_hashes, - clippy::similar_names, - clippy::doc_markdown, - clippy::match_wildcard_for_single_variants, - reason = "Test code: test fixtures and panic-on-unexpected matches are idiomatic in tests." -)] -mod tests { - use super::*; - - use openshell_core::proto::{ - FilesystemPolicy as ProtoFs, L7Allow, L7QueryMatcher, L7Rule, NetworkBinary, - NetworkEndpoint, NetworkPolicyRule, ProcessPolicy as ProtoProc, - SandboxPolicy as ProtoSandboxPolicy, - }; - - const TEST_POLICY: &str = include_str!("../data/sandbox-policy.rego"); - const TEST_DATA_YAML: &str = include_str!("../testdata/sandbox-policy.yaml"); - - fn test_engine() -> OpaEngine { - OpaEngine::from_strings(TEST_POLICY, TEST_DATA_YAML).expect("Failed to load test policy") - } - - fn test_proto() -> ProtoSandboxPolicy { - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "claude_code".to_string(), - NetworkPolicyRule { - name: "claude_code".to_string(), - endpoints: vec![ - NetworkEndpoint { - host: "api.anthropic.com".to_string(), - port: 443, - ..Default::default() - }, - NetworkEndpoint { - host: "statsig.anthropic.com".to_string(), - port: 443, - ..Default::default() - }, - ], - binaries: vec![NetworkBinary { - path: "/usr/local/bin/claude".to_string(), - ..Default::default() - }], - }, - ); - network_policies.insert( - "gitlab".to_string(), - NetworkPolicyRule { - name: "gitlab".to_string(), - endpoints: vec![NetworkEndpoint { - host: "gitlab.com".to_string(), - port: 443, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/glab".to_string(), - ..Default::default() - }], - }, - ); - ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec!["/usr".to_string(), "/lib".to_string()], - read_write: vec!["/sandbox".to_string(), "/tmp".to_string()], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - } - } - - #[test] - fn allowed_binary_and_endpoint() { - let engine = test_engine(); - // Simulates Claude Code: exe is /usr/bin/node, script is /usr/local/bin/claude - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected allow, got deny: {}", - decision.reason - ); - assert_eq!(decision.matched_policy.as_deref(), Some("claude_code")); - } - - #[test] - fn wrong_binary_denied() { - let engine = test_engine(); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - assert!( - decision.reason.contains("not allowed"), - "Expected specific deny reason, got: {}", - decision.reason - ); - } - - #[test] - fn wrong_endpoint_denied() { - let engine = test_engine(); - let input = NetworkInput { - host: "evil.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - assert!( - decision.reason.contains("endpoint"), - "Expected endpoint deny reason, got: {}", - decision.reason - ); - } - - #[test] - fn unknown_binary_default_deny() { - let engine = test_engine(); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/tmp/malicious"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - } - - #[test] - fn github_policy_allows_git() { - let engine = test_engine(); - let input = NetworkInput { - host: "github.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/git"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected allow, got deny: {}", - decision.reason - ); - assert_eq!( - decision.matched_policy.as_deref(), - Some("github_ssh_over_https") - ); - } - - #[test] - fn case_insensitive_host_matching() { - let engine = test_engine(); - let input = NetworkInput { - host: "API.ANTHROPIC.COM".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected case-insensitive match, got deny: {}", - decision.reason - ); - } - - #[test] - fn wrong_port_denied() { - let engine = test_engine(); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 80, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - } - - #[test] - fn query_sandbox_config_extracts_filesystem() { - let engine = test_engine(); - let config = engine.query_sandbox_config().unwrap(); - assert!(config.filesystem.include_workdir); - assert!(config.filesystem.read_only.contains(&PathBuf::from("/usr"))); - assert!( - config - .filesystem - .read_write - .contains(&PathBuf::from("/tmp")) - ); - } - - #[test] - fn query_sandbox_config_extracts_process() { - let engine = test_engine(); - let config = engine.query_sandbox_config().unwrap(); - assert_eq!(config.process.run_as_user.as_deref(), Some("sandbox")); - assert_eq!(config.process.run_as_group.as_deref(), Some("sandbox")); - } - - #[test] - fn from_strings_and_from_files_produce_same_results() { - let engine = test_engine(); - - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(decision.allowed); - } - - #[test] - fn reload_replaces_policy() { - let engine = test_engine(); - - // Verify initial policy works - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(decision.allowed); - - // Reload with a policy that has no network policies (deny all) - let empty_data = r" -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -network_policies: {} -"; - engine.reload(TEST_POLICY, empty_data).unwrap(); - - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - !decision.allowed, - "Expected deny after reload with empty policies" - ); - } - - #[test] - fn ancestor_binary_allowed() { - // Use github policy: binary /usr/bin/git is the policy binary. - // If the socket process is /usr/bin/python3 but its ancestor is /usr/bin/git, allow. - let engine = test_engine(); - let input = NetworkInput { - host: "github.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![PathBuf::from("/usr/bin/git")], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected allow via ancestor match, got deny: {}", - decision.reason - ); - assert_eq!( - decision.matched_policy.as_deref(), - Some("github_ssh_over_https") - ); - } - - #[test] - fn no_ancestor_match_denied() { - let engine = test_engine(); - let input = NetworkInput { - host: "github.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![PathBuf::from("/usr/bin/bash")], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - assert!( - decision.reason.contains("not allowed"), - "Expected 'not allowed' in deny reason, got: {}", - decision.reason - ); - } - - #[test] - fn deep_ancestor_chain_matches() { - let engine = test_engine(); - let input = NetworkInput { - host: "github.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![PathBuf::from("/usr/bin/sh"), PathBuf::from("/usr/bin/git")], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected allow via deep ancestor match, got deny: {}", - decision.reason - ); - } - - #[test] - fn empty_ancestors_falls_back_to_direct() { - let engine = test_engine(); - // Direct binary path match still works with empty ancestors and cmdline - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Direct path match should still work with empty ancestors" - ); - } - - #[test] - fn glob_pattern_matches_binary() { - // Test with a policy that uses glob patterns - let glob_data = r#" -network_policies: - glob_test: - name: glob_test - endpoints: - - { host: example.com, port: 443 } - binaries: - - { path: "/usr/bin/*" } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected glob pattern to match binary, got deny: {}", - decision.reason - ); - } - - #[test] - fn glob_pattern_matches_ancestor() { - let glob_data = r#" -network_policies: - glob_test: - name: glob_test - endpoints: - - { host: example.com, port: 443 } - binaries: - - { path: "/usr/local/bin/*" } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![PathBuf::from("/usr/local/bin/claude")], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected glob pattern to match ancestor, got deny: {}", - decision.reason - ); - } - - #[test] - fn glob_pattern_no_cross_segment() { - // * should NOT match across / boundaries - let glob_data = r#" -network_policies: - glob_test: - name: glob_test - endpoints: - - { host: example.com, port: 443 } - binaries: - - { path: "/usr/bin/*" } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/subdir/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed, "Glob * should not cross / boundaries"); - } - - #[test] - fn cmdline_path_does_not_grant_access() { - // Simulates: node runs /usr/local/bin/my-tool (a script with shebang). - // exe = /usr/bin/node, cmdline contains /usr/local/bin/my-tool. - // cmdline_paths are attacker-controlled (argv[0] spoofing) and must - // NOT be used as a grant-access signal. - let cmdline_data = r" -network_policies: - script_test: - name: script_test - endpoints: - - { host: example.com, port: 443 } - binaries: - - { path: /usr/local/bin/my-tool } -"; - let engine = OpaEngine::from_strings(TEST_POLICY, cmdline_data).unwrap(); - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![PathBuf::from("/usr/bin/bash")], - cmdline_paths: vec![PathBuf::from("/usr/local/bin/my-tool")], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - !decision.allowed, - "cmdline_paths must not grant network access (argv[0] is spoofable)" - ); - } - - #[test] - fn cmdline_path_no_match_denied() { - let cmdline_data = r" -network_policies: - script_test: - name: script_test - endpoints: - - { host: example.com, port: 443 } - binaries: - - { path: /usr/local/bin/my-tool } -"; - let engine = OpaEngine::from_strings(TEST_POLICY, cmdline_data).unwrap(); - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![PathBuf::from("/usr/bin/bash")], - cmdline_paths: vec![ - PathBuf::from("/usr/bin/node"), - PathBuf::from("/tmp/script.js"), - ], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - } - - #[test] - fn cmdline_glob_pattern_does_not_grant_access() { - let glob_data = r#" -network_policies: - glob_test: - name: glob_test - endpoints: - - { host: example.com, port: 443 } - binaries: - - { path: "/usr/local/bin/*" } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, glob_data).unwrap(); - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![PathBuf::from("/usr/local/bin/claude")], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - !decision.allowed, - "cmdline_paths must not match globs for granting access (argv[0] is spoofable)" - ); - } - - #[test] - fn from_proto_allows_matching_request() { - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Expected allow from proto-based engine, got deny: {}", - decision.reason - ); - assert_eq!(decision.matched_policy.as_deref(), Some("claude_code")); - } - - #[test] - fn from_proto_denies_unmatched_request() { - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); - let input = NetworkInput { - host: "evil.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - } - - #[test] - fn from_proto_extracts_sandbox_config() { - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); - let config = engine.query_sandbox_config().unwrap(); - assert!(config.filesystem.include_workdir); - assert!(config.filesystem.read_only.contains(&PathBuf::from("/usr"))); - assert!( - config - .filesystem - .read_write - .contains(&PathBuf::from("/tmp")) - ); - assert_eq!(config.process.run_as_user.as_deref(), Some("sandbox")); - assert_eq!(config.process.run_as_group.as_deref(), Some("sandbox")); - } - - // ======================================================================== - // L7 request evaluation tests - // ======================================================================== - - const L7_TEST_DATA: &str = r#" -network_policies: - rest_api: - name: rest_api - endpoints: - - host: api.example.com - port: 8080 - protocol: rest - enforcement: enforce - rules: - - allow: - method: GET - path: "/repos/**" - - allow: - method: POST - path: "/repos/*/issues" - binaries: - - { path: /usr/bin/curl } - readonly_api: - name: readonly_api - endpoints: - - host: api.readonly.com - port: 8080 - protocol: rest - enforcement: enforce - access: read-only - binaries: - - { path: /usr/bin/curl } - full_api: - name: full_api - endpoints: - - host: api.full.com - port: 8080 - protocol: rest - enforcement: audit - access: full - binaries: - - { path: /usr/bin/curl } - query_api: - name: query_api - endpoints: - - host: api.query.com - port: 8080 - protocol: rest - enforcement: enforce - rules: - - allow: - method: GET - path: "/download" - query: - tag: "foo-*" - - allow: - method: GET - path: "/search" - query: - tag: - any: ["foo-*", "bar-*"] - binaries: - - { path: /usr/bin/curl } - graphql_api: - name: graphql_api - endpoints: - - host: api.graphql.com - port: 443 - protocol: graphql - enforcement: enforce - persisted_queries: allow_registered - graphql_persisted_queries: - abc123: - operation_type: query - operation_name: Viewer - fields: [viewer] - rules: - - allow: - operation_type: query - fields: [viewer, repository] - - allow: - operation_type: mutation - operation_name: Issue* - fields: [createIssue, deleteRepository] - deny_rules: - - operation_type: mutation - fields: [deleteRepository] - binaries: - - { path: /usr/bin/curl } - graphql_readonly: - name: graphql_readonly - endpoints: - - host: gql.readonly.com - port: 443 - protocol: graphql - enforcement: enforce - access: read-only - binaries: - - { path: /usr/bin/curl } - graphql_ws: - name: graphql_ws - endpoints: - - host: realtime.graphql.com - ports: [443] - path: "/graphql" - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/graphql" - - allow: - operation_type: query - fields: [viewer] - - allow: - operation_type: subscription - fields: [messageAdded] - deny_rules: - - operation_type: mutation - binaries: - - { path: /usr/bin/curl } - l4_only: - name: l4_only - endpoints: - - { host: l4only.example.com, port: 443 } - binaries: - - { path: /usr/bin/curl } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - - fn l7_engine() -> OpaEngine { - OpaEngine::from_strings(TEST_POLICY, L7_TEST_DATA).expect("Failed to load L7 test data") - } - - fn l7_input(host: &str, port: u16, method: &str, path: &str) -> serde_json::Value { - l7_input_with_query(host, port, method, path, serde_json::json!({})) - } - - fn l7_input_with_query( - host: &str, - port: u16, - method: &str, - path: &str, - query_params: serde_json::Value, - ) -> serde_json::Value { - serde_json::json!({ - "network": { "host": host, "port": port }, - "exec": { - "path": "/usr/bin/curl", - "ancestors": [], - "cmdline_paths": [] - }, - "request": { - "method": method, - "path": path, - "query_params": query_params - } - }) - } - - fn l7_jsonrpc_input(host: &str, port: u16, path: &str, rpc_method: &str) -> serde_json::Value { - l7_jsonrpc_input_with_params(host, port, path, rpc_method, serde_json::json!({})) - } - - fn l7_jsonrpc_input_with_params( - host: &str, - port: u16, - path: &str, - rpc_method: &str, - params: serde_json::Value, - ) -> serde_json::Value { - serde_json::json!({ - "network": { "host": host, "port": port }, - "exec": { - "path": "/usr/bin/curl", - "ancestors": [], - "cmdline_paths": [] - }, - "request": { - "method": "POST", - "path": path, - "query_params": {}, - "jsonrpc": { - "method": rpc_method, - "params": params - } - } - }) - } - - fn l7_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { - serde_json::json!({ - "network": { "host": host, "port": 443 }, - "exec": { - "path": "/usr/bin/curl", - "ancestors": [], - "cmdline_paths": [] - }, - "request": { - "method": "POST", - "path": "/graphql", - "query_params": {}, - "graphql": { - "operations": operations - } - } - }) - } - - fn l7_graphql_error_input(host: &str, error: &str) -> serde_json::Value { - serde_json::json!({ - "network": { "host": host, "port": 443 }, - "exec": { - "path": "/usr/bin/curl", - "ancestors": [], - "cmdline_paths": [] - }, - "request": { - "method": "POST", - "path": "/graphql", - "query_params": {}, - "graphql": { - "operations": [], - "error": error - } - } - }) - } - - fn l7_websocket_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { - serde_json::json!({ - "network": { "host": host, "port": 443 }, - "exec": { - "path": "/usr/bin/curl", - "ancestors": [], - "cmdline_paths": [] - }, - "request": { - "method": "WEBSOCKET_TEXT", - "path": "/graphql", - "query_params": {}, - "graphql": { - "operations": operations - } - } - }) - } - - fn eval_l7(engine: &OpaEngine, input: &serde_json::Value) -> bool { - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - val == regorus::Value::from(true) - } - - #[test] - fn l7_get_allowed_by_rules() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "GET", "/repos/myorg/foo"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_post_allowed_by_rules() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "POST", "/repos/myorg/issues"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_delete_denied_by_rules() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "DELETE", "/repos/myorg/foo"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_get_wrong_path_denied() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "GET", "/admin/settings"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_readonly_preset_allows_get() { - let engine = l7_engine(); - let input = l7_input("api.readonly.com", 8080, "GET", "/anything"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_readonly_preset_allows_head() { - let engine = l7_engine(); - let input = l7_input("api.readonly.com", 8080, "HEAD", "/anything"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_readonly_preset_allows_options() { - let engine = l7_engine(); - let input = l7_input("api.readonly.com", 8080, "OPTIONS", "/anything"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_readonly_preset_denies_post() { - let engine = l7_engine(); - let input = l7_input("api.readonly.com", 8080, "POST", "/anything"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_readonly_preset_denies_delete() { - let engine = l7_engine(); - let input = l7_input("api.readonly.com", 8080, "DELETE", "/anything"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_full_preset_allows_everything() { - let engine = l7_engine(); - for method in &["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"] { - let input = l7_input("api.full.com", 8080, method, "/any/path"); - assert!( - eval_l7(&engine, &input), - "{method} should be allowed with full preset" - ); - } - } - - #[test] - fn l7_graphql_query_allowed_by_field_rule() { - let engine = l7_engine(); - let input = l7_graphql_input( - "api.graphql.com", - serde_json::json!([{ - "operation_type": "query", - "operation_name": "RepoLookup", - "fields": ["repository"], - "persisted_query": false - }]), - ); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_graphql_unlisted_field_denied() { - let engine = l7_engine(); - let input = l7_graphql_input( - "api.graphql.com", - serde_json::json!([{ - "operation_type": "query", - "fields": ["viewer", "adminAuditLog"], - "persisted_query": false - }]), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_graphql_batch_denied_if_any_operation_unallowed() { - let engine = l7_engine(); - let input = l7_graphql_input( - "api.graphql.com", - serde_json::json!([ - { - "operation_type": "query", - "fields": ["viewer"], - "persisted_query": false - }, - { - "operation_type": "mutation", - "operation_name": "DeleteRepo", - "fields": ["deleteRepository"], - "persisted_query": false - } - ]), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_graphql_deny_rule_takes_precedence() { - let engine = l7_engine(); - let input = l7_graphql_input( - "api.graphql.com", - serde_json::json!([{ - "operation_type": "mutation", - "operation_name": "IssueDelete", - "fields": ["deleteRepository"], - "persisted_query": false - }]), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_graphql_registered_hash_only_query_allowed() { - let engine = l7_engine(); - let input = l7_graphql_input( - "api.graphql.com", - serde_json::json!([{ - "operation_type": "", - "operation_name": "Viewer", - "fields": [], - "persisted_query": true, - "persisted_query_hash": "abc123" - }]), - ); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_graphql_unregistered_hash_only_query_denied() { - let engine = l7_engine(); - let input = l7_graphql_input( - "api.graphql.com", - serde_json::json!([{ - "operation_type": "", - "operation_name": "Viewer", - "fields": [], - "persisted_query": true, - "persisted_query_hash": "missing" - }]), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_graphql_unregistered_hash_only_query_has_deny_reason() { - let engine = l7_engine(); - let input = l7_graphql_input( - "api.graphql.com", - serde_json::json!([{ - "operation_type": "", - "operation_name": "Viewer", - "fields": [], - "persisted_query": true, - "persisted_query_hash": "missing" - }]), - ); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.request_deny_reason".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::String("GraphQL persisted query is not registered".into()) - ); - } - - #[test] - fn l7_graphql_parse_error_denied() { - let engine = l7_engine(); - let input = l7_graphql_error_input("api.graphql.com", "GraphQL document parse error"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_graphql_readonly_access_allows_query_and_denies_mutation() { - let engine = l7_engine(); - let query = l7_graphql_input( - "gql.readonly.com", - serde_json::json!([{ - "operation_type": "query", - "fields": ["viewer"], - "persisted_query": false - }]), - ); - assert!(eval_l7(&engine, &query)); - - let mutation = l7_graphql_input( - "gql.readonly.com", - serde_json::json!([{ - "operation_type": "mutation", - "fields": ["createIssue"], - "persisted_query": false - }]), - ); - assert!(!eval_l7(&engine, &mutation)); - } - - #[test] - fn l7_websocket_graphql_subscription_allowed_by_field_rule() { - let engine = l7_engine(); - let input = l7_websocket_graphql_input( - "realtime.graphql.com", - serde_json::json!([{ - "operation_type": "subscription", - "operation_name": "NewMessages", - "fields": ["messageAdded"], - "persisted_query": false - }]), - ); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_websocket_graphql_unlisted_field_denied() { - let engine = l7_engine(); - let input = l7_websocket_graphql_input( - "realtime.graphql.com", - serde_json::json!([{ - "operation_type": "query", - "fields": ["adminAuditLog"], - "persisted_query": false - }]), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_websocket_graphql_deny_rule_takes_precedence() { - let engine = l7_engine(); - let input = l7_websocket_graphql_input( - "realtime.graphql.com", - serde_json::json!([{ - "operation_type": "mutation", - "operation_name": "DeleteRepo", - "fields": ["deleteRepository"], - "persisted_query": false - }]), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_websocket_graphql_not_bypassed_by_generic_text_rule() { - let data = r#" -network_policies: - graphql_ws: - name: graphql_ws - endpoints: - - host: realtime.graphql.com - ports: [443] - path: "/graphql" - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/graphql" - - allow: - method: WEBSOCKET_TEXT - path: "/graphql" - - allow: - operation_type: query - fields: [viewer] - binaries: - - { path: /usr/bin/curl } -"#; - let data_json: serde_json::Value = - serde_yml::from_str(data).expect("fixture should parse as YAML"); - let mut rego = regorus::Engine::new(); - rego.add_policy("policy.rego".into(), TEST_POLICY.into()) - .expect("policy should load"); - rego.add_data_json(&data_json.to_string()) - .expect("data should load"); - let engine = OpaEngine { - engine: Mutex::new(rego), - generation: Arc::new(AtomicU64::new(0)), - }; - let input = l7_websocket_graphql_input( - "realtime.graphql.com", - serde_json::json!([{ - "operation_type": "query", - "fields": ["adminAuditLog"], - "persisted_query": false - }]), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_endpoint_path_scopes_rest_and_graphql_on_same_host() { - let data = r#" -network_policies: - mixed_api: - name: mixed_api - endpoints: - - host: api.github.test - port: 443 - path: "/repos/**" - protocol: rest - enforcement: enforce - rules: - - allow: - method: "*" - path: "/**" - - host: api.github.test - port: 443 - path: "/graphql" - protocol: graphql - enforcement: enforce - rules: - - allow: - operation_type: query - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - - let rest_write = l7_input("api.github.test", 443, "POST", "/repos/org/repo/issues"); - assert!(eval_l7(&engine, &rest_write)); - - let graphql_query = l7_graphql_input( - "api.github.test", - serde_json::json!([{ - "operation_type": "query", - "fields": ["viewer"], - "persisted_query": false - }]), - ); - assert!(eval_l7(&engine, &graphql_query)); - - let graphql_mutation = l7_graphql_input( - "api.github.test", - serde_json::json!([{ - "operation_type": "mutation", - "fields": ["deleteRepository"], - "persisted_query": false - }]), - ); - assert!( - !eval_l7(&engine, &graphql_mutation), - "REST rules on the same host must not allow a GraphQL mutation" - ); - } - - #[test] - fn l7_method_matching_case_insensitive() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "get", "/repos/myorg/foo"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_path_glob_matching() { - let engine = l7_engine(); - // /repos/** should match /repos/org/repo - let input = l7_input("api.example.com", 8080, "GET", "/repos/org/repo"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_query_glob_allows_matching_duplicate_values() { - let engine = l7_engine(); - let input = l7_input_with_query( - "api.query.com", - 8080, - "GET", - "/download", - serde_json::json!({ - "tag": ["foo-a", "foo-b"], - "extra": ["ignored"], - }), - ); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_query_glob_denies_on_mismatched_duplicate_value() { - let engine = l7_engine(); - let input = l7_input_with_query( - "api.query.com", - 8080, - "GET", - "/download", - serde_json::json!({ - "tag": ["foo-a", "evil"], - }), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_query_any_allows_if_every_value_matches_any_pattern() { - let engine = l7_engine(); - let input = l7_input_with_query( - "api.query.com", - 8080, - "GET", - "/search", - serde_json::json!({ - "tag": ["foo-a", "bar-b"], - }), - ); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_query_missing_required_key_denied() { - let engine = l7_engine(); - let input = l7_input_with_query( - "api.query.com", - 8080, - "GET", - "/download", - serde_json::json!({}), - ); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_query_rules_from_proto_are_enforced() { - let mut query = std::collections::HashMap::new(); - query.insert( - "tag".to_string(), - L7QueryMatcher { - glob: "foo-*".to_string(), - any: vec![], - }, - ); - - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "query_proto".to_string(), - NetworkPolicyRule { - name: "query_proto".to_string(), - endpoints: vec![NetworkEndpoint { - host: "api.proto.com".to_string(), - port: 8080, - protocol: "rest".to_string(), - enforcement: "enforce".to_string(), - rules: vec![L7Rule { - allow: Some(L7Allow { - method: "GET".to_string(), - path: "/download".to_string(), - command: String::new(), - query, - operation_type: String::new(), - operation_name: String::new(), - fields: Vec::new(), - rpc_method: String::new(), - params: std::collections::HashMap::new(), - }), - }], - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }], - }, - ); - - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let allow_input = l7_input_with_query( - "api.proto.com", - 8080, - "GET", - "/download", - serde_json::json!({ "tag": ["foo-a"] }), - ); - assert!(eval_l7(&engine, &allow_input)); - - let deny_input = l7_input_with_query( - "api.proto.com", - 8080, - "GET", - "/download", - serde_json::json!({ "tag": ["evil"] }), - ); - assert!(!eval_l7(&engine, &deny_input)); - } - - #[test] - fn l7_jsonrpc_rpc_method_from_proto_is_enforced() { - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "jsonrpc_proto".to_string(), - NetworkPolicyRule { - name: "jsonrpc_proto".to_string(), - endpoints: vec![NetworkEndpoint { - host: "mcp.proto.com".to_string(), - port: 8000, - path: "/mcp".to_string(), - protocol: "json-rpc".to_string(), - enforcement: "enforce".to_string(), - rules: vec![L7Rule { - allow: Some(L7Allow { - method: String::new(), - path: String::new(), - command: String::new(), - query: std::collections::HashMap::new(), - operation_type: String::new(), - operation_name: String::new(), - fields: Vec::new(), - rpc_method: "initialize".to_string(), - params: std::collections::HashMap::new(), - }), - }], - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }], - }, - ); - - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let allow_input = l7_jsonrpc_input("mcp.proto.com", 8000, "/mcp", "initialize"); - assert!(eval_l7(&engine, &allow_input)); - - let deny_input = l7_jsonrpc_input("mcp.proto.com", 8000, "/mcp", "tools/list"); - assert!(!eval_l7(&engine, &deny_input)); - } - - #[test] - fn l7_jsonrpc_params_rules_filter_tools_call() { - let data = r#" -network_policies: - jsonrpc_params: - name: jsonrpc_params - endpoints: - - host: mcp.params.test - port: 8000 - path: /mcp - protocol: json-rpc - enforcement: enforce - rules: - - allow: - rpc_method: tools/call - params: - name: read_status - - allow: - rpc_method: tools/call - params: - name: submit_report - arguments.scope: workspace/main - deny_rules: - - rpc_method: tools/call - params: - name: blocked_action - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).expect("engine from yaml"); - - let read_status = l7_jsonrpc_input_with_params( - "mcp.params.test", - 8000, - "/mcp", - "tools/call", - serde_json::json!({"name": "read_status"}), - ); - assert!(eval_l7(&engine, &read_status)); - - let submit_report = l7_jsonrpc_input_with_params( - "mcp.params.test", - 8000, - "/mcp", - "tools/call", - serde_json::json!({ - "name": "submit_report", - "arguments.scope": "workspace/main" - }), - ); - assert!(eval_l7(&engine, &submit_report)); - - let blocked_without_args = l7_jsonrpc_input_with_params( - "mcp.params.test", - 8000, - "/mcp", - "tools/call", - serde_json::json!({"name": "blocked_action"}), - ); - assert!(!eval_l7(&engine, &blocked_without_args)); - - let blocked_with_args = l7_jsonrpc_input_with_params( - "mcp.params.test", - 8000, - "/mcp", - "tools/call", - serde_json::json!({ - "name": "blocked_action", - "arguments.reason": "test" - }), - ); - assert!(!eval_l7(&engine, &blocked_with_args)); - } - - #[test] - fn l7_no_request_on_l4_only_endpoint() { - // L4-only endpoint should not match L7 allow_request - let engine = l7_engine(); - let input = l7_input("l4only.example.com", 443, "GET", "/anything"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_wrong_binary_denied_even_with_matching_rules() { - let engine = l7_engine(); - let input = serde_json::json!({ - "network": { "host": "api.example.com", "port": 8080 }, - "exec": { - "path": "/usr/bin/python3", - "ancestors": [], - "cmdline_paths": [] - }, - "request": { - "method": "GET", - "path": "/repos/myorg/foo" - } - }); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_deny_reason_populated() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "DELETE", "/repos/myorg/foo"); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.request_deny_reason".into()) - .unwrap(); - let reason = match val { - regorus::Value::String(s) => s.to_string(), - _ => String::new(), - }; - assert!( - reason.contains("not permitted"), - "Expected deny reason, got: {reason}" - ); - } - - #[test] - fn l7_endpoint_config_returned_for_l7_endpoint() { - let engine = l7_engine(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let config = engine.query_endpoint_config(&input).unwrap(); - assert!(config.is_some(), "Expected L7 config for rest endpoint"); - let config = config.unwrap(); - let l7 = crate::l7::parse_l7_config(&config).unwrap(); - assert_eq!(l7.protocol, crate::l7::L7Protocol::Rest); - assert_eq!(l7.enforcement, crate::l7::EnforcementMode::Enforce); - } - - #[test] - fn l7_endpoint_config_preserves_proto_allow_encoded_slash() { - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "npm".to_string(), - NetworkPolicyRule { - name: "npm".to_string(), - endpoints: vec![NetworkEndpoint { - host: "registry.npmjs.org".to_string(), - port: 443, - protocol: "rest".to_string(), - enforcement: "enforce".to_string(), - access: "read-only".to_string(), - allow_encoded_slash: true, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/node".to_string(), - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let input = NetworkInput { - host: "registry.npmjs.org".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let config = engine - .query_endpoint_config(&input) - .unwrap() - .expect("endpoint config"); - let l7 = crate::l7::parse_l7_config(&config).unwrap(); - assert!(l7.allow_encoded_slash); - } - - #[test] - fn l7_endpoint_config_preserves_proto_websocket_credential_rewrite() { - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "gateway".to_string(), - NetworkPolicyRule { - name: "gateway".to_string(), - endpoints: vec![NetworkEndpoint { - host: "gateway.example.com".to_string(), - port: 443, - protocol: "rest".to_string(), - enforcement: "enforce".to_string(), - access: "full".to_string(), - websocket_credential_rewrite: true, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/node".to_string(), - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let input = NetworkInput { - host: "gateway.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let config = engine - .query_endpoint_config(&input) - .unwrap() - .expect("endpoint config"); - let l7 = crate::l7::parse_l7_config(&config).unwrap(); - assert!(l7.websocket_credential_rewrite); - } - - #[test] - fn l7_endpoint_config_preserves_proto_request_body_credential_rewrite() { - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "slack".to_string(), - NetworkPolicyRule { - name: "slack".to_string(), - endpoints: vec![NetworkEndpoint { - host: "slack.com".to_string(), - port: 443, - protocol: "rest".to_string(), - enforcement: "enforce".to_string(), - access: "read-write".to_string(), - request_body_credential_rewrite: true, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/node".to_string(), - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let input = NetworkInput { - host: "slack.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/node"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let config = engine - .query_endpoint_config(&input) - .unwrap() - .expect("endpoint config"); - let l7 = crate::l7::parse_l7_config(&config).unwrap(); - assert!(l7.request_body_credential_rewrite); - } - - #[test] - fn l7_endpoint_config_none_for_l4_only() { - let engine = l7_engine(); - let input = NetworkInput { - host: "l4only.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let config = engine.query_endpoint_config(&input).unwrap(); - assert!( - config.is_none(), - "Expected no L7 config for L4-only endpoint" - ); - } - - #[test] - fn l7_clone_engine_for_tunnel() { - let engine = l7_engine(); - let cloned = engine - .clone_engine_for_tunnel(engine.current_generation()) - .unwrap(); - // Verify the cloned engine can evaluate - let input_json = l7_input("api.example.com", 8080, "GET", "/repos/myorg/foo"); - let mut eng = cloned.engine().lock().unwrap(); - eng.set_input_json(&input_json.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!(val, regorus::Value::from(true)); - } - - #[test] - fn policy_generation_starts_at_zero_and_increments_on_successful_reload() { - let engine = l7_engine(); - assert_eq!(engine.current_generation(), 0); - - engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); - - assert_eq!(engine.current_generation(), 1); - } - - #[test] - fn policy_generation_does_not_increment_on_failed_reload() { - let engine = l7_engine(); - engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); - assert_eq!(engine.current_generation(), 1); - - let invalid_l7_data = r#" -network_policies: - bad_api: - name: bad_api - endpoints: - - host: api.example.com - port: 8080 - protocol: invalid-protocol - binaries: - - { path: /usr/bin/curl } -"#; - assert!(engine.reload(TEST_POLICY, invalid_l7_data).is_err()); - assert_eq!(engine.current_generation(), 1); - - let input_json = l7_input("api.example.com", 8080, "GET", "/repos/myorg/foo"); - let cloned = engine - .clone_engine_for_tunnel(engine.current_generation()) - .unwrap(); - let mut eng = cloned.engine().lock().unwrap(); - eng.set_input_json(&input_json.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!(val, regorus::Value::from(true)); - } - - #[test] - fn endpoint_config_generation_matches_query_generation() { - let engine = l7_engine(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let (config, generation) = engine - .query_endpoint_config_with_generation(&input) - .unwrap(); - assert!(config.is_some()); - assert_eq!(generation, engine.current_generation()); - - engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); - - let (config, generation) = engine - .query_endpoint_config_with_generation(&input) - .unwrap(); - assert!(config.is_some()); - assert_eq!(generation, engine.current_generation()); - assert_eq!(generation, 1); - } - - #[test] - fn tunnel_clone_rejects_stale_generation() { - let engine = l7_engine(); - let captured_generation = engine.current_generation(); - engine.reload(TEST_POLICY, L7_TEST_DATA).unwrap(); - - assert!(engine.clone_engine_for_tunnel(captured_generation).is_err()); - } - - // ======================================================================== - // Deny rules tests - // ======================================================================== - - const L7_DENY_TEST_DATA: &str = r#" -network_policies: - github_api: - name: github_api - endpoints: - - host: api.github.com - port: 443 - protocol: rest - enforcement: enforce - access: read-write - deny_rules: - - method: POST - path: "/repos/*/pulls/*/reviews" - - method: PUT - path: "/repos/*/branches/*/protection" - - method: "*" - path: "/repos/*/rulesets" - binaries: - - { path: /usr/bin/curl } - deny_with_query: - name: deny_with_query - endpoints: - - host: api.restricted.com - port: 443 - protocol: rest - enforcement: enforce - access: full - deny_rules: - - method: POST - path: "/admin/**" - query: - force: "true" - binaries: - - { path: /usr/bin/curl } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - - fn l7_deny_engine() -> OpaEngine { - OpaEngine::from_strings(TEST_POLICY, L7_DENY_TEST_DATA) - .expect("Failed to load deny test data") - } - - #[test] - fn l7_deny_rule_blocks_allowed_method_path() { - let engine = l7_deny_engine(); - // POST to reviews is allowed by read-write preset but denied by deny rule - let input = l7_input( - "api.github.com", - 443, - "POST", - "/repos/myorg/pulls/123/reviews", - ); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(false), - "deny rule should block POST to reviews" - ); - } - - #[test] - fn l7_deny_rule_allows_non_matching_requests() { - let engine = l7_deny_engine(); - // GET repos/issues is allowed and not denied - let input = l7_input("api.github.com", 443, "GET", "/repos/myorg/issues"); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(true), - "non-denied GET should be allowed" - ); - } - - #[test] - fn l7_deny_rule_allows_same_method_different_path() { - let engine = l7_deny_engine(); - // POST to issues is allowed (deny only targets reviews) - let input = l7_input("api.github.com", 443, "POST", "/repos/myorg/issues"); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(true), - "POST to issues should be allowed" - ); - } - - #[test] - fn l7_deny_rule_blocks_wildcard_method() { - let engine = l7_deny_engine(); - // GET /repos/myorg/rulesets should be denied (method: "*") - let input = l7_input("api.github.com", 443, "GET", "/repos/myorg/rulesets"); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(false), - "wildcard method deny should block GET" - ); - } - - #[test] - fn l7_deny_rule_blocks_put_protection() { - let engine = l7_deny_engine(); - let input = l7_input( - "api.github.com", - 443, - "PUT", - "/repos/myorg/branches/main/protection", - ); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(false), - "PUT to branch protection should be denied" - ); - } - - #[test] - fn l7_deny_reason_populated_when_deny_rule_matches() { - let engine = l7_deny_engine(); - let input = l7_input( - "api.github.com", - 443, - "POST", - "/repos/myorg/pulls/123/reviews", - ); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.request_deny_reason".into()) - .unwrap(); - let reason = match val { - regorus::Value::String(s) => s.to_string(), - _ => String::new(), - }; - assert!( - reason.contains("deny rule"), - "Expected deny rule reason, got: {reason}" - ); - } - - #[test] - fn l7_deny_rule_with_query_blocks_matching_params() { - let engine = l7_deny_engine(); - // POST /admin/settings with force=true should be denied - let input = l7_input_with_query( - "api.restricted.com", - 443, - "POST", - "/admin/settings", - serde_json::json!({"force": ["true"]}), - ); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(false), - "deny with matching query should block" - ); - } - - #[test] - fn l7_deny_rule_with_query_allows_non_matching_params() { - let engine = l7_deny_engine(); - // POST /admin/settings with force=false should be allowed (query doesn't match deny) - let input = l7_input_with_query( - "api.restricted.com", - 443, - "POST", - "/admin/settings", - serde_json::json!({"force": ["false"]}), - ); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(true), - "deny with non-matching query should allow" - ); - } - - #[test] - fn l7_deny_rule_with_query_blocks_when_any_value_matches() { - let engine = l7_deny_engine(); - // POST /admin/settings with force=true&force=false should STILL be denied - // because at least one value ("true") matches the deny rule. - // This is fail-closed: any matching value triggers the deny. - let input = l7_input_with_query( - "api.restricted.com", - 443, - "POST", - "/admin/settings", - serde_json::json!({"force": ["true", "false"]}), - ); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(false), - "deny should fire when ANY value matches, even with mixed values" - ); - } - - #[test] - fn l7_deny_rule_without_matching_query_key_allows() { - let engine = l7_deny_engine(); - // POST /admin/settings with no query params -- deny rule has query.force=true, - // so no match (key not present) and request should be allowed - let input = l7_input("api.restricted.com", 443, "POST", "/admin/settings"); - let mut eng = engine.engine.lock().unwrap(); - eng.set_input_json(&input.to_string()).unwrap(); - let val = eng - .eval_rule("data.openshell.sandbox.allow_request".into()) - .unwrap(); - assert_eq!( - val, - regorus::Value::from(true), - "deny without matching query key should allow" - ); - } - - // ======================================================================== - // Overlapping policies (duplicate host:port) — regression tests - // ======================================================================== - - /// Two network_policies entries covering the same host:port with L7 rules. - /// Before the fix, this caused regorus to fail with - /// "duplicated definition of local variable ep" in allow_request. - const OVERLAPPING_L7_TEST_DATA: &str = r#" -network_policies: - test_server: - name: test_server - endpoints: - - host: 192.168.1.100 - port: 8567 - protocol: rest - enforcement: enforce - rules: - - allow: - method: GET - path: "**" - binaries: - - { path: /usr/bin/curl } - allow_192_168_1_100_8567: - name: allow_192_168_1_100_8567 - endpoints: - - host: 192.168.1.100 - port: 8567 - protocol: rest - enforcement: enforce - allowed_ips: - - 192.168.1.100 - rules: - - allow: - method: GET - path: "**" - binaries: - - { path: /usr/bin/curl } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - - #[test] - fn l7_overlapping_policies_allow_request_does_not_crash() { - let engine = OpaEngine::from_strings(TEST_POLICY, OVERLAPPING_L7_TEST_DATA) - .expect("engine should load overlapping data"); - let input = l7_input("192.168.1.100", 8567, "GET", "/test"); - // Should not panic or error — must evaluate to true. - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_overlapping_policies_deny_request_does_not_crash() { - let engine = OpaEngine::from_strings(TEST_POLICY, OVERLAPPING_L7_TEST_DATA) - .expect("engine should load overlapping data"); - let input = l7_input("192.168.1.100", 8567, "DELETE", "/test"); - // DELETE is not in the rules, so should deny — but must not crash. - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn overlapping_policies_endpoint_config_returns_result() { - let engine = OpaEngine::from_strings(TEST_POLICY, OVERLAPPING_L7_TEST_DATA) - .expect("engine should load overlapping data"); - let input = NetworkInput { - host: "192.168.1.100".into(), - port: 8567, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: String::new(), - ancestors: vec![], - cmdline_paths: vec![], - }; - // Should return config from one of the entries without error. - let config = engine.query_endpoint_config(&input).unwrap(); - assert!( - config.is_some(), - "Expected endpoint config for overlapping policies" - ); - } - - // ======================================================================== - // network_action tests - // ======================================================================== - - const INFERENCE_TEST_DATA: &str = r#" -network_policies: - claude_code: - name: claude_code - endpoints: - - { host: api.anthropic.com, port: 443 } - binaries: - - { path: /usr/local/bin/claude } - gitlab: - name: gitlab - endpoints: - - { host: gitlab.com, port: 443 } - binaries: - - { path: /usr/bin/glab } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - - const NO_INFERENCE_TEST_DATA: &str = r#" -network_policies: - gitlab: - name: gitlab - endpoints: - - { host: gitlab.com, port: 443 } - binaries: - - { path: /usr/bin/glab } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - - fn inference_engine() -> OpaEngine { - OpaEngine::from_strings(TEST_POLICY, INFERENCE_TEST_DATA) - .expect("Failed to load inference test data") - } - - fn no_inference_engine() -> OpaEngine { - OpaEngine::from_strings(TEST_POLICY, NO_INFERENCE_TEST_DATA) - .expect("Failed to load no-inference test data") - } - - #[test] - fn explicitly_allowed_endpoint_binary_returns_allow() { - let engine = inference_engine(); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - assert_eq!( - action, - NetworkAction::Allow { - matched_policy: Some("claude_code".to_string()) - }, - ); - } - - #[test] - fn unknown_endpoint_returns_deny() { - let engine = inference_engine(); - let input = NetworkInput { - host: "api.openai.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - match &action { - NetworkAction::Deny { .. } => {} - other => panic!("Expected Deny, got: {other:?}"), - } - } - - #[test] - fn unknown_endpoint_without_inference_returns_deny() { - let engine = no_inference_engine(); - let input = NetworkInput { - host: "api.openai.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - match &action { - NetworkAction::Deny { .. } => {} - other => panic!("Expected Deny, got: {other:?}"), - } - } - - #[test] - fn endpoint_in_policy_binary_not_allowed_returns_deny() { - // api.anthropic.com is declared but python3 is not in the binary list. - // With binary allow/deny, this is denied. - let engine = inference_engine(); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - match &action { - NetworkAction::Deny { .. } => {} - other => panic!("Expected Deny, got: {other:?}"), - } - } - - #[test] - fn endpoint_in_policy_binary_not_allowed_without_inference_returns_deny() { - let engine = no_inference_engine(); - let input = NetworkInput { - host: "gitlab.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - match &action { - NetworkAction::Deny { .. } => {} - other => panic!("Expected Deny, got: {other:?}"), - } - } - - #[test] - fn from_proto_explicitly_allowed_returns_allow() { - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - assert_eq!( - action, - NetworkAction::Allow { - matched_policy: Some("claude_code".to_string()) - }, - ); - } - - #[test] - fn from_proto_unknown_endpoint_returns_deny() { - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let input = NetworkInput { - host: "api.openai.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - match &action { - NetworkAction::Deny { .. } => {} - other => panic!("Expected Deny, got: {other:?}"), - } - } - - #[test] - fn network_action_with_dev_policy() { - let engine = test_engine(); - // claude direct to api.anthropic.com → allow (explicit match) - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - assert_eq!( - action, - NetworkAction::Allow { - matched_policy: Some("claude_code".to_string()) - }, - ); - - // git to github.com → allow - let input = NetworkInput { - host: "github.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/git"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let action = engine.evaluate_network_action(&input).unwrap(); - assert_eq!( - action, - NetworkAction::Allow { - matched_policy: Some("github_ssh_over_https".to_string()) - }, - ); - } - - // ======================================================================== - // allowed_ips tests - // ======================================================================== - - const ALLOWED_IPS_TEST_DATA: &str = r#" -network_policies: - # Mode 2: host + allowed_ips - internal_api: - name: internal_api - endpoints: - - host: my-service.corp.net - port: 8080 - allowed_ips: ["10.0.5.0/24"] - binaries: - - { path: /usr/bin/curl } - # Mode 3: allowed_ips only (no host) — uses port 9443 to avoid overlap - private_network: - name: private_network - endpoints: - - port: 9443 - allowed_ips: ["172.16.0.0/12", "192.168.1.1"] - binaries: - - { path: /usr/bin/curl } - # Mode 1: host only (no allowed_ips) — standard behavior - public_api: - name: public_api - endpoints: - - { host: api.github.com, port: 443 } - binaries: - - { path: /usr/bin/curl } - # Wildcard host endpoint should not count as an exact declared hostname. - wildcard_api: - name: wildcard_api - endpoints: - - { host: "*.corp.net", port: 443 } - binaries: - - { path: /usr/bin/curl } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - - fn allowed_ips_engine() -> OpaEngine { - OpaEngine::from_strings(TEST_POLICY, ALLOWED_IPS_TEST_DATA) - .expect("Failed to load allowed_ips test data") - } - - #[test] - fn allowed_ips_mode2_host_plus_ips_allows() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "my-service.corp.net".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Mode 2 (host+IPs) should allow: {}", - decision.reason - ); - assert_eq!(decision.matched_policy.as_deref(), Some("internal_api")); - } - - #[test] - fn allowed_ips_mode2_returns_allowed_ips() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "my-service.corp.net".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let ips = engine.query_allowed_ips(&input).unwrap(); - assert_eq!(ips, vec!["10.0.5.0/24"]); - } - - #[test] - fn allowed_ips_mode3_hostless_allows_any_domain() { - let engine = allowed_ips_engine(); - // Any hostname on port 9443 should match the private_network policy - let input = NetworkInput { - host: "anything.example.com".into(), - port: 9443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Mode 3 (IPs only) should allow any domain on matching port: {}", - decision.reason - ); - } - - #[test] - fn allowed_ips_mode3_returns_allowed_ips() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "anything.example.com".into(), - port: 9443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let ips = engine.query_allowed_ips(&input).unwrap(); - assert_eq!(ips, vec!["172.16.0.0/12", "192.168.1.1"]); - } - - #[test] - fn allowed_ips_mode1_no_ips_returns_empty() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "api.github.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let ips = engine.query_allowed_ips(&input).unwrap(); - assert!(ips.is_empty(), "Mode 1 should return no allowed_ips"); - } - - #[test] - fn exact_declared_endpoint_host_true_for_l4_host_only() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "api.github.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - assert!(engine.query_endpoint_config(&input).unwrap().is_none()); - assert!(engine.query_exact_declared_endpoint_host(&input).unwrap()); - } - - #[test] - fn exact_declared_endpoint_host_true_for_host_with_allowed_ips() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "my-service.corp.net".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - assert!(engine.query_exact_declared_endpoint_host(&input).unwrap()); - } - - #[test] - fn exact_declared_endpoint_host_false_for_hostless_allowed_ips() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "anything.example.com".into(), - port: 9443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - assert!(!engine.query_exact_declared_endpoint_host(&input).unwrap()); - } - - #[test] - fn exact_declared_endpoint_host_false_for_wildcard_host() { - let engine = allowed_ips_engine(); - let input = NetworkInput { - host: "api.corp.net".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let decision = engine.evaluate_network(&input).unwrap(); - assert!(decision.allowed, "wildcard endpoint should still allow"); - assert!(!engine.query_exact_declared_endpoint_host(&input).unwrap()); - } - - #[test] - fn exact_declared_endpoint_host_false_for_advisor_proposed_binary() { - let mut network_policies = std::collections::HashMap::new(); - let mut proposal_binary = NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }; - #[allow(deprecated)] - { - proposal_binary.harness = true; - } - network_policies.insert( - "allow_mcp_internal_corp_example_com_8443".to_string(), - NetworkPolicyRule { - name: "allow_mcp_internal_corp_example_com_8443".to_string(), - endpoints: vec![NetworkEndpoint { - host: "mcp-internal.corp.example.com".to_string(), - port: 8443, - ..Default::default() - }], - binaries: vec![proposal_binary], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let input = NetworkInput { - host: "mcp-internal.corp.example.com".into(), - port: 8443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "advisor proposal should still allow at OPA L4" - ); - assert!(!engine.query_exact_declared_endpoint_host(&input).unwrap()); - } - - #[test] - fn exact_declared_endpoint_host_false_for_advisor_proposed_endpoint() { - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "app-api".to_string(), - NetworkPolicyRule { - name: "app-api".to_string(), - endpoints: vec![NetworkEndpoint { - host: "internal-admin.local".to_string(), - port: 443, - ports: vec![443], - advisor_proposed: true, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/python".to_string(), - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); - let input = NetworkInput { - host: "internal-admin.local".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let decision = engine.evaluate_network(&input).unwrap(); - assert!(decision.allowed, "policy should still allow at OPA L4"); - assert!( - !engine.query_exact_declared_endpoint_host(&input).unwrap(), - "advisor endpoint provenance should block exact-host SSRF trust" - ); - } - - #[test] - fn allowed_ips_mode3_wrong_port_denied() { - let engine = allowed_ips_engine(); - // Port 12345 doesn't match any policy - let input = NetworkInput { - host: "anything.example.com".into(), - port: 12345, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed, "Mode 3: wrong port should deny"); - } - - #[test] - fn allowed_ips_proto_round_trip() { - // Test that allowed_ips survives proto → OPA data → query - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "internal".to_string(), - NetworkPolicyRule { - name: "internal".to_string(), - endpoints: vec![NetworkEndpoint { - host: "internal.corp.net".to_string(), - port: 8080, - allowed_ips: vec!["10.0.5.0/24".to_string(), "10.0.6.0/24".to_string()], - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); - - let input = NetworkInput { - host: "internal.corp.net".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let ips = engine.query_allowed_ips(&input).unwrap(); - assert_eq!(ips, vec!["10.0.5.0/24", "10.0.6.0/24"]); - } - - // ======================================================================== - // Multi-port endpoint tests - // ======================================================================== - - #[test] - fn multi_port_endpoint_matches_first_port() { - let data = r#" -network_policies: - multi: - name: multi - endpoints: - - { host: api.example.com, ports: [443, 8443] } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "First port in multi-port should match: {}", - decision.reason - ); - } - - #[test] - fn multi_port_endpoint_matches_second_port() { - let data = r#" -network_policies: - multi: - name: multi - endpoints: - - { host: api.example.com, ports: [443, 8443] } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 8443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Second port in multi-port should match: {}", - decision.reason - ); - } - - #[test] - fn multi_port_endpoint_rejects_unlisted_port() { - let data = r#" -network_policies: - multi: - name: multi - endpoints: - - { host: api.example.com, ports: [443, 8443] } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 80, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed, "Unlisted port should be denied"); - } - - #[test] - fn single_port_backwards_compat() { - // Old-style YAML with just `port: 443` should still work - let data = r#" -network_policies: - compat: - name: compat - endpoints: - - { host: api.example.com, port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Single port backwards compat: {}", - decision.reason - ); - - // Wrong port should still deny - let input_bad = NetworkInput { - host: "api.example.com".into(), - port: 80, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input_bad).unwrap(); - assert!(!decision.allowed); - } - - #[test] - fn hostless_endpoint_multi_port() { - let data = r#" -network_policies: - private: - name: private - endpoints: - - ports: [80, 443] - allowed_ips: ["10.0.0.0/8"] - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - // Port 80 - let input80 = NetworkInput { - host: "anything.internal".into(), - port: 80, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input80).unwrap(); - assert!( - decision.allowed, - "Hostless multi-port should match port 80: {}", - decision.reason - ); - // Port 443 - let input443 = NetworkInput { - host: "anything.internal".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input443).unwrap(); - assert!( - decision.allowed, - "Hostless multi-port should match port 443: {}", - decision.reason - ); - // Port 8080 should deny - let input_bad = NetworkInput { - host: "anything.internal".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input_bad).unwrap(); - assert!(!decision.allowed); - } - - #[test] - fn from_proto_multi_port_allows_matching() { - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "multi".to_string(), - NetworkPolicyRule { - name: "multi".to_string(), - endpoints: vec![NetworkEndpoint { - host: "api.example.com".to_string(), - port: 443, - ports: vec![443, 8443], - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - let engine = OpaEngine::from_proto(&proto).unwrap(); - // Port 443 - let input443 = NetworkInput { - host: "api.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - assert!(engine.evaluate_network(&input443).unwrap().allowed); - // Port 8443 - let input8443 = NetworkInput { - host: "api.example.com".into(), - port: 8443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - assert!(engine.evaluate_network(&input8443).unwrap().allowed); - // Port 80 denied - let input80 = NetworkInput { - host: "api.example.com".into(), - port: 80, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - assert!(!engine.evaluate_network(&input80).unwrap().allowed); - } - - // ======================================================================== - // Host wildcard tests - // ======================================================================== - - #[test] - fn wildcard_host_matches_subdomain() { - let data = r#" -network_policies: - wildcard: - name: wildcard - endpoints: - - { host: "*.example.com", port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "*.example.com should match api.example.com: {}", - decision.reason - ); - } - - #[test] - fn wildcard_host_rejects_deep_subdomain() { - // * should match single DNS label only (does not cross .) - let data = r#" -network_policies: - wildcard: - name: wildcard - endpoints: - - { host: "*.example.com", port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "deep.sub.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - !decision.allowed, - "*.example.com should NOT match deep.sub.example.com" - ); - } - - #[test] - fn wildcard_host_rejects_exact_domain() { - let data = r#" -network_policies: - wildcard: - name: wildcard - endpoints: - - { host: "*.example.com", port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - !decision.allowed, - "*.example.com should NOT match example.com (requires at least one label)" - ); - } - - #[test] - fn wildcard_host_case_insensitive() { - let data = r#" -network_policies: - wildcard: - name: wildcard - endpoints: - - { host: "*.EXAMPLE.COM", port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Host wildcards should be case-insensitive: {}", - decision.reason - ); - } - - #[test] - fn wildcard_host_plus_port() { - let data = r#" -network_policies: - wildcard: - name: wildcard - endpoints: - - { host: "*.example.com", port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - // Right host, wrong port - let input = NetworkInput { - host: "api.example.com".into(), - port: 80, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed, "Wildcard host on wrong port should deny"); - } - - #[test] - fn wildcard_host_intra_label_matches() { - // First-label intra-label wildcard: `*` matches the variable prefix - // within a single DNS label. Locks validator/runtime alignment for - // the pattern accepted by `validate_host_wildcard`. - let data = r#" -network_policies: - intra_label: - name: intra_label - endpoints: - - { host: "*-aiplatform.googleapis.com", port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "us-central1-aiplatform.googleapis.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "*-aiplatform.googleapis.com should match us-central1-aiplatform.googleapis.com: {}", - decision.reason - ); - } - - #[test] - fn wildcard_host_intra_label_does_not_cross_dot() { - // `glob.match(..., ["."])` treats `.` as a label boundary that `*` - // cannot cross. `*-aiplatform.googleapis.com` must not match a host - // whose first label is `us-central1` and where `aiplatform` is a - // separate label. - let data = r#" -network_policies: - intra_label: - name: intra_label - endpoints: - - { host: "*-aiplatform.googleapis.com", port: 443 } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "us-central1.aiplatform.googleapis.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - !decision.allowed, - "*-aiplatform.googleapis.com must NOT match us-central1.aiplatform.googleapis.com \ - (would cross a `.` boundary)" - ); - } - - #[test] - fn wildcard_host_multi_port() { - let data = r#" -network_policies: - wildcard: - name: wildcard - endpoints: - - { host: "*.example.com", ports: [443, 8443] } - binaries: - - { path: /usr/bin/curl } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 8443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Wildcard host + multi-port should match: {}", - decision.reason - ); - } - - #[test] - fn wildcard_host_l7_rules_apply() { - let data = r#" -network_policies: - wildcard_l7: - name: wildcard_l7 - endpoints: - - host: "*.example.com" - port: 8080 - protocol: rest - enforcement: enforce - tls: terminate - rules: - - allow: - method: GET - path: "/api/**" - binaries: - - { path: /usr/bin/curl } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - // L7 GET to /api/foo — should be allowed - let input = l7_input("api.example.com", 8080, "GET", "/api/foo"); - assert!( - eval_l7(&engine, &input), - "L7 rule should apply to wildcard-matched host" - ); - // L7 DELETE to /api/foo — should be denied by L7 rule - let input_bad = l7_input("api.example.com", 8080, "DELETE", "/api/foo"); - assert!( - !eval_l7(&engine, &input_bad), - "L7 DELETE should be denied even on wildcard host" - ); - } - - #[test] - fn wildcard_host_l7_endpoint_config_returned() { - let data = r#" -network_policies: - wildcard_l7: - name: wildcard_l7 - endpoints: - - host: "*.example.com" - port: 8080 - protocol: rest - enforcement: enforce - tls: terminate - rules: - - allow: - method: GET - path: "**" - binaries: - - { path: /usr/bin/curl } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 8080, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let config = engine.query_endpoint_config(&input).unwrap(); - assert!( - config.is_some(), - "Should return endpoint config for wildcard-matched host" - ); - let config = config.unwrap(); - let l7 = crate::l7::parse_l7_config(&config).unwrap(); - assert_eq!(l7.protocol, crate::l7::L7Protocol::Rest); - assert_eq!(l7.enforcement, crate::l7::EnforcementMode::Enforce); - } - - #[test] - fn l7_multi_port_request_evaluation() { - let data = r#" -network_policies: - multi_l7: - name: multi_l7 - endpoints: - - host: api.example.com - ports: [8080, 9090] - protocol: rest - enforcement: enforce - tls: terminate - rules: - - allow: - method: GET - path: "**" - binaries: - - { path: /usr/bin/curl } -filesystem_policy: - include_workdir: true - read_only: [] - read_write: [] -landlock: - compatibility: best_effort -process: - run_as_user: sandbox - run_as_group: sandbox -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - // GET on port 8080 — allowed - let input1 = l7_input("api.example.com", 8080, "GET", "/anything"); - assert!( - eval_l7(&engine, &input1), - "L7 on first port of multi-port should work" - ); - // GET on port 9090 — allowed - let input2 = l7_input("api.example.com", 9090, "GET", "/anything"); - assert!( - eval_l7(&engine, &input2), - "L7 on second port of multi-port should work" - ); - } - - // ======================================================================== - // Symlink resolution tests (issue #770) - // ======================================================================== - - #[test] - fn normalize_path_resolves_parent_and_current() { - use std::path::{Path, PathBuf}; - assert_eq!( - normalize_path(Path::new("/usr/bin/../lib/python3")), - PathBuf::from("/usr/lib/python3") - ); - assert_eq!( - normalize_path(Path::new("/usr/bin/./python3")), - PathBuf::from("/usr/bin/python3") - ); - assert_eq!( - normalize_path(Path::new("/a/b/c/../../d")), - PathBuf::from("/a/d") - ); - assert_eq!( - normalize_path(Path::new("/usr/bin/python3")), - PathBuf::from("/usr/bin/python3") - ); - } - - #[test] - fn resolve_binary_skips_glob_paths() { - // Glob patterns should never be resolved — they're matched differently - assert!(resolve_binary_in_container("/usr/bin/*", 1).is_none()); - assert!(resolve_binary_in_container("/usr/local/bin/**", 1).is_none()); - } - - #[test] - fn resolve_binary_skips_pid_zero() { - // pid=0 means the container hasn't started yet - assert!(resolve_binary_in_container("/usr/bin/python3", 0).is_none()); - } - - #[test] - fn resolve_binary_returns_none_for_nonexistent_path() { - // A path that doesn't exist in any container should gracefully return None - assert!( - resolve_binary_in_container("/nonexistent/binary/path/that/will/never/exist", 1) - .is_none() - ); - } - - #[test] - fn proto_to_opa_data_json_pid_zero_no_expansion() { - // With pid=0, proto_to_opa_data_json should produce the same output - // as the original (no symlink expansion) - let proto = test_proto(); - let data_no_pid = proto_to_opa_data_json(&proto, 0); - let parsed: serde_json::Value = serde_json::from_str(&data_no_pid).unwrap(); - - // Verify the claude_code policy has exactly 1 binary entry (no expansion) - let binaries = parsed["network_policies"]["claude_code"]["binaries"] - .as_array() - .unwrap(); - assert_eq!( - binaries.len(), - 1, - "With pid=0, should have no expanded binaries" - ); - assert_eq!(binaries[0]["path"], "/usr/local/bin/claude"); - } - - #[test] - fn symlink_expanded_binary_allows_resolved_path() { - // Simulate what happens after symlink resolution: the OPA data - // contains both the original symlink path and the resolved path. - // A request using the resolved path should be allowed. - let data = r#" -network_policies: - python_policy: - name: python_policy - endpoints: - - { host: pypi.org, port: 443 } - binaries: - - { path: /usr/bin/python3 } - - { path: /usr/bin/python3.11 } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - - // Request with the resolved path (what the kernel reports) - let input = NetworkInput { - host: "pypi.org".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3.11"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Resolved symlink path should be allowed: {}", - decision.reason - ); - assert_eq!(decision.matched_policy.as_deref(), Some("python_policy")); - } - - #[test] - fn symlink_expanded_binary_still_allows_original_path() { - // Even with expansion, the original path must still work - let data = r#" -network_policies: - python_policy: - name: python_policy - endpoints: - - { host: pypi.org, port: 443 } - binaries: - - { path: /usr/bin/python3 } - - { path: /usr/bin/python3.11 } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - - // Request with the original symlink path (unlikely at runtime, but must not break) - let input = NetworkInput { - host: "pypi.org".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Original symlink path should still be allowed: {}", - decision.reason - ); - } - - #[test] - fn symlink_expanded_binary_does_not_weaken_security() { - // A binary NOT in the policy should still be denied, even if - // the expanded entries exist for other binaries. - let data = r#" -network_policies: - python_policy: - name: python_policy - endpoints: - - { host: pypi.org, port: 443 } - binaries: - - { path: /usr/bin/python3 } - - { path: /usr/bin/python3.11 } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - - let input = NetworkInput { - host: "pypi.org".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed, "Unrelated binary should still be denied"); - } - - #[test] - fn symlink_expansion_works_with_ancestors() { - // Ancestor binary matching should also work with expanded paths - let data = r#" -network_policies: - python_policy: - name: python_policy - endpoints: - - { host: pypi.org, port: 443 } - binaries: - - { path: /usr/bin/python3 } - - { path: /usr/bin/python3.11 } -"#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - - // The exe is curl, but an ancestor is the resolved python3.11 - let input = NetworkInput { - host: "pypi.org".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![PathBuf::from("/usr/bin/python3.11")], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Resolved symlink path should match as ancestor: {}", - decision.reason - ); - } - - #[test] - fn symlink_expansion_via_proto_with_pid_zero() { - // from_proto_with_pid(proto, 0) should produce same results as from_proto(proto) - let proto = test_proto(); - let engine_default = OpaEngine::from_proto(&proto).expect("from_proto should succeed"); - let engine_pid0 = OpaEngine::from_proto_with_pid(&proto, 0) - .expect("from_proto_with_pid(0) should succeed"); - - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - - let decision_default = engine_default.evaluate_network(&input).unwrap(); - let decision_pid0 = engine_pid0.evaluate_network(&input).unwrap(); - - assert_eq!( - decision_default.allowed, decision_pid0.allowed, - "from_proto and from_proto_with_pid(0) should produce identical results" - ); - } - - #[test] - fn reload_from_proto_with_pid_zero_works() { - // reload_from_proto_with_pid(proto, 0) should function identically to reload_from_proto - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("from_proto should succeed"); - - // Verify initial policy works - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(decision.allowed); - - // Reload with same proto at pid=0 - engine - .reload_from_proto_with_pid(&proto, 0) - .expect("reload_from_proto_with_pid should succeed"); - - // Should still work - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "reload_from_proto_with_pid(0) should preserve behavior" - ); - } - - #[test] - fn hot_reload_preserves_symlink_expansion_behavior() { - // Simulates the hot-reload path: initial load at pid=0, then reload - // with a new proto that would have expanded binaries at a real PID. - // Since we can't mock /proc//root/ in unit tests, we test - // that reload_from_proto_with_pid at pid=0 still works correctly - // and that the engine is properly replaced. - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("initial load should succeed"); - - // Verify initial policy allows claude - let claude_input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - assert!(engine.evaluate_network(&claude_input).unwrap().allowed); - - // Create a new proto with an additional policy - let mut new_proto = test_proto(); - new_proto.network_policies.insert( - "python_api".to_string(), - NetworkPolicyRule { - name: "python_api".to_string(), - endpoints: vec![NetworkEndpoint { - host: "pypi.org".to_string(), - port: 443, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/python3".to_string(), - ..Default::default() - }], - }, - ); - - // Hot-reload with pid=0 - engine - .reload_from_proto_with_pid(&new_proto, 0) - .expect("hot-reload should succeed"); - - // Old policy should still work - assert!( - engine.evaluate_network(&claude_input).unwrap().allowed, - "Old policies should survive hot-reload" - ); - - // New policy should also work - let python_input = NetworkInput { - host: "pypi.org".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - assert!( - engine.evaluate_network(&python_input).unwrap().allowed, - "New policy should be active after hot-reload" - ); - } - - #[test] - fn hot_reload_replaces_engine_atomically() { - // Test that a failed reload preserves the last-known-good engine - let proto = test_proto(); - let engine = OpaEngine::from_proto(&proto).expect("initial load should succeed"); - - let claude_input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - assert!(engine.evaluate_network(&claude_input).unwrap().allowed); - - // Reload with same proto — should succeed and preserve behavior - engine - .reload_from_proto_with_pid(&proto, 0) - .expect("reload should succeed"); - - assert!( - engine.evaluate_network(&claude_input).unwrap().allowed, - "Engine should work after successful reload" - ); - } - - #[test] - fn deny_reason_includes_symlink_hint() { - // Verify the deny reason includes an actionable symlink hint - let engine = test_engine(); - let input = NetworkInput { - host: "api.anthropic.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/python3.11"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - assert!( - decision.reason.contains("SYMLINK HINT"), - "Deny reason should include prominent symlink hint, got: {}", - decision.reason - ); - assert!( - decision.reason.contains("readlink -f"), - "Deny reason should include actionable fix command, got: {}", - decision.reason - ); - } - - #[test] - fn deny_reason_collapses_endpoint_misses() { - let engine = test_engine(); - let input = NetworkInput { - host: "not-configured.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/local/bin/claude"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!(!decision.allowed); - assert_eq!( - decision.reason, - "endpoint not-configured.example.com:443 is not allowed by any policy" - ); - } - - /// Check if symlink resolution through `/proc//root/` actually works. - /// Creates a real symlink in a tempdir and attempts to resolve it via - /// the procfs root path. This catches environments where the probe path - /// is readable but canonicalization/read_link fails (e.g., containers - /// with restricted ptrace scope, rootless containers). - #[cfg(target_os = "linux")] - fn procfs_root_accessible() -> bool { - use std::os::unix::fs::symlink; - let Ok(dir) = tempfile::tempdir() else { - return false; - }; - let target = dir.path().join("probe_target"); - let link = dir.path().join("probe_link"); - if std::fs::write(&target, b"probe").is_err() { - return false; - } - if symlink(&target, &link).is_err() { - return false; - } - let pid = std::process::id(); - let link_path = link.to_string_lossy().to_string(); - // Actually attempt the same resolution our production code uses - resolve_binary_in_container(&link_path, pid).is_some() - } - - #[cfg(target_os = "linux")] - #[test] - fn resolve_binary_with_real_symlink() { - use std::os::unix::fs::symlink; - - if !procfs_root_accessible() { - eprintln!("Skipping: /proc//root/ not accessible in this environment"); - return; - } - - // Create a real symlink in a temp directory and verify resolution - // works through /proc/self/root (which maps to / on the host) - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("python3.11"); - let link = dir.path().join("python3"); - - // Create the target file - std::fs::write(&target, b"#!/usr/bin/env python3\n").unwrap(); - // Create symlink - symlink(&target, &link).unwrap(); - - // Use our own PID — /proc//root/ points to / - let our_pid = std::process::id(); - let link_path = link.to_string_lossy().to_string(); - let result = resolve_binary_in_container(&link_path, our_pid); - - assert!( - result.is_some(), - "Should resolve symlink via /proc//root/" - ); - let resolved = result.unwrap(); - assert!( - resolved.ends_with("python3.11"), - "Resolved path should point to target: {resolved}" - ); - } - - #[cfg(target_os = "linux")] - #[test] - fn resolve_binary_non_symlink_returns_none() { - use std::io::Write; - - if !procfs_root_accessible() { - eprintln!("Skipping: /proc//root/ not accessible in this environment"); - return; - } - - // A regular file should return None (no expansion needed) - let mut tmp = tempfile::NamedTempFile::new().unwrap(); - tmp.write_all(b"regular file").unwrap(); - tmp.flush().unwrap(); - - let our_pid = std::process::id(); - let path = tmp.path().to_string_lossy().to_string(); - let result = resolve_binary_in_container(&path, our_pid); - - assert!( - result.is_none(), - "Non-symlink file should return None, got: {result:?}" - ); - } - - #[cfg(target_os = "linux")] - #[test] - fn resolve_binary_multi_level_symlink() { - use std::os::unix::fs::symlink; - - if !procfs_root_accessible() { - eprintln!("Skipping: /proc//root/ not accessible in this environment"); - return; - } - - // Test multi-level symlink resolution: python3 -> python3.11 -> cpython3.11 - let dir = tempfile::tempdir().unwrap(); - let final_target = dir.path().join("cpython3.11"); - let mid_link = dir.path().join("python3.11"); - let top_link = dir.path().join("python3"); - - std::fs::write(&final_target, b"final binary").unwrap(); - symlink(&final_target, &mid_link).unwrap(); - symlink(&mid_link, &top_link).unwrap(); - - let our_pid = std::process::id(); - let link_path = top_link.to_string_lossy().to_string(); - let result = resolve_binary_in_container(&link_path, our_pid); - - assert!(result.is_some(), "Should resolve multi-level symlink chain"); - let resolved = result.unwrap(); - assert!( - resolved.ends_with("cpython3.11"), - "Should resolve to final target: {resolved}" - ); - } - - #[cfg(target_os = "linux")] - #[test] - fn from_proto_with_pid_expands_symlinks_in_container() { - use std::os::unix::fs::symlink; - - if !procfs_root_accessible() { - eprintln!("Skipping: /proc//root/ not accessible in this environment"); - return; - } - - // End-to-end test: create a symlink, build engine with our PID, - // verify the resolved path is allowed - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("node22"); - let link = dir.path().join("node"); - - std::fs::write(&target, b"node binary").unwrap(); - symlink(&target, &link).unwrap(); - - let link_path = link.to_string_lossy().to_string(); - let target_path = target.to_string_lossy().to_string(); - - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "test".to_string(), - NetworkPolicyRule { - name: "test".to_string(), - endpoints: vec![NetworkEndpoint { - host: "example.com".to_string(), - port: 443, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: link_path, - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - - // Build engine with our PID (symlink resolution will work via /proc/self/root/) - let our_pid = std::process::id(); - let engine = OpaEngine::from_proto_with_pid(&proto, our_pid) - .expect("from_proto_with_pid should succeed"); - - // Request using the resolved target path should be allowed - let input = NetworkInput { - host: "example.com".into(), - port: 443, - binary_path: PathBuf::from(&target_path), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input).unwrap(); - assert!( - decision.allowed, - "Resolved symlink target should be allowed after expansion: {}", - decision.reason - ); - } - - #[cfg(target_os = "linux")] - #[test] - fn reload_from_proto_with_pid_resolves_symlinks() { - use std::os::unix::fs::symlink; - - if !procfs_root_accessible() { - eprintln!("Skipping: /proc//root/ not accessible in this environment"); - return; - } - - // Test hot-reload path: initial engine at pid=0, then reload with - // real PID to trigger symlink resolution - let dir = tempfile::tempdir().unwrap(); - let target = dir.path().join("python3.11"); - let link = dir.path().join("python3"); - - std::fs::write(&target, b"python binary").unwrap(); - symlink(&target, &link).unwrap(); - - let link_path = link.to_string_lossy().to_string(); - let target_path = target.to_string_lossy().to_string(); - - let mut network_policies = std::collections::HashMap::new(); - network_policies.insert( - "python".to_string(), - NetworkPolicyRule { - name: "python".to_string(), - endpoints: vec![NetworkEndpoint { - host: "pypi.org".to_string(), - port: 443, - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: link_path, - ..Default::default() - }], - }, - ); - let proto = ProtoSandboxPolicy { - version: 1, - filesystem: Some(ProtoFs { - include_workdir: true, - read_only: vec![], - read_write: vec![], - }), - landlock: Some(openshell_core::proto::LandlockPolicy { - compatibility: "best_effort".to_string(), - }), - process: Some(ProtoProc { - run_as_user: "sandbox".to_string(), - run_as_group: "sandbox".to_string(), - }), - network_policies, - }; - - // Initial load at pid=0 — no symlink expansion - let engine = OpaEngine::from_proto(&proto).expect("initial load"); - - // Request with resolved path should be DENIED (no expansion yet) - let input_resolved = NetworkInput { - host: "pypi.org".into(), - port: 443, - binary_path: PathBuf::from(&target_path), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], - }; - let decision = engine.evaluate_network(&input_resolved).unwrap(); - assert!( - !decision.allowed, - "Before reload with PID, resolved path should be denied" - ); - - // Hot-reload with real PID — symlinks resolved - let our_pid = std::process::id(); - engine - .reload_from_proto_with_pid(&proto, our_pid) - .expect("reload with PID"); - - // Now the resolved path should be ALLOWED - let decision = engine.evaluate_network(&input_resolved).unwrap(); - assert!( - decision.allowed, - "After reload with PID, resolved path should be allowed: {}", - decision.reason - ); - } - - #[test] - fn l7_head_allowed_where_get_is_allowed() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "HEAD", "/repos/myorg/foo"); - assert!(eval_l7(&engine, &input)); - } - - #[test] - fn l7_head_denied_when_only_post_allowed() { - let engine = OpaEngine::from_strings( - TEST_POLICY, - "network_policies:\n p:\n name: p\n endpoints:\n - host: h.test\n port: 80\n protocol: rest\n enforcement: enforce\n rules:\n - allow: {method: POST, path: \"/\"}\n binaries:\n - {path: /usr/bin/curl}\n", - ) - .unwrap(); - let input = l7_input("h.test", 80, "HEAD", "/"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_options_not_implicitly_allowed_by_get() { - let engine = l7_engine(); - let input = l7_input("api.example.com", 8080, "OPTIONS", "/repos/myorg/foo"); - assert!(!eval_l7(&engine, &input)); - } - - #[test] - fn l7_head_blocked_by_deny_rule_targeting_get() { - // deny_rules use method_matches() too; a deny on GET must also block HEAD. - let engine = OpaEngine::from_strings( - TEST_POLICY, - "network_policies:\n p:\n name: p\n endpoints:\n - host: h.test\n port: 80\n protocol: rest\n enforcement: enforce\n access: full\n deny_rules:\n - method: GET\n path: \"/protected\"\n binaries:\n - {path: /usr/bin/curl}\n", - ) - .unwrap(); - let input = l7_input("h.test", 80, "HEAD", "/protected"); - assert!(!eval_l7(&engine, &input)); - } -} diff --git a/crates/openshell-sandbox/src/policy_local.rs b/crates/openshell-sandbox/src/policy_local.rs deleted file mode 100644 index aa270d017..000000000 --- a/crates/openshell-sandbox/src/policy_local.rs +++ /dev/null @@ -1,2030 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! Sandbox-local policy advisor HTTP API. - -use miette::{IntoDiagnostic, Result}; -use openshell_core::proto::{ - L7Allow, L7DenyRule, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, PolicyChunk, - SandboxPolicy as ProtoSandboxPolicy, -}; -use openshell_ocsf::{ConfigStateChangeBuilder, SeverityId, StateId, StatusId, ocsf_emit}; -use serde::Deserialize; -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::sync::Arc; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; -use tokio::sync::RwLock; - -pub const POLICY_LOCAL_HOST: &str = "policy.local"; - -/// Filesystem path of the static agent guidance bundle inside the sandbox. -/// Single source of truth: the skill installer writes here, the L7 deny body -/// references this path in `next_steps`, and the skill's own documentation -/// renders the same path. Changing the location is a one-line update here. -pub const SKILL_PATH: &str = "/etc/openshell/skills/policy_advisor.md"; - -/// Human-readable guidance for agents that are more likely to follow plain -/// instructions than structured next-step JSON alone. -pub const AGENT_GUIDANCE: &str = "OpenShell blocked this request with sandbox policy. If the user's task still needs this network action, do not stop here: read /etc/openshell/skills/policy_advisor.md, submit the narrowest policy proposal to http://policy.local/v1/proposals, wait for approval and `policy_reloaded: true`, then retry the original request."; - -/// Routes served by the in-sandbox policy advisor API. Held in one place so -/// the L7 deny `next_steps` array, the route dispatcher, the skill content, -/// and tests all stay in sync — change the wire path here and every caller -/// follows. See `agent_next_steps()` for the consumer that surfaces these -/// to the agent on a 403. -pub const ROUTE_POLICY_CURRENT: &str = "/v1/policy/current"; -pub const ROUTE_DENIALS: &str = "/v1/denials"; -pub const ROUTE_PROPOSALS: &str = "/v1/proposals"; -/// Per-proposal status and long-poll routes live below this prefix: -/// `GET /v1/proposals/{chunk_id}` — immediate status -/// `GET /v1/proposals/{chunk_id}/wait?timeout` — long-poll until terminal -/// Trailing slash differentiates from the bare `POST /v1/proposals` submit. -const ROUTE_PROPOSALS_PREFIX: &str = "/v1/proposals/"; - -/// Long-poll bounds for `GET /v1/proposals/{id}/wait?timeout=`. The agent -/// re-issues on timeout, so the cap is a hold ceiling, not a hard limit on -/// how long the agent can wait overall. -const PROPOSAL_WAIT_DEFAULT_SECS: u64 = 60; -const PROPOSAL_WAIT_MIN_SECS: u64 = 1; -const PROPOSAL_WAIT_MAX_SECS: u64 = 300; -const PROPOSAL_WAIT_POLL_INTERVAL: std::time::Duration = std::time::Duration::from_secs(1); -/// Minimum window the reload-readiness phase gets after a chunk -/// terminalizes, even if the caller's deadline is shorter. Without this, -/// approvals that arrive at T-50ms always return `policy_reloaded=false` -/// and force a re-issue. 500ms is well below typical supervisor poll -/// latency but enough to cover the in-memory coverage check. -const RELOAD_WAIT_MIN_FLOOR: std::time::Duration = std::time::Duration::from_millis(500); - -const MAX_POLICY_LOCAL_BODY_BYTES: usize = 64 * 1024; -/// Hard ceiling on how long a single request body read can stall. Bounds a -/// slowloris-style upload from an in-sandbox process; the proxy listener only -/// accepts loopback connections, so practical impact is limited, but this is -/// cheap defense-in-depth. -const POLICY_LOCAL_BODY_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(15); -const DEFAULT_DENIALS_LIMIT: usize = 10; -const MAX_DENIALS_LIMIT: usize = 100; -/// The shorthand rolling appender keeps three files (daily rotation); read the -/// most recent two so a request just past midnight still has yesterday's -/// denials. -const DENIAL_LOG_FILES_TO_SCAN: usize = 2; -const LOG_DIR: &str = "/var/log"; -/// Shorthand log filenames are `openshell.YYYY-MM-DD.log`. The trailing dot in -/// the prefix is intentional: it disambiguates from the OCSF JSONL appender's -/// `openshell-ocsf.YYYY-MM-DD.log`, which we never want to surface here (the -/// JSONL is opt-in via `ocsf_json_enabled` and not the source of truth for -/// `/v1/denials`). -const SHORTHAND_LOG_PREFIX: &str = "openshell."; -/// Defensive cap on per-line length returned to the agent so a pathological -/// log entry (very long URL path, etc.) cannot blow up the response. -const MAX_DENIAL_LINE_BYTES: usize = 4096; - -#[derive(Debug)] -pub struct PolicyLocalContext { - current_policy: Arc>>, - gateway_endpoint: Option, - sandbox_name: Option, - shorthand_log_dir: PathBuf, -} - -impl PolicyLocalContext { - pub fn new( - current_policy: Option, - gateway_endpoint: Option, - sandbox_name: Option, - ) -> Self { - Self::with_log_dir( - current_policy, - gateway_endpoint, - sandbox_name, - PathBuf::from(LOG_DIR), - ) - } - - fn with_log_dir( - current_policy: Option, - gateway_endpoint: Option, - sandbox_name: Option, - shorthand_log_dir: PathBuf, - ) -> Self { - Self { - current_policy: Arc::new(RwLock::new(current_policy)), - gateway_endpoint, - sandbox_name, - shorthand_log_dir, - } - } - - pub async fn set_current_policy(&self, policy: ProtoSandboxPolicy) { - *self.current_policy.write().await = Some(policy); - } -} - -pub async fn handle_forward_request( - ctx: &PolicyLocalContext, - method: &str, - path: &str, - initial_request: &[u8], - client: &mut S, -) -> Result<()> -where - S: AsyncRead + AsyncWrite + Unpin, -{ - let body = read_request_body(initial_request, client).await?; - let (status, payload) = route_request(ctx, method, path, &body).await; - write_json_response(client, status, payload).await -} - -async fn route_request( - ctx: &PolicyLocalContext, - method: &str, - path: &str, - body: &[u8], -) -> (u16, serde_json::Value) { - let (route, query) = path.split_once('?').map_or((path, ""), |(r, q)| (r, q)); - // Gate every route on the feature flag so the agent surface is fully off - // when the flag is off — including the diagnostic `current_policy` and - // `denials` routes. The skill is also not installed in that mode, so a - // disabled sandbox has no entry point into this API at all. - if !crate::agent_proposals_enabled() { - return ( - 404, - serde_json::json!({ - "error": "feature_disabled", - "detail": "agent-driven policy proposals are not enabled in this sandbox; set the `agent_policy_proposals_enabled` setting to true to enable" - }), - ); - } - match (method, route) { - ("GET", ROUTE_POLICY_CURRENT) => current_policy_response(ctx).await, - ("GET", ROUTE_DENIALS) => recent_denials_response(ctx, query).await, - ("POST", ROUTE_PROPOSALS) => submit_proposal(ctx, body).await, - ("GET", path) if path.starts_with(ROUTE_PROPOSALS_PREFIX) => { - proposal_state_route(ctx, path, query).await - } - _ => ( - 404, - serde_json::json!({ - "error": "not_found", - "detail": format!("policy.local route not found: {method} {route}") - }), - ), - } -} - -/// Parse `{chunk_id}` (status) or `{chunk_id}/wait` (long-poll) from the path -/// suffix and dispatch. Empty `chunk_id` or extra segments return 404 so a -/// malformed path cannot trigger a gateway call. -async fn proposal_state_route( - ctx: &PolicyLocalContext, - path: &str, - query: &str, -) -> (u16, serde_json::Value) { - let suffix = path - .strip_prefix(ROUTE_PROPOSALS_PREFIX) - .unwrap_or_default(); - let (chunk_id, wait) = match suffix.split_once('/') { - Some((id, "wait")) => (id, true), - Some(_) => return not_found_payload(path), - None => (suffix, false), - }; - if chunk_id.is_empty() { - return not_found_payload(path); - } - if wait { - proposal_wait_response(ctx, chunk_id, query).await - } else { - proposal_status_response(ctx, chunk_id).await - } -} - -fn not_found_payload(path: &str) -> (u16, serde_json::Value) { - ( - 404, - serde_json::json!({ - "error": "not_found", - "detail": format!("policy.local proposal sub-route not found: {path}") - }), - ) -} - -/// Build the `next_steps` array embedded in the L7 deny body so the agent has -/// machine-readable pointers to this API. Centralizes the shape here to keep -/// the deny body and the actual route table from drifting — adding or -/// renaming a route only requires touching the route constants above. -/// -/// Returns an empty array when `agent_proposals_enabled()` is false so a -/// disabled sandbox doesn't advertise a surface that 404s. The deny body -/// caller still emits the field (with `[]`) so the wire shape is stable. -#[must_use] -pub fn agent_next_steps() -> serde_json::Value { - if !crate::agent_proposals_enabled() { - return serde_json::json!([]); - } - let host = POLICY_LOCAL_HOST; - serde_json::json!([ - { - "action": "read_skill", - "path": SKILL_PATH, - }, - { - "action": "inspect_policy", - "method": "GET", - "url": format!("http://{host}{ROUTE_POLICY_CURRENT}"), - }, - { - "action": "inspect_recent_denials", - "method": "GET", - "url": format!("http://{host}{ROUTE_DENIALS}?last=5"), - }, - { - "action": "submit_proposal", - "method": "POST", - "url": format!("http://{host}{ROUTE_PROPOSALS}"), - "body_type": "PolicyMergeOperation", - }, - ]) -} - -/// Build the optional natural-language guidance embedded in L7 deny bodies. -#[must_use] -pub fn agent_guidance() -> Option<&'static str> { - crate::agent_proposals_enabled().then_some(AGENT_GUIDANCE) -} - -async fn current_policy_response(ctx: &PolicyLocalContext) -> (u16, serde_json::Value) { - let Some(policy) = ctx.current_policy.read().await.clone() else { - return ( - 404, - serde_json::json!({ - "error": "policy_unavailable", - "detail": "no current sandbox policy is loaded" - }), - ); - }; - - match openshell_policy::serialize_sandbox_policy(&policy) { - Ok(policy_yaml) => ( - 200, - serde_json::json!({ - "format": "yaml", - "policy_yaml": policy_yaml - }), - ), - Err(error) => ( - 500, - serde_json::json!({ - "error": "policy_serialize_failed", - "detail": error.to_string() - }), - ), - } -} - -async fn recent_denials_response( - ctx: &PolicyLocalContext, - query: &str, -) -> (u16, serde_json::Value) { - let limit = parse_last_query(query).unwrap_or(DEFAULT_DENIALS_LIMIT); - let log_dir = ctx.shorthand_log_dir.clone(); - - // Distinguish "shorthand log exists and no denials happened" from "no log - // file yet, so we have nothing to read." Without this flag the agent sees - // `[]` in both cases and cannot tell the difference. The shorthand log is - // always-on (no setting gates it), so the only way `log_available=false` - // happens in practice is if the supervisor has not flushed any events to - // disk yet, or `/var/log` is not writable in this image. - let log_available = matches!( - collect_shorthand_log_files(&log_dir, 1), - Ok(files) if !files.is_empty() - ); - - let denials = tokio::task::spawn_blocking(move || read_recent_denial_lines(&log_dir, limit)) - .await - .unwrap_or_default(); - - let mut payload = serde_json::json!({ - "denials": denials, - "log_available": log_available, - }); - if !log_available { - payload["note"] = serde_json::json!( - "no shorthand log file is present yet at /var/log/openshell.YYYY-MM-DD.log; the supervisor may not have emitted any events to disk yet" - ); - } - - (200, payload) -} - -fn parse_last_query(query: &str) -> Option { - if query.is_empty() { - return None; - } - for pair in query.split('&') { - let Some((key, value)) = pair.split_once('=') else { - continue; - }; - if key == "last" { - return value - .parse::() - .ok() - .map(|n| n.clamp(1, MAX_DENIALS_LIMIT)); - } - } - None -} - -/// Walk the shorthand log files (most-recent first) and return up to `limit` -/// raw denial lines in newest-first order. The agent receives the same -/// human-readable text that `openshell logs` displays — no parsing back into -/// structured form. Updating the shorthand format adds fields automatically; -/// no schema rev required. -/// -/// Reads files synchronously and is intended to run inside `spawn_blocking`. -fn read_recent_denial_lines(log_dir: &Path, limit: usize) -> Vec { - let Ok(files) = collect_shorthand_log_files(log_dir, DENIAL_LOG_FILES_TO_SCAN) else { - return Vec::new(); - }; - - let mut lines: Vec = Vec::with_capacity(limit); - for path in files { - let Ok(contents) = std::fs::read_to_string(&path) else { - continue; - }; - // Walk lines newest-first. Within a single file, the last line written - // is the freshest event. - for line in contents.lines().rev() { - if !is_ocsf_denial_line(line) { - continue; - } - // Defense-in-depth: redact query strings before truncation. The - // FORWARD deny path in `proxy.rs` populates the OCSF `message` - // and URL with the raw request path including `?query=...`, which - // the shorthand layer then renders verbatim. Stripping queries - // here means the agent never sees the secret even if an upstream - // emit site forgets to redact (TODO: harden the emit sites in - // proxy.rs FORWARD path so the on-disk shorthand log itself is - // clean — tracked separately). Redact first so truncation cannot - // slice mid-secret. - let redacted = redact_query_strings(line); - let surfaced = truncate_at_char_boundary(&redacted, MAX_DENIAL_LINE_BYTES); - lines.push(surfaced); - if lines.len() >= limit { - return lines; - } - } - } - lines -} - -/// Replace any `?` substring with `?[redacted]` to keep query-string -/// secrets out of the agent's view. Walks per Unicode scalar value so multi-byte -/// content is safe. A query is everything from `?` until the next whitespace or -/// `]` (the shorthand format uses `[...]` for context tags). -fn redact_query_strings(line: &str) -> String { - let mut out = String::with_capacity(line.len()); - let mut chars = line.chars(); - while let Some(c) = chars.next() { - if c == '?' { - out.push('?'); - out.push_str("[redacted]"); - // Consume until whitespace or `]` (preserved as the next token's - // boundary by writing it back out). - for next in chars.by_ref() { - if next.is_whitespace() || next == ']' { - out.push(next); - break; - } - } - } else { - out.push(c); - } - } - out -} - -/// Truncate `s` at the largest UTF-8 char boundary <= `max_bytes`, appending a -/// `...[truncated]` suffix. Returning a `String` (not `&str`) avoids surprising -/// callers about lifetime relationships with `s`. -fn truncate_at_char_boundary(s: &str, max_bytes: usize) -> String { - if s.len() <= max_bytes { - return s.to_string(); - } - let mut end = max_bytes; - while end > 0 && !s.is_char_boundary(end) { - end -= 1; - } - let mut out = String::with_capacity(end + "...[truncated]".len()); - out.push_str(&s[..end]); - out.push_str("...[truncated]"); - out -} - -/// True for OCSF denial events as rendered by the shorthand layer. The format -/// is ` OCSF <[SEV]> ...`. The literal -/// ` OCSF ` substring identifies an OCSF event (vs. a non-OCSF tracing line); -/// ` DENIED ` is the OCSF action label uppercased and surrounded by spaces, so -/// matching it is safe against substring collisions in URLs or hostnames. -fn is_ocsf_denial_line(line: &str) -> bool { - line.contains(" OCSF ") && line.contains(" DENIED ") -} - -fn collect_shorthand_log_files(log_dir: &Path, max_files: usize) -> std::io::Result> { - let mut entries: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(log_dir)? - .filter_map(std::result::Result::ok) - .filter_map(|entry| { - let path = entry.path(); - let name = entry.file_name(); - let name = name.to_string_lossy(); - // `openshell.YYYY-MM-DD.log` only — the trailing dot in the prefix - // disambiguates from `openshell-ocsf.YYYY-MM-DD.log`. - if !name.starts_with(SHORTHAND_LOG_PREFIX) || !name.ends_with(".log") { - return None; - } - let modified = entry.metadata().and_then(|m| m.modified()).ok()?; - Some((modified, path)) - }) - .collect(); - - entries.sort_by_key(|entry| std::cmp::Reverse(entry.0)); - Ok(entries - .into_iter() - .take(max_files) - .map(|(_, p)| p) - .collect()) -} - -async fn submit_proposal(ctx: &PolicyLocalContext, body: &[u8]) -> (u16, serde_json::Value) { - let Some(endpoint) = ctx.gateway_endpoint.as_deref() else { - return ( - 503, - serde_json::json!({ - "error": "gateway_unavailable", - "detail": "policy proposal submission requires a gateway-connected sandbox" - }), - ); - }; - let Some(sandbox_name) = ctx - .sandbox_name - .as_deref() - .map(str::trim) - .filter(|name| !name.is_empty()) - else { - return ( - 503, - serde_json::json!({ - "error": "sandbox_name_unavailable", - "detail": "policy proposal submission requires a sandbox name" - }), - ); - }; - - let chunks = match proposal_chunks_from_body(body) { - Ok(chunks) => chunks, - Err(error) => return (400, error_payload("invalid_proposal", error)), - }; - - let client = match crate::grpc_client::CachedOpenShellClient::connect(endpoint).await { - Ok(client) => client, - Err(error) => { - return ( - 502, - serde_json::json!({ - "error": "gateway_connect_failed", - "detail": error.to_string() - }), - ); - } - }; - - // Pre-compute the audit summaries before handing `chunks` to the - // gateway client (which consumes the vec). The summaries pair up with - // the gateway's `accepted_chunk_ids` by index for the propose events - // emitted after submit returns. - let audit_summaries: Vec = chunks.iter().map(summarize_chunk_for_audit).collect(); - - let response = match client - .submit_policy_analysis(sandbox_name, vec![], chunks, vec![], "agent_authored") - .await - { - Ok(response) => response, - Err(error) => { - return ( - 502, - serde_json::json!({ - "error": "proposal_submit_failed", - "detail": error.to_string() - }), - ); - } - }; - - // One OCSF event per accepted chunk so the audit trace in - // `openshell logs ` carries the propose beat alongside the - // proxy deny and policy reload that bracket it. - // - // The gateway compresses its `accepted_chunk_ids` by skipping rejected - // chunks (`grpc/policy.rs:1357-1436`); the proto does not promise 1:1 - // ordering against the request. Today client-side validation catches - // both rejection causes (missing rule_name, missing proposed_rule) - // before submit, so the lengths match in practice. If they don't, we - // can't safely pair audit_summaries by index — fall back to a generic - // event per accepted chunk_id rather than mis-attribute a summary. - let pairing_is_safe = response.accepted_chunk_ids.len() == audit_summaries.len(); - for (idx, chunk_id) in response.accepted_chunk_ids.iter().enumerate() { - let summary = if pairing_is_safe { - audit_summaries[idx].as_str() - } else { - "(summary unavailable: gateway partially accepted)" - }; - emit_policy_propose_event(chunk_id, summary); - } - - ( - 202, - serde_json::json!({ - "status": "submitted", - "accepted_chunks": response.accepted_chunks, - "rejected_chunks": response.rejected_chunks, - "rejection_reasons": response.rejection_reasons, - "accepted_chunk_ids": response.accepted_chunk_ids, - }), - ) -} - -/// Emit one CONFIG:PROPOSED audit event for an agent-authored proposal that -/// the gateway just accepted. The message names the `chunk_id`, the binary, -/// and the endpoint the agent is asking to reach — what a developer needs -/// to see in the audit trace to correlate against the inbox card. -fn emit_policy_propose_event(chunk_id: &str, summary: &str) { - ocsf_emit!( - ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .state(StateId::Other, "PROPOSED") - .unmapped("chunk_id", serde_json::json!(chunk_id)) - .message(format!( - "agent_authored proposal chunk:{chunk_id} {summary}" - )) - .build() - ); -} - -/// Emit one CONFIG:APPROVED or CONFIG:REJECTED audit event observed by the -/// `/wait` poll loop. The reviewer's free-form `rejection_reason` (if any) -/// is included verbatim so the audit trace shows what guidance the agent -/// received. -fn emit_policy_decision_event(chunk: &PolicyChunk) { - let summary = summarize_chunk_for_audit(chunk); - match chunk.status.as_str() { - "approved" => ocsf_emit!( - ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .state(StateId::Enabled, "APPROVED") - .unmapped("chunk_id", serde_json::json!(chunk.id)) - .message(format!("chunk:{} approved {summary}", chunk.id)) - .build() - ), - "rejected" => { - // The reviewer's free-form rejection_reason is opaque user - // input. The agent reads the raw text via `GET /v1/proposals/ - // {id}` to redraft; the OCSF surface (which can be shipped to - // external SIEMs per AGENTS.md) gets a sanitized copy — caps - // length and strips control characters so a stray credential - // or escape sequence cannot leak into the audit log. - let sanitized = sanitize_reason_for_audit(&chunk.rejection_reason); - let reason_display = if sanitized.is_empty() { - "(no guidance)".to_string() - } else { - format!("\"{sanitized}\"") - }; - ocsf_emit!( - ConfigStateChangeBuilder::new(crate::ocsf_ctx()) - .severity(SeverityId::Low) - .status(StatusId::Success) - .state(StateId::Disabled, "REJECTED") - .unmapped("chunk_id", serde_json::json!(chunk.id)) - .unmapped("rejection_reason", serde_json::json!(sanitized)) - .message(format!( - "chunk:{} rejected {summary} reason:{reason_display}", - chunk.id - )) - .build() - ); - } - // Caller is gated on `is_terminal_status`, so a non-terminal status - // here is a code change that broke the invariant. Warn loudly so - // the audit gap doesn't go silent. - other => tracing::warn!( - chunk_id = %chunk.id, - status = %other, - "emit_policy_decision_event called on non-terminal status; no audit event emitted" - ), - } -} - -/// Sanitize a free-form reviewer-typed string before it lands in the OCSF -/// audit surface. The agent still reads the raw text via the API — this is -/// audit-side defense only. -fn sanitize_reason_for_audit(raw: &str) -> String { - const MAX_CHARS: usize = 200; - let cleaned: String = raw - .chars() - .filter(|c| !c.is_control() || *c == ' ') - .take(MAX_CHARS) - .collect(); - if raw.chars().count() > MAX_CHARS { - format!("{cleaned}…") - } else { - cleaned - } -} - -/// One-line audit description of a chunk's target: binary, host, port, and -/// L7 method/path if present. Used by both the propose and approve/reject -/// audit events so the trace can be grepped by endpoint without parsing -/// JSON. -fn summarize_chunk_for_audit(chunk: &PolicyChunk) -> String { - let Some(rule) = chunk.proposed_rule.as_ref() else { - return format!("rule_name:{}", chunk.rule_name); - }; - let endpoint = rule.endpoints.first().map_or_else( - || "unknown".to_string(), - |ep| format!("{}:{}", ep.host, ep.port), - ); - let l7 = rule - .endpoints - .first() - .and_then(|ep| ep.rules.first()) - .and_then(|r| r.allow.as_ref()) - .map(|a| format!(" {} {}", a.method, a.path)) - .unwrap_or_default(); - let binary = if chunk.binary.is_empty() { - String::new() - } else { - format!(" by {}", chunk.binary) - }; - format!("on {endpoint}{l7}{binary}") -} - -/// `GET /v1/proposals/{chunk_id}` — immediate state. One gateway call, no loop. -async fn proposal_status_response( - ctx: &PolicyLocalContext, - chunk_id: &str, -) -> (u16, serde_json::Value) { - let session = match open_lookup_session(ctx).await { - Ok(session) => session, - Err(err) => return err, - }; - fetch_chunk_or_404(&session, chunk_id, false).await -} - -/// `GET /v1/proposals/{chunk_id}/wait?timeout=` — block until terminal or -/// timeout. Returns the chunk's current state on a status transition; on -/// timeout, returns the still-pending state with `timed_out: true` so the -/// agent can re-issue without ambiguity. The agent's wait costs zero LLM -/// tokens — the tool call sits in a socket recv until we return. -async fn proposal_wait_response( - ctx: &PolicyLocalContext, - chunk_id: &str, - query: &str, -) -> (u16, serde_json::Value) { - let session = match open_lookup_session(ctx).await { - Ok(session) => session, - Err(err) => return err, - }; - let timeout_secs = parse_timeout_query(query); - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(timeout_secs); - loop { - match fetch_chunk(&session, chunk_id).await { - Ok(Some(chunk)) if is_terminal_status(&chunk.status) => { - // Audit beat: emit at the moment this sandbox observes the - // decision so the trace correlates with the proxy events - // bracketing the loop. Multiple waiters on the same chunk - // each fire one event — acceptable for a wakeup audit. - emit_policy_decision_event(&chunk); - let policy_reloaded = if chunk.status == "approved" { - // Hold the wait until the local supervisor has loaded a - // policy that semantically contains this chunk's - // proposed rule. Reloads triggered by *other* chunks or - // settings changes do not wake us; a missing - // proposed_rule (defensive) skips the check and - // returns reloaded=false so the agent can decide. - // - // Floor the reload-wait window to RELOAD_WAIT_MIN_FLOOR - // so an approval that arrives at T-50ms still gets a - // realistic shot at seeing the reload. Worst case we - // overshoot the caller's deadline by this floor — - // preferable to returning reloaded=false on every - // short-budget call and forcing the agent to re-issue. - let reload_deadline = std::cmp::max( - deadline, - tokio::time::Instant::now() + RELOAD_WAIT_MIN_FLOOR, - ); - match chunk.proposed_rule.as_ref() { - Some(rule) => { - wait_for_local_policy_to_cover(ctx, rule, reload_deadline).await - } - None => false, - } - } else { - // Rejected: no reload semantics — the agent reads - // rejection_reason and redrafts. - false - }; - return (200, chunk_state_payload(&chunk, false, policy_reloaded)); - } - Ok(Some(chunk)) => { - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - return (200, chunk_state_payload(&chunk, true, false)); - } - let sleep_for = std::cmp::min(remaining, PROPOSAL_WAIT_POLL_INTERVAL); - tokio::time::sleep(sleep_for).await; - } - Ok(None) => return chunk_not_found_payload(chunk_id), - Err(err) => return err, - } - } -} - -fn chunk_not_found_payload(chunk_id: &str) -> (u16, serde_json::Value) { - ( - 404, - error_payload( - "chunk_not_found", - format!("chunk '{chunk_id}' is not present in this sandbox's draft policy"), - ), - ) -} - -async fn fetch_chunk_or_404( - session: &LookupSession<'_>, - chunk_id: &str, - timed_out: bool, -) -> (u16, serde_json::Value) { - match fetch_chunk(session, chunk_id).await { - Ok(Some(chunk)) => (200, chunk_state_payload(&chunk, timed_out, false)), - Ok(None) => chunk_not_found_payload(chunk_id), - Err(err) => err, - } -} - -/// Build the agent-facing response for a chunk. -/// -/// Selection rule: include the fields the agent needs to decide what to do -/// next on the redraft loop — identity (`chunk_id`, `status`), the proposal -/// it submitted (`rule_name`, `binary`), the two feedback signals -/// (`rejection_reason` from the reviewer, `validation_result` from the -/// gateway prover), and (on /wait) `policy_reloaded` so the agent can tell -/// "approved AND the new rule is loaded — safe to retry" from "approved -/// but the supervisor hasn't reloaded yet — re-issue /wait or surface to -/// user". Display-only proto fields (`hit_count`, `confidence`, `stage`, -/// timing) are left off until a concrete agent need surfaces them. -fn chunk_state_payload( - chunk: &PolicyChunk, - timed_out: bool, - policy_reloaded: bool, -) -> serde_json::Value { - let mut payload = serde_json::json!({ - "chunk_id": chunk.id, - "status": chunk.status, - "rule_name": chunk.rule_name, - "binary": chunk.binary, - "rejection_reason": chunk.rejection_reason, - "validation_result": chunk.validation_result, - }); - if timed_out { - payload["timed_out"] = serde_json::json!(true); - } - if chunk.status == "approved" { - payload["policy_reloaded"] = serde_json::json!(policy_reloaded); - } - payload -} - -fn is_terminal_status(status: &str) -> bool { - matches!(status, "approved" | "rejected") -} - -/// After a chunk is approved upstream, wait until the local supervisor has -/// loaded a policy that semantically contains the chunk's proposed rule. -/// Returns `true` if coverage was observed before the deadline, `false` -/// otherwise — the caller reports that bool back to the agent as -/// `policy_reloaded` so it can decide whether to retry immediately or -/// re-issue `/wait`. -/// -/// Why rule-coverage instead of whole-policy diff (as we used to do): -/// -/// 1. **False sleep.** If the agent re-issues `/wait` after a `timed_out` -/// response, the chunk may have approved AND the supervisor may have -/// reloaded between the two `/wait` calls. A diff-based check snapshots -/// the already-updated policy as baseline and then waits forever for -/// another change. The skill tells the agent to re-issue on -/// `timed_out`, so the diff approach is broken on the happy path. -/// 2. **False wakeup.** Any unrelated reload (another agent's approval, -/// settings change) flips a whole-policy diff, but the chunk's actual -/// rule may not be loaded yet. The agent retries, hits another -/// `policy_denied`, and the revise-loop fires with no real signal to -/// revise on. -/// -/// The polling cadence here is faster than `PROPOSAL_WAIT_POLL_INTERVAL` -/// (which paces upstream gateway calls). This loop only reads in-memory -/// state, so 200ms gives a responsive handoff to the agent's retry once -/// the supervisor's own policy poll catches up. -async fn wait_for_local_policy_to_cover( - ctx: &PolicyLocalContext, - proposed_rule: &NetworkPolicyRule, - deadline: tokio::time::Instant, -) -> bool { - const TICK: std::time::Duration = std::time::Duration::from_millis(200); - loop { - // Clone the snapshot out of the RwLock before running coverage — - // otherwise the read guard is held across `policy_covers_rule`'s - // iteration of `network_policies`, serializing a writer (supervisor - // reload) on the very thing we're waiting for. Clone-per-tick on - // a few-KB struct is cheap for the bounded wait window here. - let snapshot = ctx.current_policy.read().await.clone(); - if let Some(policy) = snapshot.as_ref() - && openshell_policy::policy_covers_rule(policy, proposed_rule) - { - return true; - } - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - return false; - } - tokio::time::sleep(std::cmp::min(remaining, TICK)).await; - } -} - -/// Parse `?timeout=` from the query string. Default applies for missing -/// or unparseable values; bounds clamp to keep the agent's hold ceiling -/// sane. Re-issue is the right pattern for longer waits. -fn parse_timeout_query(query: &str) -> u64 { - let raw = query - .split('&') - .filter_map(|kv| kv.split_once('=')) - .find(|(k, _)| *k == "timeout") - .map_or("", |(_, v)| v); - raw.parse::() - .unwrap_or(PROPOSAL_WAIT_DEFAULT_SECS) - .clamp(PROPOSAL_WAIT_MIN_SECS, PROPOSAL_WAIT_MAX_SECS) -} - -/// One connected gateway client + the validated sandbox name. Built once -/// per request and reused for every `fetch_chunk` call in a wait loop so a -/// 60-second wait does one TLS handshake, not sixty. -struct LookupSession<'a> { - client: crate::grpc_client::CachedOpenShellClient, - sandbox_name: &'a str, -} - -/// Validate ctx and open one gateway channel. Failures map to the canonical -/// error payload shape used by both `/proposals/{id}` and `/wait`. -async fn open_lookup_session( - ctx: &PolicyLocalContext, -) -> std::result::Result, (u16, serde_json::Value)> { - let endpoint = ctx.gateway_endpoint.as_deref().ok_or_else(|| { - ( - 503, - error_payload( - "gateway_unavailable", - "proposal state lookup requires a gateway-connected sandbox".to_string(), - ), - ) - })?; - let sandbox_name = ctx - .sandbox_name - .as_deref() - .map(str::trim) - .filter(|name| !name.is_empty()) - .ok_or_else(|| { - ( - 503, - error_payload( - "sandbox_name_unavailable", - "proposal state lookup requires a sandbox name".to_string(), - ), - ) - })?; - let client = crate::grpc_client::CachedOpenShellClient::connect(endpoint) - .await - .map_err(|e| (502, error_payload("gateway_connect_failed", e.to_string())))?; - Ok(LookupSession { - client, - sandbox_name, - }) -} - -/// One gateway call: list the sandbox's draft chunks and find the matching -/// id. Returns `Ok(None)` only when the gateway responded successfully but -/// no chunk in this sandbox matches. -async fn fetch_chunk( - session: &LookupSession<'_>, - chunk_id: &str, -) -> std::result::Result, (u16, serde_json::Value)> { - let chunks = session - .client - .get_draft_policy(session.sandbox_name, "") - .await - .map_err(|e| (502, error_payload("gateway_lookup_failed", e.to_string())))?; - Ok(chunks.into_iter().find(|c| c.id == chunk_id)) -} - -fn proposal_chunks_from_body(body: &[u8]) -> std::result::Result, String> { - let request: ProposalRequest = serde_json::from_slice(body).map_err(|e| e.to_string())?; - if request.operations.is_empty() { - return Err("proposal requires at least one operation".to_string()); - } - - let mut chunks = Vec::new(); - for operation in request.operations { - let Some(add_rule) = operation.get("addRule").cloned() else { - return Err( - "this MVP accepts `addRule` operations; submit a full narrow NetworkPolicyRule" - .to_string(), - ); - }; - let add_rule: AddNetworkRuleJson = - serde_json::from_value(add_rule).map_err(|e| e.to_string())?; - chunks.push(policy_chunk_from_add_rule( - add_rule, - request.intent_summary.as_deref().unwrap_or_default(), - )?); - } - - Ok(chunks) -} - -fn policy_chunk_from_add_rule( - add_rule: AddNetworkRuleJson, - intent_summary: &str, -) -> std::result::Result { - let mut rule = network_rule_from_json(add_rule.rule)?; - let rule_name = add_rule - .rule_name - .as_deref() - .map(str::trim) - .filter(|name| !name.is_empty()) - .map_or_else(|| rule.name.clone(), ToString::to_string); - if rule_name.trim().is_empty() { - return Err("addRule.ruleName or rule.name is required".to_string()); - } - if rule.name.trim().is_empty() { - rule.name.clone_from(&rule_name); - } - - let binary = rule - .binaries - .first() - .map(|binary| binary.path.clone()) - .unwrap_or_default(); - - Ok(PolicyChunk { - id: String::new(), - status: "pending".to_string(), - rule_name, - proposed_rule: Some(rule), - rationale: intent_summary.to_string(), - security_notes: String::new(), - confidence: 0.75, - denial_summary_ids: vec![], - created_at_ms: 0, - decided_at_ms: 0, - stage: "agent".to_string(), - supersedes_chunk_id: String::new(), - hit_count: 1, - first_seen_ms: 0, - last_seen_ms: 0, - binary, - validation_result: String::new(), - rejection_reason: String::new(), - }) -} - -fn network_rule_from_json( - rule: NetworkPolicyRuleJson, -) -> std::result::Result { - if rule.endpoints.is_empty() { - return Err("rule.endpoints must contain at least one endpoint".to_string()); - } - - let endpoints = rule - .endpoints - .into_iter() - .map(|endpoint| { - let mut endpoint = network_endpoint_from_json(endpoint)?; - endpoint.advisor_proposed = true; - Ok::(endpoint) - }) - .collect::, _>>()?; - let binaries = rule - .binaries - .into_iter() - .map(|binary| { - let mut proposal_binary = NetworkBinary { - path: binary.path, - ..Default::default() - }; - // The deprecated harness bit is ignored by policy YAML, but OPA - // maps it to advisor_proposed to preserve the SSRF two-step flow. - #[allow(deprecated)] - { - proposal_binary.harness = true; - } - proposal_binary - }) - .collect(); - - Ok(NetworkPolicyRule { - name: rule.name.unwrap_or_default(), - endpoints, - binaries, - }) -} - -fn network_endpoint_from_json( - endpoint: NetworkEndpointJson, -) -> std::result::Result { - if endpoint.host.trim().is_empty() { - return Err("endpoint.host is required".to_string()); - } - - let mut ports = endpoint.ports; - if ports.is_empty() && endpoint.port > 0 { - ports.push(endpoint.port); - } - if ports.is_empty() { - return Err("endpoint.port or endpoint.ports is required".to_string()); - } - if endpoint - .rules - .iter() - .any(|rule| rule.allow.path.contains('?')) - { - return Err("L7 allow paths must not include query strings".to_string()); - } - - let port = ports.first().copied().unwrap_or_default(); - let rules = endpoint - .rules - .into_iter() - .map(|rule| L7Rule { - allow: Some(L7Allow { - method: rule.allow.method, - path: rule.allow.path, - command: rule.allow.command, - query: HashMap::new(), - // GraphQL fields default empty — agent-authored proposals from - // policy.local target REST/SQL/L4 endpoints; GraphQL operation - // matching is set on the policy server side or via direct YAML. - operation_type: String::new(), - operation_name: String::new(), - fields: Vec::new(), - rpc_method: String::new(), - params: HashMap::new(), - }), - }) - .collect(); - let deny_rules = endpoint - .deny_rules - .into_iter() - .map(|rule| L7DenyRule { - method: rule.method, - path: rule.path, - command: rule.command, - query: HashMap::new(), - operation_type: String::new(), - operation_name: String::new(), - fields: Vec::new(), - rpc_method: String::new(), - params: HashMap::new(), - }) - .collect(); - - Ok(NetworkEndpoint { - host: endpoint.host, - port, - protocol: endpoint.protocol, - tls: endpoint.tls, - enforcement: endpoint.enforcement, - access: endpoint.access, - rules, - allowed_ips: endpoint.allowed_ips, - ports, - deny_rules, - allow_encoded_slash: endpoint.allow_encoded_slash, - websocket_credential_rewrite: false, - request_body_credential_rewrite: false, - advisor_proposed: false, - // GraphQL persisted-query knobs and path scoping default empty — - // agent proposals don't author them today. - persisted_queries: String::new(), - graphql_persisted_queries: HashMap::new(), - graphql_max_body_bytes: 0, - json_rpc_max_body_bytes: 0, - path: String::new(), - }) -} - -async fn read_request_body(initial_request: &[u8], client: &mut S) -> Result> -where - S: AsyncRead + Unpin, -{ - let Some(header_end) = find_header_end(initial_request) else { - return Ok(Vec::new()); - }; - let content_length = parse_content_length(&initial_request[..header_end])?; - if content_length > MAX_POLICY_LOCAL_BODY_BYTES { - return Err(miette::miette!( - "policy.local request body exceeds {MAX_POLICY_LOCAL_BODY_BYTES} bytes" - )); - } - - let mut body = initial_request[header_end..].to_vec(); - if body.len() > content_length { - body.truncate(content_length); - } - let read_loop = async { - while body.len() < content_length { - let remaining = content_length - body.len(); - let mut chunk = vec![0u8; remaining.min(8192)]; - let n = client.read(&mut chunk).await.into_diagnostic()?; - if n == 0 { - return Err(miette::miette!("policy.local request body ended early")); - } - body.extend_from_slice(&chunk[..n]); - } - Ok::<(), miette::Report>(()) - }; - tokio::time::timeout(POLICY_LOCAL_BODY_READ_TIMEOUT, read_loop) - .await - .map_err(|_| miette::miette!("policy.local request body read timed out"))??; - - Ok(body) -} - -fn parse_content_length(headers: &[u8]) -> Result { - let headers = String::from_utf8_lossy(headers); - for line in headers.lines().skip(1) { - if let Some((name, value)) = line.split_once(':') - && name.eq_ignore_ascii_case("content-length") - { - return value - .trim() - .parse::() - .into_diagnostic() - .map_err(|_| miette::miette!("invalid policy.local Content-Length")); - } - } - Ok(0) -} - -fn find_header_end(buf: &[u8]) -> Option { - buf.windows(4) - .position(|window| window == b"\r\n\r\n") - .map(|idx| idx + 4) -} - -async fn write_json_response( - client: &mut S, - status: u16, - payload: serde_json::Value, -) -> Result<()> -where - S: AsyncWrite + Unpin, -{ - let body = payload.to_string(); - let response = format!( - "HTTP/1.1 {status} {}\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - Connection: close\r\n\ - \r\n\ - {}", - status_text(status), - body.len(), - body - ); - client - .write_all(response.as_bytes()) - .await - .into_diagnostic()?; - client.flush().await.into_diagnostic()?; - Ok(()) -} - -fn status_text(status: u16) -> &'static str { - match status { - 202 => "Accepted", - 400 => "Bad Request", - 404 => "Not Found", - 500 => "Internal Server Error", - 502 => "Bad Gateway", - 503 => "Service Unavailable", - _ => "OK", - } -} - -fn error_payload(error: &str, detail: String) -> serde_json::Value { - serde_json::json!({ - "error": error, - "detail": detail - }) -} - -#[derive(Debug, Deserialize)] -struct ProposalRequest { - #[serde(default)] - intent_summary: Option, - #[serde(default)] - operations: Vec, -} - -#[derive(Debug, Deserialize)] -struct AddNetworkRuleJson { - #[serde(default, rename = "ruleName")] - rule_name: Option, - rule: NetworkPolicyRuleJson, -} - -#[derive(Debug, Deserialize)] -struct NetworkPolicyRuleJson { - #[serde(default)] - name: Option, - #[serde(default)] - endpoints: Vec, - #[serde(default)] - binaries: Vec, -} - -#[derive(Debug, Deserialize)] -struct NetworkEndpointJson { - host: String, - #[serde(default)] - port: u32, - #[serde(default)] - ports: Vec, - #[serde(default)] - protocol: String, - #[serde(default)] - tls: String, - #[serde(default)] - enforcement: String, - #[serde(default)] - access: String, - #[serde(default)] - rules: Vec, - #[serde(default)] - allowed_ips: Vec, - #[serde(default)] - deny_rules: Vec, - #[serde(default)] - allow_encoded_slash: bool, -} - -#[derive(Debug, Deserialize)] -struct NetworkBinaryJson { - path: String, -} - -#[derive(Debug, Deserialize)] -struct L7RuleJson { - allow: L7AllowJson, -} - -#[derive(Debug, Deserialize)] -struct L7AllowJson { - #[serde(default)] - method: String, - #[serde(default)] - path: String, - #[serde(default)] - command: String, -} - -#[derive(Debug, Deserialize)] -struct L7DenyRuleJson { - #[serde(default)] - method: String, - #[serde(default)] - path: String, - #[serde(default)] - command: String, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn proposal_chunks_from_body_accepts_add_rule_operation() { - let body = br#"{ - "intent_summary": "Allow gh to create one repo.", - "operations": [ - { - "addRule": { - "ruleName": "github_api_repo_create", - "rule": { - "endpoints": [ - { - "host": "api.github.com", - "port": 443, - "protocol": "rest", - "tls": "terminate", - "enforcement": "enforce", - "rules": [ - { - "allow": { - "method": "POST", - "path": "/user/repos" - } - } - ] - } - ], - "binaries": [ - { - "path": "/usr/bin/gh" - } - ] - } - } - } - ] - }"#; - - let chunks = proposal_chunks_from_body(body).unwrap(); - - assert_eq!(chunks.len(), 1); - assert_eq!(chunks[0].rule_name, "github_api_repo_create"); - assert_eq!(chunks[0].rationale, "Allow gh to create one repo."); - assert_eq!(chunks[0].binary, "/usr/bin/gh"); - let rule = chunks[0].proposed_rule.as_ref().unwrap(); - assert_eq!(rule.name, "github_api_repo_create"); - assert_eq!(rule.endpoints[0].host, "api.github.com"); - assert_eq!(rule.endpoints[0].port, 443); - assert_eq!(rule.endpoints[0].ports, vec![443]); - assert_eq!(rule.endpoints[0].protocol, "rest"); - #[allow(deprecated)] - { - assert!(rule.binaries[0].harness); - } - assert_eq!( - rule.endpoints[0].rules[0].allow.as_ref().unwrap().path, - "/user/repos" - ); - } - - #[test] - fn proposal_chunks_from_body_rejects_query_in_l7_path() { - let body = br#"{ - "operations": [ - { - "addRule": { - "ruleName": "bad", - "rule": { - "endpoints": [ - { - "host": "api.github.com", - "port": 443, - "rules": [ - { - "allow": { - "method": "GET", - "path": "/repos?token=secret" - } - } - ] - } - ] - } - } - } - ] - }"#; - - let error = proposal_chunks_from_body(body).unwrap_err(); - assert!(error.contains("query strings")); - assert!(!error.contains("secret")); - } - - #[test] - fn parse_last_query_clamps_to_max() { - assert_eq!(parse_last_query("last=5"), Some(5)); - assert_eq!(parse_last_query("foo=bar&last=20"), Some(20)); - assert_eq!(parse_last_query("last=999"), Some(MAX_DENIALS_LIMIT)); - assert_eq!(parse_last_query("last=0"), Some(1)); - assert_eq!(parse_last_query(""), None); - assert_eq!(parse_last_query("other=1"), None); - } - - #[test] - fn is_ocsf_denial_line_filters_correctly() { - // OCSF denial — match. - assert!(is_ocsf_denial_line( - "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/x [policy:p engine:l7]" - )); - assert!(is_ocsf_denial_line( - "2026-05-06T17:02:00.000Z OCSF NET:OPEN [MED] DENIED curl(42) -> blocked.com:443 [policy:- engine:opa]" - )); - - // OCSF allowed — must not match. - assert!(!is_ocsf_denial_line( - "2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(42) -> api.example.com:443" - )); - - // Non-OCSF tracing line — must not match even if it contains the word DENIED. - assert!(!is_ocsf_denial_line( - "2026-05-06T17:02:00.000Z INFO some::module: request DENIED in upstream" - )); - - // Empty line — must not match. - assert!(!is_ocsf_denial_line("")); - } - - #[tokio::test] - async fn recent_denials_returns_newest_first_from_shorthand_lines() { - let dir = tempfile::tempdir().unwrap(); - let log_path = dir.path().join("openshell.2026-05-06.log"); - // Mixed file: allowed events, non-OCSF info lines, two denials. - // Lines are written in chronological order; reader walks newest-first. - let body = "\ -2026-05-06T17:02:00.000Z OCSF NET:OPEN [INFO] ALLOWED curl(10) -> api.example.com:443 [policy:default engine:opa] -2026-05-06T17:02:01.000Z INFO some::module: routine status check -2026-05-06T17:02:02.000Z OCSF HTTP:GET [MED] DENIED GET http://blocked.example/v1/data [policy:default-deny engine:l7] -2026-05-06T17:02:03.000Z OCSF NET:OPEN [INFO] ALLOWED curl(11) -> api.example.com:443 -2026-05-06T17:02:04.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com:443/repos/x/y/contents/z [policy:gh_readonly engine:l7] -"; - std::fs::write(&log_path, body).unwrap(); - - let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); - let (status, payload) = recent_denials_response(&ctx, "last=10").await; - assert_eq!(status, 200); - assert_eq!(payload["log_available"], true); - let denials = payload["denials"].as_array().unwrap(); - assert_eq!(denials.len(), 2); - // Newest first. - assert!(denials[0].as_str().unwrap().contains("HTTP:PUT")); - assert!( - denials[0] - .as_str() - .unwrap() - .contains("/repos/x/y/contents/z") - ); - assert!(denials[1].as_str().unwrap().contains("HTTP:GET")); - assert!(denials[1].as_str().unwrap().contains("blocked.example")); - } - - #[tokio::test] - async fn recent_denials_skips_jsonl_log_files() { - // The shorthand reader must not surface `openshell-ocsf.*.log` content - // even if a deny-looking line is present, so the response stays - // independent of the JSONL appender's enabled state. - let dir = tempfile::tempdir().unwrap(); - let jsonl = dir.path().join("openshell-ocsf.2026-05-06.log"); - std::fs::write( - &jsonl, - r#"{"class_uid":4002,"action_id":2,"message":"DENIED","time":1}"#, - ) - .unwrap(); - - let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); - let (status, payload) = recent_denials_response(&ctx, "").await; - assert_eq!(status, 200); - assert_eq!(payload["log_available"], false); - assert_eq!(payload["denials"].as_array().unwrap().len(), 0); - } - - #[tokio::test] - async fn recent_denials_signals_when_log_is_missing() { - let dir = tempfile::tempdir().unwrap(); - let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); - let (status, payload) = recent_denials_response(&ctx, "").await; - assert_eq!(status, 200); - assert_eq!(payload["log_available"], false); - assert_eq!(payload["denials"].as_array().unwrap().len(), 0); - assert!( - payload["note"] - .as_str() - .unwrap() - .contains("/var/log/openshell.") - ); - } - - #[test] - fn redact_query_strings_removes_query_from_url_token() { - let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x?access_token=secret-token-1234 [policy:p engine:l7]"; - let redacted = redact_query_strings(line); - assert!(!redacted.contains("secret-token-1234")); - assert!(!redacted.contains("access_token")); - assert!(redacted.contains("?[redacted]")); - // Bracketed tag after the URL preserved. - assert!(redacted.contains("[policy:p engine:l7]")); - } - - #[test] - fn redact_query_strings_removes_query_in_reason_tag() { - // The FORWARD deny path's `message` becomes `[reason:...]` and may - // include a path with query string lacking a `://` prefix. - let line = "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://api.github.com/x [policy:p engine:opa] [reason:FORWARD denied PUT api.github.com:443/x?token=secret-456]"; - let redacted = redact_query_strings(line); - assert!(!redacted.contains("secret-456")); - assert!(!redacted.contains("token=secret")); - assert!(redacted.contains("?[redacted]]")); - } - - #[test] - fn redact_query_strings_handles_multibyte_chars() { - let line = "ÜLÅUTF8 ? secret-x [policy:p]"; - // No `?` here, so no redaction — but must not panic. - let _ = redact_query_strings(line); - } - - #[test] - fn truncate_at_char_boundary_does_not_panic_on_multibyte() { - // 4-byte emoji sequence so byte-naive slicing would panic. - let s = "🚀".repeat(2000); // 8000 bytes - let truncated = truncate_at_char_boundary(&s, 4096); - assert!(truncated.len() <= 4096 + "...[truncated]".len()); - assert!(truncated.ends_with("...[truncated]")); - // Result must be valid UTF-8 — implicit if we return without panic. - } - - #[tokio::test] - async fn recent_denials_truncates_pathological_lines() { - let dir = tempfile::tempdir().unwrap(); - let log_path = dir.path().join("openshell.2026-05-06.log"); - // A single OCSF denial line exceeding MAX_DENIAL_LINE_BYTES. - let huge_path = "/".to_string() + &"a".repeat(MAX_DENIAL_LINE_BYTES + 100); - let line = format!( - "2026-05-06T17:02:00.000Z OCSF HTTP:PUT [MED] DENIED PUT http://x{huge_path} [policy:p engine:l7]\n" - ); - std::fs::write(&log_path, line).unwrap(); - - let ctx = PolicyLocalContext::with_log_dir(None, None, None, dir.path().to_path_buf()); - let (_, payload) = recent_denials_response(&ctx, "last=1").await; - let denials = payload["denials"].as_array().unwrap(); - assert_eq!(denials.len(), 1); - let surfaced = denials[0].as_str().unwrap(); - assert!(surfaced.len() <= MAX_DENIAL_LINE_BYTES + "...[truncated]".len()); - assert!(surfaced.ends_with("...[truncated]")); - } - - use crate::test_helpers::ProposalsFlagGuard; - - #[test] - fn agent_next_steps_returns_empty_when_flag_off() { - let _guard = ProposalsFlagGuard::set_blocking(false); - let steps = agent_next_steps(); - let arr = steps.as_array().expect("agent_next_steps is an array"); - assert!( - arr.is_empty(), - "expected empty next_steps when feature is off, got {steps}" - ); - } - - #[test] - fn agent_next_steps_returns_full_array_when_flag_on() { - let _guard = ProposalsFlagGuard::set_blocking(true); - let steps = agent_next_steps(); - let arr = steps.as_array().expect("agent_next_steps is an array"); - assert_eq!(arr.len(), 4, "expected 4 next_steps when feature is on"); - let actions: Vec<&str> = arr - .iter() - .filter_map(|v| v.get("action").and_then(serde_json::Value::as_str)) - .collect(); - assert!(actions.contains(&"read_skill")); - assert!(actions.contains(&"submit_proposal")); - } - - #[test] - fn agent_guidance_is_absent_when_flag_off() { - let _guard = ProposalsFlagGuard::set_blocking(false); - assert!(agent_guidance().is_none()); - } - - #[test] - fn agent_guidance_points_to_policy_advisor_when_flag_on() { - let _guard = ProposalsFlagGuard::set_blocking(true); - let guidance = agent_guidance().expect("guidance when proposals are enabled"); - assert!(guidance.contains("do not stop")); - assert!(guidance.contains("/etc/openshell/skills/policy_advisor.md")); - assert!(guidance.contains("http://policy.local/v1/proposals")); - assert!(guidance.contains("policy_reloaded: true")); - } - - #[tokio::test] - async fn route_request_returns_feature_disabled_when_flag_off() { - let _guard = ProposalsFlagGuard::set(false).await; - let ctx = PolicyLocalContext::new( - Some(ProtoSandboxPolicy { - version: 1, - ..Default::default() - }), - None, - None, - ); - - // Even the otherwise-public `current_policy` route returns 404 with - // a feature_disabled error: when the surface is off it's off - // entirely, not selectively. - let (status, payload) = route_request(&ctx, "GET", ROUTE_POLICY_CURRENT, &[]).await; - assert_eq!(status, 404); - assert_eq!(payload["error"], "feature_disabled"); - assert!( - payload["detail"] - .as_str() - .unwrap() - .contains("agent_policy_proposals_enabled"), - "feature_disabled detail must name the setting key for actionability" - ); - } - - #[tokio::test] - async fn current_policy_route_returns_yaml_envelope() { - let _guard = ProposalsFlagGuard::set(true).await; - let ctx = PolicyLocalContext::new( - Some(ProtoSandboxPolicy { - version: 1, - ..Default::default() - }), - None, - None, - ); - - let (mut client, mut server) = tokio::io::duplex(4096); - let request = - b"GET http://policy.local/v1/policy/current HTTP/1.1\r\nHost: policy.local\r\n\r\n"; - let task = tokio::spawn(async move { - handle_forward_request(&ctx, "GET", "/v1/policy/current", request, &mut server) - .await - .unwrap(); - }); - - let mut received = Vec::new(); - client.read_to_end(&mut received).await.unwrap(); - task.await.unwrap(); - - let response = String::from_utf8(received).unwrap(); - assert!(response.starts_with("HTTP/1.1 200 OK")); - let (_, body) = response.split_once("\r\n\r\n").unwrap(); - let body: serde_json::Value = serde_json::from_str(body).unwrap(); - assert_eq!(body["format"], "yaml"); - assert!(body["policy_yaml"].as_str().unwrap().contains("version: 1")); - } - - #[test] - fn parse_timeout_query_defaults_and_clamps() { - assert_eq!(parse_timeout_query(""), PROPOSAL_WAIT_DEFAULT_SECS); - assert_eq!(parse_timeout_query("timeout="), PROPOSAL_WAIT_DEFAULT_SECS); - assert_eq!( - parse_timeout_query("timeout=abc"), - PROPOSAL_WAIT_DEFAULT_SECS - ); - assert_eq!(parse_timeout_query("timeout=30"), 30); - assert_eq!(parse_timeout_query("foo=1&timeout=45"), 45); - // Below floor clamps up; above ceiling clamps down. - assert_eq!(parse_timeout_query("timeout=0"), PROPOSAL_WAIT_MIN_SECS); - assert_eq!(parse_timeout_query("timeout=9999"), PROPOSAL_WAIT_MAX_SECS); - } - - #[test] - fn is_terminal_status_matches_only_approved_and_rejected() { - assert!(!is_terminal_status("pending")); - assert!(is_terminal_status("approved")); - assert!(is_terminal_status("rejected")); - assert!(!is_terminal_status("")); - } - - #[test] - fn chunk_state_payload_surfaces_loop_fields() { - let chunk = PolicyChunk { - id: "chunk-x".to_string(), - status: "rejected".to_string(), - rule_name: "allow_example".to_string(), - binary: "/usr/bin/curl".to_string(), - rejection_reason: "scope too broad".to_string(), - validation_result: "no exfil paths".to_string(), - ..Default::default() - }; - let pending = chunk_state_payload(&chunk, false, false); - assert_eq!(pending["chunk_id"], "chunk-x"); - assert_eq!(pending["status"], "rejected"); - assert_eq!(pending["rejection_reason"], "scope too broad"); - assert_eq!(pending["validation_result"], "no exfil paths"); - // timed_out and policy_reloaded only appear when relevant. - assert!(pending.get("timed_out").is_none()); - assert!( - pending.get("policy_reloaded").is_none(), - "policy_reloaded is only meaningful for approved chunks" - ); - - let timed = chunk_state_payload(&chunk, true, false); - assert_eq!(timed["timed_out"], true); - } - - #[test] - fn chunk_state_payload_includes_policy_reloaded_when_approved() { - let chunk = PolicyChunk { - id: "chunk-y".to_string(), - status: "approved".to_string(), - rule_name: "allow_github".to_string(), - binary: "/usr/bin/curl".to_string(), - ..Default::default() - }; - let reloaded = chunk_state_payload(&chunk, false, true); - assert_eq!(reloaded["status"], "approved"); - assert_eq!(reloaded["policy_reloaded"], true); - - let not_reloaded = chunk_state_payload(&chunk, false, false); - assert_eq!(not_reloaded["policy_reloaded"], false); - } - - #[tokio::test] - async fn proposal_routes_reject_malformed_paths() { - let _guard = ProposalsFlagGuard::set(true).await; - let ctx = PolicyLocalContext::new(None, None, None); - - // Empty chunk_id after the prefix is 404, not a wildcard list. - let (status, _) = route_request(&ctx, "GET", "/v1/proposals/", &[]).await; - assert_eq!(status, 404); - - // More than one segment after the id (not "/wait") is 404, not a - // partial match. Prevents `/v1/proposals/abc/extra` from silently - // dispatching as a status lookup for "abc/extra". - let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/extra", &[]).await; - assert_eq!(status, 404); - - // Trailing path after `/wait` also 404 — must not match the wait - // arm as a wildcard. - let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/wait/extra", &[]).await; - assert_eq!(status, 404); - } - - #[tokio::test] - async fn proposal_status_route_returns_503_when_no_gateway() { - let _guard = ProposalsFlagGuard::set(true).await; - let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); - - let (status, body) = route_request(&ctx, "GET", "/v1/proposals/chunk-id", &[]).await; - assert_eq!(status, 503); - assert_eq!(body["error"], "gateway_unavailable"); - } - - #[tokio::test] - async fn proposal_wait_route_returns_503_when_no_gateway() { - let _guard = ProposalsFlagGuard::set(true).await; - let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); - - let (status, body) = - route_request(&ctx, "GET", "/v1/proposals/chunk-id/wait?timeout=1", &[]).await; - assert_eq!(status, 503); - assert_eq!(body["error"], "gateway_unavailable"); - } - - #[tokio::test] - async fn proposal_routes_return_feature_disabled_when_flag_off() { - let _guard = ProposalsFlagGuard::set(false).await; - let ctx = PolicyLocalContext::new(None, None, Some("test-sandbox".to_string())); - - let (status, body) = route_request(&ctx, "GET", "/v1/proposals/abc", &[]).await; - assert_eq!(status, 404); - assert_eq!(body["error"], "feature_disabled"); - - let (status, _) = route_request(&ctx, "GET", "/v1/proposals/abc/wait", &[]).await; - assert_eq!(status, 404); - } - - #[test] - fn summarize_chunk_for_audit_includes_endpoint_l7_path_and_binary() { - let chunk = PolicyChunk { - id: "ignored".to_string(), - rule_name: "github_write".to_string(), - binary: "/usr/bin/curl".to_string(), - proposed_rule: Some(NetworkPolicyRule { - name: "github_write".to_string(), - endpoints: vec![NetworkEndpoint { - host: "api.github.com".to_string(), - port: 443, - rules: vec![L7Rule { - allow: Some(L7Allow { - method: "PUT".to_string(), - path: "/repos/foo/bar/contents/x.md".to_string(), - ..Default::default() - }), - }], - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }], - }), - ..Default::default() - }; - let summary = summarize_chunk_for_audit(&chunk); - assert!(summary.contains("api.github.com:443")); - assert!(summary.contains("PUT /repos/foo/bar/contents/x.md")); - assert!(summary.contains("/usr/bin/curl")); - } - - // Helpers — synthetic proposed rule + policy with that rule already - // merged. Both reused across reload-readiness tests. - fn proposed_curl_rule_for_github() -> NetworkPolicyRule { - NetworkPolicyRule { - name: "agent_proposed".to_string(), - endpoints: vec![NetworkEndpoint { - host: "api.github.com".to_string(), - port: 443, - ports: vec![443], - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }], - } - } - - fn policy_with_rule(rule: NetworkPolicyRule) -> ProtoSandboxPolicy { - ProtoSandboxPolicy { - version: 1, - network_policies: HashMap::from([(rule.name.clone(), rule)]), - ..Default::default() - } - } - - #[tokio::test] - async fn wait_returns_reloaded_true_when_rule_already_loaded() { - // John's false-sleep case: the supervisor has already reloaded a - // policy containing the proposed rule before /wait starts. A - // whole-policy diff would never see another change and burn the - // full timeout. Rule-coverage must return immediately. - let proposed = proposed_curl_rule_for_github(); - let ctx = PolicyLocalContext::new(Some(policy_with_rule(proposed.clone())), None, None); - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2); - - let start = tokio::time::Instant::now(); - let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; - let elapsed = start.elapsed(); - - assert!(reloaded, "should report reloaded=true on coverage"); - assert!( - elapsed < std::time::Duration::from_millis(200), - "should return immediately, not poll-and-wait; took {elapsed:?}" - ); - } - - #[tokio::test] - async fn wait_does_not_wake_on_unrelated_policy_change() { - // John's false-wakeup case: a *different* rule gets added to the - // local policy (other agent's approval, settings change, etc.). - // The agent's specific rule is still not loaded. A diff-based - // check would wake here; coverage must not. - let proposed = proposed_curl_rule_for_github(); - // Start with a policy that does NOT contain the proposed rule. - let initial = ProtoSandboxPolicy { - version: 1, - ..Default::default() - }; - let ctx = PolicyLocalContext::new(Some(initial), None, None); - - // Concurrently, an unrelated rule lands. We must not return. - let unrelated_load = { - let policy = ctx.current_policy.clone(); - tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - *policy.write().await = Some(policy_with_rule(NetworkPolicyRule { - name: "unrelated".to_string(), - endpoints: vec![NetworkEndpoint { - host: "api.example.com".to_string(), - port: 443, - ports: vec![443], - ..Default::default() - }], - binaries: vec![NetworkBinary { - path: "/usr/bin/curl".to_string(), - ..Default::default() - }], - })); - }) - }; - - let deadline = tokio::time::Instant::now() + std::time::Duration::from_millis(400); - let start = tokio::time::Instant::now(); - let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; - unrelated_load.await.unwrap(); - let elapsed = start.elapsed(); - - assert!( - !reloaded, - "must not wake on an unrelated reload; coverage was never satisfied" - ); - assert!( - elapsed >= std::time::Duration::from_millis(350), - "should have held until the deadline; only waited {elapsed:?}" - ); - } - - #[tokio::test] - async fn wait_wakes_when_matching_rule_arrives_mid_flight() { - // Sandbox starts without the rule, then a reload lands containing - // it. /wait should observe coverage and return reloaded=true. - let proposed = proposed_curl_rule_for_github(); - let ctx = PolicyLocalContext::new( - Some(ProtoSandboxPolicy { - version: 1, - ..Default::default() - }), - None, - None, - ); - - let matching_load = { - let policy = ctx.current_policy.clone(); - let target = proposed.clone(); - tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - *policy.write().await = Some(policy_with_rule(target)); - }) - }; - - let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(2); - let start = tokio::time::Instant::now(); - let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; - matching_load.await.unwrap(); - let elapsed = start.elapsed(); - - assert!(reloaded, "should report reloaded=true after coverage lands"); - assert!( - elapsed < std::time::Duration::from_millis(800), - "should return shortly after coverage; took {elapsed:?}" - ); - } - - #[tokio::test] - async fn wait_returns_reloaded_false_at_deadline_when_no_coverage() { - // Deadline budget exhausted, the proposed rule never showed up. - // Coverage check returns false — the agent gets policy_reloaded= - // false and decides whether to retry blind or re-issue /wait. - let proposed = proposed_curl_rule_for_github(); - let ctx = PolicyLocalContext::new( - Some(ProtoSandboxPolicy { - version: 1, - ..Default::default() - }), - None, - None, - ); - let deadline = tokio::time::Instant::now() + std::time::Duration::from_millis(300); - let start = tokio::time::Instant::now(); - let reloaded = wait_for_local_policy_to_cover(&ctx, &proposed, deadline).await; - let elapsed = start.elapsed(); - - assert!(!reloaded); - assert!( - elapsed >= std::time::Duration::from_millis(250), - "should wait until ~deadline; only waited {elapsed:?}" - ); - assert!( - elapsed < std::time::Duration::from_millis(800), - "should not extend past deadline by much; took {elapsed:?}" - ); - } - - #[test] - fn sanitize_reason_for_audit_strips_control_chars_and_caps_length() { - // Tabs and newlines are stripped; ordinary printable chars survive; - // multi-byte characters count as one char in the cap. - let raw = "line one\nline\ttwo\u{0001}\u{0007}"; - let cleaned = sanitize_reason_for_audit(raw); - assert!(!cleaned.contains('\n')); - assert!(!cleaned.contains('\t')); - assert!(!cleaned.contains('\u{0001}')); - assert!(cleaned.contains("line one")); - assert!(cleaned.contains("linetwo")); - - // Length cap with ellipsis marker so a downstream reader can tell - // the audit string is truncated. - let long: String = "x".repeat(500); - let capped = sanitize_reason_for_audit(&long); - assert!(capped.chars().count() <= 201); - assert!(capped.ends_with('…')); - - // Empty input maps to empty output (caller renders "(no guidance)"). - assert_eq!(sanitize_reason_for_audit(""), ""); - } - - #[test] - fn summarize_chunk_for_audit_falls_back_to_rule_name_without_rule() { - let chunk = PolicyChunk { - rule_name: "fallback".to_string(), - proposed_rule: None, - ..Default::default() - }; - assert_eq!(summarize_chunk_for_audit(&chunk), "rule_name:fallback"); - } -} diff --git a/crates/openshell-sandbox/src/proxy.rs b/crates/openshell-sandbox/src/proxy.rs deleted file mode 100644 index d6a1807c7..000000000 --- a/crates/openshell-sandbox/src/proxy.rs +++ /dev/null @@ -1,7718 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! HTTP CONNECT proxy with OPA policy evaluation and process-identity binding. - -use crate::activity_aggregator::{ActivitySender, try_record_activity}; -use crate::denial_aggregator::DenialEvent; -use crate::identity::BinaryIdentityCache; -use crate::l7::tls::ProxyTlsState; -use crate::opa::{NetworkAction, OpaEngine, PolicyGenerationGuard}; -use crate::policy::ProxyPolicy; -use crate::policy_local::{POLICY_LOCAL_HOST, PolicyLocalContext}; -use crate::provider_credentials::ProviderCredentialState; -use crate::secrets::{SecretResolver, rewrite_header_line_checked}; -use miette::{IntoDiagnostic, Result}; -use openshell_core::net::{is_always_blocked_ip, is_internal_ip, is_link_local_ip}; -use openshell_ocsf::{ - ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, - NetworkActivityBuilder, Process, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, -}; -use std::net::{IpAddr, SocketAddr}; -use std::path::PathBuf; -use std::sync::Arc; -use std::sync::atomic::{AtomicU32, Ordering}; -use tokio::io::{ - AsyncRead as TokioAsyncRead, AsyncReadExt, AsyncWrite as TokioAsyncWrite, AsyncWriteExt, -}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; -use tokio::task::JoinHandle; -use tracing::{debug, warn}; - -const MAX_HEADER_BYTES: usize = 8192; -const INFERENCE_LOCAL_HOST: &str = "inference.local"; -const INFERENCE_LOCAL_PORT: u16 = 443; - -/// Hostnames injected by compute drivers as `/etc/hosts` aliases for the host -/// machine. Traffic to these names is eligible for the trusted-gateway SSRF -/// exemption when the resolved IP matches the driver-injected value read from -/// `/etc/hosts` at proxy startup. -const HOST_GATEWAY_ALIASES: &[&str] = &[ - "host.openshell.internal", - "host.containers.internal", - "host.docker.internal", -]; - -/// Cloud instance metadata IPs that are NEVER exempted from SSRF blocking, -/// even when they coincidentally match a host-gateway alias resolution. -/// This list covers the well-known IMDS endpoints across major cloud providers. -const CLOUD_METADATA_IPS: &[IpAddr] = &[ - // AWS / GCP / Azure instance metadata service - IpAddr::V4(std::net::Ipv4Addr::new(169, 254, 169, 254)), -]; - -/// Maximum total bytes for a streaming inference response body (32 MiB). -#[cfg(not(test))] -const MAX_STREAMING_BODY: usize = 32 * 1024 * 1024; -// Keep unit tests deterministic without pushing tens of MiB through loopback. -#[cfg(test)] -const MAX_STREAMING_BODY: usize = 1024; - -/// Idle timeout per chunk when relaying streaming inference responses. -/// -/// Reasoning models (e.g. nemotron-3-super, o1, o3) can pause for 60+ seconds -/// between "thinking" and output phases. 120s provides headroom while still -/// catching genuinely stuck streams. -#[cfg(not(test))] -const CHUNK_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); -// Exercise idle-timeout truncation without slowing the full package test suite. -#[cfg(test)] -const CHUNK_IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(100); - -/// Result of a proxy CONNECT policy decision. -struct ConnectDecision { - action: NetworkAction, - /// Policy generation used for the L4 network decision. - generation: u64, - /// Resolved binary path. - binary: Option, - /// PID owning the socket. - binary_pid: Option, - /// Ancestor binary paths from process tree walk. - ancestors: Vec, - /// Cmdline-derived absolute paths (for script detection). - cmdline_paths: Vec, -} - -/// Outcome of an inference interception attempt. -/// -/// Returned by [`handle_inference_interception`] so the call site can emit -/// a structured CONNECT deny log when the connection is not successfully routed. -#[derive(Debug)] -enum InferenceOutcome { - /// At least one request was successfully routed to a local inference backend. - Routed, - /// The connection was denied (TLS failure, non-inference request, etc.). - Denied { reason: String }, -} - -/// Inference routing context for sandbox-local execution. -/// -/// Holds a `Router` (HTTP client) and cached sets of resolved routes. -/// User routes serve `inference.local` traffic; system routes are consumed -/// in-process by the supervisor for platform functions (e.g. agent harness). -pub struct InferenceContext { - pub patterns: Vec, - router: openshell_router::Router, - /// Routes for the user-facing `inference.local` endpoint. - routes: Arc>>, - /// Routes for supervisor-only system inference (`sandbox-system`). - system_routes: Arc>>, -} - -impl InferenceContext { - // `router`/`routes` are intentionally distinct nouns (the router and the - // route list it consumes); both names are clearer than alternatives. - #[allow(clippy::similar_names)] - pub fn new( - patterns: Vec, - router: openshell_router::Router, - routes: Vec, - system_routes: Vec, - ) -> Self { - Self { - patterns, - router, - routes: Arc::new(tokio::sync::RwLock::new(routes)), - system_routes: Arc::new(tokio::sync::RwLock::new(system_routes)), - } - } - - /// Get a handle to the user route cache for background refresh. - pub fn route_cache( - &self, - ) -> Arc>> { - self.routes.clone() - } - - /// Get a handle to the system route cache for background refresh. - pub fn system_route_cache( - &self, - ) -> Arc>> { - self.system_routes.clone() - } - - /// Make an inference call using system routes (supervisor-only). - /// - /// This is the in-process API for platform functions. It bypasses the - /// CONNECT proxy entirely — the supervisor calls the router directly - /// from the host network namespace. - pub async fn system_inference( - &self, - protocol: &str, - method: &str, - path: &str, - headers: Vec<(String, String)>, - body: bytes::Bytes, - ) -> Result { - let routes = self.system_routes.read().await; - self.router - .proxy_with_candidates(protocol, method, path, headers, body, &routes) - .await - } -} - -#[derive(Debug)] -pub struct ProxyHandle { - #[allow(dead_code)] - http_addr: Option, - join: JoinHandle<()>, -} - -impl ProxyHandle { - /// Start the proxy with OPA engine for policy evaluation. - /// - /// The proxy uses OPA for network decisions with process-identity binding - /// via `/proc/net/tcp`. All connections are evaluated through OPA policy. - #[allow(clippy::too_many_arguments)] - pub(crate) async fn start_with_bind_addr( - policy: &ProxyPolicy, - bind_addr: Option, - opa_engine: Arc, - identity_cache: Arc, - entrypoint_pid: Arc, - tls_state: Option>, - inference_ctx: Option>, - provider_credentials: Option, - policy_local_ctx: Option>, - denial_tx: Option>, - activity_tx: Option, - ) -> Result { - // Use override bind_addr, fall back to policy http_addr, then default - // to loopback:3128. The default allows the proxy to function when no - // network namespace is available (e.g. missing CAP_NET_ADMIN) and the - // policy doesn't specify an explicit address. - let default_addr: SocketAddr = ([127, 0, 0, 1], 3128).into(); - let http_addr = bind_addr.or(policy.http_addr).unwrap_or(default_addr); - - // Only enforce loopback restriction when not using network namespace override - if bind_addr.is_none() && !http_addr.ip().is_loopback() { - return Err(miette::miette!( - "Proxy http_addr must be loopback-only: {http_addr}" - )); - } - - let listener = TcpListener::bind(http_addr).await.into_diagnostic()?; - let local_addr = listener.local_addr().into_diagnostic()?; - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Listen) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .dst_endpoint(Endpoint::from_ip(local_addr.ip(), local_addr.port())) - .message(format!("Proxy listening on {local_addr}")) - .build(); - ocsf_emit!(event); - } - - // Detect the trusted host gateway IP from /etc/hosts before user code - // runs. This is read once at startup so later /etc/hosts modifications - // by sandbox workloads cannot influence the stored value. - let trusted_host_gateway: Arc> = Arc::new(detect_trusted_host_gateway()); - if let Some(ref ip) = *trusted_host_gateway { - tracing::info!( - %ip, - "Trusted host gateway detected from /etc/hosts; \ - host-gateway aliases exempt from SSRF always-blocked check" - ); - } - - let join = tokio::spawn(async move { - loop { - match listener.accept().await { - Ok((stream, _addr)) => { - let opa = opa_engine.clone(); - let cache = identity_cache.clone(); - let spid = entrypoint_pid.clone(); - let tls = tls_state.clone(); - let inf = inference_ctx.clone(); - let policy_local = policy_local_ctx.clone(); - let gw = trusted_host_gateway.clone(); - let resolver = provider_credentials - .as_ref() - .and_then(ProviderCredentialState::resolver); - let dynamic_credentials = provider_credentials.as_ref().map(|state| { - Arc::new(std::sync::RwLock::new( - state.snapshot().dynamic_credentials.clone(), - )) - }); - let dtx = denial_tx.clone(); - let atx = activity_tx.clone(); - tokio::spawn(async move { - if let Err(err) = handle_tcp_connection( - stream, - opa, - cache, - spid, - tls, - inf, - policy_local, - gw, - resolver, - dynamic_credentials, - dtx, - atx, - ) - .await - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .message(format!("Proxy connection error: {err}")) - .build(); - ocsf_emit!(event); - } - }); - } - Err(err) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .message(format!("Proxy accept error: {err}")) - .build(); - ocsf_emit!(event); - break; - } - } - } - }); - - Ok(Self { - http_addr: Some(local_addr), - join, - }) - } - - #[allow(dead_code)] - pub const fn http_addr(&self) -> Option { - self.http_addr - } -} - -impl Drop for ProxyHandle { - fn drop(&mut self) { - self.join.abort(); - } -} - -fn emit_activity(tx: &Option, denied: bool, deny_group: &'static str) { - if let Some(tx) = tx { - let _ = try_record_activity(tx, denied, deny_group); - } -} - -fn l7_inspection_active(l7_route: Option<&L7RouteSnapshot>) -> bool { - l7_route.is_some_and(|route| !route.configs.is_empty()) -} - -fn emit_connect_activity_if_l4_only( - tx: &Option, - l7_route: Option<&L7RouteSnapshot>, -) { - if !l7_inspection_active(l7_route) { - emit_activity(tx, false, "unknown"); - } -} - -fn emit_activity_simple(tx: Option<&ActivitySender>, denied: bool, deny_group: &'static str) { - if let Some(tx) = tx { - let _ = try_record_activity(tx, denied, deny_group); - } -} - -fn emit_forward_success_activity(tx: Option<&ActivitySender>, l7_activity_pending: bool) { - emit_activity_simple( - tx, - false, - if l7_activity_pending { - "l7_policy" - } else { - "unknown" - }, - ); -} - -fn l7_parse_error_reason(request_info: &crate::l7::L7RequestInfo) -> Option { - request_info - .graphql - .as_ref() - .and_then(|info| info.error.as_deref()) - .map(|error| format!("GraphQL request rejected: {error}")) - .or_else(|| { - request_info - .jsonrpc - .as_ref() - .and_then(|info| info.error.as_deref()) - .map(|error| format!("JSON-RPC request rejected: {error}")) - }) -} - -/// Emit a denial event to the aggregator channel (if configured). -/// Used by `handle_tcp_connection` which owns `Option`. -fn emit_denial( - tx: &Option>, - host: &str, - port: u16, - binary: &str, - decision: &ConnectDecision, - reason: &str, - stage: &str, -) { - if let Some(tx) = tx { - let _ = tx.send(DenialEvent { - host: host.to_string(), - port, - binary: binary.to_string(), - ancestors: decision - .ancestors - .iter() - .map(|p| p.display().to_string()) - .collect(), - deny_reason: reason.to_string(), - denial_stage: stage.to_string(), - l7_method: None, - l7_path: None, - }); - } -} - -/// Emit a denial event from a borrowed sender reference. -/// Used by `handle_forward_proxy` which borrows `Option<&Sender>`. -fn emit_denial_simple( - tx: Option<&mpsc::UnboundedSender>, - host: &str, - port: u16, - binary: &str, - decision: &ConnectDecision, - reason: &str, - stage: &str, -) { - if let Some(tx) = tx { - let _ = tx.send(DenialEvent { - host: host.to_string(), - port, - binary: binary.to_string(), - ancestors: decision - .ancestors - .iter() - .map(|p| p.display().to_string()) - .collect(), - deny_reason: reason.to_string(), - denial_stage: stage.to_string(), - l7_method: None, - l7_path: None, - }); - } -} - -// Many distinct, non-related context parameters are required for a CONNECT -// dispatch; bundling them into a struct would just shift the noise into call -// sites. -#[allow(clippy::too_many_arguments)] -async fn handle_tcp_connection( - mut client: TcpStream, - opa_engine: Arc, - identity_cache: Arc, - entrypoint_pid: Arc, - tls_state: Option>, - inference_ctx: Option>, - policy_local_ctx: Option>, - trusted_host_gateway: Arc>, - secret_resolver: Option>, - dynamic_credentials: Option< - Arc< - std::sync::RwLock< - std::collections::HashMap, - >, - >, - >, - denial_tx: Option>, - activity_tx: Option, -) -> Result<()> { - let mut buf = vec![0u8; MAX_HEADER_BYTES]; - let mut used = 0usize; - - loop { - if used == buf.len() { - respond( - &mut client, - b"HTTP/1.1 431 Request Header Fields Too Large\r\n\r\n", - ) - .await?; - return Ok(()); - } - - let n = client.read(&mut buf[used..]).await.into_diagnostic()?; - if n == 0 { - return Ok(()); - } - used += n; - - if buf[..used].windows(4).any(|win| win == b"\r\n\r\n") { - break; - } - } - - let request = String::from_utf8_lossy(&buf[..used]); - let mut lines = request.split("\r\n"); - let request_line = lines.next().unwrap_or(""); - let mut parts = request_line.split_whitespace(); - let method = parts.next().unwrap_or(""); - let target = parts.next().unwrap_or(""); - - if method != "CONNECT" { - return handle_forward_proxy( - method, - target, - &buf[..], - used, - &mut client, - opa_engine, - identity_cache, - entrypoint_pid, - policy_local_ctx, - trusted_host_gateway, - secret_resolver, - dynamic_credentials, - denial_tx.as_ref(), - activity_tx.as_ref(), - ) - .await; - } - - let (host, port) = parse_target(target)?; - let host_lc = host.to_ascii_lowercase(); - - if host_lc == INFERENCE_LOCAL_HOST && port == INFERENCE_LOCAL_PORT { - respond(&mut client, b"HTTP/1.1 200 Connection Established\r\n\r\n").await?; - let outcome = handle_inference_interception( - client, - INFERENCE_LOCAL_HOST, - port, - tls_state.as_ref(), - inference_ctx.as_ref(), - ) - .await?; - if let InferenceOutcome::Denied { reason } = outcome { - emit_activity(&activity_tx, true, "forward_policy"); - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, port)) - .message(format!("Inference interception denied: {reason}")) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - return Ok(()); - } - - let peer_addr = client.peer_addr().into_diagnostic()?; - let _local_addr = client.local_addr().into_diagnostic()?; - - // Evaluate OPA policy with process-identity binding. - // Wrapped in spawn_blocking because identity resolution does heavy sync I/O: - // /proc scanning + SHA256 hashing of binaries (e.g. node at 124MB). - let opa_clone = opa_engine.clone(); - let cache_clone = identity_cache.clone(); - let pid_clone = entrypoint_pid.clone(); - let host_clone = host_lc.clone(); - let decision = tokio::task::spawn_blocking(move || { - evaluate_opa_tcp( - peer_addr, - &opa_clone, - &cache_clone, - &pid_clone, - &host_clone, - port, - ) - }) - .await - .map_err(|e| miette::miette!("identity resolution task panicked: {e}"))?; - - // Extract action string and matched policy for logging - let (matched_policy, deny_reason) = match &decision.action { - NetworkAction::Allow { matched_policy } => (matched_policy.clone(), String::new()), - NetworkAction::Deny { reason } => (None, reason.clone()), - }; - - // Build log context fields (shared by deny log below and deferred allow log after L7 check) - let binary_str = decision - .binary - .as_ref() - .map_or_else(|| "-".to_string(), |p| p.display().to_string()); - let pid_str = decision - .binary_pid - .map_or_else(|| "-".to_string(), |p| p.to_string()); - let ancestors_str = if decision.ancestors.is_empty() { - "-".to_string() - } else { - decision - .ancestors - .iter() - .map(|p| p.display().to_string()) - .collect::>() - .join(" -> ") - }; - let cmdline_str = if decision.cmdline_paths.is_empty() { - "-".to_string() - } else { - decision - .cmdline_paths - .iter() - .map(|p| p.display().to_string()) - .collect::>() - .join(", ") - }; - let policy_str = matched_policy.as_deref().unwrap_or("-"); - - // Log denied connections immediately — they never reach L7. - // Allowed connections are logged after the L7 config check (below) - // so we can distinguish CONNECT (L4-only) from CONNECT_L7 (L7 follows). - if matches!(decision.action, NetworkAction::Deny { .. }) { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule("-", "opa") - .message(format!("CONNECT denied {host_lc}:{port}")) - .status_detail(&deny_reason) - .build(); - ocsf_emit!(event); - emit_denial( - &denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &deny_reason, - "connect", - ); - emit_activity(&activity_tx, true, "connect_policy"); - respond( - &mut client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("CONNECT {host_lc}:{port} not permitted by policy"), - ), - ) - .await?; - return Ok(()); - } - - let sandbox_entrypoint_pid = entrypoint_pid.load(Ordering::Acquire); - - // Query allowed_ips from the matched endpoint config (if any). - // When present, the SSRF check validates resolved IPs against this - // allowlist instead of blanket-blocking all private IPs. - // When the policy host is already a literal IP address, treat it as - // implicitly allowed — the user explicitly declared the destination. - // Exact declared hostnames also skip the private-IP blanket block below, - // while keeping loopback/link-local/unspecified addresses denied. - let mut raw_allowed_ips = query_allowed_ips(&opa_engine, &decision, &host_lc, port); - if raw_allowed_ips.is_empty() { - raw_allowed_ips = implicit_allowed_ips_for_ip_host(&host); - } - let exact_declared_endpoint_host = - query_exact_declared_endpoint_host(&opa_engine, &decision, &host_lc, port); - - // Defense-in-depth: resolve DNS and reject connections to internal IPs. - let dns_connect_start = std::time::Instant::now(); - // The "non-empty" branch is the explicit-allowlist path; reading it first - // matches the policy decision narrative. - #[allow(clippy::if_not_else)] - let mut upstream = if is_host_gateway_alias(&host_lc) - && let Some(gw) = *trusted_host_gateway - { - // Trusted host-gateway path. The compute driver injected this hostname - // into /etc/hosts pointing at a known IP (read at proxy startup before - // user code runs). Bypass the normal SSRF tiers so link-local gateway - // addresses (used by rootless Podman with pasta) are not hard-blocked. - // Cloud metadata IPs and control-plane ports are still rejected. - match resolve_and_check_trusted_gateway(&host, port, gw, sandbox_entrypoint_pid).await { - Ok(addrs) => TcpStream::connect(addrs.as_slice()) - .await - .into_diagnostic()?, - Err(reason) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule("-", "ssrf") - .message(format!( - "CONNECT blocked: trusted-gateway check failed for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial( - &denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity(&activity_tx, true, "ssrf"); - respond( - &mut client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!("CONNECT {host_lc}:{port} blocked: trusted-gateway check failed"), - ), - ) - .await?; - return Ok(()); - } - } - } else if !raw_allowed_ips.is_empty() { - // allowed_ips mode: validate resolved IPs against CIDR allowlist. - // Loopback and link-local are still always blocked. - match parse_allowed_ips(&raw_allowed_ips) { - Ok(nets) => { - match resolve_and_check_allowed_ips(&host, port, &nets, sandbox_entrypoint_pid) - .await - { - Ok(addrs) => TcpStream::connect(addrs.as_slice()) - .await - .into_diagnostic()?, - Err(reason) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule("-", "ssrf") - .message(format!( - "CONNECT blocked: allowed_ips check failed for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial( - &denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity(&activity_tx, true, "ssrf"); - respond( - &mut client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!( - "CONNECT {host_lc}:{port} blocked: allowed_ips check failed" - ), - ), - ) - .await?; - return Ok(()); - } - } - } - Err(reason) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule("-", "ssrf") - .message(format!( - "CONNECT blocked: invalid allowed_ips in policy for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial( - &denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity(&activity_tx, true, "ssrf"); - respond( - &mut client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!("CONNECT {host_lc}:{port} blocked: invalid allowed_ips in policy"), - ), - ) - .await?; - return Ok(()); - } - } - } else if exact_declared_endpoint_host { - // Exact declared hostname mode: the operator explicitly allowed this - // host:port, so private IP resolution is permitted without duplicating - // the resolved IP in allowed_ips. Always-blocked addresses and - // control-plane ports remain denied. - match resolve_and_check_declared_endpoint(&host, port, sandbox_entrypoint_pid).await { - Ok(addrs) => TcpStream::connect(addrs.as_slice()) - .await - .into_diagnostic()?, - Err(reason) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule("-", "ssrf") - .message(format!( - "CONNECT blocked: declared endpoint check failed for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial( - &denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - respond( - &mut client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!( - "CONNECT {host_lc}:{port} blocked: declared endpoint check failed" - ), - ), - ) - .await?; - return Ok(()); - } - } - } else { - // Default: reject all internal IPs (loopback, RFC 1918, link-local). - match resolve_and_reject_internal(&host, port, sandbox_entrypoint_pid).await { - Ok(addrs) => TcpStream::connect(addrs.as_slice()) - .await - .into_diagnostic()?, - Err(reason) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule("-", "ssrf") - .message(format!( - "CONNECT blocked: internal address {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial( - &denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity(&activity_tx, true, "ssrf"); - respond( - &mut client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!("CONNECT {host_lc}:{port} blocked: internal address"), - ), - ) - .await?; - return Ok(()); - } - } - }; - - debug!( - "handle_tcp_connection dns_resolve_and_tcp_connect: {}ms host={host_lc}", - dns_connect_start.elapsed().as_millis() - ); - - respond(&mut client, b"HTTP/1.1 200 Connection Established\r\n\r\n").await?; - - // Check if endpoint has L7 config for protocol-aware inspection, and - // retain the generation for HTTP passthrough keep-alive tunnels. - let l7_route = query_l7_route_snapshot(&opa_engine, &decision, &host_lc, port); - let should_inspect_l7 = l7_inspection_active(l7_route.as_ref()); - - // Log the allowed CONNECT — use CONNECT_L7 when L7 inspection follows, - // so log consumers can distinguish L4-only decisions from tunnel lifecycle events. - let connect_msg = if should_inspect_l7 { - "CONNECT_L7" - } else { - "CONNECT" - }; - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Allowed) - .disposition(DispositionId::Allowed) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint_addr(peer_addr.ip(), peer_addr.port()) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "opa") - .message(format!("{connect_msg} allowed {host_lc}:{port}")) - .build(); - ocsf_emit!(event); - } - emit_connect_activity_if_l4_only(&activity_tx, l7_route.as_ref()); - - // Determine effective TLS mode. Check the raw endpoint config for - // `tls: skip` independently of L7 config (which requires `protocol`). - let effective_tls_skip = - query_tls_mode(&opa_engine, &decision, &host_lc, port) == crate::l7::TlsMode::Skip; - - // Build L7 eval context (shared by TLS-terminated and plaintext paths). - let ctx = crate::l7::relay::L7EvalContext { - host: host_lc.clone(), - port, - policy_name: matched_policy.clone().unwrap_or_default(), - binary_path: decision - .binary - .as_ref() - .map(|p| p.to_string_lossy().into_owned()) - .unwrap_or_default(), - ancestors: decision - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - cmdline_paths: decision - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - secret_resolver: secret_resolver.clone(), - activity_tx: activity_tx.clone(), - dynamic_credentials: dynamic_credentials.clone(), - token_grant_resolver: dynamic_credentials - .as_ref() - .map(|_| crate::l7::token_grant_injection::default_resolver()), - }; - - if effective_tls_skip { - // tls: skip — raw tunnel, no termination, no credential injection. - debug!( - host = %host_lc, - port = port, - "tls: skip — bypassing TLS auto-detection, raw tunnel" - ); - let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream) - .await - .into_diagnostic()?; - return Ok(()); - } - - // Auto-detect TLS by peeking the first bytes. - let mut peek_buf = [0u8; 8]; - let n = client.peek(&mut peek_buf).await.into_diagnostic()?; - if n == 0 { - return Ok(()); - } - - let is_tls = crate::l7::tls::looks_like_tls(&peek_buf[..n]); - let is_http = crate::l7::rest::looks_like_http(&peek_buf[..n]); - - if is_tls { - // TLS detected — terminate unconditionally. - if let Some(ref tls) = tls_state { - let tls_result = async { - let mut tls_client = - crate::l7::tls::tls_terminate_client(client, tls, &host_lc).await?; - let mut tls_upstream = - crate::l7::tls::tls_connect_upstream(upstream, &host_lc, tls.upstream_config()) - .await?; - - if let Some(route) = l7_route.as_ref().filter(|route| !route.configs.is_empty()) { - // L7 inspection on terminated TLS traffic. - let tunnel_engine = match opa_engine.clone_engine_for_tunnel(route.generation) { - Ok(engine) => engine, - Err(e) => { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - return Ok(()); - } - }; - if route.configs.len() == 1 { - crate::l7::relay::relay_with_inspection( - &route.configs[0].config, - tunnel_engine, - &mut tls_client, - &mut tls_upstream, - &ctx, - ) - .await - } else { - let configs: Vec = route - .configs - .iter() - .map(|snapshot| snapshot.config.clone()) - .collect(); - crate::l7::relay::relay_with_route_selection( - &configs, - tunnel_engine, - &mut tls_client, - &mut tls_upstream, - &ctx, - ) - .await - } - } else { - // No L7 config — relay with credential injection only. - let generation = l7_route - .as_ref() - .map_or(decision.generation, |route| route.generation); - let generation_guard = match opa_engine.generation_guard(generation) { - Ok(guard) => guard, - Err(e) => { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - return Ok(()); - } - }; - crate::l7::relay::relay_passthrough_with_credentials( - &mut tls_client, - &mut tls_upstream, - &ctx, - &generation_guard, - ) - .await - } - }; - if let Err(e) = tls_result.await { - if is_benign_relay_error(&e) { - debug!( - host = %host_lc, - port = port, - error = %e, - "TLS connection closed" - ); - } else { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!("TLS relay error: {e}")) - .build(); - ocsf_emit!(event); - } - } - } else { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!( - "TLS detected but TLS state not configured for {host_lc}:{port}, falling back to raw tunnel" - )) - .build(); - ocsf_emit!(event); - } - let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream) - .await - .into_diagnostic()?; - } - } else if is_http { - // Plaintext HTTP detected. - if let Some(route) = l7_route.as_ref().filter(|route| !route.configs.is_empty()) { - let tunnel_engine = match opa_engine.clone_engine_for_tunnel(route.generation) { - Ok(engine) => engine, - Err(e) => { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - return Ok(()); - } - }; - let relay_result = if route.configs.len() == 1 { - crate::l7::relay::relay_with_inspection( - &route.configs[0].config, - tunnel_engine, - &mut client, - &mut upstream, - &ctx, - ) - .await - } else { - let configs: Vec = route - .configs - .iter() - .map(|snapshot| snapshot.config.clone()) - .collect(); - crate::l7::relay::relay_with_route_selection( - &configs, - tunnel_engine, - &mut client, - &mut upstream, - &ctx, - ) - .await - }; - if let Err(e) = relay_result { - if is_benign_relay_error(&e) { - debug!(host = %host_lc, port = port, error = %e, "L7 connection closed"); - } else { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!("L7 relay error: {e}")) - .build(); - ocsf_emit!(event); - } - } - } else { - // Plaintext HTTP, no L7 config — relay with credential injection. - let generation = l7_route - .as_ref() - .map_or(decision.generation, |route| route.generation); - let generation_guard = match opa_engine.generation_guard(generation) { - Ok(guard) => guard, - Err(e) => { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - return Ok(()); - } - }; - if let Err(e) = crate::l7::relay::relay_passthrough_with_credentials( - &mut client, - &mut upstream, - &ctx, - &generation_guard, - ) - .await - { - if is_benign_relay_error(&e) { - debug!(host = %host_lc, port = port, error = %e, "HTTP relay closed"); - } else { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!("HTTP relay error: {e}")) - .build(); - ocsf_emit!(event); - } - } - } - } else { - // Neither TLS nor HTTP — raw binary relay. - debug!( - host = %host_lc, - port = port, - "Non-TLS non-HTTP traffic detected, raw tunnel" - ); - let _ = tokio::io::copy_bidirectional(&mut client, &mut upstream) - .await - .into_diagnostic()?; - } - - Ok(()) -} - -/// Resolved process identity for a TCP peer: binary path, PID, ancestor chain, -/// cmdline paths, and the TOFU-verified binary hash. -/// -/// Produced by [`resolve_process_identity`]; consumed by [`evaluate_opa_tcp`] -/// and by the identity-chain regression tests. -#[cfg(target_os = "linux")] -struct ResolvedIdentity { - bin_path: PathBuf, - binary_pid: u32, - ancestors: Vec, - cmdline_paths: Vec, - bin_hash: String, -} - -#[cfg(target_os = "linux")] -#[derive(Debug, Eq, PartialEq)] -struct PolicyIdentityKey { - bin_path: PathBuf, - ancestors: Vec, - cmdline_paths: Vec, - bin_hash: String, -} - -#[cfg(target_os = "linux")] -impl ResolvedIdentity { - fn policy_key(&self) -> PolicyIdentityKey { - PolicyIdentityKey { - bin_path: self.bin_path.clone(), - ancestors: self.ancestors.clone(), - cmdline_paths: self.cmdline_paths.clone(), - bin_hash: self.bin_hash.clone(), - } - } -} - -/// Error from [`resolve_process_identity`]. Carries the deny reason and -/// whatever partial identity data was resolved before the failure so the -/// caller can include it in the [`ConnectDecision`] and OCSF event. -#[cfg(target_os = "linux")] -struct IdentityError { - reason: String, - binary: Option, - binary_pid: Option, - ancestors: Vec, -} - -#[cfg(target_os = "linux")] -fn resolve_owner_identity( - owner_pid: u32, - entrypoint_pid: u32, - identity_cache: &BinaryIdentityCache, -) -> std::result::Result { - let bin_path = - crate::procfs::binary_path(owner_pid.cast_signed()).map_err(|e| IdentityError { - reason: format!("failed to resolve peer binary for PID {owner_pid}: {e}"), - binary: None, - binary_pid: Some(owner_pid), - ancestors: vec![], - })?; - - let bin_hash = identity_cache - .verify_or_cache(&bin_path) - .map_err(|e| IdentityError { - reason: format!("binary integrity check failed: {e}"), - binary: Some(bin_path.clone()), - binary_pid: Some(owner_pid), - ancestors: vec![], - })?; - - let ancestors = crate::procfs::collect_ancestor_binaries(owner_pid, entrypoint_pid); - - for ancestor in &ancestors { - identity_cache - .verify_or_cache(ancestor) - .map_err(|e| IdentityError { - reason: format!( - "ancestor integrity check failed for {}: {e}", - ancestor.display() - ), - binary: Some(bin_path.clone()), - binary_pid: Some(owner_pid), - ancestors: ancestors.clone(), - })?; - } - - let mut exclude = ancestors.clone(); - exclude.push(bin_path.clone()); - let cmdline_paths = crate::procfs::collect_cmdline_paths(owner_pid, entrypoint_pid, &exclude); - - Ok(ResolvedIdentity { - bin_path, - binary_pid: owner_pid, - ancestors, - cmdline_paths, - bin_hash, - }) -} - -/// Resolve the identity of the process owning a TCP peer connection. -/// -/// Walks `/proc//net/tcp` to find the socket inode, locates -/// every owning PID, reads `/proc//exe`, TOFU-verifies each binary hash, -/// walks each ancestor chain verifying every ancestor, and collects -/// cmdline-derived absolute paths for script detection. -/// -/// This is the identity-resolution block of [`evaluate_opa_tcp`] extracted -/// into a standalone helper so it can be exercised by Linux-only regression -/// tests without a full OPA engine. The key invariant under test is that on -/// a hot-swap of the peer binary, the failure mode is -/// `"Binary integrity violation"` (from the identity cache) rather than -/// `"Failed to stat ... (deleted)"` (from the kernel-tainted path). -#[cfg(target_os = "linux")] -fn resolve_process_identity( - entrypoint_pid: u32, - peer_port: u16, - identity_cache: &BinaryIdentityCache, -) -> std::result::Result { - let socket_owners = crate::procfs::resolve_tcp_peer_socket_owners(entrypoint_pid, peer_port) - .map_err(|e| IdentityError { - reason: format!("failed to resolve peer binary: {e}"), - binary: None, - binary_pid: None, - ancestors: vec![], - })?; - - let mut identities = Vec::with_capacity(socket_owners.owners.len()); - for owner in &socket_owners.owners { - identities.push(resolve_owner_identity( - owner.pid, - entrypoint_pid, - identity_cache, - )?); - } - - let Some(first_identity) = identities.first() else { - return Err(IdentityError { - reason: format!( - "failed to resolve peer binary: no process found owning socket inode {}", - socket_owners.inode - ), - binary: None, - binary_pid: None, - ancestors: vec![], - }); - }; - - let first_key = first_identity.policy_key(); - if identities - .iter() - .skip(1) - .any(|identity| identity.policy_key() != first_key) - { - let mut pids: Vec = identities - .iter() - .map(|identity| identity.binary_pid) - .collect(); - pids.sort_unstable(); - return Err(IdentityError { - reason: format!( - "ambiguous shared socket ownership: inode {} is held by PIDs [{}] with different policy identities", - socket_owners.inode, - pids.iter() - .map(u32::to_string) - .collect::>() - .join(", ") - ), - binary: None, - binary_pid: None, - ancestors: vec![], - }); - } - - let mut identity = identities.swap_remove(0); - if let Some(lowest_pid) = socket_owners.owners.iter().map(|owner| owner.pid).min() { - identity.binary_pid = lowest_pid; - } - Ok(identity) -} - -/// Evaluate OPA policy for a TCP connection with identity binding via /proc/net/tcp. -#[cfg(target_os = "linux")] -fn evaluate_opa_tcp( - peer_addr: SocketAddr, - engine: &OpaEngine, - identity_cache: &BinaryIdentityCache, - entrypoint_pid: &AtomicU32, - host: &str, - port: u16, -) -> ConnectDecision { - use crate::opa::NetworkInput; - use std::sync::atomic::Ordering; - - let deny = |reason: String, - binary: Option, - binary_pid: Option, - ancestors: Vec, - cmdline_paths: Vec| - -> ConnectDecision { - ConnectDecision { - action: NetworkAction::Deny { reason }, - generation: engine.current_generation(), - binary, - binary_pid, - ancestors, - cmdline_paths, - } - }; - - let pid = entrypoint_pid.load(Ordering::Acquire); - if pid == 0 { - return deny( - "entrypoint process not yet spawned".into(), - None, - None, - vec![], - vec![], - ); - } - - let total_start = std::time::Instant::now(); - let peer_port = peer_addr.port(); - - let identity = match resolve_process_identity(pid, peer_port, identity_cache) { - Ok(id) => id, - Err(err) => { - return deny( - err.reason, - err.binary, - err.binary_pid, - err.ancestors, - vec![], - ); - } - }; - - let ResolvedIdentity { - bin_path, - binary_pid, - ancestors, - cmdline_paths, - bin_hash, - } = identity; - - let input = NetworkInput { - host: host.to_string(), - port, - binary_path: bin_path.clone(), - binary_sha256: bin_hash, - ancestors: ancestors.clone(), - cmdline_paths: cmdline_paths.clone(), - }; - - let result = match engine.evaluate_network_action_with_generation(&input) { - Ok((action, generation)) => ConnectDecision { - action, - generation, - binary: Some(bin_path), - binary_pid: Some(binary_pid), - ancestors, - cmdline_paths, - }, - Err(e) => deny( - format!("policy evaluation error: {e}"), - Some(bin_path), - Some(binary_pid), - ancestors, - cmdline_paths, - ), - }; - debug!( - "evaluate_opa_tcp TOTAL: {}ms host={host} port={port}", - total_start.elapsed().as_millis() - ); - result -} - -/// Non-Linux stub: OPA identity binding requires /proc. -#[cfg(not(target_os = "linux"))] -fn evaluate_opa_tcp( - _peer_addr: SocketAddr, - engine: &OpaEngine, - _identity_cache: &BinaryIdentityCache, - _entrypoint_pid: &AtomicU32, - _host: &str, - _port: u16, -) -> ConnectDecision { - ConnectDecision { - action: NetworkAction::Deny { - reason: "identity binding unavailable on this platform".into(), - }, - generation: engine.current_generation(), - binary: None, - binary_pid: None, - ancestors: vec![], - cmdline_paths: vec![], - } -} - -/// Maximum buffer size for inference request parsing (10 MiB). -const MAX_INFERENCE_BUF: usize = 10 * 1024 * 1024; - -/// Initial buffer size for inference request parsing (64 KiB). -const INITIAL_INFERENCE_BUF: usize = 65536; - -/// Handle an intercepted connection for inference routing. -/// -/// TLS-terminates the client connection, parses HTTP requests, and executes -/// inference API calls locally via `openshell-router`. -/// Non-inference requests are denied with 403. -/// -/// Returns [`InferenceOutcome::Routed`] if at least one request was successfully -/// routed, or [`InferenceOutcome::Denied`] with a reason for all denial cases. -async fn handle_inference_interception( - client: TcpStream, - host: &str, - port: u16, - tls_state: Option<&Arc>, - inference_ctx: Option<&Arc>, -) -> Result { - let Some(ctx) = inference_ctx else { - return Ok(InferenceOutcome::Denied { - reason: "cluster inference context not configured".to_string(), - }); - }; - - let Some(tls) = tls_state else { - return Ok(InferenceOutcome::Denied { - reason: "missing TLS state".to_string(), - }); - }; - - // TLS-terminate the client side (present a cert for the target host) - let mut tls_client = match crate::l7::tls::tls_terminate_client(client, tls, host).await { - Ok(c) => c, - Err(e) => { - return Ok(InferenceOutcome::Denied { - reason: format!("TLS handshake failed: {e}"), - }); - } - }; - - process_inference_keepalive(&mut tls_client, ctx, port).await -} - -/// Read and process HTTP requests from a TLS-terminated inference connection. -/// -/// Each request is matched against inference patterns and routed locally. -/// Any non-inference request is immediately denied and the connection is closed, -/// even if previous requests on the same keep-alive connection were routed -/// successfully. -async fn process_inference_keepalive( - stream: &mut S, - ctx: &InferenceContext, - port: u16, -) -> Result { - use crate::l7::inference::{ParseResult, format_http_response, try_parse_http_request}; - - let mut buf = vec![0u8; INITIAL_INFERENCE_BUF]; - let mut used = 0usize; - let mut routed_any = false; - - loop { - let n = match stream.read(&mut buf[used..]).await { - Ok(n) => n, - Err(e) => { - if routed_any { - break; - } - return Ok(InferenceOutcome::Denied { - reason: format!("I/O error: {e}"), - }); - } - }; - if n == 0 { - if routed_any { - break; - } - return Ok(InferenceOutcome::Denied { - reason: "client closed connection".to_string(), - }); - } - used += n; - - // Try to parse a complete HTTP request - match try_parse_http_request(&buf[..used]) { - ParseResult::Complete(request, consumed) => { - let was_routed = route_inference_request(&request, ctx, stream).await?; - if was_routed { - routed_any = true; - } else { - // Deny and close: a non-inference request must not be silently - // ignored on a keep-alive connection that previously routed - // inference traffic. - return Ok(InferenceOutcome::Denied { - reason: "connection not allowed by policy".to_string(), - }); - } - - // Shift buffer for next request - buf.copy_within(consumed..used, 0); - used -= consumed; - } - ParseResult::Incomplete => { - // Need more data — grow buffer if full - if used == buf.len() { - if buf.len() >= MAX_INFERENCE_BUF { - let response = format_http_response(413, &[], b"Payload Too Large"); - write_all(stream, &response).await?; - if routed_any { - break; - } - return Ok(InferenceOutcome::Denied { - reason: "payload too large".to_string(), - }); - } - buf.resize((buf.len() * 2).min(MAX_INFERENCE_BUF), 0); - } - } - ParseResult::Invalid(reason) => { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Refuse) - .action(ActionId::Denied) - .disposition(DispositionId::Rejected) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, port)) - .message(format!("Rejecting malformed inference request: {reason}")) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - let response = format_http_response(400, &[], b"Bad Request"); - write_all(stream, &response).await?; - return Ok(InferenceOutcome::Denied { reason }); - } - } - } - - Ok(InferenceOutcome::Routed) -} - -/// Route a parsed inference request locally via the sandbox router, or deny it. -/// -/// Returns `Ok(true)` if the request was routed to an inference backend, -/// `Ok(false)` if it was denied as a non-inference request. -async fn route_inference_request( - request: &crate::l7::inference::ParsedHttpRequest, - ctx: &InferenceContext, - tls_client: &mut (impl tokio::io::AsyncWrite + Unpin), -) -> Result { - use crate::l7::inference::{detect_inference_pattern, format_http_response}; - - let normalized_path = normalize_inference_path(&request.path); - - if let Some(pattern) = - detect_inference_pattern(&request.method, &normalized_path, &ctx.patterns) - { - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Allowed) - .disposition(DispositionId::Detected) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) - .message(format!( - "Intercepted inference request, routing locally: {} {} (protocol={}, kind={})", - request.method, normalized_path, pattern.protocol, pattern.kind - )) - .build(); - ocsf_emit!(event); - } - - let routes = ctx.routes.read().await; - - if routes.is_empty() { - let body = serde_json::json!({ - "error": "cluster inference is not configured", - "hint": "run: openshell cluster inference set --help" - }); - let body_bytes = body.to_string(); - let response = format_http_response( - 503, - &[("content-type".to_string(), "application/json".to_string())], - body_bytes.as_bytes(), - ); - write_all(tls_client, &response).await?; - return Ok(true); - } - - // Buffered protocols (embeddings, model discovery) return a single JSON - // object, not an SSE token stream. Serve them buffered with an accurate - // Content-Length: the streaming path would append an SSE error frame to - // the body on a size-cap or idle-timeout truncation, corrupting a - // payload the client parses as one JSON object. Framing is declared per - // protocol on the matched pattern. - if pattern.is_buffered() { - match ctx - .router - .proxy_with_candidates( - &pattern.protocol, - &request.method, - &normalized_path, - request.headers.clone(), - bytes::Bytes::from(request.body.clone()), - &routes, - ) - .await - { - Ok(resp) => { - let resp_headers = sanitize_inference_response_headers(resp.headers); - let response = format_http_response(resp.status, &resp_headers, &resp.body); - write_all(tls_client, &response).await?; - } - Err(e) => write_inference_router_error(tls_client, &e).await?, - } - return Ok(true); - } - - match ctx - .router - .proxy_with_candidates_streaming( - &pattern.protocol, - &request.method, - &normalized_path, - request.headers.clone(), - bytes::Bytes::from(request.body.clone()), - &routes, - ) - .await - { - Ok(mut resp) => { - use crate::l7::inference::{ - format_chunk, format_chunk_terminator, format_http_response_header, - format_sse_error, - }; - - let resp_headers = sanitize_inference_response_headers( - std::mem::take(&mut resp.headers).into_iter().collect(), - ); - - // Write response headers immediately (chunked TE). - let header_bytes = format_http_response_header(resp.status, &resp_headers); - write_all(tls_client, &header_bytes).await?; - - // Stream body chunks with byte cap and idle timeout. - // - // Each upstream chunk is wrapped in HTTP chunked framing and - // flushed immediately so SSE events reach the client without - // delay. Unlike the previous per-byte write_all+flush, we - // coalesce the framing header + data + trailer into a single - // write_all call, reducing the number of TLS records per chunk - // from 3 to 1 while preserving incremental delivery. - let mut total_bytes: usize = 0; - loop { - match tokio::time::timeout(CHUNK_IDLE_TIMEOUT, resp.next_chunk()).await { - Ok(Ok(Some(chunk))) => { - total_bytes += chunk.len(); - if total_bytes > MAX_STREAMING_BODY { - warn!( - total_bytes = total_bytes, - limit = MAX_STREAMING_BODY, - "streaming response exceeded byte limit, truncating" - ); - let err = format_sse_error( - "response truncated: exceeded maximum streaming body size", - ); - let _ = write_all(tls_client, &format_chunk(&err)).await; - break; - } - let encoded = format_chunk(&chunk); - write_all(tls_client, &encoded).await?; - } - Ok(Ok(None)) => break, - Ok(Err(e)) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) - .message(format!( - "error reading upstream response chunk after \ - {total_bytes} bytes: {e}" - )) - .build(); - ocsf_emit!(event); - let err = format_sse_error("response truncated: upstream read error"); - let _ = write_all(tls_client, &format_chunk(&err)).await; - break; - } - Err(_) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) - .message(format!( - "streaming response chunk idle timeout after \ - {total_bytes} bytes, closing" - )) - .build(); - ocsf_emit!(event); - let err = - format_sse_error("response truncated: chunk idle timeout exceeded"); - let _ = write_all(tls_client, &format_chunk(&err)).await; - break; - } - } - } - - // Terminate the chunked stream. - write_all(tls_client, format_chunk_terminator()).await?; - } - Err(e) => write_inference_router_error(tls_client, &e).await?, - } - Ok(true) - } else { - // Not an inference request — deny - { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) - .message(format!( - "connection not allowed by policy: {} {}", - request.method, normalized_path - )) - .build(); - ocsf_emit!(event); - } - let body = serde_json::json!({"error": "connection not allowed by policy"}); - let body_bytes = body.to_string(); - let response = format_http_response( - 403, - &[("content-type".to_string(), "application/json".to_string())], - body_bytes.as_bytes(), - ); - write_all(tls_client, &response).await?; - Ok(false) - } -} - -/// Emit an OCSF failure event and write a buffered JSON error response for a -/// router error hit while proxying an inference request. -/// -/// Shared by the streaming and buffered routing paths so both surface upstream -/// failures with the same status mapping and the same audit record. -async fn write_inference_router_error( - tls_client: &mut (impl tokio::io::AsyncWrite + Unpin), - err: &openshell_router::RouterError, -) -> Result<()> { - use crate::l7::inference::format_http_response; - - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(INFERENCE_LOCAL_HOST, 443)) - .message(format!( - "inference endpoint detected but upstream service failed: {err}" - )) - .build(); - ocsf_emit!(event); - - let (status, msg) = router_error_to_http(err); - let body = serde_json::json!({ "error": msg }).to_string(); - let response = format_http_response( - status, - &[("content-type".to_string(), "application/json".to_string())], - body.as_bytes(), - ); - write_all(tls_client, &response).await -} - -/// Map router errors to HTTP status codes and sanitized messages. -/// -/// Returns generic, client-safe messages instead of verbatim internal details; -/// the full error is recorded in the OCSF failure event by the caller. -fn router_error_to_http(err: &openshell_router::RouterError) -> (u16, String) { - use openshell_router::RouterError; - match err { - RouterError::RouteNotFound(_) => (400, "no inference route configured".to_string()), - RouterError::NoCompatibleRoute(_) => { - (400, "no compatible inference route available".to_string()) - } - RouterError::Unauthorized(_) => (401, "unauthorized".to_string()), - RouterError::UpstreamUnavailable(_) => (503, "inference service unavailable".to_string()), - RouterError::UpstreamProtocol(_) | RouterError::Internal(_) => { - (502, "inference service error".to_string()) - } - } -} - -fn sanitize_inference_response_headers(headers: Vec<(String, String)>) -> Vec<(String, String)> { - headers - .into_iter() - .filter(|(name, _)| !should_strip_response_header(name)) - .collect() -} - -fn should_strip_response_header(name: &str) -> bool { - let name_lc = name.to_ascii_lowercase(); - matches!(name_lc.as_str(), "content-length") || is_hop_by_hop_header(&name_lc) -} - -fn is_hop_by_hop_header(name: &str) -> bool { - matches!( - name, - "connection" - | "keep-alive" - | "proxy-authenticate" - | "proxy-authorization" - | "proxy-connection" - | "te" - | "trailer" - | "transfer-encoding" - | "upgrade" - ) -} - -/// Write all bytes to an async writer. -async fn write_all(writer: &mut (impl tokio::io::AsyncWrite + Unpin), data: &[u8]) -> Result<()> { - use tokio::io::AsyncWriteExt; - writer.write_all(data).await.into_diagnostic()?; - writer.flush().await.into_diagnostic()?; - Ok(()) -} - -#[derive(Debug, Clone)] -struct L7ConfigSnapshot { - config: crate::l7::L7EndpointConfig, -} - -#[derive(Debug, Clone)] -struct L7RouteSnapshot { - configs: Vec, - generation: u64, -} - -fn emit_l7_tunnel_close_after_policy_change(host: &str, port: u16, error: miette::Report) { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Open) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(host, port)) - .message(format!( - "L7 tunnel closed before inspection because policy changed: {error}" - )) - .build(); - ocsf_emit!(event); -} - -/// Query L7 endpoint config from the OPA engine for a matched CONNECT decision. -/// -/// Returns `Some(L7EndpointConfig)` if the matched endpoint has L7 config (protocol field), -/// `None` for L4-only endpoints. -fn query_l7_route_snapshot( - engine: &OpaEngine, - decision: &ConnectDecision, - host: &str, - port: u16, -) -> Option { - // Only query if action is Allow (not Deny) - let has_policy = match &decision.action { - NetworkAction::Allow { matched_policy } => matched_policy.is_some(), - NetworkAction::Deny { .. } => false, - }; - if !has_policy { - return None; - } - - let input = crate::opa::NetworkInput { - host: host.to_string(), - port, - binary_path: decision.binary.clone().unwrap_or_default(), - binary_sha256: String::new(), - ancestors: decision.ancestors.clone(), - cmdline_paths: decision.cmdline_paths.clone(), - }; - - match engine.query_endpoint_configs_with_generation(&input) { - Ok((vals, generation)) => Some(L7RouteSnapshot { - configs: vals - .into_iter() - .filter_map(|val| crate::l7::parse_l7_config(&val)) - .map(|config| L7ConfigSnapshot { config }) - .collect(), - generation, - }), - Err(e) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(host, port)) - .message(format!("Failed to query L7 endpoint config: {e}")) - .build(); - ocsf_emit!(event); - None - } - } -} - -fn select_l7_config_for_path<'a>( - configs: &'a [L7ConfigSnapshot], - path: &str, -) -> Option<&'a L7ConfigSnapshot> { - configs - .iter() - .filter(|snapshot| snapshot.config.matches_path(path)) - .max_by_key(|snapshot| snapshot.config.path_specificity()) -} - -/// Query the TLS mode for an endpoint, independent of L7 config. -/// -/// This extracts `tls: skip` from the endpoint even when no `protocol` is set. -fn query_tls_mode( - engine: &OpaEngine, - decision: &ConnectDecision, - host: &str, - port: u16, -) -> crate::l7::TlsMode { - let has_policy = match &decision.action { - NetworkAction::Allow { matched_policy } => matched_policy.is_some(), - NetworkAction::Deny { .. } => false, - }; - if !has_policy { - return crate::l7::TlsMode::Auto; - } - - let input = crate::opa::NetworkInput { - host: host.to_string(), - port, - binary_path: decision.binary.clone().unwrap_or_default(), - binary_sha256: String::new(), - ancestors: decision.ancestors.clone(), - cmdline_paths: decision.cmdline_paths.clone(), - }; - - match engine.query_endpoint_config(&input) { - Ok(Some(val)) => crate::l7::parse_tls_mode(&val), - _ => crate::l7::TlsMode::Auto, - } -} - -/// When the policy endpoint host is a literal IP address, the user has -/// explicitly declared intent to allow that destination. Synthesize an -/// `allowed_ips` entry so the existing allowlist-validation path is used -/// instead of the blanket internal-IP rejection. -/// -/// Always-blocked addresses (loopback, link-local, unspecified) are skipped -/// — synthesizing an `allowed_ips` entry for them would be silently -/// un-enforceable at runtime. -fn implicit_allowed_ips_for_ip_host(host: &str) -> Vec { - let lookup_host = normalize_host_lookup_key(host); - if let Ok(ip) = lookup_host.parse::() { - if is_always_blocked_ip(ip) { - warn!( - host, - "Policy host is an always-blocked address; \ - implicit allowed_ips skipped — SSRF hardening prevents \ - traffic to this destination regardless of policy" - ); - return vec![]; - } - vec![lookup_host.to_string()] - } else { - vec![] - } -} - -fn normalize_host_lookup_key(host: &str) -> &str { - host.strip_prefix('[') - .and_then(|trimmed| trimmed.strip_suffix(']')) - .unwrap_or(host) -} - -/// Returns `true` if `host` is one of the well-known driver-injected aliases -/// for the host machine (e.g. `host.openshell.internal`). -fn is_host_gateway_alias(host: &str) -> bool { - let h = normalize_host_lookup_key(host); - HOST_GATEWAY_ALIASES - .iter() - .any(|alias| alias.eq_ignore_ascii_case(h)) -} - -/// Returns `true` if `ip` is a known cloud instance metadata endpoint that -/// must never be exempted from SSRF blocking. -/// -/// IPv4-mapped IPv6 addresses (e.g. `::ffff:169.254.169.254`) are normalized -/// to their embedded IPv4 representation before comparison, so the invariant -/// holds regardless of how the address is represented. -fn is_cloud_metadata_ip(ip: IpAddr) -> bool { - match ip { - IpAddr::V4(_) => CLOUD_METADATA_IPS.contains(&ip), - IpAddr::V6(v6) => v6 - .to_ipv4_mapped() - .is_some_and(|v4| CLOUD_METADATA_IPS.contains(&IpAddr::V4(v4))), - } -} - -/// Read the proxy's own `/etc/hosts` at startup and return the IP mapped to -/// `host.openshell.internal`, if present and safe. -/// -/// This is called once before user code runs, so the returned value is immune -/// to later `/etc/hosts` tampering by sandbox workloads. Returns `None` if no -/// entry exists, the entry cannot be parsed, or the mapped IP is a cloud -/// metadata address. -#[cfg(any(target_os = "linux", test))] -fn detect_trusted_host_gateway() -> Option { - let contents = std::fs::read_to_string("/etc/hosts").ok()?; - let ips = parse_hosts_file_for_host(&contents, "host.openshell.internal"); - - // Multiple distinct IPs for the alias is unexpected — compute drivers - // always inject exactly one. Warn loudly so operators can diagnose the - // inconsistency; we still proceed with the first entry rather than - // disabling the exemption entirely, because the mismatch guard in - // resolve_and_check_trusted_gateway() will reject any runtime resolution - // that returns a different IP. - if ips.len() > 1 { - warn!( - ips = ?ips, - "host.openshell.internal has {} distinct IPs in /etc/hosts; \ - expected exactly one. Using first entry. \ - Connections resolving to any other IP will be rejected.", - ips.len() - ); - } - - let ip = ips.into_iter().next()?; - - if is_cloud_metadata_ip(ip) { - warn!( - %ip, - "host.openshell.internal resolves to a cloud metadata IP; \ - trusted-gateway SSRF exemption disabled" - ); - return None; - } - // The exemption exists solely for link-local IPs used by rootless Podman - // with pasta. Private RFC 1918 addresses (e.g. Docker bridge 172.17.0.1, - // Kubernetes node 192.168.x.x), loopback, unspecified, and all other - // non-link-local addresses are never legitimate candidates for the - // link-local SSRF exemption — they must fall through to the normal - // allowed_ips / resolve_and_reject_internal() enforcement path. - if !is_link_local_ip(ip) { - warn!( - %ip, - "host.openshell.internal maps to a non-link-local IP; \ - trusted-gateway SSRF exemption disabled" - ); - return None; - } - Some(ip) -} - -#[cfg(not(any(target_os = "linux", test)))] -fn detect_trusted_host_gateway() -> Option { - None -} - -/// Resolve `host:port` and validate that every resolved address matches the -/// trusted host gateway IP. -/// -/// This bypasses the normal SSRF tiers (always-blocked and internal-IP) for -/// driver-injected host-gateway aliases, allowing link-local addresses used -/// by rootless Podman with pasta without opening up arbitrary link-local or -/// cloud metadata access. -/// -/// Rejects: -/// - Any resolved IP that is a cloud metadata address (defense-in-depth) -/// - Any resolved IP that does not match `trusted_gw` (prevents /etc/hosts tampering) -/// - Control-plane ports (etcd, K8s API, kubelet) regardless of IP -async fn resolve_and_check_trusted_gateway( - host: &str, - port: u16, - trusted_gw: IpAddr, - entrypoint_pid: u32, -) -> std::result::Result, String> { - if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { - return Err(format!( - "port {port} is a blocked control-plane port, connection rejected" - )); - } - let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; - if addrs.is_empty() { - return Err(format!( - "DNS resolution returned no addresses for {}", - normalize_host_lookup_key(host) - )); - } - for addr in &addrs { - if is_cloud_metadata_ip(addr.ip()) { - return Err(format!( - "{host} resolves to cloud metadata address {}, connection rejected", - addr.ip() - )); - } - if addr.ip() != trusted_gw { - return Err(format!( - "{host} resolves to {} which does not match trusted host gateway \ - {trusted_gw}, connection rejected", - addr.ip() - )); - } - // Defense-in-depth: even if the resolved IP matches trusted_gw, reject - // any non-link-local address. detect_trusted_host_gateway() already - // enforces this at startup, but we re-check here to guard against any - // unanticipated code path that might admit a private or loopback IP. - if !is_link_local_ip(addr.ip()) { - return Err(format!( - "{host} resolves to non-link-local address {}, \ - connection rejected", - addr.ip() - )); - } - } - Ok(addrs) -} - -fn resolve_ip_literal(host: &str, port: u16) -> Option> { - normalize_host_lookup_key(host) - .parse::() - .ok() - .map(|ip| vec![SocketAddr::new(ip, port)]) -} - -#[cfg(any(target_os = "linux", test))] -fn parse_hosts_file_for_host(contents: &str, host: &str) -> Vec { - let lookup_host = normalize_host_lookup_key(host); - let mut addrs = Vec::new(); - - for raw_line in contents.lines() { - let line = raw_line.split('#').next().unwrap_or("").trim(); - if line.is_empty() { - continue; - } - - let mut fields = line.split_whitespace(); - let Some(ip_str) = fields.next() else { - continue; - }; - let Ok(ip) = ip_str.parse::() else { - continue; - }; - - if fields.any(|alias| alias.eq_ignore_ascii_case(lookup_host)) && !addrs.contains(&ip) { - addrs.push(ip); - } - } - - addrs -} - -#[cfg(any(target_os = "linux", test))] -fn resolve_from_hosts_file_contents(contents: &str, host: &str, port: u16) -> Vec { - parse_hosts_file_for_host(contents, host) - .into_iter() - .map(|ip| SocketAddr::new(ip, port)) - .collect() -} - -#[cfg(target_os = "linux")] -async fn resolve_from_sandbox_hosts( - host: &str, - port: u16, - entrypoint_pid: u32, -) -> Option> { - if entrypoint_pid == 0 { - return None; - } - - let hosts_path = format!("/proc/{entrypoint_pid}/root/etc/hosts"); - let contents = match tokio::fs::read_to_string(&hosts_path).await { - Ok(contents) => contents, - Err(error) => { - debug!( - pid = entrypoint_pid, - path = %hosts_path, - host, - "Falling back to DNS; failed to read sandbox hosts file: {error}" - ); - return None; - } - }; - - let addrs = resolve_from_hosts_file_contents(&contents, host, port); - if addrs.is_empty() { None } else { Some(addrs) } -} - -// Mirrors the Linux signature so call sites can `.await` uniformly across -// platforms; the non-Linux path has nothing to await. -#[cfg(not(target_os = "linux"))] -#[allow(clippy::unused_async)] -async fn resolve_from_sandbox_hosts( - _host: &str, - _port: u16, - _entrypoint_pid: u32, -) -> Option> { - None -} - -async fn resolve_socket_addrs( - host: &str, - port: u16, - entrypoint_pid: u32, -) -> std::result::Result, String> { - if let Some(addrs) = resolve_ip_literal(host, port) { - return Ok(addrs); - } - - if let Some(addrs) = resolve_from_sandbox_hosts(host, port, entrypoint_pid).await { - return Ok(addrs); - } - - let lookup_host = normalize_host_lookup_key(host); - let addrs: Vec = tokio::net::lookup_host((lookup_host, port)) - .await - .map_err(|e| format!("DNS resolution failed for {lookup_host}:{port}: {e}"))? - .collect(); - - if addrs.is_empty() { - return Err(format!( - "DNS resolution returned no addresses for {lookup_host}:{port}" - )); - } - - Ok(addrs) -} - -fn reject_internal_resolved_addrs( - host: &str, - addrs: &[SocketAddr], -) -> std::result::Result<(), String> { - if addrs.is_empty() { - return Err(format!( - "DNS resolution returned no addresses for {}", - normalize_host_lookup_key(host) - )); - } - - for addr in addrs { - if is_internal_ip(addr.ip()) { - return Err(format!( - "{host} resolves to internal address {}, connection rejected", - addr.ip() - )); - } - } - - Ok(()) -} - -fn validate_allowed_ips_for_resolved_addrs( - host: &str, - port: u16, - addrs: &[SocketAddr], - allowed_ips: &[ipnet::IpNet], -) -> std::result::Result<(), String> { - if addrs.is_empty() { - return Err(format!( - "DNS resolution returned no addresses for {}", - normalize_host_lookup_key(host) - )); - } - - // Block control-plane ports regardless of IP match. - if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { - return Err(format!( - "port {port} is a blocked control-plane port, connection rejected" - )); - } - - for addr in addrs { - // Always block loopback and link-local - if is_always_blocked_ip(addr.ip()) { - return Err(format!( - "{host} resolves to always-blocked address {}, connection rejected", - addr.ip() - )); - } - - // Check resolved IP against the allowlist - let ip_allowed = allowed_ips.iter().any(|net| net.contains(&addr.ip())); - if !ip_allowed { - return Err(format!( - "{host} resolves to {} which is not in allowed_ips, connection rejected", - addr.ip() - )); - } - } - - Ok(()) -} - -fn validate_declared_endpoint_resolved_addrs( - host: &str, - port: u16, - addrs: &[SocketAddr], -) -> std::result::Result<(), String> { - if addrs.is_empty() { - return Err(format!( - "DNS resolution returned no addresses for {}", - normalize_host_lookup_key(host) - )); - } - - if BLOCKED_CONTROL_PLANE_PORTS.contains(&port) { - return Err(format!( - "port {port} is a blocked control-plane port, connection rejected" - )); - } - - for addr in addrs { - if is_always_blocked_ip(addr.ip()) { - return Err(format!( - "{host} resolves to always-blocked address {}, connection rejected", - addr.ip() - )); - } - } - - Ok(()) -} - -/// Resolve a host:port using sandbox `/etc/hosts` first (when available), then -/// reject if any resolved address is internal. -/// -/// Returns the resolved `SocketAddr` list on success. Returns an error string -/// if any resolved IP is in an internal range or if DNS resolution fails. -async fn resolve_and_reject_internal( - host: &str, - port: u16, - entrypoint_pid: u32, -) -> std::result::Result, String> { - let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; - reject_internal_resolved_addrs(host, &addrs)?; - Ok(addrs) -} - -/// Resolve a host:port using sandbox `/etc/hosts` first (when available), then -/// validate resolved addresses against a CIDR/IP allowlist. -/// -/// Rejects loopback and link-local unconditionally. For all other resolved -/// addresses, checks that each one matches at least one entry in `allowed_ips`. -/// Entries can be CIDR notation ("10.0.5.0/24") or exact IPs ("10.0.5.20"). -/// -/// Returns the resolved `SocketAddr` list on success. -async fn resolve_and_check_allowed_ips( - host: &str, - port: u16, - allowed_ips: &[ipnet::IpNet], - entrypoint_pid: u32, -) -> std::result::Result, String> { - let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; - validate_allowed_ips_for_resolved_addrs(host, port, &addrs, allowed_ips)?; - Ok(addrs) -} - -/// Resolve a host:port that was explicitly declared by hostname in policy. -/// -/// Exact declared hostnames are the operator's trust signal, so RFC1918 and -/// other private ranges are allowed without a duplicated `allowed_ips` entry. -/// Loopback, link-local, unspecified, and control-plane ports remain blocked. -async fn resolve_and_check_declared_endpoint( - host: &str, - port: u16, - entrypoint_pid: u32, -) -> std::result::Result, String> { - let addrs = resolve_socket_addrs(host, port, entrypoint_pid).await?; - validate_declared_endpoint_resolved_addrs(host, port, &addrs)?; - Ok(addrs) -} - -/// Minimum CIDR prefix length before logging a breadth warning. -/// CIDRs broader than /16 (65,536+ addresses) may unintentionally expose -/// control-plane services on the same network. -const MIN_SAFE_PREFIX_LEN: u8 = 16; - -/// Ports that are always blocked in `resolve_and_check_allowed_ips`, even -/// when the resolved IP matches an `allowed_ips` entry. These ports belong -/// to control-plane services that should never be reachable from a sandbox. -const BLOCKED_CONTROL_PLANE_PORTS: &[u16] = &[ - 2379, // etcd client - 2380, // etcd peer - 6443, // Kubernetes API server - 10250, // kubelet API - 10255, // kubelet read-only -]; - -/// Parse CIDR/IP strings into `IpNet` values, rejecting invalid entries and -/// entries that overlap always-blocked ranges (loopback, link-local, -/// unspecified). -/// -/// Returns parsed networks on success, or an error describing which entries -/// are invalid or always-blocked. Logs a warning for overly broad CIDRs -/// that are not outright blocked. -fn parse_allowed_ips(raw: &[String]) -> std::result::Result, String> { - use openshell_core::net::is_always_blocked_net; - - let mut nets = Vec::with_capacity(raw.len()); - let mut errors = Vec::new(); - - for entry in raw { - // Try as CIDR first, then as bare IP (convert to /32 or /128) - let parsed = entry.parse::().or_else(|_| { - entry - .parse::() - .map(|ip| match ip { - IpAddr::V4(v4) => ipnet::IpNet::V4(ipnet::Ipv4Net::from(v4)), - IpAddr::V6(v6) => ipnet::IpNet::V6(ipnet::Ipv6Net::from(v6)), - }) - .map_err(|_| ()) - }); - - match parsed { - Ok(n) => { - // Reject entries that overlap always-blocked ranges — these - // would be silently denied at runtime by is_always_blocked_ip - // and cause confusing UX (accepted in policy, never works). - if is_always_blocked_net(n) { - errors.push(format!( - "allowed_ips entry {entry} falls within always-blocked range \ - (loopback/link-local/unspecified); remove this entry — \ - SSRF hardening prevents traffic to these destinations \ - regardless of policy" - )); - continue; - } - - if n.prefix_len() < MIN_SAFE_PREFIX_LEN { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .severity(SeverityId::Medium) - .message(format!( - "allowed_ips entry has a very broad CIDR {n} (/{}) < /{MIN_SAFE_PREFIX_LEN}; \ - this may expose control-plane services on the same network", - n.prefix_len() - )) - .build(); - ocsf_emit!(event); - } - nets.push(n); - } - Err(()) => errors.push(format!("invalid CIDR/IP in allowed_ips: {entry}")), - } - } - - if errors.is_empty() { - Ok(nets) - } else { - Err(errors.join("; ")) - } -} - -/// Query `allowed_ips` from the matched endpoint config for a CONNECT decision. -fn query_allowed_ips( - engine: &OpaEngine, - decision: &ConnectDecision, - host: &str, - port: u16, -) -> Vec { - // Only query if action is Allow with a matched policy - let has_policy = match &decision.action { - NetworkAction::Allow { matched_policy } => matched_policy.is_some(), - NetworkAction::Deny { .. } => false, - }; - if !has_policy { - return vec![]; - } - - let input = crate::opa::NetworkInput { - host: host.to_string(), - port, - binary_path: decision.binary.clone().unwrap_or_default(), - binary_sha256: String::new(), - ancestors: decision.ancestors.clone(), - cmdline_paths: decision.cmdline_paths.clone(), - }; - - match engine.query_allowed_ips(&input) { - Ok(ips) => ips, - Err(e) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(host, port)) - .message(format!( - "Failed to query allowed_ips from endpoint config: {e}" - )) - .build(); - ocsf_emit!(event); - vec![] - } - } -} - -/// Query whether the matched endpoint was declared as this exact hostname. -fn query_exact_declared_endpoint_host( - engine: &OpaEngine, - decision: &ConnectDecision, - host: &str, - port: u16, -) -> bool { - let has_policy = match &decision.action { - NetworkAction::Allow { matched_policy } => matched_policy.is_some(), - NetworkAction::Deny { .. } => false, - }; - if !has_policy { - return false; - } - - let input = crate::opa::NetworkInput { - host: host.to_string(), - port, - binary_path: decision.binary.clone().unwrap_or_default(), - binary_sha256: String::new(), - ancestors: decision.ancestors.clone(), - cmdline_paths: decision.cmdline_paths.clone(), - }; - - match engine.query_exact_declared_endpoint_host(&input) { - Ok(is_exact_declared) => is_exact_declared, - Err(e) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(host, port)) - .message(format!("Failed to query exact declared endpoint host: {e}")) - .build(); - ocsf_emit!(event); - false - } - } -} - -/// Canonicalize the request-target for inference pattern detection. -/// -/// Falls back to the raw path on canonicalization error: the request is then -/// routed through the normal forward path, where `rest.rs::parse_http_request` -/// will reject it properly. Returning the raw path here prevents a crafted -/// target from bypassing inference routing without our detection logic having -/// to implement a second, duplicate error-response surface. -fn normalize_inference_path(path: &str) -> String { - match crate::l7::path::canonicalize_request_target( - path, - &crate::l7::path::CanonicalizeOptions::default(), - ) { - Ok((canon, _)) => canon.path, - Err(_) => path.to_string(), - } -} - -/// Extract the hostname from an absolute-form URI used in plain HTTP proxy requests. -/// -/// For example, `"http://example.com/path"` yields `"example.com"` and -/// `"http://example.com:8080/path"` yields `"example.com"`. Returns `"unknown"` -/// if the URI cannot be parsed. -#[cfg(test)] -fn extract_host_from_uri(uri: &str) -> String { - // Absolute-form URIs look like "http://host[:port]/path" - // Strip the scheme prefix, then extract the authority (host[:port]) before the first '/'. - let after_scheme = uri.find("://").map_or(uri, |i| &uri[i + 3..]); - let authority = after_scheme.split('/').next().unwrap_or(after_scheme); - // Strip port if present (handle IPv6 bracket notation) - let host = if authority.starts_with('[') { - // IPv6: [::1]:port - authority.find(']').map_or(authority, |i| &authority[..=i]) - } else { - authority.split(':').next().unwrap_or(authority) - }; - if host.is_empty() { - "unknown".to_string() - } else { - host.to_string() - } -} - -/// Parse an absolute-form proxy request URI into its components. -/// -/// For example, `"http://10.86.8.223:8000/screenshot/"` yields -/// `("http", "10.86.8.223", 8000, "/screenshot/")`. -/// -/// Handles: -/// - Default port 80 for `http`, 443 for `https` -/// - IPv6 bracket notation (`[::1]`) -/// - Missing path (defaults to `/`) -/// - Query strings (preserved in path) -fn parse_proxy_uri(uri: &str) -> Result<(String, String, u16, String)> { - // Extract scheme - let (scheme, rest) = uri - .split_once("://") - .ok_or_else(|| miette::miette!("Missing scheme in proxy URI: {uri}"))?; - let scheme = scheme.to_ascii_lowercase(); - - // Split authority from path - let (authority, path) = if rest.starts_with('[') { - // IPv6: [::1]:port/path - let bracket_end = rest - .find(']') - .ok_or_else(|| miette::miette!("Unclosed IPv6 bracket in URI: {uri}"))?; - let after_bracket = &rest[bracket_end + 1..]; - after_bracket.find('/').map_or((rest, "/"), |slash_pos| { - ( - &rest[..=bracket_end + slash_pos], - &after_bracket[slash_pos..], - ) - }) - } else if let Some(slash_pos) = rest.find('/') { - (&rest[..slash_pos], &rest[slash_pos..]) - } else { - (rest, "/") - }; - - // Parse host and port from authority - let (host, port) = if authority.starts_with('[') { - // IPv6: [::1]:port or [::1] - let bracket_end = authority - .find(']') - .ok_or_else(|| miette::miette!("Unclosed IPv6 bracket: {uri}"))?; - let host = &authority[1..bracket_end]; // strip brackets - let port_str = &authority[bracket_end + 1..]; - let port = if let Some(port_str) = port_str.strip_prefix(':') { - port_str - .parse::() - .map_err(|_| miette::miette!("Invalid port in URI: {uri}"))? - } else { - match scheme.as_str() { - "https" => 443, - _ => 80, - } - }; - (host.to_string(), port) - } else if let Some((h, p)) = authority.rsplit_once(':') { - let port = p - .parse::() - .map_err(|_| miette::miette!("Invalid port in URI: {uri}"))?; - (h.to_string(), port) - } else { - let port = match scheme.as_str() { - "https" => 443, - _ => 80, - }; - (authority.to_string(), port) - }; - - if host.is_empty() { - return Err(miette::miette!("Empty host in URI: {uri}")); - } - - let path = if path.is_empty() { "/" } else { path }; - - Ok((scheme, host, port, path.to_string())) -} - -/// Rewrite an absolute-form HTTP proxy request to origin-form for upstream. -/// -/// Transforms `GET http://host:port/path HTTP/1.1` into `GET /path HTTP/1.1`, -/// strips proxy hop-by-hop headers, injects `Connection: close` and `Via`. -/// -/// Returns the rewritten request bytes (headers + any overflow body bytes). -fn rewrite_forward_request( - raw: &[u8], - used: usize, - path: &str, - secret_resolver: Option<&SecretResolver>, - request_body_credential_rewrite: bool, -) -> Result, crate::secrets::UnresolvedPlaceholderError> { - let header_end = raw[..used] - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map_or(used, |p| p + 4); - let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); - let upstream_path = match secret_resolver { - Some(resolver) => crate::secrets::rewrite_target_for_eval(path, resolver)?.resolved, - None => path.to_string(), - }; - - let header_str = String::from_utf8_lossy(&raw[..header_end]); - let lines = header_str.split("\r\n").collect::>(); - - // Rebuild headers, stripping hop-by-hop and adding proxy headers - let mut output = Vec::with_capacity(header_end + 128); - let mut has_connection = false; - let mut has_via = false; - - for (i, line) in lines.iter().enumerate() { - if i == 0 { - // Rewrite request line: METHOD absolute-uri HTTP/1.1 → METHOD path HTTP/1.1 - let parts: Vec<&str> = line.splitn(3, ' ').collect(); - if parts.len() == 3 { - output.extend_from_slice(parts[0].as_bytes()); - output.push(b' '); - output.extend_from_slice(upstream_path.as_bytes()); - output.push(b' '); - output.extend_from_slice(parts[2].as_bytes()); - } else { - output.extend_from_slice(line.as_bytes()); - } - output.extend_from_slice(b"\r\n"); - continue; - } - if line.is_empty() { - // End of headers - break; - } - - let lower = line.to_ascii_lowercase(); - - // Strip proxy hop-by-hop headers - if lower.starts_with("proxy-connection:") - || lower.starts_with("proxy-authorization:") - || lower.starts_with("proxy-authenticate:") - { - continue; - } - - // Replace Connection header - if lower.starts_with("connection:") { - has_connection = true; - if websocket_upgrade { - output.extend_from_slice(line.as_bytes()); - output.extend_from_slice(b"\r\n"); - continue; - } - output.extend_from_slice(b"Connection: close\r\n"); - continue; - } - - let rewritten_line = match secret_resolver { - Some(resolver) => rewrite_header_line_checked(line, resolver)?, - None => line.to_string(), - }; - - output.extend_from_slice(rewritten_line.as_bytes()); - output.extend_from_slice(b"\r\n"); - - if lower.starts_with("via:") { - has_via = true; - } - } - - // Inject missing headers - if !has_connection && !websocket_upgrade { - output.extend_from_slice(b"Connection: close\r\n"); - } - if !has_via { - output.extend_from_slice(b"Via: 1.1 openshell-sandbox\r\n"); - } - - // End of headers - output.extend_from_slice(b"\r\n"); - let rewritten_header_end = output.len(); - - // Append any overflow body bytes from the original buffer - if header_end < used { - output.extend_from_slice(&raw[header_end..used]); - } - - // Fail-closed: scan for any remaining unresolved placeholders - if secret_resolver.is_some() { - let scan_end = if request_body_credential_rewrite { - rewritten_header_end - } else { - output.len() - }; - let output_str = String::from_utf8_lossy(&output[..scan_end]); - if output_str.contains(crate::secrets::PLACEHOLDER_PREFIX_PUBLIC) - || output_str.contains(crate::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) - { - return Err(crate::secrets::UnresolvedPlaceholderError { location: "header" }); - } - } - - Ok(output) -} - -struct ForwardRelayOptions<'a> { - generation_guard: &'a PolicyGenerationGuard, - websocket_extensions: crate::l7::rest::WebSocketExtensionMode, - secret_resolver: Option<&'a SecretResolver>, - request_body_credential_rewrite: bool, -} - -async fn relay_rewritten_forward_request( - method: &str, - path: &str, - rewritten: Vec, - client: &mut C, - upstream: &mut U, - options: ForwardRelayOptions<'_>, -) -> Result -where - C: TokioAsyncRead + TokioAsyncWrite + Unpin, - U: TokioAsyncRead + TokioAsyncWrite + Unpin, -{ - let header_end = rewritten - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map_or(rewritten.len(), |p| p + 4); - let header_str = String::from_utf8_lossy(&rewritten[..header_end]); - let body_length = crate::l7::rest::parse_body_length(&header_str)?; - let (_, query_params) = crate::l7::rest::parse_target_query(path)?; - let req = crate::l7::provider::L7Request { - action: method.to_string(), - target: path.to_string(), - query_params, - raw_header: rewritten, - body_length, - }; - - crate::l7::rest::relay_http_request_with_options_guarded( - &req, - client, - upstream, - crate::l7::rest::RelayRequestOptions { - resolver: options.secret_resolver, - generation_guard: Some(options.generation_guard), - websocket_extensions: options.websocket_extensions, - request_body_credential_rewrite: options.request_body_credential_rewrite, - }, - ) - .await -} - -async fn inject_token_grant_for_forward_request( - method: &str, - upstream_target: &str, - forward_request_bytes: Vec, - l7_ctx: &crate::l7::relay::L7EvalContext, -) -> Result> { - let header_end = forward_request_bytes - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map_or(forward_request_bytes.len(), |p| p + 4); - let header_str = std::str::from_utf8(&forward_request_bytes[..header_end]) - .into_diagnostic() - .map_err(|_| miette::miette!("Forward HTTP headers contain invalid UTF-8"))?; - let body_length = crate::l7::rest::parse_body_length(header_str)?; - let forward_request_for_token_grant = crate::l7::provider::L7Request { - action: method.to_string(), - target: upstream_target.to_string(), - query_params: std::collections::HashMap::new(), - raw_header: forward_request_bytes, - body_length, - }; - - crate::l7::token_grant_injection::inject_if_needed(forward_request_for_token_grant, l7_ctx) - .await - .map(|req| req.raw_header) -} - -/// Handle a plain HTTP forward proxy request (non-CONNECT). -/// -/// Public IPs are allowed through when the endpoint passes OPA evaluation. -/// Private IPs require explicit `allowed_ips` on the endpoint config (SSRF -/// override). Rewrites the absolute-form request to origin-form, connects -/// upstream, and relays the request/response using the guarded HTTP relay. -// Many distinct, non-related context parameters are required for forward proxy -// dispatch; bundling them into a struct would just shift the noise into call sites. -#[allow(clippy::too_many_arguments)] -async fn handle_forward_proxy( - method: &str, - target_uri: &str, - buf: &[u8], - used: usize, - client: &mut TcpStream, - opa_engine: Arc, - identity_cache: Arc, - entrypoint_pid: Arc, - policy_local_ctx: Option>, - trusted_host_gateway: Arc>, - secret_resolver: Option>, - dynamic_credentials: Option< - Arc< - std::sync::RwLock< - std::collections::HashMap, - >, - >, - >, - denial_tx: Option<&mpsc::UnboundedSender>, - activity_tx: Option<&ActivitySender>, -) -> Result<()> { - // 1. Parse the absolute-form URI. `path` is marked `mut` so that, when an - // L7 config applies, the canonicalized form produced below replaces it - // in-place — keeping OPA evaluation and the bytes written onto the wire - // in sync. See the L7 block below. - let (scheme, host, port, mut path) = match parse_proxy_uri(target_uri) { - Ok(parsed) => parsed, - Err(e) => { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .message(format!("FORWARD parse error for {target_uri}: {e}")) - .build(); - ocsf_emit!(event); - respond(client, b"HTTP/1.1 400 Bad Request\r\n\r\n").await?; - return Ok(()); - } - }; - let host_lc = host.to_ascii_lowercase(); - - if host_lc == POLICY_LOCAL_HOST { - if scheme != "http" || port != 80 { - respond( - client, - &build_json_error_response( - 400, - "Bad Request", - "invalid_policy_local_scheme", - "Use http://policy.local only", - ), - ) - .await?; - return Ok(()); - } - if let Some(ctx) = policy_local_ctx { - return crate::policy_local::handle_forward_request( - &ctx, - method, - &path, - &buf[..used], - client, - ) - .await; - } - respond( - client, - b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 31\r\n\r\npolicy.local is not configured", - ) - .await?; - return Ok(()); - } - - // 2. Reject HTTPS — must use CONNECT for TLS - if scheme == "https" { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Refuse) - .action(ActionId::Denied) - .disposition(DispositionId::Rejected) - .severity(SeverityId::Informational) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!( - "FORWARD rejected: HTTPS requires CONNECT for {host_lc}:{port}" - )) - .build(); - ocsf_emit!(event); - } - respond( - client, - b"HTTP/1.1 400 Bad Request\r\nContent-Length: 27\r\n\r\nUse CONNECT for HTTPS URLs", - ) - .await?; - return Ok(()); - } - - // 3. Evaluate OPA policy (same identity binding as CONNECT) - let peer_addr = client.peer_addr().into_diagnostic()?; - let _local_addr = client.local_addr().into_diagnostic()?; - - let opa_clone = opa_engine.clone(); - let cache_clone = identity_cache.clone(); - let pid_clone = entrypoint_pid.clone(); - let host_clone = host_lc.clone(); - let decision = tokio::task::spawn_blocking(move || { - evaluate_opa_tcp( - peer_addr, - &opa_clone, - &cache_clone, - &pid_clone, - &host_clone, - port, - ) - }) - .await - .map_err(|e| miette::miette!("identity resolution task panicked: {e}"))?; - - // Build log context - let binary_str = decision - .binary - .as_ref() - .map_or_else(|| "-".to_string(), |p| p.display().to_string()); - let pid_str = decision - .binary_pid - .map_or_else(|| "-".to_string(), |p| p.to_string()); - let ancestors_str = if decision.ancestors.is_empty() { - "-".to_string() - } else { - decision - .ancestors - .iter() - .map(|p| p.display().to_string()) - .collect::>() - .join(" -> ") - }; - let cmdline_str = if decision.cmdline_paths.is_empty() { - "-".to_string() - } else { - decision - .cmdline_paths - .iter() - .map(|p| p.display().to_string()) - .collect::>() - .join(", ") - }; - - // 4. Only proceed on explicit Allow — reject Deny - let matched_policy = match &decision.action { - NetworkAction::Allow { matched_policy } => matched_policy.clone(), - NetworkAction::Deny { reason } => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule("-", "opa") - .message(format!("FORWARD denied {method} {host_lc}:{port}{path}")) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - reason, - "forward", - ); - emit_activity_simple(activity_tx, true, "forward_policy"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} not permitted by policy"), - ), - ) - .await?; - return Ok(()); - } - }; - let policy_str = matched_policy.as_deref().unwrap_or("-"); - let sandbox_entrypoint_pid = entrypoint_pid.load(Ordering::Acquire); - let forward_generation_guard = match opa_engine.generation_guard(decision.generation) { - Ok(guard) => guard, - Err(e) => { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - emit_activity_simple(activity_tx, true, "policy_stale"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} not permitted by policy"), - ), - ) - .await?; - return Ok(()); - } - }; - let mut forward_request_bytes = buf[..used].to_vec(); - let mut upstream_target = path.clone(); - let mut websocket_extensions = crate::l7::rest::WebSocketExtensionMode::Preserve; - let mut forward_tunnel_engine: Option = None; - let mut forward_upgrade_config: Option = None; - let mut forward_upgrade_target = String::new(); - let mut forward_upgrade_query_params = std::collections::HashMap::new(); - let mut forward_websocket_request = - crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); - let mut request_body_credential_rewrite = false; - let l7_ctx = crate::l7::relay::L7EvalContext { - host: host_lc.clone(), - port, - policy_name: matched_policy.clone().unwrap_or_default(), - binary_path: decision - .binary - .as_ref() - .map(|p| p.to_string_lossy().into_owned()) - .unwrap_or_default(), - ancestors: decision - .ancestors - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - cmdline_paths: decision - .cmdline_paths - .iter() - .map(|p| p.to_string_lossy().into_owned()) - .collect(), - secret_resolver: secret_resolver.clone(), - activity_tx: activity_tx.cloned(), - dynamic_credentials: dynamic_credentials.clone(), - token_grant_resolver: dynamic_credentials - .as_ref() - .map(|_| crate::l7::token_grant_injection::default_resolver()), - }; - let mut l7_activity_pending = false; - - // 4b. If the endpoint has L7 config, evaluate the request against - // L7 policy. The forward proxy handles exactly one request per - // connection (Connection: close), so a single evaluation suffices. - if let Some(route) = query_l7_route_snapshot(&opa_engine, &decision, &host_lc, port) - && !route.configs.is_empty() - { - if route.generation != forward_generation_guard.captured_generation() { - emit_l7_tunnel_close_after_policy_change( - &host_lc, - port, - miette::miette!( - "policy changed before forward L7 evaluation [expected_generation:{} current_generation:{}]", - forward_generation_guard.captured_generation(), - route.generation, - ), - ); - emit_activity_simple(activity_tx, true, "policy_stale"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} not permitted by policy"), - ), - ) - .await?; - return Ok(()); - } - let tunnel_engine = match opa_engine.clone_engine_for_tunnel(route.generation) { - Ok(engine) => engine, - Err(e) => { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - emit_activity_simple(activity_tx, true, "policy_stale"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} not permitted by policy"), - ), - ) - .await?; - return Ok(()); - } - }; - - // Canonicalize the request-target. The canonical form is fed to OPA - // AND reassigned to the outer `path` variable so the later call to - // `rewrite_forward_request` writes canonical bytes to the upstream. - // This closes the policy/upstream parser-differential at this site; - // without this reassignment, OPA would evaluate the canonical form - // while the upstream re-normalizes the raw input and dispatches on a - // potentially different path. - let canonicalize_options = crate::l7::path::CanonicalizeOptions { - allow_encoded_slash: route - .configs - .iter() - .any(|snapshot| snapshot.config.allow_encoded_slash), - ..Default::default() - }; - let query_params = - match crate::l7::path::canonicalize_request_target(&path, &canonicalize_options) { - Ok((canon, query)) => { - upstream_target = match query.as_deref() { - Some(raw_query) if !raw_query.is_empty() => { - format!("{}?{raw_query}", canon.path) - } - _ => canon.path.clone(), - }; - let params = query - .as_deref() - .map_or_else(std::collections::HashMap::new, |q| { - crate::l7::rest::parse_query_params(q).unwrap_or_default() - }); - path = canon.path; - params - } - Err(e) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!( - "FORWARD_L7 rejecting non-canonical request-target: {e}" - )) - .build(); - ocsf_emit!(event); - emit_activity_simple(activity_tx, true, "l7_parse_rejection"); - respond( - client, - &build_json_error_response( - 400, - "Bad Request", - "invalid_request_target", - "request-target must be canonical", - ), - ) - .await?; - return Ok(()); - } - }; - let Some(l7_config) = select_l7_config_for_path(&route.configs, &path) else { - emit_activity_simple(activity_tx, true, "l7_policy"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} did not match an L7 endpoint path"), - ), - ) - .await?; - return Ok(()); - }; - forward_websocket_request = - crate::l7::rest::request_is_websocket_upgrade(&forward_request_bytes); - websocket_extensions = crate::l7::relay::websocket_extension_mode(&l7_config.config); - request_body_credential_rewrite = l7_config.config.protocol == crate::l7::L7Protocol::Rest - && l7_config.config.request_body_credential_rewrite; - forward_upgrade_config = Some(l7_config.config.clone()); - forward_upgrade_target = path.clone(); - forward_upgrade_query_params = query_params.clone(); - let graphql = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { - let header_end = forward_request_bytes - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map_or(forward_request_bytes.len(), |p| p + 4); - let header_str = std::str::from_utf8(&forward_request_bytes[..header_end]) - .map_err(|_| miette::miette!("Forward GraphQL headers contain invalid UTF-8"))?; - let body_length = crate::l7::rest::parse_body_length(header_str)?; - let mut graphql_request = crate::l7::provider::L7Request { - action: method.to_string(), - target: path.clone(), - query_params: query_params.clone(), - raw_header: forward_request_bytes, - body_length, - }; - let info = match crate::l7::graphql::inspect_graphql_request( - client, - &mut graphql_request, - l7_config.config.graphql_max_body_bytes, - ) - .await - { - Ok(info) => info, - Err(e) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!("FORWARD_GRAPHQL_L7 request rejected: {e}")) - .build(); - ocsf_emit!(event); - emit_activity_simple(activity_tx, true, "l7_parse_rejection"); - respond( - client, - &build_json_error_response( - 400, - "Bad Request", - "invalid_graphql_request", - &format!("GraphQL request rejected before policy evaluation: {e}"), - ), - ) - .await?; - return Ok(()); - } - }; - forward_request_bytes = graphql_request.raw_header; - Some(info) - } else { - None - }; - let jsonrpc = if l7_config.config.protocol == crate::l7::L7Protocol::JsonRpc { - let header_end = forward_request_bytes - .windows(4) - .position(|w| w == b"\r\n\r\n") - .map_or(forward_request_bytes.len(), |p| p + 4); - let header_str = std::str::from_utf8(&forward_request_bytes[..header_end]) - .map_err(|_| miette::miette!("Forward JSON-RPC headers contain invalid UTF-8"))?; - let body_length = crate::l7::rest::parse_body_length(header_str)?; - let mut jsonrpc_request = crate::l7::provider::L7Request { - action: method.to_string(), - target: path.clone(), - query_params: query_params.clone(), - raw_header: forward_request_bytes, - body_length, - }; - let body = match crate::l7::http::read_body_for_inspection( - client, - &mut jsonrpc_request, - l7_config.config.json_rpc_max_body_bytes, - ) - .await - { - Ok(body) => body, - Err(e) => { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!("FORWARD_JSONRPC_L7 request rejected: {e}")) - .build(); - ocsf_emit!(event); - emit_activity_simple(activity_tx, true, "l7_parse_rejection"); - respond( - client, - &build_json_error_response( - 400, - "Bad Request", - "invalid_jsonrpc_request", - &format!("JSON-RPC request rejected before policy evaluation: {e}"), - ), - ) - .await?; - return Ok(()); - } - }; - forward_request_bytes = jsonrpc_request.raw_header; - Some(crate::l7::jsonrpc::parse_jsonrpc_body(&body)) - } else { - None - }; - let request_info = crate::l7::L7RequestInfo { - action: method.to_string(), - target: path.clone(), - query_params, - graphql, - jsonrpc, - }; - - let parse_error_reason = l7_parse_error_reason(&request_info); - let force_deny = parse_error_reason.is_some(); - let (allowed, reason) = parse_error_reason.map_or_else( - || { - crate::l7::relay::evaluate_l7_request(&tunnel_engine, &l7_ctx, &request_info) - .unwrap_or_else(|e| { - let event = NetworkActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .message(format!("L7 eval failed, denying request: {e}")) - .build(); - ocsf_emit!(event); - (false, format!("L7 evaluation error: {e}")) - }) - }, - |reason| (false, reason), - ); - - let decision_str = match (allowed, l7_config.config.enforcement) { - (_, _) if force_deny => "deny", - (true, _) => "allow", - (false, crate::l7::EnforcementMode::Audit) => "audit", - (false, crate::l7::EnforcementMode::Enforce) => "deny", - }; - - { - let (action_id, disposition_id, severity) = match decision_str { - "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), - "allow" | "audit" => ( - ActionId::Allowed, - DispositionId::Allowed, - SeverityId::Informational, - ), - _ => ( - ActionId::Other, - DispositionId::Other, - SeverityId::Informational, - ), - }; - let engine_type = match l7_config.config.protocol { - crate::l7::L7Protocol::Graphql => "l7-graphql", - crate::l7::L7Protocol::JsonRpc => "l7-jsonrpc", - _ => "l7", - }; - let log_message = request_info.jsonrpc.as_ref().map_or_else( - || { - let message_prefix = - if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { - "FORWARD_GRAPHQL_L7" - } else { - "FORWARD_L7" - }; - format!( - "{message_prefix} {decision_str} {method} {host_lc}:{port}{path} reason={reason}" - ) - }, - |jsonrpc_info| { - let endpoint = format!("{host_lc}:{port}{path}"); - let params_sha256 = jsonrpc_info - .params_sha256() - .unwrap_or_else(|| "".to_string()); - crate::l7::relay::jsonrpc_log_message( - decision_str, - method, - &endpoint, - jsonrpc_info, - ¶ms_sha256, - tunnel_engine.captured_generation(), - &reason, - ) - }, - ); - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(action_id) - .disposition(disposition_id) - .severity(severity) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, engine_type) - .message(log_message) - .build(); - ocsf_emit!(event); - } - - let effectively_denied = force_deny - || (!allowed && l7_config.config.enforcement == crate::l7::EnforcementMode::Enforce); - - if effectively_denied { - emit_activity_simple(activity_tx, true, "l7_policy"); - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "forward-l7-deny", - ); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} denied by L7 policy: {reason}"), - ), - ) - .await?; - return Ok(()); - } - l7_activity_pending = true; - forward_tunnel_engine = Some(tunnel_engine); - } - - // 5. DNS resolution + SSRF defence (mirrors the CONNECT path logic). - // - If the host is a driver-injected host-gateway alias: bypass SSRF - // tiers and validate only against the trusted gateway IP. - // - If allowed_ips is set: validate resolved IPs against the allowlist - // (this is the SSRF override for private IP destinations). - // - If the endpoint is an exact declared hostname: allow private IPs, - // but still reject always-blocked addresses and control-plane ports. - // - Otherwise: reject internal IPs, allow public IPs through. - // When the policy host is already a literal IP address, treat it as - // implicitly allowed — the user explicitly declared the destination. - let mut raw_allowed_ips = query_allowed_ips(&opa_engine, &decision, &host_lc, port); - if raw_allowed_ips.is_empty() { - raw_allowed_ips = implicit_allowed_ips_for_ip_host(&host); - } - let exact_declared_endpoint_host = - query_exact_declared_endpoint_host(&opa_engine, &decision, &host_lc, port); - - // The trusted-gateway branch is the first path; reading it before the - // allowed_ips and default branches matches the policy decision narrative. - #[allow(clippy::if_not_else)] - let addrs = if is_host_gateway_alias(&host_lc) - && let Some(gw) = *trusted_host_gateway - { - // Trusted host-gateway path. Mirrors the CONNECT path logic. - match resolve_and_check_trusted_gateway(&host, port, gw, sandbox_entrypoint_pid).await { - Ok(addrs) => addrs, - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "ssrf") - .message(format!( - "FORWARD blocked: trusted-gateway check failed for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity_simple(activity_tx, true, "ssrf"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!("{method} {host_lc}:{port} blocked: trusted-gateway check failed"), - ), - ) - .await?; - return Ok(()); - } - } - } else if !raw_allowed_ips.is_empty() { - // allowed_ips mode: validate resolved IPs against CIDR allowlist. - match parse_allowed_ips(&raw_allowed_ips) { - Ok(nets) => { - match resolve_and_check_allowed_ips(&host, port, &nets, sandbox_entrypoint_pid) - .await - { - Ok(addrs) => addrs, - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "ssrf") - .message(format!( - "FORWARD blocked: allowed_ips check failed for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity_simple(activity_tx, true, "ssrf"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!( - "{method} {host_lc}:{port} blocked: allowed_ips check failed" - ), - ), - ) - .await?; - return Ok(()); - } - } - } - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "ssrf") - .message(format!( - "FORWARD blocked: invalid allowed_ips in policy for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity_simple(activity_tx, true, "ssrf"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!( - "{method} {host_lc}:{port} blocked: invalid allowed_ips in policy" - ), - ), - ) - .await?; - return Ok(()); - } - } - } else if exact_declared_endpoint_host { - // Exact declared hostname mode mirrors CONNECT: private resolved - // addresses are allowed for this operator-declared host:port, while - // always-blocked addresses and control-plane ports remain denied. - match resolve_and_check_declared_endpoint(&host, port, sandbox_entrypoint_pid).await { - Ok(addrs) => addrs, - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "ssrf") - .message(format!( - "FORWARD blocked: declared endpoint check failed for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!( - "{method} {host_lc}:{port} blocked: declared endpoint check failed" - ), - ), - ) - .await?; - return Ok(()); - } - } - } else { - // No allowed_ips: reject internal IPs, allow public IPs through. - match resolve_and_reject_internal(&host, port, sandbox_entrypoint_pid).await { - Ok(addrs) => addrs, - Err(reason) => { - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Denied) - .disposition(DispositionId::Blocked) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "ssrf") - .message(format!( - "FORWARD blocked: internal IP without allowed_ips for {host_lc}:{port}" - )) - .status_detail(&reason) - .build(); - ocsf_emit!(event); - } - emit_denial_simple( - denial_tx, - &host_lc, - port, - &binary_str, - &decision, - &reason, - "ssrf", - ); - emit_activity_simple(activity_tx, true, "ssrf"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "ssrf_denied", - &format!("{method} {host_lc}:{port} blocked: internal address"), - ), - ) - .await?; - return Ok(()); - } - } - }; - - if let Err(e) = forward_generation_guard.ensure_current() { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - emit_activity_simple(activity_tx, true, "policy_stale"); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} not permitted by policy"), - ), - ) - .await?; - return Ok(()); - } - - // 6. Connect upstream - let mut upstream = match TcpStream::connect(addrs.as_slice()).await { - Ok(s) => s, - Err(e) => { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Fail) - .severity(SeverityId::Low) - .status(StatusId::Failure) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .message(format!( - "FORWARD upstream connect failed for {host_lc}:{port}: {e}" - )) - .build(); - ocsf_emit!(event); - respond( - client, - &build_json_error_response( - 502, - "Bad Gateway", - "upstream_unreachable", - &format!("connection to {host_lc}:{port} failed"), - ), - ) - .await?; - return Ok(()); - } - }; - - // Log success - { - let event = HttpActivityBuilder::new(crate::ocsf_ctx()) - .activity(ActivityId::Other) - .action(ActionId::Allowed) - .disposition(DispositionId::Allowed) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .http_request(HttpRequest::new( - method, - OcsfUrl::new("http", &host_lc, &path, port), - )) - .dst_endpoint(Endpoint::from_domain(&host_lc, port)) - .src_endpoint(Endpoint::from_ip(peer_addr.ip(), peer_addr.port())) - .actor_process( - Process::from_bypass(&binary_str, &pid_str, &ancestors_str) - .with_cmd_line(&cmdline_str), - ) - .firewall_rule(policy_str, "opa") - .message(format!("FORWARD allowed {method} {host_lc}:{port}{path}")) - .build(); - ocsf_emit!(event); - } - emit_forward_success_activity(activity_tx, l7_activity_pending); - - forward_request_bytes = match inject_token_grant_for_forward_request( - method, - &upstream_target, - forward_request_bytes, - &l7_ctx, - ) - .await - { - Ok(bytes) => bytes, - Err(e) => { - warn!( - dst_host = %host_lc, - dst_port = port, - error = %e, - "token grant failed in forward proxy" - ); - respond( - client, - &build_json_error_response( - 502, - "Bad Gateway", - "token_grant_failed", - "dynamic token grant failed", - ), - ) - .await?; - return Ok(()); - } - }; - - // 9. Rewrite request and forward to upstream - let rewritten = match rewrite_forward_request( - &forward_request_bytes, - forward_request_bytes.len(), - &upstream_target, - secret_resolver.as_deref(), - request_body_credential_rewrite, - ) { - Ok(bytes) => bytes, - Err(e) => { - warn!( - dst_host = %host_lc, - dst_port = port, - error = %e, - "credential injection failed in forward proxy" - ); - respond( - client, - &build_json_error_response( - 500, - "Internal Server Error", - "credential_injection_failed", - "unresolved credential placeholder in request", - ), - ) - .await?; - return Ok(()); - } - }; - if let Err(e) = forward_generation_guard.ensure_current() { - emit_l7_tunnel_close_after_policy_change(&host_lc, port, e); - respond( - client, - &build_json_error_response( - 403, - "Forbidden", - "policy_denied", - &format!("{method} {host_lc}:{port}{path} not permitted by policy"), - ), - ) - .await?; - return Ok(()); - } - let outcome = relay_rewritten_forward_request( - method, - &path, - rewritten, - client, - &mut upstream, - ForwardRelayOptions { - generation_guard: &forward_generation_guard, - websocket_extensions, - secret_resolver: secret_resolver.as_deref(), - request_body_credential_rewrite, - }, - ) - .await?; - if let crate::l7::provider::RelayOutcome::Upgraded { - overflow, - websocket_permessage_deflate, - } = outcome - { - let mut upgrade_options = if let (Some(config), Some(engine)) = ( - forward_upgrade_config.as_ref(), - forward_tunnel_engine.as_ref(), - ) { - crate::l7::relay::upgrade_options( - config, - &l7_ctx, - forward_websocket_request, - &forward_upgrade_target, - &forward_upgrade_query_params, - Some(engine), - ) - } else { - crate::l7::relay::UpgradeRelayOptions { - websocket_request: forward_websocket_request, - ..Default::default() - } - }; - upgrade_options.websocket.permessage_deflate = websocket_permessage_deflate; - crate::l7::relay::handle_upgrade( - client, - &mut upstream, - overflow, - &host_lc, - port, - upgrade_options, - ) - .await?; - } - - Ok(()) -} - -fn parse_target(target: &str) -> Result<(String, u16)> { - let (host, port_str) = target - .split_once(':') - .ok_or_else(|| miette::miette!("CONNECT target missing port: {target}"))?; - let port: u16 = port_str - .parse() - .map_err(|_| miette::miette!("Invalid port in CONNECT target: {target}"))?; - Ok((host.to_string(), port)) -} - -async fn respond(client: &mut TcpStream, bytes: &[u8]) -> Result<()> { - client.write_all(bytes).await.into_diagnostic()?; - Ok(()) -} - -/// Build an HTTP error response with a JSON body. -/// -/// Returns bytes ready to write to the client socket. The body is a JSON -/// object with `error` and `detail` fields, matching the format used by the -/// L7 deny path in `l7/rest.rs`. -fn build_json_error_response(status: u16, status_text: &str, error: &str, detail: &str) -> Vec { - let body = serde_json::json!({ - "error": error, - "detail": detail, - }); - let body_str = body.to_string(); - format!( - "HTTP/1.1 {status} {status_text}\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\ - Connection: close\r\n\ - \r\n\ - {}", - body_str.len(), - body_str, - ) - .into_bytes() -} - -/// Check if a miette error represents a benign connection close. -/// -/// TLS handshake EOF, missing `close_notify`, connection resets, and broken -/// pipes are all normal lifecycle events for proxied connections — not worth -/// a WARN that interrupts the user's terminal. -fn is_benign_relay_error(err: &miette::Report) -> bool { - const BENIGN: &[&str] = &[ - "close_notify", - "tls handshake eof", - "connection reset", - "broken pipe", - "unexpected eof", - ]; - let msg = err.to_string().to_ascii_lowercase(); - BENIGN.iter().any(|pat| msg.contains(pat)) -} - -#[cfg(test)] -#[allow( - clippy::needless_raw_string_hashes, - clippy::iter_on_single_items, - clippy::needless_continue, - reason = "Test code: test fixtures and explicit control-flow markers are idiomatic in tests." -)] -mod tests { - use super::*; - use std::future::Future; - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; - use std::sync::Arc; - use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt}; - use tokio::net::{TcpListener, TcpStream}; - - fn websocket_l7_config( - protocol: crate::l7::L7Protocol, - websocket_credential_rewrite: bool, - ) -> crate::l7::L7EndpointConfig { - crate::l7::L7EndpointConfig { - protocol, - path: "/**".to_string(), - tls: crate::l7::TlsMode::Auto, - enforcement: crate::l7::EnforcementMode::Enforce, - graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, - json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, - allow_encoded_slash: false, - websocket_credential_rewrite, - request_body_credential_rewrite: false, - websocket_graphql_policy: false, - } - } - - #[test] - fn connect_activity_is_skipped_when_l7_will_count_the_request() { - let (tx, mut rx) = mpsc::channel(4); - let activity_tx = Some(tx); - let l7_route = L7RouteSnapshot { - configs: vec![L7ConfigSnapshot { - config: websocket_l7_config(crate::l7::L7Protocol::Rest, false), - }], - generation: 1, - }; - let l4_route = L7RouteSnapshot { - configs: Vec::new(), - generation: 1, - }; - - emit_connect_activity_if_l4_only(&activity_tx, Some(&l7_route)); - assert!( - rx.try_recv().is_err(), - "L7-inspected CONNECT should not emit an extra L4 activity event" - ); - - emit_connect_activity_if_l4_only(&activity_tx, Some(&l4_route)); - let event = rx.try_recv().expect("L4-only CONNECT should emit activity"); - assert!(!event.denied); - assert_eq!(event.deny_group, "unknown"); - - emit_connect_activity_if_l4_only(&activity_tx, None); - let event = rx - .try_recv() - .expect("CONNECT without an L7 route should emit activity"); - assert!(!event.denied); - assert_eq!(event.deny_group, "unknown"); - } - - #[test] - fn l7_parse_error_reason_includes_jsonrpc_errors() { - let request_info = crate::l7::L7RequestInfo { - action: "POST".to_string(), - target: "/mcp".to_string(), - query_params: std::collections::HashMap::new(), - graphql: None, - jsonrpc: Some(crate::l7::jsonrpc::JsonRpcRequestInfo { - calls: Vec::new(), - is_batch: false, - error: Some("ambiguous dotted params key 'arguments.scope'".to_string()), - }), - }; - - let reason = l7_parse_error_reason(&request_info).expect("JSON-RPC parse error"); - - assert_eq!( - reason, - "JSON-RPC request rejected: ambiguous dotted params key 'arguments.scope'" - ); - } - - #[test] - fn forward_l7_allowed_activity_is_deferred_until_after_ssrf() { - let (tx, mut rx) = mpsc::channel(4); - let activity_tx = Some(tx); - - let l7_activity_pending = true; - assert!( - rx.try_recv().is_err(), - "allowed L7 evaluation must not emit activity before SSRF succeeds" - ); - - emit_activity_simple(activity_tx.as_ref(), true, "ssrf"); - let event = rx - .try_recv() - .expect("SSRF denial should emit the request activity"); - assert!(event.denied); - assert_eq!(event.deny_group, "ssrf"); - assert!( - rx.try_recv().is_err(), - "SSRF-denied forward request must not also emit allowed L7 activity" - ); - - emit_forward_success_activity(activity_tx.as_ref(), l7_activity_pending); - let event = rx - .try_recv() - .expect("L7 activity should emit after SSRF succeeds"); - assert!(!event.denied); - assert_eq!(event.deny_group, "l7_policy"); - } - - #[test] - fn forward_success_activity_uses_unknown_without_l7() { - let (tx, mut rx) = mpsc::channel(4); - let activity_tx = Some(tx); - - emit_forward_success_activity(activity_tx.as_ref(), false); - let event = rx - .try_recv() - .expect("non-L7 forward success should emit activity"); - assert!(!event.denied); - assert_eq!(event.deny_group, "unknown"); - } - - fn forward_test_guard() -> PolicyGenerationGuard { - let policy = include_str!("../data/sandbox-policy.rego"); - let policy_data = "network_policies: {}\n"; - let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); - engine - .generation_guard(engine.current_generation()) - .unwrap() - } - - async fn relay_forward_request_and_capture( - method: &str, - path: &str, - raw: &[u8], - resolver: Option<&SecretResolver>, - request_body_credential_rewrite: bool, - ) -> Result { - let guard = forward_test_guard(); - let rewritten = rewrite_forward_request( - raw, - raw.len(), - path, - resolver, - request_body_credential_rewrite, - ) - .map_err(|e| miette::miette!("{e}"))?; - let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - - let upstream_task = tokio::spawn(async move { - let mut buf = vec![0u8; 8192]; - let mut total = 0usize; - let mut expected_total = None; - loop { - let n = upstream_side.read(&mut buf[total..]).await.unwrap(); - if n == 0 { - break; - } - total += n; - if expected_total.is_none() - && let Some(end) = buf[..total].windows(4).position(|w| w == b"\r\n\r\n") - { - let header_end = end + 4; - let headers = String::from_utf8_lossy(&buf[..header_end]); - let len = headers - .lines() - .find_map(|line| { - let (name, value) = line.split_once(':')?; - name.eq_ignore_ascii_case("content-length") - .then(|| value.trim().parse::().ok()) - .flatten() - }) - .unwrap_or(0); - expected_total = Some(header_end + len); - } - if expected_total.is_some_and(|expected| total >= expected) { - break; - } - } - upstream_side - .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") - .await - .unwrap(); - upstream_side.flush().await.unwrap(); - String::from_utf8_lossy(&buf[..total]).to_string() - }); - - relay_rewritten_forward_request( - method, - path, - rewritten, - &mut proxy_to_client, - &mut proxy_to_upstream, - ForwardRelayOptions { - generation_guard: &guard, - websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, - secret_resolver: resolver, - request_body_credential_rewrite, - }, - ) - .await?; - - upstream_task - .await - .map_err(|e| miette::miette!("upstream task failed: {e}")) - } - - fn forward_token_grant_context( - resolver_response: std::result::Result<&str, &str>, - ) -> ( - crate::l7::relay::L7EvalContext, - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture, - ) { - let provider_key = "api.example.test\t8080\t/v1/**\tprovider:access_token"; - let fixture = match resolver_response { - Ok(token) => { - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::success( - provider_key, - token, - ) - } - Err(error) => { - crate::l7::token_grant_injection::test_support::TokenGrantTestFixture::failure( - provider_key, - error, - ) - } - }; - let ctx = crate::l7::relay::L7EvalContext { - host: "api.example.test".into(), - port: 8080, - policy_name: "rest_api".into(), - binary_path: "/usr/bin/curl".into(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: Some(fixture.dynamic_credentials()), - token_grant_resolver: Some(fixture.resolver()), - }; - - (ctx, fixture) - } - - fn authorization_header_count(headers: &str) -> usize { - headers - .lines() - .filter(|line| { - line.split_once(':') - .is_some_and(|(name, _)| name.eq_ignore_ascii_case("authorization")) - }) - .count() - } - - fn forward_websocket_policy_parts( - data: &str, - host: &str, - port: u16, - path: &str, - policy_name: &str, - ) -> ( - crate::l7::L7EndpointConfig, - crate::opa::TunnelPolicyEngine, - crate::l7::relay::L7EvalContext, - ) { - let policy = include_str!("../data/sandbox-policy.rego"); - let engine = OpaEngine::from_strings(policy, data).unwrap(); - let decision = ConnectDecision { - action: NetworkAction::Allow { - matched_policy: Some(policy_name.to_string()), - }, - generation: engine.current_generation(), - binary: Some(PathBuf::from("/usr/bin/node")), - binary_pid: None, - ancestors: vec![], - cmdline_paths: vec![], - }; - let route = - query_l7_route_snapshot(&engine, &decision, host, port).expect("L7 route should match"); - let config = select_l7_config_for_path(&route.configs, path) - .expect("path-specific L7 config should match") - .config - .clone(); - let tunnel_engine = engine - .clone_engine_for_tunnel(route.generation) - .expect("tunnel engine"); - let ctx = crate::l7::relay::L7EvalContext { - host: host.to_string(), - port, - policy_name: policy_name.to_string(), - binary_path: "/usr/bin/node".to_string(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - (config, tunnel_engine, ctx) - } - - async fn read_http_headers(reader: &mut R) -> Vec { - let mut bytes = Vec::new(); - let mut chunk = [0u8; 256]; - loop { - let n = - tokio::time::timeout(std::time::Duration::from_secs(1), reader.read(&mut chunk)) - .await - .expect("HTTP headers should arrive") - .expect("header read should succeed"); - assert!(n > 0, "stream closed before HTTP headers"); - bytes.extend_from_slice(&chunk[..n]); - if bytes.windows(4).any(|w| w == b"\r\n\r\n") { - return bytes; - } - } - } - - fn masked_text_frame(payload: &[u8]) -> Vec { - let mask = [0x11, 0x22, 0x33, 0x44]; - assert!( - payload.len() <= 125, - "test helper only supports small frames" - ); - let payload_len = u8::try_from(payload.len()).expect("small frame length"); - let mut frame = vec![0x81, 0x80 | payload_len]; - frame.extend_from_slice(&mask); - frame.extend( - payload - .iter() - .enumerate() - .map(|(idx, byte)| byte ^ mask[idx % 4]), - ); - frame - } - - async fn forward_websocket_denied_after_upgrade( - config: crate::l7::L7EndpointConfig, - tunnel_engine: crate::opa::TunnelPolicyEngine, - ctx: crate::l7::relay::L7EvalContext, - path: &str, - payload: &str, - ) -> (miette::Report, Vec) { - let host = ctx.host.clone(); - let port = ctx.port; - let raw = format!( - "GET http://{host}{path} HTTP/1.1\r\n\ - Host: {host}\r\n\ - Upgrade: websocket\r\n\ - Connection: Upgrade\r\n\ - Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ - Sec-WebSocket-Version: 13\r\n\r\n" - ); - let rewritten = rewrite_forward_request(raw.as_bytes(), raw.len(), path, None, false) - .expect("forward websocket request should rewrite to origin form"); - let websocket_extensions = crate::l7::relay::websocket_extension_mode(&config); - let target = path.to_string(); - let query_params = std::collections::HashMap::new(); - let (mut proxy_to_upstream, mut upstream) = tokio::io::duplex(8192); - let (mut app, mut proxy_to_client) = tokio::io::duplex(8192); - - let relay = tokio::spawn(async move { - let guard = tunnel_engine.generation_guard(); - let outcome = relay_rewritten_forward_request( - "GET", - &target, - rewritten, - &mut proxy_to_client, - &mut proxy_to_upstream, - ForwardRelayOptions { - generation_guard: guard, - websocket_extensions, - secret_resolver: None, - request_body_credential_rewrite: false, - }, - ) - .await?; - if let crate::l7::provider::RelayOutcome::Upgraded { - overflow, - websocket_permessage_deflate, - } = outcome - { - let mut options = crate::l7::relay::upgrade_options( - &config, - &ctx, - true, - &target, - &query_params, - Some(&tunnel_engine), - ); - options.websocket.permessage_deflate = websocket_permessage_deflate; - crate::l7::relay::handle_upgrade( - &mut proxy_to_client, - &mut proxy_to_upstream, - overflow, - &host, - port, - options, - ) - .await?; - } - Ok::<(), miette::Report>(()) - }); - - let forwarded_headers = read_http_headers(&mut upstream).await; - let forwarded_headers = String::from_utf8_lossy(&forwarded_headers); - assert!(forwarded_headers.starts_with(&format!("GET {path} HTTP/1.1\r\n"))); - assert!(forwarded_headers.contains("Upgrade: websocket\r\n")); - - upstream - .write_all( - b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n\r\n", - ) - .await - .unwrap(); - - let response = read_http_headers(&mut app).await; - assert!(String::from_utf8_lossy(&response).contains("101 Switching Protocols")); - - app.write_all(&masked_text_frame(payload.as_bytes())) - .await - .unwrap(); - - let err = tokio::time::timeout(std::time::Duration::from_secs(1), relay) - .await - .expect("websocket relay should fail closed after denied frame") - .expect("relay task should not panic") - .expect_err("denied websocket frame should fail the forward relay"); - - let mut leaked = Vec::new(); - tokio::time::timeout( - std::time::Duration::from_secs(1), - upstream.read_to_end(&mut leaked), - ) - .await - .expect("upstream side should close") - .expect("upstream read should succeed"); - (err, leaked) - } - - #[test] - fn forward_websocket_upgrade_options_enable_native_policy_context() { - let (_, resolver) = SecretResolver::from_provider_env( - [("DISCORD_BOT_TOKEN".to_string(), "discord-real".to_string())] - .into_iter() - .collect(), - ); - let resolver = resolver.map(Arc::new); - let policy = include_str!("../data/sandbox-policy.rego"); - let policy_data = "network_policies: {}\n"; - let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); - let tunnel_engine = engine - .clone_engine_for_tunnel(engine.current_generation()) - .unwrap(); - let ctx = crate::l7::relay::L7EvalContext { - host: "gateway.example.test".to_string(), - port: 80, - policy_name: "ws_api".to_string(), - binary_path: "/usr/bin/node".to_string(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: resolver, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - let query_params = std::collections::HashMap::new(); - - let extensions = crate::l7::relay::websocket_extension_mode(&websocket_l7_config( - crate::l7::L7Protocol::Websocket, - true, - )); - let options = crate::l7::relay::upgrade_options( - &websocket_l7_config(crate::l7::L7Protocol::Websocket, true), - &ctx, - true, - "/ws", - &query_params, - Some(&tunnel_engine), - ); - - assert_eq!( - extensions, - crate::l7::rest::WebSocketExtensionMode::PermessageDeflate - ); - assert!(options.websocket.credential_rewrite); - assert!(options.secret_resolver.is_some()); - assert!(options.engine.is_some()); - assert!(options.ctx.is_some()); - assert!(matches!( - options.websocket.message_policy, - crate::l7::relay::WebSocketMessagePolicy::Transport - )); - } - - #[test] - fn forward_websocket_upgrade_options_preserve_rest_without_rewrite() { - let ctx = crate::l7::relay::L7EvalContext { - host: "gateway.example.test".to_string(), - port: 80, - policy_name: "rest_api".to_string(), - binary_path: "/usr/bin/node".to_string(), - ancestors: vec![], - cmdline_paths: vec![], - secret_resolver: None, - activity_tx: None, - dynamic_credentials: None, - token_grant_resolver: None, - }; - let query_params = std::collections::HashMap::new(); - let config = websocket_l7_config(crate::l7::L7Protocol::Rest, false); - let extensions = crate::l7::relay::websocket_extension_mode(&config); - let options = - crate::l7::relay::upgrade_options(&config, &ctx, true, "/ws", &query_params, None); - - assert_eq!( - extensions, - crate::l7::rest::WebSocketExtensionMode::Preserve - ); - assert!(!options.websocket.credential_rewrite); - assert!(options.secret_resolver.is_none()); - assert!(options.engine.is_none()); - assert!(options.ctx.is_none()); - assert!(matches!( - options.websocket.message_policy, - crate::l7::relay::WebSocketMessagePolicy::None - )); - } - - #[tokio::test] - async fn forward_websocket_upgrade_blocks_text_frame_by_policy() { - let data = r#" -network_policies: - ws_api: - name: ws_api - endpoints: - - host: gateway.example.test - port: 80 - path: "/ws" - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/ws" - - allow: - method: WEBSOCKET_TEXT - path: "/ws" - deny_rules: - - method: WEBSOCKET_TEXT - path: "/ws" - binaries: - - { path: /usr/bin/node } -"#; - let (config, tunnel_engine, ctx) = - forward_websocket_policy_parts(data, "gateway.example.test", 80, "/ws", "ws_api"); - - let (err, leaked) = forward_websocket_denied_after_upgrade( - config, - tunnel_engine, - ctx, - "/ws", - r#"{"type":"unsafe"}"#, - ) - .await; - - assert!(err.to_string().contains("websocket text message denied")); - assert!( - leaked.is_empty(), - "denied forward-proxy WebSocket text frames must not reach upstream" - ); - } - - #[tokio::test] - async fn forward_graphql_websocket_upgrade_blocks_unallowed_operation() { - let data = r#" -network_policies: - graphql_ws: - name: graphql_ws - endpoints: - - host: gateway.example.test - port: 80 - path: "/graphql" - protocol: websocket - enforcement: enforce - rules: - - allow: - method: GET - path: "/graphql" - - allow: - operation_type: query - fields: [viewer] - deny_rules: - - operation_type: query - fields: [admin] - binaries: - - { path: /usr/bin/node } -"#; - let (config, tunnel_engine, ctx) = forward_websocket_policy_parts( - data, - "gateway.example.test", - 80, - "/graphql", - "graphql_ws", - ); - assert!( - config.websocket_graphql_policy, - "operation rules should enable GraphQL-over-WebSocket inspection" - ); - - let (err, leaked) = forward_websocket_denied_after_upgrade( - config, - tunnel_engine, - ctx, - "/graphql", - r#"{"id":"1","type":"subscribe","payload":{"query":"query { admin }"}}"#, - ) - .await; - - assert!(err.to_string().contains("websocket GraphQL message denied")); - assert!( - leaked.is_empty(), - "denied forward-proxy GraphQL WebSocket operations must not reach upstream" - ); - } - - #[test] - fn l7_route_selection_prefers_path_specific_graphql_endpoint() { - let configs = vec![ - L7ConfigSnapshot { - config: crate::l7::L7EndpointConfig { - protocol: crate::l7::L7Protocol::Rest, - path: "/**".to_string(), - tls: crate::l7::TlsMode::Auto, - enforcement: crate::l7::EnforcementMode::Enforce, - graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, - json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, - allow_encoded_slash: false, - websocket_credential_rewrite: false, - request_body_credential_rewrite: false, - websocket_graphql_policy: false, - }, - }, - L7ConfigSnapshot { - config: crate::l7::L7EndpointConfig { - protocol: crate::l7::L7Protocol::Graphql, - path: "/graphql".to_string(), - tls: crate::l7::TlsMode::Auto, - enforcement: crate::l7::EnforcementMode::Enforce, - graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, - json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, - allow_encoded_slash: false, - websocket_credential_rewrite: false, - request_body_credential_rewrite: false, - websocket_graphql_policy: false, - }, - }, - ]; - - let selected = - select_l7_config_for_path(&configs, "/graphql").expect("expected path-specific route"); - assert_eq!(selected.config.protocol, crate::l7::L7Protocol::Graphql); - - let selected = - select_l7_config_for_path(&configs, "/repos/org/repo").expect("expected REST route"); - assert_eq!(selected.config.protocol, crate::l7::L7Protocol::Rest); - } - - // -- is_internal_ip: IPv4 -- - - #[test] - fn test_rejects_ipv4_loopback() { - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)))); - } - - #[test] - fn test_rejects_ipv4_private_10() { - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(10, 255, 255, 255)))); - } - - #[test] - fn test_rejects_ipv4_private_172_16() { - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 31, 255, 255)))); - } - - #[test] - fn test_rejects_ipv4_private_192_168() { - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( - 192, 168, 255, 255 - )))); - } - - #[test] - fn test_rejects_ipv4_link_local_metadata() { - // Cloud metadata endpoint - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( - 169, 254, 169, 254 - )))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(169, 254, 0, 1)))); - } - - #[test] - fn test_rejects_ipv4_unspecified() { - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))); - } - - #[test] - fn test_rejects_ipv4_cgnat() { - // 100.64.0.0/10 — CGNAT / shared address space (RFC 6598) - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 64, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 100, 50, 3)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new( - 100, 127, 255, 255 - )))); - // Just outside the /10 boundary - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(100, 128, 0, 1)))); - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new( - 100, 63, 255, 255 - )))); - } - - #[test] - fn test_rejects_ipv4_special_use_ranges() { - // 192.0.0.0/24 — IETF protocol assignments - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(192, 0, 0, 1)))); - // 198.18.0.0/15 — benchmarking - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 18, 0, 1)))); - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 19, 255, 255)))); - // 198.51.100.0/24 — TEST-NET-2 - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(198, 51, 100, 1)))); - // 203.0.113.0/24 — TEST-NET-3 - assert!(is_internal_ip(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1)))); - } - - #[test] - fn test_rejects_ipv6_mapped_cgnat() { - // ::ffff:100.64.0.1 should be caught via IPv4-mapped unwrapping - let v6 = Ipv4Addr::new(100, 64, 0, 1).to_ipv6_mapped(); - assert!(is_internal_ip(IpAddr::V6(v6))); - } - - #[test] - fn test_allows_ipv4_public() { - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)))); - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)))); - } - - #[test] - fn test_allows_ipv4_non_private_172() { - // 172.32.0.0 is outside the 172.16/12 private range - assert!(!is_internal_ip(IpAddr::V4(Ipv4Addr::new(172, 32, 0, 1)))); - } - - // -- is_internal_ip: IPv6 -- - - #[test] - fn test_rejects_ipv6_loopback() { - assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))); - } - - #[test] - fn test_rejects_ipv6_unspecified() { - assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED))); - } - - #[test] - fn test_rejects_ipv6_link_local() { - // fe80::1 - assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::new( - 0xfe80, 0, 0, 0, 0, 0, 0, 1 - )))); - } - - #[test] - fn test_rejects_ipv6_unique_local_address() { - // fdc4:f303:9324::254 - assert!(is_internal_ip(IpAddr::V6(Ipv6Addr::new( - 0xfdc4, 0xf303, 0x9324, 0, 0, 0, 0, 0x0254 - )))); - } - - #[test] - fn test_rejects_ipv4_mapped_ipv6_private() { - // ::ffff:10.0.0.1 - let v6 = Ipv4Addr::new(10, 0, 0, 1).to_ipv6_mapped(); - assert!(is_internal_ip(IpAddr::V6(v6))); - } - - #[test] - fn test_rejects_ipv4_mapped_ipv6_loopback() { - // ::ffff:127.0.0.1 - let v6 = Ipv4Addr::LOCALHOST.to_ipv6_mapped(); - assert!(is_internal_ip(IpAddr::V6(v6))); - } - - #[test] - fn test_rejects_ipv4_mapped_ipv6_link_local() { - // ::ffff:169.254.169.254 - let v6 = Ipv4Addr::new(169, 254, 169, 254).to_ipv6_mapped(); - assert!(is_internal_ip(IpAddr::V6(v6))); - } - - #[test] - fn test_allows_ipv6_public() { - // 2001:4860:4860::8888 (Google DNS) - assert!(!is_internal_ip(IpAddr::V6(Ipv6Addr::new( - 0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888 - )))); - } - - #[test] - fn test_allows_ipv4_mapped_ipv6_public() { - // ::ffff:8.8.8.8 - let v6 = Ipv4Addr::new(8, 8, 8, 8).to_ipv6_mapped(); - assert!(!is_internal_ip(IpAddr::V6(v6))); - } - - // -- resolve_and_reject_internal -- - - #[test] - fn test_parse_hosts_file_for_host_handles_comments_invalid_rows_and_case() { - let contents = r#" - # comment - 192.168.1.105 searxng.local searxng - bad-ip ignored.local - 93.184.216.34 Example.Local # trailing comment - ::1 loopback.local - 192.168.1.105 searxng.local - "#; - - let result = parse_hosts_file_for_host(contents, "SEARXNG.LOCAL"); - assert_eq!(result, vec![IpAddr::V4(Ipv4Addr::new(192, 168, 1, 105))]); - - let public = parse_hosts_file_for_host(contents, "example.local"); - assert_eq!(public, vec![IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34))]); - } - - #[test] - fn test_resolve_from_hosts_file_contents_requires_exact_alias_match() { - let contents = "192.168.1.105 searxng.local\n"; - - assert!( - resolve_from_hosts_file_contents(contents, "searxng", 8080).is_empty(), - "partial alias match should not resolve" - ); - - let result = resolve_from_hosts_file_contents(contents, "searxng.local", 8080); - assert_eq!( - result, - vec![SocketAddr::new( - IpAddr::V4(Ipv4Addr::new(192, 168, 1, 105)), - 8080 - )] - ); - } - - #[test] - fn test_resolve_from_hosts_file_contents_public_ip_passes_default_ssrf_check() { - let addrs = - resolve_from_hosts_file_contents("93.184.216.34 example.local\n", "example.local", 80); - assert!(reject_internal_resolved_addrs("example.local", &addrs).is_ok()); - } - - #[test] - fn test_resolve_from_hosts_file_contents_private_ip_requires_allowed_ips() { - let addrs = resolve_from_hosts_file_contents( - "192.168.1.105 searxng.local\n", - "searxng.local", - 8080, - ); - - let err = reject_internal_resolved_addrs("searxng.local", &addrs).unwrap_err(); - assert!( - err.contains("internal address"), - "expected private hosts-file resolution to remain blocked: {err}" - ); - - let nets = parse_allowed_ips(&["192.168.1.105/32".to_string()]).unwrap(); - assert!( - validate_allowed_ips_for_resolved_addrs("searxng.local", 8080, &addrs, &nets).is_ok() - ); - } - - #[test] - fn test_declared_endpoint_private_hosts_file_resolution_allowed() { - let addrs = resolve_from_hosts_file_contents( - "192.168.1.105 searxng.local\n", - "searxng.local", - 8080, - ); - - assert!(validate_declared_endpoint_resolved_addrs("searxng.local", 8080, &addrs).is_ok()); - } - - #[test] - fn test_declared_endpoint_loopback_stays_blocked() { - let addrs = - resolve_from_hosts_file_contents("127.0.0.1 loopback.local\n", "loopback.local", 80); - - let err = - validate_declared_endpoint_resolved_addrs("loopback.local", 80, &addrs).unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected loopback to stay blocked: {err}" - ); - } - - #[test] - fn test_declared_endpoint_link_local_stays_blocked() { - let addrs = resolve_from_hosts_file_contents( - "169.254.169.254 metadata.local\n", - "metadata.local", - 80, - ); - - let err = - validate_declared_endpoint_resolved_addrs("metadata.local", 80, &addrs).unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected link-local to stay blocked: {err}" - ); - } - - #[test] - fn test_declared_endpoint_blocks_control_plane_ports() { - let addrs = - resolve_from_hosts_file_contents("10.0.0.5 kube-api.local\n", "kube-api.local", 6443); - - let err = - validate_declared_endpoint_resolved_addrs("kube-api.local", 6443, &addrs).unwrap_err(); - assert!( - err.contains("blocked control-plane port"), - "expected control-plane port to stay blocked: {err}" - ); - } - - #[test] - fn test_resolve_from_hosts_file_contents_always_blocked_ip_stays_blocked() { - let addrs = - resolve_from_hosts_file_contents("127.0.0.1 loopback.local\n", "loopback.local", 80); - let nets = vec!["127.0.0.0/8".parse::().unwrap()]; - let err = validate_allowed_ips_for_resolved_addrs("loopback.local", 80, &addrs, &nets) - .unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected always-blocked hosts-file resolution to stay blocked: {err}" - ); - } - - #[test] - fn test_resolve_from_hosts_file_contents_returns_empty_without_match() { - let result = - resolve_from_hosts_file_contents("192.168.1.105 searxng.local\n", "missing.local", 80); - assert!(result.is_empty()); - } - - // -- is_host_gateway_alias -- - - #[test] - fn test_is_host_gateway_alias_recognises_known_aliases() { - assert!(is_host_gateway_alias("host.openshell.internal")); - assert!(is_host_gateway_alias("host.containers.internal")); - assert!(is_host_gateway_alias("host.docker.internal")); - } - - #[test] - fn test_is_host_gateway_alias_is_case_insensitive() { - assert!(is_host_gateway_alias("HOST.OPENSHELL.INTERNAL")); - assert!(is_host_gateway_alias("Host.Containers.Internal")); - assert!(is_host_gateway_alias("HOST.DOCKER.INTERNAL")); - } - - #[test] - fn test_is_host_gateway_alias_rejects_unknown_hosts() { - assert!(!is_host_gateway_alias("api.example.com")); - assert!(!is_host_gateway_alias("host.openshell.internal.evil.com")); - assert!(!is_host_gateway_alias("evil.host.openshell.internal")); - assert!(!is_host_gateway_alias("openshell.internal")); - assert!(!is_host_gateway_alias("")); - } - - // -- is_cloud_metadata_ip -- - - #[test] - fn test_is_cloud_metadata_ip_blocks_known_metadata_ip() { - assert!(is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( - 169, 254, 169, 254 - )))); - } - - #[test] - fn test_is_cloud_metadata_ip_allows_other_link_local() { - // The pasta gateway address on this test host — not a metadata IP. - assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( - 169, 254, 1, 2 - )))); - assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( - 169, 254, 0, 1 - )))); - } - - #[test] - fn test_is_cloud_metadata_ip_allows_private_and_public() { - assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( - 10, 0, 0, 1 - )))); - assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new( - 192, 168, 1, 1 - )))); - assert!(!is_cloud_metadata_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); - } - - #[test] - fn test_is_cloud_metadata_ip_blocks_ipv4_mapped_metadata() { - // ::ffff:169.254.169.254 is the IPv4-mapped IPv6 representation of the - // AWS/GCP/Azure IMDS endpoint. is_link_local_ip() recognizes it as - // link-local, so is_cloud_metadata_ip() must also catch it — otherwise - // the trusted-gateway exemption would be granted to the metadata service. - let mapped = Ipv4Addr::new(169, 254, 169, 254).to_ipv6_mapped(); - assert!( - is_cloud_metadata_ip(IpAddr::V6(mapped)), - "::ffff:169.254.169.254 must be recognized as cloud metadata" - ); - } - - #[test] - fn test_is_cloud_metadata_ip_allows_other_ipv4_mapped_link_local() { - // Other IPv4-mapped link-local addresses are NOT metadata. - let mapped = Ipv4Addr::new(169, 254, 1, 2).to_ipv6_mapped(); - assert!( - !is_cloud_metadata_ip(IpAddr::V6(mapped)), - "::ffff:169.254.1.2 should not be flagged as cloud metadata" - ); - } - - // -- detect_trusted_host_gateway -- - - #[test] - fn test_detect_trusted_host_gateway_returns_ip_from_hosts_content() { - // We test the underlying parser directly since detect_trusted_host_gateway - // reads the real /etc/hosts. The production code composes these same primitives. - let contents = "169.254.1.2\thost.openshell.internal host.containers.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); - } - - #[test] - fn test_detect_trusted_host_gateway_ignores_cloud_metadata_ip() { - // Simulate a /etc/hosts where the driver injected the cloud metadata IP — - // this should be caught and suppressed. - let contents = "169.254.169.254\thost.openshell.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254))]); - // is_cloud_metadata_ip should flag it, preventing the exemption. - assert!(is_cloud_metadata_ip(ips[0])); - } - - #[test] - fn test_detect_trusted_host_gateway_no_entry_returns_empty() { - let contents = "127.0.0.1 localhost\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert!(ips.is_empty()); - } - - #[test] - fn test_detect_trusted_host_gateway_rejects_loopback() { - // Loopback is not link-local — must not receive the SSRF exemption. - let ip = IpAddr::V4(Ipv4Addr::LOCALHOST); - assert!(!is_cloud_metadata_ip(ip)); - assert!(!is_link_local_ip(ip)); - // The guard: !link-local → reject. - assert!(!is_link_local_ip(ip)); - } - - #[test] - fn test_detect_trusted_host_gateway_rejects_unspecified() { - // Unspecified (0.0.0.0) is not link-local — must not be trusted. - let ip = IpAddr::V4(Ipv4Addr::UNSPECIFIED); - assert!(!is_cloud_metadata_ip(ip)); - assert!(!is_link_local_ip(ip)); - assert!(!is_link_local_ip(ip)); - } - - #[test] - fn test_detect_trusted_host_gateway_rejects_loopback_v6() { - let ip = IpAddr::V6(Ipv6Addr::LOCALHOST); - assert!(!is_cloud_metadata_ip(ip)); - assert!(!is_link_local_ip(ip)); - } - - #[test] - fn test_detect_trusted_host_gateway_rejects_private_ip() { - // Docker bridge (172.17.0.1) and K8s host gateway (192.168.x.x) are - // RFC 1918 private addresses — not link-local. Before this fix they - // slipped through the old always-blocked guard and received the SSRF - // exemption. The new guard (!is_link_local_ip) rejects them, so - // connections to these hosts fall through to resolve_and_reject_internal(). - for ip in [ - IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1)), - IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), - IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), - ] { - assert!(!is_cloud_metadata_ip(ip), "{ip} should not be metadata"); - assert!(!is_link_local_ip(ip), "{ip} should not be link-local"); - // Guard fires — exemption disabled. - assert!(!is_link_local_ip(ip), "{ip}: guard must reject"); - } - } - - #[test] - fn test_detect_trusted_host_gateway_allows_link_local_non_metadata() { - // 169.254.1.2 (rootless Podman pasta gateway) IS link-local and is - // not a cloud metadata IP — it is the only address class the exemption - // is designed for. - let ip = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); - assert!(!is_cloud_metadata_ip(ip)); - assert!(is_link_local_ip(ip)); - // Guard does NOT fire — this IP is eligible for the exemption. - assert!(is_link_local_ip(ip)); - } - - // -- parse_hosts_file_for_host: multi-entry / duplicate scenarios -- - - #[test] - fn test_parse_hosts_file_single_entry() { - // Normal driver-injected case: exactly one IP for the alias. - let contents = "169.254.1.2\thost.openshell.internal host.containers.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); - } - - #[test] - fn test_parse_hosts_file_duplicate_same_ip_deduplicated() { - // Same IP on two separate lines for the same alias — deduplicated to one. - let contents = "169.254.1.2\thost.openshell.internal\n\ - 169.254.1.2\thost.openshell.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!( - ips, - vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))], - "identical IPs across lines must be deduplicated" - ); - } - - #[test] - fn test_parse_hosts_file_multiple_distinct_ips() { - // Two distinct IPs for the same alias — both returned, first entry wins - // in detect_trusted_host_gateway(), second would cause mismatch rejection - // in resolve_and_check_trusted_gateway(). - let contents = "169.254.1.2\thost.openshell.internal\n\ - 169.254.1.3\thost.openshell.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!(ips.len(), 2, "two distinct IPs must both be returned"); - assert_eq!(ips[0], IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))); - assert_eq!(ips[1], IpAddr::V4(Ipv4Addr::new(169, 254, 1, 3))); - } - - #[test] - fn test_parse_hosts_file_first_entry_wins_on_ambiguity() { - // detect_trusted_host_gateway() pins to the first entry via .next(). - // Verify the ordering guarantee: first line wins. - let contents = "169.254.1.3\thost.openshell.internal\n\ - 169.254.1.2\thost.openshell.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!( - ips[0], - IpAddr::V4(Ipv4Addr::new(169, 254, 1, 3)), - "first line must be first in the returned vec" - ); - } - - #[test] - fn test_parse_hosts_file_ignores_other_aliases_on_same_line() { - // An entry with multiple aliases — only the matching alias counts. - let contents = - "169.254.1.2\thost.containers.internal host.openshell.internal host.docker.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); - // Non-matching aliases on the same line do not produce extra entries. - let ips2 = parse_hosts_file_for_host(contents, "host.docker.internal"); - assert_eq!(ips2, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); - } - - #[test] - fn test_parse_hosts_file_alias_not_present() { - let contents = "127.0.0.1\tlocalhost\n\ - ::1\t\tlocalhost\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert!(ips.is_empty()); - } - - #[test] - fn test_parse_hosts_file_comment_lines_skipped() { - let contents = "# 169.254.1.2 host.openshell.internal\n\ - 169.254.1.2\thost.openshell.internal\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - // Commented-out line must not produce an entry. - assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); - } - - #[test] - fn test_parse_hosts_file_inline_comment_stripped() { - // Anything after '#' on a data line is treated as a comment. - let contents = "169.254.1.2\thost.openshell.internal # injected by driver\n"; - let ips = parse_hosts_file_for_host(contents, "host.openshell.internal"); - assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2))]); - } - - // -- resolve_and_check_trusted_gateway -- - - #[tokio::test] - async fn test_trusted_gateway_allows_link_local_gateway_ip() { - // Simulate the rootless Podman pasta case: host.openshell.internal - // points to a link-local address which is the only path to the host. - let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); - - // We resolve via /etc/hosts (pid=0 falls back to system), so we - // exercise the trusted_gw mismatch / cloud-metadata guards directly - // against a known resolved address. - let addrs = [SocketAddr::new(trusted_gw, 8080)]; - - // Validate the guard logic inline (mirrors resolve_and_check_trusted_gateway). - assert!(!is_cloud_metadata_ip(trusted_gw)); - assert_eq!(addrs[0].ip(), trusted_gw); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_cloud_metadata_ip() { - let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); - let metadata_ip = IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)); - - // Simulate resolution returning the metadata IP. - let addrs = [SocketAddr::new(metadata_ip, 80)]; - - // Cloud metadata check must fire before the trusted_gw equality check. - let err: Result<(), String> = if is_cloud_metadata_ip(addrs[0].ip()) { - Err(format!( - "host resolves to cloud metadata address {}, connection rejected", - addrs[0].ip() - )) - } else if addrs[0].ip() != trusted_gw { - Err(format!( - "host resolves to {} which does not match trusted host gateway \ - {trusted_gw}, connection rejected", - addrs[0].ip() - )) - } else { - Ok(()) - }; - - assert!(err.is_err()); - assert!( - err.unwrap_err().contains("cloud metadata"), - "expected cloud-metadata rejection" - ); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_mismatched_ip() { - let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); - let other_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)); - - let addrs = [SocketAddr::new(other_ip, 8080)]; - - let err: Result<(), String> = if is_cloud_metadata_ip(addrs[0].ip()) { - Err("cloud metadata".to_string()) - } else if addrs[0].ip() != trusted_gw { - Err(format!( - "{} does not match trusted host gateway {trusted_gw}", - addrs[0].ip() - )) - } else { - Ok(()) - }; - - assert!(err.is_err()); - assert!( - err.unwrap_err() - .contains("does not match trusted host gateway"), - "expected mismatch rejection" - ); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_control_plane_port() { - // Control-plane port check runs before resolution. - let result = resolve_and_check_trusted_gateway( - "host.openshell.internal", - 6443, - IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)), - 0, - ) - .await; - assert!(result.is_err()); - assert!( - result.unwrap_err().contains("blocked control-plane port"), - "expected control-plane port rejection" - ); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_all_control_plane_ports() { - let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); - for &port in BLOCKED_CONTROL_PLANE_PORTS { - let result = - resolve_and_check_trusted_gateway("host.openshell.internal", port, trusted_gw, 0) - .await; - assert!( - result.is_err(), - "port {port} should be blocked by control-plane guard" - ); - assert!( - result.unwrap_err().contains("blocked control-plane port"), - "expected control-plane rejection for port {port}" - ); - } - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_loopback_as_trusted_gw() { - // Defense-in-depth: even if detect_trusted_host_gateway somehow admitted - // a loopback IP, resolve_and_check_trusted_gateway must reject it. - // Using an IP literal as the host bypasses DNS and gives a deterministic - // resolved address, allowing us to exercise the actual function. - let loopback = IpAddr::V4(Ipv4Addr::LOCALHOST); - let result = resolve_and_check_trusted_gateway("127.0.0.1", 8080, loopback, 0).await; - assert!(result.is_err(), "loopback must be rejected"); - let err = result.unwrap_err(); - assert!( - err.contains("non-link-local"), - "expected non-link-local rejection, got: {err}" - ); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_unspecified_as_trusted_gw() { - // Defense-in-depth: 0.0.0.0 as trusted_gw must be rejected. - // IP literal resolves to 0.0.0.0 directly, bypassing DNS. - let unspecified = IpAddr::V4(Ipv4Addr::UNSPECIFIED); - let result = resolve_and_check_trusted_gateway("0.0.0.0", 8080, unspecified, 0).await; - assert!(result.is_err(), "unspecified must be rejected"); - let err = result.unwrap_err(); - assert!( - err.contains("non-link-local"), - "expected non-link-local rejection, got: {err}" - ); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_ip_literal_mismatch() { - // If the requested IP literal doesn't match trusted_gw, the mismatch - // guard fires. This exercises the full resolution→validation path. - let trusted_gw = IpAddr::V4(Ipv4Addr::new(169, 254, 1, 2)); - let other_ip = "10.0.0.1"; // RFC1918, resolves as a literal - let result = resolve_and_check_trusted_gateway(other_ip, 8080, trusted_gw, 0).await; - assert!(result.is_err(), "IP mismatch must be rejected"); - let err = result.unwrap_err(); - assert!( - err.contains("does not match trusted host gateway"), - "expected mismatch rejection, got: {err}" - ); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_cloud_metadata_literal() { - // Cloud metadata IP as a literal address — must be rejected even when - // it matches trusted_gw (which detect_trusted_host_gateway prevents, - // but this is the defense-in-depth layer). - let metadata = IpAddr::V4(Ipv4Addr::new(169, 254, 169, 254)); - let result = resolve_and_check_trusted_gateway("169.254.169.254", 80, metadata, 0).await; - assert!(result.is_err(), "cloud metadata IP must be rejected"); - let err = result.unwrap_err(); - assert!( - err.contains("cloud metadata"), - "expected cloud-metadata rejection, got: {err}" - ); - } - - #[tokio::test] - async fn test_trusted_gateway_rejects_private_ip_as_trusted_gw() { - // Defense-in-depth: a private RFC 1918 IP (e.g. Docker bridge 172.17.0.1) - // must be rejected even if it somehow matched trusted_gw. - // detect_trusted_host_gateway() already blocks these via !is_link_local_ip(), - // but resolve_and_check_trusted_gateway() must enforce the same invariant. - let docker_bridge = IpAddr::V4(Ipv4Addr::new(172, 17, 0, 1)); - let result = resolve_and_check_trusted_gateway("172.17.0.1", 8080, docker_bridge, 0).await; - assert!(result.is_err(), "private RFC 1918 IP must be rejected"); - let err = result.unwrap_err(); - assert!( - err.contains("non-link-local"), - "expected non-link-local rejection for private IP, got: {err}" - ); - } - - #[tokio::test] - async fn test_rejects_localhost_resolution() { - let result = resolve_and_reject_internal("localhost", 80, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("internal address"), - "expected 'internal address' in error: {err}" - ); - } - - #[tokio::test] - async fn test_rejects_loopback_ip_literal() { - let result = resolve_and_reject_internal("127.0.0.1", 443, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("internal address"), - "expected 'internal address' in error: {err}" - ); - } - - #[tokio::test] - async fn test_rejects_metadata_ip() { - let result = resolve_and_reject_internal("169.254.169.254", 80, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("internal address"), - "expected 'internal address' in error: {err}" - ); - } - - #[tokio::test] - async fn test_dns_failure_returns_error() { - let result = resolve_and_reject_internal("this-host-does-not-exist.invalid", 80, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("DNS resolution failed"), - "expected 'DNS resolution failed' in error: {err}" - ); - } - - #[tokio::test] - async fn inference_interception_applies_router_header_allowlist() { - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use tokio::net::TcpListener; - - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let upstream_addr = listener.local_addr().unwrap(); - let upstream_task = tokio::spawn(async move { - use crate::l7::inference::{ParseResult, try_parse_http_request}; - - let (mut upstream, _) = listener.accept().await.unwrap(); - let mut buf = Vec::new(); - let mut chunk = [0u8; 4096]; - - loop { - let n = upstream.read(&mut chunk).await.unwrap(); - assert!(n > 0, "upstream request closed before request completed"); - buf.extend_from_slice(&chunk[..n]); - - match try_parse_http_request(&buf) { - ParseResult::Complete(_, consumed) => { - upstream - .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok") - .await - .unwrap(); - return String::from_utf8_lossy(&buf[..consumed]).to_string(); - } - ParseResult::Incomplete => continue, - ParseResult::Invalid(reason) => { - panic!("forwarded request should parse cleanly: {reason}"); - } - } - } - }); - - let router = openshell_router::Router::new().unwrap(); - let patterns = crate::l7::inference::default_patterns(); - let ctx = InferenceContext::new( - patterns, - router, - vec![openshell_router::config::ResolvedRoute { - name: "inference.local".to_string(), - endpoint: format!("http://{upstream_addr}"), - model: "meta/llama-3.1-8b-instruct".to_string(), - api_key: "test-api-key".to_string(), - protocols: vec!["openai_chat_completions".to_string()], - auth: openshell_router::config::AuthHeader::Bearer, - default_headers: vec![], - passthrough_headers: vec![ - "openai-organization".to_string(), - "x-model-id".to_string(), - ], - timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, - model_in_path: false, - request_path_override: None, - }], - vec![], - ); - - let body = r#"{"model":"ignored","messages":[{"role":"user","content":"hi"}]}"#; - let request = format!( - "POST /v1/chat/completions HTTP/1.1\r\n\ - Host: inference.local\r\n\ - Content-Type: application/json\r\n\ - OpenAI-Organization: org_123\r\n\ - Authorization: Bearer client-key\r\n\ - Cookie: session=abc\r\n\ - Content-Length: {}\r\n\r\n{}", - body.len(), - body, - ); - - let (client, mut server) = tokio::io::duplex(65536); - let (mut client_read, mut client_write) = tokio::io::split(client); - - let server_task = - tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); - - client_write.write_all(request.as_bytes()).await.unwrap(); - client_write.shutdown().await.unwrap(); - - let mut response = Vec::new(); - client_read.read_to_end(&mut response).await.unwrap(); - let response_text = String::from_utf8_lossy(&response); - assert!(response_text.starts_with("HTTP/1.1 200")); - - let outcome = server_task.await.unwrap().unwrap(); - assert!( - matches!(outcome, InferenceOutcome::Routed), - "expected Routed outcome, got: {outcome:?}" - ); - - let forwarded = upstream_task.await.unwrap(); - let forwarded_lc = forwarded.to_ascii_lowercase(); - assert!(forwarded_lc.contains("openai-organization: org_123")); - assert!(forwarded_lc.contains("authorization: bearer test-api-key")); - assert!(!forwarded_lc.contains("authorization: bearer client-key")); - assert!(!forwarded_lc.contains("cookie:")); - } - - fn streaming_inference_route(endpoint: String) -> openshell_router::config::ResolvedRoute { - openshell_router::config::ResolvedRoute { - name: "inference.local".to_string(), - endpoint, - model: "meta/llama-3.1-8b-instruct".to_string(), - api_key: "test-api-key".to_string(), - protocols: vec!["openai_chat_completions".to_string()], - auth: openshell_router::config::AuthHeader::Bearer, - default_headers: vec![], - passthrough_headers: vec![], - timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, - model_in_path: false, - request_path_override: None, - } - } - - fn embeddings_inference_route(endpoint: String) -> openshell_router::config::ResolvedRoute { - openshell_router::config::ResolvedRoute { - name: "inference.local".to_string(), - endpoint, - model: "text-embedding-3-small".to_string(), - api_key: "test-api-key".to_string(), - protocols: vec!["openai_embeddings".to_string()], - auth: openshell_router::config::AuthHeader::Bearer, - default_headers: vec![], - passthrough_headers: vec![], - timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, - model_in_path: false, - request_path_override: None, - } - } - - /// Embeddings responses are a single buffered JSON object, not an SSE - /// stream. They must be framed with `Content-Length` and must never be sent - /// through the chunked streaming path, whose truncation handlers would - /// append an SSE `proxy_stream_error` frame into the JSON body. - #[tokio::test] - async fn inference_embeddings_served_buffered_with_content_length() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let upstream_addr = listener.local_addr().unwrap(); - let upstream_body = r#"{"object":"list","data":[{"object":"embedding","index":0,"embedding":[0.1,0.2]}],"model":"text-embedding-3-small"}"#; - let upstream_task = tokio::spawn(async move { - let (mut upstream, _) = listener.accept().await.unwrap(); - read_forwarded_inference_request(&mut upstream).await; - // Buffered upstream response with Content-Length (no chunked TE). - let resp = format!( - "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", - upstream_body.len(), - upstream_body, - ); - upstream.write_all(resp.as_bytes()).await.unwrap(); - }); - - let router = openshell_router::Router::new().unwrap(); - let patterns = crate::l7::inference::default_patterns(); - let ctx = InferenceContext::new( - patterns, - router, - vec![embeddings_inference_route(format!( - "http://{upstream_addr}" - ))], - vec![], - ); - - let body = r#"{"model":"text-embedding-3-small","input":"hello"}"#; - let request = format!( - "POST /v1/embeddings HTTP/1.1\r\n\ - Host: inference.local\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\r\n{}", - body.len(), - body, - ); - - let (client, mut server) = tokio::io::duplex(65536); - let (mut client_read, mut client_write) = tokio::io::split(client); - let server_task = - tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); - - client_write.write_all(request.as_bytes()).await.unwrap(); - client_write.shutdown().await.unwrap(); - - let mut response = Vec::new(); - client_read.read_to_end(&mut response).await.unwrap(); - let response = String::from_utf8(response).unwrap(); - - server_task.await.unwrap().unwrap(); - upstream_task.await.unwrap(); - - assert!( - response.starts_with("HTTP/1.1 200 OK\r\n"), - "expected buffered 200 response, got: {response}" - ); - let lower = response.to_ascii_lowercase(); - assert!( - lower.contains("content-length:"), - "embeddings response must be Content-Length framed, got: {response}" - ); - assert!( - !lower.contains("transfer-encoding: chunked"), - "embeddings response must NOT be chunked, got: {response}" - ); - assert!( - !response.contains("proxy_stream_error"), - "embeddings response must not carry an SSE error frame, got: {response}" - ); - assert!( - response.contains(r#""object":"list""#), - "embeddings JSON body must be forwarded intact, got: {response}" - ); - } - - fn model_discovery_inference_route( - endpoint: String, - ) -> openshell_router::config::ResolvedRoute { - openshell_router::config::ResolvedRoute { - name: "inference.local".to_string(), - endpoint, - model: "text-embedding-3-small".to_string(), - api_key: "test-api-key".to_string(), - protocols: vec!["model_discovery".to_string()], - auth: openshell_router::config::AuthHeader::Bearer, - default_headers: vec![], - passthrough_headers: vec![], - timeout: openshell_router::config::DEFAULT_ROUTE_TIMEOUT, - model_in_path: false, - request_path_override: None, - } - } - - /// `GET /v1/models` (model discovery) returns one JSON object — a model - /// list — exactly like embeddings. It must be served buffered with - /// `Content-Length`, never through the chunked streaming path whose - /// truncation handlers would append an SSE `proxy_stream_error` frame into - /// the JSON body. This guards the framing classification for the protocol. - #[tokio::test] - async fn inference_model_discovery_served_buffered_with_content_length() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let upstream_addr = listener.local_addr().unwrap(); - let upstream_body = - r#"{"object":"list","data":[{"id":"text-embedding-3-small","object":"model"}]}"#; - let upstream_task = tokio::spawn(async move { - let (mut upstream, _) = listener.accept().await.unwrap(); - read_forwarded_inference_request(&mut upstream).await; - let resp = format!( - "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", - upstream_body.len(), - upstream_body, - ); - upstream.write_all(resp.as_bytes()).await.unwrap(); - }); - - let router = openshell_router::Router::new().unwrap(); - let patterns = crate::l7::inference::default_patterns(); - let ctx = InferenceContext::new( - patterns, - router, - vec![model_discovery_inference_route(format!( - "http://{upstream_addr}" - ))], - vec![], - ); - - // GET model discovery carries no request body. - let request = "GET /v1/models HTTP/1.1\r\n\ - Host: inference.local\r\n\ - Content-Length: 0\r\n\r\n" - .to_string(); - - let (client, mut server) = tokio::io::duplex(65536); - let (mut client_read, mut client_write) = tokio::io::split(client); - let server_task = - tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); - - client_write.write_all(request.as_bytes()).await.unwrap(); - client_write.shutdown().await.unwrap(); - - let mut response = Vec::new(); - client_read.read_to_end(&mut response).await.unwrap(); - let response = String::from_utf8(response).unwrap(); - - server_task.await.unwrap().unwrap(); - upstream_task.await.unwrap(); - - assert!( - response.starts_with("HTTP/1.1 200 OK\r\n"), - "expected buffered 200 response, got: {response}" - ); - let lower = response.to_ascii_lowercase(); - assert!( - lower.contains("content-length:"), - "model discovery response must be Content-Length framed, got: {response}" - ); - assert!( - !lower.contains("transfer-encoding: chunked"), - "model discovery response must NOT be chunked, got: {response}" - ); - assert!( - !response.contains("proxy_stream_error"), - "model discovery response must not carry an SSE error frame, got: {response}" - ); - assert!( - response.contains(r#""object":"list""#), - "model discovery JSON body must be forwarded intact, got: {response}" - ); - } - - /// `GET /v1/models/{id}` (model discovery glob) must forward the model id in - /// the path through the buffered path with the id intact, never streamed. - #[tokio::test] - async fn inference_model_discovery_glob_path_served_buffered() { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let upstream_addr = listener.local_addr().unwrap(); - let upstream_body = r#"{"id":"gpt-4.1","object":"model"}"#; - let upstream_task = tokio::spawn(async move { - let (mut upstream, _) = listener.accept().await.unwrap(); - let forwarded = read_forwarded_request_line(&mut upstream).await; - let resp = format!( - "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", - upstream_body.len(), - upstream_body, - ); - upstream.write_all(resp.as_bytes()).await.unwrap(); - forwarded - }); - - let router = openshell_router::Router::new().unwrap(); - let patterns = crate::l7::inference::default_patterns(); - let ctx = InferenceContext::new( - patterns, - router, - vec![model_discovery_inference_route(format!( - "http://{upstream_addr}" - ))], - vec![], - ); - - let request = "GET /v1/models/gpt-4.1 HTTP/1.1\r\n\ - Host: inference.local\r\n\ - Content-Length: 0\r\n\r\n" - .to_string(); - let (client, mut server) = tokio::io::duplex(65536); - let (mut client_read, mut client_write) = tokio::io::split(client); - let server_task = - tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); - client_write.write_all(request.as_bytes()).await.unwrap(); - client_write.shutdown().await.unwrap(); - let mut response = Vec::new(); - client_read.read_to_end(&mut response).await.unwrap(); - let response = String::from_utf8(response).unwrap(); - server_task.await.unwrap().unwrap(); - let (method, forwarded_path) = upstream_task.await.unwrap(); - - assert_eq!(method, "GET"); - assert_eq!( - forwarded_path, "/v1/models/gpt-4.1", - "the model id in the glob path must be forwarded intact" - ); - let lower = response.to_ascii_lowercase(); - assert!( - response.starts_with("HTTP/1.1 200 OK\r\n") - && lower.contains("content-length:") - && !lower.contains("transfer-encoding: chunked") - && !response.contains("proxy_stream_error"), - "glob model discovery must be buffered and Content-Length framed, got: {response}" - ); - } - - /// A failed model-discovery upstream must produce a buffered, Content-Length - /// framed JSON error, never a chunked SSE `proxy_stream_error` frame. - #[tokio::test] - async fn inference_model_discovery_error_served_buffered() { - // A port with no listener so the upstream connection is refused. - let dead_addr = { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - drop(listener); - addr - }; - - let router = openshell_router::Router::new().unwrap(); - let patterns = crate::l7::inference::default_patterns(); - let ctx = InferenceContext::new( - patterns, - router, - vec![model_discovery_inference_route(format!( - "http://{dead_addr}" - ))], - vec![], - ); - - let request = "GET /v1/models HTTP/1.1\r\n\ - Host: inference.local\r\n\ - Content-Length: 0\r\n\r\n" - .to_string(); - let (client, mut server) = tokio::io::duplex(65536); - let (mut client_read, mut client_write) = tokio::io::split(client); - let server_task = - tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); - client_write.write_all(request.as_bytes()).await.unwrap(); - client_write.shutdown().await.unwrap(); - let mut response = Vec::new(); - client_read.read_to_end(&mut response).await.unwrap(); - let response = String::from_utf8(response).unwrap(); - server_task.await.unwrap().unwrap(); - - let lower = response.to_ascii_lowercase(); - assert!( - response.starts_with("HTTP/1.1 5"), - "a refused upstream should yield a 5xx, got: {response}" - ); - assert!( - lower.contains("content-length:") - && !lower.contains("transfer-encoding: chunked") - && !response.contains("proxy_stream_error"), - "buffered model-discovery error must be Content-Length framed JSON, got: {response}" - ); - assert!( - response.contains("error"), - "error response should carry a JSON error body, got: {response}" - ); - } - - async fn read_forwarded_inference_request(stream: &mut S) { - use crate::l7::inference::{ParseResult, try_parse_http_request}; - - let mut buf = Vec::new(); - let mut chunk = [0u8; 4096]; - loop { - let n = stream.read(&mut chunk).await.unwrap(); - assert!(n > 0, "upstream request closed before completion"); - buf.extend_from_slice(&chunk[..n]); - - match try_parse_http_request(&buf) { - ParseResult::Complete(_, _) => return, - ParseResult::Incomplete => continue, - ParseResult::Invalid(reason) => { - panic!("forwarded request should parse cleanly: {reason}"); - } - } - } - } - - /// Like [`read_forwarded_inference_request`] but returns the forwarded - /// request line (method, path) so a test can assert the upstream URL path. - async fn read_forwarded_request_line(stream: &mut S) -> (String, String) { - use crate::l7::inference::{ParseResult, try_parse_http_request}; - - let mut buf = Vec::new(); - let mut chunk = [0u8; 4096]; - loop { - let n = stream.read(&mut chunk).await.unwrap(); - assert!(n > 0, "upstream request closed before completion"); - buf.extend_from_slice(&chunk[..n]); - - match try_parse_http_request(&buf) { - ParseResult::Complete(req, _) => return (req.method, req.path), - ParseResult::Incomplete => continue, - ParseResult::Invalid(reason) => { - panic!("forwarded request should parse cleanly: {reason}"); - } - } - } - } - - async fn run_live_streaming_inference(serve_upstream: F) -> String - where - F: FnOnce(TcpStream) -> Fut + Send + 'static, - Fut: Future + Send + 'static, - { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let upstream_addr = listener.local_addr().unwrap(); - let upstream_task = tokio::spawn(async move { - let (mut upstream, _) = listener.accept().await.unwrap(); - read_forwarded_inference_request(&mut upstream).await; - serve_upstream(upstream).await; - }); - - let router = openshell_router::Router::new().unwrap(); - let patterns = crate::l7::inference::default_patterns(); - let ctx = InferenceContext::new( - patterns, - router, - vec![streaming_inference_route(format!("http://{upstream_addr}"))], - vec![], - ); - - let body = r#"{"model":"ignored","messages":[{"role":"user","content":"hi"}]}"#; - let request = format!( - "POST /v1/chat/completions HTTP/1.1\r\n\ - Host: inference.local\r\n\ - Content-Type: application/json\r\n\ - Accept: text/event-stream\r\n\ - Content-Length: {}\r\n\r\n{}", - body.len(), - body, - ); - - let (client, mut server) = tokio::io::duplex(65536); - let (mut client_read, mut client_write) = tokio::io::split(client); - let server_task = - tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); - - client_write.write_all(request.as_bytes()).await.unwrap(); - client_write.shutdown().await.unwrap(); - - let mut response = Vec::new(); - client_read.read_to_end(&mut response).await.unwrap(); - - let outcome = server_task.await.unwrap().unwrap(); - assert!( - matches!(outcome, InferenceOutcome::Routed), - "expected Routed outcome, got: {outcome:?}" - ); - upstream_task.await.unwrap(); - - String::from_utf8(response).unwrap() - } - - fn assert_streaming_sse_error(response: &str, message: &str) { - assert!( - response.starts_with("HTTP/1.1 200 OK\r\n"), - "expected successful streaming response, got: {response}" - ); - assert!( - response - .to_ascii_lowercase() - .contains("transfer-encoding: chunked"), - "expected chunked streaming response, got: {response}" - ); - assert!( - response.contains("\"type\":\"proxy_stream_error\""), - "expected proxy_stream_error SSE event, got: {response}" - ); - assert!( - response.contains(&format!("\"message\":\"{message}\"")), - "expected SSE message {message:?}, got: {response}" - ); - assert!( - response.ends_with("0\r\n\r\n"), - "streaming response must end with chunked terminator, got: {response}" - ); - } - - #[tokio::test] - async fn inference_stream_byte_limit_injects_sse_error() { - let response = run_live_streaming_inference(|mut upstream| async move { - use crate::l7::inference::{format_chunk, format_chunk_terminator}; - - upstream - .write_all( - b"HTTP/1.1 200 OK\r\n\ - Content-Type: text/event-stream\r\n\ - Transfer-Encoding: chunked\r\n\r\n", - ) - .await - .unwrap(); - let body = vec![b'a'; MAX_STREAMING_BODY + 1]; - let _ = upstream.write_all(&format_chunk(&body)).await; - let _ = upstream.write_all(format_chunk_terminator()).await; - }) - .await; - - assert_streaming_sse_error( - &response, - "response truncated: exceeded maximum streaming body size", - ); - } - - #[tokio::test] - async fn inference_stream_upstream_read_error_injects_sse_error() { - let response = run_live_streaming_inference(|mut upstream| async move { - upstream - .write_all( - b"HTTP/1.1 200 OK\r\n\ - Content-Type: text/event-stream\r\n\ - Content-Length: 64\r\n\r\n\ - partial", - ) - .await - .unwrap(); - }) - .await; - - assert!( - response.contains("partial"), - "expected initial upstream bytes before truncation, got: {response}" - ); - assert_streaming_sse_error(&response, "response truncated: upstream read error"); - } - - #[tokio::test] - async fn inference_stream_idle_timeout_injects_sse_error() { - let response = run_live_streaming_inference(|mut upstream| async move { - upstream - .write_all( - b"HTTP/1.1 200 OK\r\n\ - Content-Type: text/event-stream\r\n\ - Transfer-Encoding: chunked\r\n\r\n", - ) - .await - .unwrap(); - tokio::time::sleep(CHUNK_IDLE_TIMEOUT + std::time::Duration::from_millis(50)).await; - }) - .await; - - assert_streaming_sse_error(&response, "response truncated: chunk idle timeout exceeded"); - } - - // -- router_error_to_http -- - - #[test] - fn router_error_route_not_found_maps_to_400() { - let err = openshell_router::RouterError::RouteNotFound("local".into()); - let (status, msg) = router_error_to_http(&err); - assert_eq!(status, 400); - assert_eq!(msg, "no inference route configured"); - // SEC-008: must NOT leak the route hint to sandboxed code - assert!(!msg.contains("local")); - } - - #[test] - fn router_error_no_compatible_route_maps_to_400() { - let err = openshell_router::RouterError::NoCompatibleRoute("anthropic_messages".into()); - let (status, msg) = router_error_to_http(&err); - assert_eq!(status, 400); - assert_eq!(msg, "no compatible inference route available"); - // SEC-008: must NOT leak the protocol name to sandboxed code - assert!(!msg.contains("anthropic_messages")); - } - - #[test] - fn router_error_unauthorized_maps_to_401() { - let err = - openshell_router::RouterError::Unauthorized("bad token from 10.0.0.5:8080".into()); - let (status, msg) = router_error_to_http(&err); - assert_eq!(status, 401); - assert_eq!(msg, "unauthorized"); - // SEC-008: must NOT leak upstream details to sandboxed code - assert!(!msg.contains("10.0.0.5")); - } - - #[test] - fn router_error_upstream_unavailable_maps_to_503() { - let err = openshell_router::RouterError::UpstreamUnavailable( - "connection refused to 10.0.0.5:8080".into(), - ); - let (status, msg) = router_error_to_http(&err); - assert_eq!(status, 503); - assert_eq!(msg, "inference service unavailable"); - // SEC-008: must NOT leak upstream address to sandboxed code - assert!(!msg.contains("10.0.0.5")); - } - - #[test] - fn router_error_upstream_protocol_maps_to_502() { - let err = openshell_router::RouterError::UpstreamProtocol( - "TLS handshake failed for nim.internal.svc:443".into(), - ); - let (status, msg) = router_error_to_http(&err); - assert_eq!(status, 502); - assert_eq!(msg, "inference service error"); - // SEC-008: must NOT leak internal hostnames to sandboxed code - assert!(!msg.contains("nim.internal")); - } - - #[test] - fn router_error_internal_maps_to_502() { - let err = openshell_router::RouterError::Internal( - "failed to read /etc/openshell/routes.json".into(), - ); - let (status, msg) = router_error_to_http(&err); - assert_eq!(status, 502); - assert_eq!(msg, "inference service error"); - // SEC-008: must NOT leak file paths to sandboxed code - assert!(!msg.contains("/etc/openshell")); - } - - #[test] - fn sanitize_response_headers_strips_hop_by_hop() { - let headers = vec![ - ("transfer-encoding".to_string(), "chunked".to_string()), - ("content-length".to_string(), "128".to_string()), - ("connection".to_string(), "keep-alive".to_string()), - ("content-type".to_string(), "text/event-stream".to_string()), - ("cache-control".to_string(), "no-cache".to_string()), - ]; - - let kept = sanitize_inference_response_headers(headers); - - assert!( - kept.iter() - .all(|(k, _)| !k.eq_ignore_ascii_case("transfer-encoding")), - "transfer-encoding should be stripped" - ); - assert!( - kept.iter() - .all(|(k, _)| !k.eq_ignore_ascii_case("content-length")), - "content-length should be stripped" - ); - assert!( - kept.iter() - .all(|(k, _)| !k.eq_ignore_ascii_case("connection")), - "connection should be stripped" - ); - assert!( - kept.iter() - .any(|(k, _)| k.eq_ignore_ascii_case("content-type")), - "content-type should be preserved" - ); - assert!( - kept.iter() - .any(|(k, _)| k.eq_ignore_ascii_case("cache-control")), - "cache-control should be preserved" - ); - } - - // -- is_always_blocked_ip -- - - #[test] - fn test_always_blocked_loopback_v4() { - assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::LOCALHOST))); - assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( - 127, 0, 0, 2 - )))); - } - - #[test] - fn test_always_blocked_link_local_v4() { - assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( - 169, 254, 169, 254 - )))); - assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( - 169, 254, 0, 1 - )))); - } - - #[test] - fn test_always_blocked_loopback_v6() { - assert!(is_always_blocked_ip(IpAddr::V6(Ipv6Addr::LOCALHOST))); - } - - #[test] - fn test_always_blocked_link_local_v6() { - assert!(is_always_blocked_ip(IpAddr::V6(Ipv6Addr::new( - 0xfe80, 0, 0, 0, 0, 0, 0, 1 - )))); - } - - #[test] - fn test_always_blocked_ipv4_unspecified() { - assert!(is_always_blocked_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED))); - } - - #[test] - fn test_always_blocked_ipv6_unspecified() { - assert!(is_always_blocked_ip(IpAddr::V6(Ipv6Addr::UNSPECIFIED))); - } - - #[test] - fn test_always_blocked_ipv4_mapped_v6_loopback() { - let v6 = Ipv4Addr::LOCALHOST.to_ipv6_mapped(); - assert!(is_always_blocked_ip(IpAddr::V6(v6))); - } - - #[test] - fn test_always_blocked_ipv4_mapped_v6_link_local() { - let v6 = Ipv4Addr::new(169, 254, 169, 254).to_ipv6_mapped(); - assert!(is_always_blocked_ip(IpAddr::V6(v6))); - } - - #[test] - fn test_always_blocked_allows_rfc1918() { - // RFC 1918 addresses should NOT be always-blocked (they're allowed - // when allowed_ips is configured) - assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( - 10, 0, 0, 1 - )))); - assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( - 172, 16, 0, 1 - )))); - assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new( - 192, 168, 0, 1 - )))); - } - - #[test] - fn test_always_blocked_allows_public() { - assert!(!is_always_blocked_ip(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))); - assert!(!is_always_blocked_ip(IpAddr::V6(Ipv6Addr::new( - 0x2001, 0x4860, 0x4860, 0, 0, 0, 0, 0x8888 - )))); - } - - // -- parse_allowed_ips -- - - #[test] - fn test_parse_cidr_notation() { - let raw = vec!["10.0.5.0/24".to_string()]; - let nets = parse_allowed_ips(&raw).unwrap(); - assert_eq!(nets.len(), 1); - assert!(nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 5, 1)))); - assert!(!nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 6, 1)))); - } - - #[test] - fn test_parse_exact_ip() { - let raw = vec!["10.0.5.20".to_string()]; - let nets = parse_allowed_ips(&raw).unwrap(); - assert_eq!(nets.len(), 1); - assert!(nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 5, 20)))); - assert!(!nets[0].contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 5, 21)))); - } - - #[test] - fn test_parse_multiple_entries() { - let raw = vec![ - "10.0.0.0/8".to_string(), - "172.16.0.0/12".to_string(), - "192.168.1.1".to_string(), - ]; - let nets = parse_allowed_ips(&raw).unwrap(); - assert_eq!(nets.len(), 3); - } - - #[test] - fn test_parse_invalid_entry_errors() { - let raw = vec!["not-an-ip".to_string()]; - let result = parse_allowed_ips(&raw); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("invalid CIDR/IP")); - } - - #[test] - fn test_parse_mixed_valid_invalid_errors() { - let raw = vec!["10.0.5.0/24".to_string(), "garbage".to_string()]; - let result = parse_allowed_ips(&raw); - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_resolve_check_allowed_ips_blocks_loopback() { - // Construct nets directly (parse_allowed_ips now rejects always-blocked). - let nets = vec!["127.0.0.0/8".parse::().unwrap()]; - let result = resolve_and_check_allowed_ips("127.0.0.1", 80, &nets, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected 'always-blocked' in error: {err}" - ); - } - - #[tokio::test] - async fn test_resolve_check_allowed_ips_blocks_metadata() { - // Construct nets directly (parse_allowed_ips now rejects always-blocked). - let nets = vec!["169.254.0.0/16".parse::().unwrap()]; - let result = resolve_and_check_allowed_ips("169.254.169.254", 80, &nets, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected 'always-blocked' in error: {err}" - ); - } - - #[tokio::test] - async fn test_resolve_check_allowed_ips_blocks_unspecified() { - // Construct nets directly (parse_allowed_ips now rejects always-blocked). - let nets = vec!["0.0.0.0/0".parse::().unwrap()]; - let result = resolve_and_check_allowed_ips("0.0.0.0", 80, &nets, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected 'always-blocked' in error: {err}" - ); - } - - #[tokio::test] - async fn test_resolve_check_allowed_ips_rejects_outside_allowlist() { - // 8.8.8.8 resolves to a public IP which is NOT in 10.0.0.0/8 - let nets = parse_allowed_ips(&["10.0.0.0/8".to_string()]).unwrap(); - let result = resolve_and_check_allowed_ips("dns.google", 443, &nets, 0).await; - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - err.contains("not in allowed_ips"), - "expected 'not in allowed_ips' in error: {err}" - ); - } - - // --- SEC-005: CIDR breadth warning and control-plane port blocklist --- - - #[tokio::test] - async fn test_resolve_check_allowed_ips_blocks_control_plane_ports() { - // Use a public CIDR (parse_allowed_ips now rejects 0.0.0.0/0). - let nets = parse_allowed_ips(&["8.8.8.0/24".to_string()]).unwrap(); - // K8s API server port - let result = resolve_and_check_allowed_ips("8.8.8.8", 6443, &nets, 0).await; - assert!(result.is_err()); - assert!(result.unwrap_err().contains("blocked control-plane port")); - - // etcd client port - let result = resolve_and_check_allowed_ips("8.8.8.8", 2379, &nets, 0).await; - assert!(result.is_err()); - assert!(result.unwrap_err().contains("blocked control-plane port")); - - // kubelet API port - let result = resolve_and_check_allowed_ips("8.8.8.8", 10250, &nets, 0).await; - assert!(result.is_err()); - assert!(result.unwrap_err().contains("blocked control-plane port")); - } - - #[tokio::test] - async fn test_resolve_check_allowed_ips_allows_non_control_plane_ports() { - // Port 443 should not be blocked by the control-plane port list - let nets = parse_allowed_ips(&["8.8.8.0/24".to_string()]).unwrap(); - let result = resolve_and_check_allowed_ips("8.8.8.8", 443, &nets, 0).await; - assert!(result.is_ok()); - } - - #[test] - fn test_parse_allowed_ips_broad_cidr_is_accepted() { - // Broad CIDRs are accepted (just warned about) -- design trade-off - let result = parse_allowed_ips(&["10.0.0.0/8".to_string()]); - assert!(result.is_ok()); - } - - // --- parse_allowed_ips: always-blocked rejection tests --- - - #[test] - fn test_parse_allowed_ips_rejects_loopback_cidr() { - let result = parse_allowed_ips(&["127.0.0.0/8".to_string()]); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("always-blocked")); - } - - #[test] - fn test_parse_allowed_ips_rejects_link_local_cidr() { - let result = parse_allowed_ips(&["169.254.0.0/16".to_string()]); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("always-blocked")); - } - - #[test] - fn test_parse_allowed_ips_rejects_unspecified() { - let result = parse_allowed_ips(&["0.0.0.0".to_string()]); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("always-blocked")); - } - - #[test] - fn test_parse_allowed_ips_rejects_single_loopback_ip() { - let result = parse_allowed_ips(&["127.0.0.1".to_string()]); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("always-blocked")); - } - - #[test] - fn test_parse_allowed_ips_rejects_single_metadata_ip() { - let result = parse_allowed_ips(&["169.254.169.254".to_string()]); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("always-blocked")); - } - - #[test] - fn test_parse_allowed_ips_rejects_wildcard_cidr() { - let result = parse_allowed_ips(&["0.0.0.0/0".to_string()]); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("always-blocked")); - } - - #[test] - fn test_parse_allowed_ips_mixed_valid_and_blocked() { - // A blocked entry taints the whole batch. - let result = parse_allowed_ips(&["10.0.5.0/24".to_string(), "127.0.0.1".to_string()]); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("always-blocked")); - } - - #[test] - fn test_parse_allowed_ips_accepts_rfc1918() { - let result = parse_allowed_ips(&["10.0.5.0/24".to_string(), "192.168.1.0/24".to_string()]); - assert!(result.is_ok()); - } - - // --- implicit_allowed_ips_for_ip_host: always-blocked skip tests --- - - #[test] - fn test_implicit_allowed_ips_skips_loopback() { - let result = implicit_allowed_ips_for_ip_host("127.0.0.1"); - assert!(result.is_empty()); - } - - #[test] - fn test_implicit_allowed_ips_skips_link_local() { - let result = implicit_allowed_ips_for_ip_host("169.254.169.254"); - assert!(result.is_empty()); - } - - #[test] - fn test_implicit_allowed_ips_skips_unspecified() { - let result = implicit_allowed_ips_for_ip_host("0.0.0.0"); - assert!(result.is_empty()); - } - - #[test] - fn test_implicit_allowed_ips_allows_rfc1918() { - let result = implicit_allowed_ips_for_ip_host("10.0.5.20"); - assert_eq!(result, vec!["10.0.5.20"]); - } - - // --- extract_host_from_uri tests --- - - #[test] - fn test_extract_host_from_http_uri() { - assert_eq!( - extract_host_from_uri("http://example.com/path"), - "example.com" - ); - } - - #[test] - fn test_extract_host_from_https_uri() { - assert_eq!( - extract_host_from_uri("https://api.openai.com/v1/chat/completions"), - "api.openai.com" - ); - } - - #[test] - fn test_extract_host_from_uri_with_port() { - assert_eq!( - extract_host_from_uri("http://example.com:8080/path"), - "example.com" - ); - } - - #[test] - fn test_extract_host_from_uri_ipv6() { - assert_eq!(extract_host_from_uri("http://[::1]:8080/path"), "[::1]"); - } - - #[test] - fn test_extract_host_from_uri_no_path() { - assert_eq!(extract_host_from_uri("http://example.com"), "example.com"); - } - - #[test] - fn test_extract_host_from_uri_empty() { - assert_eq!(extract_host_from_uri(""), "unknown"); - } - - #[test] - fn test_extract_host_from_uri_malformed() { - // Gracefully handles garbage input - let result = extract_host_from_uri("not-a-uri"); - assert!(!result.is_empty()); - } - - // --- parse_proxy_uri tests --- - - #[test] - fn test_parse_proxy_uri_standard() { - let (scheme, host, port, path) = - parse_proxy_uri("http://10.86.8.223:8000/screenshot/").unwrap(); - assert_eq!(scheme, "http"); - assert_eq!(host, "10.86.8.223"); - assert_eq!(port, 8000); - assert_eq!(path, "/screenshot/"); - } - - #[test] - fn test_parse_proxy_uri_default_port() { - let (scheme, host, port, path) = parse_proxy_uri("http://example.com/path").unwrap(); - assert_eq!(scheme, "http"); - assert_eq!(host, "example.com"); - assert_eq!(port, 80); - assert_eq!(path, "/path"); - } - - #[test] - fn test_parse_proxy_uri_https_default_port() { - let (scheme, host, port, path) = - parse_proxy_uri("https://api.example.com/v1/chat").unwrap(); - assert_eq!(scheme, "https"); - assert_eq!(host, "api.example.com"); - assert_eq!(port, 443); - assert_eq!(path, "/v1/chat"); - } - - #[test] - fn test_parse_proxy_uri_missing_path() { - let (_, host, port, path) = parse_proxy_uri("http://10.0.0.1:9090").unwrap(); - assert_eq!(host, "10.0.0.1"); - assert_eq!(port, 9090); - assert_eq!(path, "/"); - } - - #[test] - fn test_parse_proxy_uri_with_query() { - let (_, _, _, path) = parse_proxy_uri("http://host:80/api?key=val&foo=bar").unwrap(); - assert_eq!(path, "/api?key=val&foo=bar"); - } - - #[test] - fn test_parse_proxy_uri_ipv6() { - let (_, host, port, path) = parse_proxy_uri("http://[::1]:8080/test").unwrap(); - assert_eq!(host, "::1"); - assert_eq!(port, 8080); - assert_eq!(path, "/test"); - } - - #[test] - fn test_parse_proxy_uri_ipv6_default_port() { - let (_, host, port, path) = parse_proxy_uri("http://[fe80::1]/path").unwrap(); - assert_eq!(host, "fe80::1"); - assert_eq!(port, 80); - assert_eq!(path, "/path"); - } - - #[test] - fn test_parse_proxy_uri_missing_scheme() { - let result = parse_proxy_uri("example.com/path"); - assert!(result.is_err()); - } - - #[test] - fn test_parse_proxy_uri_empty_host() { - let result = parse_proxy_uri("http:///path"); - assert!(result.is_err()); - } - - // --- rewrite_forward_request tests --- - - #[tokio::test] - async fn forward_proxy_injects_token_grant_before_rewriting_request() { - let (ctx, fixture) = forward_token_grant_context(Ok("grant-token")); - let raw = b"GET http://api.example.test:8080/v1/projects HTTP/1.1\r\nHost: api.example.test:8080\r\nAuthorization: Bearer stale-token\r\nConnection: close\r\n\r\n".to_vec(); - - let with_token = inject_token_grant_for_forward_request("GET", "/v1/projects", raw, &ctx) - .await - .expect("forward token grant should inject"); - let rewritten = - rewrite_forward_request(&with_token, with_token.len(), "/v1/projects", None, false) - .expect("forward request should rewrite"); - let rewritten = String::from_utf8_lossy(&rewritten); - - assert!(rewritten.starts_with("GET /v1/projects HTTP/1.1\r\n")); - assert!(rewritten.contains("Authorization: Bearer grant-token\r\n")); - assert!(!rewritten.contains("stale-token")); - assert_eq!(authorization_header_count(&rewritten), 1); - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); - } - - #[tokio::test] - async fn forward_proxy_token_grant_failure_returns_error_before_rewrite() { - let (ctx, fixture) = forward_token_grant_context(Err("oauth unavailable")); - let raw = b"GET http://api.example.test:8080/v1/projects HTTP/1.1\r\nHost: api.example.test:8080\r\nConnection: close\r\n\r\n".to_vec(); - - let err = inject_token_grant_for_forward_request("GET", "/v1/projects", raw, &ctx) - .await - .expect_err("forward token grant failure should stop request rewriting"); - - assert!(err.to_string().contains("Token grant failed")); - assert!(err.to_string().contains("oauth unavailable")); - fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); - } - - #[test] - fn test_rewrite_get_request() { - let raw = - b"GET http://10.0.0.1:8000/api HTTP/1.1\r\nHost: 10.0.0.1:8000\r\nAccept: */*\r\n\r\n"; - let result = - rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); - let result_str = String::from_utf8_lossy(&result); - assert!(result_str.starts_with("GET /api HTTP/1.1\r\n")); - assert!(result_str.contains("Host: 10.0.0.1:8000")); - assert!(result_str.contains("Connection: close")); - assert!(result_str.contains("Via: 1.1 openshell-sandbox")); - } - - #[test] - fn test_rewrite_strips_proxy_headers() { - let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nProxy-Authorization: Basic abc\r\nProxy-Connection: keep-alive\r\nAccept: */*\r\n\r\n"; - let result = - rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); - let result_str = String::from_utf8_lossy(&result); - assert!( - !result_str - .to_ascii_lowercase() - .contains("proxy-authorization") - ); - assert!(!result_str.to_ascii_lowercase().contains("proxy-connection")); - assert!(result_str.contains("Accept: */*")); - } - - #[test] - fn test_rewrite_replaces_connection_header() { - let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nConnection: keep-alive\r\n\r\n"; - let result = - rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); - let result_str = String::from_utf8_lossy(&result); - assert!(result_str.contains("Connection: close")); - assert!(!result_str.contains("keep-alive")); - } - - #[test] - fn test_rewrite_preserves_body_overflow() { - let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 13\r\n\r\n{\"key\":\"val\"}"; - let result = - rewrite_forward_request(raw, raw.len(), "/api", None, false).expect("should succeed"); - let result_str = String::from_utf8_lossy(&result); - assert!(result_str.contains("{\"key\":\"val\"}")); - assert!(result_str.contains("POST /api HTTP/1.1")); - } - - #[test] - fn test_rewrite_preserves_existing_via() { - let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nVia: 1.0 upstream\r\n\r\n"; - let result = - rewrite_forward_request(raw, raw.len(), "/p", None, false).expect("should succeed"); - let result_str = String::from_utf8_lossy(&result); - assert!(result_str.contains("Via: 1.0 upstream")); - // Should not add a second Via header - assert!(!result_str.contains("Via: 1.1 openshell-sandbox")); - } - - #[test] - fn test_rewrite_forward_request_uses_canonical_path_on_the_wire() { - // Regression: the forward-proxy caller must canonicalize first and - // then pass the canonical form to rewrite_forward_request so that - // OPA's policy evaluation and the bytes dispatched to the upstream - // agree. Prior to this guarantee, OPA saw the canonical form while - // the upstream re-normalized the raw path independently, re-opening - // the parser-differential this PR closes. - let raw = b"GET http://host/public/../secret HTTP/1.1\r\nHost: host\r\n\r\n"; - let (canon, _) = crate::l7::path::canonicalize_request_target( - "/public/../secret", - &crate::l7::path::CanonicalizeOptions::default(), - ) - .expect("canonicalization should succeed for the attack payload"); - assert_eq!(canon.path, "/secret"); - - let rewritten = rewrite_forward_request(raw, raw.len(), &canon.path, None, false) - .expect("rewrite_forward_request should succeed"); - let rewritten_str = String::from_utf8_lossy(&rewritten); - assert!( - rewritten_str.starts_with("GET /secret HTTP/1.1\r\n"), - "outbound request line must use canonical path, got: {rewritten_str:?}" - ); - assert!( - !rewritten_str.contains(".."), - "outbound bytes must not leak the pre-canonical form, got: {rewritten_str:?}" - ); - } - - #[test] - fn test_rewrite_forward_request_preserves_canonical_query_on_the_wire() { - let raw = b"GET http://host/public/../graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D HTTP/1.1\r\nHost: host\r\n\r\n"; - let (canon, raw_query) = crate::l7::path::canonicalize_request_target( - "/public/../graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D", - &crate::l7::path::CanonicalizeOptions::default(), - ) - .expect("canonicalization should preserve query separately"); - let upstream_target = match raw_query.as_deref() { - Some(raw_query) if !raw_query.is_empty() => format!("{}?{raw_query}", canon.path), - _ => canon.path, - }; - - let rewritten = rewrite_forward_request(raw, raw.len(), &upstream_target, None, false) - .expect("rewrite_forward_request should succeed"); - let rewritten_str = String::from_utf8_lossy(&rewritten); - assert!( - rewritten_str.starts_with( - "GET /graphql?query=query+Viewer+%7B+viewer+%7B+login+%7D+%7D HTTP/1.1\r\n" - ), - "outbound request line must preserve canonical query, got: {rewritten_str:?}" - ); - } - - #[test] - fn test_rewrite_resolves_placeholder_auth_headers() { - let (_, resolver) = SecretResolver::from_provider_env( - [("ANTHROPIC_API_KEY".to_string(), "sk-test".to_string())] - .into_iter() - .collect(), - ); - let raw = b"GET http://host/p HTTP/1.1\r\nHost: host\r\nAuthorization: Bearer openshell:resolve:env:ANTHROPIC_API_KEY\r\n\r\n"; - let result = rewrite_forward_request(raw, raw.len(), "/p", resolver.as_ref(), false) - .expect("should succeed"); - let result_str = String::from_utf8_lossy(&result); - assert!(result_str.contains("Authorization: Bearer sk-test")); - assert!(!result_str.contains("openshell:resolve:env:ANTHROPIC_API_KEY")); - } - - #[tokio::test] - async fn forward_relay_rewrites_urlencoded_body_alias_from_initial_read() { - let (_, resolver) = SecretResolver::from_provider_env( - [("API_TOKEN".to_string(), "provider-real-token".to_string())] - .into_iter() - .collect(), - ); - let resolver = resolver.expect("resolver"); - let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; - let body = format!("token={alias}&channel=C123"); - let raw = format!( - "POST http://api.example.com/api/messages HTTP/1.1\r\n\ - Host: api.example.com\r\n\ - Authorization: Bearer {alias}\r\n\ - Content-Type: application/x-www-form-urlencoded\r\n\ - Content-Length: {}\r\n\r\n{}", - body.len(), - body - ); - - let forwarded = relay_forward_request_and_capture( - "POST", - "/api/messages", - raw.as_bytes(), - Some(&resolver), - true, - ) - .await - .expect("forward relay should rewrite credentials"); - - let expected_body = "token=provider-real-token&channel=C123"; - assert!(forwarded.starts_with("POST /api/messages HTTP/1.1\r\n")); - assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); - assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); - assert!(forwarded.ends_with(expected_body)); - assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); - } - - #[tokio::test] - async fn forward_relay_rewrites_urlencoded_canonical_body_from_initial_read() { - let (_, resolver) = SecretResolver::from_provider_env( - [("API_TOKEN".to_string(), "provider-real-token".to_string())] - .into_iter() - .collect(), - ); - let resolver = resolver.expect("resolver"); - let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; - let body = "token=openshell%3Aresolve%3Aenv%3AAPI_TOKEN&channel=C123"; - let raw = format!( - "POST http://api.example.com/api/messages HTTP/1.1\r\n\ - Host: api.example.com\r\n\ - Authorization: Bearer {alias}\r\n\ - Content-Type: application/x-www-form-urlencoded\r\n\ - Content-Length: {}\r\n\r\n{}", - body.len(), - body - ); - - let forwarded = relay_forward_request_and_capture( - "POST", - "/api/messages", - raw.as_bytes(), - Some(&resolver), - true, - ) - .await - .expect("forward relay should rewrite credentials"); - - let expected_body = "token=provider-real-token&channel=C123"; - assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); - assert!(forwarded.contains(&format!("Content-Length: {}\r\n", expected_body.len()))); - assert!(forwarded.ends_with(expected_body)); - assert!(!forwarded.contains("openshell%3Aresolve%3Aenv%3AAPI_TOKEN")); - assert!(!forwarded.contains("openshell:resolve:env:API_TOKEN")); - } - - #[tokio::test] - async fn forward_relay_unresolved_body_placeholder_fails_before_upstream_write() { - let (_, resolver) = SecretResolver::from_provider_env( - [("API_TOKEN".to_string(), "provider-real-token".to_string())] - .into_iter() - .collect(), - ); - let resolver = resolver.expect("resolver"); - let alias = "provider-OPENSHELL-RESOLVE-ENV-API_TOKEN"; - let body = "token=provider-OPENSHELL-RESOLVE-ENV-MISSING_TOKEN"; - let raw = format!( - "POST http://api.example.com/api/messages HTTP/1.1\r\n\ - Host: api.example.com\r\n\ - Authorization: Bearer {alias}\r\n\ - Content-Type: application/x-www-form-urlencoded\r\n\ - Content-Length: {}\r\n\r\n{}", - body.len(), - body - ); - let guard = forward_test_guard(); - let rewritten = rewrite_forward_request( - raw.as_bytes(), - raw.len(), - "/api/messages", - Some(&resolver), - true, - ) - .expect("header rewrite should defer body overflow to body rewriter"); - let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - - let err = relay_rewritten_forward_request( - "POST", - "/api/messages", - rewritten, - &mut proxy_to_client, - &mut proxy_to_upstream, - ForwardRelayOptions { - generation_guard: &guard, - websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, - secret_resolver: Some(&resolver), - request_body_credential_rewrite: true, - }, - ) - .await - .expect_err("unresolved body placeholder should fail closed"); - - assert!(!err.to_string().contains("provider-real-token")); - assert!(!err.to_string().contains("MISSING_TOKEN")); - drop(proxy_to_upstream); - let mut forwarded = Vec::new(); - upstream_side.read_to_end(&mut forwarded).await.unwrap(); - assert!( - forwarded.is_empty(), - "failed forward body rewrite must not reach upstream" - ); - } - - #[test] - fn test_forward_rewrite_preserves_websocket_upgrade_connection_header() { - let raw = "GET http://gateway.example.test/ws HTTP/1.1\r\n\ - Host: gateway.example.test\r\n\ - Upgrade: websocket\r\n\ - Connection: keep-alive, Upgrade\r\n\ - Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\ - Sec-WebSocket-Extensions: permessage-deflate; client_no_context_takeover\r\n\ - Sec-WebSocket-Version: 13\r\n\r\n"; - - let result = rewrite_forward_request(raw.as_bytes(), raw.len(), "/ws", None, false) - .expect("websocket forward rewrite should succeed"); - let result_str = String::from_utf8_lossy(&result); - - assert!(result_str.starts_with("GET /ws HTTP/1.1\r\n")); - assert!(result_str.contains("Connection: keep-alive, Upgrade\r\n")); - assert!( - !result_str.contains("Connection: close\r\n"), - "websocket forward proxy must not strip the upgrade token" - ); - } - - #[tokio::test] - async fn test_forward_relay_guard_blocks_stale_generation_before_upstream_write() { - let policy = include_str!("../data/sandbox-policy.rego"); - let policy_data = "network_policies: {}\n"; - let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); - let guard = engine - .generation_guard(engine.current_generation()) - .unwrap(); - engine.reload(policy, policy_data).unwrap(); - - let raw = b"GET http://host/api HTTP/1.1\r\nHost: host\r\n\r\n"; - let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) - .expect("rewrite should succeed"); - let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - - let result = relay_rewritten_forward_request( - "GET", - "/api", - rewritten, - &mut proxy_to_client, - &mut proxy_to_upstream, - ForwardRelayOptions { - generation_guard: &guard, - websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, - secret_resolver: None, - request_body_credential_rewrite: false, - }, - ) - .await; - assert!( - result.is_err(), - "stale generation must stop forward relay before upstream write" - ); - - drop(proxy_to_upstream); - let mut forwarded = Vec::new(); - upstream_side.read_to_end(&mut forwarded).await.unwrap(); - assert!( - forwarded.is_empty(), - "stale forward request bytes must not reach upstream" - ); - } - - #[tokio::test] - async fn test_forward_relay_rejects_cl_te_smuggling_before_upstream_write() { - let policy = include_str!("../data/sandbox-policy.rego"); - let policy_data = "network_policies: {}\n"; - let engine = OpaEngine::from_strings(policy, policy_data).unwrap(); - let guard = engine - .generation_guard(engine.current_generation()) - .unwrap(); - - let raw = b"POST http://host/api HTTP/1.1\r\nHost: host\r\nContent-Length: 4\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n"; - let rewritten = rewrite_forward_request(raw, raw.len(), "/api", None, false) - .expect("rewrite should succeed"); - let (mut proxy_to_upstream, mut upstream_side) = tokio::io::duplex(8192); - let (mut _app_side, mut proxy_to_client) = tokio::io::duplex(8192); - - let result = relay_rewritten_forward_request( - "POST", - "/api", - rewritten, - &mut proxy_to_client, - &mut proxy_to_upstream, - ForwardRelayOptions { - generation_guard: &guard, - websocket_extensions: crate::l7::rest::WebSocketExtensionMode::Preserve, - secret_resolver: None, - request_body_credential_rewrite: false, - }, - ) - .await; - assert!(result.is_err(), "forward relay must reject CL/TE ambiguity"); - - drop(proxy_to_upstream); - let mut forwarded = Vec::new(); - upstream_side.read_to_end(&mut forwarded).await.unwrap(); - assert!( - forwarded.is_empty(), - "smuggled forward request bytes must not reach upstream" - ); - } - - // --- Forward proxy SSRF defence tests --- - // - // The forward proxy handler uses the same SSRF logic as the CONNECT path: - // - No allowed_ips: resolve_and_reject_internal blocks private IPs, allows public. - // - With allowed_ips: resolve_and_check_allowed_ips validates against allowlist. - // - // These tests document that contract for the forward proxy path specifically. - - #[tokio::test] - async fn test_forward_public_ip_allowed_without_allowed_ips() { - // Public IPs (e.g. dns.google -> 8.8.8.8) should pass through - // resolve_and_reject_internal without needing allowed_ips. - let result = resolve_and_reject_internal("dns.google", 80, 0).await; - assert!( - result.is_ok(), - "Public IP should be allowed without allowed_ips: {result:?}" - ); - let addrs = result.unwrap(); - assert!(!addrs.is_empty(), "Should resolve to at least one address"); - // All resolved addresses should be public. - for addr in &addrs { - assert!( - !is_internal_ip(addr.ip()), - "dns.google should resolve to public IPs, got {}", - addr.ip() - ); - } - } - - #[tokio::test] - async fn test_forward_private_ip_rejected_without_allowed_ips() { - // Private IP literals should be rejected by resolve_and_reject_internal. - let result = resolve_and_reject_internal("10.0.0.1", 80, 0).await; - assert!( - result.is_err(), - "Private IP should be rejected without allowed_ips" - ); - let err = result.unwrap_err(); - assert!( - err.contains("internal address"), - "expected 'internal address' in error: {err}" - ); - } - - #[tokio::test] - async fn test_forward_private_ip_accepted_with_allowed_ips() { - // Private IP with matching allowed_ips should pass through. - let nets = parse_allowed_ips(&["10.0.0.0/8".to_string()]).unwrap(); - let result = resolve_and_check_allowed_ips("10.0.0.1", 80, &nets, 0).await; - assert!( - result.is_ok(), - "Private IP with matching allowed_ips should be accepted: {result:?}" - ); - } - - #[tokio::test] - async fn test_forward_private_ip_rejected_with_wrong_allowed_ips() { - // Private IP not in allowed_ips should be rejected. - let nets = parse_allowed_ips(&["192.168.0.0/16".to_string()]).unwrap(); - let result = resolve_and_check_allowed_ips("10.0.0.1", 80, &nets, 0).await; - assert!( - result.is_err(), - "Private IP not in allowed_ips should be rejected" - ); - let err = result.unwrap_err(); - assert!( - err.contains("not in allowed_ips"), - "expected 'not in allowed_ips' in error: {err}" - ); - } - - #[tokio::test] - async fn test_forward_loopback_always_blocked_even_with_allowed_ips() { - // Loopback addresses are always blocked, even if in allowed_ips. - // Construct nets directly (parse_allowed_ips now rejects always-blocked). - let nets = vec!["127.0.0.0/8".parse::().unwrap()]; - let result = resolve_and_check_allowed_ips("127.0.0.1", 80, &nets, 0).await; - assert!(result.is_err(), "Loopback should be always blocked"); - let err = result.unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected 'always-blocked' in error: {err}" - ); - } - - #[tokio::test] - async fn test_forward_link_local_always_blocked_even_with_allowed_ips() { - // Link-local / cloud metadata addresses are always blocked. - // Construct nets directly (parse_allowed_ips now rejects always-blocked). - let nets = vec!["169.254.0.0/16".parse::().unwrap()]; - let result = resolve_and_check_allowed_ips("169.254.169.254", 80, &nets, 0).await; - assert!(result.is_err(), "Link-local should be always blocked"); - let err = result.unwrap_err(); - assert!( - err.contains("always-blocked"), - "expected 'always-blocked' in error: {err}" - ); - } - - // -- implicit_allowed_ips_for_ip_host -- - - #[test] - fn test_implicit_allowed_ips_returns_ip_for_ipv4_literal() { - let result = implicit_allowed_ips_for_ip_host("192.168.1.100"); - assert_eq!(result, vec!["192.168.1.100"]); - } - - #[test] - fn test_implicit_allowed_ips_skips_ipv6_loopback() { - // ::1 is always-blocked, so implicit allowed_ips should be empty. - let result = implicit_allowed_ips_for_ip_host("::1"); - assert!(result.is_empty()); - } - - #[test] - fn test_implicit_allowed_ips_returns_empty_for_hostname() { - let result = implicit_allowed_ips_for_ip_host("api.github.com"); - assert!(result.is_empty()); - } - - #[test] - fn test_implicit_allowed_ips_returns_empty_for_wildcard() { - let result = implicit_allowed_ips_for_ip_host("*.example.com"); - assert!(result.is_empty()); - } - - /// Regression test: exercises the actual keep-alive interception loop to - /// verify that a non-inference request is denied even after a previous - /// inference request was successfully routed on the same connection. - /// - /// Before the fix, `handle_inference_interception` used - /// `else if !routed_any` which silently dropped denials once `routed_any` - /// was true, allowing non-inference HTTP requests to piggyback on a - /// keep-alive connection that had previously handled inference traffic. - /// Regression test: exercises the actual keep-alive interception loop to - /// verify that a non-inference request is denied even after a previous - /// inference request was successfully routed on the same connection. - /// - /// The server runs in a spawned task with empty routes (the inference - /// request gets a 503 "not configured" but is still recognized as - /// inference and returns Ok(true)). The client sends the inference - /// request, reads the 503 response, then sends a non-inference request - /// on the same connection. The server must return Denied. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_keepalive_denies_non_inference_after_routed() { - use openshell_router::Router; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - let router = Router::new().unwrap(); - let patterns = crate::l7::inference::default_patterns(); - // Empty routes: inference request gets 503 but returns Ok(true). - let ctx = InferenceContext::new(patterns, router, vec![], vec![]); - - let body = r#"{"model":"test","messages":[{"role":"user","content":"hi"}]}"#; - let inference_req = format!( - "POST /v1/chat/completions HTTP/1.1\r\n\ - Host: inference.local\r\n\ - Content-Type: application/json\r\n\ - Content-Length: {}\r\n\r\n{}", - body.len(), - body, - ); - let non_inference_req = "GET /admin/config HTTP/1.1\r\nHost: inference.local\r\n\r\n"; - - let (client, mut server) = tokio::io::duplex(65536); - let (mut client_read, mut client_write) = tokio::io::split(client); - - // Spawn the server task so it runs concurrently. - let server_task = - tokio::spawn(async move { process_inference_keepalive(&mut server, &ctx, 443).await }); - - // Client: send inference request, read response, send non-inference. - client_write - .write_all(inference_req.as_bytes()) - .await - .unwrap(); - - // Read the 503 response so the server loops back to read. - let mut buf = vec![0u8; 4096]; - let _ = client_read.read(&mut buf).await.unwrap(); - - // Send non-inference request on the same keep-alive connection. - client_write - .write_all(non_inference_req.as_bytes()) - .await - .unwrap(); - drop(client_write); - - // Drain remaining response bytes. - tokio::spawn(async move { - let mut buf = vec![0u8; 4096]; - loop { - match client_read.read(&mut buf).await { - Ok(0) | Err(_) => break, - Ok(_) => continue, - } - } - }); - - let outcome = server_task.await.unwrap().unwrap(); - - assert!( - matches!(outcome, InferenceOutcome::Denied { .. }), - "expected Denied after non-inference request on keep-alive, got: {outcome:?}" - ); - } - - // -- build_json_error_response -- - - #[test] - fn test_json_error_response_403() { - let resp = build_json_error_response( - 403, - "Forbidden", - "policy_denied", - "CONNECT api.example.com:443 not permitted by policy", - ); - let resp_str = String::from_utf8(resp).unwrap(); - - assert!(resp_str.starts_with("HTTP/1.1 403 Forbidden\r\n")); - assert!(resp_str.contains("Content-Type: application/json\r\n")); - assert!(resp_str.contains("Connection: close\r\n")); - - // Extract body after \r\n\r\n - let body_start = resp_str.find("\r\n\r\n").unwrap() + 4; - let body: serde_json::Value = serde_json::from_str(&resp_str[body_start..]).unwrap(); - assert_eq!(body["error"], "policy_denied"); - assert_eq!( - body["detail"], - "CONNECT api.example.com:443 not permitted by policy" - ); - } - - #[test] - fn test_json_error_response_502() { - let resp = build_json_error_response( - 502, - "Bad Gateway", - "upstream_unreachable", - "connection to api.example.com:443 failed", - ); - let resp_str = String::from_utf8(resp).unwrap(); - - assert!(resp_str.starts_with("HTTP/1.1 502 Bad Gateway\r\n")); - - let body_start = resp_str.find("\r\n\r\n").unwrap() + 4; - let body: serde_json::Value = serde_json::from_str(&resp_str[body_start..]).unwrap(); - assert_eq!(body["error"], "upstream_unreachable"); - assert_eq!(body["detail"], "connection to api.example.com:443 failed"); - } - - #[test] - fn test_json_error_response_content_length_matches() { - let resp = build_json_error_response(403, "Forbidden", "test", "detail"); - let resp_str = String::from_utf8(resp).unwrap(); - - // Extract Content-Length value - let cl_line = resp_str - .lines() - .find(|l| l.starts_with("Content-Length:")) - .unwrap(); - let cl: usize = cl_line.split(": ").nth(1).unwrap().trim().parse().unwrap(); - - // Verify body length matches - let body_start = resp_str.find("\r\n\r\n").unwrap() + 4; - assert_eq!(resp_str[body_start..].len(), cl); - } - - /// End-to-end regression for the `docker cp` hot-swap hazard that - /// motivated `binary_path()` stripping the kernel's `" (deleted)"` - /// suffix (PR #844). - /// - /// Before the strip, the identity-resolution chain inside - /// `evaluate_opa_tcp` failed with `"Failed to stat - /// /opt/openshell/bin/openshell-sandbox (deleted)"` because - /// `BinaryIdentityCache::verify_or_cache()` tried to `metadata()` the - /// tainted path. That masked the real security signal: a live process - /// was now bound to a *different* binary on disk than the one that was - /// TOFU-cached. After the strip, `binary_path()` returns a path that - /// stats fine, the cache rehashes the new bytes, and the hash mismatch - /// surfaces as a `Binary integrity violation` error — the contract this - /// PR is trying to establish. - /// - /// Test shape (from the review comment on the initial PR): - /// 1. Start a `TcpListener` in the test process. - /// 2. Copy `/bin/bash` to a temp path we control. - /// 3. Prime `BinaryIdentityCache` with that temp binary's hash. - /// 4. Spawn the temp bash as a child with a `/dev/tcp` one-liner that - /// opens a real TCP connection to the listener and holds it open. - /// 5. Accept the connection on the listener side and capture the peer's - /// ephemeral port — that's what `resolve_process_identity` uses to - /// walk `/proc/net/tcp` back to the child PID. - /// 6. Overwrite the temp bash on disk with different bytes to simulate - /// a `docker cp` hot-swap. The running child is unaffected (it still - /// executes from its in-memory image), but `/proc//exe` will - /// now readlink to `" (deleted)"` OR the overwritten file, depending - /// on whether the filesystem reused the inode. - /// 7. Call `resolve_process_identity` and assert: - /// - the error reason contains `"Binary integrity violation"` (the - /// cache detected the tampered on-disk bytes), and - /// - the error reason does NOT contain `"Failed to stat"` or - /// `"(deleted)"` (the old pre-strip failure mode). - #[cfg(target_os = "linux")] - #[test] - fn resolve_process_identity_surfaces_binary_integrity_violation_on_hot_swap() { - use crate::identity::BinaryIdentityCache; - use std::io::Read; - use std::net::TcpListener; - use std::os::unix::fs::PermissionsExt; - use std::process::{Command, Stdio}; - use std::time::Duration; - - // Skip if /bin/bash is not present (e.g. minimal containers). - if !std::path::Path::new("/bin/bash").exists() { - eprintln!("skipping: /bin/bash not available"); - return; - } - - // 1. Start a listener on loopback. - let listener = TcpListener::bind("127.0.0.1:0").expect("bind"); - let listener_port = listener.local_addr().unwrap().port(); - - // 2. Copy /bin/bash to a temp path. - let tmp = tempfile::TempDir::new().unwrap(); - let bash_v1 = tmp.path().join("hotswap-bash"); - std::fs::copy("/bin/bash", &bash_v1).expect("copy bash"); - std::fs::set_permissions(&bash_v1, std::fs::Permissions::from_mode(0o755)).unwrap(); - - // 3. Prime the cache with the v1 hash of the temp bash. - let cache = BinaryIdentityCache::new(); - let v1_hash = cache - .verify_or_cache(&bash_v1) - .expect("prime cache with v1 bash hash"); - assert!(!v1_hash.is_empty()); - - // 4. Spawn the temp bash with a /dev/tcp one-liner that opens a real - // connection to the listener and sleeps to keep it open. The - // `read -t` blocks on stdin so the shell stays resident. - let script = format!("exec 3<>/dev/tcp/127.0.0.1/{listener_port}; sleep 30 <&3"); - let mut child = Command::new(&bash_v1) - .arg("-c") - .arg(&script) - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .expect("spawn hotswap-bash child"); - - // 5. Accept on the listener side, capture the peer port. - listener.set_nonblocking(false).expect("blocking listener"); - let (mut stream, peer_addr) = match listener.accept() { - Ok(pair) => pair, - Err(e) => { - let _ = child.kill(); - let _ = child.wait(); - panic!("failed to accept child connection: {e}"); - } - }; - let peer_port = peer_addr.port(); - // Drain any spurious data; we just need the socket open. - stream - .set_read_timeout(Some(Duration::from_millis(50))) - .ok(); - let mut buf = [0u8; 16]; - let _ = stream.read(&mut buf); - - // Give the kernel a moment so /proc//net/tcp and - // /proc//fd/ both reflect the ESTABLISHED socket. - std::thread::sleep(Duration::from_millis(50)); - - // 6. Simulate `docker cp`: unlink the running binary and create a - // fresh file with different bytes at the same path. Writing - // in place via O_TRUNC is rejected by the kernel with ETXTBSY - // because the inode is still being executed. Unlink is cheap: - // the inode persists in memory via the child's exec mapping, - // so the child keeps running, but a new inode now lives at - // `bash_v1` with a different SHA-256. - std::fs::remove_file(&bash_v1).expect("unlink running bash_v1"); - let tampered_bytes = b"#!/bin/sh\n# tampered bash v2 from hotswap test\nexit 0\n"; - std::fs::write(&bash_v1, tampered_bytes).expect("write replacement bytes"); - - // 7. Resolve identity through the real helper and assert the - // contract: we want "Binary integrity violation", not - // "Failed to stat ... (deleted)". - let test_pid = std::process::id(); - let result = resolve_process_identity(test_pid, peer_port, &cache); - - // Always clean up the child before asserting so a failure doesn't - // leak a sleeping process across test runs. - let _ = child.kill(); - let _ = child.wait(); - - match result { - Ok(_) => panic!( - "resolve_process_identity unexpectedly succeeded after hot-swap; \ - the cache should have detected the tampered on-disk bytes" - ), - Err(err) => { - assert!( - err.reason.contains("Binary integrity violation"), - "expected 'Binary integrity violation' error, got: {}", - err.reason - ); - assert!( - !err.reason.contains("Failed to stat"), - "pre-PR-#844 failure mode leaked: {}", - err.reason - ); - assert!( - !err.reason.contains("(deleted)"), - "resolved path still contains '(deleted)' suffix: {}", - err.reason - ); - // The binary field should be populated — we did resolve a - // path before failing. - assert!( - err.binary.is_some(), - "expected resolved binary path on integrity failure" - ); - if let Some(path) = &err.binary { - assert!( - !path.to_string_lossy().contains("(deleted)"), - "resolved binary path still tainted: {}", - path.display() - ); - } - } - } - } - - #[cfg(target_os = "linux")] - #[test] - // TODO: exec'ing /bin/sleep (SELinux label bin_t) from a user_home_t test - // binary causes /proc//exe readlink to return ENOENT on - // SELinux-enforcing hosts. Fix by building a test-sleep-helper binary in - // the same crate so it inherits the user_home_t label. - fn resolve_process_identity_denies_fork_exec_shared_socket_ambiguity() { - use crate::identity::BinaryIdentityCache; - use std::ffi::CString; - use std::net::{TcpListener, TcpStream}; - use std::os::fd::AsRawFd; - use std::time::{Duration, Instant}; - - struct ChildGuard(libc::pid_t); - impl Drop for ChildGuard { - fn drop(&mut self) { - #[allow(unsafe_code)] - unsafe { - libc::kill(self.0, libc::SIGKILL); - libc::waitpid(self.0, std::ptr::null_mut(), 0); - } - } - } - - if !std::path::Path::new("/bin/sleep").exists() { - eprintln!("skipping: /bin/sleep not available"); - return; - } - - if std::process::Command::new("getenforce") - .output() - .is_ok_and(|o| String::from_utf8_lossy(&o.stdout).trim() == "Enforcing") - { - eprintln!( - "skipping: SELinux is enforcing — cross-label /proc//exe readlink fails" - ); - return; - } - - let listener = TcpListener::bind("127.0.0.1:0").expect("bind listener"); - let listener_port = listener.local_addr().unwrap().port(); - let stream = TcpStream::connect(("127.0.0.1", listener_port)).expect("connect"); - let peer_port = stream.local_addr().unwrap().port(); - let (_accepted, _) = listener.accept().expect("accept"); - - let fd = stream.as_raw_fd(); - // libc/syscall FFI requires unsafe - #[allow(unsafe_code)] - unsafe { - let flags = libc::fcntl(fd, libc::F_GETFD); - assert!(flags >= 0, "F_GETFD failed"); - assert_eq!( - libc::fcntl(fd, libc::F_SETFD, flags & !libc::FD_CLOEXEC), - 0, - "F_SETFD failed" - ); - } - - let sleep_path = CString::new("/bin/sleep").unwrap(); - let arg0 = CString::new("sleep").unwrap(); - let arg1 = CString::new("30").unwrap(); - // libc/syscall FFI requires unsafe - #[allow(unsafe_code)] - let child_pid = unsafe { libc::fork() }; - assert!(child_pid >= 0, "fork failed"); - if child_pid == 0 { - // libc/syscall FFI requires unsafe - #[allow(unsafe_code)] - unsafe { - libc::execl( - sleep_path.as_ptr(), - arg0.as_ptr(), - arg1.as_ptr(), - std::ptr::null::(), - ); - libc::_exit(127); - } - } - - let _guard = ChildGuard(child_pid); - let entrypoint_pid = std::process::id(); - - let deadline = Instant::now() + Duration::from_secs(5); - loop { - if let Ok(link) = std::fs::read_link(format!("/proc/{child_pid}/exe")) - && link.to_string_lossy().contains("sleep") - { - break; - } - assert!( - Instant::now() < deadline, - "child pid {child_pid} did not exec into sleep within 5s" - ); - std::thread::sleep(Duration::from_millis(20)); - } - - let cache = BinaryIdentityCache::new(); - - let mut result = resolve_process_identity(entrypoint_pid, peer_port, &cache); - for _ in 0..10 { - match &result { - Err(err) - if err.reason.contains("No such file or directory") - || err.reason.contains("os error 2") => - { - // /proc//fd scan transiently failed; give procfs time to settle. - std::thread::sleep(Duration::from_millis(50)); - result = resolve_process_identity(entrypoint_pid, peer_port, &cache); - } - Ok(_) => { - // On arm64 under heavy CI load the /proc fd scan can transiently - // miss the parent process's socket fd, making the scan return only - // the child as owner and yielding a spurious Ok. Retry to give - // both owners time to appear consistently in /proc//fd. - std::thread::sleep(Duration::from_millis(50)); - result = resolve_process_identity(entrypoint_pid, peer_port, &cache); - } - _ => break, - } - } - - match result { - Ok(identity) => panic!( - "resolve_process_identity unexpectedly succeeded for shared socket owned by PID {}", - identity.binary_pid - ), - Err(err) => { - assert!( - err.reason.contains("ambiguous shared socket ownership"), - "expected ambiguous socket ownership error, got: {}", - err.reason - ); - assert!( - err.reason.contains(&entrypoint_pid.to_string()), - "error should include parent PID; got: {}", - err.reason - ); - assert!( - err.reason.contains(&child_pid.to_string()), - "error should include child PID; got: {}", - err.reason - ); - } - } - } -} diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index afcd28863..38070911c 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -274,6 +274,15 @@ request_denied_for_endpoint(request, endpoint) if { command_matches(request.command, deny_rule.command) } +# --- L7 deny rule matching: JSON-RPC method + params --- + +request_denied_for_endpoint(request, endpoint) if { + some deny_rule + deny_rule := endpoint.deny_rules[_] + deny_rule.rpc_method + jsonrpc_rule_matches(request, deny_rule) +} + # --- L7 deny rule matching: GraphQL operation --- request_denied_for_endpoint(request, endpoint) if { @@ -417,6 +426,15 @@ request_allowed_for_endpoint(request, endpoint) if { command_matches(request.command, rule.allow.command) } +# --- L7 rule matching: JSON-RPC method --- + +request_allowed_for_endpoint(request, endpoint) if { + some rule + rule := endpoint.rules[_] + rule.allow.rpc_method + jsonrpc_rule_matches(request, rule.allow) +} + # --- L7 rule matching: GraphQL operation --- request_allowed_for_endpoint(request, endpoint) if { @@ -638,6 +656,35 @@ query_value_matches(value, matcher) if { glob.match(any_patterns[i], [], value) } +# JSON-RPC method and params matching. The sandbox flattens object params into +# dot-separated keys before policy evaluation, e.g. arguments.scope. +jsonrpc_rule_matches(request, rule) if { + jsonrpc := object.get(request, "jsonrpc", {}) + method := object.get(jsonrpc, "method", null) + method != null + glob.match(rule.rpc_method, [], method) + jsonrpc_params_match(jsonrpc, rule) +} + +jsonrpc_params_match(jsonrpc, rule) if { + param_rules := object.get(rule, "params", {}) + not jsonrpc_param_mismatch(jsonrpc, param_rules) +} + +jsonrpc_param_mismatch(jsonrpc, param_rules) if { + some key + matcher := param_rules[key] + not jsonrpc_param_key_matches(jsonrpc, key, matcher) +} + +jsonrpc_param_key_matches(jsonrpc, key, matcher) if { + params := object.get(jsonrpc, "params", {}) + value := object.get(params, key, null) + value != null + is_string(value) + query_value_matches(value, matcher) +} + # SQL command matching: "*" matches any; otherwise case-insensitive. command_matches(_, "*") if true diff --git a/crates/openshell-supervisor-network/src/l7/graphql.rs b/crates/openshell-supervisor-network/src/l7/graphql.rs index 82c35720e..12979f0b1 100644 --- a/crates/openshell-supervisor-network/src/l7/graphql.rs +++ b/crates/openshell-supervisor-network/src/l7/graphql.rs @@ -810,6 +810,7 @@ network_policies: target: req.target, query_params: req.query_params, graphql: Some(info), + jsonrpc: None, }; let tunnel_engine = engine diff --git a/crates/openshell-sandbox/src/l7/http.rs b/crates/openshell-supervisor-network/src/l7/http.rs similarity index 100% rename from crates/openshell-sandbox/src/l7/http.rs rename to crates/openshell-supervisor-network/src/l7/http.rs diff --git a/crates/openshell-sandbox/src/l7/jsonrpc.rs b/crates/openshell-supervisor-network/src/l7/jsonrpc.rs similarity index 100% rename from crates/openshell-sandbox/src/l7/jsonrpc.rs rename to crates/openshell-supervisor-network/src/l7/jsonrpc.rs diff --git a/crates/openshell-supervisor-network/src/l7/mod.rs b/crates/openshell-supervisor-network/src/l7/mod.rs index 802058ec2..563094dd0 100644 --- a/crates/openshell-supervisor-network/src/l7/mod.rs +++ b/crates/openshell-supervisor-network/src/l7/mod.rs @@ -9,7 +9,9 @@ //! evaluated against OPA policy, and either forwarded or denied. pub mod graphql; +pub(crate) mod http; pub mod inference; +pub mod jsonrpc; pub mod path; pub mod provider; pub mod relay; @@ -25,6 +27,7 @@ pub enum L7Protocol { Websocket, Graphql, Sql, + JsonRpc, } impl L7Protocol { @@ -34,6 +37,7 @@ impl L7Protocol { "websocket" => Some(Self::Websocket), "graphql" => Some(Self::Graphql), "sql" => Some(Self::Sql), + "json-rpc" => Some(Self::JsonRpc), _ => None, } } @@ -76,6 +80,8 @@ pub struct L7EndpointConfig { pub enforcement: EnforcementMode, /// Maximum GraphQL request body bytes to buffer for inspection. pub graphql_max_body_bytes: usize, + /// Maximum JSON-RPC request body bytes to buffer for inspection. + pub json_rpc_max_body_bytes: usize, /// When true, percent-encoded `/` (`%2F`) is preserved in path segments /// rather than rejected at the parser. Needed by upstreams like GitLab /// that embed `%2F` in namespaced project paths. Defaults to false. @@ -110,6 +116,8 @@ pub struct L7RequestInfo { pub query_params: std::collections::HashMap>, /// Parsed GraphQL operation metadata for GraphQL endpoints. pub graphql: Option, + /// Parsed JSON-RPC request metadata for JSON-RPC endpoints. + pub jsonrpc: Option, } /// Parse an L7 endpoint config from a regorus Value (returned by Rego query). @@ -165,6 +173,10 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { .and_then(|v| usize::try_from(v).ok()) .filter(|v| *v > 0) .unwrap_or(graphql::DEFAULT_MAX_BODY_BYTES); + let json_rpc_max_body_bytes = get_object_u64(val, "json_rpc_max_body_bytes") + .and_then(|v| usize::try_from(v).ok()) + .filter(|v| *v > 0) + .unwrap_or(jsonrpc::DEFAULT_MAX_BODY_BYTES); Some(L7EndpointConfig { protocol, @@ -172,6 +184,7 @@ pub fn parse_l7_config(val: ®orus::Value) -> Option { tls, enforcement, graphql_max_body_bytes, + json_rpc_max_body_bytes, allow_encoded_slash, websocket_credential_rewrite, request_body_credential_rewrite, @@ -598,7 +611,7 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< if !protocol.is_empty() && L7Protocol::parse(protocol).is_none() { errors.push(format!( - "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, or sql)" + "{loc}: unknown protocol '{protocol}' (expected rest, websocket, graphql, sql, or json-rpc)" )); } @@ -624,6 +637,18 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< } } + if ep.get("json_rpc_max_body_bytes").is_some() { + let valid_max = ep + .get("json_rpc_max_body_bytes") + .and_then(serde_json::Value::as_u64) + .is_some_and(|v| v > 0); + if !valid_max { + errors.push(format!( + "{loc}: json_rpc_max_body_bytes must be a positive integer" + )); + } + } + if protocol != "graphql" && protocol != "websocket" && (ep.get("persisted_queries").is_some() @@ -635,6 +660,12 @@ pub fn validate_l7_policies(data_json: &serde_json::Value) -> (Vec, Vec< )); } + if protocol != "json-rpc" && ep.get("json_rpc_max_body_bytes").is_some() { + warnings.push(format!( + "{loc}: JSON-RPC-specific endpoint fields are ignored unless protocol is json-rpc" + )); + } + if ep .get("websocket_credential_rewrite") .and_then(serde_json::Value::as_bool) diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 3054a4530..1251b8e7b 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -178,6 +178,7 @@ where .into_diagnostic()?; Ok(()) } + L7Protocol::JsonRpc => relay_jsonrpc(config, &engine, client, upstream, ctx).await, } } @@ -297,6 +298,7 @@ where target: redacted_target.clone(), query_params: req.query_params.clone(), graphql: graphql_info.clone(), + jsonrpc: None, }; let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); if config.protocol == L7Protocol::Websocket && !websocket_request { @@ -341,7 +343,7 @@ where let engine_type = match config.protocol { L7Protocol::Graphql => "l7-graphql", L7Protocol::Websocket => "l7-websocket", - L7Protocol::Rest | L7Protocol::Sql => "l7", + L7Protocol::Rest | L7Protocol::Sql | L7Protocol::JsonRpc => "l7", }; emit_l7_request_log( ctx, @@ -694,6 +696,7 @@ where target: redacted_target.clone(), query_params: req.query_params.clone(), graphql: None, + jsonrpc: None, }; let websocket_request = crate::l7::rest::request_is_websocket_upgrade(&req.raw_header); if config.protocol == L7Protocol::Websocket && !websocket_request { @@ -885,6 +888,173 @@ fn close_if_stale(guard: &PolicyGenerationGuard, ctx: &L7EvalContext) -> bool { true } +async fn relay_jsonrpc( + config: &L7EndpointConfig, + engine: &TunnelPolicyEngine, + client: &mut C, + upstream: &mut U, + ctx: &L7EvalContext, +) -> Result<()> +where + C: AsyncRead + AsyncWrite + Unpin + Send, + U: AsyncRead + AsyncWrite + Unpin + Send, +{ + loop { + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let parsed = match crate::l7::jsonrpc::parse_jsonrpc_http_request( + client, + config.json_rpc_max_body_bytes, + crate::l7::path::CanonicalizeOptions { + allow_encoded_slash: config.allow_encoded_slash, + ..Default::default() + }, + ) + .await + { + Ok(Some(parsed)) => parsed, + Ok(None) => return Ok(()), + Err(e) => { + if is_benign_connection_error(&e) { + debug!( + host = %ctx.host, + port = ctx.port, + error = %e, + "JSON-RPC L7 connection closed" + ); + } else { + let detail = + parse_rejection_detail(&e.to_string(), ParseRejectionMode::L7Endpoint); + emit_parse_rejection(ctx, &detail, "l7-jsonrpc"); + } + return Ok(()); + } + }; + + let req = parsed.request; + let jsonrpc_info = parsed.info; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let redacted_target = req.target.clone(); + + let request_info = L7RequestInfo { + action: req.action.clone(), + target: redacted_target.clone(), + query_params: req.query_params.clone(), + graphql: None, + jsonrpc: Some(jsonrpc_info.clone()), + }; + + let parse_error_reason = jsonrpc_info + .error + .as_deref() + .map(|e| format!("JSON-RPC request rejected: {e}")); + let force_deny = parse_error_reason.is_some(); + let (allowed, reason, jsonrpc_log_info) = if let Some(reason) = parse_error_reason { + (false, reason, jsonrpc_info.clone()) + } else { + let evaluation = + evaluate_jsonrpc_l7_request_for_log(engine, ctx, &request_info, &jsonrpc_info)?; + (evaluation.allowed, evaluation.reason, evaluation.log_info) + }; + + if close_if_stale(engine.generation_guard(), ctx) { + return Ok(()); + } + + let decision_str = match (allowed, config.enforcement) { + (_, _) if force_deny => "deny", + (true, _) => "allow", + (false, EnforcementMode::Audit) => "audit", + (false, EnforcementMode::Enforce) => "deny", + }; + + { + let (action_id, disposition_id, severity) = match decision_str { + "deny" => (ActionId::Denied, DispositionId::Blocked, SeverityId::Medium), + _ => ( + ActionId::Allowed, + DispositionId::Allowed, + SeverityId::Informational, + ), + }; + let endpoint = format!("{}:{}{}", ctx.host, ctx.port, redacted_target); + let params_sha256 = jsonrpc_log_info + .params_sha256() + .unwrap_or_else(|| "".to_string()); + let policy_version = engine.captured_generation(); + let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) + .activity(ActivityId::Other) + .action(action_id) + .disposition(disposition_id) + .severity(severity) + .http_request(HttpRequest::new( + &request_info.action, + OcsfUrl::new("http", &ctx.host, &redacted_target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "l7-jsonrpc") + .message(jsonrpc_log_message( + decision_str, + &request_info.action, + &endpoint, + &jsonrpc_log_info, + ¶ms_sha256, + policy_version, + &reason, + )) + .build(); + ocsf_emit!(event); + } + + if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( + &req, + client, + upstream, + ctx.secret_resolver.as_deref(), + Some(engine.generation_guard()), + ) + .await?; + match outcome { + RelayOutcome::Reusable => {} + RelayOutcome::Consumed => { + debug!( + host = %ctx.host, + port = ctx.port, + "Upstream connection not reusable, closing JSON-RPC L7 relay" + ); + return Ok(()); + } + RelayOutcome::Upgraded { .. } => { + return Ok(()); + } + } + } else { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &req, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } +} + async fn relay_graphql( config: &L7EndpointConfig, engine: &TunnelPolicyEngine, @@ -962,6 +1132,7 @@ where target: redacted_target.clone(), query_params: req.query_params.clone(), graphql: Some(graphql_info.clone()), + jsonrpc: None, }; // Malformed or ambiguous GraphQL requests, such as duplicated GET @@ -1110,6 +1281,38 @@ fn graphql_log_summary(info: &crate::l7::graphql::GraphqlRequestInfo) -> String format!("graphql_ops={}", ops.join(";")) } +pub(crate) fn jsonrpc_log_message( + decision: &str, + http_method: &str, + endpoint: &str, + info: &crate::l7::jsonrpc::JsonRpcRequestInfo, + params_sha256: &str, + policy_version: u64, + reason: &str, +) -> String { + let rpc_methods = jsonrpc_methods_for_log(info); + format!( + "JSONRPC_L7_REQUEST decision={decision} http_method={http_method} endpoint={endpoint} rpc_methods={rpc_methods} params_sha256={params_sha256} policy_version={policy_version} reason={reason}" + ) +} + +pub(crate) fn jsonrpc_methods_for_log(info: &crate::l7::jsonrpc::JsonRpcRequestInfo) -> String { + if info.calls.is_empty() { + return "-".to_string(); + } + info.calls + .iter() + .map(|call| call.method.as_str()) + .collect::>() + .join(",") +} + +struct JsonRpcEvaluation { + allowed: bool, + reason: String, + log_info: crate::l7::jsonrpc::JsonRpcRequestInfo, +} + /// Check if a miette error represents a benign connection close. /// /// TLS handshake EOF, missing `close_notify`, connection resets, and broken @@ -1135,6 +1338,88 @@ pub fn evaluate_l7_request( engine: &TunnelPolicyEngine, ctx: &L7EvalContext, request: &L7RequestInfo, +) -> Result<(bool, String)> { + if let Some(jsonrpc) = &request.jsonrpc + && jsonrpc.is_batch + && !jsonrpc.calls.is_empty() + { + for call in &jsonrpc.calls { + let item_request = jsonrpc_request_for_call(request, call); + let (allowed, reason) = evaluate_l7_request_once(engine, ctx, &item_request)?; + if !allowed { + return Ok((false, reason)); + } + } + return Ok((true, String::new())); + } + + evaluate_l7_request_once(engine, ctx, request) +} + +fn evaluate_jsonrpc_l7_request_for_log( + engine: &TunnelPolicyEngine, + ctx: &L7EvalContext, + request: &L7RequestInfo, + jsonrpc: &crate::l7::jsonrpc::JsonRpcRequestInfo, +) -> Result { + if jsonrpc.is_batch && !jsonrpc.calls.is_empty() { + let mut denied_calls = Vec::new(); + let mut first_denied_reason = None; + for call in &jsonrpc.calls { + let item_request = jsonrpc_request_for_call(request, call); + let (allowed, reason) = evaluate_l7_request_once(engine, ctx, &item_request)?; + if !allowed { + if first_denied_reason.is_none() { + first_denied_reason = Some(reason); + } + denied_calls.push(call.clone()); + } + } + + if denied_calls.is_empty() { + return Ok(JsonRpcEvaluation { + allowed: true, + reason: String::new(), + log_info: jsonrpc.clone(), + }); + } + + return Ok(JsonRpcEvaluation { + allowed: false, + reason: first_denied_reason.unwrap_or_else(|| "request denied by policy".to_string()), + log_info: crate::l7::jsonrpc::JsonRpcRequestInfo { + calls: denied_calls, + is_batch: true, + error: None, + }, + }); + } + + let (allowed, reason) = evaluate_l7_request_once(engine, ctx, request)?; + Ok(JsonRpcEvaluation { + allowed, + reason, + log_info: jsonrpc.clone(), + }) +} + +fn jsonrpc_request_for_call( + request: &L7RequestInfo, + call: &crate::l7::jsonrpc::JsonRpcCallInfo, +) -> L7RequestInfo { + let mut item_request = request.clone(); + item_request.jsonrpc = Some(crate::l7::jsonrpc::JsonRpcRequestInfo { + calls: vec![call.clone()], + is_batch: false, + error: None, + }); + item_request +} + +fn evaluate_l7_request_once( + engine: &TunnelPolicyEngine, + ctx: &L7EvalContext, + request: &L7RequestInfo, ) -> Result<(bool, String)> { if engine.is_stale() { return Err(miette!( @@ -1159,6 +1444,14 @@ pub fn evaluate_l7_request( "path": request.target, "query_params": request.query_params.clone(), "graphql": request.graphql.clone(), + "jsonrpc": request.jsonrpc.as_ref().map(|j| { + let call = if j.is_batch { None } else { j.calls.first() }; + serde_json::json!({ + "method": call.map(|call| call.method.as_str()), + "params": call.map(|call| call.params.clone()).unwrap_or_default(), + "error": j.error, + }) + }), } }); @@ -1792,6 +2085,7 @@ network_policies: target: "/ws".into(), query_params: std::collections::HashMap::new(), graphql: None, + jsonrpc: None, }; let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); @@ -1800,6 +2094,180 @@ network_policies: assert!(reason.contains("WEBSOCKET_TEXT /ws not permitted")); } + #[test] + fn jsonrpc_batch_evaluates_each_call() { + let data = r#" +network_policies: + jsonrpc_api: + name: jsonrpc_api + endpoints: + - host: api.example.test + port: 443 + protocol: json-rpc + enforcement: enforce + rules: + - allow: + method: POST + path: "/mcp" + rpc_method: "tools/list" + - allow: + method: POST + path: "/mcp" + rpc_method: "tools/call" + params: + name: read_status + deny_rules: + - rpc_method: "tools/call" + params: + name: blocked_action + - rpc_method: "tools/delete" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let tunnel_engine = engine + .clone_engine_for_tunnel(engine.current_generation()) + .unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "jsonrpc_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let mut request = L7RequestInfo { + action: "POST".into(), + target: "/mcp".into(), + query_params: std::collections::HashMap::new(), + graphql: None, + jsonrpc: Some(crate::l7::jsonrpc::parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"read_status"}} + ]"#, + )), + }; + + let (allowed, reason) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + assert!(allowed, "{reason}"); + + request.jsonrpc = Some(crate::l7::jsonrpc::parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"blocked_action"}}, + {"jsonrpc":"2.0","id":3,"method":"tools/delete","params":{"name":"purge_cache"}} + ]"#, + )); + let (allowed, _) = evaluate_l7_request(&tunnel_engine, &ctx, &request).unwrap(); + assert!(!allowed); + + let jsonrpc = request.jsonrpc.as_ref().expect("jsonrpc request"); + let evaluation = + evaluate_jsonrpc_l7_request_for_log(&tunnel_engine, &ctx, &request, jsonrpc).unwrap(); + assert!(!evaluation.allowed); + assert!(evaluation.log_info.is_batch); + assert_eq!( + jsonrpc_methods_for_log(&evaluation.log_info), + "tools/call,tools/delete" + ); + + let full_params_sha256 = jsonrpc.params_sha256().expect("full batch params digest"); + let log_params_sha256 = evaluation + .log_info + .params_sha256() + .expect("logged batch params digest"); + assert_ne!(full_params_sha256, log_params_sha256); + let message = jsonrpc_log_message( + "deny", + "POST", + "api.example.test:443/mcp", + &evaluation.log_info, + &log_params_sha256, + 42, + &evaluation.reason, + ); + assert!(message.contains("rpc_methods=tools/call,tools/delete")); + assert!(message.contains("params_sha256=")); + assert!(!message.contains("params_sha256=sha256:")); + assert!(message.contains("policy_version=42")); + assert!(!message.contains("tools/list")); + assert!(!message.contains("blocked_action")); + assert!(!message.contains("purge_cache")); + } + + #[test] + fn jsonrpc_log_records_digest_not_args() { + let info = crate::l7::jsonrpc::parse_jsonrpc_body( + br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"delete_resource","arguments":{"scope":"secret-scope"}}}"#, + ); + let params_sha256 = info.params_sha256().expect("params digest"); + let message = jsonrpc_log_message( + "deny", + "POST", + "mcp.example.com:443/mcp", + &info, + ¶ms_sha256, + 42, + "request denied by policy", + ); + + assert!(message.contains("endpoint=mcp.example.com:443/mcp")); + assert!(message.contains("rpc_methods=tools/call")); + assert!(message.contains("params_sha256=")); + assert!(!message.contains("params_sha256=sha256:")); + assert!(message.contains("policy_version=42")); + assert!(!message.contains("delete_resource")); + assert!(!message.contains("secret-scope")); + + let batch = crate::l7::jsonrpc::parse_jsonrpc_body( + br#"[ + {"jsonrpc":"2.0","id":1,"method":"tools/list"}, + {"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"delete_resource"}} + ]"#, + ); + let batch_params_sha256 = batch.params_sha256().expect("batch params digest"); + let batch_message = jsonrpc_log_message( + "allow", + "POST", + "mcp.example.com:443/mcp", + &batch, + &batch_params_sha256, + 43, + "", + ); + + assert!(batch_message.starts_with("JSONRPC_L7_REQUEST ")); + assert!(batch_message.contains("rpc_methods=tools/list,tools/call")); + assert!(batch_message.contains("params_sha256=")); + assert!(!batch_message.contains("params_sha256=sha256:")); + assert!(batch_message.contains("policy_version=43")); + assert!(!batch_message.contains("rpc_method=")); + assert!(!batch_message.contains("delete_resource")); + + let no_params = crate::l7::jsonrpc::parse_jsonrpc_body( + br#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#, + ); + let no_params_sha256 = no_params + .params_sha256() + .unwrap_or_else(|| "".to_string()); + let no_params_message = jsonrpc_log_message( + "allow", + "POST", + "mcp.example.com:443/mcp", + &no_params, + &no_params_sha256, + 44, + "", + ); + assert!(no_params_message.contains("rpc_methods=initialize")); + assert!(no_params_message.contains("params_sha256=")); + } + #[tokio::test] async fn route_selected_websocket_upgrade_rejects_invalid_accept_without_forwarding_101() { let data = r#" @@ -1828,6 +2296,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: EnforcementMode::Enforce, graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: true, request_body_credential_rewrite: false, @@ -1931,6 +2400,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: EnforcementMode::Enforce, graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: true, request_body_credential_rewrite: false, @@ -2051,6 +2521,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: EnforcementMode::Enforce, graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: true, request_body_credential_rewrite: false, @@ -2407,4 +2878,100 @@ network_policies: "stale passthrough request must not be forwarded upstream" ); } + + #[tokio::test] + async fn jsonrpc_relay_denies_method_not_in_allow_list() { + let data = r" +network_policies: + mcp_api: + name: mcp_api + endpoints: + - host: mcp.example.test + port: 8000 + path: /mcp + protocol: json-rpc + enforcement: enforce + rules: + - allow: + rpc_method: initialize + binaries: + - { path: /usr/bin/python3 } +"; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "mcp.example.test".into(), + port: 8000, + binary_path: PathBuf::from("/usr/bin/python3"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "mcp.example.test".into(), + port: 8000, + policy_name: "mcp_api".into(), + binary_path: "/usr/bin/python3".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = + br#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"list_repos"}}"#; + let request = format!( + "POST /mcp HTTP/1.1\r\nHost: mcp.example.test:8000\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + body.len() + ); + app.write_all(request.as_bytes()).await.unwrap(); + app.write_all(body).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(2), app.read(&mut response)) + .await + .expect("relay should respond without reaching upstream") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!( + response.contains("403"), + "tools/call not in allow list must be denied with 403, got: {response:?}" + ); + + let mut upstream_buf = [0u8; 128]; + let n = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_buf), + ) + .await + .unwrap_or(Ok(0)) + .unwrap_or(0); + assert_eq!(n, 0, "denied request must not be forwarded to upstream"); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should complete") + .unwrap() + .unwrap(); + } } diff --git a/crates/openshell-supervisor-network/src/l7/websocket.rs b/crates/openshell-supervisor-network/src/l7/websocket.rs index 31aa35509..e1f92e6ec 100644 --- a/crates/openshell-supervisor-network/src/l7/websocket.rs +++ b/crates/openshell-supervisor-network/src/l7/websocket.rs @@ -545,6 +545,7 @@ fn inspect_websocket_text_message( target: inspector.target.clone(), query_params: inspector.query_params.clone(), graphql: None, + jsonrpc: None, }; let (allowed, reason) = evaluate_l7_request(inspector.engine, inspector.ctx, &request_info)?; let decision = match (allowed, inspector.enforcement) { @@ -581,6 +582,7 @@ fn inspect_graphql_websocket_message( target: inspector.target.clone(), query_params: inspector.query_params.clone(), graphql: None, + jsonrpc: None, }; emit_websocket_l7_event( host, @@ -602,6 +604,7 @@ fn inspect_graphql_websocket_message( target: inspector.target.clone(), query_params: inspector.query_params.clone(), graphql: Some(graphql.clone()), + jsonrpc: None, }; let parse_error_reason = graphql .error diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 4dd0350ff..14d88a87f 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -925,6 +925,24 @@ fn resolve_binary_in_container(_policy_path: &str, _entrypoint_pid: u32) -> Opti None } +fn l7_matchers_to_json( + matchers: &std::collections::HashMap, +) -> serde_json::Map { + matchers + .iter() + .map(|(key, matcher)| { + let mut matcher_json = serde_json::json!({}); + if !matcher.glob.is_empty() { + matcher_json["glob"] = matcher.glob.clone().into(); + } + if !matcher.any.is_empty() { + matcher_json["any"] = matcher.any.clone().into(); + } + (key.clone(), matcher_json) + }) + .collect() +} + /// Convert typed proto policy fields to JSON suitable for `engine.add_data_json()`. /// /// The rego rules reference `data.*` directly, so the JSON structure has @@ -1023,35 +1041,25 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St "command": a.map_or("", |a| &a.command), "operation_type": a.map_or("", |a| &a.operation_type), "operation_name": a.map_or("", |a| &a.operation_name), + "rpc_method": a.map_or("", |a| &a.rpc_method), }); if let Some(a) = a && !a.fields.is_empty() { allow["fields"] = a.fields.clone().into(); } - let query: serde_json::Map = a - .map(|allow| { - allow - .query - .iter() - .map(|(key, matcher)| { - let mut matcher_json = serde_json::json!({}); - if !matcher.glob.is_empty() { - matcher_json["glob"] = - matcher.glob.clone().into(); - } - if !matcher.any.is_empty() { - matcher_json["any"] = - matcher.any.clone().into(); - } - (key.clone(), matcher_json) - }) - .collect() - }) - .unwrap_or_default(); + let query = a.map_or_else(serde_json::Map::new, |allow| { + l7_matchers_to_json(&allow.query) + }); if !query.is_empty() { allow["query"] = query.into(); } + let params = a.map_or_else(serde_json::Map::new, |allow| { + l7_matchers_to_json(&allow.params) + }); + if !params.is_empty() { + allow["params"] = params.into(); + } serde_json::json!({ "allow": allow }) }) .collect(); @@ -1087,23 +1095,17 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if !d.fields.is_empty() { deny["fields"] = d.fields.clone().into(); } - let query: serde_json::Map = d - .query - .iter() - .map(|(key, matcher)| { - let mut matcher_json = serde_json::json!({}); - if !matcher.glob.is_empty() { - matcher_json["glob"] = matcher.glob.clone().into(); - } - if !matcher.any.is_empty() { - matcher_json["any"] = matcher.any.clone().into(); - } - (key.clone(), matcher_json) - }) - .collect(); + if !d.rpc_method.is_empty() { + deny["rpc_method"] = d.rpc_method.clone().into(); + } + let query = l7_matchers_to_json(&d.query); if !query.is_empty() { deny["query"] = query.into(); } + let params = l7_matchers_to_json(&d.params); + if !params.is_empty() { + deny["params"] = params.into(); + } deny }) .collect(); @@ -1141,6 +1143,9 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St if e.graphql_max_body_bytes > 0 { ep["graphql_max_body_bytes"] = e.graphql_max_body_bytes.into(); } + if e.json_rpc_max_body_bytes > 0 { + ep["json_rpc_max_body_bytes"] = e.json_rpc_max_body_bytes.into(); + } ep }) .collect(); @@ -1948,6 +1953,36 @@ process: }) } + fn l7_jsonrpc_input(host: &str, port: u16, path: &str, rpc_method: &str) -> serde_json::Value { + l7_jsonrpc_input_with_params(host, port, path, rpc_method, serde_json::json!({})) + } + + fn l7_jsonrpc_input_with_params( + host: &str, + port: u16, + path: &str, + rpc_method: &str, + params: serde_json::Value, + ) -> serde_json::Value { + serde_json::json!({ + "network": { "host": host, "port": port }, + "exec": { + "path": "/usr/bin/curl", + "ancestors": [], + "cmdline_paths": [] + }, + "request": { + "method": "POST", + "path": path, + "query_params": {}, + "jsonrpc": { + "method": rpc_method, + "params": params + } + } + }) + } + fn l7_graphql_input(host: &str, operations: serde_json::Value) -> serde_json::Value { serde_json::json!({ "network": { "host": host, "port": 443 }, @@ -2494,6 +2529,8 @@ network_policies: operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + rpc_method: String::new(), + params: std::collections::HashMap::new(), }), }], ..Default::default() @@ -2542,6 +2579,140 @@ network_policies: assert!(!eval_l7(&engine, &deny_input)); } + #[test] + fn l7_jsonrpc_rpc_method_from_proto_is_enforced() { + let mut network_policies = std::collections::HashMap::new(); + network_policies.insert( + "jsonrpc_proto".to_string(), + NetworkPolicyRule { + name: "jsonrpc_proto".to_string(), + endpoints: vec![NetworkEndpoint { + host: "mcp.proto.com".to_string(), + port: 8000, + path: "/mcp".to_string(), + protocol: "json-rpc".to_string(), + enforcement: "enforce".to_string(), + rules: vec![L7Rule { + allow: Some(L7Allow { + method: String::new(), + path: String::new(), + command: String::new(), + query: std::collections::HashMap::new(), + operation_type: String::new(), + operation_name: String::new(), + fields: Vec::new(), + rpc_method: "initialize".to_string(), + params: std::collections::HashMap::new(), + }), + }], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + let proto = ProtoSandboxPolicy { + version: 1, + filesystem: Some(ProtoFs { + include_workdir: true, + read_only: vec![], + read_write: vec![], + }), + landlock: Some(openshell_core::proto::LandlockPolicy { + compatibility: "best_effort".to_string(), + }), + process: Some(ProtoProc { + run_as_user: "sandbox".to_string(), + run_as_group: "sandbox".to_string(), + }), + network_policies, + }; + + let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); + let allow_input = l7_jsonrpc_input("mcp.proto.com", 8000, "/mcp", "initialize"); + assert!(eval_l7(&engine, &allow_input)); + + let deny_input = l7_jsonrpc_input("mcp.proto.com", 8000, "/mcp", "tools/list"); + assert!(!eval_l7(&engine, &deny_input)); + } + + #[test] + fn l7_jsonrpc_params_rules_filter_tools_call() { + let data = r#" +network_policies: + jsonrpc_params: + name: jsonrpc_params + endpoints: + - host: mcp.params.test + port: 8000 + path: /mcp + protocol: json-rpc + enforcement: enforce + rules: + - allow: + rpc_method: tools/call + params: + name: read_status + - allow: + rpc_method: tools/call + params: + name: submit_report + arguments.scope: workspace/main + deny_rules: + - rpc_method: tools/call + params: + name: blocked_action + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).expect("engine from yaml"); + + let read_status = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({"name": "read_status"}), + ); + assert!(eval_l7(&engine, &read_status)); + + let submit_report = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({ + "name": "submit_report", + "arguments.scope": "workspace/main" + }), + ); + assert!(eval_l7(&engine, &submit_report)); + + let blocked_without_args = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({"name": "blocked_action"}), + ); + assert!(!eval_l7(&engine, &blocked_without_args)); + + let blocked_with_args = l7_jsonrpc_input_with_params( + "mcp.params.test", + 8000, + "/mcp", + "tools/call", + serde_json::json!({ + "name": "blocked_action", + "arguments.reason": "test" + }), + ); + assert!(!eval_l7(&engine, &blocked_with_args)); + } + #[test] fn l7_no_request_on_l4_only_endpoint() { // L4-only endpoint should not match L7 allow_request diff --git a/crates/openshell-supervisor-network/src/policy_local.rs b/crates/openshell-supervisor-network/src/policy_local.rs index 2fce25389..cf783dc5e 100644 --- a/crates/openshell-supervisor-network/src/policy_local.rs +++ b/crates/openshell-supervisor-network/src/policy_local.rs @@ -1088,6 +1088,8 @@ fn network_endpoint_from_json( operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + rpc_method: String::new(), + params: HashMap::new(), }), }) .collect(); @@ -1102,6 +1104,8 @@ fn network_endpoint_from_json( operation_type: String::new(), operation_name: String::new(), fields: Vec::new(), + rpc_method: String::new(), + params: HashMap::new(), }) .collect(); @@ -1125,6 +1129,7 @@ fn network_endpoint_from_json( persisted_queries: String::new(), graphql_persisted_queries: HashMap::new(), graphql_max_body_bytes: 0, + json_rpc_max_body_bytes: 0, path: String::new(), }) } diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index d467b022e..0ee2d6719 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -13,7 +13,7 @@ use openshell_core::denial::DenialEvent; use openshell_core::net::{is_always_blocked_ip, is_internal_ip, is_link_local_ip}; use openshell_core::policy::ProxyPolicy; use openshell_core::provider_credentials::ProviderCredentialState; -use openshell_core::secrets::{SecretResolver, rewrite_header_line_checked}; +use openshell_core::secrets::{self, SecretResolver, rewrite_header_line_checked}; use openshell_ocsf::{ ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, NetworkActivityBuilder, Process, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, @@ -176,7 +176,7 @@ impl ProxyHandle { /// The proxy uses OPA for network decisions with process-identity binding /// via `/proc/net/tcp`. All connections are evaluated through OPA policy. #[allow(clippy::too_many_arguments)] - pub async fn start_with_bind_addr( + pub(crate) async fn start_with_bind_addr( policy: &ProxyPolicy, bind_addr: Option, opa_engine: Arc, @@ -345,6 +345,21 @@ fn emit_forward_success_activity(tx: Option<&ActivitySender>, l7_activity_pendin ); } +fn l7_parse_error_reason(request_info: &crate::l7::L7RequestInfo) -> Option { + request_info + .graphql + .as_ref() + .and_then(|info| info.error.as_deref()) + .map(|error| format!("GraphQL request rejected: {error}")) + .or_else(|| { + request_info + .jsonrpc + .as_ref() + .and_then(|info| info.error.as_deref()) + .map(|error| format!("JSON-RPC request rejected: {error}")) + }) +} + /// Emit a denial event to the aggregator channel (if configured). /// Used by `handle_tcp_connection` which owns `Option`. fn emit_denial( @@ -492,6 +507,7 @@ async fn handle_tcp_connection( ) .await?; if let InferenceOutcome::Denied { reason } = outcome { + emit_activity(&activity_tx, true, "forward_policy"); let event = NetworkActivityBuilder::new(openshell_ocsf::ctx::ctx()) .activity(ActivityId::Open) .action(ActionId::Denied) @@ -2767,16 +2783,14 @@ fn rewrite_forward_request( path: &str, secret_resolver: Option<&SecretResolver>, request_body_credential_rewrite: bool, -) -> Result, openshell_core::secrets::UnresolvedPlaceholderError> { +) -> Result, secrets::UnresolvedPlaceholderError> { let header_end = raw[..used] .windows(4) .position(|w| w == b"\r\n\r\n") .map_or(used, |p| p + 4); let websocket_upgrade = crate::l7::rest::request_is_websocket_upgrade(&raw[..header_end]); let upstream_path = match secret_resolver { - Some(resolver) => { - openshell_core::secrets::rewrite_target_for_eval(path, resolver)?.resolved - } + Some(resolver) => secrets::rewrite_target_for_eval(path, resolver)?.resolved, None => path.to_string(), }; @@ -2869,10 +2883,10 @@ fn rewrite_forward_request( output.len() }; let output_str = String::from_utf8_lossy(&output[..scan_end]); - if output_str.contains(openshell_core::secrets::PLACEHOLDER_PREFIX_PUBLIC) - || output_str.contains(openshell_core::secrets::PROVIDER_ALIAS_MARKER_PUBLIC) + if output_str.contains(secrets::PLACEHOLDER_PREFIX_PUBLIC) + || output_str.contains(secrets::PROVIDER_ALIAS_MARKER_PUBLIC) { - return Err(openshell_core::secrets::UnresolvedPlaceholderError { location: "header" }); + return Err(secrets::UnresolvedPlaceholderError { location: "header" }); } } @@ -3395,18 +3409,66 @@ async fn handle_forward_proxy( } else { None }; + let jsonrpc = if l7_config.config.protocol == crate::l7::L7Protocol::JsonRpc { + let header_end = forward_request_bytes + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(forward_request_bytes.len(), |p| p + 4); + let header_str = std::str::from_utf8(&forward_request_bytes[..header_end]) + .map_err(|_| miette::miette!("Forward JSON-RPC headers contain invalid UTF-8"))?; + let body_length = crate::l7::rest::parse_body_length(header_str)?; + let mut jsonrpc_request = crate::l7::provider::L7Request { + action: method.to_string(), + target: path.clone(), + query_params: query_params.clone(), + raw_header: forward_request_bytes, + body_length, + }; + let body = match crate::l7::http::read_body_for_inspection( + client, + &mut jsonrpc_request, + l7_config.config.json_rpc_max_body_bytes, + ) + .await + { + Ok(body) => body, + Err(e) => { + let event = NetworkActivityBuilder::new(openshell_ocsf::ctx::ctx()) + .activity(ActivityId::Fail) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .dst_endpoint(Endpoint::from_domain(&host_lc, port)) + .message(format!("FORWARD_JSONRPC_L7 request rejected: {e}")) + .build(); + ocsf_emit!(event); + emit_activity_simple(activity_tx, true, "l7_parse_rejection"); + respond( + client, + &build_json_error_response( + 400, + "Bad Request", + "invalid_jsonrpc_request", + &format!("JSON-RPC request rejected before policy evaluation: {e}"), + ), + ) + .await?; + return Ok(()); + } + }; + forward_request_bytes = jsonrpc_request.raw_header; + Some(crate::l7::jsonrpc::parse_jsonrpc_body(&body)) + } else { + None + }; let request_info = crate::l7::L7RequestInfo { action: method.to_string(), target: path.clone(), query_params, graphql, + jsonrpc, }; - let parse_error_reason = request_info - .graphql - .as_ref() - .and_then(|info| info.error.as_deref()) - .map(|error| format!("GraphQL request rejected: {error}")); + let parse_error_reason = l7_parse_error_reason(&request_info); let force_deny = parse_error_reason.is_some(); let (allowed, reason) = parse_error_reason.map_or_else( || { @@ -3447,16 +3509,39 @@ async fn handle_forward_proxy( SeverityId::Informational, ), }; - let engine_type = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { - "l7-graphql" - } else { - "l7" - }; - let message_prefix = if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { - "FORWARD_GRAPHQL_L7" - } else { - "FORWARD_L7" + let engine_type = match l7_config.config.protocol { + crate::l7::L7Protocol::Graphql => "l7-graphql", + crate::l7::L7Protocol::JsonRpc => "l7-jsonrpc", + _ => "l7", }; + let log_message = request_info.jsonrpc.as_ref().map_or_else( + || { + let message_prefix = + if l7_config.config.protocol == crate::l7::L7Protocol::Graphql { + "FORWARD_GRAPHQL_L7" + } else { + "FORWARD_L7" + }; + format!( + "{message_prefix} {decision_str} {method} {host_lc}:{port}{path} reason={reason}" + ) + }, + |jsonrpc_info| { + let endpoint = format!("{host_lc}:{port}{path}"); + let params_sha256 = jsonrpc_info + .params_sha256() + .unwrap_or_else(|| "".to_string()); + crate::l7::relay::jsonrpc_log_message( + decision_str, + method, + &endpoint, + jsonrpc_info, + ¶ms_sha256, + tunnel_engine.captured_generation(), + &reason, + ) + }, + ); let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) .activity(ActivityId::Other) .action(action_id) @@ -3473,9 +3558,7 @@ async fn handle_forward_proxy( .with_cmd_line(&cmdline_str), ) .firewall_rule(policy_str, engine_type) - .message(format!( - "{message_prefix} {decision_str} {method} {host_lc}:{port}{path} reason={reason}" - )) + .message(log_message) .build(); ocsf_emit!(event); } @@ -4091,6 +4174,7 @@ mod tests { tls: crate::l7::TlsMode::Auto, enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite, request_body_credential_rewrite: false, @@ -4132,6 +4216,28 @@ mod tests { assert_eq!(event.deny_group, "unknown"); } + #[test] + fn l7_parse_error_reason_includes_jsonrpc_errors() { + let request_info = crate::l7::L7RequestInfo { + action: "POST".to_string(), + target: "/mcp".to_string(), + query_params: std::collections::HashMap::new(), + graphql: None, + jsonrpc: Some(crate::l7::jsonrpc::JsonRpcRequestInfo { + calls: Vec::new(), + is_batch: false, + error: Some("ambiguous dotted params key 'arguments.scope'".to_string()), + }), + }; + + let reason = l7_parse_error_reason(&request_info).expect("JSON-RPC parse error"); + + assert_eq!( + reason, + "JSON-RPC request rejected: ambiguous dotted params key 'arguments.scope'" + ); + } + #[test] fn forward_l7_allowed_activity_is_deferred_until_after_ssrf() { let (tx, mut rx) = mpsc::channel(4); @@ -4690,6 +4796,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, request_body_credential_rewrite: false, @@ -4703,6 +4810,7 @@ network_policies: tls: crate::l7::TlsMode::Auto, enforcement: crate::l7::EnforcementMode::Enforce, graphql_max_body_bytes: crate::l7::graphql::DEFAULT_MAX_BODY_BYTES, + json_rpc_max_body_bytes: crate::l7::jsonrpc::DEFAULT_MAX_BODY_BYTES, allow_encoded_slash: false, websocket_credential_rewrite: false, request_body_credential_rewrite: false,