diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319..e06325f2e 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -104,3 +104,6 @@ def from_error(cls, error: ErrorData) -> UrlElicitationRequiredError: raw_elicitations = cast(list[dict[str, Any]], data.get("elicitations", [])) elicitations = [ElicitRequestURLParams.model_validate(e) for e in raw_elicitations] return cls(elicitations, error.message) + + def __reduce__(self) -> tuple[Any, ...]: + return (self.from_error, (self.error,)) diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py index 9a7466264..d9e1f9384 100644 --- a/tests/shared/test_exceptions.py +++ b/tests/shared/test_exceptions.py @@ -1,5 +1,7 @@ """Tests for MCP exception classes.""" +import pickle + import pytest from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError @@ -162,3 +164,34 @@ def test_url_elicitation_required_error_exception_message() -> None: # The exception's string representation should match the message assert str(error) == "URL elicitation required" + + +def test_mcp_error_pickle_roundtrip() -> None: + """Test that MCPError survives a pickle round-trip.""" + original = MCPError(code=-32600, message="Authentication Required") + restored = pickle.loads(pickle.dumps(original)) + + assert type(restored) is MCPError + assert restored.code == -32600 + assert restored.message == "Authentication Required" + assert restored.error.code == -32600 + assert restored.error.message == "Authentication Required" + + +def test_url_elicitation_required_error_pickle_roundtrip() -> None: + """Test that UrlElicitationRequiredError survives a pickle round-trip.""" + elicitation = ElicitRequestURLParams( + mode="url", + message="Auth required", + url="https://example.com/auth", + elicitation_id="test-123", + ) + original = UrlElicitationRequiredError([elicitation]) + restored = pickle.loads(pickle.dumps(original)) + + assert type(restored) is UrlElicitationRequiredError + assert restored.message == "URL elicitation required" + assert restored.error.code == URL_ELICITATION_REQUIRED + assert len(restored.elicitations) == 1 + assert restored.elicitations[0].elicitation_id == "test-123" + assert restored.elicitations[0].url == "https://example.com/auth"