Skip to content

Commit 136a22d

Browse files
committed
feat: JSONRPCDispatcher receive loop and dispatch (chunk b)
run() drives the receive loop in a per-request task group; task_status.started() fires once send_request is usable. _dispatch routes each inbound message synchronously (no awaits — send_nowait/_spawn only) to avoid head-of-line blocking. _spawn propagates the sender's contextvars via Context.run(tg.start_soon, ...) so auth/OTel set by ASGI middleware survive. _fan_out_closed wakes pending send_request waiters with CONNECTION_CLOSED on shutdown (called both post-EOF and in finally; idempotent). Wire-param extraction (progressToken, cancelled.requestId, progress fields) uses structural match patterns — runtime narrowing, no casts, no mcp.types model coupling; malformed input fails to match and the correlation is skipped. _handle_request is happy-path only here (run on_request, write response); the exception-to-wire boundary lands in the next commit. Dispatcher.run() Protocol gained a task_status kwarg (it's a contract-level guarantee). DirectDispatcher.run() updated to match. running_pair now uses tg.start so the test body runs only once the dispatcher is ready. 20 contract tests pass; the 2 needing the exception boundary are strict-xfail.
1 parent d84f82a commit 136a22d

5 files changed

Lines changed: 274 additions & 23 deletions

File tree

src/mcp/shared/direct_dispatcher.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Any
2121

2222
import anyio
23+
import anyio.abc
2324

2425
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
2526
from mcp.shared.exceptions import MCPError, NoBackChannelError
@@ -101,10 +102,17 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
101102
raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()")
102103
await self._peer._dispatch_notify(method, params)
103104

104-
async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
105+
async def run(
106+
self,
107+
on_request: OnRequest,
108+
on_notify: OnNotify,
109+
*,
110+
task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
111+
) -> None:
105112
self._on_request = on_request
106113
self._on_notify = on_notify
107114
self._ready.set()
115+
task_status.started()
108116
await self._closed.wait()
109117

110118
def close(self) -> None:

src/mcp/shared/dispatcher.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable
2121

2222
import anyio
23+
import anyio.abc
2324

2425
from mcp.shared.transport_context import TransportContext
2526

@@ -136,11 +137,21 @@ class Dispatcher(Outbound, Protocol[TransportT_co]):
136137
receive loop, per-request concurrency, and cancellation/progress wiring.
137138
"""
138139

139-
async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
140+
async def run(
141+
self,
142+
on_request: OnRequest,
143+
on_notify: OnNotify,
144+
*,
145+
task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
146+
) -> None:
140147
"""Drive the receive loop until the underlying channel closes.
141148
142149
Each inbound request is dispatched to ``on_request`` in its own task;
143150
the returned dict (or raised ``MCPError``) is sent back as the response.
144151
Inbound notifications go to ``on_notify``.
152+
153+
``task_status.started()`` is called once the dispatcher is ready to
154+
accept ``send_request``/``notify`` calls, so callers can use
155+
``await tg.start(dispatcher.run, on_request, on_notify)``.
145156
"""
146157
...

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 226 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,30 @@
88
The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and
99
sees only `(ctx, method, params) -> dict`. Transports sit below and see only
1010
`SessionMessage` reads/writes.
11+
12+
The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and
13+
dicts — but it intercepts ``notifications/cancelled`` and
14+
``notifications/progress`` because request correlation, cancellation and
15+
progress are exactly the wiring this layer exists to provide. Those few wire
16+
shapes are extracted with structural ``match`` patterns (no casts, no
17+
``mcp.types`` model coupling); a malformed payload simply fails to match and
18+
the correlation is skipped.
1119
"""
1220

1321
from __future__ import annotations
1422

23+
import contextvars
1524
import logging
16-
from collections.abc import Callable, Mapping
25+
from collections.abc import Awaitable, Callable, Mapping
1726
from dataclasses import dataclass, field
18-
from typing import Any, Generic, Literal, TypeVar, overload
27+
from typing import Any, Generic, Literal, TypeVar, cast, overload
1928

2029
import anyio
30+
import anyio.abc
2131
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2232

2333
from mcp.shared._stream_protocols import ReadStream, WriteStream
24-
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
34+
from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT
2535
from mcp.shared.exceptions import MCPError, NoBackChannelError
2636
from mcp.shared.message import (
2737
ClientMessageMetadata,
@@ -31,11 +41,14 @@
3141
)
3242
from mcp.shared.transport_context import TransportContext
3343
from mcp.types import (
44+
CONNECTION_CLOSED,
3445
REQUEST_TIMEOUT,
3546
ErrorData,
47+
JSONRPCError,
3648
JSONRPCMessage,
3749
JSONRPCNotification,
3850
JSONRPCRequest,
51+
JSONRPCResponse,
3952
ProgressToken,
4053
RequestId,
4154
)
@@ -141,8 +154,12 @@ def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions |
141154
return None
142155

143156

144-
class JSONRPCDispatcher(Generic[TransportT]):
145-
"""`Dispatcher` over the existing `SessionMessage` stream contract."""
157+
class JSONRPCDispatcher(Dispatcher[TransportT]):
158+
"""`Dispatcher` over the existing `SessionMessage` stream contract.
159+
160+
Inherits the `Dispatcher` Protocol explicitly so pyright checks
161+
conformance at the class definition rather than at first use.
162+
"""
146163

147164
@overload
148165
def __init__(
@@ -171,13 +188,20 @@ def __init__(
171188
) -> None:
172189
self._read_stream = read_stream
173190
self._write_stream = write_stream
174-
self._transport_builder = transport_builder or _default_transport_builder
191+
# The overloads guarantee that when `transport_builder` is omitted,
192+
# `TransportT` is `TransportContext`, so the default is type-correct;
193+
# pyright can't see across overloads, hence the cast.
194+
self._transport_builder = cast(
195+
"Callable[[RequestId | None, MessageMetadata], TransportT]",
196+
transport_builder or _default_transport_builder,
197+
)
175198
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
176199
self._raise_handler_exceptions = raise_handler_exceptions
177200

178201
self._next_id = 0
179202
self._pending: dict[RequestId, _Pending] = {}
180203
self._in_flight: dict[RequestId, _InFlight[TransportT]] = {}
204+
self._tg: anyio.abc.TaskGroup | None = None
181205
self._running = False
182206

183207
async def send_request(
@@ -219,6 +243,11 @@ async def send_request(
219243
meta["progressToken"] = request_id
220244
out_params = {**(out_params or {}), "_meta": meta}
221245

246+
# buffer=1: at most one outcome is ever delivered. A `WouldBlock` from
247+
# `_resolve_pending`/`_fan_out_closed` means the waiter already has an
248+
# outcome and dropping the late/redundant signal is correct. buffer=0
249+
# is unsafe — there's a window between registering `_pending[id]` and
250+
# parking in `receive()` where a close signal would be lost.
222251
send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1)
223252
pending = _Pending(send=send, receive=receive, on_progress=on_progress)
224253
self._pending[request_id] = pending
@@ -264,8 +293,197 @@ async def notify(
264293
msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None)
265294
await self._write(msg, _outbound_metadata(_related_request_id, None))
266295

267-
async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
268-
raise NotImplementedError # chunk (b)
296+
async def run(
297+
self,
298+
on_request: OnRequest,
299+
on_notify: OnNotify,
300+
*,
301+
task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED,
302+
) -> None:
303+
"""Drive the receive loop until the read stream closes.
304+
305+
Each inbound request is handled in its own task in an internal task
306+
group; ``task_status.started()`` fires once that group is open, so
307+
``await tg.start(dispatcher.run, ...)`` resumes when ``send_request``
308+
is usable.
309+
"""
310+
try:
311+
async with anyio.create_task_group() as tg:
312+
self._tg = tg
313+
self._running = True
314+
task_status.started()
315+
async with self._read_stream:
316+
async for item in self._read_stream:
317+
# Duck-typed: `_context_streams.ContextReceiveStream`
318+
# exposes `.last_context` (the sender's contextvars
319+
# snapshot per message). Plain memory streams don't.
320+
sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None)
321+
self._dispatch(item, on_request, on_notify, sender_ctx)
322+
# Read stream EOF: wake any blocked `send_request` waiters now,
323+
# *before* the task group joins, so handlers parked in
324+
# `dctx.send_request()` can unwind and the join doesn't deadlock.
325+
self._running = False
326+
self._fan_out_closed()
327+
finally:
328+
# Covers the cancel/crash paths where the inline fan-out above is
329+
# never reached. Idempotent.
330+
self._running = False
331+
self._tg = None
332+
self._fan_out_closed()
333+
334+
def _dispatch(
335+
self,
336+
item: SessionMessage | Exception,
337+
on_request: OnRequest,
338+
on_notify: OnNotify,
339+
sender_ctx: contextvars.Context | None,
340+
) -> None:
341+
"""Route one inbound item. Synchronous: never awaits.
342+
343+
Everything here is `send_nowait` or `_spawn`. An `await` would let one
344+
slow message head-of-line block the entire read loop.
345+
"""
346+
if isinstance(item, Exception):
347+
logger.debug("transport yielded exception: %r", item)
348+
return
349+
metadata = item.metadata
350+
msg = item.message
351+
match msg:
352+
case JSONRPCRequest():
353+
self._dispatch_request(msg, metadata, on_request, sender_ctx)
354+
case JSONRPCNotification():
355+
self._dispatch_notification(msg, metadata, on_notify, sender_ctx)
356+
case JSONRPCResponse():
357+
self._resolve_pending(msg.id, msg.result)
358+
case JSONRPCError():
359+
# `id` may be None per JSON-RPC (parse error before id known).
360+
self._resolve_pending(msg.id, msg.error)
361+
362+
def _dispatch_request(
363+
self,
364+
req: JSONRPCRequest,
365+
metadata: MessageMetadata,
366+
on_request: OnRequest,
367+
sender_ctx: contextvars.Context | None,
368+
) -> None:
369+
progress_token: ProgressToken | None
370+
match req.params:
371+
case {"_meta": {"progressToken": str() | int() as progress_token}}:
372+
pass
373+
case _:
374+
progress_token = None
375+
transport_ctx = self._transport_builder(req.id, metadata)
376+
dctx = _JSONRPCDispatchContext(
377+
transport=transport_ctx,
378+
_dispatcher=self,
379+
_request_id=req.id,
380+
_progress_token=progress_token,
381+
)
382+
scope = anyio.CancelScope()
383+
self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx)
384+
self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx)
385+
386+
def _dispatch_notification(
387+
self,
388+
msg: JSONRPCNotification,
389+
metadata: MessageMetadata,
390+
on_notify: OnNotify,
391+
sender_ctx: contextvars.Context | None,
392+
) -> None:
393+
if msg.method == "notifications/cancelled":
394+
match msg.params:
395+
case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None:
396+
in_flight.cancelled_by_peer = True
397+
in_flight.dctx.cancel_requested.set()
398+
if self._peer_cancel_mode == "interrupt":
399+
in_flight.scope.cancel()
400+
case _:
401+
pass
402+
return
403+
if msg.method == "notifications/progress":
404+
match msg.params:
405+
case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if (
406+
pending := self._pending.get(token)
407+
) is not None and pending.on_progress is not None:
408+
total = msg.params.get("total")
409+
message = msg.params.get("message")
410+
self._spawn(
411+
pending.on_progress,
412+
float(progress),
413+
float(total) if isinstance(total, int | float) else None,
414+
message if isinstance(message, str) else None,
415+
sender_ctx=sender_ctx,
416+
)
417+
case _:
418+
pass
419+
# fall through: progress is also teed to on_notify
420+
transport_ctx = self._transport_builder(None, metadata)
421+
dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None)
422+
self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx)
423+
424+
def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None:
425+
pending = self._pending.get(request_id) if request_id is not None else None
426+
if pending is None:
427+
logger.debug("dropping response for unknown/late request id %r", request_id)
428+
return
429+
try:
430+
pending.send.send_nowait(outcome)
431+
except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError):
432+
logger.debug("waiter for request id %r already gone", request_id)
433+
434+
def _spawn(
435+
self,
436+
fn: Callable[..., Awaitable[Any]],
437+
*args: object,
438+
sender_ctx: contextvars.Context | None,
439+
) -> None:
440+
"""Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars.
441+
442+
ASGI middleware (auth, OTel) sets contextvars on the request task that
443+
wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes
444+
the spawned handler inherit *that* context instead of the receive
445+
loop's, so ``auth_context_var`` and OTel spans survive.
446+
"""
447+
assert self._tg is not None
448+
if sender_ctx is not None:
449+
sender_ctx.run(self._tg.start_soon, fn, *args)
450+
else:
451+
self._tg.start_soon(fn, *args)
452+
453+
def _fan_out_closed(self) -> None:
454+
"""Wake every pending ``send_request`` waiter with ``CONNECTION_CLOSED``.
455+
456+
Synchronous (uses ``send_nowait``) because it's called from ``finally``
457+
which may be inside a cancelled scope. Idempotent.
458+
"""
459+
closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed")
460+
for pending in self._pending.values():
461+
try:
462+
pending.send.send_nowait(closed)
463+
except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError):
464+
pass
465+
self._pending.clear()
466+
467+
async def _handle_request(
468+
self,
469+
req: JSONRPCRequest,
470+
dctx: _JSONRPCDispatchContext[TransportT],
471+
scope: anyio.CancelScope,
472+
on_request: OnRequest,
473+
) -> None:
474+
"""Run ``on_request`` for one inbound request and write its response.
475+
476+
Chunk (b): happy-path only. The full exception-to-wire boundary
477+
(MCPError, ValidationError, INTERNAL_ERROR scrubbing, peer-cancel
478+
no-response) lands in chunk (c).
479+
"""
480+
try:
481+
with scope:
482+
result = await on_request(dctx, req.method, req.params)
483+
await self._write(JSONRPCResponse(jsonrpc="2.0", id=req.id, result=result))
484+
finally:
485+
self._in_flight.pop(req.id, None)
486+
dctx.close()
269487

270488
def _allocate_id(self) -> int:
271489
self._next_id += 1

tests/shared/conftest.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,26 @@ def close() -> None:
4848
return client, server, close
4949

5050

51-
_JSONRPC_XFAIL = pytest.mark.xfail(
52-
strict=True,
53-
reason="JSONRPCDispatcher.run() not yet implemented (PR2 chunks b/c)",
54-
)
55-
56-
5751
@pytest.fixture(
5852
params=[
5953
pytest.param(direct_pair, id="direct"),
60-
pytest.param(jsonrpc_pair, id="jsonrpc", marks=_JSONRPC_XFAIL),
54+
pytest.param(jsonrpc_pair, id="jsonrpc"),
6155
]
6256
)
6357
def pair_factory(request: pytest.FixtureRequest) -> PairFactory:
6458
return request.param
6559

6660

67-
__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"]
61+
def xfail_jsonrpc_chunk_c(request: pytest.FixtureRequest, factory: PairFactory) -> None:
62+
"""Apply a strict xfail when running against the JSON-RPC dispatcher.
63+
64+
Use for contract tests that require `_handle_request`'s exception boundary
65+
(PR2 chunk c). Remove once that lands.
66+
"""
67+
if factory is jsonrpc_pair:
68+
request.applymarker(
69+
pytest.mark.xfail(strict=True, reason="needs JSONRPCDispatcher._handle_request exception boundary")
70+
)
71+
72+
73+
__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair", "xfail_jsonrpc_chunk_c"]

0 commit comments

Comments
 (0)