|
4 | 4 | from mcp import Client, types |
5 | 5 | from mcp.client.session import ClientSession |
6 | 6 | from mcp.server import Server, ServerRequestContext |
| 7 | +from mcp.shared._context import RequestContext |
7 | 8 | from mcp.shared.exceptions import MCPError |
8 | 9 | from mcp.shared.memory import create_client_server_memory_streams |
9 | 10 | from mcp.shared.message import SessionMessage |
@@ -416,3 +417,216 @@ async def make_request(client_session: ClientSession): |
416 | 417 | # Pending request completed successfully |
417 | 418 | assert len(result_holder) == 1 |
418 | 419 | assert isinstance(result_holder[0], EmptyResult) |
| 420 | + |
| 421 | + |
| 422 | +@pytest.mark.anyio |
| 423 | +async def test_concurrent_server_to_client_requests_run_in_parallel(): |
| 424 | + """Regression test for #2489. |
| 425 | +
|
| 426 | + A server tool fans out N concurrent ``ServerSession.create_message`` calls |
| 427 | + via ``anyio.create_task_group``. The client sampling callback records the |
| 428 | + peak number of concurrently-in-flight calls. Before the fix, requests were |
| 429 | + serialized end-to-end by ``BaseSession._receive_loop`` and peak was 1. |
| 430 | + """ |
| 431 | + n = 4 |
| 432 | + |
| 433 | + inflight = 0 |
| 434 | + peak = 0 |
| 435 | + started = anyio.Event() |
| 436 | + release = anyio.Event() |
| 437 | + |
| 438 | + async def sampling_callback( |
| 439 | + context: RequestContext[ClientSession], |
| 440 | + params: types.CreateMessageRequestParams, |
| 441 | + ) -> types.CreateMessageResult: |
| 442 | + nonlocal inflight, peak |
| 443 | + inflight += 1 |
| 444 | + peak = max(peak, inflight) |
| 445 | + if peak == n: |
| 446 | + started.set() |
| 447 | + try: |
| 448 | + with anyio.fail_after(5): |
| 449 | + await release.wait() |
| 450 | + finally: |
| 451 | + inflight -= 1 |
| 452 | + msg = params.messages[0].content |
| 453 | + echo = msg.text if isinstance(msg, types.TextContent) else "" |
| 454 | + return types.CreateMessageResult( |
| 455 | + role="assistant", |
| 456 | + content=types.TextContent(type="text", text=f"echo:{echo}"), |
| 457 | + model="test-model", |
| 458 | + ) |
| 459 | + |
| 460 | + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: |
| 461 | + results: list[str] = [""] * n |
| 462 | + |
| 463 | + async def one(i: int) -> None: |
| 464 | + r = await ctx.session.create_message( |
| 465 | + messages=[ |
| 466 | + types.SamplingMessage( |
| 467 | + role="user", |
| 468 | + content=types.TextContent(type="text", text=str(i)), |
| 469 | + ) |
| 470 | + ], |
| 471 | + max_tokens=8, |
| 472 | + ) |
| 473 | + results[i] = r.content.text if isinstance(r.content, types.TextContent) else "" |
| 474 | + |
| 475 | + async with anyio.create_task_group() as tg: # pragma: no branch |
| 476 | + for i in range(n): |
| 477 | + tg.start_soon(one, i) |
| 478 | + return types.CallToolResult(content=[types.TextContent(type="text", text=",".join(results))]) |
| 479 | + |
| 480 | + async def handle_list_tools( |
| 481 | + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None |
| 482 | + ) -> types.ListToolsResult: |
| 483 | + return types.ListToolsResult(tools=[types.Tool(name="fanout", input_schema={"type": "object"})]) |
| 484 | + |
| 485 | + server = Server(name="fanout", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools) |
| 486 | + |
| 487 | + async with Client(server, sampling_callback=sampling_callback) as client: |
| 488 | + async with anyio.create_task_group() as tg: # pragma: no branch |
| 489 | + |
| 490 | + async def call() -> None: |
| 491 | + await client.call_tool("fanout", {}) |
| 492 | + |
| 493 | + tg.start_soon(call) |
| 494 | + with anyio.fail_after(5): |
| 495 | + await started.wait() |
| 496 | + release.set() |
| 497 | + |
| 498 | + assert peak == n, f"server->client requests were serialized: peak in-flight={peak}, expected {n}" |
| 499 | + |
| 500 | + |
| 501 | +@pytest.mark.anyio |
| 502 | +async def test_sampling_callback_exception_returns_error_response(): |
| 503 | + """A raising sampling callback must produce a JSON-RPC error response so |
| 504 | + the server-side ``await ctx.session.create_message(...)`` doesn't hang. |
| 505 | + """ |
| 506 | + |
| 507 | + async def sampling_callback( |
| 508 | + context: RequestContext[ClientSession], |
| 509 | + params: types.CreateMessageRequestParams, |
| 510 | + ) -> types.CreateMessageResult: |
| 511 | + raise RuntimeError("boom") |
| 512 | + |
| 513 | + caught: list[MCPError] = [] |
| 514 | + |
| 515 | + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: |
| 516 | + try: |
| 517 | + await ctx.session.create_message( |
| 518 | + messages=[ |
| 519 | + types.SamplingMessage( |
| 520 | + role="user", |
| 521 | + content=types.TextContent(type="text", text="x"), |
| 522 | + ) |
| 523 | + ], |
| 524 | + max_tokens=8, |
| 525 | + ) |
| 526 | + except MCPError as e: |
| 527 | + caught.append(e) |
| 528 | + return types.CallToolResult(content=[types.TextContent(type="text", text="ok")]) |
| 529 | + |
| 530 | + async def handle_list_tools( |
| 531 | + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None |
| 532 | + ) -> types.ListToolsResult: |
| 533 | + return types.ListToolsResult(tools=[types.Tool(name="boom", input_schema={"type": "object"})]) |
| 534 | + |
| 535 | + server = Server(name="raise", on_call_tool=handle_call_tool, on_list_tools=handle_list_tools) |
| 536 | + |
| 537 | + async with Client(server, sampling_callback=sampling_callback) as client: |
| 538 | + with anyio.fail_after(5): |
| 539 | + await client.call_tool("boom", {}) |
| 540 | + |
| 541 | + assert len(caught) == 1 |
| 542 | + |
| 543 | + |
| 544 | +@pytest.mark.anyio |
| 545 | +async def test_double_cancel_does_not_send_second_response(): |
| 546 | + """Cancel called twice on the same responder must not emit a second response.""" |
| 547 | + |
| 548 | + class _Dummy: |
| 549 | + _send_response_calls = 0 |
| 550 | + |
| 551 | + async def _send_response(self, *, request_id: types.RequestId, response: object) -> None: |
| 552 | + self._send_response_calls += 1 |
| 553 | + |
| 554 | + dummy = _Dummy() |
| 555 | + responder = RequestResponder[types.ServerRequest, types.ClientResult]( |
| 556 | + request_id=1, |
| 557 | + request_meta=None, |
| 558 | + request=types.PingRequest(method="ping"), |
| 559 | + session=dummy, # type: ignore[arg-type] |
| 560 | + on_complete=lambda _r: None, |
| 561 | + ) |
| 562 | + with responder: |
| 563 | + await responder.cancel() |
| 564 | + await responder.cancel() |
| 565 | + assert dummy._send_response_calls == 1 |
| 566 | + |
| 567 | + |
| 568 | +@pytest.mark.anyio |
| 569 | +async def test_cancel_before_context_entered_marks_scope_cancelled(): |
| 570 | + """Regression: with concurrent dispatch, a CancelledNotification can |
| 571 | + arrive before the handler task has entered ``with responder:``. |
| 572 | + ``cancel()`` must not raise, and the scope entered later must already |
| 573 | + be cancelled. |
| 574 | + """ |
| 575 | + |
| 576 | + class _Dummy: |
| 577 | + async def _send_response(self, *, request_id: types.RequestId, response: object) -> None: |
| 578 | + pass |
| 579 | + |
| 580 | + responder = RequestResponder[types.ServerRequest, types.ClientResult]( |
| 581 | + request_id=7, |
| 582 | + request_meta=None, |
| 583 | + request=types.PingRequest(method="ping"), |
| 584 | + session=_Dummy(), # type: ignore[arg-type] |
| 585 | + on_complete=lambda _r: None, |
| 586 | + ) |
| 587 | + |
| 588 | + await responder.cancel() |
| 589 | + assert responder.cancelled |
| 590 | + assert responder._cancel_scope.cancel_called |
| 591 | + |
| 592 | + |
| 593 | +@pytest.mark.anyio |
| 594 | +async def test_handler_that_responds_then_raises_emits_no_duplicate_error(): |
| 595 | + """If a request handler completes the response and then raises, the |
| 596 | + dispatch path must not emit a second JSON-RPC error for the same id. |
| 597 | + """ |
| 598 | + |
| 599 | + class _RaiseAfterRespond(ClientSession): |
| 600 | + async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: |
| 601 | + with responder: |
| 602 | + await responder.respond(types.EmptyResult()) |
| 603 | + raise RuntimeError("after respond") |
| 604 | + |
| 605 | + class _CapturingWrite: |
| 606 | + def __init__(self) -> None: |
| 607 | + self.sent: list[SessionMessage] = [] |
| 608 | + |
| 609 | + async def send(self, msg: SessionMessage) -> None: |
| 610 | + self.sent.append(msg) |
| 611 | + |
| 612 | + async with create_client_server_memory_streams() as (client_streams, _server_streams): |
| 613 | + client_read, client_write = client_streams |
| 614 | + session = _RaiseAfterRespond(client_read, client_write) |
| 615 | + |
| 616 | + capture = _CapturingWrite() |
| 617 | + session._write_stream = capture # type: ignore[assignment] |
| 618 | + |
| 619 | + responder = RequestResponder[types.ServerRequest, types.ClientResult]( |
| 620 | + request_id=99, |
| 621 | + request_meta=None, |
| 622 | + request=types.PingRequest(method="ping"), |
| 623 | + session=session, |
| 624 | + on_complete=lambda r: session._in_flight.pop(r.request_id, None), |
| 625 | + ) |
| 626 | + session._in_flight[99] = responder |
| 627 | + |
| 628 | + await session._dispatch_request(responder) |
| 629 | + |
| 630 | + assert len(capture.sent) == 1, capture.sent |
| 631 | + assert isinstance(capture.sent[0].message, JSONRPCResponse) |
| 632 | + assert capture.sent[0].message.id == 99 |
0 commit comments