diff --git a/databricks-mcp-server/databricks_mcp_server/middleware.py b/databricks-mcp-server/databricks_mcp_server/middleware.py index 71514694..129b26ff 100644 --- a/databricks-mcp-server/databricks_mcp_server/middleware.py +++ b/databricks-mcp-server/databricks_mcp_server/middleware.py @@ -9,6 +9,7 @@ import logging import traceback +from fastmcp.exceptions import ToolError from fastmcp.server.middleware import Middleware, MiddlewareContext, CallNext from fastmcp.tools.tool import ToolResult from mcp.types import CallToolRequestParams, TextContent @@ -70,24 +71,24 @@ async def on_call_tool( # In Python 3.11+, asyncio.TimeoutError is an alias for TimeoutError, # so this single handler catches both logger.warning( - "Tool '%s' timed out. Returning structured result.", + "Tool '%s' timed out. Raising ToolError.", tool_name, ) - # Don't set structured_content for errors - it would be validated against - # the tool's outputSchema and fail (error dict doesn't match expected type) - return ToolResult( - content=[TextContent(type="text", text=json.dumps({ - "error": True, - "error_type": "timeout", - "tool": tool_name, - "message": str(e) or "Operation timed out", - "action_required": ( - "Operation may still be in progress. " - "Do NOT retry the same call. " - "Use the appropriate get/status tool to check current state." - ), - }))] - ) + # Raise ToolError so the MCP SDK sets isError=True on the response, + # which bypasses outputSchema validation. Returning a ToolResult here + # would be treated as a success and fail validation when outputSchema + # is defined (e.g., tools with -> Dict[str, Any] return type). + raise ToolError(json.dumps({ + "error": True, + "error_type": "timeout", + "tool": tool_name, + "message": str(e) or "Operation timed out", + "action_required": ( + "Operation may still be in progress. " + "Do NOT retry the same call. " + "Use the appropriate get/status tool to check current state." + ), + })) from e except anyio.get_cancelled_exc_class(): # Re-raise CancelledError so MCP SDK's handler catches it and skips @@ -110,14 +111,13 @@ async def on_call_tool( traceback.format_exc(), ) - # Return error as text content only - don't set structured_content. - # Setting structured_content would cause MCP SDK to validate it against - # the tool's outputSchema, which fails (error dict doesn't match expected type). - return ToolResult( - content=[TextContent(type="text", text=json.dumps({ - "error": True, - "error_type": type(e).__name__, - "tool": tool_name, - "message": str(e), - }))] - ) + # Raise ToolError so the MCP SDK sets isError=True on the response, + # which bypasses outputSchema validation. Returning a ToolResult here + # would be treated as a success and fail validation when outputSchema + # is defined (e.g., tools with -> Dict[str, Any] return type). + raise ToolError(json.dumps({ + "error": True, + "error_type": type(e).__name__, + "tool": tool_name, + "message": str(e), + })) from e diff --git a/databricks-mcp-server/tests/test_middleware.py b/databricks-mcp-server/tests/test_middleware.py index da6ab3b6..0dabbdff 100644 --- a/databricks-mcp-server/tests/test_middleware.py +++ b/databricks-mcp-server/tests/test_middleware.py @@ -6,6 +6,8 @@ import pytest +from fastmcp.exceptions import ToolError + from databricks_mcp_server.middleware import TimeoutHandlingMiddleware @@ -36,17 +38,15 @@ async def test_normal_call_passes_through(middleware): @pytest.mark.asyncio -async def test_timeout_error_returns_structured_result(middleware): - """TimeoutError is caught and converted to a structured JSON result.""" +async def test_timeout_error_raises_tool_error(middleware): + """TimeoutError is caught and re-raised as ToolError with structured JSON.""" call_next = AsyncMock(side_effect=TimeoutError("Run did not complete within 3600 seconds")) ctx = _make_context(tool_name="wait_for_run") - result = await middleware.on_call_tool(ctx, call_next) - - assert result is not None - assert len(result.content) == 1 + with pytest.raises(ToolError) as exc_info: + await middleware.on_call_tool(ctx, call_next) - payload = json.loads(result.content[0].text) + payload = json.loads(str(exc_info.value)) assert payload["error"] is True assert payload["error_type"] == "timeout" assert payload["tool"] == "wait_for_run" @@ -55,46 +55,40 @@ async def test_timeout_error_returns_structured_result(middleware): @pytest.mark.asyncio -async def test_asyncio_timeout_error_returns_structured_result(middleware): - """asyncio.TimeoutError is caught and converted to a structured JSON result.""" +async def test_asyncio_timeout_error_raises_tool_error(middleware): + """asyncio.TimeoutError is caught and re-raised as ToolError with structured JSON.""" call_next = AsyncMock(side_effect=asyncio.TimeoutError()) ctx = _make_context(tool_name="long_running_tool") - result = await middleware.on_call_tool(ctx, call_next) + with pytest.raises(ToolError) as exc_info: + await middleware.on_call_tool(ctx, call_next) - assert result is not None - payload = json.loads(result.content[0].text) + payload = json.loads(str(exc_info.value)) assert payload["error"] is True assert payload["error_type"] == "timeout" assert payload["tool"] == "long_running_tool" @pytest.mark.asyncio -async def test_cancelled_error_returns_structured_result(middleware): - """asyncio.CancelledError is caught and converted to a structured JSON result.""" +async def test_cancelled_error_is_reraised(middleware): + """asyncio.CancelledError is re-raised to let MCP SDK handle cleanup.""" call_next = AsyncMock(side_effect=asyncio.CancelledError()) ctx = _make_context(tool_name="cancelled_tool") - result = await middleware.on_call_tool(ctx, call_next) - - assert result is not None - payload = json.loads(result.content[0].text) - assert payload["error"] is True - assert payload["error_type"] == "cancelled" - assert payload["tool"] == "cancelled_tool" + with pytest.raises(asyncio.CancelledError): + await middleware.on_call_tool(ctx, call_next) @pytest.mark.asyncio -async def test_generic_exception_returns_structured_result(middleware): - """Generic exceptions are caught and converted to structured JSON results.""" +async def test_generic_exception_raises_tool_error(middleware): + """Generic exceptions are caught and re-raised as ToolError with structured JSON.""" call_next = AsyncMock(side_effect=ValueError("bad input")) ctx = _make_context(tool_name="failing_tool") - result = await middleware.on_call_tool(ctx, call_next) + with pytest.raises(ToolError) as exc_info: + await middleware.on_call_tool(ctx, call_next) - # Should return a ToolResult, not raise - assert result is not None - payload = json.loads(result.content[0].text) + payload = json.loads(str(exc_info.value)) assert payload["error"] is True assert payload["error_type"] == "ValueError" assert payload["tool"] == "failing_tool"