From 693675030728d12576decd0e89a85b1b3cc3ca93 Mon Sep 17 00:00:00 2001 From: lawrence3699 Date: Sat, 18 Apr 2026 12:12:58 +1000 Subject: [PATCH] fix: preserve ProcessError context from failed CLI subprocesses --- src/claude_agent_sdk/_internal/query.py | 10 +- .../_internal/transport/subprocess_cli.py | 88 +++++++++++++++--- tests/test_query.py | 35 +++++++ tests/test_subprocess_buffering.py | 91 ++++++++++++++++++- 4 files changed, 207 insertions(+), 17 deletions(-) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index 30ad2257..b29bb047 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -15,6 +15,7 @@ ListToolsRequest, ) +from .._errors import ClaudeSDKError from ..types import ( PermissionMode, PermissionResultAllow, @@ -261,7 +262,9 @@ async def _read_messages(self) -> None: self.pending_control_results[request_id] = e event.set() # Put error in stream so iterators can handle it - await self._message_send.send({"type": "error", "error": str(e)}) + await self._message_send.send( + {"type": "error", "error": str(e), "exception": e} + ) finally: # Unblock any waiters (e.g. string-prompt path waiting for first # result) so they don't stall for the full timeout on early exit. @@ -737,7 +740,10 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: if message.get("type") == "end": break elif message.get("type") == "error": - raise Exception(message.get("error", "Unknown error")) + exc = message.get("exception") + if isinstance(exc, Exception): + raise exc + raise ClaudeSDKError(message.get("error", "Unknown error")) yield message diff --git a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py index 983eeea1..ae6422d6 100644 --- a/src/claude_agent_sdk/_internal/transport/subprocess_cli.py +++ b/src/claude_agent_sdk/_internal/transport/subprocess_cli.py @@ -1,11 +1,13 @@ """Subprocess transport implementation using Claude Code CLI.""" +import inspect import json import logging import os import platform import re import shutil +from collections import deque from collections.abc import AsyncIterable, AsyncIterator from contextlib import suppress from pathlib import Path @@ -26,6 +28,7 @@ logger = logging.getLogger(__name__) _DEFAULT_MAX_BUFFER_SIZE = 1024 * 1024 # 1MB buffer limit +_STDERR_CAPTURE_LIMIT = 8 * 1024 # Keep the tail of stderr for ProcessError. MINIMUM_CLAUDE_CODE_VERSION = "2.0.0" @@ -50,7 +53,11 @@ def __init__( self._stdout_stream: TextReceiveStream | None = None self._stdin_stream: TextSendStream | None = None self._stderr_stream: TextReceiveStream | None = None + self._stderr_process_stream: Any | None = None self._stderr_task_group: anyio.abc.TaskGroup | None = None + self._stderr_reader_finished: anyio.Event | None = None + self._stderr_buffer: deque[str] = deque() + self._stderr_buffer_size = 0 self._ready = False self._exit_error: Exception | None = None # Track process exit errors self._max_buffer_size = ( @@ -436,20 +443,18 @@ async def connect(self) -> None: if self._cwd: process_env["PWD"] = self._cwd - # Pipe stderr if we have a callback OR debug mode is enabled + # Pipe stderr so ProcessError can include the CLI's actual stderr. + # We only stream it live when callers asked for callback/debug output. should_pipe_stderr = ( self._options.stderr is not None or "debug-to-stderr" in self._options.extra_args ) - # For backward compat: use debug_stderr file object if no callback and debug is on - stderr_dest = PIPE if should_pipe_stderr else None - self._process = await anyio.open_process( cmd, stdin=PIPE, stdout=PIPE, - stderr=stderr_dest, + stderr=PIPE, cwd=self._cwd, env=process_env, user=self._options.user, @@ -458,9 +463,14 @@ async def connect(self) -> None: if self._process.stdout: self._stdout_stream = TextReceiveStream(self._process.stdout) - # Setup stderr stream if piped - if should_pipe_stderr and self._process.stderr: - self._stderr_stream = TextReceiveStream(self._process.stderr) + # Setup stderr stream if available. + if self._process.stderr: + self._stderr_process_stream = self._process.stderr + if self._stderr_process_stream and ( + should_pipe_stderr or self._supports_live_stderr_reading() + ): + self._stderr_stream = TextReceiveStream(self._stderr_process_stream) + self._stderr_reader_finished = anyio.Event() # Start async task to read stderr self._stderr_task_group = anyio.create_task_group() await self._stderr_task_group.__aenter__() @@ -488,6 +498,35 @@ async def connect(self) -> None: self._exit_error = error raise error from e + def _capture_stderr_line(self, line: str) -> None: + """Keep a bounded tail of stderr for ProcessError diagnostics.""" + self._stderr_buffer.append(line) + self._stderr_buffer_size += len(line) + 1 + + while self._stderr_buffer and self._stderr_buffer_size > _STDERR_CAPTURE_LIMIT: + removed = self._stderr_buffer.popleft() + self._stderr_buffer_size -= len(removed) + 1 + + async def _drain_stderr_stream(self, stream: AsyncIterator[str]) -> None: + """Read any remaining stderr lines into the capture buffer.""" + try: + async for line in stream: + line_str = line.rstrip() + if line_str: + self._capture_stderr_line(line_str) + except anyio.ClosedResourceError: + pass + except Exception: + pass + + def _supports_live_stderr_reading(self) -> bool: + """Return True when the stderr stream can be consumed concurrently.""" + if self._stderr_process_stream is None: + return False + + receive = getattr(self._stderr_process_stream, "receive", None) + return inspect.iscoroutinefunction(receive) + async def _handle_stderr(self) -> None: """Handle stderr stream - read and invoke callbacks.""" if not self._stderr_stream: @@ -499,15 +538,15 @@ async def _handle_stderr(self) -> None: if not line_str: continue + self._capture_stderr_line(line_str) + # Call the stderr callback if provided if self._options.stderr: self._options.stderr(line_str) - # For backward compatibility: write to debug_stderr if in debug mode - elif ( - "debug-to-stderr" in self._options.extra_args - and self._options.debug_stderr - ): + # Preserve inherited-stderr behavior by forwarding stderr to + # the configured sink (defaults to sys.stderr). + elif self._options.debug_stderr: self._options.debug_stderr.write(line_str + "\n") if hasattr(self._options.debug_stderr, "flush"): self._options.debug_stderr.flush() @@ -515,6 +554,9 @@ async def _handle_stderr(self) -> None: pass # Stream closed, exit normally except Exception: pass # Ignore other errors during stderr reading + finally: + if self._stderr_reader_finished is not None: + self._stderr_reader_finished.set() async def close(self) -> None: """Close the transport and clean up resources.""" @@ -528,6 +570,7 @@ async def close(self) -> None: self._stderr_task_group.cancel_scope.cancel() await self._stderr_task_group.__aexit__(None, None, None) self._stderr_task_group = None + self._stderr_reader_finished = None # Close stdin stream (acquire lock to prevent race with concurrent writes) async with self._write_lock: @@ -541,6 +584,7 @@ async def close(self) -> None: with suppress(Exception): await self._stderr_stream.aclose() self._stderr_stream = None + self._stderr_process_stream = None # Wait for graceful shutdown after stdin EOF, then terminate if needed. # The subprocess needs time to flush its session file after receiving @@ -568,6 +612,10 @@ async def close(self) -> None: self._stdout_stream = None self._stdin_stream = None self._stderr_stream = None + self._stderr_process_stream = None + self._stderr_reader_finished = None + self._stderr_buffer.clear() + self._stderr_buffer_size = 0 self._exit_error = None async def write(self, data: str) -> None: @@ -678,10 +726,22 @@ async def _read_messages_impl(self) -> AsyncIterator[dict[str, Any]]: # Use exit code for error detection if returncode is not None and returncode != 0: + if self._stderr_task_group is None: + if self._stderr_stream is not None: + await self._drain_stderr_stream(self._stderr_stream) + elif self._stderr_process_stream is not None: + stderr_stream = TextReceiveStream(self._stderr_process_stream) + await self._drain_stderr_stream(stderr_stream) + with suppress(Exception): + await stderr_stream.aclose() + elif self._stderr_reader_finished is not None: + await self._stderr_reader_finished.wait() + + stderr_output = "\n".join(self._stderr_buffer) or None self._exit_error = ProcessError( f"Command failed with exit code {returncode}", exit_code=returncode, - stderr="Check stderr output for details", + stderr=stderr_output, ) raise self._exit_error diff --git a/tests/test_query.py b/tests/test_query.py index 1fcb4dea..2b1a53fd 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -12,6 +12,7 @@ from unittest.mock import AsyncMock, Mock, patch import anyio +import pytest from claude_agent_sdk import ( AssistantMessage, @@ -21,6 +22,7 @@ query, tool, ) +from claude_agent_sdk._errors import ProcessError from claude_agent_sdk._internal.query import Query from claude_agent_sdk.types import HookMatcher @@ -726,3 +728,36 @@ async def _test(): assert "fast_1" not in q._inflight_requests asyncio.run(_test()) + + def test_receive_messages_preserves_process_error(self): + """read task failures should surface the original ProcessError.""" + import asyncio + + async def _test(): + transport = AsyncMock() + transport.is_ready = Mock(return_value=True) + transport.close = AsyncMock() + + async def failing_messages(): + raise ProcessError( + "Command failed with exit code 1", + exit_code=1, + stderr="invalid --model alias", + ) + yield # pragma: no cover + + transport.read_messages = failing_messages + + q = Query(transport=transport, is_streaming_mode=True) + await q.start() + + with pytest.raises(ProcessError) as exc_info: + async for _ in q.receive_messages(): + pass + + await q.close() + + assert exc_info.value.exit_code == 1 + assert exc_info.value.stderr == "invalid --model alias" + + asyncio.run(_test()) diff --git a/tests/test_subprocess_buffering.py b/tests/test_subprocess_buffering.py index dc3ca45d..c0ead4d8 100644 --- a/tests/test_subprocess_buffering.py +++ b/tests/test_subprocess_buffering.py @@ -2,13 +2,16 @@ import json from collections.abc import AsyncIterator +from io import StringIO +from subprocess import PIPE from typing import Any from unittest.mock import AsyncMock, MagicMock import anyio import pytest +from anyio.streams.text import TextReceiveStream -from claude_agent_sdk._errors import CLIJSONDecodeError +from claude_agent_sdk._errors import CLIJSONDecodeError, ProcessError from claude_agent_sdk._internal.transport.subprocess_cli import ( _DEFAULT_MAX_BUFFER_SIZE, SubprocessCLITransport, @@ -383,3 +386,89 @@ async def _test() -> None: assert messages[1]["type"] == "result" anyio.run(_test) + + def test_nonzero_exit_includes_captured_stderr(self) -> None: + """ProcessError should surface the stderr emitted by the CLI.""" + + async def _test() -> None: + transport = SubprocessCLITransport(prompt="test", options=make_options()) + + mock_process = MagicMock() + mock_process.returncode = None + mock_process.wait = AsyncMock(return_value=1) + transport._process = mock_process + transport._stdout_stream = MockTextReceiveStream([]) + transport._stderr_stream = MockTextReceiveStream( + ["error: invalid --model alias", "hint: run claude --help"] + ) + + with pytest.raises(ProcessError) as exc_info: + async for _ in transport.read_messages(): + pass + + assert exc_info.value.exit_code == 1 + assert exc_info.value.stderr == ( + "error: invalid --model alias\nhint: run claude --help" + ) + + anyio.run(_test) + + def test_stderr_is_forwarded_to_sink_while_buffering(self) -> None: + """Captured stderr should still be forwarded to the configured sink.""" + + async def _test() -> None: + sink = StringIO() + transport = SubprocessCLITransport( + prompt="test", options=make_options(debug_stderr=sink) + ) + transport._stderr_stream = MockTextReceiveStream(["warning: deprecated flag"]) + + await transport._handle_stderr() + + assert sink.getvalue() == "warning: deprecated flag\n" + + anyio.run(_test) + + def test_nonzero_exit_waits_for_live_stderr_reader(self) -> None: + """ProcessError should include stderr captured by the background reader.""" + + async def _test() -> None: + import sys + + process = await anyio.open_process( + [ + sys.executable, + "-c", + ( + "import sys; " + "sys.stderr.write('error: invalid --model alias\\n'); " + "sys.stderr.write('hint: run claude --help\\n'); " + "sys.exit(1)" + ), + ], + stdout=PIPE, + stderr=PIPE, + ) + + transport = SubprocessCLITransport(prompt="test", options=make_options()) + transport._process = process + transport._stdout_stream = TextReceiveStream(process.stdout) + transport._stderr_process_stream = process.stderr + transport._stderr_stream = TextReceiveStream(process.stderr) + transport._stderr_reader_finished = anyio.Event() + transport._stderr_task_group = anyio.create_task_group() + await transport._stderr_task_group.__aenter__() + transport._stderr_task_group.start_soon(transport._handle_stderr) + + with pytest.raises(ProcessError) as exc_info: + async for _ in transport.read_messages(): + pass + + assert exc_info.value.exit_code == 1 + assert exc_info.value.stderr == ( + "error: invalid --model alias\nhint: run claude --help" + ) + + await transport.close() + + anyio.run(_test)