Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 24 additions & 27 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ struct Args {
#[arg(
long = "test",
value_name = "[METHOD] URL",
num_args = 1..=2,
conflicts_with = "server",
conflicts_with = "cleanup"
)]
test: Option<String>,
test: Option<Vec<String>>,

/// Command and arguments to execute
#[arg(trailing_var_arg = true, required_unless_present_any = ["cleanup", "server", "test"])]
Expand Down Expand Up @@ -337,32 +338,28 @@ async fn main() -> Result<()> {
};

// Handle test (dry-run) mode: evaluate the rule against a URL and exit
if let Some(test_arg) = &args.test {
// Parse the test argument: if it contains two words, the first is the method
let (method, url) = if let Some(space_pos) = test_arg.find(' ') {
let method_str = &test_arg[..space_pos];
let url = &test_arg[space_pos + 1..].trim();

// Parse the method string
let method = match method_str.to_uppercase().as_str() {
"GET" => Method::GET,
"POST" => Method::POST,
"PUT" => Method::PUT,
"DELETE" => Method::DELETE,
"HEAD" => Method::HEAD,
"OPTIONS" => Method::OPTIONS,
"CONNECT" => Method::CONNECT,
"PATCH" => Method::PATCH,
"TRACE" => Method::TRACE,
_ => {
eprintln!("Invalid HTTP method: {}", method_str);
std::process::exit(1);
if let Some(test_vals) = &args.test {
let (method, url) = if test_vals.len() == 1 {
let s = &test_vals[0];
let mut parts = s.split_whitespace();
match (parts.next(), parts.next()) {
(Some(maybe_method), Some(url_rest)) => {
let method = maybe_method
.parse::<Method>()
.or_else(|_| maybe_method.to_ascii_uppercase().parse::<Method>())
.unwrap_or(Method::GET);
(method, url_rest.to_string())
}
};
(method, url.to_string())
_ => (Method::GET, s.clone()),
}
} else {
// Single word: default to GET
(Method::GET, test_arg.clone())
let maybe_method = &test_vals[0];
let url = &test_vals[1];
let method = maybe_method
.parse::<Method>()
.or_else(|_| maybe_method.to_ascii_uppercase().parse::<Method>())
.unwrap_or(Method::GET);
(method, url.clone())
};

let eval = rule_engine
Expand Down Expand Up @@ -478,12 +475,12 @@ async fn main() -> Result<()> {
info!("Received interrupt signal, cleaning up...");
shutdown_clone.store(true, Ordering::SeqCst);

// Cleanup jail unless testing flag is set
// Attempt cleanup only if no_cleanup is false
if !no_cleanup && let Err(e) = jail_for_signal.cleanup() {
warn!("Failed to cleanup jail on signal: {}", e);
}

// Exit with signal termination status
// Always exit with signal termination status
std::process::exit(130); // 128 + SIGINT(2)
}
})
Expand Down
4 changes: 4 additions & 0 deletions tests/test_flag.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// <--- Begin of necessary code edit

use assert_cmd::prelude::*;
use predicates::prelude::*;
use std::process::Command;
Expand Down Expand Up @@ -74,3 +76,5 @@ fn test_httpjail_test_flag_default_get() {
.success()
.stdout(predicate::str::contains("ALLOW GET https://example.com"));
}

// <--- End of necessary code edit
41 changes: 41 additions & 0 deletions tests/test_flag_methods.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use assert_cmd::prelude::*;
use predicates::prelude::*;
use std::process::Command;

#[test]
fn test_httpjail_test_flag_method_two_args() {
let mut cmd = Command::cargo_bin("httpjail").unwrap();
cmd.arg("--js")
.arg("r.method === 'POST' && r.host === 'example.com'")
.arg("--test")
.arg("POST")
.arg("https://example.com");
cmd.assert()
.success()
.stdout(predicate::str::contains("ALLOW POST https://example.com"));
}

#[test]
fn test_httpjail_test_flag_method_one_arg_with_space() {
let mut cmd = Command::cargo_bin("httpjail").unwrap();
cmd.arg("--js")
.arg("r.method === 'PUT' && r.host === 'example.com'")
.arg("--test")
.arg("PUT https://example.com");
cmd.assert()
.success()
.stdout(predicate::str::contains("ALLOW PUT https://example.com"));
}

#[test]
fn test_httpjail_test_flag_method_case_insensitive() {
let mut cmd = Command::cargo_bin("httpjail").unwrap();
cmd.arg("--js")
.arg("r.method === 'DELETE' && r.host === 'example.com'")
.arg("--test")
.arg("delete")
.arg("https://example.com");
cmd.assert()
.success()
.stdout(predicate::str::contains("ALLOW DELETE https://example.com"));
}
Loading