Skip to content

Commit bd5faee

Browse files
committed
Fix stdio_server closing process stdio
1 parent 3d7b311 commit bd5faee

2 files changed

Lines changed: 72 additions & 9 deletions

File tree

src/mcp/server/stdio.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ async def run_server():
1717
```
1818
"""
1919

20+
import os
2021
import sys
2122
from contextlib import asynccontextmanager
22-
from io import TextIOWrapper
23+
from io import TextIOWrapper, UnsupportedOperation
2324

2425
import anyio
2526
import anyio.lowlevel
@@ -29,6 +30,20 @@ async def run_server():
2930
from mcp.shared.message import SessionMessage
3031

3132

33+
def _wrap_stdio_text_stream(stream: TextIOWrapper, mode: str, errors: str = "strict") -> anyio.AsyncFile[str]:
34+
"""Wrap a stdio text stream without closing the original handle on teardown."""
35+
try:
36+
wrapped_stream = TextIOWrapper(
37+
os.fdopen(os.dup(stream.fileno()), mode, closefd=True),
38+
encoding="utf-8",
39+
errors=errors,
40+
)
41+
except (AttributeError, UnsupportedOperation):
42+
wrapped_stream = TextIOWrapper(stream.buffer, encoding="utf-8", errors=errors)
43+
44+
return anyio.wrap_file(wrapped_stream)
45+
46+
3247
@asynccontextmanager
3348
async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.AsyncFile[str] | None = None):
3449
"""Server transport for stdio: this communicates with an MCP client by reading
@@ -38,10 +53,13 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.
3853
# standard process handles. Encoding of stdin/stdout as text streams on
3954
# python is platform-dependent (Windows is particularly problematic), so we
4055
# re-wrap the underlying binary stream to ensure UTF-8.
56+
close_stdin = stdin is None
57+
close_stdout = stdout is None
58+
4159
if not stdin:
42-
stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace"))
60+
stdin = _wrap_stdio_text_stream(sys.stdin, "rb", errors="replace")
4361
if not stdout:
44-
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
62+
stdout = _wrap_stdio_text_stream(sys.stdout, "wb")
4563

4664
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
4765
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
@@ -71,7 +89,13 @@ async def stdout_writer():
7189
except anyio.ClosedResourceError: # pragma: no cover
7290
await anyio.lowlevel.checkpoint()
7391

74-
async with anyio.create_task_group() as tg:
75-
tg.start_soon(stdin_reader)
76-
tg.start_soon(stdout_writer)
77-
yield read_stream, write_stream
92+
try:
93+
async with anyio.create_task_group() as tg:
94+
tg.start_soon(stdin_reader)
95+
tg.start_soon(stdout_writer)
96+
yield read_stream, write_stream
97+
finally:
98+
if close_stdin:
99+
await stdin.aclose()
100+
if close_stdout:
101+
await stdout.aclose()

tests/server/test_stdio.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import sys
3+
import tempfile
34
from io import TextIOWrapper
45

56
import anyio
@@ -73,12 +74,15 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
7374
"""
7475
# \xff\xfe are invalid UTF-8 start bytes.
7576
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
76-
raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
77+
raw_stdin = tempfile.TemporaryFile("w+b")
78+
raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
79+
raw_stdin.seek(0)
80+
raw_stdout = tempfile.TemporaryFile("w+b")
7781

7882
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
7983
# stdio_server()'s default path wraps it with errors='replace'.
8084
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
81-
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
85+
monkeypatch.setattr(sys, "stdout", TextIOWrapper(raw_stdout, encoding="utf-8"))
8286

8387
with anyio.fail_after(5):
8488
async with stdio_server() as (read_stream, write_stream):
@@ -92,3 +96,38 @@ async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch):
9296
second = await read_stream.receive()
9397
assert isinstance(second, SessionMessage)
9498
assert second.message == valid
99+
100+
sys.stdin.close()
101+
sys.stdout.close()
102+
103+
104+
@pytest.mark.anyio
105+
async def test_stdio_server_does_not_close_process_stdio(monkeypatch: pytest.MonkeyPatch):
106+
"""Default stdio_server() teardown must not close the caller's stdio handles."""
107+
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
108+
raw_stdin = tempfile.TemporaryFile("w+b")
109+
raw_stdin.write(valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
110+
raw_stdin.seek(0)
111+
raw_stdout = tempfile.TemporaryFile("w+b")
112+
113+
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
114+
monkeypatch.setattr(sys, "stdout", TextIOWrapper(raw_stdout, encoding="utf-8"))
115+
116+
with anyio.fail_after(5):
117+
async with stdio_server() as (read_stream, write_stream):
118+
await write_stream.aclose()
119+
async with read_stream: # pragma: no branch
120+
received = await read_stream.receive()
121+
assert isinstance(received, SessionMessage)
122+
assert received.message == valid
123+
124+
assert not sys.stdin.closed
125+
assert not sys.stdout.closed
126+
127+
sys.stdout.write("still-open")
128+
sys.stdout.flush()
129+
raw_stdout.seek(0)
130+
assert raw_stdout.read() == b"still-open"
131+
132+
sys.stdin.close()
133+
sys.stdout.close()

0 commit comments

Comments
 (0)