diff --git a/docs/guide.md b/docs/guide.md index 49678f3..bb1a76a 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -649,10 +649,20 @@ No extra custom REST endpoint is introduced. - session list => `Task` with `status.state=completed` - message history => `Message` - limit pagination defaults to `20`; requests above `100` are rejected + - `opencode.sessions.messages.list` also returns `result.next_cursor` + when older messages are available - `contextId` is an A2A context key derived by the adapter (format: `ctx:opencode-session:`, not raw OpenCode session ID) - OpenCode session identity is exposed explicitly at `metadata.shared.session.id` - session title is available at `metadata.shared.session.title` +- Session list filters: + - optional `directory`, `roots`, `start`, `search`, `limit` + - `directory` is normalized through the same workspace-boundary rules used by + other OpenCode directory overrides before reaching upstream +- Session message history filters: + - optional `limit`, `before` + - `before` is an opaque cursor for loading older messages and is only + supported on `opencode.sessions.messages.list` ### Session List (`opencode.sessions.list`) @@ -664,7 +674,12 @@ curl -sS http://127.0.0.1:8000/ \ "jsonrpc": "2.0", "id": 1, "method": "opencode.sessions.list", - "params": {"limit": 20} + "params": { + "directory": "services/api", + "roots": true, + "search": "planner", + "limit": 20 + } }' ``` @@ -680,11 +695,18 @@ curl -sS http://127.0.0.1:8000/ \ "method": "opencode.sessions.messages.list", "params": { "session_id": "", + "before": "", "limit": 50 } }' ``` +Message history responses include: + +- `result.items`: normalized A2A `Message[]` +- `result.next_cursor`: opaque cursor for the next older page, or `null` when + no older page is available + ### Session Prompt Async (`opencode.sessions.prompt_async`) ```bash diff --git a/src/opencode_a2a/contracts/extensions.py b/src/opencode_a2a/contracts/extensions.py index 3179fdb..e66f28d 100644 --- a/src/opencode_a2a/contracts/extensions.py +++ b/src/opencode_a2a/contracts/extensions.py @@ -102,17 +102,28 @@ class InterruptRecoveryMethodContract: *SHELL_REQUEST_OPTIONAL_FIELDS, ) -SESSION_QUERY_PAGINATION_MODE = "limit" +SESSION_QUERY_PAGINATION_MODE = "limit_and_optional_cursor" SESSION_QUERY_PAGINATION_BEHAVIOR = "passthrough" SESSION_QUERY_DEFAULT_LIMIT = 20 SESSION_QUERY_MAX_LIMIT = 100 -SESSION_QUERY_PAGINATION_PARAMS: tuple[str, ...] = ("limit",) +SESSION_QUERY_PAGINATION_PARAMS: tuple[str, ...] = ("limit", "before") SESSION_QUERY_PAGINATION_UNSUPPORTED: tuple[str, ...] = ("cursor", "page", "size") SESSION_QUERY_METHOD_CONTRACTS: dict[str, SessionQueryMethodContract] = { "list_sessions": SessionQueryMethodContract( method="opencode.sessions.list", - optional_params=("limit", "query.limit"), + optional_params=( + "limit", + "directory", + "roots", + "start", + "search", + "query.limit", + "query.directory", + "query.roots", + "query.start", + "query.search", + ), unsupported_params=SESSION_QUERY_PAGINATION_UNSUPPORTED, result_fields=("items",), items_type="Task[]", @@ -122,9 +133,9 @@ class InterruptRecoveryMethodContract: "get_session_messages": SessionQueryMethodContract( method="opencode.sessions.messages.list", required_params=("session_id",), - optional_params=("limit", "query.limit"), + optional_params=("limit", "before", "query.limit", "query.before"), unsupported_params=SESSION_QUERY_PAGINATION_UNSUPPORTED, - result_fields=("items",), + result_fields=("items", "next_cursor"), items_type="Message[]", notification_response_status=204, pagination_mode=SESSION_QUERY_PAGINATION_MODE, @@ -660,7 +671,10 @@ def build_session_query_extension_params( "max_limit": SESSION_QUERY_MAX_LIMIT, "behavior": SESSION_QUERY_PAGINATION_BEHAVIOR, "params": list(SESSION_QUERY_PAGINATION_PARAMS), + "cursor_param": "before", + "result_cursor_field": "next_cursor", "applies_to": pagination_applies_to, + "cursor_applies_to": [SESSION_QUERY_METHODS["get_session_messages"]], }, "method_contracts": method_contracts, "errors": { diff --git a/src/opencode_a2a/jsonrpc/handlers/session_queries.py b/src/opencode_a2a/jsonrpc/handlers/session_queries.py index e6c8a20..dad8df8 100644 --- a/src/opencode_a2a/jsonrpc/handlers/session_queries.py +++ b/src/opencode_a2a/jsonrpc/handlers/session_queries.py @@ -46,7 +46,6 @@ async def handle_session_query_request( params: dict[str, Any], request: Request, ) -> Response: - del request try: if base_request.method == context.method_list_sessions: query = parse_list_sessions_params(params) @@ -60,9 +59,33 @@ async def handle_session_query_request( ) limit = int(query["limit"]) + directory = None + if base_request.method == context.method_list_sessions: + requested_directory = query.pop("directory", None) + if requested_directory is not None and not isinstance(requested_directory, str): + return context.error_response( + base_request.id, + invalid_params_error( + "directory must be a string", + data={"type": "INVALID_FIELD", "field": "directory"}, + ), + ) + try: + directory = context.directory_resolver(requested_directory) + except ValueError as exc: + return context.error_response( + base_request.id, + invalid_params_error( + str(exc), + data={"type": "INVALID_FIELD", "field": "directory"}, + ), + ) try: if base_request.method == context.method_list_sessions: - raw_result = await context.upstream_client.list_sessions(params=query) + raw_result = await context.upstream_client.list_sessions( + params=query, + directory=directory, + ) else: assert session_id is not None raw_result = await context.upstream_client.list_messages(session_id, params=query) @@ -105,7 +128,7 @@ async def handle_session_query_request( if base_request.method == context.method_list_sessions: raw_items = _extract_raw_items(raw_result, kind="sessions") else: - raw_items = _extract_raw_items(raw_result, kind="messages") + raw_items = _extract_raw_items(raw_result.payload, kind="messages") except ValueError as exc: logger.warning("Upstream OpenCode payload mismatch: %s", exc) return build_upstream_payload_error_response( @@ -131,4 +154,7 @@ async def handle_session_query_request( mapped.append(message) items = mapped - return build_success_response(context, base_request.id, {"items": items}) + result: dict[str, Any] = {"items": items} + if base_request.method == context.method_get_session_messages: + result["next_cursor"] = raw_result.next_cursor + return build_success_response(context, base_request.id, result) diff --git a/src/opencode_a2a/jsonrpc/params.py b/src/opencode_a2a/jsonrpc/params.py index 6dae11f..ed08cfd 100644 --- a/src/opencode_a2a/jsonrpc/params.py +++ b/src/opencode_a2a/jsonrpc/params.py @@ -46,6 +46,66 @@ def _parse_positive_int(value: Any, *, field: str) -> int | None: return parsed +def _parse_non_negative_int(value: Any, *, field: str) -> int | None: + if value is None: + return None + if isinstance(value, bool): + raise JsonRpcParamsValidationError( + message=f"{field} must be an integer", + data={"type": "INVALID_FIELD", "field": field}, + ) + if isinstance(value, int): + parsed = value + elif isinstance(value, str): + try: + parsed = int(value) + except ValueError as exc: + raise JsonRpcParamsValidationError( + message=f"{field} must be an integer", + data={"type": "INVALID_FIELD", "field": field}, + ) from exc + else: + raise JsonRpcParamsValidationError( + message=f"{field} must be an integer", + data={"type": "INVALID_FIELD", "field": field}, + ) + if parsed < 0: + raise JsonRpcParamsValidationError( + message=f"{field} must be >= 0", + data={"type": "INVALID_FIELD", "field": field}, + ) + return parsed + + +def _parse_string_field(value: Any, *, field: str) -> str | None: + if value is None: + return None + if not isinstance(value, str): + raise JsonRpcParamsValidationError( + message=f"{field} must be a string", + data={"type": "INVALID_FIELD", "field": field}, + ) + normalized = value.strip() + return normalized or None + + +def _parse_bool_field(value: Any, *, field: str) -> bool | None: + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True + if normalized in {"false", "0", "no", "off"}: + return False + raise JsonRpcParamsValidationError( + message=f"{field} must be a boolean", + data={"type": "INVALID_FIELD", "field": field}, + ) + + def _parse_query_object(params: dict[str, Any]) -> dict[str, Any]: raw_query = params.get("query") if raw_query is None: @@ -104,10 +164,69 @@ def _normalize_session_query_limit( return normalized_query +def _normalize_alias_field( + *, + params: dict[str, Any], + query: dict[str, Any], + field: str, + parser, +) -> Any: + top_level_value = parser(params.get(field), field=field) + query_value = parser(query.get(field), field=field) + if top_level_value is not None and query_value is not None and top_level_value != query_value: + raise JsonRpcParamsValidationError( + message=f"{field} is ambiguous between params.{field} and params.query.{field}", + data={"type": "INVALID_FIELD", "field": field}, + ) + return top_level_value if top_level_value is not None else query_value + + def parse_list_sessions_params(params: dict[str, Any]) -> dict[str, Any]: query = _parse_query_object(params) _validate_pagination_fields(params, query) - return _normalize_session_query_limit(params=params, query=query) + normalized_query = _normalize_session_query_limit(params=params, query=query) + directory = _normalize_alias_field( + params=params, + query=query, + field="directory", + parser=_parse_string_field, + ) + roots = _normalize_alias_field( + params=params, + query=query, + field="roots", + parser=_parse_bool_field, + ) + start = _normalize_alias_field( + params=params, + query=query, + field="start", + parser=_parse_non_negative_int, + ) + search = _normalize_alias_field( + params=params, + query=query, + field="search", + parser=_parse_string_field, + ) + + if directory is not None: + normalized_query["directory"] = directory + else: + normalized_query.pop("directory", None) + if roots is not None: + normalized_query["roots"] = roots + else: + normalized_query.pop("roots", None) + if start is not None: + normalized_query["start"] = start + else: + normalized_query.pop("start", None) + if search is not None: + normalized_query["search"] = search + else: + normalized_query.pop("search", None) + return normalized_query def parse_get_session_messages_params(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: @@ -120,4 +239,15 @@ def parse_get_session_messages_params(params: dict[str, Any]) -> tuple[str, dict query = _parse_query_object(params) _validate_pagination_fields(params, query) - return raw_session_id.strip(), _normalize_session_query_limit(params=params, query=query) + normalized_query = _normalize_session_query_limit(params=params, query=query) + before = _normalize_alias_field( + params=params, + query=query, + field="before", + parser=_parse_string_field, + ) + if before is not None: + normalized_query["before"] = before + else: + normalized_query.pop("before", None) + return raw_session_id.strip(), normalized_query diff --git a/src/opencode_a2a/opencode_upstream_client.py b/src/opencode_a2a/opencode_upstream_client.py index 2d0783c..4983de9 100644 --- a/src/opencode_a2a/opencode_upstream_client.py +++ b/src/opencode_a2a/opencode_upstream_client.py @@ -45,6 +45,12 @@ class OpencodeMessage: raw: dict[str, Any] +@dataclass(frozen=True) +class OpencodeMessagePage: + payload: Any + next_cursor: str | None + + class _FastFailConcurrencyBudget: def __init__(self, *, category: str, limit: int) -> None: self._category = category @@ -410,23 +416,41 @@ async def abort_session(self, session_id: str, *, directory: str | None = None) params=self._query_params(directory=directory), ) - async def list_sessions(self, *, params: dict[str, Any] | None = None) -> Any: + async def list_sessions( + self, + *, + params: dict[str, Any] | None = None, + directory: str | None = None, + ) -> Any: """List sessions from OpenCode.""" - # Note: directory override is not explicitly supported by list_sessions params yet. - # If needed, we can add it later. For now we use the default. return await self._get_json( "/session", endpoint="/session", - params=self._merge_params(params), + params=self._merge_params(params, directory=directory), ) - async def list_messages(self, session_id: str, *, params: dict[str, Any] | None = None) -> Any: + async def list_messages( + self, + session_id: str, + *, + params: dict[str, Any] | None = None, + ) -> OpencodeMessagePage: """List messages for a session from OpenCode.""" - return await self._get_json( - f"/session/{session_id}/message", - endpoint="/session/{sessionID}/message", - params=self._merge_params(params), - ) + endpoint = "/session/{sessionID}/message" + async with self._request_budget.reserve(operation=endpoint): + response = await self._client.get( + f"/session/{session_id}/message", + params=self._merge_params(params), + ) + response.raise_for_status() + payload = self._decode_json_response(response, endpoint=endpoint) + raw_next_cursor = response.headers.get("X-Next-Cursor") + next_cursor = None + if isinstance(raw_next_cursor, str): + normalized_cursor = raw_next_cursor.strip() + if normalized_cursor: + next_cursor = normalized_cursor + return OpencodeMessagePage(payload=payload, next_cursor=next_cursor) async def session_prompt_async( self, diff --git a/src/opencode_a2a/server/agent_card.py b/src/opencode_a2a/server/agent_card.py index cc8ce9e..1ad4384 100644 --- a/src/opencode_a2a/server/agent_card.py +++ b/src/opencode_a2a/server/agent_card.py @@ -84,8 +84,8 @@ def _build_session_query_skill_examples( capability_snapshot: JsonRpcCapabilitySnapshot, ) -> list[str]: examples = [ - "List OpenCode sessions (method opencode.sessions.list).", - "List messages for a session (method opencode.sessions.messages.list).", + "List OpenCode sessions with filters (method opencode.sessions.list).", + ("List messages with cursor pagination (method opencode.sessions.messages.list)."), "Send async prompt to a session (method opencode.sessions.prompt_async).", "Send command to a session (method opencode.sessions.command).", ] diff --git a/src/opencode_a2a/server/openapi.py b/src/opencode_a2a/server/openapi.py index f6209f7..c8abe49 100644 --- a/src/opencode_a2a/server/openapi.py +++ b/src/opencode_a2a/server/openapi.py @@ -145,7 +145,12 @@ def _build_jsonrpc_extension_openapi_examples( "jsonrpc": "2.0", "id": 1, "method": SESSION_QUERY_METHODS["list_sessions"], - "params": {"limit": SESSION_QUERY_DEFAULT_LIMIT}, + "params": { + "directory": "services/api", + "roots": True, + "search": "planner", + "limit": SESSION_QUERY_DEFAULT_LIMIT, + }, }, }, "session_messages": { @@ -154,7 +159,11 @@ def _build_jsonrpc_extension_openapi_examples( "jsonrpc": "2.0", "id": 2, "method": SESSION_QUERY_METHODS["get_session_messages"], - "params": {"session_id": "s-1", "limit": SESSION_QUERY_DEFAULT_LIMIT}, + "params": { + "session_id": "s-1", + "before": "cursor-1", + "limit": SESSION_QUERY_DEFAULT_LIMIT, + }, }, }, "session_prompt_async": { diff --git a/tests/jsonrpc/test_jsonrpc_params.py b/tests/jsonrpc/test_jsonrpc_params.py index b86e166..26e1ccd 100644 --- a/tests/jsonrpc/test_jsonrpc_params.py +++ b/tests/jsonrpc/test_jsonrpc_params.py @@ -23,6 +23,24 @@ def test_parse_list_sessions_params_accepts_equivalent_query_and_top_level_limit } +def test_parse_list_sessions_params_accepts_filters() -> None: + assert parse_list_sessions_params( + { + "directory": "services/api", + "roots": "true", + "start": "123456789", + "search": "planner", + "limit": "10", + } + ) == { + "directory": "services/api", + "roots": True, + "start": 123456789, + "search": "planner", + "limit": 10, + } + + def test_parse_list_sessions_params_rejects_limit_above_max() -> None: with pytest.raises(JsonRpcParamsValidationError) as exc_info: parse_list_sessions_params({"limit": SESSION_QUERY_MAX_LIMIT + 1}) @@ -50,6 +68,19 @@ def test_parse_get_session_messages_params_applies_default_limit() -> None: assert query == {"limit": SESSION_QUERY_DEFAULT_LIMIT} +def test_parse_get_session_messages_params_accepts_before_cursor() -> None: + session_id, query = parse_get_session_messages_params( + { + "session_id": "s-1", + "before": "cursor-1", + "limit": "5", + } + ) + + assert session_id == "s-1" + assert query == {"limit": 5, "before": "cursor-1"} + + def test_parse_get_session_messages_params_rejects_ambiguous_limit() -> None: with pytest.raises(JsonRpcParamsValidationError) as exc_info: parse_get_session_messages_params({"session_id": "s-1", "limit": 5, "query": {"limit": 6}}) @@ -86,6 +117,30 @@ def test_parse_list_sessions_params_rejects_boolean_limit() -> None: assert exc_info.value.data == {"type": "INVALID_FIELD", "field": "limit"} +def test_parse_list_sessions_params_rejects_ambiguous_directory() -> None: + with pytest.raises(JsonRpcParamsValidationError) as exc_info: + parse_list_sessions_params( + { + "directory": "services/api", + "query": {"directory": "services/web"}, + } + ) + + assert ( + str(exc_info.value) + == "directory is ambiguous between params.directory and params.query.directory" + ) + assert exc_info.value.data == {"type": "INVALID_FIELD", "field": "directory"} + + +def test_parse_get_session_messages_params_rejects_invalid_before_type() -> None: + with pytest.raises(JsonRpcParamsValidationError) as exc_info: + parse_get_session_messages_params({"session_id": "s-1", "before": 123}) + + assert str(exc_info.value) == "before must be a string" + assert exc_info.value.data == {"type": "INVALID_FIELD", "field": "before"} + + def test_parse_get_session_messages_params_trims_session_id() -> None: session_id, query = parse_get_session_messages_params({"session_id": " s-1 "}) diff --git a/tests/jsonrpc/test_opencode_session_extension_queries.py b/tests/jsonrpc/test_opencode_session_extension_queries.py index c636639..d433fab 100644 --- a/tests/jsonrpc/test_opencode_session_extension_queries.py +++ b/tests/jsonrpc/test_opencode_session_extension_queries.py @@ -118,11 +118,123 @@ async def test_session_query_extension_returns_jsonrpc_result(monkeypatch): message = payload["result"]["items"][0] assert message["contextId"] == "ctx:opencode-session:s-1" assert message["parts"][0]["text"] == "SECRET_HISTORY" + assert payload["result"]["next_cursor"] is None assert _session_meta(message)["id"] == "s-1" assert dummy.last_messages_params is not None assert dummy.last_messages_params.get("limit") == 5 +@pytest.mark.asyncio +async def test_session_query_extension_supports_session_filters_and_message_cursor(monkeypatch): + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + dummy._messages_next_cursor = "cursor-2" + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings( + a2a_bearer_token="t-1", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + sessions_resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 10, + "method": "opencode.sessions.list", + "params": { + "directory": "services/api", + "roots": "true", + "start": "12345", + "search": "planner", + "limit": 3, + }, + }, + ) + assert sessions_resp.status_code == 200 + assert dummy.last_sessions_directory == "/workspace/services/api" + assert dummy.last_sessions_params == { + "roots": True, + "start": 12345, + "search": "planner", + "limit": 3, + } + + messages_resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 11, + "method": "opencode.sessions.messages.list", + "params": { + "session_id": "s-1", + "before": "cursor-1", + "limit": 5, + }, + }, + ) + assert messages_resp.status_code == 200 + payload = messages_resp.json() + assert payload["result"]["next_cursor"] == "cursor-2" + assert dummy.last_messages_params == {"before": "cursor-1", "limit": 5} + + +@pytest.mark.asyncio +async def test_session_query_extension_rejects_directory_outside_workspace(monkeypatch): + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings( + a2a_bearer_token="t-1", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 12, + "method": "opencode.sessions.list", + "params": {"directory": "../outside", "limit": 5}, + }, + ) + payload = resp.json() + assert payload["error"]["code"] == -32602 + assert payload["error"]["data"]["field"] == "directory" + + @pytest.mark.asyncio async def test_session_query_extension_applies_default_limit(monkeypatch): import opencode_a2a.server.application as app_module @@ -420,8 +532,8 @@ async def test_session_query_extension_maps_concurrency_limit_to_unreachable(mon import opencode_a2a.server.application as app_module class BusySessionQueryClient(DummyOpencodeUpstreamClient): - async def list_sessions(self, *, params=None): - del params + async def list_sessions(self, *, params=None, directory: str | None = None): + del params, directory raise UpstreamConcurrencyLimitError( category="request", operation="/session", diff --git a/tests/server/test_agent_card.py b/tests/server/test_agent_card.py index 9a57a07..32d5a46 100644 --- a/tests/server/test_agent_card.py +++ b/tests/server/test_agent_card.py @@ -181,10 +181,15 @@ def test_agent_card_injects_profile_into_extensions() -> None: } assert session_query.params["pagination"]["default_limit"] == SESSION_QUERY_DEFAULT_LIMIT assert session_query.params["pagination"]["max_limit"] == SESSION_QUERY_MAX_LIMIT + assert session_query.params["pagination"]["cursor_param"] == "before" + assert session_query.params["pagination"]["result_cursor_field"] == "next_cursor" assert session_query.params["pagination"]["applies_to"] == [ "opencode.sessions.list", "opencode.sessions.messages.list", ] + assert session_query.params["pagination"]["cursor_applies_to"] == [ + "opencode.sessions.messages.list" + ] prompt_contract = session_query.params["method_contracts"]["opencode.sessions.prompt_async"] command_contract = session_query.params["method_contracts"]["opencode.sessions.command"] list_contract = session_query.params["method_contracts"]["opencode.sessions.list"] @@ -197,6 +202,25 @@ def test_agent_card_injects_profile_into_extensions() -> None: "request.arguments", ] assert command_contract["result"]["fields"] == ["item"] + assert list_contract["params"]["optional"] == [ + "limit", + "directory", + "roots", + "start", + "search", + "query.limit", + "query.directory", + "query.roots", + "query.start", + "query.search", + ] + assert messages_contract["params"]["optional"] == [ + "limit", + "before", + "query.limit", + "query.before", + ] + assert messages_contract["result"]["fields"] == ["items", "next_cursor"] assert list_contract["notification_response_status"] == 204 assert messages_contract["notification_response_status"] == 204 assert prompt_contract["notification_response_status"] == 204 diff --git a/tests/server/test_app_behaviors.py b/tests/server/test_app_behaviors.py index 99b46be..06c9385 100644 --- a/tests/server/test_app_behaviors.py +++ b/tests/server/test_app_behaviors.py @@ -476,9 +476,9 @@ def _track_background_task(self, task): # noqa: ANN001 async def test_on_message_send_stream_emits_stable_failure_events_for_task_store_error() -> None: class _Aggregator: async def consume_and_emit(self, _consumer): - del _consumer + if _consumer is None: # pragma: no cover + yield None raise TaskStoreOperationError("save", "task-1") - yield # pragma: no cover class _Handler(OpencodeRequestHandler): def __init__(self) -> None: diff --git a/tests/support/helpers.py b/tests/support/helpers.py index 3639721..3be2e19 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -10,7 +10,7 @@ from a2a.types import Message, MessageSendParams, Part, Role, TextPart from opencode_a2a.config import Settings -from opencode_a2a.opencode_upstream_client import OpencodeMessage +from opencode_a2a.opencode_upstream_client import OpencodeMessage, OpencodeMessagePage def make_settings(**overrides: Any) -> Settings: @@ -212,7 +212,9 @@ def __init__(self, _settings: Settings) -> None: "parts": [{"type": "text", "text": "SECRET_HISTORY"}], } ] + self._messages_next_cursor: str | None = None self.last_sessions_params = None + self.last_sessions_directory: str | None = None self.last_messages_params = None self.prompt_async_calls: list[dict[str, Any]] = [] self.command_calls: list[dict[str, Any]] = [] @@ -266,14 +268,18 @@ def __init__(self, _settings: Settings) -> None: async def close(self) -> None: return None - async def list_sessions(self, *, params=None): + async def list_sessions(self, *, params=None, directory: str | None = None): + self.last_sessions_directory = directory self.last_sessions_params = params return self._sessions_payload async def list_messages(self, session_id: str, *, params=None): assert session_id self.last_messages_params = params - return self._messages_payload + return OpencodeMessagePage( + payload=self._messages_payload, + next_cursor=self._messages_next_cursor, + ) async def session_prompt_async( self, diff --git a/tests/upstream/test_opencode_upstream_client_params.py b/tests/upstream/test_opencode_upstream_client_params.py index 940bd3c..c860691 100644 --- a/tests/upstream/test_opencode_upstream_client_params.py +++ b/tests/upstream/test_opencode_upstream_client_params.py @@ -6,6 +6,7 @@ from opencode_a2a.opencode_upstream_client import ( _UNSET, + OpencodeMessagePage, OpencodeUpstreamClient, UpstreamConcurrencyLimitError, UpstreamContractError, @@ -86,16 +87,51 @@ async def fake_get(path: str, *, params=None, **_kwargs): monkeypatch.setattr(client._client, "get", fake_get) - await client.list_sessions(params={"directory": "/evil", "limit": 1, "roots": True}) + await client.list_sessions( + params={"directory": "/evil", "limit": 1, "roots": True}, + directory="/safe/services/api", + ) assert seen["path"] == "/session" - assert seen["params"]["directory"] == "/safe" + assert seen["params"]["directory"] == "/safe/services/api" assert seen["params"]["limit"] == "1" assert seen["params"]["roots"] == "True" - await client.list_messages("sess-1", params={"directory": "/evil", "limit": 10}) + page = await client.list_messages("sess-1", params={"directory": "/evil", "limit": 10}) assert seen["path"] == "/session/sess-1/message" assert seen["params"]["directory"] == "/safe" assert seen["params"]["limit"] == "10" + assert isinstance(page, OpencodeMessagePage) + assert page.next_cursor is None + + await client.close() + + +@pytest.mark.asyncio +async def test_list_messages_reads_next_cursor_from_headers(monkeypatch): + client = OpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + opencode_workspace_root="/safe", + opencode_timeout=1.0, + a2a_log_level="DEBUG", + a2a_log_payloads=False, + ) + ) + + async def fake_get(path: str, *, params=None, **_kwargs): + del path, params + return _DummyResponse( + payload=[{"info": {"id": "m-1", "role": "assistant"}, "parts": []}], + headers={"X-Next-Cursor": "cursor-2"}, + ) + + monkeypatch.setattr(client._client, "get", fake_get) + + page = await client.list_messages("sess-1", params={"limit": 5, "before": "cursor-1"}) + + assert isinstance(page, OpencodeMessagePage) + assert page.next_cursor == "cursor-2" + assert page.payload == [{"info": {"id": "m-1", "role": "assistant"}, "parts": []}] await client.close()