diff --git a/docs/guide.md b/docs/guide.md index 099be4a..49678f3 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -850,6 +850,33 @@ Response: - success => `{"items": [...], "default_by_provider": {...}, "connected": [...]}` (JSON-RPC result) +## Interrupt Recovery (Provider-Private Extension) + +The runtime also exposes provider-private recovery queries for pending +interactive interrupts: + +- `opencode.permissions.list` +- `opencode.questions.list` + +These methods return recovery views over the local interrupt binding registry. +They do not replace the shared `a2a.interrupt.*` callback methods. + +Response shape: + +- success => `{"items": [{"request_id", "session_id", "interrupt_type", "task_id", "context_id", "details", "expires_at"}]}` (JSON-RPC result) + +Notes: + +- Recovery results are scoped to the current authenticated caller identity when + the runtime can resolve one. +- The runtime stores normalized interrupt `details` alongside request bindings, + so recovery results match the shape emitted in + `metadata.shared.interrupt.details`. +- The first implementation stage reads from the local interrupt registry rather + than proxying upstream global `/permission` or `/question` pending lists. +- Use recovery queries to rediscover pending requests after reconnecting; use + `a2a.interrupt.*` methods to resolve them. + ## Shared Interrupt Callback (A2A Extension) When stream metadata reports an interrupt request at `metadata.shared.interrupt`, @@ -871,7 +898,8 @@ clients can reply through JSON-RPC extension methods: Notes: - `request_id` must be a live interrupt request observed from stream metadata - (`metadata.shared.interrupt.request_id`). + (`metadata.shared.interrupt.request_id`) or rediscovered through + `opencode.permissions.list` / `opencode.questions.list`. - The server keeps an interrupt binding registry; callbacks with unknown or expired `request_id` are rejected. - The cache retention windows are controlled by diff --git a/src/opencode_a2a/contracts/extensions.py b/src/opencode_a2a/contracts/extensions.py index e098bf0..3179fdb 100644 --- a/src/opencode_a2a/contracts/extensions.py +++ b/src/opencode_a2a/contracts/extensions.py @@ -20,6 +20,7 @@ SESSION_QUERY_EXTENSION_URI = "urn:opencode-a2a:session-query/v1" PROVIDER_DISCOVERY_EXTENSION_URI = "urn:opencode-a2a:provider-discovery/v1" INTERRUPT_CALLBACK_EXTENSION_URI = "urn:a2a:interactive-interrupt/v1" +INTERRUPT_RECOVERY_EXTENSION_URI = "urn:opencode-a2a:interrupt-recovery/v1" COMPATIBILITY_PROFILE_EXTENSION_URI = "urn:a2a:compatibility-profile/v1" WIRE_CONTRACT_EXTENSION_URI = "urn:a2a:wire-contract/v1" SERVICE_BEHAVIOR_CLASSIFICATION = "service-level-semantic-enhancement" @@ -57,6 +58,16 @@ class ProviderDiscoveryMethodContract: notification_response_status: int | None = None +@dataclass(frozen=True) +class InterruptRecoveryMethodContract: + method: str + required_params: tuple[str, ...] = () + optional_params: tuple[str, ...] = () + result_fields: tuple[str, ...] = () + items_type: str | None = None + notification_response_status: int | None = None + + PROMPT_ASYNC_REQUEST_REQUIRED_FIELDS: tuple[str, ...] = ("parts",) PROMPT_ASYNC_REQUEST_OPTIONAL_FIELDS: tuple[str, ...] = ( "messageID", @@ -254,6 +265,25 @@ class ProviderDiscoveryMethodContract: key: contract.method for key, contract in PROVIDER_DISCOVERY_METHOD_CONTRACTS.items() } +INTERRUPT_RECOVERY_METHOD_CONTRACTS: dict[str, InterruptRecoveryMethodContract] = { + "list_permissions": InterruptRecoveryMethodContract( + method="opencode.permissions.list", + result_fields=("items",), + items_type="InterruptRequest[]", + notification_response_status=204, + ), + "list_questions": InterruptRecoveryMethodContract( + method="opencode.questions.list", + result_fields=("items",), + items_type="InterruptRequest[]", + notification_response_status=204, + ), +} + +INTERRUPT_RECOVERY_METHODS: dict[str, str] = { + key: contract.method for key, contract in INTERRUPT_RECOVERY_METHOD_CONTRACTS.items() +} + INTERRUPT_SUCCESS_RESULT_FIELDS: tuple[str, ...] = ("ok", "request_id") INTERRUPT_ERROR_BUSINESS_CODES: dict[str, int] = { "INTERRUPT_REQUEST_NOT_FOUND": -32004, @@ -299,6 +329,11 @@ class ProviderDiscoveryMethodContract: "field", "fields", ) +INTERRUPT_RECOVERY_INVALID_PARAMS_DATA_FIELDS: tuple[str, ...] = ( + "type", + "field", + "fields", +) @dataclass(frozen=True) @@ -358,6 +393,9 @@ def session_control_methods(self) -> dict[str, str]: def provider_discovery_methods(self) -> dict[str, str]: return dict(PROVIDER_DISCOVERY_METHODS) + def interrupt_recovery_methods(self) -> dict[str, str]: + return dict(INTERRUPT_RECOVERY_METHODS) + def interrupt_callback_methods(self) -> dict[str, str]: return dict(INTERRUPT_CALLBACK_METHODS) @@ -369,6 +407,7 @@ def supported_jsonrpc_methods(self) -> list[str]: SESSION_CONTROL_METHODS["prompt_async"], SESSION_CONTROL_METHODS["command"], *PROVIDER_DISCOVERY_METHODS.values(), + *INTERRUPT_RECOVERY_METHODS.values(), *INTERRUPT_CALLBACK_METHODS.values(), ] if self.is_method_enabled(SESSION_CONTROL_METHODS["shell"]): @@ -382,6 +421,7 @@ def extension_jsonrpc_methods(self) -> list[str]: SESSION_CONTROL_METHODS["prompt_async"], SESSION_CONTROL_METHODS["command"], *PROVIDER_DISCOVERY_METHODS.values(), + *INTERRUPT_RECOVERY_METHODS.values(), *INTERRUPT_CALLBACK_METHODS.values(), ] if self.is_method_enabled(SESSION_CONTROL_METHODS["shell"]): @@ -684,6 +724,66 @@ def build_interrupt_callback_extension_params( } +def build_interrupt_recovery_extension_params( + *, + runtime_profile: RuntimeProfile, +) -> dict[str, Any]: + method_contracts: dict[str, Any] = {} + + for method_contract in INTERRUPT_RECOVERY_METHOD_CONTRACTS.values(): + params_contract = _build_method_contract_params( + required=method_contract.required_params, + optional=method_contract.optional_params, + unsupported=(), + ) + result_contract: dict[str, Any] = {"fields": list(method_contract.result_fields)} + if method_contract.items_type: + result_contract["items_type"] = method_contract.items_type + contract_doc: dict[str, Any] = { + "params": params_contract, + "result": result_contract, + } + if method_contract.notification_response_status is not None: + contract_doc["notification_response_status"] = ( + method_contract.notification_response_status + ) + method_contracts[method_contract.method] = contract_doc + + return { + "methods": dict(INTERRUPT_RECOVERY_METHODS), + "method_contracts": method_contracts, + "supported_metadata": [], + "provider_private_metadata": [], + "item_fields": { + "request_id": "items[].request_id", + "session_id": "items[].session_id", + "interrupt_type": "items[].interrupt_type", + "task_id": "items[].task_id", + "context_id": "items[].context_id", + "details": "items[].details", + "expires_at": "items[].expires_at", + }, + "errors": { + "invalid_params_data_fields": list(INTERRUPT_RECOVERY_INVALID_PARAMS_DATA_FIELDS), + }, + "profile": runtime_profile.summary_dict(), + "notes": [ + ( + "Interrupt recovery methods read from the local interrupt binding registry " + "instead of directly proxying upstream global pending lists." + ), + ( + "Results are scoped to the current authenticated caller identity when the " + "runtime can resolve one." + ), + ( + "Use a2a.interrupt.* methods to resolve requests; opencode.permissions.list " + "and opencode.questions.list are recovery surfaces only." + ), + ], + } + + def build_provider_discovery_extension_params( *, runtime_profile: RuntimeProfile, @@ -800,6 +900,17 @@ def build_compatibility_profile_params( for method in PROVIDER_DISCOVERY_METHODS.values() } ) + method_retention.update( + { + method: { + "surface": "extension", + "availability": "always", + "retention": "stable", + "extension_uri": INTERRUPT_RECOVERY_EXTENSION_URI, + } + for method in INTERRUPT_RECOVERY_METHODS.values() + } + ) method_retention.update( { method: { @@ -843,6 +954,11 @@ def build_compatibility_profile_params( "availability": "always", "retention": "stable", }, + INTERRUPT_RECOVERY_EXTENSION_URI: { + "surface": "jsonrpc-extension", + "availability": "always", + "retention": "stable", + }, INTERRUPT_CALLBACK_EXTENSION_URI: { "surface": "jsonrpc-extension", "availability": "always", @@ -862,9 +978,9 @@ def build_compatibility_profile_params( "surface for the main chat path; provider defaults still belong to OpenCode." ), ( - "Treat opencode.sessions.*, opencode.providers.*, and opencode.models.* as " - "provider-private operational surfaces rather than portable A2A baseline " - "capabilities." + "Treat opencode.sessions.*, opencode.providers.*, opencode.models.*, " + "opencode.permissions.list, and opencode.questions.list as provider-private " + "operational surfaces rather than portable A2A baseline capabilities." ), ( "Treat a2a.interrupt.* methods as declared shared extensions and opencode.* " @@ -911,6 +1027,7 @@ def build_wire_contract_params( STREAMING_EXTENSION_URI, SESSION_QUERY_EXTENSION_URI, PROVIDER_DISCOVERY_EXTENSION_URI, + INTERRUPT_RECOVERY_EXTENSION_URI, INTERRUPT_CALLBACK_EXTENSION_URI, ], }, diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index 59ad183..9725de0 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -436,6 +436,7 @@ def _tool_chunks( identity=identity, task_id=task_id, context_id=context_id, + details=asked["details"], ) await _emit_interrupt_status( state=TaskState.input_required, diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index 413cbb1..1c1cc5c 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -84,6 +84,8 @@ def __init__( self._method_shell = methods.get("shell") self._method_list_providers = methods["list_providers"] self._method_list_models = methods["list_models"] + self._method_list_permissions = methods["list_permissions"] + self._method_list_questions = methods["list_questions"] self._method_reply_permission = methods["reply_permission"] self._method_reply_question = methods["reply_question"] self._method_reject_question = methods["reject_question"] @@ -116,6 +118,8 @@ def __init__( method_shell=self._method_shell, method_list_providers=self._method_list_providers, method_list_models=self._method_list_models, + method_list_permissions=self._method_list_permissions, + method_list_questions=self._method_list_questions, method_reply_permission=self._method_reply_permission, method_reply_question=self._method_reply_question, method_reject_question=self._method_reject_question, diff --git a/src/opencode_a2a/jsonrpc/dispatch.py b/src/opencode_a2a/jsonrpc/dispatch.py index f726b36..33c5591 100644 --- a/src/opencode_a2a/jsonrpc/dispatch.py +++ b/src/opencode_a2a/jsonrpc/dispatch.py @@ -37,6 +37,8 @@ class ExtensionHandlerContext: method_shell: str | None method_list_providers: str method_list_models: str + method_list_permissions: str + method_list_questions: str method_reply_permission: str method_reply_question: str method_reject_question: str @@ -89,6 +91,7 @@ def build_extension_method_registry( context: ExtensionHandlerContext, ) -> ExtensionMethodRegistry: from .handlers.interrupt_callbacks import handle_interrupt_callback_request + from .handlers.interrupt_queries import handle_interrupt_query_request from .handlers.provider_discovery import handle_provider_discovery_request from .handlers.session_control import handle_session_control_request from .handlers.session_queries import handle_session_query_request @@ -119,6 +122,16 @@ def build_extension_method_registry( ), handler=handle_provider_discovery_request, ), + ExtensionMethodSpec( + name="interrupt_query", + methods=frozenset( + { + context.method_list_permissions, + context.method_list_questions, + } + ), + handler=handle_interrupt_query_request, + ), ExtensionMethodSpec( name="session_control", methods=frozenset(session_control_methods), diff --git a/src/opencode_a2a/jsonrpc/handlers/interrupt_queries.py b/src/opencode_a2a/jsonrpc/handlers/interrupt_queries.py new file mode 100644 index 0000000..cfd0d61 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/interrupt_queries.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from typing import Any + +from a2a.types import JSONRPCRequest +from starlette.requests import Request +from starlette.responses import Response + +from ..dispatch import ExtensionHandlerContext +from ..error_responses import invalid_params_error +from .common import build_internal_error_response, build_success_response + + +def _binding_to_result_item(binding: Any) -> dict[str, Any]: + return { + "request_id": binding.request_id, + "session_id": binding.session_id, + "interrupt_type": binding.interrupt_type, + "task_id": binding.task_id, + "context_id": binding.context_id, + "details": dict(binding.details) if isinstance(binding.details, dict) else None, + "expires_at": binding.expires_at, + } + + +async def handle_interrupt_query_request( + context: ExtensionHandlerContext, + base_request: JSONRPCRequest, + params: dict[str, Any], + request: Request, +) -> Response: + unknown_fields = sorted(params) + if unknown_fields: + return context.error_response( + base_request.id, + invalid_params_error( + f"Unsupported fields: {', '.join(unknown_fields)}", + data={"type": "INVALID_FIELD", "fields": unknown_fields}, + ), + ) + + request_identity = getattr(request.state, "user_identity", None) + identity = request_identity.strip() if isinstance(request_identity, str) else "" + if not identity: + return build_success_response(context, base_request.id, {"items": []}) + + try: + if base_request.method == context.method_list_permissions: + items = await context.upstream_client.list_permission_requests(identity=identity) + else: + items = await context.upstream_client.list_question_requests(identity=identity) + except Exception as exc: + return build_internal_error_response( + context, + base_request.id, + log_message="Interrupt recovery JSON-RPC method failed", + exc=exc, + ) + + return build_success_response( + context, + base_request.id, + {"items": [_binding_to_result_item(item) for item in items]}, + ) diff --git a/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py index 25c6313..3b17dd8 100644 --- a/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py +++ b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py @@ -29,9 +29,7 @@ logger = logging.getLogger(__name__) -ERR_DISCOVERY_UPSTREAM_UNREACHABLE = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES[ - "UPSTREAM_UNREACHABLE" -] +ERR_DISCOVERY_UPSTREAM_UNREACHABLE = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] ERR_DISCOVERY_UPSTREAM_HTTP_ERROR = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] ERR_DISCOVERY_UPSTREAM_PAYLOAD_ERROR = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES[ "UPSTREAM_PAYLOAD_ERROR" diff --git a/src/opencode_a2a/opencode_upstream_client.py b/src/opencode_a2a/opencode_upstream_client.py index 0af00d5..2d0783c 100644 --- a/src/opencode_a2a/opencode_upstream_client.py +++ b/src/opencode_a2a/opencode_upstream_client.py @@ -233,6 +233,7 @@ async def remember_interrupt_request( identity: str | None = None, task_id: str | None = None, context_id: str | None = None, + details: dict[str, Any] | None = None, ttl_seconds: float | None = None, ) -> None: request = request_id.strip() @@ -250,6 +251,7 @@ async def remember_interrupt_request( context_id=( context_id.strip() if isinstance(context_id, str) and context_id.strip() else None ), + details=dict(details) if isinstance(details, dict) else None, ttl_seconds=ttl_seconds, ) @@ -275,6 +277,27 @@ async def discard_interrupt_request(self, request_id: str) -> None: return await self._interrupt_request_repository.discard(request_id=request) + async def list_interrupt_requests( + self, + *, + identity: str, + interrupt_type: str | None = None, + ) -> list[InterruptRequestBinding]: + normalized_identity = identity.strip() + if not normalized_identity: + return [] + self._sync_interrupt_clock() + return await self._interrupt_request_repository.list_pending( + identity=normalized_identity, + interrupt_type=interrupt_type, + ) + + async def list_permission_requests(self, *, identity: str) -> list[InterruptRequestBinding]: + return await self.list_interrupt_requests(identity=identity, interrupt_type="permission") + + async def list_question_requests(self, *, identity: str) -> list[InterruptRequestBinding]: + return await self.list_interrupt_requests(identity=identity, interrupt_type="question") + @property def stream_timeout(self) -> float | None: return self._stream_timeout diff --git a/src/opencode_a2a/runtime_state.py b/src/opencode_a2a/runtime_state.py index 39efbc6..65720ad 100644 --- a/src/opencode_a2a/runtime_state.py +++ b/src/opencode_a2a/runtime_state.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Any @dataclass(frozen=True) @@ -11,6 +12,7 @@ class InterruptRequestBinding: identity: str | None task_id: str | None context_id: str | None + details: dict[str, Any] | None expires_at: float diff --git a/src/opencode_a2a/server/agent_card.py b/src/opencode_a2a/server/agent_card.py index b0f58e9..cc8ce9e 100644 --- a/src/opencode_a2a/server/agent_card.py +++ b/src/opencode_a2a/server/agent_card.py @@ -15,6 +15,7 @@ from ..contracts.extensions import ( COMPATIBILITY_PROFILE_EXTENSION_URI, INTERRUPT_CALLBACK_EXTENSION_URI, + INTERRUPT_RECOVERY_EXTENSION_URI, MODEL_SELECTION_EXTENSION_URI, PROVIDER_DISCOVERY_EXTENSION_URI, SESSION_BINDING_EXTENSION_URI, @@ -26,6 +27,7 @@ build_capability_snapshot, build_compatibility_profile_params, build_interrupt_callback_extension_params, + build_interrupt_recovery_extension_params, build_model_selection_extension_params, build_provider_discovery_extension_params, build_session_binding_extension_params, @@ -44,8 +46,8 @@ def _build_agent_card_description(settings: Settings, runtime_profile: RuntimePr "(message/send, message/stream), task APIs (tasks/get, tasks/cancel, " "tasks/resubscribe; REST mapping: GET /v1/tasks/{id}:subscribe), shared " "session-binding/model-selection/streaming contracts, provider-private " - "OpenCode session/provider/model extensions, and shared interrupt " - "callback extensions." + "OpenCode session/provider/model/interrupt recovery extensions, and " + "shared interrupt callback extensions." ) parts: list[str] = [base, summary] parts.append( @@ -92,6 +94,13 @@ def _build_session_query_skill_examples( return examples +def _build_interrupt_recovery_skill_examples() -> list[str]: + return [ + "List pending permission interrupts (method opencode.permissions.list).", + "List pending question interrupts (method opencode.questions.list).", + ] + + def build_agent_card(settings: Settings) -> AgentCard: public_url = settings.a2a_public_url.rstrip("/") base_url = public_url @@ -122,6 +131,9 @@ def build_agent_card(settings: Settings) -> AgentCard: provider_discovery_extension_params = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + interrupt_recovery_extension_params = build_interrupt_recovery_extension_params( + runtime_profile=runtime_profile, + ) interrupt_callback_extension_params = build_interrupt_callback_extension_params( runtime_profile=runtime_profile, ) @@ -196,6 +208,15 @@ def build_agent_card(settings: Settings) -> AgentCard: ), params=provider_discovery_extension_params, ), + AgentExtension( + uri=INTERRUPT_RECOVERY_EXTENSION_URI, + required=False, + description=( + "Expose provider-private interrupt recovery methods so clients can " + "list pending permission/question requests after reconnecting." + ), + params=interrupt_recovery_extension_params, + ), AgentExtension( uri=INTERRUPT_CALLBACK_EXTENSION_URI, required=False, @@ -264,6 +285,16 @@ def build_agent_card(settings: Settings) -> AgentCard: "List available models for a provider (method opencode.models.list).", ], ), + AgentSkill( + id="opencode.interrupt.recovery", + name="OpenCode Interrupt Recovery", + description=( + "provider-private OpenCode interrupt recovery surface exposed through " + "JSON-RPC extensions." + ), + tags=["interrupt", "permission", "question", "provider-private"], + examples=_build_interrupt_recovery_skill_examples(), + ), AgentSkill( id="opencode.interrupt.callback", name="Shared Interrupt Callback", diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index 9f1b812..acd0e25 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -48,6 +48,8 @@ COMPATIBILITY_PROFILE_EXTENSION_URI, INTERRUPT_CALLBACK_EXTENSION_URI, INTERRUPT_CALLBACK_METHODS, + INTERRUPT_RECOVERY_EXTENSION_URI, + INTERRUPT_RECOVERY_METHODS, MODEL_SELECTION_EXTENSION_URI, PROVIDER_DISCOVERY_EXTENSION_URI, PROVIDER_DISCOVERY_METHODS, @@ -109,6 +111,8 @@ "COMPATIBILITY_PROFILE_EXTENSION_URI", "INTERRUPT_CALLBACK_EXTENSION_URI", "INTERRUPT_CALLBACK_METHODS", + "INTERRUPT_RECOVERY_EXTENSION_URI", + "INTERRUPT_RECOVERY_METHODS", "MODEL_SELECTION_EXTENSION_URI", "PROVIDER_DISCOVERY_EXTENSION_URI", "PROVIDER_DISCOVERY_METHODS", @@ -691,6 +695,7 @@ def create_app(settings: Settings) -> FastAPI: jsonrpc_methods = { **capability_snapshot.session_query_methods(), **capability_snapshot.provider_discovery_methods(), + **capability_snapshot.interrupt_recovery_methods(), **capability_snapshot.interrupt_callback_methods(), } diff --git a/src/opencode_a2a/server/openapi.py b/src/opencode_a2a/server/openapi.py index ec8387f..f6209f7 100644 --- a/src/opencode_a2a/server/openapi.py +++ b/src/opencode_a2a/server/openapi.py @@ -7,6 +7,7 @@ from ..config import Settings from ..contracts.extensions import ( INTERRUPT_CALLBACK_METHODS, + INTERRUPT_RECOVERY_METHODS, PROVIDER_DISCOVERY_METHODS, SESSION_QUERY_DEFAULT_LIMIT, SESSION_QUERY_METHODS, @@ -14,6 +15,7 @@ build_capability_snapshot, build_compatibility_profile_params, build_interrupt_callback_extension_params, + build_interrupt_recovery_extension_params, build_model_selection_extension_params, build_provider_discovery_extension_params, build_session_binding_extension_params, @@ -31,14 +33,16 @@ def _build_jsonrpc_extension_openapi_description( ) -> str: session_methods = list(capability_snapshot.session_query_methods().values()) provider_methods = ", ".join(sorted(PROVIDER_DISCOVERY_METHODS.values())) + interrupt_recovery_methods = ", ".join(sorted(INTERRUPT_RECOVERY_METHODS.values())) interrupt_methods = ", ".join(sorted(INTERRUPT_CALLBACK_METHODS.values())) return ( "A2A JSON-RPC entrypoint. Supports core A2A methods " "(message/send, message/stream, tasks/get, tasks/cancel, tasks/resubscribe) " "plus shared model-selection metadata, OpenCode session/provider extensions, " - "and shared interrupt callback methods.\n\n" + "interrupt recovery extensions, and shared interrupt callback methods.\n\n" f"OpenCode session query/control methods: {', '.join(session_methods)}.\n" f"OpenCode provider/model discovery methods: {provider_methods}.\n" + f"OpenCode interrupt recovery methods: {interrupt_recovery_methods}.\n" f"Shared interrupt callback methods: {interrupt_methods}.\n\n" "Notification semantics: extension requests without JSON-RPC id return HTTP 204." ) @@ -200,6 +204,24 @@ def _build_jsonrpc_extension_openapi_examples( "params": {"provider_id": "openai"}, }, }, + "permissions_list": { + "summary": "List pending permission interrupts for the current caller", + "value": { + "jsonrpc": "2.0", + "id": 26, + "method": INTERRUPT_RECOVERY_METHODS["list_permissions"], + "params": {}, + }, + }, + "questions_list": { + "summary": "List pending question interrupts for the current caller", + "value": { + "jsonrpc": "2.0", + "id": 27, + "method": INTERRUPT_RECOVERY_METHODS["list_questions"], + "params": {}, + }, + }, "permission_reply": { "summary": "Reply to permission interrupt request", "value": { @@ -334,6 +356,9 @@ def _patch_jsonrpc_openapi_contract( provider_discovery = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + interrupt_recovery = build_interrupt_recovery_extension_params( + runtime_profile=runtime_profile, + ) interrupt_callback = build_interrupt_callback_extension_params( runtime_profile=runtime_profile, ) @@ -369,6 +394,7 @@ def custom_openapi() -> dict[str, Any]: "streaming": streaming, "session_query": session_query, "provider_discovery": provider_discovery, + "interrupt_recovery": interrupt_recovery, "interrupt_callback": interrupt_callback, "compatibility_profile": compatibility_profile, "wire_contract": wire_contract, diff --git a/src/opencode_a2a/server/request_parsing.py b/src/opencode_a2a/server/request_parsing.py index 5b929e5..fe3bdb7 100644 --- a/src/opencode_a2a/server/request_parsing.py +++ b/src/opencode_a2a/server/request_parsing.py @@ -5,7 +5,11 @@ from fastapi.responses import JSONResponse -from ..contracts.extensions import INTERRUPT_CALLBACK_METHODS, SESSION_QUERY_METHODS +from ..contracts.extensions import ( + INTERRUPT_CALLBACK_METHODS, + INTERRUPT_RECOVERY_METHODS, + SESSION_QUERY_METHODS, +) logger = logging.getLogger(__name__) @@ -24,8 +28,10 @@ def _detect_sensitive_extension_method(payload: dict | None) -> str | None: method = payload.get("method") if not isinstance(method, str): return None - sensitive_methods = set(SESSION_QUERY_METHODS.values()) | set( - INTERRUPT_CALLBACK_METHODS.values() + sensitive_methods = ( + set(SESSION_QUERY_METHODS.values()) + | set(INTERRUPT_CALLBACK_METHODS.values()) + | set(INTERRUPT_RECOVERY_METHODS.values()) ) if method in sensitive_methods: return method diff --git a/src/opencode_a2a/server/state_store.py b/src/opencode_a2a/server/state_store.py index 90ec36e..3b53900 100644 --- a/src/opencode_a2a/server/state_store.py +++ b/src/opencode_a2a/server/state_store.py @@ -1,9 +1,10 @@ from __future__ import annotations +import json import time from abc import ABC, abstractmethod from collections.abc import Callable -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import ( Column, @@ -60,6 +61,7 @@ Column("identity", String, nullable=True), Column("task_id", String, nullable=True), Column("context_id", String, nullable=True), + Column("details_json", String, nullable=True), Column("expires_at", Float, nullable=True), Column("tombstone_expires_at", Float, nullable=True), ) @@ -110,6 +112,7 @@ async def remember( identity: str | None, task_id: str | None, context_id: str | None, + details: dict[str, Any] | None, ttl_seconds: float | None, ) -> None: ... @@ -123,6 +126,14 @@ async def resolve( @abstractmethod async def discard(self, *, request_id: str) -> None: ... + @abstractmethod + async def list_pending( + self, + *, + identity: str, + interrupt_type: str | None = None, + ) -> list[InterruptRequestBinding]: ... + class MemorySessionStateRepository(SessionStateRepository): def __init__( @@ -426,6 +437,7 @@ async def remember( identity: str | None, task_id: str | None, context_id: str | None, + details: dict[str, Any] | None, ttl_seconds: float | None, ) -> None: now = self._clock() @@ -439,6 +451,7 @@ async def remember( identity=identity, task_id=task_id, context_id=context_id, + details=dict(details) if isinstance(details, dict) else None, expires_at=now + max(0.0, float(ttl)), ) self._interrupt_request_tombstones.pop(request_id, None) @@ -469,8 +482,43 @@ async def discard(self, *, request_id: str) -> None: self._interrupt_requests.pop(request_id, None) self._interrupt_request_tombstones.pop(request_id, None) + async def list_pending( + self, + *, + identity: str, + interrupt_type: str | None = None, + ) -> list[InterruptRequestBinding]: + now = self._clock() + self._prune_interrupt_requests(now=now) + self._prune_interrupt_request_tombstones(now=now) + normalized_type = interrupt_type.strip() if isinstance(interrupt_type, str) else None + items = [ + binding + for binding in self._interrupt_requests.values() + if binding.identity == identity + and (normalized_type is None or binding.interrupt_type == normalized_type) + and binding.expires_at > now + ] + return sorted(items, key=lambda item: (item.expires_at, item.request_id)) + class DatabaseInterruptRequestRepository(InterruptRequestRepository): + @staticmethod + def _encode_details(details: dict[str, Any] | None) -> str | None: + if not isinstance(details, dict): + return None + return json.dumps(details, ensure_ascii=False, sort_keys=True, separators=(",", ":")) + + @staticmethod + def _decode_details(value: Any) -> dict[str, Any] | None: + if not isinstance(value, str) or not value: + return None + try: + decoded = json.loads(value) + except json.JSONDecodeError: + return None + return decoded if isinstance(decoded, dict) else None + def __init__( self, *, @@ -522,6 +570,7 @@ async def _set_tombstone(self, session: AsyncSession, *, request_id: str, now: f identity=None, task_id=None, context_id=None, + details_json=None, expires_at=None, tombstone_expires_at=tombstone_expires_at, ) @@ -536,6 +585,7 @@ async def remember( identity: str | None, task_id: str | None, context_id: str | None, + details: dict[str, Any] | None, ttl_seconds: float | None, ) -> None: await self._ensure_initialized() @@ -555,6 +605,7 @@ async def remember( "identity": identity, "task_id": task_id, "context_id": context_id, + "details_json": self._encode_details(details), "expires_at": expires_at, "tombstone_expires_at": None, } @@ -604,6 +655,7 @@ async def resolve( identity=cast("str | None", row["identity"]), task_id=cast("str | None", row["task_id"]), context_id=cast("str | None", row["context_id"]), + details=self._decode_details(row.get("details_json")), expires_at=cast("float", expires_at), ), ) @@ -615,6 +667,49 @@ async def discard(self, *, request_id: str) -> None: delete(_INTERRUPT_REQUESTS).where(_INTERRUPT_REQUESTS.c.request_id == request_id) ) + async def list_pending( + self, + *, + identity: str, + interrupt_type: str | None = None, + ) -> list[InterruptRequestBinding]: + await self._ensure_initialized() + now = self._clock() + normalized_type = interrupt_type.strip() if isinstance(interrupt_type, str) else None + async with self._session_maker.begin() as session: + await self._prune_tombstones(session, now=now) + stmt = ( + select(_INTERRUPT_REQUESTS) + .where( + and_( + _INTERRUPT_REQUESTS.c.identity == identity, + _INTERRUPT_REQUESTS.c.expires_at.is_not(None), + _INTERRUPT_REQUESTS.c.expires_at > now, + ) + ) + .order_by( + _INTERRUPT_REQUESTS.c.expires_at.asc(), + _INTERRUPT_REQUESTS.c.request_id.asc(), + ) + ) + if normalized_type is not None: + stmt = stmt.where(_INTERRUPT_REQUESTS.c.interrupt_type == normalized_type) + result = await session.execute(stmt) + rows = result.mappings().all() + return [ + InterruptRequestBinding( + request_id=cast("str", row["request_id"]), + session_id=cast("str", row["session_id"]), + interrupt_type=cast("str", row["interrupt_type"]), + identity=cast("str | None", row["identity"]), + task_id=cast("str | None", row["task_id"]), + context_id=cast("str | None", row["context_id"]), + details=self._decode_details(row.get("details_json")), + expires_at=cast("float", row["expires_at"]), + ) + for row in rows + ] + def build_session_state_repository( settings: Settings, diff --git a/tests/contracts/test_extension_contract_consistency.py b/tests/contracts/test_extension_contract_consistency.py index 89e78e0..9a00fed 100644 --- a/tests/contracts/test_extension_contract_consistency.py +++ b/tests/contracts/test_extension_contract_consistency.py @@ -8,6 +8,7 @@ build_capability_snapshot, build_compatibility_profile_params, build_interrupt_callback_extension_params, + build_interrupt_recovery_extension_params, build_model_selection_extension_params, build_provider_discovery_extension_params, build_session_binding_extension_params, @@ -20,6 +21,7 @@ from opencode_a2a.server.application import ( COMPATIBILITY_PROFILE_EXTENSION_URI, INTERRUPT_CALLBACK_EXTENSION_URI, + INTERRUPT_RECOVERY_EXTENSION_URI, MODEL_SELECTION_EXTENSION_URI, PROVIDER_DISCOVERY_EXTENSION_URI, SESSION_BINDING_EXTENSION_URI, @@ -44,6 +46,7 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: streaming = ext_by_uri[STREAMING_EXTENSION_URI] session_query = ext_by_uri[SESSION_QUERY_EXTENSION_URI] provider_discovery = ext_by_uri[PROVIDER_DISCOVERY_EXTENSION_URI] + interrupt_recovery = ext_by_uri[INTERRUPT_RECOVERY_EXTENSION_URI] interrupt_callback = ext_by_uri[INTERRUPT_CALLBACK_EXTENSION_URI] compatibility_profile = ext_by_uri[COMPATIBILITY_PROFILE_EXTENSION_URI] wire_contract = ext_by_uri[WIRE_CONTRACT_EXTENSION_URI] @@ -63,6 +66,9 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: expected_provider_discovery = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + expected_interrupt_recovery = build_interrupt_recovery_extension_params( + runtime_profile=runtime_profile, + ) assert expected_session_query["pagination"]["default_limit"] == SESSION_QUERY_DEFAULT_LIMIT assert expected_session_query["pagination"]["max_limit"] == SESSION_QUERY_MAX_LIMIT expected_interrupt_callback = build_interrupt_callback_extension_params( @@ -92,6 +98,9 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: assert provider_discovery.params == expected_provider_discovery, ( "Provider discovery extension drifted from contracts.extensions SSOT." ) + assert interrupt_recovery.params == expected_interrupt_recovery, ( + "Interrupt recovery extension drifted from contracts.extensions SSOT." + ) assert interrupt_callback.params == expected_interrupt_callback, ( "Interrupt callback extension drifted from contracts.extensions SSOT." ) @@ -118,6 +127,7 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: streaming = contract["streaming"] session_query = contract["session_query"] provider_discovery = contract["provider_discovery"] + interrupt_recovery = contract["interrupt_recovery"] interrupt_callback = contract["interrupt_callback"] compatibility_profile = contract["compatibility_profile"] wire_contract = contract["wire_contract"] @@ -137,6 +147,9 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: expected_provider_discovery = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + expected_interrupt_recovery = build_interrupt_recovery_extension_params( + runtime_profile=runtime_profile, + ) expected_interrupt_callback = build_interrupt_callback_extension_params( runtime_profile=runtime_profile, ) @@ -164,6 +177,9 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: assert provider_discovery == expected_provider_discovery, ( "OpenAPI provider discovery contract drifted from contracts.extensions SSOT." ) + assert interrupt_recovery == expected_interrupt_recovery, ( + "OpenAPI interrupt recovery contract drifted from contracts.extensions SSOT." + ) assert interrupt_callback == expected_interrupt_callback, ( "OpenAPI interrupt callback contract drifted from contracts.extensions SSOT." ) @@ -194,7 +210,12 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: expected_methods = set(session_query["methods"].values()) | set( INTERRUPT_CALLBACK_METHODS.values() ) - expected_methods |= {"opencode.providers.list", "opencode.models.list"} + expected_methods |= { + "opencode.providers.list", + "opencode.models.list", + "opencode.permissions.list", + "opencode.questions.list", + } missing_methods = sorted(method for method in expected_methods if method not in example_methods) assert not missing_methods, ( "OpenAPI JSON-RPC examples are missing extension methods: " + ", ".join(missing_methods) @@ -274,6 +295,8 @@ async def test_runtime_supported_methods_align_with_capability_snapshot( ), ("opencode.providers.list", {}, None), ("opencode.models.list", {"provider_id": "openai"}, None), + ("opencode.permissions.list", {}, None), + ("opencode.questions.list", {}, None), ( "a2a.interrupt.permission.reply", {"request_id": "req-perm", "reply": "once"}, diff --git a/tests/execution/test_metrics.py b/tests/execution/test_metrics.py index ceb0acd..b3c773d 100644 --- a/tests/execution/test_metrics.py +++ b/tests/execution/test_metrics.py @@ -120,9 +120,10 @@ async def remember_interrupt_request( identity: str | None = None, task_id: str | None = None, context_id: str | None = None, + details: dict | None = None, ttl_seconds: float | None = None, ) -> None: - del interrupt_type, identity, task_id, context_id, ttl_seconds + del interrupt_type, identity, task_id, context_id, details, ttl_seconds self._interrupt_requests[request_id] = session_id async def discard_interrupt_request(self, request_id: str) -> None: diff --git a/tests/execution/test_streaming_output_contract_interrupts.py b/tests/execution/test_streaming_output_contract_interrupts.py index a9e3eee..585d7cc 100644 --- a/tests/execution/test_streaming_output_contract_interrupts.py +++ b/tests/execution/test_streaming_output_contract_interrupts.py @@ -57,6 +57,10 @@ async def test_streaming_emits_interrupt_status_for_permission_asked_event() -> assert "metadata" not in interrupt["details"] assert "tool" not in interrupt["details"] assert interrupt_statuses[0].status.state == TaskState.input_required + assert client._interrupt_requests["perm-req-1"]["details"] == { + "permission": "read", + "patterns": ["/data/project/.env.secret"], + } @pytest.mark.asyncio @@ -97,6 +101,15 @@ async def test_streaming_emits_interrupt_status_for_question_asked_event() -> No ] assert "tool" not in interrupt["details"] assert interrupt_statuses[0].status.state == TaskState.input_required + assert client._interrupt_requests["q-req-1"]["details"] == { + "questions": [ + { + "header": "Confirm", + "question": "Proceed?", + "options": [{"label": "Yes", "value": "yes"}], + } + ] + } @pytest.mark.asyncio diff --git a/tests/jsonrpc/test_dispatch_registry.py b/tests/jsonrpc/test_dispatch_registry.py index 70d1fc5..141c6cb 100644 --- a/tests/jsonrpc/test_dispatch_registry.py +++ b/tests/jsonrpc/test_dispatch_registry.py @@ -26,6 +26,7 @@ async def test_extension_registry_tracks_configured_methods(monkeypatch) -> None registry_methods = _jsonrpc_app(app)._extension_method_registry.methods() # noqa: SLF001 assert "opencode.sessions.list" in registry_methods assert "opencode.providers.list" in registry_methods + assert "opencode.permissions.list" in registry_methods assert "a2a.interrupt.permission.reply" in registry_methods assert "opencode.sessions.shell" not in registry_methods diff --git a/tests/jsonrpc/test_opencode_session_extension_queries.py b/tests/jsonrpc/test_opencode_session_extension_queries.py index c2b97ba..c636639 100644 --- a/tests/jsonrpc/test_opencode_session_extension_queries.py +++ b/tests/jsonrpc/test_opencode_session_extension_queries.py @@ -1,3 +1,4 @@ +import hashlib import logging import httpx @@ -16,6 +17,10 @@ from tests.support.session_extensions import _BASE_SETTINGS, _session_meta +def _identity_for_token(token: str) -> str: + return f"bearer:{hashlib.sha256(token.encode()).hexdigest()[:12]}" + + @pytest.mark.asyncio async def test_session_query_extension_requires_bearer_token(monkeypatch): import opencode_a2a.server.application as app_module @@ -442,6 +447,146 @@ async def list_sessions(self, *, params=None): assert "concurrency limit exceeded" in payload["error"]["data"]["detail"] +@pytest.mark.asyncio +async def test_interrupt_recovery_extension_returns_identity_scoped_items(monkeypatch): + import opencode_a2a.server.application as app_module + + token = "t-1" + identity = _identity_for_token(token) + dummy = DummyOpencodeUpstreamClient( + make_settings( + a2a_bearer_token=token, + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + await dummy.remember_interrupt_request( + request_id="perm-1", + session_id="ses-1", + interrupt_type="permission", + identity=identity, + task_id="task-1", + context_id="ctx-1", + details={"permission": "read", "patterns": ["/tmp/config.yml"]}, + ) + await dummy.remember_interrupt_request( + request_id="q-1", + session_id="ses-2", + interrupt_type="question", + identity=identity, + task_id="task-2", + context_id="ctx-2", + details={"questions": [{"question": "Proceed?"}]}, + ) + await dummy.remember_interrupt_request( + request_id="perm-other", + session_id="ses-3", + interrupt_type="permission", + identity="bearer:other-user", + task_id="task-3", + context_id="ctx-3", + details={"permission": "write"}, + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings( + a2a_bearer_token=token, + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": f"Bearer {token}"} + permission_resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 201, + "method": "opencode.permissions.list", + "params": {}, + }, + ) + question_resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 202, + "method": "opencode.questions.list", + "params": {}, + }, + ) + + permission_items = permission_resp.json()["result"]["items"] + question_items = question_resp.json()["result"]["items"] + assert [item["request_id"] for item in permission_items] == ["perm-1"] + assert permission_items[0]["details"] == { + "permission": "read", + "patterns": ["/tmp/config.yml"], + } + assert [item["request_id"] for item in question_items] == ["q-1"] + assert question_items[0]["details"] == {"questions": [{"question": "Proceed?"}]} + + +@pytest.mark.asyncio +async def test_interrupt_recovery_extension_rejects_unsupported_fields(monkeypatch): + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + resp = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 203, + "method": "opencode.permissions.list", + "params": {"session_id": "ses-1"}, + }, + ) + + payload = resp.json() + assert payload["error"]["code"] == -32602 + assert payload["error"]["data"]["fields"] == ["session_id"] + + +@pytest.mark.asyncio +async def test_interrupt_recovery_extension_notification_returns_204(monkeypatch): + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer t-1"}, + json={"jsonrpc": "2.0", "method": "opencode.questions.list", "params": {}}, + ) + + assert response.status_code == 204 + + @pytest.mark.asyncio async def test_session_query_extension_session_title_is_extracted_or_placeholder(monkeypatch): import opencode_a2a.server.application as app_module diff --git a/tests/server/test_agent_card.py b/tests/server/test_agent_card.py index 5c3ad76..9a57a07 100644 --- a/tests/server/test_agent_card.py +++ b/tests/server/test_agent_card.py @@ -7,6 +7,7 @@ from opencode_a2a.server.application import ( COMPATIBILITY_PROFILE_EXTENSION_URI, INTERRUPT_CALLBACK_EXTENSION_URI, + INTERRUPT_RECOVERY_EXTENSION_URI, MODEL_SELECTION_EXTENSION_URI, PROVIDER_DISCOVERY_EXTENSION_URI, SESSION_BINDING_EXTENSION_URI, @@ -254,6 +255,23 @@ def test_agent_card_injects_profile_into_extensions() -> None: "UPSTREAM_PAYLOAD_ERROR": -32005, } + interrupt_recovery = ext_by_uri[INTERRUPT_RECOVERY_EXTENSION_URI] + assert interrupt_recovery.params["profile"]["runtime_context"]["project"] == "alpha" + assert interrupt_recovery.params["methods"] == { + "list_permissions": "opencode.permissions.list", + "list_questions": "opencode.questions.list", + } + assert interrupt_recovery.params["method_contracts"]["opencode.permissions.list"]["result"] == { + "fields": ["items"], + "items_type": "InterruptRequest[]", + } + assert interrupt_recovery.params["item_fields"]["details"] == "items[].details" + assert interrupt_recovery.params["errors"]["invalid_params_data_fields"] == [ + "type", + "field", + "fields", + ] + interrupt = ext_by_uri[INTERRUPT_CALLBACK_EXTENSION_URI] assert interrupt.params["profile"]["runtime_context"]["project"] == "alpha" assert interrupt.params["request_id_field"] == "metadata.shared.interrupt.request_id" @@ -339,6 +357,7 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert wire_contract.params["profile"]["profile_id"] == "opencode-a2a-single-tenant-coding-v1" assert MODEL_SELECTION_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert PROVIDER_DISCOVERY_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] + assert INTERRUPT_RECOVERY_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert "opencode.sessions.shell" not in wire_contract.params["all_jsonrpc_methods"] assert wire_contract.params["service_behaviors"] == expected_service_behaviors assert wire_contract.params["extensions"]["conditionally_available_methods"] == { @@ -398,6 +417,12 @@ def test_agent_card_skills_hide_shell_when_disabled_by_default() -> None: assert all("opencode.sessions.shell" not in example for example in session_skill.examples) assert "provider-private" in provider_skill.tags assert any("opencode.providers.list" in example for example in provider_skill.examples) + interrupt_recovery_skill = next( + skill for skill in card.skills if skill.id == "opencode.interrupt.recovery" + ) + assert any( + "opencode.permissions.list" in example for example in interrupt_recovery_skill.examples + ) def test_agent_card_hides_shell_when_policy_disables_it() -> None: diff --git a/tests/server/test_database_app_persistence.py b/tests/server/test_database_app_persistence.py index 74618c9..45e5731 100644 --- a/tests/server/test_database_app_persistence.py +++ b/tests/server/test_database_app_persistence.py @@ -1,5 +1,6 @@ from __future__ import annotations +import hashlib from pathlib import Path import httpx @@ -83,6 +84,7 @@ async def remember_interrupt_request( identity: str | None = None, task_id: str | None = None, context_id: str | None = None, + details: dict | None = None, ttl_seconds: float | None = None, ) -> None: assert self._interrupt_request_repository is not None @@ -93,6 +95,7 @@ async def remember_interrupt_request( identity=identity, task_id=task_id, context_id=context_id, + details=details, ttl_seconds=ttl_seconds, ) @@ -106,6 +109,20 @@ async def resolve_interrupt_session(self, request_id: str) -> str | None: return None return binding.session_id + async def list_permission_requests(self, *, identity: str): + assert self._interrupt_request_repository is not None + return await self._interrupt_request_repository.list_pending( + identity=identity, + interrupt_type="permission", + ) + + async def list_question_requests(self, *, identity: str): + assert self._interrupt_request_repository is not None + return await self._interrupt_request_repository.list_pending( + identity=identity, + interrupt_type="question", + ) + async def discard_interrupt_request(self, request_id: str) -> None: assert self._interrupt_request_repository is not None await self._interrupt_request_repository.discard(request_id=request_id) @@ -156,9 +173,10 @@ async def permission_reply( request_id="perm-1", session_id=session_id, interrupt_type="permission", - identity=None, + identity=f"bearer:{hashlib.sha256(b'test-token').hexdigest()[:12]}", task_id="task-1", context_id="ctx-1", + details={"permission": "read", "patterns": ["/tmp/config.yml"]}, ttl_seconds=60.0, ) @@ -182,6 +200,16 @@ async def permission_reply( transport = httpx.ASGITransport(app=app2) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + query_response = await client.post( + "/", + headers={"Authorization": "Bearer test-token"}, + json={ + "jsonrpc": "2.0", + "id": 0, + "method": "opencode.permissions.list", + "params": {}, + }, + ) response = await client.post( "/", headers={"Authorization": "Bearer test-token"}, @@ -196,6 +224,19 @@ async def permission_reply( }, ) + query_payload = query_response.json() + assert query_payload["result"]["items"] == [ + { + "request_id": "perm-1", + "session_id": "ses-1", + "interrupt_type": "permission", + "task_id": "task-1", + "context_id": "ctx-1", + "details": {"permission": "read", "patterns": ["/tmp/config.yml"]}, + "expires_at": query_payload["result"]["items"][0]["expires_at"], + } + ] + payload = response.json() assert payload.get("error") is None assert payload["result"]["ok"] is True diff --git a/tests/server/test_state_store.py b/tests/server/test_state_store.py index 41887cf..4860b73 100644 --- a/tests/server/test_state_store.py +++ b/tests/server/test_state_store.py @@ -142,6 +142,7 @@ async def test_database_interrupt_request_repository_persists_active_binding( identity="user-1", task_id="task-1", context_id="ctx-1", + details={"permission": "read", "patterns": ["/tmp/config.yml"]}, ttl_seconds=30.0, ) await engine.dispose() @@ -158,5 +159,67 @@ async def test_database_interrupt_request_repository_persists_active_binding( assert binding.identity == "user-1" assert binding.task_id == "task-1" assert binding.context_id == "ctx-1" + assert binding.details == {"permission": "read", "patterns": ["/tmp/config.yml"]} + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_interrupt_request_repository_lists_pending_items_by_identity_and_type( + tmp_path: Path, +) -> None: + database_url = f"sqlite+aiosqlite:///{tmp_path / 'interrupt-list.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_database_url=database_url, + ) + engine = build_database_engine(settings) + repository = build_interrupt_request_repository(settings, engine=engine) + await initialize_state_repository(repository) + + await repository.remember( + request_id="perm-1", + session_id="ses-1", + interrupt_type="permission", + identity="user-1", + task_id="task-1", + context_id="ctx-1", + details={"permission": "read"}, + ttl_seconds=60.0, + ) + await repository.remember( + request_id="q-1", + session_id="ses-2", + interrupt_type="question", + identity="user-1", + task_id="task-2", + context_id="ctx-2", + details={"questions": [{"question": "Proceed?"}]}, + ttl_seconds=60.0, + ) + await repository.remember( + request_id="perm-other", + session_id="ses-3", + interrupt_type="permission", + identity="user-2", + task_id="task-3", + context_id="ctx-3", + details={"permission": "write"}, + ttl_seconds=60.0, + ) + + permission_items = await repository.list_pending( + identity="user-1", + interrupt_type="permission", + ) + question_items = await repository.list_pending( + identity="user-1", + interrupt_type="question", + ) + + assert [item.request_id for item in permission_items] == ["perm-1"] + assert permission_items[0].details == {"permission": "read"} + assert [item.request_id for item in question_items] == ["q-1"] + assert question_items[0].details == {"questions": [{"question": "Proceed?"}]} await engine.dispose() diff --git a/tests/support/helpers.py b/tests/support/helpers.py index 82ee026..3639721 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -179,9 +179,19 @@ async def remember_interrupt_request( identity: str | None = None, task_id: str | None = None, context_id: str | None = None, + details: dict[str, Any] | None = None, ttl_seconds: float | None = None, ) -> None: - del request_id, session_id, interrupt_type, identity, task_id, context_id, ttl_seconds + del ( + request_id, + session_id, + interrupt_type, + identity, + task_id, + context_id, + details, + ttl_seconds, + ) async def resolve_interrupt_session(self, request_id: str) -> str | None: del request_id @@ -251,6 +261,7 @@ def __init__(self, _settings: Settings) -> None: "connected": ["openai"], } self._interrupt_requests: dict[str, dict[str, str | None]] = {} + self._interrupt_request_details: dict[str, dict[str, Any] | None] = {} async def close(self) -> None: return None @@ -331,6 +342,7 @@ async def remember_interrupt_request( identity: str | None = None, task_id: str | None = None, context_id: str | None = None, + details: dict[str, Any] | None = None, ttl_seconds: float | None = None, ) -> None: del ttl_seconds @@ -341,6 +353,9 @@ async def remember_interrupt_request( "task_id": task_id, "context_id": context_id, } + self._interrupt_request_details[request_id] = ( + dict(details) if isinstance(details, dict) else None + ) async def resolve_interrupt_request(self, request_id: str): payload = self._interrupt_requests.get(request_id) @@ -349,11 +364,15 @@ async def resolve_interrupt_request(self, request_id: str): class _Binding: def __init__(self, data: dict[str, str | None]) -> None: + self.request_id = request_id self.session_id = data.get("session_id") self.interrupt_type = data.get("interrupt_type") self.identity = data.get("identity") self.task_id = data.get("task_id") self.context_id = data.get("context_id") + self.details = self_details + + self_details = self._interrupt_request_details.get(request_id) return "active", _Binding(payload) @@ -365,6 +384,51 @@ async def resolve_interrupt_session(self, request_id: str) -> str | None: async def discard_interrupt_request(self, request_id: str) -> None: self._interrupt_requests.pop(request_id, None) + self._interrupt_request_details.pop(request_id, None) + + async def list_interrupt_requests( + self, + *, + identity: str, + interrupt_type: str | None = None, + ): + class _Binding: + def __init__( + self, + *, + request_id: str, + data: dict[str, str | None], + details: dict[str, Any] | None, + ) -> None: + self.request_id = request_id + self.session_id = data.get("session_id") + self.interrupt_type = data.get("interrupt_type") + self.identity = data.get("identity") + self.task_id = data.get("task_id") + self.context_id = data.get("context_id") + self.details = details + self.expires_at = 0.0 + + items = [] + for request_id, payload in self._interrupt_requests.items(): + if payload.get("identity") != identity: + continue + if interrupt_type is not None and payload.get("interrupt_type") != interrupt_type: + continue + items.append( + _Binding( + request_id=request_id, + data=payload, + details=self._interrupt_request_details.get(request_id), + ) + ) + return items + + async def list_permission_requests(self, *, identity: str): + return await self.list_interrupt_requests(identity=identity, interrupt_type="permission") + + async def list_question_requests(self, *, identity: str): + return await self.list_interrupt_requests(identity=identity, interrupt_type="question") async def permission_reply( self, diff --git a/tests/support/streaming_output.py b/tests/support/streaming_output.py index 4cf8727..aa221a5 100644 --- a/tests/support/streaming_output.py +++ b/tests/support/streaming_output.py @@ -36,6 +36,7 @@ def __init__( self.stream_timeout = None self.directory = None self._interrupt_sessions: dict[str, str] = {} + self._interrupt_requests: dict[str, dict] = {} self.settings = make_settings( a2a_bearer_token="test", opencode_base_url="http://localhost", @@ -97,16 +98,26 @@ async def remember_interrupt_request( identity: str | None = None, task_id: str | None = None, context_id: str | None = None, + details: dict | None = None, ttl_seconds: float | None = None, ) -> None: - del interrupt_type, identity, task_id, context_id, ttl_seconds + del ttl_seconds self._interrupt_sessions[request_id] = session_id + self._interrupt_requests[request_id] = { + "session_id": session_id, + "interrupt_type": interrupt_type, + "identity": identity, + "task_id": task_id, + "context_id": context_id, + "details": details, + } async def resolve_interrupt_session(self, request_id: str) -> str | None: return self._interrupt_sessions.get(request_id) async def discard_interrupt_request(self, request_id: str) -> None: self._interrupt_sessions.pop(request_id, None) + self._interrupt_requests.pop(request_id, None) def _event( diff --git a/tests/upstream/test_opencode_upstream_client_params.py b/tests/upstream/test_opencode_upstream_client_params.py index 5f9ce52..940bd3c 100644 --- a/tests/upstream/test_opencode_upstream_client_params.py +++ b/tests/upstream/test_opencode_upstream_client_params.py @@ -593,6 +593,7 @@ async def test_interrupt_request_binding_expires_after_ttl() -> None: task_id="task-1", context_id="ctx-1", identity="user-1", + details={"permission": "read"}, ttl_seconds=5.0, ) @@ -601,6 +602,7 @@ async def test_interrupt_request_binding_expires_after_ttl() -> None: assert binding is not None assert binding.session_id == "ses-1" assert binding.interrupt_type == "permission" + assert binding.details == {"permission": "read"} now = 1006.0 status, binding = await client.resolve_interrupt_request("perm-1") @@ -983,6 +985,7 @@ async def test_interrupt_request_helpers_ignore_invalid_and_trim_values() -> Non identity=" user-1 ", task_id=" task-1 ", context_id=" ctx-1 ", + details={"questions": [{"question": "Proceed?"}]}, ) status, binding = await client.resolve_interrupt_request("perm-3") assert status == "active" @@ -992,6 +995,7 @@ async def test_interrupt_request_helpers_ignore_invalid_and_trim_values() -> Non assert binding.identity == "user-1" assert binding.task_id == "task-1" assert binding.context_id == "ctx-1" + assert binding.details == {"questions": [{"question": "Proceed?"}]} assert await client.resolve_interrupt_request(" ") == ("missing", None) await client.discard_interrupt_request(" ") @@ -999,3 +1003,47 @@ async def test_interrupt_request_helpers_ignore_invalid_and_trim_values() -> Non assert await client.resolve_interrupt_session("perm-3") is None await client.close() + + +@pytest.mark.asyncio +async def test_interrupt_request_helpers_list_pending_by_identity_and_type() -> None: + client = OpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + opencode_timeout=1.0, + a2a_log_level="DEBUG", + a2a_log_payloads=False, + ) + ) + + await client.remember_interrupt_request( + request_id="perm-1", + session_id="ses-1", + interrupt_type="permission", + identity="user-1", + details={"permission": "read"}, + ) + await client.remember_interrupt_request( + request_id="q-1", + session_id="ses-2", + interrupt_type="question", + identity="user-1", + details={"questions": [{"question": "Proceed?"}]}, + ) + await client.remember_interrupt_request( + request_id="perm-2", + session_id="ses-3", + interrupt_type="permission", + identity="user-2", + details={"permission": "write"}, + ) + + permissions = await client.list_permission_requests(identity="user-1") + questions = await client.list_question_requests(identity="user-1") + + assert [item.request_id for item in permissions] == ["perm-1"] + assert permissions[0].details == {"permission": "read"} + assert [item.request_id for item in questions] == ["q-1"] + assert questions[0].details == {"questions": [{"question": "Proceed?"}]} + + await client.close()