Skip to content

Commit 7dc1527

Browse files
committed
feat: ServerRunner middleware (two-tier) + _on_notify
ContextMiddleware is a Protocol[L] (contravariant) so Server[L].middleware: list[ContextMiddleware[L]] is properly typed. App-specific middleware sees ctx.lifespan: L; reusable middleware typed ContextMiddleware[object] registers on any Server via contravariance. Context's covariance (previous PR3 commit) makes Context[L, ST] <: Context[L, TransportContext] so the chain composes without casts. dispatch_middleware (DispatchMiddleware list on ServerRunner) wraps the raw _on_request and sees everything including initialize/METHOD_NOT_FOUND. server.middleware (ContextMiddleware) runs inside _on_request after validation/ctx-build and wraps registered handlers only. _on_notify routes notifications/initialized (sets the flag), drops before-init and unknown methods, otherwise builds Context and calls the registered handler. 11 tests over DirectDispatcher + a real lowlevel Server.
1 parent 52d0494 commit 7dc1527

File tree

4 files changed

+201
-13
lines changed

4 files changed

+201
-13
lines changed

src/mcp/server/context.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3+
from collections.abc import Awaitable, Callable
34
from dataclasses import dataclass
4-
from typing import Any, Generic
5+
from typing import Any, Generic, Protocol
56

7+
from pydantic import BaseModel
68
from typing_extensions import TypeVar
79

810
from mcp.server._typed_request import TypedServerRequestMixin
@@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
8183
if meta:
8284
params["_meta"] = meta
8385
await self.notify("notifications/message", params)
86+
87+
88+
HandlerResult = BaseModel | dict[str, Any] | None
89+
"""What a request handler (or middleware) may return. `ServerRunner` serializes
90+
all three to a result dict."""
91+
92+
CallNext = Callable[[], Awaitable[HandlerResult]]
93+
94+
_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True)
95+
96+
97+
class ContextMiddleware(Protocol[_MwLifespanT]):
98+
"""Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``.
99+
100+
Runs *inside* `ServerRunner._on_request` after params validation and
101+
`Context` construction. Wraps registered handlers (including ``ping``) but
102+
not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed
103+
outermost-first on `Server.middleware`.
104+
105+
`Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific
106+
middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific
107+
types) can be typed `ContextMiddleware[object]` — `Context` is covariant in
108+
`LifespanT`, so it registers on any `Server[L]`.
109+
"""
110+
111+
async def __call__(
112+
self,
113+
ctx: Context[_MwLifespanT, TransportContext],
114+
method: str,
115+
params: BaseModel,
116+
call_next: CallNext,
117+
) -> HandlerResult: ...

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def main():
5858
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
5959
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
6060
from mcp.server.auth.settings import AuthSettings
61-
from mcp.server.context import ServerRequestContext
61+
from mcp.server.context import ContextMiddleware, ServerRequestContext
6262
from mcp.server.experimental.request_context import Experimental
6363
from mcp.server.lowlevel.experimental import ExperimentalHandlers
6464
from mcp.server.models import InitializationOptions
@@ -199,6 +199,9 @@ def __init__(
199199
] = {}
200200
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
201201
self._session_manager: StreamableHTTPSessionManager | None = None
202+
# Context-tier middleware consumed by `ServerRunner`. Additive; the
203+
# existing `run()` path ignores it.
204+
self.middleware: list[ContextMiddleware[LifespanResultT]] = []
202205
logger.debug("Initializing server %r", name)
203206

204207
# Populate internal handler dicts from on_* kwargs
@@ -256,11 +259,6 @@ def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]]
256259
"""Return the handler for a notification method, or ``None``."""
257260
return self._notification_handlers.get(method)
258261

259-
@property
260-
def middleware(self) -> list[Any]:
261-
"""Context-tier middleware. Empty until the registry refactor adds registration."""
262-
return []
263-
264262
@property
265263
def connection_lifespan(self) -> None:
266264
"""Per-connection lifespan. ``None`` until the registry refactor adds it."""

src/mcp/server/runner.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@
1717
from __future__ import annotations
1818

1919
import logging
20-
from collections.abc import Awaitable, Callable, Mapping
20+
from collections.abc import Awaitable, Callable, Mapping, Sequence
2121
from dataclasses import dataclass, field
22+
from functools import partial, reduce
2223
from typing import Any, Generic, Protocol, cast
2324

2425
from pydantic import BaseModel
2526
from typing_extensions import TypeVar
2627

2728
from mcp.server.connection import Connection
28-
from mcp.server.context import Context
29+
from mcp.server.context import CallNext, Context, ContextMiddleware
2930
from mcp.server.lowlevel.server import NotificationOptions
30-
from mcp.shared.dispatcher import DispatchContext, Dispatcher
31+
from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest
3132
from mcp.shared.exceptions import MCPError
3233
from mcp.shared.transport_context import TransportContext
3334
from mcp.types import (
@@ -51,7 +52,7 @@
5152
UnsubscribeRequestParams,
5253
)
5354

54-
__all__ = ["ServerRegistry", "ServerRunner"]
55+
__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner"]
5556

5657
logger = logging.getLogger(__name__)
5758

@@ -64,6 +65,7 @@
6465
`Context`-based handlers both fit during the transition.
6566
"""
6667

68+
6769
_INIT_EXEMPT: frozenset[str] = frozenset({"ping"})
6870

6971
# TODO: remove this lookup once `Server` stores (params_type, handler) in its
@@ -105,6 +107,9 @@ def name(self) -> str: ...
105107
@property
106108
def version(self) -> str | None: ...
107109

110+
@property
111+
def middleware(self) -> Sequence[ContextMiddleware[Any]]: ...
112+
108113
def get_request_handler(self, method: str) -> Handler | None: ...
109114
def get_notification_handler(self, method: str) -> Handler | None: ...
110115
def get_capabilities(
@@ -131,6 +136,7 @@ class ServerRunner(Generic[LifespanT, ServerTransportT]):
131136
lifespan_state: LifespanT
132137
has_standalone_channel: bool
133138
stateless: bool = False
139+
dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware])
134140

135141
connection: Connection = field(init=False)
136142
_initialized: bool = field(init=False)
@@ -139,6 +145,16 @@ def __post_init__(self) -> None:
139145
self._initialized = self.stateless
140146
self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel)
141147

148+
def _compose_on_request(self) -> OnRequest:
149+
"""Wrap `_on_request` in `dispatch_middleware`, outermost-first.
150+
151+
Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict``
152+
and wraps everything — initialize, METHOD_NOT_FOUND, validation
153+
failures included. `run()` calls this once and hands the result to
154+
`dispatcher.run()`.
155+
"""
156+
return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request)
157+
142158
async def _on_request(
143159
self,
144160
dctx: DispatchContext[TransportContext],
@@ -162,8 +178,10 @@ async def _on_request(
162178
# it to INVALID_PARAMS.
163179
typed_params = params_type.model_validate(params or {})
164180
ctx = self._make_context(dctx, typed_params)
165-
result = await handler(ctx, typed_params)
166-
return _dump_result(result)
181+
call: CallNext = partial(handler, ctx, typed_params)
182+
for mw in reversed(self.server.middleware):
183+
call = partial(mw, ctx, method, typed_params, call)
184+
return _dump_result(await call())
167185

168186
async def _on_notify(
169187
self,

tests/server/test_runner.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any
99

1010
import anyio
11+
import anyio.lowlevel
1112
import pytest
1213

1314
from mcp.server.connection import Connection
@@ -134,6 +135,143 @@ async def test_runner_unknown_method_raises_method_not_found(server: SrvT):
134135
tg.cancel_scope.cancel()
135136

136137

138+
@pytest.mark.anyio
139+
async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT):
140+
client, server_d = create_direct_dispatcher_pair()
141+
runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True)
142+
c_req, c_notify = echo_handlers(Recorder())
143+
async with anyio.create_task_group() as tg:
144+
await tg.start(client.run, c_req, c_notify)
145+
await tg.start(server_d.run, runner._on_request, runner._on_notify)
146+
with anyio.fail_after(5):
147+
await client.notify("notifications/initialized", None)
148+
await runner.connection.initialized.wait()
149+
assert runner._initialized is True
150+
tg.cancel_scope.cancel()
151+
152+
153+
@pytest.mark.anyio
154+
async def test_runner_on_notify_routes_to_registered_handler(server: SrvT):
155+
seen: list[tuple[Any, Any]] = []
156+
157+
async def on_roots_changed(ctx: Any, params: Any) -> None:
158+
seen.append((ctx, params))
159+
160+
server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed
161+
client, server_d = create_direct_dispatcher_pair()
162+
runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True)
163+
runner._initialized = True
164+
c_req, c_notify = echo_handlers(Recorder())
165+
async with anyio.create_task_group() as tg:
166+
await tg.start(client.run, c_req, c_notify)
167+
await tg.start(server_d.run, runner._on_request, runner._on_notify)
168+
with anyio.fail_after(5):
169+
await client.notify("notifications/roots/list_changed", None)
170+
# DirectDispatcher delivers synchronously; one yield is enough.
171+
await anyio.lowlevel.checkpoint()
172+
assert len(seen) == 1
173+
assert isinstance(seen[0][0], Context)
174+
tg.cancel_scope.cancel()
175+
176+
177+
@pytest.mark.anyio
178+
async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT):
179+
client, server_d = create_direct_dispatcher_pair()
180+
runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True)
181+
c_req, c_notify = echo_handlers(Recorder())
182+
async with anyio.create_task_group() as tg:
183+
await tg.start(client.run, c_req, c_notify)
184+
await tg.start(server_d.run, runner._on_request, runner._on_notify)
185+
with anyio.fail_after(5):
186+
await client.notify("notifications/roots/list_changed", None) # before init: dropped
187+
await client.notify("notifications/initialized", None)
188+
await client.notify("notifications/unknown", None) # no handler: dropped
189+
# No exception raised; both drops are silent.
190+
tg.cancel_scope.cancel()
191+
192+
193+
@pytest.mark.anyio
194+
async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT):
195+
seen_methods: list[str] = []
196+
197+
def trace_mw(next_on_request: Any) -> Any:
198+
async def wrapped(dctx: Any, method: str, params: Any) -> Any:
199+
seen_methods.append(method)
200+
return await next_on_request(dctx, method, params)
201+
202+
return wrapped
203+
204+
client, server_d = create_direct_dispatcher_pair()
205+
runner = ServerRunner(
206+
server=server,
207+
dispatcher=server_d,
208+
lifespan_state=None,
209+
has_standalone_channel=True,
210+
dispatch_middleware=[trace_mw],
211+
)
212+
c_req, c_notify = echo_handlers(Recorder())
213+
on_req = runner._compose_on_request()
214+
async with anyio.create_task_group() as tg:
215+
await tg.start(client.run, c_req, c_notify)
216+
await tg.start(server_d.run, on_req, runner._on_notify)
217+
with anyio.fail_after(5):
218+
await client.send_raw_request("initialize", _initialize_params())
219+
await client.send_raw_request("tools/list", None)
220+
assert seen_methods == ["initialize", "tools/list"]
221+
tg.cancel_scope.cancel()
222+
223+
224+
@pytest.mark.anyio
225+
async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT):
226+
seen_methods: list[str] = []
227+
228+
async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any:
229+
seen_methods.append(method)
230+
return await call_next()
231+
232+
server.middleware.append(ctx_mw)
233+
client, server_d = create_direct_dispatcher_pair()
234+
runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True)
235+
c_req, c_notify = echo_handlers(Recorder())
236+
async with anyio.create_task_group() as tg:
237+
await tg.start(client.run, c_req, c_notify)
238+
await tg.start(server_d.run, runner._on_request, runner._on_notify)
239+
with anyio.fail_after(5):
240+
await client.send_raw_request("initialize", _initialize_params())
241+
await client.send_raw_request("ping", None)
242+
await client.send_raw_request("tools/list", None)
243+
# initialize NOT wrapped; ping and tools/list ARE wrapped.
244+
assert seen_methods == ["ping", "tools/list"]
245+
tg.cancel_scope.cancel()
246+
247+
248+
@pytest.mark.anyio
249+
async def test_runner_server_middleware_runs_outermost_first(server: SrvT):
250+
order: list[str] = []
251+
252+
def make_mw(tag: str) -> Any:
253+
async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any:
254+
order.append(f"{tag}-in")
255+
result = await call_next()
256+
order.append(f"{tag}-out")
257+
return result
258+
259+
return mw
260+
261+
server.middleware.extend([make_mw("a"), make_mw("b")])
262+
client, server_d = create_direct_dispatcher_pair()
263+
runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True)
264+
runner._initialized = True
265+
c_req, c_notify = echo_handlers(Recorder())
266+
async with anyio.create_task_group() as tg:
267+
await tg.start(client.run, c_req, c_notify)
268+
await tg.start(server_d.run, runner._on_request, runner._on_notify)
269+
with anyio.fail_after(5):
270+
await client.send_raw_request("tools/list", None)
271+
assert order == ["a-in", "b-in", "b-out", "a-out"]
272+
tg.cancel_scope.cancel()
273+
274+
137275
@pytest.mark.anyio
138276
async def test_runner_stateless_skips_init_gate(server: SrvT):
139277
client, server_d = create_direct_dispatcher_pair()

0 commit comments

Comments
 (0)