diff --git a/crates/forge_app/src/app.rs b/crates/forge_app/src/app.rs index 4b90eed73b..c8fec71741 100644 --- a/crates/forge_app/src/app.rs +++ b/crates/forge_app/src/app.rs @@ -9,7 +9,10 @@ use forge_stream::MpscStream; use crate::apply_tunable_parameters::ApplyTunableParameters; use crate::changed_files::ChangedFiles; use crate::dto::ToolsOverview; -use crate::hooks::{CompactionHandler, DoomLoopDetector, TitleGenerationHandler, TracingHandler}; +use crate::hooks::{ + CompactionHandler, DoomLoopDetector, PendingTodosHandler, TitleGenerationHandler, + TracingHandler, +}; use crate::init_conversation_metrics::InitConversationMetrics; use crate::orch::Orchestrator; use crate::services::{AgentRegistry, CustomInstructionsService, ProviderAuthService}; @@ -142,8 +145,20 @@ impl> ForgeAp // Create the orchestrator with all necessary dependencies let tracing_handler = TracingHandler::new(); let title_handler = TitleGenerationHandler::new(services.clone()); + + // Build the on_end hook, conditionally adding PendingTodosHandler based on + // config + let on_end_hook = if forge_config.verify_todos { + tracing_handler + .clone() + .and(title_handler.clone()) + .and(PendingTodosHandler::new()) + } else { + tracing_handler.clone().and(title_handler.clone()) + }; + let hook = Hook::default() - .on_start(tracing_handler.clone().and(title_handler.clone())) + .on_start(tracing_handler.clone().and(title_handler)) .on_request(tracing_handler.clone().and(DoomLoopDetector::default())) .on_response( tracing_handler @@ -151,8 +166,8 @@ impl> ForgeAp .and(CompactionHandler::new(agent.clone(), environment.clone())), ) .on_toolcall_start(tracing_handler.clone()) - .on_toolcall_end(tracing_handler.clone()) - .on_end(tracing_handler.and(title_handler)); + .on_toolcall_end(tracing_handler) + .on_end(on_end_hook); let orch = Orchestrator::new( services.clone(), diff --git a/crates/forge_app/src/hooks/mod.rs b/crates/forge_app/src/hooks/mod.rs index fb5447a8e6..26a43401f2 100644 --- a/crates/forge_app/src/hooks/mod.rs +++ b/crates/forge_app/src/hooks/mod.rs @@ -1,9 +1,11 @@ mod compaction; mod doom_loop; +mod pending_todos; mod title_generation; mod tracing; pub use compaction::CompactionHandler; pub use doom_loop::DoomLoopDetector; +pub use pending_todos::PendingTodosHandler; pub use title_generation::TitleGenerationHandler; pub use tracing::TracingHandler; diff --git a/crates/forge_app/src/hooks/pending_todos.rs b/crates/forge_app/src/hooks/pending_todos.rs new file mode 100644 index 0000000000..bad2b44fa6 --- /dev/null +++ b/crates/forge_app/src/hooks/pending_todos.rs @@ -0,0 +1,272 @@ +use std::collections::HashSet; + +use async_trait::async_trait; +use forge_domain::{ + ContextMessage, Conversation, EndPayload, EventData, EventHandle, Template, TodoStatus, +}; +use forge_template::Element; +use serde::Serialize; + +use crate::TemplateEngine; + +/// A single todo item prepared for template rendering. +#[derive(Serialize)] +struct TodoReminderItem { + status: &'static str, + content: String, +} + +/// Template context for the pending-todos reminder. +#[derive(Serialize)] +struct PendingTodosContext { + todos: Vec, +} + +/// Detects when the LLM signals task completion while there are still +/// pending or in-progress todo items. +/// +/// When triggered, it injects a formatted reminder listing all +/// outstanding todos into the conversation context, preventing the +/// orchestrator from yielding prematurely. +#[derive(Debug, Clone, Default)] +pub struct PendingTodosHandler; + +impl PendingTodosHandler { + /// Creates a new pending-todos handler + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl EventHandle> for PendingTodosHandler { + async fn handle( + &self, + _event: &EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + let pending_todos = conversation.metrics.get_active_todos(); + if pending_todos.is_empty() { + return Ok(()); + } + + // Build a set of current pending todo contents for comparison + let current_todo_set: HashSet = pending_todos + .iter() + .map(|todo| todo.content.clone()) + .collect(); + + // Check if we already have a reminder with the exact same set of todos + // This prevents duplicate reminders while still allowing new reminders + // when todos change (e.g., some completed but others still pending) + let should_add_reminder = if let Some(context) = &conversation.context { + // Find the most recent reminder message by looking for the template content + // pattern + let last_reminder_todos: Option> = context + .messages + .iter() + .rev() + .filter_map(|entry| { + let content = entry.message.content()?; + // Check if this is a pending todos reminder + if content.contains("You have pending todo items") { + // Extract todo items from the reminder message + // Format: "- [STATUS] Content" + let todos: HashSet = content + .lines() + .filter(|line| line.starts_with("- [")) + .map(|line| { + // Extract content after "- [STATUS] " + line.split_once("] ") + .map(|x| x.1) + .map(|s| s.to_string()) + .unwrap_or_default() + }) + .collect(); + Some(todos) + } else { + None + } + }) + .next(); + + match last_reminder_todos { + Some(last_todos) => last_todos != current_todo_set, + None => true, // No previous reminder found, should add + } + } else { + true // No context, should add reminder + }; + + if !should_add_reminder { + return Ok(()); + } + + let todo_items: Vec = pending_todos + .iter() + .filter_map(|todo| { + let status = match todo.status { + TodoStatus::Pending => "PENDING", + TodoStatus::InProgress => "IN_PROGRESS", + _ => return None, + }; + Some(TodoReminderItem { status, content: todo.content.clone() }) + }) + .collect(); + + let ctx = PendingTodosContext { todos: todo_items }; + let reminder = TemplateEngine::default().render( + Template::::new("forge-pending-todos-reminder.md"), + &ctx, + )?; + + if let Some(context) = conversation.context.as_mut() { + let content = Element::new("system_reminder").text(reminder); + context + .messages + .push(ContextMessage::user(content, None).into()); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use forge_domain::{ + Agent, Context, Conversation, EndPayload, EventData, EventHandle, Metrics, ModelId, Todo, + TodoStatus, + }; + use pretty_assertions::assert_eq; + + use super::*; + + fn fixture_agent() -> Agent { + Agent::new( + "test-agent", + "test-provider".to_string().into(), + ModelId::new("test-model"), + ) + } + + fn fixture_conversation(todos: Vec) -> Conversation { + let mut conversation = Conversation::generate(); + conversation.context = Some(Context::default()); + conversation.metrics = Metrics::default().todos(todos); + conversation + } + + fn fixture_event() -> EventData { + EventData::new(fixture_agent(), ModelId::new("test-model"), EndPayload) + } + + #[tokio::test] + async fn test_no_pending_todos_does_nothing() { + let handler = PendingTodosHandler::new(); + let event = fixture_event(); + let mut conversation = fixture_conversation(vec![]); + + let initial_msg_count = conversation.context.as_ref().unwrap().messages.len(); + handler.handle(&event, &mut conversation).await.unwrap(); + + let actual = conversation.context.as_ref().unwrap().messages.len(); + let expected = initial_msg_count; + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_pending_todos_injects_reminder() { + let handler = PendingTodosHandler::new(); + let event = fixture_event(); + let mut conversation = fixture_conversation(vec![ + Todo::new("Fix the build").status(TodoStatus::Pending), + Todo::new("Write tests").status(TodoStatus::InProgress), + ]); + + handler.handle(&event, &mut conversation).await.unwrap(); + + let actual = conversation.context.as_ref().unwrap().messages.len(); + let expected = 1; + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_reminder_contains_formatted_list() { + let handler = PendingTodosHandler::new(); + let event = fixture_event(); + let mut conversation = fixture_conversation(vec![ + Todo::new("Fix the build").status(TodoStatus::Pending), + Todo::new("Write tests").status(TodoStatus::InProgress), + ]); + + handler.handle(&event, &mut conversation).await.unwrap(); + + let entry = &conversation.context.as_ref().unwrap().messages[0]; + let actual = entry.message.content().unwrap(); + assert!(actual.contains("- [PENDING] Fix the build")); + assert!(actual.contains("- [IN_PROGRESS] Write tests")); + } + + #[tokio::test] + async fn test_completed_todos_not_included() { + let handler = PendingTodosHandler::new(); + let event = fixture_event(); + let mut conversation = fixture_conversation(vec![ + Todo::new("Completed task").status(TodoStatus::Completed), + Todo::new("Cancelled task").status(TodoStatus::Cancelled), + ]); + + let initial_msg_count = conversation.context.as_ref().unwrap().messages.len(); + handler.handle(&event, &mut conversation).await.unwrap(); + + let actual = conversation.context.as_ref().unwrap().messages.len(); + let expected = initial_msg_count; + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_reminder_not_duplicated_for_same_todos() { + let handler = PendingTodosHandler::new(); + let event = fixture_event(); + let mut conversation = + fixture_conversation(vec![Todo::new("Fix the build").status(TodoStatus::Pending)]); + + // First call should inject a reminder + handler.handle(&event, &mut conversation).await.unwrap(); + let after_first = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(after_first, 1); + + // Second call with the same pending todos should NOT add another reminder + handler.handle(&event, &mut conversation).await.unwrap(); + let after_second = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(after_second, 1); // Still 1, no duplicate + } + + #[tokio::test] + async fn test_reminder_added_when_todos_change() { + let handler = PendingTodosHandler::new(); + let event = fixture_event(); + let mut conversation = fixture_conversation(vec![ + Todo::new("Fix the build").status(TodoStatus::Pending), + Todo::new("Write tests").status(TodoStatus::InProgress), + ]); + + // First call should inject a reminder + handler.handle(&event, &mut conversation).await.unwrap(); + let after_first = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(after_first, 1); + + // Simulate LLM completing one todo but leaving another pending + // Update the conversation metrics with different todos + conversation.metrics = conversation.metrics.clone().todos(vec![ + Todo::new("Fix the build").status(TodoStatus::Completed), + Todo::new("Write tests").status(TodoStatus::InProgress), + Todo::new("Add documentation").status(TodoStatus::Pending), + ]); + + // Second call with different pending todos should add a new reminder + handler.handle(&event, &mut conversation).await.unwrap(); + let after_second = conversation.context.as_ref().unwrap().messages.len(); + assert_eq!(after_second, 2); // New reminder added because todos changed + } +} diff --git a/crates/forge_app/src/orch.rs b/crates/forge_app/src/orch.rs index b5435c4afb..86157c24e2 100644 --- a/crates/forge_app/src/orch.rs +++ b/crates/forge_app/src/orch.rs @@ -258,7 +258,6 @@ impl> Orc self.conversation.context = Some(context.clone()); self.services.update(self.conversation.clone()).await?; - // Fire the Request lifecycle event let request_event = LifecycleEvent::Request(EventData::new( self.agent.clone(), model_id.clone(), @@ -325,7 +324,7 @@ impl> Orc .execute_tool_calls(&message.tool_calls, &tool_context) .await?; - // Update context from conversation after tool-call hooks run + // Update context from conversation after response / tool-call hooks run if let Some(updated_context) = &self.conversation.context { context = updated_context.clone(); } @@ -403,19 +402,32 @@ impl> Orc tool_context.with_metrics(|metrics| { self.conversation.metrics = metrics.clone(); })?; - } - // Fire the End lifecycle event (title will be set here by the hook) - self.hook - .handle( - &LifecycleEvent::End(EventData::new( - self.agent.clone(), - model_id.clone(), - EndPayload, - )), - &mut self.conversation, - ) - .await?; + // If completing (should_yield is due), fire End hook and check if + // it adds messages + if should_yield { + let end_count_before = self.conversation.len(); + self.hook + .handle( + &LifecycleEvent::End(EventData::new( + self.agent.clone(), + model_id.clone(), + EndPayload, + )), + &mut self.conversation, + ) + .await?; + self.services.update(self.conversation.clone()).await?; + // Check if End hook added messages - if so, continue the loop + if self.conversation.len() > end_count_before { + // End hook added messages, sync context and continue + if let Some(updated_context) = &self.conversation.context { + context = updated_context.clone(); + } + should_yield = false; + } + } + } self.services.update(self.conversation.clone()).await?; diff --git a/crates/forge_app/src/orch_spec/orch_runner.rs b/crates/forge_app/src/orch_spec/orch_runner.rs index ae02b03b1a..c33c8349b3 100644 --- a/crates/forge_app/src/orch_spec/orch_runner.rs +++ b/crates/forge_app/src/orch_spec/orch_runner.rs @@ -12,7 +12,7 @@ use tokio::sync::Mutex; pub use super::orch_setup::TestContext; use crate::app::build_template_config; use crate::apply_tunable_parameters::ApplyTunableParameters; -use crate::hooks::DoomLoopDetector; +use crate::hooks::{DoomLoopDetector, PendingTodosHandler}; use crate::init_conversation_metrics::InitConversationMetrics; use crate::orch::Orchestrator; use crate::set_conversation_id::SetConversationId; @@ -119,6 +119,12 @@ impl Runner { .await?; let conversation = InitConversationMetrics::new(setup.current_time).apply(conversation); + // Apply initial metrics (including todos) if provided by the test + let conversation = if let Some(ref metrics) = setup.initial_metrics { + conversation.metrics(metrics.clone()) + } else { + conversation + }; let conversation = ApplyTunableParameters::new(agent.clone(), system_tools.clone()).apply(conversation); let conversation = SetConversationId.apply(conversation); @@ -127,7 +133,9 @@ impl Runner { .error_tracker(ToolErrorTracker::new(3)) .tool_definitions(system_tools) .hook(Arc::new( - Hook::default().on_request(DoomLoopDetector::default()), + Hook::default() + .on_request(DoomLoopDetector::default()) + .on_end(PendingTodosHandler::new()), )) .sender(tx); diff --git a/crates/forge_app/src/orch_spec/orch_setup.rs b/crates/forge_app/src/orch_spec/orch_setup.rs index ce411dd1f4..5a28d48218 100644 --- a/crates/forge_app/src/orch_spec/orch_setup.rs +++ b/crates/forge_app/src/orch_spec/orch_setup.rs @@ -6,8 +6,8 @@ use derive_setters::Setters; use forge_config::ForgeConfig; use forge_domain::{ Agent, AgentId, Attachment, ChatCompletionMessage, ChatResponse, Conversation, Environment, - Event, File, MessageEntry, ModelId, ProviderId, Role, Template, ToolCallFull, ToolDefinition, - ToolResult, + Event, File, MessageEntry, Metrics, ModelId, ProviderId, Role, Template, ToolCallFull, + ToolDefinition, ToolResult, }; use crate::ShellOutput; @@ -33,6 +33,9 @@ pub struct TestContext { pub model: ModelId, pub attachments: Vec, + // Initial metrics to apply to the conversation + pub initial_metrics: Option, + // Final output of the test is store in the context pub output: TestOutput, pub agent: Agent, @@ -54,6 +57,7 @@ impl Default for TestContext { templates: Default::default(), files: Default::default(), attachments: Default::default(), + initial_metrics: None, env: Environment { os: "MacOS".to_string(), cwd: PathBuf::from("/Users/tushar"), diff --git a/crates/forge_app/src/orch_spec/orch_spec.rs b/crates/forge_app/src/orch_spec/orch_spec.rs index 8bf2d0813e..4e5eaec96a 100644 --- a/crates/forge_app/src/orch_spec/orch_spec.rs +++ b/crates/forge_app/src/orch_spec/orch_spec.rs @@ -385,6 +385,10 @@ async fn test_multiple_consecutive_tool_calls() { ChatCompletionMessage::assistant("Reading 3").add_tool_call(tool_call.clone()), ChatCompletionMessage::assistant("Reading 4").add_tool_call(tool_call.clone()), ChatCompletionMessage::assistant("Completing Task").finish_reason(FinishReason::Stop), + // Extra responses for doom loop reminder iterations (detector triggers on each request + // after 4th tool call) + ChatCompletionMessage::assistant("Acknowledged").finish_reason(FinishReason::Stop), + ChatCompletionMessage::assistant("Task complete").finish_reason(FinishReason::Stop), ]); let _ = ctx.run("Read a file").await; @@ -419,6 +423,10 @@ async fn test_doom_loop_detection_adds_user_reminder_after_repeated_calls_on_nex ChatCompletionMessage::assistant("Call 3").add_tool_call(tool_call.clone()), ChatCompletionMessage::assistant("Call 4").add_tool_call(tool_call.clone()), ChatCompletionMessage::assistant("Done").finish_reason(FinishReason::Stop), + // Extra responses for doom loop reminder iterations (detector triggers on each request + // after 4th tool call) + ChatCompletionMessage::assistant("Noted").finish_reason(FinishReason::Stop), + ChatCompletionMessage::assistant("Actually done now").finish_reason(FinishReason::Stop), ]); ctx.run("Test doom loop").await.unwrap(); @@ -592,3 +600,117 @@ async fn test_not_complete_when_stop_with_tool_calls() { "Should have 2 assistant messages, confirming is_complete was false with tool calls" ); } + +#[tokio::test] +async fn test_todo_enforcement_injects_reminder() { + // Test: When the orchestrator receives a Stop response but there are pending + // todos, the PendingTodosHandler hook should inject a formatted reminder + // message into the context listing all outstanding items. + // NOTE: Since the End hook now adds reminders and triggers the outer loop + // to continue, the orchestrator will loop until todos are completed. We + // provide enough mock responses to verify the reminder is injected, and + // allow the test to exhaust mock responses (which is expected). + use forge_domain::{Metrics, Todo, TodoStatus}; + + let mut ctx = TestContext::default() + .mock_assistant_responses(vec![ + // LLM tries to finish but has pending todos - reminder will be injected + ChatCompletionMessage::assistant(Content::full("Task is done")) + .finish_reason(FinishReason::Stop), + // Second response after the first reminder is injected + // Handler won't add duplicate reminder, so this will complete + ChatCompletionMessage::assistant(Content::full( + "I see there are pending todos. Let me continue.", + )) + .finish_reason(FinishReason::Stop), + ]) + .initial_metrics(Metrics::default().todos(vec![ + Todo::new("Pending task 1").status(TodoStatus::Pending), + Todo::new("In progress task").status(TodoStatus::InProgress), + ])); + + // Run the orchestrator - after first reminder, handler won't add duplicates + // so the second response will complete successfully + ctx.run("Complete this task").await.unwrap(); + + let messages = ctx.output.context_messages(); + + // Find the reminder message injected by the PendingTodosHandler hook + let reminder = messages + .iter() + .filter_map(|entry| entry.message.content()) + .find(|content| content.contains("pending todo items")); + + assert!( + reminder.is_some(), + "Should have a reminder message about pending todos" + ); + + let actual = reminder.unwrap(); + assert!( + actual.contains("- [PENDING] Pending task 1"), + "Reminder should list pending items with status" + ); + assert!( + actual.contains("- [IN_PROGRESS] In progress task"), + "Reminder should list in-progress items with status" + ); +} +#[tokio::test] +async fn test_complete_when_no_pending_todos() { + // Test: is_complete = true when there are no pending todos (only + // completed/cancelled) + use forge_domain::{Metrics, Todo, TodoStatus}; + + let mut ctx = TestContext::default() + .mock_assistant_responses(vec![ + ChatCompletionMessage::assistant(Content::full("Task is done")) + .finish_reason(FinishReason::Stop), + ]) + .initial_metrics(Metrics::default().todos(vec![ + Todo::new("Completed task").status(TodoStatus::Completed), + ])); + + ctx.run("Complete this task").await.unwrap(); + + // Verify TaskComplete IS sent (no pending todos to block completion) + let has_task_complete = ctx + .output + .chat_responses + .iter() + .filter_map(|r| r.as_ref().ok()) + .any(|response| matches!(response, ChatResponse::TaskComplete)); + + assert!( + has_task_complete, + "Should have TaskComplete when no pending todos exist" + ); +} + +#[tokio::test] +async fn test_complete_when_empty_todos() { + // Test: is_complete = true when there are no todos at all + use forge_domain::Metrics; + + let mut ctx = TestContext::default() + .mock_assistant_responses(vec![ + ChatCompletionMessage::assistant(Content::full("Task is done")) + .finish_reason(FinishReason::Stop), + ]) + .initial_metrics(Metrics::default()); + + ctx.run("Complete this task").await.unwrap(); + + // Verify TaskComplete IS sent (no todos to block completion) + let has_task_complete = ctx + .output + .chat_responses + .iter() + .filter_map(|r| r.as_ref().ok()) + .any(|response| matches!(response, ChatResponse::TaskComplete)); + + assert!( + has_task_complete, + "Should have TaskComplete when no todos exist" + ); +} diff --git a/crates/forge_config/.forge.toml b/crates/forge_config/.forge.toml index 930fae6bb0..dfefb6f2f2 100644 --- a/crates/forge_config/.forge.toml +++ b/crates/forge_config/.forge.toml @@ -26,6 +26,7 @@ tool_supported = true tool_timeout_secs = 300 top_k = 30 top_p = 0.8 +verify_todos = true [retry] backoff_factor = 2 diff --git a/crates/forge_config/src/config.rs b/crates/forge_config/src/config.rs index f705d308c3..6b9baaa213 100644 --- a/crates/forge_config/src/config.rs +++ b/crates/forge_config/src/config.rs @@ -276,6 +276,11 @@ pub struct ForgeConfig { /// in a local currency. Defaults to `1.0` (no conversion). #[serde(default)] pub currency_conversion_rate: Decimal, + + /// Enables the pending todos hook that checks for incomplete todo items + /// when a task ends and reminds the LLM about them. + #[serde(default)] + pub verify_todos: bool, } impl ForgeConfig { diff --git a/crates/forge_domain/src/conversation.rs b/crates/forge_domain/src/conversation.rs index 94f0300f15..c0bde6e4e8 100644 --- a/crates/forge_domain/src/conversation.rs +++ b/crates/forge_domain/src/conversation.rs @@ -160,6 +160,21 @@ impl Conversation { Some(costs.iter().sum()) } + /// Returns the number of messages in the conversation context. + /// + /// Returns `0` if the context has not been initialized yet. + pub fn len(&self) -> usize { + self.context + .as_ref() + .map(|ctx| ctx.messages.len()) + .unwrap_or(0) + } + + /// Returns `true` if the conversation context has no messages. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Extracts all related conversation IDs from agent tool calls. /// /// This method scans through all tool results in the conversation's context diff --git a/forge.schema.json b/forge.schema.json index 0355e81aba..43cc190609 100644 --- a/forge.schema.json +++ b/forge.schema.json @@ -344,6 +344,11 @@ "type": "null" } ] + }, + "verify_todos": { + "description": "Enables the pending todos hook that checks for incomplete todo items\nwhen a task ends and reminds the LLM about them.", + "type": "boolean", + "default": false } }, "$defs": { diff --git a/templates/forge-pending-todos-reminder.md b/templates/forge-pending-todos-reminder.md new file mode 100644 index 0000000000..4cbcc00cc5 --- /dev/null +++ b/templates/forge-pending-todos-reminder.md @@ -0,0 +1,7 @@ +You have pending todo items that must be completed before finishing the task: + +{{#each todos}} +- [{{this.status}}] {{this.content}} +{{/each}} + +Please complete all pending items before finishing. \ No newline at end of file