Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/mcp/server/context.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any, Generic
from typing import Any, Generic, Protocol

from pydantic import BaseModel
from typing_extensions import TypeVar

from mcp.server._typed_request import TypedServerRequestMixin
Expand Down Expand Up @@ -81,3 +83,35 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
if meta:
params["_meta"] = meta
await self.notify("notifications/message", params)


HandlerResult = BaseModel | dict[str, Any] | None
"""What a request handler (or middleware) may return. `ServerRunner` serializes
all three to a result dict."""

CallNext = Callable[[], Awaitable[HandlerResult]]

_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True)


class ContextMiddleware(Protocol[_MwLifespanT]):
"""Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``.

Runs *inside* `ServerRunner._on_request` after params validation and
`Context` construction. Wraps registered handlers (including ``ping``) but
not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed
outermost-first on `Server.middleware`.

`Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific
middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific
types) can be typed `ContextMiddleware[object]` — `Context` is covariant in
`LifespanT`, so it registers on any `Server[L]`.
"""

async def __call__(
self,
ctx: Context[_MwLifespanT, TransportContext],
method: str,
params: BaseModel,
call_next: CallNext,
) -> HandlerResult: ...
20 changes: 19 additions & 1 deletion src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async def main():
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
from mcp.server.auth.settings import AuthSettings
from mcp.server.context import ServerRequestContext
from mcp.server.context import ContextMiddleware, ServerRequestContext
from mcp.server.experimental.request_context import Experimental
from mcp.server.lowlevel.experimental import ExperimentalHandlers
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -199,6 +199,9 @@ def __init__(
] = {}
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
self._session_manager: StreamableHTTPSessionManager | None = None
# Context-tier middleware consumed by `ServerRunner`. Additive; the
# existing `run()` path ignores it.
self.middleware: list[ContextMiddleware[LifespanResultT]] = []
logger.debug("Initializing server %r", name)

# Populate internal handler dicts from on_* kwargs
Expand Down Expand Up @@ -246,6 +249,21 @@ def _has_handler(self, method: str) -> bool:
"""Check if a handler is registered for the given method."""
return method in self._request_handlers or method in self._notification_handlers

# --- ServerRegistry protocol (consumed by ServerRunner) ------------------

def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
"""Return the handler for a request method, or ``None``."""
return self._request_handlers.get(method)

def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
"""Return the handler for a notification method, or ``None``."""
return self._notification_handlers.get(method)

@property
def connection_lifespan(self) -> None:
"""Per-connection lifespan. ``None`` until the registry refactor adds it."""
return None

# TODO: Rethink capabilities API. Currently capabilities are derived from registered
# handlers but require NotificationOptions to be passed externally for list_changed
# flags, and experimental_capabilities as a separate dict. Consider deriving capabilities
Expand Down
287 changes: 287 additions & 0 deletions src/mcp/server/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""`ServerRunner` — per-connection orchestrator over a `Dispatcher`.

`ServerRunner` is the bridge between the dispatcher layer (`on_request` /
`on_notify`, untyped dicts) and the user's handler layer (typed `Context`,
typed params). One instance per client connection. It:

* handles the ``initialize`` handshake and populates `Connection`
* gates requests until initialized (``ping`` exempt)
* looks up the handler in the server's registry, validates params, builds
`Context`, runs the middleware chain, returns the result dict
* drives ``dispatcher.run()`` and the per-connection lifespan

`ServerRunner` consumes any `ServerRegistry` — the lowlevel `Server` satisfies
it via additive methods so the existing ``Server.run()`` path is unaffected.
"""

from __future__ import annotations

import logging
from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import partial, reduce
from typing import Any, Generic, Protocol, cast

import anyio.abc
from opentelemetry.trace import SpanKind, StatusCode
from pydantic import BaseModel
from typing_extensions import TypeVar

from mcp.server.connection import Connection
from mcp.server.context import CallNext, Context, ContextMiddleware
from mcp.server.lowlevel.server import NotificationOptions
from mcp.shared._otel import extract_trace_context, otel_span
from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest
from mcp.shared.exceptions import MCPError
from mcp.shared.transport_context import TransportContext
from mcp.types import (
INVALID_REQUEST,
LATEST_PROTOCOL_VERSION,
METHOD_NOT_FOUND,
CallToolRequestParams,
CompleteRequestParams,
GetPromptRequestParams,
Implementation,
InitializeRequestParams,
InitializeResult,
NotificationParams,
PaginatedRequestParams,
ProgressNotificationParams,
ReadResourceRequestParams,
RequestParams,
ServerCapabilities,
SetLevelRequestParams,
SubscribeRequestParams,
UnsubscribeRequestParams,
)

__all__ = ["CallNext", "ContextMiddleware", "ServerRegistry", "ServerRunner", "otel_middleware"]

logger = logging.getLogger(__name__)

LifespanT = TypeVar("LifespanT", default=Any)
ServerTransportT = TypeVar("ServerTransportT", bound=TransportContext, default=TransportContext)

Handler = Callable[..., Awaitable[Any]]
"""A request/notification handler: ``(ctx, params) -> result``. Typed loosely
so the existing `ServerRequestContext`-based handlers and the new
`Context`-based handlers both fit during the transition.
"""


_INIT_EXEMPT: frozenset[str] = frozenset({"ping"})

# TODO: remove this lookup once `Server` stores (params_type, handler) in its
# registry directly. This is scaffolding so ServerRunner can validate params
# without changing the existing `_request_handlers` dict shape.
_PARAMS_FOR_METHOD: dict[str, type[BaseModel]] = {
"ping": RequestParams,
"tools/list": PaginatedRequestParams,
"tools/call": CallToolRequestParams,
"prompts/list": PaginatedRequestParams,
"prompts/get": GetPromptRequestParams,
"resources/list": PaginatedRequestParams,
"resources/templates/list": PaginatedRequestParams,
"resources/read": ReadResourceRequestParams,
"resources/subscribe": SubscribeRequestParams,
"resources/unsubscribe": UnsubscribeRequestParams,
"logging/setLevel": SetLevelRequestParams,
"completion/complete": CompleteRequestParams,
}
"""Spec method → params model. Scaffolding while the lowlevel `Server`'s
`_request_handlers` stores handler-only; the registry refactor should make this
the registry's responsibility (or store params types alongside handlers)."""

_PARAMS_FOR_NOTIFICATION: dict[str, type[BaseModel]] = {
"notifications/initialized": NotificationParams,
"notifications/roots/list_changed": NotificationParams,
"notifications/progress": ProgressNotificationParams,
}


class ServerRegistry(Protocol):
"""The handler registry `ServerRunner` consumes.

The lowlevel `Server` satisfies this via additive methods.
"""

@property
def name(self) -> str: ...
@property
def version(self) -> str | None: ...

@property
def middleware(self) -> Sequence[ContextMiddleware[Any]]: ...

def get_request_handler(self, method: str) -> Handler | None: ...
def get_notification_handler(self, method: str) -> Handler | None: ...
def get_capabilities(
self, notification_options: Any, experimental_capabilities: dict[str, dict[str, Any]]
) -> ServerCapabilities: ...


def otel_middleware(next_on_request: OnRequest) -> OnRequest:
"""Dispatch-tier middleware that wraps each request in an OpenTelemetry span.

Mirrors the span shape of the existing `Server._handle_request`: span name
``"MCP handle <method> [<target>]"``, ``mcp.method.name`` attribute, W3C
trace context extracted from ``params._meta`` (SEP-414), and an ERROR
status if the handler raises.
"""

async def wrapped(
dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None
) -> dict[str, Any]:
target: str | None
match params:
case {"name": str() as target}:
pass
case _:
target = None
parent: Any | None
match params:
case {"_meta": {**meta}}:
parent = extract_trace_context(meta)
case _:
parent = None
span_name = f"MCP handle {method}{f' {target}' if target else ''}"
with otel_span(span_name, kind=SpanKind.SERVER, attributes={"mcp.method.name": method}, context=parent) as span:
try:
return await next_on_request(dctx, method, params)
except MCPError as e:
span.set_status(StatusCode.ERROR, e.error.message)
raise
except Exception as e:
span.set_status(StatusCode.ERROR, str(e))
raise

return wrapped


def _dump_result(result: Any) -> dict[str, Any]:
if result is None:
return {}
if isinstance(result, BaseModel):
return result.model_dump(by_alias=True, mode="json", exclude_none=True)
if isinstance(result, dict):
return cast(dict[str, Any], result)
raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None")


@dataclass
class ServerRunner(Generic[LifespanT, ServerTransportT]):
"""Per-connection orchestrator. One instance per client connection."""

server: ServerRegistry
dispatcher: Dispatcher[ServerTransportT]
lifespan_state: LifespanT
has_standalone_channel: bool
stateless: bool = False
dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware])

connection: Connection = field(init=False)
_initialized: bool = field(init=False)

def __post_init__(self) -> None:
self._initialized = self.stateless
self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel)

async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None:
"""Drive the dispatcher until the underlying channel closes.

Composes `dispatch_middleware` over `_on_request` and hands the result
to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers
can ``await tg.start(runner.run)`` and resume once the dispatcher is
ready to accept requests.
"""
await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status)

def _compose_on_request(self) -> OnRequest:
"""Wrap `_on_request` in `dispatch_middleware`, outermost-first.

Dispatch-tier middleware sees raw ``(dctx, method, params) -> dict``
and wraps everything — initialize, METHOD_NOT_FOUND, validation
failures included. `run()` calls this once and hands the result to
`dispatcher.run()`.
"""
return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request)

async def _on_request(
self,
dctx: DispatchContext[TransportContext],
method: str,
params: Mapping[str, Any] | None,
) -> dict[str, Any]:
if method == "initialize":
return self._handle_initialize(params)
if not self._initialized and method not in _INIT_EXEMPT:
raise MCPError(
code=INVALID_REQUEST,
message=f"Received {method!r} before initialization was complete",
)
handler = self.server.get_request_handler(method)
if handler is None:
raise MCPError(code=METHOD_NOT_FOUND, message=f"Method not found: {method}")
# TODO: scaffolding — params_type comes from a static lookup until the
# registry stores it alongside the handler.
params_type = _PARAMS_FOR_METHOD.get(method, RequestParams)
# ValidationError propagates; the dispatcher's exception boundary maps
# it to INVALID_PARAMS.
typed_params = params_type.model_validate(params or {})
ctx = self._make_context(dctx, typed_params)
call: CallNext = partial(handler, ctx, typed_params)
for mw in reversed(self.server.middleware):
call = partial(mw, ctx, method, typed_params, call)
return _dump_result(await call())

async def _on_notify(
self,
dctx: DispatchContext[TransportContext],
method: str,
params: Mapping[str, Any] | None,
) -> None:
if method == "notifications/initialized":
self._initialized = True
self.connection.initialized.set()
return
if not self._initialized:
logger.debug("dropped %s: received before initialization", method)
return
handler = self.server.get_notification_handler(method)
if handler is None:
logger.debug("no handler for notification %s", method)
return
params_type = _PARAMS_FOR_NOTIFICATION.get(method, NotificationParams)
typed_params = params_type.model_validate(params or {})
ctx = self._make_context(dctx, typed_params)
await handler(ctx, typed_params)

def _make_context(
self, dctx: DispatchContext[TransportContext], typed_params: BaseModel
) -> Context[LifespanT, ServerTransportT]:
# `OnRequest` delivers `DispatchContext[TransportContext]`; this
# ServerRunner instance was constructed for a specific
# `ServerTransportT`, so the narrow is safe by construction.
narrowed = cast(DispatchContext[ServerTransportT], dctx)
meta = getattr(typed_params, "meta", None)
return Context(narrowed, lifespan=self.lifespan_state, connection=self.connection, meta=meta)

def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]:
init = InitializeRequestParams.model_validate(params or {})
self.connection.client_info = init.client_info
self.connection.client_capabilities = init.capabilities
# TODO: real version negotiation. This always responds with LATEST,
# which is wrong — the server should pick the highest version both
# sides support and compute a per-connection feature set from it.
# See FOLLOWUPS: "Consolidate per-connection mode/negotiation".
self.connection.protocol_version = (
init.protocol_version if init.protocol_version in {LATEST_PROTOCOL_VERSION} else LATEST_PROTOCOL_VERSION
)
self._initialized = True
self.connection.initialized.set()
result = InitializeResult(
protocol_version=self.connection.protocol_version,
capabilities=self.server.get_capabilities(NotificationOptions(), {}),
server_info=Implementation(name=self.server.name, version=self.server.version or "0.0.0"),
)
return _dump_result(result)
Loading
Loading