diff --git a/src/claude_agent_sdk/__init__.py b/src/claude_agent_sdk/__init__.py index 6403415b..6b255b51 100644 --- a/src/claude_agent_sdk/__init__.py +++ b/src/claude_agent_sdk/__init__.py @@ -24,6 +24,7 @@ CLIJSONDecodeError, CLINotFoundError, ProcessError, + RateLimitError, ) from ._internal.session_mutations import ( ForkSessionResult, @@ -604,4 +605,5 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> Any: "CLINotFoundError", "ProcessError", "CLIJSONDecodeError", + "RateLimitError", ] diff --git a/src/claude_agent_sdk/_errors.py b/src/claude_agent_sdk/_errors.py index c86bf235..8d086690 100644 --- a/src/claude_agent_sdk/_errors.py +++ b/src/claude_agent_sdk/_errors.py @@ -54,3 +54,17 @@ class MessageParseError(ClaudeSDKError): def __init__(self, message: str, data: dict[str, Any] | None = None): self.data = data super().__init__(message) + + +class RateLimitError(ClaudeSDKError): + """Raised when the API returns a 429 rate limit error.""" + + def __init__( + self, + message: str, + retry_after: float | None = None, + original_error: Exception | None = None, + ): + self.retry_after = retry_after + self.original_error = original_error + super().__init__(message) diff --git a/src/claude_agent_sdk/_internal/client.py b/src/claude_agent_sdk/_internal/client.py index 76323323..0c6a40df 100644 --- a/src/claude_agent_sdk/_internal/client.py +++ b/src/claude_agent_sdk/_internal/client.py @@ -1,11 +1,16 @@ """Internal client implementation.""" +import asyncio import json +import logging import os +import random +import re from collections.abc import AsyncIterable, AsyncIterator from dataclasses import asdict, replace from typing import Any +from .._errors import ProcessError, RateLimitError from ..types import ( ClaudeAgentOptions, HookEvent, @@ -17,6 +22,36 @@ from .transport import Transport from .transport.subprocess_cli import SubprocessCLITransport +logger = logging.getLogger(__name__) + + +def _is_rate_limit_error(error: Exception) -> tuple[bool, float | None]: + """Detect if an error is a 429 rate limit error.""" + error_str = str(error) + + if "rate_limit_error" in error_str or "429" in error_str: + retry_after: float | None = None + match = re.search(r'"retryAfter"\s*:\s*(\d+(?:\.\d+)?)', error_str) + if match: + retry_after = float(match.group(1)) + else: + match = re.search( + r'Retry-After["\s:]+(\d+(?:\.\d+)?)', error_str, re.IGNORECASE + ) + if match: + retry_after = float(match.group(1)) + return True, retry_after + + if hasattr(error, "stderr") and error.stderr: + stderr_str = str(error.stderr) + if "rate_limit_error" in stderr_str or "429" in stderr_str: + match = re.search(r'"retryAfter"\s*:\s*(\d+(?:\.\d+)?)', stderr_str) + if match: + return True, float(match.group(1)) + return True, None + + return False, None + class InternalClient: """Internal client implementation.""" @@ -48,8 +83,7 @@ async def process_query( options: ClaudeAgentOptions, transport: Transport | None = None, ) -> AsyncIterator[Message]: - """Process a query through transport and Query.""" - + """Process a query through transport and Query with automatic 429 retry.""" # Validate and configure permission settings (matching TypeScript SDK logic) configured_options = options if options.can_use_tool: @@ -70,94 +104,142 @@ async def process_query( # Automatically set permission_prompt_tool_name to "stdio" for control protocol configured_options = replace(options, permission_prompt_tool_name="stdio") - # Use provided transport or create subprocess transport - if transport is not None: - chosen_transport = transport - else: - chosen_transport = SubprocessCLITransport( - prompt=prompt, - options=configured_options, - ) + max_retries = configured_options.rate_limit_max_retries + attempt = 0 + + while True: + is_retry = attempt > 0 + chosen_transport: Transport + query: Query | None = None + + try: + # Use provided transport or create subprocess transport + if transport is not None: + chosen_transport = transport + else: + chosen_transport = SubprocessCLITransport( + prompt=prompt, + options=configured_options, + ) + await chosen_transport.connect() + + # Extract SDK MCP servers from configured options + sdk_mcp_servers = {} + if configured_options.mcp_servers and isinstance( + configured_options.mcp_servers, dict + ): + for name, config in configured_options.mcp_servers.items(): + if isinstance(config, dict) and config.get("type") == "sdk": + sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item] + + # Extract exclude_dynamic_sections from preset system prompt for the + # initialize request (older CLIs ignore unknown initialize fields). + exclude_dynamic_sections: bool | None = None + sp = configured_options.system_prompt + if isinstance(sp, dict) and sp.get("type") == "preset": + eds = sp.get("exclude_dynamic_sections") + if isinstance(eds, bool): + exclude_dynamic_sections = eds + + # Convert agents to dict format for initialize request + agents_dict = None + if configured_options.agents: + agents_dict = { + name: { + k: v for k, v in asdict(agent_def).items() if v is not None + } + for name, agent_def in configured_options.agents.items() + } + + # Match ClaudeSDKClient.connect() — without this, query() ignores the env var + initialize_timeout_ms = int( + os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000") + ) + initialize_timeout = max(initialize_timeout_ms / 1000.0, 60.0) + + # Create Query to handle control protocol + # Always use streaming mode internally (matching TypeScript SDK) + # This ensures agents are always sent via initialize request + query = Query( + transport=chosen_transport, + is_streaming_mode=True, # Always streaming internally + can_use_tool=configured_options.can_use_tool, + hooks=self._convert_hooks_to_internal_format( + configured_options.hooks + ) + if configured_options.hooks + else None, + sdk_mcp_servers=sdk_mcp_servers, + initialize_timeout=initialize_timeout, + agents=agents_dict, + exclude_dynamic_sections=exclude_dynamic_sections, + ) - # Connect transport - await chosen_transport.connect() - - # Extract SDK MCP servers from configured options - sdk_mcp_servers = {} - if configured_options.mcp_servers and isinstance( - configured_options.mcp_servers, dict - ): - for name, config in configured_options.mcp_servers.items(): - if isinstance(config, dict) and config.get("type") == "sdk": - sdk_mcp_servers[name] = config["instance"] # type: ignore[typeddict-item] - - # Extract exclude_dynamic_sections from preset system prompt for the - # initialize request (older CLIs ignore unknown initialize fields). - exclude_dynamic_sections: bool | None = None - sp = configured_options.system_prompt - if isinstance(sp, dict) and sp.get("type") == "preset": - eds = sp.get("exclude_dynamic_sections") - if isinstance(eds, bool): - exclude_dynamic_sections = eds - - # Convert agents to dict format for initialize request - agents_dict = None - if configured_options.agents: - agents_dict = { - name: {k: v for k, v in asdict(agent_def).items() if v is not None} - for name, agent_def in configured_options.agents.items() - } - - # Match ClaudeSDKClient.connect() — without this, query() ignores the env var - initialize_timeout_ms = int( - os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000") - ) - initialize_timeout = max(initialize_timeout_ms / 1000.0, 60.0) - - # Create Query to handle control protocol - # Always use streaming mode internally (matching TypeScript SDK) - # This ensures agents are always sent via initialize request - query = Query( - transport=chosen_transport, - is_streaming_mode=True, # Always streaming internally - can_use_tool=configured_options.can_use_tool, - hooks=self._convert_hooks_to_internal_format(configured_options.hooks) - if configured_options.hooks - else None, - sdk_mcp_servers=sdk_mcp_servers, - initialize_timeout=initialize_timeout, - agents=agents_dict, - exclude_dynamic_sections=exclude_dynamic_sections, - ) - - try: - # Start reading messages - await query.start() - - # Always initialize to send agents via stdin (matching TypeScript SDK) - await query.initialize() - - # Handle prompt input - if isinstance(prompt, str): - # For string prompts, write user message to stdin after initialize - # (matching TypeScript SDK behavior) - user_message = { - "type": "user", - "session_id": "", - "message": {"role": "user", "content": prompt}, - "parent_tool_use_id": None, - } - await chosen_transport.write(json.dumps(user_message) + "\n") - query.spawn_task(query.wait_for_result_and_end_input()) - elif isinstance(prompt, AsyncIterable): - # Stream input in background for async iterables - query.spawn_task(query.stream_input(prompt)) - - # Yield parsed messages, skipping unknown message types - async for data in query.receive_messages(): - message = parse_message(data) - if message is not None: - yield message - - finally: - await query.close() + # Start reading messages + # Start reading messages + await query.start() + + # Always initialize to send agents via stdin (matching TypeScript SDK) + await query.initialize() + + # Handle prompt input + if isinstance(prompt, str): + # For string prompts, write user message to stdin after initialize + # (matching TypeScript SDK behavior) + user_message = { + "type": "user", + "session_id": "", + "message": {"role": "user", "content": prompt}, + "parent_tool_use_id": None, + } + await chosen_transport.write(json.dumps(user_message) + "\n") + query.spawn_task(query.wait_for_result_and_end_input()) + elif isinstance(prompt, AsyncIterable): + # Stream input in background for async iterables + query.spawn_task(query.stream_input(prompt)) + + # Yield parsed messages, skipping unknown message types + async for data in query.receive_messages(): + message = parse_message(data) + if message is not None: + yield message + + return + + except ProcessError as e: + is_rl, retry_after = _is_rate_limit_error(e) + + if is_rl and attempt < max_retries: + attempt += 1 + if retry_after is None: + base_delay = min(2.0 * (2 ** (attempt - 1)), 60.0) + delay = base_delay + random.uniform(0, 1) + else: + delay = retry_after + + logger.warning( + "Rate limit hit (attempt %d/%d). Retrying in %.1fs.", + attempt, + max_retries, + delay, + ) + + if query is not None: + await query.close() + elif chosen_transport is not None: + await chosen_transport.close() + + await asyncio.sleep(delay) + continue + + if is_rl: + raise RateLimitError( + str(e), + retry_after=retry_after, + original_error=e, + ) from e + raise + + finally: + if query is not None: + await query.close() diff --git a/src/claude_agent_sdk/types.py b/src/claude_agent_sdk/types.py index a82a8b9b..f0b5a4d1 100644 --- a/src/claude_agent_sdk/types.py +++ b/src/claude_agent_sdk/types.py @@ -1246,6 +1246,10 @@ class ClaudeAgentOptions: # its remaining token budget so it can pace tool use and wrap up before # the limit. task_budget: TaskBudget | None = None + # Maximum number of automatic retries on 429 rate limit errors (default: 3). + # Set to 0 to disable automatic retries. Each retry waits with exponential + # backoff. When a Retry-After header is present, that value is used instead. + rate_limit_max_retries: int = 3 # SDK Control Protocol diff --git a/tests/test_rate_limit_retry.py b/tests/test_rate_limit_retry.py new file mode 100644 index 00000000..4c9aaf25 --- /dev/null +++ b/tests/test_rate_limit_retry.py @@ -0,0 +1,131 @@ +"""Tests for 429 rate limit retry with exponential backoff.""" + +from claude_agent_sdk import ClaudeAgentOptions +from claude_agent_sdk._errors import ProcessError, RateLimitError + + +class TestIsRateLimitError: + """Test _is_rate_limit_error helper function.""" + + def test_detects_rate_limit_in_error_message(self): + """Error message containing rate_limit_error is detected.""" + from claude_agent_sdk._internal.client import _is_rate_limit_error + + error = Exception( + 'API Error: 429 {"type":"error","error":{"type":"rate_limit_error","message":"Rate limit exceeded"}}' + ) + is_rl, retry_after = _is_rate_limit_error(error) + assert is_rl is True + assert retry_after is None + + def test_detects_429_in_error_message(self): + """Error message containing 429 is detected.""" + from claude_agent_sdk._internal.client import _is_rate_limit_error + + error = Exception("Command failed with exit code 1: 429 rate limit") + is_rl, retry_after = _is_rate_limit_error(error) + assert is_rl is True + assert retry_after is None + + def test_parses_retry_after_from_error_message(self): + """retryAfter field is extracted from error message.""" + from claude_agent_sdk._internal.client import _is_rate_limit_error + + error = Exception('{"error":{"type":"rate_limit_error","retryAfter":45}}') + is_rl, retry_after = _is_rate_limit_error(error) + assert is_rl is True + assert retry_after == 45.0 + + def test_parses_retry_after_float(self): + """retryAfter field handles float values.""" + from claude_agent_sdk._internal.client import _is_rate_limit_error + + error = Exception('{"error":{"type":"rate_limit_error","retryAfter":12.5}}') + is_rl, retry_after = _is_rate_limit_error(error) + assert is_rl is True + assert retry_after == 12.5 + + def test_detects_rate_limit_from_stderr_attribute(self): + """Error with stderr attribute containing rate_limit_error is detected.""" + from claude_agent_sdk._internal.client import _is_rate_limit_error + + class ErrorWithStderr(Exception): + stderr = ( + '{"type":"error","error":{"type":"rate_limit_error","retryAfter":30}}' + ) + + error = ErrorWithStderr("Process failed") + is_rl, retry_after = _is_rate_limit_error(error) + assert is_rl is True + assert retry_after == 30.0 + + def test_non_rate_limit_error_returns_false(self): + """Generic errors without rate limit indicators return False.""" + from claude_agent_sdk._internal.client import _is_rate_limit_error + + error = Exception("Something went wrong") + is_rl, retry_after = _is_rate_limit_error(error) + assert is_rl is False + assert retry_after is None + + def test_process_error_without_rate_limit_returns_false(self): + """ProcessError without rate limit indicators returns False.""" + from claude_agent_sdk._internal.client import _is_rate_limit_error + + error = ProcessError("Process exited with code 1", exit_code=1) + is_rl, retry_after = _is_rate_limit_error(error) + assert is_rl is False + assert retry_after is None + + +class TestRateLimitRetryOptions: + """Test rate_limit_max_retries option.""" + + def test_rate_limit_max_retries_option(self): + """ClaudeAgentOptions accepts rate_limit_max_retries.""" + opts = ClaudeAgentOptions(rate_limit_max_retries=5) + assert opts.rate_limit_max_retries == 5 + + def test_rate_limit_max_retries_default(self): + """rate_limit_max_retries defaults to 3.""" + opts = ClaudeAgentOptions() + assert opts.rate_limit_max_retries == 3 + + def test_rate_limit_max_retries_zero_disables_retry(self): + """rate_limit_max_retries=0 disables automatic retry.""" + opts = ClaudeAgentOptions(rate_limit_max_retries=0) + assert opts.rate_limit_max_retries == 0 + + +class TestRateLimitError: + """Test RateLimitError exception class.""" + + def test_rate_limit_error_attributes(self): + """RateLimitError stores retry_after and original_error.""" + original = ProcessError("429", exit_code=1) + error = RateLimitError( + "Rate limit exceeded", retry_after=30.0, original_error=original + ) + + assert error.retry_after == 30.0 + assert error.original_error is original + assert "Rate limit exceeded" in str(error) + + def test_rate_limit_error_inherits_from_claude_sdk_error(self): + """RateLimitError is a subclass of ClaudeSDKError.""" + from claude_agent_sdk._errors import ClaudeSDKError + + error = RateLimitError("test") + assert isinstance(error, ClaudeSDKError) + + def test_rate_limit_error_without_retry_after(self): + """RateLimitError works without retry_after.""" + error = RateLimitError("Rate limit exceeded") + assert error.retry_after is None + assert error.original_error is None + + def test_rate_limit_error_repr(self): + """RateLimitError message includes original error info.""" + original = ProcessError("429", exit_code=1) + error = RateLimitError("Rate limit exceeded", original_error=original) + assert "Rate limit exceeded" in str(error)