Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def __init__(
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
) -> None:
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
# Dispatch incoming server->client requests concurrently so a slow
# sampling/elicitation callback doesn't serialize other in-flight requests.
self._dispatch_requests_concurrently = True
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._sampling_capabilities = sampling_capabilities
Expand Down
55 changes: 45 additions & 10 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def __init__(
def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]:
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_scope = anyio.CancelScope()
# Enter the scope created in __init__ so pre-entry cancel() targets
# the same scope the handler will later run under.
self._cancel_scope.__enter__()
return self

Expand Down Expand Up @@ -140,11 +141,12 @@ async def respond(self, response: SendResultT | ErrorData) -> None:
)

async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered: # pragma: no cover
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope: # pragma: no cover
raise RuntimeError("No active cancel scope")
"""Cancel this request and mark it as completed.

Safe to call before the context manager has been entered.
"""
if self._completed:
return

self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
Expand All @@ -158,6 +160,10 @@ async def cancel(self) -> None:
def in_flight(self) -> bool: # pragma: no cover
return not self._completed and not self.cancelled

@property
def completed(self) -> bool:
return self._completed

@property
def cancelled(self) -> bool:
return self._cancel_scope.cancel_called
Expand Down Expand Up @@ -185,6 +191,10 @@ class BaseSession(
_progress_callbacks: dict[RequestId, ProgressFnT]
_response_routers: list[ResponseRouter]

# When True, incoming requests are dispatched to the session's task group
# so handlers run concurrently with the receive loop.
_dispatch_requests_concurrently: bool = False

def __init__(
self,
read_stream: ReadStream[SessionMessage | Exception],
Expand Down Expand Up @@ -348,6 +358,29 @@ def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]:
def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError

async def _dispatch_request(
self,
responder: RequestResponder[ReceiveRequestT, SendResultT],
) -> None:
"""Run the per-request handler chain, translating handler exceptions
into a JSON-RPC error response so they can't wedge the peer.
"""
request_id = responder.request_id
try:
await self._received_request(responder)
if not responder.completed:
await self._handle_incoming(responder)
except Exception:
logging.warning("Request handler raised an exception", exc_info=True)
if not responder.completed:
error_response = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""),
)
await self._write_stream.send(SessionMessage(message=error_response))
self._in_flight.pop(request_id, None)

async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
try:
Expand All @@ -370,10 +403,6 @@ async def _handle_session_message(message: SessionMessage) -> None:
context=sender_context,
)
self._in_flight[responder.request_id] = responder
await self._received_request(responder)

if not responder._completed: # type: ignore[reportPrivateUsage]
await self._handle_incoming(responder)
except Exception:
# For request validation errors, send a proper JSON-RPC error
# response instead of crashing the server
Expand All @@ -386,6 +415,12 @@ async def _handle_session_message(message: SessionMessage) -> None:
)
session_message = SessionMessage(message=error_response)
await self._write_stream.send(session_message)
return

if self._dispatch_requests_concurrently:
self._task_group.start_soon(self._dispatch_request, responder)
else:
await self._dispatch_request(responder)

elif isinstance(message.message, JSONRPCNotification):
try:
Expand Down
214 changes: 214 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mcp import Client, types
from mcp.client.session import ClientSession
from mcp.server import Server, ServerRequestContext
from mcp.shared._context import RequestContext
from mcp.shared.exceptions import MCPError
from mcp.shared.memory import create_client_server_memory_streams
from mcp.shared.message import SessionMessage
Expand Down Expand Up @@ -416,3 +417,216 @@ async def make_request(client_session: ClientSession):
# Pending request completed successfully
assert len(result_holder) == 1
assert isinstance(result_holder[0], EmptyResult)


@pytest.mark.anyio
async def test_concurrent_server_to_client_requests_run_in_parallel():
"""Regression test for #2489.

A server tool fans out N concurrent ``ServerSession.create_message`` calls
via ``anyio.create_task_group``. The client sampling callback records the
peak number of concurrently-in-flight calls. Before the fix, requests were
serialized end-to-end by ``BaseSession._receive_loop`` and peak was 1.
"""
n = 4

inflight = 0
peak = 0
started = anyio.Event()
release = anyio.Event()

async def sampling_callback(
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult:
nonlocal inflight, peak
inflight += 1
peak = max(peak, inflight)
if peak == n:
started.set()
try:
with anyio.fail_after(5):
await release.wait()
finally:
inflight -= 1
msg = params.messages[0].content
echo = msg.text if isinstance(msg, types.TextContent) else ""
return types.CreateMessageResult(
role="assistant",
content=types.TextContent(type="text", text=f"echo:{echo}"),
model="test-model",
)

async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
results: list[str] = [""] * n

async def one(i: int) -> None:
r = await ctx.session.create_message(
messages=[
types.SamplingMessage(
role="user",
content=types.TextContent(type="text", text=str(i)),
)
],
max_tokens=8,
)
results[i] = r.content.text if isinstance(r.content, types.TextContent) else ""

async with anyio.create_task_group() as tg: # pragma: no branch
for i in range(n):
tg.start_soon(one, i)
return types.CallToolResult(content=[types.TextContent(type="text", text=",".join(results))])

async def handle_list_tools(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListToolsResult:
return types.ListToolsResult(tools=[types.Tool(name="fanout", input_schema={"type": "object"})])

server = Server(name="fanout", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)

async with Client(server, sampling_callback=sampling_callback) as client:
async with anyio.create_task_group() as tg: # pragma: no branch

async def call() -> None:
await client.call_tool("fanout", {})

tg.start_soon(call)
with anyio.fail_after(5):
await started.wait()
release.set()

assert peak == n, f"server->client requests were serialized: peak in-flight={peak}, expected {n}"


@pytest.mark.anyio
async def test_sampling_callback_exception_returns_error_response():
"""A raising sampling callback must produce a JSON-RPC error response so
the server-side ``await ctx.session.create_message(...)`` doesn't hang.
"""

async def sampling_callback(
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult:
raise RuntimeError("boom")

caught: list[MCPError] = []

async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
try:
await ctx.session.create_message(
messages=[
types.SamplingMessage(
role="user",
content=types.TextContent(type="text", text="x"),
)
],
max_tokens=8,
)
except MCPError as e:
caught.append(e)
return types.CallToolResult(content=[types.TextContent(type="text", text="ok")])

async def handle_list_tools(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListToolsResult:
return types.ListToolsResult(tools=[types.Tool(name="boom", input_schema={"type": "object"})])

server = Server(name="raise", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)

async with Client(server, sampling_callback=sampling_callback) as client:
with anyio.fail_after(5):
await client.call_tool("boom", {})

assert len(caught) == 1


@pytest.mark.anyio
async def test_double_cancel_does_not_send_second_response():
"""Cancel called twice on the same responder must not emit a second response."""

class _Dummy:
_send_response_calls = 0

async def _send_response(self, *, request_id: types.RequestId, response: object) -> None:
self._send_response_calls += 1

dummy = _Dummy()
responder = RequestResponder[types.ServerRequest, types.ClientResult](
request_id=1,
request_meta=None,
request=types.PingRequest(method="ping"),
session=dummy, # type: ignore[arg-type]
on_complete=lambda _r: None,
)
with responder:
await responder.cancel()
await responder.cancel()
assert dummy._send_response_calls == 1


@pytest.mark.anyio
async def test_cancel_before_context_entered_marks_scope_cancelled():
"""Regression: with concurrent dispatch, a CancelledNotification can
arrive before the handler task has entered ``with responder:``.
``cancel()`` must not raise, and the scope entered later must already
be cancelled.
"""

class _Dummy:
async def _send_response(self, *, request_id: types.RequestId, response: object) -> None:
pass

responder = RequestResponder[types.ServerRequest, types.ClientResult](
request_id=7,
request_meta=None,
request=types.PingRequest(method="ping"),
session=_Dummy(), # type: ignore[arg-type]
on_complete=lambda _r: None,
)

await responder.cancel()
assert responder.cancelled
assert responder._cancel_scope.cancel_called


@pytest.mark.anyio
async def test_handler_that_responds_then_raises_emits_no_duplicate_error():
"""If a request handler completes the response and then raises, the
dispatch path must not emit a second JSON-RPC error for the same id.
"""

class _RaiseAfterRespond(ClientSession):
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
with responder:
await responder.respond(types.EmptyResult())
raise RuntimeError("after respond")

class _CapturingWrite:
def __init__(self) -> None:
self.sent: list[SessionMessage] = []

async def send(self, msg: SessionMessage) -> None:
self.sent.append(msg)

async with create_client_server_memory_streams() as (client_streams, _server_streams):
client_read, client_write = client_streams
session = _RaiseAfterRespond(client_read, client_write)

capture = _CapturingWrite()
session._write_stream = capture # type: ignore[assignment]

responder = RequestResponder[types.ServerRequest, types.ClientResult](
request_id=99,
request_meta=None,
request=types.PingRequest(method="ping"),
session=session,
on_complete=lambda r: session._in_flight.pop(r.request_id, None),
)
session._in_flight[99] = responder

await session._dispatch_request(responder)

assert len(capture.sent) == 1, capture.sent
assert isinstance(capture.sent[0].message, JSONRPCResponse)
assert capture.sent[0].message.id == 99
Loading