Skip to content

Commit 870cb08

Browse files
committed
feat: PeerMixin and Peer wrapper
PeerMixin defines the typed server-to-client request methods (sample with overloads, elicit_form, elicit_url, list_roots, ping) once. Each method constrains `self: Outbound` so any class with send_request/notify can mix it in — pyright checks the host structurally at the call site. The mixin does no capability gating; that's the host's send_request's job. Peer is a trivial standalone wrapper for when you have a bare Outbound (e.g. a dispatcher) and want the typed sugar without writing your own host class. 6 tests over DirectDispatcher, 0.03s.
1 parent eb74b7c commit 870cb08

2 files changed

Lines changed: 322 additions & 0 deletions

File tree

src/mcp/shared/peer.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""Typed MCP request sugar over an `Outbound`.
2+
3+
`PeerMixin` defines the server-to-client request methods (sampling, elicitation,
4+
roots, ping) once. Any class that satisfies `Outbound` (i.e. has `send_request`
5+
and `notify`) can mix it in and get the typed methods for free — `Context`,
6+
`Connection`, `Client`, or the bare `Peer` wrapper below.
7+
8+
The mixin does no capability gating: it builds the params, calls
9+
``self.send_request(method, params)``, and parses the result into the typed
10+
model. Gating (and `NoBackChannelError`) is the host's `send_request`'s job.
11+
"""
12+
13+
from collections.abc import Mapping
14+
from typing import Any, overload
15+
16+
from pydantic import BaseModel
17+
18+
from mcp.shared.dispatcher import CallOptions, Outbound
19+
from mcp.types import (
20+
CreateMessageRequestParams,
21+
CreateMessageResult,
22+
CreateMessageResultWithTools,
23+
ElicitRequestedSchema,
24+
ElicitRequestFormParams,
25+
ElicitRequestURLParams,
26+
ElicitResult,
27+
IncludeContext,
28+
ListRootsResult,
29+
ModelPreferences,
30+
SamplingMessage,
31+
Tool,
32+
ToolChoice,
33+
)
34+
35+
__all__ = ["Peer", "PeerMixin"]
36+
37+
38+
def _dump(model: BaseModel) -> dict[str, Any]:
39+
return model.model_dump(by_alias=True, mode="json", exclude_none=True)
40+
41+
42+
class PeerMixin:
43+
"""Typed server-to-client request methods.
44+
45+
Each method constrains ``self`` to `Outbound` so the mixin can be applied
46+
to anything with ``send_request``/``notify`` — pyright checks the host
47+
class structurally at the call site.
48+
"""
49+
50+
@overload
51+
async def sample(
52+
self: Outbound,
53+
messages: list[SamplingMessage],
54+
*,
55+
max_tokens: int,
56+
system_prompt: str | None = None,
57+
include_context: IncludeContext | None = None,
58+
temperature: float | None = None,
59+
stop_sequences: list[str] | None = None,
60+
metadata: dict[str, Any] | None = None,
61+
model_preferences: ModelPreferences | None = None,
62+
tools: None = None,
63+
tool_choice: ToolChoice | None = None,
64+
opts: CallOptions | None = None,
65+
) -> CreateMessageResult: ...
66+
@overload
67+
async def sample(
68+
self: Outbound,
69+
messages: list[SamplingMessage],
70+
*,
71+
max_tokens: int,
72+
system_prompt: str | None = None,
73+
include_context: IncludeContext | None = None,
74+
temperature: float | None = None,
75+
stop_sequences: list[str] | None = None,
76+
metadata: dict[str, Any] | None = None,
77+
model_preferences: ModelPreferences | None = None,
78+
tools: list[Tool],
79+
tool_choice: ToolChoice | None = None,
80+
opts: CallOptions | None = None,
81+
) -> CreateMessageResultWithTools: ...
82+
async def sample(
83+
self: Outbound,
84+
messages: list[SamplingMessage],
85+
*,
86+
max_tokens: int,
87+
system_prompt: str | None = None,
88+
include_context: IncludeContext | None = None,
89+
temperature: float | None = None,
90+
stop_sequences: list[str] | None = None,
91+
metadata: dict[str, Any] | None = None,
92+
model_preferences: ModelPreferences | None = None,
93+
tools: list[Tool] | None = None,
94+
tool_choice: ToolChoice | None = None,
95+
opts: CallOptions | None = None,
96+
) -> CreateMessageResult | CreateMessageResultWithTools:
97+
"""Send a ``sampling/createMessage`` request to the peer.
98+
99+
Raises:
100+
MCPError: The peer responded with an error.
101+
NoBackChannelError: The host's transport context has no
102+
back-channel for server-initiated requests.
103+
"""
104+
params = CreateMessageRequestParams(
105+
messages=messages,
106+
system_prompt=system_prompt,
107+
include_context=include_context,
108+
temperature=temperature,
109+
max_tokens=max_tokens,
110+
stop_sequences=stop_sequences,
111+
metadata=metadata,
112+
model_preferences=model_preferences,
113+
tools=tools,
114+
tool_choice=tool_choice,
115+
)
116+
result = await self.send_request("sampling/createMessage", _dump(params), opts)
117+
if tools is not None:
118+
return CreateMessageResultWithTools.model_validate(result)
119+
return CreateMessageResult.model_validate(result)
120+
121+
async def elicit_form(
122+
self: Outbound,
123+
message: str,
124+
requested_schema: ElicitRequestedSchema,
125+
opts: CallOptions | None = None,
126+
) -> ElicitResult:
127+
"""Send a form-mode ``elicitation/create`` request.
128+
129+
Raises:
130+
MCPError: The peer responded with an error.
131+
NoBackChannelError: No back-channel for server-initiated requests.
132+
"""
133+
params = ElicitRequestFormParams(message=message, requested_schema=requested_schema)
134+
result = await self.send_request("elicitation/create", _dump(params), opts)
135+
return ElicitResult.model_validate(result)
136+
137+
async def elicit_url(
138+
self: Outbound,
139+
message: str,
140+
url: str,
141+
elicitation_id: str,
142+
opts: CallOptions | None = None,
143+
) -> ElicitResult:
144+
"""Send a URL-mode ``elicitation/create`` request.
145+
146+
Raises:
147+
MCPError: The peer responded with an error.
148+
NoBackChannelError: No back-channel for server-initiated requests.
149+
"""
150+
params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id)
151+
result = await self.send_request("elicitation/create", _dump(params), opts)
152+
return ElicitResult.model_validate(result)
153+
154+
async def list_roots(self: Outbound, opts: CallOptions | None = None) -> ListRootsResult:
155+
"""Send a ``roots/list`` request.
156+
157+
Raises:
158+
MCPError: The peer responded with an error.
159+
NoBackChannelError: No back-channel for server-initiated requests.
160+
"""
161+
result = await self.send_request("roots/list", None, opts)
162+
return ListRootsResult.model_validate(result)
163+
164+
async def ping(self: Outbound, opts: CallOptions | None = None) -> None:
165+
"""Send a ``ping`` request and ignore the result.
166+
167+
Raises:
168+
MCPError: The peer responded with an error.
169+
NoBackChannelError: No back-channel for server-initiated requests.
170+
"""
171+
await self.send_request("ping", None, opts)
172+
173+
174+
class Peer(PeerMixin):
175+
"""Standalone wrapper that gives any `Outbound` the `PeerMixin` sugar.
176+
177+
`Context` and `Connection` mix `PeerMixin` in directly; use `Peer` when
178+
you have a bare dispatcher (or any `Outbound`) and want the typed methods
179+
without writing your own host class.
180+
"""
181+
182+
def __init__(self, outbound: Outbound) -> None:
183+
self._outbound = outbound
184+
185+
async def send_request(
186+
self,
187+
method: str,
188+
params: Mapping[str, Any] | None,
189+
opts: CallOptions | None = None,
190+
) -> dict[str, Any]:
191+
return await self._outbound.send_request(method, params, opts)
192+
193+
async def notify(self, method: str, params: Mapping[str, Any] | None) -> None:
194+
await self._outbound.notify(method, params)

tests/shared/test_peer.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Tests for `PeerMixin` and `Peer`.
2+
3+
Each PeerMixin method is tested by wrapping a `DirectDispatcher` in `Peer`,
4+
calling the typed method, and asserting (a) the right method+params went out
5+
and (b) the return value is the typed result model.
6+
"""
7+
8+
from collections.abc import Mapping
9+
from typing import Any
10+
11+
import anyio
12+
import pytest
13+
14+
from mcp.shared.dispatcher import DispatchContext
15+
from mcp.shared.peer import Peer
16+
from mcp.shared.transport_context import TransportContext
17+
from mcp.types import (
18+
CreateMessageResult,
19+
CreateMessageResultWithTools,
20+
ElicitResult,
21+
ListRootsResult,
22+
SamplingMessage,
23+
TextContent,
24+
Tool,
25+
)
26+
27+
from .conftest import direct_pair
28+
from .test_dispatcher import running_pair
29+
30+
DCtx = DispatchContext[TransportContext]
31+
32+
33+
class _Recorder:
34+
def __init__(self, result: dict[str, Any]) -> None:
35+
self.result = result
36+
self.seen: list[tuple[str, Mapping[str, Any] | None]] = []
37+
38+
async def on_request(self, ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
39+
self.seen.append((method, params))
40+
return self.result
41+
42+
43+
@pytest.mark.anyio
44+
async def test_peer_sample_sends_create_message_and_returns_typed_result():
45+
rec = _Recorder({"role": "assistant", "content": {"type": "text", "text": "hi"}, "model": "m"})
46+
async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_):
47+
peer = Peer(client)
48+
with anyio.fail_after(5):
49+
result = await peer.sample(
50+
[SamplingMessage(role="user", content=TextContent(type="text", text="hello"))],
51+
max_tokens=10,
52+
)
53+
method, params = rec.seen[0]
54+
assert method == "sampling/createMessage"
55+
assert params is not None and params["maxTokens"] == 10
56+
assert isinstance(result, CreateMessageResult)
57+
assert result.model == "m"
58+
59+
60+
@pytest.mark.anyio
61+
async def test_peer_sample_with_tools_returns_with_tools_result():
62+
rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"})
63+
async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_):
64+
peer = Peer(client)
65+
with anyio.fail_after(5):
66+
result = await peer.sample(
67+
[SamplingMessage(role="user", content=TextContent(type="text", text="q"))],
68+
max_tokens=5,
69+
tools=[Tool(name="t", input_schema={"type": "object"})],
70+
)
71+
method, params = rec.seen[0]
72+
assert method == "sampling/createMessage"
73+
assert params is not None and params["tools"][0]["name"] == "t"
74+
assert isinstance(result, CreateMessageResultWithTools)
75+
76+
77+
@pytest.mark.anyio
78+
async def test_peer_elicit_form_sends_elicitation_create_with_form_params():
79+
rec = _Recorder({"action": "accept", "content": {"name": "Max"}})
80+
async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_):
81+
peer = Peer(client)
82+
with anyio.fail_after(5):
83+
result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}})
84+
method, params = rec.seen[0]
85+
assert method == "elicitation/create"
86+
assert params is not None and params["mode"] == "form"
87+
assert params["message"] == "Your name?"
88+
assert isinstance(result, ElicitResult)
89+
90+
91+
@pytest.mark.anyio
92+
async def test_peer_elicit_url_sends_elicitation_create_with_url_params():
93+
rec = _Recorder({"action": "accept"})
94+
async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_):
95+
peer = Peer(client)
96+
with anyio.fail_after(5):
97+
result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1")
98+
method, params = rec.seen[0]
99+
assert method == "elicitation/create"
100+
assert params is not None and params["mode"] == "url"
101+
assert params["url"] == "https://example.com/auth"
102+
assert isinstance(result, ElicitResult)
103+
104+
105+
@pytest.mark.anyio
106+
async def test_peer_list_roots_sends_roots_list_and_returns_typed_result():
107+
rec = _Recorder({"roots": [{"uri": "file:///workspace"}]})
108+
async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_):
109+
peer = Peer(client)
110+
with anyio.fail_after(5):
111+
result = await peer.list_roots()
112+
method, _ = rec.seen[0]
113+
assert method == "roots/list"
114+
assert isinstance(result, ListRootsResult)
115+
assert len(result.roots) == 1
116+
assert str(result.roots[0].uri) == "file:///workspace"
117+
118+
119+
@pytest.mark.anyio
120+
async def test_peer_ping_sends_ping_and_returns_none():
121+
rec = _Recorder({})
122+
async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_):
123+
peer = Peer(client)
124+
with anyio.fail_after(5):
125+
result = await peer.ping()
126+
method, _ = rec.seen[0]
127+
assert method == "ping"
128+
assert result is None

0 commit comments

Comments
 (0)