Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
844 changes: 37 additions & 807 deletions src/opencode_a2a/jsonrpc/application.py

Large diffs are not rendered by default.

139 changes: 139 additions & 0 deletions src/opencode_a2a/jsonrpc/dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations

from collections.abc import Awaitable, Callable, Iterable
from dataclasses import dataclass
from typing import Any, TypeAlias

from a2a.server.apps.jsonrpc.jsonrpc_app import JSONRPCApplication
from a2a.types import A2AError, JSONRPCError, JSONRPCRequest
from fastapi.responses import JSONResponse
from starlette.requests import Request
from starlette.responses import Response

from ..opencode_upstream_client import OpencodeUpstreamClient

# Delegate all SDK-owned JSON-RPC methods to the base app, then let the local
# extension registry override only the OpenCode-specific methods.
CORE_JSONRPC_METHODS = frozenset(JSONRPCApplication.METHOD_TO_MODEL)

ErrorResponseFactory: TypeAlias = Callable[[str | int | None, JSONRPCError | A2AError], Response]
SuccessResponseFactory: TypeAlias = Callable[[str | int, Any], JSONResponse]
SessionClaimFunc: TypeAlias = Callable[..., Awaitable[bool]]
SessionFinalizeFunc: TypeAlias = Callable[..., Awaitable[None]]
SessionReleaseFunc: TypeAlias = Callable[..., Awaitable[None]]
ExtensionHandlerFunc: TypeAlias = Callable[
["ExtensionHandlerContext", JSONRPCRequest, dict[str, Any], Request],
Awaitable[Response],
]


@dataclass(frozen=True)
class ExtensionHandlerContext:
upstream_client: OpencodeUpstreamClient
method_list_sessions: str
method_get_session_messages: str
method_prompt_async: str
method_command: str
method_shell: str | None
method_list_providers: str
method_list_models: str
method_reply_permission: str
method_reply_question: str
method_reject_question: str
protocol_version: str
supported_methods: tuple[str, ...]
directory_resolver: Callable[[str | None], str | None]
session_claim: SessionClaimFunc
session_claim_finalize: SessionFinalizeFunc
session_claim_release: SessionReleaseFunc
error_response: ErrorResponseFactory
success_response: SuccessResponseFactory


@dataclass(frozen=True)
class ExtensionMethodSpec:
name: str
methods: frozenset[str]
handler: ExtensionHandlerFunc


class ExtensionMethodRegistry:
def __init__(self, specs: Iterable[ExtensionMethodSpec]) -> None:
method_map: dict[str, ExtensionMethodSpec] = {}
normalized_specs: list[ExtensionMethodSpec] = []
for spec in specs:
normalized_specs.append(spec)
for method in spec.methods:
existing = method_map.get(method)
if existing is not None:
raise ValueError(
f"Extension method {method!r} registered by both "
f"{existing.name!r} and {spec.name!r}"
)
method_map[method] = spec
self._specs = tuple(normalized_specs)
self._method_map = method_map

@property
def specs(self) -> tuple[ExtensionMethodSpec, ...]:
return self._specs

def methods(self) -> frozenset[str]:
return frozenset(self._method_map)

def resolve(self, method: str) -> ExtensionMethodSpec | None:
return self._method_map.get(method)


def build_extension_method_registry(
context: ExtensionHandlerContext,
) -> ExtensionMethodRegistry:
from .handlers.interrupt_callbacks import handle_interrupt_callback_request
from .handlers.provider_discovery import handle_provider_discovery_request
from .handlers.session_control import handle_session_control_request
from .handlers.session_queries import handle_session_query_request

session_control_methods = {context.method_prompt_async, context.method_command}
if context.method_shell is not None:
session_control_methods.add(context.method_shell)

return ExtensionMethodRegistry(
(
ExtensionMethodSpec(
name="session_query",
methods=frozenset(
{
context.method_list_sessions,
context.method_get_session_messages,
}
),
handler=handle_session_query_request,
),
ExtensionMethodSpec(
name="provider_discovery",
methods=frozenset(
{
context.method_list_providers,
context.method_list_models,
}
),
handler=handle_provider_discovery_request,
),
ExtensionMethodSpec(
name="session_control",
methods=frozenset(session_control_methods),
handler=handle_session_control_request,
),
ExtensionMethodSpec(
name="interrupt_callback",
methods=frozenset(
{
context.method_reply_permission,
context.method_reply_question,
context.method_reject_question,
}
),
handler=handle_interrupt_callback_request,
),
)
)
1 change: 1 addition & 0 deletions src/opencode_a2a/jsonrpc/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Domain handlers for OpenCode JSON-RPC extension methods."""
251 changes: 251 additions & 0 deletions src/opencode_a2a/jsonrpc/handlers/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
from __future__ import annotations

import logging
from typing import Any

from a2a.types import A2AError, InternalError
from starlette.responses import Response

from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES
from ...opencode_upstream_client import UpstreamConcurrencyLimitError
from ..dispatch import ExtensionHandlerContext
from ..error_responses import (
invalid_params_error,
session_forbidden_error,
upstream_http_error,
upstream_payload_error,
upstream_unreachable_error,
)

ERR_SESSION_FORBIDDEN = SESSION_QUERY_ERROR_BUSINESS_CODES["SESSION_FORBIDDEN"]
logger = logging.getLogger(__name__)


def build_success_response(
context: ExtensionHandlerContext,
request_id: str | int | None,
result: dict[str, Any],
) -> Response:
if request_id is None:
return Response(status_code=204)
return context.success_response(request_id, result)


def build_session_forbidden_response(
context: ExtensionHandlerContext,
request_id: str | int | None,
*,
session_id: str,
) -> Response:
return context.error_response(
request_id,
session_forbidden_error(ERR_SESSION_FORBIDDEN, session_id=session_id),
)


def extract_directory_from_metadata(
context: ExtensionHandlerContext,
*,
request_id: str | int | None,
params: dict[str, Any],
) -> tuple[str | None, Response | None]:
metadata = params.get("metadata")
if metadata is not None and not isinstance(metadata, dict):
return None, context.error_response(
request_id,
invalid_params_error(
"metadata must be an object",
data={"type": "INVALID_FIELD", "field": "metadata"},
),
)

opencode_metadata: dict[str, Any] | None = None
if isinstance(metadata, dict):
unknown_metadata_fields = sorted(set(metadata) - {"opencode", "shared"})
if unknown_metadata_fields:
prefixed_fields = [f"metadata.{field}" for field in unknown_metadata_fields]
return None, context.error_response(
request_id,
invalid_params_error(
f"Unsupported metadata fields: {', '.join(prefixed_fields)}",
data={"type": "INVALID_FIELD", "fields": prefixed_fields},
),
)
raw_opencode_metadata = metadata.get("opencode")
if raw_opencode_metadata is not None and not isinstance(raw_opencode_metadata, dict):
return None, context.error_response(
request_id,
invalid_params_error(
"metadata.opencode must be an object",
data={"type": "INVALID_FIELD", "field": "metadata.opencode"},
),
)
if isinstance(raw_opencode_metadata, dict):
opencode_metadata = raw_opencode_metadata
raw_shared_metadata = metadata.get("shared")
if raw_shared_metadata is not None and not isinstance(raw_shared_metadata, dict):
return None, context.error_response(
request_id,
invalid_params_error(
"metadata.shared must be an object",
data={"type": "INVALID_FIELD", "field": "metadata.shared"},
),
)

directory = None
if opencode_metadata is not None:
directory = opencode_metadata.get("directory")
if directory is not None and not isinstance(directory, str):
return None, context.error_response(
request_id,
invalid_params_error(
"metadata.opencode.directory must be a string",
data={"type": "INVALID_FIELD", "field": "metadata.opencode.directory"},
),
)

return directory, None


def resolve_directory(
context: ExtensionHandlerContext,
*,
request_id: str | int | None,
params: dict[str, Any],
) -> tuple[str | None, Response | None]:
directory, metadata_error = extract_directory_from_metadata(
context,
request_id=request_id,
params=params,
)
if metadata_error is not None:
return None, metadata_error

try:
return context.directory_resolver(directory), None
except ValueError as exc:
return None, context.error_response(
request_id,
invalid_params_error(
str(exc),
data={"type": "INVALID_FIELD", "field": "metadata.opencode.directory"},
),
)


def extract_interrupt_callback_directory_hint(
context: ExtensionHandlerContext,
*,
request_id: str | int | None,
params: dict[str, Any],
) -> tuple[str | None, Response | None]:
# Historical contract: interrupt callbacks accept raw metadata.opencode.directory
# and do not run it through the directory resolver used by session methods.
return extract_directory_from_metadata(
context,
request_id=request_id,
params=params,
)


def build_upstream_http_error_response(
context: ExtensionHandlerContext,
request_id: str | int | None,
code: int,
*,
upstream_status: int,
method: str | None = None,
session_id: str | None = None,
interrupt_request_id: str | None = None,
detail: str | None = None,
) -> Response:
return context.error_response(
request_id,
upstream_http_error(
code,
upstream_status=upstream_status,
method=method,
session_id=session_id,
request_id=interrupt_request_id,
detail=detail,
),
)


def build_upstream_unreachable_error_response(
context: ExtensionHandlerContext,
request_id: str | int | None,
code: int,
*,
method: str | None = None,
session_id: str | None = None,
interrupt_request_id: str | None = None,
detail: str | None = None,
) -> Response:
return context.error_response(
request_id,
upstream_unreachable_error(
code,
method=method,
session_id=session_id,
request_id=interrupt_request_id,
detail=detail,
),
)


def build_upstream_concurrency_error_response(
context: ExtensionHandlerContext,
request_id: str | int | None,
code: int,
*,
exc: UpstreamConcurrencyLimitError,
method: str | None = None,
session_id: str | None = None,
interrupt_request_id: str | None = None,
) -> Response:
return build_upstream_unreachable_error_response(
context,
request_id,
code,
method=method,
session_id=session_id,
interrupt_request_id=interrupt_request_id,
detail=str(exc),
)


def build_upstream_payload_error_response(
context: ExtensionHandlerContext,
request_id: str | int | None,
code: int,
*,
detail: str,
method: str | None = None,
session_id: str | None = None,
interrupt_request_id: str | None = None,
) -> Response:
return context.error_response(
request_id,
upstream_payload_error(
code,
detail=detail,
method=method,
session_id=session_id,
request_id=interrupt_request_id,
),
)


def build_internal_error_response(
context: ExtensionHandlerContext,
request_id: str | int | None,
*,
log_message: str,
exc: Exception,
) -> Response:
logger.exception(log_message)
return context.error_response(
request_id,
A2AError(root=InternalError(message=str(exc))),
)
Loading