Skip to content

Commit 58ab810

Browse files
fix(server): unwrap ExceptionGroup in transport servers
ROOT CAUSE: Server transports propagate ExceptionGroup wrapping real errors. CHANGES: - Added exception unwrapping in sse_server - Added exception unwrapping in stdio_server - Added exception unwrapping in websocket_server - Added exception unwrapping in streamable_http_server (2 locations) IMPACT: - Callers can catch specific exceptions directly FILES MODIFIED: - src/mcp/server/sse.py - src/mcp/server/stdio.py - src/mcp/server/websocket.py - src/mcp/server/streamable_http.py
1 parent 15f4e8f commit 58ab810

4 files changed

Lines changed: 132 additions & 86 deletions

File tree

src/mcp/server/sse.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -175,24 +175,30 @@ async def sse_writer():
175175
)
176176

177177
async with anyio.create_task_group() as tg:
178+
try:
179+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
180+
"""The EventSourceResponse returning signals a client close / disconnect.
181+
In this case we close our side of the streams to signal the client that
182+
the connection has been closed.
183+
"""
184+
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
185+
scope, receive, send
186+
)
187+
await read_stream_writer.aclose()
188+
await write_stream_reader.aclose()
189+
logging.debug(f"Client session disconnected {session_id}")
190+
191+
logger.debug("Starting SSE response task")
192+
tg.start_soon(response_wrapper, scope, receive, send)
193+
194+
logger.debug("Yielding read and write streams")
195+
yield (read_stream, write_stream)
196+
except BaseExceptionGroup as e:
197+
from mcp.shared.exceptions import unwrap_task_group_exception
178198

179-
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
180-
"""The EventSourceResponse returning signals a client close / disconnect.
181-
In this case we close our side of the streams to signal the client that
182-
the connection has been closed.
183-
"""
184-
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
185-
scope, receive, send
186-
)
187-
await read_stream_writer.aclose()
188-
await write_stream_reader.aclose()
189-
logging.debug(f"Client session disconnected {session_id}")
190-
191-
logger.debug("Starting SSE response task")
192-
tg.start_soon(response_wrapper, scope, receive, send)
193-
194-
logger.debug("Yielding read and write streams")
195-
yield (read_stream, write_stream)
199+
real_exc = unwrap_task_group_exception(e)
200+
if real_exc is not e:
201+
raise real_exc
196202

197203
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
198204
logger.debug("Handling POST message")

src/mcp/server/stdio.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ async def stdout_writer():
7878
await anyio.lowlevel.checkpoint()
7979

8080
async with anyio.create_task_group() as tg:
81-
tg.start_soon(stdin_reader)
82-
tg.start_soon(stdout_writer)
83-
yield read_stream, write_stream
81+
try:
82+
tg.start_soon(stdin_reader)
83+
tg.start_soon(stdout_writer)
84+
yield read_stream, write_stream
85+
except BaseExceptionGroup as e:
86+
from mcp.shared.exceptions import unwrap_task_group_exception
87+
88+
real_exc = unwrap_task_group_exception(e)
89+
if real_exc is not e:
90+
raise real_exc

src/mcp/server/streamable_http.py

Lines changed: 89 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,17 @@ async def sse_writer(): # pragma: lax no cover
615615
try:
616616
# First send the response to establish the SSE connection
617617
async with anyio.create_task_group() as tg:
618-
tg.start_soon(response, scope, receive, send)
619-
# Then send the message to be processed by the server
620-
session_message = self._create_session_message(message, request, request_id, protocol_version)
621-
await writer.send(session_message)
618+
try:
619+
tg.start_soon(response, scope, receive, send)
620+
# Then send the message to be processed by the server
621+
session_message = self._create_session_message(message, request, request_id, protocol_version)
622+
await writer.send(session_message)
623+
except BaseExceptionGroup as e:
624+
from mcp.shared.exceptions import unwrap_task_group_exception
625+
626+
real_exc = unwrap_task_group_exception(e)
627+
if real_exc is not e:
628+
raise real_exc
622629
except Exception: # pragma: no cover
623630
logger.exception("SSE response error")
624631
await sse_stream_writer.aclose()
@@ -971,67 +978,86 @@ async def connect(
971978

972979
# Start a task group for message routing
973980
async with anyio.create_task_group() as tg:
974-
# Create a message router that distributes messages to request streams
975-
async def message_router():
976-
try:
977-
async for session_message in write_stream_reader: # pragma: no branch
978-
# Determine which request stream(s) should receive this message
979-
message = session_message.message
980-
target_request_id = None
981-
# Check if this is a response with a known request id.
982-
# Null-id errors (e.g., parse errors) fall through to
983-
# the GET stream since they can't be correlated.
984-
if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None:
985-
target_request_id = str(message.id)
986-
# Extract related_request_id from meta if it exists
987-
elif ( # pragma: no cover
988-
session_message.metadata is not None
989-
and isinstance(
990-
session_message.metadata,
991-
ServerMessageMetadata,
992-
)
993-
and session_message.metadata.related_request_id is not None
994-
):
995-
target_request_id = str(session_message.metadata.related_request_id)
996-
997-
request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY
998-
999-
# Store the event if we have an event store,
1000-
# regardless of whether a client is connected
1001-
# messages will be replayed on the re-connect
1002-
event_id = None
1003-
if self._event_store: # pragma: lax no cover
1004-
event_id = await self._event_store.store_event(request_stream_id, message)
1005-
logger.debug(f"Stored {event_id} from {request_stream_id}")
1006-
1007-
if request_stream_id in self._request_streams:
1008-
try:
1009-
# Send both the message and the event ID
1010-
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
1011-
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover
1012-
# Stream might be closed, remove from registry
1013-
self._request_streams.pop(request_stream_id, None)
1014-
else: # pragma: no cover
1015-
logger.debug(
1016-
f"""Request stream {request_stream_id} not found
1017-
for message. Still processing message as the client
1018-
might reconnect and replay."""
1019-
)
1020-
except anyio.ClosedResourceError:
1021-
if self._terminated:
1022-
logger.debug("Read stream closed by client")
1023-
else:
1024-
logger.exception("Unexpected closure of read stream in message router")
1025-
except Exception: # pragma: lax no cover
1026-
logger.exception("Error in message router")
981+
try:
982+
# Create a message router that distributes messages to request streams
983+
async def message_router():
984+
try:
985+
async for session_message in write_stream_reader: # pragma: no branch
986+
# Determine which request stream(s) should receive this message
987+
message = session_message.message
988+
target_request_id = None
989+
# Check if this is a response with a known request id.
990+
# Null-id errors (e.g., parse errors) fall through to
991+
# the GET stream since they can't be correlated.
992+
if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None:
993+
target_request_id = str(message.id)
994+
# Extract related_request_id from meta if it exists
995+
elif ( # pragma: no cover
996+
session_message.metadata is not None
997+
and isinstance(
998+
session_message.metadata,
999+
ServerMessageMetadata,
1000+
)
1001+
and session_message.metadata.related_request_id is not None
1002+
):
1003+
target_request_id = str(session_message.metadata.related_request_id)
1004+
1005+
request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY
1006+
1007+
# Store the event if we have an event store,
1008+
# regardless of whether a client is connected
1009+
# messages will be replayed on the re-connect
1010+
event_id = None
1011+
if self._event_store: # pragma: lax no cover
1012+
event_id = await self._event_store.store_event(request_stream_id, message)
1013+
logger.debug(f"Stored {event_id} from {request_stream_id}")
1014+
1015+
if request_stream_id in self._request_streams:
1016+
try:
1017+
# Send both the message and the event ID
1018+
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
1019+
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover
1020+
# Stream might be closed, remove from registry
1021+
self._request_streams.pop(request_stream_id, None)
1022+
else: # pragma: no cover
1023+
logger.debug(
1024+
f"""Request stream {request_stream_id} not found
1025+
for message. Still processing message as the client
1026+
might reconnect and replay."""
1027+
)
1028+
except anyio.ClosedResourceError:
1029+
if self._terminated:
1030+
logger.debug("Read stream closed by client")
1031+
else:
1032+
logger.exception("Unexpected closure of read stream in message router")
1033+
except Exception: # pragma: lax no cover
1034+
logger.exception("Error in message router")
10271035

1028-
# Start the message router
1029-
tg.start_soon(message_router)
1036+
# Start the message router
1037+
tg.start_soon(message_router)
10301038

1031-
try:
1032-
# Yield the streams for the caller to use
1033-
yield read_stream, write_stream
1034-
finally:
1039+
try:
1040+
# Yield the streams for the caller to use
1041+
yield read_stream, write_stream
1042+
finally:
1043+
for stream_id in list(self._request_streams.keys()): # pragma: lax no cover
1044+
await self._clean_up_memory_streams(stream_id)
1045+
self._request_streams.clear()
1046+
1047+
# Clean up the read and write streams
1048+
try:
1049+
await read_stream_writer.aclose()
1050+
await read_stream.aclose()
1051+
await write_stream_reader.aclose()
1052+
await write_stream.aclose()
1053+
except Exception: # pragma: no cover
1054+
logger.exception("Error closing streams")
1055+
except BaseExceptionGroup as e:
1056+
from mcp.shared.exceptions import unwrap_task_group_exception
1057+
1058+
real_exc = unwrap_task_group_exception(e)
1059+
if real_exc is not e:
1060+
raise real_exc
10351061
for stream_id in list(self._request_streams.keys()): # pragma: lax no cover
10361062
await self._clean_up_memory_streams(stream_id)
10371063
self._request_streams.clear()

src/mcp/server/websocket.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ async def ws_writer():
5353
await websocket.close()
5454

5555
async with anyio.create_task_group() as tg:
56-
tg.start_soon(ws_reader)
57-
tg.start_soon(ws_writer)
58-
yield (read_stream, write_stream)
56+
try:
57+
tg.start_soon(ws_reader)
58+
tg.start_soon(ws_writer)
59+
yield (read_stream, write_stream)
60+
except BaseExceptionGroup as e:
61+
from mcp.shared.exceptions import unwrap_task_group_exception
62+
63+
real_exc = unwrap_task_group_exception(e)
64+
if real_exc is not e:
65+
raise real_exc

0 commit comments

Comments
 (0)