diff --git a/crates/forge_api/src/api.rs b/crates/forge_api/src/api.rs index 3ea061cbdc..dfd144cac5 100644 --- a/crates/forge_api/src/api.rs +++ b/crates/forge_api/src/api.rs @@ -23,8 +23,11 @@ pub trait API: Sync + Send { /// Provides a list of models available in the current environment async fn get_models(&self) -> Result>; - /// Provides models from all configured providers. Providers that fail to - /// return models are silently skipped. + /// Provides models from all configured providers. Providers that + /// successfully return models are included in the result. If every + /// configured provider fails (e.g. due to an invalid API key), the + /// first error is returned so the caller sees the real underlying cause + /// rather than an empty list. async fn get_all_provider_models(&self) -> Result>; /// Provides a list of agents available in the current environment diff --git a/crates/forge_app/src/app.rs b/crates/forge_app/src/app.rs index 13304e911d..4b90eed73b 100644 --- a/crates/forge_app/src/app.rs +++ b/crates/forge_app/src/app.rs @@ -282,13 +282,14 @@ impl> ForgeAp /// Gets available models from all configured providers concurrently. /// - /// Returns a list of `ProviderModels` for each configured provider. - /// All providers are queried in parallel; providers that fail to - /// return models are silently skipped. + /// Returns a list of `ProviderModels` for each configured provider that + /// successfully returned models. If every configured provider fails (e.g. + /// due to an invalid API key), the first error encountered is returned so + /// the caller receives the real underlying cause rather than an empty list. pub async fn get_all_provider_models(&self) -> Result> { let all_providers = self.services.get_all_providers().await?; - // Build one future per configured provider + // Build one future per configured provider, preserving the error on failure. let futures: Vec<_> = all_providers .into_iter() .filter_map(|any_provider| any_provider.into_configured()) @@ -296,24 +297,24 @@ impl> ForgeAp let provider_id = provider.id.clone(); let services = self.services.clone(); async move { - let refreshed = services - .provider_auth_service() - .refresh_provider_credential(provider) - .await - .ok()?; - let models = services.models(refreshed).await.ok()?; - Some(ProviderModels { provider_id, models }) + let result: Result = async { + let refreshed = services + .provider_auth_service() + .refresh_provider_credential(provider) + .await?; + let models = services.models(refreshed).await?; + Ok(ProviderModels { provider_id, models }) + } + .await; + result } }) .collect(); - // Execute all provider fetches concurrently and collect successful results - let results = futures::future::join_all(futures) + // Execute all provider fetches concurrently. + futures::future::join_all(futures) .await .into_iter() - .flatten() - .collect(); - - Ok(results) + .collect::>>() } } diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 2f85b0920b..ee1919b6d9 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -189,10 +189,8 @@ pub trait AppConfigService: Send + Sync { /// provider. /// /// # Errors - /// - Returns `Error::NoDefaultProvider` when no active provider is set and - /// provider_id is None - /// - Returns `Error::NoDefaultModel` when no model is configured for the - /// provider + /// - Returns `Error::NoDefaultSession` when no provider and model are + /// configured. async fn get_provider_model( &self, provider_id: Option<&forge_domain::ProviderId>, diff --git a/crates/forge_config/src/legacy.rs b/crates/forge_config/src/legacy.rs index 22f35ce52b..e8699ff9a3 100644 --- a/crates/forge_config/src/legacy.rs +++ b/crates/forge_config/src/legacy.rs @@ -57,18 +57,24 @@ impl LegacyConfig { /// Converts a [`LegacyConfig`] into the fields of [`ForgeConfig`] that it /// covers, leaving all other fields at their defaults (`None`). fn into_forge_config(self) -> ForgeConfig { - let session = self.provider.as_deref().map(|provider_id| { - let model_id = self.model.get(provider_id).cloned(); - ModelConfig { provider_id: Some(provider_id.to_string()), model_id } + let session = self.provider.as_deref().and_then(|provider_id| { + self.model + .get(provider_id) + .cloned() + .map(|model_id| ModelConfig { provider_id: provider_id.to_string(), model_id }) }); - let commit = self - .commit - .map(|c| ModelConfig { provider_id: c.provider, model_id: c.model }); + let commit = self.commit.and_then(|c| { + c.provider + .zip(c.model) + .map(|(provider_id, model_id)| ModelConfig { provider_id, model_id }) + }); - let suggest = self - .suggest - .map(|s| ModelConfig { provider_id: s.provider, model_id: s.model }); + let suggest = self.suggest.and_then(|s| { + s.provider + .zip(s.model) + .map(|(provider_id, model_id)| ModelConfig { provider_id, model_id }) + }); ForgeConfig { session, commit, suggest, ..Default::default() } } diff --git a/crates/forge_config/src/model.rs b/crates/forge_config/src/model.rs index c993222700..e48d8e7c4b 100644 --- a/crates/forge_config/src/model.rs +++ b/crates/forge_config/src/model.rs @@ -9,13 +9,18 @@ pub type ProviderId = String; pub type ModelId = String; /// Pairs a provider and model together for a specific operation. -#[derive( - Default, Debug, Setters, Clone, PartialEq, Serialize, Deserialize, JsonSchema, fake::Dummy, -)] -#[setters(strip_option, into)] +#[derive(Debug, Setters, Clone, PartialEq, Serialize, Deserialize, JsonSchema, fake::Dummy)] +#[setters(into)] pub struct ModelConfig { /// The provider to use for this operation. - pub provider_id: Option, + pub provider_id: String, /// The model to use for this operation. - pub model_id: Option, + pub model_id: String, +} + +impl ModelConfig { + /// Creates a new ModelConfig with the given provider and model IDs. + pub fn new(provider_id: impl Into, model_id: impl Into) -> Self { + Self { provider_id: provider_id.into(), model_id: model_id.into() } + } } diff --git a/crates/forge_config/src/reader.rs b/crates/forge_config/src/reader.rs index 54a09375c7..6ce724cbda 100644 --- a/crates/forge_config/src/reader.rs +++ b/crates/forge_config/src/reader.rs @@ -198,8 +198,8 @@ mod tests { // it on top of the embedded defaults. The default values must survive. let legacy = ForgeConfig { session: Some(ModelConfig { - provider_id: Some("anthropic".to_string()), - model_id: Some("claude-3".to_string()), + provider_id: "anthropic".to_string(), + model_id: "claude-3".to_string(), }), ..Default::default() }; @@ -216,8 +216,8 @@ mod tests { assert_eq!( actual.session, Some(ModelConfig { - provider_id: Some("anthropic".to_string()), - model_id: Some("claude-3".to_string()), + provider_id: "anthropic".to_string(), + model_id: "claude-3".to_string(), }) ); @@ -243,8 +243,8 @@ mod tests { .unwrap(); let expected = Some(ModelConfig { - provider_id: Some("fake-provider".to_string()), - model_id: Some("fake-model".to_string()), + provider_id: "fake-provider".to_string(), + model_id: "fake-model".to_string(), }); assert_eq!(actual.session, expected); } diff --git a/crates/forge_domain/src/env.rs b/crates/forge_domain/src/env.rs index adbadd6ca2..324e3f70f5 100644 --- a/crates/forge_domain/src/env.rs +++ b/crates/forge_domain/src/env.rs @@ -7,19 +7,6 @@ use serde::{Deserialize, Serialize}; use crate::{Effort, ModelConfig}; -/// Domain-level session configuration pairing a provider with a model. -/// -/// Used to represent an active session, decoupled from the on-disk -/// configuration format. -#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, Setters)] -#[setters(strip_option, into)] -pub struct SessionConfig { - /// The active provider ID (e.g. `"anthropic"`). - pub provider_id: Option, - /// The model ID to use with this provider. - pub model_id: Option, -} - /// All discrete mutations that can be applied to the application configuration. /// /// Instead of replacing the entire config, callers describe exactly which field diff --git a/crates/forge_domain/src/error.rs b/crates/forge_domain/src/error.rs index 3e34502b84..02d8f60529 100644 --- a/crates/forge_domain/src/error.rs +++ b/crates/forge_domain/src/error.rs @@ -102,12 +102,8 @@ pub enum Error { #[error("Failed to sync {count} file(s)")] SyncFailed { count: usize }, - #[error("No default provider set.")] - NoDefaultProvider, - - #[error("No default model configured for provider: {0}")] - #[from(skip)] - NoDefaultModel(ProviderId), + #[error("No default provider and model configured.")] + NoDefaultSession, } pub type Result = std::result::Result; @@ -149,10 +145,6 @@ impl Error { Self::VertexAiConfiguration { message: message.into() } } - pub fn no_default_model(provider: ProviderId) -> Self { - Self::NoDefaultModel(provider) - } - pub fn sync_failed(count: usize) -> Self { Self::SyncFailed { count } } diff --git a/crates/forge_infra/src/env.rs b/crates/forge_infra/src/env.rs index d609db1775..cb50e27558 100644 --- a/crates/forge_infra/src/env.rs +++ b/crates/forge_infra/src/env.rs @@ -35,24 +35,18 @@ fn apply_config_op(fc: &mut ForgeConfig, op: ConfigOperation) { ConfigOperation::SetSessionConfig(mc) => { let pid_str = mc.provider.as_ref().to_string(); let mid_str = mc.model.to_string(); - let session = fc.session.get_or_insert_with(ModelConfig::default); - if session.provider_id.as_deref() == Some(&pid_str) { - session.model_id = Some(mid_str); - } else { - fc.session = - Some(ModelConfig { provider_id: Some(pid_str), model_id: Some(mid_str) }); - } + fc.session = Some(ModelConfig { provider_id: pid_str, model_id: mid_str }); } ConfigOperation::SetCommitConfig(mc) => { fc.commit = mc.map(|m| ModelConfig { - provider_id: Some(m.provider.as_ref().to_string()), - model_id: Some(m.model.to_string()), + provider_id: m.provider.as_ref().to_string(), + model_id: m.model.to_string(), }); } ConfigOperation::SetSuggestConfig(mc) => { fc.suggest = Some(ModelConfig { - provider_id: Some(mc.provider.as_ref().to_string()), - model_id: Some(mc.model.to_string()), + provider_id: mc.provider.as_ref().to_string(), + model_id: mc.model.to_string(), }); } ConfigOperation::SetReasoningEffort(effort) => { @@ -236,25 +230,22 @@ mod tests { )), ); - let actual_provider = fixture - .session - .as_ref() - .and_then(|s| s.provider_id.as_deref()); - let actual_model = fixture.session.as_ref().and_then(|s| s.model_id.as_deref()); + let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str()); + let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str()); assert_eq!(actual_provider, Some("anthropic")); assert_eq!(actual_model, Some("claude-3-5-sonnet")); } #[test] - fn test_apply_config_op_set_model_matching_provider() { + fn test_apply_config_op_set_session_config_replaces_existing() { use forge_config::ModelConfig as ForgeCfgModelConfig; use forge_domain::{ModelConfig as DomainModelConfig, ModelId, ProviderId}; let mut fixture = ForgeConfig { session: Some(ForgeCfgModelConfig { - provider_id: Some("anthropic".to_string()), - model_id: None, + provider_id: "openai".to_string(), + model_id: "gpt-4".to_string(), }), ..Default::default() }; @@ -267,24 +258,18 @@ mod tests { )), ); - let actual = fixture.session.as_ref().and_then(|s| s.model_id.as_deref()); - let expected = Some("claude-3-5-sonnet-20241022"); + let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str()); + let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str()); - assert_eq!(actual, expected); + assert_eq!(actual_provider, Some("anthropic")); + assert_eq!(actual_model, Some("claude-3-5-sonnet-20241022")); } #[test] - fn test_apply_config_op_set_model_different_provider_replaces_session() { - use forge_config::ModelConfig as ForgeCfgModelConfig; + fn test_apply_config_op_set_session_config_creates_new_session() { use forge_domain::{ModelConfig as DomainModelConfig, ModelId, ProviderId}; - let mut fixture = ForgeConfig { - session: Some(ForgeCfgModelConfig { - provider_id: Some("openai".to_string()), - model_id: Some("gpt-4".to_string()), - }), - ..Default::default() - }; + let mut fixture = ForgeConfig::default(); apply_config_op( &mut fixture, @@ -294,11 +279,8 @@ mod tests { )), ); - let actual_provider = fixture - .session - .as_ref() - .and_then(|s| s.provider_id.as_deref()); - let actual_model = fixture.session.as_ref().and_then(|s| s.model_id.as_deref()); + let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str()); + let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str()); assert_eq!(actual_provider, Some("anthropic")); assert_eq!(actual_model, Some("claude-3-5-sonnet-20241022")); diff --git a/crates/forge_main/src/cli.rs b/crates/forge_main/src/cli.rs index e90049cb08..01d5b56f77 100644 --- a/crates/forge_main/src/cli.rs +++ b/crates/forge_main/src/cli.rs @@ -529,20 +529,12 @@ pub struct ConfigGetArgs { /// Type-safe subcommands for `forge config set`. #[derive(Subcommand, Debug, Clone)] pub enum ConfigSetField { - /// Set the active model. + /// Set the active model and provider atomically. Model { - /// Model ID to set as default. - model: ModelId, - }, - /// Set the active provider. - Provider { /// Provider ID to set as default. provider: ProviderId, - - /// Optional model ID to set simultaneously, skipping interactive model - /// selection. - #[arg(long)] - model: Option, + /// Model ID to set as default. + model: ModelId, }, /// Set the provider and model for commit message generation. Commit { @@ -850,64 +842,21 @@ mod tests { assert_eq!(actual, expected); } - #[test] - fn test_config_set_with_model() { - let fixture = Cli::parse_from([ - "forge", - "config", - "set", - "model", - "anthropic/claude-sonnet-4", - ]); - let actual = match fixture.subcommands { - Some(TopLevelCommand::Config(config)) => match config.command { - ConfigCommand::Set(args) => match args.field { - ConfigSetField::Model { model } => Some(model.as_str().to_string()), - _ => None, - }, - _ => None, - }, - _ => None, - }; - let expected = Some("anthropic/claude-sonnet-4".to_string()); - assert_eq!(actual, expected); - } - - #[test] - fn test_config_set_with_provider() { - let fixture = Cli::parse_from(["forge", "config", "set", "provider", "OpenAI"]); - let actual = match fixture.subcommands { - Some(TopLevelCommand::Config(config)) => match config.command { - ConfigCommand::Set(args) => match args.field { - ConfigSetField::Provider { provider, model } => { - Some((provider.to_string(), model)) - } - _ => None, - }, - _ => None, - }, - _ => None, - }; - let expected = Some(("OpenAi".to_string(), None)); - assert_eq!(actual, expected); - } - #[test] fn test_config_set_with_provider_and_model() { let fixture = Cli::parse_from([ "forge", "config", "set", - "provider", + "model", "anthropic", - "--model", "claude-sonnet-4-20250514", ]); let actual = match fixture.subcommands { Some(TopLevelCommand::Config(config)) => match config.command { ConfigCommand::Set(args) => match args.field { - ConfigSetField::Provider { provider, model } => { - Some((provider.to_string(), model.map(|m| m.as_str().to_string()))) + ConfigSetField::Model { provider, model } => { + Some((provider.to_string(), model.as_str().to_string())) } _ => None, }, @@ -917,7 +866,7 @@ mod tests { }; let expected = Some(( "Anthropic".to_string(), - Some("claude-sonnet-4-20250514".to_string()), + "claude-sonnet-4-20250514".to_string(), )); assert_eq!(actual, expected); } diff --git a/crates/forge_main/src/model.rs b/crates/forge_main/src/model.rs index 61abc4e712..f81e11bd95 100644 --- a/crates/forge_main/src/model.rs +++ b/crates/forge_main/src/model.rs @@ -92,6 +92,7 @@ impl ForgeCommandManager { | "dump" | "model" | "tools" + | "provider" | "login" | "logout" | "retry" @@ -269,10 +270,9 @@ impl ForgeCommandManager { "/sage" => Ok(SlashCommand::Sage), "/help" => Ok(SlashCommand::Help), "/model" => Ok(SlashCommand::Model), - "/provider" => Ok(SlashCommand::Provider), + "/provider" | "/login" => Ok(SlashCommand::Login), "/tools" => Ok(SlashCommand::Tools), "/agent" => Ok(SlashCommand::Agent), - "/login" => Ok(SlashCommand::Login), "/logout" => Ok(SlashCommand::Logout), "/retry" => Ok(SlashCommand::Retry), "/conversation" | "/conversations" => Ok(SlashCommand::Conversations), @@ -391,10 +391,6 @@ pub enum SlashCommand { /// This can be triggered with the '/model' command. #[strum(props(usage = "Switch to a different model"))] Model, - /// Switch or select the active provider - /// This can be triggered with the '/provider' command. - #[strum(props(usage = "Switch to a different provider"))] - Provider, /// List all available tools with their descriptions and schema /// This can be triggered with the '/tools' command. #[strum(props(usage = "List all available tools with their descriptions and schema"))] @@ -469,7 +465,6 @@ impl SlashCommand { SlashCommand::Commit { .. } => "commit", SlashCommand::Dump { .. } => "dump", SlashCommand::Model => "model", - SlashCommand::Provider => "provider", SlashCommand::Tools => "tools", SlashCommand::Custom(event) => &event.name, SlashCommand::Shell(_) => "!shell", diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 618b8f2fba..5ba5de69e4 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -1969,10 +1969,7 @@ impl A + Send + Sync> UI self.on_custom_event(event.into()).await?; } SlashCommand::Model => { - self.on_model_selection(None, None).await?; - } - SlashCommand::Provider => { - self.on_provider_selection().await?; + self.on_model_selection(None).await?; } SlashCommand::Shell(ref command) => { self.api.execute_shell_command_raw(command).await?; @@ -2136,13 +2133,14 @@ impl A + Send + Sync> UI /// selected the model list is scoped to that provider only. /// /// # Returns - /// - `Ok(Some(ModelId))` if a model was selected + /// - `Ok(Some((ModelId, ProviderId)))` if a model was selected, carrying + /// both the model and the provider it belongs to /// - `Ok(None)` if selection was canceled #[async_recursion::async_recursion] async fn select_model( &mut self, provider_filter: Option, - ) -> Result> { + ) -> Result> { // Check if provider is set otherwise first ask to select a provider if provider_filter.is_none() && self.api.get_default_provider().await.is_err() { if !self.on_provider_selection().await? { @@ -2242,21 +2240,22 @@ impl A + Send + Sync> UI return Ok(None); } - // Build a flat list of (ModelId, display_line) for the data rows. + // Build a flat list of (ModelId, ProviderId) for the data rows. // The first line is the header; data rows follow in the same order as // the Info entries (sorted by provider, then model within provider). - let mut model_ids: Vec = Vec::new(); + let mut model_entries: Vec<(ModelId, ProviderId)> = Vec::new(); for pm in &all_provider_models { for model in &pm.models { - model_ids.push(model.id.clone()); + model_entries.push((model.id.clone(), pm.provider_id.clone())); } } // Create display items: header line first, then data lines paired with - // model IDs. + // model and provider IDs. #[derive(Clone)] struct ModelRow { model_id: Option, + provider_id: Option, display: String, } impl std::fmt::Display for ModelRow { @@ -2267,11 +2266,17 @@ impl A + Send + Sync> UI let mut rows: Vec = Vec::with_capacity(all_lines.len()); // Header row (non-selectable via header_lines=1) - rows.push(ModelRow { model_id: None, display: all_lines[0].to_string() }); + rows.push(ModelRow { + model_id: None, + provider_id: None, + display: all_lines[0].to_string(), + }); // Data rows for (i, line) in all_lines.iter().skip(1).enumerate() { + let entry = model_entries.get(i); rows.push(ModelRow { - model_id: model_ids.get(i).cloned(), + model_id: entry.map(|(m, _)| m.clone()), + provider_id: entry.map(|(_, p)| p.clone()), display: line.to_string(), }); } @@ -2284,7 +2289,7 @@ impl A + Send + Sync> UI .await; let starting_cursor = current_model .as_ref() - .and_then(|current| model_ids.iter().position(|id| id == current)) + .and_then(|current| model_entries.iter().position(|(id, _)| id == current)) .unwrap_or(0); match ForgeWidget::select("Model", rows) @@ -2292,7 +2297,7 @@ impl A + Send + Sync> UI .with_header_lines(1) .prompt()? { - Some(row) => Ok(row.model_id), + Some(row) => Ok(row.model_id.zip(row.provider_id)), None => Ok(None), } } @@ -2475,7 +2480,7 @@ impl A + Send + Sync> UI async fn display_credential_success(&mut self, provider_id: ProviderId) -> anyhow::Result<()> { self.writeln_title(TitleFormat::info(format!( - "{provider_id} configured successfully!" + "{provider_id} configured successfully" )))?; Ok(()) @@ -2681,10 +2686,11 @@ impl A + Send + Sync> UI } } + // Verify by fetching the configured provider + let provider = self.api.get_provider(&provider_id).await?; + self.display_credential_success(provider_id.clone()).await?; - // Fetch and return the configured provider - let provider = self.api.get_provider(&provider_id).await?; Ok(provider.into_configured()) } @@ -2807,37 +2813,28 @@ impl A + Send + Sync> UI // Helper method to handle model selection and update the conversation. // When `provider_filter` is `Some`, only models from that provider are shown. + // The model and provider returned by the selector are always set as one + // atomic operation. #[async_recursion::async_recursion] async fn on_model_selection( &mut self, provider_filter: Option, - provider_to_activate: Option, ) -> Result> { - // Select a model - let model_option = self.select_model(provider_filter).await?; + // Select a model; the selector returns both the model and its provider + let selection = self.select_model(provider_filter).await?; // If no model was selected (user canceled), return early - let model = match model_option { - Some(model) => model, + let (model, provider_id) = match selection { + Some(pair) => pair, None => return Ok(None), }; - // If we have a provider to activate, write both atomically - if let Some(provider_id) = provider_to_activate { - self.api - .update_config(vec![ConfigOperation::SetSessionConfig( - forge_domain::ModelConfig::new(provider_id, model.clone()), - )]) - .await?; - } else { - // Resolve the active provider so we can build a SetModel op - let provider_id = self.api.get_default_provider().await?.id; - self.api - .update_config(vec![ConfigOperation::SetSessionConfig( - forge_domain::ModelConfig::new(provider_id, model.clone()), - )]) - .await?; - } + // Set model and provider atomically as a single config operation + self.api + .update_config(vec![ConfigOperation::SetSessionConfig( + forge_domain::ModelConfig::new(provider_id, model.clone()), + )]) + .await?; // Update the UI state with the new model self.update_model(Some(model.clone())); @@ -2946,10 +2943,7 @@ impl A + Send + Sync> UI }; if needs_model_selection { - self.writeln_title(TitleFormat::info("Please select a new model"))?; - let selected = self - .on_model_selection(Some(provider.id.clone()), Some(provider.id.clone())) - .await?; + let selected = self.on_model_selection(Some(provider.id.clone())).await?; if selected.is_none() { // User cancelled — preserve existing config untouched return Ok(()); @@ -3086,7 +3080,7 @@ impl A + Send + Sync> UI let mut operating_model = self.get_agent_model(active_agent.clone()).await; if operating_model.is_none() { // Use the model returned from selection instead of re-fetching - operating_model = self.on_model_selection(None, None).await?; + operating_model = self.on_model_selection(None).await?; } if first { @@ -3610,22 +3604,10 @@ impl A + Send + Sync> UI use crate::cli::ConfigSetField; match args.field { - ConfigSetField::Provider { provider, model } => { + ConfigSetField::Model { provider, model } => { let provider = self.api.get_provider(&provider).await?; - self.activate_provider_with_model(provider, model).await?; - } - ConfigSetField::Model { model } => { - let model_id = self.validate_model(model.as_str(), None).await?; - // Resolve the active provider so we can build a SetModel op - let provider_id = self.api.get_default_provider().await?.id; - self.api - .update_config(vec![ConfigOperation::SetSessionConfig( - forge_domain::ModelConfig::new(provider_id, model_id.clone()), - )]) + self.activate_provider_with_model(provider, Some(model)) .await?; - self.writeln_title( - TitleFormat::action(model_id.as_str()).sub_title("is now the default model"), - )?; } ConfigSetField::Commit { provider, model } => { // Validate provider exists and model belongs to that specific provider diff --git a/crates/forge_services/src/agent_registry.rs b/crates/forge_services/src/agent_registry.rs index 37a801c09f..b9df201ef1 100644 --- a/crates/forge_services/src/agent_registry.rs +++ b/crates/forge_services/src/agent_registry.rs @@ -73,19 +73,9 @@ impl> /// defaults. async fn load_agents(&self) -> anyhow::Result> { let config = self.repository.get_config()?; - let session = config.session.as_ref().ok_or(Error::NoDefaultProvider)?; - let provider_id = session - .provider_id - .as_ref() - .map(|id| ProviderId::from(id.clone())) - .ok_or(Error::NoDefaultProvider)?; - let model_id = session - .model_id - .as_ref() - .map(|id| ModelId::new(id.clone())) - .ok_or_else(|| { - anyhow::anyhow!("No default model configured for provider {}", provider_id) - })?; + let session = config.session.as_ref().ok_or(Error::NoDefaultSession)?; + let provider_id = ProviderId::from(session.provider_id.clone()); + let model_id = ModelId::new(session.model_id.clone()); let agents = self.repository.get_agents(provider_id, model_id).await?; diff --git a/crates/forge_services/src/app_config.rs b/crates/forge_services/src/app_config.rs index 2e908d0f45..45551d63f7 100644 --- a/crates/forge_services/src/app_config.rs +++ b/crates/forge_services/src/app_config.rs @@ -28,9 +28,8 @@ impl id, - None => active_provider - .as_ref() - .ok_or(forge_domain::Error::NoDefaultProvider)?, + // Use the requested provider or the session's active provider + let requested_provider = match provider_id { + Some(id) => id.as_ref(), + None => session.provider_id.as_str(), }; - // Only return the model if the session's provider matches the requested - // provider - if session.provider_id.as_deref() == Some(provider_id.as_ref()) { - session - .model_id - .as_ref() - .map(ModelId::new) - .ok_or_else(|| forge_domain::Error::no_default_model(provider_id.clone()).into()) + // Return the session's model if the provider matches + if session.provider_id == requested_provider { + Ok(ModelId::new(session.model_id.clone())) } else { - Err(forge_domain::Error::no_default_model(provider_id.clone()).into()) + Err(forge_domain::Error::NoDefaultSession.into()) } } async fn get_commit_config(&self) -> anyhow::Result> { let config = self.infra.get_config()?; - Ok(config.commit.clone().and_then(|mc| { - mc.provider_id - .zip(mc.model_id) - .map(|(pid, mid)| ModelConfig { - provider: ProviderId::from(pid), - model: ModelId::new(mid), - }) + Ok(config.commit.clone().map(|mc| ModelConfig { + provider: ProviderId::from(mc.provider_id), + model: ModelId::new(mc.model_id), })) } async fn get_suggest_config(&self) -> anyhow::Result> { let config = self.infra.get_config()?; - Ok(config.suggest.clone().and_then(|mc| { - mc.provider_id - .zip(mc.model_id) - .map(|(pid, mid)| ModelConfig { - provider: ProviderId::from(pid), - model: ModelId::new(mid), - }) + Ok(config.suggest.clone().map(|mc| ModelConfig { + provider: ProviderId::from(mc.provider_id), + model: ModelId::new(mc.model_id), })) } @@ -227,30 +207,21 @@ mod tests { ConfigOperation::SetSessionConfig(mc) => { let pid_str = mc.provider.as_ref().to_string(); let mid_str = mc.model.to_string(); - config.session = Some(match config.session.take() { - Some(existing) - if existing.provider_id.as_deref() == Some(&pid_str) => - { - existing.model_id(mid_str) - } - _ => ModelConfig::default() - .provider_id(pid_str) - .model_id(mid_str), - }); + config.session = Some(ModelConfig::new(pid_str, mid_str)); } ConfigOperation::SetCommitConfig(mc) => { config.commit = mc.map(|m| { - ModelConfig::default() - .provider_id(m.provider.as_ref().to_string()) - .model_id(m.model.to_string()) + ModelConfig::new( + m.provider.as_ref().to_string(), + m.model.to_string(), + ) }); } ConfigOperation::SetSuggestConfig(mc) => { - config.suggest = Some( - ModelConfig::default() - .provider_id(mc.provider.as_ref().to_string()) - .model_id(mc.model.to_string()), - ); + config.suggest = Some(ModelConfig::new( + mc.provider.as_ref().to_string(), + mc.model.to_string(), + )); } ConfigOperation::SetReasoningEffort(_) => { // No-op in tests diff --git a/forge.schema.json b/forge.schema.json index 7420e9f736..0355e81aba 100644 --- a/forge.schema.json +++ b/forge.schema.json @@ -585,19 +585,17 @@ "properties": { "model_id": { "description": "The model to use for this operation.", - "type": [ - "string", - "null" - ] + "type": "string" }, "provider_id": { "description": "The provider to use for this operation.", - "type": [ - "string", - "null" - ] + "type": "string" } - } + }, + "required": [ + "provider_id", + "model_id" + ] }, "ProviderAuthMethod": { "description": "Authentication method supported by a provider.\n\nOnly the simple (non-OAuth) methods are available here; providers that\nrequire OAuth device or authorization-code flows must be configured via the\nfile-based `provider.json` override instead.", diff --git a/shell-plugin/lib/actions/config.zsh b/shell-plugin/lib/actions/config.zsh index 0ddd511005..a6d5f51479 100644 --- a/shell-plugin/lib/actions/config.zsh +++ b/shell-plugin/lib/actions/config.zsh @@ -155,10 +155,10 @@ function _forge_action_model() { # Switch provider first if it differs from the current one # current_provider (fetched above) is the display name, compare against that if [[ -n "$provider_display" && "$provider_display" != "$current_provider" ]]; then - _forge_exec_interactive config set provider "$provider_id" --model "$model_id" + _forge_exec_interactive config set model "$provider_id" "$model_id" return fi - _forge_exec config set model "$model_id" + _forge_exec config set model "$provider_id" "$model_id" fi ) }