Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 77 additions & 18 deletions src/proxy_app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
print(" → Initializing proxy core...")
with _console.status("[dim]Initializing proxy core...", spinner="dots"):
from rotator_library import RotatingClient
from rotator_library.core.errors import StreamBootstrapError
from rotator_library.credential_manager import CredentialManager
from rotator_library.background_refresher import BackgroundRefresher
from rotator_library.model_info_service import init_model_info_service
Expand Down Expand Up @@ -714,6 +715,29 @@ async def verify_anthropic_api_key(
raise HTTPException(status_code=401, detail="Invalid or missing API Key")


async def _prepend_first_stream_chunk(
first_chunk: str,
stream: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]:
"""Yield a primed first chunk, then continue streaming remaining chunks."""
yield first_chunk
async for chunk in stream:
yield chunk


async def _prepare_streaming_response(
response_stream: AsyncGenerator[str, None],
) -> tuple[Optional[str], AsyncGenerator[str, None]]:
"""
Prime streaming generator to avoid committing HTTP 200 before bootstrap errors.

Returns:
(first_chunk, stream_with_first_chunk_replayed)
"""
first_chunk = await response_stream.__anext__()
return first_chunk, _prepend_first_stream_chunk(first_chunk, response_stream)


async def streaming_response_wrapper(
request: Request,
request_data: dict,
Expand Down Expand Up @@ -961,29 +985,64 @@ async def chat_completions(
response_generator = await client.acompletion(
request=request, **request_data
)

try:
_, primed_stream = await _prepare_streaming_response(response_generator)
except StreamBootstrapError as e:
status_code = e.status_code if e.status_code else 500
headers = {}
if e.retry_after is not None and e.retry_after > 0:
headers["Retry-After"] = str(int(e.retry_after))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Retry-After header is only set when retry_after > 0. If retry_after is exactly 0, the header will not be set. Consider whether retry_after >= 0 would be more appropriate, or document why 0 is excluded.

Comment on lines +994 to +995
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid truncating Retry-After to a lower value.

At Line [995], int(e.retry_after) floors fractional values (e.g., 0.5 -> 0), which can cause clients to retry too early.

🔧 Proposed fix
+import math
@@
-                if e.retry_after is not None and e.retry_after > 0:
-                    headers["Retry-After"] = str(int(e.retry_after))
+                if e.retry_after is not None and e.retry_after > 0:
+                    headers["Retry-After"] = str(math.ceil(e.retry_after))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if e.retry_after is not None and e.retry_after > 0:
headers["Retry-After"] = str(int(e.retry_after))
import math
if e.retry_after is not None and e.retry_after > 0:
headers["Retry-After"] = str(math.ceil(e.retry_after))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/proxy_app/main.py` around lines 994 - 995, The code sets
headers["Retry-After"] = str(int(e.retry_after)), which floors fractional
retry_after values and can make clients retry too early; change this to round up
instead (e.g., import math and set headers["Retry-After"] =
str(math.ceil(e.retry_after))) so any fractional seconds are preserved by
rounding up; update the exception handling block that accesses e.retry_after and
ensure math.ceil is used rather than int to compute the header value.

logging.warning(
"Streaming bootstrap failed before first chunk "
f"(status={status_code}, retry_after={e.retry_after})"
)
if raw_logger:
raw_logger.log_final_response(
status_code=status_code,
headers=headers or None,
body=e.error_payload,
)
return JSONResponse(
content=e.error_payload,
status_code=status_code,
headers=headers or None,
)
except StopAsyncIteration:
empty_error = {
"error": {
"message": "Upstream stream ended before first chunk",
"type": "proxy_internal_error",
}
}
logging.error("Streaming bootstrap produced no chunks.")
if raw_logger:
raw_logger.log_final_response(
status_code=502,
headers=None,
body=empty_error,
)
return JSONResponse(content=empty_error, status_code=502)

return StreamingResponse(
streaming_response_wrapper(
request, request_data, response_generator, raw_logger
request, request_data, primed_stream, raw_logger
),
media_type="text/event-stream",
)
else:
response = await client.acompletion(request=request, **request_data)
if raw_logger:
# Assuming response has status_code and headers attributes
# This might need adjustment based on the actual response object
response_headers = (
response.headers if hasattr(response, "headers") else None
)
status_code = (
response.status_code if hasattr(response, "status_code") else 200
)
raw_logger.log_final_response(
status_code=status_code,
headers=response_headers,
body=response.model_dump(),
)
return response

response = await client.acompletion(request=request, **request_data)
if raw_logger:
# Assuming response has status_code and headers attributes
# This might need adjustment based on the actual response object
response_headers = response.headers if hasattr(response, "headers") else None
status_code = response.status_code if hasattr(response, "status_code") else 200
raw_logger.log_final_response(
status_code=status_code,
headers=response_headers,
body=response.model_dump(),
)
Comment on lines +1035 to +1044
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t let raw logging break a successful completion response.

At Line [1043], response.model_dump() is unconditional. If a provider/plugin returns a non-Pydantic object, this can raise and turn a successful upstream response into a 500 when raw logging is enabled.

🔧 Proposed fix
         if raw_logger:
             # Assuming response has status_code and headers attributes
             # This might need adjustment based on the actual response object
             response_headers = response.headers if hasattr(response, "headers") else None
             status_code = response.status_code if hasattr(response, "status_code") else 200
+            if hasattr(response, "model_dump"):
+                response_body = response.model_dump()
+            elif isinstance(response, dict):
+                response_body = response
+            else:
+                response_body = {"response": str(response)}
             raw_logger.log_final_response(
                 status_code=status_code,
                 headers=response_headers,
-                body=response.model_dump(),
+                body=response_body,
             )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/proxy_app/main.py` around lines 1035 - 1044, The raw logging calls can
raise when response.model_dump() is not present and thus convert a successful
response into a 500; update the block that calls raw_logger.log_final_response
to guard the model_dump call (e.g., check hasattr(response, "model_dump") or use
a try/except) and pass a safe fallback (like the raw response or its str/repr)
for the body so logging never raises; modify the code referencing raw_logger,
response.model_dump(), response.headers and status_code to safely derive body
without throwing and then call raw_logger.log_final_response with the safe body
value.

return response

except (
litellm.InvalidRequestError,
Expand Down
Loading