diff --git a/docs/WALKTHROUGH.md b/docs/WALKTHROUGH.md new file mode 100644 index 000000000..65c57297c --- /dev/null +++ b/docs/WALKTHROUGH.md @@ -0,0 +1,213 @@ +# Walkthrough: why the SDK fights stateless, and how to fix it + +This is a code walk, not a spec. I'm going to start in the current SDK, show where it hurts, and then show what the same thing looks like after the proposed split. The RFC has the formal proposal; this is the "let me show you" version. + +--- + +## Part 1: The current code + +### Start at the only entrance + +There is exactly one way to make an MCP server handle requests: + +```ts +// packages/core/src/shared/protocol.ts:437 +async connect(transport: Transport): Promise { + this._transport = transport; + transport.onmessage = (message, extra) => { + // route to _onrequest / _onresponse / _onnotification + }; + await transport.start(); +} +``` + +You hand it a long-lived `Transport`, it takes over the `onmessage` callback, and from then on requests arrive asynchronously. There is no `handle(request) → response`. If you want to call a handler, you go through a transport. + +`Transport` is shaped like a pipe: + +```ts +// packages/core/src/shared/transport.ts:8 +interface Transport { + start(): Promise; + send(message: JSONRPCMessage): Promise; + onmessage?: (message, extra) => void; + close(): Promise; + sessionId?: string; + setProtocolVersion?(v: string): void; +} +``` + +`start`/`close` for lifecycle, fire-and-forget `send`, async `onmessage` callback. That's stdio's shape. It's also the shape every transport must implement, including HTTP. + +### Follow an HTTP request through + +The Streamable HTTP server transport is `packages/server/src/server/streamableHttp.ts` — 1038 lines. Let's follow a `tools/list` POST: + +1. User's Express handler calls `transport.handleRequest(req, res, body)` (line 176) +2. `handlePostRequest` validates headers (217-268), parses body (282) +3. Now it has a JSON-RPC request and needs to get it to the dispatcher. But the only path is `onmessage`. So it... calls `this.onmessage?.(msg, extra)` (370). Fire and forget. +4. `Protocol._onrequest` runs the handler, gets a result, builds a response, calls `this._transport.send(response)` (634) +5. Back in the transport, `send(response)` needs to find *which* HTTP response stream to write to. It looks up `_streamMapping[streamId]` (756) using a `relatedRequestId` that was threaded through. + +So the transport keeps a table mapping in-flight request IDs to open `Response` writers (`_streamMapping`, `_requestToStreamMapping`, ~80 LOC of bookkeeping), because `send()` is fire-and-forget and the response has to find its way back to the right HTTP response somehow. + +This is the core impedance mismatch: **HTTP is request→response, but the only interface is pipe-shaped, so the transport reconstructs request→response correlation on top of a pipe abstraction that sits on top of HTTP's native request→response.** + +### The session sniffing + +The transport also has to know about `initialize`: + +```ts +// streamableHttp.ts:323 +if (isInitializeRequest(body)) { + if (this._sessionIdGenerator) { + this.sessionId = this._sessionIdGenerator(); + // ... onsessioninitialized callback + } + this._initialized = true; +} +``` + +A transport — whose job should be "bytes in, bytes out" — is parsing message bodies to detect a specific MCP method so it knows when to mint a session ID. There are 18 references to `initialize` in this file. The transport knows about the protocol's handshake. + +### What "stateless" looks like today + +The protocol direction (SEP-2575/2567) is: no `initialize`, no sessions, each request is independent. You can do this today with a module-scope transport: + +```ts +const t = new NodeStreamableHTTPServerTransport({sessionIdGenerator: undefined}); +await mcp.connect(t); +app.all('/mcp', (req, res) => t.handleRequest(req, res, req.body)); +``` + +`sessionIdGenerator: undefined` is the opt-out — it makes `handleRequest` skip the session-ID minting/validation branches in the transport. The request still goes through the pipe-shaped path (`onmessage → _onrequest → handler → send → _streamMapping` lookup), but without sessions the mapping is just per-in-flight-request. + +It works. It's not obvious — you have to know that `undefined` is the flag, that `connect()` is still needed, and that the transport class is doing pipe-correlation under a request/response API. (The shipped example actually constructs the transport per-request, which is unnecessary but suggests the authors weren't confident in the module-scope version either.) + +### Why is Protocol 1100 lines? + +`protocol.ts` is the abstract base for both `Server` and `Client`. It does: + +- handler registry (`_requestHandlers`, `setRequestHandler`) +- outbound request/response correlation (`_responseHandlers`, `_requestMessageId`) +- timeouts (`_timeoutInfo`, `_setupTimeout`, `_resetTimeout`) +- progress callbacks (`_progressHandlers`) +- debounced notifications (`_pendingDebouncedNotifications`) +- cancellation (`_requestHandlerAbortControllers`) +- TaskManager binding (`_bindTaskManager`) +- 4 abstract `assert*Capability` methods subclasses must implement +- `connect()` — wiring all of the above to a transport + +Some of those are per-connection state (correlation, timeouts, debounce). Some are pure routing (handler registry). Some are protocol semantics (capabilities). They're fused, so you can't get at the routing without the connection state. + +When you trace a request through, you bounce between `Protocol._onrequest`, `Server.buildContext`, `McpServer`'s registry handlers, back to `Protocol`'s send path. Three classes, two levels of inheritance. (Python folks will recognize this — "is BaseSession or ServerSession handling this line?") + +--- + +## Part 2: The proposed split + +### The primitive + +```ts +class Dispatcher { + setRequestHandler(method, handler): void; + dispatch(req: JSONRPCRequest, env?: RequestEnv): AsyncIterable; +} +``` + +A `Map` and a function that looks up + calls. `dispatch` yields zero-or-more notifications then exactly one response (matching SEP-2260's wire constraint). `RequestEnv` is per-request context the caller provides — `{sessionId?, authInfo?, signal?, send?}`. No transport. No connection state. ~270 LOC. + +That's it. You can call `dispatch` from anywhere — a test, a Lambda, a loop reading stdin. + +### The channel adapter + +For stdio/WebSocket/InMemory — things that *are* persistent pipes — `StreamDriver` wraps a `ChannelTransport` and a `Dispatcher`: + +```ts +class StreamDriver { + constructor(dispatcher, channel) { ... } + start() { + channel.onmessage = msg => { + for await (const out of dispatcher.dispatch(msg, env)) channel.send(out); + }; + } + request(req): Promise; // outbound, with correlation/timeout +} +``` + +This is where Protocol's per-connection half goes: `_responseHandlers`, `_timeoutInfo`, `_progressHandlers`, debounce. One driver per pipe; the dispatcher it wraps can be shared. ~450 LOC. + +`connect(channelTransport)` builds one of these. So `connect` still works exactly as before for stdio. + +### The request adapter + +For HTTP — things that are *not* persistent pipes — `shttpHandler`: + +```ts +function shttpHandler(dispatcher, opts?): (req: Request) => Promise { + return async (req) => { + const body = await req.json(); + const stream = sseStreamFrom(dispatcher.dispatch(body, env)); + return new Response(stream, {headers: {'content-type': 'text/event-stream'}}); + }; +} +``` + +Parse → `dispatch` → stream the AsyncIterable as SSE. ~400 LOC including header validation, batch handling, EventStore replay. No `_streamMapping` — the response stream is just in lexical scope. + +`mcp.handleHttp(req)` is McpServer's convenience wrapper around this. + +### The deletable parts + +`SessionCompat` — bounded LRU `{sessionId → negotiatedVersion}`. If you pass it to `shttpHandler`, the handler validates `mcp-session-id` headers and mints IDs on `initialize`. If you don't, it doesn't. ~200 LOC. + +`BackchannelCompat` — per-session `{requestId → resolver}` so a tool handler can `await ctx.elicitInput()` and the response comes back via a separate POST. The 2025-11 server→client-over-SSE behavior. ~140 LOC. + +These two are the *only* places 2025-11 stateful behavior lives. When that protocol version sunsets and MRTR (SEP-2322) is the floor, delete both files; `shttpHandler` is fully stateless. + +### Same examples, after + +```ts +// stateless — one server, no transport instance +const mcp = new McpServer({name: 'hello', version: '1'}); +mcp.registerTool('greet', ..., ...); +app.post('/mcp', c => mcp.handleHttp(c.req.raw)); +``` + +```ts +// 2025-11 stateful — same server, opt-in session +const session = new SessionCompat({sessionIdGenerator: () => randomUUID()}); +app.all('/mcp', toNodeHttpHandler(shttpHandler(mcp, {session}))); +``` + +```ts +// stdio — unchanged from today +const t = new StdioServerTransport(); +await mcp.connect(t); +``` + +```ts +// the existing v1 pattern — also unchanged +const t = new NodeStreamableHTTPServerTransport({sessionIdGenerator: () => randomUUID()}); +await mcp.connect(t); +app.all('/mcp', (req, res) => t.handleRequest(req, res, req.body)); +// (internally, t.handleRequest now calls shttpHandler — same wire behavior) +``` + +--- + +## Part 3: What you get + +**The stateless server is one line.** One `McpServer` at module scope, `handleHttp` per request. The per-request build-and-tear-down workaround is gone. + +**Handlers are testable without a transport.** `await mcp.dispatchToResponse({...})` — no `InMemoryTransport` pair, no `connect`. + +**The SHTTP transport drops from 1038 to ~290 LOC.** No `_streamMapping` (the response stream is in lexical scope), no body-sniffing for `initialize` (SessionCompat handles it), no fake `start()`. + +**2025-11 protocol state lives in two named files.** When that version sunsets, delete `SessionCompat` and `BackchannelCompat`; `shttpHandler` is fully stateless. Today the same logic is `if (sessionIdGenerator)` branches scattered through one transport. + +**Existing code doesn't change.** `new NodeStreamableHTTPServerTransport({...})` + `connect(t)` + `t.handleRequest(...)` works exactly as before — the class builds the compat pieces internally from the options you already pass. + +--- + +*Reference implementation on [`fweinberger/ts-sdk-rebuild`](https://github.com/modelcontextprotocol/typescript-sdk/tree/fweinberger/ts-sdk-rebuild). See the [RFC](./rfc-stateless-architecture.md) for the formal proposal.* diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index a37b5e206..ad15c5014 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -75,9 +75,10 @@ Notes: ## 4. Renamed Symbols -| v1 symbol | v2 symbol | v2 package | -| ------------------------------- | ----------------------------------- | ---------------------------- | -| `StreamableHTTPServerTransport` | `NodeStreamableHTTPServerTransport` | `@modelcontextprotocol/node` | +| v1 symbol | v2 symbol | v2 package | +| ------------------------------- | ----------------------------------- | ------------------------------------- | +| `StreamableHTTPServerTransport` | `NodeStreamableHTTPServerTransport` | `@modelcontextprotocol/node` | +| `Transport` (interface) | `ChannelTransport` | `@modelcontextprotocol/{client,server}` (deprecated alias `Transport` kept) | ## 5. Removed / Renamed Type Aliases and Symbols diff --git a/docs/migration.md b/docs/migration.md index 7cb7d58f6..129c90883 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -86,6 +86,22 @@ npm install @modelcontextprotocol/express # Express npm install @modelcontextprotocol/hono # Hono ``` +### `Transport` interface renamed to `ChannelTransport`; `RequestTransport` added + +The pipe-shaped `Transport` interface has been renamed `ChannelTransport`. A new `RequestTransport` interface (callback-based, for request/response transports like Streamable HTTP) sits alongside it. `connect()` accepts either. + +`Transport` is kept as a deprecated type alias of `ChannelTransport`, so existing code compiles unchanged. + +```typescript +// Before (v1) +import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; + +// After (v2) +import type { ChannelTransport, RequestTransport } from '@modelcontextprotocol/server'; +``` + +If you implemented a custom transport by implementing `Transport`, switch to `implements ChannelTransport` (same shape) or, for HTTP-style transports, `implements RequestTransport`. + ### `StreamableHTTPServerTransport` renamed `StreamableHTTPServerTransport` has been renamed to `NodeStreamableHTTPServerTransport` and moved to `@modelcontextprotocol/node`. diff --git a/docs/rfc-stateless-architecture.md b/docs/rfc-stateless-architecture.md new file mode 100644 index 000000000..7786d97e5 --- /dev/null +++ b/docs/rfc-stateless-architecture.md @@ -0,0 +1,243 @@ +# RFC: Request-first SDK architecture + +**Status:** Draft, seeking direction feedback +**Reference impl:** [#1942](https://github.com/modelcontextprotocol/typescript-sdk/pull/1942) (`fweinberger/ts-sdk-rebuild` — proof-of-concept, not for direct merge) + +--- + +## TL;DR + +The only way into the SDK today is `server.connect(transport)`, which assumes a persistent channel. The protocol is moving to per-request stateless (SEP-2575/2567/2322). This RFC proposes adding `dispatch(request, env) → response` as the core primitive and building the connection model as one adapter on top of it. Existing code keeps working unchanged. + +--- + +## Problem + +``` + ┌────────────────────────────────────────────┐ + │ Protocol (~1100 LOC, abstract) │ + │ ├ handler registry │ + │ ├ request/response correlation │ + │ ├ timeouts, debounce, progress │ + │ ├ capability assertions (abstract) │ + │ ├ TaskManager binding │ + │ └ connect(transport) — wires onmessage │ + └────────────────────────────────────────────┘ + ▲ ▲ + extends │ │ extends + ┌─────────┴──┐ ┌───────┴──────┐ + │ Server │ │ Client │ + └─────┬──────┘ └──────────────┘ + wraps │ + ┌─────┴──────┐ + │ McpServer │ + └────────────┘ +``` + +Everything goes through `connect(transport)`. `Transport` is pipe-shaped (`{start, send, onmessage, close}`). The Streamable HTTP transport (1038 LOC) implements that pipe shape on top of HTTP — keeping a `_streamMapping` table to route fire-and-forget `send()` calls back to the right HTTP response, sniffing message bodies to detect `initialize` so it knows when to mint a session ID. + +The shipped stateless example constructs a fresh server and transport per request ([`examples/server/src/simpleStatelessStreamableHttp.ts`](https://github.com/modelcontextprotocol/typescript-sdk/blob/7bb79ebbbba88a503851617d053b13d8fd9228bb/examples/server/src/simpleStatelessStreamableHttp.ts#L99-L111)): + +```ts +app.post('/mcp', async (req, res) => { + const server = getServer(); // McpServer + all registrations + const transport = new NodeStreamableHTTPServerTransport({ + sessionIdGenerator: undefined // opt-out flag + }); + await server.connect(transport); + await transport.handleRequest(req, res, req.body); + res.on('close', () => { + transport.close(); + server.close(); + }); +}); +``` + +A module-scope version (one server, one transport, `sessionIdGenerator: undefined`) does work, but the example doesn't use it — and the request still goes through the pipe-shaped path either way. + +--- + +## Proposal + +``` + ┌───────────────────────────────────────┐ + │ Dispatcher (~270 LOC) │ + │ ├ handler registry │ + │ └ dispatch(req, env) → AsyncIterable │ + │ No transport. No connection state. │ + └───────────────────────────────────────┘ + ▲ + extends │ + ┌─────────────────┴─────────────────┐ + │ McpServer / Client │ + │ (MCP handlers, registries) │ + └─────────────────┬─────────────────┘ + │ dispatch() called by: + ┌───────────────┴───────────────┐ + ▼ ▼ + ┌───────────────────┐ ┌──────────────────────────┐ + │ StreamDriver │ │ shttpHandler │ + │ (channel adapter) │ │ (request adapter) │ + │ correlation, │ │ ├ SessionCompat (opt) │ + │ timeouts, debounce│ │ └ BackchannelCompat(opt) │ + │ stdio, WS, InMem │ │ SHTTP │ + └───────────────────┘ └──────────────────────────┘ +``` + +| Piece | What it does | Replaces | +|---|---|---| +| **Dispatcher** | Knows which handler to call for which method. You register handlers (`setRequestHandler('tools/list', fn)`); `dispatch(request)` looks one up, runs it, returns the output. Doesn't know how the request arrived or where the response goes. | Protocol's handler-registry half | +| **StreamDriver** | Runs a Dispatcher over a persistent connection (stdio, WebSocket). Reads from the pipe → `dispatch()` → writes back. Owns the per-connection state: response correlation, timeouts, debounce. One per pipe; the Dispatcher it wraps can be shared. | Protocol's correlation/timeout half | +| **shttpHandler** | Runs a Dispatcher over HTTP. Takes a web `Request`, parses the body, calls `dispatch()`, streams the result as a `Response`. A function you mount on a router, not a class you connect. | The 1038-LOC SHTTP transport's core | +| **SessionCompat** | Remembers session IDs across HTTP requests. 2025-11 servers mint an ID on `initialize` and validate it on every later request — this is the bounded LRU that does that. Pass it to `shttpHandler` for 2025-11 clients; omit it for stateless. | `Transport.sessionId` + SHTTP `_initialized` | +| **BackchannelCompat** | Lets a tool handler ask the client a question mid-call (`ctx.elicitInput()`) over HTTP. 2025-11 does this by writing the question into the still-open SSE response and waiting for the client to POST the answer back; this holds the "waiting for answer N" table. Under MRTR the same thing is a return value, so this gets deleted. | `_streamMapping` + `relatedRequestId` | + +The last two are the only places 2025-11 stateful behavior lives. They're passed to `shttpHandler` as options; without them it's pure request→response. + +### Middleware + +`Dispatcher.use(mw)` registers generator middleware that wraps every `dispatch()`: + +```ts +mcp.use(next => async function* (req, env) { + // before handler + for await (const out of next(req, env)) { + // around each notification + the response + yield out; + } + // after +}); +``` + +Runs for every method (including `initialize`), regardless of transport. Short-circuit (auth reject, cache hit), transform outputs, time the call. A small `onMethod('tools/list', fn)` helper gives typed per-method post-processing without the `if (req.method === ...)` boilerplate. + +### Transport interfaces + +`Transport` is renamed `ChannelTransport` (the pipe shape: `start/send/onmessage/close`). `Transport` stays as a deprecated alias. A second internal shape, `RequestTransport`, is what the SHTTP server transport implements — it doesn't pretend to be a pipe. `connect()` accepts both and picks the right adapter via an explicit `kind: 'channel' | 'request'` brand on the transport. + +--- + +## Compatibility + +**Existing stateful SHTTP code does not change:** + +```ts +const t = new NodeStreamableHTTPServerTransport({sessionIdGenerator: () => randomUUID()}); +await mcp.connect(t); +app.all('/mcp', (req, res) => t.handleRequest(req, res, req.body)); +``` + +Same options, same wire behavior — sessions are minted on `initialize`, validated on every later request, `transport.sessionId` is populated, `onsessioninitialized`/`onsessionclosed` fire, `ctx.elicitInput()` works mid-tool-call. Under the hood the transport class constructs a `SessionCompat` and `BackchannelCompat` from those options and routes `handleRequest` through `shttpHandler`. The session-ful behavior is identical; the implementation is the new path. + +**Existing stdio code does not change:** + +```ts +const t = new StdioServerTransport(); +await mcp.connect(t); +``` + +`connect()` sees a channel-shaped transport and builds a `StreamDriver(mcp, t)` internally — which reads stdin, calls `dispatch()`, writes stdout. The stdio transport class itself is unchanged (it was always just a pipe wrapper); what's different is that the read-dispatch-write loop now lives in `StreamDriver` instead of `Protocol`. + +`Protocol` and `Server` stay as back-compat shims for direct subclassers (ext-apps). + +--- + +## Client side + +The same split applies. `Client extends Dispatcher` — its registry holds the handlers for requests the *server* sends (`elicitation/create`, `sampling/createMessage`, `roots/list`). When one arrives, `dispatch()` routes it. + +For outbound (`callTool`, `listTools`, etc.), Client uses a `ClientTransport`: + +```ts +interface ClientTransport { + fetch(req: JSONRPCRequest, opts?): Promise; // request → response + notify(n: Notification): Promise; + close(): Promise; +} +``` + +This is the request-shaped mirror of the server side: `fetch` is one request → one response. + +``` + ┌───────────────────────────────────────┐ + │ Client extends Dispatcher │ + │ inbound: dispatch() for elicit/ │ + │ sampling/roots │ + │ outbound: callTool → │ + │ _clientTransport.fetch(req)│ + └─────────────────┬─────────────────────┘ + │ _clientTransport is ONE of: + ┌───────────────┴───────────────┐ + ▼ ▼ + ┌───────────────────────┐ ┌────────────────────────────────┐ + │ pipeAsClientTransport │ │ StreamableHTTPClientTransport │ + │ (wraps a channel via │ │ (implements ClientTransport │ + │ StreamDriver) │ │ directly: POST → Response) │ + │ stdio, WS, InMem │ │ SHTTP │ + └───────────────────────┘ └────────────────────────────────┘ +``` + +**Over HTTP:** `StreamableHTTPClientTransport.fetch` POSTs the request and reads the response (SSE or JSON). If the server writes a JSON-RPC *request* into that SSE stream (2025-11 elicitation), the transport calls `opts.onrequest(r)` — which Client wires to `this.dispatch(r)` — and POSTs the answer back. Same flow as today, request-shaped underneath. + +**Over stdio:** `pipeAsClientTransport(stdioTransport)` wraps the channel in a StreamDriver and exposes `{fetch, notify, close}`. `fetch` becomes "send over the pipe, await the correlated response." + +**MRTR (SEP-2322):** the stateless server→client path. Instead of the held-stream backchannel, the server *returns* `{input_required, requests: [...]}` as the `tools/call` result. Client sees that, services each request via its own `dispatch()`, and re-sends `tools/call` with the answers attached. No held stream, works over any transport. Client's `_request` runs this loop transparently — `await client.callTool(...)` looks the same to the caller whether the server used the backchannel or MRTR. + +**Compat:** `client.connect(transport)` keeps working with both `ChannelTransport` and `ClientTransport`. Existing code (`new StreamableHTTPClientTransport(url)` + `connect`) is unchanged. + +--- + +## Wins + +**Stateless without the opt-out.** Today's stateless is `sessionIdGenerator: undefined` — a flag that opts you out of session handling but leaves the request going through the pipe-shaped path (`onmessage → dispatch → send → _streamMapping` lookup). It's stateless at the wire but not in the code: concurrent requests still share a `_streamMapping` table on the transport instance, the transport still parses bodies looking for `initialize`, and the shipped example constructs everything per-request because the module-scope version isn't obviously safe. After: +```ts +import { McpServer } from '@modelcontextprotocol/server'; +import { Hono } from 'hono'; + +const mcp = new McpServer({name: 'hello', version: '1.0.0'}); +mcp.registerTool('greet', {description: 'Say hello'}, async () => ({ + content: [{type: 'text', text: 'hello'}] +})); + +const app = new Hono(); +app.post('/mcp', c => mcp.handleHttp(c.req.raw)); +``` +No transport class, no `connect`, no flag. The path is `parse → dispatch → respond`. + +**Handlers are testable without a transport.** Today, unit-testing a tool handler means an `InMemoryTransport` pair, two `connect()` calls, and a client to drive it. After: +```ts +const mcp = new McpServer({name: 'test', version: '1.0.0'}); +mcp.registerTool('greet', {description: '...'}, async () => ({ + content: [{type: 'text', text: 'hello'}] +})); + +const out = await mcp.dispatchToResponse({ + jsonrpc: '2.0', id: 1, method: 'tools/call', params: {name: 'greet', arguments: {}} +}); +expect(out.result.content[0].text).toBe('hello'); +``` +The HTTP layer is testable the same way — `await shttpHandler(mcp)(new Request('http://test/mcp', {method: 'POST', body: ...}))` returns a `Response` you can assert on, no server to spin up. + +**Method-level middleware.** There's no per-method hook today — auth is HTTP-layer (`requireBearerAuth` checks the bearer token before MCP parsing), and to log/trace/rate-limit by MCP method you'd wrap each handler manually. `Dispatcher.use(mw)` wraps every dispatch including `initialize`: +```ts +mcp.use(next => async function* (req, env) { + const start = Date.now(); + yield* next(req, env); + metrics.timing('mcp.method', Date.now() - start, {method: req.method}); +}); +``` +(Python's FastMCP ships ten middleware modules — auth, caching, rate-limiting, tracing — and had to subclass an SDK-private method to intercept `initialize`. That's the demand signal; `use()` is the hook.) + +**Pluggable transports stop paying the pipe tax.** A gRPC/WebTransport/Lambda integration today has to implement `{start, send, onmessage, close}` and reconstruct request→response on top. After, request-shaped transports call `dispatch()` directly; only genuinely persistent channels (stdio, WebSocket) implement `ChannelTransport`. + +**Extensions plug in cleanly.** Tasks (and later sampling/roots when they move to `ext-*` packages) attach via `mcp.use(tasksMiddleware(store))` instead of being wired into Protocol. The core SDK doesn't import them. + +**2025-11 state is deletable.** Two named files instead of `if (sessionIdGenerator)` branches through one transport. The sunset is `git rm sessionCompat.ts backchannelCompat.ts`, not a hunt. + +**Protocol stops being a god class.** Today `Protocol` (~1100 LOC) is registry + correlation + timeouts + capabilities + tasks + connect, abstract, with both Server and Client extending it. Tracing a request means bouncing between Protocol, Server, and McpServer. After: Dispatcher does routing, StreamDriver does per-connection state, McpServer does MCP semantics. Each file has one job; you can read one without the others. + +**The SHTTP server transport class drops from 1038 to ~290 LOC.** New server code doesn't need the class at all (`handleHttp` is the entry). The class still exists for back-compat — existing code that does `new NodeStreamableHTTPServerTransport(...)` keeps working — but it's now a thin shim that constructs `shttpHandler` internally. No `_streamMapping`, no body-sniffing for `initialize`, no fake `start()`. (Client-side still needs a transport instance — it has to know where to send. `StreamableHTTPClientTransport` stays, just request-shaped underneath.) + +--- + +The reference implementation passes all SDK tests, conformance (40/40 server, 317/317 client), and 14/14 consumer typecheck after the existing v2 back-compat PRs. See the [WALKTHROUGH](./WALKTHROUGH.md) for a code-level walk through the current pain and the fix. diff --git a/examples/client/src/helloStatelessClient.ts b/examples/client/src/helloStatelessClient.ts new file mode 100644 index 000000000..bb85d0f69 --- /dev/null +++ b/examples/client/src/helloStatelessClient.ts @@ -0,0 +1,23 @@ +/** + * Client for the stateless hello-world server. + * + * This is identical to the v1/v2 client pattern — same classes, same `connect()` call. + * Nothing about the client side changes for users. + * + * Run: npx tsx examples/client/src/helloStatelessClient.ts + */ +import { Client, StreamableHTTPClientTransport } from '@modelcontextprotocol/client'; + +const client = new Client({ name: 'hello-client', version: '1.0.0' }); +await client.connect(new StreamableHTTPClientTransport(new URL('http://localhost:3400/mcp'))); + +const { tools } = await client.listTools(); +console.log( + 'Tools:', + tools.map(t => t.name) +); + +const result = await client.callTool({ name: 'greet', arguments: { name: 'world' } }); +console.log('Result:', result.content[0]); + +await client.close(); diff --git a/examples/server/src/helloStateless.ts b/examples/server/src/helloStateless.ts new file mode 100644 index 000000000..d26abfe5e --- /dev/null +++ b/examples/server/src/helloStateless.ts @@ -0,0 +1,22 @@ +/** + * Stateless hello-world MCP server. No connect(), no transport instance — + * one McpServer at module scope, handleHttp() per request. + * + * Run: npx tsx examples/server/src/helloStateless.ts + */ +import { serve } from '@hono/node-server'; +import { McpServer } from '@modelcontextprotocol/server'; +import { Hono } from 'hono'; +import { z } from 'zod/v4'; + +const mcp = new McpServer({ name: 'hello-stateless', version: '1.0.0' }); + +mcp.registerTool('greet', { description: 'Say hello', inputSchema: z.object({ name: z.string() }) }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] +})); + +const app = new Hono(); +app.post('/mcp', c => mcp.handleHttp(c.req.raw)); + +serve({ fetch: app.fetch, port: 3400 }); +console.log('Stateless MCP server on http://localhost:3400/mcp'); diff --git a/examples/server/src/helloStatelessExpress.ts b/examples/server/src/helloStatelessExpress.ts new file mode 100644 index 000000000..be710bf3f --- /dev/null +++ b/examples/server/src/helloStatelessExpress.ts @@ -0,0 +1,39 @@ +/** + * Hello-world MCP server (Express). Shown two equivalent ways: + * + * 1. The existing v1/v2 pattern — `connect(transport)` + `transport.handleRequest`. + * Works unchanged in the rebuild. + * 2. The new direct pattern — `mcp.handleHttp(req)` with no transport instance. + * + * Both produce identical wire behavior. Pick one. + * + * Run: npx tsx examples/server/src/helloStatelessExpress.ts + */ +import { randomUUID } from 'node:crypto'; + +import { NodeStreamableHTTPServerTransport, toNodeHttpHandler } from '@modelcontextprotocol/node'; +import { McpServer } from '@modelcontextprotocol/server'; +import express from 'express'; +import { z } from 'zod/v4'; + +const mcp = new McpServer({ name: 'hello-express', version: '1.0.0' }); + +mcp.registerTool('greet', { description: 'Say hello', inputSchema: z.object({ name: z.string() }) }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] +})); + +const app = express(); + +// ─── Way 1: existing v1/v2 pattern (unchanged) ───────────────────────────── +const transport = new NodeStreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID() }); +await mcp.connect(transport); +app.all('/mcp-v1style', express.json(), (req, res) => transport.handleRequest(req, res, req.body)); + +// ─── Way 2: new direct pattern (no connect, no transport instance) ───────── +// Don't pre-parse the body — handleHttp reads it from the raw Request. +app.post( + '/mcp', + toNodeHttpHandler(req => mcp.handleHttp(req)) +); + +app.listen(3400, () => console.log('Express MCP server on :3400 — /mcp (new) and /mcp-v1style (existing)')); diff --git a/package.json b/package.json index a2cb93f62..98b3da976 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,5 @@ { - "name": "@modelcontextprotocol/sdk", + "name": "@modelcontextprotocol/sdk-workspace", "private": true, "version": "2.0.0-alpha.0", "description": "Model Context Protocol implementation for TypeScript", diff --git a/packages/client/package.json b/packages/client/package.json index cf9dbff6b..558f8d2dc 100644 --- a/packages/client/package.json +++ b/packages/client/package.json @@ -22,28 +22,34 @@ "exports": { ".": { "types": "./dist/index.d.mts", - "import": "./dist/index.mjs" + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" }, "./validators/cf-worker": { "types": "./dist/validators/cfWorker.d.mts", - "import": "./dist/validators/cfWorker.mjs" + "import": "./dist/validators/cfWorker.mjs", + "require": "./dist/validators/cfWorker.mjs" }, "./_shims": { "workerd": { "types": "./dist/shimsWorkerd.d.mts", - "import": "./dist/shimsWorkerd.mjs" + "import": "./dist/shimsWorkerd.mjs", + "require": "./dist/shimsWorkerd.mjs" }, "browser": { "types": "./dist/shimsBrowser.d.mts", - "import": "./dist/shimsBrowser.mjs" + "import": "./dist/shimsBrowser.mjs", + "require": "./dist/shimsBrowser.mjs" }, "node": { "types": "./dist/shimsNode.d.mts", - "import": "./dist/shimsNode.mjs" + "import": "./dist/shimsNode.mjs", + "require": "./dist/shimsNode.mjs" }, "default": { "types": "./dist/shimsNode.d.mts", - "import": "./dist/shimsNode.mjs" + "import": "./dist/shimsNode.mjs", + "require": "./dist/shimsNode.mjs" } } }, @@ -92,5 +98,16 @@ "typescript-eslint": "catalog:devTools", "vitest": "catalog:devTools", "tsdown": "catalog:devTools" + }, + "types": "./dist/index.d.mts", + "typesVersions": { + "*": { + "stdio": [ + "./dist/stdio.d.mts" + ], + "*": [ + "./dist/*.d.mts" + ] + } } } diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 21a43bd15..9bb8ef20f 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -1,15 +1,23 @@ import { DefaultJsonSchemaValidator } from '@modelcontextprotocol/client/_shims'; import type { - BaseContext, + AnySchema, CallToolRequest, + CallToolResult, + CancelTaskRequest, ClientCapabilities, ClientContext, ClientNotification, ClientRequest, ClientResult, CompleteRequest, + CreateTaskResult, GetPromptRequest, + GetTaskRequest, + GetTaskResult, Implementation, + JSONRPCErrorResponse, + JSONRPCRequest, + JSONRPCResultResponse, JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, @@ -18,17 +26,24 @@ import type { ListPromptsRequest, ListResourcesRequest, ListResourceTemplatesRequest, + ListTasksRequest, ListToolsRequest, LoggingLevel, - MessageExtraInfo, + Notification, NotificationMethod, + NotificationOptions, ProtocolOptions, ReadResourceRequest, + Request, RequestMethod, RequestOptions, RequestTypeMap, + Result, ResultTypeMap, + SchemaOutput, ServerCapabilities, + StandardSchemaV1, + StreamDriverOptions, SubscribeRequest, TaskManagerOptions, Tool, @@ -36,58 +51,60 @@ import type { UnsubscribeRequest } from '@modelcontextprotocol/core'; import { - assertClientRequestTaskCapability, - assertToolsCallTaskCapability, CallToolResultSchema, + CancelTaskResultSchema, CompleteResultSchema, CreateMessageRequestSchema, CreateMessageResultSchema, CreateMessageResultWithToolsSchema, CreateTaskResultSchema, + Dispatcher, ElicitRequestSchema, ElicitResultSchema, EmptyResultSchema, extractTaskManagerOptions, GetPromptResultSchema, + getResultSchema, + GetTaskResultSchema, InitializeResultSchema, LATEST_PROTOCOL_VERSION, ListChangedOptionsBaseSchema, ListPromptsResultSchema, ListResourcesResultSchema, ListResourceTemplatesResultSchema, + ListTasksResultSchema, ListToolsResultSchema, mergeCapabilities, + NullTaskManager, parseSchema, - Protocol, ProtocolError, ProtocolErrorCode, ReadResourceResultSchema, + RELATED_TASK_META_KEY, SdkError, - SdkErrorCode + SdkErrorCode, + SUPPORTED_PROTOCOL_VERSIONS, + TaskManager } from '@modelcontextprotocol/core'; import { ExperimentalClientTasks } from '../experimental/tasks/client.js'; +import type { ClientFetchOptions, ClientTransport } from './clientTransport.js'; +import { channelAsClientTransport, isChannelTransport, isJSONRPCErrorResponse } from './clientTransport.js'; /** * Elicitation default application helper. Applies defaults to the `data` based on the `schema`. - * - * @param schema - The schema to apply defaults to. - * @param data - The data to apply defaults to. */ function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unknown): void { if (!schema || data === null || typeof data !== 'object') return; - // Handle object properties if (schema.type === 'object' && schema.properties && typeof schema.properties === 'object') { const obj = data as Record; const props = schema.properties as Record; for (const key of Object.keys(props)) { const propSchema = props[key]!; - // If missing or explicitly undefined, apply default if present if (obj[key] === undefined && Object.prototype.hasOwnProperty.call(propSchema, 'default')) { obj[key] = propSchema.default; } - // Recurse into existing nested objects/arrays if (obj[key] !== undefined) { applyElicitationDefaults(propSchema, obj[key]); } @@ -96,20 +113,12 @@ function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unkn if (Array.isArray(schema.anyOf)) { for (const sub of schema.anyOf) { - // Skip boolean schemas (true/false are valid JSON Schemas but have no defaults) - if (typeof sub !== 'boolean') { - applyElicitationDefaults(sub, data); - } + if (typeof sub !== 'boolean') applyElicitationDefaults(sub, data); } } - - // Combine schemas if (Array.isArray(schema.oneOf)) { for (const sub of schema.oneOf) { - // Skip boolean schemas (true/false are valid JSON Schemas but have no defaults) - if (typeof sub !== 'boolean') { - applyElicitationDefaults(sub, data); - } + if (typeof sub !== 'boolean') applyElicitationDefaults(sub, data); } } } @@ -117,12 +126,8 @@ function applyElicitationDefaults(schema: JsonSchemaType | undefined, data: unkn /** * Determines which elicitation modes are supported based on declared client capabilities. * - * According to the spec: * - An empty elicitation capability object defaults to form mode support (backwards compatibility) * - URL mode is only supported if explicitly declared - * - * @param capabilities - The client's elicitation capabilities - * @returns An object indicating which modes are supported */ export function getSupportedElicitationModes(capabilities: ClientCapabilities['elicitation']): { supportsFormMode: boolean; @@ -131,17 +136,48 @@ export function getSupportedElicitationModes(capabilities: ClientCapabilities['e if (!capabilities) { return { supportsFormMode: false, supportsUrlMode: false }; } - const hasFormCapability = capabilities.form !== undefined; const hasUrlCapability = capabilities.url !== undefined; - - // If neither form nor url are explicitly declared, form mode is supported (backwards compatibility) const supportsFormMode = hasFormCapability || (!hasFormCapability && !hasUrlCapability); const supportsUrlMode = hasUrlCapability; - return { supportsFormMode, supportsUrlMode }; } +/** + * Runtime guard for the polymorphic `tools/call` (and per SEP-2557, any + * task-capable method) result. SEP-2557 lets servers return a task even when + * the client did not request one. + */ +function isCreateTaskResult(r: unknown): r is CreateTaskResult { + return ( + typeof r === 'object' && + r !== null && + typeof (r as { task?: unknown }).task === 'object' && + (r as { task?: unknown }).task !== null && + typeof (r as { task: { taskId?: unknown } }).task.taskId === 'string' + ); +} + +/** + * Loose envelope for the SEP-2322 MRTR `input_required` result. Typed minimally + * (field names not yet finalized in the spec); runtime detection is by shape. + */ +type InputRequiredEnvelope = { + ResultType: 'input_required'; + InputRequests: Record }>; +}; +function isInputRequired(r: unknown): r is InputRequiredEnvelope { + return ( + typeof r === 'object' && + r !== null && + (r as { ResultType?: unknown }).ResultType === 'input_required' && + typeof (r as { InputRequests?: unknown }).InputRequests === 'object' + ); +} + +const MRTR_INPUT_RESPONSES_META_KEY = 'modelcontextprotocol.io/mrtr/inputResponses'; +const DEFAULT_MRTR_MAX_ROUNDS = 16; + /** * Extended tasks capability that includes runtime configuration (store, messageQueue). * The runtime-only fields are stripped before advertising capabilities to servers. @@ -149,920 +185,845 @@ export function getSupportedElicitationModes(capabilities: ClientCapabilities['e export type ClientTasksCapabilityWithRuntime = NonNullable & TaskManagerOptions; export type ClientOptions = ProtocolOptions & { - /** - * Capabilities to advertise as being supported by this client. - */ + /** Capabilities to advertise to the server. */ capabilities?: Omit & { tasks?: ClientTasksCapabilityWithRuntime; }; - - /** - * JSON Schema validator for tool output validation. - * - * The validator is used to validate structured content returned by tools - * against their declared output schemas. - * - * @default {@linkcode DefaultJsonSchemaValidator} ({@linkcode index.AjvJsonSchemaValidator | AjvJsonSchemaValidator} on Node.js, `CfWorkerJsonSchemaValidator` on Cloudflare Workers) - */ + /** Validator for tool `outputSchema`. Defaults to the runtime-appropriate Ajv/CF validator. */ jsonSchemaValidator?: jsonSchemaValidator; - + /** Handlers for `notifications/*_list_changed`. */ + listChanged?: ListChangedHandlers; /** - * Configure handlers for list changed notifications (tools, prompts, resources). - * - * @example - * ```ts source="./client.examples.ts#ClientOptions_listChanged" - * const client = new Client( - * { name: 'my-client', version: '1.0.0' }, - * { - * listChanged: { - * tools: { - * onChanged: (error, tools) => { - * if (error) { - * console.error('Failed to refresh tools:', error); - * return; - * } - * console.log('Tools updated:', tools); - * } - * }, - * prompts: { - * onChanged: (error, prompts) => console.log('Prompts updated:', prompts) - * } - * } - * } - * ); - * ``` + * Upper bound on MRTR rounds for one logical request before throwing + * {@linkcode SdkErrorCode.InternalError}. Default 16. */ - listChanged?: ListChangedHandlers; + mrtrMaxRounds?: number; }; /** - * An MCP client on top of a pluggable transport. - * - * The client will automatically begin the initialization flow with the server when {@linkcode connect} is called. + * MCP client built on a request-shaped {@linkcode ClientTransport}. * + * Every request is independent; `request()` runs the SEP-2322 MRTR loop, + * servicing `input_required` rounds via locally registered handlers. + * {@linkcode connect} also accepts a legacy pipe-shaped {@linkcode Transport} + * and runs the 2025-11 initialize handshake for back-compat. */ -export class Client extends Protocol { +export class Client extends Dispatcher { + private _clientTransport?: ClientTransport; + private _capabilities: ClientCapabilities; private _serverCapabilities?: ServerCapabilities; private _serverVersion?: Implementation; - private _negotiatedProtocolVersion?: string; - private _capabilities: ClientCapabilities; private _instructions?: string; + private _negotiatedProtocolVersion?: string; + private _supportedProtocolVersions: string[]; + private _enforceStrictCapabilities: boolean; + private _mrtrMaxRounds: number; private _jsonSchemaValidator: jsonSchemaValidator; private _cachedToolOutputValidators: Map> = new Map(); private _cachedKnownTaskTools: Set = new Set(); private _cachedRequiredTaskTools: Set = new Set(); + private _requestMessageId = 0; + private _pendingListChangedConfig?: ListChangedHandlers; private _experimental?: { tasks: ExperimentalClientTasks }; private _listChangedDebounceTimers: Map> = new Map(); - private _pendingListChangedConfig?: ListChangedHandlers; - private _enforceStrictCapabilities: boolean; + private _taskManager: TaskManager; + + onclose?: () => void; + onerror?: (error: Error) => void; - /** - * Initializes this client with the given name and version information. - */ constructor( private _clientInfo: Implementation, - options?: ClientOptions + private _options?: ClientOptions ) { - super({ - ...options, - tasks: extractTaskManagerOptions(options?.capabilities?.tasks) + super(); + this._capabilities = _options?.capabilities ? { ..._options.capabilities } : {}; + this._jsonSchemaValidator = _options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); + this._supportedProtocolVersions = _options?.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; + this._enforceStrictCapabilities = _options?.enforceStrictCapabilities ?? false; + this._mrtrMaxRounds = _options?.mrtrMaxRounds ?? DEFAULT_MRTR_MAX_ROUNDS; + this._pendingListChangedConfig = _options?.listChanged; + + const tasksOpts = extractTaskManagerOptions(_options?.capabilities?.tasks); + this._taskManager = tasksOpts ? new TaskManager(tasksOpts) : new NullTaskManager(); + this._taskManager.attachTo(this, { + channel: () => + this._clientTransport + ? { + request: (r, schema, opts) => this._request(r, schema, opts), + notification: (n, opts) => this.notification(n, opts), + close: () => this.close(), + removeProgressHandler: t => this._clientTransport?.driver?.removeProgressHandler(t) + } + : undefined, + reportError: e => this.onerror?.(e), + enforceStrictCapabilities: this._enforceStrictCapabilities, + assertTaskCapability: () => {}, + assertTaskHandlerCapability: () => {} }); - this._capabilities = options?.capabilities ? { ...options.capabilities } : {}; - this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); - this._enforceStrictCapabilities = options?.enforceStrictCapabilities ?? false; // Strip runtime-only fields from advertised capabilities - if (options?.capabilities?.tasks) { + if (_options?.capabilities?.tasks) { // eslint-disable-next-line @typescript-eslint/no-unused-vars const { taskStore, taskMessageQueue, defaultTaskPollInterval, maxTaskQueueSize, ...wireCapabilities } = - options.capabilities.tasks; + _options.capabilities.tasks; this._capabilities.tasks = wireCapabilities; } - // Store list changed config for setup after connection (when we know server capabilities) - if (options?.listChanged) { - this._pendingListChangedConfig = options.listChanged; - } - } - - protected override buildContext(ctx: BaseContext, _transportInfo?: MessageExtraInfo): ClientContext { - return ctx; + super.setRequestHandler('ping', async () => ({})); } /** - * Set up handlers for list changed notifications based on config and server capabilities. - * This should only be called after initialization when server capabilities are known. - * Handlers are silently skipped if the server doesn't advertise the corresponding listChanged capability. - * @internal + * Connects to a server. Accepts either a request-shaped {@linkcode ClientTransport} + * or a legacy pipe {@linkcode Transport} (stdio, SSE, the v1 SHTTP class). + * Pipe transports are adapted via {@linkcode channelAsClientTransport} and + * the 2025-11 initialize handshake is performed. */ - private _setupListChangedHandlers(config: ListChangedHandlers): void { - if (config.tools && this._serverCapabilities?.tools?.listChanged) { - this._setupListChangedHandler('tools', 'notifications/tools/list_changed', config.tools, async () => { - const result = await this.listTools(); - return result.tools; - }); + async connect(transport: Transport | ClientTransport, options?: RequestOptions): Promise { + if (isChannelTransport(transport)) { + const driverOpts: StreamDriverOptions = { + supportedProtocolVersions: this._supportedProtocolVersions, + debouncedNotificationMethods: this._options?.debouncedNotificationMethods + }; + this._clientTransport = channelAsClientTransport(transport, this, driverOpts); + this._clientTransport.driver!.onresponse = (r, id) => this._taskManager.processInboundResponse(r, id); + this._clientTransport.driver!.onclose = () => { + this._taskManager.onClose(); + this.onclose?.(); + }; + this._clientTransport.driver!.onerror = e => this.onerror?.(e); + const skipInit = transport.sessionId !== undefined; + if (skipInit) { + if (this._negotiatedProtocolVersion && transport.setProtocolVersion) { + transport.setProtocolVersion(this._negotiatedProtocolVersion); + } + return; + } + try { + await this._initializeHandshake(options, v => transport.setProtocolVersion?.(v)); + } catch (error) { + void this.close(); + throw error; + } + return; } - - if (config.prompts && this._serverCapabilities?.prompts?.listChanged) { - this._setupListChangedHandler('prompts', 'notifications/prompts/list_changed', config.prompts, async () => { - const result = await this.listPrompts(); - return result.prompts; - }); + this._clientTransport = transport; + const t = transport as { sessionId?: string; setProtocolVersion?: (v: string) => void }; + const setProtocolVersion = (v: string) => t.setProtocolVersion?.(v); + if (t.sessionId !== undefined) { + if (this._negotiatedProtocolVersion) setProtocolVersion(this._negotiatedProtocolVersion); + this._startStandaloneStream(); + return; } - - if (config.resources && this._serverCapabilities?.resources?.listChanged) { - this._setupListChangedHandler('resources', 'notifications/resources/list_changed', config.resources, async () => { - const result = await this.listResources(); - return result.resources; - }); + try { + await this._discoverOrInitialize(options, setProtocolVersion); + } catch (error) { + void this.close(); + throw error; } + this._startStandaloneStream(); } /** - * Access experimental features. - * - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental + * Open the optional standalone server→client stream (e.g. SHTTP GET SSE) so + * server-initiated requests (elicitation/sampling/roots) and unsolicited + * notifications reach this client when going through the request-shaped + * {@linkcode ClientTransport} path. No-op if the transport doesn't support it. */ - get experimental(): { tasks: ExperimentalClientTasks } { - if (!this._experimental) { - this._experimental = { - tasks: new ExperimentalClientTasks(this) - }; - } - return this._experimental; + private _startStandaloneStream(): void { + const ct = this._clientTransport; + if (!ct?.subscribe) return; + void (async () => { + try { + const stream = ct.subscribe!({ + onrequest: async r => { + let resp: JSONRPCResultResponse | JSONRPCErrorResponse | undefined; + for await (const out of this.dispatch(r)) { + if (out.kind === 'response') resp = out.message; + } + return resp ?? { jsonrpc: '2.0', id: r.id, error: { code: -32_601, message: 'Method not found' } }; + }, + onresponse: r => { + const consumed = this._taskManager.processInboundResponse(r, Number(r.id)).consumed; + if (!consumed) this.onerror?.(new Error(`Unmatched response on standalone stream: ${JSON.stringify(r)}`)); + } + }); + for await (const n of stream) { + void this.dispatchNotification(n).catch(error => this.onerror?.(error)); + } + } catch (error) { + this.onerror?.(error instanceof Error ? error : new Error(String(error))); + } + })(); } - /** - * Registers new capabilities. This can only be called before connecting to a transport. - * - * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). - */ - public registerCapabilities(capabilities: ClientCapabilities): void { - if (this.transport) { - throw new Error('Cannot register capabilities after connecting to transport'); - } + async close(): Promise { + const ct = this._clientTransport; + this._clientTransport = undefined; + for (const t of this._listChangedDebounceTimers.values()) clearTimeout(t); + this._listChangedDebounceTimers.clear(); + // For pipe transports, driver.onclose (wired in connect) fires this.onclose. + // For ClientTransport (no driver), fire it here. + const fireOnclose = !ct?.driver; + await ct?.close(); + if (fireOnclose) this.onclose?.(); + } + + get transport(): Transport | undefined { + return this._clientTransport?.driver?.pipe; + } + /** Register additional capabilities. Must be called before {@linkcode connect}. */ + registerCapabilities(capabilities: ClientCapabilities): void { + if (this._clientTransport) throw new Error('Cannot register capabilities after connecting to transport'); this._capabilities = mergeCapabilities(this._capabilities, capabilities); } + getServerCapabilities(): ServerCapabilities | undefined { + return this._serverCapabilities; + } + getServerVersion(): Implementation | undefined { + return this._serverVersion; + } + getNegotiatedProtocolVersion(): string | undefined { + return this._negotiatedProtocolVersion; + } + getInstructions(): string | undefined { + return this._instructions; + } + /** - * Registers a handler for server-initiated requests (sampling, elicitation, roots). - * The client must declare the corresponding capability for the handler to be accepted. - * Replaces any previously registered handler for the same method. + * Register a handler for server-initiated requests (sampling, elicitation, + * roots, ping). In MRTR mode these handlers service `input_required` rounds. + * In pipe mode they are dispatched directly by the {@linkcode StreamDriver}. * * For `sampling/createMessage` and `elicitation/create`, the handler is automatically * wrapped with schema validation for both the incoming request and the returned result. - * - * @example Handling a sampling request - * ```ts source="./client.examples.ts#Client_setRequestHandler_sampling" - * client.setRequestHandler('sampling/createMessage', async request => { - * const lastMessage = request.params.messages.at(-1); - * console.log('Sampling request:', lastMessage); - * - * // In production, send messages to your LLM here - * return { - * model: 'my-model', - * role: 'assistant' as const, - * content: { - * type: 'text' as const, - * text: 'Response from the model' - * } - * }; - * }); - * ``` */ - public override setRequestHandler( + override setRequestHandler( method: M, handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise + ): void; + override setRequestHandler( + method: string, + paramsSchema: S, + handler: (params: StandardSchemaV1.InferOutput, ctx: ClientContext) => Result | Promise + ): void; + /** @deprecated Pass a method string instead of a Zod request schema. */ + override setRequestHandler( + schema: S, + handler: ( + request: S extends StandardSchemaV1 ? O : JSONRPCRequest, + ctx: ClientContext + ) => Result | Promise + ): void; + override setRequestHandler( + methodOrSchema: string | { shape: { method: unknown } }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + handlerOrSchema: any, + maybeHandler?: (params: unknown, ctx: ClientContext) => Result | Promise ): void { - if (method === 'elicitation/create') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { - const validatedRequest = parseSchema(ElicitRequestSchema, request); - if (!validatedRequest.success) { - // Type guard: if success is false, error is guaranteed to exist - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation request: ${errorMessage}`); - } - - const { params } = validatedRequest.data; - params.mode = params.mode ?? 'form'; - const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(this._capabilities.elicitation); + if (maybeHandler !== undefined) { + const customMethod = methodOrSchema as string; + this._assertRequestHandlerCapability(customMethod); + super.setRequestHandler(customMethod, handlerOrSchema, maybeHandler); + return; + } + const handler = handlerOrSchema; + const method = ( + typeof methodOrSchema === 'string' + ? methodOrSchema + : ((methodOrSchema.shape.method as { value?: string })?.value ?? + (methodOrSchema.shape.method as { _def?: { value?: string } })?._def?.value) + ) as RequestMethod; + this._assertRequestHandlerCapability(method); - if (params.mode === 'form' && !supportsFormMode) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests'); - } + if (method === 'elicitation/create') { + super.setRequestHandler(method, this._wrapElicitationHandler(handler)); + return; + } + if (method === 'sampling/createMessage') { + super.setRequestHandler(method, this._wrapSamplingHandler(handler)); + return; + } + super.setRequestHandler(method, handler); + } - if (params.mode === 'url' && !supportsUrlMode) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); - } + /** Low-level: send one typed request. Runs the MRTR loop. */ + request( + req: { method: M; params?: RequestTypeMap[M]['params'] }, + options?: RequestOptions + ): Promise; + /** @deprecated Pass options as the second argument; the result schema is inferred from `req.method`. */ + request unknown }>( + req: { method: string; params?: Record }, + resultSchema: S, + options?: RequestOptions + ): Promise>; + async request( + req: { method: string; params?: Record }, + schemaOrOptions?: RequestOptions | { parse: (v: unknown) => unknown }, + maybeOptions?: RequestOptions + ) { + const isSchema = schemaOrOptions != null && typeof (schemaOrOptions as { parse?: unknown }).parse === 'function'; + const options = isSchema ? maybeOptions : (schemaOrOptions as RequestOptions | undefined); + const schema = isSchema ? (schemaOrOptions as AnySchema) : getResultSchema(req.method as RequestMethod); + return this._request({ method: req.method, params: req.params }, schema, options); + } - const result = await Promise.resolve(handler(request, ctx)); + /** Low-level: send a notification to the server. */ + async notification(n: Notification, _options?: NotificationOptions): Promise { + if (!this._clientTransport) throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); + if (this._enforceStrictCapabilities) this._assertNotificationCapability(n.method as NotificationMethod); + await this._clientTransport.notify(n); + } - // When task creation is requested, validate and return CreateTaskResult - if (params.task) { - const taskValidationResult = parseSchema(CreateTaskResultSchema, result); - if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); - } - return taskValidationResult.data; - } + // -- typed RPC sugar ------------------------------------------------------ - // For non-task requests, validate against ElicitResultSchema - const validationResult = parseSchema(ElicitResultSchema, result); - if (!validationResult.success) { - // Type guard: if success is false, error is guaranteed to exist - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation result: ${errorMessage}`); + async ping(options?: RequestOptions) { + return this._request({ method: 'ping' }, EmptyResultSchema, options); + } + async complete(params: CompleteRequest['params'], options?: RequestOptions) { + return this._request({ method: 'completion/complete', params }, CompleteResultSchema, options); + } + async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) { + return this._request({ method: 'logging/setLevel', params: { level } }, EmptyResultSchema, options); + } + async getPrompt(params: GetPromptRequest['params'], options?: RequestOptions) { + return this._request({ method: 'prompts/get', params }, GetPromptResultSchema, options); + } + async listPrompts(params?: ListPromptsRequest['params'], options?: RequestOptions) { + if (!this._serverCapabilities?.prompts && !this._enforceStrictCapabilities) return { prompts: [] }; + return this._request({ method: 'prompts/list', params }, ListPromptsResultSchema, options); + } + async listResources(params?: ListResourcesRequest['params'], options?: RequestOptions) { + if (!this._serverCapabilities?.resources && !this._enforceStrictCapabilities) return { resources: [] }; + return this._request({ method: 'resources/list', params }, ListResourcesResultSchema, options); + } + async listResourceTemplates(params?: ListResourceTemplatesRequest['params'], options?: RequestOptions) { + if (!this._serverCapabilities?.resources && !this._enforceStrictCapabilities) return { resourceTemplates: [] }; + return this._request({ method: 'resources/templates/list', params }, ListResourceTemplatesResultSchema, options); + } + async readResource(params: ReadResourceRequest['params'], options?: RequestOptions) { + return this._request({ method: 'resources/read', params }, ReadResourceResultSchema, options); + } + async subscribeResource(params: SubscribeRequest['params'], options?: RequestOptions) { + return this._request({ method: 'resources/subscribe', params }, EmptyResultSchema, options); + } + async unsubscribeResource(params: UnsubscribeRequest['params'], options?: RequestOptions) { + return this._request({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); + } + async listTools(params?: ListToolsRequest['params'], options?: RequestOptions) { + if (!this._serverCapabilities?.tools && !this._enforceStrictCapabilities) return { tools: [] }; + const result = await this._request({ method: 'tools/list', params }, ListToolsResultSchema, options); + this._cacheToolMetadata(result.tools); + return result; + } + async callTool( + params: CallToolRequest['params'], + options: RequestOptions & { task: NonNullable } + ): Promise; + async callTool(params: CallToolRequest['params'], options?: RequestOptions & { awaitTask?: boolean }): Promise; + async callTool( + params: CallToolRequest['params'], + options?: RequestOptions & { awaitTask?: boolean } + ): Promise { + if (this._cachedRequiredTaskTools.has(params.name) && !options?.task && !options?.awaitTask) { + throw new ProtocolError( + ProtocolErrorCode.InvalidRequest, + `Tool "${params.name}" requires task-based execution. Use client.experimental.tasks.callToolStream() or pass {awaitTask: true}.` + ); + } + const raw = await this._requestRaw({ method: 'tools/call', params }, options); + // SEP-2557: server may return a task even when not requested. With options.task + // the caller asked for it; with awaitTask we poll to completion; otherwise (truly + // unsolicited) throw with guidance so the v1 callTool() return type stays CallToolResult. + if (isCreateTaskResult(raw)) { + if (options?.task) return raw; + if (options?.awaitTask) return this._pollTaskToCompletion(raw.task.taskId, options); + throw new ProtocolError( + ProtocolErrorCode.InvalidRequest, + `Server returned a task for "${params.name}". Pass {task: {...}} or {awaitTask: true}, or use client.experimental.tasks.callToolStream().` + ); + } + const parsed = parseSchema(CallToolResultSchema, raw); + if (!parsed.success) throw parsed.error; + const result = parsed.data; + const validator = this._cachedToolOutputValidators.get(params.name); + if (validator) { + if (!result.structuredContent && !result.isError) { + throw new ProtocolError( + ProtocolErrorCode.InvalidRequest, + `Tool ${params.name} has an output schema but did not return structured content` + ); + } + if (result.structuredContent) { + const v = validator(result.structuredContent); + if (!v.valid) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Structured content does not match the tool's output schema: ${v.errorMessage}` + ); } + } + } + return result; + } + async sendRootsListChanged() { + return this.notification({ method: 'notifications/roots/list_changed' }); + } - const validatedResult = validationResult.data; - const requestedSchema = params.mode === 'form' ? (params.requestedSchema as JsonSchemaType) : undefined; + // -- tasks (SEP-1686 / SEP-2557) ----------------------------------------- + // Kept isolated: typed RPCs + the polymorphism check in callTool above. The + // streaming/polling helpers live in {@linkcode ExperimentalClientTasks}. - if ( - params.mode === 'form' && - validatedResult.action === 'accept' && - validatedResult.content && - requestedSchema && - this._capabilities.elicitation?.form?.applyDefaults - ) { - try { - applyElicitationDefaults(requestedSchema, validatedResult.content); - } catch { - // gracefully ignore errors in default application - } - } + async getTask(params: GetTaskRequest['params'], options?: RequestOptions) { + return this._request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + } + async listTasks(params?: ListTasksRequest['params'], options?: RequestOptions) { + return this._request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + } + async cancelTask(params: CancelTaskRequest['params'], options?: RequestOptions) { + return this._request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); + } - return validatedResult; - }; + /** + * This client's {@linkcode TaskManager}. Owned here (not by the transport adapter). + */ + get taskManager(): TaskManager { + return this._taskManager; + } - // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); + /** + * Access experimental task helpers (callToolStream, getTaskResult, ...). + * + * @experimental + */ + get experimental(): { tasks: ExperimentalClientTasks } { + if (!this._experimental) { + this._experimental = { tasks: new ExperimentalClientTasks(this as never) }; } + return this._experimental; + } - if (method === 'sampling/createMessage') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { - const validatedRequest = parseSchema(CreateMessageRequestSchema, request); - if (!validatedRequest.success) { - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling request: ${errorMessage}`); + /** @internal structural compat for {@linkcode ExperimentalClientTasks} */ + private isToolTask(toolName: string): boolean { + return this._cachedKnownTaskTools.has(toolName); + } + /** @internal structural compat for {@linkcode ExperimentalClientTasks} */ + private getToolOutputValidator(toolName: string): JsonSchemaValidator | undefined { + return this._cachedToolOutputValidators.get(toolName); + } + + private async _pollTaskToCompletion(taskId: string, options?: RequestOptions): Promise { + // SEP-2557 collapses tasks/result into tasks/get; poll status, then + // fetch payload. Backoff is fixed-interval; the streaming variant lives + // in ExperimentalClientTasks. + const intervalMs = 500; + while (true) { + options?.signal?.throwIfAborted(); + const r: GetTaskResult = await this.getTask({ taskId }, options); + const status = r.status; + if (status === 'completed' || status === 'failed' || status === 'cancelled') { + try { + return await this._request({ method: 'tasks/result', params: { taskId } }, CallToolResultSchema, options); + } catch { + return { content: [], isError: status !== 'completed' }; } + } + await new Promise(resolve => setTimeout(resolve, intervalMs)); + } + } - const { params } = validatedRequest.data; + // -- internals ----------------------------------------------------------- - const result = await Promise.resolve(handler(request, ctx)); + /** @internal alias for {@linkcode ExperimentalClientTasks} structural compat */ + private _requestWithSchema(req: Request, resultSchema: T, options?: RequestOptions): Promise> { + return this._request(req, resultSchema, options); + } - // When task creation is requested, validate and return CreateTaskResult - if (params.task) { - const taskValidationResult = parseSchema(CreateTaskResultSchema, result); - if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); - } - return taskValidationResult.data; - } + private async _request(req: Request, resultSchema: T, options?: RequestOptions): Promise> { + const raw = await this._requestRaw(req, options); + const parsed = parseSchema(resultSchema, raw); + if (!parsed.success) throw parsed.error; + return parsed.data as SchemaOutput; + } - // For non-task requests, validate against appropriate schema based on tools presence - const hasTools = params.tools || params.toolChoice; - const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; - const validationResult = parseSchema(resultSchema, result); - if (!validationResult.success) { - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling result: ${errorMessage}`); + /** Like {@linkcode _request} but returns the unparsed result. Used where the result is polymorphic (e.g. SEP-2557 task results). */ + private async _requestRaw(req: Request, options?: RequestOptions): Promise { + if (!this._clientTransport) throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); + if (this._enforceStrictCapabilities) this._assertCapabilityForMethod(req.method as RequestMethod); + let inputResponses: Record = {}; + for (let round = 0; round < this._mrtrMaxRounds; round++) { + const id = this._requestMessageId++; + const meta = { + ...(req.params?._meta as Record | undefined), + ...(round > 0 ? { [MRTR_INPUT_RESPONSES_META_KEY]: inputResponses } : {}) + }; + const jr: JSONRPCRequest = { + jsonrpc: '2.0', + id, + method: req.method, + params: req.params || round > 0 ? { ...req.params, _meta: Object.keys(meta).length > 0 ? meta : undefined } : undefined + }; + // Thread task augmentation into request params (mirrors TaskManager.prepareOutboundRequest + // for the request-shaped path; the pipe path threads via StreamDriver.request). + if (options?.task) jr.params = { ...jr.params, task: options.task }; + if (options?.relatedTask) { + jr.params = { + ...jr.params, + _meta: { ...(jr.params?._meta as Record | undefined), [RELATED_TASK_META_KEY]: options.relatedTask } + }; + } + const opts: ClientFetchOptions = { + signal: options?.signal, + timeout: options?.timeout, + resetTimeoutOnProgress: options?.resetTimeoutOnProgress, + maxTotalTimeout: options?.maxTotalTimeout, + onprogress: options?.onprogress, + relatedRequestId: options?.relatedRequestId, + task: options?.task, + relatedTask: options?.relatedTask, + resumptionToken: options?.resumptionToken, + onresumptiontoken: options?.onresumptiontoken, + onnotification: n => void this.dispatchNotification(n).catch(error => this.onerror?.(error)), + onresponse: r => { + const consumed = this._taskManager.processInboundResponse(r, Number(r.id)).consumed; + if (!consumed) this.onerror?.(new Error(`Unmatched response on stream: ${JSON.stringify(r)}`)); + }, + onrequest: async r => { + let resp: JSONRPCResultResponse | JSONRPCErrorResponse | undefined; + for await (const out of this.dispatch(r)) { + if (out.kind === 'response') resp = out.message; + } + return resp ?? { jsonrpc: '2.0', id: r.id, error: { code: -32_601, message: 'Method not found' } }; } - - return validationResult.data; }; - - // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); + const resp = await this._clientTransport.fetch(jr, opts); + if (isJSONRPCErrorResponse(resp)) { + throw ProtocolError.fromError(resp.error.code, resp.error.message, resp.error.data); + } + const raw = resp.result; + if (isInputRequired(raw)) { + inputResponses = { ...inputResponses, ...(await this._serviceInputRequests(raw.InputRequests)) }; + continue; + } + return raw; } - - // Other handlers use default behavior - return super.setRequestHandler(method, handler); + throw new ProtocolError(ProtocolErrorCode.InternalError, `MRTR exceeded ${this._mrtrMaxRounds} rounds for ${req.method}`); } - protected assertCapability(capability: keyof ServerCapabilities, method: string): void { - if (!this._serverCapabilities?.[capability]) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support ${capability} (required for ${method})`); + private async _serviceInputRequests( + reqs: Record }> + ): Promise> { + const out: Record = {}; + for (const [key, ir] of Object.entries(reqs)) { + const synthetic: JSONRPCRequest = { jsonrpc: '2.0', id: `mrtr:${key}`, method: ir.method, params: ir.params }; + const resp = await this.dispatchToResponse(synthetic); + if (isJSONRPCErrorResponse(resp)) { + throw ProtocolError.fromError(resp.error.code, resp.error.message, resp.error.data); + } + out[key] = resp.result; } + return out; } - /** - * Connects to a server via the given transport and performs the MCP initialization handshake. - * - * @example Basic usage (stdio) - * ```ts source="./client.examples.ts#Client_connect_stdio" - * const client = new Client({ name: 'my-client', version: '1.0.0' }); - * const transport = new StdioClientTransport({ command: 'my-mcp-server' }); - * await client.connect(transport); - * ``` - * - * @example Streamable HTTP with SSE fallback - * ```ts source="./client.examples.ts#Client_connect_sseFallback" - * const baseUrl = new URL(url); - * - * try { - * // Try modern Streamable HTTP transport first - * const client = new Client({ name: 'my-client', version: '1.0.0' }); - * const transport = new StreamableHTTPClientTransport(baseUrl); - * await client.connect(transport); - * return { client, transport }; - * } catch { - * // Fall back to legacy SSE transport - * const client = new Client({ name: 'my-client', version: '1.0.0' }); - * const transport = new SSEClientTransport(baseUrl); - * await client.connect(transport); - * return { client, transport }; - * } - * ``` - */ - override async connect(transport: Transport, options?: RequestOptions): Promise { - await super.connect(transport); - // When transport sessionId is already set this means we are trying to reconnect. - // Restore the protocol version negotiated during the original initialize handshake - // so HTTP transports include the required mcp-protocol-version header, but skip re-init. - if (transport.sessionId !== undefined) { - if (this._negotiatedProtocolVersion !== undefined && transport.setProtocolVersion) { - transport.setProtocolVersion(this._negotiatedProtocolVersion); - } - return; + private async _initializeHandshake(options: RequestOptions | undefined, setProtocolVersion: (v: string) => void): Promise { + const result = await this._request( + { + method: 'initialize', + params: { + protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION, + capabilities: this._capabilities, + clientInfo: this._clientInfo + } + }, + InitializeResultSchema, + options + ); + if (!this._supportedProtocolVersions.includes(result.protocolVersion)) { + throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`); + } + this._serverCapabilities = result.capabilities; + this._serverVersion = result.serverInfo; + this._negotiatedProtocolVersion = result.protocolVersion; + this._instructions = result.instructions; + setProtocolVersion(result.protocolVersion); + await this.notification({ method: 'notifications/initialized' }); + if (this._pendingListChangedConfig) { + this._setupListChangedHandlers(this._pendingListChangedConfig); + this._pendingListChangedConfig = undefined; } + } + + private async _discoverOrInitialize(options: RequestOptions | undefined, setProtocolVersion: (v: string) => void): Promise { + // Try server/discover (SEP-2575 stateless), fall back to initialize. Discover schema + // is not yet in spec types, so probe and accept the result loosely. try { - const result = await this._requestWithSchema( - { - method: 'initialize', - params: { - protocolVersion: this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION, - capabilities: this._capabilities, - clientInfo: this._clientInfo - } - }, - InitializeResultSchema, - options + const resp = await this._clientTransport!.fetch( + { jsonrpc: '2.0', id: this._requestMessageId++, method: 'server/discover' as RequestMethod }, + { timeout: options?.timeout, signal: options?.signal } ); - - if (result === undefined) { - throw new Error(`Server sent invalid initialize result: ${result}`); - } - - if (!this._supportedProtocolVersions.includes(result.protocolVersion)) { - throw new Error(`Server's protocol version is not supported: ${result.protocolVersion}`); - } - - this._serverCapabilities = result.capabilities; - this._serverVersion = result.serverInfo; - this._negotiatedProtocolVersion = result.protocolVersion; - // HTTP transports must set the protocol version in each header after initialization. - if (transport.setProtocolVersion) { - transport.setProtocolVersion(result.protocolVersion); + if (!isJSONRPCErrorResponse(resp)) { + const r = resp.result as { + capabilities?: ServerCapabilities; + serverInfo?: Implementation; + instructions?: string; + protocolVersion?: string; + }; + // Only accept discover if the result is shaped like a real discover response; + // 2025-11 servers may return an empty/echo result for unknown methods. + if (r?.serverInfo) { + this._serverCapabilities = r.capabilities; + this._serverVersion = r.serverInfo; + this._instructions = r.instructions; + if (r.protocolVersion) setProtocolVersion(r.protocolVersion); + return; + } } + } catch { + // Any error from the discover probe falls through to initialize. + } + await this._initializeHandshake(options, setProtocolVersion); + } - this._instructions = result.instructions; - - await this.notification({ - method: 'notifications/initialized' - }); - - // Set up list changed handlers now that we know server capabilities - if (this._pendingListChangedConfig) { - this._setupListChangedHandlers(this._pendingListChangedConfig); - this._pendingListChangedConfig = undefined; + private _cacheToolMetadata(tools: Tool[]): void { + this._cachedToolOutputValidators.clear(); + this._cachedKnownTaskTools.clear(); + this._cachedRequiredTaskTools.clear(); + for (const tool of tools) { + if (tool.outputSchema) { + this._cachedToolOutputValidators.set( + tool.name, + this._jsonSchemaValidator.getValidator(tool.outputSchema as JsonSchemaType) + ); } - } catch (error) { - // Disconnect if initialization fails. - void this.close(); - throw error; + const ts = tool.execution?.taskSupport; + if (ts === 'required' || ts === 'optional') this._cachedKnownTaskTools.add(tool.name); + if (ts === 'required') this._cachedRequiredTaskTools.add(tool.name); } } - /** - * After initialization has completed, this will be populated with the server's reported capabilities. - */ - getServerCapabilities(): ServerCapabilities | undefined { - return this._serverCapabilities; + private _wrapElicitationHandler( + handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise + ) { + return async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { + const validatedRequest = parseSchema(ElicitRequestSchema, request); + if (!validatedRequest.success) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Invalid elicitation request: ${formatErr(validatedRequest.error)}` + ); + } + const { params } = validatedRequest.data; + params.mode = params.mode ?? 'form'; + const { supportsFormMode, supportsUrlMode } = getSupportedElicitationModes(this._capabilities.elicitation); + if (params.mode === 'form' && !supportsFormMode) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support form-mode elicitation requests'); + } + if (params.mode === 'url' && !supportsUrlMode) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests'); + } + const result = await Promise.resolve(handler(request, ctx)); + if (params.task) { + const tv = parseSchema(CreateTaskResultSchema, result); + if (!tv.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${formatErr(tv.error)}`); + } + return tv.data; + } + const vr = parseSchema(ElicitResultSchema, result); + if (!vr.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid elicitation result: ${formatErr(vr.error)}`); + } + const validatedResult = vr.data; + const requestedSchema = params.mode === 'form' ? (params.requestedSchema as JsonSchemaType) : undefined; + if ( + params.mode === 'form' && + validatedResult.action === 'accept' && + validatedResult.content && + requestedSchema && + this._capabilities.elicitation?.form?.applyDefaults + ) { + try { + applyElicitationDefaults(requestedSchema, validatedResult.content); + } catch { + // gracefully ignore errors in default application + } + } + return validatedResult; + }; } - /** - * After initialization has completed, this will be populated with information about the server's name and version. - */ - getServerVersion(): Implementation | undefined { - return this._serverVersion; + private _wrapSamplingHandler( + handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise + ) { + return async (request: RequestTypeMap[M], ctx: ClientContext): Promise => { + const validatedRequest = parseSchema(CreateMessageRequestSchema, request); + if (!validatedRequest.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling request: ${formatErr(validatedRequest.error)}`); + } + const { params } = validatedRequest.data; + const result = await Promise.resolve(handler(request, ctx)); + if (params.task) { + const tv = parseSchema(CreateTaskResultSchema, result); + if (!tv.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${formatErr(tv.error)}`); + } + return tv.data; + } + const hasTools = params.tools || params.toolChoice; + const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema; + const vr = parseSchema(resultSchema, result); + if (!vr.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid sampling result: ${formatErr(vr.error)}`); + } + return vr.data; + }; } - /** - * After initialization has completed, this will be populated with the protocol version negotiated - * during the initialize handshake. When manually reconstructing a transport for reconnection, pass this - * value to the new transport so it continues sending the required `mcp-protocol-version` header. - */ - getNegotiatedProtocolVersion(): string | undefined { - return this._negotiatedProtocolVersion; + private _setupListChangedHandlers(config: ListChangedHandlers): void { + if (config.tools && this._serverCapabilities?.tools?.listChanged) { + this._setupListChangedHandler('tools', 'notifications/tools/list_changed', config.tools, async () => { + const result = await this.listTools(); + return result.tools; + }); + } + if (config.prompts && this._serverCapabilities?.prompts?.listChanged) { + this._setupListChangedHandler('prompts', 'notifications/prompts/list_changed', config.prompts, async () => { + const result = await this.listPrompts(); + return result.prompts; + }); + } + if (config.resources && this._serverCapabilities?.resources?.listChanged) { + this._setupListChangedHandler('resources', 'notifications/resources/list_changed', config.resources, async () => { + const result = await this.listResources(); + return result.resources; + }); + } } - /** - * After initialization has completed, this may be populated with information about the server's instructions. - */ - getInstructions(): string | undefined { - return this._instructions; + private _setupListChangedHandler( + listType: string, + notificationMethod: NotificationMethod, + options: ListChangedOptions, + fetcher: () => Promise + ): void { + const parseResult = parseSchema(ListChangedOptionsBaseSchema, options); + if (!parseResult.success) { + throw new Error(`Invalid ${listType} listChanged options: ${parseResult.error.message}`); + } + if (typeof options.onChanged !== 'function') { + throw new TypeError(`Invalid ${listType} listChanged options: onChanged must be a function`); + } + const { autoRefresh, debounceMs } = parseResult.data; + const { onChanged } = options; + + const refresh = async () => { + if (!autoRefresh) { + onChanged(null, null); + return; + } + try { + onChanged(null, await fetcher()); + } catch (error) { + onChanged(error instanceof Error ? error : new Error(String(error)), null); + } + }; + + this.setNotificationHandler(notificationMethod, () => { + if (debounceMs) { + const existing = this._listChangedDebounceTimers.get(listType); + if (existing) clearTimeout(existing); + this._listChangedDebounceTimers.set(listType, setTimeout(refresh, debounceMs)); + } else { + void refresh(); + } + }); } - protected assertCapabilityForMethod(method: RequestMethod): void { + private _assertCapabilityForMethod(method: RequestMethod): void { switch (method as ClientRequest['method']) { case 'logging/setLevel': { - if (!this._serverCapabilities?.logging) { + if (!this._serverCapabilities?.logging) throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); - } break; } - case 'prompts/get': case 'prompts/list': { - if (!this._serverCapabilities?.prompts) { + if (!this._serverCapabilities?.prompts) throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support prompts (required for ${method})`); - } break; } - case 'resources/list': case 'resources/templates/list': case 'resources/read': case 'resources/subscribe': case 'resources/unsubscribe': { - if (!this._serverCapabilities?.resources) { + if (!this._serverCapabilities?.resources) throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support resources (required for ${method})`); - } - - if (method === 'resources/subscribe' && !this._serverCapabilities.resources.subscribe) { + if (method === 'resources/subscribe' && !this._serverCapabilities.resources.subscribe) throw new SdkError( SdkErrorCode.CapabilityNotSupported, `Server does not support resource subscriptions (required for ${method})` ); - } - break; } - case 'tools/call': case 'tools/list': { - if (!this._serverCapabilities?.tools) { + if (!this._serverCapabilities?.tools) throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support tools (required for ${method})`); - } break; } - case 'completion/complete': { - if (!this._serverCapabilities?.completions) { + if (!this._serverCapabilities?.completions) throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support completions (required for ${method})`); - } - break; - } - - case 'initialize': { - // No specific capability required for initialize - break; - } - - case 'ping': { - // No specific capability required for ping break; } } } - protected assertNotificationCapability(method: NotificationMethod): void { - switch (method as ClientNotification['method']) { - case 'notifications/roots/list_changed': { - if (!this._capabilities.roots?.listChanged) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Client does not support roots list changed notifications (required for ${method})` - ); - } - break; - } - - case 'notifications/initialized': { - // No specific capability required for initialized - break; - } - - case 'notifications/cancelled': { - // Cancellation notifications are always allowed - break; - } - - case 'notifications/progress': { - // Progress notifications are always allowed - break; - } + private _assertNotificationCapability(method: NotificationMethod): void { + if ((method as ClientNotification['method']) === 'notifications/roots/list_changed' && !this._capabilities.roots?.listChanged) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + `Client does not support roots list changed notifications (required for ${method})` + ); } } - protected assertRequestHandlerCapability(method: string): void { + private _assertRequestHandlerCapability(method: string): void { switch (method) { case 'sampling/createMessage': { - if (!this._capabilities.sampling) { + if (!this._capabilities.sampling) throw new SdkError( SdkErrorCode.CapabilityNotSupported, `Client does not support sampling capability (required for ${method})` ); - } break; } - case 'elicitation/create': { - if (!this._capabilities.elicitation) { + if (!this._capabilities.elicitation) throw new SdkError( SdkErrorCode.CapabilityNotSupported, `Client does not support elicitation capability (required for ${method})` ); - } break; } - case 'roots/list': { - if (!this._capabilities.roots) { + if (!this._capabilities.roots) throw new SdkError( SdkErrorCode.CapabilityNotSupported, `Client does not support roots capability (required for ${method})` ); - } break; } - - case 'ping': { - // No specific capability required for ping - break; - } - } - } - - protected assertTaskCapability(method: string): void { - assertToolsCallTaskCapability(this._serverCapabilities?.tasks?.requests, method, 'Server'); - } - - protected assertTaskHandlerCapability(method: string): void { - assertClientRequestTaskCapability(this._capabilities?.tasks?.requests, method, 'Client'); - } - - async ping(options?: RequestOptions) { - return this._requestWithSchema({ method: 'ping' }, EmptyResultSchema, options); - } - - /** Requests argument autocompletion suggestions from the server for a prompt or resource. */ - async complete(params: CompleteRequest['params'], options?: RequestOptions) { - return this._requestWithSchema({ method: 'completion/complete', params }, CompleteResultSchema, options); - } - - /** Sets the minimum severity level for log messages sent by the server. */ - async setLoggingLevel(level: LoggingLevel, options?: RequestOptions) { - return this._requestWithSchema({ method: 'logging/setLevel', params: { level } }, EmptyResultSchema, options); - } - - /** Retrieves a prompt by name from the server, passing the given arguments for template substitution. */ - async getPrompt(params: GetPromptRequest['params'], options?: RequestOptions) { - return this._requestWithSchema({ method: 'prompts/get', params }, GetPromptResultSchema, options); - } - - /** - * Lists available prompts. Results may be paginated — loop on `nextCursor` to collect all pages. - * - * Returns an empty list if the server does not advertise prompts capability - * (or throws if {@linkcode ClientOptions.enforceStrictCapabilities} is enabled). - * - * @example - * ```ts source="./client.examples.ts#Client_listPrompts_pagination" - * const allPrompts: Prompt[] = []; - * let cursor: string | undefined; - * do { - * const { prompts, nextCursor } = await client.listPrompts({ cursor }); - * allPrompts.push(...prompts); - * cursor = nextCursor; - * } while (cursor); - * console.log( - * 'Available prompts:', - * allPrompts.map(p => p.name) - * ); - * ``` - */ - async listPrompts(params?: ListPromptsRequest['params'], options?: RequestOptions) { - if (!this._serverCapabilities?.prompts && !this._enforceStrictCapabilities) { - // Respect capability negotiation: server does not support prompts - console.debug('Client.listPrompts() called but server does not advertise prompts capability - returning empty list'); - return { prompts: [] }; - } - return this._requestWithSchema({ method: 'prompts/list', params }, ListPromptsResultSchema, options); - } - - /** - * Lists available resources. Results may be paginated — loop on `nextCursor` to collect all pages. - * - * Returns an empty list if the server does not advertise resources capability - * (or throws if {@linkcode ClientOptions.enforceStrictCapabilities} is enabled). - * - * @example - * ```ts source="./client.examples.ts#Client_listResources_pagination" - * const allResources: Resource[] = []; - * let cursor: string | undefined; - * do { - * const { resources, nextCursor } = await client.listResources({ cursor }); - * allResources.push(...resources); - * cursor = nextCursor; - * } while (cursor); - * console.log( - * 'Available resources:', - * allResources.map(r => r.name) - * ); - * ``` - */ - async listResources(params?: ListResourcesRequest['params'], options?: RequestOptions) { - if (!this._serverCapabilities?.resources && !this._enforceStrictCapabilities) { - // Respect capability negotiation: server does not support resources - console.debug('Client.listResources() called but server does not advertise resources capability - returning empty list'); - return { resources: [] }; - } - return this._requestWithSchema({ method: 'resources/list', params }, ListResourcesResultSchema, options); - } - - /** - * Lists available resource URI templates for dynamic resources. Results may be paginated — see {@linkcode listResources | listResources()} for the cursor pattern. - * - * Returns an empty list if the server does not advertise resources capability - * (or throws if {@linkcode ClientOptions.enforceStrictCapabilities} is enabled). - */ - async listResourceTemplates(params?: ListResourceTemplatesRequest['params'], options?: RequestOptions) { - if (!this._serverCapabilities?.resources && !this._enforceStrictCapabilities) { - // Respect capability negotiation: server does not support resources - console.debug( - 'Client.listResourceTemplates() called but server does not advertise resources capability - returning empty list' - ); - return { resourceTemplates: [] }; - } - return this._requestWithSchema({ method: 'resources/templates/list', params }, ListResourceTemplatesResultSchema, options); - } - - /** Reads the contents of a resource by URI. */ - async readResource(params: ReadResourceRequest['params'], options?: RequestOptions) { - return this._requestWithSchema({ method: 'resources/read', params }, ReadResourceResultSchema, options); - } - - /** Subscribes to change notifications for a resource. The server must support resource subscriptions. */ - async subscribeResource(params: SubscribeRequest['params'], options?: RequestOptions) { - return this._requestWithSchema({ method: 'resources/subscribe', params }, EmptyResultSchema, options); - } - - /** Unsubscribes from change notifications for a resource. */ - async unsubscribeResource(params: UnsubscribeRequest['params'], options?: RequestOptions) { - return this._requestWithSchema({ method: 'resources/unsubscribe', params }, EmptyResultSchema, options); - } - - /** - * Calls a tool on the connected server and returns the result. Automatically validates structured output - * if the tool has an `outputSchema`. - * - * Tool results have two error surfaces: `result.isError` for tool-level failures (the tool ran but reported - * a problem), and thrown {@linkcode ProtocolError} for protocol-level failures or {@linkcode SdkError} for - * SDK-level issues (timeouts, missing capabilities). - * - * For task-based execution with streaming behavior, use {@linkcode ExperimentalClientTasks.callToolStream | client.experimental.tasks.callToolStream()} instead. - * - * @example Basic usage - * ```ts source="./client.examples.ts#Client_callTool_basic" - * const result = await client.callTool({ - * name: 'calculate-bmi', - * arguments: { weightKg: 70, heightM: 1.75 } - * }); - * - * // Tool-level errors are returned in the result, not thrown - * if (result.isError) { - * console.error('Tool error:', result.content); - * return; - * } - * - * console.log(result.content); - * ``` - * - * @example Structured output - * ```ts source="./client.examples.ts#Client_callTool_structuredOutput" - * const result = await client.callTool({ - * name: 'calculate-bmi', - * arguments: { weightKg: 70, heightM: 1.75 } - * }); - * - * // Machine-readable output for the client application - * if (result.structuredContent) { - * console.log(result.structuredContent); // e.g. { bmi: 22.86 } - * } - * ``` - */ - async callTool(params: CallToolRequest['params'], options?: RequestOptions) { - // Guard: required-task tools need experimental API - if (this.isToolTaskRequired(params.name)) { - throw new ProtocolError( - ProtocolErrorCode.InvalidRequest, - `Tool "${params.name}" requires task-based execution. Use client.experimental.tasks.callToolStream() instead.` - ); - } - - const result = await this._requestWithSchema({ method: 'tools/call', params }, CallToolResultSchema, options); - - // Check if the tool has an outputSchema - const validator = this.getToolOutputValidator(params.name); - if (validator) { - // If tool has outputSchema, it MUST return structuredContent (unless it's an error) - if (!result.structuredContent && !result.isError) { - throw new ProtocolError( - ProtocolErrorCode.InvalidRequest, - `Tool ${params.name} has an output schema but did not return structured content` - ); - } - - // Only validate structured content if present (not when there's an error) - if (result.structuredContent) { - try { - // Validate the structured content against the schema - const validationResult = validator(result.structuredContent); - - if (!validationResult.valid) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Structured content does not match the tool's output schema: ${validationResult.errorMessage}` - ); - } - } catch (error) { - if (error instanceof ProtocolError) { - throw error; - } - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Failed to validate structured content: ${error instanceof Error ? error.message : String(error)}` - ); - } - } } - - return result; - } - - private isToolTask(toolName: string): boolean { - if (!this._serverCapabilities?.tasks?.requests?.tools?.call) { - return false; - } - - return this._cachedKnownTaskTools.has(toolName); - } - - /** - * Check if a tool requires task-based execution. - * Unlike {@linkcode isToolTask} which includes `'optional'` tools, this only checks for `'required'`. - */ - private isToolTaskRequired(toolName: string): boolean { - return this._cachedRequiredTaskTools.has(toolName); - } - - /** - * Cache validators for tool output schemas. - * Called after {@linkcode listTools | listTools()} to pre-compile validators for better performance. - */ - private cacheToolMetadata(tools: Tool[]): void { - this._cachedToolOutputValidators.clear(); - this._cachedKnownTaskTools.clear(); - this._cachedRequiredTaskTools.clear(); - - for (const tool of tools) { - // If the tool has an outputSchema, create and cache the validator - if (tool.outputSchema) { - const toolValidator = this._jsonSchemaValidator.getValidator(tool.outputSchema as JsonSchemaType); - this._cachedToolOutputValidators.set(tool.name, toolValidator); - } - - // If the tool supports task-based execution, cache that information - const taskSupport = tool.execution?.taskSupport; - if (taskSupport === 'required' || taskSupport === 'optional') { - this._cachedKnownTaskTools.add(tool.name); - } - if (taskSupport === 'required') { - this._cachedRequiredTaskTools.add(tool.name); - } - } - } - - /** - * Get cached validator for a tool - */ - private getToolOutputValidator(toolName: string): JsonSchemaValidator | undefined { - return this._cachedToolOutputValidators.get(toolName); - } - - /** - * Lists available tools. Results may be paginated — loop on `nextCursor` to collect all pages. - * - * Returns an empty list if the server does not advertise tools capability - * (or throws if {@linkcode ClientOptions.enforceStrictCapabilities} is enabled). - * - * @example - * ```ts source="./client.examples.ts#Client_listTools_pagination" - * const allTools: Tool[] = []; - * let cursor: string | undefined; - * do { - * const { tools, nextCursor } = await client.listTools({ cursor }); - * allTools.push(...tools); - * cursor = nextCursor; - * } while (cursor); - * console.log( - * 'Available tools:', - * allTools.map(t => t.name) - * ); - * ``` - */ - async listTools(params?: ListToolsRequest['params'], options?: RequestOptions) { - if (!this._serverCapabilities?.tools && !this._enforceStrictCapabilities) { - // Respect capability negotiation: server does not support tools - console.debug('Client.listTools() called but server does not advertise tools capability - returning empty list'); - return { tools: [] }; - } - const result = await this._requestWithSchema({ method: 'tools/list', params }, ListToolsResultSchema, options); - - // Cache the tools and their output schemas for future validation - this.cacheToolMetadata(result.tools); - - return result; - } - - /** - * Set up a single list changed handler. - * @internal - */ - private _setupListChangedHandler( - listType: string, - notificationMethod: NotificationMethod, - options: ListChangedOptions, - fetcher: () => Promise - ): void { - // Validate options using Zod schema (validates autoRefresh and debounceMs) - const parseResult = parseSchema(ListChangedOptionsBaseSchema, options); - if (!parseResult.success) { - throw new Error(`Invalid ${listType} listChanged options: ${parseResult.error.message}`); - } - - // Validate callback - if (typeof options.onChanged !== 'function') { - throw new TypeError(`Invalid ${listType} listChanged options: onChanged must be a function`); - } - - const { autoRefresh, debounceMs } = parseResult.data; - const { onChanged } = options; - - const refresh = async () => { - if (!autoRefresh) { - onChanged(null, null); - return; - } - - try { - const items = await fetcher(); - onChanged(null, items); - } catch (error) { - const newError = error instanceof Error ? error : new Error(String(error)); - onChanged(newError, null); - } - }; - - const handler = () => { - if (debounceMs) { - // Clear any pending debounce timer for this list type - const existingTimer = this._listChangedDebounceTimers.get(listType); - if (existingTimer) { - clearTimeout(existingTimer); - } - - // Set up debounced refresh - const timer = setTimeout(refresh, debounceMs); - this._listChangedDebounceTimers.set(listType, timer); - } else { - // No debounce, refresh immediately - refresh(); - } - }; - - // Register notification handler - this.setNotificationHandler(notificationMethod, handler); } +} - /** Notifies the server that the client's root list has changed. Requires the `roots.listChanged` capability. */ - async sendRootsListChanged() { - return this.notification({ method: 'notifications/roots/list_changed' }); - } +function formatErr(e: unknown): string { + return e instanceof Error ? e.message : String(e); } + +export type { ClientFetchOptions, ClientTransport } from './clientTransport.js'; +export { channelAsClientTransport, isChannelTransport } from './clientTransport.js'; diff --git a/packages/client/src/client/clientTransport.ts b/packages/client/src/client/clientTransport.ts new file mode 100644 index 000000000..efdb0de35 --- /dev/null +++ b/packages/client/src/client/clientTransport.ts @@ -0,0 +1,200 @@ +import type { + Dispatcher, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResultResponse, + Notification, + Progress, + RelatedTaskMetadata, + Request, + RequestId, + RequestOptions, + StreamDriverOptions, + TaskCreationParams, + Transport +} from '@modelcontextprotocol/core'; +import { getResultSchema, SdkError, SdkErrorCode, StreamDriver } from '@modelcontextprotocol/core'; + +/** + * Per-call options for {@linkcode ClientTransport.fetch}. + */ +export type ClientFetchOptions = { + /** Abort the in-flight request. */ + signal?: AbortSignal; + /** Called for each `notifications/progress` received before the terminal response. */ + onprogress?: (progress: Progress) => void; + /** Called for each non-progress notification received before the terminal response. */ + onnotification?: (notification: JSONRPCNotification) => void; + /** + * Called for each server-initiated request (elicitation/sampling/roots) received on the + * response stream. Must return the response to send back. If absent, such requests are + * surfaced via {@linkcode onnotification} (best-effort). + */ + onrequest?: (request: JSONRPCRequest) => Promise; + /** + * Called for each JSON-RPC response on the stream whose `id` does NOT match the outbound + * request (e.g. queued task messages delivered via `sendOnResponseStream`). If absent, + * such responses are dropped. + */ + onresponse?: (response: JSONRPCResultResponse | JSONRPCErrorResponse) => void; + /** Per-request timeout (ms). */ + timeout?: number; + /** Reset {@linkcode timeout} when a progress notification arrives. */ + resetTimeoutOnProgress?: boolean; + /** Absolute upper bound (ms) regardless of progress. */ + maxTotalTimeout?: number; + /** Associates this outbound request with an inbound one (pipe transports only). */ + relatedRequestId?: RequestId; + /** Augment as a task-creating request (pipe transports only; threaded to TaskManager). */ + task?: TaskCreationParams; + /** Associate with an existing task (pipe transports only). */ + relatedTask?: RelatedTaskMetadata; + /** Resumption token to continue a previous request (SHTTP only). */ + resumptionToken?: string; + /** Called when the resumption token changes (SHTTP only). */ + onresumptiontoken?: (token: string) => void; +}; + +/** + * Request-shaped client transport. One JSON-RPC request in, one terminal + * response out. The transport may be stateful internally (session id, protocol + * version) but the contract is per-call. + * + * The legacy pipe {@linkcode Transport} interface is adapted via + * {@linkcode channelAsClientTransport}. + */ +export interface ClientTransport { + /** Explicit shape brand. Required so {@linkcode isChannelTransport} can discriminate without duck-typing. */ + readonly kind: 'request'; + + /** + * Send one JSON-RPC request and resolve with the terminal response. + * Any progress/notifications received before the response are surfaced + * via the callbacks in {@linkcode ClientFetchOptions}. + */ + fetch(request: JSONRPCRequest, opts?: ClientFetchOptions): Promise; + + /** + * Send a fire-and-forget notification. + */ + notify(notification: Notification): Promise; + + /** + * Open a server→client subscription stream for unsolicited notifications, + * server-initiated requests (elicitation/sampling/roots), and queued task + * responses. Optional; transports that cannot stream (e.g. plain HTTP + * without SSE GET) omit this. The transport handles inbound requests via + * {@linkcode ClientFetchOptions.onrequest | opts.onrequest} (and POSTs the + * reply back itself); only notifications are yielded. + */ + subscribe?(opts?: Pick): AsyncIterable; + + /** + * Close the transport and release resources. + */ + close(): Promise; + + /** The underlying {@linkcode StreamDriver} when adapted from a pipe. Compat-only. */ + readonly driver?: StreamDriver; +} + +/** + * Type guard distinguishing the legacy pipe-shaped {@linkcode Transport} from + * a request-shaped {@linkcode ClientTransport}. A transport that implements + * both (e.g. {@linkcode StreamableHTTPClientTransport}) is treated as + * {@linkcode ClientTransport} so {@linkcode Client.connect} uses the + * request-shaped path. + */ +export function isChannelTransport(t: Transport | ClientTransport): t is Transport { + return (t as ClientTransport).kind !== 'request'; +} + +/** + * Adapt a legacy pipe-shaped {@linkcode Transport} (stdio, SSE, InMemory, the + * v1 SHTTP client transport) into a {@linkcode ClientTransport}. + * + * Correlation, timeouts, progress and cancellation are handled by an internal + * {@linkcode StreamDriver}. The supplied {@linkcode Dispatcher} services any + * server-initiated requests (sampling, elicitation, roots) that arrive on the pipe. + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any -- adapter is context-agnostic; the caller's Dispatcher subclass owns ContextT +export function channelAsClientTransport(pipe: Transport, dispatcher: Dispatcher, options?: StreamDriverOptions): ClientTransport { + const driver = new StreamDriver(dispatcher, pipe, options); + let started = false; + const subscribers: Set<(n: JSONRPCNotification) => void> = new Set(); + const priorFallback = dispatcher.fallbackNotificationHandler; + dispatcher.fallbackNotificationHandler = async n => { + await priorFallback?.(n); + const msg: JSONRPCNotification = { jsonrpc: '2.0', method: n.method, params: n.params }; + for (const s of subscribers) s(msg); + }; + const ensureStarted = async () => { + if (!started) { + started = true; + await driver.start(); + } + }; + return { + kind: 'request', + driver, + async fetch(request, opts) { + await ensureStarted(); + if (opts?.signal?.aborted) { + throw new SdkError(SdkErrorCode.RequestTimeout, String(opts.signal.reason ?? 'Aborted')); + } + const schema = getResultSchema(request.method as never); + try { + const result = await driver.request({ method: request.method, params: request.params } as Request, schema, { + signal: opts?.signal, + timeout: opts?.timeout, + resetTimeoutOnProgress: opts?.resetTimeoutOnProgress, + maxTotalTimeout: opts?.maxTotalTimeout, + onprogress: opts?.onprogress, + relatedRequestId: opts?.relatedRequestId, + task: opts?.task, + relatedTask: opts?.relatedTask, + resumptionToken: opts?.resumptionToken, + onresumptiontoken: opts?.onresumptiontoken + } as RequestOptions); + return { jsonrpc: '2.0', id: request.id, result } as JSONRPCResultResponse; + } catch (error) { + const e = error as { code?: number; message?: string; data?: unknown }; + if (typeof e?.code === 'number') { + return { jsonrpc: '2.0', id: request.id, error: { code: e.code, message: e.message ?? 'Error', data: e.data } }; + } + throw error; + } + }, + async notify(notification) { + await ensureStarted(); + await driver.notification(notification); + }, + async *subscribe() { + await ensureStarted(); + const queue: JSONRPCNotification[] = []; + let wake: (() => void) | undefined; + const push = (n: JSONRPCNotification) => { + queue.push(n); + wake?.(); + }; + subscribers.add(push); + try { + while (true) { + while (queue.length > 0) yield queue.shift()!; + await new Promise(r => (wake = r)); + wake = undefined; + } + } finally { + subscribers.delete(push); + } + }, + async close() { + await driver.close(); + } + }; +} + +/** Re-exported so callers can detect protocol-level errors uniformly. */ + +export { isJSONRPCErrorResponse } from '@modelcontextprotocol/core'; diff --git a/packages/client/src/client/streamableHttp.ts b/packages/client/src/client/streamableHttp.ts index cd643c96d..681689271 100644 --- a/packages/client/src/client/streamableHttp.ts +++ b/packages/client/src/client/streamableHttp.ts @@ -1,10 +1,19 @@ import type { ReadableWritablePair } from 'node:stream/web'; -import type { FetchLike, JSONRPCMessage, Transport } from '@modelcontextprotocol/core'; +import type { + FetchLike, + JSONRPCErrorResponse, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResultResponse, + Notification +} from '@modelcontextprotocol/core'; import { createFetchWithInit, isInitializedNotification, isJSONRPCErrorResponse, + isJSONRPCNotification, isJSONRPCRequest, isJSONRPCResultResponse, JSONRPCMessageSchema, @@ -16,6 +25,19 @@ import { EventSourceParserStream } from 'eventsource-parser/stream'; import type { AuthProvider, OAuthClientProvider } from './auth.js'; import { adaptOAuthProvider, auth, extractWWWAuthenticateParams, isOAuthClientProvider, UnauthorizedError } from './auth.js'; +import type { ClientFetchOptions, ClientTransport } from './clientTransport.js'; + +/** + * @deprecated Use {@linkcode SdkError} with {@linkcode SdkErrorCode}. Kept for v1 import compatibility. + */ +export class StreamableHTTPError extends SdkError { + constructor( + public readonly statusCode: number | undefined, + message: string + ) { + super(SdkErrorCode.ClientHttpUnexpectedContent, message); + } +} // Default reconnection options for StreamableHTTP connections const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { @@ -166,8 +188,14 @@ export type StreamableHTTPClientTransportOptions = { * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It will connect to a server using HTTP `POST` for sending messages and HTTP `GET` with Server-Sent Events * for receiving messages. + * + * Implements both the request-shaped {@linkcode ClientTransport} (the primary path used by + * {@linkcode Client.connect}) and the legacy pipe-shaped {@linkcode Transport} (deprecated; kept for + * direct callers and v1 compat). */ -export class StreamableHTTPClientTransport implements Transport { +export class StreamableHTTPClientTransport implements ClientTransport { + readonly kind = 'request' as const; + private _abortController?: AbortController; private _url: URL; private _resourceMetadataUrl?: URL; @@ -185,8 +213,11 @@ export class StreamableHTTPClientTransport implements Transport { private readonly _reconnectionScheduler?: ReconnectionScheduler; private _cancelReconnection?: () => void; + /** @deprecated Pipe-shaped {@linkcode Transport} callback. The {@linkcode ClientTransport} path returns responses directly. */ onclose?: () => void; + /** @deprecated Pipe-shaped {@linkcode Transport} callback. */ onerror?: (error: Error) => void; + /** @deprecated Pipe-shaped {@linkcode Transport} callback. */ onmessage?: (message: JSONRPCMessage) => void; constructor(url: URL, opts?: StreamableHTTPClientTransportOptions) { @@ -208,6 +239,10 @@ export class StreamableHTTPClientTransport implements Transport { this._reconnectionScheduler = opts?.reconnectionScheduler; } + // ─────────────────────────────────────────────────────────────────────── + // Shared internals + // ─────────────────────────────────────────────────────────────────────── + private async _commonHeaders(): Promise { const headers: RequestInit['headers'] & Record = {}; const token = await this._authProvider?.token(); @@ -230,75 +265,88 @@ export class StreamableHTTPClientTransport implements Transport { }); } - private async _startOrAuthSse(options: StartSSEOptions, isAuthRetry = false): Promise { - const { resumptionToken } = options; + /** + * Single auth-aware HTTP request. Adds bearer header, captures session id, and + * handles 401 (one retry via {@linkcode AuthProvider.onUnauthorized}) and 403 + * insufficient_scope (upscope via OAuth, with loop guard). Returns the Response + * even when not-ok for status codes other than the handled auth cases. + */ + private async _authedHttpFetch( + build: (headers: Headers) => RequestInit, + opts: { signal?: AbortSignal } = {}, + isAuthRetry = false + ): Promise { + const headers = await this._commonHeaders(); + const init = { ...this._requestInit, ...build(headers), signal: opts.signal ?? this._abortController?.signal }; + const response = await (this._fetch ?? fetch)(this._url, init); + + const sessionId = response.headers?.get('mcp-session-id'); + if (sessionId) { + this._sessionId = sessionId; + } + if (response.ok) { + this._lastUpscopingHeader = undefined; + return response; + } - try { - // Try to open an initial SSE stream with GET to listen for server messages - // This is optional according to the spec - server may not support it - const headers = await this._commonHeaders(); - const userAccept = headers.get('accept'); - const types = [...(userAccept?.split(',').map(s => s.trim().toLowerCase()) ?? []), 'text/event-stream']; - headers.set('accept', [...new Set(types)].join(', ')); - - // Include Last-Event-ID header for resumable streams if provided - if (resumptionToken) { - headers.set('last-event-id', resumptionToken); + if (response.status === 401 && this._authProvider) { + if (response.headers.has('www-authenticate')) { + const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); + this._resourceMetadataUrl = resourceMetadataUrl; + this._scope = scope; } - - const response = await (this._fetch ?? fetch)(this._url, { - ...this._requestInit, - method: 'GET', - headers, - signal: this._abortController?.signal - }); - - if (!response.ok) { - if (response.status === 401 && this._authProvider) { - if (response.headers.has('www-authenticate')) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; - } - - if (this._authProvider.onUnauthorized && !isAuthRetry) { - await this._authProvider.onUnauthorized({ - response, - serverUrl: this._url, - fetchFn: this._fetchWithInit - }); - await response.text?.().catch(() => {}); - // Purposely _not_ awaited, so we don't call onerror twice - return this._startOrAuthSse(options, true); - } - await response.text?.().catch(() => {}); - if (isAuthRetry) { - throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { - status: 401 - }); - } - throw new UnauthorizedError(); - } - + if (this._authProvider.onUnauthorized && !isAuthRetry) { + await this._authProvider.onUnauthorized({ + response, + serverUrl: this._url, + fetchFn: this._fetchWithInit + }); await response.text?.().catch(() => {}); + return this._authedHttpFetch(build, opts, true); + } + await response.text?.().catch(() => {}); + if (isAuthRetry) { + throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { status: 401 }); + } + throw new UnauthorizedError(); + } - // 405 indicates that the server does not offer an SSE stream at GET endpoint - // This is an expected case that should not trigger an error - if (response.status === 405) { - return; + if (response.status === 403 && this._oauthProvider) { + const text = await response.text?.().catch(() => null); + const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); + if (error === 'insufficient_scope') { + const wwwAuthHeader = response.headers.get('WWW-Authenticate'); + if (this._lastUpscopingHeader === wwwAuthHeader) { + throw new SdkError(SdkErrorCode.ClientHttpForbidden, 'Server returned 403 after trying upscoping', { + status: 403, + text + }); } - - throw new SdkError(SdkErrorCode.ClientHttpFailedToOpenStream, `Failed to open SSE stream: ${response.statusText}`, { - status: response.status, - statusText: response.statusText + if (scope) this._scope = scope; + if (resourceMetadataUrl) this._resourceMetadataUrl = resourceMetadataUrl; + this._lastUpscopingHeader = wwwAuthHeader ?? undefined; + const result = await auth(this._oauthProvider, { + serverUrl: this._url, + resourceMetadataUrl: this._resourceMetadataUrl, + scope: this._scope, + fetchFn: this._fetchWithInit }); + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError(); + } + return this._authedHttpFetch(build, opts, isAuthRetry); } - - this._handleSseStream(response.body, options, true); - } catch (error) { - this.onerror?.(error as Error); - throw error; + // Re-wrap consumed-body 403 so caller's `await response.text()` doesn't blow up. + return new Response(text, { status: 403, headers: response.headers }); } + + return response; + } + + private _setAccept(headers: Headers, ...required: string[]): void { + const userAccept = headers.get('accept'); + const types = [...(userAccept?.split(',').map(s => s.trim().toLowerCase()) ?? []), ...required]; + headers.set('accept', [...new Set(types)].join(', ')); } /** @@ -308,39 +356,279 @@ export class StreamableHTTPClientTransport implements Transport { * @returns Time to wait in milliseconds before next reconnection attempt */ private _getNextReconnectionDelay(attempt: number): number { - // Use server-provided retry value if available - if (this._serverRetryMs !== undefined) { - return this._serverRetryMs; - } - - // Fall back to exponential backoff + if (this._serverRetryMs !== undefined) return this._serverRetryMs; const initialDelay = this._reconnectionOptions.initialReconnectionDelay; const growFactor = this._reconnectionOptions.reconnectionDelayGrowFactor; const maxDelay = this._reconnectionOptions.maxReconnectionDelay; - - // Cap at maximum delay return Math.min(initialDelay * Math.pow(growFactor, attempt), maxDelay); } + private _sseReader(stream: ReadableStream) { + return stream + .pipeThrough(new TextDecoderStream() as ReadableWritablePair) + .pipeThrough(new EventSourceParserStream({ onRetry: ms => (this._serverRetryMs = ms) })) + .getReader(); + } + + private _linkSignal(a: AbortSignal | undefined): AbortSignal | undefined { + const b = this._abortController?.signal; + if (!a) return b; + if (!b) return a; + if (typeof (AbortSignal as { any?: (s: AbortSignal[]) => AbortSignal }).any === 'function') { + return (AbortSignal as unknown as { any: (s: AbortSignal[]) => AbortSignal }).any([a, b]); + } + const c = new AbortController(); + const wire = (s: AbortSignal) => + s.aborted ? c.abort(s.reason) : s.addEventListener('abort', () => c.abort(s.reason), { once: true }); + wire(a); + wire(b); + return c.signal; + } + + // ─────────────────────────────────────────────────────────────────────── + // ClientTransport (request-shaped) — primary path + // ─────────────────────────────────────────────────────────────────────── + + /** + * Send one JSON-RPC request and resolve with the terminal response. Progress and other + * notifications received before the response are surfaced via {@linkcode ClientFetchOptions}. + */ + async fetch(request: JSONRPCRequest, opts: ClientFetchOptions = {}): Promise { + this._abortController ??= new AbortController(); + return this._fetchOnce(request, opts, opts.resumptionToken, 0); + } + + private async _fetchOnce( + request: JSONRPCRequest, + opts: ClientFetchOptions, + lastEventId: string | undefined, + attempt: number + ): Promise { + const signal = this._linkSignal(opts.signal); + const isResume = lastEventId !== undefined; + const res = await this._authedHttpFetch( + headers => { + if (isResume) { + this._setAccept(headers, 'text/event-stream'); + headers.set('last-event-id', lastEventId); + return { method: 'GET', headers }; + } + headers.set('content-type', 'application/json'); + this._setAccept(headers, 'application/json', 'text/event-stream'); + return { method: 'POST', headers, body: JSON.stringify(request) }; + }, + { signal } + ); + + if (!res.ok) { + const text = await res.text?.().catch(() => null); + throw new SdkError(SdkErrorCode.ClientHttpNotImplemented, `Error POSTing to endpoint (HTTP ${res.status}): ${text}`, { + status: res.status, + text + }); + } + const ct = res.headers.get('content-type') ?? ''; + if (ct.includes('text/event-stream')) { + return this._readSseToTerminal(res, request, opts, attempt); + } + if (ct.includes('application/json')) { + const data = await res.json(); + const messages = Array.isArray(data) ? data : [data]; + let terminal: JSONRPCResultResponse | JSONRPCErrorResponse | undefined; + for (const m of messages) { + const msg = JSONRPCMessageSchema.parse(m); + if (isJSONRPCResultResponse(msg) || isJSONRPCErrorResponse(msg)) terminal = msg; + else if (isJSONRPCNotification(msg)) this._routeFetchNotification(msg, opts); + } + if (!terminal) { + throw new SdkError(SdkErrorCode.ClientHttpUnexpectedContent, 'JSON response contained no terminal response'); + } + return terminal; + } + await res.text?.().catch(() => {}); + throw new SdkError(SdkErrorCode.ClientHttpUnexpectedContent, `Unexpected content type: ${ct}`, { contentType: ct }); + } + + private async _readSseToTerminal( + res: Response, + request: JSONRPCRequest, + opts: ClientFetchOptions, + attempt: number + ): Promise { + if (!res.body) throw new SdkError(SdkErrorCode.ClientHttpUnexpectedContent, 'SSE response has no body'); + let lastEventId: string | undefined; + let primed = false; + const reader = this._sseReader(res.body); + try { + while (true) { + const { value, done } = await reader.read(); + if (done) break; + if (value.id) { + lastEventId = value.id; + primed = true; + opts.onresumptiontoken?.(value.id); + } + if (!value.data) continue; + if (value.event && value.event !== 'message') continue; + const msg = JSONRPCMessageSchema.parse(JSON.parse(value.data)); + if (isJSONRPCResultResponse(msg) || isJSONRPCErrorResponse(msg)) { + if (msg.id === request.id) return msg; + opts.onresponse?.(msg); + continue; + } + if (isJSONRPCNotification(msg)) { + this._routeFetchNotification(msg, opts); + } else if (isJSONRPCRequest(msg)) { + void this._serviceInboundRequest(msg, opts); + } + } + } catch { + // fallthrough to resume below + } finally { + try { + reader.releaseLock(); + } catch { + /* noop */ + } + } + if (primed && attempt < this._reconnectionOptions.maxRetries && !this._abortController?.signal.aborted && !opts.signal?.aborted) { + await new Promise(r => setTimeout(r, this._getNextReconnectionDelay(attempt))); + return this._fetchOnce(request, opts, lastEventId, attempt + 1); + } + throw new SdkError(SdkErrorCode.ClientHttpFailedToOpenStream, 'SSE stream ended without a terminal response'); + } + + /** Handle a server-initiated request received on the SSE response stream and POST the reply back. */ + private async _serviceInboundRequest( + inbound: JSONRPCRequest, + opts: Pick + ): Promise { + if (!opts.onrequest) { + opts.onnotification?.(inbound as unknown as JSONRPCNotification); + return; + } + let response: JSONRPCResultResponse | JSONRPCErrorResponse; + try { + response = await opts.onrequest(inbound); + } catch (error) { + response = { + jsonrpc: '2.0', + id: inbound.id, + error: { code: -32_603, message: error instanceof Error ? error.message : String(error) } + }; + } + try { + const r = await this._authedHttpFetch(headers => { + headers.set('content-type', 'application/json'); + this._setAccept(headers, 'application/json', 'text/event-stream'); + return { method: 'POST', headers, body: JSON.stringify(response) }; + }); + await r.text?.().catch(() => {}); + } catch (error) { + this.onerror?.(error instanceof Error ? error : new Error(String(error))); + } + } + + private _routeFetchNotification(msg: JSONRPCNotification, opts: ClientFetchOptions): void { + if (msg.method === 'notifications/progress' && opts.onprogress) { + const { progressToken: _t, ...progress } = (msg.params ?? {}) as Record; + void _t; + opts.onprogress(progress as never); + return; + } + opts.onnotification?.(msg); + } + + /** Send a fire-and-forget JSON-RPC notification. */ + async notify(n: Notification): Promise { + this._abortController ??= new AbortController(); + const res = await this._authedHttpFetch(headers => { + headers.set('content-type', 'application/json'); + this._setAccept(headers, 'application/json', 'text/event-stream'); + return { method: 'POST', headers, body: JSON.stringify({ jsonrpc: '2.0', method: n.method, params: n.params }) }; + }); + await res.text?.().catch(() => {}); + if (!res.ok && res.status !== 202) { + throw new SdkError(SdkErrorCode.ClientHttpNotImplemented, `Notification POST failed: ${res.status}`, { status: res.status }); + } + } + + /** + * Open the standalone GET SSE stream and yield server-initiated notifications. + * Inbound requests (elicitation/sampling/roots) are dispatched via + * {@linkcode ClientFetchOptions.onrequest | opts.onrequest} and the reply is + * POSTed back automatically. Best-effort: if the server replies 405 (no SSE + * GET), the iterable completes immediately. + */ + async *subscribe(opts: Pick = {}): AsyncIterable { + this._abortController ??= new AbortController(); + const res = await this._authedHttpFetch(headers => { + this._setAccept(headers, 'text/event-stream'); + return { method: 'GET', headers }; + }); + if (res.status === 405 || !res.ok || !res.body) { + await res.text?.().catch(() => {}); + return; + } + const reader = this._sseReader(res.body); + try { + while (true) { + const { value, done } = await reader.read(); + if (done) return; + if (!value.data) continue; + const msg = JSONRPCMessageSchema.parse(JSON.parse(value.data)); + if (isJSONRPCNotification(msg)) { + yield msg; + } else if (isJSONRPCRequest(msg)) { + void this._serviceInboundRequest(msg, opts); + } else if (isJSONRPCResultResponse(msg) || isJSONRPCErrorResponse(msg)) { + opts.onresponse?.(msg); + } + } + } finally { + reader.releaseLock(); + } + } + + // ─────────────────────────────────────────────────────────────────────── + // Transport (pipe-shaped) — deprecated compat surface + // ─────────────────────────────────────────────────────────────────────── + + private async _startOrAuthSse(options: StartSSEOptions): Promise { + const { resumptionToken } = options; + try { + const response = await this._authedHttpFetch(headers => { + this._setAccept(headers, 'text/event-stream'); + if (resumptionToken) headers.set('last-event-id', resumptionToken); + return { method: 'GET', headers }; + }); + + if (!response.ok) { + await response.text?.().catch(() => {}); + // 405 indicates that the server does not offer an SSE stream at GET endpoint + if (response.status === 405) return; + throw new SdkError(SdkErrorCode.ClientHttpFailedToOpenStream, `Failed to open SSE stream: ${response.statusText}`, { + status: response.status, + statusText: response.statusText + }); + } + this._handleSseStream(response.body, options, true); + } catch (error) { + this.onerror?.(error as Error); + throw error; + } + } + /** * Schedule a reconnection attempt using server-provided retry interval or backoff - * - * @param lastEventId The ID of the last received event for resumability - * @param attemptCount Current reconnection attempt count for this specific stream */ private _scheduleReconnection(options: StartSSEOptions, attemptCount = 0): void { - // Use provided options or default options const maxRetries = this._reconnectionOptions.maxRetries; - - // Check if we've exceeded maximum retry attempts if (attemptCount >= maxRetries) { this.onerror?.(new Error(`Maximum reconnection attempts (${maxRetries}) exceeded.`)); return; } - - // Calculate next delay based on current attempt count const delay = this._getNextReconnectionDelay(attemptCount); - const reconnect = (): void => { this._cancelReconnection = undefined; if (this._abortController?.signal.aborted) return; @@ -353,7 +641,6 @@ export class StreamableHTTPClientTransport implements Transport { } }); }; - if (this._reconnectionScheduler) { const cancel = this._reconnectionScheduler(reconnect, delay, attemptCount); this._cancelReconnection = typeof cancel === 'function' ? cancel : undefined; @@ -364,60 +651,28 @@ export class StreamableHTTPClientTransport implements Transport { } private _handleSseStream(stream: ReadableStream | null, options: StartSSEOptions, isReconnectable: boolean): void { - if (!stream) { - return; - } + if (!stream) return; const { onresumptiontoken, replayMessageId } = options; let lastEventId: string | undefined; - // Track whether we've received a priming event (event with ID) - // Per spec, server SHOULD send a priming event with ID before closing let hasPrimingEvent = false; - // Track whether we've received a response - if so, no need to reconnect - // Reconnection is for when server disconnects BEFORE sending response let receivedResponse = false; const processStream = async () => { - // this is the closest we can get to trying to catch network errors - // if something happens reader will throw try { - // Create a pipeline: binary stream -> text decoder -> SSE parser - const reader = stream - .pipeThrough(new TextDecoderStream() as ReadableWritablePair) - .pipeThrough( - new EventSourceParserStream({ - onRetry: (retryMs: number) => { - // Capture server-provided retry value for reconnection timing - this._serverRetryMs = retryMs; - } - }) - ) - .getReader(); - + const reader = this._sseReader(stream); while (true) { const { value: event, done } = await reader.read(); - if (done) { - break; - } - - // Update last event ID if provided + if (done) break; if (event.id) { lastEventId = event.id; - // Mark that we've received a priming event - stream is now resumable hasPrimingEvent = true; onresumptiontoken?.(event.id); } - - // Skip events with no data (priming events, keep-alives) - if (!event.data) { - continue; - } - + if (!event.data) continue; if (!event.event || event.event === 'message') { try { const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); - // Handle both success AND error responses for completion detection and ID remapping if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - // Mark that we received a response - no need to reconnect for this request receivedResponse = true; if (replayMessageId !== undefined) { message.id = replayMessageId; @@ -429,43 +684,18 @@ export class StreamableHTTPClientTransport implements Transport { } } } - - // Handle graceful server-side disconnect - // Server may close connection after sending event ID and retry field - // Reconnect if: already reconnectable (GET stream) OR received a priming event (POST stream with event ID) - // BUT don't reconnect if we already received a response - the request is complete const canResume = isReconnectable || hasPrimingEvent; const needsReconnect = canResume && !receivedResponse; if (needsReconnect && this._abortController && !this._abortController.signal.aborted) { - this._scheduleReconnection( - { - resumptionToken: lastEventId, - onresumptiontoken, - replayMessageId - }, - 0 - ); + this._scheduleReconnection({ resumptionToken: lastEventId, onresumptiontoken, replayMessageId }, 0); } } catch (error) { - // Handle stream errors - likely a network disconnect this.onerror?.(new Error(`SSE stream disconnected: ${error}`)); - - // Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing - // Reconnect if: already reconnectable (GET stream) OR received a priming event (POST stream with event ID) - // BUT don't reconnect if we already received a response - the request is complete const canResume = isReconnectable || hasPrimingEvent; const needsReconnect = canResume && !receivedResponse; if (needsReconnect && this._abortController && !this._abortController.signal.aborted) { - // Use the exponential backoff reconnection strategy try { - this._scheduleReconnection( - { - resumptionToken: lastEventId, - onresumptiontoken, - replayMessageId - }, - 0 - ); + this._scheduleReconnection({ resumptionToken: lastEventId, onresumptiontoken, replayMessageId }, 0); } catch (error) { this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); } @@ -475,13 +705,13 @@ export class StreamableHTTPClientTransport implements Transport { processStream(); } + /** @deprecated Part of the pipe-shaped {@linkcode Transport} interface. {@linkcode Client.connect} uses the request-shaped path. */ async start() { if (this._abortController) { throw new Error( 'StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.' ); } - this._abortController = new AbortController(); } @@ -492,7 +722,6 @@ export class StreamableHTTPClientTransport implements Transport { if (!this._oauthProvider) { throw new UnauthorizedError('finishAuth requires an OAuthClientProvider'); } - const result = await auth(this._oauthProvider, { serverUrl: this._url, authorizationCode, @@ -515,161 +744,55 @@ export class StreamableHTTPClientTransport implements Transport { } } + /** @deprecated Part of the pipe-shaped {@linkcode Transport} interface. Use {@linkcode fetch} / {@linkcode notify}. */ async send( message: JSONRPCMessage | JSONRPCMessage[], options?: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } - ): Promise { - return this._send(message, options, false); - } - - private async _send( - message: JSONRPCMessage | JSONRPCMessage[], - options: { resumptionToken?: string; onresumptiontoken?: (token: string) => void } | undefined, - isAuthRetry: boolean ): Promise { try { const { resumptionToken, onresumptiontoken } = options || {}; if (resumptionToken) { - // If we have a last event ID, we need to reconnect the SSE stream this._startOrAuthSse({ resumptionToken, replayMessageId: isJSONRPCRequest(message) ? message.id : undefined }).catch( error => this.onerror?.(error) ); return; } - const headers = await this._commonHeaders(); - headers.set('content-type', 'application/json'); - const userAccept = headers.get('accept'); - const types = [...(userAccept?.split(',').map(s => s.trim().toLowerCase()) ?? []), 'application/json', 'text/event-stream']; - headers.set('accept', [...new Set(types)].join(', ')); - - const init = { - ...this._requestInit, - method: 'POST', - headers, - body: JSON.stringify(message), - signal: this._abortController?.signal - }; - - const response = await (this._fetch ?? fetch)(this._url, init); - - // Handle session ID received during initialization - const sessionId = response.headers.get('mcp-session-id'); - if (sessionId) { - this._sessionId = sessionId; - } + const response = await this._authedHttpFetch(headers => { + headers.set('content-type', 'application/json'); + this._setAccept(headers, 'application/json', 'text/event-stream'); + return { method: 'POST', headers, body: JSON.stringify(message) }; + }); if (!response.ok) { - if (response.status === 401 && this._authProvider) { - // Store WWW-Authenticate params for interactive finishAuth() path - if (response.headers.has('www-authenticate')) { - const { resourceMetadataUrl, scope } = extractWWWAuthenticateParams(response); - this._resourceMetadataUrl = resourceMetadataUrl; - this._scope = scope; - } - - if (this._authProvider.onUnauthorized && !isAuthRetry) { - await this._authProvider.onUnauthorized({ - response, - serverUrl: this._url, - fetchFn: this._fetchWithInit - }); - await response.text?.().catch(() => {}); - // Purposely _not_ awaited, so we don't call onerror twice - return this._send(message, options, true); - } - await response.text?.().catch(() => {}); - if (isAuthRetry) { - throw new SdkError(SdkErrorCode.ClientHttpAuthentication, 'Server returned 401 after re-authentication', { - status: 401 - }); - } - throw new UnauthorizedError(); - } - const text = await response.text?.().catch(() => null); - - if (response.status === 403 && this._oauthProvider) { - const { resourceMetadataUrl, scope, error } = extractWWWAuthenticateParams(response); - - if (error === 'insufficient_scope') { - const wwwAuthHeader = response.headers.get('WWW-Authenticate'); - - // Check if we've already tried upscoping with this header to prevent infinite loops. - if (this._lastUpscopingHeader === wwwAuthHeader) { - throw new SdkError(SdkErrorCode.ClientHttpForbidden, 'Server returned 403 after trying upscoping', { - status: 403, - text - }); - } - - if (scope) { - this._scope = scope; - } - - if (resourceMetadataUrl) { - this._resourceMetadataUrl = resourceMetadataUrl; - } - - // Mark that upscoping was tried. - this._lastUpscopingHeader = wwwAuthHeader ?? undefined; - const result = await auth(this._oauthProvider, { - serverUrl: this._url, - resourceMetadataUrl: this._resourceMetadataUrl, - scope: this._scope, - fetchFn: this._fetchWithInit - }); - - if (result !== 'AUTHORIZED') { - throw new UnauthorizedError(); - } - - return this._send(message, options, isAuthRetry); - } - } - throw new SdkError(SdkErrorCode.ClientHttpNotImplemented, `Error POSTing to endpoint: ${text}`, { status: response.status, text }); } - this._lastUpscopingHeader = undefined; - - // If the response is 202 Accepted, there's no body to process if (response.status === 202) { await response.text?.().catch(() => {}); - // if the accepted notification is initialized, we start the SSE stream - // if it's supported by the server if (isInitializedNotification(message)) { - // Start without a lastEventId since this is a fresh connection this._startOrAuthSse({ resumptionToken: undefined }).catch(error => this.onerror?.(error)); } return; } - // Get original message(s) for detecting request IDs const messages = Array.isArray(message) ? message : [message]; - const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg && msg.id !== undefined); - - // Check the response type const contentType = response.headers.get('content-type'); if (hasRequests) { if (contentType?.includes('text/event-stream')) { - // Handle SSE stream responses for requests - // We use the same handler as standalone streams, which now supports - // reconnection with the last event ID this._handleSseStream(response.body, { onresumptiontoken }, false); } else if (contentType?.includes('application/json')) { - // For non-streaming servers, we might get direct JSON responses const data = await response.json(); const responseMessages = Array.isArray(data) ? data.map(msg => JSONRPCMessageSchema.parse(msg)) : [JSONRPCMessageSchema.parse(data)]; - for (const msg of responseMessages) { this.onmessage?.(msg); } @@ -680,7 +803,6 @@ export class StreamableHTTPClientTransport implements Transport { }); } } else { - // No requests in message but got 200 OK - still need to release connection await response.text?.().catch(() => {}); } } catch (error) { @@ -705,32 +827,16 @@ export class StreamableHTTPClientTransport implements Transport { * the server does not allow clients to terminate sessions. */ async terminateSession(): Promise { - if (!this._sessionId) { - return; // No session to terminate - } - + if (!this._sessionId) return; try { - const headers = await this._commonHeaders(); - - const init = { - ...this._requestInit, - method: 'DELETE', - headers, - signal: this._abortController?.signal - }; - - const response = await (this._fetch ?? fetch)(this._url, init); + const response = await this._authedHttpFetch(headers => ({ method: 'DELETE', headers })); await response.text?.().catch(() => {}); - - // We specifically handle 405 as a valid response according to the spec, - // meaning the server does not support explicit session termination if (!response.ok && response.status !== 405) { throw new SdkError(SdkErrorCode.ClientHttpFailedToTerminateSession, `Failed to terminate session: ${response.statusText}`, { status: response.status, statusText: response.statusText }); } - this._sessionId = undefined; } catch (error) { this.onerror?.(error as Error); @@ -749,13 +855,9 @@ export class StreamableHTTPClientTransport implements Transport { * Resume an SSE stream from a previous event ID. * Opens a `GET` SSE connection with `Last-Event-ID` header to replay missed events. * - * @param lastEventId The event ID to resume from - * @param options Optional callback to receive new resumption tokens + * @deprecated Part of the pipe-shaped {@linkcode Transport} surface; messages surface via {@linkcode onmessage}. */ async resumeStream(lastEventId: string, options?: { onresumptiontoken?: (token: string) => void }): Promise { - await this._startOrAuthSse({ - resumptionToken: lastEventId, - onresumptiontoken: options?.onresumptiontoken - }); + await this._startOrAuthSse({ resumptionToken: lastEventId, onresumptiontoken: options?.onresumptiontoken }); } } diff --git a/packages/client/src/index.ts b/packages/client/src/index.ts index 48b79b5ce..cfa977874 100644 --- a/packages/client/src/index.ts +++ b/packages/client/src/index.ts @@ -69,7 +69,7 @@ export type { StreamableHTTPClientTransportOptions, StreamableHTTPReconnectionOptions } from './client/streamableHttp.js'; -export { StreamableHTTPClientTransport } from './client/streamableHttp.js'; +export { StreamableHTTPClientTransport, StreamableHTTPError } from './client/streamableHttp.js'; // experimental exports export { ExperimentalClientTasks } from './experimental/tasks/client.js'; diff --git a/packages/client/src/validators/cfWorker.ts b/packages/client/src/validators/cfWorker.ts index b068e69a1..7d1c843e5 100644 --- a/packages/client/src/validators/cfWorker.ts +++ b/packages/client/src/validators/cfWorker.ts @@ -6,5 +6,5 @@ * import { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/client/validators/cf-worker'; * ``` */ -export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; export type { CfWorkerSchemaDraft } from '@modelcontextprotocol/core'; +export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; diff --git a/packages/client/test/client/client.test.ts b/packages/client/test/client/client.test.ts new file mode 100644 index 000000000..bbec860b6 --- /dev/null +++ b/packages/client/test/client/client.test.ts @@ -0,0 +1,369 @@ +import type { + CallToolResult, + CreateTaskResult, + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResultResponse, + Notification +} from '@modelcontextprotocol/core'; +import { InMemoryTransport, LATEST_PROTOCOL_VERSION, ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core'; +import { describe, expect, it, vi } from 'vitest'; + +import type { ClientFetchOptions, ClientTransport } from '../../src/client/clientTransport.js'; +import { isChannelTransport } from '../../src/client/clientTransport.js'; +import { Client } from '../../src/client/client.js'; + +type FetchResp = JSONRPCResultResponse | JSONRPCErrorResponse; + +function mockTransport(handler: (req: JSONRPCRequest, opts?: ClientFetchOptions) => Promise | FetchResp): { + ct: ClientTransport; + sent: JSONRPCRequest[]; + notified: Notification[]; +} { + const sent: JSONRPCRequest[] = []; + const notified: Notification[] = []; + const ct: ClientTransport = { + kind: 'request', + async fetch(req, opts) { + sent.push(req); + return handler(req, opts); + }, + async notify(n) { + notified.push(n); + }, + async close() {} + }; + return { ct, sent, notified }; +} + +const ok = (id: JSONRPCRequest['id'], result: unknown): JSONRPCResultResponse => ({ jsonrpc: '2.0', id, result }) as JSONRPCResultResponse; +const err = (id: JSONRPCRequest['id'], code: number, message: string): JSONRPCErrorResponse => ({ + jsonrpc: '2.0', + id, + error: { code, message } +}); + +const initResult = (caps: Record = { tools: { listChanged: true } }) => ({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: caps, + serverInfo: { name: 's', version: '1.0.0' } +}); + +describe('Client (V2)', () => { + describe('connect via ClientTransport', () => { + it('falls back to initialize when server/discover is MethodNotFound, populates server caps', async () => { + const { ct, sent, notified } = mockTransport(req => { + if (req.method === 'server/discover') return err(req.id, ProtocolErrorCode.MethodNotFound, 'nope'); + if (req.method === 'initialize') return ok(req.id, initResult()); + return err(req.id, ProtocolErrorCode.MethodNotFound, 'unexpected'); + }); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(ct); + expect(sent[0]?.method).toBe('server/discover'); + expect(sent.find(r => r.method === 'initialize')).toBeDefined(); + expect(c.getServerCapabilities()?.tools).toBeDefined(); + expect(c.getServerVersion()?.name).toBe('s'); + expect(notified.find(n => n.method === 'notifications/initialized')).toBeDefined(); + }); + + it('uses server/discover result directly when supported (2026-06)', async () => { + const { ct, sent } = mockTransport(req => { + if (req.method === 'server/discover') { + return ok(req.id, { capabilities: { tools: {} }, serverInfo: { name: 'd', version: '2' } }); + } + throw new Error('should not reach'); + }); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(ct); + expect(sent.some(r => r.method === 'initialize')).toBe(false); + expect(c.getServerVersion()?.name).toBe('d'); + }); + + it('isChannelTransport correctly distinguishes the two shapes', () => { + const [a] = InMemoryTransport.createLinkedPair(); + const { ct } = mockTransport(r => ok(r.id, {})); + expect(isChannelTransport(a)).toBe(true); + expect(isChannelTransport(ct)).toBe(false); + }); + }); + + describe('typed RPC sugar', () => { + async function connected(handler: (req: JSONRPCRequest, opts?: ClientFetchOptions) => FetchResp | Promise) { + const m = mockTransport((req, opts) => { + if (req.method === 'server/discover') + return ok(req.id, { capabilities: { tools: {}, prompts: {}, resources: {} }, serverInfo: { name: 's', version: '1' } }); + return handler(req, opts); + }); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(m.ct); + return { c, ...m }; + } + + it('callTool returns the result', async () => { + const { c } = await connected(r => + r.method === 'tools/call' ? ok(r.id, { content: [{ type: 'text', text: 'hi' }] }) : err(r.id, -32601, 'nope') + ); + const result = (await c.callTool({ name: 'x', arguments: {} })) as CallToolResult; + expect(result.content[0]).toEqual({ type: 'text', text: 'hi' }); + }); + + it('listTools caches output validators and callTool enforces them', async () => { + const tools = [ + { name: 'typed', inputSchema: { type: 'object' }, outputSchema: { type: 'object', properties: { n: { type: 'number' } } } } + ]; + const { c } = await connected(r => { + if (r.method === 'tools/list') return ok(r.id, { tools }); + if (r.method === 'tools/call') return ok(r.id, { content: [], structuredContent: { n: 'not-a-number' } }); + return err(r.id, -32601, 'nope'); + }); + await c.listTools(); + await expect(c.callTool({ name: 'typed', arguments: {} })).rejects.toThrow(ProtocolError); + }); + + it('callTool rejects when tool with outputSchema returns no structuredContent', async () => { + const { c } = await connected(r => { + if (r.method === 'tools/list') { + return ok(r.id, { tools: [{ name: 't', inputSchema: { type: 'object' }, outputSchema: { type: 'object' } }] }); + } + if (r.method === 'tools/call') return ok(r.id, { content: [] }); + return err(r.id, -32601, 'nope'); + }); + await c.listTools(); + await expect(c.callTool({ name: 't', arguments: {} })).rejects.toThrow(/structured content/); + }); + + it('list* return empty when capability missing and not strict', async () => { + const { ct } = mockTransport(r => + r.method === 'server/discover' + ? ok(r.id, { capabilities: {}, serverInfo: { name: 's', version: '1' } }) + : err(r.id, -32601, 'nope') + ); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(ct); + expect(await c.listTools()).toEqual({ tools: [] }); + expect(await c.listPrompts()).toEqual({ prompts: [] }); + expect(await c.listResources()).toEqual({ resources: [] }); + }); + + it('throws ProtocolError on JSON-RPC error response', async () => { + const { c } = await connected(r => err(r.id, ProtocolErrorCode.InvalidParams, 'bad')); + await expect(c.ping()).rejects.toBeInstanceOf(ProtocolError); + }); + + it('passes onprogress through to transport', async () => { + const seen: unknown[] = []; + const { c } = await connected(async (r, opts) => { + opts?.onprogress?.({ progress: 1, total: 2 }); + return ok(r.id, { content: [] }); + }); + await c.callTool({ name: 'x', arguments: {} }, { onprogress: (p: unknown) => seen.push(p) }); + expect(seen).toEqual([{ progress: 1, total: 2 }]); + }); + }); + + describe('MRTR loop', () => { + it('re-sends with inputResponses when server returns input_required, resolves on complete', async () => { + let round = 0; + const elicitArgs = { + method: 'elicitation/create', + params: { message: 'q', requestedSchema: { type: 'object', properties: {} } } + }; + const { ct, sent } = mockTransport(r => { + if (r.method === 'server/discover') + return ok(r.id, { capabilities: { tools: {} }, serverInfo: { name: 's', version: '1' } }); + if (r.method === 'tools/call') { + round++; + if (round === 1) return ok(r.id, { ResultType: 'input_required', InputRequests: { ask: elicitArgs } }); + return ok(r.id, { content: [{ type: 'text', text: 'done' }] }); + } + return err(r.id, -32601, 'nope'); + }); + const c = new Client({ name: 'c', version: '1' }, { capabilities: { elicitation: {} } }); + c.setRequestHandler('elicitation/create', async () => ({ action: 'accept', content: { x: 1 } })); + await c.connect(ct); + const result = (await c.callTool({ name: 't', arguments: {} })) as CallToolResult; + expect(result.content[0]).toEqual({ type: 'text', text: 'done' }); + expect(round).toBe(2); + const second = sent.filter(r => r.method === 'tools/call')[1]; + const meta = second?.params?._meta as Record | undefined; + const irs = meta?.['modelcontextprotocol.io/mrtr/inputResponses'] as Record | undefined; + expect(irs?.ask).toEqual({ action: 'accept', content: { x: 1 } }); + }); + + it('throws if no handler is registered for an InputRequest method', async () => { + const { ct } = mockTransport(r => { + if (r.method === 'server/discover') + return ok(r.id, { capabilities: { tools: {} }, serverInfo: { name: 's', version: '1' } }); + if (r.method === 'tools/call') { + return ok(r.id, { ResultType: 'input_required', InputRequests: { s: { method: 'sampling/createMessage' } } }); + } + return err(r.id, -32601, 'nope'); + }); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(ct); + await expect(c.callTool({ name: 't', arguments: {} })).rejects.toThrow(); + }); + + it('caps rounds at mrtrMaxRounds', async () => { + const { ct } = mockTransport(r => { + if (r.method === 'server/discover') + return ok(r.id, { capabilities: { tools: {} }, serverInfo: { name: 's', version: '1' } }); + return ok(r.id, { ResultType: 'input_required', InputRequests: { p: { method: 'ping' } } }); + }); + const c = new Client({ name: 'c', version: '1' }, { mrtrMaxRounds: 3 }); + await c.connect(ct); + await expect(c.callTool({ name: 't', arguments: {} })).rejects.toThrow(/MRTR exceeded 3/); + }); + }); + + describe('connect via legacy pipe Transport (2025-11 compat)', () => { + it('runs initialize handshake over an InMemoryTransport pair', async () => { + const [clientPipe, serverPipe] = InMemoryTransport.createLinkedPair(); + // Minimal hand-rolled server end of the pipe. + serverPipe.onmessage = msg => { + if ('method' in msg && msg.method === 'initialize' && 'id' in msg) { + void serverPipe.send({ jsonrpc: '2.0', id: msg.id, result: initResult() } as JSONRPCResultResponse); + } + if ('method' in msg && msg.method === 'tools/list' && 'id' in msg) { + void serverPipe.send({ jsonrpc: '2.0', id: msg.id, result: { tools: [] } } as JSONRPCResultResponse); + } + }; + await serverPipe.start(); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(clientPipe); + expect(c.getServerCapabilities()?.tools).toBeDefined(); + expect(c.getNegotiatedProtocolVersion()).toBe(LATEST_PROTOCOL_VERSION); + const r = await c.listTools(); + expect(r.tools).toEqual([]); + await c.close(); + }); + + it('skips re-init when transport already has a sessionId', async () => { + const [clientPipe, serverPipe] = InMemoryTransport.createLinkedPair(); + (clientPipe as { sessionId?: string }).sessionId = 'existing'; + const seen: string[] = []; + serverPipe.onmessage = msg => { + if ('method' in msg) seen.push(msg.method); + }; + await serverPipe.start(); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(clientPipe); + expect(seen).not.toContain('initialize'); + }); + }); + + describe('handler registration', () => { + it('setRequestHandler is used for MRTR servicing and pipe-mode dispatch alike', async () => { + const handler = vi.fn(async () => ({ roots: [] })); + const c = new Client({ name: 'c', version: '1' }, { capabilities: { roots: {} } }); + c.setRequestHandler('roots/list', handler); + // Exercise via MRTR path: + const { ct } = mockTransport(r => { + if (r.method === 'server/discover') + return ok(r.id, { capabilities: { tools: {} }, serverInfo: { name: 's', version: '1' } }); + if (r.method === 'tools/call') { + return ok(r.id, { ResultType: 'input_required', InputRequests: { r: { method: 'roots/list' } } }); + } + return ok(r.id, { content: [] }); + }); + await c.connect(ct); + // First call hits input_required → roots/list handler, second resolves. + // We don't await because the second mock branch never returns complete; instead + // verify the handler was invoked at least once via the MRTR servicing path. + const p = c.callTool({ name: 't', arguments: {} }).catch(() => {}); + await new Promise(r => setTimeout(r, 0)); + expect(handler).toHaveBeenCalled(); + void p; + }); + + it('routes per-request notifications from transport to local notification handlers', async () => { + const got: JSONRPCNotification[] = []; + const { ct } = mockTransport(async (r, opts) => { + if (r.method === 'server/discover') + return ok(r.id, { capabilities: { tools: {} }, serverInfo: { name: 's', version: '1' } }); + opts?.onnotification?.({ jsonrpc: '2.0', method: 'notifications/message', params: { level: 'info', data: 'x' } }); + return ok(r.id, { content: [] }); + }); + const c = new Client({ name: 'c', version: '1' }); + c.setNotificationHandler('notifications/message', (n: unknown) => void got.push(n as JSONRPCNotification)); + await c.connect(ct); + await c.callTool({ name: 't', arguments: {} }); + expect(got).toHaveLength(1); + }); + }); + + describe('tasks (SEP-1686 / SEP-2557)', () => { + async function connected(handler: (req: JSONRPCRequest) => FetchResp | Promise) { + const m = mockTransport(req => { + if (req.method === 'server/discover') + return ok(req.id, { + capabilities: { tools: {}, tasks: { tools: { call: true } } }, + serverInfo: { name: 's', version: '1' } + }); + return handler(req); + }); + const c = new Client({ name: 'c', version: '1' }); + await c.connect(m.ct); + return { c, ...m }; + } + + it('experimental.tasks getter exists and is lazily constructed once', async () => { + const { c } = await connected(r => ok(r.id, {})); + const a = c.experimental.tasks; + const b = c.experimental.tasks; + expect(a).toBe(b); + expect(typeof a.callToolStream).toBe('function'); + }); + + it('callTool throws with guidance when server returns a task without awaitTask (v1-compat surface)', async () => { + const taskResult = { task: { taskId: 't-1', status: 'working', createdAt: '2026-01-01T00:00:00Z' } }; + const { c } = await connected(r => (r.method === 'tools/call' ? ok(r.id, taskResult) : err(r.id, -32601, ''))); + await expect(c.callTool({ name: 'slow', arguments: {} })).rejects.toThrow(/returned a task.*awaitTask/); + }); + + const taskBody = (overrides: Record = {}) => ({ + taskId: 't-2', + status: 'working', + ttl: null, + createdAt: '2026-01-01T00:00:00Z', + lastUpdatedAt: '2026-01-01T00:00:00Z', + ...overrides + }); + + it('callTool with awaitTask polls tasks/get then tasks/result', async () => { + let getCalls = 0; + const { c, sent } = await connected(r => { + if (r.method === 'tools/call') return ok(r.id, { task: taskBody() }); + if (r.method === 'tasks/get') { + getCalls++; + return ok(r.id, taskBody({ status: getCalls === 1 ? 'working' : 'completed' })); + } + if (r.method === 'tasks/result') return ok(r.id, { content: [{ type: 'text', text: 'done' }] }); + return err(r.id, -32601, ''); + }); + const result = (await c.callTool({ name: 'slow', arguments: {} }, { awaitTask: true })) as CallToolResult; + expect(result.content[0]).toEqual({ type: 'text', text: 'done' }); + expect(getCalls).toBe(2); + expect(sent.some(r => r.method === 'tasks/result')).toBe(true); + }); + + it('getTask / listTasks / cancelTask call the right methods', async () => { + const { c, sent } = await connected(r => { + if (r.method === 'tasks/get') return ok(r.id, taskBody({ taskId: 'x', status: 'completed' })); + if (r.method === 'tasks/list') return ok(r.id, { tasks: [] }); + if (r.method === 'tasks/cancel') return ok(r.id, taskBody({ taskId: 'x', status: 'cancelled' })); + return err(r.id, -32601, ''); + }); + await c.getTask({ taskId: 'x' }); + await c.listTasks(); + await c.cancelTask({ taskId: 'x' }); + expect(sent.map(r => r.method).filter(m => m.startsWith('tasks/'))).toEqual(['tasks/get', 'tasks/list', 'tasks/cancel']); + }); + + it('taskManager is available on the request-shaped path (Client-owned)', async () => { + const { c } = await connected(r => ok(r.id, {})); + expect(c.taskManager).toBeDefined(); + }); + }); +}); diff --git a/packages/core/package.json b/packages/core/package.json index b5858ff00..855303983 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -83,5 +83,19 @@ "typescript": "catalog:devTools", "typescript-eslint": "catalog:devTools", "vitest": "catalog:devTools" + }, + "types": "./dist/index.d.mts", + "typesVersions": { + "*": { + "public": [ + "./dist/exports/public/index.d.mts" + ], + "types": [ + "./dist/types/index.d.mts" + ], + "*": [ + "./dist/*.d.mts" + ] + } } } diff --git a/packages/core/src/exports/public/index.ts b/packages/core/src/exports/public/index.ts index 2dc1e13a8..4bc735ac8 100644 --- a/packages/core/src/exports/public/index.ts +++ b/packages/core/src/exports/public/index.ts @@ -38,7 +38,8 @@ export { checkResourceAllowed, resourceUrlFromServerUrl } from '../../shared/aut // Metadata utilities export { getDisplayName } from '../../shared/metadataUtils.js'; -// Protocol types (NOT the Protocol class itself or mergeCapabilities) +// Protocol types. The Protocol class is exported for v1-compat (subclassed by +// some consumers); new code should use Dispatcher / McpServer / Client directly. export type { BaseContext, ClientContext, @@ -48,7 +49,10 @@ export type { RequestOptions, ServerContext } from '../../shared/protocol.js'; -export { DEFAULT_REQUEST_TIMEOUT_MSEC } from '../../shared/protocol.js'; +export { DEFAULT_REQUEST_TIMEOUT_MSEC, mergeCapabilities, Protocol } from '../../shared/protocol.js'; + +// In-memory transport (testing + v1 compat) +export { InMemoryTransport } from '../../util/inMemory.js'; // Task manager types (NOT TaskManager class itself — internal) export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from '../../shared/taskManager.js'; @@ -67,8 +71,9 @@ export { takeResult, toArrayAsync } from '../../shared/responseMessage.js'; // stdio message framing utilities (for custom transport authors) export { deserializeMessage, ReadBuffer, serializeMessage } from '../../shared/stdio.js'; -// Transport types (NOT normalizeHeaders) -export type { FetchLike, Transport, TransportSendOptions } from '../../shared/transport.js'; +// Transport types (NOT normalizeHeaders). RequestTransport stays internal until +// SEP-2598 (pluggable transports) finalizes. +export type { ChannelTransport, FetchLike, Transport, TransportSendOptions } from '../../shared/transport.js'; export { createFetchWithInit } from '../../shared/transport.js'; // URI Template @@ -136,7 +141,7 @@ export { isTerminal } from '../../experimental/tasks/interfaces.js'; export { InMemoryTaskMessageQueue, InMemoryTaskStore } from '../../experimental/tasks/stores/inMemory.js'; // Validator types and classes -export type { StandardSchemaWithJSON } from '../../util/standardSchema.js'; +export type { StandardSchemaV1, StandardSchemaWithJSON } from '../../util/standardSchema.js'; export { AjvJsonSchemaValidator } from '../../validators/ajvProvider.js'; export type { CfWorkerSchemaDraft } from '../../validators/cfWorkerProvider.js'; // fromJsonSchema is intentionally NOT exported here — the server and client packages diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index e707d9939..13db3e569 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -2,11 +2,13 @@ export * from './auth/errors.js'; export * from './errors/sdkErrors.js'; export * from './shared/auth.js'; export * from './shared/authUtils.js'; +export * from './shared/dispatcher.js'; export * from './shared/metadataUtils.js'; export * from './shared/protocol.js'; export * from './shared/responseMessage.js'; export * from './shared/stdio.js'; -export type { RequestTaskStore, TaskContext, TaskManagerOptions, TaskRequestOptions } from './shared/taskManager.js'; +export * from './shared/streamDriver.js'; +export type { RequestTaskStore, TaskAttachHooks, TaskContext, TaskManagerOptions, TaskRequestOptions } from './shared/taskManager.js'; export { extractTaskManagerOptions, NullTaskManager, TaskManager } from './shared/taskManager.js'; export * from './shared/toolNameValidation.js'; export * from './shared/transport.js'; diff --git a/packages/core/src/shared/context.ts b/packages/core/src/shared/context.ts new file mode 100644 index 000000000..8338c5e75 --- /dev/null +++ b/packages/core/src/shared/context.ts @@ -0,0 +1,368 @@ +import type { + AuthInfo, + ClientCapabilities, + CreateMessageRequest, + CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResultResponse, + LoggingLevel, + Notification, + Progress, + RelatedTaskMetadata, + Request, + RequestId, + RequestMeta, + RequestMethod, + Result, + ResultTypeMap, + ServerCapabilities, + TaskCreationParams +} from '../types/index.js'; +import type { AnySchema, SchemaOutput } from '../util/schema.js'; +import type { TaskContext, TaskManagerOptions, TaskRequestOptions } from './taskManager.js'; +import type { TransportSendOptions } from './transport.js'; + +/** + * Callback for progress notifications. + */ +export type ProgressCallback = (progress: Progress) => void; + +/** + * Per-request environment a transport adapter passes to {@linkcode Dispatcher.dispatch}. + * Everything is optional; a bare `dispatch()` call works with no transport at all. + */ +export type RequestEnv = { + /** + * Sends a request back to the peer (server→client elicitation/sampling, or + * client→server nested calls). Supplied by {@linkcode StreamDriver} when running + * over a persistent pipe. Defaults to throwing {@linkcode SdkErrorCode.NotConnected}. + */ + send?: (request: Request, options?: RequestOptions) => Promise; + + /** Session identifier from the transport, if any. Surfaced as {@linkcode BaseContext.sessionId}. */ + sessionId?: string; + + /** Validated auth token info for HTTP transports. */ + authInfo?: AuthInfo; + + /** Original HTTP {@linkcode globalThis.Request | Request}, if any. */ + httpReq?: globalThis.Request; + + /** Abort signal for the inbound request. If omitted, a fresh controller is created. */ + signal?: AbortSignal; + + /** Task context, if task storage is configured by the caller. */ + task?: TaskContext; +}; + +/** + * Additional initialization options. + */ +export type ProtocolOptions = { + /** + * Protocol versions supported. First version is preferred (sent by client, + * used as fallback by server). Passed to transport during {@linkcode Protocol.connect | connect()}. + * + * @default {@linkcode SUPPORTED_PROTOCOL_VERSIONS} + */ + supportedProtocolVersions?: string[]; + + /** + * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. + * + * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. + * + * Currently this defaults to `false`, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to `true`. + */ + enforceStrictCapabilities?: boolean; + /** + * An array of notification method names that should be automatically debounced. + * Any notifications with a method in this list will be coalesced if they + * occur in the same tick of the event loop. + * e.g., `['notifications/tools/list_changed']` + */ + debouncedNotificationMethods?: string[]; + + /** + * Runtime configuration for task management. + * If provided, creates a TaskManager with the given options; otherwise a NullTaskManager is used. + * + * Capability assertions are wired automatically from the protocol's + * `assertTaskCapability()` and `assertTaskHandlerCapability()` methods, + * so they should NOT be included here. + */ + tasks?: TaskManagerOptions; +}; + +/** + * The default request timeout, in milliseconds. + */ +export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60_000; + +/** + * Options that can be given per request. + */ +export type RequestOptions = { + /** + * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. + * + * For task-augmented requests: progress notifications continue after {@linkcode CreateTaskResult} is returned and stop automatically when the task reaches a terminal status. + */ + onprogress?: ProgressCallback; + + /** + * Can be used to cancel an in-flight request. This will cause an `AbortError` to be raised from {@linkcode Protocol.request | request()}. + */ + signal?: AbortSignal; + + /** + * A timeout (in milliseconds) for this request. If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised from {@linkcode Protocol.request | request()}. + * + * If not specified, {@linkcode DEFAULT_REQUEST_TIMEOUT_MSEC} will be used as the timeout. + */ + timeout?: number; + + /** + * If `true`, receiving a progress notification will reset the request timeout. + * This is useful for long-running operations that send periodic progress updates. + * Default: `false` + */ + resetTimeoutOnProgress?: boolean; + + /** + * Maximum total time (in milliseconds) to wait for a response. + * If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised, regardless of progress notifications. + * If not specified, there is no maximum total timeout. + */ + maxTotalTimeout?: number; + + /** + * If provided, augments the request with task creation parameters to enable call-now, fetch-later execution patterns. + */ + task?: TaskCreationParams; + + /** + * If provided, associates this request with a related task. + */ + relatedTask?: RelatedTaskMetadata; + + /** + * @internal Called by the channel adapter after the wire request is built + * (id assigned, `_meta.progressToken` set, response/progress handlers registered) + * but before send. Return `true` to skip the send; the caller takes ownership of + * delivering `wire` later. The registered handlers stay live. `settle` injects + * a result/error into the registered response handler out-of-band. + */ + intercept?: ( + wire: JSONRPCRequest, + messageId: number, + settle: (response: JSONRPCResultResponse | Error) => void, + onError: (error: unknown) => void + ) => boolean; +} & TransportSendOptions; + +/** + * Options that can be given per notification. + */ +export type NotificationOptions = { + /** + * May be used to indicate to the transport which incoming request to associate this outgoing notification with. + */ + relatedRequestId?: RequestId; + + /** + * If provided, associates this notification with a related task. + */ + relatedTask?: RelatedTaskMetadata; +}; + +/** + * The minimal contract a {@linkcode Dispatcher} owner needs to send outbound + * requests/notifications to the connected peer. Decouples {@linkcode McpServer} + * (and the compat {@linkcode Protocol}) from any specific transport adapter: + * they hold an `Outbound`, not a `StreamDriver`. + * + * {@linkcode StreamDriver} implements this for persistent pipes. Request-shaped + * paths can supply their own (e.g. routing through a backchannel). + */ +export interface Outbound { + /** Send a request to the peer and resolve with the parsed result. */ + request(req: Request, resultSchema: T, options?: RequestOptions): Promise>; + /** Send a notification to the peer. */ + notification(notification: Notification, options?: NotificationOptions): Promise; + /** Close the underlying connection. */ + close(): Promise; + /** Clear a registered progress callback by its message id. Optional; pipe-channels expose this for {@linkcode TaskManager}. */ + removeProgressHandler?(messageId: number): void; + /** Inform the channel which protocol version was negotiated (for header echoing etc.). Optional. */ + setProtocolVersion?(version: string): void; + /** Write a raw JSON-RPC message on the same stream as a prior request. Optional; pipe-only. */ + sendRaw?(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise; +} + +/** + * Base context provided to all request handlers. + */ +export type BaseContext = { + /** + * The session ID from the transport, if available. + */ + sessionId?: string; + + /** + * Information about the MCP request being handled. + */ + mcpReq: { + /** + * The JSON-RPC ID of the request being handled. + */ + id: RequestId; + + /** + * The method name of the request (e.g., 'tools/call', 'ping'). + */ + method: string; + + /** + * Metadata from the original request. + */ + _meta?: RequestMeta; + + /** + * An abort signal used to communicate if the request was cancelled from the sender's side. + */ + signal: AbortSignal; + + /** + * Sends a request that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + send: ( + request: { method: M; params?: Record }, + options?: TaskRequestOptions + ) => Promise; + + /** + * Sends a notification that relates to the current request being handled. + * + * This is used by certain transports to correctly associate related messages. + */ + notify: (notification: Notification) => Promise; + }; + + /** + * HTTP transport information, only available when using an HTTP-based transport. + */ + http?: { + /** + * Information about a validated access token, provided to request handlers. + */ + authInfo?: AuthInfo; + }; + + /** + * Task context, available when task storage is configured. + */ + task?: TaskContext; + + // ─── v1 flat aliases (deprecated) ──────────────────────────────────── + // v1's RequestHandlerExtra exposed these at the top level. v2 nests them + // under {@linkcode mcpReq} / {@linkcode http}. The flat forms are kept + // typed (and populated at runtime by McpServer.buildContext) so v1 handler + // code keeps compiling. Prefer the nested paths for new code. + + /** @deprecated Use {@linkcode mcpReq.signal}. */ + signal?: AbortSignal; + /** @deprecated Use {@linkcode mcpReq.id}. */ + requestId?: RequestId; + /** @deprecated Use {@linkcode mcpReq._meta}. */ + _meta?: RequestMeta; + /** @deprecated Use {@linkcode mcpReq.notify}. */ + sendNotification?: (notification: Notification) => Promise; + /** @deprecated Use {@linkcode mcpReq.send}. */ + sendRequest?: ( + request: { method: M; params?: Record }, + options?: TaskRequestOptions + ) => Promise; + /** @deprecated Use {@linkcode http.authInfo}. */ + authInfo?: AuthInfo; + /** @deprecated v1 carried raw request info here. v2 surfaces the web `Request` via {@linkcode ServerContext.http}. */ + requestInfo?: globalThis.Request; +}; + +/** + * Context provided to server-side request handlers, extending {@linkcode BaseContext} with server-specific fields. + */ +export type ServerContext = BaseContext & { + mcpReq: { + /** + * Send a log message notification to the client. + * Respects the client's log level filter set via logging/setLevel. + */ + log: (level: LoggingLevel, data: unknown, logger?: string) => Promise; + + /** + * Send an elicitation request to the client, requesting user input. + */ + elicitInput: (params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions) => Promise; + + /** + * Request LLM sampling from the client. + */ + requestSampling: ( + params: CreateMessageRequest['params'], + options?: RequestOptions + ) => Promise; + }; + + http?: { + /** + * The original HTTP request. + */ + req?: globalThis.Request; + + /** + * Closes the SSE stream for this request, triggering client reconnection. + * Only available when using a StreamableHTTPServerTransport with eventStore configured. + */ + closeSSE?: () => void; + + /** + * Closes the standalone GET SSE stream, triggering client reconnection. + * Only available when using a StreamableHTTPServerTransport with eventStore configured. + */ + closeStandaloneSSE?: () => void; + }; +}; + +/** + * Context provided to client-side request handlers. + */ +export type ClientContext = BaseContext; + +function isPlainObject(value: unknown): value is Record { + return value !== null && typeof value === 'object' && !Array.isArray(value); +} + +export function mergeCapabilities(base: ServerCapabilities, additional: Partial): ServerCapabilities; +export function mergeCapabilities(base: ClientCapabilities, additional: Partial): ClientCapabilities; +export function mergeCapabilities(base: T, additional: Partial): T { + const result: T = { ...base }; + for (const key in additional) { + const k = key as keyof T; + const addValue = additional[k]; + if (addValue === undefined) continue; + const baseValue = result[k]; + result[k] = + isPlainObject(baseValue) && isPlainObject(addValue) + ? ({ ...(baseValue as Record), ...(addValue as Record) } as T[typeof k]) + : (addValue as T[typeof k]); + } + return result; +} diff --git a/packages/core/src/shared/dispatcher.ts b/packages/core/src/shared/dispatcher.ts new file mode 100644 index 000000000..65bb781d4 --- /dev/null +++ b/packages/core/src/shared/dispatcher.ts @@ -0,0 +1,306 @@ +import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + Notification, + NotificationMethod, + NotificationTypeMap, + Request, + RequestId, + RequestMethod, + RequestTypeMap, + Result, + ResultTypeMap +} from '../types/index.js'; +import { getNotificationSchema, getRequestSchema, ProtocolError, ProtocolErrorCode } from '../types/index.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import type { BaseContext, RequestEnv, RequestOptions } from './context.js'; + +/** + * One yielded item from {@linkcode Dispatcher.dispatch}. A dispatch yields zero or more + * notifications followed by exactly one terminal response. + */ +export type DispatchOutput = + | { kind: 'notification'; message: JSONRPCNotification } + | { kind: 'response'; message: JSONRPCResponse | JSONRPCErrorResponse }; + +type RawHandler = (request: JSONRPCRequest, ctx: ContextT) => Promise; + +/** Signature of {@linkcode Dispatcher.dispatch}. Target type for {@linkcode DispatchMiddleware}. */ +export type DispatchFn = (req: JSONRPCRequest, env?: RequestEnv) => AsyncGenerator; + +/** + * Onion-style middleware around {@linkcode Dispatcher.dispatch}. Registered via + * {@linkcode Dispatcher.use}; composed outermost-first (registration order). + * + * A middleware may transform `req`/`env` before delegating, transform or filter + * yielded outputs, or short-circuit by yielding a response without calling `next`. + */ +export type DispatchMiddleware = (next: DispatchFn) => DispatchFn; + +/** + * Stateless JSON-RPC handler registry with a request-in / messages-out + * {@linkcode Dispatcher.dispatch | dispatch()} entry point. + * + * Holds no transport, no correlation state, no timers. One instance can serve + * any number of concurrent requests from any driver. + */ +export class Dispatcher { + protected _requestHandlers: Map> = new Map(); + protected _notificationHandlers: Map Promise> = new Map(); + private _dispatchMw: DispatchMiddleware[] = []; + + /** + * A handler to invoke for any request types that do not have their own handler installed. + */ + fallbackRequestHandler?: (request: JSONRPCRequest, ctx: ContextT) => Promise; + + /** + * A handler to invoke for any notification types that do not have their own handler installed. + */ + fallbackNotificationHandler?: (notification: Notification) => Promise; + + /** + * Subclasses override to enrich the context (e.g. {@linkcode ServerContext}). Default returns base unchanged. + */ + protected buildContext(base: BaseContext, _env: RequestEnv): ContextT { + return base as ContextT; + } + + /** + * Registers a {@linkcode DispatchMiddleware}. Registration order is outer-to-inner: + * the first middleware registered sees the rawest request and the final yields. + */ + use(mw: DispatchMiddleware): this { + this._dispatchMw.push(mw); + return this; + } + + /** + * Dispatch one inbound request through the registered middleware chain, then the + * core handler lookup. Yields any notifications the handler emits via + * `ctx.mcpReq.notify()`, then yields exactly one terminal response. + * + * Never throws for handler errors; they are wrapped as JSON-RPC error responses. + * May throw if iteration itself is misused. + */ + dispatch(request: JSONRPCRequest, env: RequestEnv = {}): AsyncGenerator { + // eslint-disable-next-line unicorn/consistent-function-scoping -- closes over `this` + let chain: DispatchFn = (r, e) => this._dispatchCore(r, e); + // eslint-disable-next-line unicorn/no-array-reverse -- toReversed() requires ES2023 lib; consumers may target ES2022 + for (const mw of [...this._dispatchMw].reverse()) chain = mw(chain); + return chain(request, env); + } + + /** + * The handler lookup + invocation. Middleware composes around this; subclasses do + * not override `dispatch()` directly — use {@linkcode Dispatcher.use | use()} instead. + */ + private async *_dispatchCore(request: JSONRPCRequest, env: RequestEnv = {}): AsyncGenerator { + const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; + + if (handler === undefined) { + yield errorResponse(request.id, ProtocolErrorCode.MethodNotFound, 'Method not found'); + return; + } + + const queue: JSONRPCNotification[] = []; + let wake: (() => void) | undefined; + let done = false; + let final: JSONRPCResponse | JSONRPCErrorResponse | undefined; + + const localAbort = new AbortController(); + if (env.signal) { + if (env.signal.aborted) localAbort.abort(env.signal.reason); + else env.signal.addEventListener('abort', () => localAbort.abort(env.signal!.reason), { once: true }); + } + + const send = + env.send ?? + (async () => { + throw new SdkError( + SdkErrorCode.NotConnected, + 'ctx.mcpReq.send is unavailable: no peer channel. Use the MRTR-native return form for elicitation/sampling, or run via connect()/StreamDriver.' + ); + }); + + const base: BaseContext = { + sessionId: env.sessionId, + mcpReq: { + id: request.id, + method: request.method, + _meta: request.params?._meta, + signal: localAbort.signal, + send: (r: { method: M; params?: Record }, options?: RequestOptions) => + send(r as Request, options) as Promise, + notify: async (n: Notification) => { + if (done) return; + queue.push({ jsonrpc: '2.0', method: n.method, params: n.params } as JSONRPCNotification); + wake?.(); + } + }, + http: env.authInfo || env.httpReq ? { authInfo: env.authInfo } : undefined, + task: env.task + }; + const ctx = this.buildContext(base, env); + + Promise.resolve() + .then(() => handler(request, ctx)) + .then( + result => { + final = localAbort.signal.aborted + ? errorResponse(request.id, ProtocolErrorCode.InternalError, 'Request cancelled').message + : { jsonrpc: '2.0', id: request.id, result }; + }, + error => { + final = toErrorResponse(request.id, error); + } + ) + .finally(() => { + done = true; + wake?.(); + }); + + while (true) { + while (queue.length > 0) { + yield { kind: 'notification', message: queue.shift()! }; + } + if (done) break; + await new Promise(resolve => { + wake = resolve; + }); + wake = undefined; + } + // Drain anything pushed between done=true and the wake. + while (queue.length > 0) { + yield { kind: 'notification', message: queue.shift()! }; + } + yield { kind: 'response', message: final! }; + } + + /** + * Dispatch one inbound notification to its handler. Errors are reported via the + * returned promise; unknown methods are silently ignored. + */ + async dispatchNotification(notification: JSONRPCNotification): Promise { + const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; + if (handler === undefined) return; + await Promise.resolve().then(() => handler(notification)); + } + + /** + * Registers a handler to invoke when this dispatcher receives a request with the given method. + * + * For spec methods, the request is parsed against the spec schema and the handler receives + * the typed `RequestTypeMap[M]`. + */ + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise + ): void; + /** + * Registers a handler for a custom (non-spec) method. The provided `paramsSchema` validates + * `request.params` (with `_meta` stripped); the handler receives the parsed params object. + */ + setRequestHandler( + method: string, + paramsSchema: S, + handler: (params: StandardSchemaV1.InferOutput, ctx: ContextT) => Result | Promise + ): void; + setRequestHandler(method: string, schemaOrHandler: unknown, maybeHandler?: unknown): void { + if (maybeHandler !== undefined) { + const userHandler = maybeHandler as (params: unknown, ctx: ContextT) => Result | Promise; + this._requestHandlers.set(method, this._wrapParamsSchemaHandler(method, schemaOrHandler as StandardSchemaV1, userHandler)); + return; + } + const handler = schemaOrHandler as (request: unknown, ctx: ContextT) => Result | Promise; + const schema = getRequestSchema(method as RequestMethod); + this._requestHandlers.set(method, (request, ctx) => { + const parsed = schema.parse(request); + return Promise.resolve(handler(parsed, ctx)); + }); + } + + /** + * Builds a raw handler that validates `request.params` (minus `_meta`) against `paramsSchema` + * and invokes `handler(parsedParams, ctx)`. Shared with subclass overrides so per-method + * wrapping composes uniformly with the 3-arg form. + */ + protected _wrapParamsSchemaHandler( + method: string, + paramsSchema: StandardSchemaV1, + handler: (params: unknown, ctx: ContextT) => Result | Promise + ): RawHandler { + return async (request, ctx) => { + const { _meta, ...userParams } = (request.params ?? {}) as Record; + void _meta; + const parsed = await paramsSchema['~standard'].validate(userParams); + if (parsed.issues) { + const msg = parsed.issues.map(i => i.message).join('; '); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${msg}`); + } + return handler(parsed.value, ctx); + }; + } + + /** Registers a raw handler with no schema parsing. Used for compat shims. */ + setRawRequestHandler(method: string, handler: RawHandler): void { + this._requestHandlers.set(method, handler); + } + + removeRequestHandler(method: string): void { + this._requestHandlers.delete(method); + } + + assertCanSetRequestHandler(method: string): void { + if (this._requestHandlers.has(method)) { + throw new Error(`A request handler for ${method} already exists, which would be overridden`); + } + } + + setNotificationHandler( + method: M, + handler: (notification: NotificationTypeMap[M]) => void | Promise + ): void { + const schema = getNotificationSchema(method); + this._notificationHandlers.set(method, notification => { + const parsed = schema.parse(notification); + return Promise.resolve(handler(parsed)); + }); + } + + removeNotificationHandler(method: string): void { + this._notificationHandlers.delete(method); + } + + /** Convenience: collect a full dispatch into a single response, discarding notifications. */ + async dispatchToResponse(request: JSONRPCRequest, env?: RequestEnv): Promise { + let resp: JSONRPCResponse | JSONRPCErrorResponse | undefined; + for await (const out of this.dispatch(request, env)) { + if (out.kind === 'response') resp = out.message; + } + return resp!; + } +} + +function errorResponse(id: RequestId, code: number, message: string): { kind: 'response'; message: JSONRPCErrorResponse } { + return { kind: 'response', message: { jsonrpc: '2.0', id, error: { code, message } } }; +} + +function toErrorResponse(id: RequestId, error: unknown): JSONRPCErrorResponse { + const e = error as { code?: unknown; message?: unknown; data?: unknown }; + return { + jsonrpc: '2.0', + id, + error: { + code: Number.isSafeInteger(e?.code) ? (e.code as number) : ProtocolErrorCode.InternalError, + message: typeof e?.message === 'string' ? e.message : 'Internal error', + ...(e?.data !== undefined && { data: e.data }) + } + }; +} + +/** Re-export for convenience; the canonical definition lives in protocol.ts for now. */ +// BaseContext / RequestOptions are exported from protocol.ts via the core barrel. diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 57eab6932..5d5b38be9 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -1,312 +1,43 @@ +/** + * v1-compat module. The types live in {@link ./context.ts}; the runtime lives in + * {@link Dispatcher} + {@link StreamDriver}. The {@link Protocol} class here is + * a thin wrapper that composes those two so that v1 code subclassing `Protocol` + * keeps working. + */ + import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; import type { - AuthInfo, - CancelledNotification, - ClientCapabilities, - CreateMessageRequest, - CreateMessageResult, - CreateMessageResultWithTools, - ElicitRequestFormParams, - ElicitRequestURLParams, - ElicitResult, - JSONRPCErrorResponse, - JSONRPCNotification, JSONRPCRequest, - JSONRPCResponse, - JSONRPCResultResponse, - LoggingLevel, MessageExtraInfo, Notification, NotificationMethod, NotificationTypeMap, - Progress, - ProgressNotification, - RelatedTaskMetadata, Request, - RequestId, - RequestMeta, RequestMethod, RequestTypeMap, Result, - ResultTypeMap, - ServerCapabilities, - TaskCreationParams -} from '../types/index.js'; -import { - getNotificationSchema, - getRequestSchema, - getResultSchema, - isJSONRPCErrorResponse, - isJSONRPCNotification, - isJSONRPCRequest, - isJSONRPCResultResponse, - ProtocolError, - ProtocolErrorCode, - SUPPORTED_PROTOCOL_VERSIONS + ResultTypeMap } from '../types/index.js'; +import { getResultSchema, SUPPORTED_PROTOCOL_VERSIONS } from '../types/index.js'; import type { AnySchema, SchemaOutput } from '../util/schema.js'; -import { parseSchema } from '../util/schema.js'; -import type { TaskContext, TaskManagerHost, TaskManagerOptions, TaskRequestOptions } from './taskManager.js'; +import type { BaseContext, NotificationOptions, Outbound, ProtocolOptions, RequestEnv, RequestOptions } from './context.js'; +import type { DispatchMiddleware } from './dispatcher.js'; +import { Dispatcher } from './dispatcher.js'; +import { StreamDriver } from './streamDriver.js'; import { NullTaskManager, TaskManager } from './taskManager.js'; -import type { Transport, TransportSendOptions } from './transport.js'; - -/** - * Callback for progress notifications. - */ -export type ProgressCallback = (progress: Progress) => void; - -/** - * Additional initialization options. - */ -export type ProtocolOptions = { - /** - * Protocol versions supported. First version is preferred (sent by client, - * used as fallback by server). Passed to transport during {@linkcode Protocol.connect | connect()}. - * - * @default {@linkcode SUPPORTED_PROTOCOL_VERSIONS} - */ - supportedProtocolVersions?: string[]; - - /** - * Whether to restrict emitted requests to only those that the remote side has indicated that they can handle, through their advertised capabilities. - * - * Note that this DOES NOT affect checking of _local_ side capabilities, as it is considered a logic error to mis-specify those. - * - * Currently this defaults to `false`, for backwards compatibility with SDK versions that did not advertise capabilities correctly. In future, this will default to `true`. - */ - enforceStrictCapabilities?: boolean; - /** - * An array of notification method names that should be automatically debounced. - * Any notifications with a method in this list will be coalesced if they - * occur in the same tick of the event loop. - * e.g., `['notifications/tools/list_changed']` - */ - debouncedNotificationMethods?: string[]; - - /** - * Runtime configuration for task management. - * If provided, creates a TaskManager with the given options; otherwise a NullTaskManager is used. - * - * Capability assertions are wired automatically from the protocol's - * `assertTaskCapability()` and `assertTaskHandlerCapability()` methods, - * so they should NOT be included here. - */ - tasks?: TaskManagerOptions; -}; - -/** - * The default request timeout, in milliseconds. - */ -export const DEFAULT_REQUEST_TIMEOUT_MSEC = 60_000; - -/** - * Options that can be given per request. - */ -export type RequestOptions = { - /** - * If set, requests progress notifications from the remote end (if supported). When progress notifications are received, this callback will be invoked. - * - * For task-augmented requests: progress notifications continue after {@linkcode CreateTaskResult} is returned and stop automatically when the task reaches a terminal status. - */ - onprogress?: ProgressCallback; - - /** - * Can be used to cancel an in-flight request. This will cause an `AbortError` to be raised from {@linkcode Protocol.request | request()}. - */ - signal?: AbortSignal; - - /** - * A timeout (in milliseconds) for this request. If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised from {@linkcode Protocol.request | request()}. - * - * If not specified, {@linkcode DEFAULT_REQUEST_TIMEOUT_MSEC} will be used as the timeout. - */ - timeout?: number; - - /** - * If `true`, receiving a progress notification will reset the request timeout. - * This is useful for long-running operations that send periodic progress updates. - * Default: `false` - */ - resetTimeoutOnProgress?: boolean; - - /** - * Maximum total time (in milliseconds) to wait for a response. - * If exceeded, an {@linkcode SdkError} with code {@linkcode SdkErrorCode.RequestTimeout} will be raised, regardless of progress notifications. - * If not specified, there is no maximum total timeout. - */ - maxTotalTimeout?: number; - - /** - * If provided, augments the request with task creation parameters to enable call-now, fetch-later execution patterns. - */ - task?: TaskCreationParams; - - /** - * If provided, associates this request with a related task. - */ - relatedTask?: RelatedTaskMetadata; -} & TransportSendOptions; - -/** - * Options that can be given per notification. - */ -export type NotificationOptions = { - /** - * May be used to indicate to the transport which incoming request to associate this outgoing notification with. - */ - relatedRequestId?: RequestId; - - /** - * If provided, associates this notification with a related task. - */ - relatedTask?: RelatedTaskMetadata; -}; - -/** - * Base context provided to all request handlers. - */ -export type BaseContext = { - /** - * The session ID from the transport, if available. - */ - sessionId?: string; - - /** - * Information about the MCP request being handled. - */ - mcpReq: { - /** - * The JSON-RPC ID of the request being handled. - */ - id: RequestId; - - /** - * The method name of the request (e.g., 'tools/call', 'ping'). - */ - method: string; - - /** - * Metadata from the original request. - */ - _meta?: RequestMeta; - - /** - * An abort signal used to communicate if the request was cancelled from the sender's side. - */ - signal: AbortSignal; - - /** - * Sends a request that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - */ - send: ( - request: { method: M; params?: Record }, - options?: TaskRequestOptions - ) => Promise; - - /** - * Sends a notification that relates to the current request being handled. - * - * This is used by certain transports to correctly associate related messages. - */ - notify: (notification: Notification) => Promise; - }; - - /** - * HTTP transport information, only available when using an HTTP-based transport. - */ - http?: { - /** - * Information about a validated access token, provided to request handlers. - */ - authInfo?: AuthInfo; - }; +import type { Transport } from './transport.js'; - /** - * Task context, available when task storage is configured. - */ - task?: TaskContext; -}; +export * from './context.js'; /** - * Context provided to server-side request handlers, extending {@linkcode BaseContext} with server-specific fields. - */ -export type ServerContext = BaseContext & { - mcpReq: { - /** - * Send a log message notification to the client. - * Respects the client's log level filter set via logging/setLevel. - */ - log: (level: LoggingLevel, data: unknown, logger?: string) => Promise; - - /** - * Send an elicitation request to the client, requesting user input. - */ - elicitInput: (params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions) => Promise; - - /** - * Request LLM sampling from the client. - */ - requestSampling: ( - params: CreateMessageRequest['params'], - options?: RequestOptions - ) => Promise; - }; - - http?: { - /** - * The original HTTP request. - */ - req?: globalThis.Request; - - /** - * Closes the SSE stream for this request, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - */ - closeSSE?: () => void; - - /** - * Closes the standalone GET SSE stream, triggering client reconnection. - * Only available when using a StreamableHTTPServerTransport with eventStore configured. - */ - closeStandaloneSSE?: () => void; - }; -}; - -/** - * Context provided to client-side request handlers. - */ -export type ClientContext = BaseContext; - -/** - * Information about a request's timeout state - */ -type TimeoutInfo = { - timeoutId: ReturnType; - startTime: number; - timeout: number; - maxTotalTimeout?: number; - resetTimeoutOnProgress: boolean; - onTimeout: () => void; -}; - -/** - * Implements MCP protocol framing on top of a pluggable transport, including - * features like request/response linking, notifications, and progress. + * v1-compat MCP protocol base. New code should use {@linkcode McpServer} (which + * extends {@linkcode Dispatcher}) or {@linkcode Client}. This class composes a + * {@linkcode Dispatcher} (handler registry + dispatch) and a + * {@linkcode StreamDriver} (per-connection state) to preserve the v1 surface. */ export abstract class Protocol { - private _transport?: Transport; - private _requestMessageId = 0; - private _requestHandlers: Map Promise> = new Map(); - private _requestHandlerAbortControllers: Map = new Map(); - private _notificationHandlers: Map Promise> = new Map(); - private _responseHandlers: Map void> = new Map(); - private _progressHandlers: Map = new Map(); - private _timeoutInfo: Map = new Map(); - private _pendingDebouncedNotifications = new Set(); - - private _taskManager: TaskManager; + private _outbound?: Outbound; + private readonly _dispatcher: Dispatcher; protected _supportedProtocolVersions: string[]; @@ -324,459 +55,144 @@ export abstract class Protocol { */ onerror?: (error: Error) => void; - /** - * A handler to invoke for any request types that do not have their own handler installed. - */ - fallbackRequestHandler?: (request: JSONRPCRequest, ctx: ContextT) => Promise; - - /** - * A handler to invoke for any notification types that do not have their own handler installed. - */ - fallbackNotificationHandler?: (notification: Notification) => Promise; - constructor(private _options?: ProtocolOptions) { + // eslint-disable-next-line @typescript-eslint/no-this-alias, unicorn/no-this-assignment + const self = this; + this._dispatcher = new (class extends Dispatcher { + protected override buildContext(base: BaseContext, env: RequestEnv & { _transportExtra?: MessageExtraInfo }): ContextT { + return self.buildContext(base, env._transportExtra); + } + })(); this._supportedProtocolVersions = _options?.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; - - // Create TaskManager from protocol options - this._taskManager = _options?.tasks ? new TaskManager(_options.tasks) : new NullTaskManager(); - this._bindTaskManager(); - - this.setNotificationHandler('notifications/cancelled', notification => { - this._oncancel(notification); - }); - - this.setNotificationHandler('notifications/progress', notification => { - this._onprogress(notification); - }); - - this.setRequestHandler( - 'ping', - // Automatic pong by default. - _request => ({}) as Result - ); - } - - /** - * Access the TaskManager for task orchestration. - * Always available; returns a NullTaskManager when no task store is configured. - */ - get taskManager(): TaskManager { - return this._taskManager; - } - - private _bindTaskManager(): void { - const taskManager = this._taskManager; - const host: TaskManagerHost = { - request: (request, resultSchema, options) => this._requestWithSchema(request, resultSchema, options), - notification: (notification, options) => this.notification(notification, options), - reportError: error => this._onerror(error), - removeProgressHandler: token => this._progressHandlers.delete(token), - registerHandler: (method, handler) => { - const schema = getRequestSchema(method as RequestMethod); - this._requestHandlers.set(method, (request, ctx) => { - // Validate request params via Zod (strips jsonrpc/id, so we pass original to handler) - schema.parse(request); - return handler(request, ctx); - }); - }, - sendOnResponseStream: async (message, relatedRequestId) => { - await this._transport?.send(message, { relatedRequestId }); - }, + this._ownTaskManager = _options?.tasks ? new TaskManager(_options.tasks) : new NullTaskManager(); + this._ownTaskManager.attachTo(this._dispatcher, { + channel: () => this._outbound, + reportError: e => this.onerror?.(e), enforceStrictCapabilities: this._options?.enforceStrictCapabilities === true, - assertTaskCapability: method => this.assertTaskCapability(method), - assertTaskHandlerCapability: method => this.assertTaskHandlerCapability(method) - }; - taskManager.bind(host); - } - - /** - * Builds the context object for request handlers. Subclasses must override - * to return the appropriate context type (e.g., ServerContext adds HTTP request info). - */ - protected abstract buildContext(ctx: BaseContext, transportInfo?: MessageExtraInfo): ContextT; - - private async _oncancel(notification: CancelledNotification): Promise { - if (!notification.params.requestId) { - return; - } - // Handle request cancellation - const controller = this._requestHandlerAbortControllers.get(notification.params.requestId); - controller?.abort(notification.params.reason); - } - - private _setupTimeout( - messageId: number, - timeout: number, - maxTotalTimeout: number | undefined, - onTimeout: () => void, - resetTimeoutOnProgress: boolean = false - ) { - this._timeoutInfo.set(messageId, { - timeoutId: setTimeout(onTimeout, timeout), - startTime: Date.now(), - timeout, - maxTotalTimeout, - resetTimeoutOnProgress, - onTimeout + assertTaskCapability: m => this.assertTaskCapability(m), + assertTaskHandlerCapability: m => this.assertTaskHandlerCapability(m) }); } - private _resetTimeout(messageId: number): boolean { - const info = this._timeoutInfo.get(messageId); - if (!info) return false; + private readonly _ownTaskManager: TaskManager; - const totalElapsed = Date.now() - info.startTime; - if (info.maxTotalTimeout && totalElapsed >= info.maxTotalTimeout) { - this._timeoutInfo.delete(messageId); - throw new SdkError(SdkErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { - maxTotalTimeout: info.maxTotalTimeout, - totalElapsed - }); - } - - clearTimeout(info.timeoutId); - info.timeoutId = setTimeout(info.onTimeout, info.timeout); - return true; + /** Register a {@linkcode DispatchMiddleware} on the inner dispatcher. */ + use(mw: DispatchMiddleware): this { + this._dispatcher.use(mw); + return this; } - private _cleanupTimeout(messageId: number) { - const info = this._timeoutInfo.get(messageId); - if (info) { - clearTimeout(info.timeoutId); - this._timeoutInfo.delete(messageId); - } - } + // ─────────────────────────────────────────────────────────────────────── + // Subclass hooks (v1 signatures) + // ─────────────────────────────────────────────────────────────────────── /** - * Attaches to the given transport, starts it, and starts listening for messages. - * - * The caller assumes ownership of the {@linkcode Transport}, replacing any callbacks that have already been set, and expects that it is the only user of the {@linkcode Transport} instance going forward. + * Subclasses override to enrich the handler context. v1 signature; the + * {@linkcode MessageExtraInfo} is forwarded from the transport. */ - async connect(transport: Transport): Promise { - this._transport = transport; - const _onclose = this.transport?.onclose; - this._transport.onclose = () => { - try { - _onclose?.(); - } finally { - this._onclose(); - } - }; - - const _onerror = this.transport?.onerror; - this._transport.onerror = (error: Error) => { - _onerror?.(error); - this._onerror(error); - }; - - const _onmessage = this._transport?.onmessage; - this._transport.onmessage = (message, extra) => { - _onmessage?.(message, extra); - if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - this._onresponse(message); - } else if (isJSONRPCRequest(message)) { - this._onrequest(message, extra); - } else if (isJSONRPCNotification(message)) { - this._onnotification(message); - } else { - this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); - } - }; - - // Pass supported protocol versions to transport for header validation - transport.setSupportedProtocolVersions?.(this._supportedProtocolVersions); - - await this._transport.start(); + protected buildContext(ctx: BaseContext, _transportInfo?: MessageExtraInfo): ContextT { + return ctx as ContextT; } - private _onclose(): void { - const responseHandlers = this._responseHandlers; - this._responseHandlers = new Map(); - this._progressHandlers.clear(); - this._taskManager.onClose(); - this._pendingDebouncedNotifications.clear(); - - for (const info of this._timeoutInfo.values()) { - clearTimeout(info.timeoutId); - } - this._timeoutInfo.clear(); - - const requestHandlerAbortControllers = this._requestHandlerAbortControllers; - this._requestHandlerAbortControllers = new Map(); - - const error = new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed'); + /** Override to enforce capabilities. Default is a no-op. */ + protected assertCapabilityForMethod(_method: RequestMethod): void {} + /** Override to enforce capabilities. Default is a no-op. */ + protected assertNotificationCapability(_method: NotificationMethod): void {} + /** Override to enforce capabilities. Default is a no-op. */ + protected assertRequestHandlerCapability(_method: string): void {} + /** Override to enforce capabilities. Default is a no-op. */ + protected assertTaskCapability(_method: string): void {} + /** Override to enforce capabilities. Default is a no-op. */ + protected assertTaskHandlerCapability(_method: string): void {} - this._transport = undefined; + // ─────────────────────────────────────────────────────────────────────── + // Handler registration (delegates to Dispatcher) + // ─────────────────────────────────────────────────────────────────────── - try { - this.onclose?.(); - } finally { - for (const handler of responseHandlers.values()) { - handler(error); - } - - for (const controller of requestHandlerAbortControllers.values()) { - controller.abort(error); - } - } + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise + ): void { + this.assertRequestHandlerCapability(method); + this._dispatcher.setRequestHandler(method, handler); } - private _onerror(error: Error): void { - this.onerror?.(error); + removeRequestHandler(method: string): void { + this._dispatcher.removeRequestHandler(method); } - private _onnotification(notification: JSONRPCNotification): void { - const handler = this._notificationHandlers.get(notification.method) ?? this.fallbackNotificationHandler; - - // Ignore notifications not being subscribed to. - if (handler === undefined) { - return; - } - - // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() - .then(() => handler(notification)) - .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); + assertCanSetRequestHandler(method: string): void { + this._dispatcher.assertCanSetRequestHandler(method); } - private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { - const handler = this._requestHandlers.get(request.method) ?? this.fallbackRequestHandler; - - // Capture the current transport at request time to ensure responses go to the correct client - const capturedTransport = this._transport; - - // Delegate context extraction to module (if registered) - const inboundCtx = { - sessionId: capturedTransport?.sessionId, - sendNotification: (notification: Notification, options?: NotificationOptions) => - this.notification(notification, { ...options, relatedRequestId: request.id }), - sendRequest: (r: Request, resultSchema: U, options?: RequestOptions) => - this._requestWithSchema(r, resultSchema, { ...options, relatedRequestId: request.id }) - }; - - // Delegate to TaskManager for task context, wrapped send/notify, and response routing - const taskResult = this._taskManager.processInboundRequest(request, inboundCtx); - const sendNotification = taskResult.sendNotification; - const sendRequest = taskResult.sendRequest; - const taskContext = taskResult.taskContext; - const routeResponse = taskResult.routeResponse; - const validators: Array<() => void> = []; - if (taskResult.validateInbound) validators.push(taskResult.validateInbound); - - if (handler === undefined) { - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: ProtocolErrorCode.MethodNotFound, - message: 'Method not found' - } - }; - - // Queue or send the error response based on whether this is a task-related request - routeResponse(errorResponse) - .then(routed => { - if (!routed) { - capturedTransport - ?.send(errorResponse) - .catch(error => this._onerror(new Error(`Failed to send an error response: ${error}`))); - } - }) - .catch(error => this._onerror(new Error(`Failed to enqueue error response: ${error}`))); - return; - } - - const abortController = new AbortController(); - this._requestHandlerAbortControllers.set(request.id, abortController); - - const baseCtx: BaseContext = { - sessionId: capturedTransport?.sessionId, - mcpReq: { - id: request.id, - method: request.method, - _meta: request.params?._meta, - signal: abortController.signal, - send: (r: { method: M; params?: Record }, options?: TaskRequestOptions) => { - const resultSchema = getResultSchema(r.method); - return sendRequest(r as Request, resultSchema, options) as Promise; - }, - notify: sendNotification - }, - http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined, - task: taskContext - }; - const ctx = this.buildContext(baseCtx, extra); - - // Starting with Promise.resolve() puts any synchronous errors into the monad as well. - Promise.resolve() - .then(() => { - for (const validate of validators) { - validate(); - } - }) - .then(() => handler(request, ctx)) - .then( - async result => { - if (abortController.signal.aborted) { - // Request was cancelled - return; - } - - const response: JSONRPCResponse = { - result, - jsonrpc: '2.0', - id: request.id - }; - - // Queue or send the response based on whether this is a task-related request - const routed = await routeResponse(response); - if (!routed) { - await capturedTransport?.send(response); - } - }, - async error => { - if (abortController.signal.aborted) { - // Request was cancelled - return; - } - - const errorResponse: JSONRPCErrorResponse = { - jsonrpc: '2.0', - id: request.id, - error: { - code: Number.isSafeInteger(error['code']) ? error['code'] : ProtocolErrorCode.InternalError, - message: error.message ?? 'Internal error', - ...(error['data'] !== undefined && { data: error['data'] }) - } - }; - - // Queue or send the error response based on whether this is a task-related request - const routed = await routeResponse(errorResponse); - if (!routed) { - await capturedTransport?.send(errorResponse); - } - } - ) - .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) - .finally(() => { - if (this._requestHandlerAbortControllers.get(request.id) === abortController) { - this._requestHandlerAbortControllers.delete(request.id); - } - }); + setNotificationHandler( + method: M, + handler: (notification: NotificationTypeMap[M]) => void | Promise + ): void { + this._dispatcher.setNotificationHandler(method, handler); } - private _onprogress(notification: ProgressNotification): void { - const { progressToken, ...params } = notification.params; - const messageId = Number(progressToken); - - const handler = this._progressHandlers.get(messageId); - if (!handler) { - this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); - return; - } - - const responseHandler = this._responseHandlers.get(messageId); - const timeoutInfo = this._timeoutInfo.get(messageId); - - if (timeoutInfo && responseHandler && timeoutInfo.resetTimeoutOnProgress) { - try { - this._resetTimeout(messageId); - } catch (error) { - // Clean up if maxTotalTimeout was exceeded - this._responseHandlers.delete(messageId); - this._progressHandlers.delete(messageId); - this._cleanupTimeout(messageId); - responseHandler(error as Error); - return; - } - } - - handler(params); + removeNotificationHandler(method: string): void { + this._dispatcher.removeNotificationHandler(method); } - private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { - const messageId = Number(response.id); - - // Delegate to TaskManager for task-related response handling - const taskResult = this._taskManager.processInboundResponse(response, messageId); - if (taskResult.consumed) return; - const preserveProgress = taskResult.preserveProgress; - - const handler = this._responseHandlers.get(messageId); - if (handler === undefined) { - this._onerror(new Error(`Received a response for an unknown message ID: ${JSON.stringify(response)}`)); - return; - } - - this._responseHandlers.delete(messageId); - this._cleanupTimeout(messageId); - - // Keep progress handler alive for CreateTaskResult responses - if (!preserveProgress) { - this._progressHandlers.delete(messageId); - } - - if (isJSONRPCResultResponse(response)) { - handler(response); - } else { - const error = ProtocolError.fromError(response.error.code, response.error.message, response.error.data); - handler(error); - } + get fallbackRequestHandler(): ((request: JSONRPCRequest, ctx: ContextT) => Promise) | undefined { + return this._dispatcher.fallbackRequestHandler; } - - get transport(): Transport | undefined { - return this._transport; + set fallbackRequestHandler(h) { + this._dispatcher.fallbackRequestHandler = h; } - /** - * Closes the connection. - */ - async close(): Promise { - await this._transport?.close(); + get fallbackNotificationHandler(): ((notification: Notification) => Promise) | undefined { + return this._dispatcher.fallbackNotificationHandler; + } + set fallbackNotificationHandler(h) { + this._dispatcher.fallbackNotificationHandler = h; } - /** - * A method to check if a capability is supported by the remote side, for the given method to be called. - * - * This should be implemented by subclasses. - */ - protected abstract assertCapabilityForMethod(method: RequestMethod): void; + // ─────────────────────────────────────────────────────────────────────── + // Connection (delegates to StreamDriver) + // ─────────────────────────────────────────────────────────────────────── /** - * A method to check if a notification is supported by the local side, for the given method to be sent. - * - * This should be implemented by subclasses. + * Connects to a transport. Creates a fresh {@linkcode StreamDriver} per call, + * so re-connecting (the v1 stateful-SHTTP pattern) is supported. */ - protected abstract assertNotificationCapability(method: NotificationMethod): void; + async connect(transport: Transport): Promise { + const driver = new StreamDriver(this._dispatcher, transport, { + supportedProtocolVersions: this._supportedProtocolVersions, + debouncedNotificationMethods: this._options?.debouncedNotificationMethods, + buildEnv: (extra, base) => ({ ...base, _transportExtra: extra }) + }); + this._outbound = driver; + driver.onresponse = (r, id) => this._ownTaskManager.processInboundResponse(r, id); + driver.onclose = () => { + if (this._outbound === driver) this._outbound = undefined; + this._ownTaskManager.onClose(); + this.onclose?.(); + }; + driver.onerror = error => this.onerror?.(error); + await driver.start(); + } /** - * A method to check if a request handler is supported by the local side, for the given method to be handled. - * - * This should be implemented by subclasses. + * Closes the connection. */ - protected abstract assertRequestHandlerCapability(method: string): void; + async close(): Promise { + await this._outbound?.close(); + } - /** - * A method to check if the remote side supports task creation for the given method. - * - * Called when sending a task-augmented outbound request (only when enforceStrictCapabilities is true). - * This should be implemented by subclasses. - */ - protected abstract assertTaskCapability(method: string): void; + /** @deprecated Protocol is no longer coupled to a specific transport. Returns the underlying pipe only when connected via {@linkcode StreamDriver}. */ + get transport(): Transport | undefined { + return (this._outbound as { pipe?: Transport } | undefined)?.pipe; + } - /** - * A method to check if this side supports handling task creation for the given method. - * - * Called when receiving a task-augmented inbound request. - * This should be implemented by subclasses. - */ - protected abstract assertTaskHandlerCapability(method: string): void; + get taskManager(): TaskManager { + return this._ownTaskManager; + } /** - * Sends a request and waits for a response, resolving the result schema - * automatically from the method name. - * - * Do not use this method to emit notifications! Use {@linkcode Protocol.notification | notification()} instead. + * Sends a request and waits for a response. */ request( request: { method: M; params?: Record }, @@ -788,294 +204,44 @@ export abstract class Protocol { /** * Sends a request and waits for a response, using the provided schema for validation. - * - * This is the internal implementation used by SDK methods that need to specify - * a particular result schema (e.g., for compatibility or task-specific schemas). */ protected _requestWithSchema( request: Request, resultSchema: T, options?: RequestOptions ): Promise> { - const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; - - let onAbort: (() => void) | undefined; - let cleanupMessageId: number | undefined; - - // Send the request - return new Promise>((resolve, reject) => { - const earlyReject = (error: unknown) => { - reject(error); - }; - - if (!this._transport) { - earlyReject(new Error('Not connected')); - return; - } - - if (this._options?.enforceStrictCapabilities === true) { - try { - this.assertCapabilityForMethod(request.method as RequestMethod); - } catch (error) { - earlyReject(error); - return; - } - } - - options?.signal?.throwIfAborted(); - - const messageId = this._requestMessageId++; - cleanupMessageId = messageId; - const jsonrpcRequest: JSONRPCRequest = { - ...request, - jsonrpc: '2.0', - id: messageId - }; - - if (options?.onprogress) { - this._progressHandlers.set(messageId, options.onprogress); - jsonrpcRequest.params = { - ...request.params, - _meta: { - ...request.params?._meta, - progressToken: messageId - } - }; - } - - const cancel = (reason: unknown) => { - this._progressHandlers.delete(messageId); - - this._transport - ?.send( - { - jsonrpc: '2.0', - method: 'notifications/cancelled', - params: { - requestId: messageId, - reason: String(reason) - } - }, - { relatedRequestId, resumptionToken, onresumptiontoken } - ) - .catch(error => this._onerror(new Error(`Failed to send cancellation: ${error}`))); - - // Wrap the reason in an SdkError if it isn't already - const error = reason instanceof SdkError ? reason : new SdkError(SdkErrorCode.RequestTimeout, String(reason)); - reject(error); - }; - - this._responseHandlers.set(messageId, response => { - if (options?.signal?.aborted) { - return; - } - - if (response instanceof Error) { - return reject(response); - } - - try { - const parseResult = parseSchema(resultSchema, response.result); - if (parseResult.success) { - resolve(parseResult.data as SchemaOutput); - } else { - reject(parseResult.error); - } - } catch (error) { - reject(error); - } - }); - - onAbort = () => cancel(options?.signal?.reason); - options?.signal?.addEventListener('abort', onAbort, { once: true }); - - const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; - const timeoutHandler = () => cancel(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout })); - - this._setupTimeout(messageId, timeout, options?.maxTotalTimeout, timeoutHandler, options?.resetTimeoutOnProgress ?? false); - - // Delegate task augmentation and routing to module (if registered) - const responseHandler = (response: JSONRPCResultResponse | Error) => { - const handler = this._responseHandlers.get(messageId); - if (handler) { - handler(response); - } else { - this._onerror(new Error(`Response handler missing for side-channeled request ${messageId}`)); - } - }; - - let outboundQueued = false; - try { - const taskResult = this._taskManager.processOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, error => { - this._progressHandlers.delete(messageId); - reject(error); - }); - if (taskResult.queued) { - outboundQueued = true; - } - } catch (error) { - this._progressHandlers.delete(messageId); - reject(error); - return; - } - - if (!outboundQueued) { - // No related task or no module - send through transport normally - this._transport.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { - this._progressHandlers.delete(messageId); - reject(error); - }); - } - }).finally(() => { - // Per-request cleanup that must run on every exit path. Consolidated - // here so new exit paths added to the promise body can't forget it. - // _progressHandlers is NOT cleaned up here: _onresponse deletes it - // conditionally (preserveProgress for task flows), and error paths - // above delete it inline since no task exists in those cases. - if (onAbort) { - options?.signal?.removeEventListener('abort', onAbort); - } - if (cleanupMessageId !== undefined) { - this._responseHandlers.delete(cleanupMessageId); - this._cleanupTimeout(cleanupMessageId); - } - }); + if (!this._outbound) { + return Promise.reject(new SdkError(SdkErrorCode.NotConnected, 'Not connected')); + } + if (this._options?.enforceStrictCapabilities === true) { + this.assertCapabilityForMethod(request.method as RequestMethod); + } + return this._ownTaskManager.sendRequest(request, resultSchema, options, this._outbound); } /** * Emits a notification, which is a one-way message that does not expect a response. */ async notification(notification: Notification, options?: NotificationOptions): Promise { - if (!this._transport) { + if (!this._outbound) { throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); } - this.assertNotificationCapability(notification.method as NotificationMethod); - - // Delegate task-related notification routing and JSONRPC building to TaskManager - const taskResult = await this._taskManager.processOutboundNotification(notification, options); - const queued = taskResult.queued; - const jsonrpcNotification = taskResult.queued ? undefined : taskResult.jsonrpcNotification; - - if (queued) { - // Don't send through transport - queued messages are delivered via tasks/result only - return; - } - - const debouncedMethods = this._options?.debouncedNotificationMethods ?? []; - // A notification can only be debounced if it's in the list AND it's "simple" - // (i.e., has no parameters and no related request ID or related task that could be lost). - const canDebounce = - debouncedMethods.includes(notification.method) && !notification.params && !options?.relatedRequestId && !options?.relatedTask; - - if (canDebounce) { - // If a notification of this type is already scheduled, do nothing. - if (this._pendingDebouncedNotifications.has(notification.method)) { - return; - } - - // Mark this notification type as pending. - this._pendingDebouncedNotifications.add(notification.method); - - // Schedule the actual send to happen in the next microtask. - // This allows all synchronous calls in the current event loop tick to be coalesced. - Promise.resolve().then(() => { - // Un-mark the notification so the next one can be scheduled. - this._pendingDebouncedNotifications.delete(notification.method); - - // SAFETY CHECK: If the connection was closed while this was pending, abort. - if (!this._transport) { - return; - } - - // Send the notification, but don't await it here to avoid blocking. - // Handle potential errors with a .catch(). - this._transport?.send(jsonrpcNotification!, options).catch(error => this._onerror(error)); - }); - - // Return immediately. - return; - } - - await this._transport.send(jsonrpcNotification!, options); + return this._ownTaskManager.sendNotification(notification, options, this._outbound); } - /** - * Registers a handler to invoke when this protocol object receives a request with the given method. - * - * Note that this will replace any previous request handler for the same method. - */ - setRequestHandler( - method: M, - handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise - ): void { - this.assertRequestHandlerCapability(method); - const schema = getRequestSchema(method); + // ─────────────────────────────────────────────────────────────────────── + // Test-compat accessors. v1 tests reach into these privates; proxy them to + // the driver so the test corpus keeps passing without rewrites. + // ─────────────────────────────────────────────────────────────────────── - this._requestHandlers.set(method, (request, ctx) => { - const parsed = schema.parse(request) as RequestTypeMap[M]; - return Promise.resolve(handler(parsed, ctx)); - }); - } - - /** - * Removes the request handler for the given method. - */ - removeRequestHandler(method: RequestMethod): void { - this._requestHandlers.delete(method); - } - - /** - * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. - */ - assertCanSetRequestHandler(method: RequestMethod): void { - if (this._requestHandlers.has(method)) { - throw new Error(`A request handler for ${method} already exists, which would be overridden`); - } + /** @internal v1 tests reach into this. */ + protected get _taskManager(): TaskManager { + return this._ownTaskManager; } - /** - * Registers a handler to invoke when this protocol object receives a notification with the given method. - * - * Note that this will replace any previous notification handler for the same method. - */ - setNotificationHandler( - method: M, - handler: (notification: NotificationTypeMap[M]) => void | Promise - ): void { - const schema = getNotificationSchema(method); - - this._notificationHandlers.set(method, notification => { - const parsed = schema.parse(notification); - return Promise.resolve(handler(parsed)); - }); - } - - /** - * Removes the notification handler for the given method. - */ - removeNotificationHandler(method: NotificationMethod): void { - this._notificationHandlers.delete(method); - } -} - -function isPlainObject(value: unknown): value is Record { - return value !== null && typeof value === 'object' && !Array.isArray(value); -} - -export function mergeCapabilities(base: ServerCapabilities, additional: Partial): ServerCapabilities; -export function mergeCapabilities(base: ClientCapabilities, additional: Partial): ClientCapabilities; -export function mergeCapabilities(base: T, additional: Partial): T { - const result: T = { ...base }; - for (const key in additional) { - const k = key as keyof T; - const addValue = additional[k]; - if (addValue === undefined) continue; - const baseValue = result[k]; - result[k] = - isPlainObject(baseValue) && isPlainObject(addValue) - ? ({ ...(baseValue as Record), ...(addValue as Record) } as T[typeof k]) - : (addValue as T[typeof k]); + /** @internal v1 tests reach into this. */ + protected get _responseHandlers(): Map void> | undefined { + return (this._outbound as unknown as { _responseHandlers?: Map void> })?._responseHandlers; } - return result; } diff --git a/packages/core/src/shared/streamDriver.ts b/packages/core/src/shared/streamDriver.ts new file mode 100644 index 000000000..0ac24ae6d --- /dev/null +++ b/packages/core/src/shared/streamDriver.ts @@ -0,0 +1,430 @@ +import { SdkError, SdkErrorCode } from '../errors/sdkErrors.js'; +import type { + JSONRPCErrorResponse, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + MessageExtraInfo, + Notification, + Progress, + ProgressNotification, + Request, + RequestId, + RequestMethod, + Result +} from '../types/index.js'; +import { + getResultSchema, + isJSONRPCErrorResponse, + isJSONRPCNotification, + isJSONRPCRequest, + isJSONRPCResultResponse, + ProtocolError, + SUPPORTED_PROTOCOL_VERSIONS +} from '../types/index.js'; +import type { AnySchema, SchemaOutput } from '../util/schema.js'; +import { parseSchema } from '../util/schema.js'; +import type { NotificationOptions, Outbound, ProgressCallback, RequestEnv, RequestOptions } from './context.js'; +import { DEFAULT_REQUEST_TIMEOUT_MSEC } from './context.js'; +import type { Dispatcher } from './dispatcher.js'; +import type { AttachOptions, Transport } from './transport.js'; + +type TimeoutInfo = { + timeoutId: ReturnType; + startTime: number; + timeout: number; + maxTotalTimeout?: number; + resetTimeoutOnProgress: boolean; + onTimeout: () => void; +}; + +export type StreamDriverOptions = { + supportedProtocolVersions?: string[]; + debouncedNotificationMethods?: string[]; + /** + * Hook to enrich the per-request {@linkcode RequestEnv} from transport-supplied + * {@linkcode MessageExtraInfo} (e.g. auth, http req). + */ + buildEnv?: (extra: MessageExtraInfo | undefined, base: RequestEnv) => RequestEnv; +}; + +/** + * Runs a {@linkcode Dispatcher} over a persistent bidirectional {@linkcode Transport} + * (stdio, WebSocket, InMemory). Owns all per-connection state: outbound request + * id correlation, timeouts, progress callbacks, cancellation, debouncing. + * + * One driver per pipe. The dispatcher it wraps may be shared. + */ +export class StreamDriver implements Outbound { + private _requestMessageId = 0; + private _responseHandlers: Map void> = new Map(); + private _progressHandlers: Map = new Map(); + private _timeoutInfo: Map = new Map(); + private _requestHandlerAbortControllers: Map = new Map(); + private _pendingDebouncedNotifications = new Set(); + private _closed = false; + private _supportedProtocolVersions: string[]; + + onclose?: () => void; + onerror?: (error: Error) => void; + /** + * Tap for every inbound response. Return `consumed: true` to claim it (suppresses the + * matched-handler dispatch / unknown-id error). Return `preserveProgress: true` to keep + * the progress handler registered after the matched handler runs. Set by the owner. + */ + onresponse?: (response: JSONRPCResponse | JSONRPCErrorResponse, messageId: number) => { consumed: boolean; preserveProgress?: boolean }; + + constructor( + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- driver is context-agnostic; subclass owns ContextT + readonly dispatcher: Dispatcher, + readonly pipe: Transport, + private _options: StreamDriverOptions = {} + ) { + this._supportedProtocolVersions = _options.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; + } + + /** {@linkcode Outbound.removeProgressHandler}. */ + removeProgressHandler(token: number): void { + this._progressHandlers.delete(token); + } + + /** + * Wires the pipe's callbacks and starts it. After this resolves, inbound + * requests are dispatched and {@linkcode StreamDriver.request | request()} works. + */ + async start(): Promise { + const prevClose = this.pipe.onclose; + this.pipe.onclose = () => { + try { + prevClose?.(); + } finally { + this._onclose(); + } + }; + + const prevError = this.pipe.onerror; + this.pipe.onerror = (error: Error) => { + prevError?.(error); + this._onerror(error); + }; + + const prevMessage = this.pipe.onmessage; + this.pipe.onmessage = (message, extra) => { + prevMessage?.(message, extra); + if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { + this._onresponse(message); + } else if (isJSONRPCRequest(message)) { + this._onrequest(message, extra); + } else if (isJSONRPCNotification(message)) { + this._onnotification(message); + } else { + this._onerror(new Error(`Unknown message type: ${JSON.stringify(message)}`)); + } + }; + + this.pipe.setSupportedProtocolVersions?.(this._supportedProtocolVersions); + await this.pipe.start(); + } + + async close(): Promise { + await this.pipe.close(); + } + + /** {@linkcode Outbound.setProtocolVersion} — delegates to the pipe. */ + setProtocolVersion(version: string): void { + this.pipe.setProtocolVersion?.(version); + } + + /** {@linkcode Outbound.sendRaw} — write a raw JSON-RPC message to the pipe. */ + async sendRaw(message: Parameters[0], options?: { relatedRequestId?: RequestId }): Promise { + await this.pipe.send(message, options); + } + + /** + * Sends a request over the pipe and resolves with the parsed result. + */ + request(req: Request, resultSchema: T, options?: RequestOptions): Promise> { + const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; + let onAbort: (() => void) | undefined; + let cleanupId: number | undefined; + + return new Promise>((resolve, reject) => { + options?.signal?.throwIfAborted(); + + const messageId = this._requestMessageId++; + cleanupId = messageId; + const jsonrpcRequest: JSONRPCRequest = { ...req, jsonrpc: '2.0', id: messageId }; + + if (options?.onprogress) { + this._progressHandlers.set(messageId, options.onprogress); + jsonrpcRequest.params = { + ...req.params, + _meta: { ...(req.params?._meta as Record | undefined), progressToken: messageId } + }; + } + + const cancel = (reason: unknown) => { + this._progressHandlers.delete(messageId); + this.pipe + .send( + { + jsonrpc: '2.0', + method: 'notifications/cancelled', + params: { requestId: messageId, reason: String(reason) } + }, + { relatedRequestId, resumptionToken, onresumptiontoken } + ) + .catch(error => this._onerror(new Error(`Failed to send cancellation: ${error}`))); + const error = reason instanceof SdkError ? reason : new SdkError(SdkErrorCode.RequestTimeout, String(reason)); + reject(error); + }; + + this._responseHandlers.set(messageId, response => { + if (options?.signal?.aborted) return; + if (response instanceof Error) return reject(response); + try { + const parsed = parseSchema(resultSchema, response.result); + if (parsed.success) resolve(parsed.data as SchemaOutput); + else reject(parsed.error); + } catch (error) { + reject(error); + } + }); + + onAbort = () => cancel(options?.signal?.reason); + options?.signal?.addEventListener('abort', onAbort, { once: true }); + + const timeout = options?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; + this._setupTimeout( + messageId, + timeout, + options?.maxTotalTimeout, + () => cancel(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout })), + options?.resetTimeoutOnProgress ?? false + ); + + if (options?.intercept) { + const settle = (r: JSONRPCResultResponse | Error) => this._responseHandlers.get(messageId)?.(r); + const onError = (e: unknown) => { + this._progressHandlers.delete(messageId); + reject(e); + }; + if (options.intercept(jsonrpcRequest, messageId, settle, onError)) return; + } + + this.pipe.send(jsonrpcRequest, { relatedRequestId, resumptionToken, onresumptiontoken }).catch(error => { + this._progressHandlers.delete(messageId); + reject(error); + }); + }).finally(() => { + if (onAbort) options?.signal?.removeEventListener('abort', onAbort); + if (cleanupId !== undefined) { + this._responseHandlers.delete(cleanupId); + this._cleanupTimeout(cleanupId); + } + }); + } + + /** + * Sends a notification over the pipe. Supports debouncing per the constructor option. + */ + async notification(notification: Notification, options?: NotificationOptions): Promise { + if (this._closed) return; + const jsonrpc: JSONRPCNotification = { + jsonrpc: '2.0', + method: notification.method, + params: notification.params + }; + + const debounced = this._options.debouncedNotificationMethods ?? []; + const canDebounce = + debounced.includes(notification.method) && !notification.params && !options?.relatedRequestId && !options?.relatedTask; + if (canDebounce) { + if (this._pendingDebouncedNotifications.has(notification.method)) return; + this._pendingDebouncedNotifications.add(notification.method); + Promise.resolve().then(() => { + // If the entry was already removed (by _onclose), skip the send. + if (!this._pendingDebouncedNotifications.delete(notification.method)) return; + this.pipe.send(jsonrpc, options).catch(error => this._onerror(error)); + }); + return; + } + await this.pipe.send(jsonrpc, options); + } + + private _onrequest(request: JSONRPCRequest, extra?: MessageExtraInfo): void { + const abort = new AbortController(); + this._requestHandlerAbortControllers.set(request.id, abort); + + const baseEnv: RequestEnv = { + signal: abort.signal, + sessionId: this.pipe.sessionId, + authInfo: extra?.authInfo, + httpReq: extra?.request, + send: (r, opts) => + this.request(r, getResultSchema(r.method as RequestMethod), { ...opts, relatedRequestId: request.id }) as Promise + }; + const env = this._options.buildEnv ? this._options.buildEnv(extra, baseEnv) : baseEnv; + + const drain = async () => { + for await (const out of this.dispatcher.dispatch(request, env)) { + if (out.kind === 'notification') { + await this.notification({ method: out.message.method, params: out.message.params }, { relatedRequestId: request.id }); + } else { + if (abort.signal.aborted) return; + await this.pipe.send(out.message, { relatedRequestId: request.id }); + } + } + }; + drain() + .catch(error => this._onerror(new Error(`Failed to send response: ${error}`))) + .finally(() => { + if (this._requestHandlerAbortControllers.get(request.id) === abort) { + this._requestHandlerAbortControllers.delete(request.id); + } + }); + } + + private _onnotification(notification: JSONRPCNotification): void { + if (notification.method === 'notifications/cancelled') { + const requestId = (notification.params as { requestId?: RequestId } | undefined)?.requestId; + if (requestId !== undefined) + this._requestHandlerAbortControllers.get(requestId)?.abort((notification.params as { reason?: unknown })?.reason); + return; + } + if (notification.method === 'notifications/progress') { + this._onprogress(notification as unknown as ProgressNotification); + return; + } + this.dispatcher + .dispatchNotification(notification) + .catch(error => this._onerror(new Error(`Uncaught error in notification handler: ${error}`))); + } + + private _onprogress(notification: ProgressNotification): void { + const { progressToken, ...params } = notification.params; + const messageId = Number(progressToken); + const handler = this._progressHandlers.get(messageId); + if (!handler) { + this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); + return; + } + const responseHandler = this._responseHandlers.get(messageId); + const info = this._timeoutInfo.get(messageId); + if (info && responseHandler && info.resetTimeoutOnProgress) { + try { + this._resetTimeout(messageId); + } catch (error) { + this._responseHandlers.delete(messageId); + this._progressHandlers.delete(messageId); + this._cleanupTimeout(messageId); + responseHandler(error as Error); + return; + } + } + handler(params as Progress); + } + + private _onresponse(response: JSONRPCResponse | JSONRPCErrorResponse): void { + const messageId = Number(response.id); + const tap = this.onresponse?.(response, messageId); + if (tap?.consumed) return; + const handler = this._responseHandlers.get(messageId); + if (handler === undefined) { + this._onerror(new Error(`Received a response for an unknown message ID: ${JSON.stringify(response)}`)); + return; + } + this._responseHandlers.delete(messageId); + this._cleanupTimeout(messageId); + if (!tap?.preserveProgress) this._progressHandlers.delete(messageId); + if (isJSONRPCResultResponse(response)) { + handler(response); + } else { + handler(ProtocolError.fromError(response.error.code, response.error.message, response.error.data)); + } + } + + private _onclose(): void { + this._closed = true; + const responseHandlers = this._responseHandlers; + this._responseHandlers = new Map(); + this._progressHandlers.clear(); + this._pendingDebouncedNotifications.clear(); + for (const info of this._timeoutInfo.values()) clearTimeout(info.timeoutId); + this._timeoutInfo.clear(); + const aborts = this._requestHandlerAbortControllers; + this._requestHandlerAbortControllers = new Map(); + const error = new SdkError(SdkErrorCode.ConnectionClosed, 'Connection closed'); + try { + this.onclose?.(); + } finally { + for (const handler of responseHandlers.values()) handler(error); + for (const c of aborts.values()) c.abort(error); + } + } + + private _onerror(error: Error): void { + this.onerror?.(error); + } + + private _setupTimeout(id: number, timeout: number, maxTotal: number | undefined, onTimeout: () => void, reset: boolean): void { + this._timeoutInfo.set(id, { + timeoutId: setTimeout(onTimeout, timeout), + startTime: Date.now(), + timeout, + maxTotalTimeout: maxTotal, + resetTimeoutOnProgress: reset, + onTimeout + }); + } + + private _resetTimeout(id: number): boolean { + const info = this._timeoutInfo.get(id); + if (!info) return false; + const elapsed = Date.now() - info.startTime; + if (info.maxTotalTimeout && elapsed >= info.maxTotalTimeout) { + this._timeoutInfo.delete(id); + throw new SdkError(SdkErrorCode.RequestTimeout, 'Maximum total timeout exceeded', { + maxTotalTimeout: info.maxTotalTimeout, + totalElapsed: elapsed + }); + } + clearTimeout(info.timeoutId); + info.timeoutId = setTimeout(info.onTimeout, info.timeout); + return true; + } + + private _cleanupTimeout(id: number): void { + const info = this._timeoutInfo.get(id); + if (info) { + clearTimeout(info.timeoutId); + this._timeoutInfo.delete(id); + } + } +} + +/** + * Wraps a {@linkcode ChannelTransport} in a {@linkcode StreamDriver} and starts it. + * Callers (`McpServer.connect`, `Client.connect`) use this helper instead of + * importing `StreamDriver` themselves. + */ +export async function attachChannelTransport( + pipe: Transport, + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- adapter is context-agnostic + dispatcher: Dispatcher, + options?: AttachOptions +): Promise { + const driver = new StreamDriver(dispatcher, pipe, { + supportedProtocolVersions: options?.supportedProtocolVersions, + debouncedNotificationMethods: options?.debouncedNotificationMethods, + buildEnv: options?.buildEnv + }); + if (options?.onclose || options?.onerror || options?.onresponse) { + driver.onclose = options.onclose; + driver.onerror = options.onerror; + driver.onresponse = options.onresponse; + } + await driver.start(); + return driver; +} diff --git a/packages/core/src/shared/taskManager.ts b/packages/core/src/shared/taskManager.ts index d7d40c550..f3361f3c5 100644 --- a/packages/core/src/shared/taskManager.ts +++ b/packages/core/src/shared/taskManager.ts @@ -32,56 +32,27 @@ import { TaskStatusNotificationSchema } from '../types/index.js'; import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/schema.js'; -import type { BaseContext, NotificationOptions, RequestOptions } from './protocol.js'; +import type { NotificationOptions, Outbound, RequestEnv, RequestOptions } from './context.js'; +import type { Dispatcher, DispatchFn, DispatchMiddleware, DispatchOutput } from './dispatcher.js'; import type { ResponseMessage } from './responseMessage.js'; /** - * Host interface for TaskManager to call back into Protocol. @internal + * Hooks {@linkcode TaskManager.attachTo} needs from its owner. The owner is whoever + * holds the {@linkcode Outbound} (McpServer/Client/Protocol). Replaces the + * previous wider host vtable: most of what the vtable provided is reachable via + * `channel()` or via the {@linkcode Dispatcher} passed to `attachTo`. + * @internal */ -export interface TaskManagerHost { - request(request: Request, resultSchema: T, options?: RequestOptions): Promise>; - notification(notification: Notification, options?: NotificationOptions): Promise; +export interface TaskAttachHooks { + /** Current outbound channel (may be undefined before connect). */ + channel(): Outbound | undefined; + /** Surface non-fatal errors. */ reportError(error: Error): void; - removeProgressHandler(token: number): void; - registerHandler(method: string, handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise): void; - sendOnResponseStream(message: JSONRPCNotification | JSONRPCRequest, relatedRequestId: RequestId): Promise; enforceStrictCapabilities: boolean; assertTaskCapability(method: string): void; assertTaskHandlerCapability(method: string): void; } -/** - * Context provided to TaskManager when processing an inbound request. - * @internal - */ -export interface InboundContext { - sessionId?: string; - sendNotification: (notification: Notification, options?: NotificationOptions) => Promise; - sendRequest: (request: Request, resultSchema: U, options?: RequestOptions) => Promise>; -} - -/** - * Result returned by TaskManager after processing an inbound request. - * @internal - */ -export interface InboundResult { - taskContext?: BaseContext['task']; - sendNotification: (notification: Notification) => Promise; - sendRequest: ( - request: Request, - resultSchema: U, - options?: Omit - ) => Promise>; - routeResponse: (message: JSONRPCResponse | JSONRPCErrorResponse) => Promise; - hasTaskCreationParams: boolean; - /** - * Optional validation to run inside the async handler chain (before the request handler). - * Throwing here produces a proper JSON-RPC error response, matching the behavior of - * capability checks on main. - */ - validateInbound?: () => void; -} - /** * Options that can be given per request. */ @@ -152,6 +123,13 @@ export type TaskContext = { id?: string; store: RequestTaskStore; requestedTtl?: number; + /** + * Yield a queued task message on the *current* dispatch's response stream. + * Set by the dispatch middleware; used by the `tasks/result` handler so queued + * messages flow on the same stream as that handler's terminal response. + * @internal + */ + sendOnResponseStream?: (message: JSONRPCNotification | JSONRPCRequest) => void; }; export type TaskManagerOptions = { @@ -195,10 +173,12 @@ export function extractTaskManagerOptions(tasksCapability: TaskManagerOptions | export class TaskManager { private _taskStore?: TaskStore; private _taskMessageQueue?: TaskMessageQueue; + /** @internal id allocator for dispatch-middleware-queued requests (independent of any transport's id space). */ + _dispatchOutboundId = 0; private _taskProgressTokens: Map = new Map(); private _requestResolvers: Map void> = new Map(); private _options: TaskManagerOptions; - private _host?: TaskManagerHost; + private _hooks?: TaskAttachHooks; constructor(options: TaskManagerOptions) { this._options = options; @@ -206,46 +186,184 @@ export class TaskManager { this._taskMessageQueue = options.taskMessageQueue; } - bind(host: TaskManagerHost): void { - this._host = host; + /** + * Attaches this manager to a {@linkcode Dispatcher}: registers the dispatch middleware + * via `d.use()`, installs `tasks/*` request handlers when a store is configured, and + * stores the {@linkcode TaskAttachHooks}. Outbound-side hooks (request/notification + * augmentation, response correlation, close) are called directly by the channel adapter + * (see {@linkcode StreamDriver}), which receives this manager via {@linkcode AttachOptions}. + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any -- attach is context-agnostic + attachTo(d: Dispatcher, hooks: TaskAttachHooks): void { + this._hooks = hooks; + d.use(this.dispatchMiddleware); if (this._taskStore) { - host.registerHandler('tasks/get', async (request, ctx) => { + d.setRawRequestHandler('tasks/get', async (request, ctx) => { const params = request.params as { taskId: string }; - const task = await this.handleGetTask(params.taskId, ctx.sessionId); - // Per spec: tasks/get responses SHALL NOT include related-task metadata - // as the taskId parameter is the source of truth - return { - ...task - } as Result; + return (await this.handleGetTask(params.taskId, ctx.sessionId)) as Result; }); - host.registerHandler('tasks/result', async (request, ctx) => { + d.setRawRequestHandler('tasks/result', async (request, ctx) => { const params = request.params as { taskId: string }; - return await this.handleGetTaskPayload(params.taskId, ctx.sessionId, ctx.mcpReq.signal, async message => { - // Send the message on the response stream by passing the relatedRequestId - // This tells the transport to write the message to the tasks/result response stream - await host.sendOnResponseStream(message, ctx.mcpReq.id); + return this.handleGetTaskPayload(params.taskId, ctx.sessionId, ctx.mcpReq.signal, async message => { + const sink = + ctx.task?.sendOnResponseStream ?? + ((m: JSONRPCNotification | JSONRPCRequest) => { + void hooks.channel()?.sendRaw?.(m, { relatedRequestId: ctx.mcpReq.id }); + }); + sink(message); }); }); - host.registerHandler('tasks/list', async (request, ctx) => { + d.setRawRequestHandler('tasks/list', async (request, ctx) => { const params = request.params as { cursor?: string } | undefined; return (await this.handleListTasks(params?.cursor, ctx.sessionId)) as Result; }); - host.registerHandler('tasks/cancel', async (request, ctx) => { + d.setRawRequestHandler('tasks/cancel', async (request, ctx) => { const params = request.params as { taskId: string }; - return await this.handleCancelTask(params.taskId, ctx.sessionId); + return this.handleCancelTask(params.taskId, ctx.sessionId); }); } } - protected get _requireHost(): TaskManagerHost { - if (!this._host) { - throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskManager is not bound to a Protocol host — call bind() first'); + protected get _requireHooks(): TaskAttachHooks { + if (!this._hooks) { + throw new ProtocolError(ProtocolErrorCode.InternalError, 'TaskManager is not attached to a Dispatcher — call attachTo() first'); } - return this._host; + return this._hooks; + } + + /** + * The {@linkcode DispatchMiddleware}: detects task-augmented inbound requests, builds + * `env.task` (with the request-scoped store + side-channel sink), wraps `env.send` to + * carry `relatedTask`, intercepts yielded notifications/response for queueing. + */ + get dispatchMiddleware(): DispatchMiddleware { + // eslint-disable-next-line @typescript-eslint/no-this-alias, unicorn/no-this-assignment + const tm = this; + return next => + async function* (request, env = {}) { + const taskInfo = tm.extractInboundTaskContext(request, env.sessionId); + const relatedTaskId = taskInfo?.relatedTaskId; + const hasTaskCreationParams = !!taskInfo?.taskCreationParams; + + if (hasTaskCreationParams) { + try { + tm._requireHooks.assertTaskHandlerCapability(request.method); + } catch (error) { + const e = error as { code?: number; message?: string; data?: unknown }; + yield { + kind: 'response', + message: { + jsonrpc: '2.0', + id: request.id, + error: { + code: Number.isSafeInteger(e?.code) ? (e.code as number) : ProtocolErrorCode.InternalError, + message: e?.message ?? 'Internal error', + ...(e?.data !== undefined && { data: e.data }) + } + } + }; + return; + } + } + + // Side-channel sink so `tasks/result` (and any handler) can yield arbitrary + // queued messages on this dispatch's stream. Drained interleaved with `next()`. + const sideQueue: (JSONRPCNotification | JSONRPCRequest)[] = []; + let wake: (() => void) | undefined; + const sendOnResponseStream = (m: JSONRPCNotification | JSONRPCRequest) => { + sideQueue.push(m); + wake?.(); + }; + const drain = function* (): Generator { + while (sideQueue.length > 0) { + const m = sideQueue.shift()!; + yield { kind: 'notification', message: m as JSONRPCNotification }; + } + }; + + const wrappedSend: NonNullable = async (r, opts) => { + const relatedTask = relatedTaskId && !opts?.relatedTask ? { taskId: relatedTaskId } : opts?.relatedTask; + const effectiveTaskId = relatedTask?.taskId; + if (effectiveTaskId && taskInfo?.taskContext?.store) { + await taskInfo.taskContext.store.updateTaskStatus(effectiveTaskId, 'input_required'); + } + if (effectiveTaskId) { + // Queue to the task message queue (delivered via tasks/result), don't hit env.send. + return new Promise((resolve, reject) => { + const messageId = tm._dispatchOutboundId++; + const wire: JSONRPCRequest = { jsonrpc: '2.0', id: messageId, method: r.method, params: r.params }; + const settle = (resp: { result: Result } | Error) => + resp instanceof Error ? reject(resp) : resolve(resp.result); + const { queued } = tm.processOutboundRequest(wire, { ...opts, relatedTask }, messageId, settle, reject); + if (queued) return; + if (env.send) { + env.send(r, { ...opts, relatedTask }).then(result => settle({ result }), reject); + } else { + reject(new ProtocolError(ProtocolErrorCode.InternalError, 'env.send unavailable')); + } + }); + } + if (env.send) return env.send(r, { ...opts, relatedTask }); + throw new ProtocolError(ProtocolErrorCode.InternalError, 'env.send unavailable'); + }; + + const taskCtx: TaskContext | undefined = taskInfo?.taskContext + ? { ...taskInfo.taskContext, sendOnResponseStream } + : tm._taskStore + ? { store: tm.createRequestTaskStore(request, env.sessionId), sendOnResponseStream } + : undefined; + + const taskEnv: RequestEnv = { + ...env, + task: taskCtx ?? env.task, + send: relatedTaskId || taskInfo?.taskContext ? wrappedSend : env.send + }; + + const inner = next(request, taskEnv); + let pending: Promise> | undefined; + while (true) { + yield* drain(); + pending ??= inner.next(); + const wakeP = new Promise<'side'>(resolve => { + wake = () => resolve('side'); + }); + if (sideQueue.length > 0) { + wake = undefined; + continue; + } + const r = await Promise.race([pending, wakeP]); + wake = undefined; + if (r === 'side') continue; + pending = undefined; + if (r.done) break; + const out = r.value; + if (out.kind === 'response') { + const routed = relatedTaskId ? await tm.routeResponse(relatedTaskId, out.message, env.sessionId) : false; + if (!routed) { + yield* drain(); + yield out; + } + } else if (relatedTaskId === undefined) { + yield out; + } else { + // Handler-emitted notifications inside a related-task request are queued + // (not yielded) so they deliver via tasks/result, avoiding duplicate + // delivery on bidirectional transports. + const result = await tm.processOutboundNotification( + { method: out.message.method, params: out.message.params }, + { relatedTask: { taskId: relatedTaskId } } + ); + if (!result.queued && result.jsonrpcNotification) { + yield { kind: 'notification', message: result.jsonrpcNotification }; + } + } + } + yield* drain(); + } as DispatchFn; } get taskStore(): TaskStore | undefined { @@ -263,18 +381,23 @@ export class TaskManager { return this._taskMessageQueue; } + private _outboundRequest(req: Request, schema: T, opts?: RequestOptions): Promise> { + const ch = this._requireHooks.channel(); + if (!ch) throw new ProtocolError(ProtocolErrorCode.InternalError, 'Not connected'); + return this.sendRequest(req, schema, opts, ch); + } + // -- Public API (client-facing) -- async *requestStream( request: Request, resultSchema: T, options?: RequestOptions ): AsyncGenerator>, void, void> { - const host = this._requireHost; const { task } = options ?? {}; if (!task) { try { - const result = await host.request(request, resultSchema, options); + const result = await this._outboundRequest(request, resultSchema, options); yield { type: 'result', result }; } catch (error) { yield { @@ -287,7 +410,7 @@ export class TaskManager { let taskId: string | undefined; try { - const createResult = await host.request(request, CreateTaskResultSchema, options); + const createResult = await this._outboundRequest(request, CreateTaskResultSchema, options); if (createResult.task) { taskId = createResult.task.taskId; @@ -341,7 +464,7 @@ export class TaskManager { } async getTask(params: GetTaskRequest['params'], options?: RequestOptions): Promise { - return this._requireHost.request({ method: 'tasks/get', params }, GetTaskResultSchema, options); + return this._outboundRequest({ method: 'tasks/get', params }, GetTaskResultSchema, options); } async getTaskResult( @@ -349,15 +472,15 @@ export class TaskManager { resultSchema: T, options?: RequestOptions ): Promise> { - return this._requireHost.request({ method: 'tasks/result', params }, resultSchema, options); + return this._outboundRequest({ method: 'tasks/result', params }, resultSchema, options); } async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { - return this._requireHost.request({ method: 'tasks/list', params }, ListTasksResultSchema, options); + return this._outboundRequest({ method: 'tasks/list', params }, ListTasksResultSchema, options); } async cancelTask(params: { taskId: string }, options?: RequestOptions): Promise> { - return this._requireHost.request({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); + return this._outboundRequest({ method: 'tasks/cancel', params }, CancelTaskResultSchema, options); } // -- Handler bodies (delegated from Protocol's registered handlers) -- @@ -395,7 +518,7 @@ export class TaskManager { } } else { const messageType = queuedMessage.type === 'response' ? 'Response' : 'Error'; - this._host?.reportError(new Error(`${messageType} handler missing for request ${requestId}`)); + this._hooks?.reportError(new Error(`${messageType} handler missing for request ${requestId}`)); } continue; } @@ -553,36 +676,6 @@ export class TaskManager { }; } - private wrapSendNotification( - relatedTaskId: string, - originalSendNotification: (notification: Notification, options?: NotificationOptions) => Promise - ): (notification: Notification) => Promise { - return async (notification: Notification) => { - const notificationOptions: NotificationOptions = { relatedTask: { taskId: relatedTaskId } }; - await originalSendNotification(notification, notificationOptions); - }; - } - - private wrapSendRequest( - relatedTaskId: string, - taskStore: RequestTaskStore | undefined, - originalSendRequest: (request: Request, resultSchema: V, options?: RequestOptions) => Promise> - ): (request: Request, resultSchema: V, options?: TaskRequestOptions) => Promise> { - return async (request: Request, resultSchema: V, options?: TaskRequestOptions) => { - const requestOptions: RequestOptions = { ...options }; - if (relatedTaskId && !requestOptions.relatedTask) { - requestOptions.relatedTask = { taskId: relatedTaskId }; - } - - const effectiveTaskId = requestOptions.relatedTask?.taskId ?? relatedTaskId; - if (effectiveTaskId && taskStore) { - await taskStore.updateTaskStatus(effectiveTaskId, 'input_required'); - } - - return await originalSendRequest(request, resultSchema, requestOptions); - }; - } - private handleResponse(response: JSONRPCResponse | JSONRPCErrorResponse): boolean { const messageId = Number(response.id); const resolver = this._requestResolvers.get(messageId); @@ -656,7 +749,7 @@ export class TaskManager { private createRequestTaskStore(request?: JSONRPCRequest, sessionId?: string): RequestTaskStore { const taskStore = this._requireTaskStore; - const host = this._host; + const hooks = this._hooks; return { createTask: async taskParams => { @@ -676,7 +769,7 @@ export class TaskManager { method: 'notifications/tasks/status', params: task }); - await host?.notification(notification as Notification); + await hooks?.channel()?.notification(notification as Notification); if (isTerminal(task.status)) { this._cleanupTaskProgressHandler(taskId); } @@ -701,7 +794,7 @@ export class TaskManager { method: 'notifications/tasks/status', params: updatedTask }); - await host?.notification(notification as Notification); + await hooks?.channel()?.notification(notification as Notification); if (isTerminal(updatedTask.status)) { this._cleanupTaskProgressHandler(taskId); } @@ -711,39 +804,42 @@ export class TaskManager { }; } - // -- Lifecycle methods (called by Protocol directly) -- - - processInboundRequest(request: JSONRPCRequest, ctx: InboundContext): InboundResult { - const taskInfo = this.extractInboundTaskContext(request, ctx.sessionId); - const relatedTaskId = taskInfo?.relatedTaskId; - - const sendNotification = relatedTaskId - ? this.wrapSendNotification(relatedTaskId, ctx.sendNotification) - : (notification: Notification) => ctx.sendNotification(notification); + // -- Outbound helpers (called by McpServer/Client/Protocol before delegating to Outbound) -- - const sendRequest = relatedTaskId - ? this.wrapSendRequest(relatedTaskId, taskInfo?.taskContext?.store, ctx.sendRequest) - : taskInfo?.taskContext - ? this.wrapSendRequest('', taskInfo.taskContext.store, ctx.sendRequest) - : ctx.sendRequest; - - const hasTaskCreationParams = !!taskInfo?.taskCreationParams; + /** + * Task-aware request send: routes through {@linkcode RequestOptions.intercept} so the + * channel adapter builds the wire (id/progressToken/handlers) and TaskManager decides + * whether to queue it. Use this where instance-level outbound requests are made + * (Protocol/McpServer/Client), so the channel adapter stays task-agnostic. + */ + sendRequest( + request: Request, + resultSchema: T, + options: RequestOptions | undefined, + outbound: Outbound + ): Promise> { + if (!options?.relatedTask && !options?.task) { + return outbound.request(request, resultSchema, options); + } + return outbound.request(request, resultSchema, { + ...options, + intercept: (wire, messageId, settle, onError) => this.processOutboundRequest(wire, options, messageId, settle, onError).queued + }); + } - return { - taskContext: taskInfo?.taskContext, - sendNotification, - sendRequest, - routeResponse: async (message: JSONRPCResponse | JSONRPCErrorResponse) => { - if (relatedTaskId) { - return this.routeResponse(relatedTaskId, message, ctx.sessionId); - } - return false; - }, - hasTaskCreationParams, - // Deferred validation: runs inside the async handler chain so errors - // produce proper JSON-RPC error responses (matching main's behavior). - validateInbound: hasTaskCreationParams ? () => this._requireHost.assertTaskHandlerCapability(request.method) : undefined - }; + /** + * Task-aware notification send: queues when `options.relatedTask` is set, otherwise + * delegates to `outbound.notification()` with related-task metadata attached. + */ + async sendNotification(notification: Notification, options: NotificationOptions | undefined, outbound: Outbound): Promise { + const result = await this.processOutboundNotification(notification, options); + if (result.queued) return; + await outbound.notification( + result.jsonrpcNotification + ? { method: result.jsonrpcNotification.method, params: result.jsonrpcNotification.params } + : notification, + options + ); } processOutboundRequest( @@ -753,9 +849,8 @@ export class TaskManager { responseHandler: (response: JSONRPCResultResponse | Error) => void, onError: (error: unknown) => void ): { queued: boolean } { - // Check task capability when sending a task-augmented request (matches main's enforceStrictCapabilities gate) - if (this._requireHost.enforceStrictCapabilities && options?.task) { - this._requireHost.assertTaskCapability(jsonrpcRequest.method); + if (this._requireHooks.enforceStrictCapabilities && options?.task) { + this._requireHooks.assertTaskCapability(jsonrpcRequest.method); } const queued = this.prepareOutboundRequest(jsonrpcRequest, options, messageId, responseHandler, onError); @@ -824,7 +919,7 @@ export class TaskManager { resolver(new ProtocolError(ProtocolErrorCode.InternalError, 'Task cancelled or completed')); this._requestResolvers.delete(requestId); } else { - this._host?.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); + this._hooks?.reportError(new Error(`Resolver missing for request ${requestId} during task ${taskId} cleanup`)); } } } @@ -854,7 +949,7 @@ export class TaskManager { private _cleanupTaskProgressHandler(taskId: string): void { const progressToken = this._taskProgressTokens.get(taskId); if (progressToken !== undefined) { - this._host?.removeProgressHandler(progressToken); + this._hooks?.channel()?.removeProgressHandler?.(progressToken); this._taskProgressTokens.delete(taskId); } } @@ -862,32 +957,44 @@ export class TaskManager { /** * No-op TaskManager used when tasks capability is not configured. - * Provides passthrough implementations for the hot paths, avoiding - * unnecessary task extraction logic on every request. + * Its middleware getters return identity / no-op so registering it costs nothing. */ export class NullTaskManager extends TaskManager { constructor() { super({}); } - override processInboundRequest(request: JSONRPCRequest, ctx: InboundContext): InboundResult { - const hasTaskCreationParams = isTaskAugmentedRequestParams(request.params) && !!request.params.task; - return { - taskContext: undefined, - sendNotification: (notification: Notification) => ctx.sendNotification(notification), - sendRequest: ctx.sendRequest, - routeResponse: async () => false, - hasTaskCreationParams, - validateInbound: hasTaskCreationParams ? () => this._requireHost.assertTaskHandlerCapability(request.method) : undefined - }; + override get dispatchMiddleware(): DispatchMiddleware { + // No store → identity middleware. Only validate task-creation capability so the + // "client sent params.task but server has no tasks capability" error path matches. + // eslint-disable-next-line @typescript-eslint/no-this-alias, unicorn/no-this-assignment + const tm = this; + return next => + async function* (req, env) { + if (isTaskAugmentedRequestParams(req.params) && req.params.task) { + try { + tm._requireHooks.assertTaskHandlerCapability(req.method); + } catch (error) { + const e = error as { code?: number; message?: string; data?: unknown }; + yield { + kind: 'response', + message: { + jsonrpc: '2.0', + id: req.id, + error: { + code: Number.isSafeInteger(e?.code) ? (e.code as number) : ProtocolErrorCode.InternalError, + message: e?.message ?? 'Internal error', + ...(e?.data !== undefined && { data: e.data }) + } + } + }; + return; + } + } + yield* next(req, env); + } as DispatchFn; } - // processOutboundRequest is inherited - it handles task/relatedTask augmentation - // and only queues if relatedTask is set (which won't happen without a task store) - - // processInboundResponse is inherited - it checks _requestResolvers (empty for NullTaskManager) - // and _taskProgressTokens (empty for NullTaskManager) - override async processOutboundNotification( notification: Notification, _options?: NotificationOptions diff --git a/packages/core/src/shared/transport.ts b/packages/core/src/shared/transport.ts index c606e2e3b..4cca92e49 100644 --- a/packages/core/src/shared/transport.ts +++ b/packages/core/src/shared/transport.ts @@ -1,4 +1,13 @@ -import type { JSONRPCMessage, MessageExtraInfo, RequestId } from '../types/index.js'; +import type { + JSONRPCErrorResponse, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResultResponse, + MessageExtraInfo, + RequestId +} from '../types/index.js'; +import type { RequestEnv } from './context.js'; export type FetchLike = (url: string | URL, init?: RequestInit) => Promise; @@ -69,9 +78,19 @@ export type TransportSendOptions = { onresumptiontoken?: ((token: string) => void) | undefined; }; /** - * Describes the minimal contract for an MCP transport that a client or server can communicate over. + * Describes the minimal contract for a persistent, bidirectional MCP message channel + * (stdio, WebSocket, in-memory). The SDK wraps this in a {@linkcode StreamDriver} to + * do request/response correlation. + * + * For request/response-shaped transports (Streamable HTTP), see {@linkcode RequestTransport}. */ -export interface Transport { +export interface ChannelTransport { + /** + * Explicit shape brand. Optional (defaults to `'channel'`) so existing + * `Transport` implementations don't need to declare it. + */ + readonly kind?: 'channel'; + /** * Starts processing messages on the transport, including any connection steps that might need to be taken. * @@ -132,3 +151,80 @@ export interface Transport { */ setSupportedProtocolVersions?: ((versions: string[]) => void) | undefined; } + +/** @deprecated Use {@linkcode ChannelTransport}. Renamed for clarity alongside {@linkcode RequestTransport}; kept as an alias. */ +export type Transport = ChannelTransport; + +/** + * Options McpServer passes when wiring a {@linkcode ChannelTransport} via {@linkcode attachChannelTransport}. + * @internal + */ +export type AttachOptions = { + supportedProtocolVersions?: string[]; + debouncedNotificationMethods?: string[]; + buildEnv?: (extra: MessageExtraInfo | undefined, base: RequestEnv) => RequestEnv; + onclose?: () => void; + onerror?: (error: Error) => void; + /** Tap for every inbound response. See {@linkcode StreamDriver.onresponse}. */ + onresponse?: ( + response: JSONRPCResultResponse | JSONRPCErrorResponse, + messageId: number + ) => { consumed: boolean; preserveProgress?: boolean }; +}; + +/** + * A request/response-shaped server transport (e.g. Streamable HTTP). Unlike + * {@linkcode ChannelTransport}, there is no persistent pipe: the transport receives + * one HTTP request at a time and calls {@linkcode onrequest} for each, streaming the + * yielded messages back as the HTTP response. + * + * The `on*` callback slots are set by `McpServer.connect()`; the transport calls them + * per inbound message. The transport itself never imports or references a `Dispatcher`. + */ +export interface RequestTransport { + /** Explicit shape brand. Required so {@linkcode isRequestTransport} can discriminate without duck-typing. */ + readonly kind: 'request'; + + /** + * Callback slot for inbound JSON-RPC requests. Set by `McpServer.connect()`. + * The transport calls this per request and writes the yielded messages + * (notifications + one terminal response) to the HTTP response stream. + */ + onrequest?: ((req: JSONRPCRequest, env?: RequestEnv) => AsyncIterable) | undefined; + + /** Callback slot for inbound notifications (e.g. `notifications/initialized`). */ + onnotification?: (n: JSONRPCNotification) => void | Promise; + + /** + * Callback slot for inbound JSON-RPC responses (a client POSTing back the answer to + * a server-initiated request). Returns `true` if the response was claimed. + */ + onresponse?: (r: JSONRPCResultResponse | JSONRPCErrorResponse) => boolean; + + /** Aborts in-flight handlers and releases resources (open SSE streams, session map). */ + close(): Promise; + + /** + * 2025-11 back-compat: write an unsolicited notification to the session's standalone + * GET subscription stream. + */ + notify?(n: JSONRPCNotification): Promise; + + /** + * 2025-11 back-compat: send an unsolicited server→client request via the standalone + * GET stream and await the client's POSTed-back response. + */ + request?(r: JSONRPCRequest): Promise; + + /** Callback for when the transport is closed for any reason. */ + onclose?: (() => void) | undefined; + /** Callback for transport-level errors. */ + onerror?: ((error: Error) => void) | undefined; + /** Session id (single-session compat mode). */ + sessionId?: string | undefined; +} + +/** Type guard distinguishing {@linkcode RequestTransport} from {@linkcode ChannelTransport}. */ +export function isRequestTransport(t: ChannelTransport | RequestTransport): t is RequestTransport { + return (t as RequestTransport).kind === 'request'; +} diff --git a/packages/core/src/util/compatSchema.ts b/packages/core/src/util/compatSchema.ts new file mode 100644 index 000000000..63956c97b --- /dev/null +++ b/packages/core/src/util/compatSchema.ts @@ -0,0 +1,41 @@ +/** + * Helpers for the Zod-schema form of `setRequestHandler` / `setNotificationHandler`. + * + * v1 accepted a Zod object whose `.shape.method` is `z.literal('')`. + * v2 also accepts the method string directly. These helpers detect the schema + * form and extract the literal so the dispatcher can route to the correct path. + * + * @internal + */ + +/** + * Minimal structural type for a Zod object schema. The `method` literal is + * checked at runtime by `extractMethodLiteral`; the type-level constraint + * is intentionally loose because zod v4's `ZodLiteral` doesn't surface `.value` + * in its declared type (only at runtime). + */ +export interface ZodLikeRequestSchema { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + shape: any; + parse(input: unknown): unknown; +} + +/** True if `arg` looks like a Zod object schema (has `.shape` and `.parse`). */ +export function isZodLikeSchema(arg: unknown): arg is ZodLikeRequestSchema { + return typeof arg === 'object' && arg !== null && 'shape' in arg && typeof (arg as { parse?: unknown }).parse === 'function'; +} + +/** + * Extracts the string value from a Zod-like schema's `shape.method` literal. + * Throws if no string `method` literal is present. + */ +export function extractMethodLiteral(schema: ZodLikeRequestSchema): string { + const methodField = (schema.shape as Record | undefined)?.method as + | { value?: unknown; def?: { values?: unknown[] } } + | undefined; + const value = methodField?.value ?? methodField?.def?.values?.[0]; + if (typeof value !== 'string') { + throw new TypeError('Schema passed to setRequestHandler/setNotificationHandler is missing a string `method` literal'); + } + return value; +} diff --git a/packages/core/src/util/standardSchema.ts b/packages/core/src/util/standardSchema.ts index 9817dc39a..7f0d17276 100644 --- a/packages/core/src/util/standardSchema.ts +++ b/packages/core/src/util/standardSchema.ts @@ -6,6 +6,9 @@ /* eslint-disable @typescript-eslint/no-namespace */ +import type { ZodType as zType } from 'zod/v4'; +import { toJSONSchema as zToJSONSchema } from 'zod/v4'; + // Standard Schema interfaces — vendored from https://standardschema.dev (spec v1, Jan 2025) export interface StandardTypedV1 { @@ -149,7 +152,16 @@ export function isStandardSchemaWithJSON(schema: unknown): schema is StandardSch * since that cannot satisfy the MCP spec. */ export function standardSchemaToJsonSchema(schema: StandardJSONSchemaV1, io: 'input' | 'output' = 'input'): Record { - const result = schema['~standard'].jsonSchema[io]({ target: 'draft-2020-12' }); + // For zod schemas, use the package-level converter which handles cross-instance + // children (consumer's zod ≠ SDK's zod). The schema-local `~standard.jsonSchema` + // is constructed with empty processors and throws on cross-instance children. + let result: Record; + if ('_zod' in schema) { + result = zToJSONSchema(schema as unknown as zType, { io }) as Record; + delete result.$schema; + } else { + result = schema['~standard'].jsonSchema[io]({ target: 'draft-2020-12' }); + } if (result.type !== undefined && result.type !== 'object') { throw new Error( `MCP tool and prompt schemas must describe objects (got type: ${JSON.stringify(result.type)}). ` + diff --git a/packages/core/test/shared/dispatcher.test.ts b/packages/core/test/shared/dispatcher.test.ts new file mode 100644 index 000000000..86ee1f5a4 --- /dev/null +++ b/packages/core/test/shared/dispatcher.test.ts @@ -0,0 +1,214 @@ +import { describe, expect, test } from 'vitest'; +import { z } from 'zod/v4'; + +import { SdkError } from '../../src/errors/sdkErrors.js'; +import type { DispatchOutput } from '../../src/shared/dispatcher.js'; +import { Dispatcher } from '../../src/shared/dispatcher.js'; +import type { JSONRPCErrorResponse, JSONRPCRequest, JSONRPCResultResponse, Result } from '../../src/types/index.js'; +import { ProtocolError, ProtocolErrorCode } from '../../src/types/index.js'; + +const req = (method: string, params?: Record, id = 1): JSONRPCRequest => ({ jsonrpc: '2.0', id, method, params }); + +async function collect(gen: AsyncIterable): Promise { + const out: DispatchOutput[] = []; + for await (const o of gen) out.push(o); + return out; +} + +describe('Dispatcher', () => { + test('dispatch yields a single response for a registered handler', async () => { + const d = new Dispatcher(); + d.setRequestHandler('ping', async () => ({})); + const out = await collect(d.dispatch(req('ping'))); + expect(out).toHaveLength(1); + expect(out[0]!.kind).toBe('response'); + expect((out[0]!.message as JSONRPCResultResponse).result).toEqual({}); + }); + + test('dispatch yields MethodNotFound for an unregistered method', async () => { + const d = new Dispatcher(); + const out = await collect(d.dispatch(req('tools/list'))); + expect(out).toHaveLength(1); + const msg = out[0]!.message as JSONRPCErrorResponse; + expect(msg.error.code).toBe(ProtocolErrorCode.MethodNotFound); + }); + + test('handler throw is wrapped as InternalError', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('boom', async () => { + throw new Error('kaboom'); + }); + const out = await collect(d.dispatch(req('boom'))); + const msg = out[0]!.message as JSONRPCErrorResponse; + expect(msg.error.code).toBe(ProtocolErrorCode.InternalError); + expect(msg.error.message).toBe('kaboom'); + }); + + test('handler throwing ProtocolError preserves code and data', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('boom', async () => { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'bad', { hint: 'x' }); + }); + const out = await collect(d.dispatch(req('boom'))); + const msg = out[0]!.message as JSONRPCErrorResponse; + expect(msg.error.code).toBe(ProtocolErrorCode.InvalidParams); + expect(msg.error.data).toEqual({ hint: 'x' }); + }); + + test('ctx.mcpReq.notify yields notifications before the final response', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('work', async (_r, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 1, progress: 0.5 } }); + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: 'hi' } }); + return { ok: true } as Result; + }); + const out = await collect(d.dispatch(req('work'))); + expect(out.map(o => o.kind)).toEqual(['notification', 'notification', 'response']); + expect((out[0]!.message as any).params.progress).toBe(0.5); + expect((out[2]!.message as JSONRPCResultResponse).result).toEqual({ ok: true }); + }); + + test('notifications interleave with async handler work', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('work', async (_r, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: '1' } }); + await new Promise(r => setTimeout(r, 1)); + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: '2' } }); + return {} as Result; + }); + const seen: string[] = []; + for await (const o of d.dispatch(req('work'))) { + seen.push(o.kind === 'notification' ? `n:${(o.message.params as any).data}` : 'response'); + } + expect(seen).toEqual(['n:1', 'n:2', 'response']); + }); + + test('ctx.mcpReq.send throws by default with no env.send', async () => { + const d = new Dispatcher(); + let caught: unknown; + d.setRawRequestHandler('elicit', async (_r, ctx) => { + try { + await ctx.mcpReq.send({ method: 'elicitation/create', params: {} }); + } catch (e) { + caught = e; + } + return {} as Result; + }); + await collect(d.dispatch(req('elicit'))); + expect(caught).toBeInstanceOf(SdkError); + expect((caught as Error).message).toMatch(/no peer channel/); + }); + + test('ctx.mcpReq.send delegates to env.send when provided', async () => { + const d = new Dispatcher(); + let sent: unknown; + d.setRawRequestHandler('ask', async (_r, ctx) => { + const r = await ctx.mcpReq.send({ method: 'ping' }); + return { got: r } as Result; + }); + const out = await collect( + d.dispatch(req('ask'), { + send: async r => { + sent = r; + return { pong: true } as Result; + } + }) + ); + expect(sent).toEqual({ method: 'ping' }); + expect((out[0]!.message as JSONRPCResultResponse).result).toEqual({ got: { pong: true } }); + }); + + test('env.signal abort yields a cancelled error response', async () => { + const d = new Dispatcher(); + const ac = new AbortController(); + d.setRawRequestHandler('slow', async (_r, ctx) => { + if (ctx.mcpReq.signal.aborted) return {} as Result; + await new Promise(resolve => ctx.mcpReq.signal.addEventListener('abort', () => resolve(), { once: true })); + return {} as Result; + }); + const gen = d.dispatch(req('slow'), { signal: ac.signal }); + const p = collect(gen); + await Promise.resolve(); + ac.abort('stop'); + const out = await p; + const msg = out[out.length - 1]!.message as JSONRPCErrorResponse; + expect(msg.error.message).toBe('Request cancelled'); + }); + + test('env values surface on context', async () => { + const d = new Dispatcher(); + let seen: any; + d.setRawRequestHandler('echo', async (_r, ctx) => { + seen = { sessionId: ctx.sessionId, auth: ctx.http?.authInfo }; + return {} as Result; + }); + await collect(d.dispatch(req('echo'), { sessionId: 's1', authInfo: { token: 't', clientId: 'c', scopes: [] } })); + expect(seen.sessionId).toBe('s1'); + expect(seen.auth.token).toBe('t'); + }); + + test('dispatchNotification routes to handler and ignores unknown', async () => { + const d = new Dispatcher(); + let got: unknown; + d.setNotificationHandler('notifications/initialized', n => { + got = n.method; + }); + await d.dispatchNotification({ jsonrpc: '2.0', method: 'notifications/initialized' }); + expect(got).toBe('notifications/initialized'); + await expect(d.dispatchNotification({ jsonrpc: '2.0', method: 'unknown/thing' } as any)).resolves.toBeUndefined(); + }); + + test('fallbackRequestHandler is used when no specific handler matches', async () => { + const d = new Dispatcher(); + d.fallbackRequestHandler = async r => ({ echoed: r.method }) as Result; + const out = await collect(d.dispatch(req('whatever/method'))); + expect((out[0]!.message as JSONRPCResultResponse).result).toEqual({ echoed: 'whatever/method' }); + }); + + test('assertCanSetRequestHandler throws on collision', () => { + const d = new Dispatcher(); + d.setRequestHandler('ping', async () => ({})); + expect(() => d.assertCanSetRequestHandler('ping')).toThrow(/already exists/); + }); + + test('setRequestHandler parses request via schema', async () => { + const d = new Dispatcher(); + let parsed: unknown; + d.setRequestHandler('ping', r => { + parsed = r; + return {}; + }); + await collect(d.dispatch(req('ping'))); + expect(parsed).toMatchObject({ method: 'ping' }); + }); + + test('dispatchToResponse returns the terminal response', async () => { + const d = new Dispatcher(); + d.setRawRequestHandler('x', async (_r, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: 'n' } }); + return { v: 1 } as Result; + }); + const r = (await d.dispatchToResponse(req('x'))) as JSONRPCResultResponse; + expect(r.result).toEqual({ v: 1 }); + }); +}); + +describe('Dispatcher.setRequestHandler 3-arg (custom method + paramsSchema)', () => { + test('parses params, strips _meta, types handler arg', async () => { + const d = new Dispatcher(); + const schema = z.object({ q: z.string(), limit: z.number().optional() }); + d.setRequestHandler('acme/search', schema, async params => { + return { hits: [params.q], limit: params.limit ?? 10 } as Result; + }); + const r = (await d.dispatchToResponse(req('acme/search', { q: 'foo', _meta: { progressToken: 1 } }))) as JSONRPCResultResponse; + expect(r.result).toEqual({ hits: ['foo'], limit: 10 }); + }); + + test('schema validation failure becomes InvalidParams error response', async () => { + const d = new Dispatcher(); + d.setRequestHandler('acme/search', z.object({ q: z.string() }), async () => ({}) as Result); + const r = (await d.dispatchToResponse(req('acme/search', { q: 123 }))) as JSONRPCErrorResponse; + expect(r.error.code).toBe(ProtocolErrorCode.InvalidParams); + expect(r.error.message).toMatch(/Invalid params for acme\/search/); + }); +}); diff --git a/packages/core/test/shared/protocol.test.ts b/packages/core/test/shared/protocol.test.ts index 619e09376..6a91b9f83 100644 --- a/packages/core/test/shared/protocol.test.ts +++ b/packages/core/test/shared/protocol.test.ts @@ -39,12 +39,12 @@ import { SdkError, SdkErrorCode } from '../../src/errors/sdkErrors.js'; // Test Protocol subclass for testing class TestProtocolImpl extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { + protected override assertCapabilityForMethod(): void {} + protected override assertNotificationCapability(): void {} + protected override assertRequestHandlerCapability(): void {} + protected override assertTaskCapability(): void {} + protected override assertTaskHandlerCapability(): void {} + protected override buildContext(ctx: BaseContext): BaseContext { return ctx; } } @@ -2069,7 +2069,8 @@ describe('Task-based execution', () => { taskId: task.taskId, status: 'working' }) - }) + }), + expect.anything() ); // Verify _meta is not present or doesn't contain RELATED_TASK_META_KEY @@ -2186,7 +2187,8 @@ describe('Task-based execution', () => { } }) }) - }) + }), + expect.anything() ); }); @@ -2419,7 +2421,8 @@ describe('Request Cancellation vs Task Cancellation', () => { code: ProtocolErrorCode.InvalidParams, message: expect.stringContaining('Cannot cancel task in terminal status') }) - }) + }), + expect.anything() ); }); @@ -2451,7 +2454,8 @@ describe('Request Cancellation vs Task Cancellation', () => { code: ProtocolErrorCode.InvalidParams, message: expect.stringContaining('Task not found') }) - }) + }), + expect.anything() ); }); }); @@ -2804,48 +2808,32 @@ describe('Progress notification support for tasks', () => { const messageId = sentRequest.id; const progressToken = sentRequest.params._meta.progressToken; - // Simulate CreateTaskResult response - const taskId = 'test-task-456'; + // Create the task in the store so the ctx.task.store path can find it. + const createdTask = await taskStore.createTask({ ttl: 60_000 }, messageId, request); + const taskId = createdTask.taskId; if (transport.onmessage) { transport.onmessage({ jsonrpc: '2.0', id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } + result: { task: createdTask } }); } await new Promise(resolve => setTimeout(resolve, 10)); - // Simulate task failure via storeTaskResult - await taskStore.storeTaskResult(taskId, 'failed', { - content: [], - isError: true + // Simulate task failure via the public ctx.task.store path (same as the + // (completed) variant), which is what triggers progress-handler cleanup. + protocol.setRequestHandler('ping', async (_request, ctx) => { + if (ctx.task?.store) { + await ctx.task.store.storeTaskResult(taskId, 'failed', { content: [], isError: true }); + } + return {}; }); - - // Manually trigger the status notification if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/tasks/status', - params: { - taskId, - status: 'failed', - ttl: 60000, - createdAt: new Date().toISOString(), - lastUpdatedAt: new Date().toISOString(), - statusMessage: 'Task failed' - } - }); + transport.onmessage({ jsonrpc: '2.0', id: 998, method: 'ping', params: {} }); } - await new Promise(resolve => setTimeout(resolve, 10)); + await new Promise(resolve => setTimeout(resolve, 50)); // Try to send progress notification after task failure - should be ignored progressCallback.mockClear(); @@ -2896,45 +2884,32 @@ describe('Progress notification support for tasks', () => { const messageId = sentRequest.id; const progressToken = sentRequest.params._meta.progressToken; - // Simulate CreateTaskResult response - const taskId = 'test-task-789'; + // Create the task in the store so the ctx.task.store path can find it. + const createdTask = await taskStore.createTask({ ttl: 60_000 }, messageId, request); + const taskId = createdTask.taskId; if (transport.onmessage) { transport.onmessage({ jsonrpc: '2.0', id: messageId, - result: { - task: { - taskId, - status: 'working', - ttl: 60000, - createdAt: new Date().toISOString() - } - } + result: { task: createdTask } }); } await new Promise(resolve => setTimeout(resolve, 10)); - // Simulate task cancellation via updateTaskStatus - await taskStore.updateTaskStatus(taskId, 'cancelled', 'User cancelled'); - - // Manually trigger the status notification + // Simulate task cancellation via the public ctx.task.store path (same as the + // (completed) variant), which is what triggers progress-handler cleanup. + protocol.setRequestHandler('ping', async (_request, ctx) => { + if (ctx.task?.store) { + await ctx.task.store.updateTaskStatus(taskId, 'cancelled', 'User cancelled'); + } + return {}; + }); if (transport.onmessage) { - transport.onmessage({ - jsonrpc: '2.0', - method: 'notifications/tasks/status', - params: { - taskId, - status: 'cancelled', - ttl: 60000, - createdAt: new Date().toISOString(), - lastUpdatedAt: new Date().toISOString(), - statusMessage: 'User cancelled' - } - }); + transport.onmessage({ jsonrpc: '2.0', id: 997, method: 'ping', params: {} }); } - await new Promise(resolve => setTimeout(resolve, 10)); + await new Promise(resolve => setTimeout(resolve, 50)); // Try to send progress notification after cancellation - should be ignored progressCallback.mockClear(); @@ -3899,7 +3874,8 @@ describe('Message Interception', () => { jsonrpc: '2.0', id: requestId, result: { content: [{ type: 'text', text: 'done' }] } - }) + }), + expect.anything() ); }); }); @@ -5633,7 +5609,8 @@ describe('Protocol without task configuration', () => { jsonrpc: '2.0', id: 1, result: { content: 'ok' } - }) + }), + expect.anything() ); }); }); @@ -5647,12 +5624,12 @@ describe('TaskManager lifecycle via Protocol', () => { protocol = new TestProtocolImpl(); }); - test('bind() is called during Protocol construction', () => { - const bindSpy = vi.spyOn(TaskManager.prototype, 'bind'); + test('attachTo() is called during Protocol construction', () => { + const attachSpy = vi.spyOn(TaskManager.prototype, 'attachTo'); const p = new TestProtocolImpl({ tasks: {} }); - expect(bindSpy).toHaveBeenCalled(); + expect(attachSpy).toHaveBeenCalled(); expect(p.taskManager).toBeInstanceOf(TaskManager); - bindSpy.mockRestore(); + attachSpy.mockRestore(); }); test('NullTaskManager is created when no tasks config is provided', () => { diff --git a/packages/core/test/shared/protocolTransportHandling.test.ts b/packages/core/test/shared/protocolTransportHandling.test.ts index 4e9c33e67..f6f162441 100644 --- a/packages/core/test/shared/protocolTransportHandling.test.ts +++ b/packages/core/test/shared/protocolTransportHandling.test.ts @@ -35,12 +35,12 @@ describe('Protocol transport handling bug', () => { beforeEach(() => { protocol = new (class extends Protocol { - protected assertCapabilityForMethod(): void {} - protected assertNotificationCapability(): void {} - protected assertRequestHandlerCapability(): void {} - protected assertTaskCapability(): void {} - protected assertTaskHandlerCapability(): void {} - protected buildContext(ctx: BaseContext): BaseContext { + protected override assertCapabilityForMethod(): void {} + protected override assertNotificationCapability(): void {} + protected override assertRequestHandlerCapability(): void {} + protected override assertTaskCapability(): void {} + protected override assertTaskHandlerCapability(): void {} + protected override buildContext(ctx: BaseContext): BaseContext { return ctx; } })(); diff --git a/packages/core/test/shared/streamDriver.test.ts b/packages/core/test/shared/streamDriver.test.ts new file mode 100644 index 000000000..dc0312d39 --- /dev/null +++ b/packages/core/test/shared/streamDriver.test.ts @@ -0,0 +1,217 @@ +import { describe, expect, test, vi } from 'vitest'; + +import { Dispatcher } from '../../src/shared/dispatcher.js'; +import { StreamDriver } from '../../src/shared/streamDriver.js'; +import type { JSONRPCMessage, Progress, Result } from '../../src/types/index.js'; +import { ResultSchema } from '../../src/types/index.js'; +import { InMemoryTransport } from '../../src/util/inMemory.js'; + +function linkedDrivers(opts: { server?: Dispatcher; client?: Dispatcher } = {}) { + const [cPipe, sPipe] = InMemoryTransport.createLinkedPair(); + const serverDisp = opts.server ?? new Dispatcher(); + const clientDisp = opts.client ?? new Dispatcher(); + const server = new StreamDriver(serverDisp, sPipe); + const client = new StreamDriver(clientDisp, cPipe); + return { server, client, serverDisp, clientDisp, cPipe, sPipe }; +} + +describe('StreamDriver', () => { + test('correlates outbound request with inbound response', async () => { + const { server, client, serverDisp } = linkedDrivers(); + serverDisp.setRequestHandler('ping', async () => ({})); + await server.start(); + await client.start(); + const r = await client.request({ method: 'ping' }, ResultSchema); + expect(r).toEqual({}); + }); + + test('request rejects on JSON-RPC error response', async () => { + const { server, client } = linkedDrivers(); + await server.start(); + await client.start(); + await expect(client.request({ method: 'tools/list' }, ResultSchema)).rejects.toThrow(); + }); + + test('request times out and sends cancellation', async () => { + vi.useFakeTimers(); + const { server, client, sPipe } = linkedDrivers(); + await server.start(); + await client.start(); + const sent: JSONRPCMessage[] = []; + const orig = sPipe.onmessage!; + sPipe.onmessage = m => { + sent.push(m); + // swallow: never respond + }; + void orig; + const p = client.request({ method: 'ping' }, ResultSchema, { timeout: 50 }); + vi.advanceTimersByTime(60); + await expect(p).rejects.toThrow(/timed out/); + expect(sent.some(m => 'method' in m && m.method === 'notifications/cancelled')).toBe(true); + vi.useRealTimers(); + }); + + test('progress callback invoked and resets timeout when configured', async () => { + vi.useFakeTimers(); + const { server, client, serverDisp } = linkedDrivers(); + let resolveHandler!: () => void; + serverDisp.setRawRequestHandler('work', (_r, ctx) => { + void ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: ctx.mcpReq.id, progress: 0.5 } }); + return new Promise(r => { + resolveHandler = () => r({} as Result); + }); + }); + await server.start(); + await client.start(); + const seen: Progress[] = []; + const p = client.request({ method: 'work' as any }, ResultSchema, { + timeout: 100, + resetTimeoutOnProgress: true, + onprogress: pr => seen.push(pr) + }); + await vi.advanceTimersByTimeAsync(0); + expect(seen).toHaveLength(1); + expect(seen[0]!.progress).toBe(0.5); + await vi.advanceTimersByTimeAsync(80); + resolveHandler(); + await vi.advanceTimersByTimeAsync(0); + await expect(p).resolves.toEqual({}); + vi.useRealTimers(); + }); + + test('outbound abort signal cancels request', async () => { + const { server, client, serverDisp } = linkedDrivers(); + serverDisp.setRawRequestHandler('slow', () => new Promise(() => {})); + await server.start(); + await client.start(); + const ac = new AbortController(); + const p = client.request({ method: 'slow' as any }, ResultSchema, { signal: ac.signal, timeout: 10_000 }); + ac.abort('user'); + await expect(p).rejects.toThrow(); + }); + + test('inbound notifications/cancelled aborts the handler', async () => { + const { server, client, serverDisp } = linkedDrivers(); + let aborted = false; + serverDisp.setRawRequestHandler('slow', (_r, ctx) => { + return new Promise(resolve => { + ctx.mcpReq.signal.addEventListener('abort', () => { + aborted = true; + resolve({} as Result); + }); + }); + }); + await server.start(); + await client.start(); + const ac = new AbortController(); + const p = client.request({ method: 'slow' as any }, ResultSchema, { signal: ac.signal, timeout: 10_000 }); + await new Promise(r => setTimeout(r, 0)); + ac.abort('stop'); + await p.catch(() => {}); + await new Promise(r => setTimeout(r, 0)); + expect(aborted).toBe(true); + }); + + test('handler notify flows over pipe and arrives at client dispatcher', async () => { + const { server, client, serverDisp, clientDisp } = linkedDrivers(); + const got: unknown[] = []; + clientDisp.setNotificationHandler('notifications/message', n => { + got.push(n.params); + }); + serverDisp.setRawRequestHandler('work', async (_r, ctx) => { + await ctx.mcpReq.notify({ method: 'notifications/message', params: { level: 'info', data: 'hi' } }); + return {} as Result; + }); + await server.start(); + await client.start(); + await client.request({ method: 'work' as any }, ResultSchema); + expect(got).toEqual([{ level: 'info', data: 'hi' }]); + }); + + test('close rejects pending outbound requests', async () => { + const { server, client, serverDisp } = linkedDrivers(); + serverDisp.setRawRequestHandler('slow', () => new Promise(() => {})); + await server.start(); + await client.start(); + const p = client.request({ method: 'slow' as any }, ResultSchema, { timeout: 10_000 }); + await client.close(); + await expect(p).rejects.toThrow(/Connection closed/); + }); + + test('close aborts in-flight inbound handlers', async () => { + const { server, client, serverDisp } = linkedDrivers(); + let abortedReason: unknown; + serverDisp.setRawRequestHandler('slow', (_r, ctx) => { + return new Promise(() => { + ctx.mcpReq.signal.addEventListener('abort', () => { + abortedReason = ctx.mcpReq.signal.reason; + }); + }); + }); + await server.start(); + await client.start(); + client.request({ method: 'slow' as any }, ResultSchema, { timeout: 10_000 }).catch(() => {}); + await new Promise(r => setTimeout(r, 0)); + await server.close(); + expect(abortedReason).toBeDefined(); + }); + + test('debounced notifications coalesce within a tick', async () => { + const [cPipe, sPipe] = InMemoryTransport.createLinkedPair(); + const driver = new StreamDriver(new Dispatcher(), cPipe, { + debouncedNotificationMethods: ['notifications/tools/list_changed'] + }); + const seen: JSONRPCMessage[] = []; + sPipe.onmessage = m => seen.push(m); + await sPipe.start(); + await driver.start(); + void driver.notification({ method: 'notifications/tools/list_changed' }); + void driver.notification({ method: 'notifications/tools/list_changed' }); + void driver.notification({ method: 'notifications/tools/list_changed' }); + await new Promise(r => setTimeout(r, 0)); + expect(seen.filter(m => 'method' in m && m.method === 'notifications/tools/list_changed')).toHaveLength(1); + }); + + test('ctx.mcpReq.send round-trips back through the same driver pair', async () => { + const { server, client, serverDisp, clientDisp } = linkedDrivers(); + let pinged = false; + clientDisp.setRequestHandler('ping', async () => { + pinged = true; + return {}; + }); + let elicited: unknown; + serverDisp.setRawRequestHandler('ask', async (_r, ctx) => { + elicited = await ctx.mcpReq.send({ method: 'ping' }); + return {} as Result; + }); + await server.start(); + await client.start(); + await client.request({ method: 'ask' as any }, ResultSchema); + expect(pinged).toBe(true); + expect(elicited).toEqual({}); + }); + + test('onerror fires for response with unknown id', async () => { + const [cPipe, sPipe] = InMemoryTransport.createLinkedPair(); + const driver = new StreamDriver(new Dispatcher(), cPipe); + const errs: Error[] = []; + driver.onerror = e => errs.push(e); + await driver.start(); + await sPipe.start(); + await sPipe.send({ jsonrpc: '2.0', id: 999, result: {} }); + expect(errs[0]?.message).toMatch(/unknown message ID/); + }); + + test('concurrent requests get distinct ids and resolve independently', async () => { + const { server, client, serverDisp } = linkedDrivers(); + serverDisp.setRawRequestHandler('echo', async r => ({ id: r.id }) as Result); + await server.start(); + await client.start(); + const [a, b, c] = await Promise.all([ + client.request({ method: 'echo' as any }, ResultSchema), + client.request({ method: 'echo' as any }, ResultSchema), + client.request({ method: 'echo' as any }, ResultSchema) + ]); + expect(new Set([a.id, b.id, c.id]).size).toBe(3); + }); +}); diff --git a/packages/middleware/express/package.json b/packages/middleware/express/package.json index 39d671b81..633570d5c 100644 --- a/packages/middleware/express/package.json +++ b/packages/middleware/express/package.json @@ -24,7 +24,8 @@ "exports": { ".": { "types": "./dist/index.d.mts", - "import": "./dist/index.mjs" + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" } }, "files": [ @@ -63,5 +64,6 @@ "typescript": "catalog:devTools", "typescript-eslint": "catalog:devTools", "vitest": "catalog:devTools" - } + }, + "types": "./dist/index.d.mts" } diff --git a/packages/middleware/fastify/package.json b/packages/middleware/fastify/package.json index d3d4c352b..071fcdab7 100644 --- a/packages/middleware/fastify/package.json +++ b/packages/middleware/fastify/package.json @@ -24,7 +24,8 @@ "exports": { ".": { "types": "./dist/index.d.mts", - "import": "./dist/index.mjs" + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" } }, "files": [ @@ -61,5 +62,6 @@ "typescript": "catalog:devTools", "typescript-eslint": "catalog:devTools", "vitest": "catalog:devTools" - } + }, + "types": "./dist/index.d.mts" } diff --git a/packages/middleware/hono/package.json b/packages/middleware/hono/package.json index f23c9ccb6..c20e5ddbb 100644 --- a/packages/middleware/hono/package.json +++ b/packages/middleware/hono/package.json @@ -24,7 +24,8 @@ "exports": { ".": { "types": "./dist/index.d.mts", - "import": "./dist/index.mjs" + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" } }, "files": [ @@ -61,5 +62,6 @@ "typescript": "catalog:devTools", "typescript-eslint": "catalog:devTools", "vitest": "catalog:devTools" - } + }, + "types": "./dist/index.d.mts" } diff --git a/packages/middleware/node/package.json b/packages/middleware/node/package.json index 7fcaf9106..9ea58d700 100644 --- a/packages/middleware/node/package.json +++ b/packages/middleware/node/package.json @@ -23,7 +23,8 @@ "exports": { ".": { "types": "./dist/index.d.mts", - "import": "./dist/index.mjs" + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" } }, "files": [ @@ -67,5 +68,6 @@ "typescript": "catalog:devTools", "typescript-eslint": "catalog:devTools", "vitest": "catalog:devTools" - } + }, + "types": "./dist/index.d.mts" } diff --git a/packages/middleware/node/src/streamableHttp.ts b/packages/middleware/node/src/streamableHttp.ts index 68a0c224f..591842785 100644 --- a/packages/middleware/node/src/streamableHttp.ts +++ b/packages/middleware/node/src/streamableHttp.ts @@ -10,7 +10,18 @@ import type { IncomingMessage, ServerResponse } from 'node:http'; import { getRequestListener } from '@hono/node-server'; -import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core'; +import type { + AuthInfo, + JSONRPCErrorResponse, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResultResponse, + MessageExtraInfo, + RequestEnv, + RequestId, + RequestTransport +} from '@modelcontextprotocol/core'; import type { WebStandardStreamableHTTPServerTransportOptions } from '@modelcontextprotocol/server'; import { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/server'; @@ -21,6 +32,26 @@ import { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/ */ export type StreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServerTransportOptions; +/** + * Converts a web-standard `(Request) => Response` handler into a Node.js + * `(IncomingMessage, ServerResponse) => void` handler suitable for Express, + * `http.createServer`, etc. + * + * @example + * ```ts + * const app = express(); + * app.post('/mcp', toNodeHttpHandler(req => mcpServer.handleHttp(req))); + * ``` + */ +export function toNodeHttpHandler( + handler: (req: Request) => Response | Promise +): (req: IncomingMessage, res: ServerResponse) => Promise { + const listener = getRequestListener(handler, { overrideGlobalObjects: false }); + return async (req, res) => { + await listener(req, res); + }; +} + /** * Server transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. * It supports both SSE streaming and direct HTTP responses. @@ -64,7 +95,9 @@ export type StreamableHTTPServerTransportOptions = WebStandardStreamableHTTPServ * }); * ``` */ -export class NodeStreamableHTTPServerTransport implements Transport { +export class NodeStreamableHTTPServerTransport implements RequestTransport { + readonly kind = 'request' as const; + private _webStandardTransport: WebStandardStreamableHTTPServerTransport; private _requestListener: ReturnType; // Store auth and parsedBody per request for passing through to handleRequest @@ -130,6 +163,35 @@ export class NodeStreamableHTTPServerTransport implements Transport { return this._webStandardTransport.onmessage; } + // RequestTransport callback slots — delegate to the wrapped web-standard transport. + get onrequest(): ((req: JSONRPCRequest, env?: RequestEnv) => AsyncIterable) | undefined { + return this._webStandardTransport.onrequest; + } + set onrequest(h: ((req: JSONRPCRequest, env?: RequestEnv) => AsyncIterable) | undefined) { + this._webStandardTransport.onrequest = h; + } + get onnotification(): ((n: JSONRPCNotification) => void | Promise) | undefined { + return this._webStandardTransport.onnotification; + } + set onnotification(h: ((n: JSONRPCNotification) => void | Promise) | undefined) { + this._webStandardTransport.onnotification = h; + } + get onresponse(): ((r: JSONRPCResultResponse | JSONRPCErrorResponse) => boolean) | undefined { + return this._webStandardTransport.onresponse; + } + set onresponse(h: ((r: JSONRPCResultResponse | JSONRPCErrorResponse) => boolean) | undefined) { + this._webStandardTransport.onresponse = h; + } + + /** {@linkcode RequestTransport.notify} — delegates to the wrapped transport. */ + notify(n: JSONRPCNotification): Promise { + return this._webStandardTransport.notify(n); + } + /** {@linkcode RequestTransport.request} — delegates to the wrapped transport. */ + request(r: JSONRPCRequest): Promise { + return this._webStandardTransport.request(r); + } + /** * Starts the transport. This is required by the {@linkcode Transport} interface but is a no-op * for the Streamable HTTP transport as connections are managed per-request. diff --git a/packages/sdk/README.md b/packages/sdk/README.md new file mode 100644 index 000000000..a64d003cb --- /dev/null +++ b/packages/sdk/README.md @@ -0,0 +1,25 @@ +# @modelcontextprotocol/sdk + +The **primary entry point** for the Model Context Protocol TypeScript SDK. + +This meta-package re-exports the full public surface of [`@modelcontextprotocol/server`](../server), [`@modelcontextprotocol/client`](../client), and [`@modelcontextprotocol/node`](../middleware/node), so most applications can depend on this package alone: + +```ts +import { McpServer, Client, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/sdk'; +``` + +## Upgrading from v1 + +`@modelcontextprotocol/sdk` v2 is a drop-in upgrade for most v1 servers — just bump the version. v1 deep-import paths (`@modelcontextprotocol/sdk/types.js`, `/server/mcp.js`, `/client/index.js`, `/shared/transport.js`, etc.) are preserved as compatibility subpaths that re-export +the matching v2 symbols and emit one-time deprecation warnings where the API shape changed. + +See [`docs/migration.md`](../../docs/migration.md) for the full mapping. + +## When to use the sub-packages directly + +Bundle-sensitive targets (browsers, Cloudflare Workers) should import from `@modelcontextprotocol/client` or `@modelcontextprotocol/server` directly to avoid pulling in Node-only transports. + +## Optional subpaths + +The `./server/auth/*` subpaths re-export the legacy Authorization Server helpers from `@modelcontextprotocol/server-auth-legacy`, which require `express` to be installed by the consumer. Similarly, `./server/sse.js` (the deprecated `SSEServerTransport`) is provided by +`@modelcontextprotocol/node`. Both `express` and `hono` are optional peer dependencies — install them only if you use those subpaths. diff --git a/packages/sdk/eslint.config.mjs b/packages/sdk/eslint.config.mjs new file mode 100644 index 000000000..e34a2a51a --- /dev/null +++ b/packages/sdk/eslint.config.mjs @@ -0,0 +1,27 @@ +// @ts-check + +import baseConfig from '@modelcontextprotocol/eslint-config'; + +export default [ + ...baseConfig, + { + settings: { + 'import/internal-regex': '^@modelcontextprotocol/' + } + }, + { + // This package is the v1-compat surface; deprecated re-exports are intentional. + // import/no-unresolved: subpaths re-export from sibling packages (server-auth-legacy, + // node/sse, server/zod-schemas) that don't exist on this branch standalone — they + // land via separate PRs in this BC series. Resolves once those merge. + // import/export: types.ts deliberately shadows `export *` names with v1-compat aliases + // (TS spec: named export wins over re-export). + // unicorn/filename-case: validation/ajv-provider.ts etc. match v1 subpath names. + rules: { + '@typescript-eslint/no-deprecated': 'off', + 'import/no-unresolved': 'off', + 'import/export': 'off', + 'unicorn/filename-case': 'off' + } + } +]; diff --git a/packages/sdk/package.json b/packages/sdk/package.json new file mode 100644 index 000000000..c23becf8d --- /dev/null +++ b/packages/sdk/package.json @@ -0,0 +1,395 @@ +{ + "name": "@modelcontextprotocol/sdk", + "version": "2.0.0-alpha.2", + "description": "Model Context Protocol implementation for TypeScript - Full SDK (re-exports client, server, and node middleware)", + "license": "MIT", + "author": "Anthropic, PBC (https://anthropic.com)", + "homepage": "https://modelcontextprotocol.io", + "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", + "type": "module", + "types": "./dist/index.d.ts", + "repository": { + "type": "git", + "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" + }, + "engines": { + "node": ">=20" + }, + "keywords": [ + "modelcontextprotocol", + "mcp" + ], + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" + }, + "./stdio": { + "types": "./dist/stdio.d.mts", + "import": "./dist/stdio.mjs", + "require": "./dist/stdio.mjs" + }, + "./types.js": { + "types": "./dist/types.d.mts", + "import": "./dist/types.mjs", + "require": "./dist/types.mjs" + }, + "./types": { + "types": "./dist/types.d.mts", + "import": "./dist/types.mjs", + "require": "./dist/types.mjs" + }, + "./server/index.js": { + "types": "./dist/server/index.d.mts", + "import": "./dist/server/index.mjs", + "require": "./dist/server/index.mjs" + }, + "./server/index": { + "types": "./dist/server/index.d.mts", + "import": "./dist/server/index.mjs", + "require": "./dist/server/index.mjs" + }, + "./server/mcp.js": { + "types": "./dist/server/mcp.d.mts", + "import": "./dist/server/mcp.mjs", + "require": "./dist/server/mcp.mjs" + }, + "./server/mcp": { + "types": "./dist/server/mcp.d.mts", + "import": "./dist/server/mcp.mjs", + "require": "./dist/server/mcp.mjs" + }, + "./server/zod-compat.js": { + "types": "./dist/server/zod-compat.d.mts", + "import": "./dist/server/zod-compat.mjs", + "require": "./dist/server/zod-compat.mjs" + }, + "./server/zod-compat": { + "types": "./dist/server/zod-compat.d.mts", + "import": "./dist/server/zod-compat.mjs", + "require": "./dist/server/zod-compat.mjs" + }, + "./server/stdio.js": { + "types": "./dist/server/stdio.d.mts", + "import": "./dist/server/stdio.mjs", + "require": "./dist/server/stdio.mjs" + }, + "./server/stdio": { + "types": "./dist/server/stdio.d.mts", + "import": "./dist/server/stdio.mjs", + "require": "./dist/server/stdio.mjs" + }, + "./server/streamableHttp.js": { + "types": "./dist/server/streamableHttp.d.mts", + "import": "./dist/server/streamableHttp.mjs", + "require": "./dist/server/streamableHttp.mjs" + }, + "./server/streamableHttp": { + "types": "./dist/server/streamableHttp.d.mts", + "import": "./dist/server/streamableHttp.mjs", + "require": "./dist/server/streamableHttp.mjs" + }, + "./server/auth/types.js": { + "types": "./dist/server/auth/types.d.mts", + "import": "./dist/server/auth/types.mjs", + "require": "./dist/server/auth/types.mjs" + }, + "./server/auth/types": { + "types": "./dist/server/auth/types.d.mts", + "import": "./dist/server/auth/types.mjs", + "require": "./dist/server/auth/types.mjs" + }, + "./server/auth/errors.js": { + "types": "./dist/server/auth/errors.d.mts", + "import": "./dist/server/auth/errors.mjs", + "require": "./dist/server/auth/errors.mjs" + }, + "./server/auth/errors": { + "types": "./dist/server/auth/errors.d.mts", + "import": "./dist/server/auth/errors.mjs", + "require": "./dist/server/auth/errors.mjs" + }, + "./client": { + "types": "./dist/client/index.d.mts", + "import": "./dist/client/index.mjs", + "require": "./dist/client/index.mjs" + }, + "./client/index.js": { + "types": "./dist/client/index.d.mts", + "import": "./dist/client/index.mjs", + "require": "./dist/client/index.mjs" + }, + "./client/index": { + "types": "./dist/client/index.d.mts", + "import": "./dist/client/index.mjs", + "require": "./dist/client/index.mjs" + }, + "./client/stdio.js": { + "types": "./dist/client/stdio.d.mts", + "import": "./dist/client/stdio.mjs", + "require": "./dist/client/stdio.mjs" + }, + "./client/stdio": { + "types": "./dist/client/stdio.d.mts", + "import": "./dist/client/stdio.mjs", + "require": "./dist/client/stdio.mjs" + }, + "./client/streamableHttp.js": { + "types": "./dist/client/streamableHttp.d.mts", + "import": "./dist/client/streamableHttp.mjs", + "require": "./dist/client/streamableHttp.mjs" + }, + "./client/streamableHttp": { + "types": "./dist/client/streamableHttp.d.mts", + "import": "./dist/client/streamableHttp.mjs", + "require": "./dist/client/streamableHttp.mjs" + }, + "./client/sse.js": { + "types": "./dist/client/sse.d.mts", + "import": "./dist/client/sse.mjs", + "require": "./dist/client/sse.mjs" + }, + "./client/sse": { + "types": "./dist/client/sse.d.mts", + "import": "./dist/client/sse.mjs", + "require": "./dist/client/sse.mjs" + }, + "./client/auth.js": { + "types": "./dist/client/auth.d.mts", + "import": "./dist/client/auth.mjs", + "require": "./dist/client/auth.mjs" + }, + "./client/auth": { + "types": "./dist/client/auth.d.mts", + "import": "./dist/client/auth.mjs", + "require": "./dist/client/auth.mjs" + }, + "./shared/protocol.js": { + "types": "./dist/shared/protocol.d.mts", + "import": "./dist/shared/protocol.mjs", + "require": "./dist/shared/protocol.mjs" + }, + "./shared/protocol": { + "types": "./dist/shared/protocol.d.mts", + "import": "./dist/shared/protocol.mjs", + "require": "./dist/shared/protocol.mjs" + }, + "./shared/transport.js": { + "types": "./dist/shared/transport.d.mts", + "import": "./dist/shared/transport.mjs", + "require": "./dist/shared/transport.mjs" + }, + "./shared/transport": { + "types": "./dist/shared/transport.d.mts", + "import": "./dist/shared/transport.mjs", + "require": "./dist/shared/transport.mjs" + }, + "./shared/auth.js": { + "types": "./dist/shared/auth.d.mts", + "import": "./dist/shared/auth.mjs", + "require": "./dist/shared/auth.mjs" + }, + "./shared/auth": { + "types": "./dist/shared/auth.d.mts", + "import": "./dist/shared/auth.mjs", + "require": "./dist/shared/auth.mjs" + }, + "./server/auth/middleware/bearerAuth.js": { + "types": "./dist/server/auth/middleware/bearerAuth.d.mts", + "import": "./dist/server/auth/middleware/bearerAuth.mjs", + "require": "./dist/server/auth/middleware/bearerAuth.mjs" + }, + "./server/auth/middleware/bearerAuth": { + "types": "./dist/server/auth/middleware/bearerAuth.d.mts", + "import": "./dist/server/auth/middleware/bearerAuth.mjs", + "require": "./dist/server/auth/middleware/bearerAuth.mjs" + }, + "./server/auth/router.js": { + "types": "./dist/server/auth/router.d.mts", + "import": "./dist/server/auth/router.mjs", + "require": "./dist/server/auth/router.mjs" + }, + "./server/auth/router": { + "types": "./dist/server/auth/router.d.mts", + "import": "./dist/server/auth/router.mjs", + "require": "./dist/server/auth/router.mjs" + }, + "./server/auth/provider.js": { + "types": "./dist/server/auth/provider.d.mts", + "import": "./dist/server/auth/provider.mjs", + "require": "./dist/server/auth/provider.mjs" + }, + "./server/auth/provider": { + "types": "./dist/server/auth/provider.d.mts", + "import": "./dist/server/auth/provider.mjs", + "require": "./dist/server/auth/provider.mjs" + }, + "./server/auth/clients.js": { + "types": "./dist/server/auth/clients.d.mts", + "import": "./dist/server/auth/clients.mjs", + "require": "./dist/server/auth/clients.mjs" + }, + "./server/auth/clients": { + "types": "./dist/server/auth/clients.d.mts", + "import": "./dist/server/auth/clients.mjs", + "require": "./dist/server/auth/clients.mjs" + }, + "./inMemory.js": { + "types": "./dist/inMemory.d.mts", + "import": "./dist/inMemory.mjs", + "require": "./dist/inMemory.mjs" + }, + "./inMemory": { + "types": "./dist/inMemory.d.mts", + "import": "./dist/inMemory.mjs", + "require": "./dist/inMemory.mjs" + }, + "./server/completable.js": { + "types": "./dist/server/completable.d.mts", + "import": "./dist/server/completable.mjs", + "require": "./dist/server/completable.mjs" + }, + "./server/completable": { + "types": "./dist/server/completable.d.mts", + "import": "./dist/server/completable.mjs", + "require": "./dist/server/completable.mjs" + }, + "./server/sse.js": { + "types": "./dist/server/sse.d.mts", + "import": "./dist/server/sse.mjs", + "require": "./dist/server/sse.mjs" + }, + "./server/sse": { + "types": "./dist/server/sse.d.mts", + "import": "./dist/server/sse.mjs", + "require": "./dist/server/sse.mjs" + }, + "./experimental/tasks": { + "types": "./dist/experimental/tasks.d.mts", + "import": "./dist/experimental/tasks.mjs", + "require": "./dist/experimental/tasks.mjs" + }, + "./server": { + "types": "./dist/server/index.d.mts", + "import": "./dist/server/index.mjs", + "require": "./dist/server/index.mjs" + }, + "./server.js": { + "types": "./dist/server/index.d.mts", + "import": "./dist/server/index.mjs", + "require": "./dist/server/index.mjs" + }, + "./client.js": { + "types": "./dist/client/index.d.mts", + "import": "./dist/client/index.mjs", + "require": "./dist/client/index.mjs" + }, + "./server/webStandardStreamableHttp.js": { + "types": "./dist/server/webStandardStreamableHttp.d.mts", + "import": "./dist/server/webStandardStreamableHttp.mjs", + "require": "./dist/server/webStandardStreamableHttp.mjs" + }, + "./server/webStandardStreamableHttp": { + "types": "./dist/server/webStandardStreamableHttp.d.mts", + "import": "./dist/server/webStandardStreamableHttp.mjs", + "require": "./dist/server/webStandardStreamableHttp.mjs" + }, + "./shared/stdio.js": { + "types": "./dist/shared/stdio.d.mts", + "import": "./dist/shared/stdio.mjs", + "require": "./dist/shared/stdio.mjs" + }, + "./shared/stdio": { + "types": "./dist/shared/stdio.d.mts", + "import": "./dist/shared/stdio.mjs", + "require": "./dist/shared/stdio.mjs" + }, + "./validation/types.js": { + "types": "./dist/validation/types.d.mts", + "import": "./dist/validation/types.mjs", + "require": "./dist/validation/types.mjs" + }, + "./validation/types": { + "types": "./dist/validation/types.d.mts", + "import": "./dist/validation/types.mjs", + "require": "./dist/validation/types.mjs" + }, + "./validation/cfworker-provider.js": { + "types": "./dist/validation/cfworker-provider.d.mts", + "import": "./dist/validation/cfworker-provider.mjs", + "require": "./dist/validation/cfworker-provider.mjs" + }, + "./validation/cfworker-provider": { + "types": "./dist/validation/cfworker-provider.d.mts", + "import": "./dist/validation/cfworker-provider.mjs", + "require": "./dist/validation/cfworker-provider.mjs" + }, + "./validation/ajv-provider.js": { + "types": "./dist/validation/ajv-provider.d.mts", + "import": "./dist/validation/ajv-provider.mjs", + "require": "./dist/validation/ajv-provider.mjs" + }, + "./validation/ajv-provider": { + "types": "./dist/validation/ajv-provider.d.mts", + "import": "./dist/validation/ajv-provider.mjs", + "require": "./dist/validation/ajv-provider.mjs" + } + }, + "files": [ + "dist" + ], + "scripts": { + "typecheck": "tsgo -p tsconfig.json --noEmit", + "build": "tsdown", + "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "check": "pnpm run typecheck && pnpm run lint", + "test": "vitest run", + "test:watch": "vitest", + "prepack": "pnpm run build" + }, + "dependencies": { + "@modelcontextprotocol/client": "workspace:^", + "@modelcontextprotocol/node": "workspace:^", + "@modelcontextprotocol/server": "workspace:^", + "@modelcontextprotocol/server-auth-legacy": "workspace:^" + }, + "peerDependencies": { + "express": "^4.18.0 || ^5.0.0", + "hono": "*" + }, + "peerDependenciesMeta": { + "express": { + "optional": true + }, + "hono": { + "optional": true + } + }, + "devDependencies": { + "@modelcontextprotocol/core": "workspace:^", + "@modelcontextprotocol/eslint-config": "workspace:^", + "@modelcontextprotocol/test-helpers": "workspace:^", + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", + "@typescript/native-preview": "catalog:devTools", + "eslint": "catalog:devTools", + "prettier": "catalog:devTools", + "tsdown": "catalog:devTools", + "typescript": "catalog:devTools", + "vitest": "catalog:devTools", + "zod": "catalog:runtimeShared" + }, + "typesVersions": { + "*": { + "*.js": [ + "dist/*.d.mts" + ], + "*": [ + "dist/*.d.mts", + "dist/*/index.d.mts" + ] + } + } +} diff --git a/packages/sdk/src/client/auth.ts b/packages/sdk/src/client/auth.ts new file mode 100644 index 000000000..55eaaedf6 --- /dev/null +++ b/packages/sdk/src/client/auth.ts @@ -0,0 +1 @@ +export * from '@modelcontextprotocol/client'; diff --git a/packages/sdk/src/client/index.ts b/packages/sdk/src/client/index.ts new file mode 100644 index 000000000..55eaaedf6 --- /dev/null +++ b/packages/sdk/src/client/index.ts @@ -0,0 +1 @@ +export * from '@modelcontextprotocol/client'; diff --git a/packages/sdk/src/client/sse.ts b/packages/sdk/src/client/sse.ts new file mode 100644 index 000000000..de4e3b56e --- /dev/null +++ b/packages/sdk/src/client/sse.ts @@ -0,0 +1 @@ +export { SSEClientTransport, type SSEClientTransportOptions, SseError } from '@modelcontextprotocol/client'; diff --git a/packages/sdk/src/client/stdio.ts b/packages/sdk/src/client/stdio.ts new file mode 100644 index 000000000..3b7c7397d --- /dev/null +++ b/packages/sdk/src/client/stdio.ts @@ -0,0 +1,2 @@ +export type { StdioServerParameters } from '@modelcontextprotocol/client'; +export { DEFAULT_INHERITED_ENV_VARS, getDefaultEnvironment, StdioClientTransport } from '@modelcontextprotocol/client'; diff --git a/packages/sdk/src/client/streamableHttp.ts b/packages/sdk/src/client/streamableHttp.ts new file mode 100644 index 000000000..0d2d4d14b --- /dev/null +++ b/packages/sdk/src/client/streamableHttp.ts @@ -0,0 +1,7 @@ +export { + type StartSSEOptions, + StreamableHTTPClientTransport, + type StreamableHTTPClientTransportOptions, + StreamableHTTPError, + type StreamableHTTPReconnectionOptions +} from '@modelcontextprotocol/client'; diff --git a/packages/sdk/src/experimental/tasks.ts b/packages/sdk/src/experimental/tasks.ts new file mode 100644 index 000000000..cf883d266 --- /dev/null +++ b/packages/sdk/src/experimental/tasks.ts @@ -0,0 +1,5 @@ +// v1 compat: `@modelcontextprotocol/sdk/experimental/tasks` +// Re-exports the full server surface (task stores, handlers, and related types +// are scattered between core/public and server/experimental/tasks; the root +// barrel includes both). +export * from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/inMemory.ts b/packages/sdk/src/inMemory.ts new file mode 100644 index 000000000..7f2e9c6be --- /dev/null +++ b/packages/sdk/src/inMemory.ts @@ -0,0 +1 @@ +export { InMemoryTransport } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts new file mode 100644 index 000000000..71126956f --- /dev/null +++ b/packages/sdk/src/index.ts @@ -0,0 +1,90 @@ +// Root barrel for @modelcontextprotocol/sdk — the everything package. +// +// Re-exports the full public surface of the server, client, and node packages +// so consumers can `import { McpServer, Client, NodeStreamableHTTPServerTransport } from '@modelcontextprotocol/sdk'` +// without choosing a sub-package. +// +// Bundle-sensitive consumers (browser, Workers) should import from +// @modelcontextprotocol/client or @modelcontextprotocol/server directly instead. + +// Server gives us all server-specific exports + the entire core/public surface +// (spec types, error classes, transport interface, constants, guards). +export * from '@modelcontextprotocol/server'; + +// Node middleware — explicit named exports only. Not `export *`, because the +// node package re-exports core types from server and `export *` from both +// packages would collide on overlapping symbols (TS2308). +export { NodeStreamableHTTPServerTransport, type StreamableHTTPServerTransportOptions } from '@modelcontextprotocol/node'; +/** @deprecated Renamed to {@linkcode NodeStreamableHTTPServerTransport}. */ +export { NodeStreamableHTTPServerTransport as StreamableHTTPServerTransport } from '@modelcontextprotocol/node'; + +// Client-specific exports only — NOT `export *`, because client also re-exports +// core/public and the duplicate runtime-value identities (each package bundles +// core separately) trigger TS2308. core/public is already covered by server above. +export type { + AddClientAuthentication, + AssertionCallback, + AuthProvider, + AuthResult, + ClientAuthMethod, + ClientCredentialsProviderOptions, + ClientOptions, + CrossAppAccessContext, + CrossAppAccessProviderOptions, + DiscoverAndRequestJwtAuthGrantOptions, + JwtAuthGrantResult, + LoggingOptions, + Middleware, + OAuthClientProvider, + OAuthDiscoveryState, + OAuthServerInfo, + PrivateKeyJwtProviderOptions, + ReconnectionScheduler, + RequestJwtAuthGrantOptions, + RequestLogger, + SSEClientTransportOptions, + StartSSEOptions, + StaticPrivateKeyJwtProviderOptions, + StreamableHTTPClientTransportOptions, + StreamableHTTPReconnectionOptions +} from '@modelcontextprotocol/client'; +export { + applyMiddlewares, + auth, + buildDiscoveryUrls, + Client, + ClientCredentialsProvider, + createMiddleware, + createPrivateKeyJwtAuth, + CrossAppAccessProvider, + discoverAndRequestJwtAuthGrant, + discoverAuthorizationServerMetadata, + discoverOAuthMetadata, + discoverOAuthProtectedResourceMetadata, + discoverOAuthServerInfo, + exchangeAuthorization, + exchangeJwtAuthGrant, + ExperimentalClientTasks, + extractResourceMetadataUrl, + extractWWWAuthenticateParams, + fetchToken, + getSupportedElicitationModes, + isHttpsUrl, + parseErrorResponse, + prepareAuthorizationCodeRequest, + PrivateKeyJwtProvider, + refreshAuthorization, + registerClient, + requestJwtAuthorizationGrant, + selectClientAuthMethod, + selectResourceURL, + SSEClientTransport, + SseError, + startAuthorization, + StaticPrivateKeyJwtProvider, + StreamableHTTPClientTransport, + UnauthorizedError, + validateClientMetadataUrl, + withLogging, + withOAuth +} from '@modelcontextprotocol/client'; diff --git a/packages/sdk/src/server/auth/clients.ts b/packages/sdk/src/server/auth/clients.ts new file mode 100644 index 000000000..f7916a07a --- /dev/null +++ b/packages/sdk/src/server/auth/clients.ts @@ -0,0 +1,2 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/auth/clients.js` +export type { OAuthRegisteredClientsStore } from '@modelcontextprotocol/server-auth-legacy'; diff --git a/packages/sdk/src/server/auth/errors.ts b/packages/sdk/src/server/auth/errors.ts new file mode 100644 index 000000000..c65da81ca --- /dev/null +++ b/packages/sdk/src/server/auth/errors.ts @@ -0,0 +1,23 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/auth/errors.js` +export { OAuthErrorCode } from '@modelcontextprotocol/server'; +export { + AccessDeniedError, + CustomOAuthError, + InsufficientScopeError, + InvalidClientError, + InvalidClientMetadataError, + InvalidGrantError, + InvalidRequestError, + InvalidScopeError, + InvalidTargetError, + InvalidTokenError, + MethodNotAllowedError, + OAuthError, + ServerError, + TemporarilyUnavailableError, + TooManyRequestsError, + UnauthorizedClientError, + UnsupportedGrantTypeError, + UnsupportedResponseTypeError, + UnsupportedTokenTypeError +} from '@modelcontextprotocol/server-auth-legacy'; diff --git a/packages/sdk/src/server/auth/middleware/bearerAuth.ts b/packages/sdk/src/server/auth/middleware/bearerAuth.ts new file mode 100644 index 000000000..4cc64c6b5 --- /dev/null +++ b/packages/sdk/src/server/auth/middleware/bearerAuth.ts @@ -0,0 +1,2 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/auth/middleware/bearerAuth.js` +export { type BearerAuthMiddlewareOptions, requireBearerAuth } from '@modelcontextprotocol/server-auth-legacy'; diff --git a/packages/sdk/src/server/auth/provider.ts b/packages/sdk/src/server/auth/provider.ts new file mode 100644 index 000000000..a6dfaade3 --- /dev/null +++ b/packages/sdk/src/server/auth/provider.ts @@ -0,0 +1,2 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/auth/provider.js` +export type { AuthorizationParams, OAuthServerProvider, OAuthTokenVerifier } from '@modelcontextprotocol/server-auth-legacy'; diff --git a/packages/sdk/src/server/auth/router.ts b/packages/sdk/src/server/auth/router.ts new file mode 100644 index 000000000..3b0c1d037 --- /dev/null +++ b/packages/sdk/src/server/auth/router.ts @@ -0,0 +1,9 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/auth/router.js` +export { + type AuthMetadataOptions, + type AuthRouterOptions, + createOAuthMetadata, + getOAuthProtectedResourceMetadataUrl, + mcpAuthMetadataRouter, + mcpAuthRouter +} from '@modelcontextprotocol/server-auth-legacy'; diff --git a/packages/sdk/src/server/auth/types.ts b/packages/sdk/src/server/auth/types.ts new file mode 100644 index 000000000..f10524f66 --- /dev/null +++ b/packages/sdk/src/server/auth/types.ts @@ -0,0 +1,2 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/auth/types.js` +export type { AuthInfo } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/server/completable.ts b/packages/sdk/src/server/completable.ts new file mode 100644 index 000000000..84eee372f --- /dev/null +++ b/packages/sdk/src/server/completable.ts @@ -0,0 +1 @@ +export { completable, type CompletableSchema, type CompleteCallback, isCompletable } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/server/index.ts b/packages/sdk/src/server/index.ts new file mode 100644 index 000000000..6ecbcce01 --- /dev/null +++ b/packages/sdk/src/server/index.ts @@ -0,0 +1 @@ +export * from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/server/mcp.ts b/packages/sdk/src/server/mcp.ts new file mode 100644 index 000000000..56a5c924d --- /dev/null +++ b/packages/sdk/src/server/mcp.ts @@ -0,0 +1,21 @@ +export { + type AnyToolHandler, + type BaseToolCallback, + completable, + type CompletableSchema, + type CompleteCallback, + type CompleteResourceTemplateCallback, + isCompletable, + type ListResourcesCallback, + McpServer, + type PromptCallback, + type ReadResourceCallback, + type ReadResourceTemplateCallback, + type RegisteredPrompt, + type RegisteredResource, + type RegisteredResourceTemplate, + type RegisteredTool, + type ResourceMetadata, + ResourceTemplate, + type ToolCallback +} from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/server/sse.ts b/packages/sdk/src/server/sse.ts new file mode 100644 index 000000000..e81316838 --- /dev/null +++ b/packages/sdk/src/server/sse.ts @@ -0,0 +1,9 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/sse.js` +// The SSE server transport was removed in v2. Use Streamable HTTP instead. + +/** + * @deprecated SSE server transport was removed in v2. Use {@link NodeStreamableHTTPServerTransport} + * (from `@modelcontextprotocol/node`) instead. This alias is provided for source-compat only; + * the wire behavior is Streamable HTTP, not legacy SSE. + */ +export { NodeStreamableHTTPServerTransport as SSEServerTransport } from '@modelcontextprotocol/node'; diff --git a/packages/sdk/src/server/stdio.ts b/packages/sdk/src/server/stdio.ts new file mode 100644 index 000000000..9a7f02eee --- /dev/null +++ b/packages/sdk/src/server/stdio.ts @@ -0,0 +1 @@ +export { StdioServerTransport } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/server/streamableHttp.ts b/packages/sdk/src/server/streamableHttp.ts new file mode 100644 index 000000000..4cfda1b20 --- /dev/null +++ b/packages/sdk/src/server/streamableHttp.ts @@ -0,0 +1,4 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/streamableHttp.js` +export * from '@modelcontextprotocol/node'; +/** @deprecated Renamed to {@link NodeStreamableHTTPServerTransport} and moved to `@modelcontextprotocol/node` in v2. */ +export { NodeStreamableHTTPServerTransport as StreamableHTTPServerTransport } from '@modelcontextprotocol/node'; diff --git a/packages/sdk/src/server/webStandardStreamableHttp.ts b/packages/sdk/src/server/webStandardStreamableHttp.ts new file mode 100644 index 000000000..75870a36b --- /dev/null +++ b/packages/sdk/src/server/webStandardStreamableHttp.ts @@ -0,0 +1,4 @@ +export { + WebStandardStreamableHTTPServerTransport, + type WebStandardStreamableHTTPServerTransportOptions +} from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/server/zod-compat.ts b/packages/sdk/src/server/zod-compat.ts new file mode 100644 index 000000000..5486bac0e --- /dev/null +++ b/packages/sdk/src/server/zod-compat.ts @@ -0,0 +1,16 @@ +// v1 compat: `@modelcontextprotocol/sdk/server/zod-compat.js` +// v1 unified Zod v3 + v4 types. v2 is Zod v4-only, so these collapse to the +// v4 types. Prefer `StandardSchemaV1` / `StandardSchemaWithJSON` for new code. + +import type * as z from 'zod'; + +/** @deprecated Use `StandardSchemaV1` (any Standard Schema) or a Zod type directly in v2. */ +export type AnySchema = z.core.$ZodType; + +/** @deprecated Use `Record` directly in v2. */ +export type ZodRawShapeCompat = Record; + +/** @deprecated */ +export type AnyObjectSchema = z.core.$ZodObject | AnySchema; + +export type { StandardSchemaV1, StandardSchemaWithJSON } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/shared/auth.ts b/packages/sdk/src/shared/auth.ts new file mode 100644 index 000000000..3c8b49505 --- /dev/null +++ b/packages/sdk/src/shared/auth.ts @@ -0,0 +1,33 @@ +// v1 compat: `@modelcontextprotocol/sdk/shared/auth.js` +export type { + AuthorizationServerMetadata, + OAuthClientInformation, + OAuthClientInformationFull, + OAuthClientInformationMixed, + OAuthClientMetadata, + OAuthClientRegistrationError, + OAuthErrorResponse, + OAuthMetadata, + OAuthProtectedResourceMetadata, + OAuthTokenRevocationRequest, + OAuthTokens, + OpenIdProviderDiscoveryMetadata, + OpenIdProviderMetadata +} from '@modelcontextprotocol/server'; +export { OAuthError, OAuthErrorCode } from '@modelcontextprotocol/server'; +export { + IdJagTokenExchangeResponseSchema, + OAuthClientInformationFullSchema, + OAuthClientInformationSchema, + OAuthClientMetadataSchema, + OAuthClientRegistrationErrorSchema, + OAuthErrorResponseSchema, + OAuthMetadataSchema, + OAuthProtectedResourceMetadataSchema, + OAuthTokenRevocationRequestSchema, + OAuthTokensSchema, + OpenIdProviderDiscoveryMetadataSchema, + OpenIdProviderMetadataSchema, + OptionalSafeUrlSchema, + SafeUrlSchema +} from '@modelcontextprotocol/server/zod-schemas'; diff --git a/packages/sdk/src/shared/protocol.ts b/packages/sdk/src/shared/protocol.ts new file mode 100644 index 000000000..bd9b6cdc1 --- /dev/null +++ b/packages/sdk/src/shared/protocol.ts @@ -0,0 +1,15 @@ +// v1 compat: `@modelcontextprotocol/sdk/shared/protocol.js` + +export type { + BaseContext, + ClientContext, + NotificationOptions, + ProtocolOptions, + RequestOptions, + ServerContext +} from '@modelcontextprotocol/server'; +export { DEFAULT_REQUEST_TIMEOUT_MSEC, Protocol } from '@modelcontextprotocol/server'; + +/** @deprecated Use {@link ServerContext} (server handlers) or {@link ClientContext} (client handlers) in v2. */ +// eslint-disable-next-line @typescript-eslint/no-unused-vars +export type RequestHandlerExtra<_ReqT = unknown, _NotifT = unknown> = import('@modelcontextprotocol/server').ServerContext; diff --git a/packages/sdk/src/shared/stdio.ts b/packages/sdk/src/shared/stdio.ts new file mode 100644 index 000000000..87ff5ce32 --- /dev/null +++ b/packages/sdk/src/shared/stdio.ts @@ -0,0 +1 @@ +export { deserializeMessage, ReadBuffer, serializeMessage } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/shared/transport.ts b/packages/sdk/src/shared/transport.ts new file mode 100644 index 000000000..2cee788b9 --- /dev/null +++ b/packages/sdk/src/shared/transport.ts @@ -0,0 +1,2 @@ +export type { FetchLike, Transport, TransportSendOptions } from '@modelcontextprotocol/server'; +export { createFetchWithInit } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/stdio.ts b/packages/sdk/src/stdio.ts new file mode 100644 index 000000000..491941f10 --- /dev/null +++ b/packages/sdk/src/stdio.ts @@ -0,0 +1,3 @@ +export type { StdioServerParameters } from '@modelcontextprotocol/client'; +export { DEFAULT_INHERITED_ENV_VARS, getDefaultEnvironment, StdioClientTransport } from '@modelcontextprotocol/client'; +export { StdioServerTransport } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/types.ts b/packages/sdk/src/types.ts new file mode 100644 index 000000000..db973e627 --- /dev/null +++ b/packages/sdk/src/types.ts @@ -0,0 +1,38 @@ +// v1 compat: `@modelcontextprotocol/sdk/types.js` +// In v1 this was the giant types.ts file with all spec types + Zod schemas. +// v2 splits them: spec TypeScript types live in the server barrel (via core/public), +// zod schema constants live at @modelcontextprotocol/server/zod-schemas. + +export * from '@modelcontextprotocol/server'; +export * from '@modelcontextprotocol/server/zod-schemas'; +// Explicit tie-break for symbols both barrels export. +export { fromJsonSchema } from '@modelcontextprotocol/server'; +// Explicit re-exports of commonly-used spec types (belt-and-suspenders over the +// wildcard above; some d.ts toolchains drop type-only symbols across export-*). +export type { + CallToolResult, + ClientCapabilities, + GetPromptResult, + Implementation, + ListResourcesResult, + ListToolsResult, + Prompt, + ReadResourceResult, + Resource, + ServerCapabilities, + Tool +} from '@modelcontextprotocol/server'; + +/** + * @deprecated Use {@link ResourceTemplateType}. + * + * v1's `types.js` exported the spec-derived ResourceTemplate data type under + * this name. v2 renamed it to `ResourceTemplateType` to avoid clashing with the + * `ResourceTemplate` helper class exported by `@modelcontextprotocol/server`. + */ +export type { ResourceTemplateType as ResourceTemplate } from '@modelcontextprotocol/server'; + +/** @deprecated Use {@link ProtocolError}. */ +export { ProtocolError as McpError } from '@modelcontextprotocol/server'; +/** @deprecated Use {@link ProtocolErrorCode}. */ +export { ProtocolErrorCode as ErrorCode } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/validation/ajv-provider.ts b/packages/sdk/src/validation/ajv-provider.ts new file mode 100644 index 000000000..8f6e0f51e --- /dev/null +++ b/packages/sdk/src/validation/ajv-provider.ts @@ -0,0 +1 @@ +export { AjvJsonSchemaValidator } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/src/validation/cfworker-provider.ts b/packages/sdk/src/validation/cfworker-provider.ts new file mode 100644 index 000000000..08b23bf34 --- /dev/null +++ b/packages/sdk/src/validation/cfworker-provider.ts @@ -0,0 +1 @@ +export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server/validators/cf-worker'; diff --git a/packages/sdk/src/validation/types.ts b/packages/sdk/src/validation/types.ts new file mode 100644 index 000000000..b83e4cb5a --- /dev/null +++ b/packages/sdk/src/validation/types.ts @@ -0,0 +1 @@ +export type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator, JsonSchemaValidatorResult } from '@modelcontextprotocol/server'; diff --git a/packages/sdk/test/compat.test.ts b/packages/sdk/test/compat.test.ts new file mode 100644 index 000000000..f254e7c41 --- /dev/null +++ b/packages/sdk/test/compat.test.ts @@ -0,0 +1,40 @@ +import { describe, expect, test } from 'vitest'; + +import * as serverIndex from '../src/server/index.js'; +import * as serverMcp from '../src/server/mcp.js'; +import * as serverStdio from '../src/server/stdio.js'; +import * as serverSHttp from '../src/server/streamableHttp.js'; +import * as sharedProtocol from '../src/shared/protocol.js'; +import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '../src/types.js'; + +describe('@modelcontextprotocol/sdk meta-package v1 paths', () => { + test('types.js re-exports zod schemas + error aliases', () => { + expect(CallToolRequestSchema).toBeDefined(); + expect(ListToolsRequestSchema).toBeDefined(); + expect(McpError).toBeDefined(); + expect(ErrorCode.MethodNotFound).toBeDefined(); + }); + + test('server/mcp.js exports McpServer', () => { + expect(serverMcp.McpServer).toBeDefined(); + expect(serverMcp.ResourceTemplate).toBeDefined(); + }); + + test('server/index.js exports Server (alias)', () => { + expect(serverIndex.Server).toBeDefined(); + }); + + test('server/stdio.js exports StdioServerTransport', () => { + expect(serverStdio.StdioServerTransport).toBeDefined(); + }); + + test('server/streamableHttp.js exports the v1 alias', () => { + expect(serverSHttp.StreamableHTTPServerTransport).toBeDefined(); + expect(serverSHttp.NodeStreamableHTTPServerTransport).toBeDefined(); + }); + + test('shared/protocol.js exports Protocol + RequestHandlerExtra type alias', () => { + expect(sharedProtocol.Protocol).toBeDefined(); + expect(sharedProtocol.DEFAULT_REQUEST_TIMEOUT_MSEC).toBeGreaterThan(0); + }); +}); diff --git a/packages/sdk/tsconfig.json b/packages/sdk/tsconfig.json new file mode 100644 index 000000000..5a30e6c18 --- /dev/null +++ b/packages/sdk/tsconfig.json @@ -0,0 +1,24 @@ +{ + "extends": "@modelcontextprotocol/tsconfig", + "include": ["./"], + "exclude": ["node_modules", "dist"], + "compilerOptions": { + "paths": { + "*": ["./*"], + "@modelcontextprotocol/core": ["./node_modules/@modelcontextprotocol/core/src/index.ts"], + "@modelcontextprotocol/core/public": ["./node_modules/@modelcontextprotocol/core/src/exports/public/index.ts"], + "@modelcontextprotocol/client": ["./node_modules/@modelcontextprotocol/client/src/index.ts"], + "@modelcontextprotocol/client/stdio": ["./node_modules/@modelcontextprotocol/client/src/client/stdio.ts"], + "@modelcontextprotocol/server": ["./node_modules/@modelcontextprotocol/server/src/index.ts"], + "@modelcontextprotocol/server/stdio": ["./node_modules/@modelcontextprotocol/server/src/server/stdio.ts"], + "@modelcontextprotocol/server/zod-schemas": ["./node_modules/@modelcontextprotocol/server/src/zodSchemas.ts"], + "@modelcontextprotocol/server/validators/cf-worker": ["./node_modules/@modelcontextprotocol/server/src/validators/cfWorker.ts"], + "@modelcontextprotocol/node": ["./node_modules/@modelcontextprotocol/node/src/index.ts"], + "@modelcontextprotocol/node/sse": ["./node_modules/@modelcontextprotocol/node/src/sse.ts"], + "@modelcontextprotocol/server-auth-legacy": ["./node_modules/@modelcontextprotocol/server-auth-legacy/src/index.ts"], + "@modelcontextprotocol/client/_shims": ["./node_modules/@modelcontextprotocol/client/src/shimsNode.ts"], + "@modelcontextprotocol/server/_shims": ["./node_modules/@modelcontextprotocol/server/src/shimsNode.ts"], + "@modelcontextprotocol/test-helpers": ["./node_modules/@modelcontextprotocol/test-helpers/src/index.ts"] + } + } +} diff --git a/packages/sdk/tsdown.config.ts b/packages/sdk/tsdown.config.ts new file mode 100644 index 000000000..b6ee521f8 --- /dev/null +++ b/packages/sdk/tsdown.config.ts @@ -0,0 +1,46 @@ +import { defineConfig } from 'tsdown'; + +export default defineConfig({ + failOnWarn: false, + entry: [ + 'src/index.ts', + 'src/stdio.ts', + 'src/types.ts', + 'src/inMemory.ts', + 'src/experimental/tasks.ts', + 'src/server/index.ts', + 'src/server/mcp.ts', + 'src/server/zod-compat.ts', + 'src/server/completable.ts', + 'src/server/sse.ts', + 'src/server/stdio.ts', + 'src/server/streamableHttp.ts', + 'src/server/auth/types.ts', + 'src/server/auth/errors.ts', + 'src/server/auth/middleware/bearerAuth.ts', + 'src/server/auth/router.ts', + 'src/server/auth/provider.ts', + 'src/server/auth/clients.ts', + 'src/client/index.ts', + 'src/client/stdio.ts', + 'src/client/streamableHttp.ts', + 'src/client/sse.ts', + 'src/client/auth.ts', + 'src/shared/protocol.ts', + 'src/shared/transport.ts', + 'src/shared/auth.ts', + 'src/shared/stdio.ts', + 'src/server/webStandardStreamableHttp.ts', + 'src/validation/types.ts', + 'src/validation/cfworker-provider.ts', + 'src/validation/ajv-provider.ts' + ], + format: ['esm'], + outDir: 'dist', + clean: true, + sourcemap: true, + target: 'esnext', + platform: 'node', + dts: true, + external: [/^@modelcontextprotocol\//] +}); diff --git a/packages/sdk/vitest.config.js b/packages/sdk/vitest.config.js new file mode 100644 index 000000000..496fca320 --- /dev/null +++ b/packages/sdk/vitest.config.js @@ -0,0 +1,3 @@ +import baseConfig from '@modelcontextprotocol/vitest-config'; + +export default baseConfig; diff --git a/packages/server-auth-legacy/README.md b/packages/server-auth-legacy/README.md new file mode 100644 index 000000000..a21d0b710 --- /dev/null +++ b/packages/server-auth-legacy/README.md @@ -0,0 +1,22 @@ +# @modelcontextprotocol/server-auth-legacy + + +> [!WARNING] +> **Deprecated.** This package is a frozen copy of the v1 SDK's `src/server/auth/` Authorization Server helpers (`mcpAuthRouter`, `ProxyOAuthServerProvider`, etc.). It exists solely to ease migration from `@modelcontextprotocol/sdk` v1 and will not receive new features or non-critical bug fixes. + +The v2 SDK no longer ships an OAuth Authorization Server implementation. MCP servers are Resource Servers; running your own AS is an anti-pattern for most deployments. + +## Migration + +- **Resource Server glue** (`requireBearerAuth`, `mcpAuthMetadataRouter`, Protected Resource Metadata): use the first-class helpers in `@modelcontextprotocol/express`. +- **Authorization Server**: use a dedicated IdP (Auth0, Keycloak, Okta, etc.) or a purpose-built OAuth library. + +## Usage (legacy) + +```ts +import express from 'express'; +import { mcpAuthRouter, ProxyOAuthServerProvider } from '@modelcontextprotocol/server-auth-legacy'; + +const app = express(); +app.use(mcpAuthRouter({ provider, issuerUrl: new URL('https://example.com') })); +``` diff --git a/packages/server-auth-legacy/eslint.config.mjs b/packages/server-auth-legacy/eslint.config.mjs new file mode 100644 index 000000000..4f034f223 --- /dev/null +++ b/packages/server-auth-legacy/eslint.config.mjs @@ -0,0 +1,12 @@ +// @ts-check + +import baseConfig from '@modelcontextprotocol/eslint-config'; + +export default [ + ...baseConfig, + { + settings: { + 'import/internal-regex': '^@modelcontextprotocol/core' + } + } +]; diff --git a/packages/server-auth-legacy/package.json b/packages/server-auth-legacy/package.json new file mode 100644 index 000000000..39c505619 --- /dev/null +++ b/packages/server-auth-legacy/package.json @@ -0,0 +1,84 @@ +{ + "name": "@modelcontextprotocol/server-auth-legacy", + "private": false, + "version": "2.0.0-alpha.2", + "description": "Frozen v1 OAuth Authorization Server helpers (mcpAuthRouter, ProxyOAuthServerProvider) for the Model Context Protocol TypeScript SDK. Deprecated; use a dedicated OAuth server in production.", + "deprecated": "The MCP SDK no longer ships an Authorization Server implementation. This package is a frozen copy of the v1 src/server/auth helpers for migration purposes only and will not receive new features. Use a dedicated OAuth Authorization Server (e.g. an IdP) and the Resource Server helpers in @modelcontextprotocol/express instead.", + "license": "MIT", + "author": "Anthropic, PBC (https://anthropic.com)", + "homepage": "https://modelcontextprotocol.io", + "bugs": "https://github.com/modelcontextprotocol/typescript-sdk/issues", + "type": "module", + "repository": { + "type": "git", + "url": "git+https://github.com/modelcontextprotocol/typescript-sdk.git" + }, + "engines": { + "node": ">=20" + }, + "keywords": [ + "modelcontextprotocol", + "mcp", + "oauth", + "express", + "legacy" + ], + "types": "./dist/index.d.mts", + "exports": { + ".": { + "types": "./dist/index.d.mts", + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" + } + }, + "files": [ + "dist" + ], + "scripts": { + "typecheck": "tsgo -p tsconfig.json --noEmit", + "build": "tsdown", + "build:watch": "tsdown --watch", + "prepack": "npm run build", + "lint": "eslint src/ && prettier --ignore-path ../../.prettierignore --check .", + "lint:fix": "eslint src/ --fix && prettier --ignore-path ../../.prettierignore --write .", + "check": "pnpm run typecheck && pnpm run lint", + "test": "vitest run", + "test:watch": "vitest" + }, + "dependencies": { + "cors": "catalog:runtimeServerOnly", + "express-rate-limit": "^8.2.1", + "pkce-challenge": "catalog:runtimeShared", + "zod": "catalog:runtimeShared" + }, + "peerDependencies": { + "express": "catalog:runtimeServerOnly" + }, + "peerDependenciesMeta": { + "express": { + "optional": true + } + }, + "devDependencies": { + "@modelcontextprotocol/core": "workspace:^", + "@modelcontextprotocol/tsconfig": "workspace:^", + "@modelcontextprotocol/vitest-config": "workspace:^", + "@modelcontextprotocol/eslint-config": "workspace:^", + "@eslint/js": "catalog:devTools", + "@types/cors": "catalog:devTools", + "@types/express": "catalog:devTools", + "@types/express-serve-static-core": "catalog:devTools", + "@types/supertest": "catalog:devTools", + "@typescript/native-preview": "catalog:devTools", + "eslint": "catalog:devTools", + "eslint-config-prettier": "catalog:devTools", + "eslint-plugin-n": "catalog:devTools", + "express": "catalog:runtimeServerOnly", + "prettier": "catalog:devTools", + "supertest": "catalog:devTools", + "tsdown": "catalog:devTools", + "typescript": "catalog:devTools", + "typescript-eslint": "catalog:devTools", + "vitest": "catalog:devTools" + } +} diff --git a/packages/server-auth-legacy/src/clients.ts b/packages/server-auth-legacy/src/clients.ts new file mode 100644 index 000000000..f6aca1be9 --- /dev/null +++ b/packages/server-auth-legacy/src/clients.ts @@ -0,0 +1,22 @@ +import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; + +/** + * Stores information about registered OAuth clients for this server. + */ +export interface OAuthRegisteredClientsStore { + /** + * Returns information about a registered client, based on its ID. + */ + getClient(clientId: string): OAuthClientInformationFull | undefined | Promise; + + /** + * Registers a new client with the server. The client ID and secret will be automatically generated by the library. A modified version of the client information can be returned to reflect specific values enforced by the server. + * + * NOTE: Implementations should NOT delete expired client secrets in-place. Auth middleware provided by this library will automatically check the `client_secret_expires_at` field and reject requests with expired secrets. Any custom logic for authenticating clients should check the `client_secret_expires_at` field as well. + * + * If unimplemented, dynamic client registration is unsupported. + */ + registerClient?( + client: Omit + ): OAuthClientInformationFull | Promise; +} diff --git a/packages/server-auth-legacy/src/errors.ts b/packages/server-auth-legacy/src/errors.ts new file mode 100644 index 000000000..eac277779 --- /dev/null +++ b/packages/server-auth-legacy/src/errors.ts @@ -0,0 +1,212 @@ +import type { OAuthErrorResponse } from '@modelcontextprotocol/core'; + +/** + * Base class for all OAuth errors + */ +export class OAuthError extends Error { + static errorCode: string; + + constructor( + message: string, + public readonly errorUri?: string + ) { + super(message); + this.name = this.constructor.name; + } + + /** + * Converts the error to a standard OAuth error response object + */ + toResponseObject(): OAuthErrorResponse { + const response: OAuthErrorResponse = { + error: this.errorCode, + error_description: this.message + }; + + if (this.errorUri) { + response.error_uri = this.errorUri; + } + + return response; + } + + get errorCode(): string { + return (this.constructor as typeof OAuthError).errorCode; + } +} + +/** + * Invalid request error - The request is missing a required parameter, + * includes an invalid parameter value, includes a parameter more than once, + * or is otherwise malformed. + */ +export class InvalidRequestError extends OAuthError { + static override errorCode = 'invalid_request'; +} + +/** + * Invalid client error - Client authentication failed (e.g., unknown client, no client + * authentication included, or unsupported authentication method). + */ +export class InvalidClientError extends OAuthError { + static override errorCode = 'invalid_client'; +} + +/** + * Invalid grant error - The provided authorization grant or refresh token is + * invalid, expired, revoked, does not match the redirection URI used in the + * authorization request, or was issued to another client. + */ +export class InvalidGrantError extends OAuthError { + static override errorCode = 'invalid_grant'; +} + +/** + * Unauthorized client error - The authenticated client is not authorized to use + * this authorization grant type. + */ +export class UnauthorizedClientError extends OAuthError { + static override errorCode = 'unauthorized_client'; +} + +/** + * Unsupported grant type error - The authorization grant type is not supported + * by the authorization server. + */ +export class UnsupportedGrantTypeError extends OAuthError { + static override errorCode = 'unsupported_grant_type'; +} + +/** + * Invalid scope error - The requested scope is invalid, unknown, malformed, or + * exceeds the scope granted by the resource owner. + */ +export class InvalidScopeError extends OAuthError { + static override errorCode = 'invalid_scope'; +} + +/** + * Access denied error - The resource owner or authorization server denied the request. + */ +export class AccessDeniedError extends OAuthError { + static override errorCode = 'access_denied'; +} + +/** + * Server error - The authorization server encountered an unexpected condition + * that prevented it from fulfilling the request. + */ +export class ServerError extends OAuthError { + static override errorCode = 'server_error'; +} + +/** + * Temporarily unavailable error - The authorization server is currently unable to + * handle the request due to a temporary overloading or maintenance of the server. + */ +export class TemporarilyUnavailableError extends OAuthError { + static override errorCode = 'temporarily_unavailable'; +} + +/** + * Unsupported response type error - The authorization server does not support + * obtaining an authorization code using this method. + */ +export class UnsupportedResponseTypeError extends OAuthError { + static override errorCode = 'unsupported_response_type'; +} + +/** + * Unsupported token type error - The authorization server does not support + * the requested token type. + */ +export class UnsupportedTokenTypeError extends OAuthError { + static override errorCode = 'unsupported_token_type'; +} + +/** + * Invalid token error - The access token provided is expired, revoked, malformed, + * or invalid for other reasons. + */ +export class InvalidTokenError extends OAuthError { + static override errorCode = 'invalid_token'; +} + +/** + * Method not allowed error - The HTTP method used is not allowed for this endpoint. + * (Custom, non-standard error) + */ +export class MethodNotAllowedError extends OAuthError { + static override errorCode = 'method_not_allowed'; +} + +/** + * Too many requests error - Rate limit exceeded. + * (Custom, non-standard error based on RFC 6585) + */ +export class TooManyRequestsError extends OAuthError { + static override errorCode = 'too_many_requests'; +} + +/** + * Invalid client metadata error - The client metadata is invalid. + * (Custom error for dynamic client registration - RFC 7591) + */ +export class InvalidClientMetadataError extends OAuthError { + static override errorCode = 'invalid_client_metadata'; +} + +/** + * Insufficient scope error - The request requires higher privileges than provided by the access token. + */ +export class InsufficientScopeError extends OAuthError { + static override errorCode = 'insufficient_scope'; +} + +/** + * Invalid target error - The requested resource is invalid, missing, unknown, or malformed. + * (Custom error for resource indicators - RFC 8707) + */ +export class InvalidTargetError extends OAuthError { + static override errorCode = 'invalid_target'; +} + +/** + * A utility class for defining one-off error codes + */ +export class CustomOAuthError extends OAuthError { + constructor( + private readonly customErrorCode: string, + message: string, + errorUri?: string + ) { + super(message, errorUri); + } + + override get errorCode(): string { + return this.customErrorCode; + } +} + +/** + * A full list of all OAuthErrors, enabling parsing from error responses + */ +export const OAUTH_ERRORS = { + [InvalidRequestError.errorCode]: InvalidRequestError, + [InvalidClientError.errorCode]: InvalidClientError, + [InvalidGrantError.errorCode]: InvalidGrantError, + [UnauthorizedClientError.errorCode]: UnauthorizedClientError, + [UnsupportedGrantTypeError.errorCode]: UnsupportedGrantTypeError, + [InvalidScopeError.errorCode]: InvalidScopeError, + [AccessDeniedError.errorCode]: AccessDeniedError, + [ServerError.errorCode]: ServerError, + [TemporarilyUnavailableError.errorCode]: TemporarilyUnavailableError, + [UnsupportedResponseTypeError.errorCode]: UnsupportedResponseTypeError, + [UnsupportedTokenTypeError.errorCode]: UnsupportedTokenTypeError, + [InvalidTokenError.errorCode]: InvalidTokenError, + [MethodNotAllowedError.errorCode]: MethodNotAllowedError, + [TooManyRequestsError.errorCode]: TooManyRequestsError, + [InvalidClientMetadataError.errorCode]: InvalidClientMetadataError, + [InsufficientScopeError.errorCode]: InsufficientScopeError, + [InvalidTargetError.errorCode]: InvalidTargetError +} as const; diff --git a/packages/server-auth-legacy/src/handlers/authorize.ts b/packages/server-auth-legacy/src/handlers/authorize.ts new file mode 100644 index 000000000..3f84bf329 --- /dev/null +++ b/packages/server-auth-legacy/src/handlers/authorize.ts @@ -0,0 +1,203 @@ +import type { RequestHandler } from 'express'; +import express from 'express'; +import type { Options as RateLimitOptions } from 'express-rate-limit'; +import { rateLimit } from 'express-rate-limit'; +import * as z from 'zod/v4'; + +import { InvalidClientError, InvalidRequestError, OAuthError, ServerError, TooManyRequestsError } from '../errors.js'; +import { allowedMethods } from '../middleware/allowedMethods.js'; +import type { OAuthServerProvider } from '../provider.js'; + +export type AuthorizationHandlerOptions = { + provider: OAuthServerProvider; + /** + * Rate limiting configuration for the authorization endpoint. + * Set to false to disable rate limiting for this endpoint. + */ + rateLimit?: Partial | false; +}; + +const LOOPBACK_HOSTS = new Set(['localhost', '127.0.0.1', '[::1]']); + +/** + * Validates a requested redirect_uri against a registered one. + * + * Per RFC 8252 §7.3 (OAuth 2.0 for Native Apps), authorization servers MUST + * allow any port for loopback redirect URIs (localhost, 127.0.0.1, [::1]) to + * accommodate native clients that obtain an ephemeral port from the OS. For + * non-loopback URIs, exact match is required. + * + * @see https://datatracker.ietf.org/doc/html/rfc8252#section-7.3 + */ +export function redirectUriMatches(requested: string, registered: string): boolean { + if (requested === registered) { + return true; + } + let req: URL, reg: URL; + try { + req = new URL(requested); + reg = new URL(registered); + } catch { + return false; + } + // Port relaxation only applies when both URIs target a loopback host. + if (!LOOPBACK_HOSTS.has(req.hostname) || !LOOPBACK_HOSTS.has(reg.hostname)) { + return false; + } + // RFC 8252 relaxes the port only — scheme, host, path, and query must + // still match exactly. Note: hostname must match exactly too (the RFC + // does not allow localhost↔127.0.0.1 cross-matching). + return req.protocol === reg.protocol && req.hostname === reg.hostname && req.pathname === reg.pathname && req.search === reg.search; +} + +// Parameters that must be validated in order to issue redirects. +const ClientAuthorizationParamsSchema = z.object({ + client_id: z.string(), + redirect_uri: z + .string() + .optional() + .refine(value => value === undefined || URL.canParse(value), { message: 'redirect_uri must be a valid URL' }) +}); + +// Parameters that must be validated for a successful authorization request. Failure can be reported to the redirect URI. +const RequestAuthorizationParamsSchema = z.object({ + response_type: z.literal('code'), + code_challenge: z.string(), + code_challenge_method: z.literal('S256'), + scope: z.string().optional(), + state: z.string().optional(), + resource: z.string().url().optional() +}); + +export function authorizationHandler({ provider, rateLimit: rateLimitConfig }: AuthorizationHandlerOptions): RequestHandler { + // Create a router to apply middleware + const router = express.Router(); + router.use(allowedMethods(['GET', 'POST'])); + router.use(express.urlencoded({ extended: false })); + + // Apply rate limiting unless explicitly disabled + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 15 * 60 * 1000, // 15 minutes + max: 100, // 100 requests per windowMs + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for authorization requests').toResponseObject(), + ...rateLimitConfig + }) + ); + } + + router.all('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); + + // In the authorization flow, errors are split into two categories: + // 1. Pre-redirect errors (direct response with 400) + // 2. Post-redirect errors (redirect with error parameters) + + // Phase 1: Validate client_id and redirect_uri. Any errors here must be direct responses. + let client_id, client; + let redirect_uri: string; + try { + const result = ClientAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + if (!result.success) { + throw new InvalidRequestError(result.error.message); + } + + client_id = result.data.client_id; + const requested_redirect_uri = result.data.redirect_uri; + + client = await provider.clientsStore.getClient(client_id); + if (!client) { + throw new InvalidClientError('Invalid client_id'); + } + + if (requested_redirect_uri !== undefined) { + const requested = requested_redirect_uri; + if (!client.redirect_uris.some(registered => redirectUriMatches(requested, registered))) { + throw new InvalidRequestError('Unregistered redirect_uri'); + } + redirect_uri = requested_redirect_uri; + } else if (client.redirect_uris.length === 1) { + redirect_uri = client.redirect_uris[0]!; + } else { + throw new InvalidRequestError('redirect_uri must be specified when client has multiple registered URIs'); + } + } catch (error) { + // Pre-redirect errors - return direct response + // + // These don't need to be JSON encoded, as they'll be displayed in a user + // agent, but OTOH they all represent exceptional situations (arguably, + // "programmer error"), so presenting a nice HTML page doesn't help the + // user anyway. + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + + return; + } + + // Phase 2: Validate other parameters. Any errors here should go into redirect responses. + let state; + try { + // Parse and validate authorization parameters + const parseResult = RequestAuthorizationParamsSchema.safeParse(req.method === 'POST' ? req.body : req.query); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const { scope, code_challenge, resource } = parseResult.data; + state = parseResult.data.state; + + // Validate scopes + let requestedScopes: string[] = []; + if (scope !== undefined) { + requestedScopes = scope.split(' '); + } + + // All validation passed, proceed with authorization + await provider.authorize( + client, + { + state, + scopes: requestedScopes, + redirectUri: redirect_uri, + codeChallenge: code_challenge, + resource: resource ? new URL(resource) : undefined + }, + res + ); + } catch (error) { + // Post-redirect errors - redirect with error parameters + if (error instanceof OAuthError) { + res.redirect(302, createErrorRedirect(redirect_uri, error, state)); + } else { + const serverError = new ServerError('Internal Server Error'); + res.redirect(302, createErrorRedirect(redirect_uri, serverError, state)); + } + } + }); + + return router; +} + +/** + * Helper function to create redirect URL with error parameters + */ +function createErrorRedirect(redirectUri: string, error: OAuthError, state?: string): string { + const errorUrl = new URL(redirectUri); + errorUrl.searchParams.set('error', error.errorCode); + errorUrl.searchParams.set('error_description', error.message); + if (error.errorUri) { + errorUrl.searchParams.set('error_uri', error.errorUri); + } + if (state) { + errorUrl.searchParams.set('state', state); + } + return errorUrl.href; +} diff --git a/packages/server-auth-legacy/src/handlers/metadata.ts b/packages/server-auth-legacy/src/handlers/metadata.ts new file mode 100644 index 000000000..529a6e57a --- /dev/null +++ b/packages/server-auth-legacy/src/handlers/metadata.ts @@ -0,0 +1,21 @@ +import type { OAuthMetadata, OAuthProtectedResourceMetadata } from '@modelcontextprotocol/core'; +import cors from 'cors'; +import type { RequestHandler } from 'express'; +import express from 'express'; + +import { allowedMethods } from '../middleware/allowedMethods.js'; + +export function metadataHandler(metadata: OAuthMetadata | OAuthProtectedResourceMetadata): RequestHandler { + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); + + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); + + router.use(allowedMethods(['GET', 'OPTIONS'])); + router.get('/', (req, res) => { + res.status(200).json(metadata); + }); + + return router; +} diff --git a/packages/server-auth-legacy/src/handlers/register.ts b/packages/server-auth-legacy/src/handlers/register.ts new file mode 100644 index 000000000..6ca5324eb --- /dev/null +++ b/packages/server-auth-legacy/src/handlers/register.ts @@ -0,0 +1,124 @@ +import crypto from 'node:crypto'; + +import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; +import { OAuthClientMetadataSchema } from '@modelcontextprotocol/core'; +import cors from 'cors'; +import type { RequestHandler } from 'express'; +import express from 'express'; +import type { Options as RateLimitOptions } from 'express-rate-limit'; +import { rateLimit } from 'express-rate-limit'; + +import type { OAuthRegisteredClientsStore } from '../clients.js'; +import { InvalidClientMetadataError, OAuthError, ServerError, TooManyRequestsError } from '../errors.js'; +import { allowedMethods } from '../middleware/allowedMethods.js'; + +export type ClientRegistrationHandlerOptions = { + /** + * A store used to save information about dynamically registered OAuth clients. + */ + clientsStore: OAuthRegisteredClientsStore; + + /** + * The number of seconds after which to expire issued client secrets, or 0 to prevent expiration of client secrets (not recommended). + * + * If not set, defaults to 30 days. + */ + clientSecretExpirySeconds?: number; + + /** + * Rate limiting configuration for the client registration endpoint. + * Set to false to disable rate limiting for this endpoint. + * Registration endpoints are particularly sensitive to abuse and should be rate limited. + */ + rateLimit?: Partial | false; + + /** + * Whether to generate a client ID before calling the client registration endpoint. + * + * If not set, defaults to true. + */ + clientIdGeneration?: boolean; +}; + +const DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS = 30 * 24 * 60 * 60; // 30 days + +export function clientRegistrationHandler({ + clientsStore, + clientSecretExpirySeconds = DEFAULT_CLIENT_SECRET_EXPIRY_SECONDS, + rateLimit: rateLimitConfig, + clientIdGeneration = true +}: ClientRegistrationHandlerOptions): RequestHandler { + if (!clientsStore.registerClient) { + throw new Error('Client registration store does not support registering clients'); + } + + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); + + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); + + router.use(allowedMethods(['POST'])); + router.use(express.json()); + + // Apply rate limiting unless explicitly disabled - stricter limits for registration + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 60 * 60 * 1000, // 1 hour + max: 20, // 20 requests per hour - stricter as registration is sensitive + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for client registration requests').toResponseObject(), + ...rateLimitConfig + }) + ); + } + + router.post('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); + + try { + const parseResult = OAuthClientMetadataSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidClientMetadataError(parseResult.error.message); + } + + const clientMetadata = parseResult.data; + const isPublicClient = clientMetadata.token_endpoint_auth_method === 'none'; + + // Generate client credentials + const clientSecret = isPublicClient ? undefined : crypto.randomBytes(32).toString('hex'); + const clientIdIssuedAt = Math.floor(Date.now() / 1000); + + // Calculate client secret expiry time + const clientsDoExpire = clientSecretExpirySeconds > 0; + const secretExpiryTime = clientsDoExpire ? clientIdIssuedAt + clientSecretExpirySeconds : 0; + const clientSecretExpiresAt = isPublicClient ? undefined : secretExpiryTime; + + let clientInfo: Omit & { client_id?: string } = { + ...clientMetadata, + client_secret: clientSecret, + client_secret_expires_at: clientSecretExpiresAt + }; + + if (clientIdGeneration) { + clientInfo.client_id = crypto.randomUUID(); + clientInfo.client_id_issued_at = clientIdIssuedAt; + } + + clientInfo = await clientsStore.registerClient!(clientInfo); + res.status(201).json(clientInfo); + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }); + + return router; +} diff --git a/packages/server-auth-legacy/src/handlers/revoke.ts b/packages/server-auth-legacy/src/handlers/revoke.ts new file mode 100644 index 000000000..2a34a4449 --- /dev/null +++ b/packages/server-auth-legacy/src/handlers/revoke.ts @@ -0,0 +1,82 @@ +import { OAuthTokenRevocationRequestSchema } from '@modelcontextprotocol/core'; +import cors from 'cors'; +import type { RequestHandler } from 'express'; +import express from 'express'; +import type { Options as RateLimitOptions } from 'express-rate-limit'; +import { rateLimit } from 'express-rate-limit'; + +import { InvalidRequestError, OAuthError, ServerError, TooManyRequestsError } from '../errors.js'; +import { allowedMethods } from '../middleware/allowedMethods.js'; +import { authenticateClient } from '../middleware/clientAuth.js'; +import type { OAuthServerProvider } from '../provider.js'; + +export type RevocationHandlerOptions = { + provider: OAuthServerProvider; + /** + * Rate limiting configuration for the token revocation endpoint. + * Set to false to disable rate limiting for this endpoint. + */ + rateLimit?: Partial | false; +}; + +export function revocationHandler({ provider, rateLimit: rateLimitConfig }: RevocationHandlerOptions): RequestHandler { + if (!provider.revokeToken) { + throw new Error('Auth provider does not support revoking tokens'); + } + + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); + + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); + + router.use(allowedMethods(['POST'])); + router.use(express.urlencoded({ extended: false })); + + // Apply rate limiting unless explicitly disabled + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 15 * 60 * 1000, // 15 minutes + max: 50, // 50 requests per windowMs + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for token revocation requests').toResponseObject(), + ...rateLimitConfig + }) + ); + } + + // Authenticate and extract client details + router.use(authenticateClient({ clientsStore: provider.clientsStore })); + + router.post('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); + + try { + const parseResult = OAuthTokenRevocationRequestSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const client = req.client; + if (!client) { + // This should never happen + throw new ServerError('Internal Server Error'); + } + + await provider.revokeToken!(client, parseResult.data); + res.status(200).json({}); + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }); + + return router; +} diff --git a/packages/server-auth-legacy/src/handlers/token.ts b/packages/server-auth-legacy/src/handlers/token.ts new file mode 100644 index 000000000..7adfa7809 --- /dev/null +++ b/packages/server-auth-legacy/src/handlers/token.ts @@ -0,0 +1,160 @@ +import cors from 'cors'; +import type { RequestHandler } from 'express'; +import express from 'express'; +import type { Options as RateLimitOptions } from 'express-rate-limit'; +import { rateLimit } from 'express-rate-limit'; +import { verifyChallenge } from 'pkce-challenge'; +import * as z from 'zod/v4'; + +import { + InvalidGrantError, + InvalidRequestError, + OAuthError, + ServerError, + TooManyRequestsError, + UnsupportedGrantTypeError +} from '../errors.js'; +import { allowedMethods } from '../middleware/allowedMethods.js'; +import { authenticateClient } from '../middleware/clientAuth.js'; +import type { OAuthServerProvider } from '../provider.js'; + +export type TokenHandlerOptions = { + provider: OAuthServerProvider; + /** + * Rate limiting configuration for the token endpoint. + * Set to false to disable rate limiting for this endpoint. + */ + rateLimit?: Partial | false; +}; + +const TokenRequestSchema = z.object({ + grant_type: z.string() +}); + +const AuthorizationCodeGrantSchema = z.object({ + code: z.string(), + code_verifier: z.string(), + redirect_uri: z.string().optional(), + resource: z.string().url().optional() +}); + +const RefreshTokenGrantSchema = z.object({ + refresh_token: z.string(), + scope: z.string().optional(), + resource: z.string().url().optional() +}); + +export function tokenHandler({ provider, rateLimit: rateLimitConfig }: TokenHandlerOptions): RequestHandler { + // Nested router so we can configure middleware and restrict HTTP method + const router = express.Router(); + + // Configure CORS to allow any origin, to make accessible to web-based MCP clients + router.use(cors()); + + router.use(allowedMethods(['POST'])); + router.use(express.urlencoded({ extended: false })); + + // Apply rate limiting unless explicitly disabled + if (rateLimitConfig !== false) { + router.use( + rateLimit({ + windowMs: 15 * 60 * 1000, // 15 minutes + max: 50, // 50 requests per windowMs + standardHeaders: true, + legacyHeaders: false, + message: new TooManyRequestsError('You have exceeded the rate limit for token requests').toResponseObject(), + ...rateLimitConfig + }) + ); + } + + // Authenticate and extract client details + router.use(authenticateClient({ clientsStore: provider.clientsStore })); + + router.post('/', async (req, res) => { + res.setHeader('Cache-Control', 'no-store'); + + try { + const parseResult = TokenRequestSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const { grant_type } = parseResult.data; + + const client = req.client; + if (!client) { + // This should never happen + throw new ServerError('Internal Server Error'); + } + + switch (grant_type) { + case 'authorization_code': { + const parseResult = AuthorizationCodeGrantSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const { code, code_verifier, redirect_uri, resource } = parseResult.data; + + const skipLocalPkceValidation = provider.skipLocalPkceValidation; + + // Perform local PKCE validation unless explicitly skipped + // (e.g. to validate code_verifier in upstream server) + if (!skipLocalPkceValidation) { + const codeChallenge = await provider.challengeForAuthorizationCode(client, code); + if (!(await verifyChallenge(code_verifier, codeChallenge))) { + throw new InvalidGrantError('code_verifier does not match the challenge'); + } + } + + // Passes the code_verifier to the provider if PKCE validation didn't occur locally + const tokens = await provider.exchangeAuthorizationCode( + client, + code, + skipLocalPkceValidation ? code_verifier : undefined, + redirect_uri, + resource ? new URL(resource) : undefined + ); + res.status(200).json(tokens); + break; + } + + case 'refresh_token': { + const parseResult = RefreshTokenGrantSchema.safeParse(req.body); + if (!parseResult.success) { + throw new InvalidRequestError(parseResult.error.message); + } + + const { refresh_token, scope, resource } = parseResult.data; + + const scopes = scope?.split(' '); + const tokens = await provider.exchangeRefreshToken( + client, + refresh_token, + scopes, + resource ? new URL(resource) : undefined + ); + res.status(200).json(tokens); + break; + } + // Additional auth methods will not be added on the server side of the SDK. + // eslint-disable-next-line unicorn/no-useless-switch-case -- frozen v1 copy; explicit for clarity + case 'client_credentials': + default: { + throw new UnsupportedGrantTypeError('The grant type is not supported by this authorization server.'); + } + } + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }); + + return router; +} diff --git a/packages/server-auth-legacy/src/index.ts b/packages/server-auth-legacy/src/index.ts new file mode 100644 index 000000000..166384f30 --- /dev/null +++ b/packages/server-auth-legacy/src/index.ts @@ -0,0 +1,34 @@ +/** + * @packageDocumentation + * + * Frozen copy of the v1 SDK's `src/server/auth/` Authorization Server helpers. + * + * @deprecated The MCP SDK no longer ships an Authorization Server implementation. + * This package exists solely to ease migration from `@modelcontextprotocol/sdk` v1 + * and will not receive new features. Use a dedicated OAuth Authorization Server + * (e.g. an IdP) and the Resource Server helpers in `@modelcontextprotocol/express` + * instead. + */ + +export type { OAuthRegisteredClientsStore } from './clients.js'; +export * from './errors.js'; +export type { AuthorizationHandlerOptions } from './handlers/authorize.js'; +export { authorizationHandler, redirectUriMatches } from './handlers/authorize.js'; +export { metadataHandler } from './handlers/metadata.js'; +export type { ClientRegistrationHandlerOptions } from './handlers/register.js'; +export { clientRegistrationHandler } from './handlers/register.js'; +export type { RevocationHandlerOptions } from './handlers/revoke.js'; +export { revocationHandler } from './handlers/revoke.js'; +export type { TokenHandlerOptions } from './handlers/token.js'; +export { tokenHandler } from './handlers/token.js'; +export { allowedMethods } from './middleware/allowedMethods.js'; +export type { BearerAuthMiddlewareOptions } from './middleware/bearerAuth.js'; +export { requireBearerAuth } from './middleware/bearerAuth.js'; +export type { ClientAuthenticationMiddlewareOptions } from './middleware/clientAuth.js'; +export { authenticateClient } from './middleware/clientAuth.js'; +export type { AuthorizationParams, OAuthServerProvider, OAuthTokenVerifier } from './provider.js'; +export type { ProxyEndpoints, ProxyOptions } from './providers/proxyProvider.js'; +export { ProxyOAuthServerProvider } from './providers/proxyProvider.js'; +export type { AuthMetadataOptions, AuthRouterOptions } from './router.js'; +export { createOAuthMetadata, getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter, mcpAuthRouter } from './router.js'; +export type { AuthInfo } from './types.js'; diff --git a/packages/server-auth-legacy/src/middleware/allowedMethods.ts b/packages/server-auth-legacy/src/middleware/allowedMethods.ts new file mode 100644 index 000000000..b24dac3f2 --- /dev/null +++ b/packages/server-auth-legacy/src/middleware/allowedMethods.ts @@ -0,0 +1,21 @@ +import type { RequestHandler } from 'express'; + +import { MethodNotAllowedError } from '../errors.js'; + +/** + * Middleware to handle unsupported HTTP methods with a 405 Method Not Allowed response. + * + * @param allowedMethods Array of allowed HTTP methods for this endpoint (e.g., ['GET', 'POST']) + * @returns Express middleware that returns a 405 error if method not in allowed list + */ +export function allowedMethods(allowedMethods: string[]): RequestHandler { + return (req, res, next) => { + if (allowedMethods.includes(req.method)) { + next(); + return; + } + + const error = new MethodNotAllowedError(`The method ${req.method} is not allowed for this endpoint`); + res.status(405).set('Allow', allowedMethods.join(', ')).json(error.toResponseObject()); + }; +} diff --git a/packages/server-auth-legacy/src/middleware/bearerAuth.ts b/packages/server-auth-legacy/src/middleware/bearerAuth.ts new file mode 100644 index 000000000..247a3f152 --- /dev/null +++ b/packages/server-auth-legacy/src/middleware/bearerAuth.ts @@ -0,0 +1,104 @@ +import type { RequestHandler } from 'express'; + +import { InsufficientScopeError, InvalidTokenError, OAuthError, ServerError } from '../errors.js'; +import type { OAuthTokenVerifier } from '../provider.js'; +import type { AuthInfo } from '../types.js'; + +export type BearerAuthMiddlewareOptions = { + /** + * A provider used to verify tokens. + */ + verifier: OAuthTokenVerifier; + + /** + * Optional scopes that the token must have. + */ + requiredScopes?: string[]; + + /** + * Optional resource metadata URL to include in WWW-Authenticate header. + */ + resourceMetadataUrl?: string; +}; + +declare module 'express-serve-static-core' { + interface Request { + /** + * Information about the validated access token, if the `requireBearerAuth` middleware was used. + */ + auth?: AuthInfo; + } +} + +/** + * Middleware that requires a valid Bearer token in the Authorization header. + * + * This will validate the token with the auth provider and add the resulting auth info to the request object. + * + * If resourceMetadataUrl is provided, it will be included in the WWW-Authenticate header + * for 401 responses as per the OAuth 2.0 Protected Resource Metadata spec. + */ +export function requireBearerAuth({ verifier, requiredScopes = [], resourceMetadataUrl }: BearerAuthMiddlewareOptions): RequestHandler { + return async (req, res, next) => { + try { + const authHeader = req.headers.authorization; + if (!authHeader) { + throw new InvalidTokenError('Missing Authorization header'); + } + + const [type, token] = authHeader.split(' '); + if (type?.toLowerCase() !== 'bearer' || !token) { + throw new InvalidTokenError("Invalid Authorization header format, expected 'Bearer TOKEN'"); + } + + const authInfo = await verifier.verifyAccessToken(token); + + // Check if token has the required scopes (if any) + if (requiredScopes.length > 0) { + const hasAllScopes = requiredScopes.every(scope => authInfo.scopes.includes(scope)); + + if (!hasAllScopes) { + throw new InsufficientScopeError('Insufficient scope'); + } + } + + // Check if the token is set to expire or if it is expired + if (typeof authInfo.expiresAt !== 'number' || Number.isNaN(authInfo.expiresAt)) { + throw new InvalidTokenError('Token has no expiration time'); + } else if (authInfo.expiresAt < Date.now() / 1000) { + throw new InvalidTokenError('Token has expired'); + } + + req.auth = authInfo; + next(); + } catch (error) { + // Build WWW-Authenticate header parts + // eslint-disable-next-line unicorn/consistent-function-scoping -- frozen v1 copy; closes over middleware options + const buildWwwAuthHeader = (errorCode: string, message: string): string => { + let header = `Bearer error="${errorCode}", error_description="${message}"`; + if (requiredScopes.length > 0) { + header += `, scope="${requiredScopes.join(' ')}"`; + } + if (resourceMetadataUrl) { + header += `, resource_metadata="${resourceMetadataUrl}"`; + } + return header; + }; + + if (error instanceof InvalidTokenError) { + res.set('WWW-Authenticate', buildWwwAuthHeader(error.errorCode, error.message)); + res.status(401).json(error.toResponseObject()); + } else if (error instanceof InsufficientScopeError) { + res.set('WWW-Authenticate', buildWwwAuthHeader(error.errorCode, error.message)); + res.status(403).json(error.toResponseObject()); + } else if (error instanceof ServerError) { + res.status(500).json(error.toResponseObject()); + } else if (error instanceof OAuthError) { + res.status(400).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }; +} diff --git a/packages/server-auth-legacy/src/middleware/clientAuth.ts b/packages/server-auth-legacy/src/middleware/clientAuth.ts new file mode 100644 index 000000000..f3f7a3896 --- /dev/null +++ b/packages/server-auth-legacy/src/middleware/clientAuth.ts @@ -0,0 +1,65 @@ +import type { OAuthClientInformationFull } from '@modelcontextprotocol/core'; +import type { RequestHandler } from 'express'; +import * as z from 'zod/v4'; + +import type { OAuthRegisteredClientsStore } from '../clients.js'; +import { InvalidClientError, InvalidRequestError, OAuthError, ServerError } from '../errors.js'; + +export type ClientAuthenticationMiddlewareOptions = { + /** + * A store used to read information about registered OAuth clients. + */ + clientsStore: OAuthRegisteredClientsStore; +}; + +const ClientAuthenticatedRequestSchema = z.object({ + client_id: z.string(), + client_secret: z.string().optional() +}); + +declare module 'express-serve-static-core' { + interface Request { + /** + * The authenticated client for this request, if the `authenticateClient` middleware was used. + */ + client?: OAuthClientInformationFull; + } +} + +export function authenticateClient({ clientsStore }: ClientAuthenticationMiddlewareOptions): RequestHandler { + return async (req, res, next) => { + try { + const result = ClientAuthenticatedRequestSchema.safeParse(req.body); + if (!result.success) { + throw new InvalidRequestError(String(result.error)); + } + const { client_id, client_secret } = result.data; + const client = await clientsStore.getClient(client_id); + if (!client) { + throw new InvalidClientError('Invalid client_id'); + } + if (client.client_secret) { + if (!client_secret) { + throw new InvalidClientError('Client secret is required'); + } + if (client.client_secret !== client_secret) { + throw new InvalidClientError('Invalid client_secret'); + } + if (client.client_secret_expires_at && client.client_secret_expires_at < Math.floor(Date.now() / 1000)) { + throw new InvalidClientError('Client secret has expired'); + } + } + + req.client = client; + next(); + } catch (error) { + if (error instanceof OAuthError) { + const status = error instanceof ServerError ? 500 : 400; + res.status(status).json(error.toResponseObject()); + } else { + const serverError = new ServerError('Internal Server Error'); + res.status(500).json(serverError.toResponseObject()); + } + } + }; +} diff --git a/packages/server-auth-legacy/src/provider.ts b/packages/server-auth-legacy/src/provider.ts new file mode 100644 index 000000000..528e8d27b --- /dev/null +++ b/packages/server-auth-legacy/src/provider.ts @@ -0,0 +1,84 @@ +import type { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; +import type { Response } from 'express'; + +import type { OAuthRegisteredClientsStore } from './clients.js'; +import type { AuthInfo } from './types.js'; + +export type AuthorizationParams = { + state?: string; + scopes?: string[]; + codeChallenge: string; + redirectUri: string; + resource?: URL; +}; + +/** + * Implements an end-to-end OAuth server. + */ +export interface OAuthServerProvider { + /** + * A store used to read information about registered OAuth clients. + */ + get clientsStore(): OAuthRegisteredClientsStore; + + /** + * Begins the authorization flow, which can either be implemented by this server itself or via redirection to a separate authorization server. + * + * This server must eventually issue a redirect with an authorization response or an error response to the given redirect URI. Per OAuth 2.1: + * - In the successful case, the redirect MUST include the `code` and `state` (if present) query parameters. + * - In the error case, the redirect MUST include the `error` query parameter, and MAY include an optional `error_description` query parameter. + */ + authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise; + + /** + * Returns the `codeChallenge` that was used when the indicated authorization began. + */ + challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise; + + /** + * Exchanges an authorization code for an access token. + */ + exchangeAuthorizationCode( + client: OAuthClientInformationFull, + authorizationCode: string, + codeVerifier?: string, + redirectUri?: string, + resource?: URL + ): Promise; + + /** + * Exchanges a refresh token for an access token. + */ + exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[], resource?: URL): Promise; + + /** + * Verifies an access token and returns information about it. + */ + verifyAccessToken(token: string): Promise; + + /** + * Revokes an access or refresh token. If unimplemented, token revocation is not supported (not recommended). + * + * If the given token is invalid or already revoked, this method should do nothing. + */ + revokeToken?(client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest): Promise; + + /** + * Whether to skip local PKCE validation. + * + * If true, the server will not perform PKCE validation locally and will pass the code_verifier to the upstream server. + * + * NOTE: This should only be true if the upstream server is performing the actual PKCE validation. + */ + skipLocalPkceValidation?: boolean; +} + +/** + * Slim implementation useful for token verification + */ +export interface OAuthTokenVerifier { + /** + * Verifies an access token and returns information about it. + */ + verifyAccessToken(token: string): Promise; +} diff --git a/packages/server-auth-legacy/src/providers/proxyProvider.ts b/packages/server-auth-legacy/src/providers/proxyProvider.ts new file mode 100644 index 000000000..b469ce6df --- /dev/null +++ b/packages/server-auth-legacy/src/providers/proxyProvider.ts @@ -0,0 +1,233 @@ +import type { FetchLike, OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; +import { OAuthClientInformationFullSchema, OAuthTokensSchema } from '@modelcontextprotocol/core'; +import type { Response } from 'express'; + +import type { OAuthRegisteredClientsStore } from '../clients.js'; +import { ServerError } from '../errors.js'; +import type { AuthorizationParams, OAuthServerProvider } from '../provider.js'; +import type { AuthInfo } from '../types.js'; + +export type ProxyEndpoints = { + authorizationUrl: string; + tokenUrl: string; + revocationUrl?: string; + registrationUrl?: string; +}; + +export type ProxyOptions = { + /** + * Individual endpoint URLs for proxying specific OAuth operations + */ + endpoints: ProxyEndpoints; + + /** + * Function to verify access tokens and return auth info + */ + verifyAccessToken: (token: string) => Promise; + + /** + * Function to fetch client information from the upstream server + */ + getClient: (clientId: string) => Promise; + + /** + * Custom fetch implementation used for all network requests. + */ + fetch?: FetchLike; +}; + +/** + * Implements an OAuth server that proxies requests to another OAuth server. + */ +export class ProxyOAuthServerProvider implements OAuthServerProvider { + protected readonly _endpoints: ProxyEndpoints; + protected readonly _verifyAccessToken: (token: string) => Promise; + protected readonly _getClient: (clientId: string) => Promise; + protected readonly _fetch?: FetchLike; + + skipLocalPkceValidation = true; + + revokeToken?: (client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest) => Promise; + + constructor(options: ProxyOptions) { + this._endpoints = options.endpoints; + this._verifyAccessToken = options.verifyAccessToken; + this._getClient = options.getClient; + this._fetch = options.fetch; + if (options.endpoints?.revocationUrl) { + this.revokeToken = async (client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest) => { + const revocationUrl = this._endpoints.revocationUrl; + + if (!revocationUrl) { + throw new Error('No revocation endpoint configured'); + } + + const params = new URLSearchParams(); + params.set('token', request.token); + params.set('client_id', client.client_id); + if (client.client_secret) { + params.set('client_secret', client.client_secret); + } + if (request.token_type_hint) { + params.set('token_type_hint', request.token_type_hint); + } + + const response = await (this._fetch ?? fetch)(revocationUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: params.toString() + }); + await response.body?.cancel(); + + if (!response.ok) { + throw new ServerError(`Token revocation failed: ${response.status}`); + } + }; + } + } + + get clientsStore(): OAuthRegisteredClientsStore { + const registrationUrl = this._endpoints.registrationUrl; + return { + getClient: this._getClient, + ...(registrationUrl && { + registerClient: async (client: OAuthClientInformationFull) => { + const response = await (this._fetch ?? fetch)(registrationUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(client) + }); + + if (!response.ok) { + await response.body?.cancel(); + throw new ServerError(`Client registration failed: ${response.status}`); + } + + const data = await response.json(); + return OAuthClientInformationFullSchema.parse(data); + } + }) + }; + } + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + // Start with required OAuth parameters + const targetUrl = new URL(this._endpoints.authorizationUrl); + const searchParams = new URLSearchParams({ + client_id: client.client_id, + response_type: 'code', + redirect_uri: params.redirectUri, + code_challenge: params.codeChallenge, + code_challenge_method: 'S256' + }); + + // Add optional standard OAuth parameters + if (params.state) searchParams.set('state', params.state); + if (params.scopes?.length) searchParams.set('scope', params.scopes.join(' ')); + if (params.resource) searchParams.set('resource', params.resource.href); + + targetUrl.search = searchParams.toString(); + res.redirect(targetUrl.toString()); + } + + async challengeForAuthorizationCode(_client: OAuthClientInformationFull, _authorizationCode: string): Promise { + // In a proxy setup, we don't store the code challenge ourselves + // Instead, we proxy the token request and let the upstream server validate it + return ''; + } + + async exchangeAuthorizationCode( + client: OAuthClientInformationFull, + authorizationCode: string, + codeVerifier?: string, + redirectUri?: string, + resource?: URL + ): Promise { + const params = new URLSearchParams({ + grant_type: 'authorization_code', + client_id: client.client_id, + code: authorizationCode + }); + + if (client.client_secret) { + params.append('client_secret', client.client_secret); + } + + if (codeVerifier) { + params.append('code_verifier', codeVerifier); + } + + if (redirectUri) { + params.append('redirect_uri', redirectUri); + } + + if (resource) { + params.append('resource', resource.href); + } + + const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: params.toString() + }); + + if (!response.ok) { + await response.body?.cancel(); + throw new ServerError(`Token exchange failed: ${response.status}`); + } + + const data = await response.json(); + return OAuthTokensSchema.parse(data); + } + + async exchangeRefreshToken( + client: OAuthClientInformationFull, + refreshToken: string, + scopes?: string[], + resource?: URL + ): Promise { + const params = new URLSearchParams({ + grant_type: 'refresh_token', + client_id: client.client_id, + refresh_token: refreshToken + }); + + if (client.client_secret) { + params.set('client_secret', client.client_secret); + } + + if (scopes?.length) { + params.set('scope', scopes.join(' ')); + } + + if (resource) { + params.set('resource', resource.href); + } + + const response = await (this._fetch ?? fetch)(this._endpoints.tokenUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: params.toString() + }); + + if (!response.ok) { + await response.body?.cancel(); + throw new ServerError(`Token refresh failed: ${response.status}`); + } + + const data = await response.json(); + return OAuthTokensSchema.parse(data); + } + + async verifyAccessToken(token: string): Promise { + return this._verifyAccessToken(token); + } +} diff --git a/packages/server-auth-legacy/src/router.ts b/packages/server-auth-legacy/src/router.ts new file mode 100644 index 000000000..ba8b030e0 --- /dev/null +++ b/packages/server-auth-legacy/src/router.ts @@ -0,0 +1,246 @@ +import type { OAuthMetadata, OAuthProtectedResourceMetadata } from '@modelcontextprotocol/core'; +import type { RequestHandler } from 'express'; +import express from 'express'; + +import type { AuthorizationHandlerOptions } from './handlers/authorize.js'; +import { authorizationHandler } from './handlers/authorize.js'; +import { metadataHandler } from './handlers/metadata.js'; +import type { ClientRegistrationHandlerOptions } from './handlers/register.js'; +import { clientRegistrationHandler } from './handlers/register.js'; +import type { RevocationHandlerOptions } from './handlers/revoke.js'; +import { revocationHandler } from './handlers/revoke.js'; +import type { TokenHandlerOptions } from './handlers/token.js'; +import { tokenHandler } from './handlers/token.js'; +import type { OAuthServerProvider } from './provider.js'; + +// Check for dev mode flag that allows HTTP issuer URLs (for development/testing only) +const allowInsecureIssuerUrl = + process.env.MCP_DANGEROUSLY_ALLOW_INSECURE_ISSUER_URL === 'true' || process.env.MCP_DANGEROUSLY_ALLOW_INSECURE_ISSUER_URL === '1'; +if (allowInsecureIssuerUrl) { + // eslint-disable-next-line no-console + console.warn('MCP_DANGEROUSLY_ALLOW_INSECURE_ISSUER_URL is enabled - HTTP issuer URLs are allowed. Do not use in production.'); +} + +export type AuthRouterOptions = { + /** + * A provider implementing the actual authorization logic for this router. + */ + provider: OAuthServerProvider; + + /** + * The authorization server's issuer identifier, which is a URL that uses the "https" scheme and has no query or fragment components. + */ + issuerUrl: URL; + + /** + * The base URL of the authorization server to use for the metadata endpoints. + * + * If not provided, the issuer URL will be used as the base URL. + */ + baseUrl?: URL; + + /** + * An optional URL of a page containing human-readable information that developers might want or need to know when using the authorization server. + */ + serviceDocumentationUrl?: URL; + + /** + * An optional list of scopes supported by this authorization server + */ + scopesSupported?: string[]; + + /** + * The resource name to be displayed in protected resource metadata + */ + resourceName?: string; + + /** + * The URL of the protected resource (RS) whose metadata we advertise. + * If not provided, falls back to `baseUrl` and then to `issuerUrl` (AS=RS). + */ + resourceServerUrl?: URL; + + // Individual options per route + authorizationOptions?: Omit; + clientRegistrationOptions?: Omit; + revocationOptions?: Omit; + tokenOptions?: Omit; +}; + +const checkIssuerUrl = (issuer: URL): void => { + // Technically RFC 8414 does not permit a localhost HTTPS exemption, but this will be necessary for ease of testing + if (issuer.protocol !== 'https:' && issuer.hostname !== 'localhost' && issuer.hostname !== '127.0.0.1' && !allowInsecureIssuerUrl) { + throw new Error('Issuer URL must be HTTPS'); + } + if (issuer.hash) { + throw new Error(`Issuer URL must not have a fragment: ${issuer}`); + } + if (issuer.search) { + throw new Error(`Issuer URL must not have a query string: ${issuer}`); + } +}; + +export const createOAuthMetadata = (options: { + provider: OAuthServerProvider; + issuerUrl: URL; + baseUrl?: URL; + serviceDocumentationUrl?: URL; + scopesSupported?: string[]; +}): OAuthMetadata => { + const issuer = options.issuerUrl; + const baseUrl = options.baseUrl; + + checkIssuerUrl(issuer); + + const authorization_endpoint = '/authorize'; + const token_endpoint = '/token'; + const registration_endpoint = options.provider.clientsStore.registerClient ? '/register' : undefined; + const revocation_endpoint = options.provider.revokeToken ? '/revoke' : undefined; + + const metadata: OAuthMetadata = { + issuer: issuer.href, + service_documentation: options.serviceDocumentationUrl?.href, + + authorization_endpoint: new URL(authorization_endpoint, baseUrl || issuer).href, + response_types_supported: ['code'], + code_challenge_methods_supported: ['S256'], + + token_endpoint: new URL(token_endpoint, baseUrl || issuer).href, + token_endpoint_auth_methods_supported: ['client_secret_post', 'none'], + grant_types_supported: ['authorization_code', 'refresh_token'], + + scopes_supported: options.scopesSupported, + + revocation_endpoint: revocation_endpoint ? new URL(revocation_endpoint, baseUrl || issuer).href : undefined, + revocation_endpoint_auth_methods_supported: revocation_endpoint ? ['client_secret_post'] : undefined, + + registration_endpoint: registration_endpoint ? new URL(registration_endpoint, baseUrl || issuer).href : undefined + }; + + return metadata; +}; + +/** + * Installs standard MCP authorization server endpoints, including dynamic client registration and token revocation (if supported). + * Also advertises standard authorization server metadata, for easier discovery of supported configurations by clients. + * Note: if your MCP server is only a resource server and not an authorization server, use mcpAuthMetadataRouter instead. + * + * By default, rate limiting is applied to all endpoints to prevent abuse. + * + * This router MUST be installed at the application root, like so: + * + * const app = express(); + * app.use(mcpAuthRouter(...)); + */ +export function mcpAuthRouter(options: AuthRouterOptions): RequestHandler { + const oauthMetadata = createOAuthMetadata(options); + + const router = express.Router(); + + router.use( + new URL(oauthMetadata.authorization_endpoint).pathname, + authorizationHandler({ provider: options.provider, ...options.authorizationOptions }) + ); + + router.use(new URL(oauthMetadata.token_endpoint).pathname, tokenHandler({ provider: options.provider, ...options.tokenOptions })); + + router.use( + mcpAuthMetadataRouter({ + oauthMetadata, + // Prefer explicit RS; otherwise fall back to AS baseUrl, then to issuer (back-compat) + resourceServerUrl: options.resourceServerUrl ?? options.baseUrl ?? new URL(oauthMetadata.issuer), + serviceDocumentationUrl: options.serviceDocumentationUrl, + scopesSupported: options.scopesSupported, + resourceName: options.resourceName + }) + ); + + if (oauthMetadata.registration_endpoint) { + router.use( + new URL(oauthMetadata.registration_endpoint).pathname, + clientRegistrationHandler({ + clientsStore: options.provider.clientsStore, + ...options.clientRegistrationOptions + }) + ); + } + + if (oauthMetadata.revocation_endpoint) { + router.use( + new URL(oauthMetadata.revocation_endpoint).pathname, + revocationHandler({ provider: options.provider, ...options.revocationOptions }) + ); + } + + return router; +} + +export type AuthMetadataOptions = { + /** + * OAuth Metadata as would be returned from the authorization server + * this MCP server relies on + */ + oauthMetadata: OAuthMetadata; + + /** + * The url of the MCP server, for use in protected resource metadata + */ + resourceServerUrl: URL; + + /** + * The url for documentation for the MCP server + */ + serviceDocumentationUrl?: URL; + + /** + * An optional list of scopes supported by this MCP server + */ + scopesSupported?: string[]; + + /** + * An optional resource name to display in resource metadata + */ + resourceName?: string; +}; + +export function mcpAuthMetadataRouter(options: AuthMetadataOptions): express.Router { + checkIssuerUrl(new URL(options.oauthMetadata.issuer)); + + const router = express.Router(); + + const protectedResourceMetadata: OAuthProtectedResourceMetadata = { + resource: options.resourceServerUrl.href, + + authorization_servers: [options.oauthMetadata.issuer], + + scopes_supported: options.scopesSupported, + resource_name: options.resourceName, + resource_documentation: options.serviceDocumentationUrl?.href + }; + + // Serve PRM at the path-specific URL per RFC 9728 + const rsPath = new URL(options.resourceServerUrl.href).pathname; + router.use(`/.well-known/oauth-protected-resource${rsPath === '/' ? '' : rsPath}`, metadataHandler(protectedResourceMetadata)); + + // Always add this for OAuth Authorization Server metadata per RFC 8414 + router.use('/.well-known/oauth-authorization-server', metadataHandler(options.oauthMetadata)); + + return router; +} + +/** + * Helper function to construct the OAuth 2.0 Protected Resource Metadata URL + * from a given server URL. This replaces the path with the standard metadata endpoint. + * + * @param serverUrl - The base URL of the protected resource server + * @returns The URL for the OAuth protected resource metadata endpoint + * + * @example + * getOAuthProtectedResourceMetadataUrl(new URL('https://api.example.com/mcp')) + * // Returns: 'https://api.example.com/.well-known/oauth-protected-resource/mcp' + */ +export function getOAuthProtectedResourceMetadataUrl(serverUrl: URL): string { + const u = new URL(serverUrl.href); + const rsPath = u.pathname && u.pathname !== '/' ? u.pathname : ''; + return new URL(`/.well-known/oauth-protected-resource${rsPath}`, u).href; +} diff --git a/packages/server-auth-legacy/src/types.ts b/packages/server-auth-legacy/src/types.ts new file mode 100644 index 000000000..b15d371fa --- /dev/null +++ b/packages/server-auth-legacy/src/types.ts @@ -0,0 +1,8 @@ +/** + * Information about a validated access token, provided to request handlers. + * + * Re-exported from `@modelcontextprotocol/core` so that tokens verified by the + * legacy `requireBearerAuth` middleware are structurally compatible with + * the v2 SDK's request-handler context. + */ +export type { AuthInfo } from '@modelcontextprotocol/core'; diff --git a/packages/server-auth-legacy/test/handlers/authorize.test.ts b/packages/server-auth-legacy/test/handlers/authorize.test.ts new file mode 100644 index 000000000..215d79df0 --- /dev/null +++ b/packages/server-auth-legacy/test/handlers/authorize.test.ts @@ -0,0 +1,400 @@ +import { authorizationHandler, AuthorizationHandlerOptions, redirectUriMatches } from '../../src/handlers/authorize.js'; +import { OAuthServerProvider, AuthorizationParams } from '../../src/provider.js'; +import { OAuthRegisteredClientsStore } from '../../src/clients.js'; +import { OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; +import express, { Response } from 'express'; +import supertest from 'supertest'; +import { AuthInfo } from '../../src/types.js'; +import { InvalidTokenError } from '../../src/errors.js'; + +describe('Authorization Handler', () => { + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'], + scope: 'profile email' + }; + + const multiRedirectClient: OAuthClientInformationFull = { + client_id: 'multi-redirect-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback1', 'https://example.com/callback2'], + scope: 'profile email' + }; + + // Native app client with a portless loopback redirect (e.g., from CIMD / SEP-991) + const loopbackClient: OAuthClientInformationFull = { + client_id: 'loopback-client', + client_secret: 'valid-secret', + redirect_uris: ['http://localhost/callback', 'http://127.0.0.1/callback'], + scope: 'profile email' + }; + + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return validClient; + } else if (clientId === 'multi-redirect-client') { + return multiRedirectClient; + } else if (clientId === 'loopback-client') { + return loopbackClient; + } + return undefined; + } + }; + + // Mock provider + const mockProvider: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + // Mock implementation - redirects to redirectUri with code and state + const redirectUrl = new URL(params.redirectUri); + redirectUrl.searchParams.set('code', 'mock_auth_code'); + if (params.state) { + redirectUrl.searchParams.set('state', params.state); + } + res.redirect(302, redirectUrl.toString()); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, + + async revokeToken(): Promise { + // Do nothing in mock + } + }; + + // Setup express app with handler + let app: express.Express; + let options: AuthorizationHandlerOptions; + + beforeEach(() => { + app = express(); + options = { provider: mockProvider }; + const handler = authorizationHandler(options); + app.use('/authorize', handler); + }); + + describe('HTTP method validation', () => { + it('rejects non-GET/POST methods', async () => { + const response = await supertest(app).put('/authorize').query({ client_id: 'valid-client' }); + + expect(response.status).toBe(405); // Method not allowed response from handler + }); + }); + + describe('Client validation', () => { + it('requires client_id parameter', async () => { + const response = await supertest(app).get('/authorize'); + + expect(response.status).toBe(400); + expect(response.text).toContain('client_id'); + }); + + it('validates that client exists', async () => { + const response = await supertest(app).get('/authorize').query({ client_id: 'nonexistent-client' }); + + expect(response.status).toBe(400); + }); + }); + + describe('Redirect URI validation', () => { + it('uses the only redirect_uri if client has just one and none provided', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.origin + location.pathname).toBe('https://example.com/callback'); + }); + + it('requires redirect_uri if client has multiple', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'multi-redirect-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(400); + }); + + it('validates redirect_uri against client registered URIs', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://malicious.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(400); + }); + + it('accepts valid redirect_uri that client registered with', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.origin + location.pathname).toBe('https://example.com/callback'); + }); + + // RFC 8252 §7.3: authorization servers MUST allow any port for loopback + // redirect URIs. Native apps obtain ephemeral ports from the OS. + it('accepts loopback redirect_uri with ephemeral port (RFC 8252)', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'loopback-client', + redirect_uri: 'http://localhost:53428/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.hostname).toBe('localhost'); + expect(location.port).toBe('53428'); + expect(location.pathname).toBe('/callback'); + }); + + it('accepts 127.0.0.1 loopback redirect_uri with ephemeral port (RFC 8252)', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'loopback-client', + redirect_uri: 'http://127.0.0.1:9000/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + }); + + it('rejects loopback redirect_uri with different path', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'loopback-client', + redirect_uri: 'http://localhost:53428/evil', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(400); + }); + + it('does not relax port for non-loopback redirect_uri', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com:8443/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(400); + }); + }); + + describe('redirectUriMatches (RFC 8252 §7.3)', () => { + it('exact match passes', () => { + expect(redirectUriMatches('https://example.com/cb', 'https://example.com/cb')).toBe(true); + }); + + it('loopback: any port matches portless registration', () => { + expect(redirectUriMatches('http://localhost:53428/callback', 'http://localhost/callback')).toBe(true); + expect(redirectUriMatches('http://127.0.0.1:8080/callback', 'http://127.0.0.1/callback')).toBe(true); + expect(redirectUriMatches('http://[::1]:9000/cb', 'http://[::1]/cb')).toBe(true); + }); + + it('loopback: any port matches ported registration', () => { + expect(redirectUriMatches('http://localhost:53428/callback', 'http://localhost:3118/callback')).toBe(true); + }); + + it('loopback: different path rejected', () => { + expect(redirectUriMatches('http://localhost:53428/evil', 'http://localhost/callback')).toBe(false); + }); + + it('loopback: different scheme rejected', () => { + expect(redirectUriMatches('https://localhost:53428/callback', 'http://localhost/callback')).toBe(false); + }); + + it('loopback: localhost↔127.0.0.1 cross-match rejected', () => { + // RFC 8252 relaxes port only, not host + expect(redirectUriMatches('http://127.0.0.1:53428/callback', 'http://localhost/callback')).toBe(false); + }); + + it('non-loopback: port must match exactly', () => { + expect(redirectUriMatches('https://example.com:8443/cb', 'https://example.com/cb')).toBe(false); + }); + + it('non-loopback: no relaxation for private IPs', () => { + expect(redirectUriMatches('http://192.168.1.1:8080/cb', 'http://192.168.1.1/cb')).toBe(false); + }); + + it('malformed URIs rejected', () => { + expect(redirectUriMatches('not a url', 'http://localhost/cb')).toBe(false); + expect(redirectUriMatches('http://localhost/cb', 'not a url')).toBe(false); + }); + }); + + describe('Authorization request validation', () => { + it('requires response_type=code', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'token', // invalid - we only support code flow + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.searchParams.get('error')).toBe('invalid_request'); + }); + + it('requires code_challenge parameter', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge_method: 'S256' + // Missing code_challenge + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.searchParams.get('error')).toBe('invalid_request'); + }); + + it('requires code_challenge_method=S256', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'plain' // Only S256 is supported + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.searchParams.get('error')).toBe('invalid_request'); + }); + }); + + describe('Resource parameter validation', () => { + it('propagates resource parameter', async () => { + const mockProviderWithResource = vi.spyOn(mockProvider, 'authorize'); + + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + resource: 'https://api.example.com/resource' + }); + + expect(response.status).toBe(302); + expect(mockProviderWithResource).toHaveBeenCalledWith( + validClient, + expect.objectContaining({ + resource: new URL('https://api.example.com/resource'), + redirectUri: 'https://example.com/callback', + codeChallenge: 'challenge123' + }), + expect.any(Object) + ); + }); + }); + + describe('Successful authorization', () => { + it('handles successful authorization with all parameters', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + scope: 'profile email', + state: 'xyz789' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.origin + location.pathname).toBe('https://example.com/callback'); + expect(location.searchParams.get('code')).toBe('mock_auth_code'); + expect(location.searchParams.get('state')).toBe('xyz789'); + }); + + it('preserves state parameter in response', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + redirect_uri: 'https://example.com/callback', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256', + state: 'state-value-123' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.searchParams.get('state')).toBe('state-value-123'); + }); + + it('handles POST requests the same as GET', async () => { + const response = await supertest(app).post('/authorize').type('form').send({ + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.searchParams.has('code')).toBe(true); + }); + }); +}); diff --git a/packages/server-auth-legacy/test/handlers/metadata.test.ts b/packages/server-auth-legacy/test/handlers/metadata.test.ts new file mode 100644 index 000000000..3c89134ae --- /dev/null +++ b/packages/server-auth-legacy/test/handlers/metadata.test.ts @@ -0,0 +1,78 @@ +import { metadataHandler } from '../../src/handlers/metadata.js'; +import { OAuthMetadata } from '@modelcontextprotocol/core'; +import express from 'express'; +import supertest from 'supertest'; + +describe('Metadata Handler', () => { + const exampleMetadata: OAuthMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + registration_endpoint: 'https://auth.example.com/register', + revocation_endpoint: 'https://auth.example.com/revoke', + scopes_supported: ['profile', 'email'], + response_types_supported: ['code'], + grant_types_supported: ['authorization_code', 'refresh_token'], + token_endpoint_auth_methods_supported: ['client_secret_basic'], + code_challenge_methods_supported: ['S256'] + }; + + let app: express.Express; + + beforeEach(() => { + // Setup express app with metadata handler + app = express(); + app.use('/.well-known/oauth-authorization-server', metadataHandler(exampleMetadata)); + }); + + it('requires GET method', async () => { + const response = await supertest(app).post('/.well-known/oauth-authorization-server').send({}); + + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('GET, OPTIONS'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method POST is not allowed for this endpoint' + }); + }); + + it('returns the metadata object', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server'); + + expect(response.status).toBe(200); + expect(response.body).toEqual(exampleMetadata); + }); + + it('includes CORS headers in response', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server').set('Origin', 'https://example.com'); + + expect(response.header['access-control-allow-origin']).toBe('*'); + }); + + it('supports OPTIONS preflight requests', async () => { + const response = await supertest(app) + .options('/.well-known/oauth-authorization-server') + .set('Origin', 'https://example.com') + .set('Access-Control-Request-Method', 'GET'); + + expect(response.status).toBe(204); + expect(response.header['access-control-allow-origin']).toBe('*'); + }); + + it('works with minimal metadata', async () => { + // Setup a new express app with minimal metadata + const minimalApp = express(); + const minimalMetadata: OAuthMetadata = { + issuer: 'https://auth.example.com', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'] + }; + minimalApp.use('/.well-known/oauth-authorization-server', metadataHandler(minimalMetadata)); + + const response = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); + + expect(response.status).toBe(200); + expect(response.body).toEqual(minimalMetadata); + }); +}); diff --git a/packages/server-auth-legacy/test/handlers/register.test.ts b/packages/server-auth-legacy/test/handlers/register.test.ts new file mode 100644 index 000000000..dc3e45023 --- /dev/null +++ b/packages/server-auth-legacy/test/handlers/register.test.ts @@ -0,0 +1,272 @@ +import { clientRegistrationHandler, ClientRegistrationHandlerOptions } from '../../src/handlers/register.js'; +import { OAuthRegisteredClientsStore } from '../../src/clients.js'; +import { OAuthClientInformationFull, OAuthClientMetadata } from '@modelcontextprotocol/core'; +import express from 'express'; +import supertest from 'supertest'; +import { MockInstance } from 'vitest'; + +describe('Client Registration Handler', () => { + // Mock client store with registration support + const mockClientStoreWithRegistration: OAuthRegisteredClientsStore = { + async getClient(_clientId: string): Promise { + return undefined; + }, + + async registerClient(client: OAuthClientInformationFull): Promise { + // Return the client info as-is in the mock + return client; + } + }; + + // Mock client store without registration support + const mockClientStoreWithoutRegistration: OAuthRegisteredClientsStore = { + async getClient(_clientId: string): Promise { + return undefined; + } + // No registerClient method + }; + + describe('Handler creation', () => { + it('throws error if client store does not support registration', () => { + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithoutRegistration + }; + + expect(() => clientRegistrationHandler(options)).toThrow('does not support registering clients'); + }); + + it('creates handler if client store supports registration', () => { + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration + }; + + expect(() => clientRegistrationHandler(options)).not.toThrow(); + }); + }); + + describe('Request handling', () => { + let app: express.Express; + let spyRegisterClient: MockInstance; + + beforeEach(() => { + // Setup express app with registration handler + app = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientSecretExpirySeconds: 86400 // 1 day for testing + }; + + app.use('/register', clientRegistrationHandler(options)); + + // Spy on the registerClient method + spyRegisterClient = vi.spyOn(mockClientStoreWithRegistration, 'registerClient'); + }); + + afterEach(() => { + spyRegisterClient.mockRestore(); + }); + + it('requires POST method', async () => { + const response = await supertest(app) + .get('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('POST'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method GET is not allowed for this endpoint' + }); + expect(spyRegisterClient).not.toHaveBeenCalled(); + }); + + it('validates required client metadata', async () => { + const response = await supertest(app).post('/register').send({ + // Missing redirect_uris (required) + client_name: 'Test Client' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client_metadata'); + expect(spyRegisterClient).not.toHaveBeenCalled(); + }); + + it('validates redirect URIs format', async () => { + const response = await supertest(app) + .post('/register') + .send({ + redirect_uris: ['invalid-url'] // Invalid URL format + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client_metadata'); + expect(response.body.error_description).toContain('redirect_uris'); + expect(spyRegisterClient).not.toHaveBeenCalled(); + }); + + it('successfully registers client with minimal metadata', async () => { + const clientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'] + }; + + const response = await supertest(app).post('/register').send(clientMetadata); + + expect(response.status).toBe(201); + + // Verify the generated client information + expect(response.body.client_id).toBeDefined(); + expect(response.body.client_secret).toBeDefined(); + expect(response.body.client_id_issued_at).toBeDefined(); + expect(response.body.client_secret_expires_at).toBeDefined(); + expect(response.body.redirect_uris).toEqual(['https://example.com/callback']); + + // Verify client was registered + expect(spyRegisterClient).toHaveBeenCalledTimes(1); + }); + + it('sets client_secret to undefined for token_endpoint_auth_method=none', async () => { + const clientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'none' + }; + + const response = await supertest(app).post('/register').send(clientMetadata); + + expect(response.status).toBe(201); + expect(response.body.client_secret).toBeUndefined(); + expect(response.body.client_secret_expires_at).toBeUndefined(); + }); + + it('sets client_secret_expires_at for public clients only', async () => { + // Test for public client (token_endpoint_auth_method not 'none') + const publicClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'client_secret_basic' + }; + + const publicResponse = await supertest(app).post('/register').send(publicClientMetadata); + + expect(publicResponse.status).toBe(201); + expect(publicResponse.body.client_secret).toBeDefined(); + expect(publicResponse.body.client_secret_expires_at).toBeDefined(); + + // Test for non-public client (token_endpoint_auth_method is 'none') + const nonPublicClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'none' + }; + + const nonPublicResponse = await supertest(app).post('/register').send(nonPublicClientMetadata); + + expect(nonPublicResponse.status).toBe(201); + expect(nonPublicResponse.body.client_secret).toBeUndefined(); + expect(nonPublicResponse.body.client_secret_expires_at).toBeUndefined(); + }); + + it('sets expiry based on clientSecretExpirySeconds', async () => { + // Create handler with custom expiry time + const customApp = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientSecretExpirySeconds: 3600 // 1 hour + }; + + customApp.use('/register', clientRegistrationHandler(options)); + + const response = await supertest(customApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.status).toBe(201); + + // Verify the expiration time (~1 hour from now) + const issuedAt = response.body.client_id_issued_at; + const expiresAt = response.body.client_secret_expires_at; + expect(expiresAt - issuedAt).toBe(3600); + }); + + it('sets no expiry when clientSecretExpirySeconds=0', async () => { + // Create handler with no expiry + const customApp = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientSecretExpirySeconds: 0 // No expiry + }; + + customApp.use('/register', clientRegistrationHandler(options)); + + const response = await supertest(customApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.status).toBe(201); + expect(response.body.client_secret_expires_at).toBe(0); + }); + + it('sets no client_id when clientIdGeneration=false', async () => { + // Create handler with no expiry + const customApp = express(); + const options: ClientRegistrationHandlerOptions = { + clientsStore: mockClientStoreWithRegistration, + clientIdGeneration: false + }; + + customApp.use('/register', clientRegistrationHandler(options)); + + const response = await supertest(customApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.status).toBe(201); + expect(response.body.client_id).toBeUndefined(); + expect(response.body.client_id_issued_at).toBeUndefined(); + }); + + it('handles client with all metadata fields', async () => { + const fullClientMetadata: OAuthClientMetadata = { + redirect_uris: ['https://example.com/callback'], + token_endpoint_auth_method: 'client_secret_basic', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + client_name: 'Test Client', + client_uri: 'https://example.com', + logo_uri: 'https://example.com/logo.png', + scope: 'profile email', + contacts: ['dev@example.com'], + tos_uri: 'https://example.com/tos', + policy_uri: 'https://example.com/privacy', + jwks_uri: 'https://example.com/jwks', + software_id: 'test-software', + software_version: '1.0.0' + }; + + const response = await supertest(app).post('/register').send(fullClientMetadata); + + expect(response.status).toBe(201); + + // Verify all metadata was preserved + Object.entries(fullClientMetadata).forEach(([key, value]) => { + expect(response.body[key]).toEqual(value); + }); + }); + + it('includes CORS headers in response', async () => { + const response = await supertest(app) + .post('/register') + .set('Origin', 'https://example.com') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + expect(response.header['access-control-allow-origin']).toBe('*'); + }); + }); +}); diff --git a/packages/server-auth-legacy/test/handlers/revoke.test.ts b/packages/server-auth-legacy/test/handlers/revoke.test.ts new file mode 100644 index 000000000..d0aba1152 --- /dev/null +++ b/packages/server-auth-legacy/test/handlers/revoke.test.ts @@ -0,0 +1,231 @@ +import { revocationHandler, RevocationHandlerOptions } from '../../src/handlers/revoke.js'; +import { OAuthServerProvider, AuthorizationParams } from '../../src/provider.js'; +import { OAuthRegisteredClientsStore } from '../../src/clients.js'; +import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; +import express, { Response } from 'express'; +import supertest from 'supertest'; +import { AuthInfo } from '../../src/types.js'; +import { InvalidTokenError } from '../../src/errors.js'; +import { MockInstance } from 'vitest'; + +describe('Revocation Handler', () => { + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return validClient; + } + return undefined; + } + }; + + // Mock provider with revocation capability + const mockProviderWithRevocation: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + res.redirect('https://example.com/callback?code=mock_auth_code'); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, + + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // Success - do nothing in mock + } + }; + + // Mock provider without revocation capability + const mockProviderWithoutRevocation: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + res.redirect('https://example.com/callback?code=mock_auth_code'); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + } + + // No revokeToken method + }; + + describe('Handler creation', () => { + it('throws error if provider does not support token revocation', () => { + const options: RevocationHandlerOptions = { provider: mockProviderWithoutRevocation }; + expect(() => revocationHandler(options)).toThrow('does not support revoking tokens'); + }); + + it('creates handler if provider supports token revocation', () => { + const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; + expect(() => revocationHandler(options)).not.toThrow(); + }); + }); + + describe('Request handling', () => { + let app: express.Express; + let spyRevokeToken: MockInstance; + + beforeEach(() => { + // Setup express app with revocation handler + app = express(); + const options: RevocationHandlerOptions = { provider: mockProviderWithRevocation }; + app.use('/revoke', revocationHandler(options)); + + // Spy on the revokeToken method + spyRevokeToken = vi.spyOn(mockProviderWithRevocation, 'revokeToken'); + }); + + afterEach(() => { + spyRevokeToken.mockRestore(); + }); + + it('requires POST method', async () => { + const response = await supertest(app).get('/revoke').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('POST'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method GET is not allowed for this endpoint' + }); + expect(spyRevokeToken).not.toHaveBeenCalled(); + }); + + it('requires token parameter', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret' + // Missing token + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + expect(spyRevokeToken).not.toHaveBeenCalled(); + }); + + it('authenticates client before revoking token', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'invalid-client', + client_secret: 'wrong-secret', + token: 'token_to_revoke' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(spyRevokeToken).not.toHaveBeenCalled(); + }); + + it('successfully revokes token', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + + expect(response.status).toBe(200); + expect(response.body).toEqual({}); // Empty response on success + expect(spyRevokeToken).toHaveBeenCalledTimes(1); + expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { + token: 'token_to_revoke' + }); + }); + + it('accepts optional token_type_hint', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke', + token_type_hint: 'refresh_token' + }); + + expect(response.status).toBe(200); + expect(spyRevokeToken).toHaveBeenCalledWith(validClient, { + token: 'token_to_revoke', + token_type_hint: 'refresh_token' + }); + }); + + it('includes CORS headers in response', async () => { + const response = await supertest(app).post('/revoke').type('form').set('Origin', 'https://example.com').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + + expect(response.header['access-control-allow-origin']).toBe('*'); + }); + }); +}); diff --git a/packages/server-auth-legacy/test/handlers/token.test.ts b/packages/server-auth-legacy/test/handlers/token.test.ts new file mode 100644 index 000000000..8d5634922 --- /dev/null +++ b/packages/server-auth-legacy/test/handlers/token.test.ts @@ -0,0 +1,479 @@ +import { tokenHandler, TokenHandlerOptions } from '../../src/handlers/token.js'; +import { OAuthServerProvider, AuthorizationParams } from '../../src/provider.js'; +import { OAuthRegisteredClientsStore } from '../../src/clients.js'; +import { OAuthClientInformationFull, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; +import express, { Response } from 'express'; +import supertest from 'supertest'; +import * as pkceChallenge from 'pkce-challenge'; +import { InvalidGrantError, InvalidTokenError } from '../../src/errors.js'; +import { AuthInfo } from '../../src/types.js'; +import { ProxyOAuthServerProvider } from '../../src/providers/proxyProvider.js'; +import { type Mock } from 'vitest'; + +// Mock pkce-challenge +vi.mock('pkce-challenge', () => ({ + verifyChallenge: vi.fn().mockImplementation(async (verifier, challenge) => { + return verifier === 'valid_verifier' && challenge === 'mock_challenge'; + }) +})); + +const mockTokens = { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' +}; + +const mockTokensWithIdToken = { + ...mockTokens, + id_token: 'mock_id_token' +}; + +describe('Token Handler', () => { + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return validClient; + } + return undefined; + } + }; + + // Mock provider + let mockProvider: OAuthServerProvider; + let app: express.Express; + + beforeEach(() => { + // Create fresh mocks for each test + mockProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + res.redirect('https://example.com/callback?code=mock_auth_code'); + }, + + async challengeForAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { + if (authorizationCode === 'valid_code') { + return 'mock_challenge'; + } else if (authorizationCode === 'expired_code') { + throw new InvalidGrantError('The authorization code has expired'); + } + throw new InvalidGrantError('The authorization code is invalid'); + }, + + async exchangeAuthorizationCode(client: OAuthClientInformationFull, authorizationCode: string): Promise { + if (authorizationCode === 'valid_code') { + return mockTokens; + } + throw new InvalidGrantError('The authorization code is invalid or has expired'); + }, + + async exchangeRefreshToken(client: OAuthClientInformationFull, refreshToken: string, scopes?: string[]): Promise { + if (refreshToken === 'valid_refresh_token') { + const response: OAuthTokens = { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + + if (scopes) { + response.scope = scopes.join(' '); + } + + return response; + } + throw new InvalidGrantError('The refresh token is invalid or has expired'); + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, + + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // Do nothing in mock + } + }; + + // Mock PKCE verification + (pkceChallenge.verifyChallenge as Mock).mockImplementation(async (verifier: string, challenge: string) => { + return verifier === 'valid_verifier' && challenge === 'mock_challenge'; + }); + + // Setup express app with token handler + app = express(); + const options: TokenHandlerOptions = { provider: mockProvider }; + app.use('/token', tokenHandler(options)); + }); + + describe('Basic request validation', () => { + it('requires POST method', async () => { + const response = await supertest(app).get('/token').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code' + }); + + expect(response.status).toBe(405); + expect(response.headers.allow).toBe('POST'); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: 'The method GET is not allowed for this endpoint' + }); + }); + + it('requires grant_type parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret' + // Missing grant_type + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + + it('rejects unsupported grant types', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'password' // Unsupported grant type + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('unsupported_grant_type'); + }); + }); + + describe('Client authentication', () => { + it('requires valid client credentials', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'invalid-client', + client_secret: 'wrong-secret', + grant_type: 'authorization_code' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + }); + + it('accepts valid client credentials', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(200); + }); + }); + + describe('Authorization code grant', () => { + it('requires code parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + // Missing code + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + + it('requires code_verifier parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code' + // Missing code_verifier + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + + it('verifies code_verifier against challenge', async () => { + // Setup invalid verifier + (pkceChallenge.verifyChallenge as Mock).mockResolvedValueOnce(false); + + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'invalid_verifier' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_grant'); + expect(response.body.error_description).toContain('code_verifier'); + }); + + it('rejects expired or invalid authorization codes', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'expired_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_grant'); + }); + + it('returns tokens for valid code exchange', async () => { + const mockExchangeCode = vi.spyOn(mockProvider, 'exchangeAuthorizationCode'); + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + resource: 'https://api.example.com/resource', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + expect(response.body.token_type).toBe('bearer'); + expect(response.body.expires_in).toBe(3600); + expect(response.body.refresh_token).toBe('mock_refresh_token'); + expect(mockExchangeCode).toHaveBeenCalledWith( + validClient, + 'valid_code', + undefined, // code_verifier is undefined after PKCE validation + undefined, // redirect_uri + new URL('https://api.example.com/resource') // resource parameter + ); + }); + + it('returns id token in code exchange if provided', async () => { + mockProvider.exchangeAuthorizationCode = async ( + client: OAuthClientInformationFull, + authorizationCode: string + ): Promise => { + if (authorizationCode === 'valid_code') { + return mockTokensWithIdToken; + } + throw new InvalidGrantError('The authorization code is invalid or has expired'); + }; + + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.status).toBe(200); + expect(response.body.id_token).toBe('mock_id_token'); + }); + + it('passes through code verifier when using proxy provider', async () => { + const originalFetch = global.fetch; + + try { + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve(mockTokens) + }); + + const proxyProvider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: 'https://example.com/authorize', + tokenUrl: 'https://example.com/token' + }, + verifyAccessToken: async token => ({ + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }), + getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) + }); + + const proxyApp = express(); + const options: TokenHandlerOptions = { provider: proxyProvider }; + proxyApp.use('/token', tokenHandler(options)); + + const response = await supertest(proxyApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'any_verifier', + redirect_uri: 'https://example.com/callback' + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('code_verifier=any_verifier') + }) + ); + } finally { + global.fetch = originalFetch; + } + }); + + it('passes through redirect_uri when using proxy provider', async () => { + const originalFetch = global.fetch; + + try { + global.fetch = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve(mockTokens) + }); + + const proxyProvider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: 'https://example.com/authorize', + tokenUrl: 'https://example.com/token' + }, + verifyAccessToken: async token => ({ + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }), + getClient: async clientId => (clientId === 'valid-client' ? validClient : undefined) + }); + + const proxyApp = express(); + const options: TokenHandlerOptions = { provider: proxyProvider }; + proxyApp.use('/token', tokenHandler(options)); + + const redirectUri = 'https://example.com/callback'; + const response = await supertest(proxyApp).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'any_verifier', + redirect_uri: redirectUri + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('mock_access_token'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) + }) + ); + } finally { + global.fetch = originalFetch; + } + }); + }); + + describe('Refresh token grant', () => { + it('requires refresh_token parameter', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'refresh_token' + // Missing refresh_token + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + + it('rejects invalid refresh tokens', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'refresh_token', + refresh_token: 'invalid_refresh_token' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_grant'); + }); + + it('returns new tokens for valid refresh token', async () => { + const mockExchangeRefresh = vi.spyOn(mockProvider, 'exchangeRefreshToken'); + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + resource: 'https://api.example.com/resource', + grant_type: 'refresh_token', + refresh_token: 'valid_refresh_token' + }); + + expect(response.status).toBe(200); + expect(response.body.access_token).toBe('new_mock_access_token'); + expect(response.body.token_type).toBe('bearer'); + expect(response.body.expires_in).toBe(3600); + expect(response.body.refresh_token).toBe('new_mock_refresh_token'); + expect(mockExchangeRefresh).toHaveBeenCalledWith( + validClient, + 'valid_refresh_token', + undefined, // scopes + new URL('https://api.example.com/resource') // resource parameter + ); + }); + + it('respects requested scopes on refresh', async () => { + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'refresh_token', + refresh_token: 'valid_refresh_token', + scope: 'profile email' + }); + + expect(response.status).toBe(200); + expect(response.body.scope).toBe('profile email'); + }); + }); + + describe('CORS support', () => { + it('includes CORS headers in response', async () => { + const response = await supertest(app).post('/token').type('form').set('Origin', 'https://example.com').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + expect(response.header['access-control-allow-origin']).toBe('*'); + }); + }); +}); diff --git a/packages/server-auth-legacy/test/helpers/http.ts b/packages/server-auth-legacy/test/helpers/http.ts new file mode 100644 index 000000000..98846621a --- /dev/null +++ b/packages/server-auth-legacy/test/helpers/http.ts @@ -0,0 +1,56 @@ +import type { Response } from 'express'; +import { vi } from 'vitest'; + +/** + * Create a minimal Express-like Response mock for tests. + * + * The mock supports: + * - redirect() + * - status().json().send() chaining + * - set()/header() + * - optional getRedirectUrl() helper used in some tests + */ +export function createExpressResponseMock(options: { trackRedirectUrl?: boolean } = {}): Response & { + getRedirectUrl?: () => string; +} { + let capturedRedirectUrl: string | undefined; + + const res: Partial & { getRedirectUrl?: () => string } = { + redirect: vi.fn((urlOrStatus: string | number, maybeUrl?: string | number) => { + if (options.trackRedirectUrl) { + if (typeof urlOrStatus === 'string') { + capturedRedirectUrl = urlOrStatus; + } else if (typeof maybeUrl === 'string') { + capturedRedirectUrl = maybeUrl; + } + } + return res as Response; + }) as unknown as Response['redirect'], + status: vi.fn().mockImplementation((_code: number) => { + return res as Response; + }), + json: vi.fn().mockImplementation((_body: unknown) => { + return res as Response; + }), + send: vi.fn().mockImplementation((_body?: unknown) => { + return res as Response; + }), + set: vi.fn().mockImplementation((_field: string, _value?: string | string[]) => { + return res as Response; + }), + header: vi.fn().mockImplementation((_field: string, _value?: string | string[]) => { + return res as Response; + }) + }; + + if (options.trackRedirectUrl) { + res.getRedirectUrl = () => { + if (capturedRedirectUrl === undefined) { + throw new Error('No redirect URL was captured. Ensure redirect() was called first.'); + } + return capturedRedirectUrl; + }; + } + + return res as Response & { getRedirectUrl?: () => string }; +} diff --git a/packages/server-auth-legacy/test/index.test.ts b/packages/server-auth-legacy/test/index.test.ts new file mode 100644 index 000000000..16276eec6 --- /dev/null +++ b/packages/server-auth-legacy/test/index.test.ts @@ -0,0 +1,70 @@ +import express from 'express'; +import request from 'supertest'; +import { describe, expect, it } from 'vitest'; + +import { + type AuthInfo, + type AuthRouterOptions, + InvalidTokenError, + mcpAuthRouter, + OAuthError, + type OAuthRegisteredClientsStore, + type OAuthServerProvider, + ProxyOAuthServerProvider, + ServerError +} from '../src/index.js'; + +describe('@modelcontextprotocol/server-auth-legacy (frozen v1 compat)', () => { + it('exports the v1 OAuthError subclass hierarchy', () => { + const err = new InvalidTokenError('bad token'); + expect(err).toBeInstanceOf(OAuthError); + expect(err.errorCode).toBe('invalid_token'); + expect(err.toResponseObject()).toEqual({ + error: 'invalid_token', + error_description: 'bad token' + }); + }); + + it('exports ProxyOAuthServerProvider', () => { + const provider = new ProxyOAuthServerProvider({ + endpoints: { + authorizationUrl: 'https://upstream.example/authorize', + tokenUrl: 'https://upstream.example/token' + }, + verifyAccessToken: async token => ({ token, clientId: 'c', scopes: [] }) satisfies AuthInfo, + getClient: async () => undefined + }); + expect(provider.skipLocalPkceValidation).toBe(true); + expect(provider.clientsStore.getClient).toBeTypeOf('function'); + }); + + it('mcpAuthRouter wires up /authorize, /token and AS metadata', async () => { + const clientsStore: OAuthRegisteredClientsStore = { + getClient: () => undefined + }; + const provider: OAuthServerProvider = { + clientsStore, + authorize: async () => { + throw new ServerError('not implemented'); + }, + challengeForAuthorizationCode: async () => 'challenge', + exchangeAuthorizationCode: async () => ({ access_token: 't', token_type: 'Bearer' }), + exchangeRefreshToken: async () => ({ access_token: 't', token_type: 'Bearer' }), + verifyAccessToken: async token => ({ token, clientId: 'c', scopes: [] }) + }; + + const options: AuthRouterOptions = { + provider, + issuerUrl: new URL('http://localhost/') + }; + + const app = express(); + app.use(mcpAuthRouter(options)); + + const res = await request(app).get('/.well-known/oauth-authorization-server'); + expect(res.status).toBe(200); + expect(res.body.issuer).toBe('http://localhost/'); + expect(res.body.authorization_endpoint).toBe('http://localhost/authorize'); + expect(res.body.token_endpoint).toBe('http://localhost/token'); + }); +}); diff --git a/packages/server-auth-legacy/test/middleware/allowedMethods.test.ts b/packages/server-auth-legacy/test/middleware/allowedMethods.test.ts new file mode 100644 index 000000000..d8d0fa63d --- /dev/null +++ b/packages/server-auth-legacy/test/middleware/allowedMethods.test.ts @@ -0,0 +1,75 @@ +import { allowedMethods } from '../../src/middleware/allowedMethods.js'; +import express, { Request, Response } from 'express'; +import request from 'supertest'; + +describe('allowedMethods', () => { + let app: express.Express; + + beforeEach(() => { + app = express(); + + // Set up a test router with a GET handler and 405 middleware + const router = express.Router(); + + router.get('/test', (req, res) => { + res.status(200).send('GET success'); + }); + + // Add method not allowed middleware for all other methods + router.all('/test', allowedMethods(['GET'])); + + app.use(router); + }); + + test('allows specified HTTP method', async () => { + const response = await request(app).get('/test'); + expect(response.status).toBe(200); + expect(response.text).toBe('GET success'); + }); + + test('returns 405 for unspecified HTTP methods', async () => { + const methods = ['post', 'put', 'delete', 'patch']; + + for (const method of methods) { + // @ts-expect-error - dynamic method call + const response = await request(app)[method]('/test'); + expect(response.status).toBe(405); + expect(response.body).toEqual({ + error: 'method_not_allowed', + error_description: `The method ${method.toUpperCase()} is not allowed for this endpoint` + }); + } + }); + + test('includes Allow header with specified methods', async () => { + const response = await request(app).post('/test'); + expect(response.headers.allow).toBe('GET'); + }); + + test('works with multiple allowed methods', async () => { + const multiMethodApp = express(); + const router = express.Router(); + + router.get('/multi', (req: Request, res: Response) => { + res.status(200).send('GET'); + }); + router.post('/multi', (req: Request, res: Response) => { + res.status(200).send('POST'); + }); + router.all('/multi', allowedMethods(['GET', 'POST'])); + + multiMethodApp.use(router); + + // Allowed methods should work + const getResponse = await request(multiMethodApp).get('/multi'); + expect(getResponse.status).toBe(200); + + const postResponse = await request(multiMethodApp).post('/multi'); + expect(postResponse.status).toBe(200); + + // Unallowed methods should return 405 + const putResponse = await request(multiMethodApp).put('/multi'); + expect(putResponse.status).toBe(405); + expect(putResponse.headers.allow).toBe('GET, POST'); + }); +}); diff --git a/packages/server-auth-legacy/test/middleware/bearerAuth.test.ts b/packages/server-auth-legacy/test/middleware/bearerAuth.test.ts new file mode 100644 index 000000000..451f49112 --- /dev/null +++ b/packages/server-auth-legacy/test/middleware/bearerAuth.test.ts @@ -0,0 +1,501 @@ +import { Request, Response } from 'express'; +import { Mock } from 'vitest'; +import { requireBearerAuth } from '../../src/middleware/bearerAuth.js'; +import { AuthInfo } from '../../src/types.js'; +import { InsufficientScopeError, InvalidTokenError, CustomOAuthError, ServerError } from '../../src/errors.js'; +import { OAuthTokenVerifier } from '../../src/provider.js'; +import { createExpressResponseMock } from '../helpers/http.js'; + +// Mock verifier +const mockVerifyAccessToken = vi.fn(); +const mockVerifier: OAuthTokenVerifier = { + verifyAccessToken: mockVerifyAccessToken +}; + +describe('requireBearerAuth middleware', () => { + let mockRequest: Partial; + let mockResponse: Partial; + let nextFunction: Mock; + + beforeEach(() => { + mockRequest = { + headers: {} + }; + mockResponse = createExpressResponseMock(); + nextFunction = vi.fn(); + vi.spyOn(console, 'error').mockImplementation(() => {}); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should call next when token is valid', async () => { + const validAuthInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour + }; + mockVerifyAccessToken.mockResolvedValue(validAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockRequest.auth).toEqual(validAuthInfo); + expect(nextFunction).toHaveBeenCalled(); + expect(mockResponse.status).not.toHaveBeenCalled(); + expect(mockResponse.json).not.toHaveBeenCalled(); + }); + + it.each([ + [100], // Token expired 100 seconds ago + [0] // Token expires at the same time as now + ])('should reject expired tokens (expired %s seconds ago)', async (expiredSecondsAgo: number) => { + const expiresAt = Math.floor(Date.now() / 1000) - expiredSecondsAgo; + const expiredAuthInfo: AuthInfo = { + token: 'expired-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt + }; + mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer expired-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Token has expired' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it.each([ + [undefined], // Token has no expiration time + [NaN] // Token has no expiration time + ])('should reject tokens with no expiration time (expiresAt: %s)', async (expiresAt: number | undefined) => { + const noExpirationAuthInfo: AuthInfo = { + token: 'no-expiration-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt + }; + mockVerifyAccessToken.mockResolvedValue(noExpirationAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer expired-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('expired-token'); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Token has no expiration time' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should accept non-expired tokens', async () => { + const nonExpiredAuthInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour + }; + mockVerifyAccessToken.mockResolvedValue(nonExpiredAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockRequest.auth).toEqual(nonExpiredAuthInfo); + expect(nextFunction).toHaveBeenCalled(); + expect(mockResponse.status).not.toHaveBeenCalled(); + expect(mockResponse.json).not.toHaveBeenCalled(); + }); + + it('should require specific scopes when configured', async () => { + const authInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read'] + }; + mockVerifyAccessToken.mockResolvedValue(authInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'] + }); + + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'insufficient_scope', error_description: 'Insufficient scope' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should accept token with all required scopes', async () => { + const authInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read', 'write', 'admin'], + expiresAt: Math.floor(Date.now() / 1000) + 3600 // Token expires in an hour + }; + mockVerifyAccessToken.mockResolvedValue(authInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'] + }); + + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockRequest.auth).toEqual(authInfo); + expect(nextFunction).toHaveBeenCalled(); + expect(mockResponse.status).not.toHaveBeenCalled(); + expect(mockResponse.json).not.toHaveBeenCalled(); + }); + + it('should return 401 when no Authorization header is present', async () => { + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).not.toHaveBeenCalled(); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Missing Authorization header' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should return 401 when Authorization header format is invalid', async () => { + mockRequest.headers = { + authorization: 'InvalidFormat' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).not.toHaveBeenCalled(); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ + error: 'invalid_token', + error_description: "Invalid Authorization header format, expected 'Bearer TOKEN'" + }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should return 401 when token verification fails with InvalidTokenError', async () => { + mockRequest.headers = { + authorization: 'Bearer invalid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('invalid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="invalid_token"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'invalid_token', error_description: 'Token expired' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should return 403 when access token has insufficient scopes', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: read, write')); + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith('WWW-Authenticate', expect.stringContaining('Bearer error="insufficient_scope"')); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'insufficient_scope', error_description: 'Required scopes: read, write' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should return 500 when a ServerError occurs', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(500); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'server_error', error_description: 'Internal server issue' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should return 400 for generic OAuthError', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new CustomOAuthError('custom_error', 'Some OAuth error')); + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(400); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'custom_error', error_description: 'Some OAuth error' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should return 500 when unexpected error occurs', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new Error('Unexpected error')); + + const middleware = requireBearerAuth({ verifier: mockVerifier }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockVerifyAccessToken).toHaveBeenCalledWith('valid-token'); + expect(mockResponse.status).toHaveBeenCalledWith(500); + expect(mockResponse.json).toHaveBeenCalledWith( + expect.objectContaining({ error: 'server_error', error_description: 'Internal Server Error' }) + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + describe('with requiredScopes in WWW-Authenticate header', () => { + it('should include scope in WWW-Authenticate header for 401 responses when requiredScopes is provided', async () => { + mockRequest.headers = {}; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'] + }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + 'Bearer error="invalid_token", error_description="Missing Authorization header", scope="read write"' + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include scope in WWW-Authenticate header for 403 insufficient scope responses', async () => { + const authInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read'] + }; + mockVerifyAccessToken.mockResolvedValue(authInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'] + }); + + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + 'Bearer error="insufficient_scope", error_description="Insufficient scope", scope="read write"' + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include both scope and resource_metadata in WWW-Authenticate header when both are provided', async () => { + mockRequest.headers = {}; + + const resourceMetadataUrl = 'https://api.example.com/.well-known/oauth-protected-resource'; + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['admin'], + resourceMetadataUrl + }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="invalid_token", error_description="Missing Authorization header", scope="admin", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + }); + + describe('with resourceMetadataUrl', () => { + const resourceMetadataUrl = 'https://api.example.com/.well-known/oauth-protected-resource'; + + it('should include resource_metadata in WWW-Authenticate header for 401 responses', async () => { + mockRequest.headers = {}; + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata in WWW-Authenticate header when token verification fails', async () => { + mockRequest.headers = { + authorization: 'Bearer invalid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new InvalidTokenError('Token expired')); + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="invalid_token", error_description="Token expired", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata in WWW-Authenticate header for insufficient scope errors', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new InsufficientScopeError('Required scopes: admin')); + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="insufficient_scope", error_description="Required scopes: admin", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata when token is expired', async () => { + const expiredAuthInfo: AuthInfo = { + token: 'expired-token', + clientId: 'client-123', + scopes: ['read', 'write'], + expiresAt: Math.floor(Date.now() / 1000) - 100 + }; + mockVerifyAccessToken.mockResolvedValue(expiredAuthInfo); + + mockRequest.headers = { + authorization: 'Bearer expired-token' + }; + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(401); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="invalid_token", error_description="Token has expired", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should include resource_metadata when scope check fails', async () => { + const authInfo: AuthInfo = { + token: 'valid-token', + clientId: 'client-123', + scopes: ['read'] + }; + mockVerifyAccessToken.mockResolvedValue(authInfo); + + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + const middleware = requireBearerAuth({ + verifier: mockVerifier, + requiredScopes: ['read', 'write'], + resourceMetadataUrl + }); + + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(403); + expect(mockResponse.set).toHaveBeenCalledWith( + 'WWW-Authenticate', + `Bearer error="insufficient_scope", error_description="Insufficient scope", scope="read write", resource_metadata="${resourceMetadataUrl}"` + ); + expect(nextFunction).not.toHaveBeenCalled(); + }); + + it('should not affect server errors (no WWW-Authenticate header)', async () => { + mockRequest.headers = { + authorization: 'Bearer valid-token' + }; + + mockVerifyAccessToken.mockRejectedValue(new ServerError('Internal server issue')); + + const middleware = requireBearerAuth({ verifier: mockVerifier, resourceMetadataUrl }); + await middleware(mockRequest as Request, mockResponse as Response, nextFunction); + + expect(mockResponse.status).toHaveBeenCalledWith(500); + expect(mockResponse.set).not.toHaveBeenCalledWith('WWW-Authenticate', expect.anything()); + expect(nextFunction).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/server-auth-legacy/test/middleware/clientAuth.test.ts b/packages/server-auth-legacy/test/middleware/clientAuth.test.ts new file mode 100644 index 000000000..265216810 --- /dev/null +++ b/packages/server-auth-legacy/test/middleware/clientAuth.test.ts @@ -0,0 +1,132 @@ +import { authenticateClient, ClientAuthenticationMiddlewareOptions } from '../../src/middleware/clientAuth.js'; +import { OAuthRegisteredClientsStore } from '../../src/clients.js'; +import { OAuthClientInformationFull } from '@modelcontextprotocol/core'; +import express from 'express'; +import supertest from 'supertest'; + +describe('clientAuth middleware', () => { + // Mock client store + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + } else if (clientId === 'expired-client') { + // Client with no secret + return { + client_id: 'expired-client', + redirect_uris: ['https://example.com/callback'] + }; + } else if (clientId === 'client-with-expired-secret') { + // Client with an expired secret + return { + client_id: 'client-with-expired-secret', + client_secret: 'expired-secret', + client_secret_expires_at: Math.floor(Date.now() / 1000) - 3600, // Expired 1 hour ago + redirect_uris: ['https://example.com/callback'] + }; + } + return undefined; + } + }; + + // Setup Express app with middleware + let app: express.Express; + let options: ClientAuthenticationMiddlewareOptions; + + beforeEach(() => { + app = express(); + app.use(express.json()); + + options = { + clientsStore: mockClientStore + }; + + // Setup route with client auth + app.post('/protected', authenticateClient(options), (req, res) => { + res.status(200).json({ success: true, client: req.client }); + }); + }); + + it('authenticates valid client credentials', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'valid-client', + client_secret: 'valid-secret' + }); + + expect(response.status).toBe(200); + expect(response.body.success).toBe(true); + expect(response.body.client.client_id).toBe('valid-client'); + }); + + it('rejects invalid client_id', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'non-existent-client', + client_secret: 'some-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(response.body.error_description).toBe('Invalid client_id'); + }); + + it('rejects invalid client_secret', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'valid-client', + client_secret: 'wrong-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(response.body.error_description).toBe('Invalid client_secret'); + }); + + it('rejects missing client_id', async () => { + const response = await supertest(app).post('/protected').send({ + client_secret: 'valid-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_request'); + }); + + it('allows missing client_secret if client has none', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'expired-client' + }); + + // Since the client has no secret, this should pass without providing one + expect(response.status).toBe(200); + }); + + it('rejects request when client secret has expired', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'client-with-expired-secret', + client_secret: 'expired-secret' + }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe('invalid_client'); + expect(response.body.error_description).toBe('Client secret has expired'); + }); + + it('handles malformed request body', async () => { + const response = await supertest(app).post('/protected').send('not-json-format'); + + expect(response.status).toBe(400); + }); + + // Testing request with extra fields to ensure they're ignored + it('ignores extra fields in request', async () => { + const response = await supertest(app).post('/protected').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + extra_field: 'should be ignored' + }); + + expect(response.status).toBe(200); + }); +}); diff --git a/packages/server-auth-legacy/test/providers/proxyProvider.test.ts b/packages/server-auth-legacy/test/providers/proxyProvider.test.ts new file mode 100644 index 000000000..2646f49a0 --- /dev/null +++ b/packages/server-auth-legacy/test/providers/proxyProvider.test.ts @@ -0,0 +1,344 @@ +import { Response } from 'express'; +import { ProxyOAuthServerProvider, ProxyOptions } from '../../src/providers/proxyProvider.js'; +import { AuthInfo } from '../../src/types.js'; +import { OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/core'; +import { ServerError } from '../../src/errors.js'; +import { InvalidTokenError } from '../../src/errors.js'; +import { InsufficientScopeError } from '../../src/errors.js'; +import { type Mock } from 'vitest'; + +describe('Proxy OAuth Server Provider', () => { + // Mock client data + const validClient: OAuthClientInformationFull = { + client_id: 'test-client', + client_secret: 'test-secret', + redirect_uris: ['https://example.com/callback'] + }; + + // Mock response object + const mockResponse = { + redirect: vi.fn() + } as unknown as Response; + + // Mock provider functions + const mockVerifyToken = vi.fn(); + const mockGetClient = vi.fn(); + + // Base provider options + const baseOptions: ProxyOptions = { + endpoints: { + authorizationUrl: 'https://auth.example.com/authorize', + tokenUrl: 'https://auth.example.com/token', + revocationUrl: 'https://auth.example.com/revoke', + registrationUrl: 'https://auth.example.com/register' + }, + verifyAccessToken: mockVerifyToken, + getClient: mockGetClient + }; + + let provider: ProxyOAuthServerProvider; + let originalFetch: typeof global.fetch; + + beforeEach(() => { + provider = new ProxyOAuthServerProvider(baseOptions); + originalFetch = global.fetch; + global.fetch = vi.fn(); + + // Setup mock implementations + mockVerifyToken.mockImplementation(async (token: string) => { + if (token === 'valid-token') { + return { + token, + clientId: 'test-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + } as AuthInfo; + } + throw new InvalidTokenError('Invalid token'); + }); + + mockGetClient.mockImplementation(async (clientId: string) => { + if (clientId === 'test-client') { + return validClient; + } + return undefined; + }); + }); + + // Add helper function for failed responses + const mockFailedResponse = () => { + (global.fetch as Mock).mockImplementation(() => + Promise.resolve({ + ok: false, + status: 400 + }) + ); + }; + + afterEach(() => { + global.fetch = originalFetch; + vi.clearAllMocks(); + }); + + describe('authorization', () => { + it('redirects to authorization endpoint with correct parameters', async () => { + await provider.authorize( + validClient, + { + redirectUri: 'https://example.com/callback', + codeChallenge: 'test-challenge', + state: 'test-state', + scopes: ['read', 'write'], + resource: new URL('https://api.example.com/resource') + }, + mockResponse + ); + + const expectedUrl = new URL('https://auth.example.com/authorize'); + expectedUrl.searchParams.set('client_id', 'test-client'); + expectedUrl.searchParams.set('response_type', 'code'); + expectedUrl.searchParams.set('redirect_uri', 'https://example.com/callback'); + expectedUrl.searchParams.set('code_challenge', 'test-challenge'); + expectedUrl.searchParams.set('code_challenge_method', 'S256'); + expectedUrl.searchParams.set('state', 'test-state'); + expectedUrl.searchParams.set('scope', 'read write'); + expectedUrl.searchParams.set('resource', 'https://api.example.com/resource'); + + expect(mockResponse.redirect).toHaveBeenCalledWith(expectedUrl.toString()); + }); + }); + + describe('token exchange', () => { + const mockTokenResponse: OAuthTokens = { + access_token: 'new-access-token', + token_type: 'Bearer', + expires_in: 3600, + refresh_token: 'new-refresh-token' + }; + + beforeEach(() => { + (global.fetch as Mock).mockImplementation(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(mockTokenResponse) + }) + ); + }); + + it('exchanges authorization code for tokens', async () => { + const tokens = await provider.exchangeAuthorizationCode(validClient, 'test-code', 'test-verifier'); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('grant_type=authorization_code') + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('includes redirect_uri in token request when provided', async () => { + const redirectUri = 'https://example.com/callback'; + const tokens = await provider.exchangeAuthorizationCode(validClient, 'test-code', 'test-verifier', redirectUri); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining(`redirect_uri=${encodeURIComponent(redirectUri)}`) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('includes resource parameter in authorization code exchange', async () => { + const tokens = await provider.exchangeAuthorizationCode( + validClient, + 'test-code', + 'test-verifier', + 'https://example.com/callback', + new URL('https://api.example.com/resource') + ); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('handles authorization code exchange without resource parameter', async () => { + const tokens = await provider.exchangeAuthorizationCode(validClient, 'test-code', 'test-verifier'); + + const fetchCall = (global.fetch as Mock).mock.calls[0]; + const body = fetchCall![1].body as string; + expect(body).not.toContain('resource='); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('exchanges refresh token for new tokens', async () => { + const tokens = await provider.exchangeRefreshToken(validClient, 'test-refresh-token', ['read', 'write']); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('grant_type=refresh_token') + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + + it('includes resource parameter in refresh token exchange', async () => { + const tokens = await provider.exchangeRefreshToken( + validClient, + 'test-refresh-token', + ['read', 'write'], + new URL('https://api.example.com/resource') + ); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/token', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('resource=' + encodeURIComponent('https://api.example.com/resource')) + }) + ); + expect(tokens).toEqual(mockTokenResponse); + }); + }); + + describe('client registration', () => { + it('registers new client', async () => { + const newClient: OAuthClientInformationFull = { + client_id: 'new-client', + redirect_uris: ['https://new-client.com/callback'] + }; + + (global.fetch as Mock).mockImplementation(() => + Promise.resolve({ + ok: true, + json: () => Promise.resolve(newClient) + }) + ); + + const result = await provider.clientsStore.registerClient!(newClient); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/register', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(newClient) + }) + ); + expect(result).toEqual(newClient); + }); + + it('handles registration failure', async () => { + mockFailedResponse(); + const newClient: OAuthClientInformationFull = { + client_id: 'new-client', + redirect_uris: ['https://new-client.com/callback'] + }; + + await expect(provider.clientsStore.registerClient!(newClient)).rejects.toThrow(ServerError); + }); + }); + + describe('token revocation', () => { + it('revokes token', async () => { + (global.fetch as Mock).mockImplementation(() => + Promise.resolve({ + ok: true + }) + ); + + await provider.revokeToken!(validClient, { + token: 'token-to-revoke', + token_type_hint: 'access_token' + }); + + expect(global.fetch).toHaveBeenCalledWith( + 'https://auth.example.com/revoke', + expect.objectContaining({ + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded' + }, + body: expect.stringContaining('token=token-to-revoke') + }) + ); + }); + + it('handles revocation failure', async () => { + mockFailedResponse(); + await expect( + provider.revokeToken!(validClient, { + token: 'invalid-token' + }) + ).rejects.toThrow(ServerError); + }); + }); + + describe('token verification', () => { + it('verifies valid token', async () => { + const validAuthInfo: AuthInfo = { + token: 'valid-token', + clientId: 'test-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + mockVerifyToken.mockResolvedValue(validAuthInfo); + + const authInfo = await provider.verifyAccessToken('valid-token'); + expect(authInfo).toEqual(validAuthInfo); + expect(mockVerifyToken).toHaveBeenCalledWith('valid-token'); + }); + + it('passes through InvalidTokenError', async () => { + const error = new InvalidTokenError('Token expired'); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken('invalid-token')).rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith('invalid-token'); + }); + + it('passes through InsufficientScopeError', async () => { + const error = new InsufficientScopeError('Required scopes: read, write'); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken('token-with-insufficient-scope')).rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith('token-with-insufficient-scope'); + }); + + it('passes through unexpected errors', async () => { + const error = new Error('Unexpected error'); + mockVerifyToken.mockRejectedValue(error); + + await expect(provider.verifyAccessToken('valid-token')).rejects.toBe(error); + expect(mockVerifyToken).toHaveBeenCalledWith('valid-token'); + }); + }); +}); diff --git a/packages/server-auth-legacy/test/router.test.ts b/packages/server-auth-legacy/test/router.test.ts new file mode 100644 index 000000000..4b44ef571 --- /dev/null +++ b/packages/server-auth-legacy/test/router.test.ts @@ -0,0 +1,463 @@ +import { mcpAuthRouter, AuthRouterOptions, mcpAuthMetadataRouter, AuthMetadataOptions } from '../src/router.js'; +import { OAuthServerProvider, AuthorizationParams } from '../src/provider.js'; +import { OAuthRegisteredClientsStore } from '../src/clients.js'; +import { OAuthClientInformationFull, OAuthMetadata, OAuthTokenRevocationRequest, OAuthTokens } from '@modelcontextprotocol/core'; +import express, { Response } from 'express'; +import supertest from 'supertest'; +import { AuthInfo } from '../src/types.js'; +import { InvalidTokenError } from '../src/errors.js'; + +describe('MCP Auth Router', () => { + // Setup mock provider with full capabilities + const mockClientStore: OAuthRegisteredClientsStore = { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + } + return undefined; + }, + + async registerClient(client: OAuthClientInformationFull): Promise { + return client; + } + }; + + const mockProvider: OAuthServerProvider = { + clientsStore: mockClientStore, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + const redirectUrl = new URL(params.redirectUri); + redirectUrl.searchParams.set('code', 'mock_auth_code'); + if (params.state) { + redirectUrl.searchParams.set('state', params.state); + } + res.redirect(302, redirectUrl.toString()); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read', 'write'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + }, + + async revokeToken(_client: OAuthClientInformationFull, _request: OAuthTokenRevocationRequest): Promise { + // Success - do nothing in mock + } + }; + + // Provider without registration and revocation + const mockProviderMinimal: OAuthServerProvider = { + clientsStore: { + async getClient(clientId: string): Promise { + if (clientId === 'valid-client') { + return { + client_id: 'valid-client', + client_secret: 'valid-secret', + redirect_uris: ['https://example.com/callback'] + }; + } + return undefined; + } + }, + + async authorize(client: OAuthClientInformationFull, params: AuthorizationParams, res: Response): Promise { + const redirectUrl = new URL(params.redirectUri); + redirectUrl.searchParams.set('code', 'mock_auth_code'); + if (params.state) { + redirectUrl.searchParams.set('state', params.state); + } + res.redirect(302, redirectUrl.toString()); + }, + + async challengeForAuthorizationCode(): Promise { + return 'mock_challenge'; + }, + + async exchangeAuthorizationCode(): Promise { + return { + access_token: 'mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'mock_refresh_token' + }; + }, + + async exchangeRefreshToken(): Promise { + return { + access_token: 'new_mock_access_token', + token_type: 'bearer', + expires_in: 3600, + refresh_token: 'new_mock_refresh_token' + }; + }, + + async verifyAccessToken(token: string): Promise { + if (token === 'valid_token') { + return { + token, + clientId: 'valid-client', + scopes: ['read'], + expiresAt: Date.now() / 1000 + 3600 + }; + } + throw new InvalidTokenError('Token is invalid or expired'); + } + }; + + describe('Router creation', () => { + it('throws error for non-HTTPS issuer URL', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('http://auth.example.com') + }; + + expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must be HTTPS'); + }); + + it('allows localhost HTTP for development', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('http://localhost:3000') + }; + + expect(() => mcpAuthRouter(options)).not.toThrow(); + }); + + it('throws error for issuer URL with fragment', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com#fragment') + }; + + expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a fragment'); + }); + + it('throws error for issuer URL with query string', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com?param=value') + }; + + expect(() => mcpAuthRouter(options)).toThrow('Issuer URL must not have a query string'); + }); + + it('successfully creates router with valid options', () => { + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com') + }; + + expect(() => mcpAuthRouter(options)).not.toThrow(); + }); + }); + + describe('Metadata endpoint', () => { + let app: express.Express; + + beforeEach(() => { + // Setup full-featured router + app = express(); + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com'), + serviceDocumentationUrl: new URL('https://docs.example.com') + }; + app.use(mcpAuthRouter(options)); + }); + + it('returns complete metadata for full-featured router', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server'); + + expect(response.status).toBe(200); + + // Verify essential fields + expect(response.body.issuer).toBe('https://auth.example.com/'); + expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); + expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); + expect(response.body.registration_endpoint).toBe('https://auth.example.com/register'); + expect(response.body.revocation_endpoint).toBe('https://auth.example.com/revoke'); + + // Verify supported features + expect(response.body.response_types_supported).toEqual(['code']); + expect(response.body.grant_types_supported).toEqual(['authorization_code', 'refresh_token']); + expect(response.body.code_challenge_methods_supported).toEqual(['S256']); + expect(response.body.token_endpoint_auth_methods_supported).toEqual(['client_secret_post', 'none']); + expect(response.body.revocation_endpoint_auth_methods_supported).toEqual(['client_secret_post']); + + // Verify optional fields + expect(response.body.service_documentation).toBe('https://docs.example.com/'); + }); + + it('returns minimal metadata for minimal router', async () => { + // Setup minimal router + const minimalApp = express(); + const options: AuthRouterOptions = { + provider: mockProviderMinimal, + issuerUrl: new URL('https://auth.example.com') + }; + minimalApp.use(mcpAuthRouter(options)); + + const response = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); + + expect(response.status).toBe(200); + + // Verify essential endpoints + expect(response.body.issuer).toBe('https://auth.example.com/'); + expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); + expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); + + // Verify missing optional endpoints + expect(response.body.registration_endpoint).toBeUndefined(); + expect(response.body.revocation_endpoint).toBeUndefined(); + expect(response.body.revocation_endpoint_auth_methods_supported).toBeUndefined(); + expect(response.body.service_documentation).toBeUndefined(); + }); + + it('provides protected resource metadata', async () => { + // Setup router with draft protocol version + const draftApp = express(); + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://mcp.example.com'), + scopesSupported: ['read', 'write'], + resourceName: 'Test API' + }; + draftApp.use(mcpAuthRouter(options)); + + const response = await supertest(draftApp).get('/.well-known/oauth-protected-resource'); + + expect(response.status).toBe(200); + + // Verify protected resource metadata + expect(response.body.resource).toBe('https://mcp.example.com/'); + expect(response.body.authorization_servers).toContain('https://mcp.example.com/'); + expect(response.body.scopes_supported).toEqual(['read', 'write']); + expect(response.body.resource_name).toBe('Test API'); + }); + }); + + describe('Endpoint routing', () => { + let app: express.Express; + + beforeEach(() => { + // Setup full-featured router + app = express(); + const options: AuthRouterOptions = { + provider: mockProvider, + issuerUrl: new URL('https://auth.example.com') + }; + app.use(mcpAuthRouter(options)); + vi.spyOn(console, 'error').mockImplementation(() => {}); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('routes to authorization endpoint', async () => { + const response = await supertest(app).get('/authorize').query({ + client_id: 'valid-client', + response_type: 'code', + code_challenge: 'challenge123', + code_challenge_method: 'S256' + }); + + expect(response.status).toBe(302); + const location = new URL(response.header.location!); + expect(location.searchParams.has('code')).toBe(true); + }); + + it('routes to token endpoint', async () => { + // Setup verifyChallenge mock for token handler + vi.mock('pkce-challenge', () => ({ + verifyChallenge: vi.fn().mockResolvedValue(true) + })); + + const response = await supertest(app).post('/token').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + grant_type: 'authorization_code', + code: 'valid_code', + code_verifier: 'valid_verifier' + }); + + // The request will fail in testing due to mocking limitations, + // but we can verify the route was matched + expect(response.status).not.toBe(404); + }); + + it('routes to registration endpoint', async () => { + const response = await supertest(app) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + + // The request will fail in testing due to mocking limitations, + // but we can verify the route was matched + expect(response.status).not.toBe(404); + }); + + it('routes to revocation endpoint', async () => { + const response = await supertest(app).post('/revoke').type('form').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + + // The request will fail in testing due to mocking limitations, + // but we can verify the route was matched + expect(response.status).not.toBe(404); + }); + + it('excludes endpoints for unsupported features', async () => { + // Setup minimal router + const minimalApp = express(); + const options: AuthRouterOptions = { + provider: mockProviderMinimal, + issuerUrl: new URL('https://auth.example.com') + }; + minimalApp.use(mcpAuthRouter(options)); + + // Registration should not be available + const regResponse = await supertest(minimalApp) + .post('/register') + .send({ + redirect_uris: ['https://example.com/callback'] + }); + expect(regResponse.status).toBe(404); + + // Revocation should not be available + const revokeResponse = await supertest(minimalApp).post('/revoke').send({ + client_id: 'valid-client', + client_secret: 'valid-secret', + token: 'token_to_revoke' + }); + expect(revokeResponse.status).toBe(404); + }); + }); +}); + +describe('MCP Auth Metadata Router', () => { + const mockOAuthMetadata: OAuthMetadata = { + issuer: 'https://auth.example.com/', + authorization_endpoint: 'https://auth.example.com/authorize', + token_endpoint: 'https://auth.example.com/token', + response_types_supported: ['code'], + grant_types_supported: ['authorization_code', 'refresh_token'], + code_challenge_methods_supported: ['S256'], + token_endpoint_auth_methods_supported: ['client_secret_post'] + }; + + describe('Router creation', () => { + it('successfully creates router with valid options', () => { + const options: AuthMetadataOptions = { + oauthMetadata: mockOAuthMetadata, + resourceServerUrl: new URL('https://api.example.com') + }; + + expect(() => mcpAuthMetadataRouter(options)).not.toThrow(); + }); + }); + + describe('Metadata endpoints', () => { + let app: express.Express; + + beforeEach(() => { + app = express(); + const options: AuthMetadataOptions = { + oauthMetadata: mockOAuthMetadata, + resourceServerUrl: new URL('https://api.example.com'), + serviceDocumentationUrl: new URL('https://docs.example.com'), + scopesSupported: ['read', 'write'], + resourceName: 'Test API' + }; + app.use(mcpAuthMetadataRouter(options)); + }); + + it('returns OAuth authorization server metadata', async () => { + const response = await supertest(app).get('/.well-known/oauth-authorization-server'); + + expect(response.status).toBe(200); + + // Verify metadata points to authorization server + expect(response.body.issuer).toBe('https://auth.example.com/'); + expect(response.body.authorization_endpoint).toBe('https://auth.example.com/authorize'); + expect(response.body.token_endpoint).toBe('https://auth.example.com/token'); + expect(response.body.response_types_supported).toEqual(['code']); + expect(response.body.grant_types_supported).toEqual(['authorization_code', 'refresh_token']); + expect(response.body.code_challenge_methods_supported).toEqual(['S256']); + expect(response.body.token_endpoint_auth_methods_supported).toEqual(['client_secret_post']); + }); + + it('returns OAuth protected resource metadata', async () => { + const response = await supertest(app).get('/.well-known/oauth-protected-resource'); + + expect(response.status).toBe(200); + + // Verify protected resource metadata + expect(response.body.resource).toBe('https://api.example.com/'); + expect(response.body.authorization_servers).toEqual(['https://auth.example.com/']); + expect(response.body.scopes_supported).toEqual(['read', 'write']); + expect(response.body.resource_name).toBe('Test API'); + expect(response.body.resource_documentation).toBe('https://docs.example.com/'); + }); + + it('works with minimal configuration', async () => { + const minimalApp = express(); + const options: AuthMetadataOptions = { + oauthMetadata: mockOAuthMetadata, + resourceServerUrl: new URL('https://api.example.com') + }; + minimalApp.use(mcpAuthMetadataRouter(options)); + + const authResponse = await supertest(minimalApp).get('/.well-known/oauth-authorization-server'); + + expect(authResponse.status).toBe(200); + expect(authResponse.body.issuer).toBe('https://auth.example.com/'); + expect(authResponse.body.service_documentation).toBeUndefined(); + expect(authResponse.body.scopes_supported).toBeUndefined(); + + const resourceResponse = await supertest(minimalApp).get('/.well-known/oauth-protected-resource'); + + expect(resourceResponse.status).toBe(200); + expect(resourceResponse.body.resource).toBe('https://api.example.com/'); + expect(resourceResponse.body.authorization_servers).toEqual(['https://auth.example.com/']); + expect(resourceResponse.body.scopes_supported).toBeUndefined(); + expect(resourceResponse.body.resource_name).toBeUndefined(); + expect(resourceResponse.body.resource_documentation).toBeUndefined(); + }); + }); +}); diff --git a/packages/server-auth-legacy/tsconfig.json b/packages/server-auth-legacy/tsconfig.json new file mode 100644 index 000000000..18c1327cb --- /dev/null +++ b/packages/server-auth-legacy/tsconfig.json @@ -0,0 +1,12 @@ +{ + "extends": "@modelcontextprotocol/tsconfig", + "include": ["./"], + "exclude": ["node_modules", "dist"], + "compilerOptions": { + "paths": { + "*": ["./*"], + "@modelcontextprotocol/core": ["./node_modules/@modelcontextprotocol/core/src/index.ts"], + "@modelcontextprotocol/core/public": ["./node_modules/@modelcontextprotocol/core/src/exports/public/index.ts"] + } + } +} diff --git a/packages/server-auth-legacy/tsdown.config.ts b/packages/server-auth-legacy/tsdown.config.ts new file mode 100644 index 000000000..bc0ef8329 --- /dev/null +++ b/packages/server-auth-legacy/tsdown.config.ts @@ -0,0 +1,22 @@ +import { defineConfig } from 'tsdown'; + +export default defineConfig({ + failOnWarn: 'ci-only', + entry: ['src/index.ts'], + format: ['esm'], + outDir: 'dist', + clean: true, + sourcemap: true, + target: 'esnext', + platform: 'node', + shims: true, + dts: { + resolver: 'tsc', + compilerOptions: { + baseUrl: '.', + paths: { + '@modelcontextprotocol/core': ['../core/src/index.ts'] + } + } + } +}); diff --git a/packages/server-auth-legacy/typedoc.json b/packages/server-auth-legacy/typedoc.json new file mode 100644 index 000000000..a9fd090d0 --- /dev/null +++ b/packages/server-auth-legacy/typedoc.json @@ -0,0 +1,10 @@ +{ + "$schema": "https://typedoc.org/schema.json", + "entryPoints": ["src"], + "entryPointStrategy": "expand", + "exclude": ["**/*.test.ts", "**/__*__/**"], + "navigation": { + "includeGroups": true, + "includeCategories": true + } +} diff --git a/packages/server-auth-legacy/vitest.config.js b/packages/server-auth-legacy/vitest.config.js new file mode 100644 index 000000000..496fca320 --- /dev/null +++ b/packages/server-auth-legacy/vitest.config.js @@ -0,0 +1,3 @@ +import baseConfig from '@modelcontextprotocol/vitest-config'; + +export default baseConfig; diff --git a/packages/server/package.json b/packages/server/package.json index b40135ec9..c8195f459 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -22,28 +22,39 @@ "exports": { ".": { "types": "./dist/index.d.mts", - "import": "./dist/index.mjs" + "import": "./dist/index.mjs", + "require": "./dist/index.mjs" + }, + "./zod-schemas": { + "types": "./dist/zodSchemas.d.mts", + "import": "./dist/zodSchemas.mjs", + "require": "./dist/zodSchemas.mjs" }, "./validators/cf-worker": { "types": "./dist/validators/cfWorker.d.mts", - "import": "./dist/validators/cfWorker.mjs" + "import": "./dist/validators/cfWorker.mjs", + "require": "./dist/validators/cfWorker.mjs" }, "./_shims": { "workerd": { "types": "./dist/shimsWorkerd.d.mts", - "import": "./dist/shimsWorkerd.mjs" + "import": "./dist/shimsWorkerd.mjs", + "require": "./dist/shimsWorkerd.mjs" }, "browser": { "types": "./dist/shimsWorkerd.d.mts", - "import": "./dist/shimsWorkerd.mjs" + "import": "./dist/shimsWorkerd.mjs", + "require": "./dist/shimsWorkerd.mjs" }, "node": { "types": "./dist/shimsNode.d.mts", - "import": "./dist/shimsNode.mjs" + "import": "./dist/shimsNode.mjs", + "require": "./dist/shimsNode.mjs" }, "default": { "types": "./dist/shimsNode.d.mts", - "import": "./dist/shimsNode.mjs" + "import": "./dist/shimsNode.mjs", + "require": "./dist/shimsNode.mjs" } } }, @@ -87,5 +98,19 @@ "typescript": "catalog:devTools", "typescript-eslint": "catalog:devTools", "vitest": "catalog:devTools" + }, + "types": "./dist/index.d.mts", + "typesVersions": { + "*": { + "zod-schemas": [ + "./dist/zodSchemas.d.mts" + ], + "validators/cf-worker": [ + "./dist/validators/cfWorker.d.mts" + ], + "*": [ + "./dist/*.d.mts" + ] + } } } diff --git a/packages/server/src/experimental/tasks/mcpServer.ts b/packages/server/src/experimental/tasks/mcpServer.ts index b7c28c40d..730ebe28d 100644 --- a/packages/server/src/experimental/tasks/mcpServer.ts +++ b/packages/server/src/experimental/tasks/mcpServer.ts @@ -7,8 +7,9 @@ import type { StandardSchemaWithJSON, TaskToolExecution, ToolAnnotations, ToolExecution } from '@modelcontextprotocol/core'; -import type { AnyToolHandler, McpServer, RegisteredTool } from '../../server/mcp.js'; +import type { AnyToolHandler, McpServer, RegisteredTool } from '../../server/mcpServer.js'; import type { ToolTaskHandler } from './interfaces.js'; +import { ExperimentalServerTasks } from './server.js'; /** * Internal interface for accessing {@linkcode McpServer}'s private _createRegisteredTool method. @@ -38,8 +39,13 @@ interface McpServerInternal { * * @experimental */ -export class ExperimentalMcpServerTasks { - constructor(private readonly _mcpServer: McpServer) {} +export class ExperimentalMcpServerTasks extends ExperimentalServerTasks { + private readonly _mcpServer: McpServer; + + constructor(mcpServer: McpServer) { + super(mcpServer); + this._mcpServer = mcpServer; + } /** * Registers a task-based tool with a config object and handler. diff --git a/packages/server/src/experimental/tasks/server.ts b/packages/server/src/experimental/tasks/server.ts index 2e7b205fd..20fe8cd0a 100644 --- a/packages/server/src/experimental/tasks/server.ts +++ b/packages/server/src/experimental/tasks/server.ts @@ -24,7 +24,7 @@ import type { } from '@modelcontextprotocol/core'; import { getResultSchema, GetTaskPayloadResultSchema, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; -import type { Server } from '../../server/server.js'; +import type { McpServer as Server } from '../../server/mcpServer.js'; /** * Experimental task features for low-level MCP servers. diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index 6e1bba28d..f4c811c76 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -6,6 +6,7 @@ // // Any new export added here becomes public API. Use named exports, not wildcards. +export { Server } from './server/compat.js'; export type { CompletableSchema, CompleteCallback } from './server/completable.js'; export { completable, isCompletable } from './server/completable.js'; export type { @@ -21,13 +22,16 @@ export type { RegisteredResourceTemplate, RegisteredTool, ResourceMetadata, + ServerOptions, ToolCallback -} from './server/mcp.js'; -export { McpServer, ResourceTemplate } from './server/mcp.js'; +} from './server/mcpServer.js'; +export { McpServer, ResourceTemplate } from './server/mcpServer.js'; export type { HostHeaderValidationResult } from './server/middleware/hostHeaderValidation.js'; export { hostHeaderValidationResponse, localhostAllowedHostnames, validateHostHeader } from './server/middleware/hostHeaderValidation.js'; -export type { ServerOptions } from './server/server.js'; -export { Server } from './server/server.js'; +export type { SessionCompatOptions } from './server/sessionCompat.js'; +export { SessionCompat } from './server/sessionCompat.js'; +export type { ShttpHandlerOptions } from './server/shttpHandler.js'; +export { shttpHandler } from './server/shttpHandler.js'; export { StdioServerTransport } from './server/stdio.js'; export type { EventId, diff --git a/packages/server/src/server/backchannelCompat.ts b/packages/server/src/server/backchannelCompat.ts new file mode 100644 index 000000000..5445f3401 --- /dev/null +++ b/packages/server/src/server/backchannelCompat.ts @@ -0,0 +1,140 @@ +import type { + JSONRPCErrorResponse, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResultResponse, + Request, + RequestOptions, + Result +} from '@modelcontextprotocol/core'; +import { DEFAULT_REQUEST_TIMEOUT_MSEC, isJSONRPCErrorResponse, ProtocolError, SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; + +/** + * Isolated 2025-11 server-to-client request backchannel for {@linkcode shttpHandler}. + * + * The 2025-11 protocol allows a server to send `elicitation/create` and + * `sampling/createMessage` requests to the client mid-tool-call by writing them as + * SSE events on the open POST response stream and waiting for the client to POST + * the response back. This class owns the per-session `{requestId -> resolver}` + * map that correlation requires, plus the standalone-GET writer registry used for + * unsolicited server notifications. + * + * It exists so this stateful behaviour is in one removable file once MRTR + * (SEP-2322) is the protocol floor and `env.send` becomes a hard error in + * stateless paths. + */ +export class BackchannelCompat { + private _pending = new Map void; reject: (e: Error) => void }>>(); + private _standaloneWriters = new Map void>(); + private _nextId = 0; + + /** + * Returns an `env.send` implementation bound to the given session and POST-stream writer. + * The returned function writes the outbound JSON-RPC request to `writeSSE` and resolves when + * {@linkcode handleResponse} is called for the same id on the same session. + */ + makeEnvSend(sessionId: string, writeSSE: (msg: JSONRPCMessage) => void): (req: Request, opts?: RequestOptions) => Promise { + return (req: Request, opts?: RequestOptions): Promise => { + return new Promise((resolve, reject) => { + if (opts?.signal?.aborted) { + reject(opts.signal.reason instanceof Error ? opts.signal.reason : new Error(String(opts.signal.reason))); + return; + } + + const id = this._nextId++; + const sessionMap = this._pending.get(sessionId) ?? new Map(); + this._pending.set(sessionId, sessionMap); + + const timeoutMs = opts?.timeout ?? DEFAULT_REQUEST_TIMEOUT_MSEC; + const timer = setTimeout(() => { + sessionMap.delete(id); + reject(new SdkError(SdkErrorCode.RequestTimeout, 'Request timed out', { timeout: timeoutMs })); + }, timeoutMs); + + const settle = { + resolve: (r: Result) => { + clearTimeout(timer); + sessionMap.delete(id); + resolve(r); + }, + reject: (e: Error) => { + clearTimeout(timer); + sessionMap.delete(id); + reject(e); + } + }; + sessionMap.set(id, settle); + + opts?.signal?.addEventListener( + 'abort', + () => { + settle.reject(opts.signal!.reason instanceof Error ? opts.signal!.reason : new Error(String(opts.signal!.reason))); + }, + { once: true } + ); + + const wire: JSONRPCRequest = { jsonrpc: '2.0', id, method: req.method, params: req.params }; + try { + writeSSE(wire); + } catch (error) { + settle.reject(error instanceof Error ? error : new Error(String(error))); + } + }); + }; + } + + /** + * Routes an incoming JSON-RPC response (from a client POST) to the waiting `env.send` promise. + * @returns true if a pending request matched and was settled. + */ + handleResponse(sessionId: string, response: JSONRPCResultResponse | JSONRPCErrorResponse): boolean { + const sessionMap = this._pending.get(sessionId); + const id = typeof response.id === 'number' ? response.id : Number(response.id); + const settle = sessionMap?.get(id); + if (!settle) return false; + if (isJSONRPCErrorResponse(response)) { + settle.reject(ProtocolError.fromError(response.error.code, response.error.message, response.error.data)); + } else { + settle.resolve(response.result); + } + return true; + } + + /** + * Registers (or clears) the standalone GET subscription writer for a session, used to + * deliver server-initiated notifications outside any POST request. + */ + setStandaloneWriter(sessionId: string, write: ((msg: JSONRPCMessage) => void) | undefined): void { + if (write) this._standaloneWriters.set(sessionId, write); + else this._standaloneWriters.delete(sessionId); + } + + /** True if a standalone writer is registered for the session. */ + hasStandaloneWriter(sessionId: string): boolean { + return this._standaloneWriters.has(sessionId); + } + + /** Writes a message on the session's standalone GET stream, if one is open. */ + writeStandalone(sessionId: string, msg: JSONRPCMessage): boolean { + const w = this._standaloneWriters.get(sessionId); + if (!w) return false; + try { + w(msg); + return true; + } catch { + this._standaloneWriters.delete(sessionId); + return false; + } + } + + /** Rejects all pending requests for a session and forgets it. */ + closeSession(sessionId: string): void { + const sessionMap = this._pending.get(sessionId); + if (sessionMap) { + const err = new SdkError(SdkErrorCode.ConnectionClosed, 'Session closed'); + for (const s of sessionMap.values()) s.reject(err); + this._pending.delete(sessionId); + } + this._standaloneWriters.delete(sessionId); + } +} diff --git a/packages/server/src/server/compat.ts b/packages/server/src/server/compat.ts new file mode 100644 index 000000000..96cf4f266 --- /dev/null +++ b/packages/server/src/server/compat.ts @@ -0,0 +1,6 @@ +/** + * v1 compat alias. The low-level `Server` class is now the same as `McpServer`. + * @deprecated Import {@linkcode McpServer} from `./mcpServer.js` directly. + */ +export type { ServerOptions } from './mcpServer.js'; +export { McpServer as Server } from './mcpServer.js'; diff --git a/packages/server/src/server/mcp.ts b/packages/server/src/server/mcp.ts index 6c2699997..aa25709e8 100644 --- a/packages/server/src/server/mcp.ts +++ b/packages/server/src/server/mcp.ts @@ -1,1329 +1,5 @@ -import type { - BaseMetadata, - CallToolRequest, - CallToolResult, - CompleteRequestPrompt, - CompleteRequestResourceTemplate, - CompleteResult, - CreateTaskResult, - CreateTaskServerContext, - GetPromptResult, - Implementation, - ListPromptsResult, - ListResourcesResult, - ListToolsResult, - LoggingMessageNotification, - Prompt, - PromptReference, - ReadResourceResult, - Resource, - ResourceTemplateReference, - Result, - ServerContext, - StandardSchemaWithJSON, - Tool, - ToolAnnotations, - ToolExecution, - Transport, - Variables -} from '@modelcontextprotocol/core'; -import { - assertCompleteRequestPrompt, - assertCompleteRequestResourceTemplate, - promptArgumentsFromStandardSchema, - ProtocolError, - ProtocolErrorCode, - standardSchemaToJsonSchema, - UriTemplate, - validateAndWarnToolName, - validateStandardSchema -} from '@modelcontextprotocol/core'; - -import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; -import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; -import { getCompleter, isCompletable } from './completable.js'; -import type { ServerOptions } from './server.js'; -import { Server } from './server.js'; - /** - * High-level MCP server that provides a simpler API for working with resources, tools, and prompts. - * For advanced usage (like sending notifications or setting custom request handlers), use the underlying - * {@linkcode Server} instance available via the {@linkcode McpServer.server | server} property. - * - * @example - * ```ts source="./mcp.examples.ts#McpServer_basicUsage" - * const server = new McpServer({ - * name: 'my-server', - * version: '1.0.0' - * }); - * ``` + * v1-compat module path. The implementation moved to {@link ./mcpServer.ts}. + * @deprecated Import from `@modelcontextprotocol/server` directly. */ -export class McpServer { - /** - * The underlying {@linkcode Server} instance, useful for advanced operations like sending notifications. - */ - public readonly server: Server; - - private _registeredResources: { [uri: string]: RegisteredResource } = {}; - private _registeredResourceTemplates: { - [name: string]: RegisteredResourceTemplate; - } = {}; - private _registeredTools: { [name: string]: RegisteredTool } = {}; - private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; - private _experimental?: { tasks: ExperimentalMcpServerTasks }; - - constructor(serverInfo: Implementation, options?: ServerOptions) { - this.server = new Server(serverInfo, options); - } - - /** - * Access experimental features. - * - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - get experimental(): { tasks: ExperimentalMcpServerTasks } { - if (!this._experimental) { - this._experimental = { - tasks: new ExperimentalMcpServerTasks(this) - }; - } - return this._experimental; - } - - /** - * Attaches to the given transport, starts it, and starts listening for messages. - * - * The `server` object assumes ownership of the {@linkcode Transport}, replacing any callbacks that have already been set, and expects that it is the only user of the {@linkcode Transport} instance going forward. - * - * @example - * ```ts source="./mcp.examples.ts#McpServer_connect_stdio" - * const server = new McpServer({ name: 'my-server', version: '1.0.0' }); - * const transport = new StdioServerTransport(); - * await server.connect(transport); - * ``` - */ - async connect(transport: Transport): Promise { - return await this.server.connect(transport); - } - - /** - * Closes the connection. - */ - async close(): Promise { - await this.server.close(); - } - - private _toolHandlersInitialized = false; - - private setToolRequestHandlers() { - if (this._toolHandlersInitialized) { - return; - } - - this.server.assertCanSetRequestHandler('tools/list'); - this.server.assertCanSetRequestHandler('tools/call'); - - this.server.registerCapabilities({ - tools: { - listChanged: this.server.getCapabilities().tools?.listChanged ?? true - } - }); - - this.server.setRequestHandler( - 'tools/list', - (): ListToolsResult => ({ - tools: Object.entries(this._registeredTools) - .filter(([, tool]) => tool.enabled) - .map(([name, tool]): Tool => { - const toolDefinition: Tool = { - name, - title: tool.title, - description: tool.description, - inputSchema: tool.inputSchema - ? (standardSchemaToJsonSchema(tool.inputSchema, 'input') as Tool['inputSchema']) - : EMPTY_OBJECT_JSON_SCHEMA, - annotations: tool.annotations, - execution: tool.execution, - _meta: tool._meta - }; - - if (tool.outputSchema) { - toolDefinition.outputSchema = standardSchemaToJsonSchema(tool.outputSchema, 'output') as Tool['outputSchema']; - } - - return toolDefinition; - }) - }) - ); - - this.server.setRequestHandler('tools/call', async (request, ctx): Promise => { - const tool = this._registeredTools[request.params.name]; - if (!tool) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Tool ${request.params.name} not found`); - } - if (!tool.enabled) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); - } - - try { - const isTaskRequest = !!request.params.task; - const taskSupport = tool.execution?.taskSupport; - const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); - - // Validate task hint configuration - if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { - throw new ProtocolError( - ProtocolErrorCode.InternalError, - `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` - ); - } - - // Handle taskSupport 'required' without task augmentation - if (taskSupport === 'required' && !isTaskRequest) { - throw new ProtocolError( - ProtocolErrorCode.MethodNotFound, - `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` - ); - } - - // Handle taskSupport 'optional' without task augmentation - automatic polling - if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { - return await this.handleAutomaticTaskPolling(tool, request, ctx); - } - - // Normal execution path - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const result = await this.executeToolHandler(tool, args, ctx); - - // Return CreateTaskResult immediately for task requests - if (isTaskRequest) { - return result; - } - - // Validate output schema for non-task requests - await this.validateToolOutput(tool, result, request.params.name); - return result; - } catch (error) { - if (error instanceof ProtocolError && error.code === ProtocolErrorCode.UrlElicitationRequired) { - throw error; // Return the error to the caller without wrapping in CallToolResult - } - return this.createToolError(error instanceof Error ? error.message : String(error)); - } - }); - - this._toolHandlersInitialized = true; - } - - /** - * Creates a tool error result. - * - * @param errorMessage - The error message. - * @returns The tool error result. - */ - private createToolError(errorMessage: string): CallToolResult { - return { - content: [ - { - type: 'text', - text: errorMessage - } - ], - isError: true - }; - } - - /** - * Validates tool input arguments against the tool's input schema. - */ - private async validateToolInput< - ToolType extends RegisteredTool, - Args extends ToolType['inputSchema'] extends infer InputSchema - ? InputSchema extends StandardSchemaWithJSON - ? StandardSchemaWithJSON.InferOutput - : undefined - : undefined - >(tool: ToolType, args: Args, toolName: string): Promise { - if (!tool.inputSchema) { - return undefined as Args; - } - - const parseResult = await validateStandardSchema(tool.inputSchema, args ?? {}); - if (!parseResult.success) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Input validation error: Invalid arguments for tool ${toolName}: ${parseResult.error}` - ); - } - - return parseResult.data as unknown as Args; - } - - /** - * Validates tool output against the tool's output schema. - */ - private async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { - if (!tool.outputSchema) { - return; - } - - // Only validate CallToolResult, not CreateTaskResult - if (!('content' in result)) { - return; - } - - if (result.isError) { - return; - } - - if (!result.structuredContent) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Output validation error: Tool ${toolName} has an output schema but no structured content was provided` - ); - } - - // if the tool has an output schema, validate structured content - const parseResult = await validateStandardSchema(tool.outputSchema, result.structuredContent); - if (!parseResult.success) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Output validation error: Invalid structured content for tool ${toolName}: ${parseResult.error}` - ); - } - } - - /** - * Executes a tool handler (either regular or task-based). - */ - private async executeToolHandler(tool: RegisteredTool, args: unknown, ctx: ServerContext): Promise { - // Executor encapsulates handler invocation with proper types - return tool.executor(args, ctx); - } - - /** - * Handles automatic task polling for tools with `taskSupport` `'optional'`. - */ - private async handleAutomaticTaskPolling( - tool: RegisteredTool, - request: RequestT, - ctx: ServerContext - ): Promise { - if (!ctx.task?.store) { - throw new Error('No task store provided for task-capable tool.'); - } - - // Validate input and create task using the executor - const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); - const createTaskResult = (await tool.executor(args, ctx)) as CreateTaskResult; - - // Poll until completion - const taskId = createTaskResult.task.taskId; - let task = createTaskResult.task; - const pollInterval = task.pollInterval ?? 5000; - - while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { - await new Promise(resolve => setTimeout(resolve, pollInterval)); - const updatedTask = await ctx.task.store.getTask(taskId); - if (!updatedTask) { - throw new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} not found during polling`); - } - task = updatedTask; - } - - // Return the final result - return (await ctx.task.store.getTaskResult(taskId)) as CallToolResult; - } - - private _completionHandlerInitialized = false; - - private setCompletionRequestHandler() { - if (this._completionHandlerInitialized) { - return; - } - - this.server.assertCanSetRequestHandler('completion/complete'); - - this.server.registerCapabilities({ - completions: {} - }); - - this.server.setRequestHandler('completion/complete', async (request): Promise => { - switch (request.params.ref.type) { - case 'ref/prompt': { - assertCompleteRequestPrompt(request); - return this.handlePromptCompletion(request, request.params.ref); - } - - case 'ref/resource': { - assertCompleteRequestResourceTemplate(request); - return this.handleResourceCompletion(request, request.params.ref); - } - - default: { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid completion reference: ${request.params.ref}`); - } - } - }); - - this._completionHandlerInitialized = true; - } - - private async handlePromptCompletion(request: CompleteRequestPrompt, ref: PromptReference): Promise { - const prompt = this._registeredPrompts[ref.name]; - if (!prompt) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${ref.name} not found`); - } - - if (!prompt.enabled) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${ref.name} disabled`); - } - - if (!prompt.argsSchema) { - return EMPTY_COMPLETION_RESULT; - } - - const promptShape = getSchemaShape(prompt.argsSchema); - const field = unwrapOptionalSchema(promptShape?.[request.params.argument.name]); - if (!isCompletable(field)) { - return EMPTY_COMPLETION_RESULT; - } - - const completer = getCompleter(field); - if (!completer) { - return EMPTY_COMPLETION_RESULT; - } - - const suggestions = await completer(request.params.argument.value, request.params.context); - return createCompletionResult(suggestions); - } - - private async handleResourceCompletion( - request: CompleteRequestResourceTemplate, - ref: ResourceTemplateReference - ): Promise { - const template = Object.values(this._registeredResourceTemplates).find(t => t.resourceTemplate.uriTemplate.toString() === ref.uri); - - if (!template) { - if (this._registeredResources[ref.uri]) { - // Attempting to autocomplete a fixed resource URI is not an error in the spec (but probably should be). - return EMPTY_COMPLETION_RESULT; - } - - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Resource template ${request.params.ref.uri} not found`); - } - - const completer = template.resourceTemplate.completeCallback(request.params.argument.name); - if (!completer) { - return EMPTY_COMPLETION_RESULT; - } - - const suggestions = await completer(request.params.argument.value, request.params.context); - return createCompletionResult(suggestions); - } - - private _resourceHandlersInitialized = false; - - private setResourceRequestHandlers() { - if (this._resourceHandlersInitialized) { - return; - } - - this.server.assertCanSetRequestHandler('resources/list'); - this.server.assertCanSetRequestHandler('resources/templates/list'); - this.server.assertCanSetRequestHandler('resources/read'); - - this.server.registerCapabilities({ - resources: { - listChanged: this.server.getCapabilities().resources?.listChanged ?? true - } - }); - - this.server.setRequestHandler('resources/list', async (_request, ctx) => { - const resources = Object.entries(this._registeredResources) - .filter(([_, resource]) => resource.enabled) - .map(([uri, resource]) => ({ - uri, - name: resource.name, - ...resource.metadata - })); - - const templateResources: Resource[] = []; - for (const template of Object.values(this._registeredResourceTemplates)) { - if (!template.resourceTemplate.listCallback) { - continue; - } - - const result = await template.resourceTemplate.listCallback(ctx); - for (const resource of result.resources) { - templateResources.push({ - ...template.metadata, - // the defined resource metadata should override the template metadata if present - ...resource - }); - } - } - - return { resources: [...resources, ...templateResources] }; - }); - - this.server.setRequestHandler('resources/templates/list', async () => { - const resourceTemplates = Object.entries(this._registeredResourceTemplates).map(([name, template]) => ({ - name, - uriTemplate: template.resourceTemplate.uriTemplate.toString(), - ...template.metadata - })); - - return { resourceTemplates }; - }); - - this.server.setRequestHandler('resources/read', async (request, ctx) => { - const uri = new URL(request.params.uri); - - // First check for exact resource match - const resource = this._registeredResources[uri.toString()]; - if (resource) { - if (!resource.enabled) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Resource ${uri} disabled`); - } - return resource.readCallback(uri, ctx); - } - - // Then check templates - for (const template of Object.values(this._registeredResourceTemplates)) { - const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); - if (variables) { - return template.readCallback(uri, variables, ctx); - } - } - - throw new ProtocolError(ProtocolErrorCode.ResourceNotFound, `Resource ${uri} not found`); - }); - - this._resourceHandlersInitialized = true; - } - - private _promptHandlersInitialized = false; - - private setPromptRequestHandlers() { - if (this._promptHandlersInitialized) { - return; - } - - this.server.assertCanSetRequestHandler('prompts/list'); - this.server.assertCanSetRequestHandler('prompts/get'); - - this.server.registerCapabilities({ - prompts: { - listChanged: this.server.getCapabilities().prompts?.listChanged ?? true - } - }); - - this.server.setRequestHandler( - 'prompts/list', - (): ListPromptsResult => ({ - prompts: Object.entries(this._registeredPrompts) - .filter(([, prompt]) => prompt.enabled) - .map(([name, prompt]): Prompt => { - return { - name, - title: prompt.title, - description: prompt.description, - arguments: prompt.argsSchema ? promptArgumentsFromStandardSchema(prompt.argsSchema) : undefined, - _meta: prompt._meta - }; - }) - }) - ); - - this.server.setRequestHandler('prompts/get', async (request, ctx): Promise => { - const prompt = this._registeredPrompts[request.params.name]; - if (!prompt) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); - } - - if (!prompt.enabled) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); - } - - // Handler encapsulates parsing and callback invocation with proper types - return prompt.handler(request.params.arguments, ctx); - }); - - this._promptHandlersInitialized = true; - } - - /** - * Registers a resource with a config object and callback. - * For static resources, use a URI string. For dynamic resources, use a {@linkcode ResourceTemplate}. - * - * @example - * ```ts source="./mcp.examples.ts#McpServer_registerResource_static" - * server.registerResource( - * 'config', - * 'config://app', - * { - * title: 'Application Config', - * mimeType: 'text/plain' - * }, - * async uri => ({ - * contents: [{ uri: uri.href, text: 'App configuration here' }] - * }) - * ); - * ``` - */ - registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; - registerResource( - name: string, - uriOrTemplate: ResourceTemplate, - config: ResourceMetadata, - readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate; - registerResource( - name: string, - uriOrTemplate: string | ResourceTemplate, - config: ResourceMetadata, - readCallback: ReadResourceCallback | ReadResourceTemplateCallback - ): RegisteredResource | RegisteredResourceTemplate { - if (typeof uriOrTemplate === 'string') { - if (this._registeredResources[uriOrTemplate]) { - throw new Error(`Resource ${uriOrTemplate} is already registered`); - } - - const registeredResource = this._createRegisteredResource( - name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceCallback - ); - - this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResource; - } else { - if (this._registeredResourceTemplates[name]) { - throw new Error(`Resource template ${name} is already registered`); - } - - const registeredResourceTemplate = this._createRegisteredResourceTemplate( - name, - (config as BaseMetadata).title, - uriOrTemplate, - config, - readCallback as ReadResourceTemplateCallback - ); - - this.setResourceRequestHandlers(); - this.sendResourceListChanged(); - return registeredResourceTemplate; - } - } - - private _createRegisteredResource( - name: string, - title: string | undefined, - uri: string, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceCallback - ): RegisteredResource { - const registeredResource: RegisteredResource = { - name, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResource.update({ enabled: false }), - enable: () => registeredResource.update({ enabled: true }), - remove: () => registeredResource.update({ uri: null }), - update: updates => { - if (updates.uri !== undefined && updates.uri !== uri) { - delete this._registeredResources[uri]; - if (updates.uri) this._registeredResources[updates.uri] = registeredResource; - } - if (updates.name !== undefined) registeredResource.name = updates.name; - if (updates.title !== undefined) registeredResource.title = updates.title; - if (updates.metadata !== undefined) registeredResource.metadata = updates.metadata; - if (updates.callback !== undefined) registeredResource.readCallback = updates.callback; - if (updates.enabled !== undefined) registeredResource.enabled = updates.enabled; - this.sendResourceListChanged(); - } - }; - this._registeredResources[uri] = registeredResource; - return registeredResource; - } - - private _createRegisteredResourceTemplate( - name: string, - title: string | undefined, - template: ResourceTemplate, - metadata: ResourceMetadata | undefined, - readCallback: ReadResourceTemplateCallback - ): RegisteredResourceTemplate { - const registeredResourceTemplate: RegisteredResourceTemplate = { - resourceTemplate: template, - title, - metadata, - readCallback, - enabled: true, - disable: () => registeredResourceTemplate.update({ enabled: false }), - enable: () => registeredResourceTemplate.update({ enabled: true }), - remove: () => registeredResourceTemplate.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - delete this._registeredResourceTemplates[name]; - if (updates.name) this._registeredResourceTemplates[updates.name] = registeredResourceTemplate; - } - if (updates.title !== undefined) registeredResourceTemplate.title = updates.title; - if (updates.template !== undefined) registeredResourceTemplate.resourceTemplate = updates.template; - if (updates.metadata !== undefined) registeredResourceTemplate.metadata = updates.metadata; - if (updates.callback !== undefined) registeredResourceTemplate.readCallback = updates.callback; - if (updates.enabled !== undefined) registeredResourceTemplate.enabled = updates.enabled; - this.sendResourceListChanged(); - } - }; - this._registeredResourceTemplates[name] = registeredResourceTemplate; - - // If the resource template has any completion callbacks, enable completions capability - const variableNames = template.uriTemplate.variableNames; - const hasCompleter = Array.isArray(variableNames) && variableNames.some(v => !!template.completeCallback(v)); - if (hasCompleter) { - this.setCompletionRequestHandler(); - } - - return registeredResourceTemplate; - } - - private _createRegisteredPrompt( - name: string, - title: string | undefined, - description: string | undefined, - argsSchema: StandardSchemaWithJSON | undefined, - callback: PromptCallback, - _meta: Record | undefined - ): RegisteredPrompt { - // Track current schema and callback for handler regeneration - let currentArgsSchema = argsSchema; - let currentCallback = callback; - - const registeredPrompt: RegisteredPrompt = { - title, - description, - argsSchema, - _meta, - handler: createPromptHandler(name, argsSchema, callback), - enabled: true, - disable: () => registeredPrompt.update({ enabled: false }), - enable: () => registeredPrompt.update({ enabled: true }), - remove: () => registeredPrompt.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - delete this._registeredPrompts[name]; - if (updates.name) this._registeredPrompts[updates.name] = registeredPrompt; - } - if (updates.title !== undefined) registeredPrompt.title = updates.title; - if (updates.description !== undefined) registeredPrompt.description = updates.description; - if (updates._meta !== undefined) registeredPrompt._meta = updates._meta; - - // Track if we need to regenerate the handler - let needsHandlerRegen = false; - if (updates.argsSchema !== undefined) { - registeredPrompt.argsSchema = updates.argsSchema; - currentArgsSchema = updates.argsSchema; - needsHandlerRegen = true; - } - if (updates.callback !== undefined) { - currentCallback = updates.callback as PromptCallback; - needsHandlerRegen = true; - } - if (needsHandlerRegen) { - registeredPrompt.handler = createPromptHandler(name, currentArgsSchema, currentCallback); - } - - if (updates.enabled !== undefined) registeredPrompt.enabled = updates.enabled; - this.sendPromptListChanged(); - } - }; - this._registeredPrompts[name] = registeredPrompt; - - // If any argument uses a Completable schema, enable completions capability - if (argsSchema) { - const shape = getSchemaShape(argsSchema); - if (shape) { - const hasCompletable = Object.values(shape).some(field => { - const inner = unwrapOptionalSchema(field); - return isCompletable(inner); - }); - if (hasCompletable) { - this.setCompletionRequestHandler(); - } - } - } - - return registeredPrompt; - } - - private _createRegisteredTool( - name: string, - title: string | undefined, - description: string | undefined, - inputSchema: StandardSchemaWithJSON | undefined, - outputSchema: StandardSchemaWithJSON | undefined, - annotations: ToolAnnotations | undefined, - execution: ToolExecution | undefined, - _meta: Record | undefined, - handler: AnyToolHandler - ): RegisteredTool { - // Validate tool name according to SEP specification - validateAndWarnToolName(name); - - // Track current handler for executor regeneration - let currentHandler = handler; - - const registeredTool: RegisteredTool = { - title, - description, - inputSchema, - outputSchema, - annotations, - execution, - _meta, - handler: handler, - executor: createToolExecutor(inputSchema, handler), - enabled: true, - disable: () => registeredTool.update({ enabled: false }), - enable: () => registeredTool.update({ enabled: true }), - remove: () => registeredTool.update({ name: null }), - update: updates => { - if (updates.name !== undefined && updates.name !== name) { - if (typeof updates.name === 'string') { - validateAndWarnToolName(updates.name); - } - delete this._registeredTools[name]; - if (updates.name) this._registeredTools[updates.name] = registeredTool; - } - if (updates.title !== undefined) registeredTool.title = updates.title; - if (updates.description !== undefined) registeredTool.description = updates.description; - - // Track if we need to regenerate the executor - let needsExecutorRegen = false; - if (updates.paramsSchema !== undefined) { - registeredTool.inputSchema = updates.paramsSchema; - needsExecutorRegen = true; - } - if (updates.callback !== undefined) { - registeredTool.handler = updates.callback; - currentHandler = updates.callback as AnyToolHandler; - needsExecutorRegen = true; - } - if (needsExecutorRegen) { - registeredTool.executor = createToolExecutor(registeredTool.inputSchema, currentHandler); - } - - if (updates.outputSchema !== undefined) registeredTool.outputSchema = updates.outputSchema; - if (updates.annotations !== undefined) registeredTool.annotations = updates.annotations; - if (updates._meta !== undefined) registeredTool._meta = updates._meta; - if (updates.enabled !== undefined) registeredTool.enabled = updates.enabled; - this.sendToolListChanged(); - } - }; - this._registeredTools[name] = registeredTool; - - this.setToolRequestHandlers(); - this.sendToolListChanged(); - - return registeredTool; - } - - /** - * Registers a tool with a config object and callback. - * - * @example - * ```ts source="./mcp.examples.ts#McpServer_registerTool_basic" - * server.registerTool( - * 'calculate-bmi', - * { - * title: 'BMI Calculator', - * description: 'Calculate Body Mass Index', - * inputSchema: z.object({ - * weightKg: z.number(), - * heightM: z.number() - * }), - * outputSchema: z.object({ bmi: z.number() }) - * }, - * async ({ weightKg, heightM }) => { - * const output = { bmi: weightKg / (heightM * heightM) }; - * return { - * content: [{ type: 'text', text: JSON.stringify(output) }], - * structuredContent: output - * }; - * } - * ); - * ``` - */ - registerTool( - name: string, - config: { - title?: string; - description?: string; - inputSchema?: InputArgs; - outputSchema?: OutputArgs; - annotations?: ToolAnnotations; - _meta?: Record; - }, - cb: ToolCallback - ): RegisteredTool { - if (this._registeredTools[name]) { - throw new Error(`Tool ${name} is already registered`); - } - - const { title, description, inputSchema, outputSchema, annotations, _meta } = config; - - return this._createRegisteredTool( - name, - title, - description, - inputSchema, - outputSchema, - annotations, - { taskSupport: 'forbidden' }, - _meta, - cb as ToolCallback - ); - } - - /** - * Registers a prompt with a config object and callback. - * - * @example - * ```ts source="./mcp.examples.ts#McpServer_registerPrompt_basic" - * server.registerPrompt( - * 'review-code', - * { - * title: 'Code Review', - * description: 'Review code for best practices', - * argsSchema: z.object({ code: z.string() }) - * }, - * ({ code }) => ({ - * messages: [ - * { - * role: 'user' as const, - * content: { - * type: 'text' as const, - * text: `Please review this code:\n\n${code}` - * } - * } - * ] - * }) - * ); - * ``` - */ - registerPrompt( - name: string, - config: { - title?: string; - description?: string; - argsSchema?: Args; - _meta?: Record; - }, - cb: PromptCallback - ): RegisteredPrompt { - if (this._registeredPrompts[name]) { - throw new Error(`Prompt ${name} is already registered`); - } - - const { title, description, argsSchema, _meta } = config; - - const registeredPrompt = this._createRegisteredPrompt( - name, - title, - description, - argsSchema, - cb as PromptCallback, - _meta - ); - - this.setPromptRequestHandlers(); - this.sendPromptListChanged(); - - return registeredPrompt; - } - - /** - * Checks if the server is connected to a transport. - * @returns `true` if the server is connected - */ - isConnected() { - return this.server.transport !== undefined; - } - - /** - * Sends a logging message to the client, if connected. - * Note: You only need to send the parameters object, not the entire JSON-RPC message. - * @see {@linkcode LoggingMessageNotification} - * @param params - * @param sessionId Optional for stateless transports and backward compatibility. - * - * @example - * ```ts source="./mcp.examples.ts#McpServer_sendLoggingMessage_basic" - * await server.sendLoggingMessage({ - * level: 'info', - * data: 'Processing complete' - * }); - * ``` - */ - async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string) { - return this.server.sendLoggingMessage(params, sessionId); - } - /** - * Sends a resource list changed event to the client, if connected. - */ - sendResourceListChanged() { - if (this.isConnected()) { - this.server.sendResourceListChanged(); - } - } - - /** - * Sends a tool list changed event to the client, if connected. - */ - sendToolListChanged() { - if (this.isConnected()) { - this.server.sendToolListChanged(); - } - } - - /** - * Sends a prompt list changed event to the client, if connected. - */ - sendPromptListChanged() { - if (this.isConnected()) { - this.server.sendPromptListChanged(); - } - } -} - -/** - * A callback to complete one variable within a resource template's URI template. - */ -export type CompleteResourceTemplateCallback = ( - value: string, - context?: { - arguments?: Record; - } -) => string[] | Promise; - -/** - * A resource template combines a URI pattern with optional functionality to enumerate - * all resources matching that pattern. - */ -export class ResourceTemplate { - private _uriTemplate: UriTemplate; - - constructor( - uriTemplate: string | UriTemplate, - private _callbacks: { - /** - * A callback to list all resources matching this template. This is required to be specified, even if `undefined`, to avoid accidentally forgetting resource listing. - */ - list: ListResourcesCallback | undefined; - - /** - * An optional callback to autocomplete variables within the URI template. Useful for clients and users to discover possible values. - */ - complete?: { - [variable: string]: CompleteResourceTemplateCallback; - }; - } - ) { - this._uriTemplate = typeof uriTemplate === 'string' ? new UriTemplate(uriTemplate) : uriTemplate; - } - - /** - * Gets the URI template pattern. - */ - get uriTemplate(): UriTemplate { - return this._uriTemplate; - } - - /** - * Gets the list callback, if one was provided. - */ - get listCallback(): ListResourcesCallback | undefined { - return this._callbacks.list; - } - - /** - * Gets the callback for completing a specific URI template variable, if one was provided. - */ - completeCallback(variable: string): CompleteResourceTemplateCallback | undefined { - return this._callbacks.complete?.[variable]; - } -} - -export type BaseToolCallback< - SendResultT extends Result, - Ctx extends ServerContext, - Args extends StandardSchemaWithJSON | undefined -> = Args extends StandardSchemaWithJSON - ? (args: StandardSchemaWithJSON.InferOutput, ctx: Ctx) => SendResultT | Promise - : (ctx: Ctx) => SendResultT | Promise; - -/** - * Callback for a tool handler registered with {@linkcode McpServer.registerTool}. - */ -export type ToolCallback = BaseToolCallback< - CallToolResult, - ServerContext, - Args ->; - -/** - * Supertype that can handle both regular tools (simple callback) and task-based tools (task handler object). - */ -export type AnyToolHandler = ToolCallback | ToolTaskHandler; - -/** - * Internal executor type that encapsulates handler invocation with proper types. - */ -type ToolExecutor = (args: unknown, ctx: ServerContext) => Promise; - -export type RegisteredTool = { - title?: string; - description?: string; - inputSchema?: StandardSchemaWithJSON; - outputSchema?: StandardSchemaWithJSON; - annotations?: ToolAnnotations; - execution?: ToolExecution; - _meta?: Record; - handler: AnyToolHandler; - /** @hidden */ - executor: ToolExecutor; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - description?: string; - paramsSchema?: StandardSchemaWithJSON; - outputSchema?: StandardSchemaWithJSON; - annotations?: ToolAnnotations; - _meta?: Record; - callback?: ToolCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -/** - * Creates an executor that invokes the handler with the appropriate arguments. - * When `inputSchema` is defined, the handler is called with `(args, ctx)`. - * When `inputSchema` is undefined, the handler is called with just `(ctx)`. - */ -function createToolExecutor( - inputSchema: StandardSchemaWithJSON | undefined, - handler: AnyToolHandler -): ToolExecutor { - const isTaskHandler = 'createTask' in handler; - - if (isTaskHandler) { - const taskHandler = handler as TaskHandlerInternal; - return async (args, ctx) => { - if (!ctx.task?.store) { - throw new Error('No task store provided.'); - } - const taskCtx: CreateTaskServerContext = { ...ctx, task: { store: ctx.task.store, requestedTtl: ctx.task?.requestedTtl } }; - if (inputSchema) { - return taskHandler.createTask(args, taskCtx); - } - // When no inputSchema, call with just ctx (the handler expects (ctx) signature) - return (taskHandler.createTask as (ctx: CreateTaskServerContext) => CreateTaskResult | Promise)(taskCtx); - }; - } - - if (inputSchema) { - const callback = handler as ToolCallbackInternal; - return async (args, ctx) => callback(args, ctx); - } - - // When no inputSchema, call with just ctx (the handler expects (ctx) signature) - const callback = handler as (ctx: ServerContext) => CallToolResult | Promise; - return async (_args, ctx) => callback(ctx); -} - -const EMPTY_OBJECT_JSON_SCHEMA = { - type: 'object' as const, - properties: {} -}; - -/** - * Additional, optional information for annotating a resource. - */ -export type ResourceMetadata = Omit; - -/** - * Callback to list all resources matching a given template. - */ -export type ListResourcesCallback = (ctx: ServerContext) => ListResourcesResult | Promise; - -/** - * Callback to read a resource at a given URI. - */ -export type ReadResourceCallback = (uri: URL, ctx: ServerContext) => ReadResourceResult | Promise; - -export type RegisteredResource = { - name: string; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string; - title?: string; - uri?: string | null; - metadata?: ResourceMetadata; - callback?: ReadResourceCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -/** - * Callback to read a resource at a given URI, following a filled-in URI template. - */ -export type ReadResourceTemplateCallback = ( - uri: URL, - variables: Variables, - ctx: ServerContext -) => ReadResourceResult | Promise; - -export type RegisteredResourceTemplate = { - resourceTemplate: ResourceTemplate; - title?: string; - metadata?: ResourceMetadata; - readCallback: ReadResourceTemplateCallback; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - template?: ResourceTemplate; - metadata?: ResourceMetadata; - callback?: ReadResourceTemplateCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -export type PromptCallback = Args extends StandardSchemaWithJSON - ? (args: StandardSchemaWithJSON.InferOutput, ctx: ServerContext) => GetPromptResult | Promise - : (ctx: ServerContext) => GetPromptResult | Promise; - -/** - * Internal handler type that encapsulates parsing and callback invocation. - * This allows type-safe handling without runtime type assertions. - */ -type PromptHandler = (args: Record | undefined, ctx: ServerContext) => Promise; - -type ToolCallbackInternal = (args: unknown, ctx: ServerContext) => CallToolResult | Promise; - -type TaskHandlerInternal = { - createTask: (args: unknown, ctx: CreateTaskServerContext) => CreateTaskResult | Promise; -}; - -export type RegisteredPrompt = { - title?: string; - description?: string; - argsSchema?: StandardSchemaWithJSON; - _meta?: Record; - /** @hidden */ - handler: PromptHandler; - enabled: boolean; - enable(): void; - disable(): void; - update(updates: { - name?: string | null; - title?: string; - description?: string; - argsSchema?: Args; - _meta?: Record; - callback?: PromptCallback; - enabled?: boolean; - }): void; - remove(): void; -}; - -/** - * Creates a type-safe prompt handler that captures the schema and callback in a closure. - * This eliminates the need for type assertions at the call site. - */ -function createPromptHandler( - name: string, - argsSchema: StandardSchemaWithJSON | undefined, - callback: PromptCallback -): PromptHandler { - if (argsSchema) { - const typedCallback = callback as (args: unknown, ctx: ServerContext) => GetPromptResult | Promise; - - return async (args, ctx) => { - const parseResult = await validateStandardSchema(argsSchema, args); - if (!parseResult.success) { - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid arguments for prompt ${name}: ${parseResult.error}`); - } - return typedCallback(parseResult.data, ctx); - }; - } else { - const typedCallback = callback as (ctx: ServerContext) => GetPromptResult | Promise; - - return async (_args, ctx) => { - return typedCallback(ctx); - }; - } -} - -function createCompletionResult(suggestions: readonly unknown[]): CompleteResult { - const values = suggestions.map(String).slice(0, 100); - return { - completion: { - values, - total: suggestions.length, - hasMore: suggestions.length > 100 - } - }; -} - -const EMPTY_COMPLETION_RESULT: CompleteResult = { - completion: { - values: [], - hasMore: false - } -}; - -/** @internal Gets the shape of a Zod object schema */ -function getSchemaShape(schema: unknown): Record | undefined { - const candidate = schema as { shape?: unknown }; - if (candidate.shape && typeof candidate.shape === 'object') { - return candidate.shape as Record; - } - return undefined; -} - -/** @internal Checks if a Zod schema is optional */ -function isOptionalSchema(schema: unknown): boolean { - const candidate = schema as { type?: string } | null | undefined; - return candidate?.type === 'optional'; -} - -/** @internal Unwraps an optional Zod schema */ -function unwrapOptionalSchema(schema: unknown): unknown { - if (!isOptionalSchema(schema)) { - return schema; - } - const candidate = schema as { def?: { innerType?: unknown } }; - return candidate.def?.innerType ?? schema; -} +export * from './mcpServer.js'; diff --git a/packages/server/src/server/mcpServer.ts b/packages/server/src/server/mcpServer.ts new file mode 100644 index 000000000..df2f5ee03 --- /dev/null +++ b/packages/server/src/server/mcpServer.ts @@ -0,0 +1,1081 @@ +import type { + AuthInfo, + BaseContext, + CallToolRequest, + CallToolResult, + ChannelTransport, + ClientCapabilities, + CreateMessageRequest, + CreateMessageRequestParamsBase, + CreateMessageRequestParamsWithTools, + CreateMessageResult, + CreateMessageResultWithTools, + CreateTaskResult, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + Implementation, + InitializeRequest, + InitializeResult, + JSONRPCErrorResponse, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + JSONRPCResultResponse, + JsonSchemaType, + jsonSchemaValidator, + ListRootsRequest, + LoggingLevel, + LoggingMessageNotification, + MessageExtraInfo, + Notification, + NotificationMethod, + NotificationOptions, + Outbound, + ProtocolOptions, + Request, + RequestEnv, + RequestMethod, + RequestOptions, + RequestTransport, + RequestTypeMap, + ResourceUpdatedNotification, + Result, + ResultTypeMap, + SchemaOutput, + ServerCapabilities, + ServerContext, + ServerResult, + StandardSchemaV1, + StandardSchemaWithJSON, + TaskManagerOptions, + ToolAnnotations, + ToolExecution, + ToolResultContent, + ToolUseContent, + Transport +} from '@modelcontextprotocol/core'; +import { + assertClientRequestTaskCapability, + assertToolsCallTaskCapability, + attachChannelTransport, + CallToolRequestSchema, + CallToolResultSchema, + CreateMessageResultSchema, + CreateMessageResultWithToolsSchema, + CreateTaskResultSchema, + Dispatcher, + ElicitResultSchema, + EmptyResultSchema, + extractTaskManagerOptions, + getResultSchema, + isJSONRPCRequest, + isRequestTransport, + LATEST_PROTOCOL_VERSION, + ListRootsResultSchema, + LoggingLevelSchema, + mergeCapabilities, + NullTaskManager, + parseSchema, + ProtocolError, + ProtocolErrorCode, + SdkError, + SdkErrorCode, + SUPPORTED_PROTOCOL_VERSIONS, + TaskManager +} from '@modelcontextprotocol/core'; +import { DefaultJsonSchemaValidator } from '@modelcontextprotocol/server/_shims'; + +import { ExperimentalMcpServerTasks } from '../experimental/tasks/mcpServer.js'; +import type { ResourceTemplate } from './resourceTemplate.js'; +import { assertCapabilityForMethod, assertNotificationCapability, assertRequestHandlerCapability } from './serverCapabilities.js'; +import type { LegacyPromptCallback, LegacyToolCallback, ZodRawShapeCompat } from './serverLegacy.js'; +import { extractMethodFromSchema, parseLegacyPromptArgs, parseLegacyToolArgs } from './serverLegacy.js'; +import type { + AnyToolHandler, + PromptCallback, + ReadResourceCallback, + ReadResourceTemplateCallback, + RegisteredPrompt, + RegisteredResource, + RegisteredResourceTemplate, + RegisteredTool, + RegistriesHost, + ResourceMetadata, + ToolCallback +} from './serverRegistries.js'; +import { ServerRegistries } from './serverRegistries.js'; + +/** + * Extended tasks capability that includes runtime configuration (store, messageQueue). + * The runtime-only fields are stripped before advertising capabilities to clients. + */ +export type ServerTasksCapabilityWithRuntime = NonNullable & TaskManagerOptions; + +export type ServerOptions = Omit & { + /** + * Capabilities to advertise as being supported by this server. + */ + capabilities?: Omit & { + tasks?: ServerTasksCapabilityWithRuntime; + }; + + /** + * Optional instructions describing how to use the server and its features. + */ + instructions?: string; + + /** + * JSON Schema validator for elicitation response validation. + * + * @default {@linkcode DefaultJsonSchemaValidator} + */ + jsonSchemaValidator?: jsonSchemaValidator; +}; + +/** + * MCP server. Holds tool/resource/prompt registries and exposes both a stateless + * {@linkcode McpServer.handle | handle()} entry point (for HTTP/gRPC/serverless drivers) + * and a {@linkcode McpServer.connect | connect()} entry point (for stdio/WebSocket pipes). + * + * One instance can serve any number of concurrent requests. + */ +export class McpServer extends Dispatcher implements RegistriesHost { + private _outbound?: Outbound; + private readonly _registries = new ServerRegistries(this); + + private _clientCapabilities?: ClientCapabilities; + private _clientVersion?: Implementation; + private _capabilities: ServerCapabilities; + private _instructions?: string; + private _jsonSchemaValidator: jsonSchemaValidator; + private _supportedProtocolVersions: string[]; + private _experimental?: { tasks: ExperimentalMcpServerTasks }; + private _taskManager: TaskManager; + private _loggingLevels = new Map(); + private readonly LOG_LEVEL_SEVERITY = new Map(LoggingLevelSchema.options.map((level, index) => [level, index])); + + /** + * Callback for when initialization has fully completed. + */ + oninitialized?: () => void; + + /** + * Callback for when a connected transport is closed. + */ + onclose?: () => void; + + /** + * Callback for when an error occurs. + */ + onerror?: (error: Error) => void; + + constructor( + private _serverInfo: Implementation, + private _options?: ServerOptions + ) { + super(); + this._capabilities = _options?.capabilities ? { ..._options.capabilities } : {}; + this._instructions = _options?.instructions; + this._jsonSchemaValidator = _options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); + this._supportedProtocolVersions = _options?.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; + + // Strip runtime-only fields from advertised capabilities + if (_options?.capabilities?.tasks) { + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { taskStore, taskMessageQueue, defaultTaskPollInterval, maxTaskQueueSize, ...wireCapabilities } = + _options.capabilities.tasks; + this._capabilities.tasks = wireCapabilities; + } + + const tasksOpts = extractTaskManagerOptions(_options?.capabilities?.tasks); + this._taskManager = tasksOpts ? new TaskManager(tasksOpts) : new NullTaskManager(); + this._taskManager.attachTo(this, { + channel: () => this._outbound, + reportError: e => (this.onerror ?? (() => {}))(e), + enforceStrictCapabilities: this._options?.enforceStrictCapabilities === true, + assertTaskCapability: m => assertClientRequestTaskCapability(this._clientCapabilities?.tasks?.requests, m, 'Client'), + assertTaskHandlerCapability: m => assertToolsCallTaskCapability(this._capabilities?.tasks?.requests, m, 'Server') + }); + + this.setRequestHandler('initialize', request => this._oninitialize(request)); + this.setRequestHandler('ping', () => ({})); + this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); + + if (this._capabilities.logging) { + this._registerLoggingHandler(); + } + } + + // ─────────────────────────────────────────────────────────────────────── + // Direct dispatch + // ─────────────────────────────────────────────────────────────────────── + + /** + * Routes an incoming JSON-RPC response (e.g. a client's reply to an `elicitation/create` + * request the server issued) through the {@linkcode TaskManager}. Called by + * {@linkcode shttpHandler} for response-typed POST bodies. + * + * @returns true if the task manager consumed it. + */ + dispatchInboundResponse(response: JSONRPCResponse | JSONRPCErrorResponse): boolean { + const id = typeof response.id === 'number' ? response.id : Number(response.id); + return this._taskManager.processInboundResponse(response, id).consumed; + } + + /** + * Handle one inbound request without a transport. Yields any notifications the handler + * emits via `ctx.mcpReq.notify()`, then yields exactly one terminal response. + */ + async *handle(request: JSONRPCRequest, env?: RequestEnv): AsyncGenerator { + for await (const out of this.dispatch(request, env)) { + yield out.message; + } + } + + /** + * Convenience entry for HTTP request/response drivers. Parses the body, dispatches each + * request, and returns a JSON response. SSE streaming is handled by `shttpHandler`, not here. + */ + async handleHttp(req: globalThis.Request, opts?: { authInfo?: AuthInfo }): Promise { + let body: unknown; + try { + body = await req.json(); + } catch { + return jsonResponse(400, { jsonrpc: '2.0', id: null, error: { code: ProtocolErrorCode.ParseError, message: 'Parse error' } }); + } + const messages = Array.isArray(body) ? body : [body]; + const env: RequestEnv = { authInfo: opts?.authInfo, httpReq: req }; + const responses: JSONRPCMessage[] = []; + for (const m of messages) { + if (!isJSONRPCRequest(m)) { + if (m && typeof m === 'object' && 'method' in m) { + await this.dispatchNotification(m).catch(() => {}); + } + continue; + } + for await (const out of this.dispatch(m, env)) { + if (out.kind === 'response') responses.push(out.message); + } + } + if (responses.length === 0) return new Response(null, { status: 202 }); + return jsonResponse(200, responses.length === 1 ? responses[0] : responses); + } + + // ─────────────────────────────────────────────────────────────────────── + // Transport wiring + // ─────────────────────────────────────────────────────────────────────── + + /** + * Wires this server to the given transport. + * + * - For {@linkcode RequestTransport} (Streamable HTTP): sets the transport's + * `onrequest`/`onnotification`/`onresponse` callback slots so it can route inbound + * messages here, and builds an {@linkcode Outbound} from the transport's + * optional `notify`/`request` methods. + * - For {@linkcode ChannelTransport} (stdio/WebSocket/InMemory): wraps it in a + * {@linkcode StreamDriver} via {@linkcode attachChannelTransport}. + * + * **Known limitation:** `_outbound` is a singleton — each `connect()` call overwrites + * it. Multiple concurrent connections (the v1 stateful-SHTTP `Map` + * pattern) work for *inbound* dispatch, but instance-level *outbound* methods + * ({@linkcode createMessage}, {@linkcode sendToolListChanged}, etc.) reach only the + * most-recently-connected transport. This matches v1 `Protocol.connect()` semantics. + */ + async connect(transport: ChannelTransport | RequestTransport): Promise { + let outbound: Outbound | undefined; + if (isRequestTransport(transport)) { + transport.onrequest = (req, env) => this.handle(req, env); + transport.onnotification = n => this.dispatchNotification(n); + transport.onresponse = r => this.dispatchInboundResponse(r); + + const prevClose = transport.onclose; + transport.onclose = () => { + prevClose?.(); + if (this._outbound === outbound) this._outbound = undefined; + this.onclose?.(); + }; + const prevErr = transport.onerror; + transport.onerror = e => { + prevErr?.(e); + this.onerror?.(e); + }; + + const noOutbound = (kind: string) => () => + Promise.reject( + new SdkError( + SdkErrorCode.NotConnected, + `Transport does not support out-of-band ${kind}; use ctx.mcpReq inside a handler.` + ) + ); + outbound = { + close: () => transport.close(), + notification: transport.notify + ? async (n, _opts) => transport.notify!({ jsonrpc: '2.0', ...n } as JSONRPCNotification) + : noOutbound('notifications'), + request: transport.request + ? (r, schema, opts) => + new Promise((resolve, reject) => { + const id = this._nextOutboundId++; + const wire = { jsonrpc: '2.0', id, method: r.method, params: r.params } as JSONRPCRequest; + const finish = (resp: JSONRPCResultResponse | Error) => { + if (resp instanceof Error) return reject(resp); + const parsed = parseSchema(schema, resp.result); + if (!parsed.success) return reject(parsed.error); + resolve(parsed.data as SchemaOutput); + }; + if (opts?.intercept?.(wire, id, finish, reject)) return; + transport.request!(wire) + .then(resp => + 'error' in resp + ? reject(ProtocolError.fromError(resp.error.code, resp.error.message, resp.error.data)) + : finish(resp) + ) + .catch(reject); + }) + : noOutbound('requests') + }; + } else { + outbound = await attachChannelTransport(transport, this, { + supportedProtocolVersions: this._supportedProtocolVersions, + debouncedNotificationMethods: this._options?.debouncedNotificationMethods, + buildEnv: (extra, base) => ({ ...base, _transportExtra: extra }), + onresponse: (r, id) => this._taskManager.processInboundResponse(r, id), + onclose: () => { + if (this._outbound === outbound) this._outbound = undefined; + this._taskManager.onClose(); + this.onclose?.(); + }, + onerror: e => this.onerror?.(e) + }); + } + this._outbound = outbound; + } + + private _nextOutboundId = 0; + + /** + * Closes the connection. + */ + async close(): Promise { + await this._outbound?.close(); + } + + /** + * Checks if the server is connected to a transport. + */ + isConnected(): boolean { + return this._outbound !== undefined; + } + + /** @deprecated The server is no longer coupled to a specific transport. Returns the underlying pipe only when connected via {@linkcode StreamDriver}. */ + get transport(): Transport | undefined { + return (this._outbound as { pipe?: Transport } | undefined)?.pipe; + } + + /** + * Returns this instance. Kept so v1 code that reaches `mcpServer.server.X` keeps working. + * @deprecated Call methods directly on `McpServer`. + */ + get server(): this { + return this; + } + + /** + * Access experimental features. + * @experimental + */ + get experimental(): { tasks: ExperimentalMcpServerTasks } { + if (!this._experimental) { + this._experimental = { tasks: new ExperimentalMcpServerTasks(this) }; + } + return this._experimental; + } + + /** Task orchestration. Always available; a {@linkcode NullTaskManager} when no task store is configured. */ + get taskManager(): TaskManager { + return this._taskManager; + } + + // ─────────────────────────────────────────────────────────────────────── + // Context building + // ─────────────────────────────────────────────────────────────────────── + + protected override buildContext(base: BaseContext, env: RequestEnv & { _transportExtra?: MessageExtraInfo }): ServerContext { + const extra = env._transportExtra; + const hasHttpInfo = base.http || env.httpReq || extra?.closeSSEStream || extra?.closeStandaloneSSEStream; + const ctx: ServerContext = { + ...base, + mcpReq: { + ...base.mcpReq, + log: (level, data, logger) => base.mcpReq.notify({ method: 'notifications/message', params: { level, data, logger } }), + elicitInput: (params, options) => this._elicitInputViaCtx(base, params, options), + requestSampling: (params, options) => this._createMessageViaCtx(base, params, options) + }, + http: hasHttpInfo + ? { + ...base.http, + req: env.httpReq, + closeSSE: extra?.closeSSEStream, + closeStandaloneSSE: extra?.closeStandaloneSSEStream + } + : undefined + }; + // v1 RequestHandlerExtra flat compat fields. New code should use ctx.mcpReq.* / ctx.http.*. + const compat = ctx as ServerContext & Record; + compat.signal = base.mcpReq.signal; + compat.requestId = base.mcpReq.id; + compat._meta = base.mcpReq._meta; + compat.sendNotification = base.mcpReq.notify; + compat.sendRequest = base.mcpReq.send; + compat.authInfo = ctx.http?.authInfo; + compat.requestInfo = env.httpReq; + return ctx; + } + + private async _elicitInputViaCtx( + base: BaseContext, + params: ElicitRequestFormParams | ElicitRequestURLParams, + options?: RequestOptions + ): Promise { + const mode = (params.mode ?? 'form') as 'form' | 'url'; + const formParams = mode === 'form' && params.mode !== 'form' ? { ...params, mode: 'form' as const } : params; + const result = (await base.mcpReq.send({ method: 'elicitation/create', params: formParams }, options)) as ElicitResult; + return this._validateElicitResult(result, mode === 'form' ? (formParams as ElicitRequestFormParams) : undefined); + } + + private async _createMessageViaCtx( + base: BaseContext, + params: CreateMessageRequest['params'], + options?: RequestOptions + ): Promise { + return base.mcpReq.send({ method: 'sampling/createMessage', params }, options) as Promise< + CreateMessageResult | CreateMessageResultWithTools + >; + } + + // ─────────────────────────────────────────────────────────────────────── + // Capabilities & initialize + // ─────────────────────────────────────────────────────────────────────── + + private async _oninitialize(request: InitializeRequest): Promise { + const requestedVersion = request.params.protocolVersion; + this._clientCapabilities = request.params.capabilities; + this._clientVersion = request.params.clientInfo; + + const protocolVersion = this._supportedProtocolVersions.includes(requestedVersion) + ? requestedVersion + : (this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION); + + this._outbound?.setProtocolVersion?.(protocolVersion); + + return { + protocolVersion, + capabilities: this.getCapabilities(), + serverInfo: this._serverInfo, + ...(this._instructions && { instructions: this._instructions }) + }; + } + + /** + * After initialization, populated with the client's reported capabilities. + */ + getClientCapabilities(): ClientCapabilities | undefined { + return this._clientCapabilities; + } + + /** + * After initialization, populated with the client's name and version. + */ + getClientVersion(): Implementation | undefined { + return this._clientVersion; + } + + /** + * Returns the current server capabilities. + */ + getCapabilities(): ServerCapabilities { + return this._capabilities; + } + + /** + * Registers new capabilities. Can only be called before connecting to a transport. + */ + registerCapabilities(capabilities: ServerCapabilities): void { + if (this._outbound) { + throw new SdkError(SdkErrorCode.AlreadyConnected, 'Cannot register capabilities after connecting to transport'); + } + const hadLogging = !!this._capabilities.logging; + this._capabilities = mergeCapabilities(this._capabilities, capabilities); + if (!hadLogging && this._capabilities.logging) { + this._registerLoggingHandler(); + } + } + + /** + * Override request handler registration to enforce server-side validation for `tools/call`. + * + * Also accepts the v1 form `setRequestHandler(zodRequestSchema, handler)` where the schema + * has a literal `method` shape (e.g. `z.object({method: z.literal('resources/subscribe')})`). + */ + public override setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise + ): void; + public override setRequestHandler( + method: string, + paramsSchema: S, + handler: (params: StandardSchemaV1.InferOutput, ctx: ServerContext) => Result | Promise + ): void; + /** @deprecated Pass a method string instead of a Zod request schema. */ + public override setRequestHandler( + schema: S, + handler: ( + request: S extends StandardSchemaV1 ? O : JSONRPCRequest, + ctx: ServerContext + ) => Result | Promise + ): void; + public override setRequestHandler( + methodOrSchema: string | { shape: { method: unknown } }, + handlerOrSchema: unknown, + maybeHandler?: (params: unknown, ctx: ServerContext) => Result | Promise + ): void { + if (maybeHandler !== undefined) { + const method = methodOrSchema as string; + assertRequestHandlerCapability(method as RequestMethod, this._capabilities); + this.setRawRequestHandler(method, this._wrapParamsSchemaHandler(method, handlerOrSchema as StandardSchemaV1, maybeHandler)); + return; + } + const handler = handlerOrSchema as (request: never, ctx: ServerContext) => Result | Promise; + const method = (typeof methodOrSchema === 'string' ? methodOrSchema : extractMethodFromSchema(methodOrSchema)) as RequestMethod; + assertRequestHandlerCapability(method, this._capabilities); + const h = handler as (request: JSONRPCRequest, ctx: ServerContext) => Result | Promise; + if (method === 'tools/call') { + const wrapped = async (request: JSONRPCRequest, ctx: ServerContext): Promise => { + const validated = parseSchema(CallToolRequestSchema, request); + if (!validated.success) { + const msg = validated.error instanceof Error ? validated.error.message : String(validated.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${msg}`); + } + const { params } = validated.data; + const result = await Promise.resolve(h(request, ctx)); + if (params.task) { + const taskValidation = parseSchema(CreateTaskResultSchema, result); + if (!taskValidation.success) { + const msg = taskValidation.error instanceof Error ? taskValidation.error.message : String(taskValidation.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${msg}`); + } + return taskValidation.data; + } + const resultValidation = parseSchema(CallToolResultSchema, result); + if (!resultValidation.success) { + const msg = resultValidation.error instanceof Error ? resultValidation.error.message : String(resultValidation.error); + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${msg}`); + } + return resultValidation.data; + }; + return super.setRequestHandler(method, wrapped as never); + } + return super.setRequestHandler(method, h as never); + } + + // ─────────────────────────────────────────────────────────────────────── + // Server→client requests (require a connected Outbound) + // ─────────────────────────────────────────────────────────────────────── + + private _requireOutbound(): Outbound { + if (!this._outbound) { + throw new SdkError( + SdkErrorCode.NotConnected, + 'Server is not connected. Use ctx.mcpReq.* inside handlers, or the MRTR-native return form, or call connect().' + ); + } + return this._outbound; + } + + private _outboundRequest(req: Request, schema: { parse(v: unknown): T }, options?: RequestOptions): Promise { + if (this._options?.enforceStrictCapabilities === true) { + assertCapabilityForMethod(req.method as RequestMethod, this._clientCapabilities); + } + return this._taskManager.sendRequest(req, schema as never, options, this._requireOutbound()) as Promise; + } + + /** + * Sends a request to the connected peer and awaits the result. Result schema is + * resolved from the method name. + */ + async request( + req: { method: M; params?: Record }, + options?: RequestOptions + ): Promise { + return this._outboundRequest(req as Request, getResultSchema(req.method), options) as Promise; + } + + async ping(): Promise { + return this._outboundRequest({ method: 'ping' }, EmptyResultSchema); + } + + /** + * Request LLM sampling from the client. Only available when connected via {@linkcode connect}. + * Inside a request handler, prefer `ctx.mcpReq.requestSampling`. + */ + async createMessage(params: CreateMessageRequestParamsBase, options?: RequestOptions): Promise; + async createMessage(params: CreateMessageRequestParamsWithTools, options?: RequestOptions): Promise; + async createMessage( + params: CreateMessageRequest['params'], + options?: RequestOptions + ): Promise; + async createMessage( + params: CreateMessageRequest['params'], + options?: RequestOptions + ): Promise { + if ((params.tools || params.toolChoice) && !this._clientCapabilities?.sampling?.tools) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support sampling tools capability.'); + } + if (params.messages.length > 0) { + const lastMessage = params.messages.at(-1)!; + const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; + const hasToolResults = lastContent.some(c => c.type === 'tool_result'); + const previousMessage = params.messages.length > 1 ? params.messages.at(-2) : undefined; + const previousContent = previousMessage + ? Array.isArray(previousMessage.content) + ? previousMessage.content + : [previousMessage.content] + : []; + const hasPreviousToolUse = previousContent.some(c => c.type === 'tool_use'); + if (hasToolResults) { + if (lastContent.some(c => c.type !== 'tool_result')) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + 'The last message must contain only tool_result content if any is present' + ); + } + if (!hasPreviousToolUse) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + 'tool_result blocks are not matching any tool_use from the previous message' + ); + } + } + if (hasPreviousToolUse) { + const toolUseIds = new Set(previousContent.filter(c => c.type === 'tool_use').map(c => (c as ToolUseContent).id)); + const toolResultIds = new Set( + lastContent.filter(c => c.type === 'tool_result').map(c => (c as ToolResultContent).toolUseId) + ); + if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + 'ids of tool_result blocks and tool_use blocks from previous message do not match' + ); + } + } + } + if (params.tools) { + return this._outboundRequest({ method: 'sampling/createMessage', params }, CreateMessageResultWithToolsSchema, options); + } + return this._outboundRequest({ method: 'sampling/createMessage', params }, CreateMessageResultSchema, options); + } + + /** + * Creates an elicitation request. Only available when connected via {@linkcode connect}. + * Inside a request handler, prefer `ctx.mcpReq.elicitInput`. + */ + async elicitInput(params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions): Promise { + const mode = (params.mode ?? 'form') as 'form' | 'url'; + switch (mode) { + case 'url': { + if (this._clientCapabilities && !this._clientCapabilities.elicitation?.url) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support url elicitation.'); + } + const urlParams = params as ElicitRequestURLParams; + return this._outboundRequest({ method: 'elicitation/create', params: urlParams }, ElicitResultSchema, options); + } + case 'form': { + if (this._clientCapabilities && !this._clientCapabilities.elicitation?.form) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support form elicitation.'); + } + const formParams: ElicitRequestFormParams = + params.mode === 'form' ? (params as ElicitRequestFormParams) : { ...(params as ElicitRequestFormParams), mode: 'form' }; + const result = await this._outboundRequest( + { method: 'elicitation/create', params: formParams }, + ElicitResultSchema, + options + ); + return this._validateElicitResult(result, formParams); + } + } + } + + private _validateElicitResult(result: ElicitResult, formParams?: ElicitRequestFormParams): ElicitResult { + if (result.action === 'accept' && result.content && formParams?.requestedSchema) { + try { + const validator = this._jsonSchemaValidator.getValidator(formParams.requestedSchema as JsonSchemaType); + const validation = validator(result.content); + if (!validation.valid) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Elicitation response content does not match requested schema: ${validation.errorMessage}` + ); + } + } catch (error) { + if (error instanceof ProtocolError) throw error; + throw new ProtocolError( + ProtocolErrorCode.InternalError, + `Error validating elicitation response: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + return result; + } + + createElicitationCompletionNotifier(elicitationId: string, options?: NotificationOptions): () => Promise { + if (this._clientCapabilities && !this._clientCapabilities.elicitation?.url) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + 'Client does not support URL elicitation (required for notifications/elicitation/complete)' + ); + } + return () => this.notification({ method: 'notifications/elicitation/complete', params: { elicitationId } }, options); + } + + async listRoots(params?: ListRootsRequest['params'], options?: RequestOptions) { + return this._outboundRequest({ method: 'roots/list', params }, ListRootsResultSchema, options); + } + + // ─────────────────────────────────────────────────────────────────────── + // Outbound notifications + // ─────────────────────────────────────────────────────────────────────── + + /** + * Sends a notification over the connected transport. No-op when not connected. + */ + async notification(notification: Notification, options?: NotificationOptions): Promise { + assertNotificationCapability(notification.method as NotificationMethod, this._capabilities, this._clientCapabilities); + if (this._outbound) await this._taskManager.sendNotification(notification, options, this._outbound); + } + + async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string): Promise { + if (this._capabilities.logging && !this._isMessageIgnored(params.level, sessionId)) { + return this.notification({ method: 'notifications/message', params }); + } + } + + async sendResourceUpdated(params: ResourceUpdatedNotification['params']): Promise { + return this.notification({ method: 'notifications/resources/updated', params }); + } + + async sendResourceListChanged(): Promise { + if (this.isConnected()) return this.notification({ method: 'notifications/resources/list_changed' }); + } + + async sendToolListChanged(): Promise { + if (this.isConnected()) return this.notification({ method: 'notifications/tools/list_changed' }); + } + + async sendPromptListChanged(): Promise { + if (this.isConnected()) return this.notification({ method: 'notifications/prompts/list_changed' }); + } + + private _registerLoggingHandler(): void { + this.setRequestHandler('logging/setLevel', async (request, ctx) => { + const transportSessionId = ctx.sessionId || ctx.http?.req?.headers.get('mcp-session-id') || undefined; + const { level } = request.params; + const parsed = parseSchema(LoggingLevelSchema, level); + if (parsed.success) { + this._loggingLevels.set(transportSessionId, parsed.data); + } + return {}; + }); + } + + private _isMessageIgnored(level: LoggingLevel, sessionId?: string): boolean { + const currentLevel = this._loggingLevels.get(sessionId); + return currentLevel ? this.LOG_LEVEL_SEVERITY.get(level)! < this.LOG_LEVEL_SEVERITY.get(currentLevel)! : false; + } + + // ─────────────────────────────────────────────────────────────────────── + // Registries (delegated to ServerRegistries) + // ─────────────────────────────────────────────────────────────────────── + + /** + * Registers a tool with a config object and callback. + */ + registerTool( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs | ZodRawShapeCompat; + outputSchema?: OutputArgs | ZodRawShapeCompat; + annotations?: ToolAnnotations; + _meta?: Record; + }, + cb: ToolCallback + ): RegisteredTool { + return this._registries.registerTool(name, config, cb); + } + + /** + * Registers a prompt with a config object and callback. + */ + registerPrompt( + name: string, + config: { title?: string; description?: string; argsSchema?: Args | ZodRawShapeCompat; _meta?: Record }, + cb: PromptCallback + ): RegisteredPrompt { + return this._registries.registerPrompt(name, config, cb); + } + + /** + * Registers a resource with a config object and callback. + */ + registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: ResourceTemplate, + config: ResourceMetadata, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate; + registerResource( + name: string, + uriOrTemplate: string | ResourceTemplate, + config: ResourceMetadata, + readCallback: ReadResourceCallback | ReadResourceTemplateCallback + ): RegisteredResource | RegisteredResourceTemplate { + return this._registries.registerResource(name, uriOrTemplate as never, config, readCallback as never); + } + + // ─────────────────────────────────────────────────────────────────────── + // v1-internal compat surface — for code that monkey-patches McpServer + // private methods (e.g., shortcut's CustomMcpServer overrides + // setToolRequestHandlers). Routed through here so instance overrides fire. + // ─────────────────────────────────────────────────────────────────────── + + /** @hidden v1 compat: lazy installer hook, override on instance to customize tools/* handlers. */ + setToolRequestHandlers(): void { + this._registries.setToolRequestHandlers(); + } + /** @hidden v1 compat */ + setResourceRequestHandlers(): void { + this._registries.setResourceRequestHandlers(); + } + /** @hidden v1 compat */ + setPromptRequestHandlers(): void { + this._registries.setPromptRequestHandlers(); + } + /** @hidden v1 compat */ + setCompletionRequestHandler(): void { + this._registries.setCompletionRequestHandler(); + } + /** @hidden v1 compat */ + protected validateToolInput(tool: RegisteredTool, args: unknown, toolName: string) { + return this._registries.validateToolInput(tool, args as never, toolName); + } + /** @hidden v1 compat */ + protected validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string) { + return this._registries.validateToolOutput(tool, result, toolName); + } + /** @hidden v1 compat */ + protected handleAutomaticTaskPolling(tool: RegisteredTool, request: CallToolRequest, ctx: ServerContext) { + return this._registries.handleAutomaticTaskPolling(tool, request, ctx); + } + /** @hidden v1 compat: was a private instance method in v1 mcp.ts. */ + protected createToolError(errorMessage: string): CallToolResult { + return { content: [{ type: 'text', text: errorMessage }], isError: true }; + } + /** @hidden v1 compat: removed in v2 (replaced by `tool.executor`); shim calls executor. */ + protected executeToolHandler(tool: RegisteredTool, args: unknown, ctx: ServerContext) { + return tool.executor(args as never, ctx); + } + + /** @hidden v1 compat for `(mcpServer as any)._registeredTools` and `experimental.tasks`. */ + get _registeredTools(): { [name: string]: RegisteredTool } { + return this._registries.registeredTools; + } + /** @hidden v1 compat. */ + get _registeredResources(): { [uri: string]: RegisteredResource } { + return this._registries.registeredResources; + } + /** @hidden v1 compat. */ + get _registeredResourceTemplates(): { [name: string]: RegisteredResourceTemplate } { + return this._registries.registeredResourceTemplates; + } + /** @hidden v1 compat. */ + get _registeredPrompts(): { [name: string]: RegisteredPrompt } { + return this._registries.registeredPrompts; + } + + /** @hidden v1 compat for `experimental.tasks.registerToolTask` which calls this directly. */ + _createRegisteredTool( + name: string, + title: string | undefined, + description: string | undefined, + inputSchema: StandardSchemaWithJSON | undefined, + outputSchema: StandardSchemaWithJSON | undefined, + annotations: ToolAnnotations | undefined, + execution: ToolExecution | undefined, + _meta: Record | undefined, + handler: AnyToolHandler + ): RegisteredTool { + return this._registries.createRegisteredTool( + name, + title, + description, + inputSchema, + outputSchema, + annotations, + execution, + _meta, + handler + ); + } + + // ─────────────────────────────────────────────────────────────────────── + // Deprecated v1 overloads (positional, raw-shape) — call register* internally + // ─────────────────────────────────────────────────────────────────────── + + /** @deprecated Use {@linkcode McpServer.registerTool | registerTool()} instead. */ + tool(name: string, cb: ToolCallback): RegisteredTool; + /** @deprecated Use {@linkcode McpServer.registerTool | registerTool()} instead. */ + tool(name: string, description: string, cb: ToolCallback): RegisteredTool; + /** @deprecated Use {@linkcode McpServer.registerTool | registerTool()} instead. */ + tool( + name: string, + paramsSchemaOrAnnotations: Args | ToolAnnotations, + cb: LegacyToolCallback + ): RegisteredTool; + /** @deprecated Use {@linkcode McpServer.registerTool | registerTool()} instead. */ + tool( + name: string, + description: string, + paramsSchemaOrAnnotations: Args | ToolAnnotations, + cb: LegacyToolCallback + ): RegisteredTool; + /** @deprecated Use {@linkcode McpServer.registerTool | registerTool()} instead. */ + tool( + name: string, + paramsSchema: Args, + annotations: ToolAnnotations, + cb: LegacyToolCallback + ): RegisteredTool; + /** @deprecated Use {@linkcode McpServer.registerTool | registerTool()} instead. */ + tool( + name: string, + description: string, + paramsSchema: Args, + annotations: ToolAnnotations, + cb: LegacyToolCallback + ): RegisteredTool; + tool(name: string, ...rest: unknown[]): RegisteredTool { + if (this._registries.registeredTools[name]) throw new Error(`Tool ${name} is already registered`); + const { description, inputSchema, annotations, cb } = parseLegacyToolArgs(name, rest); + return this._registries.createRegisteredTool( + name, + undefined, + description, + inputSchema, + undefined, + annotations, + { taskSupport: 'forbidden' }, + undefined, + cb as ToolCallback + ); + } + + /** @deprecated Use {@linkcode McpServer.registerPrompt | registerPrompt()} instead. */ + prompt(name: string, cb: PromptCallback): RegisteredPrompt; + /** @deprecated Use {@linkcode McpServer.registerPrompt | registerPrompt()} instead. */ + prompt(name: string, description: string, cb: PromptCallback): RegisteredPrompt; + /** @deprecated Use {@linkcode McpServer.registerPrompt | registerPrompt()} instead. */ + prompt(name: string, argsSchema: Args, cb: LegacyPromptCallback): RegisteredPrompt; + /** @deprecated Use {@linkcode McpServer.registerPrompt | registerPrompt()} instead. */ + prompt( + name: string, + description: string, + argsSchema: Args, + cb: LegacyPromptCallback + ): RegisteredPrompt; + prompt(name: string, ...rest: unknown[]): RegisteredPrompt { + if (this._registries.registeredPrompts[name]) throw new Error(`Prompt ${name} is already registered`); + const { description, argsSchema, cb } = parseLegacyPromptArgs(rest); + const r = this._registries.createRegisteredPrompt( + name, + undefined, + description, + argsSchema, + cb as PromptCallback, + undefined + ); + this._registries.installPromptHandlers(); + this.sendPromptListChanged(); + return r; + } + + /** @deprecated Use {@linkcode McpServer.registerResource | registerResource()} instead. */ + resource(name: string, uri: string, readCallback: ReadResourceCallback): RegisteredResource; + /** @deprecated Use {@linkcode McpServer.registerResource | registerResource()} instead. */ + resource(name: string, uri: string, metadata: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + /** @deprecated Use {@linkcode McpServer.registerResource | registerResource()} instead. */ + resource(name: string, template: ResourceTemplate, readCallback: ReadResourceTemplateCallback): RegisteredResourceTemplate; + /** @deprecated Use {@linkcode McpServer.registerResource | registerResource()} instead. */ + resource( + name: string, + template: ResourceTemplate, + metadata: ResourceMetadata, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate; + resource(name: string, uriOrTemplate: string | ResourceTemplate, ...rest: unknown[]): RegisteredResource | RegisteredResourceTemplate { + let metadata: ResourceMetadata | undefined; + if (typeof rest[0] === 'object') metadata = rest.shift() as ResourceMetadata; + const readCallback = rest[0] as ReadResourceCallback | ReadResourceTemplateCallback; + if (typeof uriOrTemplate === 'string') { + if (this._registries.registeredResources[uriOrTemplate]) throw new Error(`Resource ${uriOrTemplate} is already registered`); + const r = this._registries.createRegisteredResource( + name, + undefined, + uriOrTemplate, + metadata, + readCallback as ReadResourceCallback + ); + this._registries.installResourceHandlers(); + this.sendResourceListChanged(); + return r; + } + if (this._registries.registeredResourceTemplates[name]) throw new Error(`Resource template ${name} is already registered`); + const r = this._registries.createRegisteredResourceTemplate( + name, + undefined, + uriOrTemplate, + metadata, + readCallback as ReadResourceTemplateCallback + ); + this._registries.installResourceHandlers(); + this.sendResourceListChanged(); + return r; + } +} + +function jsonResponse(status: number, body: unknown): Response { + return Response.json(body, { status, headers: { 'content-type': 'application/json' } }); +} + +// ─────────────────────────────────────────────────────────────────────────── +// Re-exports for path compat. External code imports these from './mcpServer.js'. +// ─────────────────────────────────────────────────────────────────────────── + +export type { CompleteResourceTemplateCallback, ListResourcesCallback } from './resourceTemplate.js'; +export { ResourceTemplate } from './resourceTemplate.js'; +export type { + AnyToolHandler, + BaseToolCallback, + PromptCallback, + ReadResourceCallback, + ReadResourceTemplateCallback, + RegisteredPrompt, + RegisteredResource, + RegisteredResourceTemplate, + RegisteredTool, + ResourceMetadata, + ToolCallback +} from './serverRegistries.js'; diff --git a/packages/server/src/server/resourceTemplate.ts b/packages/server/src/server/resourceTemplate.ts new file mode 100644 index 000000000..ac994dfdc --- /dev/null +++ b/packages/server/src/server/resourceTemplate.ts @@ -0,0 +1,45 @@ +import type { ListResourcesResult, ServerContext } from '@modelcontextprotocol/core'; +import { UriTemplate } from '@modelcontextprotocol/core'; + +/** + * A callback to list all resources matching a template. + */ +export type ListResourcesCallback = (ctx: ServerContext) => ListResourcesResult | Promise; + +/** + * A callback to complete one variable within a resource template's URI template. + */ +export type CompleteResourceTemplateCallback = ( + value: string, + context?: { arguments?: Record } +) => string[] | Promise; + +/** + * A resource template combines a URI pattern with optional functionality to enumerate + * all resources matching that pattern. + */ +export class ResourceTemplate { + private _uriTemplate: UriTemplate; + + constructor( + uriTemplate: string | UriTemplate, + private _callbacks: { + list: ListResourcesCallback | undefined; + complete?: { [variable: string]: CompleteResourceTemplateCallback }; + } + ) { + this._uriTemplate = typeof uriTemplate === 'string' ? new UriTemplate(uriTemplate) : uriTemplate; + } + + get uriTemplate(): UriTemplate { + return this._uriTemplate; + } + + get listCallback(): ListResourcesCallback | undefined { + return this._callbacks.list; + } + + completeCallback(variable: string): CompleteResourceTemplateCallback | undefined { + return this._callbacks.complete?.[variable]; + } +} diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 4361f3e1e..1945af6ae 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -1,677 +1,7 @@ -import type { - BaseContext, - ClientCapabilities, - CreateMessageRequest, - CreateMessageRequestParamsBase, - CreateMessageRequestParamsWithTools, - CreateMessageResult, - CreateMessageResultWithTools, - ElicitRequestFormParams, - ElicitRequestURLParams, - ElicitResult, - Implementation, - InitializeRequest, - InitializeResult, - JsonSchemaType, - jsonSchemaValidator, - ListRootsRequest, - LoggingLevel, - LoggingMessageNotification, - MessageExtraInfo, - NotificationMethod, - NotificationOptions, - ProtocolOptions, - RequestMethod, - RequestOptions, - RequestTypeMap, - ResourceUpdatedNotification, - ResultTypeMap, - ServerCapabilities, - ServerContext, - ServerResult, - TaskManagerOptions, - ToolResultContent, - ToolUseContent -} from '@modelcontextprotocol/core'; -import { - assertClientRequestTaskCapability, - assertToolsCallTaskCapability, - CallToolRequestSchema, - CallToolResultSchema, - CreateMessageResultSchema, - CreateMessageResultWithToolsSchema, - CreateTaskResultSchema, - ElicitResultSchema, - EmptyResultSchema, - extractTaskManagerOptions, - LATEST_PROTOCOL_VERSION, - ListRootsResultSchema, - LoggingLevelSchema, - mergeCapabilities, - parseSchema, - Protocol, - ProtocolError, - ProtocolErrorCode, - SdkError, - SdkErrorCode -} from '@modelcontextprotocol/core'; -import { DefaultJsonSchemaValidator } from '@modelcontextprotocol/server/_shims'; - -import { ExperimentalServerTasks } from '../experimental/tasks/server.js'; - /** - * Extended tasks capability that includes runtime configuration (store, messageQueue). - * The runtime-only fields are stripped before advertising capabilities to clients. + * v1-compat module path. The low-level `Server` class is now an alias for + * {@linkcode McpServer}; see {@link ./compat.ts} and {@link ./mcpServer.ts}. + * @deprecated Import from `@modelcontextprotocol/server` directly. */ -export type ServerTasksCapabilityWithRuntime = NonNullable & TaskManagerOptions; - -export type ServerOptions = ProtocolOptions & { - /** - * Capabilities to advertise as being supported by this server. - */ - capabilities?: Omit & { - tasks?: ServerTasksCapabilityWithRuntime; - }; - - /** - * Optional instructions describing how to use the server and its features. - */ - instructions?: string; - - /** - * JSON Schema validator for elicitation response validation. - * - * The validator is used to validate user input returned from elicitation - * requests against the requested schema. - * - * @default {@linkcode DefaultJsonSchemaValidator} ({@linkcode index.AjvJsonSchemaValidator | AjvJsonSchemaValidator} on Node.js, `CfWorkerJsonSchemaValidator` on Cloudflare Workers) - */ - jsonSchemaValidator?: jsonSchemaValidator; -}; - -/** - * An MCP server on top of a pluggable transport. - * - * This server will automatically respond to the initialization flow as initiated from the client. - * - * @deprecated Use {@linkcode server/mcp.McpServer | McpServer} instead for the high-level API. Only use `Server` for advanced use cases. - */ -export class Server extends Protocol { - private _clientCapabilities?: ClientCapabilities; - private _clientVersion?: Implementation; - private _capabilities: ServerCapabilities; - private _instructions?: string; - private _jsonSchemaValidator: jsonSchemaValidator; - private _experimental?: { tasks: ExperimentalServerTasks }; - - /** - * Callback for when initialization has fully completed (i.e., the client has sent an `notifications/initialized` notification). - */ - oninitialized?: () => void; - - /** - * Initializes this server with the given name and version information. - */ - constructor( - private _serverInfo: Implementation, - options?: ServerOptions - ) { - super({ - ...options, - tasks: extractTaskManagerOptions(options?.capabilities?.tasks) - }); - this._capabilities = options?.capabilities ? { ...options.capabilities } : {}; - this._instructions = options?.instructions; - this._jsonSchemaValidator = options?.jsonSchemaValidator ?? new DefaultJsonSchemaValidator(); - - // Strip runtime-only fields from advertised capabilities - if (options?.capabilities?.tasks) { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { taskStore, taskMessageQueue, defaultTaskPollInterval, maxTaskQueueSize, ...wireCapabilities } = - options.capabilities.tasks; - this._capabilities.tasks = wireCapabilities; - } - - this.setRequestHandler('initialize', request => this._oninitialize(request)); - this.setNotificationHandler('notifications/initialized', () => this.oninitialized?.()); - - if (this._capabilities.logging) { - this._registerLoggingHandler(); - } - } - - private _registerLoggingHandler(): void { - this.setRequestHandler('logging/setLevel', async (request, ctx) => { - const transportSessionId: string | undefined = - ctx.sessionId || (ctx.http?.req?.headers.get('mcp-session-id') as string) || undefined; - const { level } = request.params; - const parseResult = parseSchema(LoggingLevelSchema, level); - if (parseResult.success) { - this._loggingLevels.set(transportSessionId, parseResult.data); - } - return {}; - }); - } - - protected override buildContext(ctx: BaseContext, transportInfo?: MessageExtraInfo): ServerContext { - // Only create http when there's actual HTTP transport info or auth info - const hasHttpInfo = ctx.http || transportInfo?.request || transportInfo?.closeSSEStream || transportInfo?.closeStandaloneSSEStream; - return { - ...ctx, - mcpReq: { - ...ctx.mcpReq, - log: (level, data, logger) => this.sendLoggingMessage({ level, data, logger }), - elicitInput: (params, options) => this.elicitInput(params, options), - requestSampling: (params, options) => this.createMessage(params, options) - }, - http: hasHttpInfo - ? { - ...ctx.http, - req: transportInfo?.request, - closeSSE: transportInfo?.closeSSEStream, - closeStandaloneSSE: transportInfo?.closeStandaloneSSEStream - } - : undefined - }; - } - - /** - * Access experimental features. - * - * WARNING: These APIs are experimental and may change without notice. - * - * @experimental - */ - get experimental(): { tasks: ExperimentalServerTasks } { - if (!this._experimental) { - this._experimental = { - tasks: new ExperimentalServerTasks(this) - }; - } - return this._experimental; - } - - // Map log levels by session id - private _loggingLevels = new Map(); - - // Map LogLevelSchema to severity index - private readonly LOG_LEVEL_SEVERITY = new Map(LoggingLevelSchema.options.map((level, index) => [level, index])); - - // Is a message with the given level ignored in the log level set for the given session id? - private isMessageIgnored = (level: LoggingLevel, sessionId?: string): boolean => { - const currentLevel = this._loggingLevels.get(sessionId); - return currentLevel ? this.LOG_LEVEL_SEVERITY.get(level)! < this.LOG_LEVEL_SEVERITY.get(currentLevel)! : false; - }; - - /** - * Registers new capabilities. This can only be called before connecting to a transport. - * - * The new capabilities will be merged with any existing capabilities previously given (e.g., at initialization). - */ - public registerCapabilities(capabilities: ServerCapabilities): void { - if (this.transport) { - throw new SdkError(SdkErrorCode.AlreadyConnected, 'Cannot register capabilities after connecting to transport'); - } - const hadLogging = !!this._capabilities.logging; - this._capabilities = mergeCapabilities(this._capabilities, capabilities); - if (!hadLogging && this._capabilities.logging) { - this._registerLoggingHandler(); - } - } - - /** - * Override request handler registration to enforce server-side validation for `tools/call`. - */ - public override setRequestHandler( - method: M, - handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise - ): void { - if (method === 'tools/call') { - const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise => { - const validatedRequest = parseSchema(CallToolRequestSchema, request); - if (!validatedRequest.success) { - const errorMessage = - validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`); - } - - const { params } = validatedRequest.data; - - const result = await Promise.resolve(handler(request, ctx)); - - // When task creation is requested, validate and return CreateTaskResult - if (params.task) { - const taskValidationResult = parseSchema(CreateTaskResultSchema, result); - if (!taskValidationResult.success) { - const errorMessage = - taskValidationResult.error instanceof Error - ? taskValidationResult.error.message - : String(taskValidationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`); - } - return taskValidationResult.data; - } - - // For non-task requests, validate against CallToolResultSchema - const validationResult = parseSchema(CallToolResultSchema, result); - if (!validationResult.success) { - const errorMessage = - validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); - throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`); - } - - return validationResult.data; - }; - - // Install the wrapped handler - return super.setRequestHandler(method, wrappedHandler); - } - - // Other handlers use default behavior - return super.setRequestHandler(method, handler); - } - - protected assertCapabilityForMethod(method: RequestMethod): void { - switch (method) { - case 'sampling/createMessage': { - if (!this._clientCapabilities?.sampling) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Client does not support sampling (required for ${method})`); - } - break; - } - - case 'elicitation/create': { - if (!this._clientCapabilities?.elicitation) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Client does not support elicitation (required for ${method})`); - } - break; - } - - case 'roots/list': { - if (!this._clientCapabilities?.roots) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Client does not support listing roots (required for ${method})` - ); - } - break; - } - - case 'ping': { - // No specific capability required for ping - break; - } - } - } - - protected assertNotificationCapability(method: NotificationMethod): void { - switch (method) { - case 'notifications/message': { - if (!this._capabilities.logging) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); - } - break; - } - - case 'notifications/resources/updated': - case 'notifications/resources/list_changed': { - if (!this._capabilities.resources) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Server does not support notifying about resources (required for ${method})` - ); - } - break; - } - - case 'notifications/tools/list_changed': { - if (!this._capabilities.tools) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Server does not support notifying of tool list changes (required for ${method})` - ); - } - break; - } - - case 'notifications/prompts/list_changed': { - if (!this._capabilities.prompts) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Server does not support notifying of prompt list changes (required for ${method})` - ); - } - break; - } - - case 'notifications/elicitation/complete': { - if (!this._clientCapabilities?.elicitation?.url) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - `Client does not support URL elicitation (required for ${method})` - ); - } - break; - } - - case 'notifications/cancelled': { - // Cancellation notifications are always allowed - break; - } - - case 'notifications/progress': { - // Progress notifications are always allowed - break; - } - } - } - - protected assertRequestHandlerCapability(method: string): void { - switch (method) { - case 'completion/complete': { - if (!this._capabilities.completions) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support completions (required for ${method})`); - } - break; - } - - case 'logging/setLevel': { - if (!this._capabilities.logging) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); - } - break; - } - - case 'prompts/get': - case 'prompts/list': { - if (!this._capabilities.prompts) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support prompts (required for ${method})`); - } - break; - } - - case 'resources/list': - case 'resources/templates/list': - case 'resources/read': { - if (!this._capabilities.resources) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support resources (required for ${method})`); - } - break; - } - - case 'tools/call': - case 'tools/list': { - if (!this._capabilities.tools) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support tools (required for ${method})`); - } - break; - } - - case 'ping': - case 'initialize': { - // No specific capability required for these methods - break; - } - } - } - - protected assertTaskCapability(method: string): void { - assertClientRequestTaskCapability(this._clientCapabilities?.tasks?.requests, method, 'Client'); - } - - protected assertTaskHandlerCapability(method: string): void { - assertToolsCallTaskCapability(this._capabilities?.tasks?.requests, method, 'Server'); - } - - private async _oninitialize(request: InitializeRequest): Promise { - const requestedVersion = request.params.protocolVersion; - - this._clientCapabilities = request.params.capabilities; - this._clientVersion = request.params.clientInfo; - - const protocolVersion = this._supportedProtocolVersions.includes(requestedVersion) - ? requestedVersion - : (this._supportedProtocolVersions[0] ?? LATEST_PROTOCOL_VERSION); - - this.transport?.setProtocolVersion?.(protocolVersion); - - return { - protocolVersion, - capabilities: this.getCapabilities(), - serverInfo: this._serverInfo, - ...(this._instructions && { instructions: this._instructions }) - }; - } - - /** - * After initialization has completed, this will be populated with the client's reported capabilities. - */ - getClientCapabilities(): ClientCapabilities | undefined { - return this._clientCapabilities; - } - - /** - * After initialization has completed, this will be populated with information about the client's name and version. - */ - getClientVersion(): Implementation | undefined { - return this._clientVersion; - } - - /** - * Returns the current server capabilities. - */ - public getCapabilities(): ServerCapabilities { - return this._capabilities; - } - - async ping() { - return this._requestWithSchema({ method: 'ping' }, EmptyResultSchema); - } - - /** - * Request LLM sampling from the client (without tools). - * Returns single content block for backwards compatibility. - */ - async createMessage(params: CreateMessageRequestParamsBase, options?: RequestOptions): Promise; - - /** - * Request LLM sampling from the client with tool support. - * Returns content that may be a single block or array (for parallel tool calls). - */ - async createMessage(params: CreateMessageRequestParamsWithTools, options?: RequestOptions): Promise; - - /** - * Request LLM sampling from the client. - * When tools may or may not be present, returns the union type. - */ - async createMessage( - params: CreateMessageRequest['params'], - options?: RequestOptions - ): Promise; - - // Implementation - async createMessage( - params: CreateMessageRequest['params'], - options?: RequestOptions - ): Promise { - // Capability check - only required when tools/toolChoice are provided - if ((params.tools || params.toolChoice) && !this._clientCapabilities?.sampling?.tools) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support sampling tools capability.'); - } - - // Message structure validation - always validate tool_use/tool_result pairs. - // These may appear even without tools/toolChoice in the current request when - // a previous sampling request returned tool_use and this is a follow-up with results. - if (params.messages.length > 0) { - const lastMessage = params.messages.at(-1)!; - const lastContent = Array.isArray(lastMessage.content) ? lastMessage.content : [lastMessage.content]; - const hasToolResults = lastContent.some(c => c.type === 'tool_result'); - - const previousMessage = params.messages.length > 1 ? params.messages.at(-2) : undefined; - const previousContent = previousMessage - ? Array.isArray(previousMessage.content) - ? previousMessage.content - : [previousMessage.content] - : []; - const hasPreviousToolUse = previousContent.some(c => c.type === 'tool_use'); - - if (hasToolResults) { - if (lastContent.some(c => c.type !== 'tool_result')) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - 'The last message must contain only tool_result content if any is present' - ); - } - if (!hasPreviousToolUse) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - 'tool_result blocks are not matching any tool_use from the previous message' - ); - } - } - if (hasPreviousToolUse) { - const toolUseIds = new Set(previousContent.filter(c => c.type === 'tool_use').map(c => (c as ToolUseContent).id)); - const toolResultIds = new Set( - lastContent.filter(c => c.type === 'tool_result').map(c => (c as ToolResultContent).toolUseId) - ); - if (toolUseIds.size !== toolResultIds.size || ![...toolUseIds].every(id => toolResultIds.has(id))) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - 'ids of tool_result blocks and tool_use blocks from previous message do not match' - ); - } - } - } - - // Use different schemas based on whether tools are provided - if (params.tools) { - return this._requestWithSchema({ method: 'sampling/createMessage', params }, CreateMessageResultWithToolsSchema, options); - } - return this._requestWithSchema({ method: 'sampling/createMessage', params }, CreateMessageResultSchema, options); - } - - /** - * Creates an elicitation request for the given parameters. - * For backwards compatibility, `mode` may be omitted for form requests and will default to `"form"`. - * @param params The parameters for the elicitation request. - * @param options Optional request options. - * @returns The result of the elicitation request. - */ - async elicitInput(params: ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions): Promise { - const mode = (params.mode ?? 'form') as 'form' | 'url'; - - switch (mode) { - case 'url': { - if (!this._clientCapabilities?.elicitation?.url) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support url elicitation.'); - } - - const urlParams = params as ElicitRequestURLParams; - return this._requestWithSchema({ method: 'elicitation/create', params: urlParams }, ElicitResultSchema, options); - } - case 'form': { - if (!this._clientCapabilities?.elicitation?.form) { - throw new SdkError(SdkErrorCode.CapabilityNotSupported, 'Client does not support form elicitation.'); - } - - const formParams: ElicitRequestFormParams = - params.mode === 'form' ? (params as ElicitRequestFormParams) : { ...(params as ElicitRequestFormParams), mode: 'form' }; - - const result = await this._requestWithSchema( - { method: 'elicitation/create', params: formParams }, - ElicitResultSchema, - options - ); - - if (result.action === 'accept' && result.content && formParams.requestedSchema) { - try { - const validator = this._jsonSchemaValidator.getValidator(formParams.requestedSchema as JsonSchemaType); - const validationResult = validator(result.content); - - if (!validationResult.valid) { - throw new ProtocolError( - ProtocolErrorCode.InvalidParams, - `Elicitation response content does not match requested schema: ${validationResult.errorMessage}` - ); - } - } catch (error) { - if (error instanceof ProtocolError) { - throw error; - } - throw new ProtocolError( - ProtocolErrorCode.InternalError, - `Error validating elicitation response: ${error instanceof Error ? error.message : String(error)}` - ); - } - } - return result; - } - } - } - - /** - * Creates a reusable callback that, when invoked, will send a `notifications/elicitation/complete` - * notification for the specified elicitation ID. - * - * @param elicitationId The ID of the elicitation to mark as complete. - * @param options Optional notification options. Useful when the completion notification should be related to a prior request. - * @returns A function that emits the completion notification when awaited. - */ - createElicitationCompletionNotifier(elicitationId: string, options?: NotificationOptions): () => Promise { - if (!this._clientCapabilities?.elicitation?.url) { - throw new SdkError( - SdkErrorCode.CapabilityNotSupported, - 'Client does not support URL elicitation (required for notifications/elicitation/complete)' - ); - } - - return () => - this.notification( - { - method: 'notifications/elicitation/complete', - params: { - elicitationId - } - }, - options - ); - } - - async listRoots(params?: ListRootsRequest['params'], options?: RequestOptions) { - return this._requestWithSchema({ method: 'roots/list', params }, ListRootsResultSchema, options); - } - - /** - * Sends a logging message to the client, if connected. - * Note: You only need to send the parameters object, not the entire JSON-RPC message. - * @see {@linkcode LoggingMessageNotification} - * @param params - * @param sessionId Optional for stateless transports and backward compatibility. - */ - async sendLoggingMessage(params: LoggingMessageNotification['params'], sessionId?: string) { - if (this._capabilities.logging && !this.isMessageIgnored(params.level, sessionId)) { - return this.notification({ method: 'notifications/message', params }); - } - } - - async sendResourceUpdated(params: ResourceUpdatedNotification['params']) { - return this.notification({ - method: 'notifications/resources/updated', - params - }); - } - - async sendResourceListChanged() { - return this.notification({ - method: 'notifications/resources/list_changed' - }); - } - - async sendToolListChanged() { - return this.notification({ method: 'notifications/tools/list_changed' }); - } - - async sendPromptListChanged() { - return this.notification({ method: 'notifications/prompts/list_changed' }); - } -} +export { Server } from './compat.js'; +export type { ServerOptions, ServerTasksCapabilityWithRuntime } from './mcpServer.js'; diff --git a/packages/server/src/server/serverCapabilities.ts b/packages/server/src/server/serverCapabilities.ts new file mode 100644 index 000000000..a82b10848 --- /dev/null +++ b/packages/server/src/server/serverCapabilities.ts @@ -0,0 +1,127 @@ +import type { ClientCapabilities, NotificationMethod, RequestMethod, ServerCapabilities } from '@modelcontextprotocol/core'; +import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; + +/** + * Throws if the connected client does not advertise the capability required + * for the server to send the given outbound request. + */ +export function assertCapabilityForMethod(method: RequestMethod, clientCapabilities: ClientCapabilities | undefined): void { + switch (method) { + case 'sampling/createMessage': { + if (!clientCapabilities?.sampling) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Client does not support sampling (required for ${method})`); + } + break; + } + case 'elicitation/create': { + if (!clientCapabilities?.elicitation) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Client does not support elicitation (required for ${method})`); + } + break; + } + case 'roots/list': { + if (!clientCapabilities?.roots) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Client does not support listing roots (required for ${method})`); + } + break; + } + } +} + +/** + * Throws if either side lacks the capability required for the server to emit + * the given notification. + */ +export function assertNotificationCapability( + method: NotificationMethod, + serverCapabilities: ServerCapabilities, + clientCapabilities: ClientCapabilities | undefined +): void { + switch (method) { + case 'notifications/message': { + if (!serverCapabilities.logging) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); + } + break; + } + case 'notifications/resources/updated': + case 'notifications/resources/list_changed': { + if (!serverCapabilities.resources) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + `Server does not support notifying about resources (required for ${method})` + ); + } + break; + } + case 'notifications/tools/list_changed': { + if (!serverCapabilities.tools) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + `Server does not support notifying of tool list changes (required for ${method})` + ); + } + break; + } + case 'notifications/prompts/list_changed': { + if (!serverCapabilities.prompts) { + throw new SdkError( + SdkErrorCode.CapabilityNotSupported, + `Server does not support notifying of prompt list changes (required for ${method})` + ); + } + break; + } + case 'notifications/elicitation/complete': { + if (!clientCapabilities?.elicitation?.url) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Client does not support URL elicitation (required for ${method})`); + } + break; + } + } +} + +/** + * Throws if the server does not advertise the capability required to register + * a handler for the given inbound request method. + */ +export function assertRequestHandlerCapability(method: string, serverCapabilities: ServerCapabilities): void { + switch (method) { + case 'completion/complete': { + if (!serverCapabilities.completions) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support completions (required for ${method})`); + } + break; + } + case 'logging/setLevel': { + if (!serverCapabilities.logging) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support logging (required for ${method})`); + } + break; + } + case 'prompts/get': + case 'prompts/list': { + if (!serverCapabilities.prompts) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support prompts (required for ${method})`); + } + break; + } + case 'resources/list': + case 'resources/templates/list': + case 'resources/read': + case 'resources/subscribe': + case 'resources/unsubscribe': { + if (!serverCapabilities.resources) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support resources (required for ${method})`); + } + break; + } + case 'tools/call': + case 'tools/list': { + if (!serverCapabilities.tools) { + throw new SdkError(SdkErrorCode.CapabilityNotSupported, `Server does not support tools (required for ${method})`); + } + break; + } + } +} diff --git a/packages/server/src/server/serverLegacy.ts b/packages/server/src/server/serverLegacy.ts new file mode 100644 index 000000000..9fc61c541 --- /dev/null +++ b/packages/server/src/server/serverLegacy.ts @@ -0,0 +1,221 @@ +import type { CallToolResult, GetPromptResult, ServerContext, StandardSchemaWithJSON, ToolAnnotations } from '@modelcontextprotocol/core'; +import { isStandardSchema, isStandardSchemaWithJSON } from '@modelcontextprotocol/core'; +import { z } from 'zod/v4'; + +/** + * v1 compat: a "raw shape" is a plain object whose values are Zod schemas + * (e.g. `{ name: z.string() }`), or an empty object. v1's `tool()`/`prompt()` + * and `registerTool({inputSchema:{}})` accepted these directly. + */ +export type ZodRawShapeCompat = Record; + +/** v1-style callback signature for the deprecated {@linkcode McpServer.tool | tool()} overloads. */ +export type LegacyToolCallback = ( + args: z.infer>, + ctx: ServerContext +) => CallToolResult | Promise; + +/** v1-style callback signature for the deprecated {@linkcode McpServer.prompt | prompt()} overloads. */ +export type LegacyPromptCallback = ( + args: z.infer>, + ctx: ServerContext +) => GetPromptResult | Promise; + +/** + * v1 compat: extract the literal method string from a `z.object({method: z.literal('x'), ...})` schema. + */ +export function extractMethodFromSchema(schema: { shape: { method: unknown } }): string { + const lit = schema.shape.method as + | { value?: unknown; def?: { values?: unknown[] }; _zod?: { def?: { values?: unknown[] } } } + | undefined; + const v = lit?.value ?? lit?.def?.values?.[0] ?? lit?._zod?.def?.values?.[0]; + if (typeof v !== 'string') { + throw new TypeError('setRequestHandler(schema, handler): schema.shape.method must be a z.literal(string)'); + } + return v; +} + +function isZodTypeLike(v: unknown): boolean { + if (v == null || typeof v !== 'object') return false; + return '_zod' in (v as object) || '_def' in (v as object); +} + +function isZodV4Type(v: unknown): boolean { + return v != null && typeof v === 'object' && '_zod' in (v as object); +} + +export function isZodRawShapeCompat(v: unknown): v is ZodRawShapeCompat { + if (v == null || typeof v !== 'object') return false; + if (isStandardSchema(v)) return false; + const values = Object.values(v as object); + if (values.length === 0) return true; + return values.some(v => isZodTypeLike(v)); +} + +type ZodV3Like = { + _def: { typeName?: string; innerType?: ZodV3Like; type?: ZodV3Like; shape?: () => Record; values?: unknown[] }; + description?: string; + isOptional?: () => boolean; + '~standard'?: { validate: (v: unknown) => unknown }; +}; + +/** Best-effort JSON Schema synthesis for a single zod v3 schema (covers common primitives). */ +function v3ToJsonSchema(s: ZodV3Like): Record { + const out: Record = {}; + if (s.description) out.description = s.description; + const tn = s._def?.typeName; + switch (tn) { + case 'ZodString': { + out.type = 'string'; + break; + } + case 'ZodNumber': { + out.type = 'number'; + break; + } + case 'ZodBoolean': { + out.type = 'boolean'; + break; + } + case 'ZodArray': { + out.type = 'array'; + if (s._def.type) out.items = v3ToJsonSchema(s._def.type); + break; + } + case 'ZodEnum': + case 'ZodNativeEnum': { + if (Array.isArray(s._def.values)) out.enum = s._def.values; + break; + } + case 'ZodObject': { + const shape = s._def.shape?.(); + out.type = 'object'; + if (shape) { + const entries = Object.entries(shape); + out.properties = Object.fromEntries(entries.map(([k, v]) => [k, v3ToJsonSchema(v)])); + out.required = entries.filter(([, v]) => !v.isOptional?.()).map(([k]) => k); + } + break; + } + case 'ZodOptional': + case 'ZodNullable': + case 'ZodDefault': { + return s._def.innerType ? { ...v3ToJsonSchema(s._def.innerType), ...out } : out; + } + default: { + break; + } + } + return out; +} + +/** Wrap a raw shape whose values are zod v3 (or any Standard Schema lacking jsonSchema) into a {@linkcode StandardSchemaWithJSON}. */ +function adaptRawShapeToStandard(shape: Record): StandardSchemaWithJSON { + const entries = Object.entries(shape); + const required = entries.filter(([, v]) => !v.isOptional?.()).map(([k]) => k); + const jsonSchema = { + type: 'object', + properties: Object.fromEntries(entries.map(([k, v]) => [k, v3ToJsonSchema(v)])), + required, + additionalProperties: false + }; + const emit = () => jsonSchema; + return { + '~standard': { + version: 1, + vendor: 'mcp-zod-v3-compat', + validate: input => { + if (typeof input !== 'object' || input === null) { + return { issues: [{ message: 'Expected object' }] }; + } + const value: Record = {}; + const issues: { message: string; path: PropertyKey[] }[] = []; + for (const [k, field] of entries) { + const std = field['~standard']; + const raw = (input as Record)[k]; + if (std) { + const r = std.validate(raw) as { value?: unknown; issues?: { message: string }[] }; + if (r.issues) for (const i of r.issues) issues.push({ message: i.message, path: [k] }); + else value[k] = r.value; + } else { + value[k] = raw; + } + } + return issues.length > 0 ? { issues } : { value }; + }, + jsonSchema: { input: emit, output: emit } + } + } as StandardSchemaWithJSON; +} + +/** Wrap a Standard Schema that lacks `jsonSchema` (e.g. zod v3's `z.object({...})`) by synthesizing one from `_def`. */ +function adaptStandardSchemaWithoutJson(schema: ZodV3Like): StandardSchemaWithJSON { + const json = v3ToJsonSchema(schema); + const emit = () => json; + const std = schema['~standard'] as { version: 1; vendor: string; validate: (v: unknown) => unknown }; + return { + '~standard': { ...std, jsonSchema: { input: emit, output: emit } } + } as unknown as StandardSchemaWithJSON; +} + +/** + * Coerce a v1-style raw Zod shape (or empty object) to a {@linkcode StandardSchemaWithJSON}. + * Standard Schemas pass through unchanged. + */ +export function coerceSchema(schema: unknown): StandardSchemaWithJSON | undefined { + if (schema == null) return undefined; + if (isStandardSchemaWithJSON(schema)) return schema; + if (isZodRawShapeCompat(schema)) { + const values = Object.values(schema as object); + if (values.every(v => isZodV4Type(v))) { + return z.object(schema as ZodRawShapeCompat) as unknown as StandardSchemaWithJSON; + } + return adaptRawShapeToStandard(schema as unknown as Record); + } + if (isStandardSchema(schema)) { + if ('_def' in (schema as object)) { + return adaptStandardSchemaWithoutJson(schema as unknown as ZodV3Like); + } + throw new Error('Schema lacks JSON-Schema emission (zod >=4.2 or equivalent required).'); + } + throw new Error('inputSchema/argsSchema must be a Standard Schema or a Zod raw shape (e.g. {name: z.string()})'); +} + +/** + * Parse the variadic argument list of the deprecated {@linkcode McpServer.tool | tool()} overloads. + */ +export function parseLegacyToolArgs( + name: string, + rest: unknown[] +): { description?: string; inputSchema?: StandardSchemaWithJSON; annotations?: ToolAnnotations; cb: unknown } { + let description: string | undefined; + let inputSchema: StandardSchemaWithJSON | undefined; + let annotations: ToolAnnotations | undefined; + if (typeof rest[0] === 'string') description = rest.shift() as string; + if (rest.length > 1) { + const first = rest[0]; + if (isZodRawShapeCompat(first) || isStandardSchema(first)) { + inputSchema = coerceSchema(rest.shift()); + if (rest.length > 1 && typeof rest[0] === 'object' && rest[0] !== null && !isZodRawShapeCompat(rest[0])) { + annotations = rest.shift() as ToolAnnotations; + } + } else if (typeof first === 'object' && first !== null) { + if (Object.values(first).some(v => typeof v === 'object' && v !== null)) { + throw new Error(`Tool ${name} expected a Zod schema or ToolAnnotations, but received an unrecognized object`); + } + annotations = rest.shift() as ToolAnnotations; + } + } + return { description, inputSchema, annotations, cb: rest[0] }; +} + +/** + * Parse the variadic argument list of the deprecated {@linkcode McpServer.prompt | prompt()} overloads. + */ +export function parseLegacyPromptArgs(rest: unknown[]): { description?: string; argsSchema?: StandardSchemaWithJSON; cb: unknown } { + let description: string | undefined; + if (typeof rest[0] === 'string') description = rest.shift() as string; + let argsSchema: StandardSchemaWithJSON | undefined; + if (rest.length > 1) argsSchema = coerceSchema(rest.shift()); + return { description, argsSchema, cb: rest[0] }; +} diff --git a/packages/server/src/server/serverRegistries.ts b/packages/server/src/server/serverRegistries.ts new file mode 100644 index 000000000..67bd358a8 --- /dev/null +++ b/packages/server/src/server/serverRegistries.ts @@ -0,0 +1,889 @@ +import type { + BaseMetadata, + CallToolRequest, + CallToolResult, + CompleteRequestPrompt, + CompleteRequestResourceTemplate, + CompleteResult, + CreateTaskResult, + CreateTaskServerContext, + GetPromptResult, + ListPromptsResult, + ListToolsResult, + Prompt, + PromptReference, + ReadResourceResult, + RequestMethod, + RequestTypeMap, + Resource, + ResourceTemplateReference, + Result, + ResultTypeMap, + ServerCapabilities, + ServerContext, + StandardSchemaWithJSON, + Tool, + ToolAnnotations, + ToolExecution, + Variables +} from '@modelcontextprotocol/core'; +import { + assertCompleteRequestPrompt, + assertCompleteRequestResourceTemplate, + promptArgumentsFromStandardSchema, + ProtocolError, + ProtocolErrorCode, + standardSchemaToJsonSchema, + validateAndWarnToolName, + validateStandardSchema +} from '@modelcontextprotocol/core'; + +import type { ToolTaskHandler } from '../experimental/tasks/interfaces.js'; +import { getCompleter, isCompletable } from './completable.js'; +import type { ResourceTemplate } from './resourceTemplate.js'; +import type { ZodRawShapeCompat } from './serverLegacy.js'; +import { coerceSchema } from './serverLegacy.js'; + +/** + * Minimal surface a {@linkcode ServerRegistries} instance needs from its owning server. + * {@linkcode McpServer} satisfies this directly. + */ +export interface RegistriesHost { + setRequestHandler( + method: M, + handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise + ): void; + assertCanSetRequestHandler(method: string): void; + registerCapabilities(capabilities: ServerCapabilities): void; + getCapabilities(): ServerCapabilities; + sendToolListChanged(): Promise; + sendResourceListChanged(): Promise; + sendPromptListChanged(): Promise; + /** + * Lazy installers, called on first registerTool/Resource/Prompt. Defined on the host so + * subclasses can override the install (v1 compat for code that monkey-patches `setToolRequestHandlers`). + * Default impl on McpServer delegates back to {@link ServerRegistries}. + */ + setToolRequestHandlers(): void; + setResourceRequestHandlers(): void; + setPromptRequestHandlers(): void; + setCompletionRequestHandler(): void; +} + +/** + * In-memory tool/resource/prompt registries plus the lazy `tools/*`, `resources/*`, + * `prompts/*`, and `completion/*` request-handler installers. + * + * Composed by {@linkcode McpServer}. One instance per server. + */ +export class ServerRegistries { + readonly registeredResources: { [uri: string]: RegisteredResource } = {}; + readonly registeredResourceTemplates: { [name: string]: RegisteredResourceTemplate } = {}; + readonly registeredTools: { [name: string]: RegisteredTool } = {}; + readonly registeredPrompts: { [name: string]: RegisteredPrompt } = {}; + + private _toolHandlersInitialized = false; + private _completionHandlerInitialized = false; + private _resourceHandlersInitialized = false; + private _promptHandlersInitialized = false; + + constructor(private readonly host: RegistriesHost) {} + + // ─────────────────────────────────────────────────────────────────────── + // Tools + // ─────────────────────────────────────────────────────────────────────── + + /** @internal v1-compat: kept callable so subclassers can invoke the default after overriding the host hook. */ + setToolRequestHandlers(): void { + if (this._toolHandlersInitialized) return; + const h = this.host; + h.assertCanSetRequestHandler('tools/list'); + h.assertCanSetRequestHandler('tools/call'); + h.registerCapabilities({ tools: { listChanged: h.getCapabilities().tools?.listChanged ?? true } }); + + h.setRequestHandler( + 'tools/list', + (): ListToolsResult => ({ + tools: Object.entries(this.registeredTools) + .filter(([, tool]) => tool.enabled) + .map(([name, tool]): Tool => { + const def: Tool = { + name, + title: tool.title, + description: tool.description, + inputSchema: tool.inputSchema + ? (standardSchemaToJsonSchema(tool.inputSchema, 'input') as Tool['inputSchema']) + : EMPTY_OBJECT_JSON_SCHEMA, + annotations: tool.annotations, + execution: tool.execution, + _meta: tool._meta + }; + if (tool.outputSchema) { + def.outputSchema = standardSchemaToJsonSchema(tool.outputSchema, 'output') as Tool['outputSchema']; + } + return def; + }) + }) + ); + + h.setRequestHandler('tools/call', async (request, ctx): Promise => { + const tool = this.registeredTools[request.params.name]; + if (!tool) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Tool ${request.params.name} not found`); + } + if (!tool.enabled) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Tool ${request.params.name} disabled`); + } + try { + const isTaskRequest = !!request.params.task; + const taskSupport = tool.execution?.taskSupport; + const isTaskHandler = 'createTask' in (tool.handler as AnyToolHandler); + if ((taskSupport === 'required' || taskSupport === 'optional') && !isTaskHandler) { + throw new ProtocolError( + ProtocolErrorCode.InternalError, + `Tool ${request.params.name} has taskSupport '${taskSupport}' but was not registered with registerToolTask` + ); + } + if (taskSupport === 'required' && !isTaskRequest) { + throw new ProtocolError( + ProtocolErrorCode.MethodNotFound, + `Tool ${request.params.name} requires task augmentation (taskSupport: 'required')` + ); + } + if (taskSupport === 'optional' && !isTaskRequest && isTaskHandler) { + return await this.handleAutomaticTaskPolling(tool, request, ctx); + } + const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); + const result = await tool.executor(args, ctx); + if (isTaskRequest) return result; + await this.validateToolOutput(tool, result, request.params.name); + return result; + } catch (error) { + if (error instanceof ProtocolError && error.code === ProtocolErrorCode.UrlElicitationRequired) { + throw error; + } + return createToolError(error instanceof Error ? error.message : String(error)); + } + }); + + this._toolHandlersInitialized = true; + } + + /** @internal v1-compat */ + async validateToolInput< + ToolType extends RegisteredTool, + Args extends ToolType['inputSchema'] extends infer InputSchema + ? InputSchema extends StandardSchemaWithJSON + ? StandardSchemaWithJSON.InferOutput + : undefined + : undefined + >(tool: ToolType, args: Args, toolName: string): Promise { + if (!tool.inputSchema) return undefined as Args; + const parsed = await validateStandardSchema(tool.inputSchema, args ?? {}); + if (!parsed.success) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Input validation error: Invalid arguments for tool ${toolName}: ${parsed.error}` + ); + } + return parsed.data as unknown as Args; + } + + /** @internal v1-compat */ + async validateToolOutput(tool: RegisteredTool, result: CallToolResult | CreateTaskResult, toolName: string): Promise { + if (!tool.outputSchema) return; + if (!('content' in result)) return; + if (result.isError) return; + if (!result.structuredContent) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Output validation error: Tool ${toolName} has an output schema but no structured content was provided` + ); + } + const parsed = await validateStandardSchema(tool.outputSchema, result.structuredContent); + if (!parsed.success) { + throw new ProtocolError( + ProtocolErrorCode.InvalidParams, + `Output validation error: Invalid structured content for tool ${toolName}: ${parsed.error}` + ); + } + } + + /** @internal v1-compat */ + async handleAutomaticTaskPolling( + tool: RegisteredTool, + request: RequestT, + ctx: ServerContext + ): Promise { + if (!ctx.task?.store) { + throw new Error('No task store provided for task-capable tool.'); + } + const args = await this.validateToolInput(tool, request.params.arguments, request.params.name); + const createTaskResult = (await tool.executor(args, ctx)) as CreateTaskResult; + const taskId = createTaskResult.task.taskId; + let task = createTaskResult.task; + const pollInterval = task.pollInterval ?? 5000; + while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') { + await new Promise(resolve => setTimeout(resolve, pollInterval)); + const updated = await ctx.task.store.getTask(taskId); + if (!updated) { + throw new ProtocolError(ProtocolErrorCode.InternalError, `Task ${taskId} not found during polling`); + } + task = updated; + } + return (await ctx.task.store.getTaskResult(taskId)) as CallToolResult; + } + + // ─────────────────────────────────────────────────────────────────────── + // Completion + // ─────────────────────────────────────────────────────────────────────── + + /** @internal v1-compat */ + setCompletionRequestHandler(): void { + if (this._completionHandlerInitialized) return; + const h = this.host; + h.assertCanSetRequestHandler('completion/complete'); + h.registerCapabilities({ completions: {} }); + h.setRequestHandler('completion/complete', async (request): Promise => { + switch (request.params.ref.type) { + case 'ref/prompt': { + assertCompleteRequestPrompt(request); + return this.handlePromptCompletion(request, request.params.ref); + } + case 'ref/resource': { + assertCompleteRequestResourceTemplate(request); + return this.handleResourceCompletion(request, request.params.ref); + } + default: { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid completion reference: ${request.params.ref}`); + } + } + }); + this._completionHandlerInitialized = true; + } + + private async handlePromptCompletion(request: CompleteRequestPrompt, ref: PromptReference): Promise { + const prompt = this.registeredPrompts[ref.name]; + if (!prompt) throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${ref.name} not found`); + if (!prompt.enabled) throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${ref.name} disabled`); + if (!prompt.argsSchema) return EMPTY_COMPLETION_RESULT; + const promptShape = getSchemaShape(prompt.argsSchema); + const field = unwrapOptionalSchema(promptShape?.[request.params.argument.name]); + if (!isCompletable(field)) return EMPTY_COMPLETION_RESULT; + const completer = getCompleter(field); + if (!completer) return EMPTY_COMPLETION_RESULT; + const suggestions = await completer(request.params.argument.value, request.params.context); + return createCompletionResult(suggestions); + } + + private async handleResourceCompletion( + request: CompleteRequestResourceTemplate, + ref: ResourceTemplateReference + ): Promise { + const template = Object.values(this.registeredResourceTemplates).find(t => t.resourceTemplate.uriTemplate.toString() === ref.uri); + if (!template) { + if (this.registeredResources[ref.uri]) return EMPTY_COMPLETION_RESULT; + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Resource template ${request.params.ref.uri} not found`); + } + const completer = template.resourceTemplate.completeCallback(request.params.argument.name); + if (!completer) return EMPTY_COMPLETION_RESULT; + const suggestions = await completer(request.params.argument.value, request.params.context); + return createCompletionResult(suggestions); + } + + // ─────────────────────────────────────────────────────────────────────── + // Resources + // ─────────────────────────────────────────────────────────────────────── + + /** @internal v1-compat */ + setResourceRequestHandlers(): void { + if (this._resourceHandlersInitialized) return; + const h = this.host; + h.assertCanSetRequestHandler('resources/list'); + h.assertCanSetRequestHandler('resources/templates/list'); + h.assertCanSetRequestHandler('resources/read'); + h.registerCapabilities({ resources: { listChanged: h.getCapabilities().resources?.listChanged ?? true } }); + + h.setRequestHandler('resources/list', async (_request, ctx) => { + const resources = Object.entries(this.registeredResources) + .filter(([_, r]) => r.enabled) + .map(([uri, r]) => ({ uri, name: r.name, ...r.metadata })); + const templateResources: Resource[] = []; + for (const template of Object.values(this.registeredResourceTemplates)) { + if (!template.resourceTemplate.listCallback) continue; + const result = await template.resourceTemplate.listCallback(ctx); + for (const resource of result.resources) { + templateResources.push({ ...template.metadata, ...resource }); + } + } + return { resources: [...resources, ...templateResources] }; + }); + + h.setRequestHandler('resources/templates/list', async () => { + const resourceTemplates = Object.entries(this.registeredResourceTemplates).map(([name, t]) => ({ + name, + uriTemplate: t.resourceTemplate.uriTemplate.toString(), + ...t.metadata + })); + return { resourceTemplates }; + }); + + h.setRequestHandler('resources/read', async (request, ctx) => { + const uri = new URL(request.params.uri); + const resource = this.registeredResources[uri.toString()]; + if (resource) { + if (!resource.enabled) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Resource ${uri} disabled`); + } + return resource.readCallback(uri, ctx); + } + for (const template of Object.values(this.registeredResourceTemplates)) { + const variables = template.resourceTemplate.uriTemplate.match(uri.toString()); + if (variables) return template.readCallback(uri, variables, ctx); + } + throw new ProtocolError(ProtocolErrorCode.ResourceNotFound, `Resource ${uri} not found`); + }); + + this._resourceHandlersInitialized = true; + } + + // ─────────────────────────────────────────────────────────────────────── + // Prompts + // ─────────────────────────────────────────────────────────────────────── + + /** @internal v1-compat */ + setPromptRequestHandlers(): void { + if (this._promptHandlersInitialized) return; + const h = this.host; + h.assertCanSetRequestHandler('prompts/list'); + h.assertCanSetRequestHandler('prompts/get'); + h.registerCapabilities({ prompts: { listChanged: h.getCapabilities().prompts?.listChanged ?? true } }); + + h.setRequestHandler( + 'prompts/list', + (): ListPromptsResult => ({ + prompts: Object.entries(this.registeredPrompts) + .filter(([, p]) => p.enabled) + .map( + ([name, p]): Prompt => ({ + name, + title: p.title, + description: p.description, + arguments: p.argsSchema ? promptArgumentsFromStandardSchema(p.argsSchema) : undefined, + _meta: p._meta + }) + ) + }) + ); + + h.setRequestHandler('prompts/get', async (request, ctx): Promise => { + const prompt = this.registeredPrompts[request.params.name]; + if (!prompt) throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${request.params.name} not found`); + if (!prompt.enabled) throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Prompt ${request.params.name} disabled`); + return prompt.handler(request.params.arguments, ctx); + }); + + this._promptHandlersInitialized = true; + } + + // ─────────────────────────────────────────────────────────────────────── + // Public registration entry points + // ─────────────────────────────────────────────────────────────────────── + + /** + * Registers a resource with a config object and callback. + */ + registerResource(name: string, uriOrTemplate: string, config: ResourceMetadata, readCallback: ReadResourceCallback): RegisteredResource; + registerResource( + name: string, + uriOrTemplate: ResourceTemplate, + config: ResourceMetadata, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate; + registerResource( + name: string, + uriOrTemplate: string | ResourceTemplate, + config: ResourceMetadata, + readCallback: ReadResourceCallback | ReadResourceTemplateCallback + ): RegisteredResource | RegisteredResourceTemplate { + if (typeof uriOrTemplate === 'string') { + if (this.registeredResources[uriOrTemplate]) throw new Error(`Resource ${uriOrTemplate} is already registered`); + const r = this.createRegisteredResource( + name, + (config as BaseMetadata).title, + uriOrTemplate, + config, + readCallback as ReadResourceCallback + ); + this.host.setResourceRequestHandlers(); + this.host.sendResourceListChanged(); + return r; + } else { + if (this.registeredResourceTemplates[name]) throw new Error(`Resource template ${name} is already registered`); + const r = this.createRegisteredResourceTemplate( + name, + (config as BaseMetadata).title, + uriOrTemplate, + config, + readCallback as ReadResourceTemplateCallback + ); + this.host.setResourceRequestHandlers(); + this.host.sendResourceListChanged(); + return r; + } + } + + /** + * Registers a tool with a config object and callback. + */ + registerTool( + name: string, + config: { + title?: string; + description?: string; + inputSchema?: InputArgs | ZodRawShapeCompat; + outputSchema?: OutputArgs | ZodRawShapeCompat; + annotations?: ToolAnnotations; + _meta?: Record; + }, + cb: ToolCallback + ): RegisteredTool { + if (this.registeredTools[name]) throw new Error(`Tool ${name} is already registered`); + const { title, description, inputSchema, outputSchema, annotations, _meta } = config; + return this.createRegisteredTool( + name, + title, + description, + coerceSchema(inputSchema), + coerceSchema(outputSchema), + annotations, + { taskSupport: 'forbidden' }, + _meta, + cb as ToolCallback + ); + } + + /** + * Registers a prompt with a config object and callback. + */ + registerPrompt( + name: string, + config: { title?: string; description?: string; argsSchema?: Args | ZodRawShapeCompat; _meta?: Record }, + cb: PromptCallback + ): RegisteredPrompt { + if (this.registeredPrompts[name]) throw new Error(`Prompt ${name} is already registered`); + const { title, description, argsSchema, _meta } = config; + const r = this.createRegisteredPrompt( + name, + title, + description, + coerceSchema(argsSchema), + cb as PromptCallback, + _meta + ); + this.host.setPromptRequestHandlers(); + this.host.sendPromptListChanged(); + return r; + } + + // ─────────────────────────────────────────────────────────────────────── + // Registered* factories. Exposed so legacy `.tool()`/`.prompt()`/`.resource()` + // and `experimental.tasks.registerToolTask` can build entries directly. + // ─────────────────────────────────────────────────────────────────────── + + createRegisteredResource( + name: string, + title: string | undefined, + uri: string, + metadata: ResourceMetadata | undefined, + readCallback: ReadResourceCallback + ): RegisteredResource { + const r: RegisteredResource = { + name, + title, + metadata, + readCallback, + enabled: true, + disable: () => r.update({ enabled: false }), + enable: () => r.update({ enabled: true }), + remove: () => r.update({ uri: null }), + update: updates => { + if (updates.uri !== undefined && updates.uri !== uri) { + delete this.registeredResources[uri]; + if (updates.uri) this.registeredResources[updates.uri] = r; + } + if (updates.name !== undefined) r.name = updates.name; + if (updates.title !== undefined) r.title = updates.title; + if (updates.metadata !== undefined) r.metadata = updates.metadata; + if (updates.callback !== undefined) r.readCallback = updates.callback; + if (updates.enabled !== undefined) r.enabled = updates.enabled; + this.host.sendResourceListChanged(); + } + }; + this.registeredResources[uri] = r; + return r; + } + + createRegisteredResourceTemplate( + name: string, + title: string | undefined, + template: ResourceTemplate, + metadata: ResourceMetadata | undefined, + readCallback: ReadResourceTemplateCallback + ): RegisteredResourceTemplate { + const r: RegisteredResourceTemplate = { + resourceTemplate: template, + title, + metadata, + readCallback, + enabled: true, + disable: () => r.update({ enabled: false }), + enable: () => r.update({ enabled: true }), + remove: () => r.update({ name: null }), + update: updates => { + if (updates.name !== undefined && updates.name !== name) { + delete this.registeredResourceTemplates[name]; + if (updates.name) this.registeredResourceTemplates[updates.name] = r; + } + if (updates.title !== undefined) r.title = updates.title; + if (updates.template !== undefined) r.resourceTemplate = updates.template; + if (updates.metadata !== undefined) r.metadata = updates.metadata; + if (updates.callback !== undefined) r.readCallback = updates.callback; + if (updates.enabled !== undefined) r.enabled = updates.enabled; + this.host.sendResourceListChanged(); + } + }; + this.registeredResourceTemplates[name] = r; + const variableNames = template.uriTemplate.variableNames; + const hasCompleter = Array.isArray(variableNames) && variableNames.some(v => !!template.completeCallback(v)); + if (hasCompleter) this.host.setCompletionRequestHandler(); + return r; + } + + createRegisteredPrompt( + name: string, + title: string | undefined, + description: string | undefined, + argsSchema: StandardSchemaWithJSON | undefined, + callback: PromptCallback, + _meta: Record | undefined + ): RegisteredPrompt { + let currentArgsSchema = argsSchema; + let currentCallback = callback; + const r: RegisteredPrompt = { + title, + description, + argsSchema, + _meta, + handler: createPromptHandler(name, argsSchema, callback), + enabled: true, + disable: () => r.update({ enabled: false }), + enable: () => r.update({ enabled: true }), + remove: () => r.update({ name: null }), + update: updates => { + if (updates.name !== undefined && updates.name !== name) { + delete this.registeredPrompts[name]; + if (updates.name) this.registeredPrompts[updates.name] = r; + } + if (updates.title !== undefined) r.title = updates.title; + if (updates.description !== undefined) r.description = updates.description; + if (updates._meta !== undefined) r._meta = updates._meta; + let needsRegen = false; + if (updates.argsSchema !== undefined) { + r.argsSchema = updates.argsSchema; + currentArgsSchema = updates.argsSchema; + needsRegen = true; + } + if (updates.callback !== undefined) { + currentCallback = updates.callback as PromptCallback; + needsRegen = true; + } + if (needsRegen) r.handler = createPromptHandler(name, currentArgsSchema, currentCallback); + if (updates.enabled !== undefined) r.enabled = updates.enabled; + this.host.sendPromptListChanged(); + } + }; + this.registeredPrompts[name] = r; + if (argsSchema) { + const shape = getSchemaShape(argsSchema); + if (shape) { + const hasCompletable = Object.values(shape).some(f => isCompletable(unwrapOptionalSchema(f))); + if (hasCompletable) this.host.setCompletionRequestHandler(); + } + } + return r; + } + + createRegisteredTool( + name: string, + title: string | undefined, + description: string | undefined, + inputSchema: StandardSchemaWithJSON | undefined, + outputSchema: StandardSchemaWithJSON | undefined, + annotations: ToolAnnotations | undefined, + execution: ToolExecution | undefined, + _meta: Record | undefined, + handler: AnyToolHandler + ): RegisteredTool { + validateAndWarnToolName(name); + let currentHandler = handler; + const r: RegisteredTool = { + title, + description, + inputSchema, + outputSchema, + annotations, + execution, + _meta, + handler, + executor: createToolExecutor(inputSchema, handler), + enabled: true, + disable: () => r.update({ enabled: false }), + enable: () => r.update({ enabled: true }), + remove: () => r.update({ name: null }), + update: updates => { + if (updates.name !== undefined && updates.name !== name) { + if (typeof updates.name === 'string') validateAndWarnToolName(updates.name); + delete this.registeredTools[name]; + if (updates.name) this.registeredTools[updates.name] = r; + } + if (updates.title !== undefined) r.title = updates.title; + if (updates.description !== undefined) r.description = updates.description; + let needsRegen = false; + if (updates.paramsSchema !== undefined) { + r.inputSchema = updates.paramsSchema; + needsRegen = true; + } + if (updates.callback !== undefined) { + r.handler = updates.callback; + currentHandler = updates.callback as AnyToolHandler; + needsRegen = true; + } + if (needsRegen) r.executor = createToolExecutor(r.inputSchema, currentHandler); + if (updates.outputSchema !== undefined) r.outputSchema = updates.outputSchema; + if (updates.annotations !== undefined) r.annotations = updates.annotations; + if (updates._meta !== undefined) r._meta = updates._meta; + if (updates.enabled !== undefined) r.enabled = updates.enabled; + this.host.sendToolListChanged(); + } + }; + this.registeredTools[name] = r; + this.host.setToolRequestHandlers(); + this.host.sendToolListChanged(); + return r; + } + + /** Expose lazy installers for callers (legacy `.prompt()/.resource()`) that build entries via `create*` directly. */ + installResourceHandlers(): void { + this.host.setResourceRequestHandlers(); + } + installPromptHandlers(): void { + this.host.setPromptRequestHandlers(); + } +} + +// ─────────────────────────────────────────────────────────────────────────── +// Public types +// ─────────────────────────────────────────────────────────────────────────── + +export type BaseToolCallback< + SendResultT extends Result, + Ctx extends ServerContext, + Args extends StandardSchemaWithJSON | undefined +> = Args extends StandardSchemaWithJSON + ? (args: StandardSchemaWithJSON.InferOutput, ctx: Ctx) => SendResultT | Promise + : (ctx: Ctx) => SendResultT | Promise; + +export type ToolCallback = BaseToolCallback< + CallToolResult, + ServerContext, + Args +>; + +export type AnyToolHandler = ToolCallback | ToolTaskHandler; + +type ToolExecutor = (args: unknown, ctx: ServerContext) => Promise; + +export type RegisteredTool = { + title?: string; + description?: string; + inputSchema?: StandardSchemaWithJSON; + outputSchema?: StandardSchemaWithJSON; + annotations?: ToolAnnotations; + execution?: ToolExecution; + _meta?: Record; + handler: AnyToolHandler; + /** @hidden */ + executor: ToolExecutor; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + paramsSchema?: StandardSchemaWithJSON; + outputSchema?: StandardSchemaWithJSON; + annotations?: ToolAnnotations; + _meta?: Record; + callback?: ToolCallback; + enabled?: boolean; + }): void; + remove(): void; +}; + +export type ResourceMetadata = Omit; +export type ReadResourceCallback = (uri: URL, ctx: ServerContext) => ReadResourceResult | Promise; + +export type RegisteredResource = { + name: string; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string; + title?: string; + uri?: string | null; + metadata?: ResourceMetadata; + callback?: ReadResourceCallback; + enabled?: boolean; + }): void; + remove(): void; +}; + +export type ReadResourceTemplateCallback = ( + uri: URL, + variables: Variables, + ctx: ServerContext +) => ReadResourceResult | Promise; + +export type RegisteredResourceTemplate = { + resourceTemplate: ResourceTemplate; + title?: string; + metadata?: ResourceMetadata; + readCallback: ReadResourceTemplateCallback; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + template?: ResourceTemplate; + metadata?: ResourceMetadata; + callback?: ReadResourceTemplateCallback; + enabled?: boolean; + }): void; + remove(): void; +}; + +export type PromptCallback = Args extends StandardSchemaWithJSON + ? (args: StandardSchemaWithJSON.InferOutput, ctx: ServerContext) => GetPromptResult | Promise + : (ctx: ServerContext) => GetPromptResult | Promise; + +type PromptHandler = (args: Record | undefined, ctx: ServerContext) => Promise; +type ToolCallbackInternal = (args: unknown, ctx: ServerContext) => CallToolResult | Promise; +type TaskHandlerInternal = { + createTask: (args: unknown, ctx: CreateTaskServerContext) => CreateTaskResult | Promise; +}; + +export type RegisteredPrompt = { + title?: string; + description?: string; + argsSchema?: StandardSchemaWithJSON; + _meta?: Record; + /** @hidden */ + handler: PromptHandler; + enabled: boolean; + enable(): void; + disable(): void; + update(updates: { + name?: string | null; + title?: string; + description?: string; + argsSchema?: Args; + _meta?: Record; + callback?: PromptCallback; + enabled?: boolean; + }): void; + remove(): void; +}; + +// Re-export for path compat. + +// ─────────────────────────────────────────────────────────────────────────── +// Helpers +// ─────────────────────────────────────────────────────────────────────────── + +const EMPTY_OBJECT_JSON_SCHEMA = { type: 'object' as const, properties: {} }; +const EMPTY_COMPLETION_RESULT: CompleteResult = { completion: { values: [], hasMore: false } }; + +function createToolError(errorMessage: string): CallToolResult { + return { content: [{ type: 'text', text: errorMessage }], isError: true }; +} + +function createCompletionResult(suggestions: readonly unknown[]): CompleteResult { + const values = suggestions.map(String).slice(0, 100); + return { completion: { values, total: suggestions.length, hasMore: suggestions.length > 100 } }; +} + +function createToolExecutor( + inputSchema: StandardSchemaWithJSON | undefined, + handler: AnyToolHandler +): ToolExecutor { + const isTaskHandler = 'createTask' in handler; + if (isTaskHandler) { + const th = handler as TaskHandlerInternal; + return async (args, ctx) => { + if (!ctx.task?.store) throw new Error('No task store provided.'); + const taskCtx: CreateTaskServerContext = { ...ctx, task: { store: ctx.task.store, requestedTtl: ctx.task?.requestedTtl } }; + if (inputSchema) return th.createTask(args, taskCtx); + return (th.createTask as (ctx: CreateTaskServerContext) => CreateTaskResult | Promise)(taskCtx); + }; + } + if (inputSchema) { + const cb = handler as ToolCallbackInternal; + return async (args, ctx) => cb(args, ctx); + } + const cb = handler as (ctx: ServerContext) => CallToolResult | Promise; + return async (_args, ctx) => cb(ctx); +} + +function createPromptHandler( + name: string, + argsSchema: StandardSchemaWithJSON | undefined, + callback: PromptCallback +): PromptHandler { + if (argsSchema) { + const typed = callback as (args: unknown, ctx: ServerContext) => GetPromptResult | Promise; + return async (args, ctx) => { + const parsed = await validateStandardSchema(argsSchema, args); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid arguments for prompt ${name}: ${parsed.error}`); + } + return typed(parsed.data, ctx); + }; + } + const typed = callback as (ctx: ServerContext) => GetPromptResult | Promise; + return async (_args, ctx) => typed(ctx); +} + +function getSchemaShape(schema: unknown): Record | undefined { + const c = schema as { shape?: unknown }; + if (c.shape && typeof c.shape === 'object') return c.shape as Record; + return undefined; +} + +function isOptionalSchema(schema: unknown): boolean { + return (schema as { type?: string } | null | undefined)?.type === 'optional'; +} + +function unwrapOptionalSchema(schema: unknown): unknown { + if (!isOptionalSchema(schema)) return schema; + const c = schema as { def?: { innerType?: unknown } }; + return c.def?.innerType ?? schema; +} + +export { type ListResourcesCallback } from './resourceTemplate.js'; diff --git a/packages/server/src/server/sessionCompat.ts b/packages/server/src/server/sessionCompat.ts new file mode 100644 index 000000000..5aee5b8e2 --- /dev/null +++ b/packages/server/src/server/sessionCompat.ts @@ -0,0 +1,245 @@ +import type { JSONRPCMessage } from '@modelcontextprotocol/core'; +import { isInitializeRequest } from '@modelcontextprotocol/core'; + +/** + * Options for {@linkcode SessionCompat}. + */ +export interface SessionCompatOptions { + /** + * Function that generates a session ID. SHOULD be globally unique and cryptographically secure + * (e.g., a securely generated UUID). + * + * @default `() => crypto.randomUUID()` + */ + sessionIdGenerator?: () => string; + + /** + * Maximum number of concurrent sessions to retain. New `initialize` requests beyond this cap + * are rejected with HTTP 503 + `Retry-After`. Idle sessions are evicted LRU when at the cap. + * + * @default 10000 + */ + maxSessions?: number; + + /** + * Sessions idle (no request received) for longer than this are evicted on the next sweep. + * + * @default 30 * 60_000 (30 minutes) + */ + idleTtlMs?: number; + + /** + * Suggested `Retry-After` value (seconds) returned with 503 when at {@linkcode maxSessions}. + * + * @default 30 + */ + retryAfterSeconds?: number; + + /** Called when a new session is minted. */ + onsessioninitialized?: (sessionId: string) => void | Promise; + + /** Called when a session is deleted (via DELETE) or evicted. */ + onsessionclosed?: (sessionId: string) => void | Promise; + + /** + * When `true`, this instance allows at most one session: a second `initialize` + * is rejected with "Server already initialized". Matches the per-transport-instance + * v1 behaviour where each `WebStandardStreamableHTTPServerTransport` holds one session. + * + * @default false + */ + singleSession?: boolean; + + /** Called for validation failures (re-init, missing/unknown session header). */ + onerror?: (error: Error) => void; +} + +interface SessionEntry { + createdAt: number; + lastSeen: number; + /** Standalone GET subscription stream controller, if one is open. */ + sseController?: ReadableStreamDefaultController; + /** Protocol version requested by the client in `initialize.params.protocolVersion`. */ + protocolVersion?: string; +} + +/** Result of {@linkcode SessionCompat.validate}. */ +export type SessionValidation = { ok: true; sessionId: string | undefined; isInitialize: boolean } | { ok: false; response: Response }; + +function jsonError(status: number, code: number, message: string, headers?: Record): Response { + return Response.json( + { jsonrpc: '2.0', error: { code, message }, id: null }, + { status, headers: { 'Content-Type': 'application/json', ...headers } } + ); +} + +/** + * Bounded, in-memory `mcp-session-id` lifecycle for the pre-2026-06 stateful Streamable HTTP + * protocol. One instance is shared across all requests to a given {@linkcode shttpHandler}. + * + * Sessions are minted when an `initialize` request arrives and validated on every subsequent + * request via the `mcp-session-id` header. Storage is LRU with {@linkcode SessionCompatOptions.maxSessions} + * cap and {@linkcode SessionCompatOptions.idleTtlMs} idle eviction. + */ +export class SessionCompat { + private readonly _sessions = new Map(); + private readonly _generate: () => string; + private readonly _maxSessions: number; + private readonly _idleTtlMs: number; + private readonly _retryAfterSeconds: number; + private readonly _onsessioninitialized?: (sessionId: string) => void | Promise; + private readonly _onsessionclosed?: (sessionId: string) => void | Promise; + private readonly _singleSession: boolean; + private readonly _onerror?: (error: Error) => void; + + constructor(options: SessionCompatOptions = {}) { + this._generate = options.sessionIdGenerator ?? (() => crypto.randomUUID()); + this._maxSessions = options.maxSessions ?? 10_000; + this._idleTtlMs = options.idleTtlMs ?? 30 * 60_000; + this._retryAfterSeconds = options.retryAfterSeconds ?? 30; + this._onsessioninitialized = options.onsessioninitialized; + this._onsessionclosed = options.onsessionclosed; + this._singleSession = options.singleSession ?? false; + this._onerror = options.onerror; + } + + /** + * Validates the `mcp-session-id` header for a parsed POST body. If the body contains an + * `initialize` request, mints a new session instead. Ported from + * `WebStandardStreamableHTTPServerTransport.validateSession` + the initialize-detection + * block of `handlePostRequest`. + */ + async validate(req: Request, messages: JSONRPCMessage[]): Promise { + const isInit = messages.some(m => isInitializeRequest(m)); + + if (isInit) { + if (messages.length > 1) { + this._onerror?.(new Error('Invalid Request: Only one initialization request is allowed')); + return { + ok: false, + response: jsonError(400, -32_600, 'Invalid Request: Only one initialization request is allowed') + }; + } + if (this._singleSession && this._sessions.size > 0) { + this._onerror?.(new Error('Invalid Request: Server already initialized')); + return { + ok: false, + response: jsonError(400, -32_600, 'Invalid Request: Server already initialized') + }; + } + this._evictIdle(); + if (this._sessions.size >= this._maxSessions) { + this._evictOldest(); + } + if (this._sessions.size >= this._maxSessions) { + return { + ok: false, + response: jsonError(503, -32_000, 'Server at session capacity', { + 'Retry-After': String(this._retryAfterSeconds) + }) + }; + } + const id = this._generate(); + const now = Date.now(); + const initMsg = messages.find(m => isInitializeRequest(m)); + const protocolVersion = initMsg && isInitializeRequest(initMsg) ? initMsg.params.protocolVersion : undefined; + this._sessions.set(id, { createdAt: now, lastSeen: now, protocolVersion }); + await Promise.resolve(this._onsessioninitialized?.(id)); + return { ok: true, sessionId: id, isInitialize: true }; + } + + return this.validateHeader(req); + } + + /** + * Validates the `mcp-session-id` header without inspecting a body (for GET/DELETE). + */ + validateHeader(req: Request): SessionValidation { + if (this._singleSession && this._sessions.size === 0) { + this._onerror?.(new Error('Bad Request: Server not initialized')); + return { ok: false, response: jsonError(400, -32_000, 'Bad Request: Server not initialized') }; + } + const headerId = req.headers.get('mcp-session-id'); + if (!headerId) { + this._onerror?.(new Error('Bad Request: Mcp-Session-Id header is required')); + return { + ok: false, + response: jsonError(400, -32_000, 'Bad Request: Mcp-Session-Id header is required') + }; + } + const entry = this._sessions.get(headerId); + if (!entry) { + this._onerror?.(new Error('Session not found')); + return { ok: false, response: jsonError(404, -32_001, 'Session not found') }; + } + entry.lastSeen = Date.now(); + // Re-insert to maintain Map iteration order as LRU. + this._sessions.delete(headerId); + this._sessions.set(headerId, entry); + return { ok: true, sessionId: headerId, isInitialize: false }; + } + + /** Deletes a session (via DELETE request). */ + async delete(sessionId: string): Promise { + const entry = this._sessions.get(sessionId); + if (!entry) return; + try { + entry.sseController?.close(); + } catch { + // Already closed. + } + this._sessions.delete(sessionId); + await Promise.resolve(this._onsessionclosed?.(sessionId)); + } + + /** Protocol version the client requested in `initialize` for this session, if known. */ + negotiatedVersion(sessionId: string): string | undefined { + return this._sessions.get(sessionId)?.protocolVersion; + } + + /** Returns true if a standalone GET stream is already open for this session. */ + hasStandaloneStream(sessionId: string): boolean { + return this._sessions.get(sessionId)?.sseController !== undefined; + } + + /** Registers the open standalone GET stream controller for this session. */ + setStandaloneStream(sessionId: string, controller: ReadableStreamDefaultController | undefined): void { + const entry = this._sessions.get(sessionId); + if (entry) entry.sseController = controller; + } + + /** Closes the standalone GET stream for this session if one is open. */ + closeStandaloneStream(sessionId: string): void { + const entry = this._sessions.get(sessionId); + try { + entry?.sseController?.close(); + } catch { + // Already closed. + } + if (entry) entry.sseController = undefined; + } + + /** Number of live sessions. */ + get size(): number { + return this._sessions.size; + } + + private _evictIdle(): void { + const cutoff = Date.now() - this._idleTtlMs; + for (const [id, entry] of this._sessions) { + if (entry.lastSeen < cutoff) { + this._sessions.delete(id); + void Promise.resolve(this._onsessionclosed?.(id)); + } + } + } + + private _evictOldest(): void { + const oldest = this._sessions.keys().next(); + if (!oldest.done) { + const id = oldest.value; + this._sessions.delete(id); + void Promise.resolve(this._onsessionclosed?.(id)); + } + } +} diff --git a/packages/server/src/server/shttpHandler.ts b/packages/server/src/server/shttpHandler.ts new file mode 100644 index 000000000..d8e9360ac --- /dev/null +++ b/packages/server/src/server/shttpHandler.ts @@ -0,0 +1,515 @@ +import type { + AuthInfo, + JSONRPCErrorResponse, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResultResponse, + MessageExtraInfo, + RequestEnv +} from '@modelcontextprotocol/core'; +import { + DEFAULT_NEGOTIATED_PROTOCOL_VERSION, + isInitializeRequest, + isJSONRPCErrorResponse, + isJSONRPCNotification, + isJSONRPCRequest, + isJSONRPCResultResponse, + JSONRPCMessageSchema, + SUPPORTED_PROTOCOL_VERSIONS +} from '@modelcontextprotocol/core'; + +import type { BackchannelCompat } from './backchannelCompat.js'; +import type { SessionCompat } from './sessionCompat.js'; + +export type StreamId = string; +export type EventId = string; + +/** + * Interface for resumability support via event storage. + */ +export interface EventStore { + /** + * Stores an event for later retrieval. + * @param streamId ID of the stream the event belongs to + * @param message The JSON-RPC message to store + * @returns The generated event ID for the stored event + */ + storeEvent(streamId: StreamId, message: JSONRPCMessage): Promise; + + /** + * Replays events stored after the given event ID, calling `send` for each. + * @returns The stream ID the replayed events belong to + */ + replayEventsAfter( + lastEventId: EventId, + opts: { send: (eventId: EventId, message: JSONRPCMessage) => Promise } + ): Promise; + + /** + * Get the stream ID associated with a given event ID. + * @deprecated No longer used; the handler does not maintain a stream-mapping table. + */ + getStreamIdForEventId?(eventId: EventId): Promise; +} + +/** + * Callback bundle {@linkcode shttpHandler} uses to route inbound messages. Matches the + * inbound side of {@linkcode @modelcontextprotocol/core!shared/transport.RequestTransport | RequestTransport}; + * the handler reads these slots at call time, so a transport can pass `this` and have + * `connect()` set them later. + */ +export interface ShttpCallbacks { + /** Called per inbound JSON-RPC request; yields notifications then one terminal response. */ + onrequest?: ((request: JSONRPCRequest, env?: RequestEnv) => AsyncIterable) | undefined; + /** Called per inbound JSON-RPC notification. */ + onnotification?: (notification: JSONRPCNotification) => void | Promise; + /** Called per inbound JSON-RPC response (client POSTing back to a server-initiated request). Returns `true` if claimed. */ + onresponse?: (response: JSONRPCResultResponse | JSONRPCErrorResponse) => boolean; +} + +/** @deprecated Use {@linkcode ShttpCallbacks}. */ +export type McpServerLike = ShttpCallbacks; + +/** + * Options for {@linkcode shttpHandler}. + */ +export interface ShttpHandlerOptions { + /** + * If `true`, return a single `application/json` response instead of an SSE stream. + * Progress notifications yielded by handlers are dropped in this mode. + * + * @default false + */ + enableJsonResponse?: boolean; + + /** + * Pre-2026-06 session compatibility. When provided, the handler validates the + * `mcp-session-id` header, mints a session on `initialize`, and supports the + * standalone GET subscription stream and DELETE session termination. When omitted, + * the handler is stateless: GET/DELETE return 405. + */ + session?: SessionCompat; + + /** + * Pre-2026-06 server-to-client request backchannel. When provided alongside `session`, + * the handler supplies `env.send` to dispatched handlers (so `ctx.mcpReq.elicitInput()` etc. + * work over the open POST SSE stream) and routes incoming JSON-RPC responses to the + * waiting `env.send` promise. Version-gated: only active for sessions whose negotiated + * protocol version is below `2026-06-30`. + */ + backchannel?: BackchannelCompat; + + /** + * Event store for SSE resumability via `Last-Event-ID`. When configured, every + * outgoing SSE event is persisted and a priming event is sent at stream start. + */ + eventStore?: EventStore; + + /** + * Retry interval in milliseconds, sent in the SSE `retry` field of priming events. + */ + retryInterval?: number; + + /** + * Protocol versions accepted in the `mcp-protocol-version` header. + * + * @default {@linkcode SUPPORTED_PROTOCOL_VERSIONS} + */ + supportedProtocolVersions?: string[]; + + /** Called for non-fatal errors (validation failures, stream write errors). */ + onerror?: (error: Error) => void; +} + +/** + * Per-request extras passed alongside the {@linkcode Request}. + */ +export interface ShttpRequestExtra { + /** Pre-parsed body (e.g. from `express.json()`). When omitted, `req.json()` is used. */ + parsedBody?: unknown; + /** Validated auth token info from upstream middleware. */ + authInfo?: AuthInfo; +} + +function jsonError(status: number, code: number, message: string, extra?: { headers?: Record; data?: string }): Response { + const error: { code: number; message: string; data?: string } = { code, message }; + if (extra?.data !== undefined) error.data = extra.data; + return Response.json( + { jsonrpc: '2.0', error, id: null }, + { status, headers: { 'Content-Type': 'application/json', ...extra?.headers } } + ); +} + +function writeSSEEvent( + controller: ReadableStreamDefaultController, + encoder: InstanceType, + message: JSONRPCMessage, + eventId?: string +): boolean { + try { + let data = 'event: message\n'; + if (eventId) data += `id: ${eventId}\n`; + data += `data: ${JSON.stringify(message)}\n\n`; + controller.enqueue(encoder.encode(data)); + return true; + } catch { + return false; + } +} + +/** Sentinel session key for the standalone GET stream when no {@linkcode SessionCompat} is configured. */ +export const STATELESS_GET_KEY = '_stateless'; + +/** EventStore stream-ID prefix for the standalone GET stream (matches v1 `_standaloneSseStreamId`). */ +const STANDALONE_STREAM_ID = '_GET_stream'; + +const SSE_HEADERS: Record = { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache, no-transform', + Connection: 'keep-alive' +}; + +/** + * Creates a Web-standard `(Request) => Promise` handler for the MCP Streamable HTTP + * transport, driven by {@linkcode ShttpCallbacks.onrequest} per request. + * + * No `_streamMapping`, `_requestToStreamMapping`, or `relatedRequestId` routing — the response + * stream is in lexical scope of the request that opened it. Session lifecycle (when enabled) + * lives in the supplied {@linkcode SessionCompat}, not on this handler. + */ +export function shttpHandler( + cb: ShttpCallbacks, + options: ShttpHandlerOptions = {} +): (req: Request, extra?: ShttpRequestExtra) => Promise { + const enableJsonResponse = options.enableJsonResponse ?? false; + const session = options.session; + const backchannel = options.backchannel; + const eventStore = options.eventStore; + const retryInterval = options.retryInterval; + const supportedProtocolVersions = options.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; + const onerror = options.onerror; + + function backchannelEnabled(sessionId: string | undefined, clientProtocolVersion: string): boolean { + if (!backchannel || sessionId === undefined) return false; + const negotiated = session?.negotiatedVersion(sessionId) ?? clientProtocolVersion; + return negotiated < '2026-06-30'; + } + + function validateProtocolVersion(req: Request): Response | undefined { + const v = req.headers.get('mcp-protocol-version'); + if (v !== null && !supportedProtocolVersions.includes(v)) { + const msg = `Bad Request: Unsupported protocol version: ${v} (supported versions: ${supportedProtocolVersions.join(', ')})`; + onerror?.(new Error(msg)); + return jsonError(400, -32_000, msg); + } + return undefined; + } + + async function writePrimingEvent( + controller: ReadableStreamDefaultController, + encoder: InstanceType, + streamId: string, + protocolVersion: string + ): Promise { + if (!eventStore) return; + if (protocolVersion < '2025-11-25') return; + const primingId = await eventStore.storeEvent(streamId, {} as JSONRPCMessage); + const retry = retryInterval === undefined ? '' : `retry: ${retryInterval}\n`; + controller.enqueue(encoder.encode(`id: ${primingId}\n${retry}data: \n\n`)); + } + + async function emit( + controller: ReadableStreamDefaultController, + encoder: InstanceType, + streamId: string, + message: JSONRPCMessage + ): Promise { + const eventId = eventStore ? await eventStore.storeEvent(streamId, message) : undefined; + if (!writeSSEEvent(controller, encoder, message, eventId)) { + onerror?.(new Error('Failed to write SSE event')); + } + } + + async function handlePost(req: Request, extra?: ShttpRequestExtra): Promise { + const accept = req.headers.get('accept'); + if (!accept?.includes('application/json') || !accept.includes('text/event-stream')) { + onerror?.(new Error('Not Acceptable: Client must accept both application/json and text/event-stream')); + return jsonError(406, -32_000, 'Not Acceptable: Client must accept both application/json and text/event-stream'); + } + + const ct = req.headers.get('content-type'); + if (!ct?.includes('application/json')) { + onerror?.(new Error('Unsupported Media Type: Content-Type must be application/json')); + return jsonError(415, -32_000, 'Unsupported Media Type: Content-Type must be application/json'); + } + + let raw: unknown; + if (extra?.parsedBody === undefined) { + try { + raw = await req.json(); + } catch (error) { + onerror?.(error as Error); + return jsonError(400, -32_700, 'Parse error: Invalid JSON'); + } + } else { + raw = extra.parsedBody; + } + + let messages: JSONRPCMessage[]; + try { + messages = Array.isArray(raw) ? raw.map(m => JSONRPCMessageSchema.parse(m)) : [JSONRPCMessageSchema.parse(raw)]; + } catch (error) { + onerror?.(error as Error); + return jsonError(400, -32_700, 'Parse error: Invalid JSON-RPC message'); + } + + let sessionId: string | undefined; + let isInitialize = false; + if (session) { + const v = await session.validate(req, messages); + if (!v.ok) return v.response; + sessionId = v.sessionId; + isInitialize = v.isInitialize; + } + if (!isInitialize) { + const protoErr = validateProtocolVersion(req); + if (protoErr) return protoErr; + } + + const requests = messages.filter(m => isJSONRPCRequest(m)); + const notifications = messages.filter(m => isJSONRPCNotification(m)); + const responses = messages.filter( + (m): m is JSONRPCResultResponse | JSONRPCErrorResponse => isJSONRPCResultResponse(m) || isJSONRPCErrorResponse(m) + ); + + for (const n of notifications) { + void Promise.resolve(cb.onnotification?.(n)).catch(error => onerror?.(error as Error)); + } + + for (const r of responses) { + if (cb.onresponse?.(r)) continue; + if (backchannel && sessionId !== undefined) backchannel.handleResponse(sessionId, r); + } + + if (requests.length === 0) { + return new Response(null, { status: 202 }); + } + + const onrequest = cb.onrequest; + if (!onrequest) { + return jsonError(500, -32_603, 'Transport not connected — call mcp.connect(transport) first.'); + } + + const initReq = messages.find(m => isInitializeRequest(m)); + const clientProtocolVersion = + initReq && isInitializeRequest(initReq) + ? initReq.params.protocolVersion + : (req.headers.get('mcp-protocol-version') ?? DEFAULT_NEGOTIATED_PROTOCOL_VERSION); + + const baseEnv: RequestEnv = { sessionId, authInfo: extra?.authInfo, httpReq: req }; + const useBackchannel = backchannelEnabled(sessionId, clientProtocolVersion); + + if (enableJsonResponse) { + const out: JSONRPCMessage[] = []; + for (const r of requests) { + for await (const msg of onrequest(r, baseEnv)) { + if (isJSONRPCResultResponse(msg) || isJSONRPCErrorResponse(msg)) out.push(msg); + } + } + const headers: Record = { 'Content-Type': 'application/json' }; + if (sessionId !== undefined) headers['mcp-session-id'] = sessionId; + const body = out.length === 1 ? out[0] : out; + return Response.json(body, { status: 200, headers }); + } + + const streamId = crypto.randomUUID(); + const encoder = new TextEncoder(); + const headers: Record = { ...SSE_HEADERS }; + if (sessionId !== undefined) headers['mcp-session-id'] = sessionId; + + const readable = new ReadableStream({ + start: controller => { + const writeSSE = (msg: JSONRPCMessage) => void emit(controller, encoder, streamId, msg); + const closeStream = () => { + try { + controller.close(); + } catch { + // Already closed. + } + }; + const supportsPolling = eventStore !== undefined && clientProtocolVersion >= '2025-11-25'; + const transportExtra: MessageExtraInfo = { + request: req, + authInfo: extra?.authInfo, + closeSSEStream: supportsPolling ? closeStream : undefined, + closeStandaloneSSEStream: + supportsPolling && sessionId !== undefined + ? () => { + session?.closeStandaloneStream(sessionId); + backchannel?.setStandaloneWriter(sessionId, undefined); + } + : undefined + }; + const env: RequestEnv & { _transportExtra?: MessageExtraInfo } = { + ...baseEnv, + _transportExtra: transportExtra, + ...(useBackchannel && backchannel && sessionId !== undefined + ? { send: backchannel.makeEnvSend(sessionId, writeSSE) } + : {}) + }; + void (async () => { + try { + await writePrimingEvent(controller, encoder, streamId, clientProtocolVersion); + for (const r of requests) { + for await (const msg of onrequest(r, env)) { + await emit(controller, encoder, streamId, msg); + } + } + } catch (error) { + onerror?.(error as Error); + } finally { + try { + controller.close(); + } catch { + // Already closed. + } + } + })(); + } + }); + + return new Response(readable, { status: 200, headers }); + } + + async function handleGet(req: Request): Promise { + if (!session && !backchannel) { + return jsonError(405, -32_000, 'Method Not Allowed: stateless handler does not support GET stream', { + headers: { Allow: 'POST' } + }); + } + + const accept = req.headers.get('accept'); + if (!accept?.includes('text/event-stream')) { + onerror?.(new Error('Not Acceptable: Client must accept text/event-stream')); + return jsonError(406, -32_000, 'Not Acceptable: Client must accept text/event-stream'); + } + + let sessionId: string; + if (session) { + const v = session.validateHeader(req); + if (!v.ok) return v.response; + sessionId = v.sessionId!; + } else { + sessionId = STATELESS_GET_KEY; + } + const protoErr = validateProtocolVersion(req); + if (protoErr) return protoErr; + + if (eventStore) { + const lastEventId = req.headers.get('last-event-id'); + if (lastEventId) { + return replayEvents(lastEventId, sessionId); + } + } + + if (session?.hasStandaloneStream(sessionId) || (!session && backchannel?.hasStandaloneWriter(sessionId))) { + onerror?.(new Error('Conflict: Only one SSE stream is allowed per session')); + return jsonError(409, -32_000, 'Conflict: Only one SSE stream is allowed per session'); + } + + const encoder = new TextEncoder(); + const headers: Record = { ...SSE_HEADERS }; + if (session) headers['mcp-session-id'] = sessionId; + const readable = new ReadableStream({ + start: controller => { + session?.setStandaloneStream(sessionId, controller); + backchannel?.setStandaloneWriter(sessionId, msg => void emit(controller, encoder, STANDALONE_STREAM_ID, msg)); + }, + cancel: () => { + session?.setStandaloneStream(sessionId, undefined); + backchannel?.setStandaloneWriter(sessionId, undefined); + } + }); + return new Response(readable, { headers }); + } + + async function replayEvents(lastEventId: string, sessionId: string): Promise { + if (!eventStore) { + return jsonError(400, -32_000, 'Event store not configured'); + } + const encoder = new TextEncoder(); + const headers: Record = { ...SSE_HEADERS, 'mcp-session-id': sessionId }; + const readable = new ReadableStream({ + start: controller => { + void (async () => { + try { + await eventStore.replayEventsAfter(lastEventId, { + send: async (eventId, message) => { + writeSSEEvent(controller, encoder, message, eventId); + } + }); + if (session) session.setStandaloneStream(sessionId, controller); + backchannel?.setStandaloneWriter(sessionId, msg => void emit(controller, encoder, STANDALONE_STREAM_ID, msg)); + } catch (error) { + onerror?.(error as Error); + try { + controller.close(); + } catch { + // Already closed. + } + } + })(); + }, + cancel: () => { + session?.setStandaloneStream(sessionId, undefined); + backchannel?.setStandaloneWriter(sessionId, undefined); + } + }); + return new Response(readable, { headers }); + } + + async function handleDelete(req: Request): Promise { + if (!session) { + return jsonError(405, -32_000, 'Method Not Allowed: stateless handler does not support session DELETE', { + headers: { Allow: 'POST' } + }); + } + const v = session.validateHeader(req); + if (!v.ok) return v.response; + const protoErr = validateProtocolVersion(req); + if (protoErr) return protoErr; + backchannel?.closeSession(v.sessionId!); + try { + await session.delete(v.sessionId!); + } catch (error) { + onerror?.(error as Error); + return jsonError(500, -32_603, 'Internal server error: onsessionclosed callback failed', { + data: String(error) + }); + } + return new Response(null, { status: 200 }); + } + + return async (req: Request, extra?: ShttpRequestExtra): Promise => { + try { + switch (req.method) { + case 'POST': { + return await handlePost(req, extra); + } + case 'GET': { + return await handleGet(req); + } + case 'DELETE': { + return await handleDelete(req); + } + default: { + return jsonError(405, -32_000, 'Method not allowed.', { headers: { Allow: 'GET, POST, DELETE' } }); + } + } + } catch (error) { + onerror?.(error as Error); + return jsonError(400, -32_700, 'Parse error', { data: String(error) }); + } + }; +} diff --git a/packages/server/src/server/streamableHttp.ts b/packages/server/src/server/streamableHttp.ts index 6284189dd..ec8f3823b 100644 --- a/packages/server/src/server/streamableHttp.ts +++ b/packages/server/src/server/streamableHttp.ts @@ -1,71 +1,37 @@ /** * Web Standards Streamable HTTP Server Transport * - * This is the core transport implementation using Web Standard APIs (`Request`, `Response`, `ReadableStream`). - * It can run on any runtime that supports Web Standards: Node.js 18+, Cloudflare Workers, Deno, Bun, etc. + * Thin compat wrapper over {@linkcode shttpHandler} + {@linkcode SessionCompat} + + * {@linkcode BackchannelCompat}. The class name, constructor options, and + * {@linkcode Transport} interface are kept for back-compat so existing + * `server.connect(new WebStandardStreamableHTTPServerTransport({...}))` code + * works unchanged. Request handling delegates to {@linkcode shttpHandler}. * - * For Node.js Express/HTTP compatibility, use {@linkcode @modelcontextprotocol/node!NodeStreamableHTTPServerTransport | NodeStreamableHTTPServerTransport} which wraps this transport. + * For Node.js Express/HTTP compatibility, use + * {@linkcode @modelcontextprotocol/node!NodeStreamableHTTPServerTransport | NodeStreamableHTTPServerTransport} + * which wraps this transport. */ -import type { AuthInfo, JSONRPCMessage, MessageExtraInfo, RequestId, Transport } from '@modelcontextprotocol/core'; -import { - DEFAULT_NEGOTIATED_PROTOCOL_VERSION, - isInitializeRequest, - isJSONRPCErrorResponse, - isJSONRPCRequest, - isJSONRPCResultResponse, - JSONRPCMessageSchema, - SUPPORTED_PROTOCOL_VERSIONS +import type { + AuthInfo, + JSONRPCErrorResponse, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResultResponse, + MessageExtraInfo, + RequestEnv, + RequestTransport, + TransportSendOptions } from '@modelcontextprotocol/core'; +import { isJSONRPCErrorResponse, isJSONRPCResultResponse, SUPPORTED_PROTOCOL_VERSIONS } from '@modelcontextprotocol/core'; -export type StreamId = string; -export type EventId = string; +import { BackchannelCompat } from './backchannelCompat.js'; +import { SessionCompat } from './sessionCompat.js'; +import type { ShttpRequestExtra } from './shttpHandler.js'; +import { shttpHandler, STATELESS_GET_KEY } from './shttpHandler.js'; -/** - * Interface for resumability support via event storage - */ -export interface EventStore { - /** - * Stores an event for later retrieval - * @param streamId ID of the stream the event belongs to - * @param message The JSON-RPC message to store - * @returns The generated event ID for the stored event - */ - storeEvent(streamId: StreamId, message: JSONRPCMessage): Promise; - - /** - * Get the stream ID associated with a given event ID. - * @param eventId The event ID to look up - * @returns The stream ID, or `undefined` if not found - * - * Optional: If not provided, the SDK will use the `streamId` returned by - * {@linkcode replayEventsAfter} for stream mapping. - */ - getStreamIdForEventId?(eventId: EventId): Promise; - - replayEventsAfter( - lastEventId: EventId, - { - send - }: { - send: (eventId: EventId, message: JSONRPCMessage) => Promise; - } - ): Promise; -} - -/** - * Internal stream mapping for managing SSE connections - */ -interface StreamMapping { - /** Stream controller for pushing SSE data - only used with `ReadableStream` approach */ - controller?: ReadableStreamDefaultController; - /** Text encoder for SSE formatting */ - encoder?: InstanceType; - /** Promise resolver for JSON response mode */ - resolveJson?: (response: Response) => void; - /** Cleanup function to close stream and remove mapping */ - cleanup: () => void; -} +export type { EventId, EventStore, StreamId } from './shttpHandler.js'; /** * Configuration options for {@linkcode WebStandardStreamableHTTPServerTransport} @@ -111,7 +77,7 @@ export interface WebStandardStreamableHTTPServerTransportOptions { * Event store for resumability support * If provided, resumability will be enabled, allowing clients to reconnect and resume messages */ - eventStore?: EventStore; + eventStore?: import('./shttpHandler.js').EventStore; /** * List of allowed `Host` header values for DNS rebinding protection. @@ -180,65 +146,25 @@ export interface HandleRequestOptions { * - Session ID is generated and included in response headers * - Session ID is always included in initialization responses * - Requests with invalid session IDs are rejected with `404 Not Found` - * - Non-initialization requests without a session ID are rejected with `400 Bad Request` - * - State is maintained in-memory (connections, message history) - * - * In stateless mode: - * - No Session ID is included in any responses - * - No session validation is performed + * - GET opens a standalone subscription stream; DELETE terminates the session * - * @example Stateful setup - * ```ts source="./streamableHttp.examples.ts#WebStandardStreamableHTTPServerTransport_stateful" - * const server = new McpServer({ name: 'my-server', version: '1.0.0' }); + * In stateless mode (no `sessionIdGenerator`): + * - No session validation; GET/DELETE return 405 * - * const transport = new WebStandardStreamableHTTPServerTransport({ - * sessionIdGenerator: () => crypto.randomUUID() - * }); - * - * await server.connect(transport); - * ``` - * - * @example Stateless setup - * ```ts source="./streamableHttp.examples.ts#WebStandardStreamableHTTPServerTransport_stateless" - * const transport = new WebStandardStreamableHTTPServerTransport({ - * sessionIdGenerator: undefined - * }); - * ``` - * - * @example Hono.js - * ```ts source="./streamableHttp.examples.ts#WebStandardStreamableHTTPServerTransport_hono" - * app.all('/mcp', async c => { - * return transport.handleRequest(c.req.raw); - * }); - * ``` - * - * @example Cloudflare Workers - * ```ts source="./streamableHttp.examples.ts#WebStandardStreamableHTTPServerTransport_workers" - * const worker = { - * async fetch(request: Request): Promise { - * return transport.handleRequest(request); - * } - * }; - * ``` + * The class is now a thin shim: {@linkcode handleRequest} delegates to a captured + * {@linkcode shttpHandler} bound at {@linkcode connect | connect()} time. The + * {@linkcode Transport} interface methods route outbound messages through the + * per-session {@linkcode BackchannelCompat}. */ -export class WebStandardStreamableHTTPServerTransport implements Transport { - // when sessionId is not set (undefined), it means the transport is in stateless mode - private sessionIdGenerator: (() => string) | undefined; - private _started: boolean = false; - private _closed: boolean = false; - private _streamMapping: Map = new Map(); - private _requestToStreamMapping: Map = new Map(); - private _requestResponseMap: Map = new Map(); - private _initialized: boolean = false; - private _enableJsonResponse: boolean = false; - private _standaloneSseStreamId: string = '_GET_stream'; - private _eventStore?: EventStore; - private _onsessioninitialized?: (sessionId: string) => void | Promise; - private _onsessionclosed?: (sessionId: string) => void | Promise; - private _allowedHosts?: string[]; - private _allowedOrigins?: string[]; - private _enableDnsRebindingProtection: boolean; - private _retryInterval?: number; +export class WebStandardStreamableHTTPServerTransport implements RequestTransport { + readonly kind = 'request' as const; + + private _options: WebStandardStreamableHTTPServerTransportOptions; + private _session?: SessionCompat; + private _backchannel = new BackchannelCompat(); + private _handler: (req: Request, extra?: ShttpRequestExtra) => Promise; + private _started = false; + private _closed = false; private _supportedProtocolVersions: string[]; sessionId?: string; @@ -246,17 +172,53 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; + /** {@linkcode RequestTransport.onrequest} — set by `McpServer.connect()`. */ + onrequest: ((req: JSONRPCRequest, env?: RequestEnv) => AsyncIterable) | undefined = undefined; + /** {@linkcode RequestTransport.onnotification} — set by `McpServer.connect()`. */ + onnotification?: (n: JSONRPCNotification) => void | Promise; + /** {@linkcode RequestTransport.onresponse} — set by `McpServer.connect()`. */ + onresponse?: (r: JSONRPCResultResponse | JSONRPCErrorResponse) => boolean; + constructor(options: WebStandardStreamableHTTPServerTransportOptions = {}) { - this.sessionIdGenerator = options.sessionIdGenerator; - this._enableJsonResponse = options.enableJsonResponse ?? false; - this._eventStore = options.eventStore; - this._onsessioninitialized = options.onsessioninitialized; - this._onsessionclosed = options.onsessionclosed; - this._allowedHosts = options.allowedHosts; - this._allowedOrigins = options.allowedOrigins; - this._enableDnsRebindingProtection = options.enableDnsRebindingProtection ?? false; - this._retryInterval = options.retryInterval; + this._options = options; this._supportedProtocolVersions = options.supportedProtocolVersions ?? SUPPORTED_PROTOCOL_VERSIONS; + if (options.sessionIdGenerator) { + this._session = new SessionCompat({ + sessionIdGenerator: options.sessionIdGenerator, + singleSession: true, + onerror: e => this.onerror?.(e), + onsessioninitialized: id => { + this.sessionId = id; + return options.onsessioninitialized?.(id); + }, + onsessionclosed: id => { + this._backchannel.closeSession(id); + return options.onsessionclosed?.(id); + } + }); + } + // shttpHandler reads onrequest/onnotification/onresponse from `this` at call time, + // so connect() can set them after construction. + this._handler = shttpHandler(this, { + session: this._session, + backchannel: this._backchannel, + eventStore: this._options.eventStore, + enableJsonResponse: this._options.enableJsonResponse, + retryInterval: this._options.retryInterval, + supportedProtocolVersions: this._supportedProtocolVersions, + onerror: e => this.onerror?.(e) + }); + } + + /** + * Handles an incoming Web-standard {@linkcode Request} and returns a Web-standard {@linkcode Response}. + */ + async handleRequest(req: Request, options: HandleRequestOptions = {}): Promise { + if (this._options.enableDnsRebindingProtection) { + const err = this._validateDnsRebinding(req); + if (err) return err; + } + return this._handler(req, { parsedBody: options.parsedBody, authInfo: options.authInfo }); } /** @@ -278,761 +240,115 @@ export class WebStandardStreamableHTTPServerTransport implements Transport { this._supportedProtocolVersions = versions; } - /** - * Helper to create a JSON error response - */ - private createJsonErrorResponse( - status: number, - code: number, - message: string, - options?: { headers?: Record; data?: string } - ): Response { - const error: { code: number; message: string; data?: string } = { code, message }; - if (options?.data !== undefined) { - error.data = options.data; - } - return Response.json( - { - jsonrpc: '2.0', - error, - id: null - }, - { - status, - headers: { - 'Content-Type': 'application/json', - ...options?.headers - } - } - ); - } - - /** - * Validates request headers for DNS rebinding protection. - * @returns Error response if validation fails, `undefined` if validation passes. - */ - private validateRequestHeaders(req: Request): Response | undefined { - // Skip validation if protection is not enabled - if (!this._enableDnsRebindingProtection) { - return undefined; - } - - // Validate Host header if allowedHosts is configured - if (this._allowedHosts && this._allowedHosts.length > 0) { - const hostHeader = req.headers.get('host'); - if (!hostHeader || !this._allowedHosts.includes(hostHeader)) { - const error = `Invalid Host header: ${hostHeader}`; - this.onerror?.(new Error(error)); - return this.createJsonErrorResponse(403, -32_000, error); - } - } - - // Validate Origin header if allowedOrigins is configured - if (this._allowedOrigins && this._allowedOrigins.length > 0) { - const originHeader = req.headers.get('origin'); - if (originHeader && !this._allowedOrigins.includes(originHeader)) { - const error = `Invalid Origin header: ${originHeader}`; - this.onerror?.(new Error(error)); - return this.createJsonErrorResponse(403, -32_000, error); - } - } - - return undefined; - } - - /** - * Handles an incoming HTTP request, whether `GET`, `POST`, or `DELETE` - * Returns a `Response` object (Web Standard) - */ - async handleRequest(req: Request, options?: HandleRequestOptions): Promise { - // Validate request headers for DNS rebinding protection - const validationError = this.validateRequestHeaders(req); - if (validationError) { - return validationError; - } - - switch (req.method) { - case 'POST': { - return this.handlePostRequest(req, options); - } - case 'GET': { - return this.handleGetRequest(req); - } - case 'DELETE': { - return this.handleDeleteRequest(req); - } - default: { - return this.handleUnsupportedRequest(); - } - } - } - - /** - * Writes a priming event to establish resumption capability. - * Only sends if `eventStore` is configured (opt-in for resumability) and - * the client's protocol version supports empty SSE data (>= `2025-11-25`). - */ - private async writePrimingEvent( - controller: ReadableStreamDefaultController, - encoder: InstanceType, - streamId: string, - protocolVersion: string - ): Promise { - if (!this._eventStore) { - return; - } - - // Priming events have empty data which older clients cannot handle. - // Only send priming events to clients with protocol version >= 2025-11-25 - // which includes the fix for handling empty SSE data. - if (protocolVersion < '2025-11-25') { - return; - } - - const primingEventId = await this._eventStore.storeEvent(streamId, {} as JSONRPCMessage); - - let primingEvent = `id: ${primingEventId}\ndata: \n\n`; - if (this._retryInterval !== undefined) { - primingEvent = `id: ${primingEventId}\nretry: ${this._retryInterval}\ndata: \n\n`; - } - controller.enqueue(encoder.encode(primingEvent)); - } - - /** - * Handles `GET` requests for SSE stream - */ - private async handleGetRequest(req: Request): Promise { - // The client MUST include an Accept header, listing text/event-stream as a supported content type. - const acceptHeader = req.headers.get('accept'); - if (!acceptHeader?.includes('text/event-stream')) { - this.onerror?.(new Error('Not Acceptable: Client must accept text/event-stream')); - return this.createJsonErrorResponse(406, -32_000, 'Not Acceptable: Client must accept text/event-stream'); - } - - // If an Mcp-Session-Id is returned by the server during initialization, - // clients using the Streamable HTTP transport MUST include it - // in the Mcp-Session-Id header on all of their subsequent HTTP requests. - const sessionError = this.validateSession(req); - if (sessionError) { - return sessionError; - } - const protocolError = this.validateProtocolVersion(req); - if (protocolError) { - return protocolError; - } - - // Handle resumability: check for Last-Event-ID header - if (this._eventStore) { - const lastEventId = req.headers.get('last-event-id'); - if (lastEventId) { - return this.replayEvents(lastEventId); - } - } - - // Check if there's already an active standalone SSE stream for this session - if (this._streamMapping.get(this._standaloneSseStreamId) !== undefined) { - // Only one GET SSE stream is allowed per session - this.onerror?.(new Error('Conflict: Only one SSE stream is allowed per session')); - return this.createJsonErrorResponse(409, -32_000, 'Conflict: Only one SSE stream is allowed per session'); - } - - const encoder = new TextEncoder(); - let streamController: ReadableStreamDefaultController; - - // Create a ReadableStream with a controller we can use to push SSE events - const readable = new ReadableStream({ - start: controller => { - streamController = controller; - }, - cancel: () => { - // Stream was cancelled by client - this._streamMapping.delete(this._standaloneSseStreamId); - } - }); - - const headers: Record = { - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - Connection: 'keep-alive' - }; - - // After initialization, always include the session ID if we have one - if (this.sessionId !== undefined) { - headers['mcp-session-id'] = this.sessionId; - } - - // Store the stream mapping with the controller for pushing data - this._streamMapping.set(this._standaloneSseStreamId, { - controller: streamController!, - encoder, - cleanup: () => { - this._streamMapping.delete(this._standaloneSseStreamId); - try { - streamController!.close(); - } catch { - // Controller might already be closed - } - } - }); - - return new Response(readable, { headers }); - } - - /** - * Replays events that would have been sent after the specified event ID - * Only used when resumability is enabled - */ - private async replayEvents(lastEventId: string): Promise { - if (!this._eventStore) { - this.onerror?.(new Error('Event store not configured')); - return this.createJsonErrorResponse(400, -32_000, 'Event store not configured'); - } - - try { - // If getStreamIdForEventId is available, use it for conflict checking - let streamId: string | undefined; - if (this._eventStore.getStreamIdForEventId) { - streamId = await this._eventStore.getStreamIdForEventId(lastEventId); - - if (!streamId) { - this.onerror?.(new Error('Invalid event ID format')); - return this.createJsonErrorResponse(400, -32_000, 'Invalid event ID format'); - } - - // Check conflict with the SAME streamId we'll use for mapping - if (this._streamMapping.get(streamId) !== undefined) { - this.onerror?.(new Error('Conflict: Stream already has an active connection')); - return this.createJsonErrorResponse(409, -32_000, 'Conflict: Stream already has an active connection'); - } - } - - const headers: Record = { - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache, no-transform', - Connection: 'keep-alive' - }; - - if (this.sessionId !== undefined) { - headers['mcp-session-id'] = this.sessionId; - } - - // Create a ReadableStream with controller for SSE - const encoder = new TextEncoder(); - let streamController: ReadableStreamDefaultController; - - const readable = new ReadableStream({ - start: controller => { - streamController = controller; - }, - cancel: () => { - // Stream was cancelled by client - // Cleanup will be handled by the mapping - } - }); - - // Replay events - returns the streamId for backwards compatibility - const replayedStreamId = await this._eventStore.replayEventsAfter(lastEventId, { - send: async (eventId: string, message: JSONRPCMessage) => { - const success = this.writeSSEEvent(streamController!, encoder, message, eventId); - if (!success) { - try { - streamController!.close(); - } catch { - // Controller might already be closed - } - } - } - }); - - this._streamMapping.set(replayedStreamId, { - controller: streamController!, - encoder, - cleanup: () => { - this._streamMapping.delete(replayedStreamId); - try { - streamController!.close(); - } catch { - // Controller might already be closed - } - } - }); - - return new Response(readable, { headers }); - } catch (error) { - this.onerror?.(error as Error); - return this.createJsonErrorResponse(500, -32_000, 'Error replaying events'); - } + setProtocolVersion(_version: string): void { + // No-op: protocol version is per-session in SessionCompat. } /** - * Writes an event to an SSE stream via controller with proper formatting + * {@linkcode RequestTransport.notify} — write an unsolicited notification to the + * session's standalone GET subscription stream (2025-11 back-compat). */ - private writeSSEEvent( - controller: ReadableStreamDefaultController, - encoder: InstanceType, - message: JSONRPCMessage, - eventId?: string - ): boolean { - try { - let eventData = `event: message\n`; - // Include event ID if provided - this is important for resumability - if (eventId) { - eventData += `id: ${eventId}\n`; - } - eventData += `data: ${JSON.stringify(message)}\n\n`; - controller.enqueue(encoder.encode(eventData)); - return true; - } catch (error) { - this.onerror?.(error as Error); - return false; + async notify(n: JSONRPCNotification): Promise { + if (this._closed) return; + const sessionId = this.sessionId ?? STATELESS_GET_KEY; + const written = this._backchannel.writeStandalone(sessionId, n); + if (!written && this._options.eventStore) { + await this._options.eventStore.storeEvent('_GET_stream', n); } } /** - * Handles unsupported requests (`PUT`, `PATCH`, etc.) + * {@linkcode RequestTransport.request} — send an unsolicited server→client request via + * the standalone GET stream and await the client's POSTed-back response (2025-11 back-compat). */ - private handleUnsupportedRequest(): Response { - this.onerror?.(new Error('Method not allowed.')); - return Response.json( - { + request(r: JSONRPCRequest): Promise { + const sessionId = this.sessionId ?? STATELESS_GET_KEY; + const send = this._backchannel.makeEnvSend(sessionId, msg => void this._backchannel.writeStandalone(sessionId, msg)); + return send({ method: r.method, params: r.params }, {}).then( + result => ({ jsonrpc: '2.0', id: r.id, result }) as JSONRPCResultResponse, + (error: { code?: number; message?: string; data?: unknown }) => ({ jsonrpc: '2.0', + id: r.id, error: { - code: -32_000, - message: 'Method not allowed.' - }, - id: null - }, - { - status: 405, - headers: { - Allow: 'GET, POST, DELETE', - 'Content-Type': 'application/json' + code: error.code ?? -32_603, + message: error.message ?? String(error), + ...(error.data !== undefined && { data: error.data }) } - } + }) ); } /** - * Handles `POST` requests containing JSON-RPC messages + * {@linkcode ChannelTransport.send} (back-compat costume). Outbound responses route to the + * {@linkcode BackchannelCompat} resolver map; notifications and server-initiated requests go + * on the session's standalone GET stream. + * + * @deprecated Use {@linkcode notify} / {@linkcode request} (the {@linkcode RequestTransport} surface). */ - private async handlePostRequest(req: Request, options?: HandleRequestOptions): Promise { - try { - // Validate the Accept header - const acceptHeader = req.headers.get('accept'); - // The client MUST include an Accept header, listing both application/json and text/event-stream as supported content types. - if (!acceptHeader?.includes('application/json') || !acceptHeader.includes('text/event-stream')) { - this.onerror?.(new Error('Not Acceptable: Client must accept both application/json and text/event-stream')); - return this.createJsonErrorResponse( - 406, - -32_000, - 'Not Acceptable: Client must accept both application/json and text/event-stream' - ); - } - - const ct = req.headers.get('content-type'); - if (!ct || !ct.includes('application/json')) { - this.onerror?.(new Error('Unsupported Media Type: Content-Type must be application/json')); - return this.createJsonErrorResponse(415, -32_000, 'Unsupported Media Type: Content-Type must be application/json'); - } - - const request = req; - - let rawMessage; - if (options?.parsedBody === undefined) { - try { - rawMessage = await req.json(); - } catch (error) { - this.onerror?.(error as Error); - return this.createJsonErrorResponse(400, -32_700, 'Parse error: Invalid JSON'); - } - } else { - rawMessage = options.parsedBody; - } - - let messages: JSONRPCMessage[]; - - // handle batch and single messages - try { - messages = Array.isArray(rawMessage) - ? rawMessage.map(msg => JSONRPCMessageSchema.parse(msg)) - : [JSONRPCMessageSchema.parse(rawMessage)]; - } catch (error) { - this.onerror?.(error as Error); - return this.createJsonErrorResponse(400, -32_700, 'Parse error: Invalid JSON-RPC message'); - } - - // Check if this is an initialization request - // https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/lifecycle/ - const isInitializationRequest = messages.some(element => isInitializeRequest(element)); - if (isInitializationRequest) { - // If it's a server with session management and the session ID is already set we should reject the request - // to avoid re-initialization. - if (this._initialized && this.sessionId !== undefined) { - this.onerror?.(new Error('Invalid Request: Server already initialized')); - return this.createJsonErrorResponse(400, -32_600, 'Invalid Request: Server already initialized'); - } - if (messages.length > 1) { - this.onerror?.(new Error('Invalid Request: Only one initialization request is allowed')); - return this.createJsonErrorResponse(400, -32_600, 'Invalid Request: Only one initialization request is allowed'); - } - this.sessionId = this.sessionIdGenerator?.(); - this._initialized = true; - - // If we have a session ID and an onsessioninitialized handler, call it immediately - // This is needed in cases where the server needs to keep track of multiple sessions - if (this.sessionId && this._onsessioninitialized) { - await Promise.resolve(this._onsessioninitialized(this.sessionId)); - } - } - if (!isInitializationRequest) { - // If an Mcp-Session-Id is returned by the server during initialization, - // clients using the Streamable HTTP transport MUST include it - // in the Mcp-Session-Id header on all of their subsequent HTTP requests. - const sessionError = this.validateSession(req); - if (sessionError) { - return sessionError; - } - // Mcp-Protocol-Version header is required for all requests after initialization. - const protocolError = this.validateProtocolVersion(req); - if (protocolError) { - return protocolError; - } - } - - // check if it contains requests - const hasRequests = messages.some(element => isJSONRPCRequest(element)); - - if (!hasRequests) { - // if it only contains notifications or responses, return 202 - for (const message of messages) { - this.onmessage?.(message, { authInfo: options?.authInfo, request }); - } - return new Response(null, { status: 202 }); - } - - // The default behavior is to use SSE streaming - // but in some cases server will return JSON responses - const streamId = crypto.randomUUID(); - - // Extract protocol version for priming event decision. - // For initialize requests, get from request params. - // For other requests, get from header (already validated). - const initRequest = messages.find(m => isInitializeRequest(m)); - const clientProtocolVersion = initRequest - ? initRequest.params.protocolVersion - : (req.headers.get('mcp-protocol-version') ?? DEFAULT_NEGOTIATED_PROTOCOL_VERSION); - - if (this._enableJsonResponse) { - // For JSON response mode, return a Promise that resolves when all responses are ready - return new Promise(resolve => { - this._streamMapping.set(streamId, { - resolveJson: resolve, - cleanup: () => { - this._streamMapping.delete(streamId); - } - }); - - for (const message of messages) { - if (isJSONRPCRequest(message)) { - this._requestToStreamMapping.set(message.id, streamId); - } - } - - for (const message of messages) { - this.onmessage?.(message, { authInfo: options?.authInfo, request }); - } - }); - } - - // SSE streaming mode - use ReadableStream with controller for more reliable data pushing - const encoder = new TextEncoder(); - let streamController: ReadableStreamDefaultController; - - const readable = new ReadableStream({ - start: controller => { - streamController = controller; - }, - cancel: () => { - // Stream was cancelled by client - this._streamMapping.delete(streamId); - } - }); - - const headers: Record = { - 'Content-Type': 'text/event-stream', - 'Cache-Control': 'no-cache', - Connection: 'keep-alive' - }; - - // After initialization, always include the session ID if we have one - if (this.sessionId !== undefined) { - headers['mcp-session-id'] = this.sessionId; - } - - // Store the response for this request to send messages back through this connection - // We need to track by request ID to maintain the connection - for (const message of messages) { - if (isJSONRPCRequest(message)) { - this._streamMapping.set(streamId, { - controller: streamController!, - encoder, - cleanup: () => { - this._streamMapping.delete(streamId); - try { - streamController!.close(); - } catch { - // Controller might already be closed - } - } - }); - this._requestToStreamMapping.set(message.id, streamId); - } - } - - // Write priming event if event store is configured (after mapping is set up) - await this.writePrimingEvent(streamController!, encoder, streamId, clientProtocolVersion); - - // handle each message - for (const message of messages) { - // Build closeSSEStream callback for requests when eventStore is configured - // AND client supports resumability (protocol version >= 2025-11-25). - // Old clients can't resume if the stream is closed early because they - // didn't receive a priming event with an event ID. - let closeSSEStream: (() => void) | undefined; - let closeStandaloneSSEStream: (() => void) | undefined; - if (isJSONRPCRequest(message) && this._eventStore && clientProtocolVersion >= '2025-11-25') { - closeSSEStream = () => { - this.closeSSEStream(message.id); - }; - closeStandaloneSSEStream = () => { - this.closeStandaloneSSEStream(); - }; - } - - this.onmessage?.(message, { authInfo: options?.authInfo, request, closeSSEStream, closeStandaloneSSEStream }); - } - // The server SHOULD NOT close the SSE stream before sending all JSON-RPC responses - // This will be handled by the send() method when responses are ready - - return new Response(readable, { status: 200, headers }); - } catch (error) { - // return JSON-RPC formatted error - this.onerror?.(error as Error); - return this.createJsonErrorResponse(400, -32_700, 'Parse error', { data: String(error) }); + async send(message: JSONRPCMessage, _options?: TransportSendOptions): Promise { + if (this._closed) return; + const sessionId = this.sessionId ?? STATELESS_GET_KEY; + if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { + this._backchannel.handleResponse(sessionId, message); + return; + } + const written = this._backchannel.writeStandalone(sessionId, message); + if (!written && this._options.eventStore) { + await this._options.eventStore.storeEvent('_GET_stream', message); } } /** - * Handles `DELETE` requests to terminate sessions + * Close an SSE stream for a specific request, triggering client reconnection. + * @deprecated Per-request stream tracking was removed; this is now a no-op. Use + * `ctx.http?.closeSSE` from inside the handler instead. */ - private async handleDeleteRequest(req: Request): Promise { - const sessionError = this.validateSession(req); - if (sessionError) { - return sessionError; - } - const protocolError = this.validateProtocolVersion(req); - if (protocolError) { - return protocolError; - } - - await Promise.resolve(this._onsessionclosed?.(this.sessionId!)); - await this.close(); - return new Response(null, { status: 200 }); + closeSSEStream(_requestId: unknown): void { + // No per-request stream map in the new model. } /** - * Validates session ID for non-initialization requests. - * Returns `Response` error if invalid, `undefined` otherwise + * Close the standalone GET SSE stream, triggering client reconnection. */ - private validateSession(req: Request): Response | undefined { - if (this.sessionIdGenerator === undefined) { - // If the sessionIdGenerator ID is not set, the session management is disabled - // and we don't need to validate the session ID - return undefined; - } - if (!this._initialized) { - // If the server has not been initialized yet, reject all requests - this.onerror?.(new Error('Bad Request: Server not initialized')); - return this.createJsonErrorResponse(400, -32_000, 'Bad Request: Server not initialized'); - } - - const sessionId = req.headers.get('mcp-session-id'); - - if (!sessionId) { - // Non-initialization requests without a session ID should return 400 Bad Request - this.onerror?.(new Error('Bad Request: Mcp-Session-Id header is required')); - return this.createJsonErrorResponse(400, -32_000, 'Bad Request: Mcp-Session-Id header is required'); - } - - if (sessionId !== this.sessionId) { - // Reject requests with invalid session ID with 404 Not Found - this.onerror?.(new Error('Session not found')); - return this.createJsonErrorResponse(404, -32_001, 'Session not found'); + closeStandaloneSSEStream(): void { + if (this.sessionId !== undefined) { + this._session?.closeStandaloneStream(this.sessionId); + this._backchannel.setStandaloneWriter(this.sessionId, undefined); } - - return undefined; } /** - * Validates the `MCP-Protocol-Version` header on incoming requests. - * - * For initialization: Version negotiation handles unknown versions gracefully - * (server responds with its supported version). - * - * For subsequent requests with `MCP-Protocol-Version` header: - * - Accept if in supported list - * - 400 if unsupported - * - * For HTTP requests without the `MCP-Protocol-Version` header: - * - Accept and default to the version negotiated at initialization + * Closes the transport. */ - private validateProtocolVersion(req: Request): Response | undefined { - const protocolVersion = req.headers.get('mcp-protocol-version'); - - if (protocolVersion !== null && !this._supportedProtocolVersions.includes(protocolVersion)) { - const error = `Bad Request: Unsupported protocol version: ${protocolVersion} (supported versions: ${this._supportedProtocolVersions.join(', ')})`; - this.onerror?.(new Error(error)); - return this.createJsonErrorResponse(400, -32_000, error); - } - return undefined; - } - async close(): Promise { - if (this._closed) { - return; - } + if (this._closed) return; this._closed = true; - - // Close all SSE connections - for (const { cleanup } of this._streamMapping.values()) { - cleanup(); + if (this.sessionId !== undefined) { + this._backchannel.closeSession(this.sessionId); + await this._session?.delete(this.sessionId); } - this._streamMapping.clear(); - - // Clear any pending responses - this._requestResponseMap.clear(); this.onclose?.(); } - /** - * Close an SSE stream for a specific request, triggering client reconnection. - * Use this to implement polling behavior during long-running operations - - * client will reconnect after the retry interval specified in the priming event. - */ - closeSSEStream(requestId: RequestId): void { - const streamId = this._requestToStreamMapping.get(requestId); - if (!streamId) return; - - const stream = this._streamMapping.get(streamId); - if (stream) { - stream.cleanup(); - } - } - - /** - * Close the standalone `GET` SSE stream, triggering client reconnection. - * Use this to implement polling behavior for server-initiated notifications. - */ - closeStandaloneSSEStream(): void { - const stream = this._streamMapping.get(this._standaloneSseStreamId); - if (stream) { - stream.cleanup(); - } - } - - async send(message: JSONRPCMessage, options?: { relatedRequestId?: RequestId }): Promise { - let requestId = options?.relatedRequestId; - if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - // If the message is a response, use the request ID from the message - requestId = message.id; - } - - // Check if this message should be sent on the standalone SSE stream (no request ID) - // Ignore notifications from tools (which have relatedRequestId set) - // Those will be sent via dedicated response SSE streams - if (requestId === undefined) { - // For standalone SSE streams, we can only send requests and notifications - if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - throw new Error('Cannot send a response on a standalone SSE stream unless resuming a previous client request'); - } - - // Generate and store event ID if event store is provided - // Store even if stream is disconnected so events can be replayed on reconnect - let eventId: string | undefined; - if (this._eventStore) { - // Stores the event and gets the generated event ID - eventId = await this._eventStore.storeEvent(this._standaloneSseStreamId, message); - } - - const standaloneSse = this._streamMapping.get(this._standaloneSseStreamId); - if (standaloneSse === undefined) { - // Stream is disconnected - event is stored for replay, nothing more to do - return; - } - - // Send the message to the standalone SSE stream - if (standaloneSse.controller && standaloneSse.encoder) { - this.writeSSEEvent(standaloneSse.controller, standaloneSse.encoder, message, eventId); - } - return; - } - - // Get the response for this request - const streamId = this._requestToStreamMapping.get(requestId); - if (!streamId) { - throw new Error(`No connection established for request ID: ${String(requestId)}`); - } - - const stream = this._streamMapping.get(streamId); - - if (!this._enableJsonResponse && stream?.controller && stream?.encoder) { - // For SSE responses, generate event ID if event store is provided - let eventId: string | undefined; - - if (this._eventStore) { - eventId = await this._eventStore.storeEvent(streamId, message); + private _validateDnsRebinding(req: Request): Response | undefined { + if (this._options.allowedHosts && this._options.allowedHosts.length > 0) { + const host = req.headers.get('host'); + if (!host || !this._options.allowedHosts.includes(host)) { + return Response.json( + { jsonrpc: '2.0', error: { code: -32_000, message: `Invalid Host header: ${host ?? '(missing)'}` }, id: null }, + { status: 403 } + ); } - // Write the event to the response stream - this.writeSSEEvent(stream.controller, stream.encoder, message, eventId); } - - if (isJSONRPCResultResponse(message) || isJSONRPCErrorResponse(message)) { - this._requestResponseMap.set(requestId, message); - const relatedIds = [...this._requestToStreamMapping.entries()].filter(([_, sid]) => sid === streamId).map(([id]) => id); - - // Check if we have responses for all requests using this connection - const allResponsesReady = relatedIds.every(id => this._requestResponseMap.has(id)); - - if (allResponsesReady) { - if (!stream) { - throw new Error(`No connection established for request ID: ${String(requestId)}`); - } - if (this._enableJsonResponse && stream.resolveJson) { - // All responses ready, send as JSON - const headers: Record = { - 'Content-Type': 'application/json' - }; - if (this.sessionId !== undefined) { - headers['mcp-session-id'] = this.sessionId; - } - - const responses = relatedIds.map(id => this._requestResponseMap.get(id)!); - - if (responses.length === 1) { - stream.resolveJson(Response.json(responses[0], { status: 200, headers })); - } else { - stream.resolveJson(Response.json(responses, { status: 200, headers })); - } - } else { - // End the SSE stream - stream.cleanup(); - } - // Clean up - for (const id of relatedIds) { - this._requestResponseMap.delete(id); - this._requestToStreamMapping.delete(id); - } + if (this._options.allowedOrigins && this._options.allowedOrigins.length > 0) { + const origin = req.headers.get('origin'); + if (origin && !this._options.allowedOrigins.includes(origin)) { + return Response.json( + { jsonrpc: '2.0', error: { code: -32_000, message: `Invalid Origin header: ${origin}` }, id: null }, + { status: 403 } + ); } } + return undefined; } } diff --git a/packages/server/src/validators/cfWorker.ts b/packages/server/src/validators/cfWorker.ts index 9a3a88405..e04436dbd 100644 --- a/packages/server/src/validators/cfWorker.ts +++ b/packages/server/src/validators/cfWorker.ts @@ -6,5 +6,5 @@ * import { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/server/validators/cf-worker'; * ``` */ -export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; export type { CfWorkerSchemaDraft } from '@modelcontextprotocol/core'; +export { CfWorkerJsonSchemaValidator } from '@modelcontextprotocol/core'; diff --git a/packages/server/src/zodSchemas.ts b/packages/server/src/zodSchemas.ts new file mode 100644 index 000000000..d8f1383f8 --- /dev/null +++ b/packages/server/src/zodSchemas.ts @@ -0,0 +1,12 @@ +// v1-compat subpath: `@modelcontextprotocol/server/zod-schemas` +// +// Re-exports the Zod schema constants (`*Schema`) that v1's `types.js` +// exposed alongside the spec types. v2 keeps these out of the main barrel +// (they pull in zod at runtime); this subpath lets the `@modelcontextprotocol/sdk` +// meta-package's `types.js` shim restore them for v1 callers of +// `setRequestHandler(SomeRequestSchema, handler)`. +// +// Source of truth: core's internal types/schemas.ts + shared/auth.ts. + +// eslint-disable-next-line import/export -- intentional bulk re-export of internal zod constants +export * from '@modelcontextprotocol/core'; diff --git a/packages/server/test/mcpServer.test.ts b/packages/server/test/mcpServer.test.ts new file mode 100644 index 000000000..03dd334cc --- /dev/null +++ b/packages/server/test/mcpServer.test.ts @@ -0,0 +1,272 @@ +import type { JSONRPCErrorResponse, JSONRPCMessage, JSONRPCRequest, JSONRPCResultResponse } from '@modelcontextprotocol/core'; +import { InMemoryTransport, isJSONRPCErrorResponse, isJSONRPCResultResponse, LATEST_PROTOCOL_VERSION } from '@modelcontextprotocol/core'; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type R = JSONRPCResultResponse & { result: any }; +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; + +import { McpServer, ResourceTemplate } from '../src/server/mcpServer.js'; + +const req = (id: number, method: string, params?: Record): JSONRPCRequest => ({ + jsonrpc: '2.0', + id, + method, + params +}); + +const initReq = (id = 0): JSONRPCRequest => + req(id, 'initialize', { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: { elicitation: { form: {} } }, + clientInfo: { name: 't', version: '1' } + }); + +async function collect(it: AsyncIterable): Promise { + const out: JSONRPCMessage[] = []; + for await (const m of it) out.push(m); + return out; +} + +async function lastResponse(it: AsyncIterable): Promise { + const all = await collect(it); + const last = all[all.length - 1]; + if (!isJSONRPCResultResponse(last) && !isJSONRPCErrorResponse(last)) throw new Error('no terminal response'); + return last; +} + +describe('McpServer.handle()', () => { + it('responds to initialize with serverInfo and capabilities', async () => { + const s = new McpServer({ name: 'srv', version: '1.0.0' }, { instructions: 'hi' }); + const r = (await lastResponse(s.handle(initReq(1)))) as R; + expect(r.id).toBe(1); + expect(r.result.serverInfo).toEqual({ name: 'srv', version: '1.0.0' }); + expect(r.result.instructions).toBe('hi'); + expect(r.result.protocolVersion).toBe(LATEST_PROTOCOL_VERSION); + }); + + it('responds to ping', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const r = (await lastResponse(s.handle(req(1, 'ping')))) as R; + expect(r.result).toEqual({}); + }); + + it('returns MethodNotFound for unknown method', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const r = (await lastResponse(s.handle(req(1, 'nope/nope' as never)))) as JSONRPCErrorResponse; + expect(r.error.code).toBe(-32601); + }); + + it('registerTool + tools/list returns the tool', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('echo', { description: 'd', inputSchema: z.object({ x: z.string() }) }, async ({ x }) => ({ + content: [{ type: 'text', text: x }] + })); + const r = (await lastResponse(s.handle(req(1, 'tools/list')))) as R; + expect(r.result.tools).toHaveLength(1); + expect(r.result.tools[0].name).toBe('echo'); + expect(r.result.tools[0].inputSchema.type).toBe('object'); + }); + + it('tools/call invokes handler with validated args', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('echo', { inputSchema: z.object({ x: z.string() }) }, async ({ x }) => ({ + content: [{ type: 'text', text: `got ${x}` }] + })); + const r = (await lastResponse(s.handle(req(1, 'tools/call', { name: 'echo', arguments: { x: 'hi' } })))) as R; + expect(r.result.content[0].text).toBe('got hi'); + }); + + it('tools/call with invalid args returns isError result', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('echo', { inputSchema: z.object({ x: z.string() }) }, async ({ x }) => ({ + content: [{ type: 'text', text: x }] + })); + const r = (await lastResponse(s.handle(req(1, 'tools/call', { name: 'echo', arguments: { x: 42 } })))) as R; + expect(r.result.isError).toBe(true); + }); + + it('tools/call with unknown tool returns InvalidParams error response', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('a', {}, async () => ({ content: [] })); + const r = (await lastResponse(s.handle(req(1, 'tools/call', { name: 'b', arguments: {} })))) as JSONRPCErrorResponse; + expect(r.error.code).toBe(-32602); + expect(r.error.message).toContain('not found'); + }); + + it('handle yields notifications then a terminal response', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('progress', {}, async ctx => { + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 1, progress: 0.5 } }); + await ctx.mcpReq.notify({ method: 'notifications/progress', params: { progressToken: 1, progress: 1.0 } }); + return { content: [{ type: 'text', text: 'done' }] }; + }); + const msgs = await collect(s.handle(req(1, 'tools/call', { name: 'progress', arguments: {} }))); + expect(msgs).toHaveLength(3); + expect((msgs[0] as { method: string }).method).toBe('notifications/progress'); + expect((msgs[1] as { method: string }).method).toBe('notifications/progress'); + expect(isJSONRPCResultResponse(msgs[2])).toBe(true); + }); + + it('ctx.mcpReq.elicitInput throws when no peer channel (handle without env.send)', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('ask', {}, async ctx => { + await ctx.mcpReq.elicitInput({ message: 'q', requestedSchema: { type: 'object', properties: {} } }); + return { content: [] }; + }); + const r = (await lastResponse(s.handle(req(1, 'tools/call', { name: 'ask', arguments: {} })))) as R; + expect(r.result.isError).toBe(true); + expect(r.result.content[0].text).toContain('MRTR-native'); + }); + + it('ctx.mcpReq.elicitInput resolves when env.send provided', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('ask', {}, async ctx => { + const er = await ctx.mcpReq.elicitInput({ message: 'q', requestedSchema: { type: 'object', properties: {} } }); + return { content: [{ type: 'text', text: er.action }] }; + }); + const r = (await lastResponse( + s.handle(req(1, 'tools/call', { name: 'ask', arguments: {} }), { + send: async () => ({ action: 'accept', content: {} }) + }) + )) as R; + expect(r.result.content[0].text).toBe('accept'); + }); + + it('registerResource + resources/list + resources/read', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerResource('cfg', 'config://app', { mimeType: 'text/plain' }, async uri => ({ + contents: [{ uri: uri.href, text: 'v' }] + })); + const list = (await lastResponse(s.handle(req(1, 'resources/list')))) as R; + expect(list.result.resources[0].uri).toBe('config://app'); + const read = (await lastResponse(s.handle(req(2, 'resources/read', { uri: 'config://app' })))) as R; + expect(read.result.contents[0].text).toBe('v'); + }); + + it('registerResource with template + resources/read matches', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerResource('user', new ResourceTemplate('users://{id}', { list: undefined }), {}, async (uri, { id }) => ({ + contents: [{ uri: uri.href, text: String(id) }] + })); + const r = (await lastResponse(s.handle(req(1, 'resources/read', { uri: 'users://abc' })))) as R; + expect(r.result.contents[0].text).toBe('abc'); + }); + + it('registerPrompt + prompts/list + prompts/get', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerPrompt('p', { argsSchema: z.object({ q: z.string() }) }, ({ q }) => ({ + messages: [{ role: 'user', content: { type: 'text', text: q } }] + })); + const list = (await lastResponse(s.handle(req(1, 'prompts/list')))) as R; + expect(list.result.prompts[0].name).toBe('p'); + const get = (await lastResponse(s.handle(req(2, 'prompts/get', { name: 'p', arguments: { q: 'hi' } })))) as R; + expect(get.result.messages[0].content.text).toBe('hi'); + }); + + it('RegisteredTool.disable hides from tools/list', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const t = s.registerTool('x', {}, async () => ({ content: [] })); + t.disable(); + const r = (await lastResponse(s.handle(req(1, 'tools/list')))) as R; + expect(r.result.tools).toHaveLength(0); + }); + + it('handleHttp parses body and returns JSON response', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const httpReq = new Request('http://x/mcp', { + method: 'POST', + body: JSON.stringify(req(1, 'ping')), + headers: { 'content-type': 'application/json' } + }); + const res = await s.handleHttp(httpReq); + expect(res.status).toBe(200); + const body = (await res.json()) as R; + expect(body.id).toBe(1); + expect(body.result).toEqual({}); + }); + + it('handleHttp returns 400 on parse error', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const res = await s.handleHttp(new Request('http://x/mcp', { method: 'POST', body: '{broken' })); + expect(res.status).toBe(400); + }); + + it('handleHttp returns 202 for notification-only body', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const res = await s.handleHttp( + new Request('http://x/mcp', { + method: 'POST', + body: JSON.stringify({ jsonrpc: '2.0', method: 'notifications/initialized' }) + }) + ); + expect(res.status).toBe(202); + }); +}); + +describe('McpServer compat / .server / connect()', () => { + it('.server === this', () => { + const s = new McpServer({ name: 's', version: '1' }); + expect(s.server).toBe(s); + }); + + it('isConnected reflects connect/close', async () => { + const s = new McpServer({ name: 's', version: '1' }); + expect(s.isConnected()).toBe(false); + const [a, b] = InMemoryTransport.createLinkedPair(); + await s.connect(a); + expect(s.isConnected()).toBe(true); + expect(s.transport).toBe(a); + void b; + await s.close(); + expect(s.isConnected()).toBe(false); + }); + + it('connect() then peer can send tools/list', async () => { + const s = new McpServer({ name: 's', version: '1' }); + s.registerTool('t', {}, async () => ({ content: [] })); + const [serverPipe, clientPipe] = InMemoryTransport.createLinkedPair(); + await s.connect(serverPipe); + await clientPipe.start(); + + const responses: JSONRPCMessage[] = []; + clientPipe.onmessage = m => responses.push(m); + + await clientPipe.send(initReq(0)); + await clientPipe.send(req(1, 'tools/list')); + await new Promise(r => setTimeout(r, 10)); + + const listResp = responses.find(m => isJSONRPCResultResponse(m) && m.id === 1) as R; + expect(listResp.result.tools[0].name).toBe('t'); + }); + + it('connect() twice replaces the active driver (v1 multi-transport pattern)', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const [a] = InMemoryTransport.createLinkedPair(); + await s.connect(a); + const [c] = InMemoryTransport.createLinkedPair(); + await expect(s.connect(c)).resolves.toBeUndefined(); + expect(s.transport).toBe(c); + }); + + it('elicitInput() instance method throws NotConnected when no driver', async () => { + const s = new McpServer({ name: 's', version: '1' }); + await expect(s.elicitInput({ message: 'q', requestedSchema: { type: 'object', properties: {} } })).rejects.toThrow( + /not connected/i + ); + }); + + it('registerCapabilities throws after connect', async () => { + const s = new McpServer({ name: 's', version: '1' }); + const [a] = InMemoryTransport.createLinkedPair(); + await s.connect(a); + expect(() => s.registerCapabilities({ logging: {} })).toThrow(); + }); + + it('initialize via handle() populates getClientCapabilities', async () => { + const s = new McpServer({ name: 's', version: '1' }); + await lastResponse(s.handle(initReq(0))); + expect(s.getClientCapabilities()?.elicitation?.form).toBeDefined(); + expect(s.getClientVersion()?.name).toBe('t'); + }); +}); diff --git a/packages/server/test/server/shttpHandler.test.ts b/packages/server/test/server/shttpHandler.test.ts new file mode 100644 index 000000000..2dd3b9712 --- /dev/null +++ b/packages/server/test/server/shttpHandler.test.ts @@ -0,0 +1,246 @@ +import { describe, expect, it } from 'vitest'; + +import type { RequestEnv, JSONRPCMessage, JSONRPCNotification, JSONRPCRequest } from '@modelcontextprotocol/core'; + +import { SessionCompat } from '../../src/server/sessionCompat.js'; +import type { ShttpCallbacks } from '../../src/server/shttpHandler.js'; +import { shttpHandler } from '../../src/server/shttpHandler.js'; + +/** Minimal in-test callback bundle: maps method name → result, with optional pre-yield notification. */ +function fakeServer( + handlers: Record unknown>, + opts: { preNotify?: JSONRPCNotification } = {} +): ShttpCallbacks { + return { + async *onrequest(req: JSONRPCRequest, _env?: RequestEnv): AsyncIterable { + if (opts.preNotify) yield opts.preNotify; + const h = handlers[req.method]; + if (!h) { + yield { jsonrpc: '2.0', id: req.id, error: { code: -32_601, message: 'Method not found' } }; + return; + } + yield { jsonrpc: '2.0', id: req.id, result: h(req) as Record }; + }, + async onnotification(_n: JSONRPCNotification): Promise { + return; + } + }; +} + +const ACCEPT_BOTH = 'application/json, text/event-stream'; + +function post(body: unknown, headers: Record = {}): Request { + return new Request('http://localhost/mcp', { + method: 'POST', + headers: { 'content-type': 'application/json', accept: ACCEPT_BOTH, ...headers }, + body: JSON.stringify(body) + }); +} + +const initialize = (id: number | string = 1): JSONRPCRequest => ({ + jsonrpc: '2.0', + id, + method: 'initialize', + params: { protocolVersion: '2025-11-25', clientInfo: { name: 't', version: '1' }, capabilities: {} } +}); + +const toolsList = (id: number | string = 1): JSONRPCRequest => ({ jsonrpc: '2.0', id, method: 'tools/list', params: {} }); + +async function readSSE(res: Response): Promise { + const text = await res.text(); + const out: JSONRPCMessage[] = []; + for (const block of text.split('\n\n')) { + const dataLine = block.split('\n').find(l => l.startsWith('data: ')); + if (!dataLine) continue; + const payload = dataLine.slice('data: '.length); + if (payload.trim() === '') continue; + out.push(JSON.parse(payload)); + } + return out; +} + +describe('shttpHandler — stateless', () => { + const server = fakeServer({ + 'tools/list': () => ({ tools: [{ name: 'echo', inputSchema: { type: 'object' } }] }), + initialize: () => ({ protocolVersion: '2025-11-25', serverInfo: { name: 's', version: '1' }, capabilities: {} }) + }); + + it('POST → SSE response with one result event', async () => { + const handler = shttpHandler(server); + const res = await handler(post(toolsList())); + expect(res.status).toBe(200); + expect(res.headers.get('content-type')).toBe('text/event-stream'); + const msgs = await readSSE(res); + expect(msgs).toHaveLength(1); + expect(msgs[0]).toMatchObject({ id: 1, result: { tools: [{ name: 'echo' }] } }); + }); + + it('POST with enableJsonResponse → application/json body', async () => { + const handler = shttpHandler(server, { enableJsonResponse: true }); + const res = await handler(post(toolsList())); + expect(res.status).toBe(200); + expect(res.headers.get('content-type')).toContain('application/json'); + const body = await res.json(); + expect(body).toMatchObject({ id: 1, result: { tools: expect.any(Array) } }); + }); + + it('POST batch → SSE with one response per request, in order', async () => { + const handler = shttpHandler(server); + const res = await handler(post([toolsList(1), toolsList(2)])); + const msgs = await readSSE(res); + expect(msgs.map(m => (m as { id: number }).id)).toEqual([1, 2]); + }); + + it('POST notification only → 202', async () => { + const handler = shttpHandler(server); + const res = await handler(post({ jsonrpc: '2.0', method: 'notifications/initialized' })); + expect(res.status).toBe(202); + }); + + it('handler-yielded notification precedes the response in SSE', async () => { + const progress: JSONRPCNotification = { + jsonrpc: '2.0', + method: 'notifications/progress', + params: { progressToken: 1, progress: 0.5 } + }; + const s = fakeServer({ 'tools/list': () => ({ tools: [] }) }, { preNotify: progress }); + const handler = shttpHandler(s); + const msgs = await readSSE(await handler(post(toolsList()))); + expect(msgs).toHaveLength(2); + expect((msgs[0] as JSONRPCNotification).method).toBe('notifications/progress'); + expect(msgs[1]).toMatchObject({ id: 1, result: { tools: [] } }); + }); + + it('bad Content-Type → 415', async () => { + const handler = shttpHandler(server); + const req = new Request('http://localhost/mcp', { + method: 'POST', + headers: { 'content-type': 'text/plain', accept: ACCEPT_BOTH }, + body: '{}' + }); + expect((await handler(req)).status).toBe(415); + }); + + it('Accept missing text/event-stream → 406', async () => { + const handler = shttpHandler(server); + const req = new Request('http://localhost/mcp', { + method: 'POST', + headers: { 'content-type': 'application/json', accept: 'application/json' }, + body: JSON.stringify(toolsList()) + }); + expect((await handler(req)).status).toBe(406); + }); + + it('invalid JSON body → 400 with code -32700', async () => { + const handler = shttpHandler(server); + const req = new Request('http://localhost/mcp', { + method: 'POST', + headers: { 'content-type': 'application/json', accept: ACCEPT_BOTH }, + body: '{not json' + }); + const res = await handler(req); + expect(res.status).toBe(400); + const body = (await res.json()) as { error: { code: number } }; + expect(body.error.code).toBe(-32_700); + }); + + it('unsupported HTTP method → 405', async () => { + const handler = shttpHandler(server); + const res = await handler(new Request('http://localhost/mcp', { method: 'PUT' })); + expect(res.status).toBe(405); + }); + + it('unsupported mcp-protocol-version header → 400', async () => { + const handler = shttpHandler(server); + const res = await handler(post(toolsList(), { 'mcp-protocol-version': '1999-01-01' })); + expect(res.status).toBe(400); + }); + + it('GET without session compat → 405', async () => { + const handler = shttpHandler(server); + const res = await handler(new Request('http://localhost/mcp', { method: 'GET', headers: { accept: 'text/event-stream' } })); + expect(res.status).toBe(405); + }); + + it('DELETE without session compat → 405', async () => { + const handler = shttpHandler(server); + const res = await handler(new Request('http://localhost/mcp', { method: 'DELETE' })); + expect(res.status).toBe(405); + }); +}); + +describe('shttpHandler — with SessionCompat', () => { + const server = fakeServer({ + initialize: () => ({ protocolVersion: '2025-11-25', serverInfo: { name: 's', version: '1' }, capabilities: {} }), + 'tools/list': () => ({ tools: [] }) + }); + + it('initialize mints a session and returns mcp-session-id header', async () => { + const session = new SessionCompat(); + const handler = shttpHandler(server, { session, enableJsonResponse: true }); + const res = await handler(post(initialize())); + expect(res.status).toBe(200); + const sid = res.headers.get('mcp-session-id'); + expect(sid).toBeTruthy(); + expect(session.size).toBe(1); + }); + + it('non-initialize without mcp-session-id → 400', async () => { + const session = new SessionCompat(); + const handler = shttpHandler(server, { session }); + const res = await handler(post(toolsList())); + expect(res.status).toBe(400); + }); + + it('wrong mcp-session-id → 404', async () => { + const session = new SessionCompat(); + const handler = shttpHandler(server, { session }); + await handler(post(initialize())); + const res = await handler(post(toolsList(), { 'mcp-session-id': 'nope' })); + expect(res.status).toBe(404); + }); + + it('correct mcp-session-id → 200', async () => { + const session = new SessionCompat(); + const handler = shttpHandler(server, { session, enableJsonResponse: true }); + const initRes = await handler(post(initialize())); + const sid = initRes.headers.get('mcp-session-id')!; + const res = await handler(post(toolsList(), { 'mcp-session-id': sid, 'mcp-protocol-version': '2025-11-25' })); + expect(res.status).toBe(200); + }); + + it('DELETE removes the session', async () => { + const session = new SessionCompat(); + const handler = shttpHandler(server, { session }); + const initRes = await handler(post(initialize())); + const sid = initRes.headers.get('mcp-session-id')!; + const del = await handler( + new Request('http://localhost/mcp', { + method: 'DELETE', + headers: { 'mcp-session-id': sid, 'mcp-protocol-version': '2025-11-25' } + }) + ); + expect(del.status).toBe(200); + expect(session.size).toBe(0); + }); + + it('rejects initialize with 503 + Retry-After when at maxSessions', async () => { + const session = new SessionCompat({ maxSessions: 1, idleTtlMs: 60_000 }); + const handler = shttpHandler(server, { session, enableJsonResponse: true }); + const r1 = await handler(post(initialize(1))); + expect(r1.status).toBe(200); + const r2 = await handler(post(initialize(2))); + // maxSessions=1 + idleTtlMs=60s: first session is fresh so LRU eviction frees nothing → cap hit. + // (SessionCompat evicts the oldest before rejecting; with a single fresh session that oldest IS evicted, + // so cap is only actually hit when eviction can't make room. Use maxSessions=0 to force.) + // Re-test with maxSessions=0 to assert the 503 path deterministically. + const session0 = new SessionCompat({ maxSessions: 0 }); + const handler0 = shttpHandler(server, { session: session0, enableJsonResponse: true }); + const r0 = await handler0(post(initialize())); + expect(r0.status).toBe(503); + expect(r0.headers.get('retry-after')).toBeTruthy(); + // r2 above will have evicted r1's session and succeeded; assert that behavior too. + expect(r2.status).toBe(200); + expect(session.size).toBe(1); + }); +}); diff --git a/packages/server/test/server/streamableHttp.test.ts b/packages/server/test/server/streamableHttp.test.ts index 7a23dd56b..df3b6d000 100644 --- a/packages/server/test/server/streamableHttp.test.ts +++ b/packages/server/test/server/streamableHttp.test.ts @@ -974,23 +974,18 @@ describe('Zod v4', () => { expect(closeCallCount).toBe(1); }); - it('should clean up all streams exactly once even when close() is called concurrently', async () => { + it('should fire onclose exactly once even when close() is called concurrently', async () => { const transport = new WebStandardStreamableHTTPServerTransport({ sessionIdGenerator: randomUUID }); - const cleanupCalls: string[] = []; - - // Inject a fake stream entry to verify cleanup runs exactly once - // @ts-expect-error accessing private map for test purposes - transport._streamMapping.set('stream-1', { - cleanup: () => { - cleanupCalls.push('stream-1'); - } - }); + let closeCount = 0; + transport.onclose = () => { + closeCount++; + }; // Fire two concurrent close() calls — only the first should proceed await Promise.all([transport.close(), transport.close()]); - expect(cleanupCalls).toEqual(['stream-1']); + expect(closeCount).toBe(1); }); }); }); diff --git a/packages/server/tsdown.config.ts b/packages/server/tsdown.config.ts index fb0cd8a93..08676f626 100644 --- a/packages/server/tsdown.config.ts +++ b/packages/server/tsdown.config.ts @@ -4,7 +4,7 @@ export default defineConfig({ failOnWarn: 'ci-only', // 1. Entry Points // Directly matches package.json include/exclude globs - entry: ['src/index.ts', 'src/shimsNode.ts', 'src/shimsWorkerd.ts', 'src/validators/cfWorker.ts'], + entry: ['src/index.ts', 'src/shimsNode.ts', 'src/shimsWorkerd.ts', 'src/validators/cfWorker.ts', 'src/zodSchemas.ts'], // 2. Output Configuration format: ['esm'], diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 899586750..935269802 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -251,7 +251,7 @@ importers: version: 4.4.4(eslint-plugin-import@2.32.0)(eslint@9.39.4) eslint-plugin-import: specifier: ^2.32.0 - version: 2.32.0(@typescript-eslint/parser@8.57.2(eslint@9.39.4)(typescript@5.9.3))(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4) + version: 2.32.0(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4) eslint-plugin-n: specifier: catalog:devTools version: 17.24.0(eslint@9.39.4)(typescript@5.9.3) @@ -862,6 +862,64 @@ importers: specifier: catalog:devTools version: 4.1.2(@opentelemetry/api@1.9.1)(@types/node@25.5.0)(vite@7.3.0(@types/node@25.5.0)(tsx@4.21.0)(yaml@2.8.3)) + packages/sdk: + dependencies: + '@modelcontextprotocol/client': + specifier: workspace:^ + version: link:../client + '@modelcontextprotocol/node': + specifier: workspace:^ + version: link:../middleware/node + '@modelcontextprotocol/server': + specifier: workspace:^ + version: link:../server + '@modelcontextprotocol/server-auth-legacy': + specifier: workspace:^ + version: link:../server-auth-legacy + express: + specifier: ^4.18.0 || ^5.0.0 + version: 5.2.1 + hono: + specifier: '*' + version: 4.12.9 + devDependencies: + '@modelcontextprotocol/core': + specifier: workspace:^ + version: link:../core + '@modelcontextprotocol/eslint-config': + specifier: workspace:^ + version: link:../../common/eslint-config + '@modelcontextprotocol/test-helpers': + specifier: workspace:^ + version: link:../../test/helpers + '@modelcontextprotocol/tsconfig': + specifier: workspace:^ + version: link:../../common/tsconfig + '@modelcontextprotocol/vitest-config': + specifier: workspace:^ + version: link:../../common/vitest-config + '@typescript/native-preview': + specifier: catalog:devTools + version: 7.0.0-dev.20260327.2 + eslint: + specifier: catalog:devTools + version: 9.39.4 + prettier: + specifier: catalog:devTools + version: 3.6.2 + tsdown: + specifier: catalog:devTools + version: 0.18.4(@typescript/native-preview@7.0.0-dev.20260327.2)(typescript@5.9.3) + typescript: + specifier: catalog:devTools + version: 5.9.3 + vitest: + specifier: catalog:devTools + version: 4.1.2(@opentelemetry/api@1.9.1)(@types/node@25.5.0)(vite@7.3.0(@types/node@25.5.0)(tsx@4.21.0)(yaml@2.8.3)) + zod: + specifier: catalog:runtimeShared + version: 4.3.6 + packages/server: dependencies: zod: @@ -929,6 +987,82 @@ importers: specifier: catalog:devTools version: 4.1.2(@opentelemetry/api@1.9.1)(@types/node@25.5.0)(vite@7.3.0(@types/node@25.5.0)(tsx@4.21.0)(yaml@2.8.3)) + packages/server-auth-legacy: + dependencies: + cors: + specifier: catalog:runtimeServerOnly + version: 2.8.6 + express-rate-limit: + specifier: ^8.2.1 + version: 8.3.1(express@5.2.1) + pkce-challenge: + specifier: catalog:runtimeShared + version: 5.0.1 + zod: + specifier: catalog:runtimeShared + version: 4.3.6 + devDependencies: + '@eslint/js': + specifier: catalog:devTools + version: 9.39.4 + '@modelcontextprotocol/core': + specifier: workspace:^ + version: link:../core + '@modelcontextprotocol/eslint-config': + specifier: workspace:^ + version: link:../../common/eslint-config + '@modelcontextprotocol/tsconfig': + specifier: workspace:^ + version: link:../../common/tsconfig + '@modelcontextprotocol/vitest-config': + specifier: workspace:^ + version: link:../../common/vitest-config + '@types/cors': + specifier: catalog:devTools + version: 2.8.19 + '@types/express': + specifier: catalog:devTools + version: 5.0.6 + '@types/express-serve-static-core': + specifier: catalog:devTools + version: 5.1.1 + '@types/supertest': + specifier: catalog:devTools + version: 6.0.3 + '@typescript/native-preview': + specifier: catalog:devTools + version: 7.0.0-dev.20260327.2 + eslint: + specifier: catalog:devTools + version: 9.39.4 + eslint-config-prettier: + specifier: catalog:devTools + version: 10.1.8(eslint@9.39.4) + eslint-plugin-n: + specifier: catalog:devTools + version: 17.24.0(eslint@9.39.4)(typescript@5.9.3) + express: + specifier: catalog:runtimeServerOnly + version: 5.2.1 + prettier: + specifier: catalog:devTools + version: 3.6.2 + supertest: + specifier: catalog:devTools + version: 7.2.2 + tsdown: + specifier: catalog:devTools + version: 0.18.4(@typescript/native-preview@7.0.0-dev.20260327.2)(typescript@5.9.3) + typescript: + specifier: catalog:devTools + version: 5.9.3 + typescript-eslint: + specifier: catalog:devTools + version: 8.57.2(eslint@9.39.4)(typescript@5.9.3) + vitest: + specifier: catalog:devTools + version: 4.1.2(@opentelemetry/api@1.9.1)(@types/node@25.5.0)(vite@7.3.0(@types/node@25.5.0)(tsx@4.21.0)(yaml@2.8.3)) + test/conformance: devDependencies: '@modelcontextprotocol/client': @@ -7074,15 +7208,14 @@ snapshots: tinyglobby: 0.2.15 unrs-resolver: 1.11.1 optionalDependencies: - eslint-plugin-import: 2.32.0(@typescript-eslint/parser@8.57.2(eslint@9.39.4)(typescript@5.9.3))(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4) + eslint-plugin-import: 2.32.0(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4) transitivePeerDependencies: - supports-color - eslint-module-utils@2.12.1(@typescript-eslint/parser@8.57.2(eslint@9.39.4)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4): + eslint-module-utils@2.12.1(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4): dependencies: debug: 3.2.7 optionalDependencies: - '@typescript-eslint/parser': 8.57.2(eslint@9.39.4)(typescript@5.9.3) eslint: 9.39.4 eslint-import-resolver-node: 0.3.9 eslint-import-resolver-typescript: 4.4.4(eslint-plugin-import@2.32.0)(eslint@9.39.4) @@ -7096,7 +7229,7 @@ snapshots: eslint: 9.39.4 eslint-compat-utils: 0.5.1(eslint@9.39.4) - eslint-plugin-import@2.32.0(@typescript-eslint/parser@8.57.2(eslint@9.39.4)(typescript@5.9.3))(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4): + eslint-plugin-import@2.32.0(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4): dependencies: '@rtsao/scc': 1.1.0 array-includes: 3.1.9 @@ -7107,7 +7240,7 @@ snapshots: doctrine: 2.1.0 eslint: 9.39.4 eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.12.1(@typescript-eslint/parser@8.57.2(eslint@9.39.4)(typescript@5.9.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4) + eslint-module-utils: 2.12.1(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@4.4.4)(eslint@9.39.4) hasown: 2.0.2 is-core-module: 2.16.1 is-glob: 4.0.3 @@ -7118,8 +7251,6 @@ snapshots: semver: 6.3.1 string.prototype.trimend: 1.0.9 tsconfig-paths: 3.15.0 - optionalDependencies: - '@typescript-eslint/parser': 8.57.2(eslint@9.39.4)(typescript@5.9.3) transitivePeerDependencies: - eslint-import-resolver-typescript - eslint-import-resolver-webpack diff --git a/test/integration/test/server/cloudflareWorkers.test.ts b/test/integration/test/server/cloudflareWorkers.test.ts index 9c2d73a40..64580d6a3 100644 --- a/test/integration/test/server/cloudflareWorkers.test.ts +++ b/test/integration/test/server/cloudflareWorkers.test.ts @@ -26,6 +26,12 @@ describe('Cloudflare Workers compatibility (no nodejs_compat)', () => { let env: TestEnv | null = null; beforeAll(async () => { + // Clear any wrangler instance leaked by a previous run before claiming the port. + try { + execSync(`lsof -ti:${PORT} -sTCP:LISTEN | xargs -r kill -9`, { stdio: 'ignore' }); + } catch { + /* nothing listening */ + } const tempDir = fs.mkdtempSync(path.join(os.tmpdir(), 'cf-worker-test-')); // Pack server package