|
| 1 | +"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`. |
| 2 | +
|
| 3 | +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / |
| 4 | +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, |
| 5 | +typed params). One instance per client connection. It: |
| 6 | +
|
| 7 | +* handles the ``initialize`` handshake and populates `Connection` |
| 8 | +* gates requests until initialized (``ping`` exempt) |
| 9 | +* looks up the handler in the server's registry, validates params, builds |
| 10 | + `Context`, runs the middleware chain, returns the result dict |
| 11 | +* drives ``dispatcher.run()`` and the per-connection lifespan |
| 12 | +
|
| 13 | +`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies |
| 14 | +it via additive methods so the existing ``Server.run()`` path is unaffected. |
| 15 | +""" |
| 16 | + |
| 17 | +from __future__ import annotations |
| 18 | + |
| 19 | +import logging |
| 20 | +from collections.abc import Awaitable, Callable, Mapping |
| 21 | +from dataclasses import dataclass, field |
| 22 | +from typing import Any, Generic, Protocol, cast |
| 23 | + |
| 24 | +from pydantic import BaseModel |
| 25 | +from typing_extensions import TypeVar |
| 26 | + |
| 27 | +from mcp.server.connection import Connection |
| 28 | +from mcp.server.context import Context |
| 29 | +from mcp.server.lowlevel.server import NotificationOptions |
| 30 | +from mcp.shared.dispatcher import DispatchContext, Dispatcher |
| 31 | +from mcp.shared.exceptions import MCPError |
| 32 | +from mcp.shared.transport_context import TransportContext |
| 33 | +from mcp.types import ( |
| 34 | + INVALID_REQUEST, |
| 35 | + LATEST_PROTOCOL_VERSION, |
| 36 | + METHOD_NOT_FOUND, |
| 37 | + CallToolRequestParams, |
| 38 | + CompleteRequestParams, |
| 39 | + GetPromptRequestParams, |
| 40 | + Implementation, |
| 41 | + InitializeRequestParams, |
| 42 | + InitializeResult, |
| 43 | + NotificationParams, |
| 44 | + PaginatedRequestParams, |
| 45 | + ProgressNotificationParams, |
| 46 | + ReadResourceRequestParams, |
| 47 | + RequestParams, |
| 48 | + ServerCapabilities, |
| 49 | + SetLevelRequestParams, |
| 50 | + SubscribeRequestParams, |
| 51 | + UnsubscribeRequestParams, |
| 52 | +) |
| 53 | + |
| 54 | +__all__ = ["ServerRegistry", "ServerRunner"] |
| 55 | + |
| 56 | +logger = logging.getLogger(__name__) |
| 57 | + |
| 58 | +LifespanT = TypeVar("LifespanT", default=Any) |
| 59 | +ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext) |
| 60 | + |
| 61 | +Handler = Callable[..., Awaitable[Any]] |
| 62 | +"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely |
| 63 | +so the existing `ServerRequestContext`-based handlers and the new |
| 64 | +`Context`-based handlers both fit during the transition. |
| 65 | +""" |
| 66 | + |
| 67 | +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) |
| 68 | + |
| 69 | +# TODO: remove this lookup once `Server` stores (params_type, handler) in its |
| 70 | +# registry directly. This is scaffolding so ServerRunner can validate params |
| 71 | +# without changing the existing `_request_handlers` dict shape. |
| 72 | +_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = { |
| 73 | + "ping": RequestParams, |
| 74 | + "tools/list": PaginatedRequestParams, |
| 75 | + "tools/call": CallToolRequestParams, |
| 76 | + "prompts/list": PaginatedRequestParams, |
| 77 | + "prompts/get": GetPromptRequestParams, |
| 78 | + "resources/list": PaginatedRequestParams, |
| 79 | + "resources/templates/list": PaginatedRequestParams, |
| 80 | + "resources/read": ReadResourceRequestParams, |
| 81 | + "resources/subscribe": SubscribeRequestParams, |
| 82 | + "resources/unsubscribe": UnsubscribeRequestParams, |
| 83 | + "logging/setLevel": SetLevelRequestParams, |
| 84 | + "completion/complete": CompleteRequestParams, |
| 85 | +} |
| 86 | +"""Spec method → params model. Scaffolding while the lowlevel `Server`'s |
| 87 | +`_request_handlers` stores handler-only; the registry refactor should make this |
| 88 | +the registry's responsibility (or store params types alongside handlers).""" |
| 89 | + |
| 90 | +_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = { |
| 91 | + "notifications/initialized": NotificationParams, |
| 92 | + "notifications/roots/list_changed": NotificationParams, |
| 93 | + "notifications/progress": ProgressNotificationParams, |
| 94 | +} |
| 95 | + |
| 96 | + |
| 97 | +class ServerRegistry(Protocol): |
| 98 | + """The handler registry `ServerRunner` consumes. |
| 99 | +
|
| 100 | + The lowlevel `Server` satisfies this via additive methods. |
| 101 | + """ |
| 102 | + |
| 103 | + @property |
| 104 | + def name(self) -> str: ... |
| 105 | + @property |
| 106 | + def version(self) -> str | None: ... |
| 107 | + |
| 108 | + def get_request_handler(self, method: str) -> Handler | None: ... |
| 109 | + def get_notification_handler(self, method: str) -> Handler | None: ... |
| 110 | + def get_capabilities( |
| 111 | + self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]] |
| 112 | + ) -> ServerCapabilities: ... |
| 113 | + |
| 114 | + |
| 115 | +def _dump_result(result: Any) -> dict[str, Any]: |
| 116 | + if result is None: |
| 117 | + return {} |
| 118 | + if isinstance(result, BaseModel): |
| 119 | + return result.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 120 | + if isinstance(result, dict): |
| 121 | + return cast(dict[str, Any], result) |
| 122 | + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") |
| 123 | + |
| 124 | + |
| 125 | +@dataclass |
| 126 | +class ServerRunner(Generic[LifespanT, ServerTransportT]): |
| 127 | + """Per-connection orchestrator. One instance per client connection.""" |
| 128 | + |
| 129 | + server: ServerRegistry |
| 130 | + dispatcher: Dispatcher[ServerTransportT] |
| 131 | + lifespan_state: LifespanT |
| 132 | + has_standalone_channel: bool |
| 133 | + stateless: bool = False |
| 134 | + |
| 135 | + connection: Connection = field(init=False) |
| 136 | + _initialized: bool = field(init=False) |
| 137 | + |
| 138 | + def __post_init__(self) -> None: |
| 139 | + self._initialized = self.stateless |
| 140 | + self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel) |
| 141 | + |
| 142 | + async def _on_request( |
| 143 | + self, |
| 144 | + dctx: DispatchContext[TransportContext], |
| 145 | + method: str, |
| 146 | + params: Mapping[str, Any] | None, |
| 147 | + ) -> dict[str, Any]: |
| 148 | + if method == "initialize": |
| 149 | + return self._handle_initialize(params) |
| 150 | + if not self._initialized and method not in _INIT_EXEMPT: |
| 151 | + raise MCPError( |
| 152 | + code=INVALID_REQUEST, |
| 153 | + message=f"Received {method!r} before initialization was complete", |
| 154 | + ) |
| 155 | + handler = self.server.get_request_handler(method) |
| 156 | + if handler is None: |
| 157 | + raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}") |
| 158 | + # TODO: scaffolding — params_type comes from a static lookup until the |
| 159 | + # registry stores it alongside the handler. |
| 160 | + params_type = _PARAMS_FOR_METHOD.get(method, RequestParams) |
| 161 | + # ValidationError propagates; the dispatcher's exception boundary maps |
| 162 | + # it to INVALID_PARAMS. |
| 163 | + typed_params = params_type.model_validate(params or {}) |
| 164 | + ctx = self._make_context(dctx, typed_params) |
| 165 | + result = await handler(ctx, typed_params) |
| 166 | + return _dump_result(result) |
| 167 | + |
| 168 | + async def _on_notify( |
| 169 | + self, |
| 170 | + dctx: DispatchContext[TransportContext], |
| 171 | + method: str, |
| 172 | + params: Mapping[str, Any] | None, |
| 173 | + ) -> None: |
| 174 | + if method == "notifications/initialized": |
| 175 | + self._initialized = True |
| 176 | + self.connection.initialized.set() |
| 177 | + return |
| 178 | + if not self._initialized: |
| 179 | + logger.debug("dropped %s: received before initialization", method) |
| 180 | + return |
| 181 | + handler = self.server.get_notification_handler(method) |
| 182 | + if handler is None: |
| 183 | + logger.debug("no handler for notification %s", method) |
| 184 | + return |
| 185 | + params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams) |
| 186 | + typed_params = params_type.model_validate(params or {}) |
| 187 | + ctx = self._make_context(dctx, typed_params) |
| 188 | + await handler(ctx, typed_params) |
| 189 | + |
| 190 | + def _make_context( |
| 191 | + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel |
| 192 | + ) -> Context[LifespanT, ServerTransportT]: |
| 193 | + # `OnRequest` delivers `DispatchContext[TransportContext]`; this |
| 194 | + # ServerRunner instance was constructed for a specific |
| 195 | + # `ServerTransportT`, so the narrow is safe by construction. |
| 196 | + narrowed = cast(DispatchContext[ServerTransportT], dctx) |
| 197 | + meta = getattr(typed_params, "meta", None) |
| 198 | + return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta) |
| 199 | + |
| 200 | + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: |
| 201 | + init = InitializeRequestParams.model_validate(params or {}) |
| 202 | + self.connection.client_info = init.client_info |
| 203 | + self.connection.client_capabilities = init.capabilities |
| 204 | + # TODO: real version negotiation. This always responds with LATEST, |
| 205 | + # which is wrong — the server should pick the highest version both |
| 206 | + # sides support and compute a per-connection feature set from it. |
| 207 | + # See FOLLOWUPS: "Consolidate per-connection mode/negotiation". |
| 208 | + self.connection.protocol_version = ( |
| 209 | + init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION |
| 210 | + ) |
| 211 | + self._initialized = True |
| 212 | + self.connection.initialized.set() |
| 213 | + result = InitializeResult( |
| 214 | + protocol_version=self.connection.protocol_version, |
| 215 | + capabilities=self.server.get_capabilities(NotificationOptions(), {}), |
| 216 | + server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"), |
| 217 | + ) |
| 218 | + return _dump_result(result) |
0 commit comments