Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 141 additions & 52 deletions src/openhuman/inference/provider/reliable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::traits::{
ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamOptions, StreamResult,
ChatMessage, ChatRequest, ChatResponse, StreamChunk, StreamError, StreamOptions, StreamResult,
};
use super::Provider;
use async_trait::async_trait;
Expand Down Expand Up @@ -59,6 +59,32 @@ fn is_non_retryable(err: &anyhow::Error) -> bool {
|| msg_lower.contains("invalid"))
}

/// Classify a StreamError without losing type information.
/// Inspects the inner reqwest::Error status directly for Http variants.
fn is_stream_error_non_retryable(err: &StreamError) -> bool {
match err {
StreamError::Http(reqwest_err) => {
if let Some(status) = reqwest_err.status() {
let code = status.as_u16();
// Client errors except 429 (rate limit) and 408 (timeout) are non-retryable
return status.is_client_error() && code != 429 && code != 408;
}
false
}
StreamError::Provider(msg) => {
let lower = msg.to_lowercase();
lower.contains("invalid api key")
|| lower.contains("unauthorized")
|| lower.contains("forbidden")
|| lower.contains("model")
&& (lower.contains("not found") || lower.contains("unsupported"))
}
// JSON/SSE parse errors and IO errors are generally non-retryable
StreamError::Json(_) | StreamError::InvalidSse(_) => true,
StreamError::Io(_) => false,
}
}

fn is_context_window_exceeded(err: &anyhow::Error) -> bool {
let lower = err.to_string().to_lowercase();
let hints = [
Expand Down Expand Up @@ -924,63 +950,126 @@ impl Provider for ReliableProvider {
temperature: f64,
options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
// Try each provider/model combination for streaming
// For streaming, we use the first provider that supports it and has streaming enabled
for (provider_name, provider) in &self.providers {
if !provider.supports_streaming() || !options.enabled {
continue;
if !options.enabled {
return stream::once(async move {
Err(super::traits::StreamError::Provider(
"Streaming disabled".to_string(),
))
})
.boxed();
}

// Collect streaming-capable providers
let streaming_providers: Vec<_> = self
.providers
.iter()
.filter(|(_, p)| p.supports_streaming())
.collect();

if streaming_providers.is_empty() {
return stream::once(async move {
Err(super::traits::StreamError::Provider(
"No provider supports streaming".to_string(),
))
})
.boxed();
}

// Build model chain and provider info for the spawned task
let models = self.model_chain(model);
let model_chain: Vec<String> = models.into_iter().map(|m| m.to_string()).collect();
let base_backoff_ms = self.base_backoff_ms;

// Collect provider streams lazily inside the task — we need owned data
// Provider trait is object-safe, so we call stream_chat_with_system per attempt
// We need to pre-create all possible streams since Provider is behind &self
// Instead, collect the streams for each provider+model combo upfront
let mut candidate_streams: Vec<(
String,
String,
stream::BoxStream<'static, StreamResult<StreamChunk>>,
)> = Vec::new();
for current_model in &model_chain {
for (provider_name, provider) in &streaming_providers {
let s = provider.stream_chat_with_system(
system_prompt,
message,
current_model,
temperature,
options,
);
candidate_streams.push(((*provider_name).clone(), current_model.clone(), s));
}
}

// Clone provider data for the stream
let provider_clone = provider_name.clone();

// Try the first model in the chain for streaming
let current_model = match self.model_chain(model).first() {
Some(m) => m.to_string(),
None => model.to_string(),
};

// For streaming, we attempt once and propagate errors
// The caller can retry the entire request if needed
let stream = provider.stream_chat_with_system(
system_prompt,
message,
&current_model,
temperature,
options,
);

// Use a channel to bridge the stream with logging
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);

tokio::spawn(async move {
let mut stream = stream;
while let Some(chunk) = stream.next().await {
if let Err(ref e) = chunk {
tracing::warn!(
provider = provider_clone,
model = current_model,
"Streaming error: {e}"
);
}
if tx.send(chunk).await.is_err() {
break; // Receiver dropped
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
let max_retries = self.max_retries;

tokio::spawn(async move {
for (provider_name, current_model, mut candidate_stream) in candidate_streams {
let mut backoff_ms = base_backoff_ms;
let mut attempts = 0u32;

loop {
match candidate_stream.next().await {
Some(Ok(chunk)) => {
// First chunk succeeded — commit to this stream
if tx.send(Ok(chunk)).await.is_err() {
return;
}
// Forward remaining chunks
while let Some(chunk) = candidate_stream.next().await {
if tx.send(chunk).await.is_err() {
return;
}
}
return; // Done successfully
}
Some(Err(ref e)) => {
let non_retryable = is_stream_error_non_retryable(e);

tracing::warn!(
provider = provider_name,
model = current_model,
attempt = attempts + 1,
error = %e,
"Streaming failed{}", if non_retryable { " (non-retryable)" } else { "" }
);

if non_retryable || attempts >= max_retries {
break; // Move to next candidate
}

attempts += 1;
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
// Continue inner loop — stream may yield more items
}
None => {
// Stream exhausted without success
if attempts == 0 {
tracing::warn!(
provider = provider_name,
model = current_model,
"Stream returned empty"
);
}
break; // Move to next candidate
}
}
}
});
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// Convert channel receiver to stream
return stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(|chunk| (chunk, rx))
})
.boxed();
}
// All providers/models exhausted
let _ = tx
.send(Err(super::traits::StreamError::Provider(
"All streaming providers/models failed".to_string(),
)))
.await;
});

// No streaming support available
stream::once(async move {
Err(super::traits::StreamError::Provider(
"No provider supports streaming".to_string(),
))
stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(|chunk| (chunk, rx))
})
.boxed()
}
Expand Down
Loading