|
8 | 8 | The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and |
9 | 9 | sees only `(ctx, method, params) -> dict`. Transports sit below and see only |
10 | 10 | `SessionMessage` reads/writes. |
| 11 | +
|
| 12 | +The dispatcher is *mostly* MCP-agnostic — methods/params are opaque strings and |
| 13 | +dicts — but it intercepts ``notifications/cancelled`` and |
| 14 | +``notifications/progress`` because request correlation, cancellation and |
| 15 | +progress are exactly the wiring this layer exists to provide. Those few wire |
| 16 | +shapes are extracted with structural ``match`` patterns (no casts, no |
| 17 | +``mcp.types`` model coupling); a malformed payload simply fails to match and |
| 18 | +the correlation is skipped. |
11 | 19 | """ |
12 | 20 |
|
13 | 21 | from __future__ import annotations |
14 | 22 |
|
| 23 | +import contextvars |
15 | 24 | import logging |
16 | | -from collections.abc import Callable, Mapping |
| 25 | +from collections.abc import Awaitable, Callable, Mapping |
17 | 26 | from dataclasses import dataclass, field |
18 | | -from typing import Any, Generic, Literal, TypeVar, overload |
| 27 | +from typing import Any, Generic, Literal, TypeVar, cast, overload |
19 | 28 |
|
20 | 29 | import anyio |
| 30 | +import anyio.abc |
21 | 31 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
22 | 32 |
|
23 | 33 | from mcp.shared._stream_protocols import ReadStream, WriteStream |
24 | | -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT |
| 34 | +from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT |
25 | 35 | from mcp.shared.exceptions import MCPError, NoBackChannelError |
26 | 36 | from mcp.shared.message import ( |
27 | 37 | ClientMessageMetadata, |
|
31 | 41 | ) |
32 | 42 | from mcp.shared.transport_context import TransportContext |
33 | 43 | from mcp.types import ( |
| 44 | + CONNECTION_CLOSED, |
34 | 45 | REQUEST_TIMEOUT, |
35 | 46 | ErrorData, |
| 47 | + JSONRPCError, |
36 | 48 | JSONRPCMessage, |
37 | 49 | JSONRPCNotification, |
38 | 50 | JSONRPCRequest, |
| 51 | + JSONRPCResponse, |
39 | 52 | ProgressToken, |
40 | 53 | RequestId, |
41 | 54 | ) |
@@ -141,8 +154,12 @@ def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | |
141 | 154 | return None |
142 | 155 |
|
143 | 156 |
|
144 | | -class JSONRPCDispatcher(Generic[TransportT]): |
145 | | - """`Dispatcher` over the existing `SessionMessage` stream contract.""" |
| 157 | +class JSONRPCDispatcher(Dispatcher[TransportT]): |
| 158 | + """`Dispatcher` over the existing `SessionMessage` stream contract. |
| 159 | +
|
| 160 | + Inherits the `Dispatcher` Protocol explicitly so pyright checks |
| 161 | + conformance at the class definition rather than at first use. |
| 162 | + """ |
146 | 163 |
|
147 | 164 | @overload |
148 | 165 | def __init__( |
@@ -171,13 +188,20 @@ def __init__( |
171 | 188 | ) -> None: |
172 | 189 | self._read_stream = read_stream |
173 | 190 | self._write_stream = write_stream |
174 | | - self._transport_builder = transport_builder or _default_transport_builder |
| 191 | + # The overloads guarantee that when `transport_builder` is omitted, |
| 192 | + # `TransportT` is `TransportContext`, so the default is type-correct; |
| 193 | + # pyright can't see across overloads, hence the cast. |
| 194 | + self._transport_builder = cast( |
| 195 | + "Callable[[RequestId | None, MessageMetadata], TransportT]", |
| 196 | + transport_builder or _default_transport_builder, |
| 197 | + ) |
175 | 198 | self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode |
176 | 199 | self._raise_handler_exceptions = raise_handler_exceptions |
177 | 200 |
|
178 | 201 | self._next_id = 0 |
179 | 202 | self._pending: dict[RequestId, _Pending] = {} |
180 | 203 | self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} |
| 204 | + self._tg: anyio.abc.TaskGroup | None = None |
181 | 205 | self._running = False |
182 | 206 |
|
183 | 207 | async def send_request( |
@@ -219,6 +243,11 @@ async def send_request( |
219 | 243 | meta["progressToken"] = request_id |
220 | 244 | out_params = {**(out_params or {}), "_meta": meta} |
221 | 245 |
|
| 246 | + # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from |
| 247 | + # `_resolve_pending`/`_fan_out_closed` means the waiter already has an |
| 248 | + # outcome and dropping the late/redundant signal is correct. buffer=0 |
| 249 | + # is unsafe — there's a window between registering `_pending[id]` and |
| 250 | + # parking in `receive()` where a close signal would be lost. |
222 | 251 | send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) |
223 | 252 | pending = _Pending(send=send, receive=receive, on_progress=on_progress) |
224 | 253 | self._pending[request_id] = pending |
@@ -264,8 +293,197 @@ async def notify( |
264 | 293 | msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) |
265 | 294 | await self._write(msg, _outbound_metadata(_related_request_id, None)) |
266 | 295 |
|
267 | | - async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: |
268 | | - raise NotImplementedError # chunk (b) |
| 296 | + async def run( |
| 297 | + self, |
| 298 | + on_request: OnRequest, |
| 299 | + on_notify: OnNotify, |
| 300 | + *, |
| 301 | + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, |
| 302 | + ) -> None: |
| 303 | + """Drive the receive loop until the read stream closes. |
| 304 | +
|
| 305 | + Each inbound request is handled in its own task in an internal task |
| 306 | + group; ``task_status.started()`` fires once that group is open, so |
| 307 | + ``await tg.start(dispatcher.run, ...)`` resumes when ``send_request`` |
| 308 | + is usable. |
| 309 | + """ |
| 310 | + try: |
| 311 | + async with anyio.create_task_group() as tg: |
| 312 | + self._tg = tg |
| 313 | + self._running = True |
| 314 | + task_status.started() |
| 315 | + async with self._read_stream: |
| 316 | + async for item in self._read_stream: |
| 317 | + # Duck-typed: `_context_streams.ContextReceiveStream` |
| 318 | + # exposes `.last_context` (the sender's contextvars |
| 319 | + # snapshot per message). Plain memory streams don't. |
| 320 | + sender_ctx: contextvars.Context | None = getattr(self._read_stream, "last_context", None) |
| 321 | + self._dispatch(item, on_request, on_notify, sender_ctx) |
| 322 | + # Read stream EOF: wake any blocked `send_request` waiters now, |
| 323 | + # *before* the task group joins, so handlers parked in |
| 324 | + # `dctx.send_request()` can unwind and the join doesn't deadlock. |
| 325 | + self._running = False |
| 326 | + self._fan_out_closed() |
| 327 | + finally: |
| 328 | + # Covers the cancel/crash paths where the inline fan-out above is |
| 329 | + # never reached. Idempotent. |
| 330 | + self._running = False |
| 331 | + self._tg = None |
| 332 | + self._fan_out_closed() |
| 333 | + |
| 334 | + def _dispatch( |
| 335 | + self, |
| 336 | + item: SessionMessage | Exception, |
| 337 | + on_request: OnRequest, |
| 338 | + on_notify: OnNotify, |
| 339 | + sender_ctx: contextvars.Context | None, |
| 340 | + ) -> None: |
| 341 | + """Route one inbound item. Synchronous: never awaits. |
| 342 | +
|
| 343 | + Everything here is `send_nowait` or `_spawn`. An `await` would let one |
| 344 | + slow message head-of-line block the entire read loop. |
| 345 | + """ |
| 346 | + if isinstance(item, Exception): |
| 347 | + logger.debug("transport yielded exception: %r", item) |
| 348 | + return |
| 349 | + metadata = item.metadata |
| 350 | + msg = item.message |
| 351 | + match msg: |
| 352 | + case JSONRPCRequest(): |
| 353 | + self._dispatch_request(msg, metadata, on_request, sender_ctx) |
| 354 | + case JSONRPCNotification(): |
| 355 | + self._dispatch_notification(msg, metadata, on_notify, sender_ctx) |
| 356 | + case JSONRPCResponse(): |
| 357 | + self._resolve_pending(msg.id, msg.result) |
| 358 | + case JSONRPCError(): |
| 359 | + # `id` may be None per JSON-RPC (parse error before id known). |
| 360 | + self._resolve_pending(msg.id, msg.error) |
| 361 | + |
| 362 | + def _dispatch_request( |
| 363 | + self, |
| 364 | + req: JSONRPCRequest, |
| 365 | + metadata: MessageMetadata, |
| 366 | + on_request: OnRequest, |
| 367 | + sender_ctx: contextvars.Context | None, |
| 368 | + ) -> None: |
| 369 | + progress_token: ProgressToken | None |
| 370 | + match req.params: |
| 371 | + case {"_meta": {"progressToken": str() | int() as progress_token}}: |
| 372 | + pass |
| 373 | + case _: |
| 374 | + progress_token = None |
| 375 | + transport_ctx = self._transport_builder(req.id, metadata) |
| 376 | + dctx = _JSONRPCDispatchContext( |
| 377 | + transport=transport_ctx, |
| 378 | + _dispatcher=self, |
| 379 | + _request_id=req.id, |
| 380 | + _progress_token=progress_token, |
| 381 | + ) |
| 382 | + scope = anyio.CancelScope() |
| 383 | + self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) |
| 384 | + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) |
| 385 | + |
| 386 | + def _dispatch_notification( |
| 387 | + self, |
| 388 | + msg: JSONRPCNotification, |
| 389 | + metadata: MessageMetadata, |
| 390 | + on_notify: OnNotify, |
| 391 | + sender_ctx: contextvars.Context | None, |
| 392 | + ) -> None: |
| 393 | + if msg.method == "notifications/cancelled": |
| 394 | + match msg.params: |
| 395 | + case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: |
| 396 | + in_flight.cancelled_by_peer = True |
| 397 | + in_flight.dctx.cancel_requested.set() |
| 398 | + if self._peer_cancel_mode == "interrupt": |
| 399 | + in_flight.scope.cancel() |
| 400 | + case _: |
| 401 | + pass |
| 402 | + return |
| 403 | + if msg.method == "notifications/progress": |
| 404 | + match msg.params: |
| 405 | + case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( |
| 406 | + pending := self._pending.get(token) |
| 407 | + ) is not None and pending.on_progress is not None: |
| 408 | + total = msg.params.get("total") |
| 409 | + message = msg.params.get("message") |
| 410 | + self._spawn( |
| 411 | + pending.on_progress, |
| 412 | + float(progress), |
| 413 | + float(total) if isinstance(total, int | float) else None, |
| 414 | + message if isinstance(message, str) else None, |
| 415 | + sender_ctx=sender_ctx, |
| 416 | + ) |
| 417 | + case _: |
| 418 | + pass |
| 419 | + # fall through: progress is also teed to on_notify |
| 420 | + transport_ctx = self._transport_builder(None, metadata) |
| 421 | + dctx = _JSONRPCDispatchContext(transport=transport_ctx, _dispatcher=self, _request_id=None) |
| 422 | + self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) |
| 423 | + |
| 424 | + def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: |
| 425 | + pending = self._pending.get(request_id) if request_id is not None else None |
| 426 | + if pending is None: |
| 427 | + logger.debug("dropping response for unknown/late request id %r", request_id) |
| 428 | + return |
| 429 | + try: |
| 430 | + pending.send.send_nowait(outcome) |
| 431 | + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): |
| 432 | + logger.debug("waiter for request id %r already gone", request_id) |
| 433 | + |
| 434 | + def _spawn( |
| 435 | + self, |
| 436 | + fn: Callable[..., Awaitable[Any]], |
| 437 | + *args: object, |
| 438 | + sender_ctx: contextvars.Context | None, |
| 439 | + ) -> None: |
| 440 | + """Schedule ``fn(*args)`` in the run() task group, propagating the sender's contextvars. |
| 441 | +
|
| 442 | + ASGI middleware (auth, OTel) sets contextvars on the request task that |
| 443 | + wrote into the read stream. ``Context.run(tg.start_soon, ...)`` makes |
| 444 | + the spawned handler inherit *that* context instead of the receive |
| 445 | + loop's, so ``auth_context_var`` and OTel spans survive. |
| 446 | + """ |
| 447 | + assert self._tg is not None |
| 448 | + if sender_ctx is not None: |
| 449 | + sender_ctx.run(self._tg.start_soon, fn, *args) |
| 450 | + else: |
| 451 | + self._tg.start_soon(fn, *args) |
| 452 | + |
| 453 | + def _fan_out_closed(self) -> None: |
| 454 | + """Wake every pending ``send_request`` waiter with ``CONNECTION_CLOSED``. |
| 455 | +
|
| 456 | + Synchronous (uses ``send_nowait``) because it's called from ``finally`` |
| 457 | + which may be inside a cancelled scope. Idempotent. |
| 458 | + """ |
| 459 | + closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") |
| 460 | + for pending in self._pending.values(): |
| 461 | + try: |
| 462 | + pending.send.send_nowait(closed) |
| 463 | + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): |
| 464 | + pass |
| 465 | + self._pending.clear() |
| 466 | + |
| 467 | + async def _handle_request( |
| 468 | + self, |
| 469 | + req: JSONRPCRequest, |
| 470 | + dctx: _JSONRPCDispatchContext[TransportT], |
| 471 | + scope: anyio.CancelScope, |
| 472 | + on_request: OnRequest, |
| 473 | + ) -> None: |
| 474 | + """Run ``on_request`` for one inbound request and write its response. |
| 475 | +
|
| 476 | + Chunk (b): happy-path only. The full exception-to-wire boundary |
| 477 | + (MCPError, ValidationError, INTERNAL_ERROR scrubbing, peer-cancel |
| 478 | + no-response) lands in chunk (c). |
| 479 | + """ |
| 480 | + try: |
| 481 | + with scope: |
| 482 | + result = await on_request(dctx, req.method, req.params) |
| 483 | + await self._write(JSONRPCResponse(jsonrpc="2.0", id=req.id, result=result)) |
| 484 | + finally: |
| 485 | + self._in_flight.pop(req.id, None) |
| 486 | + dctx.close() |
269 | 487 |
|
270 | 488 | def _allocate_id(self) -> int: |
271 | 489 | self._next_id += 1 |
|
0 commit comments