diff --git a/src/main.rs b/src/main.rs index 29a01ce8..23619e21 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,9 +2,10 @@ use anyhow::{Context, Result}; use clap::Parser; use httpjail::jail::{JailConfig, create_jail}; use httpjail::proxy::ProxyServer; -use httpjail::rules::RuleEngine; use httpjail::rules::script::ScriptRuleEngine; use httpjail::rules::v8_js::V8JsRuleEngine; +use httpjail::rules::{Action, RuleEngine}; +use hyper::Method; use std::fs::OpenOptions; use std::os::unix::process::ExitStatusExt; use std::sync::atomic::{AtomicBool, Ordering}; @@ -79,8 +80,17 @@ struct Args { )] server: bool, + /// Evaluate rule against a URL and exit (dry-run) + #[arg( + long = "test", + value_name = "[METHOD] URL", + conflicts_with = "server", + conflicts_with = "cleanup" + )] + test: Option, + /// Command and arguments to execute - #[arg(trailing_var_arg = true, required_unless_present_any = ["cleanup", "server"])] + #[arg(trailing_var_arg = true, required_unless_present_any = ["cleanup", "server", "test"])] command: Vec, } @@ -323,6 +333,56 @@ async fn main() -> Result<()> { RuleEngine::from_trait(js_engine, request_log) }; + // Handle test (dry-run) mode: evaluate the rule against a URL and exit + if let Some(test_arg) = &args.test { + // Parse the test argument: if it contains two words, the first is the method + let (method, url) = if let Some(space_pos) = test_arg.find(' ') { + let method_str = &test_arg[..space_pos]; + let url = &test_arg[space_pos + 1..].trim(); + + // Parse the method string + let method = match method_str.to_uppercase().as_str() { + "GET" => Method::GET, + "POST" => Method::POST, + "PUT" => Method::PUT, + "DELETE" => Method::DELETE, + "HEAD" => Method::HEAD, + "OPTIONS" => Method::OPTIONS, + "CONNECT" => Method::CONNECT, + "PATCH" => Method::PATCH, + "TRACE" => Method::TRACE, + _ => { + eprintln!("Invalid HTTP method: {}", method_str); + std::process::exit(1); + } + }; + (method, url.to_string()) + } else { + // Single word: default to GET + (Method::GET, test_arg.clone()) + }; + + let eval = rule_engine + .evaluate_with_context(method.clone(), &url) + .await; + match eval.action { + Action::Allow => { + println!("ALLOW {} {}", method, url); + if let Some(ctx) = eval.context { + println!("{}", ctx); + } + std::process::exit(0); + } + Action::Deny => { + println!("DENY {} {}", method, url); + if let Some(ctx) = eval.context { + println!("{}", ctx); + } + std::process::exit(1); + } + } + } + // Parse bind configuration from env vars // Supports both "port" and "ip:port" formats fn parse_bind_config(env_var: &str) -> (Option, Option) { diff --git a/tests/test_flag.rs b/tests/test_flag.rs new file mode 100644 index 00000000..65385093 --- /dev/null +++ b/tests/test_flag.rs @@ -0,0 +1,76 @@ +use assert_cmd::prelude::*; +use predicates::prelude::*; +use std::process::Command; + +#[test] +fn test_httpjail_test_flag_allow() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("true") + .arg("--test") + .arg("https://example.com"); + cmd.assert() + .success() + .stdout(predicate::str::contains("ALLOW GET https://example.com")); +} + +#[test] +fn test_httpjail_test_flag_deny() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("false") + .arg("--test") + .arg("https://example.com"); + cmd.assert() + .failure() + .stdout(predicate::str::contains("DENY GET https://example.com")); +} + +#[test] +fn test_httpjail_test_flag_with_post_method() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("r.method === 'POST'") + .arg("--test") + .arg("POST https://example.com/api"); + cmd.assert().success().stdout(predicate::str::contains( + "ALLOW POST https://example.com/api", + )); +} + +#[test] +fn test_httpjail_test_flag_with_delete_method() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("r.method === 'DELETE'") + .arg("--test") + .arg("DELETE https://example.com/resource"); + cmd.assert().success().stdout(predicate::str::contains( + "ALLOW DELETE https://example.com/resource", + )); +} + +#[test] +fn test_httpjail_test_flag_with_method_deny() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("r.method === 'GET'") + .arg("--test") + .arg("POST https://example.com"); + cmd.assert() + .failure() + .stdout(predicate::str::contains("DENY POST https://example.com")); +} + +#[test] +fn test_httpjail_test_flag_default_get() { + // When no method is specified, it should default to GET + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("r.method === 'GET'") + .arg("--test") + .arg("https://example.com"); + cmd.assert() + .success() + .stdout(predicate::str::contains("ALLOW GET https://example.com")); +}