diff --git a/src/main.rs b/src/main.rs index f139ad12..c2affe74 100644 --- a/src/main.rs +++ b/src/main.rs @@ -84,10 +84,11 @@ struct Args { #[arg( long = "test", value_name = "[METHOD] URL", + num_args = 1..=2, conflicts_with = "server", conflicts_with = "cleanup" )] - test: Option, + test: Option>, /// Command and arguments to execute #[arg(trailing_var_arg = true, required_unless_present_any = ["cleanup", "server", "test"])] @@ -337,32 +338,28 @@ async fn main() -> Result<()> { }; // Handle test (dry-run) mode: evaluate the rule against a URL and exit - if let Some(test_arg) = &args.test { - // Parse the test argument: if it contains two words, the first is the method - let (method, url) = if let Some(space_pos) = test_arg.find(' ') { - let method_str = &test_arg[..space_pos]; - let url = &test_arg[space_pos + 1..].trim(); - - // Parse the method string - let method = match method_str.to_uppercase().as_str() { - "GET" => Method::GET, - "POST" => Method::POST, - "PUT" => Method::PUT, - "DELETE" => Method::DELETE, - "HEAD" => Method::HEAD, - "OPTIONS" => Method::OPTIONS, - "CONNECT" => Method::CONNECT, - "PATCH" => Method::PATCH, - "TRACE" => Method::TRACE, - _ => { - eprintln!("Invalid HTTP method: {}", method_str); - std::process::exit(1); + if let Some(test_vals) = &args.test { + let (method, url) = if test_vals.len() == 1 { + let s = &test_vals[0]; + let mut parts = s.split_whitespace(); + match (parts.next(), parts.next()) { + (Some(maybe_method), Some(url_rest)) => { + let method = maybe_method + .parse::() + .or_else(|_| maybe_method.to_ascii_uppercase().parse::()) + .unwrap_or(Method::GET); + (method, url_rest.to_string()) } - }; - (method, url.to_string()) + _ => (Method::GET, s.clone()), + } } else { - // Single word: default to GET - (Method::GET, test_arg.clone()) + let maybe_method = &test_vals[0]; + let url = &test_vals[1]; + let method = maybe_method + .parse::() + .or_else(|_| maybe_method.to_ascii_uppercase().parse::()) + .unwrap_or(Method::GET); + (method, url.clone()) }; let eval = rule_engine @@ -478,12 +475,12 @@ async fn main() -> Result<()> { info!("Received interrupt signal, cleaning up..."); shutdown_clone.store(true, Ordering::SeqCst); - // Cleanup jail unless testing flag is set + // Attempt cleanup only if no_cleanup is false if !no_cleanup && let Err(e) = jail_for_signal.cleanup() { warn!("Failed to cleanup jail on signal: {}", e); } - // Exit with signal termination status + // Always exit with signal termination status std::process::exit(130); // 128 + SIGINT(2) } }) diff --git a/tests/test_flag.rs b/tests/test_flag.rs index 65385093..22c8ee24 100644 --- a/tests/test_flag.rs +++ b/tests/test_flag.rs @@ -1,3 +1,5 @@ +// <--- Begin of necessary code edit + use assert_cmd::prelude::*; use predicates::prelude::*; use std::process::Command; @@ -74,3 +76,5 @@ fn test_httpjail_test_flag_default_get() { .success() .stdout(predicate::str::contains("ALLOW GET https://example.com")); } + +// <--- End of necessary code edit diff --git a/tests/test_flag_methods.rs b/tests/test_flag_methods.rs new file mode 100644 index 00000000..5c552b66 --- /dev/null +++ b/tests/test_flag_methods.rs @@ -0,0 +1,41 @@ +use assert_cmd::prelude::*; +use predicates::prelude::*; +use std::process::Command; + +#[test] +fn test_httpjail_test_flag_method_two_args() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("r.method === 'POST' && r.host === 'example.com'") + .arg("--test") + .arg("POST") + .arg("https://example.com"); + cmd.assert() + .success() + .stdout(predicate::str::contains("ALLOW POST https://example.com")); +} + +#[test] +fn test_httpjail_test_flag_method_one_arg_with_space() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("r.method === 'PUT' && r.host === 'example.com'") + .arg("--test") + .arg("PUT https://example.com"); + cmd.assert() + .success() + .stdout(predicate::str::contains("ALLOW PUT https://example.com")); +} + +#[test] +fn test_httpjail_test_flag_method_case_insensitive() { + let mut cmd = Command::cargo_bin("httpjail").unwrap(); + cmd.arg("--js") + .arg("r.method === 'DELETE' && r.host === 'example.com'") + .arg("--test") + .arg("delete") + .arg("https://example.com"); + cmd.assert() + .success() + .stdout(predicate::str::contains("ALLOW DELETE https://example.com")); +}