diff --git a/apps/staged/src-tauri/examples/acp_stream_probe.rs b/apps/staged/src-tauri/examples/acp_stream_probe.rs index 8c5f0bbe..f16ad5fe 100644 --- a/apps/staged/src-tauri/examples/acp_stream_probe.rs +++ b/apps/staged/src-tauri/examples/acp_stream_probe.rs @@ -188,7 +188,7 @@ fn main() -> Result<()> { let result = driver .run( "probe-session", - &prompt, + Some(prompt.as_str()), &[], &workdir, &store, diff --git a/apps/staged/src-tauri/src/agent/mod.rs b/apps/staged/src-tauri/src/agent/mod.rs index 975c7b6d..5a879be2 100644 --- a/apps/staged/src-tauri/src/agent/mod.rs +++ b/apps/staged/src-tauri/src/agent/mod.rs @@ -15,16 +15,6 @@ impl acp_client::Store for Store { self.set_agent_session_id(session_id, agent_session_id) .map_err(|e| e.to_string()) } - - fn get_session_messages(&self, session_id: &str) -> Result, String> { - self.get_session_messages(session_id) - .map(|msgs| { - msgs.into_iter() - .map(|m| (m.role.as_str().to_string(), m.content)) - .collect() - }) - .map_err(|e| e.to_string()) - } } // Re-export writer for backward compatibility diff --git a/apps/staged/src-tauri/src/agent/writer.rs b/apps/staged/src-tauri/src/agent/writer.rs index 7cfddb09..45390de7 100644 --- a/apps/staged/src-tauri/src/agent/writer.rs +++ b/apps/staged/src-tauri/src/agent/writer.rs @@ -25,6 +25,13 @@ use acp_client::strip_code_fences; /// always forces an immediate flush regardless of this interval. const FLUSH_INTERVAL: Duration = Duration::from_millis(150); +/// Replay state: roles of previously persisted messages and a cursor tracking +/// how far through replay we are. +struct ReplayState { + roles: Vec, + cursor: usize, +} + /// Streams agent output into the DB, one session at a time. /// /// All methods are `&self` + async — the struct uses interior mutability @@ -46,6 +53,12 @@ pub struct MessageWriter { /// ACP can send multiple content updates for one tool call; we update /// the same row instead of inserting duplicates. current_tool_result_msg_id: Mutex>, + /// Replay dedup state, loaded from DB on resume. + replay: Mutex, + /// Set to `true` while we are skipping a replayed assistant block. + /// Prevents double-advancing the cursor when `flush_text` is called + /// multiple times for the same block (throttled flush + finalize). + skipping_assistant: Mutex, } /// Strip backticks from agent-provided tool-call titles. @@ -65,7 +78,18 @@ fn format_tool_call_content(title: &str, raw_input: Option<&serde_json::Value>) } impl MessageWriter { - pub fn new(session_id: String, store: Arc) -> Self { + pub fn new(session_id: String, store: Arc, resuming: bool) -> Self { + let replay_roles = if resuming { + store + .get_session_messages(&session_id) + .unwrap_or_default() + .into_iter() + .filter(|m| m.role != MessageRole::User) + .map(|m| m.role) + .collect() + } else { + Vec::new() + }; Self { session_id, store, @@ -74,6 +98,23 @@ impl MessageWriter { last_flush_at: Mutex::new(Instant::now()), tool_call_rows: Mutex::new(HashMap::new()), current_tool_result_msg_id: Mutex::new(None), + replay: Mutex::new(ReplayState { + roles: replay_roles, + cursor: 0, + }), + skipping_assistant: Mutex::new(false), + } + } + + /// Check if the current message matches the next expected replay message. + /// If so, advance the cursor and return `true` (skip the write). + async fn try_skip_replay(&self, role: MessageRole) -> bool { + let mut replay = self.replay.lock().await; + if replay.cursor < replay.roles.len() && replay.roles[replay.cursor] == role { + replay.cursor += 1; + true + } else { + false } } @@ -101,6 +142,7 @@ impl MessageWriter { self.flush_text().await; self.current_assistant_msg_id.lock().await.take(); *self.current_text.lock().await = String::new(); + *self.skipping_assistant.lock().await = false; } // ===================================================================== @@ -130,6 +172,10 @@ impl MessageWriter { return; } + if self.try_skip_replay(MessageRole::ToolCall).await { + return; + } + match self .store .add_session_message(&self.session_id, MessageRole::ToolCall, &content) @@ -175,6 +221,10 @@ impl MessageWriter { return; } + if self.try_skip_replay(MessageRole::ToolResult).await { + return; + } + match self .store .add_session_message(&self.session_id, MessageRole::ToolResult, &content) @@ -199,12 +249,22 @@ impl MessageWriter { if text.is_empty() { return; } + // If we already decided to skip this assistant block during replay, + // don't re-enter try_skip_replay (which would advance the cursor + // past a subsequent message). + if *self.skipping_assistant.lock().await { + return; + } let mut msg_id = self.current_assistant_msg_id.lock().await; match *msg_id { Some(id) => { let _ = self.store.update_message_content(id, &text); } None => { + if self.try_skip_replay(MessageRole::Assistant).await { + *self.skipping_assistant.lock().await = true; + return; + } match self.store.add_session_message( &self.session_id, MessageRole::Assistant, @@ -276,7 +336,7 @@ mod tests { let store = Arc::new(Store::in_memory().expect("in-memory store")); let session = Session::new_running("test prompt", Path::new(".")); store.create_session(&session).expect("create session"); - let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store)); + let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store), false); (store, session.id, writer) } @@ -361,4 +421,46 @@ mod tests { assert_eq!(parsed["name"], "Read file"); assert_eq!(parsed["input"]["path"], "bar.rs"); } + + #[tokio::test] + async fn resume_skips_replayed_messages_without_duplicates() { + let store = Arc::new(Store::in_memory().expect("in-memory store")); + let session = Session::new_running("test prompt", Path::new(".")); + store.create_session(&session).expect("create session"); + + // Simulate a first run: user prompt + assistant + tool call + tool result. + store + .add_session_message(&session.id, MessageRole::User, "test prompt") + .expect("add user msg"); + store + .add_session_message(&session.id, MessageRole::Assistant, "thinking...") + .expect("add assistant msg"); + store + .add_session_message(&session.id, MessageRole::ToolCall, "Run ls") + .expect("add tool_call msg"); + store + .add_session_message(&session.id, MessageRole::ToolResult, "file.txt") + .expect("add tool_result msg"); + + // Create a resuming writer — it should load the 3 non-User roles. + let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store), true); + + // Replay the same sequence the server would send (no User messages). + writer.append_text("thinking...").await; + writer.finalize().await; + writer.record_tool_call("tc-1", "Run ls", None).await; + writer.record_tool_result("file.txt").await; + + // Now send a new message that goes beyond replay. + writer.append_text("new response").await; + writer.finalize().await; + + let messages = store + .get_session_messages(&session.id) + .expect("query messages"); + // Original 4 + 1 new assistant = 5 + assert_eq!(messages.len(), 5); + assert_eq!(messages[4].role, MessageRole::Assistant); + assert_eq!(messages[4].content, "new response"); + } } diff --git a/apps/staged/src-tauri/src/session_runner.rs b/apps/staged/src-tauri/src/session_runner.rs index ca156bed..b13c82fb 100644 --- a/apps/staged/src-tauri/src/session_runner.rs +++ b/apps/staged/src-tauri/src/session_runner.rs @@ -280,6 +280,7 @@ pub fn start_session( let writer = Arc::new(MessageWriter::new( config.session_id.clone(), Arc::clone(&store), + config.agent_session_id.is_some(), )); // Read and base64-encode images for the prompt content blocks. @@ -326,7 +327,7 @@ pub fn start_session( driver .run( &config.session_id, - &config.prompt, + Some(config.prompt.as_str()), &image_data, &config.working_dir, &store_trait, diff --git a/crates/acp-client/src/driver.rs b/crates/acp-client/src/driver.rs index aff529cf..e2318cae 100644 --- a/crates/acp-client/src/driver.rs +++ b/crates/acp-client/src/driver.rs @@ -7,7 +7,6 @@ //! - Remote workspace support via Blox //! - Cancellation support -use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::process::Stdio; use std::sync::Arc; @@ -97,16 +96,6 @@ pub trait MessageWriter: Send + Sync { pub trait Store: Send + Sync { /// Save the agent's session ID for resumption. fn set_agent_session_id(&self, session_id: &str, agent_session_id: &str) -> Result<(), String>; - - /// Retrieve existing session messages as `(role, content)` pairs. - /// - /// Used during session resumption to match replayed notifications - /// against previously persisted messages. The default implementation - /// returns an empty list, which is correct for stores that do not - /// support message persistence (e.g. `NoOpStore`). - fn get_session_messages(&self, _session_id: &str) -> Result, String> { - Ok(vec![]) - } } /// Everything needed to run one turn of an agent. @@ -121,10 +110,13 @@ pub trait AgentDriver { /// /// `images` contains `(base64_data, mime_type)` pairs that are sent as /// `ContentBlock::Image` entries alongside the text prompt. + /// + /// When `prompt` is `None`, the driver sets up / resumes the session but + /// does **not** send a prompt to the agent. async fn run( &self, session_id: &str, - prompt: &str, + prompt: Option<&str>, images: &[(String, String)], working_dir: &Path, store: &Arc, @@ -262,7 +254,7 @@ impl AgentDriver for AcpDriver { async fn run( &self, session_id: &str, - prompt: &str, + prompt: Option<&str>, images: &[(String, String)], working_dir: &Path, store: &Arc, @@ -402,20 +394,7 @@ impl AgentDriver for AcpDriver { }; let stdout_compat = incoming_reader.compat(); - let is_resuming = agent_session_id.is_some(); - let db_messages = if is_resuming { - store.get_session_messages(session_id).unwrap_or_else(|e| { - log::warn!("Failed to load session messages for replay matching: {e}"); - vec![] - }) - } else { - vec![] - }; - let handler = Arc::new(AcpNotificationHandler::new( - Arc::clone(writer), - is_resuming, - db_messages, - )); + let handler = Arc::new(AcpNotificationHandler::new(Arc::clone(writer))); let handler_for_conn = Arc::clone(&handler); let (connection, io_future) = @@ -674,171 +653,16 @@ async fn normalize_local_acp_stdout( } // ============================================================================= -// ACP notification handler — phase-based replay-sync state machine +// ACP notification handler // ============================================================================= -/// The current phase of the notification handler during session resumption. -enum HandlerPhase { - /// Accumulating replay notifications and matching against DB messages. - Replaying(ReplayBuffer), - /// Replay detected as complete; waiting for prompt to be sent. - /// All notifications are dropped; tool-call IDs are recorded. - WaitingForPrompt { - replayed_tool_call_ids: HashSet, - }, - /// Prompt has been sent; forwarding live notifications to the writer. - Live { - replayed_tool_call_ids: HashSet, - }, -} - -/// Accumulates replay notifications and matches them against DB messages. -struct ReplayBuffer { - /// `(role, content)` pairs from the DB, in order. - db_messages: Vec<(String, String)>, - /// Index into `db_messages` of the next message to match. - match_cursor: usize, - /// Index of the last non-user message in `db_messages`. - /// When the cursor passes this, replay is considered complete. - target_index: Option, - /// Text accumulated for the current streaming message. - current_text: String, - /// Role of the current streaming message (`"user"` or `"assistant"`). - current_role: Option, - /// Tool-call IDs observed during replay (used as a safety-net later). - replayed_tool_call_ids: HashSet, - /// Timestamp of the last notification received during replay. - last_notification_at: Instant, - /// Whether at least one notification has been received. - received_any: bool, -} - -impl ReplayBuffer { - fn new(db_messages: Vec<(String, String)>) -> Self { - // Find index of last non-user message. - let target_index = db_messages - .iter() - .enumerate() - .rev() - .find(|(_, (role, _))| role != "user") - .map(|(i, _)| i); - - Self { - db_messages, - match_cursor: 0, - target_index, - current_text: String::new(), - current_role: None, - replayed_tool_call_ids: HashSet::new(), - last_notification_at: Instant::now(), - received_any: false, - } - } - - /// Finalize the current streaming text and try to match it against DB. - /// Called when the role transitions (e.g. from assistant text to tool call). - /// Returns `true` if replay is now considered complete. - fn finalize_current(&mut self) -> bool { - if let Some(role) = self.current_role.take() { - if !self.current_text.is_empty() { - self.current_text.clear(); - return self.try_match(&role); - } - } - false - } - - /// Try to match a role against `db_messages[match_cursor]`. - /// Returns `true` if replay is now considered complete. - fn try_match(&mut self, role: &str) -> bool { - if self.match_cursor >= self.db_messages.len() { - return self.is_complete(); - } - - let (db_role, _) = &self.db_messages[self.match_cursor]; - - if role == db_role { - self.match_cursor += 1; - } - // Don't advance cursor on role mismatch. - - self.is_complete() - } - - /// Returns `true` if the match cursor has passed the target index. - fn is_complete(&self) -> bool { - match self.target_index { - Some(target) => self.match_cursor > target, - None => true, // No non-user messages → complete immediately - } - } -} - struct AcpNotificationHandler { writer: Arc, - phase: Mutex, - /// Signalled when replay matching determines all DB messages have been replayed. - replay_done: tokio::sync::Notify, } impl AcpNotificationHandler { - fn new( - writer: Arc, - replaying: bool, - db_messages: Vec<(String, String)>, - ) -> Self { - let phase = if replaying { - HandlerPhase::Replaying(ReplayBuffer::new(db_messages)) - } else { - HandlerPhase::Live { - replayed_tool_call_ids: HashSet::new(), - } - }; - - Self { - writer, - phase: Mutex::new(phase), - replay_done: tokio::sync::Notify::new(), - } - } - - /// Check whether the replay phase has been idle for at least `timeout`. - /// Returns `false` if not in the Replaying phase or no notification received yet. - async fn is_replay_idle(&self, timeout: Duration) -> bool { - let phase = self.phase.lock().await; - if let HandlerPhase::Replaying(buf) = &*phase { - buf.received_any && buf.last_notification_at.elapsed() >= timeout - } else { - false - } - } - - /// Transition from Replaying to WaitingForPrompt. - /// Extracts the replayed_tool_call_ids from the ReplayBuffer. - async fn transition_to_waiting_for_prompt(&self) { - let mut phase = self.phase.lock().await; - let ids = match &mut *phase { - HandlerPhase::Replaying(buf) => std::mem::take(&mut buf.replayed_tool_call_ids), - HandlerPhase::WaitingForPrompt { .. } | HandlerPhase::Live { .. } => return, - }; - *phase = HandlerPhase::WaitingForPrompt { - replayed_tool_call_ids: ids, - }; - } - - /// Transition from WaitingForPrompt (or Replaying) to Live. - async fn transition_to_live(&self) { - let mut phase = self.phase.lock().await; - let ids = match &mut *phase { - HandlerPhase::WaitingForPrompt { - replayed_tool_call_ids, - } => std::mem::take(replayed_tool_call_ids), - HandlerPhase::Replaying(buf) => std::mem::take(&mut buf.replayed_tool_call_ids), - HandlerPhase::Live { .. } => return, - }; - *phase = HandlerPhase::Live { - replayed_tool_call_ids: ids, - }; + fn new(writer: Arc) -> Self { + Self { writer } } } @@ -863,215 +687,52 @@ impl agent_client_protocol::Client for AcpNotificationHandler { &self, notification: SessionNotification, ) -> agent_client_protocol::Result<()> { - // Session metadata events are forwarded regardless of phase. match ¬ification.update { SessionUpdate::SessionInfoUpdate(info) => { self.writer.on_session_info_update(info).await; - return Ok(()); } SessionUpdate::ConfigOptionUpdate(update) => { self.writer .on_config_option_update(&update.config_options) .await; - return Ok(()); } - _ => {} - } - - // Determine the action to take under the lock, then drop the lock - // before calling into the writer to avoid holding it across await points. - enum LiveAction { - AppendText(String), - RecordToolCall { - id: String, - title: String, - raw_input: Option, - }, - ToolCallUpdate { - id: String, - title: Option, - raw_input: Option, - result: Option, - }, - Ignore, - Drop, - } - - let live_action = { - let mut phase = self.phase.lock().await; - - match &mut *phase { - // ── Replaying phase: accumulate chunks, match against DB ── - HandlerPhase::Replaying(buf) => { - buf.last_notification_at = Instant::now(); - buf.received_any = true; - - // Record tool-call IDs for the safety-net. - if let Some(id) = notification_tool_call_id(¬ification.update) { - buf.replayed_tool_call_ids.insert(id); - } - - let completed = match ¬ification.update { - SessionUpdate::AgentMessageChunk(chunk) => { - if let AcpContentBlock::Text(text) = &chunk.content { - // If switching from non-assistant role, finalize previous. - let mut done = false; - if buf.current_role.as_deref() != Some("assistant") { - done = buf.finalize_current(); - buf.current_role = Some("assistant".to_string()); - } - buf.current_text.push_str(&text.text); - done - } else { - false - } - } - SessionUpdate::UserMessageChunk(chunk) => { - if let AcpContentBlock::Text(text) = &chunk.content { - let mut done = false; - if buf.current_role.as_deref() != Some("user") { - done = buf.finalize_current(); - buf.current_role = Some("user".to_string()); - } - buf.current_text.push_str(&text.text); - done - } else { - false - } - } - SessionUpdate::ToolCall(_tc) => { - buf.finalize_current(); - buf.try_match("tool_call") - } - SessionUpdate::ToolCallUpdate(update) => { - if update.fields.content.is_some() { - buf.finalize_current(); - buf.try_match("tool_result") - } else { - false - } - } - SessionUpdate::AgentThoughtChunk(_) => { - // Thinking is not persisted — ignore. - false - } - _ => false, - }; - - if completed { - self.replay_done.notify_one(); - } - return Ok(()); - } - - // ── WaitingForPrompt phase: drop everything, record tool-call IDs ── - HandlerPhase::WaitingForPrompt { - replayed_tool_call_ids, - } => { - if let Some(id) = notification_tool_call_id(¬ification.update) { - replayed_tool_call_ids.insert(id); - } - return Ok(()); - } - - // ── Live phase: determine action, then release lock ── - HandlerPhase::Live { - replayed_tool_call_ids, - } => { - // Safety net: drop notifications for tool-call IDs seen during replay. - if let Some(id) = notification_tool_call_id(¬ification.update) { - if replayed_tool_call_ids.contains(&id) { - return Ok(()); - } - } - - match ¬ification.update { - SessionUpdate::AgentMessageChunk(chunk) => { - if let AcpContentBlock::Text(text) = &chunk.content { - LiveAction::AppendText(text.text.clone()) - } else { - LiveAction::Drop - } - } - SessionUpdate::ToolCall(tool_call) => LiveAction::RecordToolCall { - id: tool_call.tool_call_id.0.to_string(), - title: tool_call.title.clone(), - raw_input: tool_call.raw_input.clone(), - }, - SessionUpdate::ToolCallUpdate(update) => { - let tc_id = update.tool_call_id.0.to_string(); - let title = update.fields.title.clone(); - let raw_input = update.fields.raw_input.clone(); - let result = update - .fields - .content - .as_ref() - .and_then(|c| extract_content_preview(c)); - if title.is_some() || raw_input.is_some() || result.is_some() { - LiveAction::ToolCallUpdate { - id: tc_id, - title, - raw_input, - result, - } - } else { - LiveAction::Drop - } - } - _ => LiveAction::Ignore, - } + SessionUpdate::AgentMessageChunk(chunk) => { + if let AcpContentBlock::Text(text) = &chunk.content { + self.writer.append_text(&text.text).await; } } - // phase lock is dropped here - }; - - // Execute the live action without holding the phase lock. - match live_action { - LiveAction::AppendText(text) => { - self.writer.append_text(&text).await; - } - LiveAction::RecordToolCall { - id, - title, - raw_input, - } => { + SessionUpdate::ToolCall(tool_call) => { self.writer - .record_tool_call(&id, &title, raw_input.as_ref()) + .record_tool_call( + tool_call.tool_call_id.0.as_ref(), + &tool_call.title, + tool_call.raw_input.as_ref(), + ) .await; } - LiveAction::ToolCallUpdate { - id, - title, - raw_input, - result, - } => { + SessionUpdate::ToolCallUpdate(update) => { + let tc_id = update.tool_call_id.0.as_ref(); + let title = update.fields.title.as_deref(); + let raw_input = update.fields.raw_input.as_ref(); if title.is_some() || raw_input.is_some() { self.writer - .update_tool_call_title(&id, title.as_deref(), raw_input.as_ref()) + .update_tool_call_title(tc_id, title, raw_input) .await; } - if let Some(preview) = result { - self.writer.record_tool_result(&preview).await; + if let Some(ref content) = update.fields.content { + if let Some(preview) = extract_content_preview(content) { + self.writer.record_tool_result(&preview).await; + } } } - LiveAction::Ignore => { + _ => { log::debug!("Ignoring session update: {:?}", notification.update); } - LiveAction::Drop => {} } Ok(()) } } -/// Extract the tool-call ID from a session update, if it carries one. -fn notification_tool_call_id(update: &SessionUpdate) -> Option { - match update { - SessionUpdate::ToolCall(tc) => Some(tc.tool_call_id.0.to_string()), - SessionUpdate::ToolCallUpdate(tcu) => Some(tcu.tool_call_id.0.to_string()), - _ => None, - } -} - // ============================================================================= // Protocol helpers // ============================================================================= @@ -1080,7 +741,7 @@ fn notification_tool_call_id(update: &SessionUpdate) -> Option { async fn run_acp_protocol( connection: &ClientSideConnection, working_dir: &Path, - prompt: &str, + prompt: Option<&str>, images: &[(String, String)], store: &Arc, our_session_id: &str, @@ -1108,31 +769,13 @@ async fn run_acp_protocol( ) })??; - // If resuming, wait for replay to complete (content match OR idle timeout). - // An absolute 10s timeout prevents a hang if the server sends zero replay - // notifications (e.g. the remote session was garbage-collected). - if acp_session_id.is_some() { - let absolute_deadline = tokio::time::Instant::now() + Duration::from_secs(10); - loop { - tokio::select! { - _ = handler.replay_done.notified() => { - break; - } - _ = tokio::time::sleep_until(absolute_deadline) => { - log::warn!("Replay-wait absolute timeout reached (10s) — proceeding"); - break; - } - _ = tokio::time::sleep(Duration::from_millis(100)) => { - if handler.is_replay_idle(Duration::from_secs(1)).await { - break; - } - } - } - } - handler.transition_to_waiting_for_prompt().await; - } + // When prompt is None, skip sending a prompt (session-only setup). + let prompt_text = match prompt { + Some(p) => p, + None => return Ok(()), + }; - let mut content_blocks = vec![AcpContentBlock::Text(TextContent::new(prompt))]; + let mut content_blocks = vec![AcpContentBlock::Text(TextContent::new(prompt_text))]; for (data, mime_type) in images { content_blocks.push(AcpContentBlock::Image(ImageContent::new( data.as_str(), @@ -1141,8 +784,6 @@ async fn run_acp_protocol( } let prompt_request = PromptRequest::new(agent_session_id, content_blocks); - handler.transition_to_live().await; - connection .prompt(prompt_request) .await diff --git a/crates/acp-client/src/simple.rs b/crates/acp-client/src/simple.rs index 7f9f610c..b7084808 100644 --- a/crates/acp-client/src/simple.rs +++ b/crates/acp-client/src/simple.rs @@ -53,7 +53,7 @@ impl AgentDriver for SimpleDriverWrapper { async fn run( &self, session_id: &str, - prompt: &str, + prompt: Option<&str>, images: &[(String, String)], working_dir: &Path, store: &Arc, @@ -251,10 +251,16 @@ impl AgentDriver for SimpleDriverWrapper { handler.transition_to_waiting().await; } + // When prompt is None, skip sending a prompt (session-only setup). + let prompt_text = match prompt { + Some(p) => p, + None => return Ok::<_, String>(()), + }; + // Build and send prompt let prompt_request = PromptRequest::new( agent_session_id, - vec![AcpContentBlock::Text(TextContent::new(prompt))], + vec![AcpContentBlock::Text(TextContent::new(prompt_text))], ); if is_resuming { @@ -317,7 +323,7 @@ pub async fn run_acp_prompt(agent: &AcpAgent, working_dir: &Path, prompt: &str) driver .run( "simple-session", - &prompt, + Some(&prompt), &[], &working_dir, &store,