Skip to content

Commit d84f82a

Browse files
committed
feat: JSONRPCDispatcher outbound side + parametrized contract tests
Chunk (a) of JSONRPCDispatcher: constructor, _Pending/_InFlight/_JSONRPCDispatchContext, send_request/notify and helpers. run() is stubbed. The Dispatcher contract tests are now parametrized over a pair_factory fixture (direct + jsonrpc). The 9 jsonrpc cases are strict-xfail until run()/ _handle_request land in the next commits; once those pass, strict xfail flips to XPASS and forces removal of the marker. Factories return (client, server, close) so running_pair can shut down any implementation uniformly.
1 parent bfb5a77 commit d84f82a

3 files changed

Lines changed: 421 additions & 64 deletions

File tree

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
"""JSON-RPC `Dispatcher` implementation.
2+
3+
Consumes the existing `SessionMessage`-based stream contract that all current
4+
transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation,
5+
the receive loop, per-request task isolation, cancellation/progress wiring, and
6+
the single exception-to-wire boundary.
7+
8+
The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and
9+
sees only `(ctx, method, params) -> dict`. Transports sit below and see only
10+
`SessionMessage` reads/writes.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import logging
16+
from collections.abc import Callable, Mapping
17+
from dataclasses import dataclass, field
18+
from typing import Any, Generic, Literal, TypeVar, overload
19+
20+
import anyio
21+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
22+
23+
from mcp.shared._stream_protocols import ReadStream, WriteStream
24+
from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT
25+
from mcp.shared.exceptions import MCPError, NoBackChannelError
26+
from mcp.shared.message import (
27+
ClientMessageMetadata,
28+
MessageMetadata,
29+
ServerMessageMetadata,
30+
SessionMessage,
31+
)
32+
from mcp.shared.transport_context import TransportContext
33+
from mcp.types import (
34+
REQUEST_TIMEOUT,
35+
ErrorData,
36+
JSONRPCMessage,
37+
JSONRPCNotification,
38+
JSONRPCRequest,
39+
ProgressToken,
40+
RequestId,
41+
)
42+
43+
__all__ = ["JSONRPCDispatcher"]
44+
45+
logger = logging.getLogger(__name__)
46+
47+
TransportT = TypeVar("TransportT", bound=TransportContext)
48+
49+
PeerCancelMode = Literal["interrupt", "signal"]
50+
"""How inbound ``notifications/cancelled`` is applied to a running handler.
51+
52+
``"interrupt"`` (default) cancels the handler's scope. ``"signal"`` only sets
53+
``ctx.cancel_requested`` and lets the handler observe it cooperatively.
54+
"""
55+
56+
TransportBuilder = Callable[[RequestId | None, MessageMetadata], TransportContext]
57+
"""Builds the per-message `TransportContext` from the inbound JSON-RPC id and
58+
the `SessionMessage.metadata` the transport attached. Defaults to a plain
59+
`TransportContext(kind="jsonrpc", can_send_request=True)` when not supplied."""
60+
61+
62+
@dataclass(slots=True)
63+
class _Pending:
64+
"""An outbound request awaiting its response."""
65+
66+
send: MemoryObjectSendStream[dict[str, Any] | ErrorData]
67+
receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData]
68+
on_progress: ProgressFnT | None = None
69+
70+
71+
@dataclass(slots=True)
72+
class _InFlight(Generic[TransportT]):
73+
"""An inbound request currently being handled."""
74+
75+
scope: anyio.CancelScope
76+
dctx: _JSONRPCDispatchContext[TransportT]
77+
cancelled_by_peer: bool = False
78+
79+
80+
@dataclass
81+
class _JSONRPCDispatchContext(Generic[TransportT]):
82+
"""Concrete `DispatchContext` produced for each inbound JSON-RPC message."""
83+
84+
transport: TransportT
85+
_dispatcher: JSONRPCDispatcher[TransportT]
86+
_request_id: RequestId | None
87+
_progress_token: ProgressToken | None = None
88+
_closed: bool = False
89+
cancel_requested: anyio.Event = field(default_factory=anyio.Event)
90+
91+
@property
92+
def can_send_request(self) -> bool:
93+
return self.transport.can_send_request and not self._closed
94+
95+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
96+
await self._dispatcher.notify(method, params, _related_request_id=self._request_id)
97+
98+
async def send_request(
99+
self,
100+
method: str,
101+
params: Mapping[str, Any] | None,
102+
opts: CallOptions | None = None,
103+
) -> dict[str, Any]:
104+
if not self.can_send_request:
105+
raise NoBackChannelError(method)
106+
return await self._dispatcher.send_request(method, params, opts, _related_request_id=self._request_id)
107+
108+
async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
109+
if self._progress_token is None:
110+
return
111+
params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress}
112+
if total is not None:
113+
params["total"] = total
114+
if message is not None:
115+
params["message"] = message
116+
await self.notify("notifications/progress", params)
117+
118+
def close(self) -> None:
119+
self._closed = True
120+
121+
122+
def _default_transport_builder(_request_id: RequestId | None, _meta: MessageMetadata) -> TransportContext:
123+
return TransportContext(kind="jsonrpc", can_send_request=True)
124+
125+
126+
def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata:
127+
"""Choose the `SessionMessage.metadata` for an outgoing request/notification.
128+
129+
`ServerMessageMetadata` tags a server-to-client message with the inbound
130+
request it belongs to (so streamable-HTTP can route it onto that request's
131+
SSE stream). `ClientMessageMetadata` carries resumption hints to the
132+
client transport. ``None`` is the common case.
133+
"""
134+
if related_request_id is not None:
135+
return ServerMessageMetadata(related_request_id=related_request_id)
136+
if opts:
137+
token = opts.get("resumption_token")
138+
on_token = opts.get("on_resumption_token")
139+
if token is not None or on_token is not None:
140+
return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token)
141+
return None
142+
143+
144+
class JSONRPCDispatcher(Generic[TransportT]):
145+
"""`Dispatcher` over the existing `SessionMessage` stream contract."""
146+
147+
@overload
148+
def __init__(
149+
self: JSONRPCDispatcher[TransportContext],
150+
read_stream: ReadStream[SessionMessage | Exception],
151+
write_stream: WriteStream[SessionMessage],
152+
) -> None: ...
153+
@overload
154+
def __init__(
155+
self,
156+
read_stream: ReadStream[SessionMessage | Exception],
157+
write_stream: WriteStream[SessionMessage],
158+
*,
159+
transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT],
160+
peer_cancel_mode: PeerCancelMode = "interrupt",
161+
raise_handler_exceptions: bool = False,
162+
) -> None: ...
163+
def __init__(
164+
self,
165+
read_stream: ReadStream[SessionMessage | Exception],
166+
write_stream: WriteStream[SessionMessage],
167+
*,
168+
transport_builder: Callable[[RequestId | None, MessageMetadata], TransportT] | None = None,
169+
peer_cancel_mode: PeerCancelMode = "interrupt",
170+
raise_handler_exceptions: bool = False,
171+
) -> None:
172+
self._read_stream = read_stream
173+
self._write_stream = write_stream
174+
self._transport_builder = transport_builder or _default_transport_builder
175+
self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode
176+
self._raise_handler_exceptions = raise_handler_exceptions
177+
178+
self._next_id = 0
179+
self._pending: dict[RequestId, _Pending] = {}
180+
self._in_flight: dict[RequestId, _InFlight[TransportT]] = {}
181+
self._running = False
182+
183+
async def send_request(
184+
self,
185+
method: str,
186+
params: Mapping[str, Any] | None,
187+
opts: CallOptions | None = None,
188+
*,
189+
_related_request_id: RequestId | None = None,
190+
) -> dict[str, Any]:
191+
"""Send a JSON-RPC request and await its response.
192+
193+
``_related_request_id`` is set only by `_JSONRPCDispatchContext` when a
194+
handler makes a server-to-client request mid-flight; it routes the
195+
outgoing message onto the correct per-request SSE stream (SHTTP) via
196+
`ServerMessageMetadata`. Top-level callers leave it ``None``.
197+
198+
Raises:
199+
MCPError: The peer responded with a JSON-RPC error; or
200+
``REQUEST_TIMEOUT`` if ``opts["timeout"]`` elapsed; or
201+
``CONNECTION_CLOSED`` if the dispatcher shut down while
202+
awaiting the response.
203+
RuntimeError: Called before ``run()`` has started or after it has
204+
finished.
205+
"""
206+
if not self._running:
207+
raise RuntimeError("JSONRPCDispatcher.send_request called before run() / after close")
208+
opts = opts or {}
209+
request_id = self._allocate_id()
210+
out_params = dict(params) if params is not None else None
211+
on_progress = opts.get("on_progress")
212+
if on_progress is not None:
213+
# The caller wants progress updates. The spec mechanism is: include
214+
# `_meta.progressToken` on the request; the peer echoes that token on
215+
# any `notifications/progress` it sends. We use the request id as the
216+
# token so the receive loop can find this `_Pending.on_progress` by
217+
# `_pending[token]` without a second lookup table.
218+
meta = dict((out_params or {}).get("_meta") or {})
219+
meta["progressToken"] = request_id
220+
out_params = {**(out_params or {}), "_meta": meta}
221+
222+
send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1)
223+
pending = _Pending(send=send, receive=receive, on_progress=on_progress)
224+
self._pending[request_id] = pending
225+
226+
metadata = _outbound_metadata(_related_request_id, opts)
227+
msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params)
228+
try:
229+
await self._write(msg, metadata)
230+
with anyio.fail_after(opts.get("timeout")):
231+
outcome = await receive.receive()
232+
except TimeoutError:
233+
# Spec-recommended courtesy: tell the peer we've given up so it can
234+
# stop work and free resources. v1's BaseSession.send_request does
235+
# NOT do this; it's new behaviour.
236+
await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s")
237+
raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None
238+
except anyio.get_cancelled_exc_class():
239+
# Our caller's scope was cancelled. We're already inside a cancelled
240+
# scope, so any bare `await` here re-raises immediately — shield to
241+
# let the courtesy cancel notification go out before we propagate.
242+
with anyio.CancelScope(shield=True):
243+
await self._cancel_outbound(request_id, "caller cancelled")
244+
raise
245+
finally:
246+
# Always remove the waiter, even on cancel/timeout, so a late
247+
# response from the peer (race) hits a closed stream and is dropped
248+
# in `_dispatch` rather than leaking.
249+
self._pending.pop(request_id, None)
250+
send.close()
251+
receive.close()
252+
253+
if isinstance(outcome, ErrorData):
254+
raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data)
255+
return outcome
256+
257+
async def notify(
258+
self,
259+
method: str,
260+
params: Mapping[str, Any] | None,
261+
*,
262+
_related_request_id: RequestId | None = None,
263+
) -> None:
264+
msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None)
265+
await self._write(msg, _outbound_metadata(_related_request_id, None))
266+
267+
async def run(self, on_request: OnRequest, on_notify: OnNotify) -> None:
268+
raise NotImplementedError # chunk (b)
269+
270+
def _allocate_id(self) -> int:
271+
self._next_id += 1
272+
return self._next_id
273+
274+
async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None:
275+
await self._write_stream.send(SessionMessage(message=message, metadata=metadata))
276+
277+
async def _cancel_outbound(self, request_id: RequestId, reason: str) -> None:
278+
try:
279+
await self.notify("notifications/cancelled", {"requestId": request_id, "reason": reason})
280+
except anyio.BrokenResourceError:
281+
pass
282+
except anyio.ClosedResourceError:
283+
pass

tests/shared/conftest.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""Shared fixtures for `Dispatcher` contract tests.
2+
3+
The `pair_factory` fixture parametrizes contract tests over every `Dispatcher`
4+
implementation, so the same behavioral assertions run against `DirectDispatcher`
5+
(in-memory) and `JSONRPCDispatcher` (over crossed anyio memory streams).
6+
"""
7+
8+
from collections.abc import Callable
9+
10+
import anyio
11+
import pytest
12+
13+
from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair
14+
from mcp.shared.dispatcher import Dispatcher
15+
from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher
16+
from mcp.shared.message import SessionMessage
17+
from mcp.shared.transport_context import TransportContext
18+
19+
DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]]
20+
PairFactory = Callable[..., DispatcherTriple]
21+
22+
23+
def direct_pair(*, can_send_request: bool = True) -> DispatcherTriple:
24+
client, server = create_direct_dispatcher_pair(can_send_request=can_send_request)
25+
26+
def close() -> None:
27+
client.close()
28+
server.close()
29+
30+
return client, server, close
31+
32+
33+
def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple:
34+
"""Two `JSONRPCDispatcher`s wired over crossed in-memory streams."""
35+
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
36+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
37+
38+
def builder(_rid: object, _meta: object) -> TransportContext:
39+
return TransportContext(kind="jsonrpc", can_send_request=can_send_request)
40+
41+
client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder)
42+
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, transport_builder=builder)
43+
44+
def close() -> None:
45+
for s in (c2s_send, c2s_recv, s2c_send, s2c_recv):
46+
s.close()
47+
48+
return client, server, close
49+
50+
51+
_JSONRPC_XFAIL = pytest.mark.xfail(
52+
strict=True,
53+
reason="JSONRPCDispatcher.run() not yet implemented (PR2 chunks b/c)",
54+
)
55+
56+
57+
@pytest.fixture(
58+
params=[
59+
pytest.param(direct_pair, id="direct"),
60+
pytest.param(jsonrpc_pair, id="jsonrpc", marks=_JSONRPC_XFAIL),
61+
]
62+
)
63+
def pair_factory(request: pytest.FixtureRequest) -> PairFactory:
64+
return request.param
65+
66+
67+
__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"]

0 commit comments

Comments
 (0)