|
| 1 | +"""JSON-RPC `Dispatcher` implementation. |
| 2 | +
|
| 3 | +Consumes the existing `SessionMessage`-based stream contract that all current |
| 4 | +transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, |
| 5 | +the receive loop, per-request task isolation, cancellation/progress wiring, and |
| 6 | +the single exception-to-wire boundary. |
| 7 | +
|
| 8 | +The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and |
| 9 | +sees only `(ctx, method, params) -> dict`. Transports sit below and see only |
| 10 | +`SessionMessage` reads/writes. |
| 11 | +""" |
| 12 | + |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +import logging |
| 16 | +from collections.abc import Callable, Mapping |
| 17 | +from dataclasses import dataclass, field |
| 18 | +from typing import Any, Generic, Literal, TypeVar, overload |
| 19 | + |
| 20 | +import anyio |
| 21 | +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
| 22 | + |
| 23 | +from mcp.shared._stream_protocols import ReadStream, WriteStream |
| 24 | +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT |
| 25 | +from mcp.shared.exceptions import MCPError, NoBackChannelError |
| 26 | +from mcp.shared.message import ( |
| 27 | + ClientMessageMetadata, |
| 28 | + MessageMetadata, |
| 29 | + ServerMessageMetadata, |
| 30 | + SessionMessage, |
| 31 | +) |
| 32 | +from mcp.shared.transport_context import TransportContext |
| 33 | +from mcp.types import ( |
| 34 | + REQUEST_TIMEOUT, |
| 35 | + ErrorData, |
| 36 | + JSONRPCMessage, |
| 37 | + JSONRPCNotification, |
| 38 | + JSONRPCRequest, |
| 39 | + ProgressToken, |
| 40 | + RequestId, |
| 41 | +) |
| 42 | + |
| 43 | +__all__ = ["JSONRPCDispatcher"] |
| 44 | + |
| 45 | +logger = logging.getLogger(__name__) |
| 46 | + |
| 47 | +TransportT = TypeVar("TransportT", bound=TransportContext) |
| 48 | + |
| 49 | +PeerCancelMode = Literal["interrupt", "signal"] |
| 50 | +"""How inbound ``notifications/cancelled`` is applied to a running handler. |
| 51 | +
|
| 52 | +``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets |
| 53 | +``ctx.cancel_requested`` and lets the handler observe it cooperatively. |
| 54 | +""" |
| 55 | + |
| 56 | +TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext] |
| 57 | +"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and |
| 58 | +the `SessionMessage.metadata` the transport attached. Defaults to a plain |
| 59 | +`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied.""" |
| 60 | + |
| 61 | + |
| 62 | +@dataclass(slots=True) |
| 63 | +class _Pending: |
| 64 | + """An outbound request awaiting its response.""" |
| 65 | + |
| 66 | + send: MemoryObjectSendStream[dict[str, Any] | ErrorData] |
| 67 | + receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData] |
| 68 | + on_progress: ProgressFnT | None = None |
| 69 | + |
| 70 | + |
| 71 | +@dataclass(slots=True) |
| 72 | +class _InFlight(Generic[TransportT]): |
| 73 | + """An inbound request currently being handled.""" |
| 74 | + |
| 75 | + scope: anyio.CancelScope |
| 76 | + dctx: _JSONRPCDispatchContext[TransportT] |
| 77 | + cancelled_by_peer: bool = False |
| 78 | + |
| 79 | + |
| 80 | +@dataclass |
| 81 | +class _JSONRPCDispatchContext(Generic[TransportT]): |
| 82 | + """Concrete `DispatchContext` produced for each inbound JSON-RPC message.""" |
| 83 | + |
| 84 | + transport: TransportT |
| 85 | + _dispatcher: JSONRPCDispatcher[TransportT] |
| 86 | + _request_id: RequestId | None |
| 87 | + _progress_token: ProgressToken | None = None |
| 88 | + _closed: bool = False |
| 89 | + cancel_requested: anyio.Event = field(default_factory=anyio.Event) |
| 90 | + |
| 91 | + @property |
| 92 | + def can_send_request(self) -> bool: |
| 93 | + return self.transport.can_send_request and not self._closed |
| 94 | + |
| 95 | + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: |
| 96 | + await self._dispatcher.notify(method, params, _related_request_id=self._request_id) |
| 97 | + |
| 98 | + async def send_request( |
| 99 | + self, |
| 100 | + method: str, |
| 101 | + params: Mapping[str, Any] | None, |
| 102 | + opts: CallOptions | None = None, |
| 103 | + ) -> dict[str, Any]: |
| 104 | + if not self.can_send_request: |
| 105 | + raise NoBackChannelError(method) |
| 106 | + return await self._dispatcher.send_request(method, params, opts, _related_request_id=self._request_id) |
| 107 | + |
| 108 | + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: |
| 109 | + if self._progress_token is None: |
| 110 | + return |
| 111 | + params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress} |
| 112 | + if total is not None: |
| 113 | + params["total"] = total |
| 114 | + if message is not None: |
| 115 | + params["message"] = message |
| 116 | + await self.notify("notifications/progress", params) |
| 117 | + |
| 118 | + def close(self) -> None: |
| 119 | + self._closed = True |
| 120 | + |
| 121 | + |
| 122 | +def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext: |
| 123 | + return TransportContext(kind="jsonrpc", can_send_request=True) |
| 124 | + |
| 125 | + |
| 126 | +def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: |
| 127 | + """Choose the `SessionMessage.metadata` for an outgoing request/notification. |
| 128 | +
|
| 129 | + `ServerMessageMetadata` tags a server-to-client message with the inbound |
| 130 | + request it belongs to (so streamable-HTTP can route it onto that request's |
| 131 | + SSE stream). `ClientMessageMetadata` carries resumption hints to the |
| 132 | + client transport. ``None`` is the common case. |
| 133 | + """ |
| 134 | + if related_request_id is not None: |
| 135 | + return ServerMessageMetadata(related_request_id=related_request_id) |
| 136 | + if opts: |
| 137 | + token = opts.get("resumption_token") |
| 138 | + on_token = opts.get("on_resumption_token") |
| 139 | + if token is not None or on_token is not None: |
| 140 | + return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) |
| 141 | + return None |
| 142 | + |
| 143 | + |
| 144 | +class JSONRPCDispatcher(Generic[TransportT]): |
| 145 | + """`Dispatcher` over the existing `SessionMessage` stream contract.""" |
| 146 | + |
| 147 | + @overload |
| 148 | + def __init__( |
| 149 | + self: JSONRPCDispatcher[TransportContext], |
| 150 | + read_stream: ReadStream[SessionMessage | Exception], |
| 151 | + write_stream: WriteStream[SessionMessage], |
| 152 | + ) -> None: ... |
| 153 | + @overload |
| 154 | + def __init__( |
| 155 | + self, |
| 156 | + read_stream: ReadStream[SessionMessage | Exception], |
| 157 | + write_stream: WriteStream[SessionMessage], |
| 158 | + *, |
| 159 | + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT], |
| 160 | + peer_cancel_mode: PeerCancelMode = "interrupt", |
| 161 | + raise_handler_exceptions: bool = False, |
| 162 | + ) -> None: ... |
| 163 | + def __init__( |
| 164 | + self, |
| 165 | + read_stream: ReadStream[SessionMessage | Exception], |
| 166 | + write_stream: WriteStream[SessionMessage], |
| 167 | + *, |
| 168 | + transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None, |
| 169 | + peer_cancel_mode: PeerCancelMode = "interrupt", |
| 170 | + raise_handler_exceptions: bool = False, |
| 171 | + ) -> None: |
| 172 | + self._read_stream = read_stream |
| 173 | + self._write_stream = write_stream |
| 174 | + self._transport_builder = transport_builder or _default_transport_builder |
| 175 | + self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode |
| 176 | + self._raise_handler_exceptions = raise_handler_exceptions |
| 177 | + |
| 178 | + self._next_id = 0 |
| 179 | + self._pending: dict[RequestId, _Pending] = {} |
| 180 | + self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} |
| 181 | + self._running = False |
| 182 | + |
| 183 | + async def send_request( |
| 184 | + self, |
| 185 | + method: str, |
| 186 | + params: Mapping[str, Any] | None, |
| 187 | + opts: CallOptions | None = None, |
| 188 | + *, |
| 189 | + _related_request_id: RequestId | None = None, |
| 190 | + ) -> dict[str, Any]: |
| 191 | + """Send a JSON-RPC request and await its response. |
| 192 | +
|
| 193 | + ``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a |
| 194 | + handler makes a server-to-client request mid-flight; it routes the |
| 195 | + outgoing message onto the correct per-request SSE stream (SHTTP) via |
| 196 | + `ServerMessageMetadata`. Top-level callers leave it ``None``. |
| 197 | +
|
| 198 | + Raises: |
| 199 | + MCPError: The peer responded with a JSON-RPC error; or |
| 200 | + ``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or |
| 201 | + ``CONNECTION_CLOSED`` if the dispatcher shut down while |
| 202 | + awaiting the response. |
| 203 | + RuntimeError: Called before ``run()`` has started or after it has |
| 204 | + finished. |
| 205 | + """ |
| 206 | + if not self._running: |
| 207 | + raise RuntimeError("JSONRPCDispatcher.send_request called before run() / after close") |
| 208 | + opts = opts or {} |
| 209 | + request_id = self._allocate_id() |
| 210 | + out_params = dict(params) if params is not None else None |
| 211 | + on_progress = opts.get("on_progress") |
| 212 | + if on_progress is not None: |
| 213 | + # The caller wants progress updates. The spec mechanism is: include |
| 214 | + # `_meta.progressToken` on the request; the peer echoes that token on |
| 215 | + # any `notifications/progress` it sends. We use the request id as the |
| 216 | + # token so the receive loop can find this `_Pending.on_progress` by |
| 217 | + # `_pending[token]` without a second lookup table. |
| 218 | + meta = dict((out_params or {}).get("_meta") or {}) |
| 219 | + meta["progressToken"] = request_id |
| 220 | + out_params = {**(out_params or {}), "_meta": meta} |
| 221 | + |
| 222 | + send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) |
| 223 | + pending = _Pending(send=send, receive=receive, on_progress=on_progress) |
| 224 | + self._pending[request_id] = pending |
| 225 | + |
| 226 | + metadata = _outbound_metadata(_related_request_id, opts) |
| 227 | + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) |
| 228 | + try: |
| 229 | + await self._write(msg, metadata) |
| 230 | + with anyio.fail_after(opts.get("timeout")): |
| 231 | + outcome = await receive.receive() |
| 232 | + except TimeoutError: |
| 233 | + # Spec-recommended courtesy: tell the peer we've given up so it can |
| 234 | + # stop work and free resources. v1's BaseSession.send_request does |
| 235 | + # NOT do this; it's new behaviour. |
| 236 | + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s") |
| 237 | + raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None |
| 238 | + except anyio.get_cancelled_exc_class(): |
| 239 | + # Our caller's scope was cancelled. We're already inside a cancelled |
| 240 | + # scope, so any bare `await` here re-raises immediately — shield to |
| 241 | + # let the courtesy cancel notification go out before we propagate. |
| 242 | + with anyio.CancelScope(shield=True): |
| 243 | + await self._cancel_outbound(request_id, "caller cancelled") |
| 244 | + raise |
| 245 | + finally: |
| 246 | + # Always remove the waiter, even on cancel/timeout, so a late |
| 247 | + # response from the peer (race) hits a closed stream and is dropped |
| 248 | + # in `_dispatch` rather than leaking. |
| 249 | + self._pending.pop(request_id, None) |
| 250 | + send.close() |
| 251 | + receive.close() |
| 252 | + |
| 253 | + if isinstance(outcome, ErrorData): |
| 254 | + raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data) |
| 255 | + return outcome |
| 256 | + |
| 257 | + async def notify( |
| 258 | + self, |
| 259 | + method: str, |
| 260 | + params: Mapping[str, Any] | None, |
| 261 | + *, |
| 262 | + _related_request_id: RequestId | None = None, |
| 263 | + ) -> None: |
| 264 | + msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) |
| 265 | + await self._write(msg, _outbound_metadata(_related_request_id, None)) |
| 266 | + |
| 267 | + async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None: |
| 268 | + raise NotImplementedError # chunk (b) |
| 269 | + |
| 270 | + def _allocate_id(self) -> int: |
| 271 | + self._next_id += 1 |
| 272 | + return self._next_id |
| 273 | + |
| 274 | + async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: |
| 275 | + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) |
| 276 | + |
| 277 | + async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None: |
| 278 | + try: |
| 279 | + await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason}) |
| 280 | + except anyio.BrokenResourceError: |
| 281 | + pass |
| 282 | + except anyio.ClosedResourceError: |
| 283 | + pass |
0 commit comments