diff --git a/README.md b/README.md index 4cd441e..dfe9c55 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,7 @@ Then start `opencode-a2a` against that upstream: ```bash A2A_BEARER_TOKEN=dev-token \ OPENCODE_BASE_URL=http://127.0.0.1:4096 \ +A2A_TASK_STORE_DATABASE_URL=sqlite+aiosqlite:///./opencode-a2a.db \ A2A_HOST=127.0.0.1 \ A2A_PORT=8000 \ A2A_PUBLIC_URL=http://127.0.0.1:8000 \ diff --git a/docs/guide.md b/docs/guide.md index 1e43dde..33969ee 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -69,6 +69,10 @@ Key variables to understand protocol behavior: `session.abort` in cancel flow. - `OPENCODE_TIMEOUT` / `OPENCODE_TIMEOUT_STREAM`: upstream request timeout and optional stream timeout override. +- `OPENCODE_MAX_CONCURRENT_REQUESTS`: optional fast-fail concurrency limit for + unary/control upstream calls. `0` disables the limit. +- `OPENCODE_MAX_CONCURRENT_STREAMS`: optional fast-fail concurrency limit for + long-lived upstream `/event` streams. `0` disables the limit. - `A2A_CLIENT_TIMEOUT_SECONDS`: outbound client timeout. Default: `30` seconds. - `A2A_CLIENT_CARD_FETCH_TIMEOUT_SECONDS`: outbound Agent Card fetch timeout. Default: `5` seconds. @@ -110,10 +114,10 @@ Current client facade API: - `A2AClient.cancel_task()` - `A2AClient.resubscribe_task()` -Server-side outbound peer calls use bearer auth only for now. Configure -`A2A_CLIENT_BEARER_TOKEN` when the remote agent protects its runtime surface. -CLI outbound calls may pass `--token` explicitly or use -`A2A_CLIENT_BEARER_TOKEN`. +Server-side outbound peer calls read outbound credentials from environment +variables. Configure `A2A_CLIENT_BEARER_TOKEN` or `A2A_CLIENT_BASIC_AUTH` when +the remote agent protects its runtime surface. CLI outbound calls follow the +same environment-only model. Execution-boundary metadata is intentionally declarative deployment metadata: it is published through `RuntimeProfile`, Agent Card, OpenAPI, and `/health`, diff --git a/src/opencode_a2a/config.py b/src/opencode_a2a/config.py index 367fff2..acc2634 100644 --- a/src/opencode_a2a/config.py +++ b/src/opencode_a2a/config.py @@ -80,6 +80,16 @@ class Settings(BaseSettings): opencode_variant: str | None = Field(default=None, alias="OPENCODE_VARIANT") opencode_timeout: float = Field(default=120.0, alias="OPENCODE_TIMEOUT") opencode_timeout_stream: float | None = Field(default=None, alias="OPENCODE_TIMEOUT_STREAM") + opencode_max_concurrent_requests: int = Field( + default=0, + ge=0, + alias="OPENCODE_MAX_CONCURRENT_REQUESTS", + ) + opencode_max_concurrent_streams: int = Field( + default=0, + ge=0, + alias="OPENCODE_MAX_CONCURRENT_STREAMS", + ) # A2A settings a2a_public_url: str = Field(default="http://127.0.0.1:8000", alias="A2A_PUBLIC_URL") diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index 2de13ea..9150c23 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -28,7 +28,11 @@ TextPart, ) -from ..opencode_upstream_client import OpencodeUpstreamClient, UpstreamContractError +from ..opencode_upstream_client import ( + OpencodeUpstreamClient, + UpstreamConcurrencyLimitError, + UpstreamContractError, +) from ..parts.mapping import ( UnsupportedA2AInputError, extract_text_from_a2a_parts, @@ -274,6 +278,17 @@ async def run(self) -> None: error_type="UPSTREAM_PAYLOAD_ERROR", streaming_request=self._prepared.streaming_request, ) + except UpstreamConcurrencyLimitError as exc: + logger.warning("OpenCode request rejected by concurrency budget: %s", exc) + await self._executor._emit_error( + self._event_queue, + task_id=self._task_id, + context_id=self._context_id, + message=str(exc), + state=TaskState.failed, + error_type="UPSTREAM_BACKPRESSURE", + streaming_request=self._prepared.streaming_request, + ) except Exception as exc: logger.exception("OpenCode request failed") await self._executor._emit_error( diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index 1d1f047..0b2419e 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -21,7 +21,11 @@ PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES, SESSION_QUERY_ERROR_BUSINESS_CODES, ) -from ..opencode_upstream_client import OpencodeUpstreamClient, UpstreamContractError +from ..opencode_upstream_client import ( + OpencodeUpstreamClient, + UpstreamConcurrencyLimitError, + UpstreamContractError, +) from .error_responses import ( interrupt_not_found_error, invalid_params_error, @@ -344,6 +348,14 @@ async def _handle_session_query_request( base_request.id, upstream_unreachable_error(ERR_UPSTREAM_UNREACHABLE), ) + except UpstreamConcurrencyLimitError as exc: + return self._generate_error_response( + base_request.id, + upstream_unreachable_error( + ERR_UPSTREAM_UNREACHABLE, + detail=str(exc), + ), + ) except Exception as exc: logger.exception("OpenCode session query JSON-RPC method failed") return self._generate_error_response( @@ -470,6 +482,15 @@ async def _handle_provider_discovery_request( method=base_request.method, ), ) + except UpstreamConcurrencyLimitError as exc: + return self._generate_error_response( + base_request.id, + upstream_unreachable_error( + ERR_DISCOVERY_UPSTREAM_UNREACHABLE, + method=base_request.method, + detail=str(exc), + ), + ) except Exception as exc: logger.exception("OpenCode provider discovery JSON-RPC method failed") return self._generate_error_response( @@ -702,6 +723,17 @@ def _log_shell_audit(outcome: str) -> None: session_id=session_id, ), ) + except UpstreamConcurrencyLimitError as exc: + _log_shell_audit("upstream_backpressure") + return self._generate_error_response( + base_request.id, + upstream_unreachable_error( + ERR_UPSTREAM_UNREACHABLE, + method=base_request.method, + session_id=session_id, + detail=str(exc), + ), + ) except UpstreamContractError as exc: _log_shell_audit("upstream_payload_error") return self._generate_error_response( @@ -903,6 +935,15 @@ async def _handle_interrupt_callback_request( request_id=request_id, ), ) + except UpstreamConcurrencyLimitError as exc: + return self._generate_error_response( + base_request.id, + upstream_unreachable_error( + ERR_UPSTREAM_UNREACHABLE, + request_id=request_id, + detail=str(exc), + ), + ) except Exception as exc: logger.exception("OpenCode interrupt callback JSON-RPC method failed") return self._generate_error_response( diff --git a/src/opencode_a2a/jsonrpc/error_responses.py b/src/opencode_a2a/jsonrpc/error_responses.py index 0dd0e11..8b9a3ac 100644 --- a/src/opencode_a2a/jsonrpc/error_responses.py +++ b/src/opencode_a2a/jsonrpc/error_responses.py @@ -93,6 +93,7 @@ def upstream_unreachable_error( method: str | None = None, session_id: str | None = None, request_id: str | None = None, + detail: str | None = None, ) -> JSONRPCError: data: dict[str, Any] = {"type": "UPSTREAM_UNREACHABLE"} if method is not None: @@ -101,6 +102,8 @@ def upstream_unreachable_error( data["session_id"] = session_id if request_id is not None: data["request_id"] = request_id + if detail is not None: + data["detail"] = detail return JSONRPCError(code=code, message="Upstream OpenCode unreachable", data=data) diff --git a/src/opencode_a2a/opencode_upstream_client.py b/src/opencode_a2a/opencode_upstream_client.py index 7b64af5..0af00d5 100644 --- a/src/opencode_a2a/opencode_upstream_client.py +++ b/src/opencode_a2a/opencode_upstream_client.py @@ -5,6 +5,7 @@ import logging import time from collections.abc import AsyncIterator, Mapping, Sequence +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any @@ -23,6 +24,19 @@ class UpstreamContractError(RuntimeError): """Raised when upstream returns a shape/status that violates documented contract.""" +class UpstreamConcurrencyLimitError(RuntimeError): + """Raised when the local upstream concurrency budget is exhausted.""" + + def __init__(self, *, category: str, operation: str, limit: int) -> None: + self.category = category + self.operation = operation + self.limit = limit + super().__init__( + f"OpenCode upstream {category} concurrency limit exceeded " + f"while calling {operation} (limit={limit})" + ) + + @dataclass(frozen=True) class OpencodeMessage: text: str @@ -31,6 +45,49 @@ class OpencodeMessage: raw: dict[str, Any] +class _FastFailConcurrencyBudget: + def __init__(self, *, category: str, limit: int) -> None: + self._category = category + self._limit = max(0, int(limit)) + self._inflight = 0 + self._lock = asyncio.Lock() + + @property + def limit(self) -> int: + return self._limit + + @asynccontextmanager + async def reserve(self, *, operation: str) -> AsyncIterator[None]: + if self._limit <= 0: + yield + return + + async with self._lock: + inflight = self._inflight + if inflight >= self._limit: + logger.warning( + "OpenCode upstream concurrency limit exceeded " + "category=%s operation=%s limit=%s inflight=%s", + self._category, + operation, + self._limit, + inflight, + ) + raise UpstreamConcurrencyLimitError( + category=self._category, + operation=operation, + limit=self._limit, + ) + self._inflight += 1 + + try: + yield + finally: + async with self._lock: + if self._inflight > 0: + self._inflight -= 1 + + class OpencodeUpstreamClient: def __init__( self, @@ -59,6 +116,14 @@ def __init__( clock=self._interrupt_request_clock, ) ) + self._request_budget = _FastFailConcurrencyBudget( + category="request", + limit=settings.opencode_max_concurrent_requests, + ) + self._stream_budget = _FastFailConcurrencyBudget( + category="stream", + limit=settings.opencode_max_concurrent_streams, + ) self._client = self._build_http_client(self._base_url) def _sync_interrupt_clock(self) -> None: @@ -113,9 +178,10 @@ async def _get_json( endpoint: str, params: Mapping[str, Any] | None = None, ) -> Any: - response = await self._client.get(path, params=params) - response.raise_for_status() - return self._decode_json_response(response, endpoint=endpoint) + async with self._request_budget.reserve(operation=endpoint): + response = await self._client.get(path, params=params) + response.raise_for_status() + return self._decode_json_response(response, endpoint=endpoint) async def _post_json( self, @@ -131,13 +197,14 @@ async def _post_json( request_kwargs["json"] = json_body if timeout is not _UNSET: request_kwargs["timeout"] = timeout - response = await self._client.post( - path, - params=params, - **request_kwargs, - ) - response.raise_for_status() - return self._decode_json_response(response, endpoint=endpoint) + async with self._request_budget.reserve(operation=endpoint): + response = await self._client.post( + path, + params=params, + **request_kwargs, + ) + response.raise_for_status() + return self._decode_json_response(response, endpoint=endpoint) async def _post_boolean( self, @@ -263,37 +330,38 @@ async def stream_events( self, stop_event: asyncio.Event | None = None, *, directory: str | None = None ) -> AsyncIterator[dict[str, Any]]: params = self._query_params(directory=directory) - async with self._client.stream( - "GET", - "/event", - params=params, - timeout=None, - headers={"Accept": "text/event-stream"}, - ) as response: - response.raise_for_status() - data_lines: list[str] = [] - async for line in response.aiter_lines(): - if stop_event and stop_event.is_set(): - break - if line.startswith(":"): - continue - if line == "": - if not data_lines: + async with self._stream_budget.reserve(operation="/event"): + async with self._client.stream( + "GET", + "/event", + params=params, + timeout=None, + headers={"Accept": "text/event-stream"}, + ) as response: + response.raise_for_status() + data_lines: list[str] = [] + async for line in response.aiter_lines(): + if stop_event and stop_event.is_set(): + break + if line.startswith(":"): continue - payload = "\n".join(data_lines).strip() - data_lines.clear() - if not payload: + if line == "": + if not data_lines: + continue + payload = "\n".join(data_lines).strip() + data_lines.clear() + if not payload: + continue + try: + event = json.loads(payload) + except json.JSONDecodeError: + continue + if isinstance(event, dict): + yield event continue - try: - event = json.loads(payload) - except json.JSONDecodeError: + if line.startswith("data:"): + data_lines.append(line[5:].lstrip()) continue - if isinstance(event, dict): - yield event - continue - if line.startswith("data:"): - data_lines.append(line[5:].lstrip()) - continue async def create_session( self, title: str | None = None, *, directory: str | None = None @@ -344,17 +412,18 @@ async def session_prompt_async( *, directory: str | None = None, ) -> None: - response = await self._client.post( - f"/session/{session_id}/prompt_async", - params=self._query_params(directory=directory), - json=request, - ) - response.raise_for_status() - if response.status_code != 204: - raise UpstreamContractError( - "OpenCode /session/{sessionID}/prompt_async must return 204; " - f"got {response.status_code}" + endpoint = "/session/{sessionID}/prompt_async" + async with self._request_budget.reserve(operation=endpoint): + response = await self._client.post( + f"/session/{session_id}/prompt_async", + params=self._query_params(directory=directory), + json=request, ) + response.raise_for_status() + if response.status_code != 204: + raise UpstreamContractError( + f"OpenCode {endpoint} must return 204; got {response.status_code}" + ) async def session_command( self, diff --git a/tests/config/test_settings.py b/tests/config/test_settings.py index 892c386..74f9c8f 100644 --- a/tests/config/test_settings.py +++ b/tests/config/test_settings.py @@ -29,6 +29,8 @@ def test_settings_valid(): "A2A_INTERRUPT_REQUEST_TOMBSTONE_TTL_SECONDS": "120", "A2A_CANCEL_ABORT_TIMEOUT_SECONDS": "0.75", "A2A_ENABLE_SESSION_SHELL": "true", + "OPENCODE_MAX_CONCURRENT_REQUESTS": "12", + "OPENCODE_MAX_CONCURRENT_STREAMS": "3", "A2A_SANDBOX_MODE": "danger-full-access", "A2A_SANDBOX_FILESYSTEM_SCOPE": "unrestricted", "A2A_SANDBOX_WRITABLE_ROOTS": "/srv/workspaces/alpha,/tmp/opencode", @@ -49,6 +51,8 @@ def test_settings_valid(): assert settings.a2a_interrupt_request_ttl_seconds == 7200.0 assert settings.a2a_interrupt_request_tombstone_ttl_seconds == 120.0 assert settings.a2a_cancel_abort_timeout_seconds == 0.75 + assert settings.opencode_max_concurrent_requests == 12 + assert settings.opencode_max_concurrent_streams == 3 assert settings.a2a_enable_session_shell is True assert settings.a2a_sandbox_mode == "danger-full-access" assert settings.a2a_sandbox_filesystem_scope == "unrestricted" diff --git a/tests/execution/test_agent_errors.py b/tests/execution/test_agent_errors.py index 61fb2e8..5c1889d 100644 --- a/tests/execution/test_agent_errors.py +++ b/tests/execution/test_agent_errors.py @@ -7,7 +7,11 @@ from a2a.types import Task, TaskArtifactUpdateEvent, TaskState, TaskStatusUpdateEvent from opencode_a2a.execution.executor import OpencodeAgentExecutor -from opencode_a2a.opencode_upstream_client import OpencodeMessage, UpstreamContractError +from opencode_a2a.opencode_upstream_client import ( + OpencodeMessage, + UpstreamConcurrencyLimitError, + UpstreamContractError, +) from tests.support.helpers import configure_mock_client_runtime, make_request_context_mock @@ -194,6 +198,56 @@ async def create_session(title: str | None = None, *, directory: str | None = No assert status.metadata["opencode"]["error"]["upstream_status"] == 429 +@pytest.mark.asyncio +async def test_streaming_execute_upstream_backpressure_emits_status_update_with_metadata() -> None: + client = AsyncMock() + + async def create_session(title: str | None = None, *, directory: str | None = None) -> str: + del title, directory + return "ses-1" + + client.create_session = create_session + client.send_message = AsyncMock( + side_effect=UpstreamConcurrencyLimitError( + category="request", + operation="/session/{sessionID}/message", + limit=1, + ) + ) + configure_mock_client_runtime(client, directory="/tmp/workspace") + client.stream_timeout = None + client.directory = None + + executor = OpencodeAgentExecutor(client, streaming_enabled=True) + executor._should_stream = lambda _context: True # noqa: E731 + context = make_request_context_mock( + task_id="task-stream-backpressure", + context_id="ctx-stream-backpressure", + user_input="hello", + call_context_enabled=False, + ) + event_queue = AsyncMock(spec=EventQueue) + + await executor.execute(context, event_queue) + + status = None + for call in event_queue.enqueue_event.call_args_list: + payload = call.args[0] + if ( + isinstance(payload, TaskStatusUpdateEvent) + and payload.final + and payload.metadata is not None + and payload.metadata.get("opencode", {}).get("error", {}).get("type") + == "UPSTREAM_BACKPRESSURE" + ): + status = payload + break + + assert status is not None + assert status.status.state == TaskState.failed + assert status.metadata["opencode"]["error"]["type"] == "UPSTREAM_BACKPRESSURE" + + @pytest.mark.asyncio async def test_execute_upstream_payload_error_maps_to_task_error_type() -> None: client = AsyncMock() @@ -236,6 +290,49 @@ async def create_session(title: str | None = None, *, directory: str | None = No assert "payload mismatch" in event.status.message.parts[0].root.text +@pytest.mark.asyncio +async def test_execute_upstream_backpressure_maps_to_task_error_type() -> None: + client = AsyncMock() + + async def create_session(title: str | None = None, *, directory: str | None = None) -> str: + del title, directory + return "ses-1" + + client.create_session = create_session + client.send_message = AsyncMock( + side_effect=UpstreamConcurrencyLimitError( + category="request", + operation="/session/{sessionID}/message", + limit=1, + ) + ) + configure_mock_client_runtime(client, directory="/tmp/workspace") + + executor = OpencodeAgentExecutor(client, streaming_enabled=False) + context = make_request_context_mock( + task_id="task-backpressure", + context_id="ctx-backpressure", + user_input="hello", + call_context_enabled=False, + ) + event_queue = AsyncMock(spec=EventQueue) + + await executor.execute(context, event_queue) + + event = None + for call in event_queue.enqueue_event.call_args_list: + payload = call.args[0] + if isinstance(payload, Task): + event = payload + break + + assert event is not None + assert event.status.state == TaskState.failed + assert event.metadata is not None + assert event.metadata["opencode"]["error"]["type"] == "UPSTREAM_BACKPRESSURE" + assert "concurrency limit exceeded" in event.status.message.parts[0].root.text + + @pytest.mark.asyncio async def test_execute_response_info_error_maps_to_task_failed_state() -> None: client = AsyncMock() diff --git a/tests/jsonrpc/test_error_responses.py b/tests/jsonrpc/test_error_responses.py index 8ffa84a..59cdee2 100644 --- a/tests/jsonrpc/test_error_responses.py +++ b/tests/jsonrpc/test_error_responses.py @@ -38,6 +38,9 @@ def test_jsonrpc_error_mapping_helpers_preserve_business_contract_fields() -> No def test_jsonrpc_error_mapping_helpers_build_upstream_envelopes() -> None: + backpressure_detail = ( + "OpenCode upstream request concurrency limit exceeded while calling /session (limit=1)" + ) http_error = upstream_http_error( -32003, upstream_status=503, @@ -51,10 +54,15 @@ def test_jsonrpc_error_mapping_helpers_build_upstream_envelopes() -> None: "session_id": "s-1", } - unreachable = upstream_unreachable_error(-32002, request_id="req-1") + unreachable = upstream_unreachable_error( + -32002, + request_id="req-1", + detail=backpressure_detail, + ) assert unreachable.data == { "type": "UPSTREAM_UNREACHABLE", "request_id": "req-1", + "detail": backpressure_detail, } payload_error = upstream_payload_error( diff --git a/tests/jsonrpc/test_opencode_session_extension_interrupts.py b/tests/jsonrpc/test_opencode_session_extension_interrupts.py index 49a8ca4..93a2227 100644 --- a/tests/jsonrpc/test_opencode_session_extension_interrupts.py +++ b/tests/jsonrpc/test_opencode_session_extension_interrupts.py @@ -2,6 +2,7 @@ import pytest from opencode_a2a.config import Settings +from opencode_a2a.opencode_upstream_client import UpstreamConcurrencyLimitError from tests.support.helpers import ( DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) @@ -452,3 +453,55 @@ class InterruptClient(DummyOpencodeUpstreamClient): payload = resp.json() assert payload["error"]["code"] == -32004 assert payload["error"]["data"]["type"] == "INTERRUPT_REQUEST_NOT_FOUND" + + +@pytest.mark.asyncio +async def test_interrupt_callback_extension_maps_concurrency_limit_to_unreachable(monkeypatch): + import opencode_a2a.server.application as app_module + + class BusyInterruptClient(DummyOpencodeUpstreamClient): + async def permission_reply( + self, + request_id: str, + *, + reply: str, + message: str | None = None, + directory: str | None = None, + ) -> bool: + del request_id, reply, message, directory + raise UpstreamConcurrencyLimitError( + category="request", + operation="/permission/{requestID}/reply", + limit=1, + ) + + dummy = BusyInterruptClient( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + await dummy.remember_interrupt_request( + request_id="perm-busy", + session_id="ses-1", + interrupt_type="permission", + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 19, + "method": "a2a.interrupt.permission.reply", + "params": {"request_id": "perm-busy", "reply": "once"}, + }, + ) + payload = resp.json() + assert payload["error"]["code"] == -32002 + assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" + assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] diff --git a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py index 4bb9796..a3a454b 100644 --- a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py +++ b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py @@ -3,7 +3,10 @@ import httpx import pytest -from opencode_a2a.opencode_upstream_client import UpstreamContractError +from opencode_a2a.opencode_upstream_client import ( + UpstreamConcurrencyLimitError, + UpstreamContractError, +) from tests.support.helpers import ( DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) @@ -487,6 +490,46 @@ async def _release_raises(self: SessionManager, *, identity: str, session_id: st ) +@pytest.mark.asyncio +async def test_session_prompt_async_extension_maps_concurrency_limit_to_unreachable(monkeypatch): + import opencode_a2a.server.application as app_module + + class BusyPromptAsyncClient(DummyOpencodeUpstreamClient): + async def session_prompt_async(self, session_id: str, request: dict, *, directory=None): + del session_id, request, directory + raise UpstreamConcurrencyLimitError( + category="request", + operation="/session/{sessionID}/prompt_async", + limit=1, + ) + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", BusyPromptAsyncClient) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 3082, + "method": "opencode.sessions.prompt_async", + "params": { + "session_id": "s-1", + "request": {"parts": [{"type": "text", "text": "x"}]}, + }, + }, + ) + payload = resp.json() + assert payload["error"]["code"] == -32002 + assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" + assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] + + @pytest.mark.asyncio async def test_session_prompt_async_extension_notification_returns_204(monkeypatch): import opencode_a2a.server.application as app_module diff --git a/tests/jsonrpc/test_opencode_session_extension_queries.py b/tests/jsonrpc/test_opencode_session_extension_queries.py index 5fcb292..c2b97ba 100644 --- a/tests/jsonrpc/test_opencode_session_extension_queries.py +++ b/tests/jsonrpc/test_opencode_session_extension_queries.py @@ -8,6 +8,7 @@ SESSION_QUERY_DEFAULT_LIMIT, SESSION_QUERY_MAX_LIMIT, ) +from opencode_a2a.opencode_upstream_client import UpstreamConcurrencyLimitError from tests.support.helpers import ( DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, ) @@ -339,6 +340,43 @@ async def test_provider_discovery_extension_maps_payload_mismatch(monkeypatch): assert payload["error"]["data"]["type"] == "UPSTREAM_PAYLOAD_ERROR" +@pytest.mark.asyncio +async def test_provider_discovery_extension_maps_concurrency_limit_to_unreachable(monkeypatch): + import opencode_a2a.server.application as app_module + + class BusyDiscoveryClient(DummyOpencodeUpstreamClient): + async def list_provider_catalog(self, *, directory: str | None = None): + del directory + raise UpstreamConcurrencyLimitError( + category="request", + operation="/provider", + limit=1, + ) + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", BusyDiscoveryClient) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 141, + "method": "opencode.providers.list", + "params": {}, + }, + ) + payload = resp.json() + assert payload["error"]["code"] == -32002 + assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" + assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] + + @pytest.mark.asyncio async def test_session_query_extension_rejects_non_array_upstream_payload(monkeypatch): import opencode_a2a.server.application as app_module @@ -372,6 +410,38 @@ def __init__(self, _settings: Settings) -> None: assert payload["error"]["data"]["type"] == "UPSTREAM_PAYLOAD_ERROR" +@pytest.mark.asyncio +async def test_session_query_extension_maps_concurrency_limit_to_unreachable(monkeypatch): + import opencode_a2a.server.application as app_module + + class BusySessionQueryClient(DummyOpencodeUpstreamClient): + async def list_sessions(self, *, params=None): + del params + raise UpstreamConcurrencyLimitError( + category="request", + operation="/session", + limit=1, + ) + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", BusySessionQueryClient) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + resp = await client.post( + "/", + headers=headers, + json={"jsonrpc": "2.0", "id": 15, "method": "opencode.sessions.list", "params": {}}, + ) + payload = resp.json() + assert payload["error"]["code"] == -32002 + assert payload["error"]["data"]["type"] == "UPSTREAM_UNREACHABLE" + assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] + + @pytest.mark.asyncio async def test_session_query_extension_session_title_is_extracted_or_placeholder(monkeypatch): import opencode_a2a.server.application as app_module diff --git a/tests/upstream/test_opencode_upstream_client_params.py b/tests/upstream/test_opencode_upstream_client_params.py index 5b64e31..5f9ce52 100644 --- a/tests/upstream/test_opencode_upstream_client_params.py +++ b/tests/upstream/test_opencode_upstream_client_params.py @@ -1,3 +1,4 @@ +import asyncio import json as json_module import httpx @@ -6,6 +7,7 @@ from opencode_a2a.opencode_upstream_client import ( _UNSET, OpencodeUpstreamClient, + UpstreamConcurrencyLimitError, UpstreamContractError, ) from tests.support.helpers import make_settings @@ -36,6 +38,33 @@ def json(self): return self._payload +class _HoldingStreamResponse: + def __init__(self, started: asyncio.Event, release: asyncio.Event) -> None: + self._started = started + self._release = release + + def raise_for_status(self) -> None: + return None + + async def aiter_lines(self): + self._started.set() + await self._release.wait() + yield 'data: {"kind": "tick"}' + yield "" + + +class _HoldingStreamContext: + def __init__(self, started: asyncio.Event, release: asyncio.Event) -> None: + self._response = _HoldingStreamResponse(started, release) + + async def __aenter__(self) -> _HoldingStreamResponse: + return self._response + + async def __aexit__(self, exc_type, exc, tb) -> bool: # noqa: ANN001 + del exc_type, exc, tb + return False + + @pytest.mark.asyncio async def test_merge_params_does_not_allow_directory_override(monkeypatch): client = OpencodeUpstreamClient( @@ -341,6 +370,45 @@ async def fake_post(path: str, *, params=None, json=None, **_kwargs): await client.close() +@pytest.mark.asyncio +async def test_send_message_raises_concurrency_limit_error_when_request_budget_exhausted( + monkeypatch, +): + client = OpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + opencode_timeout=1.0, + opencode_max_concurrent_requests=1, + a2a_log_level="DEBUG", + a2a_log_payloads=False, + ) + ) + + started = asyncio.Event() + release = asyncio.Event() + + async def fake_post(path: str, *, params=None, json=None, **_kwargs): + del path, params, json + started.set() + await release.wait() + return _DummyResponse({"info": {"id": "m-1"}, "parts": [{"type": "text", "text": "ok"}]}) + + monkeypatch.setattr(client._client, "post", fake_post) + + first_request = asyncio.create_task(client.send_message("ses-1", "hello")) + await started.wait() + + with pytest.raises( + UpstreamConcurrencyLimitError, + match="request concurrency limit exceeded", + ): + await client.send_message("ses-2", "blocked") + + release.set() + await first_request + await client.close() + + @pytest.mark.asyncio async def test_permission_reply_raises_on_404_without_legacy_fallback(monkeypatch): client = OpencodeUpstreamClient( @@ -616,6 +684,45 @@ async def test_interrupt_request_ttl_defaults_to_three_hours_and_is_configurable await configured_client.close() +@pytest.mark.asyncio +async def test_stream_events_raises_concurrency_limit_error_when_stream_budget_exhausted( + monkeypatch, +) -> None: + client = OpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + opencode_timeout=1.0, + opencode_max_concurrent_streams=1, + a2a_log_level="DEBUG", + a2a_log_payloads=False, + ) + ) + + started = asyncio.Event() + release = asyncio.Event() + + def fake_stream(method: str, path: str, *, params=None, timeout=None, headers=None): + del method, path, params, timeout, headers + return _HoldingStreamContext(started, release) + + monkeypatch.setattr(client._client, "stream", fake_stream) + + first_stream = client.stream_events() + first_event = asyncio.create_task(anext(first_stream)) + await started.wait() + + with pytest.raises( + UpstreamConcurrencyLimitError, + match="stream concurrency limit exceeded", + ): + await anext(client.stream_events()) + + release.set() + assert await first_event == {"kind": "tick"} + await first_stream.aclose() + await client.close() + + def test_response_body_preview_handles_empty_and_long_payloads() -> None: empty = _DummyResponse(text=" ") assert OpencodeUpstreamClient._response_body_preview(empty) == ""