Skip to content

Commit 0ec94b4

Browse files
committed
feat: graceful SSE drain on session manager shutdown
Terminate all active transports before cancelling the task group during StreamableHTTPSessionManager shutdown. This closes their in-memory streams, allowing EventSourceResponse to send a final `more_body=False` chunk — a clean HTTP close instead of a connection reset. Without this, reverse proxies like nginx see "upstream prematurely closed connection" and return 502 to clients during rolling deploys. Changes: - Track in-flight stateless transports in `_stateless_transports` set - In `run()` finally block, call `terminate()` on all stateful and stateless transports before `tg.cancel_scope.cancel()` - Add E2E tests for both stateless and stateful modes that verify the SSE stream closes cleanly when the manager shuts down while a tool call is in-flight
1 parent 0fe16dd commit 0ec94b4

2 files changed

Lines changed: 229 additions & 4 deletions

File tree

src/mcp/server/streamable_http_manager.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
9292

93+
# Track in-flight stateless transports for graceful shutdown
94+
self._stateless_transports: set[StreamableHTTPServerTransport] = set()
95+
9396
# The task group will be set during lifespan
9497
self._task_group = None
9598
# Thread-safe tracking of run() calls
@@ -130,11 +133,28 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
130133
yield # Let the application run
131134
finally:
132135
logger.info("StreamableHTTP session manager shutting down")
136+
137+
# Terminate all active transports before cancelling the task
138+
# group. This closes their in-memory streams, which lets
139+
# EventSourceResponse send a final ``more_body=False`` chunk
140+
# — a clean HTTP close instead of a connection reset.
141+
for transport in list(self._server_instances.values()):
142+
try:
143+
await transport.terminate()
144+
except Exception:
145+
logger.debug("Error terminating transport during shutdown", exc_info=True)
146+
for transport in list(self._stateless_transports):
147+
try:
148+
await transport.terminate()
149+
except Exception:
150+
logger.debug("Error terminating stateless transport during shutdown", exc_info=True)
151+
133152
# Cancel task group to stop all spawned tasks
134153
tg.cancel_scope.cancel()
135154
self._task_group = None
136155
# Clear any remaining server instances
137156
self._server_instances.clear()
157+
self._stateless_transports.clear()
138158

139159
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140160
"""Process ASGI request with proper session handling and transport setup.
@@ -161,6 +181,9 @@ async def _handle_stateless_request(self, scope: Scope, receive: Receive, send:
161181
security_settings=self.security_settings,
162182
)
163183

184+
# Track for graceful shutdown
185+
self._stateless_transports.add(http_transport)
186+
164187
# Start server in a new task
165188
async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED):
166189
async with http_transport.connect() as streams:
@@ -181,8 +204,11 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
181204
# Start the server task
182205
await self._task_group.start(run_stateless_server)
183206

184-
# Handle the HTTP request and return the response
185-
await http_transport.handle_request(scope, receive, send)
207+
try:
208+
# Handle the HTTP request and return the response
209+
await http_transport.handle_request(scope, receive, send)
210+
finally:
211+
self._stateless_transports.discard(http_transport)
186212

187213
# Terminate the transport after the request is handled
188214
await http_transport.terminate()

tests/server/test_streamable_http_manager.py

Lines changed: 201 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
import pytest
1010
from starlette.types import Message
1111

12-
from mcp import Client
12+
from mcp import Client, types
1313
from mcp.client.streamable_http import streamable_http_client
1414
from mcp.server import Server, ServerRequestContext, streamable_http_manager
1515
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport
16-
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
16+
from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager
1717
from mcp.types import INVALID_REQUEST, ListToolsResult, PaginatedRequestParams
1818

1919

@@ -410,3 +410,202 @@ def test_session_idle_timeout_rejects_non_positive():
410410
def test_session_idle_timeout_rejects_stateless():
411411
with pytest.raises(RuntimeError, match="not supported in stateless"):
412412
StreamableHTTPSessionManager(app=Server("test"), session_idle_timeout=30, stateless=True)
413+
414+
415+
MCP_HEADERS = {
416+
"Accept": "application/json, text/event-stream",
417+
"Content-Type": "application/json",
418+
}
419+
420+
_INITIALIZE_REQUEST = {
421+
"jsonrpc": "2.0",
422+
"id": 1,
423+
"method": "initialize",
424+
"params": {
425+
"protocolVersion": "2025-03-26",
426+
"capabilities": {},
427+
"clientInfo": {"name": "test", "version": "0.1"},
428+
},
429+
}
430+
431+
_INITIALIZED_NOTIFICATION = {
432+
"jsonrpc": "2.0",
433+
"method": "notifications/initialized",
434+
}
435+
436+
_TOOL_CALL_REQUEST = {
437+
"jsonrpc": "2.0",
438+
"id": 2,
439+
"method": "tools/call",
440+
"params": {"name": "slow_tool", "arguments": {"message": "hello"}},
441+
}
442+
443+
444+
def _make_slow_tool_server() -> tuple[Server, anyio.Event]:
445+
"""Create an MCP server with a tool that blocks forever, returning
446+
the server and an event that fires when the tool starts executing."""
447+
tool_started = anyio.Event()
448+
449+
async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
450+
tool_started.set()
451+
await anyio.sleep_forever()
452+
return types.CallToolResult( # pragma: no cover
453+
content=[types.TextContent(type="text", text="never reached")]
454+
)
455+
456+
async def handle_list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | None) -> ListToolsResult:
457+
return ListToolsResult(
458+
tools=[
459+
types.Tool(
460+
name="slow_tool",
461+
description="A tool that blocks forever",
462+
inputSchema={"type": "object", "properties": {"message": {"type": "string"}}},
463+
)
464+
]
465+
)
466+
467+
app = Server("test-graceful-shutdown", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools)
468+
return app, tool_started
469+
470+
471+
@pytest.mark.anyio
472+
async def test_graceful_shutdown_terminates_active_stateless_transports():
473+
"""Verify that shutting down the session manager terminates in-flight
474+
stateless transports so SSE streams close cleanly (``more_body=False``)
475+
instead of being abruptly cancelled.
476+
477+
This prevents "upstream prematurely closed connection" errors at reverse
478+
proxies like nginx.
479+
"""
480+
app, tool_started = _make_slow_tool_server()
481+
manager = StreamableHTTPSessionManager(app=app, stateless=True)
482+
483+
mcp_app = StreamableHTTPASGIApp(manager)
484+
485+
manager_ready = anyio.Event()
486+
shutdown_event = anyio.Event()
487+
stream_outcome: str | None = None
488+
489+
async with anyio.create_task_group() as tg:
490+
491+
async def run_lifespan_and_shutdown():
492+
async with manager.run():
493+
manager_ready.set()
494+
with anyio.fail_after(5):
495+
await tool_started.wait()
496+
shutdown_event.set()
497+
498+
async def make_requests():
499+
nonlocal stream_outcome
500+
with anyio.fail_after(5):
501+
await manager_ready.wait()
502+
async with (
503+
httpx.ASGITransport(mcp_app) as transport,
504+
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
505+
):
506+
# Initialize
507+
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
508+
resp.raise_for_status()
509+
510+
# Send initialized notification
511+
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=MCP_HEADERS)
512+
assert resp.status_code == 202
513+
514+
# Send slow tool call — this returns an SSE stream
515+
try:
516+
async with client.stream(
517+
"POST",
518+
"/mcp/",
519+
json=_TOOL_CALL_REQUEST,
520+
headers=MCP_HEADERS,
521+
timeout=httpx.Timeout(10, connect=5),
522+
) as stream:
523+
stream.raise_for_status()
524+
async for _chunk in stream.aiter_bytes():
525+
pass
526+
stream_outcome = "clean"
527+
except httpx.RemoteProtocolError:
528+
stream_outcome = "reset"
529+
530+
tg.start_soon(run_lifespan_and_shutdown)
531+
tg.start_soon(make_requests)
532+
533+
with anyio.fail_after(10):
534+
await shutdown_event.wait()
535+
536+
tg.cancel_scope.cancel()
537+
538+
assert stream_outcome == "clean", f"Expected clean HTTP close, got {stream_outcome}"
539+
540+
541+
@pytest.mark.anyio
542+
async def test_graceful_shutdown_terminates_active_stateful_transports():
543+
"""Verify that shutting down the session manager terminates in-flight
544+
stateful transports so SSE streams close cleanly."""
545+
app, tool_started = _make_slow_tool_server()
546+
manager = StreamableHTTPSessionManager(app=app, stateless=False)
547+
548+
mcp_app = StreamableHTTPASGIApp(manager)
549+
550+
manager_ready = anyio.Event()
551+
shutdown_event = anyio.Event()
552+
stream_outcome: str | None = None
553+
554+
async with anyio.create_task_group() as tg:
555+
556+
async def run_lifespan_and_shutdown():
557+
async with manager.run():
558+
manager_ready.set()
559+
with anyio.fail_after(5):
560+
await tool_started.wait()
561+
shutdown_event.set()
562+
563+
async def make_requests():
564+
nonlocal stream_outcome
565+
with anyio.fail_after(5):
566+
await manager_ready.wait()
567+
async with (
568+
httpx.ASGITransport(mcp_app) as transport,
569+
httpx.AsyncClient(transport=transport, base_url="http://testserver") as client,
570+
):
571+
# Initialize (creates a session)
572+
resp = await client.post("/mcp/", json=_INITIALIZE_REQUEST, headers=MCP_HEADERS)
573+
resp.raise_for_status()
574+
session_id = resp.headers.get(MCP_SESSION_ID_HEADER)
575+
assert session_id is not None
576+
577+
session_headers = {
578+
**MCP_HEADERS,
579+
MCP_SESSION_ID_HEADER: session_id,
580+
"mcp-protocol-version": "2025-03-26",
581+
}
582+
583+
# Send initialized notification
584+
resp = await client.post("/mcp/", json=_INITIALIZED_NOTIFICATION, headers=session_headers)
585+
assert resp.status_code == 202
586+
587+
# Send slow tool call
588+
try:
589+
async with client.stream(
590+
"POST",
591+
"/mcp/",
592+
json=_TOOL_CALL_REQUEST,
593+
headers=session_headers,
594+
timeout=httpx.Timeout(10, connect=5),
595+
) as stream:
596+
stream.raise_for_status()
597+
async for _chunk in stream.aiter_bytes():
598+
pass
599+
stream_outcome = "clean"
600+
except httpx.RemoteProtocolError:
601+
stream_outcome = "reset"
602+
603+
tg.start_soon(run_lifespan_and_shutdown)
604+
tg.start_soon(make_requests)
605+
606+
with anyio.fail_after(10):
607+
await shutdown_event.wait()
608+
609+
tg.cancel_scope.cancel()
610+
611+
assert stream_outcome == "clean", f"Expected clean HTTP close, got {stream_outcome}"

0 commit comments

Comments
 (0)