Skip to content

Commit 34c26fa

Browse files
fix(session): unwrap ExceptionGroup in BaseSession.__aexit__
ROOT CAUSE: BaseSession's task group raises ExceptionGroup wrapping real errors with CancelledError from cancelled tasks. CHANGES: - Modified __aexit__ to unwrap ExceptionGroup before propagating - Real errors now propagate cleanly to callers IMPACT: - Callers can catch specific exceptions directly FILES MODIFIED: - src/mcp/shared/session.py: Added exception unwrapping in __aexit - tests/shared/test_session_exception_group.py: Added test
1 parent 2b90c2f commit 34c26fa

2 files changed

Lines changed: 63 additions & 1 deletion

File tree

src/mcp/shared/session.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,23 @@ async def __aexit__(
223223
exc_val: BaseException | None,
224224
exc_tb: TracebackType | None,
225225
) -> bool | None:
226+
from mcp.shared.exceptions import unwrap_task_group_exception
227+
226228
await self._exit_stack.aclose()
227229
# Using BaseSession as a context manager should not block on exit (this
228230
# would be very surprising behavior), so make sure to cancel the tasks
229231
# in the task group.
230232
self._task_group.cancel_scope.cancel()
231-
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
233+
234+
# Exit the task group and unwrap any ExceptionGroup
235+
try:
236+
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
237+
except BaseException as e:
238+
# Unwrap ExceptionGroup to get only the real error
239+
unwrapped = unwrap_task_group_exception(e)
240+
if unwrapped is not e:
241+
raise unwrapped
242+
raise
232243

233244
async def send_request(
234245
self,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Test that BaseSession unwraps ExceptionGroups properly."""
2+
from __future__ import annotations
3+
4+
import anyio
5+
import pytest
6+
7+
from mcp.shared.session import BaseSession
8+
9+
10+
class _TestSession(BaseSession):
11+
"""Test implementation of BaseSession."""
12+
13+
@property
14+
def _receive_request_adapter(self):
15+
from pydantic import TypeAdapter
16+
17+
return TypeAdapter(dict)
18+
19+
@property
20+
def _receive_notification_adapter(self):
21+
from pydantic import TypeAdapter
22+
23+
return TypeAdapter(dict)
24+
25+
26+
@pytest.mark.anyio
27+
async def test_session_propagates_real_error_not_exception_group() -> None:
28+
"""Test that real errors propagate unwrapped from session task groups."""
29+
# Create streams
30+
read_sender, read_stream = anyio.create_memory_object_stream()
31+
write_stream, write_receiver = anyio.create_memory_object_stream()
32+
33+
try:
34+
session = _TestSession(
35+
read_stream=read_stream,
36+
write_stream=write_stream,
37+
read_timeout_seconds=None,
38+
)
39+
40+
# The session's receive loop will start in __aenter__
41+
# If it fails with ExceptionGroup, we want only the real error
42+
with pytest.raises(ConnectionError, match="connection failed"):
43+
async with session:
44+
# Raise a connection error to trigger exception group behavior
45+
raise ConnectionError("connection failed")
46+
47+
finally:
48+
await read_sender.aclose()
49+
await read_stream.aclose()
50+
await write_stream.aclose()
51+
await write_receiver.aclose()

0 commit comments

Comments
 (0)