diff --git a/Cargo.toml b/Cargo.toml index 1d1743f..c700305 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ once_cell = "1" # Lazy statics async-trait = "0.1" # Async trait support dirs = "5" # User directories regex = "1" # Regular expressions +uuid = { version = "1.0", features = ["v4", "serde"] } # UUID generation for streaming # OAuth & Auth oauth2 = "4" # OAuth 2.0 client diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 5a003ba..009f3e2 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -189,6 +189,38 @@ struct ResponsesUsage { output_tokens: u32, } +/// OpenAI Streaming Chunk (for SSE transformation) +#[derive(Debug, Deserialize)] +struct OpenAIStreamChunk { + id: String, + #[serde(default)] + model: String, + choices: Vec, + #[serde(default)] + created: u64, +} + +#[derive(Debug, Deserialize)] +struct OpenAIStreamChoice { + delta: OpenAIStreamDelta, + #[serde(default)] + index: usize, + #[serde(default)] + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct OpenAIStreamDelta { + #[serde(default)] + content: Option, + #[serde(default)] + reasoning: Option, // For GLM/Cerebras models + #[serde(default)] + role: Option, + #[serde(default)] + tool_calls: Option>, +} + /// OpenAI provider implementation pub struct OpenAIProvider { name: String, @@ -757,6 +789,8 @@ impl OpenAIProvider { let choice = response.choices.into_iter().next() .expect("OpenAI response must have at least one choice"); + let mut content_blocks = Vec::new(); + // Extract text from content or reasoning (for GLM models via Cerebras) let text = if let Some(content) = choice.message.content { match content { @@ -781,13 +815,31 @@ impl OpenAIProvider { String::new() }; + // Add text content if present + if !text.is_empty() { + content_blocks.push(ContentBlock::Text { text }); + } + + // Transform tool_calls to Anthropic tool_use format + if let Some(tool_calls) = choice.message.tool_calls { + for tool_call in tool_calls { + // Parse arguments from JSON string + let input = serde_json::from_str(&tool_call.function.arguments) + .unwrap_or(serde_json::json!({})); + + content_blocks.push(ContentBlock::ToolUse { + id: tool_call.id, + name: tool_call.function.name, + input, + }); + } + } + ProviderResponse { id: response.id, r#type: "message".to_string(), role: "assistant".to_string(), - content: vec![ContentBlock::Text { - text, - }], + content: content_blocks, model: response.model, stop_reason: choice.finish_reason, stop_sequence: None, @@ -828,6 +880,184 @@ impl OpenAIProvider { }, } } + + /// Transform OpenAI streaming chunk to Anthropic format + /// Returns raw SSE-formatted bytes (event: type / data: json) + fn transform_openai_chunk_to_anthropic_sse(chunk: &OpenAIStreamChunk, message_id: &str, is_first: &mut bool, has_content_block: &mut bool, stream_ended: &mut bool) -> String { + let mut output = String::new(); + + // First chunk: emit message_start + if *is_first { + *is_first = false; + let message_start = serde_json::json!({ + "type": "message_start", + "message": { + "id": message_id, + "type": "message", + "role": "assistant", + "content": [], + "model": chunk.model, + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 0, + "output_tokens": 0 + } + } + }); + output.push_str(&format!("event: message_start\ndata: {}\n\n", message_start)); + } + + // Process delta content + for choice in &chunk.choices { + // Handle text content (content or reasoning fields) + let text_content = choice.delta.content.as_ref() + .or(choice.delta.reasoning.as_ref()); // Support reasoning field for GLM/Cerebras + + if let Some(text) = text_content { + // Skip empty text + if text.is_empty() { + continue; + } + + // Emit content_block_start if this is the first content + if !*has_content_block { + *has_content_block = true; + let block_start = serde_json::json!({ + "type": "content_block_start", + "index": 0, + "content_block": { + "type": "text", + "text": "" + } + }); + output.push_str(&format!("event: content_block_start\ndata: {}\n\n", block_start)); + } + + // Emit content_block_delta + let delta = serde_json::json!({ + "type": "content_block_delta", + "index": 0, + "delta": { + "type": "text_delta", + "text": text + } + }); + output.push_str(&format!("event: content_block_delta\ndata: {}\n\n", delta)); + } + + // Handle tool calls (OpenAI function calling → Anthropic tool_use) + if let Some(ref tool_calls) = choice.delta.tool_calls { + tracing::debug!("🔧 Transforming {} tool_calls to Anthropic tool_use format", tool_calls.len()); + + // First, close any open text content block + if *has_content_block { + let block_stop = serde_json::json!({ + "type": "content_block_stop", + "index": 0 + }); + output.push_str(&format!("event: content_block_stop\ndata: {}\n\n", block_stop)); + *has_content_block = false; + } + + // Transform each tool call to Anthropic format + for (idx, tool_call) in tool_calls.iter().enumerate() { + // Extract tool info + if let Some(ref function) = tool_call.get("function") { + let default_tool_id = format!("tool_{}", idx); + let tool_id = tool_call.get("id") + .and_then(|v| v.as_str()) + .unwrap_or(&default_tool_id); + let tool_name = function.get("name") + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + let tool_args = function.get("arguments") + .and_then(|v| v.as_str()) + .unwrap_or("{}"); + + tracing::debug!("🔨 Tool: {} (id: {})", tool_name, tool_id); + + // Parse arguments to validate JSON + let tool_input: serde_json::Value = serde_json::from_str(tool_args) + .unwrap_or(serde_json::json!({})); + + // Send tool_use content_block_start + let block_start = serde_json::json!({ + "type": "content_block_start", + "index": idx + 1, // Index after text content + "content_block": { + "type": "tool_use", + "id": tool_id, + "name": tool_name + } + }); + output.push_str(&format!("event: content_block_start\ndata: {}\n\n", block_start)); + + // Send tool input as delta + let input_delta = serde_json::json!({ + "type": "content_block_delta", + "index": idx + 1, + "delta": { + "type": "input_json_delta", + "partial_json": serde_json::to_string(&tool_input).unwrap_or_default() + } + }); + output.push_str(&format!("event: content_block_delta\ndata: {}\n\n", input_delta)); + + // Close tool_use block + let block_stop = serde_json::json!({ + "type": "content_block_stop", + "index": idx + 1 + }); + output.push_str(&format!("event: content_block_stop\ndata: {}\n\n", block_stop)); + } + } + + continue; + } + + // Handle finish_reason (stream end) + if let Some(reason) = &choice.finish_reason { + *stream_ended = true; // Mark that stream ended properly + + if *has_content_block { + // Emit content_block_stop + let block_stop = serde_json::json!({ + "type": "content_block_stop", + "index": 0 + }); + output.push_str(&format!("event: content_block_stop\ndata: {}\n\n", block_stop)); + } + + // Emit message_delta with stop reason + let stop_reason = match reason.as_str() { + "stop" => "end_turn", + "length" => "max_tokens", + "tool_calls" => "end_turn", // Tool calls also end the turn + _ => "end_turn" + }; + let message_delta = serde_json::json!({ + "type": "message_delta", + "delta": { + "stop_reason": stop_reason, + "stop_sequence": null + }, + "usage": { + "output_tokens": 0 + } + }); + output.push_str(&format!("event: message_delta\ndata: {}\n\n", message_delta)); + + // Emit message_stop + let message_stop = serde_json::json!({ + "type": "message_stop" + }); + output.push_str(&format!("event: message_stop\ndata: {}\n\n", message_stop)); + } + } + + output + } } #[async_trait] @@ -1120,11 +1350,77 @@ impl AnthropicProvider for OpenAIProvider { }); } - // TODO: Transform OpenAI SSE format to Anthropic SSE format - // For now, just pass through the stream - let stream = response.bytes_stream().map_err(|e| ProviderError::HttpError(e)); + // Transform OpenAI SSE format to Anthropic SSE format + use futures::stream::StreamExt; + use crate::providers::streaming::SseStream; + use std::sync::{Arc, Mutex}; + + let message_id = format!("msg_{}", uuid::Uuid::new_v4()); + let is_first = Arc::new(Mutex::new(true)); + let has_content_block = Arc::new(Mutex::new(false)); + let stream_ended_properly = Arc::new(Mutex::new(false)); + + // Convert response bytes stream to SSE events + let sse_stream = SseStream::new(response.bytes_stream()); + + // Transform OpenAI SSE events to Anthropic format + let transformed_stream = sse_stream.then(move |result| { + let message_id = message_id.clone(); + let is_first = is_first.clone(); + let has_content_block = has_content_block.clone(); + let stream_ended_properly = stream_ended_properly.clone(); + + async move { + match result { + Ok(sse_event) => { + tracing::debug!("📦 Received SSE chunk: {}", sse_event.data.chars().take(100).collect::()); + + // Skip empty data + if sse_event.data.trim().is_empty() { + tracing::debug!("⏭️ Skipping empty SSE event"); + return Ok(Bytes::new()); + } + + if sse_event.data.trim() == "[DONE]" { + tracing::debug!("✅ Stream finished with [DONE]"); + return Ok(Bytes::new()); + } + + // Parse OpenAI chunk + match serde_json::from_str::(&sse_event.data) { + Ok(chunk) => { + tracing::debug!("✨ Transforming chunk with {} choices", chunk.choices.len()); + + // Transform to Anthropic format (raw SSE bytes) + let sse_output = Self::transform_openai_chunk_to_anthropic_sse( + &chunk, + &message_id, + &mut *is_first.lock().unwrap(), + &mut *has_content_block.lock().unwrap(), + &mut *stream_ended_properly.lock().unwrap() + ); + + tracing::debug!("📤 Sending {} bytes", sse_output.len()); + + // Return as raw bytes (already SSE-formatted) + Ok(Bytes::from(sse_output)) + } + Err(e) => { + tracing::warn!("❌ Failed to parse OpenAI chunk: {} - Data: {}", e, sse_event.data); + Ok(Bytes::new()) + } + } + } + Err(e) => { + tracing::error!("💥 Stream error: {}", e); + Err(ProviderError::HttpError(e)) + } + } + } + }) + .try_filter(|bytes| futures::future::ready(!bytes.is_empty())); - Ok(Box::pin(stream)) + Ok(Box::pin(transformed_stream)) } fn supports_model(&self, model: &str) -> bool {