Skip to content

Commit de81288

Browse files
committed
feat: BaseContext
Composition over a DispatchContext: forwards transport/cancel_requested/ send_request/notify/progress and adds meta. Satisfies Outbound so PeerMixin works on it (proven by Peer(bctx).ping() round-tripping). The server Context (next commit) extends this with lifespan/connection; ClientContext will be an alias once ClientSession is reworked.
1 parent 870cb08 commit de81288

File tree

2 files changed

+197
-0
lines changed

2 files changed

+197
-0
lines changed

src/mcp/shared/context.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""`BaseContext` — the user-facing per-request context.
2+
3+
Composition over a `DispatchContext`: forwards the transport metadata, the
4+
back-channel (`send_request`/`notify`), progress reporting, and the cancel
5+
event. Adds `meta` (the inbound request's `_meta` field).
6+
7+
Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context`
8+
mixes that in directly). Shared between client and server: the server's
9+
`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an
10+
alias.
11+
"""
12+
13+
from collections.abc import Mapping
14+
from typing import Any, Generic
15+
16+
import anyio
17+
from typing_extensions import TypeVar
18+
19+
from mcp.shared.dispatcher import CallOptions, DispatchContext
20+
from mcp.shared.transport_context import TransportContext
21+
from mcp.types import RequestParamsMeta
22+
23+
__all__ = ["BaseContext"]
24+
25+
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext)
26+
27+
28+
class BaseContext(Generic[TransportT]):
29+
"""Per-request context wrapping a `DispatchContext`.
30+
31+
`ServerRunner` (PR4) constructs one per inbound request and passes it to
32+
the user's handler.
33+
"""
34+
35+
def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None:
36+
self._dctx = dctx
37+
self._meta = meta
38+
39+
@property
40+
def transport(self) -> TransportT:
41+
"""Transport-specific metadata for this inbound request."""
42+
return self._dctx.transport
43+
44+
@property
45+
def cancel_requested(self) -> anyio.Event:
46+
"""Set when the peer sends ``notifications/cancelled`` for this request."""
47+
return self._dctx.cancel_requested
48+
49+
@property
50+
def can_send_request(self) -> bool:
51+
"""Whether the back-channel can deliver server-initiated requests."""
52+
return self._dctx.transport.can_send_request
53+
54+
@property
55+
def meta(self) -> RequestParamsMeta | None:
56+
"""The inbound request's ``_meta`` field, if present."""
57+
return self._meta
58+
59+
async def send_request(
60+
self,
61+
method: str,
62+
params: Mapping[str, Any] | None,
63+
opts: CallOptions | None = None,
64+
) -> dict[str, Any]:
65+
"""Send a request to the peer on the back-channel.
66+
67+
Raises:
68+
MCPError: The peer responded with an error.
69+
NoBackChannelError: ``can_send_request`` is ``False``.
70+
"""
71+
return await self._dctx.send_request(method, params, opts)
72+
73+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
74+
"""Send a notification to the peer on the back-channel."""
75+
await self._dctx.notify(method, params)
76+
77+
async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None:
78+
"""Report progress for this request, if the peer supplied a progress token.
79+
80+
A no-op when no token was supplied.
81+
"""
82+
await self._dctx.progress(progress, total, message)

tests/shared/test_context.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""Tests for `BaseContext`.
2+
3+
`BaseContext` is composition over a `DispatchContext` — it forwards
4+
``transport``/``cancel_requested``/``send_request``/``notify``/``progress``
5+
and adds ``meta``. It must satisfy `Outbound` so `PeerMixin` works on it.
6+
"""
7+
8+
from collections.abc import Mapping
9+
from typing import Any
10+
11+
import anyio
12+
import pytest
13+
14+
from mcp.shared.context import BaseContext
15+
from mcp.shared.dispatcher import DispatchContext
16+
from mcp.shared.peer import Peer
17+
from mcp.shared.transport_context import TransportContext
18+
19+
from .conftest import direct_pair
20+
from .test_dispatcher import Recorder, echo_handlers, running_pair
21+
22+
DCtx = DispatchContext[TransportContext]
23+
24+
25+
@pytest.mark.anyio
26+
async def test_base_context_forwards_transport_and_cancel_requested():
27+
captured: list[BaseContext[TransportContext]] = []
28+
29+
async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
30+
bctx = BaseContext(ctx)
31+
captured.append(bctx)
32+
return {}
33+
34+
async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_):
35+
with anyio.fail_after(5):
36+
await client.send_request("t", None)
37+
bctx = captured[0]
38+
assert bctx.transport.kind == "direct"
39+
assert isinstance(bctx.cancel_requested, anyio.Event)
40+
assert bctx.can_send_request is True
41+
assert bctx.meta is None
42+
43+
44+
@pytest.mark.anyio
45+
async def test_base_context_send_request_and_notify_forward_to_dispatch_context():
46+
crec = Recorder()
47+
c_req, c_notify = echo_handlers(crec)
48+
49+
async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
50+
bctx = BaseContext(ctx)
51+
sample = await bctx.send_request("sampling/createMessage", {"x": 1})
52+
await bctx.notify("notifications/message", {"level": "info"})
53+
return {"sample": sample}
54+
55+
async with running_pair(
56+
direct_pair,
57+
server_on_request=server_on_request,
58+
client_on_request=c_req,
59+
client_on_notify=c_notify,
60+
) as (client, *_):
61+
with anyio.fail_after(5):
62+
result = await client.send_request("tools/call", None)
63+
await crec.notified.wait()
64+
assert crec.requests == [("sampling/createMessage", {"x": 1})]
65+
assert crec.notifications == [("notifications/message", {"level": "info"})]
66+
assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}}
67+
68+
69+
@pytest.mark.anyio
70+
async def test_base_context_report_progress_invokes_caller_on_progress():
71+
received: list[tuple[float, float | None, str | None]] = []
72+
73+
async def on_progress(progress: float, total: float | None, message: str | None) -> None:
74+
received.append((progress, total, message))
75+
76+
async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
77+
bctx = BaseContext(ctx)
78+
await bctx.report_progress(0.5, total=1.0, message="halfway")
79+
return {}
80+
81+
async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_):
82+
with anyio.fail_after(5):
83+
await client.send_request("t", None, {"on_progress": on_progress})
84+
assert received == [(0.5, 1.0, "halfway")]
85+
86+
87+
@pytest.mark.anyio
88+
async def test_base_context_satisfies_outbound_so_peer_mixin_works():
89+
"""Wrapping a BaseContext in Peer proves it satisfies Outbound structurally."""
90+
91+
async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
92+
bctx = BaseContext(ctx)
93+
await Peer(bctx).ping()
94+
return {}
95+
96+
crec = Recorder()
97+
c_req, c_notify = echo_handlers(crec)
98+
async with running_pair(
99+
direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify
100+
) as (client, *_):
101+
with anyio.fail_after(5):
102+
await client.send_request("t", None)
103+
assert crec.requests == [("ping", None)]
104+
105+
106+
@pytest.mark.anyio
107+
async def test_base_context_meta_holds_supplied_request_params_meta():
108+
async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
109+
bctx = BaseContext(ctx, meta={"progressToken": "abc"})
110+
assert bctx.meta is not None and bctx.meta.get("progressToken") == "abc"
111+
return {}
112+
113+
async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_):
114+
with anyio.fail_after(5):
115+
await client.send_request("t", None)

0 commit comments

Comments
 (0)