Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 57 additions & 27 deletions python/ray/serve/_private/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,9 @@ async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator:
if response_handler_info.should_increment_ongoing_requests:
self._ongoing_requests_start()

# The final message yielded must always be the `ResponseStatus`.
status: Optional[ResponseStatus] = None
try:
# The final message yielded must always be the `ResponseStatus`.
status: Optional[ResponseStatus] = None
async for message in response_handler_info.response_generator:
if isinstance(message, ResponseStatus):
status = message
Expand All @@ -453,36 +453,55 @@ async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator:

assert status is not None and isinstance(status, ResponseStatus)
finally:
if status is None:
if proxy_request.request_type == "websocket":
status = ResponseStatus(code="1006", is_error=True)
elif proxy_request.request_type == "http":
status = ResponseStatus(code="499", is_error=True)

# If anything during the request failed, we still want to ensure the ongoing
# request counter is decremented.
if response_handler_info.should_increment_ongoing_requests:
self._ongoing_requests_end()

latency_ms = (time.time() - start_time) * 1000.0
if response_handler_info.should_record_access_log:
request_context = ray.serve.context._get_serve_request_context()
self._access_log_context[SERVE_LOG_ROUTE] = request_context.route
self._access_log_context[SERVE_LOG_REQUEST_ID] = request_context.request_id
logger.info(
access_log_msg(
if status is not None:
latency_ms = (time.time() - start_time) * 1000.0
if response_handler_info.should_record_access_log:
request_context = ray.serve.context._get_serve_request_context()
self._access_log_context[SERVE_LOG_ROUTE] = request_context.route
self._access_log_context[SERVE_LOG_REQUEST_ID] = (
request_context.request_id
)
logger.info(
access_log_msg(
method=proxy_request.method,
route=request_context.route,
status=str(status.code),
latency_ms=latency_ms,
client=format_client_address(proxy_request.client),
),
extra=self._access_log_context,
)

self._proxy_metrics.record_request(
route=response_handler_info.metadata.route,
method=proxy_request.method,
route=request_context.route,
status=str(status.code),
application=response_handler_info.metadata.application_name,
status_code=str(status.code),
latency_ms=latency_ms,
client=format_client_address(proxy_request.client),
),
extra=self._access_log_context,
)
is_error=status.is_error,
deployment_name=response_handler_info.metadata.deployment_name,
)

self._proxy_metrics.record_request(
route=response_handler_info.metadata.route,
method=proxy_request.method,
application=response_handler_info.metadata.application_name,
status_code=str(status.code),
latency_ms=latency_ms,
is_error=status.is_error,
deployment_name=response_handler_info.metadata.deployment_name,
)
if status is None:
logger.warning(
"Proxy request ended before a response status was available.",
extra={
"route": response_handler_info.metadata.route,
"method": proxy_request.method,
"request_type": proxy_request.request_type,
},
)

@abstractmethod
def setup_request_context_and_handle(
Expand Down Expand Up @@ -663,6 +682,7 @@ async def stream_unary(
is received. It wraps the request iterator and calls proxy_request.
The return value is serialized user defined protobuf bytes.
"""

# Create async iterator wrapper for the request stream
async def async_request_iterator():
async for request in request_iterator:
Expand Down Expand Up @@ -701,6 +721,7 @@ async def stream_stream(
request is received. It wraps the request iterator and calls proxy_request.
The return value is a generator of serialized user defined protobuf bytes.
"""

# Create async iterator wrapper for the request stream
async def async_request_iterator():
async for request in request_iterator:
Expand Down Expand Up @@ -862,9 +883,11 @@ def _finalize_grpc_tracing(
set_rpc_span_attributes(
system=proxy_request.request_type,
method=proxy_request.method,
status_code=status.code.name
if isinstance(status.code, grpc.StatusCode)
else grpc.StatusCode.UNKNOWN.name,
status_code=(
status.code.name
if isinstance(status.code, grpc.StatusCode)
else grpc.StatusCode.UNKNOWN.name
),
)
self._finalize_proxy_tracing(status=status, exc=exc)

Expand Down Expand Up @@ -1345,6 +1368,13 @@ async def send_request_to_replica(

yield asgi_message
response_started = True
except GeneratorExit:
if status is None:
if proxy_request.request_type == "websocket":
status = ResponseStatus(code="1006", is_error=True)
else:
status = ResponseStatus(code="499", is_error=True)
raise
Comment on lines +1371 to +1377

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

By adding except GeneratorExit: raise to prevent GeneratorExit from being caught as an unexpected 500 error, status remains None if the generator is closed before the response starts (e.g., when an HTTP client disconnects early).

In the finally block of send_request_to_replica, if status is None and the request type is not "websocket", it falls through to the else block:

else:
    status_code = status.code

This will raise an unhandled AttributeError: 'NoneType' object has no attribute 'code', masking the original GeneratorExit and potentially causing resource leaks or broken cleanup.

To fix this, we should defensively set status to an appropriate status code (such as "499" for client closed request) when catching GeneratorExit before re-raising it.

Suggested change
except GeneratorExit:
raise
except GeneratorExit:
if status is None:
status = ResponseStatus(code="499", is_error=True)
raise

Comment thread
cursor[bot] marked this conversation as resolved.
except BaseException as e:
error_status = get_http_response_status(e, request_timeout_s, request_id)
if status is None:
Expand Down
86 changes: 85 additions & 1 deletion python/ray/serve/tests/unit/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import sys
from typing import Dict, List, Tuple
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, Mock

import grpc
import pytest
Expand Down Expand Up @@ -757,6 +757,90 @@ async def test_websocket_call(self, disconnect: str):
# Ensure after calling __call__, send.messages should be expected messages.
assert send.messages == expected_messages

@pytest.mark.asyncio
async def test_websocket_client_disconnect_records_proxy_metrics(self):
"""A dropped ASGI sender should still count the WebSocket request."""
expected_messages = [
{"type": "websocket.accept"},
{"type": "websocket.send"},
]

http_proxy = self.create_http_proxy()
http_proxy.proxy_router.route = "/ws"
http_proxy.proxy_router.handle = FakeHTTPHandle(messages=expected_messages)
http_proxy.proxy_router.app_is_cross_language = False
http_proxy._proxy_metrics.record_request = Mock()

proxy_request = ASGIProxyRequest(
scope={
"type": "websocket",
"path": "/ws",
"root_path": "",
"headers": [(b"x-request-id", b"fake_request_id")],
"client": ("127.0.0.1", 12345),
},
receive=FakeHttpReceive(),
send=FakeHttpSend(),
)

response_generator = http_proxy.proxy_request(proxy_request)
assert await response_generator.__anext__() == {"type": "websocket.accept"}
await response_generator.aclose()

http_proxy._proxy_metrics.record_request.assert_called_once()
call_kwargs = http_proxy._proxy_metrics.record_request.call_args.kwargs
assert call_kwargs["route"] == "/ws"
assert call_kwargs["method"] == "WS"
assert call_kwargs["application"] == "fake_app_name"
assert call_kwargs["deployment_name"] == "fake_deployment_name"
assert call_kwargs["status_code"] == "1006"
assert call_kwargs["is_error"] is True

@pytest.mark.asyncio
async def test_http_client_disconnect_after_response_start_records_proxy_metrics(
self,
):
"""A dropped HTTP sender before response completion should still be counted."""
http_proxy = self.create_http_proxy()
http_proxy.proxy_router.route = "/stream"
http_proxy.proxy_router.handle = FakeHTTPHandle(
messages=[
{"type": "http.response.start", "status": "200"},
{"type": "http.response.body"},
]
)
http_proxy.proxy_router.app_is_cross_language = False
http_proxy._proxy_metrics.record_request = Mock()

proxy_request = ASGIProxyRequest(
scope={
"type": "http",
"method": "GET",
"path": "/stream",
"root_path": "",
"headers": [(b"x-request-id", b"fake_request_id")],
"client": ("127.0.0.1", 12345),
},
receive=FakeHttpReceive(),
send=FakeHttpSend(),
)

response_generator = http_proxy.proxy_request(proxy_request)
assert await response_generator.__anext__() == {
"type": "http.response.start",
"status": "200",
}
await response_generator.aclose()

http_proxy._proxy_metrics.record_request.assert_called_once()
call_kwargs = http_proxy._proxy_metrics.record_request.call_args.kwargs
assert call_kwargs["route"] == "/stream"
assert call_kwargs["method"] == "GET"
assert call_kwargs["application"] == "fake_app_name"
assert call_kwargs["deployment_name"] == "fake_deployment_name"
assert call_kwargs["status_code"] == "499"
assert call_kwargs["is_error"] is True

@pytest.mark.parametrize(
"header_key",
[
Expand Down
Loading