diff --git a/.claude/commands/ci.md b/.claude/commands/ci.md new file mode 100644 index 00000000..2e9d2f6f --- /dev/null +++ b/.claude/commands/ci.md @@ -0,0 +1,7 @@ +Commit working changes and push them to CI + +Run `cargo clippy -- -D warning` before pushing changes. + +Enter loop where you wait for CI to complete, resolve issues, +and return to user once CI is green or a major decision is needed +to resolve it. diff --git a/.claude/settings.json b/.claude/settings.json index bbb9636a..aaedab0a 100644 --- a/.claude/settings.json +++ b/.claude/settings.json @@ -7,10 +7,6 @@ { "type": "command", "command": "cargo fmt" - }, - { - "type": "command", - "command": "cargo clippy --all-targets -- -D warnings" } ] } diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 866f873a..6eeee3d0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,6 +41,9 @@ jobs: - name: Run smoke tests run: cargo nextest run --profile ci --test smoke_test --verbose + - name: Run script integration tests + run: cargo nextest run --profile ci --test script_integration --verbose + - name: Run weak mode integration tests run: | # On macOS, we only support weak mode due to PF limitations @@ -112,6 +115,11 @@ jobs: source ~/.cargo/env cargo nextest run --profile ci --test smoke_test --verbose + - name: Run script integration tests + run: | + source ~/.cargo/env + cargo nextest run --profile ci --test script_integration --verbose + - name: Run Linux jail integration tests run: | source ~/.cargo/env @@ -148,6 +156,9 @@ jobs: - name: Build run: cargo build --verbose + - name: Run script integration tests + run: cargo nextest run --profile ci --test script_integration --verbose + - name: Run weak mode integration tests run: cargo nextest run --profile ci --test weak_integration --verbose diff --git a/CLAUDE.md b/CLAUDE.md index b7f0304d..f9f184ac 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -51,3 +51,7 @@ on both targets. ## Formatting After modifying code, run `cargo fmt` to ensure consistent formatting before committing changes. + +## Logging + +In regular operation of the CLI-only jail (non-server mode), info and warn logs are not permitted as they would interfere with the underlying process output. Only use debug level logs for normal operation and error logs for actual errors. The server mode (`--server`) may use info/warn logs as appropriate since it has no underlying process. diff --git a/Cargo.lock b/Cargo.lock index 4ada8594..b1b0f767 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -119,6 +119,17 @@ dependencies = [ "wait-timeout", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -221,12 +232,6 @@ version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - [[package]] name = "bytes" version = "1.10.1" @@ -421,7 +426,7 @@ dependencies = [ "libc", "option-ext", "redox_users", - "windows-sys 0.60.2", + "windows-sys 0.61.0", ] [[package]] @@ -430,6 +435,17 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] + [[package]] name = "doc-comment" version = "0.3.3" @@ -503,6 +519,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + [[package]] name = "fs_extra" version = "1.3.0" @@ -651,16 +676,6 @@ dependencies = [ "foldhash", ] -[[package]] -name = "hdrhistogram" -version = "7.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" -dependencies = [ - "byteorder", - "num-traits", -] - [[package]] name = "heck" version = "0.5.0" @@ -728,6 +743,7 @@ version = "0.1.2" dependencies = [ "anyhow", "assert_cmd", + "async-trait", "bytes", "camino", "chrono", @@ -747,18 +763,14 @@ dependencies = [ "rcgen", "regex", "rustls", - "rustls-pemfile", - "serde", - "serde_json", - "serde_yaml", "serial_test", "tempfile", "tls-parser", "tokio", "tokio-rustls", - "tower", "tracing", "tracing-subscriber", + "url", "webpki-roots 0.26.11", ] @@ -853,6 +865,113 @@ dependencies = [ "cc", ] +[[package]] +name = "icu_collections" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" +dependencies = [ + "displaydoc", + "potential_utf", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" + +[[package]] +name = "icu_properties" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" +dependencies = [ + "displaydoc", + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "potential_utf", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" + +[[package]] +name = "icu_provider" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" +dependencies = [ + "displaydoc", + "icu_locale_core", + "stable_deref_trait", + "tinystr", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + [[package]] name = "indexmap" version = "2.10.0" @@ -972,6 +1091,12 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" +[[package]] +name = "litemap" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" + [[package]] name = "lock_api" version = "0.4.13" @@ -1213,9 +1338,9 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.3.1" +version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "phf" @@ -1267,6 +1392,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "potential_utf" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "84df19adbe5b5a0782edcab45899906947ab039ccf4573713735ee7de1e6b08a" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1535,15 +1669,6 @@ dependencies = [ "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" -dependencies = [ - "rustls-pki-types", -] - [[package]] name = "rustls-pki-types" version = "1.12.0" @@ -1571,12 +1696,6 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" -[[package]] -name = "ryu" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" - [[package]] name = "scc" version = "2.4.0" @@ -1650,31 +1769,6 @@ dependencies = [ "syn 2.0.106", ] -[[package]] -name = "serde_json" -version = "1.0.143" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" -dependencies = [ - "itoa", - "memchr", - "ryu", - "serde", -] - -[[package]] -name = "serde_yaml" -version = "0.9.34+deprecated" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" -dependencies = [ - "indexmap", - "itoa", - "ryu", - "serde", - "unsafe-libyaml", -] - [[package]] name = "serial_test" version = "3.2.0" @@ -1752,6 +1846,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "strsim" version = "0.11.1" @@ -1787,10 +1887,15 @@ dependencies = [ ] [[package]] -name = "sync_wrapper" -version = "1.0.2" +name = "synstructure" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] [[package]] name = "system-configuration" @@ -1880,6 +1985,16 @@ version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" +[[package]] +name = "tinystr" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "tls-parser" version = "0.12.2" @@ -1965,32 +2080,6 @@ dependencies = [ "winnow", ] -[[package]] -name = "tower" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" -dependencies = [ - "futures-core", - "futures-util", - "hdrhistogram", - "indexmap", - "pin-project-lite", - "slab", - "sync_wrapper", - "tokio", - "tokio-util", - "tower-layer", - "tower-service", - "tracing", -] - -[[package]] -name = "tower-layer" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" - [[package]] name = "tower-service" version = "0.3.3" @@ -2071,18 +2160,30 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" -[[package]] -name = "unsafe-libyaml" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" - [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "url" +version = "2.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -2475,6 +2576,12 @@ dependencies = [ "bitflags", ] +[[package]] +name = "writeable" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" + [[package]] name = "yasna" version = "0.5.2" @@ -2484,6 +2591,30 @@ dependencies = [ "time", ] +[[package]] +name = "yoke" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.8.26" @@ -2504,8 +2635,62 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "zerofrom" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", + "synstructure", +] + [[package]] name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7aa2bd55086f1ab526693ecbe444205da57e25f4489879da80635a46d90e73b" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.106", +] diff --git a/Cargo.toml b/Cargo.toml index ee76eede..e135a05f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ categories = ["command-line-utilities", "network-programming", "development-tool isolated-cleanup-tests = [] [dependencies] +async-trait = "0.1" clap = { version = "4.5", features = ["derive"] } regex = "1.10" tokio = { version = "1.35", features = ["full"] } @@ -36,6 +37,7 @@ tls-parser = "0.12.2" camino = "1.1.11" filetime = "0.2" ctrlc = "3.4" +url = "2.5" [target.'cfg(target_os = "macos")'.dependencies] nix = { version = "0.27", features = ["user"] } diff --git a/README.md b/README.md index b6c8e5c5..8563ea5e 100644 --- a/README.md +++ b/README.md @@ -13,20 +13,18 @@ cargo install httpjail ## Features +> [!WARNING] +> httpjail is experimental and offers no API or CLI compatibility guarantees. + - 🔒 **Process-level network isolation** - Isolate processes in restricted network environments - 🌐 **HTTP/HTTPS interception** - Transparent proxy with TLS certificate injection - 🎯 **Regex-based filtering** - Flexible allow/deny rules with regex patterns +- 🔧 **Script-based evaluation** - Custom request evaluation logic via external scripts - 📝 **Request logging** - Monitor and log all HTTP/HTTPS requests - ⛔ **Default deny** - Requests are blocked unless explicitly allowed - 🖥️ **Cross-platform** - Native support for Linux and macOS - ⚡ **Zero configuration** - Works out of the box with sensible defaults -## MVP TODO - -- [ ] Update README to be more reflective of AI agent restrictions -- [x] Add a `--server` mode that runs the proxy server but doesn't execute the command -- [ ] Expand test cases to include WebSockets - ## Quick Start > By default, httpjail denies all network requests. Add `allow:` rules to permit traffic. @@ -48,6 +46,11 @@ httpjail -r "allow-get: api\.github\.com" -r "deny: .*" -- git pull # Use config file for complex rules httpjail --config rules.txt -- python script.py +# Use custom script for request evaluation +httpjail --script /path/to/check.sh -- ./my-app +# Script receives: HTTPJAIL_URL, HTTPJAIL_METHOD, HTTPJAIL_HOST, HTTPJAIL_SCHEME, HTTPJAIL_PATH +# Exit 0 to allow, non-zero to block. stdout becomes additional context in 403 response. + # Run as standalone proxy server (no command execution) httpjail --server -r "allow: .*" # Server defaults to ports 8080 (HTTP) and 8443 (HTTPS) @@ -173,6 +176,49 @@ Use the config: httpjail --config rules.txt -- ./my-application ``` +### Script-Based Evaluation + +Instead of regex rules, you can use a custom script to evaluate each request. The script receives environment variables for each request and returns an exit code to allow (0) or block (non-zero) the request. Any output to stdout becomes additional context in the 403 response. + +```bash +# Simple script example +cat > check_request.sh << 'EOF' +#!/bin/bash +# Allow only GitHub and reject everything else +if [[ "$HTTPJAIL_HOST" == "github.com" ]]; then + exit 0 +else + echo "Access denied: $HTTPJAIL_HOST is not on the allowlist" + exit 1 +fi +EOF +chmod +x check_request.sh + +# Use the script +httpjail --script ./check_request.sh -- curl https://github.com + +# Inline script (with spaces, executed via shell) +httpjail --script '[ "$HTTPJAIL_HOST" = "github.com" ] && exit 0 || exit 1' -- git pull +``` + +**Environment variables provided to the script:** + +- `HTTPJAIL_URL` - Full URL being requested +- `HTTPJAIL_METHOD` - HTTP method (GET, POST, etc.) +- `HTTPJAIL_HOST` - Hostname from the URL +- `HTTPJAIL_SCHEME` - URL scheme (http or https) +- `HTTPJAIL_PATH` - Path component of the URL + +**Script requirements:** + +- Exit code 0 allows the request +- Any non-zero exit code blocks the request +- stdout is captured and included in 403 responses as additional context +- stderr is logged for debugging but not sent to the client + +> [!TIP] +> Script-based evaluation can also be used for custom logging! Your script can log requests to a database, send metrics to a monitoring service, or implement complex audit trails before returning the allow/deny decision. + ### Advanced Options ```bash diff --git a/src/main.rs b/src/main.rs index 48295da7..d25b064e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use anyhow::{Context, Result}; use clap::Parser; use httpjail::jail::{JailConfig, create_jail}; use httpjail::proxy::ProxyServer; +use httpjail::rules::script::ScriptRuleEngine; use httpjail::rules::{Action, Rule, RuleEngine}; use std::fs::OpenOptions; use std::os::unix::process::ExitStatusExt; @@ -22,21 +23,41 @@ struct Args { /// -r "allow-get: .*" /// Actions: allow, deny /// Methods (optional): get, post, put, delete, head, options, connect, trace, patch - #[arg(short = 'r', long = "rule", value_name = "RULE")] + #[arg( + short = 'r', + long = "rule", + value_name = "RULE", + conflicts_with = "script" + )] rules: Vec, + /// Use script for evaluating requests + /// The script receives environment variables: + /// HTTPJAIL_URL, HTTPJAIL_METHOD, HTTPJAIL_HOST, HTTPJAIL_SCHEME, HTTPJAIL_PATH + /// Exit code 0 allows the request, non-zero blocks it + /// stdout becomes additional context in the 403 response + #[arg( + short = 's', + long = "script", + value_name = "PROG", + conflicts_with = "rules", + conflicts_with = "config" + )] + script: Option, + /// Use configuration file - #[arg(short = 'c', long = "config", value_name = "FILE")] + #[arg( + short = 'c', + long = "config", + value_name = "FILE", + conflicts_with = "script" + )] config: Option, /// Append requests to a log file #[arg(long = "request-log", value_name = "FILE")] request_log: Option, - /// Interactive approval mode - #[arg(long = "interactive")] - interactive: bool, - /// Use weak mode (environment variables only, no system isolation) #[arg(long = "weak")] weak: bool, @@ -311,8 +332,7 @@ async fn main() -> Result<()> { info!("Starting httpjail in server mode"); } - // Build rules from command line arguments - let rules = build_rules(&args)?; + // Build rule engine based on script or rules let request_log = if let Some(path) = &args.request_log { Some(Arc::new(Mutex::new( OpenOptions::new() @@ -324,7 +344,15 @@ async fn main() -> Result<()> { } else { None }; - let rule_engine = RuleEngine::new(rules, request_log); + + let rule_engine = if let Some(script) = &args.script { + info!("Using script-based rule evaluation: {}", script); + let script_engine = Box::new(ScriptRuleEngine::new(script.clone())); + RuleEngine::from_trait(script_engine, request_log) + } else { + let rules = build_rules(&args)?; + RuleEngine::new(rules, request_log) + }; // Parse bind configuration from env vars // Supports both "port" and "ip:port" formats @@ -522,13 +550,13 @@ mod tests { use super::*; use hyper::Method; - #[test] - fn test_build_rules_no_rules_default_deny() { + #[tokio::test] + async fn test_build_rules_no_rules_default_deny() { let args = Args { rules: vec![], + script: None, config: None, request_log: None, - interactive: false, weak: false, verbose: 0, timeout: None, @@ -544,7 +572,7 @@ mod tests { // Rule engine should deny requests when no rules are specified let engine = RuleEngine::new(rules, None); assert!(matches!( - engine.evaluate(Method::GET, "https://example.com"), + engine.evaluate(Method::GET, "https://example.com").await, Action::Deny )); } @@ -563,9 +591,9 @@ mod tests { let args = Args { rules: vec![], + script: None, config: Some(file.path().to_str().unwrap().to_string()), request_log: None, - interactive: false, weak: false, verbose: 0, timeout: None, diff --git a/src/proxy.rs b/src/proxy.rs index 27b0d5fe..06cc2530 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -23,13 +23,29 @@ use tracing::{debug, error, info, warn}; pub const HTTPJAIL_HEADER: &str = "HTTPJAIL"; pub const HTTPJAIL_HEADER_VALUE: &str = "true"; -pub const BLOCKED_MESSAGE: &str = "Request blocked by httpjail\n"; +pub const BLOCKED_MESSAGE: &str = "Request blocked by httpjail"; /// Create a raw HTTP/1.1 403 Forbidden response for CONNECT tunnels pub fn create_connect_403_response() -> &'static [u8] { b"HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nContent-Length: 27\r\n\r\nRequest blocked by httpjail" } +/// Create a raw HTTP/1.1 403 Forbidden response for CONNECT tunnels with context +pub fn create_connect_403_response_with_context(context: Option) -> Vec { + let message = if let Some(ctx) = context { + format!("{}\n{}", BLOCKED_MESSAGE, ctx) + } else { + BLOCKED_MESSAGE.to_string() + }; + + let response = format!( + "HTTP/1.1 403 Forbidden\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", + message.len(), + message + ); + response.into_bytes() +} + // Shared HTTP/HTTPS client for upstream requests static HTTPS_CLIENT: OnceLock< Client< @@ -351,10 +367,11 @@ pub async fn handle_http_request( format!("http://{}{}", host, path) }; - info!("Proxying HTTP request: {} {}", method, full_url); + debug!("Proxying HTTP request: {} {}", method, full_url); // Evaluate rules with method - match rule_engine.evaluate(method, &full_url) { + let evaluation = rule_engine.evaluate_with_context(method, &full_url).await; + match evaluation.action { Action::Allow => { debug!("Request allowed: {}", full_url); match proxy_request(req, &full_url).await { @@ -366,8 +383,8 @@ pub async fn handle_http_request( } } Action::Deny => { - warn!("Request denied: {}", full_url); - create_forbidden_response() + debug!("Request denied: {}", full_url); + create_forbidden_response(evaluation.context) } } } @@ -427,10 +444,16 @@ async fn proxy_request( Ok(Response::from_parts(parts, boxed_body)) } -/// Create a 403 Forbidden error response -pub fn create_forbidden_response() --> Result>, std::convert::Infallible> { - create_error_response(StatusCode::FORBIDDEN, BLOCKED_MESSAGE) +/// Create a 403 Forbidden error response with optional context +pub fn create_forbidden_response( + context: Option, +) -> Result>, std::convert::Infallible> { + let message = if let Some(ctx) = context { + format!("{}\n{}", BLOCKED_MESSAGE, ctx) + } else { + BLOCKED_MESSAGE.to_string() + }; + create_error_response(StatusCode::FORBIDDEN, &message) } pub fn create_error_response( diff --git a/src/proxy_tls.rs b/src/proxy_tls.rs index 90cb1380..45ab1ac0 100644 --- a/src/proxy_tls.rs +++ b/src/proxy_tls.rs @@ -1,5 +1,6 @@ use crate::proxy::{ - HTTPJAIL_HEADER, HTTPJAIL_HEADER_VALUE, create_connect_403_response, create_forbidden_response, + HTTPJAIL_HEADER, HTTPJAIL_HEADER_VALUE, create_connect_403_response_with_context, + create_forbidden_response, }; use crate::rules::{Action, RuleEngine}; use crate::tls::CertificateManager; @@ -304,7 +305,10 @@ async fn handle_connect_tunnel( // Check if this host is allowed let full_url = format!("https://{}", target); - match rule_engine.evaluate(Method::GET, &full_url) { + let evaluation = rule_engine + .evaluate_with_context(Method::GET, &full_url) + .await; + match evaluation.action { Action::Allow => { debug!("CONNECT allowed to: {}", host); @@ -341,10 +345,10 @@ async fn handle_connect_tunnel( // Get the underlying stream back let mut stream = reader.into_inner(); - // Send 403 Forbidden response - let response = create_connect_403_response(); + // Send 403 Forbidden response with context + let response = create_connect_403_response_with_context(evaluation.context); match timeout(WRITE_TIMEOUT, async { - stream.write_all(response).await?; + stream.write_all(&response).await?; stream.flush().await }) .await @@ -451,10 +455,13 @@ async fn handle_decrypted_https_request( let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/"); let full_url = format!("https://{}{}", host, path); - info!("Proxying HTTPS request: {} {}", method, full_url); + debug!("Proxying HTTPS request: {} {}", method, full_url); // Evaluate rules with method - match rule_engine.evaluate(method.clone(), &full_url) { + let evaluation = rule_engine + .evaluate_with_context(method.clone(), &full_url) + .await; + match evaluation.action { Action::Allow => { debug!("Request allowed: {}", full_url); match proxy_https_request(req, &host).await { @@ -466,8 +473,8 @@ async fn handle_decrypted_https_request( } } Action::Deny => { - warn!("Request denied: {}", full_url); - create_forbidden_response() + debug!("Request denied: {}", full_url); + create_forbidden_response(evaluation.context) } } } diff --git a/src/rules.rs b/src/rules.rs index 698077a0..dbea7f0b 100644 --- a/src/rules.rs +++ b/src/rules.rs @@ -1,12 +1,14 @@ -use anyhow::Result; +pub mod pattern; +pub mod script; + +use async_trait::async_trait; use chrono::{SecondsFormat, Utc}; use hyper::Method; -use regex::Regex; -use std::collections::HashSet; +pub use pattern::{PatternRuleEngine, Rule}; use std::fs::File; use std::io::Write; use std::sync::{Arc, Mutex}; -use tracing::{info, warn}; +use tracing::warn; #[derive(Debug, Clone)] pub enum Action { @@ -15,92 +17,63 @@ pub enum Action { } #[derive(Debug, Clone)] -pub struct Rule { +pub struct EvaluationResult { pub action: Action, - pub pattern: Regex, - pub methods: Option>, // None means all methods + pub context: Option, } -impl Rule { - pub fn new(action: Action, pattern: &str) -> Result { - Ok(Rule { - action, - pattern: Regex::new(pattern)?, - methods: None, // Default to matching all methods - }) +impl EvaluationResult { + pub fn allow() -> Self { + Self { + action: Action::Allow, + context: None, + } + } + + pub fn deny() -> Self { + Self { + action: Action::Deny, + context: None, + } } - pub fn with_methods(mut self, methods: Vec) -> Self { - self.methods = Some(methods.into_iter().collect()); + pub fn with_context(mut self, context: String) -> Self { + self.context = Some(context); self } +} - pub fn matches(&self, method: Method, url: &str) -> bool { - // Check if URL matches - if !self.pattern.is_match(url) { - return false; - } +#[async_trait] +pub trait RuleEngineTrait: Send + Sync { + async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult; - // Check if method matches (if methods are specified) - match &self.methods { - None => true, // No method filter means match all methods - Some(methods) => methods.contains(&method), - } - } + fn name(&self) -> &str; } -#[derive(Clone)] -pub struct RuleEngine { - pub rules: Vec, - pub request_log: Option>>, +pub struct LoggingRuleEngine { + engine: Box, + request_log: Option>>, } -impl RuleEngine { - pub fn new(rules: Vec, request_log: Option>>) -> Self { - RuleEngine { rules, request_log } - } - - pub fn evaluate(&self, method: Method, url: &str) -> Action { - let mut action = Action::Deny; - let mut matched = false; - - for rule in &self.rules { - if rule.matches(method.clone(), url) { - matched = true; - match &rule.action { - Action::Allow => { - info!( - "ALLOW: {} {} (matched: {:?})", - method, - url, - rule.pattern.as_str() - ); - action = Action::Allow; - } - Action::Deny => { - warn!( - "DENY: {} {} (matched: {:?})", - method, - url, - rule.pattern.as_str() - ); - action = Action::Deny; - } - } - break; - } +impl LoggingRuleEngine { + pub fn new(engine: Box, request_log: Option>>) -> Self { + Self { + engine, + request_log, } + } +} - if !matched { - warn!("DENY: {} {} (no matching rules)", method, url); - action = Action::Deny; - } +#[async_trait] +impl RuleEngineTrait for LoggingRuleEngine { + async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult { + let result = self.engine.evaluate(method.clone(), url).await; if let Some(log) = &self.request_log && let Ok(mut file) = log.lock() { let timestamp = Utc::now().to_rfc3339_opts(SecondsFormat::Millis, true); - let status = match &action { + let status = match &result.action { Action::Allow => '+', Action::Deny => '-', }; @@ -109,7 +82,54 @@ impl RuleEngine { } } - action + result + } + + fn name(&self) -> &str { + self.engine.name() + } +} + +#[derive(Clone)] +pub struct RuleEngine { + inner: Arc, +} + +impl RuleEngine { + pub fn new(rules: Vec, request_log: Option>>) -> Self { + let pattern_engine = Box::new(PatternRuleEngine::new(rules)); + let engine: Box = if request_log.is_some() { + Box::new(LoggingRuleEngine::new(pattern_engine, request_log)) + } else { + pattern_engine + }; + + RuleEngine { + inner: Arc::from(engine), + } + } + + pub fn from_trait( + engine: Box, + request_log: Option>>, + ) -> Self { + let engine: Box = if request_log.is_some() { + Box::new(LoggingRuleEngine::new(engine, request_log)) + } else { + engine + }; + + RuleEngine { + inner: Arc::from(engine), + } + } + + pub async fn evaluate(&self, method: Method, url: &str) -> Action { + self.inner.evaluate(method, url).await.action + } + + pub async fn evaluate_with_context(&self, method: Method, url: &str) -> EvaluationResult { + self.inner.evaluate(method, url).await } } @@ -138,8 +158,8 @@ mod tests { assert!(!rule.matches(Method::DELETE, "https://api.example.com/users")); } - #[test] - fn test_rule_engine() { + #[tokio::test] + async fn test_rule_engine() { let rules = vec![ Rule::new(Action::Allow, r"github\.com").unwrap(), Rule::new(Action::Deny, r"telemetry").unwrap(), @@ -148,27 +168,26 @@ mod tests { let engine = RuleEngine::new(rules, None); - // Test allow rule assert!(matches!( - engine.evaluate(Method::GET, "https://github.com/api"), + engine.evaluate(Method::GET, "https://github.com/api").await, Action::Allow )); - // Test deny rule assert!(matches!( - engine.evaluate(Method::POST, "https://telemetry.example.com"), + engine + .evaluate(Method::POST, "https://telemetry.example.com") + .await, Action::Deny )); - // Test default deny assert!(matches!( - engine.evaluate(Method::GET, "https://example.com"), + engine.evaluate(Method::GET, "https://example.com").await, Action::Deny )); } - #[test] - fn test_method_specific_rules() { + #[tokio::test] + async fn test_method_specific_rules() { let rules = vec![ Rule::new(Action::Allow, r"api\.example\.com") .unwrap() @@ -178,21 +197,23 @@ mod tests { let engine = RuleEngine::new(rules, None); - // GET should be allowed assert!(matches!( - engine.evaluate(Method::GET, "https://api.example.com/data"), + engine + .evaluate(Method::GET, "https://api.example.com/data") + .await, Action::Allow )); - // POST should be denied (doesn't match method filter) assert!(matches!( - engine.evaluate(Method::POST, "https://api.example.com/data"), + engine + .evaluate(Method::POST, "https://api.example.com/data") + .await, Action::Deny )); } - #[test] - fn test_request_logging() { + #[tokio::test] + async fn test_request_logging() { use std::fs::OpenOptions; let rules = vec![Rule::new(Action::Allow, r".*").unwrap()]; @@ -203,14 +224,14 @@ mod tests { .unwrap(); let engine = RuleEngine::new(rules, Some(Arc::new(Mutex::new(file)))); - engine.evaluate(Method::GET, "https://example.com"); + engine.evaluate(Method::GET, "https://example.com").await; let contents = std::fs::read_to_string(log_file.path()).unwrap(); assert!(contents.contains("+ GET https://example.com")); } - #[test] - fn test_request_logging_denied() { + #[tokio::test] + async fn test_request_logging_denied() { use std::fs::OpenOptions; let rules = vec![Rule::new(Action::Deny, r".*").unwrap()]; @@ -221,18 +242,18 @@ mod tests { .unwrap(); let engine = RuleEngine::new(rules, Some(Arc::new(Mutex::new(file)))); - engine.evaluate(Method::GET, "https://blocked.com"); + engine.evaluate(Method::GET, "https://blocked.com").await; let contents = std::fs::read_to_string(log_file.path()).unwrap(); assert!(contents.contains("- GET https://blocked.com")); } - #[test] - fn test_default_deny_with_no_rules() { + #[tokio::test] + async fn test_default_deny_with_no_rules() { let engine = RuleEngine::new(vec![], None); assert!(matches!( - engine.evaluate(Method::GET, "https://example.com"), + engine.evaluate(Method::GET, "https://example.com").await, Action::Deny )); } diff --git a/src/rules/pattern.rs b/src/rules/pattern.rs new file mode 100644 index 00000000..e908fbaa --- /dev/null +++ b/src/rules/pattern.rs @@ -0,0 +1,191 @@ +use super::{Action, EvaluationResult, RuleEngineTrait}; +use anyhow::Result; +use async_trait::async_trait; +use hyper::Method; +use regex::Regex; +use std::collections::HashSet; +use tracing::debug; + +#[derive(Debug, Clone)] +pub struct Rule { + pub action: Action, + pub pattern: Regex, + pub methods: Option>, +} + +impl Rule { + pub fn new(action: Action, pattern: &str) -> Result { + Ok(Rule { + action, + pattern: Regex::new(pattern)?, + methods: None, + }) + } + + pub fn with_methods(mut self, methods: Vec) -> Self { + self.methods = Some(methods.into_iter().collect()); + self + } + + pub fn matches(&self, method: Method, url: &str) -> bool { + if !self.pattern.is_match(url) { + return false; + } + + match &self.methods { + None => true, + Some(methods) => methods.contains(&method), + } + } +} + +#[derive(Clone)] +pub struct PatternRuleEngine { + pub rules: Vec, +} + +impl PatternRuleEngine { + pub fn new(rules: Vec) -> Self { + PatternRuleEngine { rules } + } +} + +#[async_trait] +impl RuleEngineTrait for PatternRuleEngine { + async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult { + for rule in &self.rules { + if rule.matches(method.clone(), url) { + match &rule.action { + Action::Allow => { + debug!( + "ALLOW: {} {} (matched: {:?})", + method, + url, + rule.pattern.as_str() + ); + return EvaluationResult::allow() + .with_context(format!("Matched pattern: {}", rule.pattern.as_str())); + } + Action::Deny => { + debug!( + "DENY: {} {} (matched: {:?})", + method, + url, + rule.pattern.as_str() + ); + return EvaluationResult::deny() + .with_context(format!("Matched pattern: {}", rule.pattern.as_str())); + } + } + } + } + + debug!("DENY: {} {} (no matching rules)", method, url); + EvaluationResult::deny().with_context("No matching rules".to_string()) + } + + fn name(&self) -> &str { + "pattern" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rule_matching() { + let rule = Rule::new(Action::Allow, r"github\.com").unwrap(); + assert!(rule.matches(Method::GET, "https://github.com/user/repo")); + assert!(rule.matches(Method::POST, "http://api.github.com/v3/repos")); + assert!(!rule.matches(Method::GET, "https://gitlab.com/user/repo")); + } + + #[test] + fn test_rule_with_methods() { + let rule = Rule::new(Action::Allow, r"api\.example\.com") + .unwrap() + .with_methods(vec![Method::GET, Method::HEAD]); + + assert!(rule.matches(Method::GET, "https://api.example.com/users")); + assert!(rule.matches(Method::HEAD, "https://api.example.com/users")); + assert!(!rule.matches(Method::POST, "https://api.example.com/users")); + assert!(!rule.matches(Method::DELETE, "https://api.example.com/users")); + } + + #[tokio::test] + async fn test_pattern_engine() { + let rules = vec![ + Rule::new(Action::Allow, r"github\.com").unwrap(), + Rule::new(Action::Deny, r"telemetry").unwrap(), + Rule::new(Action::Deny, r".*").unwrap(), + ]; + + let engine = PatternRuleEngine::new(rules); + + assert!(matches!( + engine + .evaluate(Method::GET, "https://github.com/api") + .await + .action, + Action::Allow + )); + + assert!(matches!( + engine + .evaluate(Method::POST, "https://telemetry.example.com") + .await + .action, + Action::Deny + )); + + assert!(matches!( + engine + .evaluate(Method::GET, "https://example.com") + .await + .action, + Action::Deny + )); + } + + #[tokio::test] + async fn test_method_specific_rules() { + let rules = vec![ + Rule::new(Action::Allow, r"api\.example\.com") + .unwrap() + .with_methods(vec![Method::GET]), + Rule::new(Action::Deny, r".*").unwrap(), + ]; + + let engine = PatternRuleEngine::new(rules); + + assert!(matches!( + engine + .evaluate(Method::GET, "https://api.example.com/data") + .await + .action, + Action::Allow + )); + + assert!(matches!( + engine + .evaluate(Method::POST, "https://api.example.com/data") + .await + .action, + Action::Deny + )); + } + + #[tokio::test] + async fn test_default_deny_with_no_rules() { + let engine = PatternRuleEngine::new(vec![]); + + assert!(matches!( + engine + .evaluate(Method::GET, "https://example.com") + .await + .action, + Action::Deny + )); + } +} diff --git a/src/rules/script.rs b/src/rules/script.rs new file mode 100644 index 00000000..c2dea76f --- /dev/null +++ b/src/rules/script.rs @@ -0,0 +1,280 @@ +use super::{EvaluationResult, RuleEngineTrait}; +use async_trait::async_trait; +use hyper::Method; +use std::time::Duration; +use tracing::debug; +use url::Url; + +#[derive(Clone)] +pub struct ScriptRuleEngine { + script: String, +} + +impl ScriptRuleEngine { + pub fn new(script: String) -> Self { + ScriptRuleEngine { script } + } + + async fn execute_script(&self, method: Method, url: &str) -> (bool, String) { + let parsed_url = match Url::parse(url) { + Ok(u) => u, + Err(e) => { + debug!("Failed to parse URL '{}': {}", url, e); + return (false, format!("Failed to parse URL: {}", e)); + } + }; + + let scheme = parsed_url.scheme(); + let host = parsed_url.host_str().unwrap_or(""); + let path = parsed_url.path(); + + debug!( + "Executing script for {} {} (host: {}, path: {})", + method, url, host, path + ); + + // Build the command + let mut cmd = if self.script.contains(' ') { + let shell = std::env::var("SHELL").unwrap_or_else(|_| "/bin/sh".to_string()); + let mut cmd = tokio::process::Command::new(&shell); + cmd.arg("-c").arg(&self.script); + cmd + } else { + tokio::process::Command::new(&self.script) + }; + + cmd.env("HTTPJAIL_URL", url) + .env("HTTPJAIL_METHOD", method.as_str()) + .env("HTTPJAIL_SCHEME", scheme) + .env("HTTPJAIL_HOST", host) + .env("HTTPJAIL_PATH", path) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true); // Ensure child is killed if dropped + + // Spawn the child process + let child = match cmd.spawn() { + Ok(child) => child, + Err(e) => { + debug!("Failed to spawn script: {}", e); + return (false, format!("Script execution failed: {}", e)); + } + }; + + // Wait for completion with timeout + let timeout = Duration::from_secs(5); + match tokio::time::timeout(timeout, child.wait_with_output()).await { + Ok(Ok(output)) => { + let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); + + if !stderr.is_empty() { + debug!("Script stderr: {}", stderr); + } + + let allowed = output.status.success(); + + debug!( + "Script returned {} for {} {} (exit code: {:?})", + if allowed { "ALLOW" } else { "DENY" }, + method, + url, + output.status.code() + ); + + (allowed, stdout) + } + Ok(Err(e)) => { + debug!("Error waiting for script: {}", e); + (false, format!("Script execution error: {}", e)) + } + Err(_) => { + // Timeout elapsed - process will be killed automatically due to kill_on_drop + debug!("Script execution timed out after {:?}", timeout); + (false, "Script execution timed out".to_string()) + } + } + } +} + +#[async_trait] +impl RuleEngineTrait for ScriptRuleEngine { + async fn evaluate(&self, method: Method, url: &str) -> EvaluationResult { + let (allowed, context) = self.execute_script(method.clone(), url).await; + + if allowed { + debug!("ALLOW: {} {} (script allowed)", method, url); + let mut result = EvaluationResult::allow(); + if !context.is_empty() { + result = result.with_context(context); + } + result + } else { + debug!("DENY: {} {} (script denied)", method, url); + let mut result = EvaluationResult::deny(); + if !context.is_empty() { + result = result.with_context(context); + } + result + } + } + + fn name(&self) -> &str { + "script" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rules::Action; + use std::fs; + use tempfile::NamedTempFile; + + #[tokio::test] + async fn test_script_allow() { + let mut script_file = NamedTempFile::new().unwrap(); + let script = r#"#!/bin/sh +exit 0 +"#; + use std::io::Write; + script_file.write_all(script.as_bytes()).unwrap(); + script_file.flush().unwrap(); + + let script_path = script_file.into_temp_path(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + let result = engine + .evaluate(Method::GET, "https://example.com/test") + .await; + + assert!(matches!(result.action, Action::Allow)); + drop(script_path); + } + + #[tokio::test] + async fn test_script_deny() { + let mut script_file = NamedTempFile::new().unwrap(); + let script = r#"#!/bin/sh +exit 1 +"#; + use std::io::Write; + script_file.write_all(script.as_bytes()).unwrap(); + script_file.flush().unwrap(); + + let script_path = script_file.into_temp_path(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + let result = engine + .evaluate(Method::GET, "https://example.com/test") + .await; + + assert!(matches!(result.action, Action::Deny)); + drop(script_path); + } + + #[tokio::test] + async fn test_script_with_context() { + let mut script_file = NamedTempFile::new().unwrap(); + let script = r#"#!/bin/sh +echo "Blocked by policy" +exit 1 +"#; + use std::io::Write; + script_file.write_all(script.as_bytes()).unwrap(); + script_file.flush().unwrap(); + + let script_path = script_file.into_temp_path(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + let result = engine + .evaluate(Method::GET, "https://example.com/test") + .await; + + assert!(matches!(result.action, Action::Deny)); + assert_eq!(result.context, Some("Blocked by policy".to_string())); + drop(script_path); + } + + #[tokio::test] + async fn test_script_environment_variables() { + let mut script_file = NamedTempFile::new().unwrap(); + let script = r#"#!/bin/sh +if [ "$HTTPJAIL_HOST" = "allowed.com" ]; then + exit 0 +else + echo "Host $HTTPJAIL_HOST not allowed" + exit 1 +fi +"#; + use std::io::Write; + script_file.write_all(script.as_bytes()).unwrap(); + script_file.flush().unwrap(); + + let script_path = script_file.into_temp_path(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + + let result = engine + .evaluate(Method::GET, "https://allowed.com/test") + .await; + assert!(matches!(result.action, Action::Allow)); + + let result = engine + .evaluate(Method::GET, "https://blocked.com/test") + .await; + assert!(matches!(result.action, Action::Deny)); + assert_eq!( + result.context, + Some("Host blocked.com not allowed".to_string()) + ); + drop(script_path); + } + + #[tokio::test] + async fn test_inline_script() { + let engine = ScriptRuleEngine::new("test \"$HTTPJAIL_HOST\" = \"github.com\"".to_string()); + + let result = engine + .evaluate(Method::GET, "https://github.com/test") + .await; + assert!(matches!(result.action, Action::Allow)); + + let result = engine + .evaluate(Method::GET, "https://example.com/test") + .await; + assert!(matches!(result.action, Action::Deny)); + } +} diff --git a/tests/script_integration.rs b/tests/script_integration.rs new file mode 100644 index 00000000..ac21e73d --- /dev/null +++ b/tests/script_integration.rs @@ -0,0 +1,188 @@ +use httpjail::rules::script::ScriptRuleEngine; +use httpjail::rules::{Action, RuleEngineTrait}; +use hyper::Method; +use std::fs; +use tempfile::NamedTempFile; + +#[tokio::test] +async fn test_script_allows_github() { + let mut script_file = NamedTempFile::new().unwrap(); + let script = r#"#!/bin/sh +if [ "$HTTPJAIL_HOST" = "github.com" ]; then + exit 0 +else + echo "Only github.com is allowed" + exit 1 +fi +"#; + use std::io::Write; + script_file.write_all(script.as_bytes()).unwrap(); + script_file.flush().unwrap(); + + // Convert to TempPath to close file handle (fixes "Text file busy" on Linux) + let script_path = script_file.into_temp_path(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + + // Test allowed request + let result = engine + .evaluate(Method::GET, "https://github.com/user/repo") + .await; + assert!(matches!(result.action, Action::Allow)); + + // Test denied request with context + let result = engine + .evaluate(Method::POST, "https://example.com/api") + .await; + assert!(matches!(result.action, Action::Deny)); + assert_eq!( + result.context, + Some("Only github.com is allowed".to_string()) + ); + + // TempPath will be automatically deleted when it goes out of scope + drop(script_path); +} + +#[tokio::test] +async fn test_script_with_method_filtering() { + let mut script_file = NamedTempFile::new().unwrap(); + let script = r#"#!/bin/sh +if [ "$HTTPJAIL_METHOD" = "GET" ] || [ "$HTTPJAIL_METHOD" = "HEAD" ]; then + exit 0 +else + echo "Method $HTTPJAIL_METHOD not allowed" + exit 1 +fi +"#; + use std::io::Write; + script_file.write_all(script.as_bytes()).unwrap(); + script_file.flush().unwrap(); + + // Convert to TempPath to close file handle (fixes "Text file busy" on Linux) + let script_path = script_file.into_temp_path(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + + // Test allowed methods + let result = engine + .evaluate(Method::GET, "https://example.com/api") + .await; + assert!(matches!(result.action, Action::Allow)); + + let result = engine + .evaluate(Method::HEAD, "https://example.com/api") + .await; + assert!(matches!(result.action, Action::Allow)); + + // Test denied method with context + let result = engine + .evaluate(Method::POST, "https://example.com/api") + .await; + assert!(matches!(result.action, Action::Deny)); + assert_eq!(result.context, Some("Method POST not allowed".to_string())); + + // TempPath will be automatically deleted when it goes out of scope + drop(script_path); +} + +#[tokio::test] +async fn test_inline_script_evaluation() { + // Test inline script (with spaces, executed via shell) + let engine = ScriptRuleEngine::new( + r#"[ "$HTTPJAIL_PATH" = "/api/v1/health" ] && exit 0 || exit 1"#.to_string(), + ); + + let result = engine + .evaluate(Method::GET, "https://example.com/api/v1/health") + .await; + assert!(matches!(result.action, Action::Allow)); + + let result = engine + .evaluate(Method::GET, "https://example.com/api/v2/users") + .await; + assert!(matches!(result.action, Action::Deny)); +} + +#[tokio::test] +async fn test_script_with_complex_logic() { + let mut script_file = NamedTempFile::new().unwrap(); + let script = r#"#!/bin/sh +# Complex logic: allow GET to github.com, POST to api.example.com, deny everything else + +if [ "$HTTPJAIL_METHOD" = "GET" ] && [ "$HTTPJAIL_HOST" = "github.com" ]; then + echo "GitHub read access allowed" + exit 0 +elif [ "$HTTPJAIL_METHOD" = "POST" ] && [ "$HTTPJAIL_HOST" = "api.example.com" ]; then + echo "API write access allowed" + exit 0 +else + echo "Request blocked by security policy: $HTTPJAIL_METHOD to $HTTPJAIL_HOST" + exit 1 +fi +"#; + use std::io::Write; + script_file.write_all(script.as_bytes()).unwrap(); + script_file.flush().unwrap(); + + // Convert to TempPath to close file handle (fixes "Text file busy" on Linux) + let script_path = script_file.into_temp_path(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + } + + let engine = ScriptRuleEngine::new(script_path.to_str().unwrap().to_string()); + + // Test allowed GitHub GET + let result = engine + .evaluate(Method::GET, "https://github.com/user/repo") + .await; + assert!(matches!(result.action, Action::Allow)); + assert_eq!( + result.context, + Some("GitHub read access allowed".to_string()) + ); + + // Test allowed API POST + let result = engine + .evaluate(Method::POST, "https://api.example.com/users") + .await; + assert!(matches!(result.action, Action::Allow)); + assert_eq!(result.context, Some("API write access allowed".to_string())); + + // Test denied request + let result = engine + .evaluate(Method::POST, "https://github.com/user/repo") + .await; + assert!(matches!(result.action, Action::Deny)); + assert!( + result + .context + .unwrap() + .contains("Request blocked by security policy") + ); + + // TempPath will be automatically deleted when it goes out of scope + drop(script_path); +}