|
| 1 | +import asyncio |
1 | 2 | import json |
2 | 3 | import logging |
3 | 4 | from dataclasses import dataclass |
|
12 | 13 | from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata |
13 | 14 | from mcp.server.session import ServerSessionT |
14 | 15 | from mcp.shared.context import LifespanContextT, RequestT |
15 | | -from mcp.types import TextContent, ToolAnnotations |
| 16 | +from mcp.shared.exceptions import McpError |
| 17 | +from mcp.types import REQUEST_TIMEOUT, TextContent, ToolAnnotations |
16 | 18 |
|
17 | 19 |
|
18 | 20 | class TestAddTools: |
@@ -920,3 +922,159 @@ def test_func() -> str: # pragma: no cover |
920 | 922 | # Remove with correct case |
921 | 923 | manager.remove_tool("test_func") |
922 | 924 | assert manager.get_tool("test_func") is None |
| 925 | + |
| 926 | + |
| 927 | +class TestToolTimeout: |
| 928 | + """Test timeout behavior for tool execution.""" |
| 929 | + |
| 930 | + @pytest.mark.anyio |
| 931 | + async def test_tool_timeout_exceeded(self): |
| 932 | + """Test that a slow tool times out and raises McpError with REQUEST_TIMEOUT code.""" |
| 933 | + |
| 934 | + async def slow_tool(duration: float) -> str: # pragma: no cover |
| 935 | + """A tool that sleeps for the specified duration.""" |
| 936 | + await asyncio.sleep(duration) |
| 937 | + return "completed" |
| 938 | + |
| 939 | + manager = ToolManager(timeout_seconds=0.1) # 100ms timeout |
| 940 | + manager.add_tool(slow_tool) |
| 941 | + |
| 942 | + # Tool should timeout after 100ms |
| 943 | + with pytest.raises(McpError) as exc_info: |
| 944 | + await manager.call_tool("slow_tool", {"duration": 1.0}) # Try to sleep for 1 second |
| 945 | + |
| 946 | + # Verify the error code is REQUEST_TIMEOUT |
| 947 | + assert exc_info.value.error.code == REQUEST_TIMEOUT |
| 948 | + assert "slow_tool" in exc_info.value.error.message |
| 949 | + assert "exceeded timeout" in exc_info.value.error.message |
| 950 | + |
| 951 | + @pytest.mark.anyio |
| 952 | + async def test_tool_completes_before_timeout(self): |
| 953 | + """Test that a fast tool completes successfully before timeout.""" |
| 954 | + |
| 955 | + async def fast_tool(value: str) -> str: |
| 956 | + """A tool that completes quickly.""" |
| 957 | + await asyncio.sleep(0.01) # 10ms |
| 958 | + return f"processed: {value}" |
| 959 | + |
| 960 | + manager = ToolManager(timeout_seconds=1.0) # 1 second timeout |
| 961 | + manager.add_tool(fast_tool) |
| 962 | + |
| 963 | + # Tool should complete successfully |
| 964 | + result = await manager.call_tool("fast_tool", {"value": "test"}) |
| 965 | + assert result == "processed: test" |
| 966 | + |
| 967 | + @pytest.mark.anyio |
| 968 | + async def test_tool_without_timeout(self): |
| 969 | + """Test that tools work normally when timeout is None.""" |
| 970 | + |
| 971 | + async def slow_tool(duration: float) -> str: |
| 972 | + """A tool that can take any amount of time.""" |
| 973 | + await asyncio.sleep(duration) |
| 974 | + return "completed" |
| 975 | + |
| 976 | + manager = ToolManager(timeout_seconds=None) # No timeout |
| 977 | + manager.add_tool(slow_tool) |
| 978 | + |
| 979 | + # Tool should complete without timeout even if slow |
| 980 | + result = await manager.call_tool("slow_tool", {"duration": 0.2}) |
| 981 | + assert result == "completed" |
| 982 | + |
| 983 | + @pytest.mark.anyio |
| 984 | + async def test_sync_tool_timeout(self): |
| 985 | + """Test that synchronous tools also respect timeout.""" |
| 986 | + import time |
| 987 | + |
| 988 | + def slow_sync_tool(duration: float) -> str: # pragma: no cover |
| 989 | + """A synchronous tool that sleeps.""" |
| 990 | + time.sleep(duration) |
| 991 | + return "completed" |
| 992 | + |
| 993 | + manager = ToolManager(timeout_seconds=0.1) # 100ms timeout |
| 994 | + manager.add_tool(slow_sync_tool) |
| 995 | + |
| 996 | + # Sync tool should also timeout |
| 997 | + with pytest.raises(McpError) as exc_info: |
| 998 | + await manager.call_tool("slow_sync_tool", {"duration": 1.0}) |
| 999 | + |
| 1000 | + assert exc_info.value.error.code == REQUEST_TIMEOUT |
| 1001 | + |
| 1002 | + @pytest.mark.anyio |
| 1003 | + async def test_timeout_with_context_injection(self): |
| 1004 | + """Test that timeout works correctly with context injection.""" |
| 1005 | + |
| 1006 | + async def slow_tool_with_context( |
| 1007 | + duration: float, ctx: Context[ServerSessionT, None] |
| 1008 | + ) -> str: # pragma: no cover |
| 1009 | + """A tool with context that times out.""" |
| 1010 | + await asyncio.sleep(duration) |
| 1011 | + return "completed" |
| 1012 | + |
| 1013 | + manager = ToolManager(timeout_seconds=0.1) |
| 1014 | + manager.add_tool(slow_tool_with_context) |
| 1015 | + |
| 1016 | + mcp = FastMCP() |
| 1017 | + ctx = mcp.get_context() |
| 1018 | + |
| 1019 | + # Tool should timeout even with context injection |
| 1020 | + with pytest.raises(McpError) as exc_info: |
| 1021 | + await manager.call_tool("slow_tool_with_context", {"duration": 1.0}, context=ctx) |
| 1022 | + |
| 1023 | + assert exc_info.value.error.code == REQUEST_TIMEOUT |
| 1024 | + |
| 1025 | + @pytest.mark.anyio |
| 1026 | + async def test_tool_error_not_confused_with_timeout(self): |
| 1027 | + """Test that regular tool errors are not confused with timeout errors.""" |
| 1028 | + |
| 1029 | + async def failing_tool(should_fail: bool) -> str: |
| 1030 | + """A tool that raises an error.""" |
| 1031 | + if should_fail: |
| 1032 | + raise ValueError("Tool failed intentionally") |
| 1033 | + return "success" |
| 1034 | + |
| 1035 | + manager = ToolManager(timeout_seconds=1.0) |
| 1036 | + manager.add_tool(failing_tool) |
| 1037 | + |
| 1038 | + # Regular errors should still be ToolError, not timeout |
| 1039 | + with pytest.raises(ToolError, match="Error executing tool failing_tool"): |
| 1040 | + await manager.call_tool("failing_tool", {"should_fail": True}) |
| 1041 | + |
| 1042 | + @pytest.mark.anyio |
| 1043 | + async def test_fastmcp_timeout_setting(self): |
| 1044 | + """Test that FastMCP passes timeout setting to ToolManager.""" |
| 1045 | + |
| 1046 | + async def slow_tool() -> str: # pragma: no cover |
| 1047 | + """A slow tool.""" |
| 1048 | + await asyncio.sleep(1.0) |
| 1049 | + return "completed" |
| 1050 | + |
| 1051 | + # Create FastMCP with custom timeout |
| 1052 | + app = FastMCP(tool_timeout_seconds=0.1) |
| 1053 | + |
| 1054 | + @app.tool() |
| 1055 | + async def test_tool() -> str: # pragma: no cover |
| 1056 | + """Test tool.""" |
| 1057 | + await asyncio.sleep(1.0) |
| 1058 | + return "completed" |
| 1059 | + |
| 1060 | + # Tool should timeout based on FastMCP setting |
| 1061 | + with pytest.raises(McpError) as exc_info: |
| 1062 | + await app._tool_manager.call_tool("test_tool", {}) |
| 1063 | + |
| 1064 | + assert exc_info.value.error.code == REQUEST_TIMEOUT |
| 1065 | + |
| 1066 | + @pytest.mark.anyio |
| 1067 | + async def test_fastmcp_no_timeout(self): |
| 1068 | + """Test that FastMCP works with timeout disabled.""" |
| 1069 | + |
| 1070 | + app = FastMCP(tool_timeout_seconds=None) |
| 1071 | + |
| 1072 | + @app.tool() |
| 1073 | + async def slow_tool() -> str: |
| 1074 | + """A slow tool.""" |
| 1075 | + await asyncio.sleep(0.2) |
| 1076 | + return "completed" |
| 1077 | + |
| 1078 | + # Tool should complete without timeout |
| 1079 | + result = await app._tool_manager.call_tool("slow_tool", {}) |
| 1080 | + assert result == "completed" |
0 commit comments