Skip to content
Merged
7 changes: 5 additions & 2 deletions crates/forge_api/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ pub trait API: Sync + Send {
/// Provides a list of models available in the current environment
async fn get_models(&self) -> Result<Vec<Model>>;

/// Provides models from all configured providers. Providers that fail to
/// return models are silently skipped.
/// Provides models from all configured providers. Providers that
/// successfully return models are included in the result. If every
/// configured provider fails (e.g. due to an invalid API key), the
/// first error is returned so the caller sees the real underlying cause
/// rather than an empty list.
async fn get_all_provider_models(&self) -> Result<Vec<ProviderModels>>;

/// Provides a list of agents available in the current environment
Expand Down
35 changes: 18 additions & 17 deletions crates/forge_app/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,38 +282,39 @@ impl<S: Services + EnvironmentInfra<Config = forge_config::ForgeConfig>> ForgeAp

/// Gets available models from all configured providers concurrently.
///
/// Returns a list of `ProviderModels` for each configured provider.
/// All providers are queried in parallel; providers that fail to
/// return models are silently skipped.
/// Returns a list of `ProviderModels` for each configured provider that
/// successfully returned models. If every configured provider fails (e.g.
/// due to an invalid API key), the first error encountered is returned so
/// the caller receives the real underlying cause rather than an empty list.
pub async fn get_all_provider_models(&self) -> Result<Vec<ProviderModels>> {
let all_providers = self.services.get_all_providers().await?;

// Build one future per configured provider
// Build one future per configured provider, preserving the error on failure.
let futures: Vec<_> = all_providers
.into_iter()
.filter_map(|any_provider| any_provider.into_configured())
.map(|provider| {
let provider_id = provider.id.clone();
let services = self.services.clone();
async move {
let refreshed = services
.provider_auth_service()
.refresh_provider_credential(provider)
.await
.ok()?;
let models = services.models(refreshed).await.ok()?;
Some(ProviderModels { provider_id, models })
let result: Result<ProviderModels> = async {
let refreshed = services
.provider_auth_service()
.refresh_provider_credential(provider)
.await?;
let models = services.models(refreshed).await?;
Ok(ProviderModels { provider_id, models })
}
.await;
result
}
})
.collect();

// Execute all provider fetches concurrently and collect successful results
let results = futures::future::join_all(futures)
// Execute all provider fetches concurrently.
futures::future::join_all(futures)
.await
.into_iter()
.flatten()
.collect();

Ok(results)
.collect::<anyhow::Result<Vec<_>>>()
}
}
6 changes: 2 additions & 4 deletions crates/forge_app/src/services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,8 @@ pub trait AppConfigService: Send + Sync {
/// provider.
///
/// # Errors
/// - Returns `Error::NoDefaultProvider` when no active provider is set and
/// provider_id is None
/// - Returns `Error::NoDefaultModel` when no model is configured for the
/// provider
/// - Returns `Error::NoDefaultSession` when no provider and model are
/// configured.
async fn get_provider_model(
&self,
provider_id: Option<&forge_domain::ProviderId>,
Expand Down
24 changes: 15 additions & 9 deletions crates/forge_config/src/legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,24 @@ impl LegacyConfig {
/// Converts a [`LegacyConfig`] into the fields of [`ForgeConfig`] that it
/// covers, leaving all other fields at their defaults (`None`).
fn into_forge_config(self) -> ForgeConfig {
let session = self.provider.as_deref().map(|provider_id| {
let model_id = self.model.get(provider_id).cloned();
ModelConfig { provider_id: Some(provider_id.to_string()), model_id }
let session = self.provider.as_deref().and_then(|provider_id| {
self.model
.get(provider_id)
.cloned()
.map(|model_id| ModelConfig { provider_id: provider_id.to_string(), model_id })
});

let commit = self
.commit
.map(|c| ModelConfig { provider_id: c.provider, model_id: c.model });
let commit = self.commit.and_then(|c| {
c.provider
.zip(c.model)
.map(|(provider_id, model_id)| ModelConfig { provider_id, model_id })
});

let suggest = self
.suggest
.map(|s| ModelConfig { provider_id: s.provider, model_id: s.model });
let suggest = self.suggest.and_then(|s| {
s.provider
.zip(s.model)
.map(|(provider_id, model_id)| ModelConfig { provider_id, model_id })
});

ForgeConfig { session, commit, suggest, ..Default::default() }
}
Expand Down
17 changes: 11 additions & 6 deletions crates/forge_config/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@ pub type ProviderId = String;
pub type ModelId = String;

/// Pairs a provider and model together for a specific operation.
#[derive(
Default, Debug, Setters, Clone, PartialEq, Serialize, Deserialize, JsonSchema, fake::Dummy,
)]
#[setters(strip_option, into)]
#[derive(Debug, Setters, Clone, PartialEq, Serialize, Deserialize, JsonSchema, fake::Dummy)]
#[setters(into)]
pub struct ModelConfig {
/// The provider to use for this operation.
pub provider_id: Option<String>,
pub provider_id: String,
/// The model to use for this operation.
pub model_id: Option<String>,
pub model_id: String,
}

impl ModelConfig {
/// Creates a new ModelConfig with the given provider and model IDs.
pub fn new(provider_id: impl Into<String>, model_id: impl Into<String>) -> Self {
Self { provider_id: provider_id.into(), model_id: model_id.into() }
}
}
12 changes: 6 additions & 6 deletions crates/forge_config/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ mod tests {
// it on top of the embedded defaults. The default values must survive.
let legacy = ForgeConfig {
session: Some(ModelConfig {
provider_id: Some("anthropic".to_string()),
model_id: Some("claude-3".to_string()),
provider_id: "anthropic".to_string(),
model_id: "claude-3".to_string(),
}),
..Default::default()
};
Expand All @@ -216,8 +216,8 @@ mod tests {
assert_eq!(
actual.session,
Some(ModelConfig {
provider_id: Some("anthropic".to_string()),
model_id: Some("claude-3".to_string()),
provider_id: "anthropic".to_string(),
model_id: "claude-3".to_string(),
})
);

Expand All @@ -243,8 +243,8 @@ mod tests {
.unwrap();

let expected = Some(ModelConfig {
provider_id: Some("fake-provider".to_string()),
model_id: Some("fake-model".to_string()),
provider_id: "fake-provider".to_string(),
model_id: "fake-model".to_string(),
});
assert_eq!(actual.session, expected);
}
Expand Down
13 changes: 0 additions & 13 deletions crates/forge_domain/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,6 @@ use serde::{Deserialize, Serialize};

use crate::{Effort, ModelConfig};

/// Domain-level session configuration pairing a provider with a model.
///
/// Used to represent an active session, decoupled from the on-disk
/// configuration format.
#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize, Setters)]
#[setters(strip_option, into)]
pub struct SessionConfig {
/// The active provider ID (e.g. `"anthropic"`).
pub provider_id: Option<String>,
/// The model ID to use with this provider.
pub model_id: Option<String>,
}

/// All discrete mutations that can be applied to the application configuration.
///
/// Instead of replacing the entire config, callers describe exactly which field
Expand Down
12 changes: 2 additions & 10 deletions crates/forge_domain/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,8 @@ pub enum Error {
#[error("Failed to sync {count} file(s)")]
SyncFailed { count: usize },

#[error("No default provider set.")]
NoDefaultProvider,

#[error("No default model configured for provider: {0}")]
#[from(skip)]
NoDefaultModel(ProviderId),
#[error("No default provider and model configured.")]
NoDefaultSession,
}

pub type Result<A> = std::result::Result<A, Error>;
Expand Down Expand Up @@ -149,10 +145,6 @@ impl Error {
Self::VertexAiConfiguration { message: message.into() }
}

pub fn no_default_model(provider: ProviderId) -> Self {
Self::NoDefaultModel(provider)
}

pub fn sync_failed(count: usize) -> Self {
Self::SyncFailed { count }
}
Expand Down
54 changes: 18 additions & 36 deletions crates/forge_infra/src/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,18 @@ fn apply_config_op(fc: &mut ForgeConfig, op: ConfigOperation) {
ConfigOperation::SetSessionConfig(mc) => {
let pid_str = mc.provider.as_ref().to_string();
let mid_str = mc.model.to_string();
let session = fc.session.get_or_insert_with(ModelConfig::default);
if session.provider_id.as_deref() == Some(&pid_str) {
session.model_id = Some(mid_str);
} else {
fc.session =
Some(ModelConfig { provider_id: Some(pid_str), model_id: Some(mid_str) });
}
fc.session = Some(ModelConfig { provider_id: pid_str, model_id: mid_str });
}
ConfigOperation::SetCommitConfig(mc) => {
fc.commit = mc.map(|m| ModelConfig {
provider_id: Some(m.provider.as_ref().to_string()),
model_id: Some(m.model.to_string()),
provider_id: m.provider.as_ref().to_string(),
model_id: m.model.to_string(),
});
}
ConfigOperation::SetSuggestConfig(mc) => {
fc.suggest = Some(ModelConfig {
provider_id: Some(mc.provider.as_ref().to_string()),
model_id: Some(mc.model.to_string()),
provider_id: mc.provider.as_ref().to_string(),
model_id: mc.model.to_string(),
});
}
ConfigOperation::SetReasoningEffort(effort) => {
Expand Down Expand Up @@ -236,25 +230,22 @@ mod tests {
)),
);

let actual_provider = fixture
.session
.as_ref()
.and_then(|s| s.provider_id.as_deref());
let actual_model = fixture.session.as_ref().and_then(|s| s.model_id.as_deref());
let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str());
let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str());

assert_eq!(actual_provider, Some("anthropic"));
assert_eq!(actual_model, Some("claude-3-5-sonnet"));
}

#[test]
fn test_apply_config_op_set_model_matching_provider() {
fn test_apply_config_op_set_session_config_replaces_existing() {
use forge_config::ModelConfig as ForgeCfgModelConfig;
use forge_domain::{ModelConfig as DomainModelConfig, ModelId, ProviderId};

let mut fixture = ForgeConfig {
session: Some(ForgeCfgModelConfig {
provider_id: Some("anthropic".to_string()),
model_id: None,
provider_id: "openai".to_string(),
model_id: "gpt-4".to_string(),
}),
..Default::default()
};
Expand All @@ -267,24 +258,18 @@ mod tests {
)),
);

let actual = fixture.session.as_ref().and_then(|s| s.model_id.as_deref());
let expected = Some("claude-3-5-sonnet-20241022");
let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str());
let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str());

assert_eq!(actual, expected);
assert_eq!(actual_provider, Some("anthropic"));
assert_eq!(actual_model, Some("claude-3-5-sonnet-20241022"));
}

#[test]
fn test_apply_config_op_set_model_different_provider_replaces_session() {
use forge_config::ModelConfig as ForgeCfgModelConfig;
fn test_apply_config_op_set_session_config_creates_new_session() {
use forge_domain::{ModelConfig as DomainModelConfig, ModelId, ProviderId};

let mut fixture = ForgeConfig {
session: Some(ForgeCfgModelConfig {
provider_id: Some("openai".to_string()),
model_id: Some("gpt-4".to_string()),
}),
..Default::default()
};
let mut fixture = ForgeConfig::default();

apply_config_op(
&mut fixture,
Expand All @@ -294,11 +279,8 @@ mod tests {
)),
);

let actual_provider = fixture
.session
.as_ref()
.and_then(|s| s.provider_id.as_deref());
let actual_model = fixture.session.as_ref().and_then(|s| s.model_id.as_deref());
let actual_provider = fixture.session.as_ref().map(|s| s.provider_id.as_str());
let actual_model = fixture.session.as_ref().map(|s| s.model_id.as_str());

assert_eq!(actual_provider, Some("anthropic"));
assert_eq!(actual_model, Some("claude-3-5-sonnet-20241022"));
Expand Down
Loading
Loading