diff --git a/src/proxy_app/main.py b/src/proxy_app/main.py index 3e4bbbbc..3592dc12 100644 --- a/src/proxy_app/main.py +++ b/src/proxy_app/main.py @@ -132,6 +132,7 @@ print(" → Initializing proxy core...") with _console.status("[dim]Initializing proxy core...", spinner="dots"): from rotator_library import RotatingClient + from rotator_library.core.errors import StreamBootstrapError from rotator_library.credential_manager import CredentialManager from rotator_library.background_refresher import BackgroundRefresher from rotator_library.model_info_service import init_model_info_service @@ -714,6 +715,29 @@ async def verify_anthropic_api_key( raise HTTPException(status_code=401, detail="Invalid or missing API Key") +async def _prepend_first_stream_chunk( + first_chunk: str, + stream: AsyncGenerator[str, None], +) -> AsyncGenerator[str, None]: + """Yield a primed first chunk, then continue streaming remaining chunks.""" + yield first_chunk + async for chunk in stream: + yield chunk + + +async def _prepare_streaming_response( + response_stream: AsyncGenerator[str, None], +) -> tuple[Optional[str], AsyncGenerator[str, None]]: + """ + Prime streaming generator to avoid committing HTTP 200 before bootstrap errors. + + Returns: + (first_chunk, stream_with_first_chunk_replayed) + """ + first_chunk = await response_stream.__anext__() + return first_chunk, _prepend_first_stream_chunk(first_chunk, response_stream) + + async def streaming_response_wrapper( request: Request, request_data: dict, @@ -961,29 +985,64 @@ async def chat_completions( response_generator = await client.acompletion( request=request, **request_data ) + + try: + _, primed_stream = await _prepare_streaming_response(response_generator) + except StreamBootstrapError as e: + status_code = e.status_code if e.status_code else 500 + headers = {} + if e.retry_after is not None and e.retry_after > 0: + headers["Retry-After"] = str(int(e.retry_after)) + logging.warning( + "Streaming bootstrap failed before first chunk " + f"(status={status_code}, retry_after={e.retry_after})" + ) + if raw_logger: + raw_logger.log_final_response( + status_code=status_code, + headers=headers or None, + body=e.error_payload, + ) + return JSONResponse( + content=e.error_payload, + status_code=status_code, + headers=headers or None, + ) + except StopAsyncIteration: + empty_error = { + "error": { + "message": "Upstream stream ended before first chunk", + "type": "proxy_internal_error", + } + } + logging.error("Streaming bootstrap produced no chunks.") + if raw_logger: + raw_logger.log_final_response( + status_code=502, + headers=None, + body=empty_error, + ) + return JSONResponse(content=empty_error, status_code=502) + return StreamingResponse( streaming_response_wrapper( - request, request_data, response_generator, raw_logger + request, request_data, primed_stream, raw_logger ), media_type="text/event-stream", ) - else: - response = await client.acompletion(request=request, **request_data) - if raw_logger: - # Assuming response has status_code and headers attributes - # This might need adjustment based on the actual response object - response_headers = ( - response.headers if hasattr(response, "headers") else None - ) - status_code = ( - response.status_code if hasattr(response, "status_code") else 200 - ) - raw_logger.log_final_response( - status_code=status_code, - headers=response_headers, - body=response.model_dump(), - ) - return response + + response = await client.acompletion(request=request, **request_data) + if raw_logger: + # Assuming response has status_code and headers attributes + # This might need adjustment based on the actual response object + response_headers = response.headers if hasattr(response, "headers") else None + status_code = response.status_code if hasattr(response, "status_code") else 200 + raw_logger.log_final_response( + status_code=status_code, + headers=response_headers, + body=response.model_dump(), + ) + return response except ( litellm.InvalidRequestError, diff --git a/src/rotator_library/client/executor.py b/src/rotator_library/client/executor.py index 8f50bacb..6b7579e2 100644 --- a/src/rotator_library/client/executor.py +++ b/src/rotator_library/client/executor.py @@ -44,6 +44,7 @@ NoAvailableKeysError, PreRequestCallbackError, StreamedAPIError, + StreamBootstrapError, ClassifiedError, RequestErrorAccumulator, classify_error, @@ -668,18 +669,18 @@ async def _execute_streaming( """ Execute streaming request with retry/rotation. - This is an async generator that yields SSE-formatted strings. - - Args: - context: RequestContext with all request details - - Yields: - SSE-formatted strings + Retries/rotation are only allowed before the first chunk is emitted. + After first-byte commit, failures are terminated in-band as SSE error + [DONE] + to avoid duplicated/corrupted client output. """ provider = context.provider model = context.model deadline = context.deadline + stream_committed = False + prestream_status_codes: List[int] = [] + prestream_retry_after_values: List[int] = [] + try: ( usage_manager, @@ -695,9 +696,11 @@ async def _execute_streaming( "type": "proxy_error", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - return + raise StreamBootstrapError( + message=str(exc), + status_code=503, + error_payload=error_data, + ) from exc error_accumulator = RequestErrorAccumulator() error_accumulator.model = model @@ -820,9 +823,21 @@ async def _execute_streaming( context.transaction_logger, context.kwargs, ): + if not stream_committed: + stream_committed = True + lib_logger.info( + "Stream committed after first chunk; " + "disabling further retry/rotation for this request" + ) yield chunk else: async for chunk in base_stream: + if not stream_committed: + stream_committed = True + lib_logger.info( + "Stream committed after first chunk; " + "disabling further retry/rotation for this request" + ) yield chunk return @@ -840,6 +855,22 @@ async def _execute_streaming( error_accumulator.record_error( cred, classified, str(original)[:150] ) + self._record_prestream_failure( + stream_committed, + classified, + prestream_status_codes, + prestream_retry_after_values, + ) + + if stream_committed: + lib_logger.warning( + "Streaming error after first chunk; ending stream without retry/rotation" + ) + yield self._build_stream_error_chunk( + classified, str(original) + ) + yield "data: [DONE]\n\n" + return # Track consecutive quota failures if classified.error_type == "quota_exceeded": @@ -856,9 +887,12 @@ async def _execute_streaming( "type": "quota_exhausted", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - return + raise self._build_stream_bootstrap_error( + error_data, + prestream_status_codes, + prestream_retry_after_values, + original, + ) else: retry_state.reset_quota_failures() @@ -867,7 +901,7 @@ async def _execute_streaming( raise cred_context.mark_failure(classified) - break # Rotate + break except (RateLimitError, httpx.HTTPStatusError) as e: last_exception = e @@ -882,6 +916,22 @@ async def _execute_streaming( error_accumulator.record_error( cred, classified, str(e)[:150] ) + self._record_prestream_failure( + stream_committed, + classified, + prestream_status_codes, + prestream_retry_after_values, + ) + + if stream_committed: + lib_logger.warning( + "Rate-limit/server status error after first chunk; ending stream without retry/rotation" + ) + yield self._build_stream_error_chunk( + classified, str(e) + ) + yield "data: [DONE]\n\n" + return # Track consecutive quota failures if classified.error_type == "quota_exceeded": @@ -898,9 +948,12 @@ async def _execute_streaming( "type": "quota_exhausted", } } - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - return + raise self._build_stream_bootstrap_error( + error_data, + prestream_status_codes, + prestream_retry_after_values, + e, + ) else: retry_state.reset_quota_failures() @@ -917,9 +970,7 @@ async def _execute_streaming( ) if ( classified.retry_after is not None - and 0 - < classified.retry_after - < small_cooldown_threshold + and 0 < classified.retry_after < small_cooldown_threshold and attempt < self._max_retries - 1 ): remaining = deadline - time.time() @@ -929,10 +980,10 @@ async def _execute_streaming( f"(small cooldown {classified.retry_after}s < {small_cooldown_threshold}s threshold)" ) await asyncio.sleep(classified.retry_after) - continue # Retry same key + continue cred_context.mark_failure(classified) - break # Rotate + break except ( APIConnectionError, @@ -948,13 +999,29 @@ async def _execute_streaming( error=e, request_headers=request_headers, ) + self._record_prestream_failure( + stream_committed, + classified, + prestream_status_codes, + prestream_retry_after_values, + ) + + if stream_committed: + lib_logger.warning( + "Upstream connection/server error after first chunk; ending stream without retry/rotation" + ) + yield self._build_stream_error_chunk( + classified, str(e) + ) + yield "data: [DONE]\n\n" + return if attempt >= self._max_retries - 1: error_accumulator.record_error( cred, classified, str(e)[:150] ) cred_context.mark_failure(classified) - break # Rotate + break # Calculate wait time wait_time = classified.retry_after or ( @@ -962,10 +1029,10 @@ async def _execute_streaming( ) + random.uniform(0, 1) remaining = deadline - time.time() if wait_time > remaining: - break # No time to wait + break await asyncio.sleep(wait_time) - continue # Retry + continue except Exception as e: last_exception = e @@ -980,18 +1047,34 @@ async def _execute_streaming( error_accumulator.record_error( cred, classified, str(e)[:150] ) + self._record_prestream_failure( + stream_committed, + classified, + prestream_status_codes, + prestream_retry_after_values, + ) + + if stream_committed: + lib_logger.warning( + "Unhandled stream exception after first chunk; ending stream without retry/rotation" + ) + yield self._build_stream_error_chunk( + classified, str(e) + ) + yield "data: [DONE]\n\n" + return if not should_rotate_on_error(classified): cred_context.mark_failure(classified) raise cred_context.mark_failure(classified) - break # Rotate + break + # Let context manager handle cleanup except PreRequestCallbackError: raise except Exception: - # Let context manager handle cleanup pass except NoAvailableKeysError: @@ -999,21 +1082,134 @@ async def _execute_streaming( # All credentials exhausted or timeout error_accumulator.timeout_occurred = time.time() >= deadline + + if stream_committed: + error_data = error_accumulator.build_client_error_response() + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + + if last_exception and not error_accumulator.has_errors(): + raise last_exception + error_data = error_accumulator.build_client_error_response() - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + raise self._build_stream_bootstrap_error( + error_data, + prestream_status_codes, + prestream_retry_after_values, + last_exception, + ) except NoAvailableKeysError as e: lib_logger.error(f"No keys available: {e}") error_data = {"error": {"message": str(e), "type": "proxy_busy"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + if stream_committed: + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + raise StreamBootstrapError( + message=str(e), + status_code=503, + error_payload=error_data, + ) from e + + except StreamBootstrapError: + raise except Exception as e: lib_logger.error(f"Unhandled exception in streaming: {e}", exc_info=True) - error_data = {"error": {"message": str(e), "type": "proxy_internal_error"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" + error_data = { + "error": { + "message": str(e), + "type": "proxy_internal_error", + } + } + if stream_committed: + yield f"data: {json.dumps(error_data)}\n\n" + yield "data: [DONE]\n\n" + return + raise StreamBootstrapError( + message=str(e), + status_code=500, + error_payload=error_data, + ) from e + + def _record_prestream_failure( + self, + stream_committed: bool, + classified: ClassifiedError, + prestream_status_codes: List[int], + prestream_retry_after_values: List[int], + ) -> None: + """Track pre-stream failure metadata for final HTTP status selection.""" + if stream_committed: + return + + if classified.status_code is not None: + prestream_status_codes.append(classified.status_code) + if classified.retry_after is not None and classified.retry_after > 0: + prestream_retry_after_values.append(classified.retry_after) + + def _build_stream_error_chunk( + self, + classified: ClassifiedError, + fallback_message: str, + ) -> str: + """Build a terminal SSE error chunk for post-commit stream failures.""" + error_payload: Dict[str, Any] = { + "error": { + "message": fallback_message, + "type": "stream_error", + "error_type": classified.error_type, + "code": classified.status_code, + } + } + if classified.retry_after is not None and classified.retry_after > 0: + error_payload["error"]["retry_after"] = classified.retry_after + return f"data: {json.dumps(error_payload)}\n\n" + + def _build_stream_bootstrap_error( + self, + error_payload: Dict[str, Any], + prestream_status_codes: List[int], + prestream_retry_after_values: List[int], + last_exception: Optional[Exception], + ) -> StreamBootstrapError: + """Build a structured bootstrap error for HTTP-layer mapping.""" + status_code = 500 + + if prestream_status_codes: + if all(code == 429 for code in prestream_status_codes): + status_code = 429 + else: + status_code = prestream_status_codes[-1] + elif last_exception is not None: + classified_last = classify_error(last_exception) + if classified_last.status_code is not None: + status_code = classified_last.status_code + + retry_after = ( + min(prestream_retry_after_values) + if prestream_retry_after_values + else None + ) + + message = ( + error_payload.get("error", {}).get("message") + or "Streaming request failed before first chunk" + ) + + if status_code == 429: + lib_logger.warning( + "Terminal pre-stream 429 encountered; returning HTTP 429 instead of SSE error stream" + ) + + return StreamBootstrapError( + message=message, + status_code=status_code, + error_payload=error_payload, + retry_after=retry_after, + ) def _apply_litellm_provider_params( self, provider: str, kwargs: Dict[str, Any] diff --git a/src/rotator_library/core/errors.py b/src/rotator_library/core/errors.py index 5acd9fc7..cd52d9b2 100644 --- a/src/rotator_library/core/errors.py +++ b/src/rotator_library/core/errors.py @@ -12,6 +12,8 @@ compatibility. This module provides a cleaner import path. """ +from typing import Any, Dict, Optional + # Re-export everything from error_handler from ..error_handler import ( # Exception classes @@ -62,6 +64,32 @@ def __init__(self, message: str, data=None): self.data = data +class StreamBootstrapError(Exception): + """ + Raised when a streaming request fails before the first chunk is emitted. + + This lets the HTTP layer return a proper non-200 response (e.g. HTTP 429) + instead of committing an SSE stream and sending an in-band error payload. + """ + + def __init__( + self, + message: str, + status_code: int = 500, + error_payload: Optional[Dict[str, Any]] = None, + retry_after: Optional[int] = None, + ): + super().__init__(message) + self.status_code = status_code + self.error_payload = error_payload or { + "error": { + "message": message, + "type": "proxy_error", + } + } + self.retry_after = retry_after + + __all__ = [ # Exception classes "NoAvailableKeysError", @@ -70,6 +98,7 @@ def __init__(self, message: str, data=None): "EmptyResponseError", "TransientQuotaError", "StreamedAPIError", + "StreamBootstrapError", # Error classification "ClassifiedError", "RequestErrorAccumulator",