Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions crates/forge_app/src/dto/openai/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,13 @@ pub enum Choice {
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ResponseMessage {
pub content: Option<String>,
#[serde(alias = "reasoning_content")]
pub reasoning: Option<String>,
/// Some providers (e.g. NVIDIA) send `reasoning_content` instead of
/// `reasoning`, or send both in the same response. We store them
/// separately to avoid serde's "duplicate field" error with aliases,
/// and merge via [`ResponseMessage::merged_reasoning`].
#[serde(default, skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
pub role: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub refusal: Option<String>,
Expand All @@ -152,6 +157,14 @@ pub struct ResponseMessage {
pub extra_content: Option<ExtraContent>,
}

impl ResponseMessage {
/// Returns the reasoning content, preferring `reasoning` over
/// `reasoning_content` when both are present.
pub fn merged_reasoning(&self) -> Option<&String> {
self.reasoning.as_ref().or(self.reasoning_content.as_ref())
}
}

impl From<ReasoningDetail> for forge_domain::ReasoningDetail {
fn from(detail: ReasoningDetail) -> Self {
forge_domain::ReasoningDetail {
Expand Down Expand Up @@ -319,7 +332,7 @@ impl TryFrom<Response> for ChatCompletionMessage {
.clone()
.and_then(|s| FinishReason::from_str(&s).ok()),
);
if let Some(reasoning) = &message.reasoning {
if let Some(reasoning) = message.merged_reasoning() {
resp = resp.reasoning(Content::full(reasoning.clone()));
}

Expand Down Expand Up @@ -387,7 +400,7 @@ impl TryFrom<Response> for ChatCompletionMessage {
.and_then(|s| FinishReason::from_str(&s).ok()),
);

if let Some(reasoning) = &delta.reasoning {
if let Some(reasoning) = delta.merged_reasoning() {
resp = resp.reasoning(Content::part(reasoning.clone()));
}

Expand Down Expand Up @@ -632,6 +645,7 @@ mod tests {
message: ResponseMessage {
content: Some("test content".to_string()),
reasoning: None,
reasoning_content: None,
role: Some("assistant".to_string()),
tool_calls: None,
refusal: None,
Expand Down Expand Up @@ -669,6 +683,7 @@ mod tests {
delta: ResponseMessage {
content: Some("test content".to_string()),
reasoning: None,
reasoning_content: None,
role: Some("assistant".to_string()),
tool_calls: None,
refusal: None,
Expand Down Expand Up @@ -706,6 +721,7 @@ mod tests {
message: ResponseMessage {
content: Some("Hello, world!".to_string()),
reasoning: None,
reasoning_content: None,
role: Some("assistant".to_string()),
tool_calls: None,
refusal: None,
Expand Down Expand Up @@ -963,4 +979,50 @@ mod tests {
assert!(error_string.contains("Content was filtered"));
assert!(error_string.contains("hate"));
}

#[test]
fn test_nvidia_tool_call_streaming_chunk() {
let response_json = r#"{"id":"chatcmpl-994182aa3bf1d873","object":"chat.completion.chunk","created":1775363363,"model":"qwen/qwen3.5-397b-a17b","choices":[{"index":0,"delta":{"content":null,"reasoning":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":"tool_calls","stop_reason":null,"token_ids":null}]}"#;

let actual = serde_json::from_str::<Response>(response_json);
assert!(
actual.is_ok(),
"Should parse NVIDIA tool call streaming chunk: {:?}",
actual.err()
);
}

#[test]
fn test_nvidia_tool_call_deserialization() {
// NVIDIA sends tool calls without "id" and "type" fields
let tool_call_json = r#"{"index":1,"function":{"arguments":"}"}}"#;
let actual = serde_json::from_str::<ToolCall>(tool_call_json);
assert!(
actual.is_ok(),
"Should parse NVIDIA tool call: {:?}",
actual.err()
);
}

#[test]
fn test_nvidia_response_message_deserialization() {
let msg_json = r#"{"content":null,"reasoning":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"}"}}]}"#;
let actual = serde_json::from_str::<ResponseMessage>(msg_json);
assert!(
actual.is_ok(),
"Should parse NVIDIA response message: {:?}",
actual.err()
);
}

#[test]
fn test_nvidia_choice_deserialization() {
let choice_json = r#"{"index":0,"delta":{"content":null,"reasoning":null,"reasoning_content":null,"tool_calls":[{"index":1,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":"tool_calls","stop_reason":null,"token_ids":null}"#;
let actual = serde_json::from_str::<Choice>(choice_json);
assert!(
actual.is_ok(),
"Should parse NVIDIA choice: {:?}",
actual.err()
);
}
}
199 changes: 199 additions & 0 deletions crates/forge_app/src/dto/openai/transformers/ensure_system_first.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
use forge_domain::Transformer;

use crate::dto::openai::{Message, MessageContent, Request, Role};

/// Merges all system messages into a single system message at the beginning of
/// the messages array.
///
/// Some providers (e.g. NVIDIA) reject requests with multiple system messages
/// or system messages that are not positioned at the start of the conversation.
pub struct MergeSystemMessages;

impl Transformer for MergeSystemMessages {
type Value = Request;

fn transform(&mut self, mut request: Self::Value) -> Self::Value {
if let Some(messages) = request.messages.take() {
let (system, rest): (Vec<_>, Vec<_>) =
messages.into_iter().partition(|m| m.role == Role::System);


let merged = if system.is_empty() {
rest
} else {
let combined_content = system
.iter()
.filter_map(|m| match &m.content {
Some(MessageContent::Text(text)) => Some(text.clone()),
Some(MessageContent::Parts(parts)) => Some(
parts
.iter()
.filter_map(|p| match p {
crate::dto::openai::ContentPart::Text { text, .. } => {
Some(text.clone())
}
_ => None,
})
.collect::<Vec<_>>()
.join(""),
),
None => None,
})
.collect::<Vec<_>>()
.join("\n\n");

if combined_content.is_empty() {
// All system messages had no content, don't create empty system message
rest
} else {
let mut result = vec![Message {
role: Role::System,
content: Some(MessageContent::Text(combined_content)),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning_details: None,
reasoning_text: None,
reasoning_opaque: None,
reasoning_content: None,
extra_content: None,
}];
result.extend(rest);
result
}
};


request.messages = Some(merged);
}
request
}
}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;

use super::*;
use crate::dto::openai::{Message, MessageContent, Role};

fn system_msg(content: &str) -> Message {
Message {
role: Role::System,
content: Some(MessageContent::Text(content.to_string())),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning_details: None,
reasoning_text: None,
reasoning_opaque: None,
reasoning_content: None,
extra_content: None,
}
}

fn user_msg(content: &str) -> Message {
Message {
role: Role::User,
content: Some(MessageContent::Text(content.to_string())),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning_details: None,
reasoning_text: None,
reasoning_opaque: None,
reasoning_content: None,
extra_content: None,
}
}

fn assistant_msg(content: &str) -> Message {
Message {
role: Role::Assistant,
content: Some(MessageContent::Text(content.to_string())),
name: None,
tool_call_id: None,
tool_calls: None,
reasoning_details: None,
reasoning_text: None,
reasoning_opaque: None,
reasoning_content: None,
extra_content: None,
}
}

fn get_text_content(msg: &Message) -> Option<&str> {
match msg.content.as_ref() {
Some(MessageContent::Text(text)) => Some(text.as_str()),
_ => None,
}
}

#[test]
fn test_multiple_system_messages_merged() {
let fixture = Request::default().messages(vec![
user_msg("hello"),
system_msg("you are helpful"),
assistant_msg("hi"),
system_msg("be concise"),
user_msg("how are you"),
]);

let actual = MergeSystemMessages.transform(fixture);

let messages = actual.messages.unwrap();
assert_eq!(messages.len(), 4);
assert_eq!(messages[0].role, Role::System);
assert_eq!(
get_text_content(&messages[0]),
Some("you are helpful\n\nbe concise")
);
assert_eq!(messages[1].role, Role::User);
assert_eq!(messages[2].role, Role::Assistant);
assert_eq!(messages[3].role, Role::User);
}

#[test]
fn test_single_system_message_unchanged() {
let fixture = Request::default().messages(vec![
system_msg("you are helpful"),
user_msg("hello"),
assistant_msg("hi"),
]);

let actual = MergeSystemMessages.transform(fixture);

let messages = actual.messages.unwrap();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].role, Role::System);
assert_eq!(get_text_content(&messages[0]), Some("you are helpful"));
assert_eq!(messages[1].role, Role::User);
assert_eq!(messages[2].role, Role::Assistant);
}

#[test]
fn test_no_system_messages_unchanged() {
let fixture = Request::default().messages(vec![
user_msg("hello"),
assistant_msg("hi"),
user_msg("how are you"),
]);

let actual = MergeSystemMessages.transform(fixture);

let messages = actual.messages.unwrap();
assert_eq!(messages.len(), 3);
assert_eq!(messages[0].role, Role::User);
assert_eq!(messages[1].role, Role::Assistant);
assert_eq!(messages[2].role, Role::User);
}

#[test]
fn test_no_messages_unchanged() {
let fixture = Request::default();

let actual = MergeSystemMessages.transform(fixture);

assert!(actual.messages.is_none());
}
}
1 change: 1 addition & 0 deletions crates/forge_app/src/dto/openai/transformers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod drop_tool_call;
mod ensure_system_first;
mod github_copilot_reasoning;
mod kimi_k2_reasoning;
mod make_cerebras_compat;
Expand Down
5 changes: 5 additions & 0 deletions crates/forge_app/src/dto/openai/transformers/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use forge_domain::{DefaultTransformation, Provider, ProviderId, Transformer};
use url::Url;

use super::drop_tool_call::DropToolCalls;
use super::ensure_system_first::MergeSystemMessages;
use super::github_copilot_reasoning::GitHubCopilotReasoning;
use super::kimi_k2_reasoning::KimiK2Reasoning;
use super::make_cerebras_compat::MakeCerebrasCompat;
Expand Down Expand Up @@ -67,6 +68,9 @@ impl Transformer for ProviderPipeline<'_> {

let cerebras_compat = MakeCerebrasCompat.when(move |_| provider.id == ProviderId::CEREBRAS);

let ensure_system_first =
MergeSystemMessages.when(move |_| provider.id == ProviderId::NVIDIA);

let trim_tool_call_ids = TrimToolCallIds.when(move |_| provider.id == ProviderId::OPENAI);

let strict_schema = EnforceStrictToolSchema
Expand All @@ -83,6 +87,7 @@ impl Transformer for ProviderPipeline<'_> {
.pipe(github_copilot_reasoning)
.pipe(kimi_k2_reasoning)
.pipe(cerebras_compat)
.pipe(ensure_system_first)
.pipe(trim_tool_call_ids)
.pipe(strict_schema)
.pipe(NormalizeToolSchema);
Expand Down
Loading
Loading