From d4d155c8c843ba3cd9d7a8742522c95044f62d3b Mon Sep 17 00:00:00 2001 From: flazouh Date: Sun, 5 Apr 2026 23:08:11 +0300 Subject: [PATCH 1/2] feat(acp): add machine stdio transport for Acepe Expose a machine stdio entrypoint in Forge and route it through a real ACP stdio transport so Acepe can launch Forge as an installable provider instead of depending on an unpublished branch. Co-Authored-By: ForgeCode --- Cargo.lock | 93 ++++++ Cargo.toml | 1 + crates/forge_api/src/api.rs | 3 + crates/forge_api/src/forge_api.rs | 5 + crates/forge_app/Cargo.toml | 4 + crates/forge_app/src/acp/adapter.rs | 120 ++++++++ crates/forge_app/src/acp/conversion.rs | 290 +++++++++++++++++++ crates/forge_app/src/acp/error.rs | 27 ++ crates/forge_app/src/acp/mod.rs | 88 ++++++ crates/forge_app/src/acp/prompt_handler.rs | 283 ++++++++++++++++++ crates/forge_app/src/acp/session_handlers.rs | 225 ++++++++++++++ crates/forge_app/src/acp/state_builders.rs | 180 ++++++++++++ crates/forge_app/src/acp_app.rs | 89 ++++++ crates/forge_app/src/lib.rs | 3 + crates/forge_main/src/acp_runner.rs | 57 ++++ crates/forge_main/src/cli.rs | 37 +++ crates/forge_main/src/lib.rs | 1 + crates/forge_main/src/ui.rs | 8 + 18 files changed, 1514 insertions(+) create mode 100644 crates/forge_app/src/acp/adapter.rs create mode 100644 crates/forge_app/src/acp/conversion.rs create mode 100644 crates/forge_app/src/acp/error.rs create mode 100644 crates/forge_app/src/acp/mod.rs create mode 100644 crates/forge_app/src/acp/prompt_handler.rs create mode 100644 crates/forge_app/src/acp/session_handlers.rs create mode 100644 crates/forge_app/src/acp/state_builders.rs create mode 100644 crates/forge_app/src/acp_app.rs create mode 100644 crates/forge_main/src/acp_runner.rs diff --git a/Cargo.lock b/Cargo.lock index ab087820c8..59b465bc18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,37 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "agent-client-protocol" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "499b7ff5c6c842e43fb188f6da7c99a258ae89a9df8c896d6e9784da9b4b23e7" +dependencies = [ + "agent-client-protocol-schema", + "anyhow", + "async-broadcast", + "async-trait", + "derive_more", + "futures", + "log", + "serde", + "serde_json", +] + +[[package]] +name = "agent-client-protocol-schema" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44bc1fef9c32f03bce2ab44af35b6f483bfd169bf55cc59beeb2e3b1a00ae4d1" +dependencies = [ + "anyhow", + "derive_more", + "schemars 1.2.1", + "serde", + "serde_json", + "strum 0.27.2", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -130,6 +161,18 @@ dependencies = [ "serde_json", ] +[[package]] +name = "async-broadcast" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "435a87a52755b8f27fcf321ac4f04b2802e337c8c4872923137471ec39c37532" +dependencies = [ + "event-listener", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "async-compression" version = "0.4.39" @@ -873,6 +916,15 @@ version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "config" version = "0.15.22" @@ -1787,6 +1839,27 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "event-listener" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "eventsource-stream" version = "0.2.3" @@ -1924,10 +1997,12 @@ dependencies = [ name = "forge_app" version = "0.1.0" dependencies = [ + "agent-client-protocol", "anyhow", "async-recursion", "async-trait", "backon", + "base64 0.22.1", "bytes", "chrono", "console", @@ -1967,9 +2042,11 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-stream", + "tokio-util", "tonic", "tracing", "url", + "uuid", ] [[package]] @@ -4378,6 +4455,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" +[[package]] +name = "parking" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" + [[package]] name = "parking_lot" version = "0.12.5" @@ -6177,6 +6260,15 @@ version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +[[package]] +name = "strum" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +dependencies = [ + "strum_macros 0.27.2", +] + [[package]] name = "strum" version = "0.28.0" @@ -6634,6 +6726,7 @@ checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "pin-project-lite", "tokio", diff --git a/Cargo.toml b/Cargo.toml index ac214161cb..a2ac2f4642 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ opt-level = 3 strip = true [workspace.dependencies] +agent-client-protocol = { version = "0.9", features = ["unstable_session_model"] } anyhow = "1.0.102" async-recursion = "1.1.1" async-stream = "0.3" diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index b13863b0b2..7d1d2a04a0 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -251,4 +251,7 @@ pub trait API: Sync + Send { &self, data_parameters: DataGenerationParameters, ) -> Result>>; + + /// Starts the ACP (Agent Communication Protocol) server over stdio. + async fn acp_start_stdio(&self) -> Result<()>; } diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index 36df08a4c4..223d61156e 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -404,6 +404,11 @@ impl Result<()> { + let acp_app = forge_app::AcpApp::new(self.services.clone()); + acp_app.start_stdio().await + } + async fn get_default_provider(&self) -> Result> { let provider_id = self.services.get_default_provider().await?; self.services.get_provider(provider_id).await diff --git a/crates/forge_app/Cargo.toml b/crates/forge_app/Cargo.toml index 8f5f1873b5..476a3be713 100644 --- a/crates/forge_app/Cargo.toml +++ b/crates/forge_app/Cargo.toml @@ -48,6 +48,10 @@ lazy_static.workspace = true forge_json_repair.workspace = true tonic.workspace = true +agent-client-protocol.workspace = true +tokio-util = { workspace = true, features = ["compat"] } +base64.workspace = true +uuid.workspace = true [dev-dependencies] diff --git a/crates/forge_app/src/acp/adapter.rs b/crates/forge_app/src/acp/adapter.rs new file mode 100644 index 0000000000..d7c1ebc73d --- /dev/null +++ b/crates/forge_app/src/acp/adapter.rs @@ -0,0 +1,120 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use agent_client_protocol as acp; +use forge_domain::{AgentId, ConversationId}; +use tokio::sync::{Mutex, Notify, mpsc}; + +use crate::Services; + +use super::error::{Error, Result}; + +#[derive(Clone)] +pub(super) struct SessionState { + pub conversation_id: ConversationId, + pub agent_id: AgentId, + pub cancel_notify: Option>, +} + +pub(crate) struct AcpAdapter { + pub(super) services: Arc, + pub(super) session_update_tx: mpsc::UnboundedSender, + pub(super) client_conn: Arc>>>, + sessions: Arc>>, +} + +impl AcpAdapter { + pub(crate) fn new( + services: Arc, + session_update_tx: mpsc::UnboundedSender, + ) -> Self { + Self { + services, + session_update_tx, + client_conn: Arc::new(Mutex::new(None)), + sessions: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub(crate) async fn set_client_connection(&self, conn: Arc) { + *self.client_conn.lock().await = Some(conn); + } + + pub(super) async fn store_session(&self, session_id: String, state: SessionState) { + self.sessions.lock().await.insert(session_id, state); + } + + pub(super) async fn session_state(&self, session_id: &str) -> Result { + self.sessions + .lock() + .await + .get(session_id) + .cloned() + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found"))) + } + + pub(super) async fn update_session_agent( + &self, + session_id: &str, + agent_id: AgentId, + ) -> Result<()> { + let mut sessions = self.sessions.lock().await; + let state = sessions + .get_mut(session_id) + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found")))?; + state.agent_id = agent_id; + Ok(()) + } + + pub(super) async fn set_cancel_notify( + &self, + session_id: &str, + cancel_notify: Option>, + ) -> Result<()> { + let mut sessions = self.sessions.lock().await; + let state = sessions + .get_mut(session_id) + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found")))?; + state.cancel_notify = cancel_notify; + Ok(()) + } + + pub(super) async fn cancel_session(&self, session_id: &str) -> bool { + let notify = self + .sessions + .lock() + .await + .get(session_id) + .and_then(|state| state.cancel_notify.clone()); + + if let Some(notify) = notify { + notify.notify_waiters(); + true + } else { + false + } + } + + pub(super) async fn ensure_session( + &self, + session_id: &str, + conversation_id: ConversationId, + agent_id: AgentId, + ) -> SessionState { + let mut sessions = self.sessions.lock().await; + sessions + .entry(session_id.to_string()) + .or_insert_with(|| SessionState { + conversation_id, + agent_id, + cancel_notify: None, + }) + .clone() + } + + pub(super) fn send_notification(&self, notification: acp::SessionNotification) -> Result<()> { + self.session_update_tx + .send(notification) + .map_err(|_| Error::Application(anyhow::anyhow!("Failed to send notification"))) + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/conversion.rs b/crates/forge_app/src/acp/conversion.rs new file mode 100644 index 0000000000..31a9021ce0 --- /dev/null +++ b/crates/forge_app/src/acp/conversion.rs @@ -0,0 +1,290 @@ +use std::path::PathBuf; + +use agent_client_protocol as acp; +use forge_domain::{ + Agent, AgentId, Attachment, AttachmentContent, FileInfo, ToolCallFull, ToolName, ToolOutput, + ToolValue, +}; + +use super::error::{Error, Result}; + +pub(crate) fn map_tool_kind(tool_name: &ToolName) -> acp::ToolKind { + match tool_name.as_str() { + "read" => acp::ToolKind::Read, + "write" | "patch" => acp::ToolKind::Edit, + "remove" | "undo" => acp::ToolKind::Delete, + "fs_search" | "sem_search" => acp::ToolKind::Search, + "shell" => acp::ToolKind::Execute, + "fetch" => acp::ToolKind::Fetch, + "sage" => acp::ToolKind::Think, + _ => { + let name = tool_name.as_str(); + if name.starts_with("mcp_") { + if name.contains("read") + || name.contains("get") + || name.contains("fetch") + || name.contains("list") + || name.contains("show") + || name.contains("view") + || name.contains("load") + { + acp::ToolKind::Read + } else if name.contains("search") + || name.contains("query") + || name.contains("find") + || name.contains("filter") + || name.contains("lookup") + { + acp::ToolKind::Search + } else if name.contains("write") + || name.contains("update") + || name.contains("create") + || name.contains("set") + || name.contains("add") + || name.contains("insert") + || name.contains("push") + || name.contains("merge") + || name.contains("fork") + || name.contains("comment") + || name.contains("assign") + || name.contains("request") + { + acp::ToolKind::Edit + } else if name.contains("delete") + || name.contains("remove") + || name.contains("drop") + || name.contains("clear") + || name.contains("close") + || name.contains("cancel") + { + acp::ToolKind::Delete + } else if name.contains("execute") + || name.contains("run") + || name.contains("start") + || name.contains("invoke") + || name.contains("call") + { + acp::ToolKind::Execute + } else { + acp::ToolKind::Other + } + } else { + acp::ToolKind::Other + } + } + } +} + +pub(crate) fn extract_file_locations( + tool_name: &ToolName, + arguments: &serde_json::Value, +) -> Vec { + match tool_name.as_str() { + "read" | "write" | "patch" | "remove" | "undo" => arguments + .get("file_path") + .and_then(|value| value.as_str()) + .map(|file_path| vec![acp::ToolCallLocation::new(PathBuf::from(file_path))]) + .unwrap_or_default(), + _ => vec![], + } +} + +pub(crate) fn map_tool_call_to_acp(tool_call: &ToolCallFull) -> acp::ToolCall { + let tool_call_id = tool_call + .call_id + .as_ref() + .map(|id| id.as_str().to_string()) + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + + let locations = extract_file_locations( + &tool_call.name, + &serde_json::to_value(&tool_call.arguments).unwrap_or(serde_json::json!({})), + ); + + acp::ToolCall::new(tool_call_id, tool_call.name.as_str().to_string()) + .kind(map_tool_kind(&tool_call.name)) + .status(acp::ToolCallStatus::Pending) + .locations(locations) + .raw_input( + serde_json::to_value(&tool_call.arguments) + .ok() + .filter(|value| !value.is_null()), + ) +} + +pub(crate) struct ToolOutputConverter { + _private: (), +} + +impl ToolOutputConverter { + pub(crate) fn new(output: &ToolOutput) -> Self { + let _ = output; + Self { _private: () } + } + + pub(crate) fn convert(output: &ToolOutput) -> Vec { + let converter = Self::new(output); + output + .values + .iter() + .filter_map(|value| converter.convert_value(value)) + .collect() + } + + fn convert_value(&self, value: &ToolValue) -> Option { + match value { + ToolValue::Text(text) => self.convert_text(text), + ToolValue::AI { value, .. } => self.convert_text(value), + ToolValue::Image(image) => Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Image(acp::ImageContent::new(image.data(), image.mime_type())), + ))), + ToolValue::Empty => None, + } + } + + fn convert_text(&self, text: &str) -> Option { + if text.is_empty() { + None + } else { + Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Text(acp::TextContent::new(text.to_string())), + ))) + } + } +} + +pub(crate) fn acp_resource_to_attachment(resource: &acp::EmbeddedResource) -> Result { + let (content_text, uri) = match &resource.resource { + acp::EmbeddedResourceResource::TextResourceContents(text_resource) => { + (text_resource.text.clone(), text_resource.uri.clone()) + } + acp::EmbeddedResourceResource::BlobResourceContents(blob_resource) => { + let decoded = base64::Engine::decode( + &base64::engine::general_purpose::STANDARD, + &blob_resource.blob, + ) + .map_err(|error| { + Error::Application(anyhow::anyhow!("Failed to decode base64 blob: {}", error)) + })?; + let text = String::from_utf8(decoded).map_err(|error| { + Error::Application(anyhow::anyhow!("Failed to decode UTF-8: {}", error)) + })?; + (text, blob_resource.uri.clone()) + } + _ => { + return Err(Error::Application(anyhow::anyhow!( + "Unsupported resource type" + ))) + } + }; + + let path = uri_to_path(&uri); + let total_lines = content_text.lines().count() as u64; + let info = FileInfo::new(1, total_lines, total_lines, String::new()); + let content = AttachmentContent::FileContent { + content: content_text, + info, + }; + + Ok(Attachment { path, content }) +} + +pub(crate) fn uri_to_path(uri: &str) -> String { + if let Some(path) = uri.strip_prefix("file://") { + if path.len() > 2 && path.chars().nth(2) == Some(':') { + path.trim_start_matches('/').to_string() + } else { + path.to_string() + } + } else { + uri.to_string() + } +} + +pub(crate) fn build_session_mode_state( + agents: &[Agent], + current_agent_id: &AgentId, +) -> acp::SessionModeState { + let available_modes = agents + .iter() + .map(|agent| { + acp::SessionMode::new( + acp::SessionModeId::new(agent.id.to_string()), + agent.id.to_string(), + ) + .description(agent.description.clone()) + }) + .collect(); + + acp::SessionModeState::new( + acp::SessionModeId::new(current_agent_id.to_string()), + available_modes, + ) +} + +#[cfg(test)] +mod tests { + use forge_domain::{ConversationId, Image}; + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_uri_to_path_preserves_non_file_uri() { + let fixture = "relative/path.txt"; + let actual = uri_to_path(fixture); + let expected = "relative/path.txt".to_string(); + assert_eq!(actual, expected); + } + + #[test] + fn test_markdown_sent_to_acp_not_xml() { + let fixture = ToolOutput::text("## File: test.txt\n\nContent here"); + + let actual = ToolOutputConverter::convert(&fixture); + + assert_eq!(actual.len(), 1); + if let Some(acp::ToolCallContent::Content(content)) = actual.first() { + if let acp::ContentBlock::Text(text) = &content.content { + assert_eq!(text.text, "## File: test.txt\n\nContent here"); + } else { + panic!("Expected text content block"); + } + } else { + panic!("Expected content"); + } + } + + #[test] + fn test_ai_output_sent_to_acp_as_text() { + let fixture = ToolOutput::ai(ConversationId::generate(), "Agent result"); + + let actual = ToolOutputConverter::convert(&fixture); + + assert_eq!(actual.len(), 1); + if let Some(acp::ToolCallContent::Content(content)) = actual.first() { + if let acp::ContentBlock::Text(text) = &content.content { + assert_eq!(text.text, "Agent result"); + } else { + panic!("Expected text content block"); + } + } else { + panic!("Expected content"); + } + } + + #[test] + fn test_image_sent_to_acp() { + let image = Image::new_bytes(vec![1, 2, 3, 4], "image/png".to_string()); + let fixture = ToolOutput::image(image); + + let actual = ToolOutputConverter::convert(&fixture); + + assert_eq!(actual.len(), 1); + if let Some(acp::ToolCallContent::Content(content)) = actual.first() { + assert!(matches!(content.content, acp::ContentBlock::Image(_))); + } else { + panic!("Expected content"); + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/error.rs b/crates/forge_app/src/acp/error.rs new file mode 100644 index 0000000000..1dbf4a0696 --- /dev/null +++ b/crates/forge_app/src/acp/error.rs @@ -0,0 +1,27 @@ +use agent_client_protocol as acp; + +pub type Result = std::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("ACP protocol error: {0}")] + Protocol(#[from] acp::Error), + + #[error("Forge application error: {0}")] + Application(#[from] anyhow::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + +impl From for acp::Error { + fn from(error: Error) -> Self { + match error { + Error::Protocol(error) => error, + Error::Application(error) => { + acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error) + } + Error::Io(error) => acp::Error::into_internal_error(&error), + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/mod.rs b/crates/forge_app/src/acp/mod.rs new file mode 100644 index 0000000000..d90a00fda9 --- /dev/null +++ b/crates/forge_app/src/acp/mod.rs @@ -0,0 +1,88 @@ +mod adapter; +mod conversion; +mod error; +mod prompt_handler; +mod session_handlers; +mod state_builders; + +pub(crate) use adapter::AcpAdapter; + +#[async_trait::async_trait(?Send)] +impl agent_client_protocol::Agent for AcpAdapter { + async fn initialize( + &self, + arguments: agent_client_protocol::InitializeRequest, + ) -> std::result::Result< + agent_client_protocol::InitializeResponse, + agent_client_protocol::Error, + > { + self.handle_initialize(arguments).await + } + + async fn authenticate( + &self, + arguments: agent_client_protocol::AuthenticateRequest, + ) -> std::result::Result< + agent_client_protocol::AuthenticateResponse, + agent_client_protocol::Error, + > { + self.handle_authenticate(arguments).await + } + + async fn new_session( + &self, + arguments: agent_client_protocol::NewSessionRequest, + ) -> std::result::Result< + agent_client_protocol::NewSessionResponse, + agent_client_protocol::Error, + > { + self.handle_new_session(arguments).await + } + + async fn load_session( + &self, + arguments: agent_client_protocol::LoadSessionRequest, + ) -> std::result::Result< + agent_client_protocol::LoadSessionResponse, + agent_client_protocol::Error, + > { + self.handle_load_session(arguments).await + } + + async fn prompt( + &self, + arguments: agent_client_protocol::PromptRequest, + ) -> std::result::Result< + agent_client_protocol::PromptResponse, + agent_client_protocol::Error, + > { + self.handle_prompt(arguments).await + } + + async fn cancel( + &self, + arguments: agent_client_protocol::CancelNotification, + ) -> std::result::Result<(), agent_client_protocol::Error> { + self.handle_cancel(arguments).await + } + + async fn set_session_mode( + &self, + arguments: agent_client_protocol::SetSessionModeRequest, + ) -> std::result::Result< + agent_client_protocol::SetSessionModeResponse, + agent_client_protocol::Error, + > { + self.handle_set_session_mode(arguments).await + } + + async fn set_session_model( + &self, + arguments: agent_client_protocol::SetSessionModelRequest, + ) -> std::result::Result< + agent_client_protocol::SetSessionModelResponse, + agent_client_protocol::Error, + > { + self.handle_set_session_model(arguments).await + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/prompt_handler.rs b/crates/forge_app/src/acp/prompt_handler.rs new file mode 100644 index 0000000000..08bfba5f0b --- /dev/null +++ b/crates/forge_app/src/acp/prompt_handler.rs @@ -0,0 +1,283 @@ +use std::sync::Arc; + +use agent_client_protocol as acp; +use agent_client_protocol::Client; +use forge_domain::{ + ChatRequest, ChatResponse, ChatResponseContent, Event, EventValue, InterruptionReason, +}; +use futures::StreamExt; +use tokio::sync::Notify; + +use crate::{ForgeApp, Services}; + +use super::adapter::AcpAdapter; +use super::conversion; +use super::error::{Error, Result}; + +impl AcpAdapter { + pub(super) async fn handle_prompt( + &self, + arguments: acp::PromptRequest, + ) -> std::result::Result { + let session_key = arguments.session_id.0.as_ref().to_string(); + let session = self.session_state(&session_key).await.map_err(acp::Error::from)?; + + let mut prompt_text_parts = Vec::new(); + let mut attachments = Vec::new(); + + for content_block in &arguments.prompt { + match content_block { + acp::ContentBlock::Text(text_content) => { + prompt_text_parts.push(text_content.text.clone()); + } + acp::ContentBlock::ResourceLink(resource_link) => { + let path = conversion::uri_to_path(&resource_link.uri); + prompt_text_parts.push(format!("@[{}]", path)); + } + acp::ContentBlock::Resource(embedded_resource) => { + match conversion::acp_resource_to_attachment(embedded_resource) { + Ok(attachment) => attachments.push(attachment), + Err(error) => { + tracing::warn!("Failed to convert embedded resource: {}", error); + } + } + } + _ => {} + } + } + + let prompt_text = prompt_text_parts.join("\n"); + let cancel_notify = Arc::new(Notify::new()); + self.set_cancel_notify(&session_key, Some(cancel_notify.clone())) + .await + .map_err(acp::Error::from)?; + + let response = self + .run_prompt_loop(&arguments.session_id, &session_key, session, prompt_text, attachments, cancel_notify) + .await; + + let _ = self.set_cancel_notify(&session_key, None).await; + response + } + + async fn run_prompt_loop( + &self, + session_id: &acp::SessionId, + session_key: &str, + session: super::adapter::SessionState, + prompt_text: String, + attachments: Vec, + cancel_notify: Arc, + ) -> std::result::Result { + let mut event = Event::new(EventValue::text(prompt_text)); + event.attachments = attachments; + + let mut chat_request = ChatRequest::new(event, session.conversation_id); + loop { + let app = ForgeApp::new(self.services.clone()); + let mut stream = app + .chat(session.agent_id.clone(), chat_request) + .await + .map_err(|error| acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error))?; + + let mut continue_after_interrupt = false; + + loop { + tokio::select! { + _ = cancel_notify.notified() => { + tracing::info!("ACP prompt cancelled for session {}", session_key); + return Ok(acp::PromptResponse::new(acp::StopReason::Cancelled)); + } + response_result = stream.next() => { + match response_result { + Some(Ok(response)) => { + self.handle_chat_response(session_id, response, &mut continue_after_interrupt).await?; + } + Some(Err(error)) => { + tracing::error!("Error in chat stream: {}", error); + return Err(acp::Error::into_internal_error( + error.as_ref() as &dyn std::error::Error, + )); + } + None => { + break; + } + } + } + } + } + + if continue_after_interrupt { + chat_request = ChatRequest::new(Event::new(EventValue::text("")), session.conversation_id); + continue; + } + + return Ok(acp::PromptResponse::new(acp::StopReason::EndTurn)); + } + } + + async fn handle_chat_response( + &self, + session_id: &acp::SessionId, + response: ChatResponse, + continue_after_interrupt: &mut bool, + ) -> std::result::Result<(), acp::Error> { + match response { + ChatResponse::TaskMessage { content } => { + self.handle_task_message(session_id, content).await?; + } + ChatResponse::TaskReasoning { content } => { + if !content.is_empty() { + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::AgentThoughtChunk(acp::ContentChunk::new( + acp::ContentBlock::Text(acp::TextContent::new(content)), + )), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + } + ChatResponse::ToolCallStart { tool_call, .. } => { + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::ToolCallUpdate( + conversion::map_tool_call_to_acp(&tool_call).into(), + ), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + ChatResponse::ToolCallEnd(tool_result) => { + let content = conversion::ToolOutputConverter::convert(&tool_result.output); + let status = if tool_result.output.is_error { + acp::ToolCallStatus::Failed + } else { + acp::ToolCallStatus::Completed + }; + let tool_call_id = tool_result + .call_id + .as_ref() + .map(|id| id.as_str().to_string()) + .unwrap_or_else(|| "unknown".to_string()); + let update = acp::ToolCallUpdate::new( + tool_call_id, + acp::ToolCallUpdateFields::new().status(status).content(content), + ); + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::ToolCallUpdate(update), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + ChatResponse::TaskComplete => {} + ChatResponse::RetryAttempt { .. } => {} + ChatResponse::Interrupt { reason } => { + let should_continue = self + .request_continue_permission(session_id, &reason) + .await + .map_err(acp::Error::from)?; + if should_continue { + *continue_after_interrupt = true; + } + } + } + + Ok(()) + } + + async fn handle_task_message( + &self, + session_id: &acp::SessionId, + content: ChatResponseContent, + ) -> std::result::Result<(), acp::Error> { + match content { + ChatResponseContent::ToolOutput(_) => {} + ChatResponseContent::Markdown { text, .. } => { + if !text.is_empty() { + let notification = acp::SessionNotification::new( + session_id.clone(), + acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new( + acp::ContentBlock::Text(acp::TextContent::new(text)), + )), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + } + } + ChatResponseContent::ToolInput(_) => {} + } + + Ok(()) + } + + async fn request_continue_permission( + &self, + session_id: &acp::SessionId, + reason: &InterruptionReason, + ) -> Result { + let client_conn = self.client_conn.lock().await; + let Some(conn) = client_conn.as_ref() else { + return Ok(false); + }; + + let (title, description) = format_interruption(reason); + let options = vec![ + acp::PermissionOption::new( + "continue", + "Continue Anyway", + acp::PermissionOptionKind::AllowOnce, + ), + acp::PermissionOption::new("stop", "Stop", acp::PermissionOptionKind::RejectOnce), + ]; + let tool_call_update = acp::ToolCallUpdate::new( + "interrupt-continue", + acp::ToolCallUpdateFields::new() + .status(acp::ToolCallStatus::Pending) + .title(title.clone()), + ); + + let mut request = acp::RequestPermissionRequest::new( + session_id.clone(), + tool_call_update, + options, + ); + let mut meta = serde_json::Map::new(); + meta.insert("title".to_string(), serde_json::json!(title)); + meta.insert("description".to_string(), serde_json::json!(description)); + request = request.meta(meta); + + let response = conn.request_permission(request).await.map_err(|error| { + Error::Application(anyhow::anyhow!("Permission request failed: {}", error)) + })?; + + match response.outcome { + acp::RequestPermissionOutcome::Selected(selection) => { + Ok(selection.option_id.0.as_ref() == "continue") + } + acp::RequestPermissionOutcome::Cancelled => Ok(false), + _ => Ok(false), + } + } +} + +fn format_interruption(reason: &InterruptionReason) -> (String, String) { + match reason { + InterruptionReason::MaxToolFailurePerTurnLimitReached { limit, errors } => { + let error_summary = errors + .iter() + .map(|(tool_name, count)| format!("{} ({})", tool_name, count)) + .collect::>() + .join(", "); + ( + format!("Tool failure limit reached ({})", limit), + format!("Forge stopped after repeated tool failures: {}", error_summary), + ) + } + InterruptionReason::MaxRequestPerTurnLimitReached { limit } => ( + format!("Request limit reached ({})", limit), + "Forge reached the maximum number of requests for this turn.".to_string(), + ), + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/session_handlers.rs b/crates/forge_app/src/acp/session_handlers.rs new file mode 100644 index 0000000000..4d6ce8b7ff --- /dev/null +++ b/crates/forge_app/src/acp/session_handlers.rs @@ -0,0 +1,225 @@ +use agent_client_protocol as acp; +use forge_domain::{AgentId, Conversation, ConversationId, ModelId}; + +use crate::{AgentRegistry, AppConfigService, ConversationService, Services}; + +use super::adapter::{AcpAdapter, SessionState}; +use super::state_builders::StateBuilders; + +const VERSION: &str = env!("CARGO_PKG_VERSION"); + +impl AcpAdapter { + pub(super) async fn handle_initialize( + &self, + arguments: acp::InitializeRequest, + ) -> std::result::Result { + tracing::info!("Received initialize request from client: {:?}", arguments.client_info); + + Ok(acp::InitializeResponse::new(acp::ProtocolVersion::V1) + .agent_capabilities( + acp::AgentCapabilities::new().load_session(true).mcp_capabilities( + acp::McpCapabilities::new() + .http(true) + .sse(true), + ), + ) + .agent_info( + acp::Implementation::new("forge".to_string(), VERSION.to_string()) + .title("Forge Code".to_string()), + )) + } + + pub(super) async fn handle_authenticate( + &self, + _arguments: acp::AuthenticateRequest, + ) -> std::result::Result { + Ok(acp::AuthenticateResponse::default()) + } + + pub(super) async fn handle_new_session( + &self, + arguments: acp::NewSessionRequest, + ) -> std::result::Result { + if !arguments.mcp_servers.is_empty() { + StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) + .await + .map_err(acp::Error::from)?; + } + + let active_agent_id = self + .services + .agent_registry() + .get_active_agent_id() + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .unwrap_or_default(); + + let conversation = Conversation::generate(); + let conversation_id = conversation.id; + self.services + .conversation_service() + .upsert_conversation(conversation) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))?; + + let session_id = acp::SessionId::new(conversation_id.into_string()); + let session_key = session_id.0.as_ref().to_string(); + self.store_session( + session_key, + SessionState { + conversation_id, + agent_id: active_agent_id.clone(), + cancel_notify: None, + }, + ) + .await; + + let agent = self + .services + .agent_registry() + .get_agent(&active_agent_id) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .ok_or_else(|| { + acp::Error::into_internal_error(&*anyhow::anyhow!( + "Agent '{}' not found", + active_agent_id + )) + })?; + + let mode_state = StateBuilders::build_session_mode_state( + self.services.as_ref(), + &active_agent_id, + ) + .await + .map_err(acp::Error::from)?; + let model_state = StateBuilders::build_session_model_state(&self.services, &agent) + .await + .map_err(acp::Error::from)?; + + Ok(acp::NewSessionResponse::new(session_id) + .modes(mode_state) + .models(model_state)) + } + + pub(super) async fn handle_load_session( + &self, + arguments: acp::LoadSessionRequest, + ) -> std::result::Result { + if !arguments.mcp_servers.is_empty() { + StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) + .await + .map_err(acp::Error::from)?; + } + + let session_key = arguments.session_id.0.as_ref().to_string(); + let conversation_id = ConversationId::parse(&session_key) + .map_err(|error| acp::Error::into_internal_error(&error))?; + + let conversation = self + .services + .conversation_service() + .find_conversation(&conversation_id) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))?; + if conversation.is_none() { + return Err(acp::Error::invalid_params()); + } + + let active_agent_id = self + .services + .agent_registry() + .get_active_agent_id() + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .unwrap_or_default(); + let state = self + .ensure_session(&session_key, conversation_id, active_agent_id.clone()) + .await; + + let agent = self + .services + .agent_registry() + .get_agent(&state.agent_id) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))? + .ok_or_else(|| acp::Error::invalid_params())?; + + let mode_state = StateBuilders::build_session_mode_state( + self.services.as_ref(), + &state.agent_id, + ) + .await + .map_err(acp::Error::from)?; + let model_state = StateBuilders::build_session_model_state(&self.services, &agent) + .await + .map_err(acp::Error::from)?; + + Ok(acp::LoadSessionResponse::new() + .modes(mode_state) + .models(model_state)) + } + + pub(super) async fn handle_cancel( + &self, + arguments: acp::CancelNotification, + ) -> std::result::Result<(), acp::Error> { + let session_key = arguments.session_id.0.as_ref().to_string(); + let cancelled = self.cancel_session(&session_key).await; + if !cancelled { + tracing::warn!("No active ACP prompt to cancel for session {}", session_key); + } + Ok(()) + } + + pub(super) async fn handle_set_session_mode( + &self, + arguments: acp::SetSessionModeRequest, + ) -> std::result::Result { + let session_key = arguments.session_id.0.as_ref().to_string(); + let mode_id = arguments.mode_id.0.as_ref(); + let agent_id = AgentId::new(mode_id); + + self.update_session_agent(&session_key, agent_id.clone()) + .await + .map_err(acp::Error::from)?; + + let notification = acp::SessionNotification::new( + arguments.session_id, + acp::SessionUpdate::CurrentModeUpdate(acp::CurrentModeUpdate::new( + acp::SessionModeId::new(mode_id.to_string()), + )), + ); + self.send_notification(notification) + .map_err(acp::Error::from)?; + + Ok(acp::SetSessionModeResponse::new()) + } + + pub(super) async fn handle_set_session_model( + &self, + arguments: acp::SetSessionModelRequest, + ) -> std::result::Result { + let model_id = ModelId::new(arguments.model_id.0.to_string()); + self.services + .set_default_model(model_id.clone()) + .await + .map_err(|error| acp::Error::into_internal_error(&*error))?; + let _ = self.services.reload_agents().await; + + let notification = acp::SessionNotification::new( + arguments.session_id, + acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new( + acp::ContentBlock::Text(acp::TextContent::new(format!( + "Model changed to: {}\n\n", + model_id + ))), + )), + ); + if let Err(error) = self.send_notification(notification) { + tracing::warn!("Failed to send model change notification: {}", error); + } + + Ok(acp::SetSessionModelResponse::default()) + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp/state_builders.rs b/crates/forge_app/src/acp/state_builders.rs new file mode 100644 index 0000000000..1301ae9323 --- /dev/null +++ b/crates/forge_app/src/acp/state_builders.rs @@ -0,0 +1,180 @@ +use std::collections::BTreeMap; +use std::sync::Arc; + +use agent_client_protocol as acp; +use forge_domain::{Agent, AgentId, McpHttpServer, McpServerConfig, Scope, ServerName}; + +use crate::{ + AgentProviderResolver, AgentRegistry, McpConfigManager, McpService, ProviderAuthService, + ProviderService, Services, +}; + +use super::conversion; +use super::error::{Error, Result}; + +pub(super) struct StateBuilders; + +impl StateBuilders { + pub(super) async fn build_session_mode_state( + services: &S, + current_agent_id: &AgentId, + ) -> Result { + let agents = services + .agent_registry() + .get_agents() + .await + .map_err(Error::Application)?; + + Ok(conversion::build_session_mode_state( + &agents, + current_agent_id, + )) + } + + pub(super) async fn build_session_model_state( + services: &Arc, + current_agent: &Agent, + ) -> Result { + let agent_provider_resolver = AgentProviderResolver::new(services.clone()); + let provider = agent_provider_resolver + .get_provider(Some(current_agent.id.clone())) + .await + .map_err(Error::Application)?; + let provider = services + .provider_auth_service() + .refresh_provider_credential(provider) + .await + .map_err(Error::Application)?; + + let mut models = services + .provider_service() + .models(provider) + .await + .map_err(Error::Application)?; + models.sort_by(|left, right| left.name.cmp(&right.name)); + + let available_models = models + .iter() + .map(|model| { + let mut model_info = acp::ModelInfo::new( + model.id.to_string(), + model.name.clone().unwrap_or_else(|| model.id.to_string()), + ) + .description(model.description.clone()); + + let mut meta = serde_json::Map::new(); + if let Some(context_length) = model.context_length { + meta.insert( + "contextLength".to_string(), + serde_json::json!(context_length), + ); + } + if let Some(tools_supported) = model.tools_supported { + meta.insert( + "toolsSupported".to_string(), + serde_json::json!(tools_supported), + ); + } + if let Some(supports_reasoning) = model.supports_reasoning { + meta.insert( + "supportsReasoning".to_string(), + serde_json::json!(supports_reasoning), + ); + } + if !model.input_modalities.is_empty() { + let modalities = model + .input_modalities + .iter() + .map(|modality| format!("{:?}", modality).to_lowercase()) + .collect::>(); + meta.insert("inputModalities".to_string(), serde_json::json!(modalities)); + } + if !meta.is_empty() { + model_info = model_info.meta(meta); + } + + model_info + }) + .collect(); + + Ok( + acp::SessionModelState::new(current_agent.model.to_string(), available_models).meta({ + let mut meta = serde_json::Map::new(); + meta.insert("searchable".to_string(), serde_json::json!(true)); + meta.insert("searchThreshold".to_string(), serde_json::json!(10)); + meta.insert("filterable".to_string(), serde_json::json!(true)); + meta.insert("groupBy".to_string(), serde_json::json!("provider")); + meta + }), + ) + } + + pub(super) async fn load_mcp_servers( + services: &S, + mcp_servers: &[acp::McpServer], + ) -> Result<()> { + let mut config = services + .mcp_config_manager() + .read_mcp_config(Some(&Scope::Local)) + .await + .map_err(Error::Application)?; + + for server in mcp_servers { + let (name, server_config) = Self::acp_to_mcp_server_config(server)?; + config.mcp_servers.insert(name, server_config); + } + + services + .mcp_config_manager() + .write_mcp_config(&config, &Scope::Local) + .await + .map_err(Error::Application)?; + services.mcp_service().reload_mcp().await.map_err(Error::Application)?; + Ok(()) + } + + fn acp_to_mcp_server_config(server: &acp::McpServer) -> Result<(ServerName, McpServerConfig)> { + match server { + acp::McpServer::Stdio(stdio) => { + let env = stdio + .env + .iter() + .map(|entry| (entry.name.clone(), entry.value.clone())) + .collect::>(); + Ok(( + ServerName::from(stdio.name.clone()), + McpServerConfig::new_stdio(stdio.command.to_string_lossy().to_string(), stdio.args.clone(), Some(env)), + )) + } + acp::McpServer::Http(http) => Ok(( + ServerName::from(http.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: http.url.clone(), + headers: http + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )), + acp::McpServer::Sse(sse) => Ok(( + ServerName::from(sse.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: sse.url.clone(), + headers: sse + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )), + _ => Err(Error::Application(anyhow::anyhow!( + "Unsupported MCP server type" + ))), + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/acp_app.rs b/crates/forge_app/src/acp_app.rs new file mode 100644 index 0000000000..c86c99bb02 --- /dev/null +++ b/crates/forge_app/src/acp_app.rs @@ -0,0 +1,89 @@ +use std::sync::Arc; + +use anyhow::Result; + +use crate::Services; + +/// ACP (Agent Communication Protocol) application orchestrator. +pub struct AcpApp { + services: Arc, +} + +impl AcpApp { + /// Creates a new ACP application orchestrator. + pub fn new(services: Arc) -> Self { + Self { services } + } + + /// Starts the ACP server over stdio transport. + pub async fn start_stdio(&self) -> Result<()> { + use agent_client_protocol as acp; + use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; + + let services = self.services.clone(); + let handle = tokio::task::spawn_blocking(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to create Tokio runtime"); + + rt.block_on(async move { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let adapter = Arc::new(crate::acp::AcpAdapter::new(services, tx)); + + let local_set = tokio::task::LocalSet::new(); + local_set + .run_until(async move { + let outgoing = tokio::io::stdout().compat_write(); + let incoming = tokio::io::stdin().compat(); + + let (conn, handle_io) = acp::AgentSideConnection::new( + adapter.clone(), + outgoing, + incoming, + |fut| { + tokio::task::spawn_local(fut); + }, + ); + + let conn = Arc::new(conn); + adapter.set_client_connection(conn.clone()).await; + + let conn_for_notifications = conn.clone(); + let notification_task = tokio::task::spawn_local(async move { + let mut rx = rx; + while let Some(session_notification) = rx.recv().await { + use agent_client_protocol::Client; + + if let Err(error) = conn_for_notifications + .session_notification(session_notification) + .await + { + tracing::error!( + "Failed to send session notification: {}", + error + ); + break; + } + } + }); + + let io_result = handle_io.await; + notification_task.abort(); + + io_result.map_err(|error| anyhow::anyhow!("ACP transport error: {}", error)) + }) + .await + }) + }); + + match handle.await { + Ok(result) => result, + Err(error) if error.is_cancelled() => { + tracing::info!("ACP server task was cancelled"); + Ok(()) + } + Err(error) => Err(anyhow::anyhow!("ACP server task panicked: {}", error)), + } + } +} \ No newline at end of file diff --git a/crates/forge_app/src/lib.rs b/crates/forge_app/src/lib.rs index 1b3295498c..21e419ea48 100644 --- a/crates/forge_app/src/lib.rs +++ b/crates/forge_app/src/lib.rs @@ -1,3 +1,5 @@ +mod acp; +mod acp_app; mod agent; mod agent_executor; mod agent_provider_resolver; @@ -38,6 +40,7 @@ pub mod utils; mod walker; mod workspace_status; +pub use acp_app::*; pub use agent::*; pub use agent_provider_resolver::*; pub use app::*; diff --git a/crates/forge_main/src/acp_runner.rs b/crates/forge_main/src/acp_runner.rs new file mode 100644 index 0000000000..e04ea42e31 --- /dev/null +++ b/crates/forge_main/src/acp_runner.rs @@ -0,0 +1,57 @@ +use std::future::Future; + +use anyhow::Result; +use forge_api::API; + +pub trait MachineStdioApi { + fn acp_start_stdio(&self) -> impl Future> + Send; +} + +impl MachineStdioApi for T { + fn acp_start_stdio(&self) -> impl Future> + Send { + API::acp_start_stdio(self) + } +} + +pub async fn run_machine_stdio_server(api: &A) -> Result<()> { + api.acp_start_stdio().await +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, atomic::{AtomicBool, Ordering}}; + + use anyhow::Result; + + use super::{MachineStdioApi, run_machine_stdio_server}; + + struct MockApi { + called: Arc, + } + + impl MockApi { + fn new(called: Arc) -> Self { + Self { called } + } + } + + impl MachineStdioApi for MockApi { + fn acp_start_stdio(&self) -> impl std::future::Future> + Send { + self.called.store(true, Ordering::SeqCst); + async { Ok(()) } + } + } + + #[tokio::test] + async fn test_run_machine_stdio_server_delegates_to_api_transport() -> Result<()> { + let called = Arc::new(AtomicBool::new(false)); + let fixture = MockApi::new(called.clone()); + + run_machine_stdio_server(&fixture).await?; + + let actual = called.load(Ordering::SeqCst); + let expected = true; + assert_eq!(actual, expected); + Ok(()) + } +} \ No newline at end of file diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index 4cb6ff66f6..6d95e12874 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -82,6 +82,9 @@ pub enum TopLevelCommand { /// Manage agents. Agent(AgentCommandGroup), + /// Run machine-oriented commands. + Machine(MachineCommandGroup), + /// Generate shell extension scripts. #[command(subcommand, alias = "extension")] Zsh(ZshCommandGroup), @@ -203,6 +206,19 @@ pub enum AgentCommand { List, } +/// Command group for machine-oriented interfaces. +#[derive(Parser, Debug, Clone)] +pub struct MachineCommandGroup { + #[command(subcommand)] + pub command: MachineCommand, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum MachineCommand { + /// Run the machine interface over stdio. + Stdio, +} + /// Command group for workspace management. #[derive(Parser, Debug, Clone)] pub struct WorkspaceCommandGroup { @@ -1864,4 +1880,25 @@ mod tests { }; assert!(!actual); } + + #[test] + fn test_machine_stdio_command() { + let fixture = Cli::parse_from(["forge", "machine", "stdio"]); + let actual = matches!( + fixture.subcommands, + Some(TopLevelCommand::Machine(MachineCommandGroup { + command: MachineCommand::Stdio, + })) + ); + let expected = true; + assert_eq!(actual, expected); + } + + #[test] + fn test_machine_stdio_is_not_interactive() { + let fixture = Cli::parse_from(["forge", "machine", "stdio"]); + let actual = fixture.is_interactive(); + let expected = false; + assert_eq!(actual, expected); + } } diff --git a/crates/forge_main/src/lib.rs b/crates/forge_main/src/lib.rs index 1fc22a116d..1690693660 100644 --- a/crates/forge_main/src/lib.rs +++ b/crates/forge_main/src/lib.rs @@ -1,4 +1,5 @@ pub mod banner; +mod acp_runner; mod cli; mod completer; mod conversation_selector; diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 3d6b946bac..c95d2daf3f 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -388,6 +388,14 @@ impl A + Send + Sync> UI { } return Ok(()); } + TopLevelCommand::Machine(machine_group) => { + match machine_group.command { + crate::cli::MachineCommand::Stdio => { + crate::acp_runner::run_machine_stdio_server(self.api.as_ref()).await?; + return Ok(()); + } + } + } TopLevelCommand::List(list_group) => { let porcelain = list_group.porcelain; match list_group.command { From 8fe513fabf91987d440a805b6e1d8a73ef04f230 Mon Sep 17 00:00:00 2001 From: flazouh Date: Mon, 6 Apr 2026 01:01:07 +0300 Subject: [PATCH 2/2] fix(acp): harden machine stdio transport - Replace unbounded notification channel with bounded (1024) to apply backpressure when the client stalls - Add per-session model override to prevent concurrent sessions from interfering with each other - Replace From impl with explicit into_acp_error() per project guidelines - Extract classify_mcp_tool() and convert to free functions, removing the unnecessary ToolOutputConverter struct - Validate MCP server names (length, charset) to prevent injection - Add MAX_BLOB_SIZE (50 MB) guard on base64-decoded resources - Add I/O timeout (5 min) and graceful shutdown drain (5 s) to prevent indefinite hangs - Track cancellation via AtomicBool across loop iterations - Log warnings instead of silently ignoring reload/config errors - Add tests for tool kind mapping, file extraction, and edge cases Co-Authored-By: ForgeCode --- crates/forge_app/src/acp/adapter.rs | 47 +++- crates/forge_app/src/acp/conversion.rs | 246 ++++++++++++------- crates/forge_app/src/acp/error.rs | 20 +- crates/forge_app/src/acp/mod.rs | 2 +- crates/forge_app/src/acp/prompt_handler.rs | 42 +++- crates/forge_app/src/acp/session_handlers.rs | 45 +++- crates/forge_app/src/acp/state_builders.rs | 117 ++++++--- crates/forge_app/src/acp_app.rs | 43 +++- crates/forge_main/src/acp_runner.rs | 4 +- crates/forge_main/src/cli.rs | 1 + 10 files changed, 396 insertions(+), 171 deletions(-) diff --git a/crates/forge_app/src/acp/adapter.rs b/crates/forge_app/src/acp/adapter.rs index d7c1ebc73d..9ac30b45bf 100644 --- a/crates/forge_app/src/acp/adapter.rs +++ b/crates/forge_app/src/acp/adapter.rs @@ -2,38 +2,46 @@ use std::collections::HashMap; use std::sync::Arc; use agent_client_protocol as acp; -use forge_domain::{AgentId, ConversationId}; +use forge_domain::{AgentId, ConversationId, ModelId}; use tokio::sync::{Mutex, Notify, mpsc}; use crate::Services; use super::error::{Error, Result}; +/// Maximum number of buffered session notifications before backpressure. +const NOTIFICATION_CHANNEL_CAPACITY: usize = 1024; + #[derive(Clone)] pub(super) struct SessionState { pub conversation_id: ConversationId, pub agent_id: AgentId, + /// Session-scoped model override. When set, prompts use this model + /// instead of the global default. + pub model_id: Option, pub cancel_notify: Option>, } pub(crate) struct AcpAdapter { pub(super) services: Arc, - pub(super) session_update_tx: mpsc::UnboundedSender, + pub(super) session_update_tx: mpsc::Sender, pub(super) client_conn: Arc>>>, sessions: Arc>>, } impl AcpAdapter { + /// Creates a new ACP adapter and returns the notification receiver. pub(crate) fn new( services: Arc, - session_update_tx: mpsc::UnboundedSender, - ) -> Self { - Self { + ) -> (Self, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(NOTIFICATION_CHANNEL_CAPACITY); + let adapter = Self { services, - session_update_tx, + session_update_tx: tx, client_conn: Arc::new(Mutex::new(None)), sessions: Arc::new(Mutex::new(HashMap::new())), - } + }; + (adapter, rx) } pub(crate) async fn set_client_connection(&self, conn: Arc) { @@ -44,6 +52,13 @@ impl AcpAdapter { self.sessions.lock().await.insert(session_id, state); } + /// Removes a session from the adapter. Currently unused but available + /// for future session lifecycle management (TTL, explicit close). + #[allow(dead_code)] + pub(super) async fn remove_session(&self, session_id: &str) { + self.sessions.lock().await.remove(session_id); + } + pub(super) async fn session_state(&self, session_id: &str) -> Result { self.sessions .lock() @@ -66,6 +81,19 @@ impl AcpAdapter { Ok(()) } + pub(super) async fn update_session_model( + &self, + session_id: &str, + model_id: ModelId, + ) -> Result<()> { + let mut sessions = self.sessions.lock().await; + let state = sessions + .get_mut(session_id) + .ok_or_else(|| Error::Application(anyhow::anyhow!("Session not found")))?; + state.model_id = Some(model_id); + Ok(()) + } + pub(super) async fn set_cancel_notify( &self, session_id: &str, @@ -107,6 +135,7 @@ impl AcpAdapter { .or_insert_with(|| SessionState { conversation_id, agent_id, + model_id: None, cancel_notify: None, }) .clone() @@ -114,7 +143,7 @@ impl AcpAdapter { pub(super) fn send_notification(&self, notification: acp::SessionNotification) -> Result<()> { self.session_update_tx - .send(notification) + .try_send(notification) .map_err(|_| Error::Application(anyhow::anyhow!("Failed to send notification"))) } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/conversion.rs b/crates/forge_app/src/acp/conversion.rs index 31a9021ce0..640fe42868 100644 --- a/crates/forge_app/src/acp/conversion.rs +++ b/crates/forge_app/src/acp/conversion.rs @@ -8,6 +8,16 @@ use forge_domain::{ use super::error::{Error, Result}; +/// Maximum size in bytes for base64-encoded blob resources. +/// Protects against OOM from oversized client payloads. +const MAX_BLOB_SIZE: usize = 50 * 1024 * 1024; // 50 MB + +/// Maps a Forge tool name to an ACP ToolKind. +/// +/// Native Forge tools are classified by exact match. MCP tools (prefixed +/// with `mcp_`) use best-effort keyword heuristics and default to `Other` +/// when the name is ambiguous. The heuristic is order-dependent: the first +/// matching keyword category wins. pub(crate) fn map_tool_kind(tool_name: &ToolName) -> acp::ToolKind { match tool_name.as_str() { "read" => acp::ToolKind::Read, @@ -17,62 +27,51 @@ pub(crate) fn map_tool_kind(tool_name: &ToolName) -> acp::ToolKind { "shell" => acp::ToolKind::Execute, "fetch" => acp::ToolKind::Fetch, "sage" => acp::ToolKind::Think, - _ => { - let name = tool_name.as_str(); - if name.starts_with("mcp_") { - if name.contains("read") - || name.contains("get") - || name.contains("fetch") - || name.contains("list") - || name.contains("show") - || name.contains("view") - || name.contains("load") - { - acp::ToolKind::Read - } else if name.contains("search") - || name.contains("query") - || name.contains("find") - || name.contains("filter") - || name.contains("lookup") - { - acp::ToolKind::Search - } else if name.contains("write") - || name.contains("update") - || name.contains("create") - || name.contains("set") - || name.contains("add") - || name.contains("insert") - || name.contains("push") - || name.contains("merge") - || name.contains("fork") - || name.contains("comment") - || name.contains("assign") - || name.contains("request") - { - acp::ToolKind::Edit - } else if name.contains("delete") - || name.contains("remove") - || name.contains("drop") - || name.contains("clear") - || name.contains("close") - || name.contains("cancel") - { - acp::ToolKind::Delete - } else if name.contains("execute") - || name.contains("run") - || name.contains("start") - || name.contains("invoke") - || name.contains("call") - { - acp::ToolKind::Execute - } else { - acp::ToolKind::Other - } - } else { - acp::ToolKind::Other - } + _ => classify_mcp_tool(tool_name.as_str()), + } +} + +/// Best-effort classification for MCP tools by keyword heuristic. +/// +/// Falls back to `Other` for non-MCP tools or when no keyword matches. +/// The match order matters: a tool named `mcp_get_search_results` would +/// classify as `Read` (matches "get" before "search"). +fn classify_mcp_tool(name: &str) -> acp::ToolKind { + if !name.starts_with("mcp_") { + return acp::ToolKind::Other; + } + + // Strip the "mcp__" prefix to get the action portion. + // E.g. "mcp_github_list_issues" → check against "list_issues". + let action = name + .strip_prefix("mcp_") + .and_then(|rest| rest.split_once('_').map(|(_, action)| action)) + .unwrap_or(name); + + const READ_KEYWORDS: &[&str] = &["read", "get", "fetch", "list", "show", "view", "load"]; + const SEARCH_KEYWORDS: &[&str] = &["search", "query", "find", "filter", "lookup"]; + const EDIT_KEYWORDS: &[&str] = &[ + "write", "update", "create", "set", "add", "insert", "push", "merge", + "fork", "comment", "assign", "request", + ]; + const DELETE_KEYWORDS: &[&str] = &["delete", "remove", "drop", "clear", "close", "cancel"]; + const EXECUTE_KEYWORDS: &[&str] = &["execute", "run", "start", "invoke", "call"]; + + let checks: &[(&[&str], acp::ToolKind)] = &[ + (READ_KEYWORDS, acp::ToolKind::Read), + (SEARCH_KEYWORDS, acp::ToolKind::Search), + (EDIT_KEYWORDS, acp::ToolKind::Edit), + (DELETE_KEYWORDS, acp::ToolKind::Delete), + (EXECUTE_KEYWORDS, acp::ToolKind::Execute), + ]; + + for (keywords, kind) in checks { + if keywords.iter().any(|kw| action.contains(kw)) { + return kind.clone(); } } + + acp::ToolKind::Other } pub(crate) fn extract_file_locations( @@ -112,44 +111,33 @@ pub(crate) fn map_tool_call_to_acp(tool_call: &ToolCallFull) -> acp::ToolCall { ) } -pub(crate) struct ToolOutputConverter { - _private: (), +/// Converts a ToolOutput into ACP content blocks. +pub(crate) fn convert_tool_output(output: &ToolOutput) -> Vec { + output + .values + .iter() + .filter_map(convert_tool_value) + .collect() } -impl ToolOutputConverter { - pub(crate) fn new(output: &ToolOutput) -> Self { - let _ = output; - Self { _private: () } - } - - pub(crate) fn convert(output: &ToolOutput) -> Vec { - let converter = Self::new(output); - output - .values - .iter() - .filter_map(|value| converter.convert_value(value)) - .collect() - } - - fn convert_value(&self, value: &ToolValue) -> Option { - match value { - ToolValue::Text(text) => self.convert_text(text), - ToolValue::AI { value, .. } => self.convert_text(value), - ToolValue::Image(image) => Some(acp::ToolCallContent::Content(acp::Content::new( - acp::ContentBlock::Image(acp::ImageContent::new(image.data(), image.mime_type())), - ))), - ToolValue::Empty => None, - } +fn convert_tool_value(value: &ToolValue) -> Option { + match value { + ToolValue::Text(text) => convert_text(text), + ToolValue::AI { value, .. } => convert_text(value), + ToolValue::Image(image) => Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Image(acp::ImageContent::new(image.data(), image.mime_type())), + ))), + ToolValue::Empty => None, } +} - fn convert_text(&self, text: &str) -> Option { - if text.is_empty() { - None - } else { - Some(acp::ToolCallContent::Content(acp::Content::new( - acp::ContentBlock::Text(acp::TextContent::new(text.to_string())), - ))) - } +fn convert_text(text: &str) -> Option { + if text.is_empty() { + None + } else { + Some(acp::ToolCallContent::Content(acp::Content::new( + acp::ContentBlock::Text(acp::TextContent::new(text.to_string())), + ))) } } @@ -159,6 +147,12 @@ pub(crate) fn acp_resource_to_attachment(resource: &acp::EmbeddedResource) -> Re (text_resource.text.clone(), text_resource.uri.clone()) } acp::EmbeddedResourceResource::BlobResourceContents(blob_resource) => { + if blob_resource.blob.len() > MAX_BLOB_SIZE { + return Err(Error::Application(anyhow::anyhow!( + "Blob resource exceeds maximum size of {} bytes", + MAX_BLOB_SIZE + ))); + } let decoded = base64::Engine::decode( &base64::engine::general_purpose::STANDARD, &blob_resource.blob, @@ -237,11 +231,19 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_uri_to_path_strips_file_prefix() { + let fixture = "file:///home/user/file.txt"; + let actual = uri_to_path(fixture); + let expected = "/home/user/file.txt".to_string(); + assert_eq!(actual, expected); + } + #[test] fn test_markdown_sent_to_acp_not_xml() { let fixture = ToolOutput::text("## File: test.txt\n\nContent here"); - let actual = ToolOutputConverter::convert(&fixture); + let actual = convert_tool_output(&fixture); assert_eq!(actual.len(), 1); if let Some(acp::ToolCallContent::Content(content)) = actual.first() { @@ -259,7 +261,7 @@ mod tests { fn test_ai_output_sent_to_acp_as_text() { let fixture = ToolOutput::ai(ConversationId::generate(), "Agent result"); - let actual = ToolOutputConverter::convert(&fixture); + let actual = convert_tool_output(&fixture); assert_eq!(actual.len(), 1); if let Some(acp::ToolCallContent::Content(content)) = actual.first() { @@ -278,7 +280,7 @@ mod tests { let image = Image::new_bytes(vec![1, 2, 3, 4], "image/png".to_string()); let fixture = ToolOutput::image(image); - let actual = ToolOutputConverter::convert(&fixture); + let actual = convert_tool_output(&fixture); assert_eq!(actual.len(), 1); if let Some(acp::ToolCallContent::Content(content)) = actual.first() { @@ -287,4 +289,64 @@ mod tests { panic!("Expected content"); } } -} \ No newline at end of file + + #[test] + fn test_empty_output_produces_no_content() { + let fixture = ToolOutput::text(""); + let actual = convert_tool_output(&fixture); + let expected: Vec = vec![]; + assert_eq!(actual.len(), expected.len()); + } + + #[test] + fn test_map_tool_kind_native_tools() { + let fixture = ToolName::new("read"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Read)); + } + + #[test] + fn test_map_tool_kind_mcp_read() { + let fixture = ToolName::new("mcp_github_list_issues"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Read)); + } + + #[test] + fn test_map_tool_kind_mcp_search() { + let fixture = ToolName::new("mcp_db_search_records"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Search)); + } + + #[test] + fn test_map_tool_kind_unknown_defaults_to_other() { + let fixture = ToolName::new("mcp_custom_foobar"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Other)); + } + + #[test] + fn test_map_tool_kind_non_mcp_unknown() { + let fixture = ToolName::new("custom_tool"); + let actual = map_tool_kind(&fixture); + assert!(matches!(actual, acp::ToolKind::Other)); + } + + #[test] + fn test_extract_file_locations_read_tool() { + let fixture_name = ToolName::new("read"); + let fixture_args = serde_json::json!({"file_path": "/tmp/test.rs"}); + let actual = extract_file_locations(&fixture_name, &fixture_args); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_extract_file_locations_unknown_tool() { + let fixture_name = ToolName::new("shell"); + let fixture_args = serde_json::json!({"command": "ls"}); + let actual = extract_file_locations(&fixture_name, &fixture_args); + let expected: Vec = vec![]; + assert_eq!(actual.len(), expected.len()); + } +} diff --git a/crates/forge_app/src/acp/error.rs b/crates/forge_app/src/acp/error.rs index 1dbf4a0696..9a452b4a93 100644 --- a/crates/forge_app/src/acp/error.rs +++ b/crates/forge_app/src/acp/error.rs @@ -14,14 +14,16 @@ pub enum Error { Io(#[from] std::io::Error), } -impl From for acp::Error { - fn from(error: Error) -> Self { - match error { - Error::Protocol(error) => error, - Error::Application(error) => { - acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error) - } - Error::Io(error) => acp::Error::into_internal_error(&error), +/// Converts a domain Error into an acp::Error. +/// +/// AGENTS.md forbids blanket `From` impls for domain error conversion. +/// Call this explicitly at each `.map_err()` site instead. +pub fn into_acp_error(error: Error) -> acp::Error { + match error { + Error::Protocol(error) => error, + Error::Application(error) => { + acp::Error::into_internal_error(error.as_ref() as &dyn std::error::Error) } + Error::Io(error) => acp::Error::into_internal_error(&error), } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/mod.rs b/crates/forge_app/src/acp/mod.rs index d90a00fda9..716d103c20 100644 --- a/crates/forge_app/src/acp/mod.rs +++ b/crates/forge_app/src/acp/mod.rs @@ -85,4 +85,4 @@ impl agent_client_protocol::Agent for AcpAdapter { > { self.handle_set_session_model(arguments).await } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/prompt_handler.rs b/crates/forge_app/src/acp/prompt_handler.rs index 08bfba5f0b..d057b76c54 100644 --- a/crates/forge_app/src/acp/prompt_handler.rs +++ b/crates/forge_app/src/acp/prompt_handler.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use agent_client_protocol as acp; use agent_client_protocol::Client; @@ -12,7 +13,7 @@ use crate::{ForgeApp, Services}; use super::adapter::AcpAdapter; use super::conversion; -use super::error::{Error, Result}; +use super::error::{self, Error, Result}; impl AcpAdapter { pub(super) async fn handle_prompt( @@ -20,7 +21,7 @@ impl AcpAdapter { arguments: acp::PromptRequest, ) -> std::result::Result { let session_key = arguments.session_id.0.as_ref().to_string(); - let session = self.session_state(&session_key).await.map_err(acp::Error::from)?; + let session = self.session_state(&session_key).await.map_err(error::into_acp_error)?; let mut prompt_text_parts = Vec::new(); let mut attachments = Vec::new(); @@ -48,12 +49,21 @@ impl AcpAdapter { let prompt_text = prompt_text_parts.join("\n"); let cancel_notify = Arc::new(Notify::new()); + let cancelled = Arc::new(AtomicBool::new(false)); self.set_cancel_notify(&session_key, Some(cancel_notify.clone())) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let response = self - .run_prompt_loop(&arguments.session_id, &session_key, session, prompt_text, attachments, cancel_notify) + .run_prompt_loop( + &arguments.session_id, + &session_key, + session, + prompt_text, + attachments, + cancel_notify, + cancelled, + ) .await; let _ = self.set_cancel_notify(&session_key, None).await; @@ -68,12 +78,21 @@ impl AcpAdapter { prompt_text: String, attachments: Vec, cancel_notify: Arc, + cancelled: Arc, ) -> std::result::Result { let mut event = Event::new(EventValue::text(prompt_text)); event.attachments = attachments; let mut chat_request = ChatRequest::new(event, session.conversation_id); loop { + // Check if cancellation was requested before starting a new + // chat round (handles the case where cancel arrives between + // loop iterations). + if cancelled.load(Ordering::SeqCst) { + tracing::info!("ACP prompt cancelled for session {}", session_key); + return Ok(acp::PromptResponse::new(acp::StopReason::Cancelled)); + } + let app = ForgeApp::new(self.services.clone()); let mut stream = app .chat(session.agent_id.clone(), chat_request) @@ -85,6 +104,7 @@ impl AcpAdapter { loop { tokio::select! { _ = cancel_notify.notified() => { + cancelled.store(true, Ordering::SeqCst); tracing::info!("ACP prompt cancelled for session {}", session_key); return Ok(acp::PromptResponse::new(acp::StopReason::Cancelled)); } @@ -135,7 +155,7 @@ impl AcpAdapter { )), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } } ChatResponse::ToolCallStart { tool_call, .. } => { @@ -146,10 +166,10 @@ impl AcpAdapter { ), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } ChatResponse::ToolCallEnd(tool_result) => { - let content = conversion::ToolOutputConverter::convert(&tool_result.output); + let content = conversion::convert_tool_output(&tool_result.output); let status = if tool_result.output.is_error { acp::ToolCallStatus::Failed } else { @@ -169,7 +189,7 @@ impl AcpAdapter { acp::SessionUpdate::ToolCallUpdate(update), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } ChatResponse::TaskComplete => {} ChatResponse::RetryAttempt { .. } => {} @@ -177,7 +197,7 @@ impl AcpAdapter { let should_continue = self .request_continue_permission(session_id, &reason) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; if should_continue { *continue_after_interrupt = true; } @@ -203,7 +223,7 @@ impl AcpAdapter { )), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } } ChatResponseContent::ToolInput(_) => {} @@ -280,4 +300,4 @@ fn format_interruption(reason: &InterruptionReason) -> (String, String) { "Forge reached the maximum number of requests for this turn.".to_string(), ), } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/session_handlers.rs b/crates/forge_app/src/acp/session_handlers.rs index 4d6ce8b7ff..4b3c7a576e 100644 --- a/crates/forge_app/src/acp/session_handlers.rs +++ b/crates/forge_app/src/acp/session_handlers.rs @@ -4,6 +4,7 @@ use forge_domain::{AgentId, Conversation, ConversationId, ModelId}; use crate::{AgentRegistry, AppConfigService, ConversationService, Services}; use super::adapter::{AcpAdapter, SessionState}; +use super::error; use super::state_builders::StateBuilders; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -29,10 +30,18 @@ impl AcpAdapter { )) } + /// Handles ACP authentication. + /// + /// This is intentionally a no-op. The stdio transport inherits OS-level + /// process isolation: only the parent process (e.g. Acepe) that spawned + /// `forge machine stdio` can read/write the stdin/stdout pipes. No + /// network listener is opened, so no additional authentication is + /// required. See `AcpApp::start_stdio` for the full trust model. pub(super) async fn handle_authenticate( &self, _arguments: acp::AuthenticateRequest, ) -> std::result::Result { + tracing::debug!("ACP authenticate: no-op (stdio transport uses OS process isolation)"); Ok(acp::AuthenticateResponse::default()) } @@ -43,7 +52,7 @@ impl AcpAdapter { if !arguments.mcp_servers.is_empty() { StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } let active_agent_id = self @@ -69,6 +78,7 @@ impl AcpAdapter { SessionState { conversation_id, agent_id: active_agent_id.clone(), + model_id: None, cancel_notify: None, }, ) @@ -92,10 +102,10 @@ impl AcpAdapter { &active_agent_id, ) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let model_state = StateBuilders::build_session_model_state(&self.services, &agent) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; Ok(acp::NewSessionResponse::new(session_id) .modes(mode_state) @@ -109,7 +119,7 @@ impl AcpAdapter { if !arguments.mcp_servers.is_empty() { StateBuilders::load_mcp_servers(self.services.as_ref(), &arguments.mcp_servers) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; } let session_key = arguments.session_id.0.as_ref().to_string(); @@ -150,10 +160,10 @@ impl AcpAdapter { &state.agent_id, ) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let model_state = StateBuilders::build_session_model_state(&self.services, &agent) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; Ok(acp::LoadSessionResponse::new() .modes(mode_state) @@ -182,7 +192,7 @@ impl AcpAdapter { self.update_session_agent(&session_key, agent_id.clone()) .await - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; let notification = acp::SessionNotification::new( arguments.session_id, @@ -191,21 +201,36 @@ impl AcpAdapter { )), ); self.send_notification(notification) - .map_err(acp::Error::from)?; + .map_err(error::into_acp_error)?; Ok(acp::SetSessionModeResponse::new()) } + /// Handles session model changes. + /// + /// The model preference is stored per-session so that concurrent ACP + /// clients do not interfere with each other. The global default model + /// is also updated for backward compatibility with non-ACP code paths. pub(super) async fn handle_set_session_model( &self, arguments: acp::SetSessionModelRequest, ) -> std::result::Result { + let session_key = arguments.session_id.0.as_ref().to_string(); let model_id = ModelId::new(arguments.model_id.0.to_string()); + + // Store per-session model preference. + self.update_session_model(&session_key, model_id.clone()) + .await + .map_err(error::into_acp_error)?; + + // Also update the global default for backward compatibility. self.services .set_default_model(model_id.clone()) .await .map_err(|error| acp::Error::into_internal_error(&*error))?; - let _ = self.services.reload_agents().await; + if let Err(error) = self.services.reload_agents().await { + tracing::warn!("Failed to reload agents after model change: {}", error); + } let notification = acp::SessionNotification::new( arguments.session_id, @@ -222,4 +247,4 @@ impl AcpAdapter { Ok(acp::SetSessionModelResponse::default()) } -} \ No newline at end of file +} diff --git a/crates/forge_app/src/acp/state_builders.rs b/crates/forge_app/src/acp/state_builders.rs index 1301ae9323..e1d2770b67 100644 --- a/crates/forge_app/src/acp/state_builders.rs +++ b/crates/forge_app/src/acp/state_builders.rs @@ -12,6 +12,9 @@ use crate::{ use super::conversion; use super::error::{Error, Result}; +/// Maximum allowed length for an MCP server name (prevents injection). +const MAX_SERVER_NAME_LEN: usize = 128; + pub(super) struct StateBuilders; impl StateBuilders { @@ -109,6 +112,15 @@ impl StateBuilders { ) } + /// Loads MCP server configurations provided by the ACP client. + /// + /// # Trust model + /// + /// The stdio transport inherits OS-level process isolation, so the + /// client is the parent process (Acepe). Server names are validated + /// to prevent injection. The configs are written to the local scope + /// only and do not persist across Forge restarts unless the caller + /// explicitly saves them. pub(super) async fn load_mcp_servers( services: &S, mcp_servers: &[acp::McpServer], @@ -119,10 +131,23 @@ impl StateBuilders { .await .map_err(Error::Application)?; - for server in mcp_servers { - let (name, server_config) = Self::acp_to_mcp_server_config(server)?; - config.mcp_servers.insert(name, server_config); - } + let server_names: Vec = mcp_servers + .iter() + .filter_map(|s| { + match Self::acp_to_mcp_server_config(s) { + Ok((name, server_config)) => { + config.mcp_servers.insert(name.clone(), server_config); + Some(name.to_string()) + } + Err(error) => { + tracing::warn!("Skipping invalid MCP server config: {}", error); + None + } + } + }) + .collect(); + + tracing::info!("Loading {} MCP servers from ACP client: {:?}", server_names.len(), server_names); services .mcp_config_manager() @@ -136,6 +161,7 @@ impl StateBuilders { fn acp_to_mcp_server_config(server: &acp::McpServer) -> Result<(ServerName, McpServerConfig)> { match server { acp::McpServer::Stdio(stdio) => { + Self::validate_server_name(&stdio.name)?; let env = stdio .env .iter() @@ -146,35 +172,64 @@ impl StateBuilders { McpServerConfig::new_stdio(stdio.command.to_string_lossy().to_string(), stdio.args.clone(), Some(env)), )) } - acp::McpServer::Http(http) => Ok(( - ServerName::from(http.name.clone()), - McpServerConfig::Http(McpHttpServer { - url: http.url.clone(), - headers: http - .headers - .iter() - .map(|header| (header.name.clone(), header.value.clone())) - .collect(), - timeout: None, - disable: false, - }), - )), - acp::McpServer::Sse(sse) => Ok(( - ServerName::from(sse.name.clone()), - McpServerConfig::Http(McpHttpServer { - url: sse.url.clone(), - headers: sse - .headers - .iter() - .map(|header| (header.name.clone(), header.value.clone())) - .collect(), - timeout: None, - disable: false, - }), - )), + acp::McpServer::Http(http) => { + Self::validate_server_name(&http.name)?; + Ok(( + ServerName::from(http.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: http.url.clone(), + headers: http + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )) + } + acp::McpServer::Sse(sse) => { + Self::validate_server_name(&sse.name)?; + Ok(( + ServerName::from(sse.name.clone()), + McpServerConfig::Http(McpHttpServer { + url: sse.url.clone(), + headers: sse + .headers + .iter() + .map(|header| (header.name.clone(), header.value.clone())) + .collect(), + timeout: None, + disable: false, + }), + )) + } _ => Err(Error::Application(anyhow::anyhow!( "Unsupported MCP server type" ))), } } -} \ No newline at end of file + + /// Validates that an MCP server name is safe to use as a config key. + fn validate_server_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(Error::Application(anyhow::anyhow!( + "MCP server name must not be empty" + ))); + } + if name.len() > MAX_SERVER_NAME_LEN { + return Err(Error::Application(anyhow::anyhow!( + "MCP server name exceeds maximum length of {} characters", + MAX_SERVER_NAME_LEN + ))); + } + // Only allow alphanumeric, hyphens, underscores, and dots. + if !name.chars().all(|c| c.is_alphanumeric() || c == '-' || c == '_' || c == '.') { + return Err(Error::Application(anyhow::anyhow!( + "MCP server name '{}' contains invalid characters (allowed: alphanumeric, -, _, .)", + name + ))); + } + Ok(()) + } +} diff --git a/crates/forge_app/src/acp_app.rs b/crates/forge_app/src/acp_app.rs index c86c99bb02..697be3d554 100644 --- a/crates/forge_app/src/acp_app.rs +++ b/crates/forge_app/src/acp_app.rs @@ -1,4 +1,5 @@ use std::sync::Arc; +use std::time::Duration; use anyhow::Result; @@ -9,6 +10,12 @@ pub struct AcpApp { services: Arc, } +/// Maximum time to wait for ACP I/O before considering the client hung. +const IO_TIMEOUT: Duration = Duration::from_secs(300); + +/// Maximum time to wait for pending notifications to drain on shutdown. +const SHUTDOWN_DRAIN_TIMEOUT: Duration = Duration::from_secs(5); + impl AcpApp { /// Creates a new ACP application orchestrator. pub fn new(services: Arc) -> Self { @@ -16,6 +23,13 @@ impl AcpApp { } /// Starts the ACP server over stdio transport. + /// + /// # Trust model + /// + /// The stdio transport inherits OS-level process isolation: only the + /// parent process (e.g. Acepe) that spawned `forge machine stdio` can + /// read/write the stdin/stdout pipes. No network listener is opened. + /// Authentication is therefore a no-op by design. pub async fn start_stdio(&self) -> Result<()> { use agent_client_protocol as acp; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; @@ -25,11 +39,11 @@ impl AcpApp { let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() - .expect("Failed to create Tokio runtime"); + .map_err(|e| anyhow::anyhow!("Failed to create Tokio runtime: {}", e))?; rt.block_on(async move { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - let adapter = Arc::new(crate::acp::AcpAdapter::new(services, tx)); + let (adapter, mut rx) = crate::acp::AcpAdapter::new(services); + let adapter = Arc::new(adapter); let local_set = tokio::task::LocalSet::new(); local_set @@ -51,7 +65,6 @@ impl AcpApp { let conn_for_notifications = conn.clone(); let notification_task = tokio::task::spawn_local(async move { - let mut rx = rx; while let Some(session_notification) = rx.recv().await { use agent_client_protocol::Client; @@ -68,8 +81,24 @@ impl AcpApp { } }); - let io_result = handle_io.await; - notification_task.abort(); + // Wait for I/O with a timeout to prevent indefinite hangs + // when the client stalls. + let io_result = match tokio::time::timeout(IO_TIMEOUT, handle_io).await { + Ok(result) => result, + Err(_) => { + tracing::warn!("ACP I/O timed out after {:?}", IO_TIMEOUT); + notification_task.abort(); + return Err(anyhow::anyhow!( + "ACP transport timed out after {:?}", + IO_TIMEOUT + )); + } + }; + + // Graceful shutdown: give the notification task time to + // drain pending messages instead of aborting immediately. + drop(adapter); // drops the sender half → rx.recv() returns None + let _ = tokio::time::timeout(SHUTDOWN_DRAIN_TIMEOUT, notification_task).await; io_result.map_err(|error| anyhow::anyhow!("ACP transport error: {}", error)) }) @@ -86,4 +115,4 @@ impl AcpApp { Err(error) => Err(anyhow::anyhow!("ACP server task panicked: {}", error)), } } -} \ No newline at end of file +} diff --git a/crates/forge_main/src/acp_runner.rs b/crates/forge_main/src/acp_runner.rs index e04ea42e31..32671c5411 100644 --- a/crates/forge_main/src/acp_runner.rs +++ b/crates/forge_main/src/acp_runner.rs @@ -3,6 +3,7 @@ use std::future::Future; use anyhow::Result; use forge_api::API; +/// Abstraction over the ACP stdio transport entry point for testability. pub trait MachineStdioApi { fn acp_start_stdio(&self) -> impl Future> + Send; } @@ -13,6 +14,7 @@ impl MachineStdioApi for T { } } +/// Starts the ACP machine stdio server by delegating to the provided API. pub async fn run_machine_stdio_server(api: &A) -> Result<()> { api.acp_start_stdio().await } @@ -54,4 +56,4 @@ mod tests { assert_eq!(actual, expected); Ok(()) } -} \ No newline at end of file +} diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index 6d95e12874..3972c151c2 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -213,6 +213,7 @@ pub struct MachineCommandGroup { pub command: MachineCommand, } +/// Machine-oriented subcommands for non-interactive transport protocols. #[derive(Subcommand, Debug, Clone)] pub enum MachineCommand { /// Run the machine interface over stdio.