|
2 | 2 |
|
3 | 3 | import anyio |
4 | 4 | import pytest |
| 5 | +from pydantic import BaseModel |
5 | 6 |
|
| 7 | +from mcp import types |
| 8 | +from mcp.client.session import KNOWN_SERVER_REQUEST_METHODS, ClientSession |
6 | 9 | from mcp.server.models import InitializationOptions |
7 | 10 | from mcp.server.session import ServerSession |
8 | 11 | from mcp.shared.message import SessionMessage |
| 12 | +from mcp.shared.session import BaseSession, request_methods_for_union |
9 | 13 | from mcp.types import METHOD_NOT_FOUND, JSONRPCError, JSONRPCRequest, ServerCapabilities |
10 | 14 |
|
11 | 15 |
|
@@ -49,3 +53,64 @@ async def test_invalid_method_returns_method_not_found() -> None: |
49 | 53 | await write_send_stream.aclose() |
50 | 54 | await read_receive_stream.aclose() |
51 | 55 | await write_receive_stream.aclose() |
| 56 | + |
| 57 | + |
| 58 | +class MissingDefaultMethodRequest(BaseModel): |
| 59 | + jsonrpc: str = "2.0" |
| 60 | + id: int = 1 |
| 61 | + method: str |
| 62 | + |
| 63 | + |
| 64 | +def test_request_methods_for_union_ignores_non_literal_defaults() -> None: |
| 65 | + methods = request_methods_for_union(types.ServerRequest | MissingDefaultMethodRequest) |
| 66 | + assert methods == KNOWN_SERVER_REQUEST_METHODS |
| 67 | + |
| 68 | + |
| 69 | +@pytest.mark.anyio |
| 70 | +async def test_client_session_known_request_methods_match_server_request_union() -> None: |
| 71 | + read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 72 | + write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10) |
| 73 | + |
| 74 | + try: |
| 75 | + session = ClientSession(read_stream=read_receive_stream, write_stream=write_send_stream) |
| 76 | + assert session._known_request_methods == KNOWN_SERVER_REQUEST_METHODS |
| 77 | + finally: # pragma: no cover |
| 78 | + await read_send_stream.aclose() |
| 79 | + await write_send_stream.aclose() |
| 80 | + await read_receive_stream.aclose() |
| 81 | + await write_receive_stream.aclose() |
| 82 | + |
| 83 | + |
| 84 | +class DummyBaseSession( |
| 85 | + BaseSession[ |
| 86 | + types.ClientRequest, |
| 87 | + types.ClientNotification, |
| 88 | + types.ClientResult, |
| 89 | + types.ServerRequest, |
| 90 | + types.ServerNotification, |
| 91 | + ] |
| 92 | +): |
| 93 | + @property |
| 94 | + def _receive_request_adapter(self): |
| 95 | + return types.server_request_adapter |
| 96 | + |
| 97 | + @property |
| 98 | + def _receive_notification_adapter(self): |
| 99 | + return types.server_notification_adapter |
| 100 | + |
| 101 | + |
| 102 | +@pytest.mark.anyio |
| 103 | +async def test_base_session_known_request_methods_default_to_empty() -> None: |
| 104 | + read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 105 | + write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10) |
| 106 | + |
| 107 | + try: |
| 108 | + session = DummyBaseSession(read_stream=read_receive_stream, write_stream=write_send_stream) |
| 109 | + assert session._known_request_methods == frozenset() |
| 110 | + assert session._receive_request_adapter is types.server_request_adapter |
| 111 | + assert session._receive_notification_adapter is types.server_notification_adapter |
| 112 | + finally: # pragma: no cover |
| 113 | + await read_send_stream.aclose() |
| 114 | + await write_send_stream.aclose() |
| 115 | + await read_receive_stream.aclose() |
| 116 | + await write_receive_stream.aclose() |
0 commit comments