Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ async def _read_messages(self) -> None:
if request_id not in self.pending_control_results:
self.pending_control_results[request_id] = e
event.set()
# Put error in stream so iterators can handle it
await self._message_send.send({"type": "error", "error": str(e)})
# Put error in stream so iterators can handle the original exception
await self._message_send.send(
{"type": "error", "error": str(e), "exception": e}
)
finally:
# Unblock any waiters (e.g. string-prompt path waiting for first
# result) so they don't stall for the full timeout on early exit.
Expand Down Expand Up @@ -729,6 +731,9 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]:
if message.get("type") == "end":
break
elif message.get("type") == "error":
original_exception = message.get("exception")
if isinstance(original_exception, BaseException):
raise original_exception
raise Exception(message.get("error", "Unknown error"))

yield message
Expand Down
65 changes: 65 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from unittest.mock import AsyncMock, Mock, patch

import anyio
import pytest

from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
ProcessError,
ResultMessage,
create_sdk_mcp_server,
query,
Expand Down Expand Up @@ -146,6 +148,37 @@ async def greet_tool(args):
return create_sdk_mcp_server("greeter", tools=[greet_tool])


def test_receive_messages_preserves_process_error_attributes():
"""Query.receive_messages re-raises ProcessError with its metadata intact."""

async def _test():
original_error = ProcessError(
"Command failed with exit code 1",
exit_code=1,
stderr="real stderr output",
)
mock_transport = AsyncMock()

async def mock_receive():
raise original_error
yield # pragma: no cover

mock_transport.read_messages = mock_receive
query_instance = Query(transport=mock_transport, is_streaming_mode=True)

await query_instance._read_messages()

with pytest.raises(ProcessError) as exc_info:
async for _message in query_instance.receive_messages():
pass

assert exc_info.value is original_error
assert exc_info.value.exit_code == 1
assert exc_info.value.stderr == "real stderr output"

anyio.run(_test)


class TestStringPromptWithSdkMcpServers:
"""Test that string prompts keep stdin open for SDK MCP servers."""

Expand Down Expand Up @@ -649,6 +682,38 @@ async def mock_write(data):

asyncio.run(_test())


def test_receive_messages_preserves_process_error_fields():
"""Query.receive_messages re-raises the original ProcessError."""

original_error = ProcessError(
"Command failed with exit code 1",
exit_code=1,
stderr="model alias is invalid",
)

async def read_messages():
if False:
yield {}
raise original_error

async def _test():
mock_transport = AsyncMock()
mock_transport.read_messages = read_messages
q = Query(transport=mock_transport, is_streaming_mode=True)

await q._read_messages()

with pytest.raises(ProcessError) as exc_info:
async for _ in q.receive_messages():
pass

assert exc_info.value is original_error
assert exc_info.value.exit_code == 1
assert exc_info.value.stderr == "model alias is invalid"

anyio.run(_test)

def test_cancel_request_for_unknown_id_is_noop(self):
"""A control_cancel_request for an unknown request_id should not raise."""
import asyncio
Expand Down