diff --git a/python/ray/serve/_private/proxy.py b/python/ray/serve/_private/proxy.py index ce41252cd962..6e8f74b12684 100644 --- a/python/ray/serve/_private/proxy.py +++ b/python/ray/serve/_private/proxy.py @@ -442,9 +442,9 @@ async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator: if response_handler_info.should_increment_ongoing_requests: self._ongoing_requests_start() + # The final message yielded must always be the `ResponseStatus`. + status: Optional[ResponseStatus] = None try: - # The final message yielded must always be the `ResponseStatus`. - status: Optional[ResponseStatus] = None async for message in response_handler_info.response_generator: if isinstance(message, ResponseStatus): status = message @@ -453,36 +453,55 @@ async def proxy_request(self, proxy_request: ProxyRequest) -> ResponseGenerator: assert status is not None and isinstance(status, ResponseStatus) finally: + if status is None: + if proxy_request.request_type == "websocket": + status = ResponseStatus(code="1006", is_error=True) + elif proxy_request.request_type == "http": + status = ResponseStatus(code="499", is_error=True) + # If anything during the request failed, we still want to ensure the ongoing # request counter is decremented. if response_handler_info.should_increment_ongoing_requests: self._ongoing_requests_end() - latency_ms = (time.time() - start_time) * 1000.0 - if response_handler_info.should_record_access_log: - request_context = ray.serve.context._get_serve_request_context() - self._access_log_context[SERVE_LOG_ROUTE] = request_context.route - self._access_log_context[SERVE_LOG_REQUEST_ID] = request_context.request_id - logger.info( - access_log_msg( + if status is not None: + latency_ms = (time.time() - start_time) * 1000.0 + if response_handler_info.should_record_access_log: + request_context = ray.serve.context._get_serve_request_context() + self._access_log_context[SERVE_LOG_ROUTE] = request_context.route + self._access_log_context[SERVE_LOG_REQUEST_ID] = ( + request_context.request_id + ) + logger.info( + access_log_msg( + method=proxy_request.method, + route=request_context.route, + status=str(status.code), + latency_ms=latency_ms, + client=format_client_address(proxy_request.client), + ), + extra=self._access_log_context, + ) + + self._proxy_metrics.record_request( + route=response_handler_info.metadata.route, method=proxy_request.method, - route=request_context.route, - status=str(status.code), + application=response_handler_info.metadata.application_name, + status_code=str(status.code), latency_ms=latency_ms, - client=format_client_address(proxy_request.client), - ), - extra=self._access_log_context, - ) + is_error=status.is_error, + deployment_name=response_handler_info.metadata.deployment_name, + ) - self._proxy_metrics.record_request( - route=response_handler_info.metadata.route, - method=proxy_request.method, - application=response_handler_info.metadata.application_name, - status_code=str(status.code), - latency_ms=latency_ms, - is_error=status.is_error, - deployment_name=response_handler_info.metadata.deployment_name, - ) + if status is None: + logger.warning( + "Proxy request ended before a response status was available.", + extra={ + "route": response_handler_info.metadata.route, + "method": proxy_request.method, + "request_type": proxy_request.request_type, + }, + ) @abstractmethod def setup_request_context_and_handle( @@ -663,6 +682,7 @@ async def stream_unary( is received. It wraps the request iterator and calls proxy_request. The return value is serialized user defined protobuf bytes. """ + # Create async iterator wrapper for the request stream async def async_request_iterator(): async for request in request_iterator: @@ -701,6 +721,7 @@ async def stream_stream( request is received. It wraps the request iterator and calls proxy_request. The return value is a generator of serialized user defined protobuf bytes. """ + # Create async iterator wrapper for the request stream async def async_request_iterator(): async for request in request_iterator: @@ -862,9 +883,11 @@ def _finalize_grpc_tracing( set_rpc_span_attributes( system=proxy_request.request_type, method=proxy_request.method, - status_code=status.code.name - if isinstance(status.code, grpc.StatusCode) - else grpc.StatusCode.UNKNOWN.name, + status_code=( + status.code.name + if isinstance(status.code, grpc.StatusCode) + else grpc.StatusCode.UNKNOWN.name + ), ) self._finalize_proxy_tracing(status=status, exc=exc) @@ -1345,6 +1368,13 @@ async def send_request_to_replica( yield asgi_message response_started = True + except GeneratorExit: + if status is None: + if proxy_request.request_type == "websocket": + status = ResponseStatus(code="1006", is_error=True) + else: + status = ResponseStatus(code="499", is_error=True) + raise except BaseException as e: error_status = get_http_response_status(e, request_timeout_s, request_id) if status is None: diff --git a/python/ray/serve/tests/unit/test_proxy.py b/python/ray/serve/tests/unit/test_proxy.py index 692d2cfd3901..f2b1f3c89d9f 100644 --- a/python/ray/serve/tests/unit/test_proxy.py +++ b/python/ray/serve/tests/unit/test_proxy.py @@ -2,7 +2,7 @@ import pickle import sys from typing import Dict, List, Tuple -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import grpc import pytest @@ -757,6 +757,90 @@ async def test_websocket_call(self, disconnect: str): # Ensure after calling __call__, send.messages should be expected messages. assert send.messages == expected_messages + @pytest.mark.asyncio + async def test_websocket_client_disconnect_records_proxy_metrics(self): + """A dropped ASGI sender should still count the WebSocket request.""" + expected_messages = [ + {"type": "websocket.accept"}, + {"type": "websocket.send"}, + ] + + http_proxy = self.create_http_proxy() + http_proxy.proxy_router.route = "/ws" + http_proxy.proxy_router.handle = FakeHTTPHandle(messages=expected_messages) + http_proxy.proxy_router.app_is_cross_language = False + http_proxy._proxy_metrics.record_request = Mock() + + proxy_request = ASGIProxyRequest( + scope={ + "type": "websocket", + "path": "/ws", + "root_path": "", + "headers": [(b"x-request-id", b"fake_request_id")], + "client": ("127.0.0.1", 12345), + }, + receive=FakeHttpReceive(), + send=FakeHttpSend(), + ) + + response_generator = http_proxy.proxy_request(proxy_request) + assert await response_generator.__anext__() == {"type": "websocket.accept"} + await response_generator.aclose() + + http_proxy._proxy_metrics.record_request.assert_called_once() + call_kwargs = http_proxy._proxy_metrics.record_request.call_args.kwargs + assert call_kwargs["route"] == "/ws" + assert call_kwargs["method"] == "WS" + assert call_kwargs["application"] == "fake_app_name" + assert call_kwargs["deployment_name"] == "fake_deployment_name" + assert call_kwargs["status_code"] == "1006" + assert call_kwargs["is_error"] is True + + @pytest.mark.asyncio + async def test_http_client_disconnect_after_response_start_records_proxy_metrics( + self, + ): + """A dropped HTTP sender before response completion should still be counted.""" + http_proxy = self.create_http_proxy() + http_proxy.proxy_router.route = "/stream" + http_proxy.proxy_router.handle = FakeHTTPHandle( + messages=[ + {"type": "http.response.start", "status": "200"}, + {"type": "http.response.body"}, + ] + ) + http_proxy.proxy_router.app_is_cross_language = False + http_proxy._proxy_metrics.record_request = Mock() + + proxy_request = ASGIProxyRequest( + scope={ + "type": "http", + "method": "GET", + "path": "/stream", + "root_path": "", + "headers": [(b"x-request-id", b"fake_request_id")], + "client": ("127.0.0.1", 12345), + }, + receive=FakeHttpReceive(), + send=FakeHttpSend(), + ) + + response_generator = http_proxy.proxy_request(proxy_request) + assert await response_generator.__anext__() == { + "type": "http.response.start", + "status": "200", + } + await response_generator.aclose() + + http_proxy._proxy_metrics.record_request.assert_called_once() + call_kwargs = http_proxy._proxy_metrics.record_request.call_args.kwargs + assert call_kwargs["route"] == "/stream" + assert call_kwargs["method"] == "GET" + assert call_kwargs["application"] == "fake_app_name" + assert call_kwargs["deployment_name"] == "fake_deployment_name" + assert call_kwargs["status_code"] == "499" + assert call_kwargs["is_error"] is True + @pytest.mark.parametrize( "header_key", [