Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/staged/src-tauri/examples/acp_stream_probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ fn main() -> Result<()> {
let result = driver
.run(
"probe-session",
&prompt,
Some(prompt.as_str()),
&[],
&workdir,
&store,
Expand Down
10 changes: 0 additions & 10 deletions apps/staged/src-tauri/src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@ impl acp_client::Store for Store {
self.set_agent_session_id(session_id, agent_session_id)
.map_err(|e| e.to_string())
}

fn get_session_messages(&self, session_id: &str) -> Result<Vec<(String, String)>, String> {
self.get_session_messages(session_id)
.map(|msgs| {
msgs.into_iter()
.map(|m| (m.role.as_str().to_string(), m.content))
.collect()
})
.map_err(|e| e.to_string())
}
}

// Re-export writer for backward compatibility
Expand Down
106 changes: 104 additions & 2 deletions apps/staged/src-tauri/src/agent/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ use acp_client::strip_code_fences;
/// always forces an immediate flush regardless of this interval.
const FLUSH_INTERVAL: Duration = Duration::from_millis(150);

/// Replay state: roles of previously persisted messages and a cursor tracking
/// how far through replay we are.
struct ReplayState {
roles: Vec<MessageRole>,
cursor: usize,
}

/// Streams agent output into the DB, one session at a time.
///
/// All methods are `&self` + async — the struct uses interior mutability
Expand All @@ -46,6 +53,12 @@ pub struct MessageWriter {
/// ACP can send multiple content updates for one tool call; we update
/// the same row instead of inserting duplicates.
current_tool_result_msg_id: Mutex<Option<i64>>,
/// Replay dedup state, loaded from DB on resume.
replay: Mutex<ReplayState>,
/// Set to `true` while we are skipping a replayed assistant block.
/// Prevents double-advancing the cursor when `flush_text` is called
/// multiple times for the same block (throttled flush + finalize).
skipping_assistant: Mutex<bool>,
}

/// Strip backticks from agent-provided tool-call titles.
Expand All @@ -65,7 +78,18 @@ fn format_tool_call_content(title: &str, raw_input: Option<&serde_json::Value>)
}

impl MessageWriter {
pub fn new(session_id: String, store: Arc<Store>) -> Self {
pub fn new(session_id: String, store: Arc<Store>, resuming: bool) -> Self {
let replay_roles = if resuming {
store
.get_session_messages(&session_id)
.unwrap_or_default()
.into_iter()
.filter(|m| m.role != MessageRole::User)
.map(|m| m.role)
.collect()
} else {
Vec::new()
};
Self {
session_id,
store,
Expand All @@ -74,6 +98,23 @@ impl MessageWriter {
last_flush_at: Mutex::new(Instant::now()),
tool_call_rows: Mutex::new(HashMap::new()),
current_tool_result_msg_id: Mutex::new(None),
replay: Mutex::new(ReplayState {
roles: replay_roles,
cursor: 0,
}),
skipping_assistant: Mutex::new(false),
}
}

/// Check if the current message matches the next expected replay message.
/// If so, advance the cursor and return `true` (skip the write).
async fn try_skip_replay(&self, role: MessageRole) -> bool {
let mut replay = self.replay.lock().await;
if replay.cursor < replay.roles.len() && replay.roles[replay.cursor] == role {
replay.cursor += 1;
true
} else {
false
}
}

Expand Down Expand Up @@ -101,6 +142,7 @@ impl MessageWriter {
self.flush_text().await;
self.current_assistant_msg_id.lock().await.take();
*self.current_text.lock().await = String::new();
*self.skipping_assistant.lock().await = false;
}

// =====================================================================
Expand Down Expand Up @@ -130,6 +172,10 @@ impl MessageWriter {
return;
}

if self.try_skip_replay(MessageRole::ToolCall).await {
return;
}

match self
.store
.add_session_message(&self.session_id, MessageRole::ToolCall, &content)
Expand Down Expand Up @@ -175,6 +221,10 @@ impl MessageWriter {
return;
}

if self.try_skip_replay(MessageRole::ToolResult).await {
return;
}

match self
.store
.add_session_message(&self.session_id, MessageRole::ToolResult, &content)
Expand All @@ -199,12 +249,22 @@ impl MessageWriter {
if text.is_empty() {
return;
}
// If we already decided to skip this assistant block during replay,
// don't re-enter try_skip_replay (which would advance the cursor
// past a subsequent message).
if *self.skipping_assistant.lock().await {
return;
}
let mut msg_id = self.current_assistant_msg_id.lock().await;
match *msg_id {
Some(id) => {
let _ = self.store.update_message_content(id, &text);
}
None => {
if self.try_skip_replay(MessageRole::Assistant).await {
*self.skipping_assistant.lock().await = true;
return;
}
match self.store.add_session_message(
&self.session_id,
MessageRole::Assistant,
Expand Down Expand Up @@ -276,7 +336,7 @@ mod tests {
let store = Arc::new(Store::in_memory().expect("in-memory store"));
let session = Session::new_running("test prompt", Path::new("."));
store.create_session(&session).expect("create session");
let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store));
let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store), false);
(store, session.id, writer)
}

Expand Down Expand Up @@ -361,4 +421,46 @@ mod tests {
assert_eq!(parsed["name"], "Read file");
assert_eq!(parsed["input"]["path"], "bar.rs");
}

#[tokio::test]
async fn resume_skips_replayed_messages_without_duplicates() {
let store = Arc::new(Store::in_memory().expect("in-memory store"));
let session = Session::new_running("test prompt", Path::new("."));
store.create_session(&session).expect("create session");

// Simulate a first run: user prompt + assistant + tool call + tool result.
store
.add_session_message(&session.id, MessageRole::User, "test prompt")
.expect("add user msg");
store
.add_session_message(&session.id, MessageRole::Assistant, "thinking...")
.expect("add assistant msg");
store
.add_session_message(&session.id, MessageRole::ToolCall, "Run ls")
.expect("add tool_call msg");
store
.add_session_message(&session.id, MessageRole::ToolResult, "file.txt")
.expect("add tool_result msg");

// Create a resuming writer — it should load the 3 non-User roles.
let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store), true);

// Replay the same sequence the server would send (no User messages).
writer.append_text("thinking...").await;
writer.finalize().await;
writer.record_tool_call("tc-1", "Run ls", None).await;
writer.record_tool_result("file.txt").await;

// Now send a new message that goes beyond replay.
writer.append_text("new response").await;
writer.finalize().await;

let messages = store
.get_session_messages(&session.id)
.expect("query messages");
// Original 4 + 1 new assistant = 5
assert_eq!(messages.len(), 5);
assert_eq!(messages[4].role, MessageRole::Assistant);
assert_eq!(messages[4].content, "new response");
}
}
3 changes: 2 additions & 1 deletion apps/staged/src-tauri/src/session_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ pub fn start_session(
let writer = Arc::new(MessageWriter::new(
config.session_id.clone(),
Arc::clone(&store),
config.agent_session_id.is_some(),
));

// Read and base64-encode images for the prompt content blocks.
Expand Down Expand Up @@ -326,7 +327,7 @@ pub fn start_session(
driver
.run(
&config.session_id,
&config.prompt,
Some(config.prompt.as_str()),
&image_data,
&config.working_dir,
&store_trait,
Expand Down
Loading