From f10c20337202ab37d1a1fc3b346995598d50cc11 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Fri, 27 Mar 2026 02:01:14 -0400 Subject: [PATCH] refactor(jsonrpc): register extension method handlers --- src/opencode_a2a/jsonrpc/application.py | 844 +----------------- src/opencode_a2a/jsonrpc/dispatch.py | 139 +++ src/opencode_a2a/jsonrpc/handlers/__init__.py | 1 + src/opencode_a2a/jsonrpc/handlers/common.py | 251 ++++++ .../jsonrpc/handlers/interrupt_callbacks.py | 205 +++++ .../jsonrpc/handlers/provider_discovery.py | 151 ++++ .../jsonrpc/handlers/session_control.py | 265 ++++++ .../jsonrpc/handlers/session_queries.py | 134 +++ tests/jsonrpc/test_dispatch_registry.py | 119 +++ 9 files changed, 1302 insertions(+), 807 deletions(-) create mode 100644 src/opencode_a2a/jsonrpc/dispatch.py create mode 100644 src/opencode_a2a/jsonrpc/handlers/__init__.py create mode 100644 src/opencode_a2a/jsonrpc/handlers/common.py create mode 100644 src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py create mode 100644 src/opencode_a2a/jsonrpc/handlers/provider_discovery.py create mode 100644 src/opencode_a2a/jsonrpc/handlers/session_control.py create mode 100644 src/opencode_a2a/jsonrpc/handlers/session_queries.py create mode 100644 tests/jsonrpc/test_dispatch_registry.py diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index 77569e1..413cbb1 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -4,11 +4,9 @@ from collections.abc import Awaitable, Callable from typing import Any, cast -import httpx from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication from a2a.types import ( A2AError, - InternalError, InvalidRequestError, JSONRPCRequest, ) @@ -16,34 +14,19 @@ from starlette.requests import Request from starlette.responses import Response -from ..contracts.extensions import ( - INTERRUPT_ERROR_BUSINESS_CODES, - PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES, - SESSION_QUERY_ERROR_BUSINESS_CODES, -) -from ..opencode_upstream_client import ( - OpencodeUpstreamClient, - UpstreamConcurrencyLimitError, - UpstreamContractError, +from ..opencode_upstream_client import OpencodeUpstreamClient +from .dispatch import ( + CORE_JSONRPC_METHODS, + ExtensionHandlerContext, + build_extension_method_registry, ) from .error_responses import ( - interrupt_not_found_error, - interrupt_type_mismatch_error, invalid_params_error, method_not_supported_error, - session_forbidden_error, - session_not_found_error, - upstream_http_error, - upstream_payload_error, - upstream_unreachable_error, ) from .methods import ( SESSION_CONTEXT_PREFIX, - _apply_session_query_limit, - _as_a2a_message, - _as_a2a_session_task, _extract_provider_catalog, - _extract_raw_items, _normalize_model_summaries, _normalize_permission_reply, _normalize_provider_summaries, @@ -52,45 +35,25 @@ _validate_command_request_payload, _validate_prompt_async_format, _validate_prompt_async_part, - _validate_prompt_async_request_payload, _validate_shell_request_payload, ) -from .params import ( - JsonRpcParamsValidationError, - parse_get_session_messages_params, - parse_list_sessions_params, -) logger = logging.getLogger(__name__) __all__ = [ "SESSION_CONTEXT_PREFIX", - "_PromptAsyncValidationError", "_extract_provider_catalog", "_normalize_model_summaries", "_normalize_permission_reply", "_normalize_provider_summaries", "_parse_question_answers", + "_PromptAsyncValidationError", "_validate_command_request_payload", "_validate_prompt_async_format", "_validate_prompt_async_part", "_validate_shell_request_payload", ] -ERR_SESSION_NOT_FOUND = SESSION_QUERY_ERROR_BUSINESS_CODES["SESSION_NOT_FOUND"] -ERR_SESSION_FORBIDDEN = SESSION_QUERY_ERROR_BUSINESS_CODES["SESSION_FORBIDDEN"] -ERR_UPSTREAM_UNREACHABLE = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] -ERR_UPSTREAM_HTTP_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] -ERR_INTERRUPT_NOT_FOUND = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_REQUEST_NOT_FOUND"] -ERR_INTERRUPT_EXPIRED = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_REQUEST_EXPIRED"] -ERR_INTERRUPT_TYPE_MISMATCH = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_TYPE_MISMATCH"] -ERR_UPSTREAM_PAYLOAD_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_PAYLOAD_ERROR"] -ERR_DISCOVERY_UPSTREAM_UNREACHABLE = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] -ERR_DISCOVERY_UPSTREAM_HTTP_ERROR = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] -ERR_DISCOVERY_UPSTREAM_PAYLOAD_ERROR = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES[ - "UPSTREAM_PAYLOAD_ERROR" -] - class OpencodeSessionQueryJSONRPCApplication(A2AFastAPIApplication): """Extend A2A JSON-RPC endpoint with OpenCode session methods. @@ -144,80 +107,30 @@ def __init__( self._session_claim = cast(Callable[..., Awaitable[bool]], session_claim) self._session_claim_finalize = cast(Callable[..., Awaitable[None]], session_claim_finalize) self._session_claim_release = cast(Callable[..., Awaitable[None]], session_claim_release) - - def _session_forbidden_response( - self, - request_id: str | int | None, - *, - session_id: str, - ) -> Response: - return self._generate_error_response( - request_id, - session_forbidden_error(ERR_SESSION_FORBIDDEN, session_id=session_id), + self._extension_handler_context = ExtensionHandlerContext( + upstream_client=self._upstream_client, + method_list_sessions=self._method_list_sessions, + method_get_session_messages=self._method_get_session_messages, + method_prompt_async=self._method_prompt_async, + method_command=self._method_command, + method_shell=self._method_shell, + method_list_providers=self._method_list_providers, + method_list_models=self._method_list_models, + method_reply_permission=self._method_reply_permission, + method_reply_question=self._method_reply_question, + method_reject_question=self._method_reject_question, + protocol_version=self._protocol_version, + supported_methods=tuple(self._supported_methods), + directory_resolver=self._directory_resolver, + session_claim=self._session_claim, + session_claim_finalize=self._session_claim_finalize, + session_claim_release=self._session_claim_release, + error_response=self._generate_error_response, + success_response=self._jsonrpc_success_response, + ) + self._extension_method_registry = build_extension_method_registry( + self._extension_handler_context ) - - def _extract_directory_from_metadata( - self, - *, - request_id: str | int | None, - params: dict[str, Any], - ) -> tuple[str | None, Response | None]: - metadata = params.get("metadata") - if metadata is not None and not isinstance(metadata, dict): - return None, self._generate_error_response( - request_id, - invalid_params_error( - "metadata must be an object", - data={"type": "INVALID_FIELD", "field": "metadata"}, - ), - ) - - opencode_metadata: dict[str, Any] | None = None - if isinstance(metadata, dict): - unknown_metadata_fields = sorted(set(metadata) - {"opencode", "shared"}) - if unknown_metadata_fields: - prefixed_fields = [f"metadata.{field}" for field in unknown_metadata_fields] - return None, self._generate_error_response( - request_id, - invalid_params_error( - f"Unsupported metadata fields: {', '.join(prefixed_fields)}", - data={"type": "INVALID_FIELD", "fields": prefixed_fields}, - ), - ) - raw_opencode_metadata = metadata.get("opencode") - if raw_opencode_metadata is not None and not isinstance(raw_opencode_metadata, dict): - return None, self._generate_error_response( - request_id, - invalid_params_error( - "metadata.opencode must be an object", - data={"type": "INVALID_FIELD", "field": "metadata.opencode"}, - ), - ) - if isinstance(raw_opencode_metadata, dict): - opencode_metadata = raw_opencode_metadata - raw_shared_metadata = metadata.get("shared") - if raw_shared_metadata is not None and not isinstance(raw_shared_metadata, dict): - return None, self._generate_error_response( - request_id, - invalid_params_error( - "metadata.shared must be an object", - data={"type": "INVALID_FIELD", "field": "metadata.shared"}, - ), - ) - - directory = None - if opencode_metadata is not None: - directory = opencode_metadata.get("directory") - if directory is not None and not isinstance(directory, str): - return None, self._generate_error_response( - request_id, - invalid_params_error( - "metadata.opencode.directory must be a string", - data={"type": "INVALID_FIELD", "field": "metadata.opencode.directory"}, - ), - ) - - return directory, None async def _handle_requests(self, request: Request) -> Response: # Fast path: sniff method first then either handle here or delegate. @@ -240,42 +153,10 @@ async def _handle_requests(self, request: Request) -> Response: # Delegate to base implementation for consistent error handling. return await super()._handle_requests(request) - session_query_methods = { - self._method_list_sessions, - self._method_get_session_messages, - } - provider_discovery_methods = { - self._method_list_providers, - self._method_list_models, - } - session_control_methods = { - self._method_prompt_async, - self._method_command, - } - if self._method_shell is not None: - session_control_methods.add(self._method_shell) - interrupt_callback_methods = { - self._method_reply_permission, - self._method_reply_question, - self._method_reject_question, - } - if ( - base_request.method - not in session_query_methods - | provider_discovery_methods - | session_control_methods - | interrupt_callback_methods - ): - core_methods = { - "message/send", - "message/stream", - "tasks/get", - "tasks/cancel", - "tasks/resubscribe", - } - if base_request.method in core_methods: + extension_spec = self._extension_method_registry.resolve(base_request.method) + if extension_spec is None: + if base_request.method in CORE_JSONRPC_METHODS: return await super()._handle_requests(request) - if base_request.id is None: return Response(status_code=204) @@ -294,664 +175,13 @@ async def _handle_requests(self, request: Request) -> Response: base_request.id, invalid_params_error("params must be an object"), ) - - if base_request.method in session_query_methods: - return await self._handle_session_query_request(base_request, params) - if base_request.method in provider_discovery_methods: - return await self._handle_provider_discovery_request(base_request, params) - if base_request.method in session_control_methods: - return await self._handle_session_control_request( - base_request, - params, - request=request, - ) - return await self._handle_interrupt_callback_request(base_request, params, request=request) - - async def _handle_session_query_request( - self, - base_request: JSONRPCRequest, - params: dict[str, Any], - ) -> Response: - try: - if base_request.method == self._method_list_sessions: - query = parse_list_sessions_params(params) - session_id: str | None = None - else: - session_id, query = parse_get_session_messages_params(params) - except JsonRpcParamsValidationError as exc: - return self._generate_error_response( - base_request.id, - invalid_params_error(str(exc), data=exc.data), - ) - - limit = int(query["limit"]) - try: - if base_request.method == self._method_list_sessions: - raw_result = await self._upstream_client.list_sessions(params=query) - else: - assert session_id is not None - raw_result = await self._upstream_client.list_messages(session_id, params=query) - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - if upstream_status == 404 and base_request.method == self._method_get_session_messages: - assert session_id is not None - return self._generate_error_response( - base_request.id, - session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), - ) - return self._generate_error_response( - base_request.id, - upstream_http_error( - ERR_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - ), - ) - except httpx.HTTPError: - return self._generate_error_response( - base_request.id, - upstream_unreachable_error(ERR_UPSTREAM_UNREACHABLE), - ) - except UpstreamConcurrencyLimitError as exc: - return self._generate_error_response( - base_request.id, - upstream_unreachable_error( - ERR_UPSTREAM_UNREACHABLE, - detail=str(exc), - ), - ) - except Exception as exc: - logger.exception("OpenCode session query JSON-RPC method failed") - return self._generate_error_response( - base_request.id, - A2AError(root=InternalError(message=str(exc))), - ) - - try: - if base_request.method == self._method_list_sessions: - raw_items = _extract_raw_items(raw_result, kind="sessions") - else: - raw_items = _extract_raw_items(raw_result, kind="messages") - except ValueError as exc: - logger.warning("Upstream OpenCode payload mismatch: %s", exc) - return self._generate_error_response( - base_request.id, - upstream_payload_error( - ERR_UPSTREAM_PAYLOAD_ERROR, - detail=str(exc), - ), - ) - - # Protocol: items are always arrays of A2A objects. - # Task for sessions; Message for messages. - if base_request.method == self._method_list_sessions: - mapped: list[dict[str, Any]] = [] - for item in raw_items: - task = _as_a2a_session_task(item) - if task is not None: - mapped.append(task) - # OpenCode documents `limit` for message history, not for session list. - # Enforce the adapter contract locally so the declared pagination stays true. - items: list[dict[str, Any]] = _apply_session_query_limit(mapped, limit=limit) - else: - assert session_id is not None - mapped = [] - for item in raw_items: - message = _as_a2a_message(session_id, item) - if message is not None: - mapped.append(message) - items = mapped - - result = { - "items": items, - } - - # Notifications (id omitted) should not yield a response. - if base_request.id is None: - return Response(status_code=204) - - return self._jsonrpc_success_response( - base_request.id, - result, + return await extension_spec.handler( + self._extension_handler_context, + base_request, + params, + request, ) - async def _handle_provider_discovery_request( - self, - base_request: JSONRPCRequest, - params: dict[str, Any], - ) -> Response: - allowed_fields = {"metadata"} - if base_request.method == self._method_list_models: - allowed_fields.add("provider_id") - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - prefixed_fields = [f"params.{field}" for field in unknown_fields] - return self._generate_error_response( - base_request.id, - invalid_params_error( - f"Unsupported params fields: {', '.join(prefixed_fields)}", - data={"type": "INVALID_FIELD", "fields": prefixed_fields}, - ), - ) - - provider_id: str | None = None - if base_request.method == self._method_list_models: - raw_provider_id = params.get("provider_id") - if raw_provider_id is not None: - if not isinstance(raw_provider_id, str) or not raw_provider_id.strip(): - return self._generate_error_response( - base_request.id, - invalid_params_error( - "provider_id must be a non-empty string", - data={"type": "INVALID_FIELD", "field": "provider_id"}, - ), - ) - provider_id = raw_provider_id.strip() - - directory, metadata_error = self._extract_directory_from_metadata( - request_id=base_request.id, - params=params, - ) - if metadata_error is not None: - return metadata_error - - try: - directory = self._directory_resolver(directory) - except ValueError as exc: - return self._generate_error_response( - base_request.id, - invalid_params_error( - str(exc), - data={"type": "INVALID_FIELD", "field": "metadata.opencode.directory"}, - ), - ) - - try: - raw_result = await self._upstream_client.list_provider_catalog(directory=directory) - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - return self._generate_error_response( - base_request.id, - upstream_http_error( - ERR_DISCOVERY_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - method=base_request.method, - ), - ) - except httpx.HTTPError: - return self._generate_error_response( - base_request.id, - upstream_unreachable_error( - ERR_DISCOVERY_UPSTREAM_UNREACHABLE, - method=base_request.method, - ), - ) - except UpstreamConcurrencyLimitError as exc: - return self._generate_error_response( - base_request.id, - upstream_unreachable_error( - ERR_DISCOVERY_UPSTREAM_UNREACHABLE, - method=base_request.method, - detail=str(exc), - ), - ) - except Exception as exc: - logger.exception("OpenCode provider discovery JSON-RPC method failed") - return self._generate_error_response( - base_request.id, - A2AError(root=InternalError(message=str(exc))), - ) - - try: - raw_providers, default_by_provider, connected = _extract_provider_catalog(raw_result) - if base_request.method == self._method_list_providers: - items = _normalize_provider_summaries( - raw_providers, - default_by_provider=default_by_provider, - connected=connected, - ) - else: - items = _normalize_model_summaries( - raw_providers, - default_by_provider=default_by_provider, - connected=connected, - provider_id=provider_id, - ) - except ValueError as exc: - logger.warning("Upstream OpenCode provider payload mismatch: %s", exc) - return self._generate_error_response( - base_request.id, - upstream_payload_error( - ERR_DISCOVERY_UPSTREAM_PAYLOAD_ERROR, - detail=str(exc), - method=base_request.method, - ), - ) - - result = { - "items": items, - "default_by_provider": default_by_provider, - "connected": connected, - } - - if base_request.id is None: - return Response(status_code=204) - - return self._jsonrpc_success_response(base_request.id, result) - - async def _handle_session_control_request( - self, - base_request: JSONRPCRequest, - params: dict[str, Any], - *, - request: Request, - ) -> Response: - allowed_fields = {"session_id", "request", "metadata"} - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - return self._generate_error_response( - base_request.id, - invalid_params_error( - f"Unsupported fields: {', '.join(unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) - - session_id = params.get("session_id") - if not isinstance(session_id, str) or not session_id.strip(): - return self._generate_error_response( - base_request.id, - invalid_params_error( - "Missing required params.session_id", - data={"type": "MISSING_FIELD", "field": "session_id"}, - ), - ) - session_id = session_id.strip() - - raw_request = params.get("request") - if raw_request is None: - return self._generate_error_response( - base_request.id, - invalid_params_error( - "Missing required params.request", - data={"type": "MISSING_FIELD", "field": "request"}, - ), - ) - if not isinstance(raw_request, dict): - return self._generate_error_response( - base_request.id, - invalid_params_error( - "params.request must be an object", - data={"type": "INVALID_FIELD", "field": "request"}, - ), - ) - - request_identity = getattr(request.state, "user_identity", None) - identity = request_identity if isinstance(request_identity, str) else None - task_id = getattr(request.state, "task_id", None) - context_id = getattr(request.state, "context_id", None) - - def _log_shell_audit(outcome: str) -> None: - if base_request.method != self._method_shell: - return - logger.info( - "session_shell_audit method=%s identity=%s task_id=%s context_id=%s " - "session_id=%s outcome=%s", - base_request.method, - identity if identity else "-", - task_id if isinstance(task_id, str) and task_id.strip() else "-", - context_id if isinstance(context_id, str) and context_id.strip() else "-", - session_id, - outcome, - ) - - try: - if base_request.method == self._method_prompt_async: - _validate_prompt_async_request_payload(raw_request) - elif base_request.method == self._method_command: - _validate_command_request_payload(raw_request) - elif base_request.method == self._method_shell: - _validate_shell_request_payload(raw_request) - else: - raise _PromptAsyncValidationError( - field="method", - message=f"Unsupported method: {base_request.method}", - ) - except _PromptAsyncValidationError as exc: - return self._generate_error_response( - base_request.id, - invalid_params_error(str(exc), data={"type": "INVALID_FIELD", "field": exc.field}), - ) - - directory, metadata_error = self._extract_directory_from_metadata( - request_id=base_request.id, - params=params, - ) - if metadata_error is not None: - return metadata_error - - try: - directory = self._directory_resolver(directory) - except ValueError as exc: - return self._generate_error_response( - base_request.id, - invalid_params_error( - str(exc), - data={"type": "INVALID_FIELD", "field": "metadata.opencode.directory"}, - ), - ) - - pending_claim = False - claim_finalized = False - if identity: - try: - pending_claim = await self._session_claim( - identity=identity, - session_id=session_id, - ) - except PermissionError: - _log_shell_audit("forbidden") - return self._session_forbidden_response( - base_request.id, - session_id=session_id, - ) - - try: - result: dict[str, Any] - if base_request.method == self._method_prompt_async: - await self._upstream_client.session_prompt_async( - session_id, - request=dict(raw_request), - directory=directory, - ) - result = {"ok": True, "session_id": session_id} - elif base_request.method == self._method_command: - raw_result = await self._upstream_client.session_command( - session_id, - request=dict(raw_request), - directory=directory, - ) - item = _as_a2a_message(session_id, raw_result) - if item is None: - raise UpstreamContractError( - "OpenCode /session/{sessionID}/command response could not be mapped " - "to A2A Message" - ) - result = {"item": item} - else: - raw_result = await self._upstream_client.session_shell( - session_id, - request=dict(raw_request), - directory=directory, - ) - item = _as_a2a_message(session_id, raw_result) - if item is None: - raise UpstreamContractError( - "OpenCode /session/{sessionID}/shell response could not be mapped " - "to A2A Message" - ) - result = {"item": item} - - if pending_claim and identity: - await self._session_claim_finalize( - identity=identity, - session_id=session_id, - ) - claim_finalized = True - _log_shell_audit("success") - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - if upstream_status == 404: - _log_shell_audit("upstream_404") - return self._generate_error_response( - base_request.id, - session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), - ) - _log_shell_audit("upstream_http_error") - return self._generate_error_response( - base_request.id, - upstream_http_error( - ERR_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - method=base_request.method, - session_id=session_id, - ), - ) - except httpx.HTTPError: - _log_shell_audit("upstream_unreachable") - return self._generate_error_response( - base_request.id, - upstream_unreachable_error( - ERR_UPSTREAM_UNREACHABLE, - method=base_request.method, - session_id=session_id, - ), - ) - except UpstreamConcurrencyLimitError as exc: - _log_shell_audit("upstream_backpressure") - return self._generate_error_response( - base_request.id, - upstream_unreachable_error( - ERR_UPSTREAM_UNREACHABLE, - method=base_request.method, - session_id=session_id, - detail=str(exc), - ), - ) - except UpstreamContractError as exc: - _log_shell_audit("upstream_payload_error") - return self._generate_error_response( - base_request.id, - upstream_payload_error( - ERR_UPSTREAM_PAYLOAD_ERROR, - detail=str(exc), - method=base_request.method, - session_id=session_id, - ), - ) - except PermissionError: - _log_shell_audit("forbidden") - return self._session_forbidden_response( - base_request.id, - session_id=session_id, - ) - except Exception as exc: - _log_shell_audit("internal_error") - logger.exception("OpenCode session control JSON-RPC method failed") - return self._generate_error_response( - base_request.id, - A2AError(root=InternalError(message=str(exc))), - ) - finally: - if pending_claim and not claim_finalized and identity: - try: - await self._session_claim_release( - identity=identity, - session_id=session_id, - ) - except Exception: - logger.exception( - "Failed to release pending session claim for session_id=%s", - session_id, - ) - - if base_request.id is None: - return Response(status_code=204) - return self._jsonrpc_success_response( - base_request.id, - result, - ) - - async def _handle_interrupt_callback_request( - self, - base_request: JSONRPCRequest, - params: dict[str, Any], - *, - request: Request, - ) -> Response: - request_id = params.get("request_id") - if not isinstance(request_id, str) or not request_id.strip(): - return self._generate_error_response( - base_request.id, - invalid_params_error( - "Missing required params.request_id", - data={"type": "MISSING_FIELD", "field": "request_id"}, - ), - ) - request_id = request_id.strip() - request_identity = getattr(request.state, "user_identity", None) - directory, metadata_error = self._extract_directory_from_metadata( - request_id=base_request.id, - params=params, - ) - if metadata_error is not None: - return metadata_error - expected_interrupt_type = ( - "permission" if base_request.method == self._method_reply_permission else "question" - ) - resolve_request = getattr(self._upstream_client, "resolve_interrupt_request", None) - if callable(resolve_request): - status, binding = await resolve_request(request_id) - if status != "active" or binding is None: - return self._generate_error_response( - base_request.id, - interrupt_not_found_error( - ERR_INTERRUPT_EXPIRED if status == "expired" else ERR_INTERRUPT_NOT_FOUND, - request_id=request_id, - expired=status == "expired", - ), - ) - if binding.interrupt_type != expected_interrupt_type: - return self._generate_error_response( - base_request.id, - interrupt_type_mismatch_error( - ERR_INTERRUPT_TYPE_MISMATCH, - request_id=request_id, - expected_interrupt_type=expected_interrupt_type, - actual_interrupt_type=binding.interrupt_type, - ), - ) - if ( - isinstance(request_identity, str) - and request_identity - and binding.identity - and binding.identity != request_identity - ): - return self._generate_error_response( - base_request.id, - interrupt_not_found_error( - ERR_INTERRUPT_NOT_FOUND, - request_id=request_id, - ), - ) - else: - resolve_session = getattr(self._upstream_client, "resolve_interrupt_session", None) - if callable(resolve_session): - if not await resolve_session(request_id): - return self._generate_error_response( - base_request.id, - interrupt_not_found_error( - ERR_INTERRUPT_NOT_FOUND, - request_id=request_id, - ), - ) - if base_request.method == self._method_reply_permission: - allowed_fields = {"request_id", "reply", "message", "metadata"} - elif base_request.method == self._method_reply_question: - allowed_fields = {"request_id", "answers", "metadata"} - else: - allowed_fields = {"request_id", "metadata"} - unknown_fields = sorted(set(params) - allowed_fields) - if unknown_fields: - return self._generate_error_response( - base_request.id, - invalid_params_error( - f"Unsupported fields: {', '.join(unknown_fields)}", - data={"type": "INVALID_FIELD", "fields": unknown_fields}, - ), - ) - - try: - result: dict[str, Any] = { - "ok": True, - "request_id": request_id, - } - if base_request.method == self._method_reply_permission: - reply = _normalize_permission_reply(params.get("reply")) - message = params.get("message") - if message is not None and not isinstance(message, str): - raise ValueError("message must be a string") - await self._upstream_client.permission_reply( - request_id, - reply=reply, - message=message, - directory=directory, - ) - elif base_request.method == self._method_reply_question: - answers = _parse_question_answers(params.get("answers")) - await self._upstream_client.question_reply( - request_id, - answers=answers, - directory=directory, - ) - else: - await self._upstream_client.question_reject(request_id, directory=directory) - discard_request = getattr(self._upstream_client, "discard_interrupt_request", None) - if callable(discard_request): - await discard_request(request_id) - except ValueError as exc: - return self._generate_error_response( - base_request.id, - invalid_params_error(str(exc), data={"type": "INVALID_FIELD"}), - ) - except httpx.HTTPStatusError as exc: - upstream_status = exc.response.status_code - if upstream_status == 404: - discard_request = getattr(self._upstream_client, "discard_interrupt_request", None) - if callable(discard_request): - await discard_request(request_id) - return self._generate_error_response( - base_request.id, - interrupt_not_found_error( - ERR_INTERRUPT_NOT_FOUND, - request_id=request_id, - ), - ) - return self._generate_error_response( - base_request.id, - upstream_http_error( - ERR_UPSTREAM_HTTP_ERROR, - upstream_status=upstream_status, - request_id=request_id, - ), - ) - except httpx.HTTPError: - return self._generate_error_response( - base_request.id, - upstream_unreachable_error( - ERR_UPSTREAM_UNREACHABLE, - request_id=request_id, - ), - ) - except UpstreamConcurrencyLimitError as exc: - return self._generate_error_response( - base_request.id, - upstream_unreachable_error( - ERR_UPSTREAM_UNREACHABLE, - request_id=request_id, - detail=str(exc), - ), - ) - except Exception as exc: - logger.exception("OpenCode interrupt callback JSON-RPC method failed") - return self._generate_error_response( - base_request.id, - A2AError(root=InternalError(message=str(exc))), - ) - - if base_request.id is None: - return Response(status_code=204) - return self._jsonrpc_success_response(base_request.id, result) - def _jsonrpc_success_response(self, request_id: str | int, result: Any) -> JSONResponse: return JSONResponse( { diff --git a/src/opencode_a2a/jsonrpc/dispatch.py b/src/opencode_a2a/jsonrpc/dispatch.py new file mode 100644 index 0000000..f726b36 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/dispatch.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass +from typing import Any, TypeAlias + +from a2a.server.apps.jsonrpc.jsonrpc_app import JSONRPCApplication +from a2a.types import A2AError, JSONRPCError, JSONRPCRequest +from fastapi.responses import JSONResponse +from starlette.requests import Request +from starlette.responses import Response + +from ..opencode_upstream_client import OpencodeUpstreamClient + +# Delegate all SDK-owned JSON-RPC methods to the base app, then let the local +# extension registry override only the OpenCode-specific methods. +CORE_JSONRPC_METHODS = frozenset(JSONRPCApplication.METHOD_TO_MODEL) + +ErrorResponseFactory: TypeAlias = Callable[[str | int | None, JSONRPCError | A2AError], Response] +SuccessResponseFactory: TypeAlias = Callable[[str | int, Any], JSONResponse] +SessionClaimFunc: TypeAlias = Callable[..., Awaitable[bool]] +SessionFinalizeFunc: TypeAlias = Callable[..., Awaitable[None]] +SessionReleaseFunc: TypeAlias = Callable[..., Awaitable[None]] +ExtensionHandlerFunc: TypeAlias = Callable[ + ["ExtensionHandlerContext", JSONRPCRequest, dict[str, Any], Request], + Awaitable[Response], +] + + +@dataclass(frozen=True) +class ExtensionHandlerContext: + upstream_client: OpencodeUpstreamClient + method_list_sessions: str + method_get_session_messages: str + method_prompt_async: str + method_command: str + method_shell: str | None + method_list_providers: str + method_list_models: str + method_reply_permission: str + method_reply_question: str + method_reject_question: str + protocol_version: str + supported_methods: tuple[str, ...] + directory_resolver: Callable[[str | None], str | None] + session_claim: SessionClaimFunc + session_claim_finalize: SessionFinalizeFunc + session_claim_release: SessionReleaseFunc + error_response: ErrorResponseFactory + success_response: SuccessResponseFactory + + +@dataclass(frozen=True) +class ExtensionMethodSpec: + name: str + methods: frozenset[str] + handler: ExtensionHandlerFunc + + +class ExtensionMethodRegistry: + def __init__(self, specs: Iterable[ExtensionMethodSpec]) -> None: + method_map: dict[str, ExtensionMethodSpec] = {} + normalized_specs: list[ExtensionMethodSpec] = [] + for spec in specs: + normalized_specs.append(spec) + for method in spec.methods: + existing = method_map.get(method) + if existing is not None: + raise ValueError( + f"Extension method {method!r} registered by both " + f"{existing.name!r} and {spec.name!r}" + ) + method_map[method] = spec + self._specs = tuple(normalized_specs) + self._method_map = method_map + + @property + def specs(self) -> tuple[ExtensionMethodSpec, ...]: + return self._specs + + def methods(self) -> frozenset[str]: + return frozenset(self._method_map) + + def resolve(self, method: str) -> ExtensionMethodSpec | None: + return self._method_map.get(method) + + +def build_extension_method_registry( + context: ExtensionHandlerContext, +) -> ExtensionMethodRegistry: + from .handlers.interrupt_callbacks import handle_interrupt_callback_request + from .handlers.provider_discovery import handle_provider_discovery_request + from .handlers.session_control import handle_session_control_request + from .handlers.session_queries import handle_session_query_request + + session_control_methods = {context.method_prompt_async, context.method_command} + if context.method_shell is not None: + session_control_methods.add(context.method_shell) + + return ExtensionMethodRegistry( + ( + ExtensionMethodSpec( + name="session_query", + methods=frozenset( + { + context.method_list_sessions, + context.method_get_session_messages, + } + ), + handler=handle_session_query_request, + ), + ExtensionMethodSpec( + name="provider_discovery", + methods=frozenset( + { + context.method_list_providers, + context.method_list_models, + } + ), + handler=handle_provider_discovery_request, + ), + ExtensionMethodSpec( + name="session_control", + methods=frozenset(session_control_methods), + handler=handle_session_control_request, + ), + ExtensionMethodSpec( + name="interrupt_callback", + methods=frozenset( + { + context.method_reply_permission, + context.method_reply_question, + context.method_reject_question, + } + ), + handler=handle_interrupt_callback_request, + ), + ) + ) diff --git a/src/opencode_a2a/jsonrpc/handlers/__init__.py b/src/opencode_a2a/jsonrpc/handlers/__init__.py new file mode 100644 index 0000000..967ae51 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/__init__.py @@ -0,0 +1 @@ +"""Domain handlers for OpenCode JSON-RPC extension methods.""" diff --git a/src/opencode_a2a/jsonrpc/handlers/common.py b/src/opencode_a2a/jsonrpc/handlers/common.py new file mode 100644 index 0000000..9fc4ac5 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/common.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +import logging +from typing import Any + +from a2a.types import A2AError, InternalError +from starlette.responses import Response + +from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES +from ...opencode_upstream_client import UpstreamConcurrencyLimitError +from ..dispatch import ExtensionHandlerContext +from ..error_responses import ( + invalid_params_error, + session_forbidden_error, + upstream_http_error, + upstream_payload_error, + upstream_unreachable_error, +) + +ERR_SESSION_FORBIDDEN = SESSION_QUERY_ERROR_BUSINESS_CODES["SESSION_FORBIDDEN"] +logger = logging.getLogger(__name__) + + +def build_success_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + result: dict[str, Any], +) -> Response: + if request_id is None: + return Response(status_code=204) + return context.success_response(request_id, result) + + +def build_session_forbidden_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + *, + session_id: str, +) -> Response: + return context.error_response( + request_id, + session_forbidden_error(ERR_SESSION_FORBIDDEN, session_id=session_id), + ) + + +def extract_directory_from_metadata( + context: ExtensionHandlerContext, + *, + request_id: str | int | None, + params: dict[str, Any], +) -> tuple[str | None, Response | None]: + metadata = params.get("metadata") + if metadata is not None and not isinstance(metadata, dict): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata must be an object", + data={"type": "INVALID_FIELD", "field": "metadata"}, + ), + ) + + opencode_metadata: dict[str, Any] | None = None + if isinstance(metadata, dict): + unknown_metadata_fields = sorted(set(metadata) - {"opencode", "shared"}) + if unknown_metadata_fields: + prefixed_fields = [f"metadata.{field}" for field in unknown_metadata_fields] + return None, context.error_response( + request_id, + invalid_params_error( + f"Unsupported metadata fields: {', '.join(prefixed_fields)}", + data={"type": "INVALID_FIELD", "fields": prefixed_fields}, + ), + ) + raw_opencode_metadata = metadata.get("opencode") + if raw_opencode_metadata is not None and not isinstance(raw_opencode_metadata, dict): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata.opencode must be an object", + data={"type": "INVALID_FIELD", "field": "metadata.opencode"}, + ), + ) + if isinstance(raw_opencode_metadata, dict): + opencode_metadata = raw_opencode_metadata + raw_shared_metadata = metadata.get("shared") + if raw_shared_metadata is not None and not isinstance(raw_shared_metadata, dict): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata.shared must be an object", + data={"type": "INVALID_FIELD", "field": "metadata.shared"}, + ), + ) + + directory = None + if opencode_metadata is not None: + directory = opencode_metadata.get("directory") + if directory is not None and not isinstance(directory, str): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata.opencode.directory must be a string", + data={"type": "INVALID_FIELD", "field": "metadata.opencode.directory"}, + ), + ) + + return directory, None + + +def resolve_directory( + context: ExtensionHandlerContext, + *, + request_id: str | int | None, + params: dict[str, Any], +) -> tuple[str | None, Response | None]: + directory, metadata_error = extract_directory_from_metadata( + context, + request_id=request_id, + params=params, + ) + if metadata_error is not None: + return None, metadata_error + + try: + return context.directory_resolver(directory), None + except ValueError as exc: + return None, context.error_response( + request_id, + invalid_params_error( + str(exc), + data={"type": "INVALID_FIELD", "field": "metadata.opencode.directory"}, + ), + ) + + +def extract_interrupt_callback_directory_hint( + context: ExtensionHandlerContext, + *, + request_id: str | int | None, + params: dict[str, Any], +) -> tuple[str | None, Response | None]: + # Historical contract: interrupt callbacks accept raw metadata.opencode.directory + # and do not run it through the directory resolver used by session methods. + return extract_directory_from_metadata( + context, + request_id=request_id, + params=params, + ) + + +def build_upstream_http_error_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + code: int, + *, + upstream_status: int, + method: str | None = None, + session_id: str | None = None, + interrupt_request_id: str | None = None, + detail: str | None = None, +) -> Response: + return context.error_response( + request_id, + upstream_http_error( + code, + upstream_status=upstream_status, + method=method, + session_id=session_id, + request_id=interrupt_request_id, + detail=detail, + ), + ) + + +def build_upstream_unreachable_error_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + code: int, + *, + method: str | None = None, + session_id: str | None = None, + interrupt_request_id: str | None = None, + detail: str | None = None, +) -> Response: + return context.error_response( + request_id, + upstream_unreachable_error( + code, + method=method, + session_id=session_id, + request_id=interrupt_request_id, + detail=detail, + ), + ) + + +def build_upstream_concurrency_error_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + code: int, + *, + exc: UpstreamConcurrencyLimitError, + method: str | None = None, + session_id: str | None = None, + interrupt_request_id: str | None = None, +) -> Response: + return build_upstream_unreachable_error_response( + context, + request_id, + code, + method=method, + session_id=session_id, + interrupt_request_id=interrupt_request_id, + detail=str(exc), + ) + + +def build_upstream_payload_error_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + code: int, + *, + detail: str, + method: str | None = None, + session_id: str | None = None, + interrupt_request_id: str | None = None, +) -> Response: + return context.error_response( + request_id, + upstream_payload_error( + code, + detail=detail, + method=method, + session_id=session_id, + request_id=interrupt_request_id, + ), + ) + + +def build_internal_error_response( + context: ExtensionHandlerContext, + request_id: str | int | None, + *, + log_message: str, + exc: Exception, +) -> Response: + logger.exception(log_message) + return context.error_response( + request_id, + A2AError(root=InternalError(message=str(exc))), + ) diff --git a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py new file mode 100644 index 0000000..c647217 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import logging +from typing import Any + +import httpx +from a2a.types import JSONRPCRequest +from starlette.requests import Request +from starlette.responses import Response + +from ...contracts.extensions import INTERRUPT_ERROR_BUSINESS_CODES +from ...opencode_upstream_client import UpstreamConcurrencyLimitError +from ..dispatch import ExtensionHandlerContext +from ..error_responses import ( + interrupt_not_found_error, + interrupt_type_mismatch_error, + invalid_params_error, +) +from ..methods import _normalize_permission_reply, _parse_question_answers +from .common import ( + build_internal_error_response, + build_success_response, + build_upstream_concurrency_error_response, + build_upstream_http_error_response, + build_upstream_unreachable_error_response, + extract_interrupt_callback_directory_hint, +) + +logger = logging.getLogger(__name__) + +ERR_INTERRUPT_NOT_FOUND = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_REQUEST_NOT_FOUND"] +ERR_INTERRUPT_EXPIRED = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_REQUEST_EXPIRED"] +ERR_INTERRUPT_TYPE_MISMATCH = INTERRUPT_ERROR_BUSINESS_CODES["INTERRUPT_TYPE_MISMATCH"] +ERR_UPSTREAM_UNREACHABLE = INTERRUPT_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] +ERR_UPSTREAM_HTTP_ERROR = INTERRUPT_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] + + +async def handle_interrupt_callback_request( + context: ExtensionHandlerContext, + base_request: JSONRPCRequest, + params: dict[str, Any], + request: Request, +) -> Response: + request_id = params.get("request_id") + if not isinstance(request_id, str) or not request_id.strip(): + return context.error_response( + base_request.id, + invalid_params_error( + "Missing required params.request_id", + data={"type": "MISSING_FIELD", "field": "request_id"}, + ), + ) + request_id = request_id.strip() + request_identity = getattr(request.state, "user_identity", None) + + directory, directory_error = extract_interrupt_callback_directory_hint( + context, + request_id=base_request.id, + params=params, + ) + if directory_error is not None: + return directory_error + + expected_interrupt_type = ( + "permission" if base_request.method == context.method_reply_permission else "question" + ) + resolve_request = getattr(context.upstream_client, "resolve_interrupt_request", None) + if callable(resolve_request): + status, binding = await resolve_request(request_id) + if status != "active" or binding is None: + return context.error_response( + base_request.id, + interrupt_not_found_error( + ERR_INTERRUPT_EXPIRED if status == "expired" else ERR_INTERRUPT_NOT_FOUND, + request_id=request_id, + expired=status == "expired", + ), + ) + if binding.interrupt_type != expected_interrupt_type: + return context.error_response( + base_request.id, + interrupt_type_mismatch_error( + ERR_INTERRUPT_TYPE_MISMATCH, + request_id=request_id, + expected_interrupt_type=expected_interrupt_type, + actual_interrupt_type=binding.interrupt_type, + ), + ) + if ( + isinstance(request_identity, str) + and request_identity + and binding.identity + and binding.identity != request_identity + ): + return context.error_response( + base_request.id, + interrupt_not_found_error( + ERR_INTERRUPT_NOT_FOUND, + request_id=request_id, + ), + ) + else: + resolve_session = getattr(context.upstream_client, "resolve_interrupt_session", None) + if callable(resolve_session) and not await resolve_session(request_id): + return context.error_response( + base_request.id, + interrupt_not_found_error( + ERR_INTERRUPT_NOT_FOUND, + request_id=request_id, + ), + ) + + if base_request.method == context.method_reply_permission: + allowed_fields = {"request_id", "reply", "message", "metadata"} + elif base_request.method == context.method_reply_question: + allowed_fields = {"request_id", "answers", "metadata"} + else: + allowed_fields = {"request_id", "metadata"} + unknown_fields = sorted(set(params) - allowed_fields) + if unknown_fields: + return context.error_response( + base_request.id, + invalid_params_error( + f"Unsupported fields: {', '.join(unknown_fields)}", + data={"type": "INVALID_FIELD", "fields": unknown_fields}, + ), + ) + + try: + result: dict[str, Any] = { + "ok": True, + "request_id": request_id, + } + if base_request.method == context.method_reply_permission: + reply = _normalize_permission_reply(params.get("reply")) + message = params.get("message") + if message is not None and not isinstance(message, str): + raise ValueError("message must be a string") + await context.upstream_client.permission_reply( + request_id, + reply=reply, + message=message, + directory=directory, + ) + elif base_request.method == context.method_reply_question: + answers = _parse_question_answers(params.get("answers")) + await context.upstream_client.question_reply( + request_id, + answers=answers, + directory=directory, + ) + else: + await context.upstream_client.question_reject(request_id, directory=directory) + discard_request = getattr(context.upstream_client, "discard_interrupt_request", None) + if callable(discard_request): + await discard_request(request_id) + except ValueError as exc: + return context.error_response( + base_request.id, + invalid_params_error(str(exc), data={"type": "INVALID_FIELD"}), + ) + except httpx.HTTPStatusError as exc: + upstream_status = exc.response.status_code + if upstream_status == 404: + discard_request = getattr(context.upstream_client, "discard_interrupt_request", None) + if callable(discard_request): + await discard_request(request_id) + return context.error_response( + base_request.id, + interrupt_not_found_error( + ERR_INTERRUPT_NOT_FOUND, + request_id=request_id, + ), + ) + return build_upstream_http_error_response( + context, + base_request.id, + ERR_UPSTREAM_HTTP_ERROR, + upstream_status=upstream_status, + interrupt_request_id=request_id, + ) + except httpx.HTTPError: + return build_upstream_unreachable_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + interrupt_request_id=request_id, + ) + except UpstreamConcurrencyLimitError as exc: + return build_upstream_concurrency_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + exc=exc, + interrupt_request_id=request_id, + ) + except Exception as exc: + return build_internal_error_response( + context, + base_request.id, + log_message="OpenCode interrupt callback JSON-RPC method failed", + exc=exc, + ) + + return build_success_response(context, base_request.id, result) diff --git a/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py new file mode 100644 index 0000000..25c6313 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import logging +from typing import Any + +import httpx +from a2a.types import JSONRPCRequest +from starlette.requests import Request +from starlette.responses import Response + +from ...contracts.extensions import PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES +from ...opencode_upstream_client import UpstreamConcurrencyLimitError +from ..dispatch import ExtensionHandlerContext +from ..error_responses import invalid_params_error +from ..methods import ( + _extract_provider_catalog, + _normalize_model_summaries, + _normalize_provider_summaries, +) +from .common import ( + build_internal_error_response, + build_success_response, + build_upstream_concurrency_error_response, + build_upstream_http_error_response, + build_upstream_payload_error_response, + build_upstream_unreachable_error_response, + resolve_directory, +) + +logger = logging.getLogger(__name__) + +ERR_DISCOVERY_UPSTREAM_UNREACHABLE = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES[ + "UPSTREAM_UNREACHABLE" +] +ERR_DISCOVERY_UPSTREAM_HTTP_ERROR = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] +ERR_DISCOVERY_UPSTREAM_PAYLOAD_ERROR = PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES[ + "UPSTREAM_PAYLOAD_ERROR" +] + + +async def handle_provider_discovery_request( + context: ExtensionHandlerContext, + base_request: JSONRPCRequest, + params: dict[str, Any], + request: Request, +) -> Response: + del request + allowed_fields = {"metadata"} + if base_request.method == context.method_list_models: + allowed_fields.add("provider_id") + unknown_fields = sorted(set(params) - allowed_fields) + if unknown_fields: + prefixed_fields = [f"params.{field}" for field in unknown_fields] + return context.error_response( + base_request.id, + invalid_params_error( + f"Unsupported params fields: {', '.join(prefixed_fields)}", + data={"type": "INVALID_FIELD", "fields": prefixed_fields}, + ), + ) + + provider_id: str | None = None + if base_request.method == context.method_list_models: + raw_provider_id = params.get("provider_id") + if raw_provider_id is not None: + if not isinstance(raw_provider_id, str) or not raw_provider_id.strip(): + return context.error_response( + base_request.id, + invalid_params_error( + "provider_id must be a non-empty string", + data={"type": "INVALID_FIELD", "field": "provider_id"}, + ), + ) + provider_id = raw_provider_id.strip() + + directory, directory_error = resolve_directory( + context, + request_id=base_request.id, + params=params, + ) + if directory_error is not None: + return directory_error + + try: + raw_result = await context.upstream_client.list_provider_catalog(directory=directory) + except httpx.HTTPStatusError as exc: + upstream_status = exc.response.status_code + return build_upstream_http_error_response( + context, + base_request.id, + ERR_DISCOVERY_UPSTREAM_HTTP_ERROR, + upstream_status=upstream_status, + method=base_request.method, + ) + except httpx.HTTPError: + return build_upstream_unreachable_error_response( + context, + base_request.id, + ERR_DISCOVERY_UPSTREAM_UNREACHABLE, + method=base_request.method, + ) + except UpstreamConcurrencyLimitError as exc: + return build_upstream_concurrency_error_response( + context, + base_request.id, + ERR_DISCOVERY_UPSTREAM_UNREACHABLE, + exc=exc, + method=base_request.method, + ) + except Exception as exc: + return build_internal_error_response( + context, + base_request.id, + log_message="OpenCode provider discovery JSON-RPC method failed", + exc=exc, + ) + + try: + raw_providers, default_by_provider, connected = _extract_provider_catalog(raw_result) + if base_request.method == context.method_list_providers: + items = _normalize_provider_summaries( + raw_providers, + default_by_provider=default_by_provider, + connected=connected, + ) + else: + items = _normalize_model_summaries( + raw_providers, + default_by_provider=default_by_provider, + connected=connected, + provider_id=provider_id, + ) + except ValueError as exc: + logger.warning("Upstream OpenCode provider payload mismatch: %s", exc) + return build_upstream_payload_error_response( + context, + base_request.id, + ERR_DISCOVERY_UPSTREAM_PAYLOAD_ERROR, + detail=str(exc), + method=base_request.method, + ) + + return build_success_response( + context, + base_request.id, + { + "items": items, + "default_by_provider": default_by_provider, + "connected": connected, + }, + ) diff --git a/src/opencode_a2a/jsonrpc/handlers/session_control.py b/src/opencode_a2a/jsonrpc/handlers/session_control.py new file mode 100644 index 0000000..e81a1a6 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/session_control.py @@ -0,0 +1,265 @@ +from __future__ import annotations + +import logging +from typing import Any + +import httpx +from a2a.types import JSONRPCRequest +from starlette.requests import Request +from starlette.responses import Response + +from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES +from ...opencode_upstream_client import UpstreamConcurrencyLimitError, UpstreamContractError +from ..dispatch import ExtensionHandlerContext +from ..error_responses import invalid_params_error, session_not_found_error +from ..methods import ( + _as_a2a_message, + _PromptAsyncValidationError, + _validate_command_request_payload, + _validate_prompt_async_request_payload, + _validate_shell_request_payload, +) +from .common import ( + build_internal_error_response, + build_session_forbidden_response, + build_success_response, + build_upstream_concurrency_error_response, + build_upstream_http_error_response, + build_upstream_payload_error_response, + build_upstream_unreachable_error_response, + resolve_directory, +) + +logger = logging.getLogger(__name__) + +ERR_SESSION_NOT_FOUND = SESSION_QUERY_ERROR_BUSINESS_CODES["SESSION_NOT_FOUND"] +ERR_UPSTREAM_UNREACHABLE = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] +ERR_UPSTREAM_HTTP_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] +ERR_UPSTREAM_PAYLOAD_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_PAYLOAD_ERROR"] + + +async def handle_session_control_request( + context: ExtensionHandlerContext, + base_request: JSONRPCRequest, + params: dict[str, Any], + request: Request, +) -> Response: + allowed_fields = {"session_id", "request", "metadata"} + unknown_fields = sorted(set(params) - allowed_fields) + if unknown_fields: + return context.error_response( + base_request.id, + invalid_params_error( + f"Unsupported fields: {', '.join(unknown_fields)}", + data={"type": "INVALID_FIELD", "fields": unknown_fields}, + ), + ) + + session_id = params.get("session_id") + if not isinstance(session_id, str) or not session_id.strip(): + return context.error_response( + base_request.id, + invalid_params_error( + "Missing required params.session_id", + data={"type": "MISSING_FIELD", "field": "session_id"}, + ), + ) + session_id = session_id.strip() + + raw_request = params.get("request") + if raw_request is None: + return context.error_response( + base_request.id, + invalid_params_error( + "Missing required params.request", + data={"type": "MISSING_FIELD", "field": "request"}, + ), + ) + if not isinstance(raw_request, dict): + return context.error_response( + base_request.id, + invalid_params_error( + "params.request must be an object", + data={"type": "INVALID_FIELD", "field": "request"}, + ), + ) + + request_identity = getattr(request.state, "user_identity", None) + identity = request_identity if isinstance(request_identity, str) else None + task_id = getattr(request.state, "task_id", None) + context_id = getattr(request.state, "context_id", None) + + def _log_shell_audit(outcome: str) -> None: + if base_request.method != context.method_shell: + return + logger.info( + "session_shell_audit method=%s identity=%s task_id=%s context_id=%s " + "session_id=%s outcome=%s", + base_request.method, + identity if identity else "-", + task_id if isinstance(task_id, str) and task_id.strip() else "-", + context_id if isinstance(context_id, str) and context_id.strip() else "-", + session_id, + outcome, + ) + + try: + if base_request.method == context.method_prompt_async: + _validate_prompt_async_request_payload(raw_request) + elif base_request.method == context.method_command: + _validate_command_request_payload(raw_request) + elif base_request.method == context.method_shell: + _validate_shell_request_payload(raw_request) + else: + raise _PromptAsyncValidationError( + field="method", + message=f"Unsupported method: {base_request.method}", + ) + except _PromptAsyncValidationError as exc: + return context.error_response( + base_request.id, + invalid_params_error(str(exc), data={"type": "INVALID_FIELD", "field": exc.field}), + ) + + directory, directory_error = resolve_directory( + context, + request_id=base_request.id, + params=params, + ) + if directory_error is not None: + return directory_error + + pending_claim = False + claim_finalized = False + if identity: + try: + pending_claim = await context.session_claim( + identity=identity, + session_id=session_id, + ) + except PermissionError: + _log_shell_audit("forbidden") + return build_session_forbidden_response( + context, + base_request.id, + session_id=session_id, + ) + + try: + result: dict[str, Any] + if base_request.method == context.method_prompt_async: + await context.upstream_client.session_prompt_async( + session_id, + request=dict(raw_request), + directory=directory, + ) + result = {"ok": True, "session_id": session_id} + elif base_request.method == context.method_command: + raw_result = await context.upstream_client.session_command( + session_id, + request=dict(raw_request), + directory=directory, + ) + item = _as_a2a_message(session_id, raw_result) + if item is None: + raise UpstreamContractError( + "OpenCode /session/{sessionID}/command response could not be mapped " + "to A2A Message" + ) + result = {"item": item} + else: + raw_result = await context.upstream_client.session_shell( + session_id, + request=dict(raw_request), + directory=directory, + ) + item = _as_a2a_message(session_id, raw_result) + if item is None: + raise UpstreamContractError( + "OpenCode /session/{sessionID}/shell response could not be mapped " + "to A2A Message" + ) + result = {"item": item} + + if pending_claim and identity: + await context.session_claim_finalize( + identity=identity, + session_id=session_id, + ) + claim_finalized = True + _log_shell_audit("success") + except httpx.HTTPStatusError as exc: + upstream_status = exc.response.status_code + if upstream_status == 404: + _log_shell_audit("upstream_404") + return context.error_response( + base_request.id, + session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), + ) + _log_shell_audit("upstream_http_error") + return build_upstream_http_error_response( + context, + base_request.id, + ERR_UPSTREAM_HTTP_ERROR, + upstream_status=upstream_status, + method=base_request.method, + session_id=session_id, + ) + except httpx.HTTPError: + _log_shell_audit("upstream_unreachable") + return build_upstream_unreachable_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + method=base_request.method, + session_id=session_id, + ) + except UpstreamConcurrencyLimitError as exc: + _log_shell_audit("upstream_backpressure") + return build_upstream_concurrency_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + exc=exc, + method=base_request.method, + session_id=session_id, + ) + except UpstreamContractError as exc: + _log_shell_audit("upstream_payload_error") + return build_upstream_payload_error_response( + context, + base_request.id, + ERR_UPSTREAM_PAYLOAD_ERROR, + detail=str(exc), + method=base_request.method, + session_id=session_id, + ) + except PermissionError: + _log_shell_audit("forbidden") + return build_session_forbidden_response( + context, + base_request.id, + session_id=session_id, + ) + except Exception as exc: + _log_shell_audit("internal_error") + return build_internal_error_response( + context, + base_request.id, + log_message="OpenCode session control JSON-RPC method failed", + exc=exc, + ) + finally: + if pending_claim and not claim_finalized and identity: + try: + await context.session_claim_release( + identity=identity, + session_id=session_id, + ) + except Exception: + logger.exception( + "Failed to release pending session claim for session_id=%s", + session_id, + ) + + return build_success_response(context, base_request.id, result) diff --git a/src/opencode_a2a/jsonrpc/handlers/session_queries.py b/src/opencode_a2a/jsonrpc/handlers/session_queries.py new file mode 100644 index 0000000..e6c8a20 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/session_queries.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import logging +from typing import Any + +import httpx +from a2a.types import JSONRPCRequest +from starlette.requests import Request +from starlette.responses import Response + +from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES +from ...opencode_upstream_client import UpstreamConcurrencyLimitError +from ..dispatch import ExtensionHandlerContext +from ..error_responses import invalid_params_error, session_not_found_error +from ..methods import ( + _apply_session_query_limit, + _as_a2a_message, + _as_a2a_session_task, + _extract_raw_items, +) +from ..params import ( + JsonRpcParamsValidationError, + parse_get_session_messages_params, + parse_list_sessions_params, +) +from .common import ( + build_internal_error_response, + build_success_response, + build_upstream_concurrency_error_response, + build_upstream_http_error_response, + build_upstream_payload_error_response, + build_upstream_unreachable_error_response, +) + +logger = logging.getLogger(__name__) + +ERR_SESSION_NOT_FOUND = SESSION_QUERY_ERROR_BUSINESS_CODES["SESSION_NOT_FOUND"] +ERR_UPSTREAM_UNREACHABLE = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] +ERR_UPSTREAM_HTTP_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] +ERR_UPSTREAM_PAYLOAD_ERROR = SESSION_QUERY_ERROR_BUSINESS_CODES["UPSTREAM_PAYLOAD_ERROR"] + + +async def handle_session_query_request( + context: ExtensionHandlerContext, + base_request: JSONRPCRequest, + params: dict[str, Any], + request: Request, +) -> Response: + del request + try: + if base_request.method == context.method_list_sessions: + query = parse_list_sessions_params(params) + session_id: str | None = None + else: + session_id, query = parse_get_session_messages_params(params) + except JsonRpcParamsValidationError as exc: + return context.error_response( + base_request.id, + invalid_params_error(str(exc), data=exc.data), + ) + + limit = int(query["limit"]) + try: + if base_request.method == context.method_list_sessions: + raw_result = await context.upstream_client.list_sessions(params=query) + else: + assert session_id is not None + raw_result = await context.upstream_client.list_messages(session_id, params=query) + except httpx.HTTPStatusError as exc: + upstream_status = exc.response.status_code + if upstream_status == 404 and base_request.method == context.method_get_session_messages: + assert session_id is not None + return context.error_response( + base_request.id, + session_not_found_error(ERR_SESSION_NOT_FOUND, session_id=session_id), + ) + return build_upstream_http_error_response( + context, + base_request.id, + ERR_UPSTREAM_HTTP_ERROR, + upstream_status=upstream_status, + ) + except httpx.HTTPError: + return build_upstream_unreachable_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + ) + except UpstreamConcurrencyLimitError as exc: + return build_upstream_concurrency_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + exc=exc, + ) + except Exception as exc: + return build_internal_error_response( + context, + base_request.id, + log_message="OpenCode session query JSON-RPC method failed", + exc=exc, + ) + + try: + if base_request.method == context.method_list_sessions: + raw_items = _extract_raw_items(raw_result, kind="sessions") + else: + raw_items = _extract_raw_items(raw_result, kind="messages") + except ValueError as exc: + logger.warning("Upstream OpenCode payload mismatch: %s", exc) + return build_upstream_payload_error_response( + context, + base_request.id, + ERR_UPSTREAM_PAYLOAD_ERROR, + detail=str(exc), + ) + + if base_request.method == context.method_list_sessions: + mapped: list[dict[str, Any]] = [] + for item in raw_items: + task = _as_a2a_session_task(item) + if task is not None: + mapped.append(task) + items: list[dict[str, Any]] = _apply_session_query_limit(mapped, limit=limit) + else: + assert session_id is not None + mapped = [] + for item in raw_items: + message = _as_a2a_message(session_id, item) + if message is not None: + mapped.append(message) + items = mapped + + return build_success_response(context, base_request.id, {"items": items}) diff --git a/tests/jsonrpc/test_dispatch_registry.py b/tests/jsonrpc/test_dispatch_registry.py new file mode 100644 index 0000000..70d1fc5 --- /dev/null +++ b/tests/jsonrpc/test_dispatch_registry.py @@ -0,0 +1,119 @@ +import httpx +import pytest +from a2a.server.apps.jsonrpc.fastapi_app import A2AFastAPIApplication +from fastapi.responses import JSONResponse + +import opencode_a2a.server.application as app_module +from tests.support.helpers import DummySessionQueryOpencodeUpstreamClient, make_settings +from tests.support.session_extensions import _BASE_SETTINGS, _jsonrpc_app + + +@pytest.mark.asyncio +async def test_extension_registry_tracks_configured_methods(monkeypatch) -> None: + monkeypatch.setattr( + app_module, + "OpencodeUpstreamClient", + DummySessionQueryOpencodeUpstreamClient, + ) + app = app_module.create_app( + make_settings( + a2a_bearer_token="test-token", + a2a_enable_session_shell=False, + **_BASE_SETTINGS, + ) + ) + + registry_methods = _jsonrpc_app(app)._extension_method_registry.methods() # noqa: SLF001 + assert "opencode.sessions.list" in registry_methods + assert "opencode.providers.list" in registry_methods + assert "a2a.interrupt.permission.reply" in registry_methods + assert "opencode.sessions.shell" not in registry_methods + + +@pytest.mark.asyncio +async def test_core_jsonrpc_methods_delegate_to_base_app(monkeypatch) -> None: + async def _fake_base_handle(self, request): # noqa: ANN001 + payload = await request.json() + return JSONResponse({"delegated_method": payload["method"]}) + + monkeypatch.setattr(A2AFastAPIApplication, "_handle_requests", _fake_base_handle) + app = app_module.create_app(make_settings(a2a_bearer_token="test-token", **_BASE_SETTINGS)) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer test-token"}, + json={"jsonrpc": "2.0", "id": 1, "method": "message/send", "params": {}}, + ) + + assert response.status_code == 200 + assert response.json() == {"delegated_method": "message/send"} + + +@pytest.mark.asyncio +async def test_sdk_owned_non_chat_jsonrpc_methods_delegate_to_base_app(monkeypatch) -> None: + async def _fake_base_handle(self, request): # noqa: ANN001 + payload = await request.json() + return JSONResponse({"delegated_method": payload["method"]}) + + monkeypatch.setattr(A2AFastAPIApplication, "_handle_requests", _fake_base_handle) + app = app_module.create_app(make_settings(a2a_bearer_token="test-token", **_BASE_SETTINGS)) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer test-token"}, + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "tasks/pushNotificationConfig/get", + "params": {}, + }, + ) + + assert response.status_code == 200 + assert response.json() == {"delegated_method": "tasks/pushNotificationConfig/get"} + + +@pytest.mark.asyncio +async def test_extension_methods_stay_on_local_registry(monkeypatch) -> None: + dummy = DummySessionQueryOpencodeUpstreamClient( + make_settings( + a2a_bearer_token="test-token", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + + async def _unexpected_delegate(self, request): # noqa: ANN001 + raise AssertionError("extension method should not delegate to base JSON-RPC app") + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + monkeypatch.setattr(A2AFastAPIApplication, "_handle_requests", _unexpected_delegate) + app = app_module.create_app( + make_settings( + a2a_bearer_token="test-token", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer test-token"}, + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "opencode.sessions.list", + "params": {"limit": 1}, + }, + ) + + assert response.status_code == 200 + assert response.json()["result"]["items"][0]["id"] == "s-1"