From 52bbd62501fd2830cab852404eef1939d734d3f7 Mon Sep 17 00:00:00 2001 From: Rohith Mahesh Date: Sun, 5 Apr 2026 10:39:39 +0530 Subject: [PATCH 1/3] feat: Add NVIDIA provider support with OpenAI-compatible API - Add NVIDIA provider constant and configuration - Add reasoning_content field to handle NVIDIA's response format - Implement MergeSystemMessages transformer for NVIDIA's system message requirements - Add comprehensive test coverage for NVIDIA deserialization --- crates/forge_app/src/dto/openai/response.rs | 68 ++++++- .../transformers/ensure_system_first.rs | 192 ++++++++++++++++++ .../src/dto/openai/transformers/mod.rs | 1 + .../src/dto/openai/transformers/pipeline.rs | 5 + crates/forge_domain/src/provider.rs | 6 + crates/forge_repo/src/provider/provider.json | 9 + .../forge_repo/src/provider/provider_repo.rs | 20 ++ 7 files changed, 298 insertions(+), 3 deletions(-) create mode 100644 crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs diff --git a/crates/forge_app/src/dto/openai/response.rs b/crates/forge_app/src/dto/openai/response.rs index 90829302f4..297fe164cf 100644 --- a/crates/forge_app/src/dto/openai/response.rs +++ b/crates/forge_app/src/dto/openai/response.rs @@ -139,8 +139,13 @@ pub enum Choice { #[derive(Debug, Deserialize, Serialize, Clone)] pub struct ResponseMessage { pub content: Option, - #[serde(alias = "reasoning_content")] pub reasoning: Option, + /// Some providers (e.g. NVIDIA) send `reasoning_content` instead of + /// `reasoning`, or send both in the same response. We store them + /// separately to avoid serde's "duplicate field" error with aliases, + /// and merge via [`ResponseMessage::merged_reasoning`]. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reasoning_content: Option, pub role: Option, pub tool_calls: Option>, pub refusal: Option, @@ -152,6 +157,14 @@ pub struct ResponseMessage { pub extra_content: Option, } +impl ResponseMessage { + /// Returns the reasoning content, preferring `reasoning` over + /// `reasoning_content` when both are present. + pub fn merged_reasoning(&self) -> Option<&String> { + self.reasoning.as_ref().or(self.reasoning_content.as_ref()) + } +} + impl From for forge_domain::ReasoningDetail { fn from(detail: ReasoningDetail) -> Self { forge_domain::ReasoningDetail { @@ -319,7 +332,7 @@ impl TryFrom for ChatCompletionMessage { .clone() .and_then(|s| FinishReason::from_str(&s).ok()), ); - if let Some(reasoning) = &message.reasoning { + if let Some(reasoning) = message.merged_reasoning() { resp = resp.reasoning(Content::full(reasoning.clone())); } @@ -387,7 +400,7 @@ impl TryFrom for ChatCompletionMessage { .and_then(|s| FinishReason::from_str(&s).ok()), ); - if let Some(reasoning) = &delta.reasoning { + if let Some(reasoning) = delta.merged_reasoning() { resp = resp.reasoning(Content::part(reasoning.clone())); } @@ -632,6 +645,7 @@ mod tests { message: ResponseMessage { content: Some("test content".to_string()), reasoning: None, + reasoning_content: None, role: Some("assistant".to_string()), tool_calls: None, refusal: None, @@ -669,6 +683,7 @@ mod tests { delta: ResponseMessage { content: Some("test content".to_string()), reasoning: None, + reasoning_content: None, role: Some("assistant".to_string()), tool_calls: None, refusal: None, @@ -706,6 +721,7 @@ mod tests { message: ResponseMessage { content: Some("Hello, world!".to_string()), reasoning: None, + reasoning_content: None, role: Some("assistant".to_string()), tool_calls: None, refusal: None, @@ -963,4 +979,50 @@ mod tests { assert!(error_string.contains("Content was filtered")); assert!(error_string.contains("hate")); } + + #[test] + fn test_nvidia_tool_call_streaming_chunk() { + let response_json = r#"{"id":"chatcmpl-994182aa3bf1d873","object":"chat.completion.chunk","created":1775363363,"model":"qwen/qwen3.5-397b-a17b","choices":[{"index":0,"delta":{"content":null,"reasoning":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":"tool_calls","stop_reason":null,"token_ids":null}]}"#; + + let actual = serde_json::from_str::(response_json); + assert!( + actual.is_ok(), + "Should parse NVIDIA tool call streaming chunk: {:?}", + actual.err() + ); + } + + #[test] + fn test_nvidia_tool_call_deserialization() { + // NVIDIA sends tool calls without "id" and "type" fields + let tool_call_json = r#"{"index":1,"function":{"arguments":"}"}}"#; + let actual = serde_json::from_str::(tool_call_json); + assert!( + actual.is_ok(), + "Should parse NVIDIA tool call: {:?}", + actual.err() + ); + } + + #[test] + fn test_nvidia_response_message_deserialization() { + let msg_json = r#"{"content":null,"reasoning":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"}"}}]}"#; + let actual = serde_json::from_str::(msg_json); + assert!( + actual.is_ok(), + "Should parse NVIDIA response message: {:?}", + actual.err() + ); + } + + #[test] + fn test_nvidia_choice_deserialization() { + let choice_json = r#"{"index":0,"delta":{"content":null,"reasoning":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":"tool_calls","stop_reason":null,"token_ids":null}"#; + let actual = serde_json::from_str::(choice_json); + assert!( + actual.is_ok(), + "Should parse NVIDIA choice: {:?}", + actual.err() + ); + } } diff --git a/crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs b/crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs new file mode 100644 index 0000000000..6874f403a5 --- /dev/null +++ b/crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs @@ -0,0 +1,192 @@ +use forge_domain::Transformer; + +use crate::dto::openai::{Message, MessageContent, Request, Role}; + +/// Merges all system messages into a single system message at the beginning of +/// the messages array. +/// +/// Some providers (e.g. NVIDIA) reject requests with multiple system messages +/// or system messages that are not positioned at the start of the conversation. +pub struct MergeSystemMessages; + +impl Transformer for MergeSystemMessages { + type Value = Request; + + fn transform(&mut self, mut request: Self::Value) -> Self::Value { + if let Some(messages) = request.messages.take() { + let (system, rest): (Vec<_>, Vec<_>) = + messages.into_iter().partition(|m| m.role == Role::System); + + let merged = if system.is_empty() { + rest + } else { + let combined_content = system + .iter() + .filter_map(|m| match &m.content { + Some(MessageContent::Text(text)) => Some(text.clone()), + Some(MessageContent::Parts(parts)) => Some( + parts + .iter() + .filter_map(|p| match p { + crate::dto::openai::ContentPart::Text { text, .. } => { + Some(text.clone()) + } + _ => None, + }) + .collect::>() + .join(""), + ), + None => None, + }) + .collect::>() + .join("\n\n"); + + let mut result = vec![Message { + role: Role::System, + content: Some(MessageContent::Text(combined_content)), + name: None, + tool_call_id: None, + tool_calls: None, + reasoning_details: None, + reasoning_text: None, + reasoning_opaque: None, + reasoning_content: None, + extra_content: None, + }]; + result.extend(rest); + result + }; + + request.messages = Some(merged); + } + request + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + use crate::dto::openai::{Message, MessageContent, Role}; + + fn system_msg(content: &str) -> Message { + Message { + role: Role::System, + content: Some(MessageContent::Text(content.to_string())), + name: None, + tool_call_id: None, + tool_calls: None, + reasoning_details: None, + reasoning_text: None, + reasoning_opaque: None, + reasoning_content: None, + extra_content: None, + } + } + + fn user_msg(content: &str) -> Message { + Message { + role: Role::User, + content: Some(MessageContent::Text(content.to_string())), + name: None, + tool_call_id: None, + tool_calls: None, + reasoning_details: None, + reasoning_text: None, + reasoning_opaque: None, + reasoning_content: None, + extra_content: None, + } + } + + fn assistant_msg(content: &str) -> Message { + Message { + role: Role::Assistant, + content: Some(MessageContent::Text(content.to_string())), + name: None, + tool_call_id: None, + tool_calls: None, + reasoning_details: None, + reasoning_text: None, + reasoning_opaque: None, + reasoning_content: None, + extra_content: None, + } + } + + fn get_text_content(msg: &Message) -> Option<&str> { + match msg.content.as_ref() { + Some(MessageContent::Text(text)) => Some(text.as_str()), + _ => None, + } + } + + #[test] + fn test_multiple_system_messages_merged() { + let fixture = Request::default().messages(vec![ + user_msg("hello"), + system_msg("you are helpful"), + assistant_msg("hi"), + system_msg("be concise"), + user_msg("how are you"), + ]); + + let actual = MergeSystemMessages.transform(fixture); + + let messages = actual.messages.unwrap(); + assert_eq!(messages.len(), 4); + assert_eq!(messages[0].role, Role::System); + assert_eq!( + get_text_content(&messages[0]), + Some("you are helpful\n\nbe concise") + ); + assert_eq!(messages[1].role, Role::User); + assert_eq!(messages[2].role, Role::Assistant); + assert_eq!(messages[3].role, Role::User); + } + + #[test] + fn test_single_system_message_unchanged() { + let fixture = Request::default().messages(vec![ + system_msg("you are helpful"), + user_msg("hello"), + assistant_msg("hi"), + ]); + + let actual = MergeSystemMessages.transform(fixture); + + let messages = actual.messages.unwrap(); + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].role, Role::System); + assert_eq!(get_text_content(&messages[0]), Some("you are helpful")); + assert_eq!(messages[1].role, Role::User); + assert_eq!(messages[2].role, Role::Assistant); + } + + #[test] + fn test_no_system_messages_unchanged() { + let fixture = Request::default().messages(vec![ + user_msg("hello"), + assistant_msg("hi"), + user_msg("how are you"), + ]); + + let actual = MergeSystemMessages.transform(fixture); + + let messages = actual.messages.unwrap(); + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].role, Role::User); + assert_eq!(messages[1].role, Role::Assistant); + assert_eq!(messages[2].role, Role::User); + } + + #[test] + fn test_no_messages_unchanged() { + let fixture = Request::default(); + + let actual = MergeSystemMessages.transform(fixture); + + assert!(actual.messages.is_none()); + } +} diff --git a/crates/forge_app/src/dto/openai/transformers/mod.rs b/crates/forge_app/src/dto/openai/transformers/mod.rs index 88dcb30b82..9d38f3b203 100644 --- a/crates/forge_app/src/dto/openai/transformers/mod.rs +++ b/crates/forge_app/src/dto/openai/transformers/mod.rs @@ -1,4 +1,5 @@ mod drop_tool_call; +mod ensure_system_first; mod github_copilot_reasoning; mod kimi_k2_reasoning; mod make_cerebras_compat; diff --git a/crates/forge_app/src/dto/openai/transformers/pipeline.rs b/crates/forge_app/src/dto/openai/transformers/pipeline.rs index 7b1c309c0e..2403d29c90 100644 --- a/crates/forge_app/src/dto/openai/transformers/pipeline.rs +++ b/crates/forge_app/src/dto/openai/transformers/pipeline.rs @@ -2,6 +2,7 @@ use forge_domain::{DefaultTransformation, Provider, ProviderId, Transformer}; use url::Url; use super::drop_tool_call::DropToolCalls; +use super::ensure_system_first::MergeSystemMessages; use super::github_copilot_reasoning::GitHubCopilotReasoning; use super::kimi_k2_reasoning::KimiK2Reasoning; use super::make_cerebras_compat::MakeCerebrasCompat; @@ -67,6 +68,9 @@ impl Transformer for ProviderPipeline<'_> { let cerebras_compat = MakeCerebrasCompat.when(move |_| provider.id == ProviderId::CEREBRAS); + let ensure_system_first = + MergeSystemMessages.when(move |_| provider.id == ProviderId::NVIDIA); + let trim_tool_call_ids = TrimToolCallIds.when(move |_| provider.id == ProviderId::OPENAI); let strict_schema = EnforceStrictToolSchema @@ -83,6 +87,7 @@ impl Transformer for ProviderPipeline<'_> { .pipe(github_copilot_reasoning) .pipe(kimi_k2_reasoning) .pipe(cerebras_compat) + .pipe(ensure_system_first) .pipe(trim_tool_call_ids) .pipe(strict_schema) .pipe(NormalizeToolSchema); diff --git a/crates/forge_domain/src/provider.rs b/crates/forge_domain/src/provider.rs index a65b43e416..e21478e64c 100644 --- a/crates/forge_domain/src/provider.rs +++ b/crates/forge_domain/src/provider.rs @@ -73,6 +73,7 @@ impl ProviderId { pub const FIREWORKS_AI: ProviderId = ProviderId(Cow::Borrowed("fireworks-ai")); pub const NOVITA: ProviderId = ProviderId(Cow::Borrowed("novita")); pub const GOOGLE_AI_STUDIO: ProviderId = ProviderId(Cow::Borrowed("google_ai_studio")); + pub const NVIDIA: ProviderId = ProviderId(Cow::Borrowed("nvidia")); /// Returns all built-in provider IDs /// @@ -106,6 +107,7 @@ impl ProviderId { ProviderId::FIREWORKS_AI, ProviderId::NOVITA, ProviderId::GOOGLE_AI_STUDIO, + ProviderId::NVIDIA, ] } @@ -132,6 +134,7 @@ impl ProviderId { "fireworks-ai" => "FireworksAI".to_string(), "novita" => "Novita".to_string(), "google_ai_studio" => "GoogleAIStudio".to_string(), + "nvidia" => "NVIDIA".to_string(), _ => { // For other providers, use UpperCamelCase conversion use convert_case::{Case, Casing}; @@ -177,6 +180,7 @@ impl std::str::FromStr for ProviderId { "fireworks-ai" => ProviderId::FIREWORKS_AI, "novita" => ProviderId::NOVITA, "google_ai_studio" => ProviderId::GOOGLE_AI_STUDIO, + "nvidia" => ProviderId::NVIDIA, // For custom providers, use Cow::Owned to avoid memory leaks custom => ProviderId(Cow::Owned(custom.to_string())), }; @@ -549,6 +553,7 @@ mod tests { assert_eq!(ProviderId::CODEX.to_string(), "Codex"); assert_eq!(ProviderId::FIREWORKS_AI.to_string(), "FireworksAI"); assert_eq!(ProviderId::GOOGLE_AI_STUDIO.to_string(), "GoogleAIStudio"); + assert_eq!(ProviderId::NVIDIA.to_string(), "NVIDIA"); } #[test] @@ -572,6 +577,7 @@ mod tests { assert!(built_in.contains(&ProviderId::OPENAI_RESPONSES_COMPATIBLE)); assert!(built_in.contains(&ProviderId::FIREWORKS_AI)); assert!(built_in.contains(&ProviderId::GOOGLE_AI_STUDIO)); + assert!(built_in.contains(&ProviderId::NVIDIA)); } #[test] diff --git a/crates/forge_repo/src/provider/provider.json b/crates/forge_repo/src/provider/provider.json index 16dc7899f6..e1ce581030 100644 --- a/crates/forge_repo/src/provider/provider.json +++ b/crates/forge_repo/src/provider/provider.json @@ -3099,5 +3099,14 @@ "input_modalities": ["text"] } ] + }, + { + "id": "nvidia", + "api_key_vars": "NVIDIA_API_KEY", + "url_param_vars": [], + "response_type": "OpenAI", + "url": "https://integrate.api.nvidia.com/v1/chat/completions", + "models": "https://integrate.api.nvidia.com/v1/models", + "auth_methods": ["api_key"] } ] diff --git a/crates/forge_repo/src/provider/provider_repo.rs b/crates/forge_repo/src/provider/provider_repo.rs index 5ff2a23982..a82ac67144 100644 --- a/crates/forge_repo/src/provider/provider_repo.rs +++ b/crates/forge_repo/src/provider/provider_repo.rs @@ -753,6 +753,26 @@ mod tests { "https://api.intelligence.io.solutions/api/v1/chat/completions" ); } + + #[test] + fn test_nvidia_config() { + let configs = get_provider_configs(); + let config = configs + .iter() + .find(|c| c.id == ProviderId::NVIDIA) + .unwrap(); + assert_eq!(config.id, ProviderId::NVIDIA); + assert_eq!( + config.api_key_vars, + Some("NVIDIA_API_KEY".to_string()) + ); + assert!(config.url_param_vars.is_empty()); + assert_eq!(config.response_type, Some(ProviderResponse::OpenAI)); + assert_eq!( + config.url.as_str(), + "https://integrate.api.nvidia.com/v1/chat/completions" + ); + } } #[cfg(test)] From 395edf779765fff7e28db13c1aa4e2e0747c3d09 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sun, 5 Apr 2026 05:57:28 +0000 Subject: [PATCH 2/3] [autofix.ci] apply automated fixes --- crates/forge_repo/src/provider/provider_repo.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/crates/forge_repo/src/provider/provider_repo.rs b/crates/forge_repo/src/provider/provider_repo.rs index a82ac67144..5df02a47d2 100644 --- a/crates/forge_repo/src/provider/provider_repo.rs +++ b/crates/forge_repo/src/provider/provider_repo.rs @@ -757,15 +757,9 @@ mod tests { #[test] fn test_nvidia_config() { let configs = get_provider_configs(); - let config = configs - .iter() - .find(|c| c.id == ProviderId::NVIDIA) - .unwrap(); + let config = configs.iter().find(|c| c.id == ProviderId::NVIDIA).unwrap(); assert_eq!(config.id, ProviderId::NVIDIA); - assert_eq!( - config.api_key_vars, - Some("NVIDIA_API_KEY".to_string()) - ); + assert_eq!(config.api_key_vars, Some("NVIDIA_API_KEY".to_string())); assert!(config.url_param_vars.is_empty()); assert_eq!(config.response_type, Some(ProviderResponse::OpenAI)); assert_eq!( From 2dd3d57959e64319fd8518c0fd648e8e9304e614 Mon Sep 17 00:00:00 2001 From: Rohith Mahesh <63108911+rohithmahesh3@users.noreply.github.com> Date: Sun, 5 Apr 2026 11:33:19 +0530 Subject: [PATCH 3/3] Update crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com> --- .../transformers/ensure_system_first.rs | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs b/crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs index 6874f403a5..02a21c39c5 100644 --- a/crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs +++ b/crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs @@ -17,6 +17,7 @@ impl Transformer for MergeSystemMessages { let (system, rest): (Vec<_>, Vec<_>) = messages.into_iter().partition(|m| m.role == Role::System); + let merged = if system.is_empty() { rest } else { @@ -41,22 +42,28 @@ impl Transformer for MergeSystemMessages { .collect::>() .join("\n\n"); - let mut result = vec![Message { - role: Role::System, - content: Some(MessageContent::Text(combined_content)), - name: None, - tool_call_id: None, - tool_calls: None, - reasoning_details: None, - reasoning_text: None, - reasoning_opaque: None, - reasoning_content: None, - extra_content: None, - }]; - result.extend(rest); - result + if combined_content.is_empty() { + // All system messages had no content, don't create empty system message + rest + } else { + let mut result = vec![Message { + role: Role::System, + content: Some(MessageContent::Text(combined_content)), + name: None, + tool_call_id: None, + tool_calls: None, + reasoning_details: None, + reasoning_text: None, + reasoning_opaque: None, + reasoning_content: None, + extra_content: None, + }]; + result.extend(rest); + result + } }; + request.messages = Some(merged); } request