Skip to content

Commit 4a7bcc5

Browse files
Varun SharmaCopilot
andcommitted
fix: handle HTTP 4xx/5xx gracefully in StreamableHTTP transport
Replace raise_for_status() calls with graceful status code checks in GET SSE paths, and include HTTP status code in POST error messages for better client-side error handling. Changes: - POST error message now includes HTTP status code (e.g. 'Server returned HTTP 401') instead of generic 'Server returned an error response' - GET SSE listener: check status code and retry instead of crashing - Resumption GET: return JSONRPC error with status code to client - Reconnection GET: return JSONRPC error with status code to client Fixes #1295 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 62575ed commit 4a7bcc5

3 files changed

Lines changed: 133 additions & 5 deletions

File tree

src/mcp/client/streamable_http.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:
194194
headers[LAST_EVENT_ID] = last_event_id
195195

196196
async with aconnect_sse(client, "GET", self.url, headers=headers) as event_source:
197-
event_source.response.raise_for_status()
197+
if event_source.response.status_code >= 400:
198+
logger.warning(f"GET SSE returned HTTP {event_source.response.status_code}")
199+
attempt += 1
200+
continue
198201
logger.debug("GET SSE connection established")
199202

200203
async for sse in event_source.aiter_sse():
@@ -237,7 +240,16 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
237240
original_request_id = ctx.session_message.message.id
238241

239242
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
240-
event_source.response.raise_for_status()
243+
if event_source.response.status_code >= 400:
244+
logger.warning(f"Resumption GET returned HTTP {event_source.response.status_code}")
245+
if original_request_id is not None:
246+
error_data = ErrorData(
247+
code=INTERNAL_ERROR,
248+
message=f"Server returned HTTP {event_source.response.status_code}",
249+
)
250+
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
251+
await ctx.read_stream_writer.send(error_msg)
252+
return
241253
logger.debug("Resumption GET SSE connection established")
242254

243255
async for sse in event_source.aiter_sse(): # pragma: no branch
@@ -276,7 +288,10 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
276288

277289
if response.status_code >= 400:
278290
if isinstance(message, JSONRPCRequest):
279-
error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response")
291+
error_data = ErrorData(
292+
code=INTERNAL_ERROR,
293+
message=f"Server returned HTTP {response.status_code}",
294+
)
280295
session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
281296
await ctx.read_stream_writer.send(session_message)
282297
return
@@ -398,7 +413,18 @@ async def _handle_reconnection(
398413

399414
try:
400415
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
401-
event_source.response.raise_for_status()
416+
if event_source.response.status_code >= 400:
417+
logger.warning(f"Reconnection GET returned HTTP {event_source.response.status_code}")
418+
if original_request_id is not None:
419+
error_data = ErrorData(
420+
code=INTERNAL_ERROR,
421+
message=f"Server returned HTTP {event_source.response.status_code}",
422+
)
423+
error_msg = SessionMessage(
424+
JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data)
425+
)
426+
await ctx.read_stream_writer.send(error_msg)
427+
return
402428
logger.info("Reconnected to SSE stream")
403429

404430
# Track for potential further reconnection

tests/client/test_notification_response.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async def test_http_error_status_sends_jsonrpc_error() -> None:
148148
async with ClientSession(read_stream, write_stream) as session: # pragma: no branch
149149
await session.initialize()
150150

151-
with pytest.raises(MCPError, match="Server returned an error response"): # pragma: no branch
151+
with pytest.raises(MCPError, match="Server returned HTTP 500"): # pragma: no branch
152152
await session.list_tools()
153153

154154

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Tests for StreamableHTTP client transport HTTP error handling.
2+
3+
Verifies that HTTP 4xx/5xx responses are handled gracefully
4+
instead of crashing the program.
5+
"""
6+
7+
import json
8+
9+
import httpx
10+
import pytest
11+
from starlette.applications import Starlette
12+
from starlette.requests import Request
13+
from starlette.responses import JSONResponse, Response
14+
from starlette.routing import Route
15+
16+
from mcp import ClientSession, types
17+
from mcp.client.streamable_http import streamable_http_client
18+
from mcp.shared.session import RequestResponder
19+
20+
pytestmark = pytest.mark.anyio
21+
22+
INIT_RESPONSE = {
23+
"serverInfo": {"name": "test-http-error-server", "version": "1.0.0"},
24+
"protocolVersion": "2024-11-05",
25+
"capabilities": {},
26+
}
27+
28+
29+
def _create_401_server_app() -> Starlette:
30+
"""Create a server that returns 401 for non-init requests."""
31+
32+
async def handle_mcp_request(request: Request) -> Response:
33+
body = await request.body()
34+
data = json.loads(body)
35+
36+
if data.get("method") == "initialize":
37+
return JSONResponse({"jsonrpc": "2.0", "id": data["id"], "result": INIT_RESPONSE})
38+
39+
if "id" not in data:
40+
return Response(status_code=202)
41+
42+
return Response(status_code=401)
43+
44+
return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])])
45+
46+
47+
async def test_http_401_returns_jsonrpc_error() -> None:
48+
"""Test that a 401 response returns a JSONRPC error instead of crashing.
49+
50+
Regression test for https://github.com/modelcontextprotocol/python-sdk/issues/1295
51+
"""
52+
returned_exception = None
53+
54+
async def message_handler(
55+
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
56+
) -> None:
57+
nonlocal returned_exception
58+
if isinstance(message, Exception): # pragma: no cover
59+
returned_exception = message
60+
61+
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_401_server_app())) as client:
62+
async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream):
63+
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
64+
await session.initialize()
65+
66+
# list_tools should get a JSONRPC error with HTTP status, not crash
67+
with pytest.raises(Exception) as exc_info:
68+
await session.list_tools()
69+
assert "401" in str(exc_info.value)
70+
71+
if returned_exception: # pragma: no cover
72+
pytest.fail(f"Unexpected exception: {returned_exception}")
73+
74+
75+
def _create_503_server_app() -> Starlette:
76+
"""Create a server that returns 503 for non-init requests."""
77+
78+
async def handle_mcp_request(request: Request) -> Response:
79+
body = await request.body()
80+
data = json.loads(body)
81+
82+
if data.get("method") == "initialize":
83+
return JSONResponse({"jsonrpc": "2.0", "id": data["id"], "result": INIT_RESPONSE})
84+
85+
if "id" not in data:
86+
return Response(status_code=202)
87+
88+
return Response(status_code=503)
89+
90+
return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])])
91+
92+
93+
async def test_http_503_returns_jsonrpc_error() -> None:
94+
"""Test that a 503 response returns a JSONRPC error instead of crashing."""
95+
async with httpx.AsyncClient(transport=httpx.ASGITransport(app=_create_503_server_app())) as client:
96+
async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream):
97+
async with ClientSession(read_stream, write_stream) as session:
98+
await session.initialize()
99+
100+
with pytest.raises(Exception) as exc_info:
101+
await session.list_tools()
102+
assert "503" in str(exc_info.value)

0 commit comments

Comments
 (0)