diff --git a/crates/forge_app/src/agent_executor.rs b/crates/forge_app/src/agent_executor.rs index fe92b7c7d4..a2f990c299 100644 --- a/crates/forge_app/src/agent_executor.rs +++ b/crates/forge_app/src/agent_executor.rs @@ -111,6 +111,8 @@ impl> AgentEx ChatResponse::ToolCallStart { .. } => ctx.send(message).await?, ChatResponse::ToolCallEnd(_) => ctx.send(message).await?, ChatResponse::RetryAttempt { .. } => ctx.send(message).await?, + ChatResponse::HookError { .. } => ctx.send(message).await?, + ChatResponse::HookWarning { .. } => ctx.send(message).await?, ChatResponse::Interrupt { reason } => { return Err(Error::AgentToolInterrupted(reason)) .context(format!( diff --git a/crates/forge_app/src/app.rs b/crates/forge_app/src/app.rs index c8fec71741..0cc8aa1bea 100644 --- a/crates/forge_app/src/app.rs +++ b/crates/forge_app/src/app.rs @@ -11,11 +11,13 @@ use crate::changed_files::ChangedFiles; use crate::dto::ToolsOverview; use crate::hooks::{ CompactionHandler, DoomLoopDetector, PendingTodosHandler, TitleGenerationHandler, - TracingHandler, + TracingHandler, UserHookHandler, }; use crate::init_conversation_metrics::InitConversationMetrics; use crate::orch::Orchestrator; -use crate::services::{AgentRegistry, CustomInstructionsService, ProviderAuthService}; +use crate::services::{ + AgentRegistry, CustomInstructionsService, ProviderAuthService, UserHookConfigService, +}; use crate::set_conversation_id::SetConversationId; use crate::system_prompt::SystemPrompt; use crate::tool_registry::ToolRegistry; @@ -157,7 +159,7 @@ impl> ForgeAp tracing_handler.clone().and(title_handler.clone()) }; - let hook = Hook::default() + let internal_hook = Hook::default() .on_start(tracing_handler.clone().and(title_handler)) .on_request(tracing_handler.clone().and(DoomLoopDetector::default())) .on_response( @@ -169,6 +171,29 @@ impl> ForgeAp .on_toolcall_end(tracing_handler) .on_end(on_end_hook); + // Load user-configurable hooks from settings files + let user_hook_config = services.get_user_hook_config().await?; + + let hook = if !user_hook_config.is_empty() { + let user_handler = UserHookHandler::new( + services.hook_command_service().clone(), + services.get_env_vars(), + user_hook_config, + environment.cwd.clone(), + conversation.id.to_string(), + ); + let user_hook = Hook::default() + .on_start(user_handler.clone()) + .on_request(user_handler.clone()) + .on_response(user_handler.clone()) + .on_toolcall_start(user_handler.clone()) + .on_toolcall_end(user_handler.clone()) + .on_end(user_handler); + internal_hook.zip(user_hook) + } else { + internal_hook + }; + let orch = Orchestrator::new( services.clone(), conversation, diff --git a/crates/forge_app/src/hooks/compaction.rs b/crates/forge_app/src/hooks/compaction.rs index 76e58df83d..2c2eef848b 100644 --- a/crates/forge_app/src/hooks/compaction.rs +++ b/crates/forge_app/src/hooks/compaction.rs @@ -31,7 +31,7 @@ impl CompactionHandler { impl EventHandle> for CompactionHandler { async fn handle( &self, - _event: &EventData, + _event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { if let Some(context) = &conversation.context { diff --git a/crates/forge_app/src/hooks/doom_loop.rs b/crates/forge_app/src/hooks/doom_loop.rs index 3515b74e7b..3ec3c667cc 100644 --- a/crates/forge_app/src/hooks/doom_loop.rs +++ b/crates/forge_app/src/hooks/doom_loop.rs @@ -222,7 +222,7 @@ impl DoomLoopDetector { impl EventHandle> for DoomLoopDetector { async fn handle( &self, - event: &EventData, + event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { if let Some(consecutive_calls) = self.detect_from_conversation(conversation) { diff --git a/crates/forge_app/src/hooks/mod.rs b/crates/forge_app/src/hooks/mod.rs index 26a43401f2..9dc2c77753 100644 --- a/crates/forge_app/src/hooks/mod.rs +++ b/crates/forge_app/src/hooks/mod.rs @@ -3,9 +3,12 @@ mod doom_loop; mod pending_todos; mod title_generation; mod tracing; +mod user_hook_executor; +mod user_hook_handler; pub use compaction::CompactionHandler; pub use doom_loop::DoomLoopDetector; pub use pending_todos::PendingTodosHandler; pub use title_generation::TitleGenerationHandler; pub use tracing::TracingHandler; +pub use user_hook_handler::UserHookHandler; diff --git a/crates/forge_app/src/hooks/pending_todos.rs b/crates/forge_app/src/hooks/pending_todos.rs index bad2b44fa6..52f43370c8 100644 --- a/crates/forge_app/src/hooks/pending_todos.rs +++ b/crates/forge_app/src/hooks/pending_todos.rs @@ -42,7 +42,7 @@ impl PendingTodosHandler { impl EventHandle> for PendingTodosHandler { async fn handle( &self, - _event: &EventData, + _event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { let pending_todos = conversation.metrics.get_active_todos(); @@ -157,17 +157,21 @@ mod tests { } fn fixture_event() -> EventData { - EventData::new(fixture_agent(), ModelId::new("test-model"), EndPayload) + EventData::new( + fixture_agent(), + ModelId::new("test-model"), + EndPayload::default(), + ) } #[tokio::test] async fn test_no_pending_todos_does_nothing() { let handler = PendingTodosHandler::new(); - let event = fixture_event(); + let mut event = fixture_event(); let mut conversation = fixture_conversation(vec![]); let initial_msg_count = conversation.context.as_ref().unwrap().messages.len(); - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let actual = conversation.context.as_ref().unwrap().messages.len(); let expected = initial_msg_count; @@ -177,13 +181,13 @@ mod tests { #[tokio::test] async fn test_pending_todos_injects_reminder() { let handler = PendingTodosHandler::new(); - let event = fixture_event(); + let mut event = fixture_event(); let mut conversation = fixture_conversation(vec![ Todo::new("Fix the build").status(TodoStatus::Pending), Todo::new("Write tests").status(TodoStatus::InProgress), ]); - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let actual = conversation.context.as_ref().unwrap().messages.len(); let expected = 1; @@ -193,13 +197,13 @@ mod tests { #[tokio::test] async fn test_reminder_contains_formatted_list() { let handler = PendingTodosHandler::new(); - let event = fixture_event(); + let mut event = fixture_event(); let mut conversation = fixture_conversation(vec![ Todo::new("Fix the build").status(TodoStatus::Pending), Todo::new("Write tests").status(TodoStatus::InProgress), ]); - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let entry = &conversation.context.as_ref().unwrap().messages[0]; let actual = entry.message.content().unwrap(); @@ -210,14 +214,14 @@ mod tests { #[tokio::test] async fn test_completed_todos_not_included() { let handler = PendingTodosHandler::new(); - let event = fixture_event(); + let mut event = fixture_event(); let mut conversation = fixture_conversation(vec![ Todo::new("Completed task").status(TodoStatus::Completed), Todo::new("Cancelled task").status(TodoStatus::Cancelled), ]); let initial_msg_count = conversation.context.as_ref().unwrap().messages.len(); - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let actual = conversation.context.as_ref().unwrap().messages.len(); let expected = initial_msg_count; @@ -227,17 +231,17 @@ mod tests { #[tokio::test] async fn test_reminder_not_duplicated_for_same_todos() { let handler = PendingTodosHandler::new(); - let event = fixture_event(); + let mut event = fixture_event(); let mut conversation = fixture_conversation(vec![Todo::new("Fix the build").status(TodoStatus::Pending)]); // First call should inject a reminder - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let after_first = conversation.context.as_ref().unwrap().messages.len(); assert_eq!(after_first, 1); // Second call with the same pending todos should NOT add another reminder - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let after_second = conversation.context.as_ref().unwrap().messages.len(); assert_eq!(after_second, 1); // Still 1, no duplicate } @@ -245,14 +249,14 @@ mod tests { #[tokio::test] async fn test_reminder_added_when_todos_change() { let handler = PendingTodosHandler::new(); - let event = fixture_event(); + let mut event = fixture_event(); let mut conversation = fixture_conversation(vec![ Todo::new("Fix the build").status(TodoStatus::Pending), Todo::new("Write tests").status(TodoStatus::InProgress), ]); // First call should inject a reminder - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let after_first = conversation.context.as_ref().unwrap().messages.len(); assert_eq!(after_first, 1); @@ -265,7 +269,7 @@ mod tests { ]); // Second call with different pending todos should add a new reminder - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); let after_second = conversation.context.as_ref().unwrap().messages.len(); assert_eq!(after_second, 2); // New reminder added because todos changed } diff --git a/crates/forge_app/src/hooks/title_generation.rs b/crates/forge_app/src/hooks/title_generation.rs index 262b35bdcf..8f9f03e026 100644 --- a/crates/forge_app/src/hooks/title_generation.rs +++ b/crates/forge_app/src/hooks/title_generation.rs @@ -35,7 +35,7 @@ impl TitleGenerationHandler { impl EventHandle> for TitleGenerationHandler { async fn handle( &self, - event: &EventData, + event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { if conversation.title.is_some() { @@ -85,7 +85,7 @@ impl EventHandle> for TitleGenerationHa impl EventHandle> for TitleGenerationHandler { async fn handle( &self, - _event: &EventData, + _event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { if let Some((_, entry)) = self.title_tasks.remove(&conversation.id) { @@ -176,7 +176,7 @@ mod tests { conversation.title = Some("existing".into()); handler - .handle(&event(StartPayload), &mut conversation) + .handle(&mut event(StartPayload), &mut conversation) .await .unwrap(); @@ -195,7 +195,7 @@ mod tests { .insert(conversation.id, TitleGenerationState { rx, handle }); handler - .handle(&event(StartPayload), &mut conversation) + .handle(&mut event(StartPayload), &mut conversation) .await .unwrap(); @@ -215,7 +215,7 @@ mod tests { .insert(conversation.id, TitleGenerationState { rx, handle }); handler - .handle(&event(EndPayload), &mut conversation) + .handle(&mut event(EndPayload::default()), &mut conversation) .await .unwrap(); @@ -237,7 +237,7 @@ mod tests { .insert(conversation.id, TitleGenerationState { rx, handle }); handler - .handle(&event(EndPayload), &mut conversation) + .handle(&mut event(EndPayload::default()), &mut conversation) .await .unwrap(); @@ -262,7 +262,7 @@ mod tests { .insert(conversation.id, TitleGenerationState { rx, handle }); handler - .handle(&event(EndPayload), &mut conversation) + .handle(&mut event(EndPayload::default()), &mut conversation) .await .unwrap(); @@ -290,7 +290,7 @@ mod tests { joins.push(tokio::spawn(async move { barrier.wait().await; handler - .handle(&event(StartPayload), &mut conv) + .handle(&mut event(StartPayload), &mut conv) .await .unwrap(); })); diff --git a/crates/forge_app/src/hooks/tracing.rs b/crates/forge_app/src/hooks/tracing.rs index 94755f2b2c..e533cedc8e 100644 --- a/crates/forge_app/src/hooks/tracing.rs +++ b/crates/forge_app/src/hooks/tracing.rs @@ -28,7 +28,7 @@ impl TracingHandler { impl EventHandle> for TracingHandler { async fn handle( &self, - event: &EventData, + event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { debug!( @@ -46,7 +46,7 @@ impl EventHandle> for TracingHandler { impl EventHandle> for TracingHandler { async fn handle( &self, - _event: &EventData, + _event: &mut EventData, _conversation: &mut Conversation, ) -> anyhow::Result<()> { // Request events are logged but don't need specific logging per request @@ -59,7 +59,7 @@ impl EventHandle> for TracingHandler { impl EventHandle> for TracingHandler { async fn handle( &self, - event: &EventData, + event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { let message = &event.payload.message; @@ -91,7 +91,7 @@ impl EventHandle> for TracingHandler { impl EventHandle> for TracingHandler { async fn handle( &self, - event: &EventData, + event: &mut EventData, _conversation: &mut Conversation, ) -> anyhow::Result<()> { let tool_call = &event.payload.tool_call; @@ -112,7 +112,7 @@ impl EventHandle> for TracingHandler { impl EventHandle> for TracingHandler { async fn handle( &self, - event: &EventData, + event: &mut EventData, _conversation: &mut Conversation, ) -> anyhow::Result<()> { let tool_call = &event.payload.tool_call; @@ -137,7 +137,7 @@ impl EventHandle> for TracingHandler { impl EventHandle> for TracingHandler { async fn handle( &self, - _event: &EventData, + _event: &mut EventData, conversation: &mut Conversation, ) -> anyhow::Result<()> { if let Some(title) = &conversation.title { @@ -176,20 +176,20 @@ mod tests { async fn test_tracing_handler_start() { let handler = TracingHandler::new(); let mut conversation = Conversation::generate(); - let event = EventData::new(test_agent(), test_model_id(), StartPayload); + let mut event = EventData::new(test_agent(), test_model_id(), StartPayload); // Should not panic - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); } #[tokio::test] async fn test_tracing_handler_request() { let handler = TracingHandler::new(); let mut conversation = Conversation::generate(); - let event = EventData::new(test_agent(), test_model_id(), RequestPayload::new(0)); + let mut event = EventData::new(test_agent(), test_model_id(), RequestPayload::new(0)); // Should not panic - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); } #[tokio::test] @@ -206,10 +206,11 @@ mod tests { finish_reason: None, phase: None, }; - let event = EventData::new(test_agent(), test_model_id(), ResponsePayload::new(message)); + let mut event = + EventData::new(test_agent(), test_model_id(), ResponsePayload::new(message)); // Should not panic - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); } #[tokio::test] @@ -225,23 +226,23 @@ mod tests { let result = ToolResult::new(ToolName::from("test-tool")) .call_id(ToolCallId::new("test-id")) .failure(anyhow::anyhow!("Test error")); - let event = EventData::new( + let mut event = EventData::new( test_agent(), test_model_id(), ToolcallEndPayload::new(tool_call, result), ); // Should log warning but not panic - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); } #[tokio::test] async fn test_tracing_handler_end_with_title() { let handler = TracingHandler::new(); let mut conversation = Conversation::generate().title(Some("Test Title".to_string())); - let event = EventData::new(test_agent(), test_model_id(), EndPayload); + let mut event = EventData::new(test_agent(), test_model_id(), EndPayload::default()); // Should log debug message with title - handler.handle(&event, &mut conversation).await.unwrap(); + handler.handle(&mut event, &mut conversation).await.unwrap(); } } diff --git a/crates/forge_app/src/hooks/user_hook_executor.rs b/crates/forge_app/src/hooks/user_hook_executor.rs new file mode 100644 index 0000000000..6e203b932a --- /dev/null +++ b/crates/forge_app/src/hooks/user_hook_executor.rs @@ -0,0 +1,245 @@ +use std::collections::HashMap; +use std::path::Path; +use std::time::Duration; + +use forge_domain::{CommandOutput, HookExecutionResult}; +use tracing::debug; + +use crate::services::HookCommandService; + +/// Executes user hook commands by delegating to a [`HookCommandService`]. +/// +/// Holds the service by value; the service itself is responsible for any +/// internal reference counting (`Arc`). Keeps hook-specific timeout resolution +/// in one place. +#[derive(Clone)] +pub struct UserHookExecutor(S); + +impl UserHookExecutor { + /// Creates a new `UserHookExecutor` backed by the given service. + pub fn new(service: S) -> Self { + Self(service) + } +} + +impl UserHookExecutor { + /// Executes a shell command, piping `input_json` to stdin and capturing + /// stdout/stderr. + /// + /// Applies `timeout_duration` by racing the service call against the + /// deadline. On timeout, returns a `HookExecutionResult` with + /// `exit_code: None` and a descriptive message in `stderr`. + /// + /// # Arguments + /// * `command` - The shell command string to execute. + /// * `input_json` - JSON string to pipe to the command's stdin. + /// * `timeout_duration` - Maximum time to wait for the command. + /// * `cwd` - Working directory for the command. + /// * `env_vars` - Additional environment variables to set. + /// + /// # Errors + /// Returns an error if the process cannot be spawned. + pub async fn execute( + &self, + command: &str, + input_json: &str, + timeout_duration: Duration, + cwd: &Path, + env_vars: &HashMap, + ) -> anyhow::Result { + debug!( + command = command, + cwd = %cwd.display(), + timeout_ms = timeout_duration.as_millis() as u64, + "Executing user hook command" + ); + + let result = tokio::time::timeout( + timeout_duration, + self.0.execute_command_with_input( + command.to_string(), + cwd.to_path_buf(), + input_json.to_string(), + env_vars.clone(), + ), + ) + .await; + + let output = match result { + Ok(Ok(output)) => output, + Ok(Err(e)) => return Err(e), + Err(_) => { + tracing::warn!( + command = command, + timeout_ms = timeout_duration.as_millis() as u64, + "Hook command timed out" + ); + CommandOutput { + command: command.to_string(), + exit_code: None, + stdout: String::new(), + stderr: format!( + "Hook command timed out after {}ms", + timeout_duration.as_millis() + ), + } + } + }; + + debug!( + command = command, + exit_code = ?output.exit_code, + stdout_len = output.stdout.len(), + stderr_len = output.stderr.len(), + "Hook command completed" + ); + + Ok(HookExecutionResult { + exit_code: output.exit_code, + stdout: output.stdout, + stderr: output.stderr, + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::path::PathBuf; + use std::time::Duration; + + use forge_domain::CommandOutput; + use pretty_assertions::assert_eq; + + use super::*; + + /// A minimal service stub that records calls and returns a fixed result. + #[derive(Clone)] + struct StubInfra { + result: CommandOutput, + } + + impl StubInfra { + fn success(stdout: &str) -> Self { + Self { + result: CommandOutput { + command: String::new(), + exit_code: Some(0), + stdout: stdout.to_string(), + stderr: String::new(), + }, + } + } + + fn exit(code: i32, stderr: &str) -> Self { + Self { + result: CommandOutput { + command: String::new(), + exit_code: Some(code), + stdout: String::new(), + stderr: stderr.to_string(), + }, + } + } + + fn timeout() -> Self { + Self { + result: CommandOutput { + command: String::new(), + exit_code: None, + stdout: String::new(), + stderr: "Hook command timed out after 100ms".to_string(), + }, + } + } + } + + #[async_trait::async_trait] + impl HookCommandService for StubInfra { + async fn execute_command_with_input( + &self, + command: String, + _working_dir: PathBuf, + _stdin_input: String, + _env_vars: HashMap, + ) -> anyhow::Result { + let mut out = self.result.clone(); + out.command = command; + Ok(out) + } + } + + #[tokio::test] + async fn test_execute_success() { + let fixture = UserHookExecutor::new(StubInfra::success("hello")); + let actual = fixture + .execute( + "echo hello", + "{}", + Duration::from_secs(0), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert_eq!(actual.exit_code, Some(0)); + assert_eq!(actual.stdout, "hello"); + assert!(actual.is_success()); + } + + #[tokio::test] + async fn test_execute_exit_code_2() { + let fixture = UserHookExecutor::new(StubInfra::exit(2, "blocked")); + let actual = fixture + .execute( + "exit 2", + "{}", + Duration::from_secs(0), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert_eq!(actual.exit_code, Some(2)); + assert!(actual.is_blocking_exit()); + assert!(actual.stderr.contains("blocked")); + } + + #[tokio::test] + async fn test_execute_non_blocking_error() { + let fixture = UserHookExecutor::new(StubInfra::exit(1, "")); + let actual = fixture + .execute( + "exit 1", + "{}", + Duration::from_secs(0), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert_eq!(actual.exit_code, Some(1)); + assert!(actual.is_non_blocking_error()); + } + + #[tokio::test] + async fn test_execute_timeout() { + let fixture = UserHookExecutor::new(StubInfra::timeout()); + let actual = fixture + .execute( + "sleep 10", + "{}", + Duration::from_millis(100), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert!(actual.exit_code.is_none()); + assert!(actual.stderr.contains("timed out")); + } +} diff --git a/crates/forge_app/src/hooks/user_hook_handler.rs b/crates/forge_app/src/hooks/user_hook_handler.rs new file mode 100644 index 0000000000..46325313cd --- /dev/null +++ b/crates/forge_app/src/hooks/user_hook_handler.rs @@ -0,0 +1,2374 @@ +use std::collections::{BTreeMap, HashMap}; +use std::path::PathBuf; +use std::time::Duration; + +use async_trait::async_trait; +use forge_config::{UserHookConfig, UserHookEntry, UserHookEventName, UserHookMatcherGroup}; +use forge_domain::{ + ContextMessage, Conversation, EndPayload, EventData, EventHandle, HookEventInput, + HookExecutionResult, HookInput, HookOutput, PromptSuppressed, RequestPayload, ResponsePayload, + Role, StartPayload, ToolCallArguments, ToolcallEndPayload, ToolcallStartPayload, +}; +use regex::Regex; +use serde_json::Value; +use tracing::{debug, warn}; + +use super::user_hook_executor::UserHookExecutor; +use crate::services::HookCommandService; + +/// Default timeout for hook commands (10 minutes). +const DEFAULT_HOOK_TIMEOUT: Duration = Duration::from_secs(600); + +/// EventHandle implementation that bridges user-configured hooks with the +/// existing lifecycle event system. +/// +/// This handler is constructed from a `UserHookConfig` and executes matching +/// hook commands at each lifecycle event point. It wires into the existing +/// `Hook` system via `Hook::zip()`. +#[derive(Clone)] +pub struct UserHookHandler { + executor: UserHookExecutor, + config: UserHookConfig, + cwd: PathBuf, + env_vars: HashMap, + /// Pre-compiled regex cache keyed by the raw pattern string. + /// Built once during construction from the immutable config patterns. + regex_cache: HashMap, +} + +impl UserHookHandler { + /// Creates a new user hook handler from configuration. + /// + /// # Arguments + /// * `service` - The hook command service used to execute hook commands. + /// * `config` - The merged user hook configuration. + /// * `cwd` - Current working directory for command execution. + /// * `project_dir` - Project root directory for `FORGE_PROJECT_DIR` env + /// var. + /// * `session_id` - Current session/conversation ID. + /// * `default_hook_timeout` - Default timeout in milliseconds for hook + /// commands. + pub fn new( + service: I, + mut env_vars: BTreeMap, + config: UserHookConfig, + cwd: PathBuf, + session_id: String, + ) -> Self { + env_vars.insert( + "FORGE_PROJECT_DIR".to_string(), + cwd.to_string_lossy().to_string(), + ); + env_vars.insert("FORGE_SESSION_ID".to_string(), session_id); + env_vars.insert("FORGE_CWD".to_string(), cwd.to_string_lossy().to_string()); + + // Pre-compile all regex patterns from the config into a cache. + let regex_cache = Self::build_regex_cache(&config); + + Self { + executor: UserHookExecutor::new(service), + config, + cwd, + env_vars: env_vars.into_iter().collect(), + regex_cache, + } + } + + /// Pre-compiles all unique, non-empty regex patterns found in the config. + /// + /// Invalid patterns are logged and skipped so that construction never + /// fails. The same warning will fire at match time for any pattern + /// missing from the cache. + fn build_regex_cache(config: &UserHookConfig) -> HashMap { + let mut cache = HashMap::new(); + for groups in config.events.values() { + for group in groups { + if let Some(pattern) = &group.matcher + && !pattern.is_empty() + && !cache.contains_key(pattern) + { + match Regex::new(pattern) { + Ok(re) => { + cache.insert(pattern.clone(), re); + } + Err(e) => { + warn!( + pattern = pattern, + error = %e, + "Invalid regex in hook matcher, will be skipped at match time" + ); + } + } + } + } + } + cache + } + + /// Checks if the config has any hooks for the given event. + fn has_hooks(&self, event: &UserHookEventName) -> bool { + !self.config.get_groups(event).is_empty() + } + + /// Constructs a [`HookInput`] from the common fields stored in this + /// handler, leaving only the event-specific `event_data` to the caller. + fn build_base_input( + &self, + event_name: &UserHookEventName, + event_data: HookEventInput, + ) -> HookInput { + HookInput { + hook_event_name: event_name.to_string(), + cwd: self.cwd.to_string_lossy().to_string(), + session_id: self.env_vars.get("FORGE_SESSION_ID").cloned(), + event_data, + } + } + + /// Finds matching hook entries for an event, filtered by the optional + /// matcher regex against the given subject string. + /// + /// Uses the pre-compiled `regex_cache` to avoid recompiling patterns on + /// every invocation. Patterns that failed compilation during construction + /// are silently skipped (already warned at startup). + fn find_matching_hooks<'a>( + groups: &'a [UserHookMatcherGroup], + subject: Option<&str>, + regex_cache: &HashMap, + ) -> Vec<&'a UserHookEntry> { + let mut matching = Vec::new(); + + for group in groups { + let matches = match (&group.matcher, subject) { + (None, _) => { + // No matcher means unconditional match + true + } + (Some(pattern), _) if pattern.is_empty() => { + // Empty matcher is treated as unconditional (same as None) + true + } + (Some(_), None) => { + // Matcher specified but no subject to match against; skip + false + } + (Some(pattern), Some(subj)) => { + regex_cache.get(pattern).is_some_and(|re| re.is_match(subj)) + } + }; + + if matches { + matching.extend(group.hooks.iter()); + } + } + + matching + } + + /// Executes a list of hook entries and returns their results along with + /// any warnings for commands that failed to execute. + /// Each result is paired with the command string that produced it. + async fn execute_hooks( + &self, + hooks: &[&UserHookEntry], + input: &HookInput, + ) -> (Vec<(String, HookExecutionResult)>, Vec) + where + I: HookCommandService, + { + let input_json = match serde_json::to_string(input) { + Ok(json) => json, + Err(e) => { + warn!(error = %e, "Failed to serialize hook input"); + return ( + Vec::new(), + vec![format!("Hook input serialization failed: {e}")], + ); + } + }; + + let mut results = Vec::new(); + let mut warnings = Vec::new(); + for hook in hooks { + if let Some(command) = &hook.command { + match self + .executor + .execute( + command, + &input_json, + hook.timeout + .map(Duration::from_millis) + .unwrap_or(DEFAULT_HOOK_TIMEOUT), + &self.cwd, + &self.env_vars, + ) + .await + { + Ok(result) => { + // Non-blocking errors (exit code 1, etc.) are warned + if result.is_non_blocking_error() { + let stderr = result.stderr.trim(); + let detail = if stderr.is_empty() { + format!("exit code {:?}", result.exit_code) + } else { + stderr.to_string() + }; + warnings.push(format!( + "Hook command '{command}' returned non-blocking error: {detail}" + )); + } + results.push((command.clone(), result)); + } + Err(e) => { + warn!( + command = command, + error = %e, + "Hook command failed to execute" + ); + warnings.push(format!("Hook command '{command}' failed to execute: {e}")); + } + } + } + } + + (results, warnings) + } + + /// Runs matching hooks for the given event and collects results. + /// + /// This encapsulates the common lifecycle hook pattern: + /// 1. Resolve matcher groups for the event. + /// 2. Find hooks matching the optional subject. + /// 3. Execute matched hooks, collecting results and warnings. + /// 4. Extend event warnings. + /// 5. Collect and inject any `additionalContext` into the conversation. + /// + /// Returns the raw results for event-specific post-processing. + async fn run_hooks_and_collect( + &self, + event_name: &UserHookEventName, + subject: Option<&str>, + input: &HookInput, + warnings: &mut Vec, + conversation: &mut Conversation, + ) -> Vec<(String, HookExecutionResult)> + where + I: HookCommandService, + { + let groups = self.config.get_groups(event_name); + let hooks = Self::find_matching_hooks(groups, subject, &self.regex_cache); + + if hooks.is_empty() { + return Vec::new(); + } + + let (results, exec_warnings) = self.execute_hooks(&hooks, input).await; + warnings.extend(exec_warnings); + + let contexts = Self::collect_additional_context(&results); + Self::inject_additional_context(conversation, &event_name.to_string(), &contexts); + + results + } + + /// Checks a single hook result for blocking signals (exit code 2 or JSON + /// blocking decision). Returns the blocking command and reason if found. + fn check_blocking(command: &str, result: &HookExecutionResult) -> Option<(String, String)> { + if result.is_blocking_exit() { + let message = result + .blocking_message() + .unwrap_or("Hook blocked execution") + .to_string(); + return Some((command.to_string(), message)); + } + + if let Some(output) = result.parse_output() + && output.is_blocking() + { + let reason = output.blocking_reason("Hook blocked execution"); + return Some((command.to_string(), reason)); + } + + None + } + + /// Processes hook results, returning the blocking command and reason if + /// any hook blocked. + fn process_results(results: &[(String, HookExecutionResult)]) -> Option<(String, String)> { + results + .iter() + .find_map(|(cmd, result)| Self::check_blocking(cmd, result)) + } + + /// Collects `additionalContext` strings from all successful hook results, + /// paired with the command that produced them. + fn collect_additional_context( + results: &[(String, HookExecutionResult)], + ) -> Vec<(String, String)> { + let mut contexts = Vec::new(); + for (command, result) in results { + if let Some(output) = result.parse_output() + && let Some(ctx) = &output.additional_context + && !ctx.trim().is_empty() + { + contexts.push((command.clone(), ctx.clone())); + } + } + contexts + } + + /// Injects collected `additionalContext` into the conversation as a plain + /// text user message. The format matches Claude Code's transcript format: + /// ```text + /// {event_name} hook additional context: + /// [{command}]: {context} + /// ``` + /// This avoids XML-like tags that LLMs may treat as prompt injection. + fn inject_additional_context( + conversation: &mut Conversation, + event_name: &str, + contexts: &[(String, String)], + ) { + if contexts.is_empty() { + return; + } + if let Some(ctx) = conversation.context.as_mut() { + let mut lines = vec![format!("{event_name} hook additional context:")]; + for (command, context) in contexts { + lines.push(format!("[{command}]: {context}")); + } + let content = lines.join("\n"); + ctx.messages + .push(ContextMessage::user(content, None).into()); + debug!( + event_name = event_name, + context_count = contexts.len(), + "Injected additional context from hook into conversation" + ); + } + } + + /// Processes PreToolUse results, extracting updated input if present. + fn process_pre_tool_use_output( + results: &[(String, HookExecutionResult)], + ) -> PreToolUseDecision { + for (_command, result) in results { + // Exit code 2 = blocking error + if result.is_blocking_exit() { + let message = result + .blocking_message() + .unwrap_or("Hook blocked tool execution") + .to_string(); + return PreToolUseDecision::Block(message); + } + + // Exit code 0 = check stdout for JSON decisions + if let Some(output) = result.parse_output() { + // Check permission decision + if output.permission_decision.as_deref() == Some("deny") { + let reason = output.blocking_reason("Tool execution denied by hook"); + return PreToolUseDecision::Block(reason); + } + + // Check generic block decision + if output.is_blocking() { + let reason = output.blocking_reason("Hook blocked tool execution"); + return PreToolUseDecision::Block(reason); + } + + // Check for updated input + if output.updated_input.is_some() { + return PreToolUseDecision::AllowWithUpdate(output); + } + } + } + + PreToolUseDecision::Allow + } +} + +/// Decision result from PreToolUse hook processing. +enum PreToolUseDecision { + /// Allow the tool call to proceed. + Allow, + /// Allow but with updated input from the hook output. + AllowWithUpdate(HookOutput), + /// Block the tool call with the given reason. + Block(String), +} + +// --- EventHandle implementations --- + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &mut EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + if !self.has_hooks(&UserHookEventName::SessionStart) { + return Ok(()); + } + + let input = self.build_base_input( + &UserHookEventName::SessionStart, + HookEventInput::SessionStart { source: "startup".to_string() }, + ); + + self.run_hooks_and_collect( + &UserHookEventName::SessionStart, + Some("startup"), + &input, + &mut event.warnings, + conversation, + ) + .await; + + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &mut EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + // Only fire on the first request of a turn (user-submitted prompt). + // Subsequent iterations are internal LLM retry/tool-call loops and + // should not re-trigger UserPromptSubmit. + if event.payload.request_count != 0 { + return Ok(()); + } + + if !self.has_hooks(&UserHookEventName::UserPromptSubmit) { + return Ok(()); + } + + // Extract the last user message text as the prompt sent to the hook. + let prompt = conversation + .context + .as_ref() + .and_then(|ctx| { + ctx.messages + .iter() + .rev() + .find(|m| m.has_role(Role::User)) + .and_then(|m| m.content()) + .map(|s| s.to_string()) + }) + .unwrap_or_default(); + + let input = self.build_base_input( + &UserHookEventName::UserPromptSubmit, + HookEventInput::UserPromptSubmit { prompt }, + ); + + let results = self + .run_hooks_and_collect( + &UserHookEventName::UserPromptSubmit, + None, + &input, + &mut event.warnings, + conversation, + ) + .await; + + if let Some((command, reason)) = Self::process_results(&results) { + debug!( + command = command.as_str(), + reason = reason.as_str(), + "UserPromptSubmit hook blocked with feedback" + ); + event + .warnings + .push(format!("UserPromptSubmit hook blocked: {reason}")); + // Signal the orchestrator to suppress this prompt entirely. + return Err(anyhow::Error::from(PromptSuppressed(reason))); + } + + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + _event: &mut EventData, + _conversation: &mut Conversation, + ) -> anyhow::Result<()> { + // FIXME: No user hook events map to Response currently + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &mut EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + if !self.has_hooks(&UserHookEventName::PreToolUse) { + return Ok(()); + } + + // Use owned String to avoid borrow conflicts when mutating event later. + let tool_name = event.payload.tool_call.name.as_str().to_string(); + // FIXME: Add a tool name transformer to map tool names to Forge + // equivalents (e.g. "Bash" -> "shell") so that hook matchers written + // for other coding assistants work correctly. + + let tool_input = + serde_json::to_value(&event.payload.tool_call.arguments).unwrap_or_default(); + let tool_use_id = event + .payload + .tool_call + .call_id + .as_ref() + .map(|id| id.as_str().to_string()); + + let input = self.build_base_input( + &UserHookEventName::PreToolUse, + HookEventInput::PreToolUse { tool_name: tool_name.clone(), tool_input, tool_use_id }, + ); + + let results = self + .run_hooks_and_collect( + &UserHookEventName::PreToolUse, + Some(tool_name.as_str()), + &input, + &mut event.warnings, + conversation, + ) + .await; + + let decision = Self::process_pre_tool_use_output(&results); + + match decision { + PreToolUseDecision::Allow => Ok(()), + PreToolUseDecision::AllowWithUpdate(output) => { + if let Some(updated_input) = output.updated_input { + event.payload.tool_call.arguments = + ToolCallArguments::Parsed(Value::Object(updated_input)); + debug!( + tool_name = tool_name.as_str(), + "PreToolUse hook updated tool input" + ); + } + Ok(()) + } + PreToolUseDecision::Block(reason) => { + debug!( + tool_name = tool_name.as_str(), + reason = reason.as_str(), + "PreToolUse hook blocked tool call" + ); + // Return an error to signal the orchestrator to skip this tool call. + // The orchestrator converts this into an error ToolResult visible to + // the model. + Err(anyhow::anyhow!( + "Tool call '{}' blocked by PreToolUse hook: {}", + tool_name, + reason + )) + } + } + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &mut EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + let is_error = event.payload.result.is_error(); + let event_name = if is_error { + UserHookEventName::PostToolUseFailure + } else { + UserHookEventName::PostToolUse + }; + + if !self.has_hooks(&event_name) { + return Ok(()); + } + + let tool_name = event.payload.tool_call.name.as_str().to_string(); + + let tool_input = + serde_json::to_value(&event.payload.tool_call.arguments).unwrap_or_default(); + let tool_response = serde_json::to_value(&event.payload.result.output).unwrap_or_default(); + let tool_use_id = event + .payload + .tool_call + .call_id + .as_ref() + .map(|id| id.as_str().to_string()); + + let input = self.build_base_input( + &event_name, + HookEventInput::PostToolUse { + tool_name: tool_name.to_string(), + tool_input, + tool_response, + tool_use_id, + }, + ); + + let results = self + .run_hooks_and_collect( + &event_name, + Some(&tool_name), + &input, + &mut event.warnings, + conversation, + ) + .await; + + // PostToolUse blocking: store the feedback on the event payload. + // The orchestrator reads `hook_feedback` after `append_message` and + // injects it into context at the correct position — after the tool + // result, not before it. This ensures the LLM sees the feedback in + // the right order. + if let Some((command, reason)) = Self::process_results(&results) { + debug!( + tool_name = tool_name.as_str(), + event = %event_name, + command = command.as_str(), + reason = reason.as_str(), + "PostToolUse hook blocked, storing feedback for orchestrator injection" + ); + let content = format!("{event_name}:{tool_name} hook feedback:\n[{command}]: {reason}"); + event.payload.hook_feedback = Some(content.clone()); + event + .warnings + .push(format!("{event_name}:{tool_name} hook blocked: {reason}")); + } + + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &mut EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + // Fire SessionEnd hooks + if self.has_hooks(&UserHookEventName::SessionEnd) { + let input = + self.build_base_input(&UserHookEventName::SessionEnd, HookEventInput::Empty {}); + self.run_hooks_and_collect( + &UserHookEventName::SessionEnd, + None, + &input, + &mut event.warnings, + conversation, + ) + .await; + } + + // Fire Stop hooks + if !self.has_hooks(&UserHookEventName::Stop) { + return Ok(()); + } + + let stop_hook_active = event.payload.stop_hook_active; + + // Extract the last assistant message text for the Stop hook payload. + let last_assistant_message = conversation.context.as_ref().and_then(|ctx| { + ctx.messages + .iter() + .rev() + .find(|m| m.has_role(Role::Assistant)) + .and_then(|m| m.content()) + .map(|s| s.to_string()) + }); + + let input = self.build_base_input( + &UserHookEventName::Stop, + HookEventInput::Stop { stop_hook_active, last_assistant_message }, + ); + + let results = self + .run_hooks_and_collect( + &UserHookEventName::Stop, + None, + &input, + &mut event.warnings, + conversation, + ) + .await; + + if let Some((command, reason)) = Self::process_results(&results) { + debug!( + command = command.as_str(), + reason = reason.as_str(), + stop_hook_active = stop_hook_active, + "Stop hook blocked, injecting feedback for continuation" + ); + // Inject the blocking reason as a conversation message. The + // orchestrator detects that conversation.len() increased and + // resets should_yield to false, causing another LLM turn. + // This matches Claude Code's stop-hook continuation behavior. + if let Some(ctx) = conversation.context.as_mut() { + let content = format!("Stop hook feedback:\n[{command}]: {reason}"); + ctx.messages + .push(ContextMessage::user(content, None).into()); + } + // Mark the next End invocation as stop_hook_active so hook + // scripts can detect re-entrancy and avoid infinite loops. + event.payload.stop_hook_active = true; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::path::PathBuf; + + use forge_config::{UserHookEntry, UserHookEventName, UserHookMatcherGroup, UserHookType}; + use forge_domain::{CommandOutput, HookExecutionResult}; + use pretty_assertions::assert_eq; + + use super::*; + + /// A no-op service stub for tests that only exercise config/matching logic. + #[derive(Clone)] + struct NullInfra; + + #[async_trait::async_trait] + impl HookCommandService for NullInfra { + async fn execute_command_with_input( + &self, + command: String, + _working_dir: PathBuf, + _stdin_input: String, + _env_vars: HashMap, + ) -> anyhow::Result { + Ok(CommandOutput { + command, + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }) + } + } + + fn null_handler(config: UserHookConfig) -> UserHookHandler { + UserHookHandler::new( + NullInfra, + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-1".to_string(), + ) + } + + /// Configurable stub that returns a fixed `CommandOutput` for every call. + /// Replaces all single-purpose inline stubs (BlockExit2, JsonBlockInfra, + /// ContinueFalseInfra, Exit1Infra, StopBlockInfra, etc.). + #[derive(Clone)] + struct StubInfra { + output: forge_domain::CommandOutput, + } + + impl StubInfra { + fn new(exit_code: Option, stdout: &str, stderr: &str) -> Self { + Self { + output: forge_domain::CommandOutput { + command: String::new(), + exit_code, + stdout: stdout.to_string(), + stderr: stderr.to_string(), + }, + } + } + } + + #[async_trait::async_trait] + impl HookCommandService for StubInfra { + async fn execute_command_with_input( + &self, + command: String, + _working_dir: PathBuf, + _stdin_input: String, + _env_vars: HashMap, + ) -> anyhow::Result { + let mut out = self.output.clone(); + out.command = command; + Ok(out) + } + } + + fn handler_for_event(infra: I, event_json: &str) -> UserHookHandler { + let config: UserHookConfig = serde_json::from_str(event_json).unwrap(); + UserHookHandler::new( + infra, + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-test".to_string(), + ) + } + + fn make_entry(command: &str) -> UserHookEntry { + UserHookEntry { + hook_type: UserHookType::Command, + command: Some(command.to_string()), + timeout: None, + } + } + + fn make_group(matcher: Option<&str>, commands: &[&str]) -> UserHookMatcherGroup { + UserHookMatcherGroup { + matcher: matcher.map(|s| s.to_string()), + hooks: commands.iter().map(|c| make_entry(c)).collect(), + } + } + + /// Builds a regex cache from a slice of matcher groups, mirroring the + /// logic in `UserHookHandler::build_regex_cache` for test use. + fn regex_cache_from_groups(groups: &[UserHookMatcherGroup]) -> HashMap { + let mut cache = HashMap::new(); + for group in groups { + if let Some(pattern) = &group.matcher + && !pattern.is_empty() + && !cache.contains_key(pattern) + && let Ok(re) = Regex::new(pattern) + { + cache.insert(pattern.clone(), re); + } + } + cache + } + + #[test] + fn test_find_matching_hooks_no_matcher_fires_unconditionally() { + let groups = vec![make_group(None, &["echo hi"])]; + let cache = regex_cache_from_groups(&groups); + let actual = + UserHookHandler::::find_matching_hooks(&groups, Some("Bash"), &cache); + assert_eq!(actual.len(), 1); + assert_eq!(actual[0].command, Some("echo hi".to_string())); + } + + #[test] + fn test_find_matching_hooks_no_matcher_fires_without_subject() { + let groups = vec![make_group(None, &["echo hi"])]; + let cache = regex_cache_from_groups(&groups); + let actual = UserHookHandler::::find_matching_hooks(&groups, None, &cache); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_find_matching_hooks_regex_match() { + let groups = vec![make_group(Some("Bash"), &["block.sh"])]; + let cache = regex_cache_from_groups(&groups); + let actual = + UserHookHandler::::find_matching_hooks(&groups, Some("Bash"), &cache); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_find_matching_hooks_regex_no_match() { + let groups = vec![make_group(Some("Bash"), &["block.sh"])]; + let cache = regex_cache_from_groups(&groups); + let actual = + UserHookHandler::::find_matching_hooks(&groups, Some("Write"), &cache); + assert!(actual.is_empty()); + } + + #[test] + fn test_find_matching_hooks_regex_partial_match() { + let groups = vec![make_group(Some("Bash|Write"), &["check.sh"])]; + let cache = regex_cache_from_groups(&groups); + let actual = + UserHookHandler::::find_matching_hooks(&groups, Some("Bash"), &cache); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_find_matching_hooks_matcher_but_no_subject() { + let groups = vec![make_group(Some("Bash"), &["block.sh"])]; + let cache = regex_cache_from_groups(&groups); + let actual = UserHookHandler::::find_matching_hooks(&groups, None, &cache); + assert!(actual.is_empty()); + } + + #[test] + fn test_find_matching_hooks_empty_matcher_fires_without_subject() { + let groups = vec![make_group(Some(""), &["stop-hook.sh"])]; + let cache = regex_cache_from_groups(&groups); + let actual = UserHookHandler::::find_matching_hooks(&groups, None, &cache); + assert_eq!(actual.len(), 1); + assert_eq!(actual[0].command, Some("stop-hook.sh".to_string())); + } + + #[test] + fn test_find_matching_hooks_empty_matcher_fires_with_subject() { + let groups = vec![make_group(Some(""), &["pre-tool.sh"])]; + let cache = regex_cache_from_groups(&groups); + let actual = + UserHookHandler::::find_matching_hooks(&groups, Some("Bash"), &cache); + assert_eq!(actual.len(), 1); + assert_eq!(actual[0].command, Some("pre-tool.sh".to_string())); + } + + #[test] + fn test_find_matching_hooks_invalid_regex_skipped() { + let groups = vec![make_group(Some("[invalid"), &["block.sh"])]; + let cache = regex_cache_from_groups(&groups); + let actual = + UserHookHandler::::find_matching_hooks(&groups, Some("anything"), &cache); + assert!(actual.is_empty()); + } + + #[test] + fn test_find_matching_hooks_multiple_groups() { + let groups = vec![ + make_group(Some("Bash"), &["bash-hook.sh"]), + make_group(Some("Write"), &["write-hook.sh"]), + make_group(None, &["always.sh"]), + ]; + let cache = regex_cache_from_groups(&groups); + let actual = + UserHookHandler::::find_matching_hooks(&groups, Some("Bash"), &cache); + assert_eq!(actual.len(), 2); // Bash match + unconditional + } + + #[test] + fn test_process_pre_tool_use_output_allow_on_success() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Allow)); + } + + #[test] + fn test_process_pre_tool_use_output_block_on_exit_2() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(2), + stdout: String::new(), + stderr: "Blocked: dangerous command".to_string(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!( + matches!(actual, PreToolUseDecision::Block(msg) if msg.contains("dangerous command")) + ); + } + + #[test] + fn test_process_pre_tool_use_output_block_on_deny() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"permissionDecision": "deny", "reason": "Not allowed"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg == "Not allowed")); + } + + #[test] + fn test_process_pre_tool_use_output_block_on_decision() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "Blocked by policy"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg == "Blocked by policy")); + } + + #[test] + fn test_process_pre_tool_use_output_non_blocking_error_allows() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(1), + stdout: String::new(), + stderr: "some error".to_string(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Allow)); + } + + #[test] + fn test_process_results_no_blocking() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_results(&results); + assert!(actual.is_none()); + } + + #[test] + fn test_process_results_blocking_exit_code() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(2), + stdout: String::new(), + stderr: "stop reason".to_string(), + }, + )]; + let actual = UserHookHandler::::process_results(&results); + assert_eq!( + actual, + Some(("test-cmd".to_string(), "stop reason".to_string())) + ); + } + + #[test] + fn test_process_results_blocking_json_decision() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "keep going"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_results(&results); + assert_eq!( + actual, + Some(("test-cmd".to_string(), "keep going".to_string())) + ); + } + + #[test] + fn test_has_hooks_returns_false_for_empty_config() { + let config = UserHookConfig::new(); + let handler = null_handler(config); + assert!(!handler.has_hooks(&UserHookEventName::PreToolUse)); + } + + #[test] + fn test_has_hooks_returns_true_when_configured() { + let json = r#"{"PreToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + let handler = null_handler(config); + assert!(handler.has_hooks(&UserHookEventName::PreToolUse)); + assert!(!handler.has_hooks(&UserHookEventName::Stop)); + } + + #[test] + fn test_process_pre_tool_use_output_allow_with_update_detected() { + // A hook that returns updatedInput should produce AllowWithUpdate with the + // correct updated_input value. + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"updatedInput": {"command": "echo safe"}}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + let expected_map = + serde_json::Map::from_iter([("command".to_string(), serde_json::json!("echo safe"))]); + assert!( + matches!(&actual, PreToolUseDecision::AllowWithUpdate(output) if output.updated_input == Some(expected_map)) + ); + } + + #[tokio::test] + async fn test_allow_with_update_modifies_tool_call_arguments() { + // When a PreToolUse hook returns updatedInput, the handler must + // overwrite event.payload.tool_call.arguments with the new value. + use forge_domain::{ + Agent, EventData, ModelId, ProviderId, ToolCallArguments, ToolCallFull, + ToolcallStartPayload, + }; + + let json = r#"{"PreToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + + let handler = UserHookHandler::new( + StubInfra::new(Some(0), r#"{"updatedInput": {"command": "echo safe"}}"#, ""), + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-test".to_string(), + ); + + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + let original_args = ToolCallArguments::from_json(r#"{"command": "rm -rf /"}"#); + let tool_call = ToolCallFull::new("shell").arguments(original_args); + let mut event = EventData::new( + agent, + ModelId::new("test-model"), + ToolcallStartPayload::new(tool_call), + ); + let mut conversation = forge_domain::Conversation::generate(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + let actual_args = event.payload.tool_call.arguments.parse().unwrap(); + let expected_args = serde_json::json!({"command": "echo safe"}); + assert_eq!(actual_args, expected_args); + } + + #[test] + fn test_allow_with_update_none_updated_input_leaves_args_unchanged() { + // When HookOutput has updated_input = None (e.g. only + // `{"permissionDecision": "allow"}`), AllowWithUpdate should not + // overwrite the original arguments. + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"permissionDecision": "allow"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + // permissionDecision "allow" with no updatedInput => plain Allow + assert!(matches!(actual, PreToolUseDecision::Allow)); + } + + #[test] + fn test_allow_with_update_empty_object() { + // updatedInput is an empty object — still a valid update. + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"updatedInput": {}}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + let expected_map = serde_json::Map::new(); + assert!( + matches!(&actual, PreToolUseDecision::AllowWithUpdate(output) if output.updated_input == Some(expected_map)) + ); + } + + #[test] + fn test_allow_with_update_complex_nested_input() { + // updatedInput with nested objects and arrays. + let results = vec![("test-cmd".to_string(), HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"updatedInput": {"file_path": "/safe/path", "options": {"recursive": true, "depth": 3}, "tags": ["a", "b"]}}"#.to_string(), + stderr: String::new(), + })]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + let expected_map = serde_json::Map::from_iter([ + ("file_path".to_string(), serde_json::json!("/safe/path")), + ( + "options".to_string(), + serde_json::json!({"recursive": true, "depth": 3}), + ), + ("tags".to_string(), serde_json::json!(["a", "b"])), + ]); + assert!( + matches!(&actual, PreToolUseDecision::AllowWithUpdate(output) if output.updated_input == Some(expected_map)) + ); + } + + #[test] + fn test_block_takes_priority_over_updated_input() { + // If a hook returns both decision=block AND updatedInput, the block + // must win because blocking is checked before updatedInput. + let results = vec![("test-cmd".to_string(), HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "nope", "updatedInput": {"command": "echo safe"}}"#.to_string(), + stderr: String::new(), + })]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg == "nope")); + } + + #[test] + fn test_deny_takes_priority_over_updated_input() { + // permissionDecision=deny should block even if updatedInput is present. + let results = vec![("test-cmd".to_string(), HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"permissionDecision": "deny", "reason": "forbidden", "updatedInput": {"command": "echo safe"}}"#.to_string(), + stderr: String::new(), + })]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg == "forbidden")); + } + + #[test] + fn test_exit_code_2_blocks_even_with_updated_input_in_stdout() { + // Exit code 2 is a hard block regardless of stdout content. + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(2), + stdout: r#"{"updatedInput": {"command": "echo safe"}}"#.to_string(), + stderr: "hard block".to_string(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg.contains("hard block"))); + } + + #[test] + fn test_multiple_results_first_update_wins() { + // When multiple hooks run and the first returns updatedInput, that + // result is used (iteration stops at first non-Allow decision). + let results = vec![ + ( + "test-cmd-1".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"updatedInput": {"command": "first"}}"#.to_string(), + stderr: String::new(), + }, + ), + ( + "test-cmd-2".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"updatedInput": {"command": "second"}}"#.to_string(), + stderr: String::new(), + }, + ), + ]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + let expected_map = + serde_json::Map::from_iter([("command".to_string(), serde_json::json!("first"))]); + assert!( + matches!(&actual, PreToolUseDecision::AllowWithUpdate(output) if output.updated_input == Some(expected_map)) + ); + } + + #[test] + fn test_multiple_results_block_before_update() { + // A block from an earlier hook prevents a later hook's updatedInput + // from being applied. + let results = vec![ + ( + "test-cmd-1".to_string(), + HookExecutionResult { + exit_code: Some(2), + stdout: String::new(), + stderr: "blocked first".to_string(), + }, + ), + ( + "test-cmd-2".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"updatedInput": {"command": "safe"}}"#.to_string(), + stderr: String::new(), + }, + ), + ]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg.contains("blocked first"))); + } + + #[test] + fn test_non_blocking_error_then_update() { + // A non-blocking error (exit 1) from the first hook is logged but + // doesn't prevent a subsequent hook from returning updatedInput. + let results = vec![ + ( + "test-cmd-1".to_string(), + HookExecutionResult { + exit_code: Some(1), + stdout: String::new(), + stderr: "warning".to_string(), + }, + ), + ( + "test-cmd-2".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"updatedInput": {"command": "safe"}}"#.to_string(), + stderr: String::new(), + }, + ), + ]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + let expected_map = + serde_json::Map::from_iter([("command".to_string(), serde_json::json!("safe"))]); + assert!( + matches!(&actual, PreToolUseDecision::AllowWithUpdate(output) if output.updated_input == Some(expected_map)) + ); + } + + #[tokio::test] + async fn test_allow_with_update_no_updated_input_preserves_original() { + // When the hook returns exit 0 with empty stdout (no updatedInput), + // the original tool call arguments must remain untouched. + use forge_domain::{ + Agent, EventData, ModelId, ProviderId, ToolCallArguments, ToolCallFull, + ToolcallStartPayload, + }; + + let json = r#"{"PreToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + // NullInfra returns exit 0 + empty stdout => Allow + let handler = null_handler(config); + + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + let original_args = ToolCallArguments::from_json(r#"{"command": "ls"}"#); + let tool_call = ToolCallFull::new("shell").arguments(original_args); + let mut event = EventData::new( + agent, + ModelId::new("test-model"), + ToolcallStartPayload::new(tool_call), + ); + let mut conversation = forge_domain::Conversation::generate(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + // Arguments must still be the original value + let actual_args = event.payload.tool_call.arguments.parse().unwrap(); + let expected_args = serde_json::json!({"command": "ls"}); + assert_eq!(actual_args, expected_args); + } + + #[tokio::test] + async fn test_allow_with_update_replaces_unparsed_with_parsed() { + // Original arguments are Unparsed (raw string from LLM). After + // AllowWithUpdate, they should become Parsed(Value). + use forge_domain::{ + Agent, EventData, ModelId, ProviderId, ToolCallArguments, ToolCallFull, + ToolcallStartPayload, + }; + + let json = r#"{"PreToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + + let handler = UserHookHandler::new( + StubInfra::new( + Some(0), + r#"{"updatedInput": {"file_path": "/safe/file.txt", "content": "hello"}}"#, + "", + ), + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-test2".to_string(), + ); + + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + // Start with Unparsed arguments + let original_args = + ToolCallArguments::from_json(r#"{"file_path": "/etc/passwd", "content": "evil"}"#); + assert!(matches!(original_args, ToolCallArguments::Unparsed(_))); + + let tool_call = ToolCallFull::new("write").arguments(original_args); + let mut event = EventData::new( + agent, + ModelId::new("test-model"), + ToolcallStartPayload::new(tool_call), + ); + let mut conversation = forge_domain::Conversation::generate(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + // After update, arguments should be Parsed + assert!(matches!( + event.payload.tool_call.arguments, + ToolCallArguments::Parsed(_) + )); + let actual_args = event.payload.tool_call.arguments.parse().unwrap(); + let expected_args = serde_json::json!({"file_path": "/safe/file.txt", "content": "hello"}); + assert_eq!(actual_args, expected_args); + } + + #[tokio::test] + async fn test_block_returns_error_and_preserves_original_args() { + // When a hook blocks, handle() returns Err and the event arguments + // remain unchanged. + use forge_domain::{ + Agent, EventData, ModelId, ProviderId, ToolCallArguments, ToolCallFull, + ToolcallStartPayload, + }; + + let json = r#"{"PreToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + + let handler = UserHookHandler::new( + StubInfra::new(Some(2), "", "dangerous operation"), + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-block".to_string(), + ); + + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + let original_args = ToolCallArguments::from_json(r#"{"command": "rm -rf /"}"#); + let tool_call = ToolCallFull::new("shell").arguments(original_args); + let mut event = EventData::new( + agent, + ModelId::new("test-model"), + ToolcallStartPayload::new(tool_call), + ); + let mut conversation = forge_domain::Conversation::generate(); + + let result = handler.handle(&mut event, &mut conversation).await; + + // Should be an error + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("blocked by PreToolUse hook")); + assert!(err_msg.contains("dangerous operation")); + + // Arguments must still be the original value (not modified) + let actual_args = event.payload.tool_call.arguments.parse().unwrap(); + let expected_args = serde_json::json!({"command": "rm -rf /"}); + assert_eq!(actual_args, expected_args); + } + + #[test] + fn test_process_results_blocking_continue_false() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"continue": false, "stopReason": "task complete"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_results(&results); + assert_eq!( + actual, + Some(("test-cmd".to_string(), "task complete".to_string())) + ); + } + + #[test] + fn test_process_pre_tool_use_output_block_on_continue_false() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"continue": false, "stopReason": "no more tools"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg == "no more tools")); + } + + #[test] + fn test_process_results_stop_reason_fallback() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "stopReason": "fallback reason"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_results(&results); + assert_eq!( + actual, + Some(("test-cmd".to_string(), "fallback reason".to_string())) + ); + } + + #[test] + fn test_process_results_reason_over_stop_reason() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "primary", "stopReason": "secondary"}"# + .to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::process_results(&results); + assert_eq!( + actual, + Some(("test-cmd".to_string(), "primary".to_string())) + ); + } + + // ========================================================================= + // Tests: UserPromptSubmit blocking must return Err(PromptSuppressed) + // ========================================================================= + + /// Helper: creates a RequestPayload EventData with the given request_count. + fn request_event(request_count: usize) -> EventData { + use forge_domain::{Agent, ModelId, ProviderId}; + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + EventData::new( + agent, + ModelId::new("test-model"), + forge_domain::RequestPayload::new(request_count), + ) + } + + /// Helper: creates a Conversation with a context containing one user + /// message. + fn conversation_with_user_msg(msg: &str) -> forge_domain::Conversation { + let mut conv = forge_domain::Conversation::generate(); + let mut ctx = forge_domain::Context::default(); + ctx.messages + .push(forge_domain::ContextMessage::user(msg.to_string(), None).into()); + conv.context = Some(ctx); + conv + } + + #[tokio::test] + async fn test_user_prompt_submit_block_exit2_returns_error() { + // TC16: exit code 2 must return PromptSuppressed error. + let handler = handler_for_event( + StubInfra::new(Some(2), "", "policy violation"), + r#"{"UserPromptSubmit": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = request_event(0); + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.downcast_ref::() + .is_some() + ); + assert!(err.to_string().contains("policy violation")); + + // Warning should have been pushed to event.warnings + assert_eq!(event.warnings.len(), 1); + assert!(event.warnings[0].contains("policy violation")); + } + + #[tokio::test] + async fn test_user_prompt_submit_block_json_decision_returns_error() { + // JSON {"decision":"block","reason":"Content policy"} must block. + let handler = handler_for_event( + StubInfra::new( + Some(0), + r#"{"decision":"block","reason":"Content policy"}"#, + "", + ), + r#"{"UserPromptSubmit": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = request_event(0); + let mut conversation = conversation_with_user_msg("test"); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.downcast_ref::() + .is_some() + ); + assert!(err.to_string().contains("Content policy")); + } + + #[tokio::test] + async fn test_user_prompt_submit_block_continue_false_returns_error() { + // {"continue":false,"reason":"Blocked by admin"} must block. + let handler = handler_for_event( + StubInfra::new( + Some(0), + r#"{"continue":false,"reason":"Blocked by admin"}"#, + "", + ), + r#"{"UserPromptSubmit": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = request_event(0); + let mut conversation = conversation_with_user_msg("test"); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_err()); + assert!( + result + .unwrap_err() + .downcast_ref::() + .is_some() + ); + } + + #[tokio::test] + async fn test_user_prompt_submit_allow_returns_ok() { + // Exit 0 + empty stdout => allow, no feedback injected. + let handler = handler_for_event( + NullInfra, + r#"{"UserPromptSubmit": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = request_event(0); + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count); + } + + #[tokio::test] + async fn test_user_prompt_submit_non_blocking_error_returns_ok() { + // Exit code 1 is a non-blocking error — must NOT block. + let handler = handler_for_event( + StubInfra::new(Some(1), "", "some error"), + r#"{"UserPromptSubmit": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = request_event(0); + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_user_prompt_submit_skipped_on_subsequent_requests() { + // request_count > 0 means it's a retry, not a user prompt. + let handler = handler_for_event( + NullInfra, + r#"{"UserPromptSubmit": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = request_event(1); // subsequent request + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count); + } + + // ========================================================================= + // Stop hook tests: Stop hooks fire and inject feedback for continuation + // ========================================================================= + + /// Helper: creates an EndPayload EventData with optional stop_hook_active. + fn end_event() -> EventData { + use forge_domain::{Agent, ModelId, ProviderId}; + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + EventData::new( + agent, + ModelId::new("test-model"), + forge_domain::EndPayload { stop_hook_active: false }, + ) + } + + #[tokio::test] + async fn test_stop_hook_exit_code_2_injects_message_and_sets_active() { + let handler = handler_for_event( + StubInfra::new(Some(2), "", "keep working"), + r#"{"Stop": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = end_event(); + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + // Result is Ok (never errors) + assert!(result.is_ok()); + // A conversation message should have been injected for continuation + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count + 1); + // The injected message should contain the blocking reason + let last_msg = conversation + .context + .as_ref() + .unwrap() + .messages + .last() + .unwrap(); + let content = last_msg.content().unwrap(); + assert!(content.contains("keep working")); + assert!(content.contains("Stop hook feedback")); + // stop_hook_active should be set to true for the next iteration + assert!(event.payload.stop_hook_active); + } + + #[tokio::test] + async fn test_stop_hook_allow_returns_ok() { + let handler = handler_for_event( + NullInfra, + r#"{"Stop": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = end_event(); + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + // No continue message should be injected + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count); + } + + #[tokio::test] + async fn test_stop_hook_json_continue_false_injects_message() { + let handler = handler_for_event( + StubInfra::new( + Some(0), + r#"{"continue":false,"stopReason":"keep working"}"#, + "", + ), + r#"{"Stop": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = end_event(); + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + // Result is Ok (never errors) + assert!(result.is_ok()); + // A conversation message should have been injected + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count + 1); + // stop_hook_active should be set to true + assert!(event.payload.stop_hook_active); + } + + #[tokio::test] + async fn test_session_end_and_stop_hooks_both_fire() { + // Both SessionEnd and Stop hooks should execute. Stop hooks inject + // messages for continuation when they block. + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering as AtomicOrdering}; + + #[derive(Clone)] + struct CountingInfra { + call_count: Arc, + } + + #[async_trait::async_trait] + impl HookCommandService for CountingInfra { + async fn execute_command_with_input( + &self, + command: String, + _: PathBuf, + _: String, + _: HashMap, + ) -> anyhow::Result { + self.call_count.fetch_add(1, AtomicOrdering::SeqCst); + // Return exit 2 (blocking) + Ok(forge_domain::CommandOutput { + command, + exit_code: Some(2), + stdout: String::new(), + stderr: "blocked".to_string(), + }) + } + } + + // Config with both SessionEnd and Stop hooks + let json = r#"{ + "SessionEnd": [{"hooks": [{"type": "command", "command": "echo session-end"}]}], + "Stop": [{"hooks": [{"type": "command", "command": "echo stop"}]}] + }"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + let call_count = Arc::new(AtomicU32::new(0)); + let handler = UserHookHandler::new( + CountingInfra { call_count: call_count.clone() }, + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-test".to_string(), + ); + + let mut event = end_event(); + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + + // Result is Ok + assert!(result.is_ok()); + // Both SessionEnd AND Stop hooks should have been called (2 total) + let actual = call_count.load(AtomicOrdering::SeqCst); + assert_eq!(actual, 2); + // Stop hook blocked, so stop_hook_active should be true + assert!(event.payload.stop_hook_active); + } + + #[tokio::test] + async fn test_stop_hook_active_true_passed_to_hook_input() { + // When stop_hook_active is true (re-entrant call), the hook should + // receive it in its JSON input. + use std::sync::{Arc, Mutex}; + + #[derive(Clone)] + struct CapturingInfra { + captured_input: Arc>>, + } + + #[async_trait::async_trait] + impl HookCommandService for CapturingInfra { + async fn execute_command_with_input( + &self, + command: String, + _: PathBuf, + input: String, + _: HashMap, + ) -> anyhow::Result { + *self.captured_input.lock().unwrap() = Some(input); + Ok(forge_domain::CommandOutput { + command, + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }) + } + } + + let captured = Arc::new(Mutex::new(None)); + let handler = handler_for_event( + CapturingInfra { captured_input: captured.clone() }, + r#"{"Stop": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + // Create event with stop_hook_active = true (simulating re-entrant call) + let mut event = { + use forge_domain::{Agent, ModelId, ProviderId}; + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + EventData::new( + agent, + ModelId::new("test-model"), + forge_domain::EndPayload { stop_hook_active: true }, + ) + }; + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + assert!(result.is_ok()); + + // Verify the hook received stop_hook_active = true in its JSON input + let input_json = captured.lock().unwrap().clone().unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&input_json).unwrap(); + assert_eq!(parsed["stop_hook_active"], serde_json::Value::Bool(true)); + } + + #[tokio::test] + async fn test_stop_hook_allow_does_not_inject_message() { + // When a Stop hook allows the stop (exit 0, no blocking JSON), no + // message should be injected and stop_hook_active should remain false. + let handler = handler_for_event( + NullInfra, + r#"{"Stop": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = end_event(); + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + // No message injected + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count); + // stop_hook_active should remain false + assert!(!event.payload.stop_hook_active); + } + + #[tokio::test] + async fn test_stop_hook_active_false_on_initial_call() { + // On the first call, stop_hook_active should be false in the JSON input. + use std::sync::{Arc, Mutex}; + + #[derive(Clone)] + struct CapturingInfra2 { + captured_input: Arc>>, + } + + #[async_trait::async_trait] + impl HookCommandService for CapturingInfra2 { + async fn execute_command_with_input( + &self, + command: String, + _: PathBuf, + input: String, + _: HashMap, + ) -> anyhow::Result { + *self.captured_input.lock().unwrap() = Some(input); + Ok(forge_domain::CommandOutput { + command, + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }) + } + } + + let captured = Arc::new(Mutex::new(None)); + let handler = handler_for_event( + CapturingInfra2 { captured_input: captured.clone() }, + r#"{"Stop": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = end_event(); // stop_hook_active defaults to false + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + assert!(result.is_ok()); + + // Verify stop_hook_active is false in the JSON + let input_json = captured.lock().unwrap().clone().unwrap(); + let parsed: serde_json::Value = serde_json::from_str(&input_json).unwrap(); + assert_eq!(parsed["stop_hook_active"], serde_json::Value::Bool(false)); + } + + // ========================================================================= + // BUG-3 Tests: PostToolUse feedback must use wrapper + // ========================================================================= + + /// Helper: creates a ToolcallEndPayload EventData with a successful tool + /// result. + fn toolcall_end_event( + tool_name: &str, + is_error: bool, + ) -> EventData { + use forge_domain::{Agent, ModelId, ProviderId, ToolCallFull, ToolResult}; + let agent = Agent::new( + "test-agent", + ProviderId::from("test-provider".to_string()), + ModelId::new("test-model"), + ); + let tool_call = ToolCallFull::new(tool_name); + let result = if is_error { + ToolResult::new(tool_name).failure(anyhow::anyhow!("tool failed")) + } else { + ToolResult::new(tool_name).success("output data") + }; + EventData::new( + agent, + ModelId::new("test-model"), + forge_domain::ToolcallEndPayload::new(tool_call, result), + ) + } + + #[tokio::test] + async fn test_post_tool_use_block_injects_important_feedback() { + let handler = handler_for_event( + StubInfra::new(Some(2), "", "sensitive data detected"), + r#"{"PostToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = toolcall_end_event("shell", false); + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + + // PostToolUse always returns Ok + assert!(result.is_ok()); + + // Warning pushed to event.warnings + assert_eq!(event.warnings.len(), 1); + assert!(event.warnings[0].contains("sensitive data detected")); + + // Feedback stored on payload for the orchestrator to inject after + // append_message + let feedback = event.payload.hook_feedback.as_ref().unwrap(); + assert!(feedback.contains("hook feedback")); + assert!(feedback.contains("sensitive data detected")); + } + + #[tokio::test] + async fn test_post_tool_use_block_json_injects_feedback() { + let handler = handler_for_event( + StubInfra::new( + Some(0), + r#"{"decision":"block","reason":"PII detected"}"#, + "", + ), + r#"{"PostToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = toolcall_end_event("shell", false); + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + assert_eq!(event.warnings.len(), 1); + assert!(event.warnings[0].contains("PII detected")); + + // Feedback stored on payload for the orchestrator to inject after + // append_message + let feedback = event.payload.hook_feedback.as_ref().unwrap(); + assert!(feedback.contains("hook feedback")); + assert!(feedback.contains("PII detected")); + } + + #[tokio::test] + async fn test_post_tool_use_allow_no_feedback() { + let handler = handler_for_event( + NullInfra, + r#"{"PostToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = toolcall_end_event("shell", false); + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count); + } + + #[tokio::test] + async fn test_post_tool_use_non_blocking_error_no_feedback() { + let handler = handler_for_event( + StubInfra::new(Some(1), "", "non-blocking error"), + r#"{"PostToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = toolcall_end_event("shell", false); + let mut conversation = conversation_with_user_msg("hello"); + let original_msg_count = conversation.context.as_ref().unwrap().messages.len(); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + let actual_msg_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_msg_count, original_msg_count); + } + + #[tokio::test] + async fn test_post_tool_use_failure_event_fires_separately() { + // PostToolUseFailure is a separate event from PostToolUse. + // Configure only PostToolUseFailure hooks and fire with is_error=true. + let handler = handler_for_event( + StubInfra::new(Some(2), "", "error flagged"), + r#"{"PostToolUseFailure": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + + let mut event = toolcall_end_event("shell", true); + let mut conversation = conversation_with_user_msg("hello"); + + let result = handler.handle(&mut event, &mut conversation).await; + + assert!(result.is_ok()); + assert_eq!(event.warnings.len(), 1); + assert!(event.warnings[0].contains("error flagged")); + } + + #[tokio::test] + async fn test_post_tool_use_feedback_contains_tool_name() { + let handler = handler_for_event( + StubInfra::new(Some(2), "", "flagged"), + r#"{"PostToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = toolcall_end_event("shell", false); + let mut conversation = conversation_with_user_msg("hello"); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + // The warning should reference the tool name + assert_eq!(event.warnings.len(), 1); + assert!(event.warnings[0].contains("shell")); + } + + // ========================================================================= + // Tests: additionalContext injection + // ========================================================================= + + #[tokio::test] + async fn test_session_start_injects_additional_context() { + let json = r#"{"SessionStart": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + let handler = UserHookHandler::new( + StubInfra::new( + Some(0), + r#"{"additionalContext": "Remember to follow coding standards"}"#, + "", + ), + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-ctx".to_string(), + ); + + let agent = forge_domain::Agent::new( + "test-agent", + forge_domain::ProviderId::from("test-provider".to_string()), + forge_domain::ModelId::new("test-model"), + ); + let mut event = EventData::new( + agent, + forge_domain::ModelId::new("test-model"), + forge_domain::StartPayload, + ); + let mut conversation = conversation_with_user_msg("hello"); + let original_count = conversation.context.as_ref().unwrap().messages.len(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + let actual_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_count, original_count + 1); + + let last_msg = conversation + .context + .as_ref() + .unwrap() + .messages + .last() + .unwrap(); + let content = last_msg.content().unwrap(); + assert!(content.contains("SessionStart hook additional context")); + assert!(content.contains("Remember to follow coding standards")); + } + + #[tokio::test] + async fn test_user_prompt_submit_injects_additional_context() { + let handler = UserHookHandler::new( + StubInfra::new( + Some(0), + r#"{"additionalContext": "Remember to follow coding standards"}"#, + "", + ), + BTreeMap::new(), + serde_json::from_str( + r#"{"UserPromptSubmit": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ) + .unwrap(), + PathBuf::from("/tmp"), + "sess-ctx".to_string(), + ); + + let mut event = request_event(0); + let mut conversation = conversation_with_user_msg("test prompt"); + let original_count = conversation.context.as_ref().unwrap().messages.len(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + let actual_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_count, original_count + 1); + + let last_msg = conversation + .context + .as_ref() + .unwrap() + .messages + .last() + .unwrap(); + let content = last_msg.content().unwrap(); + assert!(content.contains("UserPromptSubmit hook additional context")); + assert!(content.contains("Remember to follow coding standards")); + } + + #[tokio::test] + async fn test_pre_tool_use_injects_additional_context() { + let json = r#"{"PreToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + let handler = UserHookHandler::new( + StubInfra::new( + Some(0), + r#"{"additionalContext": "Remember to follow coding standards"}"#, + "", + ), + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-ctx".to_string(), + ); + + let agent = forge_domain::Agent::new( + "test-agent", + forge_domain::ProviderId::from("test-provider".to_string()), + forge_domain::ModelId::new("test-model"), + ); + let tool_call = forge_domain::ToolCallFull::new("shell").arguments( + forge_domain::ToolCallArguments::from_json(r#"{"command": "ls"}"#), + ); + let mut event = EventData::new( + agent, + forge_domain::ModelId::new("test-model"), + forge_domain::ToolcallStartPayload::new(tool_call), + ); + let mut conversation = conversation_with_user_msg("hello"); + let original_count = conversation.context.as_ref().unwrap().messages.len(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + let actual_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_count, original_count + 1); + + let last_msg = conversation + .context + .as_ref() + .unwrap() + .messages + .last() + .unwrap(); + let content = last_msg.content().unwrap(); + assert!(content.contains("PreToolUse hook additional context")); + assert!(content.contains("Remember to follow coding standards")); + } + + #[tokio::test] + async fn test_post_tool_use_injects_additional_context() { + let handler = handler_for_event( + StubInfra::new( + Some(0), + r#"{"additionalContext": "Remember to follow coding standards"}"#, + "", + ), + r#"{"PostToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = toolcall_end_event("shell", false); + let mut conversation = conversation_with_user_msg("hello"); + let original_count = conversation.context.as_ref().unwrap().messages.len(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + let actual_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_count, original_count + 1); + + let last_msg = conversation + .context + .as_ref() + .unwrap() + .messages + .last() + .unwrap(); + let content = last_msg.content().unwrap(); + assert!(content.contains("PostToolUse hook additional context")); + assert!(content.contains("Remember to follow coding standards")); + } + + #[tokio::test] + async fn test_no_additional_context_when_empty() { + // NullInfra returns empty stdout => no additionalContext + let handler = handler_for_event( + NullInfra, + r#"{"PostToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#, + ); + let mut event = toolcall_end_event("shell", false); + let mut conversation = conversation_with_user_msg("hello"); + let original_count = conversation.context.as_ref().unwrap().messages.len(); + + handler.handle(&mut event, &mut conversation).await.unwrap(); + + let actual_count = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(actual_count, original_count); + } + + #[test] + fn test_collect_additional_context_from_results() { + let results = vec![ + ( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"additionalContext": "first context"}"#.to_string(), + stderr: String::new(), + }, + ), + ( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"additionalContext": "second context"}"#.to_string(), + stderr: String::new(), + }, + ), + ]; + let actual = UserHookHandler::::collect_additional_context(&results); + assert_eq!( + actual, + vec![ + ("test-cmd".to_string(), "first context".to_string()), + ("test-cmd".to_string(), "second context".to_string()) + ] + ); + } + + #[test] + fn test_collect_additional_context_skips_empty() { + let results = vec![ + ( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"additionalContext": ""}"#.to_string(), + stderr: String::new(), + }, + ), + ( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"additionalContext": " "}"#.to_string(), + stderr: String::new(), + }, + ), + ]; + let actual = UserHookHandler::::collect_additional_context(&results); + assert!(actual.is_empty()); + } + + #[test] + fn test_collect_additional_context_skips_non_success() { + let results = vec![( + "test-cmd".to_string(), + HookExecutionResult { + exit_code: Some(1), + stdout: r#"{"additionalContext": "should not appear"}"#.to_string(), + stderr: String::new(), + }, + )]; + let actual = UserHookHandler::::collect_additional_context(&results); + assert!(actual.is_empty()); + } +} diff --git a/crates/forge_app/src/infra.rs b/crates/forge_app/src/infra.rs index e4b49bfb66..3cc037dddb 100644 --- a/crates/forge_app/src/infra.rs +++ b/crates/forge_app/src/infra.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::hash::Hash; use std::path::{Path, PathBuf}; @@ -162,6 +162,28 @@ pub trait CommandInfra: Send + Sync { working_dir: PathBuf, env_vars: Option>, ) -> anyhow::Result; + + /// Executes a shell command with stdin input. + /// + /// Pipes `stdin_input` to the process stdin, captures stdout and stderr, + /// and waits for the process to complete. Timeout enforcement is handled + /// by the caller. + /// + /// # Arguments + /// * `command` - Shell command string to execute. + /// * `working_dir` - Working directory for the command. + /// * `stdin_input` - Data to pipe to the process stdin. + /// * `env_vars` - Additional environment variables as key-value pairs. + /// + /// # Errors + /// Returns an error if the process cannot be spawned. + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result; } #[async_trait::async_trait] diff --git a/crates/forge_app/src/orch.rs b/crates/forge_app/src/orch.rs index 86157c24e2..3c45d50d00 100644 --- a/crates/forge_app/src/orch.rs +++ b/crates/forge_app/src/orch.rs @@ -4,7 +4,7 @@ use std::time::Duration; use async_recursion::async_recursion; use derive_setters::Setters; -use forge_domain::{Agent, *}; +use forge_domain::{Agent, PromptSuppressed, *}; use forge_template::Element; use futures::future::join_all; use tokio::sync::Notify; @@ -52,13 +52,15 @@ impl> Orc &self.conversation } - // Helper function to get all tool results from a vector of tool calls + // Returns tool results and any PostToolUse hook feedback messages. + // Feedback messages must be injected into context AFTER append_message + // so the LLM sees them in the correct order (after tool results). #[async_recursion] async fn execute_tool_calls( &mut self, tool_calls: &[ToolCallFull], tool_context: &ToolCallContext, - ) -> anyhow::Result> { + ) -> anyhow::Result<(Vec<(ToolCallFull, ToolResult)>, Vec)> { let task_tool_name = ToolKind::Task.name(); // Use a case-insensitive comparison since the model may send "Task" or "task". @@ -98,6 +100,7 @@ impl> Orc // and hooks). let mut other_results: Vec<(ToolCallFull, ToolResult)> = Vec::with_capacity(other_calls.len()); + let mut hook_feedbacks: Vec = Vec::new(); for tool_call in &other_calls { // Send the start notification for system tools and not agent as a tool let is_system_tool = system_tools.contains(&tool_call.name); @@ -114,38 +117,70 @@ impl> Orc notifier.notified().await; } - // Fire the ToolcallStart lifecycle event - let toolcall_start_event = LifecycleEvent::ToolcallStart(EventData::new( + // Fire the ToolcallStart lifecycle event. + // If a hook returns an error (e.g., PreToolUse hook blocked the + // call), skip execution and record an error result instead. + // A PreToolUse hook may also modify the tool call arguments in-flight + // via the AllowWithUpdate path. + let mut toolcall_start_event = LifecycleEvent::ToolcallStart(EventData::new( self.agent.clone(), self.agent.model.clone(), ToolcallStartPayload::new((*tool_call).clone()), )); - self.hook - .handle(&toolcall_start_event, &mut self.conversation) - .await?; - - // Execute the tool - let tool_result = self - .services - .call(&self.agent, tool_context, (*tool_call).clone()) + let hook_result = self + .hook + .handle(&mut toolcall_start_event, &mut self.conversation) .await; + self.drain_hook_warnings(&mut toolcall_start_event).await?; + + let (effective_tool_call, tool_result) = if let Err(hook_err) = hook_result { + // Hook blocked this tool call — notify the UI and produce an + // error ToolResult so the model sees feedback without aborting. + self.send(ChatResponse::HookError { + tool_name: tool_call.name.clone(), + reason: hook_err.to_string(), + }) + .await?; + let result = ToolResult::from((*tool_call).clone()).failure(hook_err); + ((*tool_call).clone(), result) + } else { + // Extract the (possibly modified) tool call from the event. + // A PreToolUse hook may have updated the tool call arguments. + let effective = match toolcall_start_event { + LifecycleEvent::ToolcallStart(data) => data.payload.tool_call, + _ => unreachable!("ToolcallStart event cannot change variant"), + }; + let result = self + .services + .call(&self.agent, tool_context, effective.clone()) + .await; + (effective, result) + }; // Fire the ToolcallEnd lifecycle event (fires on both success and failure) - let toolcall_end_event = LifecycleEvent::ToolcallEnd(EventData::new( + let mut toolcall_end_event = LifecycleEvent::ToolcallEnd(EventData::new( self.agent.clone(), self.agent.model.clone(), - ToolcallEndPayload::new((*tool_call).clone(), tool_result.clone()), + ToolcallEndPayload::new(effective_tool_call.clone(), tool_result.clone()), )); self.hook - .handle(&toolcall_end_event, &mut self.conversation) + .handle(&mut toolcall_end_event, &mut self.conversation) .await?; + self.drain_hook_warnings(&mut toolcall_end_event).await?; + + // Collect PostToolUse hook feedback to inject after append_message. + if let LifecycleEvent::ToolcallEnd(ref data) = toolcall_end_event + && let Some(feedback) = &data.payload.hook_feedback + { + hook_feedbacks.push(feedback.clone()); + } // Send the end notification for system tools and not agent as a tool if is_system_tool { self.send(ChatResponse::ToolCallEnd(tool_result.clone())) .await?; } - other_results.push(((*tool_call).clone(), tool_result)); + other_results.push((effective_tool_call, tool_result)); } // Reconstruct results in the original order of tool_calls. @@ -162,7 +197,17 @@ impl> Orc }) .collect(); - Ok(tool_call_records) + Ok((tool_call_records, hook_feedbacks)) + } + + /// Drains any hook warnings from a lifecycle event and emits them to the + /// UI as `ChatResponse::HookWarning` messages. + async fn drain_hook_warnings(&self, event: &mut LifecycleEvent) -> anyhow::Result<()> { + let warnings = event.drain_warnings(); + for message in warnings { + self.send(ChatResponse::HookWarning { message }).await?; + } + Ok(()) } async fn send(&self, message: ChatResponse) -> anyhow::Result<()> { @@ -231,14 +276,15 @@ impl> Orc let mut context = self.conversation.context.clone().unwrap_or_default(); // Fire the Start lifecycle event - let start_event = LifecycleEvent::Start(EventData::new( + let mut start_event = LifecycleEvent::Start(EventData::new( self.agent.clone(), model_id.clone(), StartPayload, )); self.hook - .handle(&start_event, &mut self.conversation) + .handle(&mut start_event, &mut self.conversation) .await?; + self.drain_hook_warnings(&mut start_event).await?; // Signals that the loop should suspend (task may or may not be completed) let mut should_yield = false; @@ -248,6 +294,11 @@ impl> Orc let mut request_count = 0; + // Tracks whether a Stop hook forced continuation. Passed to the + // next EndPayload so hook scripts can detect re-entrancy and + // avoid infinite loops (matches Claude Code's `stop_hook_active`). + let mut stop_hook_active = false; + // Retrieve the number of requests allowed per tick. let max_requests_per_turn = self.agent.max_requests_per_turn; let tool_context = @@ -258,14 +309,26 @@ impl> Orc self.conversation.context = Some(context.clone()); self.services.update(self.conversation.clone()).await?; - let request_event = LifecycleEvent::Request(EventData::new( + let mut request_event = LifecycleEvent::Request(EventData::new( self.agent.clone(), model_id.clone(), RequestPayload::new(request_count), )); - self.hook - .handle(&request_event, &mut self.conversation) - .await?; + if let Err(e) = self + .hook + .handle(&mut request_event, &mut self.conversation) + .await + { + self.drain_hook_warnings(&mut request_event).await?; + if e.downcast_ref::().is_some() { + // Prompt was blocked by a UserPromptSubmit hook. + // Persist the conversation and exit cleanly. + self.services.update(self.conversation.clone()).await?; + break; + } + return Err(e); + } + self.drain_hook_warnings(&mut request_event).await?; let message = crate::retry::retry_with_config( &self.config.clone().retry.unwrap_or_default(), @@ -298,14 +361,15 @@ impl> Orc .await?; // Fire the Response lifecycle event - let response_event = LifecycleEvent::Response(EventData::new( + let mut response_event = LifecycleEvent::Response(EventData::new( self.agent.clone(), model_id.clone(), ResponsePayload::new(message.clone()), )); self.hook - .handle(&response_event, &mut self.conversation) + .handle(&mut response_event, &mut self.conversation) .await?; + self.drain_hook_warnings(&mut response_event).await?; // Turn is completed, if finish_reason is 'stop'. Gemini models return stop as // finish reason with tool calls. @@ -320,7 +384,7 @@ impl> Orc .any(|call| ToolCatalog::should_yield(&call.name)); // Process tool calls and update context - let mut tool_call_records = self + let (mut tool_call_records, hook_feedbacks) = self .execute_tool_calls(&message.tool_calls, &tool_context) .await?; @@ -357,6 +421,14 @@ impl> Orc message.phase, ); + // Inject PostToolUse hook feedback AFTER the tool results are appended. + // This ensures the LLM sees: [tool_result] [hook_feedback], not the reverse. + for feedback in hook_feedbacks { + context + .messages + .push(ContextMessage::user(feedback, None).into()); + } + if self.error_tracker.limit_reached() { self.send(ChatResponse::Interrupt { reason: InterruptionReason::MaxToolFailurePerTurnLimitReached { @@ -407,24 +479,31 @@ impl> Orc // it adds messages if should_yield { let end_count_before = self.conversation.len(); + let mut end_event = LifecycleEvent::End(EventData::new( + self.agent.clone(), + model_id.clone(), + EndPayload { stop_hook_active }, + )); self.hook - .handle( - &LifecycleEvent::End(EventData::new( - self.agent.clone(), - model_id.clone(), - EndPayload, - )), - &mut self.conversation, - ) + .handle(&mut end_event, &mut self.conversation) .await?; + self.drain_hook_warnings(&mut end_event).await?; self.services.update(self.conversation.clone()).await?; // Check if End hook added messages - if so, continue the loop if self.conversation.len() > end_count_before { - // End hook added messages, sync context and continue + // End hook added messages, sync context and continue. + // Propagate stop_hook_active from the event payload so the + // next iteration knows a Stop hook caused this continuation. + if let LifecycleEvent::End(ref data) = end_event { + stop_hook_active = data.payload.stop_hook_active; + } if let Some(updated_context) = &self.conversation.context { context = updated_context.clone(); } should_yield = false; + } else { + // No continuation -- reset for next user turn. + stop_hook_active = false; } } } diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 4ec2c49809..f2ea8b77f8 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::time::Duration; @@ -486,6 +487,12 @@ pub trait CommandLoaderService: Send + Sync { async fn get_commands(&self) -> anyhow::Result>; } +#[async_trait::async_trait] +pub trait UserHookConfigService: Send + Sync { + /// Loads user hook configuration from `.forge.toml`. + async fn get_user_hook_config(&self) -> anyhow::Result; +} + #[async_trait::async_trait] pub trait PolicyService: Send + Sync { /// Check if an operation is allowed and handle user confirmation if needed @@ -541,6 +548,34 @@ pub trait ProviderAuthService: Send + Sync { ) -> anyhow::Result>; } +/// Service for executing hook commands with stdin input and timeout. +/// +/// Abstracts over the underlying process execution so that `UserHookExecutor` +/// depends on a service rather than infrastructure directly. +#[async_trait::async_trait] +pub trait HookCommandService: Send + Sync { + /// Executes a shell command with stdin input. + /// + /// Pipes `stdin_input` to the process stdin and captures stdout/stderr. + /// Timeout enforcement is handled by the caller. + /// + /// # Arguments + /// * `command` - Shell command string to execute. + /// * `working_dir` - Working directory for the command. + /// * `stdin_input` - Data to pipe to the process stdin. + /// * `env_vars` - Additional environment variables as key-value pairs. + /// + /// # Errors + /// Returns an error if the process cannot be spawned. + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result; +} + pub trait Services: Send + Sync + 'static + Clone + EnvironmentInfra { type ProviderService: ProviderService; type AppConfigService: AppConfigService; @@ -565,10 +600,12 @@ pub trait Services: Send + Sync + 'static + Clone + EnvironmentInfra { type AuthService: AuthService; type AgentRegistry: AgentRegistry; type CommandLoaderService: CommandLoaderService; + type UserHookConfigService: UserHookConfigService; type PolicyService: PolicyService; type ProviderAuthService: ProviderAuthService; type WorkspaceService: WorkspaceService; type SkillFetchService: SkillFetchService; + type HookCommandService: HookCommandService + Clone; fn provider_service(&self) -> &Self::ProviderService; fn config_service(&self) -> &Self::AppConfigService; @@ -593,10 +630,12 @@ pub trait Services: Send + Sync + 'static + Clone + EnvironmentInfra { fn auth_service(&self) -> &Self::AuthService; fn agent_registry(&self) -> &Self::AgentRegistry; fn command_loader_service(&self) -> &Self::CommandLoaderService; + fn user_hook_config_service(&self) -> &Self::UserHookConfigService; fn policy_service(&self) -> &Self::PolicyService; fn provider_auth_service(&self) -> &Self::ProviderAuthService; fn workspace_service(&self) -> &Self::WorkspaceService; fn skill_fetch_service(&self) -> &Self::SkillFetchService; + fn hook_command_service(&self) -> &Self::HookCommandService; } #[async_trait::async_trait] @@ -942,6 +981,13 @@ impl CommandLoaderService for I { } } +#[async_trait::async_trait] +impl UserHookConfigService for I { + async fn get_user_hook_config(&self) -> anyhow::Result { + self.user_hook_config_service().get_user_hook_config().await + } +} + #[async_trait::async_trait] impl PolicyService for I { async fn check_operation_permission( diff --git a/crates/forge_config/src/config.rs b/crates/forge_config/src/config.rs index 6b9baaa213..17429b2c5e 100644 --- a/crates/forge_config/src/config.rs +++ b/crates/forge_config/src/config.rs @@ -9,7 +9,8 @@ use serde::{Deserialize, Serialize}; use crate::reader::ConfigReader; use crate::writer::ConfigWriter; use crate::{ - AutoDumpFormat, Compact, Decimal, HttpConfig, ModelConfig, ReasoningConfig, RetryConfig, Update, + AutoDumpFormat, Compact, Decimal, HttpConfig, ModelConfig, ReasoningConfig, RetryConfig, + Update, UserHookConfig, }; /// Wire protocol a provider uses for chat completions. @@ -281,6 +282,13 @@ pub struct ForgeConfig { /// when a task ends and reminds the LLM about them. #[serde(default)] pub verify_todos: bool, + + /// User hook configuration loaded from the `[hooks]` section. + /// + /// Maps lifecycle event names (e.g. `PreToolUse`, `Stop`) to lists of + /// matcher groups that execute shell commands at each event point. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub hooks: Option, } impl ForgeConfig { @@ -353,4 +361,80 @@ mod tests { assert_eq!(actual.temperature, fixture.temperature); } + + #[test] + fn test_hooks_toml_round_trip() { + use crate::{ + UserHookConfig, UserHookEntry, UserHookEventName, UserHookMatcherGroup, UserHookType, + }; + + let mut events = std::collections::HashMap::new(); + events.insert( + UserHookEventName::PreToolUse, + vec![UserHookMatcherGroup { + matcher: Some("Bash".to_string()), + hooks: vec![UserHookEntry { + hook_type: UserHookType::Command, + command: Some("check.sh".to_string()), + timeout: Some(5000), + }], + }], + ); + let fixture = ForgeConfig { hooks: Some(UserHookConfig { events }), ..Default::default() }; + + let toml = toml_edit::ser::to_string_pretty(&fixture).unwrap(); + let actual: ForgeConfig = toml_edit::de::from_str(&toml).unwrap(); + + assert_eq!(actual.hooks, fixture.hooks); + } + + #[test] + fn test_config_without_hooks_parses() { + let toml = "restricted = false\ntool_supported = true\n"; + let actual: ForgeConfig = toml_edit::de::from_str(toml).unwrap(); + assert_eq!(actual.hooks, None); + } + + /// Verifies hooks survive the `config` crate pipeline (which lowercases + /// TOML keys internally). This is the path used by `read_global()`. + #[test] + fn test_hooks_through_config_crate_pipeline() { + use crate::UserHookEventName; + + let toml = include_str!("fixtures/hook_config_pipeline.toml"); + let result = ConfigReader::default().read_toml(toml).build(); + let actual = result.expect("hooks should parse through config crate pipeline"); + + let hooks = actual.hooks.expect("hooks should be Some"); + let groups = hooks.get_groups(&UserHookEventName::PreToolUse); + assert_eq!(groups.len(), 1, "expected 1 PreToolUse matcher group"); + assert_eq!(groups[0].matcher, Some("Bash".to_string())); + assert_eq!(groups[0].hooks.len(), 1); + assert_eq!( + groups[0].hooks[0].command, + Some("echo 'blocked'".to_string()) + ); + } + + /// Verifies hooks survive when layered with defaults via `ConfigReader`. + #[test] + fn test_hooks_layered_with_defaults() { + use crate::UserHookEventName; + + let hooks_toml = include_str!("fixtures/hook_layered_with_defaults.toml"); + let actual = ConfigReader::default() + .read_defaults() + .read_toml(hooks_toml) + .build() + .expect("hooks should parse when layered with defaults"); + + let hooks = actual.hooks.expect("hooks should be Some"); + assert_eq!(hooks.get_groups(&UserHookEventName::PreToolUse).len(), 1); + assert_eq!(hooks.get_groups(&UserHookEventName::Stop).len(), 1); + assert!( + hooks + .get_groups(&UserHookEventName::SessionStart) + .is_empty() + ); + } } diff --git a/crates/forge_config/src/fixtures/hook_config_pipeline.toml b/crates/forge_config/src/fixtures/hook_config_pipeline.toml new file mode 100644 index 0000000000..aff8502b3c --- /dev/null +++ b/crates/forge_config/src/fixtures/hook_config_pipeline.toml @@ -0,0 +1,6 @@ +[[hooks.PreToolUse]] +matcher = "Bash" + + [[hooks.PreToolUse.hooks]] + type = "command" + command = "echo 'blocked'" diff --git a/crates/forge_config/src/fixtures/hook_layered_with_defaults.toml b/crates/forge_config/src/fixtures/hook_layered_with_defaults.toml new file mode 100644 index 0000000000..c2408cd916 --- /dev/null +++ b/crates/forge_config/src/fixtures/hook_layered_with_defaults.toml @@ -0,0 +1,12 @@ +[[hooks.PreToolUse]] +matcher = "Bash" + + [[hooks.PreToolUse.hooks]] + type = "command" + command = "echo 'pre'" + +[[hooks.Stop]] + + [[hooks.Stop.hooks]] + type = "command" + command = "stop.sh" diff --git a/crates/forge_config/src/fixtures/hook_multiple_events.toml b/crates/forge_config/src/fixtures/hook_multiple_events.toml new file mode 100644 index 0000000000..832fb8cb41 --- /dev/null +++ b/crates/forge_config/src/fixtures/hook_multiple_events.toml @@ -0,0 +1,18 @@ +[[PreToolUse]] +matcher = "Bash" + + [[PreToolUse.hooks]] + type = "command" + command = "pre.sh" + +[[PostToolUse]] + + [[PostToolUse.hooks]] + type = "command" + command = "post.sh" + +[[Stop]] + + [[Stop.hooks]] + type = "command" + command = "stop.sh" diff --git a/crates/forge_config/src/fixtures/hook_no_matcher.toml b/crates/forge_config/src/fixtures/hook_no_matcher.toml new file mode 100644 index 0000000000..1bf222177b --- /dev/null +++ b/crates/forge_config/src/fixtures/hook_no_matcher.toml @@ -0,0 +1,5 @@ +[[PostToolUse]] + + [[PostToolUse.hooks]] + type = "command" + command = "always.sh" diff --git a/crates/forge_config/src/fixtures/hook_pre_tool_use.toml b/crates/forge_config/src/fixtures/hook_pre_tool_use.toml new file mode 100644 index 0000000000..159f1ddbe4 --- /dev/null +++ b/crates/forge_config/src/fixtures/hook_pre_tool_use.toml @@ -0,0 +1,6 @@ +[[PreToolUse]] +matcher = "Bash" + + [[PreToolUse.hooks]] + type = "command" + command = "echo 'blocked'" diff --git a/crates/forge_config/src/fixtures/hook_with_timeout.toml b/crates/forge_config/src/fixtures/hook_with_timeout.toml new file mode 100644 index 0000000000..b5f5f5083b --- /dev/null +++ b/crates/forge_config/src/fixtures/hook_with_timeout.toml @@ -0,0 +1,6 @@ +[[PreToolUse]] + + [[PreToolUse.hooks]] + type = "command" + command = "slow.sh" + timeout = 30000 diff --git a/crates/forge_config/src/hooks.rs b/crates/forge_config/src/hooks.rs new file mode 100644 index 0000000000..851a6efe61 --- /dev/null +++ b/crates/forge_config/src/hooks.rs @@ -0,0 +1,203 @@ +use std::collections::HashMap; + +use fake::Dummy; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use strum_macros::Display; + +/// Top-level user hook configuration. +/// +/// Maps hook event names to a list of matcher groups. This is deserialized +/// from the `hooks` section in `.forge.toml`. +/// +/// Example TOML: +/// ```toml +/// [[hooks.PreToolUse]] +/// matcher = "Bash" +/// +/// [[hooks.PreToolUse.hooks]] +/// type = "command" +/// command = "echo hi" +/// ``` +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema, Dummy)] +pub struct UserHookConfig { + /// Map of event name -> list of matcher groups. + #[serde(flatten)] + pub events: HashMap>, +} + +impl UserHookConfig { + /// Creates an empty user hook configuration. + pub fn new() -> Self { + Self { events: HashMap::new() } + } + + /// Returns the matcher groups for a given event name, or an empty slice if + /// none. + pub fn get_groups(&self, event: &UserHookEventName) -> &[UserHookMatcherGroup] { + self.events.get(event).map_or(&[], |v| v.as_slice()) + } + + /// Returns true if no hook events are configured. + pub fn is_empty(&self) -> bool { + self.events.is_empty() + } +} + +/// Supported hook event names that map to lifecycle points in the +/// orchestrator. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Display, JsonSchema, Dummy)] +pub enum UserHookEventName { + /// Fired before a tool call executes. Can block execution. + PreToolUse, + /// Fired after a tool call succeeds. + PostToolUse, + /// Fired after a tool call fails. + PostToolUseFailure, + /// Fired when the agent finishes responding. Can block stop to continue. + Stop, + /// Fired when a session starts or resumes. + SessionStart, + /// Fired when a session ends/terminates. + SessionEnd, + /// Fired when a user prompt is submitted. + UserPromptSubmit, +} + +/// A matcher group pairs an optional regex matcher with a list of hook +/// handlers. +/// +/// When a lifecycle event fires, only matcher groups whose `matcher` regex +/// matches the relevant event context (e.g., tool name) will have their hooks +/// executed. If `matcher` is `None` (or an empty string, which is normalized +/// to `None`), all hooks in this group fire unconditionally. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema, Dummy)] +pub struct UserHookMatcherGroup { + /// Optional regex pattern to match against (e.g., tool name for + /// PreToolUse/PostToolUse). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub matcher: Option, + + /// List of hook handlers to execute when this matcher matches. + #[serde(default)] + pub hooks: Vec, +} + +/// A single hook handler entry that defines what action to take. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema, Dummy)] +pub struct UserHookEntry { + /// The type of hook handler. + #[serde(rename = "type")] + pub hook_type: UserHookType, + + /// The shell command to execute (for `Command` type hooks). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub command: Option, + + /// Timeout in milliseconds for this hook. Defaults to 600000ms (10 + /// minutes). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +/// The type of hook handler to execute. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, JsonSchema, Dummy)] +#[serde(rename_all = "lowercase")] +pub enum UserHookType { + /// Executes a shell command, piping JSON to stdin and reading JSON from + /// stdout. + Command, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_deserialize_empty_config() { + let toml = ""; + let actual: UserHookConfig = toml_edit::de::from_str(toml).unwrap(); + let expected = UserHookConfig::new(); + assert_eq!(actual, expected); + } + + #[test] + fn test_deserialize_pre_tool_use_hook() { + let toml = include_str!("fixtures/hook_pre_tool_use.toml"); + let actual: UserHookConfig = toml_edit::de::from_str(toml).unwrap(); + let groups = actual.get_groups(&UserHookEventName::PreToolUse); + + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].matcher, Some("Bash".to_string())); + assert_eq!(groups[0].hooks.len(), 1); + assert_eq!(groups[0].hooks[0].hook_type, UserHookType::Command); + assert_eq!( + groups[0].hooks[0].command, + Some("echo 'blocked'".to_string()) + ); + } + + #[test] + fn test_deserialize_multiple_events() { + let toml = include_str!("fixtures/hook_multiple_events.toml"); + let actual: UserHookConfig = toml_edit::de::from_str(toml).unwrap(); + + assert_eq!(actual.get_groups(&UserHookEventName::PreToolUse).len(), 1); + assert_eq!(actual.get_groups(&UserHookEventName::PostToolUse).len(), 1); + assert_eq!(actual.get_groups(&UserHookEventName::Stop).len(), 1); + assert!( + actual + .get_groups(&UserHookEventName::SessionStart) + .is_empty() + ); + } + + #[test] + fn test_deserialize_hook_with_timeout() { + let toml = include_str!("fixtures/hook_with_timeout.toml"); + let actual: UserHookConfig = toml_edit::de::from_str(toml).unwrap(); + let groups = actual.get_groups(&UserHookEventName::PreToolUse); + + assert_eq!(groups[0].hooks[0].timeout, Some(30000)); + } + + #[test] + fn test_no_matcher_group_fires_unconditionally() { + let toml = include_str!("fixtures/hook_no_matcher.toml"); + let actual: UserHookConfig = toml_edit::de::from_str(toml).unwrap(); + let groups = actual.get_groups(&UserHookEventName::PostToolUse); + + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].matcher, None); + } + + #[test] + fn test_toml_round_trip() { + let toml_input = r#" +[[PreToolUse]] +matcher = "Bash" + + [[PreToolUse.hooks]] + type = "command" + command = "check.sh" + timeout = 5000 +"#; + let config: UserHookConfig = toml_edit::de::from_str(toml_input).unwrap(); + let serialized = toml_edit::ser::to_string_pretty(&config).unwrap(); + let roundtrip: UserHookConfig = toml_edit::de::from_str(&serialized).unwrap(); + assert_eq!(config, roundtrip); + } + + #[test] + fn test_json_deserialization_still_works() { + let json = r#"{ + "PreToolUse": [ + { "matcher": "Bash", "hooks": [{ "type": "command", "command": "echo hi" }] } + ] + }"#; + let actual: UserHookConfig = serde_json::from_str(json).unwrap(); + assert_eq!(actual.get_groups(&UserHookEventName::PreToolUse).len(), 1); + } +} diff --git a/crates/forge_config/src/lib.rs b/crates/forge_config/src/lib.rs index cc253277e4..5d903b1a22 100644 --- a/crates/forge_config/src/lib.rs +++ b/crates/forge_config/src/lib.rs @@ -3,6 +3,7 @@ mod compact; mod config; mod decimal; mod error; +mod hooks; mod http; mod legacy; mod model; @@ -17,6 +18,7 @@ pub use compact::*; pub use config::*; pub use decimal::*; pub use error::Error; +pub use hooks::*; pub use http::*; pub use model::*; pub use percentage::*; diff --git a/crates/forge_domain/src/chat_response.rs b/crates/forge_domain/src/chat_response.rs index e24cd9d731..60e1db2e3d 100644 --- a/crates/forge_domain/src/chat_response.rs +++ b/crates/forge_domain/src/chat_response.rs @@ -65,6 +65,21 @@ pub enum ChatResponse { notifier: Arc, }, ToolCallEnd(ToolResult), + /// A user-configured hook blocked execution of a tool call. + HookError { + /// Name of the tool that was blocked. + tool_name: ToolName, + /// Human-readable reason provided by the hook (from stderr or JSON + /// output). + reason: String, + }, + /// A user-configured hook encountered an error or produced a warning. + /// Displayed in the UI as a warning regardless of whether the hook + /// blocked execution or not. + HookWarning { + /// Human-readable warning message. + message: String, + }, RetryAttempt { cause: Cause, duration: Duration, diff --git a/crates/forge_domain/src/hook.rs b/crates/forge_domain/src/hook.rs index 47579d7a43..03ccdd1a6c 100644 --- a/crates/forge_domain/src/hook.rs +++ b/crates/forge_domain/src/hook.rs @@ -16,12 +16,15 @@ pub struct EventData { pub model_id: ModelId, /// Event-specific payload data pub payload: P, + /// Transient warnings collected by hook handlers. The orchestrator + /// drains these after each hook invocation and emits them to the UI. + pub warnings: Vec, } impl EventData

{ /// Creates a new event with the given agent, model ID, and payload pub fn new(agent: Agent, model_id: ModelId, payload: P) -> Self { - Self { agent, model_id, payload } + Self { agent, model_id, payload, warnings: Vec::new() } } } @@ -29,9 +32,18 @@ impl EventData

{ #[derive(Debug, PartialEq, Clone, Default)] pub struct StartPayload; -/// Payload for the End event +/// Payload for the End event. +/// +/// Carries `stop_hook_active` to signal whether the current End event was +/// triggered by a Stop hook forcing continuation (matching Claude Code's +/// `stop_hook_active` field). When `true`, hook scripts should allow the +/// agent to stop to prevent infinite loops. #[derive(Debug, PartialEq, Clone, Default)] -pub struct EndPayload; +pub struct EndPayload { + /// Whether a Stop hook caused this continuation. Sent to hook scripts + /// as `stop_hook_active` so they can break re-entrant loops. + pub stop_hook_active: bool, +} /// Payload for the Request event #[derive(Debug, PartialEq, Clone, Setters)] @@ -86,12 +98,17 @@ pub struct ToolcallEndPayload { pub tool_call: ToolCallFull, /// The tool result (success or failure) pub result: ToolResult, + /// Feedback message from a blocking PostToolUse hook, if any. + /// Set by the hook handler and read by the orchestrator after + /// `append_message` so the message is injected in the correct + /// position (after the tool result, not before it). + pub hook_feedback: Option, } impl ToolcallEndPayload { /// Creates a new tool call end payload pub fn new(tool_call: ToolCallFull, result: ToolResult) -> Self { - Self { tool_call, result } + Self { tool_call, result, hook_feedback: None } } } @@ -117,6 +134,20 @@ pub enum LifecycleEvent { ToolcallEnd(EventData), } +impl LifecycleEvent { + /// Drains all warnings from the inner `EventData`, regardless of variant. + pub fn drain_warnings(&mut self) -> Vec { + match self { + LifecycleEvent::Start(data) => data.warnings.drain(..).collect(), + LifecycleEvent::End(data) => data.warnings.drain(..).collect(), + LifecycleEvent::Request(data) => data.warnings.drain(..).collect(), + LifecycleEvent::Response(data) => data.warnings.drain(..).collect(), + LifecycleEvent::ToolcallStart(data) => data.warnings.drain(..).collect(), + LifecycleEvent::ToolcallEnd(data) => data.warnings.drain(..).collect(), + } + } +} + /// Trait for handling lifecycle events /// /// Implementations of this trait can be used to react to different @@ -126,12 +157,13 @@ pub trait EventHandle: Send + Sync { /// Handles a lifecycle event and potentially modifies the conversation /// /// # Arguments - /// * `event` - The lifecycle event that occurred + /// * `event` - The lifecycle event that occurred (mutable to allow + /// in-flight modification) /// * `conversation` - The current conversation state (mutable) /// /// # Errors /// Returns an error if the event handling fails - async fn handle(&self, event: &T, conversation: &mut Conversation) -> anyhow::Result<()>; + async fn handle(&self, event: &mut T, conversation: &mut Conversation) -> anyhow::Result<()>; } /// Extension trait for combining event handlers @@ -166,7 +198,7 @@ impl + 'static> EventHandleExt fo // Implement EventHandle for Box to allow using boxed handlers #[async_trait] impl EventHandle for Box> { - async fn handle(&self, event: &T, conversation: &mut Conversation) -> anyhow::Result<()> { + async fn handle(&self, event: &mut T, conversation: &mut Conversation) -> anyhow::Result<()> { (**self).handle(event, conversation).await } } @@ -326,10 +358,10 @@ impl Hook { impl EventHandle for Hook { async fn handle( &self, - event: &LifecycleEvent, + event: &mut LifecycleEvent, conversation: &mut Conversation, ) -> anyhow::Result<()> { - match &event { + match event { LifecycleEvent::Start(data) => self.on_start.handle(data, conversation).await, LifecycleEvent::End(data) => self.on_end.handle(data, conversation).await, LifecycleEvent::Request(data) => self.on_request.handle(data, conversation).await, @@ -354,7 +386,7 @@ struct CombinedHandler(Box>, Box EventHandle for CombinedHandler { - async fn handle(&self, event: &T, conversation: &mut Conversation) -> anyhow::Result<()> { + async fn handle(&self, event: &mut T, conversation: &mut Conversation) -> anyhow::Result<()> { // Run the first handler self.0.handle(event, conversation).await?; // Run the second handler with the cloned event @@ -371,7 +403,7 @@ pub struct NoOpHandler; #[async_trait] impl EventHandle for NoOpHandler { - async fn handle(&self, _: &T, _: &mut Conversation) -> anyhow::Result<()> { + async fn handle(&self, _: &mut T, _: &mut Conversation) -> anyhow::Result<()> { Ok(()) } } @@ -379,17 +411,17 @@ impl EventHandle for NoOpHandler { #[async_trait] impl EventHandle for F where - F: Fn(&T, &mut Conversation) -> Fut + Send + Sync, + F: Fn(&mut T, &mut Conversation) -> Fut + Send + Sync, Fut: std::future::Future> + Send, { - async fn handle(&self, event: &T, conversation: &mut Conversation) -> anyhow::Result<()> { + async fn handle(&self, event: &mut T, conversation: &mut Conversation) -> anyhow::Result<()> { (self)(event, conversation).await } } impl From for Box> where - F: Fn(&T, &mut Conversation) -> Fut + Send + Sync + 'static, + F: Fn(&mut T, &mut Conversation) -> Fut + Send + Sync + 'static, Fut: std::future::Future> + Send + 'static, { fn from(handler: F) -> Self { @@ -397,6 +429,22 @@ where } } +/// Error indicating a UserPromptSubmit hook blocked the prompt. +/// +/// When a UserPromptSubmit hook exits with code 2 or returns a blocking JSON +/// decision, the handler returns this error to signal the orchestrator that +/// the prompt should be suppressed (the LLM call must not proceed). +#[derive(Debug)] +pub struct PromptSuppressed(pub String); + +impl std::fmt::Display for PromptSuppressed { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Prompt suppressed by hook: {}", self.0) + } +} + +impl std::error::Error for PromptSuppressed {} + #[cfg(test)] mod tests { use pretty_assertions::assert_eq; @@ -432,7 +480,7 @@ mod tests { let events_clone = events.clone(); let hook = Hook::default().on_start( - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events_clone.clone(); let event = event.clone(); async move { @@ -445,7 +493,7 @@ mod tests { let mut conversation = Conversation::generate(); hook.handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), &mut conversation, ) .await @@ -466,7 +514,7 @@ mod tests { let hook = Hook::default() .on_start({ let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::Start(event.clone()); async move { @@ -477,7 +525,7 @@ mod tests { }) .on_end({ let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::End(event.clone()); async move { @@ -488,7 +536,7 @@ mod tests { }) .on_request({ let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::Request(event.clone()); async move { @@ -502,21 +550,25 @@ mod tests { // Test Start event hook.handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), &mut conversation, ) .await .unwrap(); // Test End event hook.handle( - &LifecycleEvent::End(EventData::new(test_agent(), test_model_id(), EndPayload)), + &mut LifecycleEvent::End(EventData::new( + test_agent(), + test_model_id(), + EndPayload::default(), + )), &mut conversation, ) .await .unwrap(); // Test Request event hook.handle( - &LifecycleEvent::Request(EventData::new( + &mut LifecycleEvent::Request(EventData::new( test_agent(), test_model_id(), RequestPayload::new(1), @@ -534,7 +586,11 @@ mod tests { ); assert_eq!( handled[1], - LifecycleEvent::End(EventData::new(test_agent(), test_model_id(), EndPayload)) + LifecycleEvent::End(EventData::new( + test_agent(), + test_model_id(), + EndPayload::default() + )) ); assert_eq!( handled[2], @@ -553,7 +609,7 @@ mod tests { let hook = Hook::new( { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::Start(event.clone()); async move { @@ -564,7 +620,7 @@ mod tests { }, { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::End(event.clone()); async move { @@ -575,7 +631,7 @@ mod tests { }, { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::Request(event.clone()); async move { @@ -586,7 +642,7 @@ mod tests { }, { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::Response(event.clone()); async move { @@ -597,7 +653,8 @@ mod tests { }, { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, + _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::ToolcallStart(event.clone()); async move { @@ -608,7 +665,8 @@ mod tests { }, { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, + _conversation: &mut Conversation| { let events = events.clone(); let event = LifecycleEvent::ToolcallEnd(event.clone()); async move { @@ -623,7 +681,11 @@ mod tests { let all_events = vec![ LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), - LifecycleEvent::End(EventData::new(test_agent(), test_model_id(), EndPayload)), + LifecycleEvent::End(EventData::new( + test_agent(), + test_model_id(), + EndPayload::default(), + )), LifecycleEvent::Request(EventData::new( test_agent(), test_model_id(), @@ -658,8 +720,8 @@ mod tests { )), ]; - for event in all_events { - hook.handle(&event, &mut conversation).await.unwrap(); + for mut event in all_events { + hook.handle(&mut event, &mut conversation).await.unwrap(); } let handled = events.lock().unwrap(); @@ -671,7 +733,7 @@ mod tests { let title = std::sync::Arc::new(std::sync::Mutex::new(None)); let hook = Hook::default().on_start({ let title = title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let title = title.clone(); async move { *title.lock().unwrap() = Some("Modified title".to_string()); @@ -684,7 +746,7 @@ mod tests { assert!(title.lock().unwrap().is_none()); hook.handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), &mut conversation, ) .await @@ -708,7 +770,7 @@ mod tests { let hook1 = Hook::default().on_start({ let counter = counter1.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let counter = counter.clone(); async move { *counter.lock().unwrap() += 1; @@ -719,7 +781,7 @@ mod tests { let hook2 = Hook::default().on_start({ let counter = counter2.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let counter = counter.clone(); async move { *counter.lock().unwrap() += 1; @@ -732,7 +794,11 @@ mod tests { let mut conversation = Conversation::generate(); combined .handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new( + test_agent(), + test_model_id(), + StartPayload, + )), &mut conversation, ) .await @@ -749,7 +815,7 @@ mod tests { let hook1 = Hook::default().on_start({ let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = event.clone(); async move { @@ -761,7 +827,7 @@ mod tests { let hook2 = Hook::default().on_start({ let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = event.clone(); async move { @@ -773,7 +839,7 @@ mod tests { let hook3 = Hook::default().on_start({ let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = event.clone(); async move { @@ -787,7 +853,11 @@ mod tests { let mut conversation = Conversation::generate(); combined .handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new( + test_agent(), + test_model_id(), + StartPayload, + )), &mut conversation, ) .await @@ -808,7 +878,7 @@ mod tests { let hook1 = Hook::default() .on_start({ let start_title = start_title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let start_title = start_title.clone(); async move { *start_title.lock().unwrap() = Some("Start".to_string()); @@ -818,7 +888,7 @@ mod tests { }) .on_end({ let end_title = end_title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let end_title = end_title.clone(); async move { *end_title.lock().unwrap() = Some("End".to_string()); @@ -835,7 +905,11 @@ mod tests { // Test Start event combined .handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new( + test_agent(), + test_model_id(), + StartPayload, + )), &mut conversation, ) .await @@ -845,7 +919,11 @@ mod tests { // Test End event combined .handle( - &LifecycleEvent::End(EventData::new(test_agent(), test_model_id(), EndPayload)), + &mut LifecycleEvent::End(EventData::new( + test_agent(), + test_model_id(), + EndPayload::default(), + )), &mut conversation, ) .await @@ -860,7 +938,7 @@ mod tests { let handler1 = { let counter = counter1.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let counter = counter.clone(); async move { *counter.lock().unwrap() += 1; @@ -871,7 +949,7 @@ mod tests { let handler2 = { let counter = counter2.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let counter = counter.clone(); async move { *counter.lock().unwrap() += 1; @@ -885,7 +963,7 @@ mod tests { let mut conversation = Conversation::generate(); combined .handle( - &EventData::new(test_agent(), test_model_id(), StartPayload), + &mut EventData::new(test_agent(), test_model_id(), StartPayload), &mut conversation, ) .await @@ -903,7 +981,7 @@ mod tests { let handler1 = { let counter = counter1.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let counter = counter.clone(); async move { *counter.lock().unwrap() += 1; @@ -914,7 +992,7 @@ mod tests { let handler2 = { let counter = counter2.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let counter = counter.clone(); async move { *counter.lock().unwrap() += 1; @@ -928,7 +1006,7 @@ mod tests { let mut conversation = Conversation::generate(); combined .handle( - &EventData::new(test_agent(), test_model_id(), StartPayload), + &mut EventData::new(test_agent(), test_model_id(), StartPayload), &mut conversation, ) .await @@ -938,14 +1016,13 @@ mod tests { assert_eq!(*counter1.lock().unwrap(), 1); assert_eq!(*counter2.lock().unwrap(), 1); } - #[tokio::test] async fn test_event_handle_ext_chain() { let events = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); let handler1 = { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = event.clone(); async move { @@ -957,7 +1034,7 @@ mod tests { let handler2 = { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = event.clone(); async move { @@ -969,7 +1046,7 @@ mod tests { let handler3 = { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = event.clone(); async move { @@ -986,7 +1063,7 @@ mod tests { let mut conversation = Conversation::generate(); combined .handle( - &EventData::new(test_agent(), test_model_id(), StartPayload), + &mut EventData::new(test_agent(), test_model_id(), StartPayload), &mut conversation, ) .await @@ -1006,7 +1083,7 @@ mod tests { let start_handler = { let start_title = start_title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let start_title = start_title.clone(); async move { *start_title.lock().unwrap() = Some("Started".to_string()); @@ -1017,7 +1094,7 @@ mod tests { let logging_handler = { let events = events.clone(); - move |event: &EventData, _conversation: &mut Conversation| { + move |event: &mut EventData, _conversation: &mut Conversation| { let events = events.clone(); let event = event.clone(); async move { @@ -1035,7 +1112,7 @@ mod tests { let mut conversation = Conversation::generate(); hook.handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), &mut conversation, ) .await @@ -1053,7 +1130,7 @@ mod tests { let hook = Hook::default() .on_start({ let start_title = start_title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let start_title = start_title.clone(); async move { *start_title.lock().unwrap() = Some("Started".to_string()); @@ -1063,7 +1140,7 @@ mod tests { }) .on_end({ let end_title = end_title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let end_title = end_title.clone(); async move { *end_title.lock().unwrap() = Some("Ended".to_string()); @@ -1075,7 +1152,7 @@ mod tests { // Test using handle() directly (EventHandle trait) let mut conversation = Conversation::generate(); hook.handle( - &LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), + &mut LifecycleEvent::Start(EventData::new(test_agent(), test_model_id(), StartPayload)), &mut conversation, ) .await @@ -1083,7 +1160,11 @@ mod tests { assert_eq!(*start_title.lock().unwrap(), Some("Started".to_string())); hook.handle( - &LifecycleEvent::End(EventData::new(test_agent(), test_model_id(), EndPayload)), + &mut LifecycleEvent::End(EventData::new( + test_agent(), + test_model_id(), + EndPayload::default(), + )), &mut conversation, ) .await @@ -1098,7 +1179,7 @@ mod tests { let handler1 = { let hook1_title = hook1_title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let hook1_title = hook1_title.clone(); async move { *hook1_title.lock().unwrap() = Some("Started".to_string()); @@ -1108,7 +1189,7 @@ mod tests { }; let handler2 = { let hook2_title = hook2_title.clone(); - move |_event: &EventData, _conversation: &mut Conversation| { + move |_event: &mut EventData, _conversation: &mut Conversation| { let hook2_title = hook2_title.clone(); async move { *hook2_title.lock().unwrap() = Some("Ended".to_string()); @@ -1123,7 +1204,7 @@ mod tests { let mut conversation = Conversation::generate(); combined .handle( - &EventData::new(test_agent(), test_model_id(), StartPayload), + &mut EventData::new(test_agent(), test_model_id(), StartPayload), &mut conversation, ) .await diff --git a/crates/forge_domain/src/lib.rs b/crates/forge_domain/src/lib.rs index 5db0a8553b..d97bce49d5 100644 --- a/crates/forge_domain/src/lib.rs +++ b/crates/forge_domain/src/lib.rs @@ -50,6 +50,7 @@ mod top_k; mod top_p; mod transformer; mod update; +mod user_hook_io; mod validation; mod workspace; mod xml; @@ -104,6 +105,7 @@ pub use top_k::*; pub use top_p::*; pub use transformer::*; pub use update::*; +pub use user_hook_io::*; pub use validation::*; pub use workspace::*; pub use xml::*; diff --git a/crates/forge_domain/src/user_hook_io.rs b/crates/forge_domain/src/user_hook_io.rs new file mode 100644 index 0000000000..15cde3d703 --- /dev/null +++ b/crates/forge_domain/src/user_hook_io.rs @@ -0,0 +1,523 @@ +use serde::{Deserialize, Serialize}; +use serde_json::{Map, Value}; + +/// Exit code constants for hook script results. +pub mod exit_codes { + /// Hook executed successfully. stdout may contain JSON output. + pub const SUCCESS: i32 = 0; + /// Blocking error. stderr is used as feedback message. + pub const BLOCK: i32 = 2; +} + +/// JSON input sent to hook scripts via stdin. +/// +/// Contains common fields shared across all hook events plus event-specific +/// data in the `event_data` field. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct HookInput { + /// The hook event name (e.g., "PreToolUse", "PostToolUse", "Stop"). + pub hook_event_name: String, + + /// Current working directory. + pub cwd: String, + + /// Session/conversation ID. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub session_id: Option, + + /// Event-specific payload data. + #[serde(flatten)] + pub event_data: HookEventInput, +} + +/// Event-specific input data variants. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum HookEventInput { + /// Input for PreToolUse events. + PreToolUse { + /// Name of the tool being called. + tool_name: String, + /// Tool call arguments as a JSON value. + tool_input: Value, + /// Unique identifier for this tool call. + #[serde(default, skip_serializing_if = "Option::is_none")] + tool_use_id: Option, + }, + /// Input for PostToolUse events. + PostToolUse { + /// Name of the tool that was called. + tool_name: String, + /// Tool call arguments as a JSON value. + tool_input: Value, + /// Tool output/response as a JSON value. + tool_response: Value, + /// Unique identifier for this tool call. + #[serde(default, skip_serializing_if = "Option::is_none")] + tool_use_id: Option, + }, + /// Input for Stop events. + Stop { + /// Whether a previous Stop hook caused this continuation. Hook scripts + /// should check this to prevent infinite loops. + #[serde(default)] + stop_hook_active: bool, + /// The last assistant message text before the stop event. + #[serde(default, skip_serializing_if = "Option::is_none")] + last_assistant_message: Option, + }, + /// Input for SessionStart events. + SessionStart { + /// Source of the session start (e.g., "startup", "resume"). + source: String, + }, + /// Input for UserPromptSubmit events. + UserPromptSubmit { + /// The raw prompt text submitted by the user. + prompt: String, + }, + /// Empty input for events that don't need event-specific data. + Empty {}, +} + +/// JSON output parsed from hook script stdout. +/// +/// Fields are optional; scripts that don't need to control behavior can simply +/// exit 0 with empty stdout. +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct HookOutput { + /// Whether execution should continue. When `false`, prevents the agent's + /// execution loop from continuing. Checked by `is_blocking()` alongside + /// `decision` and `permission_decision`. + #[serde(default, rename = "continue", skip_serializing_if = "Option::is_none")] + pub continue_execution: Option, + + /// Decision for blocking events. `"block"` blocks the operation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decision: Option, + + /// Reason for blocking, used as feedback to the agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reason: Option, + + /// For PreToolUse: permission decision ("allow", "deny", "ask"). + #[serde( + default, + rename = "permissionDecision", + skip_serializing_if = "Option::is_none" + )] + pub permission_decision: Option, + + /// For PreToolUse: modified tool input to replace the original. + #[serde( + default, + rename = "updatedInput", + skip_serializing_if = "Option::is_none" + )] + pub updated_input: Option>, + + /// Additional context to inject into the conversation. + #[serde( + default, + rename = "additionalContext", + skip_serializing_if = "Option::is_none" + )] + pub additional_context: Option, + + /// Reason for stopping, used as a fallback reason when + /// `continue_execution` is `false`. Consumed by `process_results` and + /// `process_pre_tool_use_output` as a fallback when `reason` is absent. + #[serde( + default, + rename = "stopReason", + skip_serializing_if = "Option::is_none" + )] + pub stop_reason: Option, +} + +impl HookOutput { + /// Attempts to parse stdout as JSON. Falls back to empty output on failure. + pub fn parse(stdout: &str) -> Self { + if stdout.trim().is_empty() { + return Self::default(); + } + serde_json::from_str(stdout).unwrap_or_default() + } + + /// Returns true if this output requests blocking. + pub fn is_blocking(&self) -> bool { + self.decision.as_deref() == Some("block") + || self.permission_decision.as_deref() == Some("deny") + || self.continue_execution == Some(false) + } + + /// Returns the blocking reason, preferring `reason` over `stop_reason`. + pub fn blocking_reason(&self, default: &str) -> String { + self.reason + .clone() + .or_else(|| self.stop_reason.clone()) + .unwrap_or_else(|| default.to_string()) + } +} + +/// Result of executing a hook command. +#[derive(Debug, Clone)] +pub struct HookExecutionResult { + /// Process exit code (None if terminated by signal). + pub exit_code: Option, + /// Captured stdout. + pub stdout: String, + /// Captured stderr. + pub stderr: String, +} + +impl HookExecutionResult { + /// Returns true if the hook exited with the blocking exit code (2). + pub fn is_blocking_exit(&self) -> bool { + self.exit_code == Some(exit_codes::BLOCK) + } + + /// Returns true if the hook exited successfully (0). + pub fn is_success(&self) -> bool { + self.exit_code == Some(exit_codes::SUCCESS) + } + + /// Returns true if the hook exited with a non-blocking error (non-0, + /// non-2). + pub fn is_non_blocking_error(&self) -> bool { + match self.exit_code { + Some(code) => code != exit_codes::SUCCESS && code != exit_codes::BLOCK, + None => true, + } + } + + /// Parses the stdout as a HookOutput if the exit was successful. + pub fn parse_output(&self) -> Option { + if self.is_success() { + Some(HookOutput::parse(&self.stdout)) + } else { + None + } + } + + /// Returns the feedback message for blocking errors (stderr content). + pub fn blocking_message(&self) -> Option<&str> { + if self.is_blocking_exit() { + let msg = self.stderr.trim(); + if msg.is_empty() { None } else { Some(msg) } + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_hook_input_serialization_pre_tool_use() { + let fixture = HookInput { + hook_event_name: "PreToolUse".to_string(), + cwd: "/project".to_string(), + session_id: Some("sess-123".to_string()), + event_data: HookEventInput::PreToolUse { + tool_name: "Bash".to_string(), + tool_input: serde_json::json!({"command": "ls"}), + tool_use_id: None, + }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["hook_event_name"], "PreToolUse"); + assert_eq!(actual["cwd"], "/project"); + assert_eq!(actual["tool_name"], "Bash"); + assert_eq!(actual["tool_input"]["command"], "ls"); + assert!(actual.get("tool_use_id").is_none()); + } + + #[test] + fn test_hook_input_serialization_pre_tool_use_with_tool_use_id() { + let fixture = HookInput { + hook_event_name: "PreToolUse".to_string(), + cwd: "/project".to_string(), + session_id: Some("sess-123".to_string()), + event_data: HookEventInput::PreToolUse { + tool_name: "Bash".to_string(), + tool_input: serde_json::json!({"command": "ls"}), + tool_use_id: Some("forge_call_id_abc123".to_string()), + }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["tool_use_id"], "forge_call_id_abc123"); + } + + #[test] + fn test_hook_input_serialization_stop() { + let fixture = HookInput { + hook_event_name: "Stop".to_string(), + cwd: "/project".to_string(), + session_id: None, + event_data: HookEventInput::Stop { + stop_hook_active: false, + last_assistant_message: None, + }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["hook_event_name"], "Stop"); + assert!(actual.get("last_assistant_message").is_none()); + } + + #[test] + fn test_hook_input_serialization_stop_with_last_assistant_message() { + let fixture = HookInput { + hook_event_name: "Stop".to_string(), + cwd: "/project".to_string(), + session_id: Some("sess-456".to_string()), + event_data: HookEventInput::Stop { + stop_hook_active: false, + last_assistant_message: Some("Here is the result.".to_string()), + }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["last_assistant_message"], "Here is the result."); + } + + #[test] + fn test_hook_input_serialization_user_prompt_submit() { + let fixture = HookInput { + hook_event_name: "UserPromptSubmit".to_string(), + cwd: "/project".to_string(), + session_id: Some("sess-abc".to_string()), + event_data: HookEventInput::UserPromptSubmit { prompt: "fix the bug".to_string() }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["hook_event_name"], "UserPromptSubmit"); + assert_eq!(actual["cwd"], "/project"); + assert_eq!(actual["session_id"], "sess-abc"); + assert_eq!(actual["prompt"], "fix the bug"); + // No tool_name or other variant fields present + assert!(actual["tool_name"].is_null()); + } + + #[test] + fn test_hook_output_parse_valid_json() { + let stdout = r#"{"decision": "block", "reason": "unsafe command"}"#; + let actual = HookOutput::parse(stdout); + + assert_eq!(actual.decision, Some("block".to_string())); + assert_eq!(actual.reason, Some("unsafe command".to_string())); + } + + #[test] + fn test_hook_output_parse_empty_string() { + let actual = HookOutput::parse(""); + let expected = HookOutput::default(); + assert_eq!(actual, expected); + } + + #[test] + fn test_hook_output_parse_invalid_json_returns_default() { + let actual = HookOutput::parse("not json at all"); + let expected = HookOutput::default(); + assert_eq!(actual, expected); + } + + #[test] + fn test_hook_output_is_blocking() { + let fixture = HookOutput { decision: Some("block".to_string()), ..Default::default() }; + assert!(fixture.is_blocking()); + + let fixture = HookOutput { + permission_decision: Some("deny".to_string()), + ..Default::default() + }; + assert!(fixture.is_blocking()); + + let fixture = HookOutput::default(); + assert!(!fixture.is_blocking()); + } + + #[test] + fn test_hook_output_is_blocking_continue_false() { + let fixture = HookOutput { continue_execution: Some(false), ..Default::default() }; + assert!(fixture.is_blocking()); + } + + #[test] + fn test_hook_output_is_not_blocking_continue_true() { + let fixture = HookOutput { continue_execution: Some(true), ..Default::default() }; + assert!(!fixture.is_blocking()); + } + + #[test] + fn test_hook_output_is_not_blocking_continue_none() { + let fixture = HookOutput { continue_execution: None, ..Default::default() }; + assert!(!fixture.is_blocking()); + } + + #[test] + fn test_hook_output_continue_false_with_stop_reason_parses_and_blocks() { + let stdout = r#"{"continue": false, "stopReason": "done"}"#; + let actual = HookOutput::parse(stdout); + assert!(actual.is_blocking()); + assert_eq!(actual.continue_execution, Some(false)); + assert_eq!(actual.stop_reason, Some("done".to_string())); + } + + #[test] + fn test_blocking_reason_prefers_reason_over_stop_reason() { + let fixture = HookOutput { + reason: Some("primary".to_string()), + stop_reason: Some("secondary".to_string()), + ..Default::default() + }; + let actual = fixture.blocking_reason("default"); + assert_eq!(actual, "primary"); + } + + #[test] + fn test_blocking_reason_falls_back_to_stop_reason() { + let fixture = HookOutput { + stop_reason: Some("fallback".to_string()), + ..Default::default() + }; + let actual = fixture.blocking_reason("default"); + assert_eq!(actual, "fallback"); + } + + #[test] + fn test_blocking_reason_uses_default_when_both_none() { + let fixture = HookOutput::default(); + let actual = fixture.blocking_reason("default reason"); + assert_eq!(actual, "default reason"); + } + + #[test] + fn test_hook_execution_result_blocking() { + let fixture = HookExecutionResult { + exit_code: Some(2), + stdout: String::new(), + stderr: "Blocked: unsafe command".to_string(), + }; + + assert!(fixture.is_blocking_exit()); + assert!(!fixture.is_success()); + assert!(!fixture.is_non_blocking_error()); + assert_eq!(fixture.blocking_message(), Some("Blocked: unsafe command")); + assert!(fixture.parse_output().is_none()); + } + + #[test] + fn test_hook_execution_result_success() { + let fixture = HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "test"}"#.to_string(), + stderr: String::new(), + }; + + assert!(fixture.is_success()); + assert!(!fixture.is_blocking_exit()); + assert!(!fixture.is_non_blocking_error()); + let output = fixture.parse_output().unwrap(); + assert!(output.is_blocking()); + } + + #[test] + fn test_hook_execution_result_non_blocking_error() { + let fixture = HookExecutionResult { + exit_code: Some(1), + stdout: String::new(), + stderr: "some error".to_string(), + }; + + assert!(fixture.is_non_blocking_error()); + assert!(!fixture.is_success()); + assert!(!fixture.is_blocking_exit()); + assert!(fixture.blocking_message().is_none()); + } + + // --- Schema validation tests for updatedInput --- + + #[test] + fn test_updated_input_valid_object_parsed() { + let stdout = r#"{"updatedInput": {"command": "echo safe"}}"#; + let actual = HookOutput::parse(stdout); + let expected_map = Map::from_iter([( + "command".to_string(), + Value::String("echo safe".to_string()), + )]); + assert_eq!(actual.updated_input, Some(expected_map)); + } + + #[test] + fn test_updated_input_string_rejected_falls_back_to_default() { + // updatedInput is a string, not an object => serde rejects it, + // entire parse falls back to default (updated_input = None). + let stdout = r#"{"updatedInput": "not an object"}"#; + let actual = HookOutput::parse(stdout); + assert_eq!(actual.updated_input, None); + } + + #[test] + fn test_updated_input_number_rejected_falls_back_to_default() { + let stdout = r#"{"updatedInput": 42}"#; + let actual = HookOutput::parse(stdout); + assert_eq!(actual.updated_input, None); + } + + #[test] + fn test_updated_input_array_rejected_falls_back_to_default() { + let stdout = r#"{"updatedInput": [1, 2, 3]}"#; + let actual = HookOutput::parse(stdout); + assert_eq!(actual.updated_input, None); + } + + #[test] + fn test_updated_input_bool_rejected_falls_back_to_default() { + let stdout = r#"{"updatedInput": true}"#; + let actual = HookOutput::parse(stdout); + assert_eq!(actual.updated_input, None); + } + + #[test] + fn test_updated_input_null_treated_as_none() { + // JSON null for an Option field => None (not Some(empty map)) + let stdout = r#"{"updatedInput": null}"#; + let actual = HookOutput::parse(stdout); + assert_eq!(actual.updated_input, None); + } + + #[test] + fn test_updated_input_nested_object_accepted() { + let stdout = r#"{"updatedInput": {"a": {"b": [1, 2]}, "c": true}}"#; + let actual = HookOutput::parse(stdout); + let expected_map = Map::from_iter([ + ("a".to_string(), serde_json::json!({"b": [1, 2]})), + ("c".to_string(), Value::Bool(true)), + ]); + assert_eq!(actual.updated_input, Some(expected_map)); + } + + #[test] + fn test_malformed_updated_input_preserves_other_fields() { + // When updatedInput is invalid, the entire HookOutput parse fails + // and falls back to default. This means other fields like `decision` + // are also lost. This is the expected behavior — a malformed hook + // output is treated as if the hook returned nothing. + let stdout = r#"{"decision": "block", "updatedInput": "bad"}"#; + let actual = HookOutput::parse(stdout); + assert_eq!(actual, HookOutput::default()); + } +} diff --git a/crates/forge_infra/src/executor.rs b/crates/forge_infra/src/executor.rs index 13f30d8c8d..a795556abc 100644 --- a/crates/forge_infra/src/executor.rs +++ b/crates/forge_infra/src/executor.rs @@ -1,10 +1,11 @@ +use std::collections::HashMap; use std::io::{self, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; use forge_app::CommandInfra; use forge_domain::{CommandOutput, ConsoleWriter as OutputPrinterTrait, Environment}; -use tokio::io::AsyncReadExt; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::Command; use tokio::sync::Mutex; @@ -224,6 +225,43 @@ impl CommandInfra for ForgeCommandExecutorService { Ok(prepared_command.spawn()?.wait().await?) } + + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + let mut prepared_command = self.prepare_command(&command, &working_dir, None); + + // Set directly-provided key-value env vars + for (key, value) in &env_vars { + prepared_command.env(key, value); + } + + // Override stdin to piped so we can write to it + prepared_command.stdin(std::process::Stdio::piped()); + + let mut child = prepared_command.spawn()?; + + // Pipe the JSON input to stdin + if let Some(mut stdin) = child.stdin.take() { + let input = stdin_input.clone(); + tokio::spawn(async move { + let _ = stdin.write_all(input.as_bytes()).await; + let _ = stdin.shutdown().await; + }); + } + + let output = child.wait_with_output().await?; + Ok(CommandOutput { + command, + exit_code: output.status.code(), + stdout: String::from_utf8_lossy(&output.stdout).into_owned(), + stderr: String::from_utf8_lossy(&output.stderr).into_owned(), + }) + } } #[cfg(test)] diff --git a/crates/forge_infra/src/forge_infra.rs b/crates/forge_infra/src/forge_infra.rs index 3a3e602d17..80df335cf5 100644 --- a/crates/forge_infra/src/forge_infra.rs +++ b/crates/forge_infra/src/forge_infra.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::path::{Path, PathBuf}; use std::process::ExitStatus; use std::sync::Arc; @@ -250,6 +250,18 @@ impl CommandInfra for ForgeInfra { .execute_command_raw(command, working_dir, env_vars) .await } + + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + self.command_executor_service + .execute_command_with_input(command, working_dir, stdin_input, env_vars) + .await + } } #[async_trait::async_trait] diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index fde42605d5..b6a9a8111c 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -3365,6 +3365,20 @@ impl A + Send + Sync> UI self.writeln_title(TitleFormat::error(cause.as_str()))?; } } + ChatResponse::HookError { tool_name, reason } => { + writer.finish()?; + self.spinner.stop(None)?; + self.writeln_title(TitleFormat::error(format!( + "PreToolUse:{tool_name} hook error: {reason}" + )))?; + self.spinner.start(None)?; + } + ChatResponse::HookWarning { message } => { + writer.finish()?; + self.spinner.stop(None)?; + self.writeln_title(TitleFormat::warning(message))?; + self.spinner.start(None)?; + } ChatResponse::Interrupt { reason } => { writer.finish()?; self.spinner.stop(None)?; diff --git a/crates/forge_repo/src/forge_repo.rs b/crates/forge_repo/src/forge_repo.rs index 34d1bb8498..88f5352dd3 100644 --- a/crates/forge_repo/src/forge_repo.rs +++ b/crates/forge_repo/src/forge_repo.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -485,6 +485,18 @@ where .execute_command_raw(command, working_dir, env_vars) .await } + + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + self.infra + .execute_command_with_input(command, working_dir, stdin_input, env_vars) + .await + } } #[async_trait::async_trait] diff --git a/crates/forge_services/src/forge_services.rs b/crates/forge_services/src/forge_services.rs index 7ff1d1a2fb..84a3ba32f6 100644 --- a/crates/forge_services/src/forge_services.rs +++ b/crates/forge_services/src/forge_services.rs @@ -26,8 +26,10 @@ use crate::provider_service::ForgeProviderService; use crate::template::ForgeTemplateService; use crate::tool_services::{ ForgeFetch, ForgeFollowup, ForgeFsPatch, ForgeFsRead, ForgeFsRemove, ForgeFsSearch, - ForgeFsUndo, ForgeFsWrite, ForgeImageRead, ForgePlanCreate, ForgeShell, ForgeSkillFetch, + ForgeFsUndo, ForgeFsWrite, ForgeHookCommandService, ForgeImageRead, ForgePlanCreate, + ForgeShell, ForgeSkillFetch, }; +use crate::user_hook_config::ForgeUserHookConfigService; type McpService = ForgeMcpService, F, ::Client>; type AuthService = ForgeAuthService; @@ -78,10 +80,12 @@ pub struct ForgeServices< auth_service: Arc>, agent_registry_service: Arc>, command_loader_service: Arc>, + user_hook_config_service: Arc>, policy_service: ForgePolicyService, provider_auth_service: ForgeProviderAuthService, workspace_service: Arc>>, skill_service: Arc>, + hook_command_service: Arc>, infra: Arc, } @@ -132,6 +136,7 @@ impl< Arc::new(ForgeCustomInstructionsService::new(infra.clone())); let agent_registry_service = Arc::new(ForgeAgentRegistryService::new(infra.clone())); let command_loader_service = Arc::new(ForgeCommandLoaderService::new(infra.clone())); + let user_hook_config_service = Arc::new(ForgeUserHookConfigService::new(infra.clone())); let policy_service = ForgePolicyService::new(infra.clone()); let provider_auth_service = ForgeProviderAuthService::new(infra.clone()); let discovery = Arc::new(FdDefault::new(infra.clone())); @@ -140,6 +145,7 @@ impl< discovery, )); let skill_service = Arc::new(ForgeSkillFetch::new(infra.clone())); + let hook_command_service = Arc::new(ForgeHookCommandService::new(infra.clone())); Self { conversation_service, @@ -164,10 +170,12 @@ impl< config_service, agent_registry_service, command_loader_service, + user_hook_config_service, policy_service, provider_auth_service, workspace_service, skill_service, + hook_command_service, chat_service, infra, } @@ -230,10 +238,12 @@ impl< type AuthService = AuthService; type AgentRegistry = ForgeAgentRegistryService; type CommandLoaderService = ForgeCommandLoaderService; + type UserHookConfigService = ForgeUserHookConfigService; type PolicyService = ForgePolicyService; type ProviderService = ForgeProviderService; type WorkspaceService = crate::context_engine::ForgeWorkspaceService>; type SkillFetchService = ForgeSkillFetch; + type HookCommandService = ForgeHookCommandService; fn config_service(&self) -> &Self::AppConfigService { &self.config_service @@ -319,6 +329,10 @@ impl< &self.command_loader_service } + fn user_hook_config_service(&self) -> &Self::UserHookConfigService { + &self.user_hook_config_service + } + fn policy_service(&self) -> &Self::PolicyService { &self.policy_service } @@ -334,6 +348,10 @@ impl< &self.skill_service } + fn hook_command_service(&self) -> &Self::HookCommandService { + &self.hook_command_service + } + fn provider_service(&self) -> &Self::ProviderService { &self.chat_service } diff --git a/crates/forge_services/src/lib.rs b/crates/forge_services/src/lib.rs index bb102e86c6..36060171e1 100644 --- a/crates/forge_services/src/lib.rs +++ b/crates/forge_services/src/lib.rs @@ -22,6 +22,7 @@ mod range; mod sync; mod template; mod tool_services; +mod user_hook_config; mod utils; pub use app_config::*; diff --git a/crates/forge_services/src/tool_services/hook_command.rs b/crates/forge_services/src/tool_services/hook_command.rs new file mode 100644 index 0000000000..3588bc0fa3 --- /dev/null +++ b/crates/forge_services/src/tool_services/hook_command.rs @@ -0,0 +1,37 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use forge_app::{CommandInfra, HookCommandService}; +use forge_domain::CommandOutput; + +/// Thin wrapper around [`CommandInfra::execute_command_with_input`] that +/// satisfies the [`HookCommandService`] contract. +/// +/// By delegating to the underlying infra this service avoids duplicating +/// process-spawning and stdin-piping logic; those concerns live entirely inside +/// the `CommandInfra` implementation. +#[derive(Clone)] +pub struct ForgeHookCommandService(Arc); + +impl ForgeHookCommandService { + /// Creates a new `ForgeHookCommandService` backed by the given infra. + pub fn new(infra: Arc) -> Self { + Self(infra) + } +} + +#[async_trait::async_trait] +impl HookCommandService for ForgeHookCommandService { + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + self.0 + .execute_command_with_input(command, working_dir, stdin_input, env_vars) + .await + } +} diff --git a/crates/forge_services/src/tool_services/mod.rs b/crates/forge_services/src/tool_services/mod.rs index 64a5c6f3c0..75e78f3d7a 100644 --- a/crates/forge_services/src/tool_services/mod.rs +++ b/crates/forge_services/src/tool_services/mod.rs @@ -6,6 +6,7 @@ mod fs_remove; mod fs_search; mod fs_undo; mod fs_write; +mod hook_command; mod image_read; mod plan_create; mod shell; @@ -19,6 +20,7 @@ pub use fs_remove::*; pub use fs_search::*; pub use fs_undo::*; pub use fs_write::*; +pub use hook_command::*; pub use image_read::*; pub use plan_create::*; pub use shell::*; diff --git a/crates/forge_services/src/tool_services/shell.rs b/crates/forge_services/src/tool_services/shell.rs index 05a671de71..a351243cf8 100644 --- a/crates/forge_services/src/tool_services/shell.rs +++ b/crates/forge_services/src/tool_services/shell.rs @@ -108,6 +108,21 @@ mod tests { ) -> anyhow::Result { unimplemented!() } + + async fn execute_command_with_input( + &self, + command: String, + _working_dir: PathBuf, + _stdin_input: String, + _env_vars: std::collections::HashMap, + ) -> anyhow::Result { + Ok(forge_domain::CommandOutput { + command, + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }) + } } impl EnvironmentInfra for MockCommandInfra { diff --git a/crates/forge_services/src/user_hook_config.rs b/crates/forge_services/src/user_hook_config.rs new file mode 100644 index 0000000000..8834063bf7 --- /dev/null +++ b/crates/forge_services/src/user_hook_config.rs @@ -0,0 +1,106 @@ +use std::sync::Arc; + +use forge_app::EnvironmentInfra; +use forge_config::UserHookConfig; + +/// Loads user hook configuration from `.forge.toml` via the config pipeline. +/// +/// Hook configuration is read from the `[hooks]` section of the user's +/// `.forge.toml` file, automatically merged with defaults by the +/// `ConfigReader` layered config system. +pub struct ForgeUserHookConfigService(Arc); + +impl ForgeUserHookConfigService { + /// Creates a new service with the given infrastructure dependency. + pub fn new(infra: Arc) -> Self { + Self(infra) + } +} + +#[async_trait::async_trait] +impl> forge_app::UserHookConfigService + for ForgeUserHookConfigService +{ + async fn get_user_hook_config(&self) -> anyhow::Result { + Ok(self.0.get_config()?.hooks.unwrap_or_default()) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use fake::Fake; + use forge_app::UserHookConfigService; + use forge_config::UserHookEventName; + use pretty_assertions::assert_eq; + + use super::*; + + #[tokio::test] + async fn test_get_user_hook_config_returns_hooks_from_config() { + let hook_json = r#"{ + "PreToolUse": [ + { "matcher": "Bash", "hooks": [{ "type": "command", "command": "check.sh" }] } + ] + }"#; + let hooks: forge_config::UserHookConfig = serde_json::from_str(hook_json).unwrap(); + let config = forge_config::ForgeConfig { hooks: Some(hooks), ..Default::default() }; + let service = fixture(config); + + let actual = service.get_user_hook_config().await.unwrap(); + + assert!(!actual.is_empty()); + assert_eq!(actual.get_groups(&UserHookEventName::PreToolUse).len(), 1); + } + + #[tokio::test] + async fn test_get_user_hook_config_returns_empty_when_no_hooks() { + let config = forge_config::ForgeConfig::default(); + let service = fixture(config); + + let actual = service.get_user_hook_config().await.unwrap(); + + assert!(actual.is_empty()); + } + + // --- Test helpers --- + + fn fixture(config: forge_config::ForgeConfig) -> ForgeUserHookConfigService { + ForgeUserHookConfigService::new(Arc::new(TestInfra { config })) + } + + struct TestInfra { + config: forge_config::ForgeConfig, + } + + impl EnvironmentInfra for TestInfra { + type Config = forge_config::ForgeConfig; + + fn get_env_var(&self, _key: &str) -> Option { + None + } + + fn get_env_vars(&self) -> std::collections::BTreeMap { + Default::default() + } + + fn get_environment(&self) -> forge_domain::Environment { + let mut env: forge_domain::Environment = fake::Faker.fake(); + env.home = Some(PathBuf::from("/nonexistent/home")); + env.cwd = PathBuf::from("/nonexistent/project"); + env + } + + fn get_config(&self) -> anyhow::Result { + Ok(self.config.clone()) + } + + async fn update_environment( + &self, + _ops: Vec, + ) -> anyhow::Result<()> { + unimplemented!("not needed for tests") + } + } +} diff --git a/forge.schema.json b/forge.schema.json index 43cc190609..2b8391cf4b 100644 --- a/forge.schema.json +++ b/forge.schema.json @@ -66,6 +66,17 @@ "null" ] }, + "hooks": { + "description": "User hook configuration loaded from the `[hooks]` section.\n\nMaps lifecycle event names (e.g. `PreToolUse`, `Stop`) to lists of\nmatcher groups that execute shell commands at each event point.", + "anyOf": [ + { + "$ref": "#/$defs/UserHookConfig" + }, + { + "type": "null" + } + ] + }, "http": { "description": "HTTP client settings including proxy, TLS, and timeout configuration.", "anyOf": [ @@ -887,6 +898,76 @@ "always" ] }, + "UserHookConfig": { + "description": "Top-level user hook configuration.\n\nMaps hook event names to a list of matcher groups. This is deserialized\nfrom the `hooks` section in `.forge.toml`.\n\nExample TOML:\n```toml\n[[hooks.PreToolUse]]\nmatcher = \"Bash\"\n\n [[hooks.PreToolUse.hooks]]\n type = \"command\"\n command = \"echo hi\"\n```", + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "$ref": "#/$defs/UserHookMatcherGroup" + } + } + }, + "UserHookEntry": { + "description": "A single hook handler entry that defines what action to take.", + "type": "object", + "properties": { + "command": { + "description": "The shell command to execute (for `Command` type hooks).", + "type": [ + "string", + "null" + ] + }, + "timeout": { + "description": "Timeout in milliseconds for this hook. Defaults to 600000ms (10\nminutes).", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + }, + "type": { + "description": "The type of hook handler.", + "$ref": "#/$defs/UserHookType" + } + }, + "required": [ + "type" + ] + }, + "UserHookMatcherGroup": { + "description": "A matcher group pairs an optional regex matcher with a list of hook\nhandlers.\n\nWhen a lifecycle event fires, only matcher groups whose `matcher` regex\nmatches the relevant event context (e.g., tool name) will have their hooks\nexecuted. If `matcher` is `None` (or an empty string, which is normalized\nto `None`), all hooks in this group fire unconditionally.", + "type": "object", + "properties": { + "hooks": { + "description": "List of hook handlers to execute when this matcher matches.", + "type": "array", + "default": [], + "items": { + "$ref": "#/$defs/UserHookEntry" + } + }, + "matcher": { + "description": "Optional regex pattern to match against (e.g., tool name for\nPreToolUse/PostToolUse).", + "type": [ + "string", + "null" + ] + } + } + }, + "UserHookType": { + "description": "The type of hook handler to execute.", + "oneOf": [ + { + "description": "Executes a shell command, piping JSON to stdin and reading JSON from\nstdout.", + "type": "string", + "const": "command" + } + ] + }, "double": { "type": "number", "format": "double"