diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319..461df3899 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -5,6 +5,20 @@ from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError +def _restore_mcp_error(exc_type: type[MCPError], error: ErrorData) -> MCPError: + """Reconstruct a pickled MCPError or subclass from ErrorData.""" + if exc_type is UrlElicitationRequiredError: + return exc_type.from_error(error) + + if hasattr(exc_type, "from_error_data"): + return exc_type.from_error_data(error) + + restored = exc_type.__new__(exc_type) + Exception.__init__(restored, error.code, error.message, error.data) + restored.error = error + return restored + + class MCPError(Exception): """Exception type raised when an error arrives over an MCP connection.""" @@ -40,6 +54,9 @@ def from_error_data(cls, error: ErrorData) -> MCPError: def __str__(self) -> str: return self.message + def __reduce__(self) -> tuple[Any, tuple[type[MCPError], ErrorData]]: + return (_restore_mcp_error, (type(self), self.error)) + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode. diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py index 9a7466264..a4801000e 100644 --- a/tests/shared/test_exceptions.py +++ b/tests/shared/test_exceptions.py @@ -1,5 +1,7 @@ """Tests for MCP exception classes.""" +import pickle + import pytest from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError @@ -162,3 +164,41 @@ def test_url_elicitation_required_error_exception_message() -> None: # The exception's string representation should match the message assert str(error) == "URL elicitation required" + + +def test_mcp_error_pickle_roundtrip() -> None: + """Test that MCPError survives a normal pickle round-trip.""" + original = MCPError( + code=-32600, + message="Authentication Required", + data={"scope": "files.read"}, + ) + + restored = pickle.loads(pickle.dumps(original)) + + assert isinstance(restored, MCPError) + assert restored.code == -32600 + assert restored.message == "Authentication Required" + assert restored.data == {"scope": "files.read"} + assert str(restored) == "Authentication Required" + + +def test_url_elicitation_required_error_pickle_roundtrip() -> None: + """Test that specialized MCPError subclasses survive pickle too.""" + original = UrlElicitationRequiredError( + [ + ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ) + ] + ) + + restored = pickle.loads(pickle.dumps(original)) + + assert isinstance(restored, UrlElicitationRequiredError) + assert restored.elicitations[0].elicitation_id == "test-123" + assert restored.elicitations[0].url == "https://example.com/auth" + assert restored.message == "URL elicitation required"