diff --git a/probing/cli/src/cli/commands.rs b/probing/cli/src/cli/commands.rs index 520593f..570eb54 100644 --- a/probing/cli/src/cli/commands.rs +++ b/probing/cli/src/cli/commands.rs @@ -171,3 +171,26 @@ pub enum Commands { #[command(subcommand = false, hide = true)] Store(StoreCommand), } + +impl Commands { + /// Determines whether this command should have a timeout applied. + /// Long-running commands like Launch and External should not time out. + pub fn is_timed_command(&self) -> bool { + match self { + // Long-running commands - no timeout + Commands::Repl => false, + Commands::Launch { .. } => false, + Commands::External(_) => false, + // Short-running commands - apply timeout + Commands::List { .. } => true, + Commands::Backtrace { .. } => true, + Commands::Rdma { .. } => true, + Commands::Eval { .. } => true, + Commands::Query { .. } => true, + Commands::Store(_) => true, + Commands::Config { .. } => true, + #[cfg(target_os = "linux")] + Commands::Inject(_) => true, + } + } +} diff --git a/probing/cli/src/cli/mod.rs b/probing/cli/src/cli/mod.rs index 26da682..5dfc7c2 100644 --- a/probing/cli/src/cli/mod.rs +++ b/probing/cli/src/cli/mod.rs @@ -56,6 +56,9 @@ pub struct Cli { } impl Cli { + pub fn should_timeout(&self) -> bool { + self.command.as_ref().map_or(true, |cmd| cmd.is_timed_command()) + } pub async fn run(&mut self) -> Result<()> { // Handle external commands first to avoid target requirement if let Some(Commands::External(args)) = &self.command { diff --git a/probing/cli/src/lib.rs b/probing/cli/src/lib.rs index 2662482..8f0dcf2 100644 --- a/probing/cli/src/lib.rs +++ b/probing/cli/src/lib.rs @@ -7,12 +7,23 @@ pub mod inject; use anyhow::Result; use clap::Parser; use env_logger::Env; +use std::time::Duration; +use tokio::time::timeout; const ENV_PROBING_LOGLEVEL: &str = "PROBING_LOGLEVEL"; /// Main entry point for the CLI, can be called from Python or as a binary -#[tokio::main] pub async fn cli_main(args: Vec) -> Result<()> { let _ = env_logger::try_init_from_env(Env::new().filter(ENV_PROBING_LOGLEVEL)); - cli::Cli::parse_from(args).run().await + + let mut cli = cli::Cli::parse_from(args); + + if cli.should_timeout() { + match timeout(Duration::from_secs(10), cli.run()).await { + Ok(result) => result, + Err(_) => Err(anyhow::anyhow!("Cli Command Timeout reached")), + } + } else { + cli.run().await + } } diff --git a/probing/cli/src/main.rs b/probing/cli/src/main.rs index 7ecc7db..9172ab7 100644 --- a/probing/cli/src/main.rs +++ b/probing/cli/src/main.rs @@ -1,8 +1,9 @@ use anyhow::Result; use probing_cli::cli_main; -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { let args: Vec = std::env::args().collect(); // cli_main already uses #[tokio::main], so it handles async execution internally - cli_main(args) + cli_main(args).await } diff --git a/probing/extensions/python/src/features/python_api.rs b/probing/extensions/python/src/features/python_api.rs index 4555e88..b089fe4 100644 --- a/probing/extensions/python/src/features/python_api.rs +++ b/probing/extensions/python/src/features/python_api.rs @@ -42,7 +42,12 @@ pub fn query_json(_py: Python, sql: String) -> PyResult { #[pyfunction] pub fn cli_main(_py: Python, args: Vec) -> PyResult<()> { - if let Err(e) = cli_main_impl(args) { + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .map_err(|e| PyErr::new::(e.to_string()))?; + + if let Err(e) = runtime.block_on(cli_main_impl(args)) { return Err(PyErr::new::( e.to_string(), ));