Skip to content

Commit 2511333

Browse files
committed
fix: dispatch client request handlers concurrently (#2489)
BaseSession._receive_loop awaited each incoming request handler inline, serializing server->client requests (e.g. concurrent sampling calls via asyncio.gather peaked at one in flight). Add an opt-in '_dispatch_requests_concurrently' flag on BaseSession that spawns each request handler in the session's task group. ClientSession enables it; ServerSession stays serial to preserve the initialize ordering that its state machine relies on. Also fix two RequestResponder races that concurrent dispatch widens: - __enter__ no longer replaces the cancel scope, so a cancel() that arrives before the handler enters the context targets the same scope the handler will later run under. - cancel() is idempotent and safe to call before entry. Handler exceptions are translated into a JSON-RPC error response so a raising handler can't wedge the peer.
1 parent 3d7b311 commit 2511333

3 files changed

Lines changed: 262 additions & 10 deletions

File tree

src/mcp/client/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ def __init__(
123123
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
124124
) -> None:
125125
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
126+
# Dispatch incoming server->client requests concurrently so a slow
127+
# sampling/elicitation callback doesn't serialize other in-flight requests.
128+
self._dispatch_requests_concurrently = True
126129
self._client_info = client_info or DEFAULT_CLIENT_INFO
127130
self._sampling_callback = sampling_callback or _default_sampling_callback
128131
self._sampling_capabilities = sampling_capabilities

src/mcp/shared/session.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def __init__(
9999
def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]:
100100
"""Enter the context manager, enabling request cancellation tracking."""
101101
self._entered = True
102-
self._cancel_scope = anyio.CancelScope()
102+
# Enter the scope created in __init__ so pre-entry cancel() targets
103+
# the same scope the handler will later run under.
103104
self._cancel_scope.__enter__()
104105
return self
105106

@@ -140,11 +141,12 @@ async def respond(self, response: SendResultT | ErrorData) -> None:
140141
)
141142

142143
async def cancel(self) -> None:
143-
"""Cancel this request and mark it as completed."""
144-
if not self._entered: # pragma: no cover
145-
raise RuntimeError("RequestResponder must be used as a context manager")
146-
if not self._cancel_scope: # pragma: no cover
147-
raise RuntimeError("No active cancel scope")
144+
"""Cancel this request and mark it as completed.
145+
146+
Safe to call before the context manager has been entered.
147+
"""
148+
if self._completed:
149+
return
148150

149151
self._cancel_scope.cancel()
150152
self._completed = True # Mark as completed so it's removed from in_flight
@@ -158,6 +160,10 @@ async def cancel(self) -> None:
158160
def in_flight(self) -> bool: # pragma: no cover
159161
return not self._completed and not self.cancelled
160162

163+
@property
164+
def completed(self) -> bool:
165+
return self._completed
166+
161167
@property
162168
def cancelled(self) -> bool:
163169
return self._cancel_scope.cancel_called
@@ -185,6 +191,10 @@ class BaseSession(
185191
_progress_callbacks: dict[RequestId, ProgressFnT]
186192
_response_routers: list[ResponseRouter]
187193

194+
# When True, incoming requests are dispatched to the session's task group
195+
# so handlers run concurrently with the receive loop.
196+
_dispatch_requests_concurrently: bool = False
197+
188198
def __init__(
189199
self,
190200
read_stream: ReadStream[SessionMessage | Exception],
@@ -348,6 +358,29 @@ def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]:
348358
def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
349359
raise NotImplementedError
350360

361+
async def _dispatch_request(
362+
self,
363+
responder: RequestResponder[ReceiveRequestT, SendResultT],
364+
) -> None:
365+
"""Run the per-request handler chain, translating handler exceptions
366+
into a JSON-RPC error response so they can't wedge the peer.
367+
"""
368+
request_id = responder.request_id
369+
try:
370+
await self._received_request(responder)
371+
if not responder.completed:
372+
await self._handle_incoming(responder)
373+
except Exception:
374+
logging.warning("Request handler raised an exception", exc_info=True)
375+
if not responder.completed:
376+
error_response = JSONRPCError(
377+
jsonrpc="2.0",
378+
id=request_id,
379+
error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""),
380+
)
381+
await self._write_stream.send(SessionMessage(message=error_response))
382+
self._in_flight.pop(request_id, None)
383+
351384
async def _receive_loop(self) -> None:
352385
async with self._read_stream, self._write_stream:
353386
try:
@@ -370,10 +403,6 @@ async def _handle_session_message(message: SessionMessage) -> None:
370403
context=sender_context,
371404
)
372405
self._in_flight[responder.request_id] = responder
373-
await self._received_request(responder)
374-
375-
if not responder._completed: # type: ignore[reportPrivateUsage]
376-
await self._handle_incoming(responder)
377406
except Exception:
378407
# For request validation errors, send a proper JSON-RPC error
379408
# response instead of crashing the server
@@ -386,6 +415,12 @@ async def _handle_session_message(message: SessionMessage) -> None:
386415
)
387416
session_message = SessionMessage(message=error_response)
388417
await self._write_stream.send(session_message)
418+
return
419+
420+
if self._dispatch_requests_concurrently:
421+
self._task_group.start_soon(self._dispatch_request, responder)
422+
else:
423+
await self._dispatch_request(responder)
389424

390425
elif isinstance(message.message, JSONRPCNotification):
391426
try:

tests/shared/test_session.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from mcp import Client, types
55
from mcp.client.session import ClientSession
66
from mcp.server import Server, ServerRequestContext
7+
from mcp.shared._context import RequestContext
78
from mcp.shared.exceptions import MCPError
89
from mcp.shared.memory import create_client_server_memory_streams
910
from mcp.shared.message import SessionMessage
@@ -416,3 +417,216 @@ async def make_request(client_session: ClientSession):
416417
# Pending request completed successfully
417418
assert len(result_holder) == 1
418419
assert isinstance(result_holder[0], EmptyResult)
420+
421+
422+
@pytest.mark.anyio
423+
async def test_concurrent_server_to_client_requests_run_in_parallel():
424+
"""Regression test for #2489.
425+
426+
A server tool fans out N concurrent ``ServerSession.create_message`` calls
427+
via ``anyio.create_task_group``. The client sampling callback records the
428+
peak number of concurrently-in-flight calls. Before the fix, requests were
429+
serialized end-to-end by ``BaseSession._receive_loop`` and peak was 1.
430+
"""
431+
n = 4
432+
433+
inflight = 0
434+
peak = 0
435+
started = anyio.Event()
436+
release = anyio.Event()
437+
438+
async def sampling_callback(
439+
context: RequestContext[ClientSession],
440+
params: types.CreateMessageRequestParams,
441+
) -> types.CreateMessageResult:
442+
nonlocal inflight, peak
443+
inflight += 1
444+
peak = max(peak, inflight)
445+
if peak == n:
446+
started.set()
447+
try:
448+
with anyio.fail_after(5):
449+
await release.wait()
450+
finally:
451+
inflight -= 1
452+
msg = params.messages[0].content
453+
echo = msg.text if isinstance(msg, types.TextContent) else ""
454+
return types.CreateMessageResult(
455+
role="assistant",
456+
content=types.TextContent(type="text", text=f"echo:{echo}"),
457+
model="test-model",
458+
)
459+
460+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
461+
results: list[str] = [""] * n
462+
463+
async def one(i: int) -> None:
464+
r = await ctx.session.create_message(
465+
messages=[
466+
types.SamplingMessage(
467+
role="user",
468+
content=types.TextContent(type="text", text=str(i)),
469+
)
470+
],
471+
max_tokens=8,
472+
)
473+
results[i] = r.content.text if isinstance(r.content, types.TextContent) else ""
474+
475+
async with anyio.create_task_group() as tg: # pragma: no branch
476+
for i in range(n):
477+
tg.start_soon(one, i)
478+
return types.CallToolResult(content=[types.TextContent(type="text", text=",".join(results))])
479+
480+
async def handle_list_tools(
481+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
482+
) -> types.ListToolsResult:
483+
return types.ListToolsResult(tools=[types.Tool(name="fanout", input_schema={"type": "object"})])
484+
485+
server = Server(name="fanout", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)
486+
487+
async with Client(server, sampling_callback=sampling_callback) as client:
488+
async with anyio.create_task_group() as tg: # pragma: no branch
489+
490+
async def call() -> None:
491+
await client.call_tool("fanout", {})
492+
493+
tg.start_soon(call)
494+
with anyio.fail_after(5):
495+
await started.wait()
496+
release.set()
497+
498+
assert peak == n, f"server->client requests were serialized: peak in-flight={peak}, expected {n}"
499+
500+
501+
@pytest.mark.anyio
502+
async def test_sampling_callback_exception_returns_error_response():
503+
"""A raising sampling callback must produce a JSON-RPC error response so
504+
the server-side ``await ctx.session.create_message(...)`` doesn't hang.
505+
"""
506+
507+
async def sampling_callback(
508+
context: RequestContext[ClientSession],
509+
params: types.CreateMessageRequestParams,
510+
) -> types.CreateMessageResult:
511+
raise RuntimeError("boom")
512+
513+
caught: list[MCPError] = []
514+
515+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
516+
try:
517+
await ctx.session.create_message(
518+
messages=[
519+
types.SamplingMessage(
520+
role="user",
521+
content=types.TextContent(type="text", text="x"),
522+
)
523+
],
524+
max_tokens=8,
525+
)
526+
except MCPError as e:
527+
caught.append(e)
528+
return types.CallToolResult(content=[types.TextContent(type="text", text="ok")])
529+
530+
async def handle_list_tools(
531+
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
532+
) -> types.ListToolsResult:
533+
return types.ListToolsResult(tools=[types.Tool(name="boom", input_schema={"type": "object"})])
534+
535+
server = Server(name="raise", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)
536+
537+
async with Client(server, sampling_callback=sampling_callback) as client:
538+
with anyio.fail_after(5):
539+
await client.call_tool("boom", {})
540+
541+
assert len(caught) == 1
542+
543+
544+
@pytest.mark.anyio
545+
async def test_double_cancel_does_not_send_second_response():
546+
"""Cancel called twice on the same responder must not emit a second response."""
547+
548+
class _Dummy:
549+
_send_response_calls = 0
550+
551+
async def _send_response(self, *, request_id: types.RequestId, response: object) -> None:
552+
self._send_response_calls += 1
553+
554+
dummy = _Dummy()
555+
responder = RequestResponder[types.ServerRequest, types.ClientResult](
556+
request_id=1,
557+
request_meta=None,
558+
request=types.PingRequest(method="ping"),
559+
session=dummy, # type: ignore[arg-type]
560+
on_complete=lambda _r: None,
561+
)
562+
with responder:
563+
await responder.cancel()
564+
await responder.cancel()
565+
assert dummy._send_response_calls == 1
566+
567+
568+
@pytest.mark.anyio
569+
async def test_cancel_before_context_entered_marks_scope_cancelled():
570+
"""Regression: with concurrent dispatch, a CancelledNotification can
571+
arrive before the handler task has entered ``with responder:``.
572+
``cancel()`` must not raise, and the scope entered later must already
573+
be cancelled.
574+
"""
575+
576+
class _Dummy:
577+
async def _send_response(self, *, request_id: types.RequestId, response: object) -> None:
578+
pass
579+
580+
responder = RequestResponder[types.ServerRequest, types.ClientResult](
581+
request_id=7,
582+
request_meta=None,
583+
request=types.PingRequest(method="ping"),
584+
session=_Dummy(), # type: ignore[arg-type]
585+
on_complete=lambda _r: None,
586+
)
587+
588+
await responder.cancel()
589+
assert responder.cancelled
590+
assert responder._cancel_scope.cancel_called
591+
592+
593+
@pytest.mark.anyio
594+
async def test_handler_that_responds_then_raises_emits_no_duplicate_error():
595+
"""If a request handler completes the response and then raises, the
596+
dispatch path must not emit a second JSON-RPC error for the same id.
597+
"""
598+
599+
class _RaiseAfterRespond(ClientSession):
600+
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
601+
with responder:
602+
await responder.respond(types.EmptyResult())
603+
raise RuntimeError("after respond")
604+
605+
class _CapturingWrite:
606+
def __init__(self) -> None:
607+
self.sent: list[SessionMessage] = []
608+
609+
async def send(self, msg: SessionMessage) -> None:
610+
self.sent.append(msg)
611+
612+
async with create_client_server_memory_streams() as (client_streams, _server_streams):
613+
client_read, client_write = client_streams
614+
session = _RaiseAfterRespond(client_read, client_write)
615+
616+
capture = _CapturingWrite()
617+
session._write_stream = capture # type: ignore[assignment]
618+
619+
responder = RequestResponder[types.ServerRequest, types.ClientResult](
620+
request_id=99,
621+
request_meta=None,
622+
request=types.PingRequest(method="ping"),
623+
session=session,
624+
on_complete=lambda r: session._in_flight.pop(r.request_id, None),
625+
)
626+
session._in_flight[99] = responder
627+
628+
await session._dispatch_request(responder)
629+
630+
assert len(capture.sent) == 1, capture.sent
631+
assert isinstance(capture.sent[0].message, JSONRPCResponse)
632+
assert capture.sent[0].message.id == 99

0 commit comments

Comments
 (0)