diff --git a/src/opencode_a2a/contracts/extensions.py b/src/opencode_a2a/contracts/extensions.py index 4e95cdd..e098bf0 100644 --- a/src/opencode_a2a/contracts/extensions.py +++ b/src/opencode_a2a/contracts/extensions.py @@ -257,6 +257,8 @@ class ProviderDiscoveryMethodContract: INTERRUPT_SUCCESS_RESULT_FIELDS: tuple[str, ...] = ("ok", "request_id") INTERRUPT_ERROR_BUSINESS_CODES: dict[str, int] = { "INTERRUPT_REQUEST_NOT_FOUND": -32004, + "INTERRUPT_REQUEST_EXPIRED": -32007, + "INTERRUPT_TYPE_MISMATCH": -32008, "UPSTREAM_UNREACHABLE": -32002, "UPSTREAM_HTTP_ERROR": -32003, } @@ -267,14 +269,19 @@ class ProviderDiscoveryMethodContract: "UPSTREAM_UNREACHABLE", "UPSTREAM_HTTP_ERROR", ) -INTERRUPT_ERROR_DATA_FIELDS: tuple[str, ...] = ("type", "request_id", "upstream_status") +INTERRUPT_ERROR_DATA_FIELDS: tuple[str, ...] = ( + "type", + "request_id", + "expected_interrupt_type", + "actual_interrupt_type", + "upstream_status", + "detail", +) INTERRUPT_INVALID_PARAMS_DATA_FIELDS: tuple[str, ...] = ( "type", "field", "fields", "request_id", - "expected", - "actual", ) PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES: dict[str, int] = { "UPSTREAM_UNREACHABLE": -32002, diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index 0b2419e..77569e1 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -28,6 +28,7 @@ ) from .error_responses import ( interrupt_not_found_error, + interrupt_type_mismatch_error, invalid_params_error, method_not_supported_error, session_forbidden_error, @@ -81,6 +82,8 @@ ERR_UPSTREAM_UNREACHABLE = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] ERR_UPSTREAM_HTTP_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] ERR_INTERRUPT_NOT_FOUND = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_REQUEST_NOT_FOUND"] +ERR_INTERRUPT_EXPIRED = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_REQUEST_EXPIRED"] +ERR_INTERRUPT_TYPE_MISMATCH = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_TYPE_MISMATCH"] ERR_UPSTREAM_PAYLOAD_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_PAYLOAD_ERROR"] ERR_DISCOVERY_UPSTREAM_UNREACHABLE = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] ERR_DISCOVERY_UPSTREAM_HTTP_ERROR = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] @@ -812,7 +815,7 @@ async def _handle_interrupt_callback_request( return self._generate_error_response( base_request.id, interrupt_not_found_error( - ERR_INTERRUPT_NOT_FOUND, + ERR_INTERRUPT_EXPIRED if status == "expired" else ERR_INTERRUPT_NOT_FOUND, request_id=request_id, expired=status == "expired", ), @@ -820,17 +823,11 @@ async def _handle_interrupt_callback_request( if binding.interrupt_type != expected_interrupt_type: return self._generate_error_response( base_request.id, - invalid_params_error( - ( - "Interrupt type mismatch: " - f"expected {expected_interrupt_type}, got {binding.interrupt_type}" - ), - data={ - "type": "INTERRUPT_TYPE_MISMATCH", - "request_id": request_id, - "expected": expected_interrupt_type, - "actual": binding.interrupt_type, - }, + interrupt_type_mismatch_error( + ERR_INTERRUPT_TYPE_MISMATCH, + request_id=request_id, + expected_interrupt_type=expected_interrupt_type, + actual_interrupt_type=binding.interrupt_type, ), ) if ( diff --git a/src/opencode_a2a/jsonrpc/error_responses.py b/src/opencode_a2a/jsonrpc/error_responses.py index 8b9a3ac..b170b0d 100644 --- a/src/opencode_a2a/jsonrpc/error_responses.py +++ b/src/opencode_a2a/jsonrpc/error_responses.py @@ -63,6 +63,25 @@ def interrupt_not_found_error( ) +def interrupt_type_mismatch_error( + code: int, + *, + request_id: str, + expected_interrupt_type: str, + actual_interrupt_type: str, +) -> JSONRPCError: + return JSONRPCError( + code=code, + message="Interrupt callback type mismatch", + data={ + "type": "INTERRUPT_TYPE_MISMATCH", + "request_id": request_id, + "expected_interrupt_type": expected_interrupt_type, + "actual_interrupt_type": actual_interrupt_type, + }, + ) + + def upstream_http_error( code: int, *, @@ -130,6 +149,7 @@ def upstream_payload_error( __all__ = [ "interrupt_not_found_error", + "interrupt_type_mismatch_error", "invalid_params_error", "method_not_supported_error", "session_forbidden_error", diff --git a/tests/jsonrpc/test_error_responses.py b/tests/jsonrpc/test_error_responses.py index 59cdee2..e864f53 100644 --- a/tests/jsonrpc/test_error_responses.py +++ b/tests/jsonrpc/test_error_responses.py @@ -4,6 +4,7 @@ from opencode_a2a.jsonrpc.error_responses import ( interrupt_not_found_error, + interrupt_type_mismatch_error, invalid_params_error, method_not_supported_error, session_forbidden_error, @@ -36,6 +37,19 @@ def test_jsonrpc_error_mapping_helpers_preserve_business_contract_fields() -> No "request_id": "req-1", } + mismatch_interrupt = interrupt_type_mismatch_error( + -32008, + request_id="req-2", + expected_interrupt_type="permission", + actual_interrupt_type="question", + ) + assert mismatch_interrupt.data == { + "type": "INTERRUPT_TYPE_MISMATCH", + "request_id": "req-2", + "expected_interrupt_type": "permission", + "actual_interrupt_type": "question", + } + def test_jsonrpc_error_mapping_helpers_build_upstream_envelopes() -> None: backpressure_detail = ( diff --git a/tests/jsonrpc/test_opencode_session_extension_interrupts.py b/tests/jsonrpc/test_opencode_session_extension_interrupts.py index 93a2227..73b494c 100644 --- a/tests/jsonrpc/test_opencode_session_extension_interrupts.py +++ b/tests/jsonrpc/test_opencode_session_extension_interrupts.py @@ -326,7 +326,7 @@ async def resolve_interrupt_request(self, request_id: str): }, ) payload = resp.json() - assert payload["error"]["code"] == -32004 + assert payload["error"]["code"] == -32007 assert payload["error"]["data"]["type"] == "INTERRUPT_REQUEST_EXPIRED" @@ -412,8 +412,10 @@ class InterruptClient(DummyOpencodeUpstreamClient): }, ) payload = resp.json() - assert payload["error"]["code"] == -32602 + assert payload["error"]["code"] == -32008 assert payload["error"]["data"]["type"] == "INTERRUPT_TYPE_MISMATCH" + assert payload["error"]["data"]["expected_interrupt_type"] == "permission" + assert payload["error"]["data"]["actual_interrupt_type"] == "question" @pytest.mark.asyncio diff --git a/tests/server/test_agent_card.py b/tests/server/test_agent_card.py index d540ff6..5c3ad76 100644 --- a/tests/server/test_agent_card.py +++ b/tests/server/test_agent_card.py @@ -262,6 +262,8 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert interrupt.params["context_fields"]["directory"] == "metadata.opencode.directory" assert interrupt.params["errors"]["business_codes"] == { "INTERRUPT_REQUEST_NOT_FOUND": -32004, + "INTERRUPT_REQUEST_EXPIRED": -32007, + "INTERRUPT_TYPE_MISMATCH": -32008, "UPSTREAM_UNREACHABLE": -32002, "UPSTREAM_HTTP_ERROR": -32003, } @@ -272,13 +274,19 @@ def test_agent_card_injects_profile_into_extensions() -> None: "UPSTREAM_UNREACHABLE", "UPSTREAM_HTTP_ERROR", ] + assert interrupt.params["errors"]["error_data_fields"] == [ + "type", + "request_id", + "expected_interrupt_type", + "actual_interrupt_type", + "upstream_status", + "detail", + ] assert interrupt.params["errors"]["invalid_params_data_fields"] == [ "type", "field", "fields", "request_id", - "expected", - "actual", ] for method_name in ( "a2a.interrupt.permission.reply",