diff --git a/Cargo.lock b/Cargo.lock index ab087820c8..e9b0b0b366 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2117,6 +2117,7 @@ dependencies = [ "http 1.4.0", "libsqlite3-sys", "oauth2", + "open", "pretty_assertions", "reqwest 0.12.28", "reqwest-eventsource", @@ -5287,6 +5288,7 @@ dependencies = [ "chrono", "futures", "http 1.4.0", + "oauth2", "pastey", "pin-project-lite", "process-wrap", @@ -5301,6 +5303,7 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", + "url", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ac214161cb..b12ca183ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -126,6 +126,7 @@ rmcp = { version = "0.10.0", features = [ "transport-sse-client-reqwest", "transport-child-process", "transport-streamable-http-client-reqwest", + "auth", ] } open = "5.3.2" nucleo = "0.5.0" diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index fdc36a67d3..3ea061cbdc 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -238,4 +238,13 @@ pub trait API: Sync + Send { &self, data_parameters: DataGenerationParameters, ) -> Result>>; + + /// Authenticate with an MCP server via OAuth flow + async fn mcp_auth(&self, server_url: &str) -> Result<()>; + + /// Remove stored OAuth credentials for an MCP server (or all servers) + async fn mcp_logout(&self, server_url: Option<&str>) -> Result<()>; + + /// Check the OAuth authentication status of an MCP server + async fn mcp_auth_status(&self, server_url: &str) -> Result; } diff --git a/crates/forge_api/src/forge_api.rs b/crates/forge_api/src/forge_api.rs index 8ed1ab99e2..d53ce3cd59 100644 --- a/crates/forge_api/src/forge_api.rs +++ b/crates/forge_api/src/forge_api.rs @@ -404,6 +404,24 @@ impl< self.services.get_provider(provider_id).await } + async fn mcp_auth(&self, server_url: &str) -> Result<()> { + let env = self.services.get_environment().clone(); + forge_infra::mcp_auth(server_url, &env).await + } + + async fn mcp_logout(&self, server_url: Option<&str>) -> Result<()> { + let env = self.services.get_environment().clone(); + match server_url { + Some(url) => forge_infra::mcp_logout(url, &env).await, + None => forge_infra::mcp_logout_all(&env).await, + } + } + + async fn mcp_auth_status(&self, server_url: &str) -> Result { + let env = self.services.get_environment().clone(); + Ok(forge_infra::mcp_auth_status(server_url, &env).await) + } + fn hydrate_channel(&self) -> Result<()> { self.infra.hydrate(); Ok(()) diff --git a/crates/forge_app/src/infra.rs b/crates/forge_app/src/infra.rs index 371bebcf0a..9659c3eb4c 100644 --- a/crates/forge_app/src/infra.rs +++ b/crates/forge_app/src/infra.rs @@ -216,6 +216,7 @@ pub trait McpServerInfra: Send + Sync + 'static { &self, config: McpServerConfig, env_vars: &BTreeMap, + environment: &Environment, ) -> anyhow::Result; } /// Service for walking filesystem directories diff --git a/crates/forge_app/src/utils.rs b/crates/forge_app/src/utils.rs index 73c7e59be6..67ba32f785 100644 --- a/crates/forge_app/src/utils.rs +++ b/crates/forge_app/src/utils.rs @@ -107,6 +107,199 @@ pub fn compute_hash(content: &str) -> String { hex::encode(hasher.finalize()) } +// Merges strict-mode incompatible `allOf` branches into a single schema object. +fn flatten_all_of_schema(map: &mut serde_json::Map) { + let Some(serde_json::Value::Array(all_of)) = map.remove("allOf") else { + return; + }; + + for sub_schema in all_of { + let serde_json::Value::Object(source) = sub_schema else { + continue; + }; + + merge_schema_object(map, source); + } +} + +fn merge_schema_object( + target: &mut serde_json::Map, + mut source: serde_json::Map, +) { + flatten_all_of_schema(&mut source); + + for (key, value) in source { + match target.get_mut(&key) { + Some(existing) => merge_schema_keyword(existing, value, &key), + None => { + target.insert(key, value); + } + } + } +} + +fn merge_schema_keyword(target: &mut serde_json::Value, source: serde_json::Value, key: &str) { + match (key, target, source) { + ( + "properties" | "$defs" | "definitions" | "patternProperties", + serde_json::Value::Object(target_map), + serde_json::Value::Object(source_map), + ) => merge_named_schema_map(target_map, source_map), + ( + "required", + serde_json::Value::Array(target_values), + serde_json::Value::Array(source_values), + ) => merge_required_arrays(target_values, source_values), + ( + "enum", + serde_json::Value::Array(target_values), + serde_json::Value::Array(source_values), + ) => merge_enum_arrays(target_values, source_values), + (_, serde_json::Value::Object(target_map), serde_json::Value::Object(source_map)) => { + merge_schema_object(target_map, source_map); + } + ("description" | "title", _, _) => {} + (_, target_value, source_value) if *target_value == source_value => {} + _ => {} + } +} + +fn merge_named_schema_map( + target: &mut serde_json::Map, + source: serde_json::Map, +) { + for (key, value) in source { + match target.get_mut(&key) { + Some(existing) => merge_schema_keyword(existing, value, "schema"), + None => { + target.insert(key, value); + } + } + } +} + +fn merge_required_arrays(target: &mut Vec, source: Vec) { + for value in source { + if !target.contains(&value) { + target.push(value); + } + } + + if target.iter().all(|value| value.as_str().is_some()) { + target.sort_by(|left, right| left.as_str().cmp(&right.as_str())); + } +} + +fn merge_enum_arrays(target: &mut Vec, source: Vec) { + target.retain(|value| source.contains(value)); +} + +fn normalize_named_schema_keyword( + map: &mut serde_json::Map, + key: &str, + strict_mode: bool, +) { + let Some(serde_json::Value::Object(named_schemas)) = map.get_mut(key) else { + return; + }; + + for schema in named_schemas.values_mut() { + enforce_strict_schema(schema, strict_mode); + } +} + +fn normalize_schema_keyword( + map: &mut serde_json::Map, + key: &str, + strict_mode: bool, +) { + let Some(schema) = map.get_mut(key) else { + return; + }; + + match schema { + serde_json::Value::Object(_) | serde_json::Value::Array(_) => { + enforce_strict_schema(schema, strict_mode); + } + serde_json::Value::Bool(_) => {} + _ => {} + } +} + +fn normalize_schema_keywords( + map: &mut serde_json::Map, + strict_mode: bool, +) { + for key in ["properties", "$defs", "definitions", "patternProperties"] { + normalize_named_schema_keyword(map, key, strict_mode); + } + + for key in [ + "items", + "contains", + "not", + "if", + "then", + "else", + "additionalProperties", + "additionalItems", + "unevaluatedProperties", + ] { + normalize_schema_keyword(map, key, strict_mode); + } + + for key in ["allOf", "anyOf", "oneOf", "prefixItems"] { + normalize_schema_keyword(map, key, strict_mode); + } +} + +fn is_object_schema(map: &serde_json::Map) -> bool { + map.get("type") + .and_then(|value| value.as_str()) + .is_some_and(|ty| ty == "object") + || map.contains_key("properties") + || map.contains_key("required") + || map.contains_key("additionalProperties") +} + +fn normalize_additional_properties( + map: &mut serde_json::Map, + strict_mode: bool, +) { + match map.get_mut("additionalProperties") { + Some(serde_json::Value::Object(additional_props_map)) => { + let has_combiners = additional_props_map.contains_key("anyOf") + || additional_props_map.contains_key("oneOf") + || additional_props_map.contains_key("allOf"); + + if !additional_props_map.contains_key("type") && !has_combiners { + additional_props_map.insert( + "type".to_string(), + serde_json::Value::String("object".to_string()), + ); + } + + let mut additional_props = + serde_json::Value::Object(std::mem::take(additional_props_map)); + enforce_strict_schema(&mut additional_props, strict_mode); + map.insert("additionalProperties".to_string(), additional_props); + } + Some(serde_json::Value::Bool(_)) => {} + Some(_) => { + map.insert( + "additionalProperties".to_string(), + serde_json::Value::Bool(false), + ); + } + None => { + map.insert( + "additionalProperties".to_string(), + serde_json::Value::Bool(false), + ); + } + } +} + /// Normalizes a JSON schema to meet LLM provider requirements /// /// Many LLM providers (OpenAI, Anthropic) require that all object types in JSON @@ -116,41 +309,34 @@ pub fn compute_hash(content: &str) -> String { /// Additionally, for OpenAI compatibility, it ensures: /// - All objects have a `properties` field (even if empty) /// - All objects have a `required` array with all property keys +/// - `allOf` branches are merged into a single schema object when strict mode +/// is enabled /// /// # Arguments /// * `schema` - The JSON schema to normalize (will be modified in place) -/// * `strict_mode` - If true, adds `properties` and `required` fields for -/// OpenAI compatibility -/// -/// # Example -/// -/// ```rust,ignore -/// use serde_json::json; -/// use forge_app::utils::normalize_json_schema; -/// -/// let mut schema = json!({ -/// "type": "object", -/// "properties": { -/// "name": { "type": "string" } -/// } -/// }); -/// -/// normalize_json_schema(&mut schema, false); -/// -/// assert_eq!(schema["additionalProperties"], json!(false)); -/// ``` +/// * `strict_mode` - If true, adds `properties`, `required`, and `allOf` +/// flattening for OpenAI compatibility pub fn enforce_strict_schema(schema: &mut serde_json::Value, strict_mode: bool) { match schema { serde_json::Value::Object(map) => { - // Check if this is an object type - let is_object = map - .get("type") - .and_then(|value| value.as_str()) - .is_some_and(|ty| ty == "object") - || map.contains_key("properties"); + if strict_mode { + flatten_all_of_schema(map); + // Remove unsupported keywords that OpenAI/Codex doesn't allow + map.remove("propertyNames"); + } + + let is_object = is_object_schema(map); + + // If this looks like an object schema but has no explicit type, add it + // OpenAI requires all schemas to have a type when they represent objects + if is_object && !map.contains_key("type") { + map.insert( + "type".to_string(), + serde_json::Value::String("object".to_string()), + ); + } if is_object { - // OpenAI strict mode: ensure properties field exists if strict_mode && !map.contains_key("properties") { map.insert( "properties".to_string(), @@ -158,13 +344,8 @@ pub fn enforce_strict_schema(schema: &mut serde_json::Value, strict_mode: bool) ); } - // Both OpenAI and Anthropic require this field to be `false` for objects - map.insert( - "additionalProperties".to_string(), - serde_json::Value::Bool(false), - ); + normalize_additional_properties(map, strict_mode); - // OpenAI strict mode: ensure required field exists with all property keys if strict_mode { let required_keys = map .get("properties") @@ -188,13 +369,6 @@ pub fn enforce_strict_schema(schema: &mut serde_json::Value, strict_mode: bool) } } - // OpenAI strict mode: convert "nullable: true" to anyOf with null type. - // OpenAI does not support the "nullable" keyword; instead, nullable - // schemas must be expressed as anyOf: [, {type: "null"}]. - // The description is kept at the top level alongside the anyOf. - // Additionally, schemars' AddNullable transform adds `null` to enum - // arrays for nullable enums, which must be stripped from the non-null - // branch. if strict_mode && map .get("nullable") @@ -203,19 +377,14 @@ pub fn enforce_strict_schema(schema: &mut serde_json::Value, strict_mode: bool) { map.remove("nullable"); - // Remove null from enum array if present (added by AddNullable) if let Some(serde_json::Value::Array(enum_values)) = map.get_mut("enum") { enum_values.retain(|v| !v.is_null()); } - // Extract description to keep at the top level let description = map.remove("description"); - - // Build the non-null branch from remaining keys let non_null_branch = serde_json::Value::Object(std::mem::take(map)); let null_branch = serde_json::json!({"type": "null"}); - // Replace the current map contents with an anyOf wrapper if let Some(desc) = description { map.insert("description".to_string(), desc); } @@ -225,10 +394,7 @@ pub fn enforce_strict_schema(schema: &mut serde_json::Value, strict_mode: bool) ); } - // Recursively normalize nested schemas - for value in map.values_mut() { - enforce_strict_schema(value, strict_mode); - } + normalize_schema_keywords(map, strict_mode); } serde_json::Value::Array(items) => { for value in items { @@ -360,6 +526,154 @@ mod tests { ); } + #[test] + fn test_dynamic_properties_schema_is_preserved_in_strict_mode() { + let mut fixture = json!({ + "type": "object", + "properties": { + "pages": { + "type": "array", + "items": { + "type": "object", + "properties": { + "properties": { + "description": "Dynamic page properties", + "type": "object", + "additionalProperties": { + "anyOf": [ + { "type": "string" }, + { "type": "number" }, + { "type": "null" } + ] + }, + "propertyNames": { + "type": "string" + } + } + }, + "additionalProperties": false + } + } + } + }); + + enforce_strict_schema(&mut fixture, true); + + let expected = json!({ + "type": "object", + "properties": { + "pages": { + "type": "array", + "items": { + "type": "object", + "properties": { + "properties": { + "description": "Dynamic page properties", + "type": "object", + "properties": {}, + "additionalProperties": { + "anyOf": [ + { "type": "string" }, + { "type": "number" }, + { "type": "null" } + ] + }, + "required": [] + } + }, + "additionalProperties": false, + "required": ["properties"] + } + } + }, + "additionalProperties": false, + "required": ["pages"] + }); + + assert_eq!(fixture, expected); + } + + #[test] + fn test_all_of_is_flattened_in_strict_mode() { + let mut fixture = json!({ + "type": "object", + "properties": { + "rich_text": { + "type": "array", + "items": { + "allOf": [ + { + "type": "object", + "properties": { + "text": { "type": "string" } + } + }, + { + "description": "Rich text item" + } + ] + } + } + } + }); + + enforce_strict_schema(&mut fixture, true); + + let expected = json!({ + "type": "object", + "properties": { + "rich_text": { + "type": "array", + "items": { + "type": "object", + "properties": { + "text": { "type": "string" } + }, + "description": "Rich text item", + "additionalProperties": false, + "required": ["text"] + } + } + }, + "additionalProperties": false, + "required": ["rich_text"] + }); + + assert_eq!(fixture, expected); + } + + #[test] + fn test_all_of_is_preserved_in_non_strict_mode() { + let mut fixture = json!({ + "type": "object", + "properties": { + "value": { + "allOf": [ + { "type": "string" }, + { "description": "A value" } + ] + } + } + }); + + enforce_strict_schema(&mut fixture, false); + + let expected = json!({ + "type": "object", + "properties": { + "value": { + "allOf": [ + { "type": "string" }, + { "description": "A value" } + ] + } + }, + "additionalProperties": false + }); + + assert_eq!(fixture, expected); + } + #[test] fn test_nullable_enum_converted_to_any_of_in_strict_mode() { // This matches what schemars AddNullable produces: nullable=true AND @@ -448,4 +762,243 @@ mod tests { assert_eq!(schema["properties"]["output_mode"]["nullable"], json!(true)); assert!(schema["properties"]["output_mode"].get("anyOf").is_none()); } + + #[test] + fn test_schema_valued_additional_properties_is_normalized() { + let mut schema = json!({ + "type": "object", + "properties": { + "metadata": { + "type": "object", + "additionalProperties": { + "type": "object", + "properties": { + "value": { "type": "string" } + } + } + } + } + }); + + enforce_strict_schema(&mut schema, true); + + // The additionalProperties schema should have been normalized + // (additionalProperties: false added to nested schema) + assert_eq!( + schema["properties"]["metadata"]["additionalProperties"], + json!({ + "type": "object", + "properties": { + "value": { "type": "string" } + }, + "additionalProperties": false, + "required": ["value"] + }) + ); + } + + #[test] + fn test_notion_mcp_create_comment_schema() { + // Simulates the actual Notion MCP create_comment schema that was failing + let mut schema = json!({ + "type": "object", + "properties": { + "rich_text": { + "type": "array", + "items": { + "anyOf": [ + { + "type": "object", + "description": "Text content", + "properties": { + "text": { + "type": "object", + "properties": { + "content": { "type": "string" } + } + } + } + }, + { + "type": "object", + "description": "Mention content", + "properties": { + "mention": { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "id": { "type": "string" } + } + } + } + } + } + } + ] + } + }, + "page_id": { + "type": "string" + }, + "discussion_id": { + "type": "string" + } + } + }); + + enforce_strict_schema(&mut schema, true); + + // Verify the schema is now valid for OpenAI + // 1. All objects should have type: "object" + assert_eq!(schema["type"], "object"); + assert_eq!(schema["properties"]["rich_text"]["type"], "array"); + + // 2. Check that the anyOf items have proper types and additionalProperties: + // false + let any_of = schema["properties"]["rich_text"]["items"]["anyOf"] + .as_array() + .unwrap(); + for branch in any_of { + assert_eq!(branch["type"], "object"); + assert_eq!(branch["additionalProperties"], false); + // All nested object properties should also have type and additionalProperties + if let Some(props) = branch["properties"].as_object() { + for (_, prop_schema) in props { + if let Some(obj) = prop_schema.as_object() + && obj.contains_key("properties") + { + assert!( + prop_schema["type"] == "object", + "Nested object should have type: object" + ); + } + } + } + } + + // 3. Verify additionalProperties: false at root level and for objects + assert_eq!(schema["additionalProperties"], false); + // Note: arrays don't get additionalProperties, only objects do + assert_eq!(schema["properties"]["rich_text"]["type"], "array"); + + // 4. Verify required fields are set + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("rich_text"))); + assert!(required.contains(&json!("page_id"))); + assert!(required.contains(&json!("discussion_id"))); + } + + #[test] + fn test_property_names_is_removed_in_strict_mode() { + // This test ensures we don't regress on propertyNames removal + // propertyNames is a JSON Schema keyword that OpenAI/Codex doesn't support + let mut schema = json!({ + "type": "object", + "properties": { + "dynamic": { + "type": "object", + "propertyNames": { + "type": "string", + "pattern": "^[a-z]+$" + }, + "additionalProperties": { + "type": "string" + } + } + } + }); + + enforce_strict_schema(&mut schema, true); + + // propertyNames should be completely removed + assert!( + !schema["properties"]["dynamic"] + .as_object() + .unwrap() + .contains_key("propertyNames"), + "propertyNames must be removed in strict mode for OpenAI/Codex compatibility" + ); + + // The rest of the schema should be preserved + assert_eq!(schema["properties"]["dynamic"]["type"], "object"); + assert_eq!( + schema["properties"]["dynamic"]["additionalProperties"]["type"], + "string" + ); + } + + /// Integration test that simulates the full Notion MCP workflow: + /// 1. Schema arrives from MCP server (with propertyNames) + /// 2. Gets normalized for OpenAI/Codex (propertyNames removed) + /// 3. Serialized to JSON for API request + #[test] + fn test_notion_mcp_create_pages_full_schema() { + // This is a realistic subset of the Notion MCP create_pages schema + // that caused the original error + let notion_mcp_schema = json!({ + "type": "object", + "properties": { + "pages": { + "type": "array", + "items": { + "type": "object", + "properties": { + "properties": { + "description": "Dynamic page properties", + "type": "object", + "propertyNames": { + "type": "string" + }, + "additionalProperties": { + "anyOf": [ + { "type": "string" }, + { "type": "number" }, + { "type": "boolean" } + ] + } + } + }, + "required": ["properties"] + } + } + }, + "required": ["pages"] + }); + + // Step 1: Convert to Schema (like MCP client does) + let schema_str = serde_json::to_string(¬ion_mcp_schema).unwrap(); + let mut schema: serde_json::Value = serde_json::from_str(&schema_str).unwrap(); + + // Step 2: Normalize for OpenAI/Codex strict mode + enforce_strict_schema(&mut schema, true); + + // Step 3: Serialize for API request + let api_request_json = serde_json::to_string(&schema).unwrap(); + + // Verify: propertyNames should NOT be in the final JSON + assert!( + !api_request_json.contains("propertyNames"), + "Final API request JSON must not contain 'propertyNames'. Schema: {}", + api_request_json + ); + + // Verify: Schema structure is preserved + assert_eq!(schema["type"], "object"); + assert_eq!(schema["properties"]["pages"]["type"], "array"); + assert_eq!( + schema["properties"]["pages"]["items"]["properties"]["properties"]["type"], + "object" + ); + + // Verify: additionalProperties is normalized + let additional_props = &schema["properties"]["pages"]["items"]["properties"]["properties"] + ["additionalProperties"]; + assert!(additional_props.is_object() || additional_props.is_boolean()); + + // Verify: Required fields are set + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("pages"))); + } } diff --git a/crates/forge_domain/src/mcp.rs b/crates/forge_domain/src/mcp.rs index 75a15aedc4..53afe53725 100644 --- a/crates/forge_domain/src/mcp.rs +++ b/crates/forge_domain/src/mcp.rs @@ -45,6 +45,7 @@ impl McpServerConfig { headers: BTreeMap::new(), timeout: None, disable: false, + oauth: McpOAuthSetting::AutoDetect, }) } @@ -110,9 +111,154 @@ pub struct McpHttpServer { /// remove it from the config. #[serde(default)] pub disable: bool, + + /// OAuth 2.0 configuration for MCP server authentication. + /// Supports three formats: + /// - Absent/null: OAuth auto-detection via server 401 response + /// - `false`: Explicitly disable OAuth (use API key/headers instead) + /// - `{ ... }`: Explicit OAuth configuration (client_id, scopes, etc.) + #[serde( + default, + skip_serializing_if = "McpOAuthSetting::is_default", + deserialize_with = "McpOAuthSetting::deserialize_flexible", + serialize_with = "McpOAuthSetting::serialize_flexible" + )] + pub oauth: McpOAuthSetting, +} + +impl McpHttpServer { + /// Returns true if OAuth is explicitly disabled for this server. + pub fn is_oauth_disabled(&self) -> bool { + matches!(self.oauth, McpOAuthSetting::Disabled) + } + + /// Returns the OAuth config if OAuth is explicitly configured. + pub fn oauth_config(&self) -> Option<&McpOAuthConfig> { + match &self.oauth { + McpOAuthSetting::Configured(config) => Some(config), + _ => None, + } + } } -impl McpHttpServer {} +/// Represents the OAuth setting for an MCP server. +/// Supports three states: auto-detect (default), explicitly disabled, or +/// explicitly configured. +#[derive(Debug, Clone, PartialEq, Hash, Default)] +pub enum McpOAuthSetting { + /// No explicit OAuth config - auto-detect via server 401 response + #[default] + AutoDetect, + /// OAuth explicitly disabled (`oauth: false`) + Disabled, + /// OAuth explicitly configured with parameters + Configured(McpOAuthConfig), +} + +impl McpOAuthSetting { + /// Returns true if the setting is the default (AutoDetect). + pub fn is_default(&self) -> bool { + matches!(self, Self::AutoDetect) + } + + /// Custom deserializer that accepts: + /// - boolean `false` -> Disabled + /// - boolean `true` -> AutoDetect + /// - null/absent -> AutoDetect + /// - object `{ ... }` -> Configured(McpOAuthConfig) + fn deserialize_flexible<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de; + + struct McpOAuthSettingVisitor; + + impl<'de> de::Visitor<'de> for McpOAuthSettingVisitor { + type Value = McpOAuthSetting; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a boolean or an OAuth config object") + } + + fn visit_bool(self, v: bool) -> Result { + if v { + Ok(McpOAuthSetting::AutoDetect) + } else { + Ok(McpOAuthSetting::Disabled) + } + } + + fn visit_none(self) -> Result { + Ok(McpOAuthSetting::AutoDetect) + } + + fn visit_unit(self) -> Result { + Ok(McpOAuthSetting::AutoDetect) + } + + fn visit_map>(self, map: M) -> Result { + let config = + McpOAuthConfig::deserialize(de::value::MapAccessDeserializer::new(map))?; + Ok(McpOAuthSetting::Configured(config)) + } + } + + deserializer.deserialize_any(McpOAuthSettingVisitor) + } + + /// Custom serializer: + /// - AutoDetect -> skip (handled by skip_serializing_if) + /// - Disabled -> `false` + /// - Configured -> serialize the config object + fn serialize_flexible(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::AutoDetect => serializer.serialize_none(), + Self::Disabled => serializer.serialize_bool(false), + Self::Configured(config) => config.serialize(serializer), + } + } +} + +/// MCP OAuth 2.0 configuration. +/// Supports automatic OAuth configuration discovery from server metadata. +/// When auth_url/token_url are not provided, Forge will automatically +/// discover them using RFC 8414 OAuth 2.0 Authorization Server Metadata. +#[derive(Default, Debug, Clone, Serialize, Deserialize, Setters, PartialEq, Hash)] +#[setters(strip_option, into)] +#[serde(rename_all = "camelCase")] +pub struct McpOAuthConfig { + /// Pre-registered OAuth client ID (optional for dynamic registration). + /// If not provided, dynamic client registration will be attempted. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_id: Option, + + /// Client secret for confidential clients. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_secret: Option, + + /// OAuth scopes to request. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub scopes: Vec, + + /// Authorization endpoint URL. + /// If not provided, discovered automatically from server metadata. + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_url: Option, + + /// Token endpoint URL. + /// If not provided, discovered automatically from server metadata. + #[serde(skip_serializing_if = "Option::is_none")] + pub token_url: Option, + + /// Redirect URI for OAuth callback. + /// Defaults to http://127.0.0.1:8765/callback. + #[serde(skip_serializing_if = "Option::is_none")] + pub redirect_uri: Option, +} #[derive( Clone, Display, Serialize, Deserialize, Debug, PartialEq, Hash, Eq, From, PartialOrd, Ord, Deref, diff --git a/crates/forge_infra/Cargo.toml b/crates/forge_infra/Cargo.toml index 6f3968f3dd..232926dbdb 100644 --- a/crates/forge_infra/Cargo.toml +++ b/crates/forge_infra/Cargo.toml @@ -47,6 +47,7 @@ http.workspace = true url.workspace = true tonic.workspace = true google-cloud-auth.workspace = true +open.workspace = true [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt", "time", "test-util"] } diff --git a/crates/forge_infra/src/auth/mcp_credentials.rs b/crates/forge_infra/src/auth/mcp_credentials.rs new file mode 100644 index 0000000000..a016260146 --- /dev/null +++ b/crates/forge_infra/src/auth/mcp_credentials.rs @@ -0,0 +1,201 @@ +//! MCP OAuth Credential Storage +//! +//! Stores OAuth tokens separately from LLM provider credentials. +//! Credentials are bound to specific MCP server URLs. + +use std::collections::HashMap; + +use anyhow::Result; +use forge_domain::Environment; +use serde::{Deserialize, Serialize}; +use tokio::fs; + +/// MCP OAuth tokens for a single server. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct McpOAuthTokens { + pub access_token: String, + pub refresh_token: Option, + pub expires_at: Option, + pub scope: Option, +} + +/// Client registration info (for dynamic registration) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpClientRegistration { + pub client_id: String, + pub client_secret: Option, + pub client_id_issued_at: Option, + pub client_secret_expires_at: Option, +} + +/// Complete credential entry for an MCP server +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpCredentialEntry { + /// Server URL (credential binding) + pub server_url: String, + /// OAuth tokens + pub tokens: McpOAuthTokens, + /// Client registration (if dynamically registered) + pub client_registration: Option, +} + +/// Credential store for all MCP servers +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct McpCredentialStore { + pub credentials: HashMap, +} + +impl McpCredentialStore { + /// Load the credential store from disk + /// + /// # Arguments + /// * `env` - The environment containing the base path for storage + /// + /// # Returns + /// * `Ok(McpCredentialStore)` - The loaded store, or empty if file doesn't + /// exist + /// * `Err(...)` - If there was an error reading or parsing the file + pub async fn load(env: &Environment) -> Result { + let path = Self::credential_path(env); + if !path.exists() { + return Ok(Self::default()); + } + let content = fs::read_to_string(&path).await?; + Ok(serde_json::from_str(&content)?) + } + + /// Save the credential store to disk + /// + /// Sets file permissions to 0o600 (user read/write only) on Unix systems. + /// + /// # Arguments + /// * `env` - The environment containing the base path for storage + pub async fn save(&self, env: &Environment) -> Result<()> { + let path = Self::credential_path(env); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).await?; + } + let content = serde_json::to_string_pretty(self)?; + fs::write(&path, content).await?; + + // Set permissions to 0o600 (user read/write only) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let metadata = fs::metadata(&path).await?; + let mut perms = metadata.permissions(); + perms.set_mode(0o600); + fs::set_permissions(&path, perms).await?; + } + Ok(()) + } + + /// Get the path to the credential file + /// + /// The file is stored at `/.mcp-credentials.json` + pub fn credential_path(env: &Environment) -> std::path::PathBuf { + env.base_path.join(".mcp-credentials.json") + } + + /// Get a credential entry by server URL + pub fn get(&self, server_url: &str) -> Option<&McpCredentialEntry> { + self.credentials.get(server_url) + } + + /// Set a credential entry + pub fn set(&mut self, entry: McpCredentialEntry) { + self.credentials.insert(entry.server_url.clone(), entry); + } + + /// Remove a credential entry by server URL + pub fn remove(&mut self, server_url: &str) { + self.credentials.remove(server_url); + } + + /// Check if credentials exist for a server URL + #[allow(dead_code)] + pub fn has_credentials(&self, server_url: &str) -> bool { + self.credentials.contains_key(server_url) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::*; + + fn test_env() -> Environment { + Environment { + os: "test".to_string(), + cwd: PathBuf::from("/tmp"), + home: Some(PathBuf::from("/home/test")), + shell: "bash".to_string(), + base_path: PathBuf::from("/tmp/test-forge"), + } + } + + #[tokio::test] + async fn test_credential_store_save_load() { + let env = test_env(); + let mut store = McpCredentialStore::default(); + + let entry = McpCredentialEntry { + server_url: "https://api.example.com/mcp".to_string(), + tokens: McpOAuthTokens { + access_token: "test-access-token".to_string(), + refresh_token: Some("test-refresh-token".to_string()), + expires_at: Some(1234567890), + scope: Some("read write".to_string()), + }, + client_registration: None, + }; + + store.set(entry); + store.save(&env).await.unwrap(); + + let loaded = McpCredentialStore::load(&env).await.unwrap(); + assert!(loaded.has_credentials("https://api.example.com/mcp")); + let entry = loaded.get("https://api.example.com/mcp").unwrap(); + assert_eq!(entry.tokens.access_token, "test-access-token"); + assert_eq!( + entry.tokens.refresh_token, + Some("test-refresh-token".to_string()) + ); + + // Cleanup + let path = McpCredentialStore::credential_path(&env); + let _ = fs::remove_file(&path).await; + } + + #[tokio::test] + async fn test_credential_store_remove() { + let env = test_env(); + let mut store = McpCredentialStore::default(); + + store.set(McpCredentialEntry { + server_url: "https://api.example.com/mcp".to_string(), + tokens: McpOAuthTokens { + access_token: "test".to_string(), + refresh_token: None, + expires_at: None, + scope: None, + }, + client_registration: None, + }); + + store.remove("https://api.example.com/mcp"); + assert!(!store.has_credentials("https://api.example.com/mcp")); + + // Cleanup + let path = McpCredentialStore::credential_path(&env); + let _ = fs::remove_file(&path).await; + } + + #[tokio::test] + async fn test_credential_store_empty() { + let env = test_env(); + let store = McpCredentialStore::load(&env).await.unwrap(); + assert!(store.credentials.is_empty()); + } +} diff --git a/crates/forge_infra/src/auth/mcp_token_storage.rs b/crates/forge_infra/src/auth/mcp_token_storage.rs new file mode 100644 index 0000000000..544833fe40 --- /dev/null +++ b/crates/forge_infra/src/auth/mcp_token_storage.rs @@ -0,0 +1,273 @@ +//! Token Storage Adapter for rmcp OAuth +//! +//! Implements rmcp's `CredentialStore` trait to persist credentials +//! using our McpCredentialStore. + +use std::sync::Arc; + +use async_trait::async_trait; +use forge_domain::Environment; +use rmcp::transport::auth::{CredentialStore, StoredCredentials}; +use tokio::sync::Mutex; + +use crate::auth::mcp_credentials::{ + McpClientRegistration, McpCredentialEntry, McpCredentialStore, McpOAuthTokens, +}; + +/// Adapter that implements rmcp's CredentialStore trait +/// using our file-based McpCredentialStore +pub struct McpTokenStorage { + server_url: String, + env: Environment, + store: Arc>>, +} + +impl McpTokenStorage { + /// Create a new token storage adapter for a specific MCP server + /// + /// # Arguments + /// * `server_url` - The URL of the MCP server (used as credential key) + /// * `env` - The environment for file system paths + pub fn new(server_url: String, env: Environment) -> Self { + Self { server_url, env, store: Arc::new(Mutex::new(None)) } + } + + /// Get the credential store, loading it if necessary + async fn get_store(&self) -> anyhow::Result { + let mut guard = self.store.lock().await; + if guard.is_none() { + *guard = Some(McpCredentialStore::load(&self.env).await?); + } + Ok(guard.as_ref().unwrap().clone()) + } + + /// Update the cached store after modifications + async fn update_store(&self, store: McpCredentialStore) { + *self.store.lock().await = Some(store); + } + + /// Load stored credentials for this server + /// + /// Returns the stored credential entry if one exists, or None. + pub async fn load_credentials(&self) -> anyhow::Result> { + let store = self.get_store().await?; + Ok(store.get(&self.server_url).cloned()) + } + + /// Remove stored credentials for this server. + pub async fn remove_credentials(&self) -> anyhow::Result<()> { + let mut store = self.get_store().await?; + store.remove(&self.server_url); + store.save(&self.env).await?; + self.update_store(store).await; + Ok(()) + } + + /// Remove only tokens while keeping client registration. + /// Useful when tokens are expired/invalid but the client registration + /// (from dynamic registration) is still valid. + #[allow(dead_code)] + pub async fn remove_tokens_only(&self) -> anyhow::Result<()> { + let mut store = self.get_store().await?; + if let Some(entry) = store.get(&self.server_url).cloned() { + let updated = McpCredentialEntry { + server_url: entry.server_url, + tokens: McpOAuthTokens::default(), + client_registration: entry.client_registration, + }; + store.set(updated); + store.save(&self.env).await?; + self.update_store(store).await; + } + Ok(()) + } +} + +#[async_trait] +impl CredentialStore for McpTokenStorage { + /// Load credentials from storage + /// + /// Converts our file-based credentials to rmcp's StoredCredentials format. + /// Preserves refresh_token and expiry information for token refresh. + async fn load(&self) -> Result, rmcp::transport::auth::AuthError> { + let store = self + .get_store() + .await + .map_err(|e| rmcp::transport::auth::AuthError::InternalError(e.to_string()))?; + + if let Some(entry) = store.get(&self.server_url) { + use oauth2::basic::BasicTokenType; + use oauth2::{AccessToken, RefreshToken}; + use rmcp::transport::auth::OAuthTokenResponse; + + let access_token = AccessToken::new(entry.tokens.access_token.clone()); + let token_type = BasicTokenType::Bearer; + let extra_fields = oauth2::EmptyExtraTokenFields {}; + + let mut token_response = + OAuthTokenResponse::new(access_token, token_type, extra_fields); + + // Set refresh token if available + if let Some(ref rt) = entry.tokens.refresh_token { + token_response.set_refresh_token(Some(RefreshToken::new(rt.clone()))); + } + + // Set expiry if available + if let Some(expires_at) = entry.tokens.expires_at { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + if expires_at > now { + token_response + .set_expires_in(Some(&std::time::Duration::from_secs(expires_at - now))); + } else { + // Token has expired - set zero duration so rmcp triggers refresh + token_response.set_expires_in(Some(&std::time::Duration::from_secs(0))); + } + } + + Ok(Some(StoredCredentials { + client_id: entry + .client_registration + .as_ref() + .map(|r| r.client_id.clone()) + .unwrap_or_default(), + token_response: Some(token_response), + })) + } else { + Ok(None) + } + } + + /// Save credentials to storage + /// + /// Converts rmcp's StoredCredentials to our file-based format, + /// preserving all token metadata including refresh_token, expiry, + /// and client registration info. + async fn save( + &self, + credentials: StoredCredentials, + ) -> Result<(), rmcp::transport::auth::AuthError> { + use oauth2::TokenResponse; + + let mut store = self + .get_store() + .await + .map_err(|e| rmcp::transport::auth::AuthError::InternalError(e.to_string()))?; + + // Get existing entry to preserve client_secret and registration info + let existing = store.get(&self.server_url).cloned(); + + let tokens = if let Some(ref response) = credentials.token_response { + let access_token = response.access_token().secret().to_string(); + let refresh_token = response.refresh_token().map(|rt| rt.secret().to_string()); + let expires_at = response.expires_in().map(|d| { + (std::time::SystemTime::now() + d) + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + }); + let scope = response.scopes().map(|scopes| { + scopes + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(" ") + }); + + McpOAuthTokens { + access_token, + refresh_token: refresh_token.or_else(|| { + existing + .as_ref() + .and_then(|e| e.tokens.refresh_token.clone()) + }), + expires_at, + scope, + } + } else { + McpOAuthTokens { + access_token: String::new(), + refresh_token: None, + expires_at: None, + scope: None, + } + }; + + // Preserve client_secret from dynamic registration + let client_registration = if credentials.client_id.is_empty() { + existing.and_then(|e| e.client_registration) + } else { + let existing_reg = existing + .as_ref() + .and_then(|e| e.client_registration.as_ref()); + Some(McpClientRegistration { + client_id: credentials.client_id, + client_secret: existing_reg.and_then(|r| r.client_secret.clone()), + client_id_issued_at: existing_reg.and_then(|r| r.client_id_issued_at), + client_secret_expires_at: existing_reg.and_then(|r| r.client_secret_expires_at), + }) + }; + + let entry = McpCredentialEntry { + server_url: self.server_url.clone(), + tokens, + client_registration, + }; + + store.set(entry); + store + .save(&self.env) + .await + .map_err(|e| rmcp::transport::auth::AuthError::InternalError(e.to_string()))?; + + self.update_store(store).await; + Ok(()) + } + + /// Clear credentials from storage + async fn clear(&self) -> Result<(), rmcp::transport::auth::AuthError> { + let mut store = self + .get_store() + .await + .map_err(|e| rmcp::transport::auth::AuthError::InternalError(e.to_string()))?; + + store.remove(&self.server_url); + store + .save(&self.env) + .await + .map_err(|e| rmcp::transport::auth::AuthError::InternalError(e.to_string()))?; + + self.update_store(store).await; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use super::*; + + fn test_env() -> Environment { + Environment { + os: "test".to_string(), + cwd: PathBuf::from("/tmp"), + home: Some(PathBuf::from("/home/test")), + shell: "bash".to_string(), + base_path: PathBuf::from("/tmp/test-forge"), + } + } + + #[tokio::test] + async fn test_token_storage_new() { + let env = test_env(); + let storage = McpTokenStorage::new("https://example.com/mcp".to_string(), env); + + // Should start with no credentials + let result = storage.load().await; + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); + } +} diff --git a/crates/forge_infra/src/auth/mod.rs b/crates/forge_infra/src/auth/mod.rs index e5c1a4f2d5..1edf7f7320 100644 --- a/crates/forge_infra/src/auth/mod.rs +++ b/crates/forge_infra/src/auth/mod.rs @@ -1,6 +1,11 @@ +mod mcp_credentials; +mod mcp_token_storage; + mod error; mod http; mod strategy; mod util; +pub(crate) use mcp_credentials::*; +pub(crate) use mcp_token_storage::*; pub use strategy::*; diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 70a4711976..2e223c98a4 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -1049,17 +1049,17 @@ impl AuthStrategy for AnyAuthStrategy { } /// Factory for creating authentication strategies -pub struct ForgeAuthStrategyFactory {} +pub struct ForgeAuthStrategyFactory; impl Default for ForgeAuthStrategyFactory { fn default() -> Self { - Self::new() + Self } } impl ForgeAuthStrategyFactory { - pub fn new() -> Self { - Self {} + pub fn new(_environment: forge_domain::Environment) -> Self { + Self } } @@ -1134,7 +1134,7 @@ mod tests { #[test] fn test_create_auth_strategy_api_key() { - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory; let strategy = factory.create_auth_strategy( ProviderId::OPENAI, forge_domain::AuthMethod::ApiKey, @@ -1157,7 +1157,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory; let strategy = factory.create_auth_strategy( ProviderId::OPENAI, forge_domain::AuthMethod::OAuthCode(config), @@ -1180,7 +1180,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory; let strategy = factory.create_auth_strategy( ProviderId::OPENAI, forge_domain::AuthMethod::OAuthDevice(config), @@ -1203,7 +1203,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory; let strategy = factory.create_auth_strategy( ProviderId::GITHUB_COPILOT, forge_domain::AuthMethod::OAuthDevice(config), @@ -1227,7 +1227,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory; let actual = factory.create_auth_strategy( ProviderId::CODEX, forge_domain::AuthMethod::CodexDevice(config), diff --git a/crates/forge_infra/src/forge_infra.rs b/crates/forge_infra/src/forge_infra.rs index 685f0c2e3b..0815066da0 100644 --- a/crates/forge_infra/src/forge_infra.rs +++ b/crates/forge_infra/src/forge_infra.rs @@ -99,7 +99,7 @@ impl ForgeInfra { inquire_service: Arc::new(ForgeInquire::new()), mcp_server: ForgeMcpServer, walker_service: Arc::new(ForgeWalkerService::new()), - strategy_factory: Arc::new(ForgeAuthStrategyFactory::new()), + strategy_factory: Arc::new(ForgeAuthStrategyFactory::new(env.clone())), http_service, grpc_client, output_printer, @@ -282,8 +282,9 @@ impl McpServerInfra for ForgeInfra { &self, config: McpServerConfig, env_vars: &BTreeMap, + environment: &forge_domain::Environment, ) -> anyhow::Result { - self.mcp_server.connect(config, env_vars).await + self.mcp_server.connect(config, env_vars, environment).await } } diff --git a/crates/forge_infra/src/lib.rs b/crates/forge_infra/src/lib.rs index ffc04010e7..a6a726d477 100644 --- a/crates/forge_infra/src/lib.rs +++ b/crates/forge_infra/src/lib.rs @@ -1,9 +1,8 @@ +mod auth; mod console; mod env; -pub mod executor; - -mod auth; mod error; +mod executor; mod forge_infra; mod fs_create_dirs; mod fs_meta; @@ -25,3 +24,4 @@ pub use executor::ForgeCommandExecutorService; pub use forge_infra::*; pub use http::sanitize_headers; pub use kv_storage::CacacheStorage; +pub use mcp_client::*; diff --git a/crates/forge_infra/src/mcp_client.rs b/crates/forge_infra/src/mcp_client.rs index 662ca0ce85..1380733efe 100644 --- a/crates/forge_infra/src/mcp_client.rs +++ b/crates/forge_infra/src/mcp_client.rs @@ -6,7 +6,9 @@ use std::sync::{Arc, OnceLock, RwLock}; use backon::{ExponentialBuilder, Retryable}; use forge_app::McpClientInfra; -use forge_domain::{Image, McpHttpServer, McpServerConfig, ToolDefinition, ToolName, ToolOutput}; +use forge_domain::{ + Environment, Image, McpHttpServer, McpServerConfig, ToolDefinition, ToolName, ToolOutput, +}; use http::{HeaderName, HeaderValue, header}; use rmcp::model::{CallToolRequestParam, ClientInfo, Implementation, InitializeRequestParam}; use rmcp::service::RunningService; @@ -33,15 +35,21 @@ pub struct ForgeMcpClient { client: Arc>>>, config: McpServerConfig, env_vars: BTreeMap, + environment: Environment, resolved_config: Arc>>, } impl ForgeMcpClient { - pub fn new(config: McpServerConfig, env_vars: &BTreeMap) -> Self { + pub fn new( + config: McpServerConfig, + env_vars: &BTreeMap, + environment: Environment, + ) -> Self { Self { client: Default::default(), config, env_vars: env_vars.clone(), + environment, resolved_config: Arc::new(OnceLock::new()), } } @@ -122,33 +130,332 @@ impl ForgeMcpClient { } }); } - self.client_info().serve(transport).await? + Arc::new(self.client_info().serve(transport).await?) } McpServerConfig::Http(http) => { - // Try HTTP first, fall back to SSE if it fails - let client = self.reqwest_client(http)?; - let transport = StreamableHttpClientTransport::with_client( - client.clone(), - StreamableHttpClientTransportConfig::with_uri(http.url.clone()), - ); - match self.client_info().serve(transport).await { - Ok(client) => client, - Err(_e) => { - let transport = SseClientTransport::start_with_client( - client, - SseClientConfig { - sse_endpoint: http.url.clone().into(), - ..Default::default() - }, - ) - .await?; - self.client_info().serve(transport).await? + // Check if OAuth is explicitly disabled + if http.is_oauth_disabled() { + // OAuth explicitly disabled - only try standard connection + Arc::new(self.create_standard_http_connection(http).await?) + } else if let Some(oauth_config) = http.oauth_config() { + // OAuth explicitly configured - use it directly + // Do NOT allow interactive auth during normal connection + self.create_oauth_connection(http, oauth_config, false) + .await? + } else { + // Auto-detect: try standard first, fall back to OAuth on auth errors + match self.create_standard_http_connection(http).await { + Ok(client) => Arc::new(client), + Err(e) => { + let error_str = e.to_string().to_lowercase(); + if error_str.contains("401") + || error_str.contains("unauthorized") + || error_str.contains("authentication required") + || error_str.contains("auth required") + || error_str.contains("oauth") + { + tracing::info!( + "Standard connection failed with auth error for: {}, trying stored credentials", + http.url + ); + // Try OAuth with stored credentials (non-interactive) + // If stored credentials exist, use them; otherwise error + let default_config = forge_domain::McpOAuthConfig::default(); + self.create_oauth_connection(http, &default_config, false) + .await? + } else { + return Err(e); + } + } } } } }; - Ok(Arc::new(client)) + Ok(client) + } + + /// Create a standard HTTP connection without OAuth + async fn create_standard_http_connection( + &self, + http: &McpHttpServer, + ) -> anyhow::Result { + // Try HTTP first, fall back to SSE if it fails + let client = self.reqwest_client(http)?; + let transport = StreamableHttpClientTransport::with_client( + client.clone(), + StreamableHttpClientTransportConfig::with_uri(http.url.clone()), + ); + match self.client_info().serve(transport).await { + Ok(client) => Ok(client), + Err(_e) => { + let transport = SseClientTransport::start_with_client( + client, + SseClientConfig { sse_endpoint: http.url.clone().into(), ..Default::default() }, + ) + .await?; + Ok(self.client_info().serve(transport).await?) + } + } + } + + /// Create an OAuth-enabled connection using rmcp's OAuth support. + /// + /// Uses rmcp's `AuthorizationManager` and `OAuthState` state machine which + /// properly handle: + /// 1. OAuth metadata discovery via RFC 8414 + /// 2. Dynamic client registration via RFC 7591 + /// 3. PKCE challenge/verifier generation and validation + /// 4. CSRF state parameter generation and validation + /// 5. Authorization code exchange for tokens + /// 6. Token refresh via refresh_token grant + /// 7. Token persistence via `CredentialStore` trait + /// + /// # Arguments + /// * `allow_interactive` - If true, will open browser for user + /// authentication if no stored credentials exist. If false, returns an + /// error instead. + async fn create_oauth_connection( + &self, + http: &McpHttpServer, + oauth_config: &forge_domain::McpOAuthConfig, + allow_interactive: bool, + ) -> anyhow::Result> { + use rmcp::transport::auth::{AuthorizationManager, OAuthState}; + + use crate::auth::McpTokenStorage; + + let credential_store = McpTokenStorage::new(http.url.clone(), self.environment.clone()); + + // First, try to use cached credentials with auto-refresh + let mut auth_manager = AuthorizationManager::new(&http.url) + .await + .map_err(|e| anyhow::anyhow!("Failed to create OAuth manager: {}", e))?; + + auth_manager.set_credential_store(credential_store); + + // Try to load and use stored credentials (with automatic token refresh) + match auth_manager.initialize_from_store().await { + Ok(true) => { + // Stored credentials loaded. Try to get a valid access token + // (this auto-refreshes if expired and refresh_token is available) + match auth_manager.get_access_token().await { + Ok(token) => { + tracing::debug!("Using stored/refreshed OAuth token for: {}", http.url); + return self.connect_with_token(http, &token).await; + } + Err(e) => { + tracing::warn!( + "Stored token invalid for {}: {}, re-authenticating", + http.url, + e + ); + } + } + } + Ok(false) => { + tracing::info!("No stored credentials for: {}", http.url); + } + Err(e) => { + tracing::warn!("Failed to load stored credentials for {}: {}", http.url, e); + } + } + + // No valid cached credentials + if !allow_interactive { + // Interactive auth not allowed - return error with instructions + return Err(anyhow::anyhow!( + "MCP server '{}' requires authentication. Run 'mcp login ' to authenticate.", + http.url + )); + } + + // Interactive auth allowed - start full OAuth authorization flow + // Create a fresh OAuthState to run the browser-based flow + let mut oauth_state = OAuthState::new(&http.url, None) + .await + .map_err(|e| anyhow::anyhow!("Failed to initialize OAuth state: {}", e))?; + + let redirect_uri = oauth_config + .redirect_uri + .clone() + .unwrap_or_else(|| "http://127.0.0.1:8765/callback".to_string()); + + let scopes: Vec<&str> = oauth_config.scopes.iter().map(|s| s.as_str()).collect(); + + // start_authorization discovers metadata, registers client, generates PKCE + + // CSRF state + oauth_state + .start_authorization(&scopes, &redirect_uri, Some("Forge")) + .await + .map_err(|e| anyhow::anyhow!("OAuth authorization flow failed: {}", e))?; + + // Get the authorization URL (includes PKCE challenge and CSRF state) + let auth_url = oauth_state + .get_authorization_url() + .await + .map_err(|e| anyhow::anyhow!("Failed to get authorization URL: {}", e))?; + + tracing::info!("Starting OAuth authentication for MCP server: {}", http.url); + + // Parse redirect URI to get port for callback server + let redirect_url: url::Url = redirect_uri + .parse() + .map_err(|e| anyhow::anyhow!("Invalid redirect URI: {}", e))?; + let port = redirect_url.port().unwrap_or(8765); + + // Start local callback server, open browser, wait for redirect + let (code, state) = self.run_oauth_callback_server(port, &auth_url).await?; + + // Exchange authorization code for tokens (validates CSRF state internally) + // rmcp's OAuthState handles PKCE verifier inclusion in the token request + oauth_state + .handle_callback(&code, &state) + .await + .map_err(|e| anyhow::anyhow!("Failed to exchange authorization code: {}", e))?; + + // Get the access token from the completed OAuth flow + let access_token = oauth_state + .get_access_token() + .await + .map_err(|e| anyhow::anyhow!("Failed to get access token after OAuth: {}", e))?; + + // Save credentials for future use via our persistent store + let credentials = oauth_state + .get_credentials() + .await + .map_err(|e| anyhow::anyhow!("Failed to get credentials: {}", e))?; + + { + use rmcp::transport::auth::CredentialStore; + let save_store = McpTokenStorage::new(http.url.clone(), self.environment.clone()); + let stored = rmcp::transport::auth::StoredCredentials { + client_id: credentials.0, + token_response: credentials.1, + }; + save_store + .save(stored) + .await + .map_err(|e| anyhow::anyhow!("Failed to save credentials: {}", e))?; + } + + tracing::info!( + "OAuth authentication successful for MCP server: {}", + http.url + ); + + self.connect_with_token(http, &access_token).await + } + + /// Connect to an MCP server using a bearer token. + /// + /// Uses StreamableHTTP transport only - does NOT fall back to SSE + /// since SSE transport doesn't support auth headers in the same way. + /// Auth errors are transport-independent so falling back to SSE + /// with the same auth issue would be pointless. + async fn connect_with_token( + &self, + http: &McpHttpServer, + token: &str, + ) -> anyhow::Result> { + let client = self.reqwest_client(http)?; + let transport = StreamableHttpClientTransport::with_client( + client, + StreamableHttpClientTransportConfig::with_uri(http.url.clone()).auth_header(token), + ); + + Ok(Arc::new(self.client_info().serve(transport).await?)) + } + + /// Runs a local HTTP server to receive the OAuth callback, opens the + /// browser, and returns the authorization code and state. + async fn run_oauth_callback_server( + &self, + port: u16, + auth_url: &str, + ) -> anyhow::Result<(String, String)> { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + + let listener = TcpListener::bind(format!("127.0.0.1:{}", port)) + .await + .map_err(|e| { + anyhow::anyhow!( + "Failed to start OAuth callback server on port {}: {}. \ + Is another process using this port?", + port, + e + ) + })?; + + tracing::info!("OAuth callback server listening on port {}", port); + + // Open browser + if let Err(e) = open::that(auth_url) { + tracing::warn!( + "Failed to open browser: {}. Please open this URL manually:\n{}", + e, + auth_url + ); + eprintln!( + "\nPlease open this URL in your browser to authenticate:\n{}\n", + auth_url + ); + } else { + eprintln!("\nOpening browser for OAuth authentication...\n"); + } + + // Wait for callback with timeout + let timeout = tokio::time::Duration::from_secs(300); // 5 minutes + let (mut stream, _addr) = tokio::time::timeout(timeout, listener.accept()) + .await + .map_err(|_| anyhow::anyhow!("OAuth callback timed out after 5 minutes"))? + .map_err(|e| anyhow::anyhow!("Failed to accept OAuth callback: {}", e))?; + + // Read the HTTP request + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await?; + let request = String::from_utf8_lossy(&buf[..n]); + + // Parse the request line to extract query parameters + let first_line = request.lines().next().unwrap_or(""); + let path = first_line.split_whitespace().nth(1).unwrap_or("/"); + + // Parse query parameters + let query_start = path.find('?').unwrap_or(path.len()); + let query_string = &path[query_start..]; + let params: std::collections::HashMap = + url::form_urlencoded::parse(query_string.trim_start_matches('?').as_bytes()) + .into_owned() + .collect(); + + let code = params + .get("code") + .ok_or_else(|| { + let error = params.get("error").map(|e| e.as_str()).unwrap_or("unknown"); + let desc = params + .get("error_description") + .map(|d| d.as_str()) + .unwrap_or("No description"); + anyhow::anyhow!("OAuth error: {} - {}", error, desc) + })? + .clone(); + + let state = params + .get("state") + .ok_or_else(|| anyhow::anyhow!("Missing state parameter in OAuth callback"))? + .clone(); + + // Send styled success response with auto-close + let response_body = r#"Forge - Authorization Successful

Authorization Successful

You can close this window and return to Forge.

"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + response_body.len(), + response_body + ); + let _ = stream.write_all(response.as_bytes()).await; + + Ok((code, state)) } fn reqwest_client(&self, config: &McpHttpServer) -> anyhow::Result { @@ -285,6 +592,180 @@ fn resolve_http_templates( Ok(http) } +/// Trigger OAuth authentication for a specific MCP server URL. +/// +/// Runs the full OAuth flow: metadata discovery, dynamic registration, +/// browser-based authorization, and token persistence. +/// +/// # Arguments +/// * `server_url` - The URL of the MCP server to authenticate with +/// * `env` - The environment for file system paths +pub async fn mcp_auth(server_url: &str, env: &Environment) -> anyhow::Result<()> { + use rmcp::transport::auth::{CredentialStore, OAuthState}; + + use crate::auth::McpTokenStorage; + + // Start fresh OAuth flow via OAuthState + let mut oauth_state = OAuthState::new(server_url, None) + .await + .map_err(|e| anyhow::anyhow!("Failed to initialize OAuth state: {}", e))?; + + let redirect_uri = "http://127.0.0.1:8765/callback"; + + oauth_state + .start_authorization(&[], redirect_uri, Some("Forge")) + .await + .map_err(|e| anyhow::anyhow!("OAuth authorization flow failed: {}", e))?; + + let auth_url = oauth_state + .get_authorization_url() + .await + .map_err(|e| anyhow::anyhow!("Failed to get authorization URL: {}", e))?; + + // Start callback server and open browser + let listener = tokio::net::TcpListener::bind("127.0.0.1:8765") + .await + .map_err(|e| anyhow::anyhow!("Failed to start OAuth callback server: {}", e))?; + + if let Err(e) = open::that(&auth_url) { + tracing::warn!("Failed to open browser: {}", e); + eprintln!( + "\nPlease open this URL in your browser to authenticate:\n{}\n", + auth_url + ); + } else { + eprintln!("\nOpening browser for OAuth authentication...\n"); + } + + let timeout = tokio::time::Duration::from_secs(300); + let (mut stream, _) = tokio::time::timeout(timeout, listener.accept()) + .await + .map_err(|_| anyhow::anyhow!("OAuth callback timed out after 5 minutes"))? + .map_err(|e| anyhow::anyhow!("Failed to accept OAuth callback: {}", e))?; + + // Read HTTP request and parse callback params + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + let mut buf = vec![0u8; 4096]; + let n = stream.read(&mut buf).await?; + let request = String::from_utf8_lossy(&buf[..n]); + let first_line = request.lines().next().unwrap_or(""); + let path = first_line.split_whitespace().nth(1).unwrap_or("/"); + let query_start = path.find('?').unwrap_or(path.len()); + let params: std::collections::HashMap = + url::form_urlencoded::parse(path[query_start..].trim_start_matches('?').as_bytes()) + .into_owned() + .collect(); + + let code = params + .get("code") + .ok_or_else(|| { + let error = params.get("error").map(|e| e.as_str()).unwrap_or("unknown"); + let desc = params + .get("error_description") + .map(|d| d.as_str()) + .unwrap_or("No description"); + anyhow::anyhow!("OAuth error: {} - {}", error, desc) + })? + .clone(); + + let state = params + .get("state") + .ok_or_else(|| anyhow::anyhow!("Missing state parameter in OAuth callback"))? + .clone(); + + // Send styled response + let body = r#"Forge - Authorization Successful

Authorization Successful

You can close this window and return to Forge.

"#; + let resp = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + let _ = stream.write_all(resp.as_bytes()).await; + + // Exchange code for tokens + oauth_state + .handle_callback(&code, &state) + .await + .map_err(|e| anyhow::anyhow!("Failed to exchange authorization code: {}", e))?; + + // Save credentials + let credentials = oauth_state + .get_credentials() + .await + .map_err(|e| anyhow::anyhow!("Failed to get credentials: {}", e))?; + + let save_store = McpTokenStorage::new(server_url.to_string(), env.clone()); + let stored = rmcp::transport::auth::StoredCredentials { + client_id: credentials.0, + token_response: credentials.1, + }; + save_store + .save(stored) + .await + .map_err(|e| anyhow::anyhow!("Failed to save credentials: {}", e))?; + + Ok(()) +} + +/// Remove stored OAuth credentials for a specific MCP server. +/// +/// # Arguments +/// * `server_url` - The URL of the MCP server to remove credentials for +/// * `env` - The environment for file system paths +pub async fn mcp_logout(server_url: &str, env: &Environment) -> anyhow::Result<()> { + use crate::auth::McpTokenStorage; + let storage = McpTokenStorage::new(server_url.to_string(), env.clone()); + storage.remove_credentials().await +} + +/// Remove all stored MCP OAuth credentials. +/// +/// # Arguments +/// * `env` - The environment for file system paths +pub async fn mcp_logout_all(env: &Environment) -> anyhow::Result<()> { + use crate::auth::McpCredentialStore; + let path = McpCredentialStore::credential_path(env); + if path.exists() { + tokio::fs::remove_file(&path).await?; + } + Ok(()) +} + +/// Get the auth status for a specific MCP server. +/// +/// Returns one of: "authenticated", "expired", "not_authenticated" +/// +/// # Arguments +/// * `server_url` - The URL of the MCP server +/// * `env` - The environment for file system paths +pub async fn mcp_auth_status(server_url: &str, env: &Environment) -> String { + use crate::auth::McpTokenStorage; + let storage = McpTokenStorage::new(server_url.to_string(), env.clone()); + match storage.load_credentials().await { + Ok(Some(entry)) => { + if let Some(expires_at) = entry.tokens.expires_at { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + if expires_at <= now { + if entry.tokens.refresh_token.is_some() { + "expired (has refresh token)".to_string() + } else { + "expired".to_string() + } + } else { + "authenticated".to_string() + } + } else { + "authenticated".to_string() + } + } + Ok(None) => "not authenticated".to_string(), + Err(_) => "unknown (error reading credentials)".to_string(), + } +} + #[cfg(test)] mod tests { use pretty_assertions::assert_eq; @@ -310,6 +791,7 @@ mod tests { ]), timeout: None, disable: false, + oauth: Default::default(), }; let resolved = resolve_http_templates(http, &env_vars).unwrap(); @@ -340,6 +822,7 @@ mod tests { )]), timeout: None, disable: false, + oauth: Default::default(), }; let resolved = resolve_http_templates(http, &env_vars).unwrap(); @@ -360,6 +843,7 @@ mod tests { headers: BTreeMap::from([("Auth".to_string(), "{{env.TOKEN}}".to_string())]), timeout: None, disable: true, + oauth: Default::default(), }; let resolved = resolve_http_templates(http, &env_vars).unwrap(); diff --git a/crates/forge_infra/src/mcp_server.rs b/crates/forge_infra/src/mcp_server.rs index 1725313767..b69a7672f7 100644 --- a/crates/forge_infra/src/mcp_server.rs +++ b/crates/forge_infra/src/mcp_server.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use forge_app::McpServerInfra; -use forge_domain::McpServerConfig; +use forge_domain::{Environment, McpServerConfig}; use crate::mcp_client::ForgeMcpClient; @@ -16,7 +16,8 @@ impl McpServerInfra for ForgeMcpServer { &self, config: McpServerConfig, env_vars: &BTreeMap, + environment: &Environment, ) -> anyhow::Result { - Ok(ForgeMcpClient::new(config, env_vars)) + Ok(ForgeMcpClient::new(config, env_vars, environment.clone())) } } diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index 4cb6ff66f6..de3285fdaa 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -415,6 +415,12 @@ pub enum McpCommand { /// Reload servers and rebuild caches. Reload, + + /// Authenticate with an OAuth-enabled MCP server. + Login(McpAuthArgs), + + /// Remove stored OAuth credentials for an MCP server. + Logout(McpLogoutArgs), } #[derive(Parser, Debug, Clone)] @@ -444,6 +450,19 @@ pub struct McpShowArgs { pub name: String, } +#[derive(Parser, Debug, Clone)] +pub struct McpAuthArgs { + /// Name of the MCP server to authenticate with. + pub name: String, +} + +#[derive(Parser, Debug, Clone)] +pub struct McpLogoutArgs { + /// Name of the MCP server to remove credentials for, or "all" to + /// remove all MCP OAuth credentials. + pub name: String, +} + /// Configuration scope for settings. #[derive(Copy, Clone, Debug, ValueEnum, Default)] pub enum Scope { diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 1980ff11c8..08aa510f53 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -538,6 +538,12 @@ impl A + Send + Sync> UI self.api.reload_mcp().await?; self.writeln_title(TitleFormat::info("MCP reloaded"))?; } + McpCommand::Login(args) => { + self.handle_mcp_login(&args.name).await?; + } + McpCommand::Logout(args) => { + self.handle_mcp_logout(&args.name).await?; + } }, TopLevelCommand::Info { porcelain, conversation_id } => { // Only initialize state (agent/provider/model resolution). @@ -825,6 +831,122 @@ impl A + Send + Sync> UI Ok(()) } + /// Handle `mcp login ` command. + /// + /// Triggers the OAuth authentication flow for the specified MCP server. + /// Uses the API layer which delegates to rmcp's OAuth state machine + /// for metadata discovery, dynamic registration, PKCE, and token exchange. + async fn handle_mcp_login(&mut self, name: &str) -> anyhow::Result<()> { + let server_name = forge_api::ServerName::from(name.to_string()); + let config = self.api.read_mcp_config(None).await?; + let server = config.mcp_servers.get(&server_name); + + match server { + Some(forge_domain::McpServerConfig::Http(http)) => { + // Check auth status first + let status = self.api.mcp_auth_status(&http.url).await?; + if status == "authenticated" { + self.writeln_title(TitleFormat::info( + format!("MCP server '{}' is already authenticated. Use 'mcp logout {}' first to re-authenticate.", name, name) + ))?; + return Ok(()); + } + + // Force re-auth by removing any stale credentials + let _ = self.api.mcp_logout(Some(&http.url)).await; + + // Run the OAuth flow (opens browser, waits for callback) + match self.api.mcp_auth(&http.url).await { + Ok(()) => { + self.writeln_title(TitleFormat::info(format!( + "Successfully authenticated with MCP server '{}'", + name + )))?; + // Reload MCP to reconnect with new credentials + self.spinner.start(Some("Reloading MCPs"))?; + match self.api.reload_mcp().await { + Ok(()) => { + self.writeln_title(TitleFormat::info("MCP reloaded"))?; + } + Err(e) => { + self.writeln_title(TitleFormat::error(format!( + "MCP reload failed: {}", + e + )))?; + } + } + } + Err(e) => { + self.writeln_title(TitleFormat::error(format!( + "Authentication with MCP server '{}' failed: {}", + name, e + )))?; + } + } + } + Some(_) => { + self.writeln_title(TitleFormat::error(format!( + "MCP server '{}' is not an HTTP server (OAuth only applies to HTTP servers)", + name + )))?; + } + None => { + self.writeln_title(TitleFormat::error(format!( + "MCP server '{}' not found. Use 'mcp list' to see available servers.", + name + )))?; + } + } + Ok(()) + } + + /// Handle `mcp logout ` command. + /// + /// Removes stored OAuth credentials for the specified MCP server + /// or all servers if "all" is specified. + /// Automatically reloads MCPs after logout to reflect auth state change. + async fn handle_mcp_logout(&mut self, name: &str) -> anyhow::Result<()> { + if name == "all" { + self.api.mcp_logout(None).await?; + self.writeln_title(TitleFormat::info("Removed all MCP OAuth credentials"))?; + } else { + let server_name = forge_api::ServerName::from(name.to_string()); + let config = self.api.read_mcp_config(None).await?; + let server = config.mcp_servers.get(&server_name); + + match server { + Some(forge_domain::McpServerConfig::Http(http)) => { + self.api.mcp_logout(Some(&http.url)).await?; + self.writeln_title(TitleFormat::info(format!( + "Removed OAuth credentials for MCP server '{}'", + name + )))?; + } + Some(_) => { + self.writeln_title(TitleFormat::error(format!( + "MCP server '{}' is not an HTTP server", + name + )))?; + return Ok(()); + } + None => { + self.writeln_title(TitleFormat::error(format!( + "MCP server '{}' not found. Use 'mcp list' to see available servers.", + name + )))?; + return Ok(()); + } + } + } + + // Reload MCPs to reflect auth state change + self.spinner.start(Some("Reloading MCPs"))?; + self.api.reload_mcp().await?; + self.writeln_title(TitleFormat::info("MCP reloaded"))?; + + Ok(()) + } + async fn handle_provider_command( &mut self, provider_group: crate::cli::ProviderCommandGroup, diff --git a/crates/forge_repo/src/forge_repo.rs b/crates/forge_repo/src/forge_repo.rs index be07f80bf7..229989738d 100644 --- a/crates/forge_repo/src/forge_repo.rs +++ b/crates/forge_repo/src/forge_repo.rs @@ -451,8 +451,9 @@ where &self, config: McpServerConfig, env_vars: &BTreeMap, + environment: &Environment, ) -> anyhow::Result { - self.infra.connect(config, env_vars).await + self.infra.connect(config, env_vars, environment).await } } diff --git a/crates/forge_services/src/mcp/service.rs b/crates/forge_services/src/mcp/service.rs index f03725edf0..dcae396d82 100644 --- a/crates/forge_services/src/mcp/service.rs +++ b/crates/forge_services/src/mcp/service.rs @@ -86,7 +86,8 @@ where config: McpServerConfig, ) -> anyhow::Result<()> { let env_vars = self.infra.get_env_vars(); - let client = self.infra.connect(config, &env_vars).await?; + let environment = self.infra.get_environment(); + let client = self.infra.connect(config, &env_vars, &environment).await?; let client = Arc::new(C::from(client)); self.insert_clients(server_name, client).await?; @@ -179,11 +180,16 @@ where tool.executable.call_tool(call.arguments.parse()?).await } - /// Refresh the MCP cache by fetching fresh data + /// Refresh the MCP cache by clearing cached data. + /// Does NOT eagerly connect to servers - connections happen lazily + /// when list() or call() is invoked, avoiding interactive OAuth during + /// reload. async fn refresh_cache(&self) -> anyhow::Result<()> { - // Fetch fresh tools by calling list() which connects to MCPs + // Clear the infra cache and reset config hash to force re-init on next access self.infra.cache_clear().await?; - let _ = self.get_mcp_servers().await?; + *self.previous_config_hash.lock().await = Default::default(); + self.clear_tools().await; + self.failed_servers.write().await.clear(); Ok(()) } }