diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index 4d35f8a90..1c855ae48 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,8 +1,10 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol +from pydantic import BaseModel from typing_extensions import TypeVar from mcp.server._typed_request import TypedServerRequestMixin @@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * if meta: params["_meta"] = meta await self.notify("notifications/message", params) + + +HandlerResult = BaseModel | dict[str, Any] | None +"""What a request handler (or middleware) may return. `ServerRunner` serializes +all three to a result dict.""" + +CallNext = Callable[[], Awaitable[HandlerResult]] + +_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True) + + +class ContextMiddleware(Protocol[_MwLifespanT]): + """Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``. + + Runs *inside* `ServerRunner._on_request` after params validation and + `Context` construction. Wraps registered handlers (including ``ping``) but + not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed + outermost-first on `Server.middleware`. + + `Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific + types) can be typed `ContextMiddleware[object]` — `Context` is covariant in + `LifespanT`, so it registers on any `Server[L]`. + """ + + async def __call__( + self, + ctx: Context[_MwLifespanT, TransportContext], + method: str, + params: BaseModel, + call_next: CallNext, + ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 59de0ace4..12d911fe7 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -58,7 +58,7 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ServerRequestContext +from mcp.server.context import ContextMiddleware, ServerRequestContext from mcp.server.experimental.request_context import Experimental from mcp.server.lowlevel.experimental import ExperimentalHandlers from mcp.server.models import InitializationOptions @@ -199,6 +199,9 @@ def __init__( ] = {} self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None self._session_manager: StreamableHTTPSessionManager | None = None + # Context-tier middleware consumed by `ServerRunner`. Additive; the + # existing `run()` path ignores it. + self.middleware: list[ContextMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) # Populate internal handler dicts from on_* kwargs @@ -246,6 +249,21 @@ def _has_handler(self, method: str) -> bool: """Check if a handler is registered for the given method.""" return method in self._request_handlers or method in self._notification_handlers + # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ + + def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a request method, or ``None``.""" + return self._request_handlers.get(method) + + def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None: + """Return the handler for a notification method, or ``None``.""" + return self._notification_handlers.get(method) + + @property + def connection_lifespan(self) -> None: + """Per-connection lifespan. ``None`` until the registry refactor adds it.""" + return None + # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py new file mode 100644 index 000000000..79dfc23e0 --- /dev/null +++ b/src/mcp/server/runner.py @@ -0,0 +1,287 @@ +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. + +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, +typed params). One instance per client connection. It: + +* handles the ``initialize`` handshake and populates `Connection` +* gates requests until initialized (``ping`` exempt) +* looks up the handler in the server's registry, validates params, builds + `Context`, runs the middleware chain, returns the result dict +* drives ``dispatcher.run()`` and the per-connection lifespan + +`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies +it via additive methods so the existing ``Server.run()`` path is unaffected. +""" + +from __future__ import annotations + +import logging +from collections.abc import Awaitable, Callable, Mapping, Sequence +from dataclasses import dataclass, field +from functools import partial, reduce +from typing import Any, Generic, Protocol, cast + +import anyio.abc +from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel +from typing_extensions import TypeVar + +from mcp.server.connection import Connection +from mcp.server.context import CallNext, Context, ContextMiddleware +from mcp.server.lowlevel.server import NotificationOptions +from mcp.shared._otel import extract_trace_context, otel_span +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + CallToolRequestParams, + CompleteRequestParams, + GetPromptRequestParams, + Implementation, + InitializeRequestParams, + InitializeResult, + NotificationParams, + PaginatedRequestParams, + ProgressNotificationParams, + ReadResourceRequestParams, + RequestParams, + ServerCapabilities, + SetLevelRequestParams, + SubscribeRequestParams, + UnsubscribeRequestParams, +) + +__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"] + +logger = logging.getLogger(__name__) + +LifespanT = TypeVar("LifespanT", default=Any) +ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) + +Handler = Callable[..., Awaitable[Any]] +"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely +so the existing `ServerRequestContext`-based handlers and the new +`Context`-based handlers both fit during the transition. +""" + + +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) + +# TODO: remove this lookup once `Server` stores (params_type, handler) in its +# registry directly. This is scaffolding so ServerRunner can validate params +# without changing the existing `_request_handlers` dict shape. +_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { + "ping": RequestParams, + "tools/list": PaginatedRequestParams, + "tools/call": CallToolRequestParams, + "prompts/list": PaginatedRequestParams, + "prompts/get": GetPromptRequestParams, + "resources/list": PaginatedRequestParams, + "resources/templates/list": PaginatedRequestParams, + "resources/read": ReadResourceRequestParams, + "resources/subscribe": SubscribeRequestParams, + "resources/unsubscribe": UnsubscribeRequestParams, + "logging/setLevel": SetLevelRequestParams, + "completion/complete": CompleteRequestParams, +} +"""Spec method → params model. Scaffolding while the lowlevel `Server`'s +`_request_handlers` stores handler-only; the registry refactor should make this +the registry's responsibility (or store params types alongside handlers).""" + +_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { + "notifications/initialized": NotificationParams, + "notifications/roots/list_changed": NotificationParams, + "notifications/progress": ProgressNotificationParams, +} + + +class ServerRegistry(Protocol): + """The handler registry `ServerRunner` consumes. + + The lowlevel `Server` satisfies this via additive methods. + """ + + @property + def name(self) -> str: ... + @property + def version(self) -> str | None: ... + + @property + def middleware(self) -> Sequence[ContextMiddleware[Any]]: ... + + def get_request_handler(self, method: str) -> Handler | None: ... + def get_notification_handler(self, method: str) -> Handler | None: ... + def get_capabilities( + self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] + ) -> ServerCapabilities: ... + + +def otel_middleware(next_on_request: OnRequest) -> OnRequest: + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. + + Mirrors the span shape of the existing `Server._handle_request`: span name + ``"MCP handle []"``, ``mcp.method.name`` attribute, W3C + trace context extracted from ``params._meta`` (SEP-414), and an ERROR + status if the handler raises. + """ + + async def wrapped( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + target: str | None + match params: + case {"name": str() as target}: + pass + case _: + target = None + parent: Any | None + match params: + case {"_meta": {**meta}}: + parent = extract_trace_context(meta) + case _: + parent = None + span_name = f"MCP handle {method}{f' {target}' if target else ''}" + with otel_span(span_name, kind=SpanKind.SERVER, attributes={"mcp.method.name": method}, context=parent) as span: + try: + return await next_on_request(dctx, method, params) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except Exception as e: + span.set_status(StatusCode.ERROR, str(e)) + raise + + return wrapped + + +def _dump_result(result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, BaseModel): + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(result, dict): + return cast(dict[str, Any], result) + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") + + +@dataclass +class ServerRunner(Generic[LifespanT, ServerTransportT]): + """Per-connection orchestrator. One instance per client connection.""" + + server: ServerRegistry + dispatcher: Dispatcher[ServerTransportT] + lifespan_state: LifespanT + has_standalone_channel: bool + stateless: bool = False + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) + + connection: Connection = field(init=False) + _initialized: bool = field(init=False) + + def __post_init__(self) -> None: + self._initialized = self.stateless + self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) + + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + """Drive the dispatcher until the underlying channel closes. + + Composes `dispatch_middleware` over `_on_request` and hands the result + to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers + can ``await tg.start(runner.run)`` and resume once the dispatcher is + ready to accept requests. + """ + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + + def _compose_on_request(self) -> OnRequest: + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict`` + and wraps everything — initialize, METHOD_NOT_FOUND, validation + failures included. `run()` calls this once and hands the result to + `dispatcher.run()`. + """ + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + + async def _on_request( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> dict[str, Any]: + if method == "initialize": + return self._handle_initialize(params) + if not self._initialized and method not in _INIT_EXEMPT: + raise MCPError( + code=INVALID_REQUEST, + message=f"Received {method!r} before initialization was complete", + ) + handler = self.server.get_request_handler(method) + if handler is None: + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") + # TODO: scaffolding — params_type comes from a static lookup until the + # registry stores it alongside the handler. + params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) + # ValidationError propagates; the dispatcher's exception boundary maps + # it to INVALID_PARAMS. + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + call: CallNext = partial(handler, ctx, typed_params) + for mw in reversed(self.server.middleware): + call = partial(mw, ctx, method, typed_params, call) + return _dump_result(await call()) + + async def _on_notify( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> None: + if method == "notifications/initialized": + self._initialized = True + self.connection.initialized.set() + return + if not self._initialized: + logger.debug("dropped %s: received before initialization", method) + return + handler = self.server.get_notification_handler(method) + if handler is None: + logger.debug("no handler for notification %s", method) + return + params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) + typed_params = params_type.model_validate(params or {}) + ctx = self._make_context(dctx, typed_params) + await handler(ctx, typed_params) + + def _make_context( + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel + ) -> Context[LifespanT, ServerTransportT]: + # `OnRequest` delivers `DispatchContext[TransportContext]`; this + # ServerRunner instance was constructed for a specific + # `ServerTransportT`, so the narrow is safe by construction. + narrowed = cast(DispatchContext[ServerTransportT], dctx) + meta = getattr(typed_params, "meta", None) + return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) + + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: + init = InitializeRequestParams.model_validate(params or {}) + self.connection.client_info = init.client_info + self.connection.client_capabilities = init.capabilities + # TODO: real version negotiation. This always responds with LATEST, + # which is wrong — the server should pick the highest version both + # sides support and compute a per-connection feature set from it. + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". + self.connection.protocol_version = ( + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION + ) + self._initialized = True + self.connection.initialized.set() + result = InitializeResult( + protocol_version=self.connection.protocol_version, + capabilities=self.server.get_capabilities(NotificationOptions(), {}), + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), + ) + return _dump_result(result) diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py new file mode 100644 index 000000000..3d2fd84c0 --- /dev/null +++ b/tests/server/test_runner.py @@ -0,0 +1,361 @@ +"""Tests for `ServerRunner`. + +End-to-end over `DirectDispatcher` with a real lowlevel `Server` as the +registry. Covers `_on_request` routing, the initialize handshake, the +init-gate, and that handlers receive a fully-built `Context`. +""" + +from typing import Any + +import anyio +import anyio.lowlevel +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.server.lowlevel.server import Server +from mcp.server.runner import ServerRunner, otel_middleware +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + INVALID_REQUEST, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + ClientCapabilities, + Implementation, + InitializeRequestParams, + Tool, +) + +from ..shared.test_dispatcher import Recorder, echo_handlers + + +def _initialize_params() -> dict[str, Any]: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test-client", version="1.0"), + ).model_dump(by_alias=True, exclude_none=True) + + +_seen_ctx: list[Context[Any, TransportContext]] = [] +SrvT = Server[dict[str, Any]] + + +@pytest.fixture +def server() -> SrvT: + """A lowlevel Server with one tools/list handler registered.""" + _seen_ctx.clear() + + async def list_tools(ctx: Any, params: Any) -> Any: + # ctx is typed `Any` because Server's on_list_tools kwarg expects the + # legacy ServerRequestContext shape; ServerRunner passes the new + # `Context`. The transition is intentional — Handler is loosely typed. + _seen_ctx.append(ctx) + return {"tools": [Tool(name="t", input_schema={"type": "object"}).model_dump(by_alias=True)]} + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_gates_requests_before_initialize(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == INVALID_REQUEST + # ping is exempt + assert await client.send_raw_request("ping", None) == {} + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_routes_to_handler_after_initialize_and_builds_context(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, Context) + assert ctx.lifespan is None + assert isinstance(ctx.connection, Connection) + assert ctx.transport.kind == "direct" + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_unknown_method_raises_method_not_found(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True # bypass gate for this test + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): + seen: list[tuple[Any, Any]] = [] + + async def on_roots_changed(ctx: Any, params: Any) -> None: + seen.append((ctx, params)) + + server._notification_handlers["notifications/roots/list_changed"] = on_roots_changed + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/roots/list_changed", None) + # DirectDispatcher delivers synchronously; one yield is enough. + await anyio.lowlevel.checkpoint() + assert len(seen) == 1 + assert isinstance(seen[0][0], Context) + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + # No exception raised; both drops are silent. + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT): + seen_methods: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen_methods.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[trace_mw], + ) + c_req, c_notify = echo_handlers(Recorder()) + on_req = runner._compose_on_request() + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, on_req, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): + seen_methods: list[str] = [] + + async def ctx_mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + seen_methods.append(method) + return await call_next() + + server.middleware.append(ctx_mw) + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize NOT wrapped; ping and tools/list ARE wrapped. + assert seen_methods == ["ping", "tools/list"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_server_middleware_runs_outermost_first(server: SrvT): + order: list[str] = [] + + def make_mw(tag: str) -> Any: + async def mw(ctx: Any, method: str, params: Any, call_next: Any) -> Any: + order.append(f"{tag}-in") + result = await call_next() + order.append(f"{tag}-out") + return result + + return mw + + server.middleware.extend([make_mw("a"), make_mw("b")]) + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + runner._initialized = True + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_run_drives_dispatcher_end_to_end(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner(server=server, dispatcher=server_d, lifespan_state=None, has_standalone_channel=True) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + init = await client.send_raw_request("initialize", _initialize_params()) + tools = await client.send_raw_request("tools/list", None) + assert init["serverInfo"]["name"] == "test-server" + assert tools["tools"][0]["name"] == "t" + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_run_applies_dispatch_middleware(server: SrvT): + seen: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[trace_mw], + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + await client.send_raw_request("ping", None) + assert seen == ["initialize", "ping"] + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_otel_middleware_passes_through_result_and_survives_handler_error(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=True, + dispatch_middleware=[otel_middleware], + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + tools = await client.send_raw_request("tools/list", None) + assert tools["tools"][0]["name"] == "t" + with pytest.raises(MCPError): + await client.send_raw_request("nonexistent/method", None) + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + client, server_d = create_direct_dispatcher_pair() + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state=None, + has_standalone_channel=False, + stateless=True, + ) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(server_d.run, runner._on_request, runner._on_notify) + with anyio.fail_after(5): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + tg.cancel_scope.cancel()