From 4c0a071baac794483740e053d1ec671113ac9b7b Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Thu, 26 Mar 2026 09:30:03 -0400 Subject: [PATCH] add optional client polling fallback Refs #310 --- docs/guide.md | 15 ++++ src/opencode_a2a/client/client.py | 81 +++++++++++++++++- src/opencode_a2a/client/config.py | 81 ++++++++++++++++++ src/opencode_a2a/client/polling.py | 68 +++++++++++++++ tests/client/test_client_config.py | 20 +++++ tests/client/test_client_facade.py | 127 ++++++++++++++++++++++++++++- tests/client/test_polling.py | 33 ++++++++ 7 files changed, 421 insertions(+), 4 deletions(-) create mode 100644 src/opencode_a2a/client/polling.py create mode 100644 tests/client/test_polling.py diff --git a/docs/guide.md b/docs/guide.md index 33969ee..099be4a 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -119,6 +119,21 @@ variables. Configure `A2A_CLIENT_BEARER_TOKEN` or `A2A_CLIENT_BASIC_AUTH` when the remote agent protects its runtime surface. CLI outbound calls follow the same environment-only model. +`A2AClient.send()` returns the latest response event and keeps the default +stream-first behavior. If a peer returns a non-terminal task snapshot and +expects follow-up `tasks/get` polling, enable the optional facade fallback +with: + +- `A2A_CLIENT_POLLING_FALLBACK_ENABLED=true` +- `A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS` +- `A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS` +- `A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER` +- `A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS` + +The fallback only applies to `send()`, keeps `send_message()` as a thin event +stream wrapper, and stops polling once the task reaches a terminal state or a +caller-intervention state such as `input-required` or `auth-required`. + Execution-boundary metadata is intentionally declarative deployment metadata: it is published through `RuntimeProfile`, Agent Card, OpenAPI, and `/health`, and should not be interpreted as a live per-request privilege snapshot or a diff --git a/src/opencode_a2a/client/client.py b/src/opencode_a2a/client/client.py index 5349159..fbf5c06 100644 --- a/src/opencode_a2a/client/client.py +++ b/src/opencode_a2a/client/client.py @@ -32,8 +32,9 @@ map_agent_card_error, map_operation_error, ) -from .errors import A2AUnsupportedBindingError +from .errors import A2ATimeoutError, A2AUnsupportedBindingError from .payload_text import extract_text as extract_text_from_payload +from .polling import PollingFallbackPolicy from .request_context import build_call_context, build_client_interceptors, split_request_metadata @@ -58,6 +59,13 @@ def __init__( self._lock = asyncio.Lock() self._request_lock = asyncio.Lock() self._active_requests = 0 + self._polling_fallback_policy = PollingFallbackPolicy( + enabled=self._settings.polling_fallback_enabled, + initial_interval_seconds=self._settings.polling_fallback_initial_interval_seconds, + max_interval_seconds=self._settings.polling_fallback_max_interval_seconds, + backoff_multiplier=self._settings.polling_fallback_backoff_multiplier, + timeout_seconds=self._settings.polling_fallback_timeout_seconds, + ) async def close(self) -> None: """Close cached client resources and owned HTTP transport.""" @@ -149,7 +157,12 @@ async def send( metadata: Mapping[str, Any] | None = None, extensions: list[str] | None = None, ) -> Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None: - """Send a message and return the terminal response/event.""" + """Send a message and return the latest response event. + + When polling fallback is enabled, a non-terminal `(Task, None)` result may + be followed by bounded `tasks/get` polling until a terminal task snapshot + is observed. + """ last_event: ( Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None ) = None @@ -162,7 +175,13 @@ async def send( extensions=extensions, ): last_event = event - return last_event + if not self._should_poll_after_send(last_event): + return last_event + terminal_task = await self._poll_task_until_terminal( + self._extract_task_from_client_event(last_event), + metadata=metadata, + ) + return (terminal_task, None) async def get_task( self, @@ -299,6 +318,62 @@ async def _release_operation(self) -> None: if self._active_requests > 0: self._active_requests -= 1 + def _should_poll_after_send( + self, + event: Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None, + ) -> bool: + if not self._polling_fallback_policy.enabled: + return False + if event is None or isinstance(event, Message) or not isinstance(event, tuple): + return False + task, update = event + if update is not None: + return False + return self._polling_fallback_policy.should_poll_state(task.status.state) + + def _extract_task_from_client_event( + self, + event: Message | tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None] | None, + ) -> Task: + task, _update = cast( + tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None], + event, + ) + return task + + async def _poll_task_until_terminal( + self, + task: Task, + *, + metadata: Mapping[str, Any] | None = None, + ) -> Task: + deadline = self._current_time() + self._polling_fallback_policy.timeout_seconds + interval = self._polling_fallback_policy.initial_interval_seconds + current_task = task + + while True: + if self._polling_fallback_policy.is_terminal_state(current_task.status.state): + return current_task + if not self._polling_fallback_policy.should_poll_state(current_task.status.state): + return current_task + + remaining = deadline - self._current_time() + if remaining <= 0: + raise A2ATimeoutError( + "Remote A2A peer did not reach a terminal task state " + "before polling fallback timed out" + ) + + await self._sleep(min(interval, remaining)) + current_task = await self.get_task(current_task.id, metadata=metadata) + interval = self._polling_fallback_policy.next_interval_seconds(interval) + + def _current_time(self) -> float: + return asyncio.get_running_loop().time() + + async def _sleep(self, delay_seconds: float) -> None: + await asyncio.sleep(delay_seconds) + def _build_user_message( self, *, diff --git a/src/opencode_a2a/client/config.py b/src/opencode_a2a/client/config.py index a48f1b6..b202674 100644 --- a/src/opencode_a2a/client/config.py +++ b/src/opencode_a2a/client/config.py @@ -7,6 +7,7 @@ from typing import Any from .auth import validate_basic_auth +from .polling import PollingFallbackPolicy, validate_polling_fallback_policy def _read_setting( @@ -113,6 +114,11 @@ class A2AClientSettings: "JSONRPC", "HTTP+JSON", ) + polling_fallback_enabled: bool = False + polling_fallback_initial_interval_seconds: float = 0.5 + polling_fallback_max_interval_seconds: float = 2.0 + polling_fallback_backoff_multiplier: float = 2.0 + polling_fallback_timeout_seconds: float = 10.0 def load_settings(raw_settings: Any) -> A2AClientSettings: @@ -177,6 +183,76 @@ def load_settings(raw_settings: Any) -> A2AClientSettings: ), default=("JSONRPC", "HTTP+JSON"), ) + polling_fallback_enabled = _coerce_bool( + "A2A_CLIENT_POLLING_FALLBACK_ENABLED", + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_POLLING_FALLBACK_ENABLED", + "a2a_client_polling_fallback_enabled", + ), + default=False, + ), + default=False, + ) + polling_fallback_initial_interval_seconds = _coerce_float( + "A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS", + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS", + "a2a_client_polling_fallback_initial_interval_seconds", + ), + default=0.5, + ), + default=0.5, + ) + polling_fallback_max_interval_seconds = _coerce_float( + "A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS", + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS", + "a2a_client_polling_fallback_max_interval_seconds", + ), + default=2.0, + ), + default=2.0, + ) + polling_fallback_backoff_multiplier = _coerce_float( + "A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER", + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER", + "a2a_client_polling_fallback_backoff_multiplier", + ), + default=2.0, + ), + default=2.0, + ) + polling_fallback_timeout_seconds = _coerce_float( + "A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS", + _read_setting( + raw_settings, + keys=( + "A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS", + "a2a_client_polling_fallback_timeout_seconds", + ), + default=10.0, + ), + default=10.0, + ) + + validate_polling_fallback_policy( + PollingFallbackPolicy( + enabled=polling_fallback_enabled, + initial_interval_seconds=polling_fallback_initial_interval_seconds, + max_interval_seconds=polling_fallback_max_interval_seconds, + backoff_multiplier=polling_fallback_backoff_multiplier, + timeout_seconds=polling_fallback_timeout_seconds, + ) + ) return A2AClientSettings( default_timeout=default_timeout, @@ -185,6 +261,11 @@ def load_settings(raw_settings: Any) -> A2AClientSettings: bearer_token=bearer_token, basic_auth=basic_auth, supported_transports=supported_transports, + polling_fallback_enabled=polling_fallback_enabled, + polling_fallback_initial_interval_seconds=polling_fallback_initial_interval_seconds, + polling_fallback_max_interval_seconds=polling_fallback_max_interval_seconds, + polling_fallback_backoff_multiplier=polling_fallback_backoff_multiplier, + polling_fallback_timeout_seconds=polling_fallback_timeout_seconds, ) diff --git a/src/opencode_a2a/client/polling.py b/src/opencode_a2a/client/polling.py new file mode 100644 index 0000000..1f38b0c --- /dev/null +++ b/src/opencode_a2a/client/polling.py @@ -0,0 +1,68 @@ +"""Polling fallback policy helpers for the A2A client facade.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from a2a.types import TaskState + +_TERMINAL_TASK_STATES = frozenset( + { + TaskState.completed, + TaskState.canceled, + TaskState.failed, + TaskState.rejected, + } +) +_AUTO_POLLING_TASK_STATES = frozenset( + { + TaskState.submitted, + TaskState.working, + TaskState.unknown, + } +) + + +@dataclass(frozen=True) +class PollingFallbackPolicy: + """Encapsulates polling fallback configuration and task-state rules.""" + + enabled: bool = False + initial_interval_seconds: float = 0.5 + max_interval_seconds: float = 2.0 + backoff_multiplier: float = 2.0 + timeout_seconds: float = 10.0 + + def should_poll_state(self, state: TaskState) -> bool: + return state in _AUTO_POLLING_TASK_STATES + + def is_terminal_state(self, state: TaskState) -> bool: + return state in _TERMINAL_TASK_STATES + + def next_interval_seconds(self, current_interval_seconds: float) -> float: + return min( + max(current_interval_seconds, 0.0) * self.backoff_multiplier, + self.max_interval_seconds, + ) + + +def validate_polling_fallback_policy(policy: PollingFallbackPolicy) -> None: + """Validate polling fallback settings before they are used at runtime.""" + if policy.initial_interval_seconds <= 0: + raise ValueError("A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS must be positive") + if policy.max_interval_seconds <= 0: + raise ValueError("A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS must be positive") + if policy.backoff_multiplier < 1.0: + raise ValueError( + "A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER must be greater than or equal to 1" + ) + if policy.timeout_seconds <= 0: + raise ValueError("A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS must be positive") + if policy.max_interval_seconds < policy.initial_interval_seconds: + raise ValueError( + "A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS must be greater than or " + "equal to A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS" + ) + + +__all__ = ["PollingFallbackPolicy", "validate_polling_fallback_policy"] diff --git a/tests/client/test_client_config.py b/tests/client/test_client_config.py index ef990db..04daaf6 100644 --- a/tests/client/test_client_config.py +++ b/tests/client/test_client_config.py @@ -21,6 +21,11 @@ def test_load_settings_from_mapping() -> None: "A2A_CLIENT_BEARER_TOKEN": "peer-token", "A2A_CLIENT_BASIC_AUTH": "user:pass", "A2A_CLIENT_SUPPORTED_TRANSPORTS": "json-rpc,http-json", + "A2A_CLIENT_POLLING_FALLBACK_ENABLED": "true", + "A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS": "0.75", + "A2A_CLIENT_POLLING_FALLBACK_MAX_INTERVAL_SECONDS": "3", + "A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER": "1.5", + "A2A_CLIENT_POLLING_FALLBACK_TIMEOUT_SECONDS": "12", } settings = load_settings(raw) @@ -31,6 +36,11 @@ def test_load_settings_from_mapping() -> None: assert settings.bearer_token == "peer-token" assert settings.basic_auth == "user:pass" assert settings.supported_transports == ("JSONRPC", "HTTP+JSON") + assert settings.polling_fallback_enabled is True + assert settings.polling_fallback_initial_interval_seconds == 0.75 + assert settings.polling_fallback_max_interval_seconds == 3.0 + assert settings.polling_fallback_backoff_multiplier == 1.5 + assert settings.polling_fallback_timeout_seconds == 12.0 def test_load_settings_invalid_transport_raises() -> None: @@ -59,3 +69,13 @@ def test_load_settings_accepts_base64_basic_auth() -> None: def test_load_settings_invalid_basic_auth_raises() -> None: with pytest.raises(ValueError, match="username:password"): load_settings({"A2A_CLIENT_BASIC_AUTH": "not-basic-auth"}) + + +def test_load_settings_invalid_polling_fallback_interval_raises() -> None: + with pytest.raises(ValueError, match="INITIAL_INTERVAL_SECONDS must be positive"): + load_settings({"A2A_CLIENT_POLLING_FALLBACK_INITIAL_INTERVAL_SECONDS": "0"}) + + +def test_load_settings_invalid_polling_fallback_backoff_raises() -> None: + with pytest.raises(ValueError, match="BACKOFF_MULTIPLIER must be greater than or equal to 1"): + load_settings({"A2A_CLIENT_POLLING_FALLBACK_BACKOFF_MULTIPLIER": "0.5"}) diff --git a/tests/client/test_client_facade.py b/tests/client/test_client_facade.py index 8f5e347..67733a6 100644 --- a/tests/client/test_client_facade.py +++ b/tests/client/test_client_facade.py @@ -8,13 +8,14 @@ import pytest from a2a.client import ClientConfig from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError, A2AClientJSONRPCError -from a2a.types import JSONRPCError, JSONRPCErrorResponse +from a2a.types import JSONRPCError, JSONRPCErrorResponse, Task, TaskState, TaskStatus from opencode_a2a.client import A2AClient from opencode_a2a.client import client as client_module from opencode_a2a.client.config import A2AClientSettings from opencode_a2a.client.errors import ( A2APeerProtocolError, + A2ATimeoutError, A2AUnsupportedOperationError, ) @@ -35,9 +36,13 @@ def __init__( events: list[object] | None = None, *, fail: BaseException | None = None, + task_results: list[object] | None = None, + task_fail: BaseException | None = None, ): self._events = list(events or []) self._fail = fail + self._task_results = list(task_results or []) + self._task_fail = task_fail self.send_message_inputs: list[tuple[object, object, object]] = [] self.task_inputs: list[tuple[object, object]] = [] self.cancel_inputs: list[tuple[object, object]] = [] @@ -52,8 +57,12 @@ async def send_message(self, message, *args: object, **kwargs: object) -> AsyncI async def get_task(self, params, *args: object, **kwargs: object) -> object: self.task_inputs.append((params, kwargs)) + if self._task_fail: + raise self._task_fail if self._fail: raise self._fail + if self._task_results: + return self._task_results.pop(0) return {"task_id": params.id} async def cancel_task(self, params, *args: object, **kwargs: object) -> object: @@ -70,6 +79,14 @@ async def resubscribe(self, params, *args: object, **kwargs: object) -> AsyncIte yield event +def _task(task_id: str, state: TaskState) -> Task: + return Task( + id=task_id, + context_id="ctx-1", + status=TaskStatus(state=state), + ) + + @pytest.mark.asyncio async def test_get_agent_card_cached_and_reused(monkeypatch: pytest.MonkeyPatch) -> None: resolver = _FakeCardResolver("agent-card") @@ -147,6 +164,114 @@ async def test_send_returns_last_event(monkeypatch: pytest.MonkeyPatch) -> None: assert response == "last" +@pytest.mark.asyncio +async def test_send_polling_fallback_returns_terminal_task(monkeypatch: pytest.MonkeyPatch) -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings( + polling_fallback_enabled=True, + polling_fallback_initial_interval_seconds=0.1, + polling_fallback_max_interval_seconds=0.2, + polling_fallback_backoff_multiplier=2.0, + polling_fallback_timeout_seconds=5.0, + ), + ) + fake_client = _FakeClient( + events=[(_task("task-1", TaskState.working), None)], + task_results=[ + _task("task-1", TaskState.working), + _task("task-1", TaskState.completed), + ], + ) + sleep_calls: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + sleep_calls.append(delay) + + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr(client, "_sleep", _fake_sleep) + + response = await client.send("hello") + + assert response == (_task("task-1", TaskState.completed), None) + assert [params.id for params, _kwargs in fake_client.task_inputs] == ["task-1", "task-1"] + assert sleep_calls == [0.1, 0.2] + + +@pytest.mark.asyncio +async def test_send_polling_fallback_skips_input_required(monkeypatch: pytest.MonkeyPatch) -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings(polling_fallback_enabled=True), + ) + event = (_task("task-1", TaskState.input_required), None) + fake_client = _FakeClient(events=[event]) + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + + response = await client.send("hello") + + assert response == event + assert fake_client.task_inputs == [] + + +@pytest.mark.asyncio +async def test_send_polling_fallback_timeout_raises(monkeypatch: pytest.MonkeyPatch) -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings( + polling_fallback_enabled=True, + polling_fallback_initial_interval_seconds=0.1, + polling_fallback_max_interval_seconds=0.2, + polling_fallback_backoff_multiplier=2.0, + polling_fallback_timeout_seconds=0.2, + ), + ) + fake_client = _FakeClient( + events=[(_task("task-1", TaskState.working), None)], + task_results=[_task("task-1", TaskState.working)], + ) + now_values = iter([0.0, 0.0, 0.3]) + + async def _fake_sleep(_delay: float) -> None: + return None + + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr(client, "_sleep", _fake_sleep) + monkeypatch.setattr(client, "_current_time", lambda: next(now_values)) + + with pytest.raises(A2ATimeoutError, match="polling fallback timed out"): + await client.send("hello") + + +@pytest.mark.asyncio +async def test_send_polling_fallback_maps_get_task_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + client = A2AClient( + "http://agent.example.com", + settings=A2AClientSettings( + polling_fallback_enabled=True, + polling_fallback_initial_interval_seconds=0.1, + polling_fallback_max_interval_seconds=0.2, + polling_fallback_backoff_multiplier=2.0, + polling_fallback_timeout_seconds=5.0, + ), + ) + fake_client = _FakeClient( + events=[(_task("task-1", TaskState.working), None)], + task_fail=A2AClientHTTPError(404, "gone"), + ) + + async def _fake_sleep(_delay: float) -> None: + return None + + monkeypatch.setattr(A2AClient, "_build_client", AsyncMock(return_value=fake_client)) + monkeypatch.setattr(client, "_sleep", _fake_sleep) + + with pytest.raises(A2AUnsupportedOperationError, match="does not support tasks/get"): + await client.send("hello") + + @pytest.mark.asyncio async def test_send_message_adds_bearer_token_from_settings( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/client/test_polling.py b/tests/client/test_polling.py new file mode 100644 index 0000000..909d845 --- /dev/null +++ b/tests/client/test_polling.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest +from a2a.types import TaskState + +from opencode_a2a.client.polling import ( + PollingFallbackPolicy, + validate_polling_fallback_policy, +) + + +def test_polling_policy_state_rules_and_backoff() -> None: + policy = PollingFallbackPolicy( + enabled=True, + initial_interval_seconds=0.5, + max_interval_seconds=2.0, + backoff_multiplier=2.0, + timeout_seconds=10.0, + ) + + assert policy.should_poll_state(TaskState.working) is True + assert policy.should_poll_state(TaskState.input_required) is False + assert policy.is_terminal_state(TaskState.completed) is True + assert policy.is_terminal_state(TaskState.working) is False + assert policy.next_interval_seconds(0.5) == 1.0 + assert policy.next_interval_seconds(2.0) == 2.0 + + +def test_validate_polling_policy_rejects_invalid_timeout() -> None: + with pytest.raises(ValueError, match="TIMEOUT_SECONDS must be positive"): + validate_polling_fallback_policy( + PollingFallbackPolicy(timeout_seconds=0.0), + )