diff --git a/docs/guide.md b/docs/guide.md index bb1a76a..cff60e8 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -657,10 +657,14 @@ No extra custom REST endpoint is introduced. - session title is available at `metadata.shared.session.title` - Session list filters: - optional `directory`, `roots`, `start`, `search`, `limit` + - optional `metadata.opencode.workspace.id` - `directory` is normalized through the same workspace-boundary rules used by other OpenCode directory overrides before reaching upstream + - when `metadata.opencode.workspace.id` is present, the adapter routes by + workspace and ignores `directory` - Session message history filters: - optional `limit`, `before` + - optional `metadata.opencode.workspace.id` - `before` is an opaque cursor for loading older messages and is only supported on `opencode.sessions.messages.list` @@ -752,6 +756,9 @@ Validation notes: - `metadata.opencode.directory` follows the same normalization and boundary rules as message send (`realpath` + workspace boundary check). +- `metadata.opencode.workspace.id` is a provider-private routing hint. When it + is present, the adapter routes the request to that workspace and does not + apply directory override resolution for the same call. - `request.model` uses the same shape as `metadata.shared.model` and is scoped only to the current session-control request. - Control methods enforce session owner guard based on request identity. @@ -847,6 +854,9 @@ curl -sS http://127.0.0.1:8000/ \ Response: - success => `{"items": [...], "default_by_provider": {...}, "connected": [...]}` (JSON-RPC result) +- optional `metadata.opencode.workspace.id` routes discovery against a specific + OpenCode workspace; otherwise the adapter falls back to directory routing + when `metadata.opencode.directory` is provided ### Model List (`opencode.models.list`) @@ -872,6 +882,100 @@ Response: - success => `{"items": [...], "default_by_provider": {...}, "connected": [...]}` (JSON-RPC result) +## Workspace Control (Provider-Private Extension) + +The runtime also exposes the OpenCode project/workspace/worktree control plane +through provider-private JSON-RPC methods: + +- `opencode.projects.list` +- `opencode.projects.current` +- `opencode.workspaces.list` +- `opencode.workspaces.create` +- `opencode.workspaces.remove` +- `opencode.worktrees.list` +- `opencode.worktrees.create` +- `opencode.worktrees.remove` +- `opencode.worktrees.reset` + +Behavior notes: + +- These methods target the active OpenCode deployment project. They are not + routed through per-request workspace forwarding. +- `metadata.opencode.workspace.id` is declared consistently across the adapter, + but current workspace-control methods do not use it to change the target + project. +- Mutating methods should be treated as operator-only control-plane actions. + +### Project Discovery (`opencode.projects.list`, `opencode.projects.current`) + +```bash +curl -sS http://127.0.0.1:8000/ \ + -H 'content-type: application/json' \ + -H 'Authorization: Bearer ' \ + -d '{ + "jsonrpc": "2.0", + "id": 31, + "method": "opencode.projects.current", + "params": {} + }' +``` + +Response: + +- `opencode.projects.list` => `{"items": [...]}` +- `opencode.projects.current` => `{"item": {...}}` + +### Workspace Discovery and Mutation + +```bash +curl -sS http://127.0.0.1:8000/ \ + -H 'content-type: application/json' \ + -H 'Authorization: Bearer ' \ + -d '{ + "jsonrpc": "2.0", + "id": 32, + "method": "opencode.workspaces.create", + "params": { + "request": { + "id": "wrk-api", + "type": "git", + "branch": "main" + } + } + }' +``` + +Response: + +- `opencode.workspaces.list` => `{"items": [...]}` +- `opencode.workspaces.create` => `{"item": {...}}` +- `opencode.workspaces.remove` => `{"item": {...}}` + +### Worktree Discovery and Mutation + +```bash +curl -sS http://127.0.0.1:8000/ \ + -H 'content-type: application/json' \ + -H 'Authorization: Bearer ' \ + -d '{ + "jsonrpc": "2.0", + "id": 33, + "method": "opencode.worktrees.reset", + "params": { + "request": { + "directory": "/repo/services/api" + } + } + }' +``` + +Response: + +- `opencode.worktrees.list` => `{"items": [...]}` +- `opencode.worktrees.create` => `{"item": {...}}` +- `opencode.worktrees.remove` => `{"ok": true|false}` +- `opencode.worktrees.reset` => `{"ok": true|false}` + ## Interrupt Recovery (Provider-Private Extension) The runtime also exposes provider-private recovery queries for pending diff --git a/src/opencode_a2a/contracts/extensions.py b/src/opencode_a2a/contracts/extensions.py index e66f28d..6682600 100644 --- a/src/opencode_a2a/contracts/extensions.py +++ b/src/opencode_a2a/contracts/extensions.py @@ -13,6 +13,7 @@ SHARED_INTERRUPT_METADATA_FIELD = "metadata.shared.interrupt" SHARED_USAGE_METADATA_FIELD = "metadata.shared.usage" OPENCODE_DIRECTORY_METADATA_FIELD = "metadata.opencode.directory" +OPENCODE_WORKSPACE_METADATA_FIELD = "metadata.opencode.workspace.id" SESSION_BINDING_EXTENSION_URI = "urn:a2a:session-binding/v1" MODEL_SELECTION_EXTENSION_URI = "urn:a2a:model-selection/v1" @@ -21,6 +22,7 @@ PROVIDER_DISCOVERY_EXTENSION_URI = "urn:opencode-a2a:provider-discovery/v1" INTERRUPT_CALLBACK_EXTENSION_URI = "urn:a2a:interactive-interrupt/v1" INTERRUPT_RECOVERY_EXTENSION_URI = "urn:opencode-a2a:interrupt-recovery/v1" +WORKSPACE_CONTROL_EXTENSION_URI = "urn:opencode-a2a:workspace-control/v1" COMPATIBILITY_PROFILE_EXTENSION_URI = "urn:a2a:compatibility-profile/v1" WIRE_CONTRACT_EXTENSION_URI = "urn:a2a:wire-contract/v1" SERVICE_BEHAVIOR_CLASSIFICATION = "service-level-semantic-enhancement" @@ -68,6 +70,16 @@ class InterruptRecoveryMethodContract: notification_response_status: int | None = None +@dataclass(frozen=True) +class WorkspaceControlMethodContract: + method: str + required_params: tuple[str, ...] = () + optional_params: tuple[str, ...] = () + result_fields: tuple[str, ...] = () + items_type: str | None = None + notification_response_status: int | None = None + + PROMPT_ASYNC_REQUEST_REQUIRED_FIELDS: tuple[str, ...] = ("parts",) PROMPT_ASYNC_REQUEST_OPTIONAL_FIELDS: tuple[str, ...] = ( "messageID", @@ -115,6 +127,7 @@ class InterruptRecoveryMethodContract: optional_params=( "limit", "directory", + OPENCODE_WORKSPACE_METADATA_FIELD, "roots", "start", "search", @@ -133,7 +146,13 @@ class InterruptRecoveryMethodContract: "get_session_messages": SessionQueryMethodContract( method="opencode.sessions.messages.list", required_params=("session_id",), - optional_params=("limit", "before", "query.limit", "query.before"), + optional_params=( + "limit", + "before", + OPENCODE_WORKSPACE_METADATA_FIELD, + "query.limit", + "query.before", + ), unsupported_params=SESSION_QUERY_PAGINATION_UNSUPPORTED, result_fields=("items", "next_cursor"), items_type="Message[]", @@ -153,6 +172,7 @@ class InterruptRecoveryMethodContract: "request.system", "request.variant", OPENCODE_DIRECTORY_METADATA_FIELD, + OPENCODE_WORKSPACE_METADATA_FIELD, ), result_fields=("ok", "session_id"), notification_response_status=204, @@ -167,6 +187,7 @@ class InterruptRecoveryMethodContract: "request.variant", "request.parts", OPENCODE_DIRECTORY_METADATA_FIELD, + OPENCODE_WORKSPACE_METADATA_FIELD, ), result_fields=("item",), notification_response_status=204, @@ -174,7 +195,11 @@ class InterruptRecoveryMethodContract: "shell": SessionQueryMethodContract( method="opencode.sessions.shell", required_params=("session_id", "request.agent", "request.command"), - optional_params=("request.model", OPENCODE_DIRECTORY_METADATA_FIELD), + optional_params=( + "request.model", + OPENCODE_DIRECTORY_METADATA_FIELD, + OPENCODE_WORKSPACE_METADATA_FIELD, + ), result_fields=("item",), notification_response_status=204, ), @@ -295,6 +320,73 @@ class InterruptRecoveryMethodContract: key: contract.method for key, contract in INTERRUPT_RECOVERY_METHOD_CONTRACTS.items() } +WORKSPACE_CONTROL_METHOD_CONTRACTS: dict[str, WorkspaceControlMethodContract] = { + "list_projects": WorkspaceControlMethodContract( + method="opencode.projects.list", + result_fields=("items",), + items_type="Project[]", + notification_response_status=204, + ), + "get_current_project": WorkspaceControlMethodContract( + method="opencode.projects.current", + result_fields=("item",), + items_type="Project", + notification_response_status=204, + ), + "list_workspaces": WorkspaceControlMethodContract( + method="opencode.workspaces.list", + result_fields=("items",), + items_type="Workspace[]", + notification_response_status=204, + ), + "create_workspace": WorkspaceControlMethodContract( + method="opencode.workspaces.create", + required_params=("request.type",), + optional_params=("request.id", "request.branch", "request.extra"), + result_fields=("item",), + items_type="Workspace", + notification_response_status=204, + ), + "remove_workspace": WorkspaceControlMethodContract( + method="opencode.workspaces.remove", + required_params=("workspace_id",), + result_fields=("item",), + items_type="Workspace|null", + notification_response_status=204, + ), + "list_worktrees": WorkspaceControlMethodContract( + method="opencode.worktrees.list", + result_fields=("items",), + items_type="string[]", + notification_response_status=204, + ), + "create_worktree": WorkspaceControlMethodContract( + method="opencode.worktrees.create", + optional_params=("request.name", "request.startCommand"), + result_fields=("item",), + items_type="Worktree", + notification_response_status=204, + ), + "remove_worktree": WorkspaceControlMethodContract( + method="opencode.worktrees.remove", + required_params=("request.directory",), + result_fields=("ok",), + items_type="boolean", + notification_response_status=204, + ), + "reset_worktree": WorkspaceControlMethodContract( + method="opencode.worktrees.reset", + required_params=("request.directory",), + result_fields=("ok",), + items_type="boolean", + notification_response_status=204, + ), +} + +WORKSPACE_CONTROL_METHODS: dict[str, str] = { + key: contract.method for key, contract in WORKSPACE_CONTROL_METHOD_CONTRACTS.items() +} + INTERRUPT_SUCCESS_RESULT_FIELDS: tuple[str, ...] = ("ok", "request_id") INTERRUPT_ERROR_BUSINESS_CODES: dict[str, int] = { "INTERRUPT_REQUEST_NOT_FOUND": -32004, @@ -345,6 +437,22 @@ class InterruptRecoveryMethodContract: "field", "fields", ) +WORKSPACE_CONTROL_ERROR_BUSINESS_CODES: dict[str, int] = { + "UPSTREAM_UNREACHABLE": -32002, + "UPSTREAM_HTTP_ERROR": -32003, + "UPSTREAM_PAYLOAD_ERROR": -32005, +} +WORKSPACE_CONTROL_ERROR_DATA_FIELDS: tuple[str, ...] = ( + "type", + "method", + "upstream_status", + "detail", +) +WORKSPACE_CONTROL_INVALID_PARAMS_DATA_FIELDS: tuple[str, ...] = ( + "type", + "field", + "fields", +) @dataclass(frozen=True) @@ -410,6 +518,9 @@ def interrupt_recovery_methods(self) -> dict[str, str]: def interrupt_callback_methods(self) -> dict[str, str]: return dict(INTERRUPT_CALLBACK_METHODS) + def workspace_control_methods(self) -> dict[str, str]: + return dict(WORKSPACE_CONTROL_METHODS) + def supported_jsonrpc_methods(self) -> list[str]: methods = [ *CORE_JSONRPC_METHODS, @@ -418,6 +529,7 @@ def supported_jsonrpc_methods(self) -> list[str]: SESSION_CONTROL_METHODS["prompt_async"], SESSION_CONTROL_METHODS["command"], *PROVIDER_DISCOVERY_METHODS.values(), + *WORKSPACE_CONTROL_METHODS.values(), *INTERRUPT_RECOVERY_METHODS.values(), *INTERRUPT_CALLBACK_METHODS.values(), ] @@ -432,6 +544,7 @@ def extension_jsonrpc_methods(self) -> list[str]: SESSION_CONTROL_METHODS["prompt_async"], SESSION_CONTROL_METHODS["command"], *PROVIDER_DISCOVERY_METHODS.values(), + *WORKSPACE_CONTROL_METHODS.values(), *INTERRUPT_RECOVERY_METHODS.values(), *INTERRUPT_CALLBACK_METHODS.values(), ] @@ -499,8 +612,9 @@ def build_session_binding_extension_params( "supported_metadata": [ "shared.session.id", "opencode.directory", + "opencode.workspace.id", ], - "provider_private_metadata": ["opencode.directory"], + "provider_private_metadata": ["opencode.directory", "opencode.workspace.id"], "profile": runtime_profile.summary_dict(), "notes": [ ( @@ -512,6 +626,11 @@ def build_session_binding_extension_params( "the (identity, contextId)->session_id mapping according to the " "configured task/state store backend and TTL policy." ), + ( + "If metadata.opencode.workspace.id is provided, the server routes the " + "request with workspace precedence and falls back to directory binding only " + "when workspace metadata is absent." + ), ], } @@ -722,10 +841,11 @@ def build_interrupt_callback_extension_params( "answers": "array of answer arrays (same order as asked questions)" }, "request_id_field": f"{SHARED_INTERRUPT_METADATA_FIELD}.request_id", - "supported_metadata": ["opencode.directory"], - "provider_private_metadata": ["opencode.directory"], + "supported_metadata": ["opencode.directory", "opencode.workspace.id"], + "provider_private_metadata": ["opencode.directory", "opencode.workspace.id"], "context_fields": { "directory": OPENCODE_DIRECTORY_METADATA_FIELD, + "workspace_id": OPENCODE_WORKSPACE_METADATA_FIELD, }, "success_result_fields": list(INTERRUPT_SUCCESS_RESULT_FIELDS), "errors": { @@ -827,10 +947,11 @@ def build_provider_discovery_extension_params( return { "methods": dict(PROVIDER_DISCOVERY_METHODS), "method_contracts": method_contracts, - "supported_metadata": ["opencode.directory"], - "provider_private_metadata": ["opencode.directory"], + "supported_metadata": ["opencode.directory", "opencode.workspace.id"], + "provider_private_metadata": ["opencode.directory", "opencode.workspace.id"], "context_fields": { "directory": OPENCODE_DIRECTORY_METADATA_FIELD, + "workspace_id": OPENCODE_WORKSPACE_METADATA_FIELD, }, "provider_item_fields": { "provider_id": "items[].provider_id", @@ -867,6 +988,65 @@ def build_provider_discovery_extension_params( "The server normalizes upstream provider catalogs into summary records so " "downstream callers do not need to parse raw OpenCode payloads." ), + ( + "If metadata.opencode.workspace.id is present, provider/model discovery is " + "routed to that workspace; otherwise the adapter falls back to directory " + "routing when metadata.opencode.directory is provided." + ), + ], + } + + +def build_workspace_control_extension_params( + *, + runtime_profile: RuntimeProfile, +) -> dict[str, Any]: + method_contracts: dict[str, Any] = {} + + for method_contract in WORKSPACE_CONTROL_METHOD_CONTRACTS.values(): + params_contract = _build_method_contract_params( + required=method_contract.required_params, + optional=method_contract.optional_params, + unsupported=(), + ) + result_contract: dict[str, Any] = {"fields": list(method_contract.result_fields)} + if method_contract.items_type: + result_contract["items_type"] = method_contract.items_type + contract_doc: dict[str, Any] = { + "params": params_contract, + "result": result_contract, + } + if method_contract.notification_response_status is not None: + contract_doc["notification_response_status"] = ( + method_contract.notification_response_status + ) + method_contracts[method_contract.method] = contract_doc + + return { + "methods": dict(WORKSPACE_CONTROL_METHODS), + "method_contracts": method_contracts, + "supported_metadata": ["opencode.workspace.id", "opencode.directory"], + "provider_private_metadata": ["opencode.workspace.id", "opencode.directory"], + "routing_fields": { + "workspace_id": OPENCODE_WORKSPACE_METADATA_FIELD, + "directory": OPENCODE_DIRECTORY_METADATA_FIELD, + }, + "errors": { + "business_codes": dict(WORKSPACE_CONTROL_ERROR_BUSINESS_CODES), + "error_data_fields": list(WORKSPACE_CONTROL_ERROR_DATA_FIELDS), + "invalid_params_data_fields": list(WORKSPACE_CONTROL_INVALID_PARAMS_DATA_FIELDS), + }, + "profile": runtime_profile.summary_dict(), + "notes": [ + ( + "Workspace control methods expose the OpenCode project/workspace/worktree " + "control plane through provider-private JSON-RPC methods." + ), + ( + "Workspace routing metadata is declared for consistency, but the current " + "control-plane methods operate on the active deployment project rather than " + "per-request workspace forwarding." + ), ], } @@ -914,6 +1094,17 @@ def build_compatibility_profile_params( for method in PROVIDER_DISCOVERY_METHODS.values() } ) + method_retention.update( + { + method: { + "surface": "extension", + "availability": "always", + "retention": "stable", + "extension_uri": WORKSPACE_CONTROL_EXTENSION_URI, + } + for method in WORKSPACE_CONTROL_METHODS.values() + } + ) method_retention.update( { method: { @@ -968,6 +1159,11 @@ def build_compatibility_profile_params( "availability": "always", "retention": "stable", }, + WORKSPACE_CONTROL_EXTENSION_URI: { + "surface": "jsonrpc-extension", + "availability": "always", + "retention": "stable", + }, INTERRUPT_RECOVERY_EXTENSION_URI: { "surface": "jsonrpc-extension", "availability": "always", @@ -993,6 +1189,7 @@ def build_compatibility_profile_params( ), ( "Treat opencode.sessions.*, opencode.providers.*, opencode.models.*, " + "opencode.projects.*, opencode.workspaces.*, opencode.worktrees.*, " "opencode.permissions.list, and opencode.questions.list as provider-private " "operational surfaces rather than portable A2A baseline capabilities." ), @@ -1041,6 +1238,7 @@ def build_wire_contract_params( STREAMING_EXTENSION_URI, SESSION_QUERY_EXTENSION_URI, PROVIDER_DISCOVERY_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, INTERRUPT_RECOVERY_EXTENSION_URI, INTERRUPT_CALLBACK_EXTENSION_URI, ], diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index 9150c23..5972550 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -28,6 +28,7 @@ TextPart, ) +from ..invocation import call_with_supported_kwargs from ..opencode_upstream_client import ( OpencodeUpstreamClient, UpstreamConcurrencyLimitError, @@ -44,6 +45,7 @@ from .request_context import ( _build_history, _extract_opencode_directory, + _extract_opencode_workspace_id, _extract_shared_model, _extract_shared_session_id, ) @@ -144,6 +146,22 @@ class _PreparedExecution: bound_session_id: str | None model_override: dict[str, str] | None directory: str | None + workspace_id: str | None + session_binding_context_id: str + + +def _build_session_binding_context_id( + *, + context_id: str, + directory: str | None, + workspace_id: str | None, + use_directory_binding: bool, +) -> str: + if isinstance(workspace_id, str) and workspace_id.strip(): + return f"{context_id}::workspace:{workspace_id.strip()}" + if use_directory_binding and isinstance(directory, str) and directory.strip(): + return f"{context_id}::directory:{directory.strip()}" + return context_id class _ExecutionCoordinator: @@ -192,19 +210,22 @@ async def run(self) -> None: while True: send_kwargs: dict[str, Any] = { "directory": self._prepared.directory, + "workspace_id": self._prepared.workspace_id, "model_override": self._prepared.model_override, } if self._prepared.streaming_request: send_kwargs["timeout_override"] = self._executor._client.stream_timeout if not self._prepared.use_structured_parts and not turn_request_parts: - response = await self._executor._client.send_message( + response = await call_with_supported_kwargs( + self._executor._client.send_message, self._session_id, user_text, **send_kwargs, ) else: - response = await self._executor._client.send_message( + response = await call_with_supported_kwargs( + self._executor._client.send_message, self._session_id, user_text or None, parts=turn_request_parts, @@ -214,7 +235,7 @@ async def run(self) -> None: if self._pending_preferred_claim: await self._executor._session_manager.finalize_preferred_session_binding( identity=self._prepared.identity, - context_id=self._context_id, + context_id=self._prepared.session_binding_context_id, session_id=self._session_id, ) self._pending_preferred_claim = False @@ -314,10 +335,11 @@ async def _bind_session(self) -> None: self._pending_preferred_claim, ) = await self._executor._session_manager.get_or_create_session( self._prepared.identity, - self._context_id, + self._prepared.session_binding_context_id, self._prepared.session_title or self._prepared.user_text, preferred_session_id=self._prepared.bound_session_id, directory=self._prepared.directory, + workspace_id=self._prepared.workspace_id, ) self._session_lock = await self._executor._session_manager.get_session_lock( self._session_id @@ -326,6 +348,10 @@ async def _bind_session(self) -> None: async with self._executor._lock: self._executor._running_session_ids[self._execution_key] = self._session_id self._executor._running_directories[self._execution_key] = self._prepared.directory + self._executor._running_workspace_ids[self._execution_key] = self._prepared.workspace_id + self._executor._running_binding_context_ids[self._execution_key] = ( + self._prepared.session_binding_context_id + ) if self._prepared.streaming_request: self._stream_terminal_signal = asyncio.get_running_loop().create_future() @@ -340,6 +366,7 @@ async def _bind_session(self) -> None: event_queue=self._event_queue, stop_event=self._stop_event, directory=self._prepared.directory, + workspace_id=self._prepared.workspace_id, terminal_signal=self._stream_terminal_signal, ) ) @@ -535,6 +562,8 @@ async def _cleanup(self) -> None: self._executor._running_identities.pop(self._execution_key, None) self._executor._running_session_ids.pop(self._execution_key, None) self._executor._running_directories.pop(self._execution_key, None) + self._executor._running_workspace_ids.pop(self._execution_key, None) + self._executor._running_binding_context_ids.pop(self._execution_key, None) class OpencodeAgentExecutor(AgentExecutor): @@ -576,6 +605,8 @@ def __init__( self._running_identities: dict[tuple[str, str], str] = {} self._running_session_ids: dict[tuple[str, str], str] = {} self._running_directories: dict[tuple[str, str], str | None] = {} + self._running_workspace_ids: dict[tuple[str, str], str | None] = {} + self._running_binding_context_ids: dict[tuple[str, str], str] = {} @staticmethod def _emit_metric( @@ -782,24 +813,34 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non streaming_request=streaming_request, ) return + workspace_id = _extract_opencode_workspace_id(context) requested_dir = _extract_opencode_directory(context) - try: - directory = self._sandbox_policy.resolve_directory( - requested_dir, - default_directory=self._client.directory, - ) - except ValueError as e: - logger.warning("Directory validation failed: %s", e) - await self._emit_error( - event_queue, - task_id=task_id, - context_id=context_id, - message=str(e), - state=TaskState.failed, - streaming_request=streaming_request, - ) - return + directory: str | None = None + if workspace_id is None: + try: + directory = self._sandbox_policy.resolve_directory( + requested_dir, + default_directory=self._client.directory, + ) + except ValueError as e: + logger.warning("Directory validation failed: %s", e) + await self._emit_error( + event_queue, + task_id=task_id, + context_id=context_id, + message=str(e), + state=TaskState.failed, + streaming_request=streaming_request, + ) + return + + session_binding_context_id = _build_session_binding_context_id( + context_id=context_id, + directory=directory, + workspace_id=workspace_id, + use_directory_binding=requested_dir is not None, + ) if not user_text and not request_parts: await self._emit_error( @@ -834,6 +875,8 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non bound_session_id=bound_session_id, model_override=model_override, directory=directory, + workspace_id=workspace_id, + session_binding_context_id=session_binding_context_id, ) coordinator = _ExecutionCoordinator( self, @@ -882,9 +925,14 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None stop_event = self._running_stop_events.get(execution_key) running_session_id = self._running_session_ids.get(execution_key) running_directory = self._running_directories.get(execution_key) + running_workspace_id = self._running_workspace_ids.get(execution_key) + running_binding_context_id = self._running_binding_context_ids.get( + execution_key, + context_id, + ) inflight = await self._session_manager.pop_cached_session( identity=running_identity, - context_id=context_id, + context_id=running_binding_context_id, ) if stop_event: stop_event.set() @@ -896,10 +944,14 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None if running_session_id and should_cancel_running_task: self._emit_metric("a2a_cancel_abort_attempt_total") try: + abort_kwargs: dict[str, Any] = {"directory": running_directory} + if running_workspace_id is not None: + abort_kwargs["workspace_id"] = running_workspace_id await asyncio.wait_for( - self._client.abort_session( + call_with_supported_kwargs( + self._client.abort_session, running_session_id, - directory=running_directory, + **abort_kwargs, ), timeout=self._cancel_abort_timeout_seconds, ) @@ -1044,6 +1096,7 @@ async def _consume_opencode_stream( stop_event: asyncio.Event, terminal_signal: asyncio.Future[_StreamTerminalSignal], directory: str | None = None, + workspace_id: str | None = None, ) -> None: await self._stream_runtime.consume( session_id=session_id, @@ -1056,6 +1109,7 @@ async def _consume_opencode_stream( stop_event=stop_event, terminal_signal=terminal_signal, directory=directory, + workspace_id=workspace_id, ) diff --git a/src/opencode_a2a/execution/request_context.py b/src/opencode_a2a/execution/request_context.py index 45e35c7..f70dcac 100644 --- a/src/opencode_a2a/execution/request_context.py +++ b/src/opencode_a2a/execution/request_context.py @@ -89,3 +89,11 @@ def _extract_opencode_directory(context: RequestContext) -> str | None: namespace="opencode", path=("directory",), ) + + +def _extract_opencode_workspace_id(context: RequestContext) -> str | None: + return _extract_namespaced_string_metadata( + context, + namespace="opencode", + path=("workspace", "id"), + ) diff --git a/src/opencode_a2a/execution/session_manager.py b/src/opencode_a2a/execution/session_manager.py index d033105..955c385 100644 --- a/src/opencode_a2a/execution/session_manager.py +++ b/src/opencode_a2a/execution/session_manager.py @@ -2,6 +2,7 @@ import asyncio +from ..invocation import call_with_supported_kwargs from ..server.state_store import MemorySessionStateRepository, SessionStateRepository @@ -33,6 +34,7 @@ async def get_or_create_session( *, preferred_session_id: str | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> tuple[str, bool]: if preferred_session_id: pending_claim = await self.claim_preferred_session( @@ -60,7 +62,12 @@ async def get_or_create_session( task = self._inflight_session_creates.get(cache_key) if task is None: task = asyncio.create_task( - self._client.create_session(title=title, directory=directory) + call_with_supported_kwargs( + self._client.create_session, + title=title, + directory=directory, + workspace_id=workspace_id, + ) ) self._inflight_session_creates[cache_key] = task diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index 9725de0..63ca1a1 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -17,6 +17,7 @@ TextPart, ) +from ..invocation import call_with_supported_kwargs from .event_helpers import _enqueue_artifact_update from .stream_events import ( BlockType, @@ -75,6 +76,7 @@ async def consume( stop_event: asyncio.Event, terminal_signal: asyncio.Future[_StreamTerminalSignal], directory: str | None = None, + workspace_id: str | None = None, ) -> None: part_states: dict[str, _StreamPartState] = {} pending_deltas: defaultdict[str, list[_PendingDelta]] = defaultdict(list) @@ -372,9 +374,11 @@ def _tool_chunks( try: while not stop_event.is_set(): try: - async for event in self._client.stream_events( + async for event in call_with_supported_kwargs( + self._client.stream_events, stop_event=stop_event, directory=directory, + workspace_id=workspace_id, ): if stop_event.is_set(): break diff --git a/src/opencode_a2a/invocation.py b/src/opencode_a2a/invocation.py new file mode 100644 index 0000000..c333d22 --- /dev/null +++ b/src/opencode_a2a/invocation.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import inspect +from typing import Any + + +def _resolve_signature_target(target): # noqa: ANN001 + side_effect = getattr(target, "side_effect", None) + if callable(side_effect): + return side_effect + return target + + +def call_with_supported_kwargs(target, /, *args: Any, **kwargs: Any): # noqa: ANN001 + signature_target = _resolve_signature_target(target) + try: + signature = inspect.signature(signature_target) + except (TypeError, ValueError): + return target(*args, **kwargs) + + if any( + parameter.kind == inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ): + return target(*args, **kwargs) + + supported_kwargs = { + name: value for name, value in kwargs.items() if name in signature.parameters + } + return target(*args, **supported_kwargs) diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index 1c1cc5c..f357858 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -84,6 +84,15 @@ def __init__( self._method_shell = methods.get("shell") self._method_list_providers = methods["list_providers"] self._method_list_models = methods["list_models"] + self._method_list_projects = methods["list_projects"] + self._method_get_current_project = methods["get_current_project"] + self._method_list_workspaces = methods["list_workspaces"] + self._method_create_workspace = methods["create_workspace"] + self._method_remove_workspace = methods["remove_workspace"] + self._method_list_worktrees = methods["list_worktrees"] + self._method_create_worktree = methods["create_worktree"] + self._method_remove_worktree = methods["remove_worktree"] + self._method_reset_worktree = methods["reset_worktree"] self._method_list_permissions = methods["list_permissions"] self._method_list_questions = methods["list_questions"] self._method_reply_permission = methods["reply_permission"] @@ -118,6 +127,15 @@ def __init__( method_shell=self._method_shell, method_list_providers=self._method_list_providers, method_list_models=self._method_list_models, + method_list_projects=self._method_list_projects, + method_get_current_project=self._method_get_current_project, + method_list_workspaces=self._method_list_workspaces, + method_create_workspace=self._method_create_workspace, + method_remove_workspace=self._method_remove_workspace, + method_list_worktrees=self._method_list_worktrees, + method_create_worktree=self._method_create_worktree, + method_remove_worktree=self._method_remove_worktree, + method_reset_worktree=self._method_reset_worktree, method_list_permissions=self._method_list_permissions, method_list_questions=self._method_list_questions, method_reply_permission=self._method_reply_permission, diff --git a/src/opencode_a2a/jsonrpc/dispatch.py b/src/opencode_a2a/jsonrpc/dispatch.py index 33c5591..d909929 100644 --- a/src/opencode_a2a/jsonrpc/dispatch.py +++ b/src/opencode_a2a/jsonrpc/dispatch.py @@ -37,6 +37,15 @@ class ExtensionHandlerContext: method_shell: str | None method_list_providers: str method_list_models: str + method_list_projects: str + method_get_current_project: str + method_list_workspaces: str + method_create_workspace: str + method_remove_workspace: str + method_list_worktrees: str + method_create_worktree: str + method_remove_worktree: str + method_reset_worktree: str method_list_permissions: str method_list_questions: str method_reply_permission: str @@ -95,6 +104,7 @@ def build_extension_method_registry( from .handlers.provider_discovery import handle_provider_discovery_request from .handlers.session_control import handle_session_control_request from .handlers.session_queries import handle_session_query_request + from .handlers.workspace_control import handle_workspace_control_request session_control_methods = {context.method_prompt_async, context.method_command} if context.method_shell is not None: @@ -132,6 +142,23 @@ def build_extension_method_registry( ), handler=handle_interrupt_query_request, ), + ExtensionMethodSpec( + name="workspace_control", + methods=frozenset( + { + context.method_list_projects, + context.method_get_current_project, + context.method_list_workspaces, + context.method_create_workspace, + context.method_remove_workspace, + context.method_list_worktrees, + context.method_create_worktree, + context.method_remove_worktree, + context.method_reset_worktree, + } + ), + handler=handle_workspace_control_request, + ), ExtensionMethodSpec( name="session_control", methods=frozenset(session_control_methods), diff --git a/src/opencode_a2a/jsonrpc/handlers/common.py b/src/opencode_a2a/jsonrpc/handlers/common.py index 9fc4ac5..66311e2 100644 --- a/src/opencode_a2a/jsonrpc/handlers/common.py +++ b/src/opencode_a2a/jsonrpc/handlers/common.py @@ -107,6 +107,106 @@ def extract_directory_from_metadata( return directory, None +def extract_workspace_id_from_metadata( + context: ExtensionHandlerContext, + *, + request_id: str | int | None, + params: dict[str, Any], +) -> tuple[str | None, Response | None]: + metadata = params.get("metadata") + if metadata is None: + return None, None + if not isinstance(metadata, dict): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata must be an object", + data={"type": "INVALID_FIELD", "field": "metadata"}, + ), + ) + + raw_opencode_metadata = metadata.get("opencode") + if raw_opencode_metadata is None: + return None, None + if not isinstance(raw_opencode_metadata, dict): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata.opencode must be an object", + data={"type": "INVALID_FIELD", "field": "metadata.opencode"}, + ), + ) + + raw_workspace = raw_opencode_metadata.get("workspace") + if raw_workspace is None: + return None, None + if not isinstance(raw_workspace, dict): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata.opencode.workspace must be an object", + data={"type": "INVALID_FIELD", "field": "metadata.opencode.workspace"}, + ), + ) + + raw_workspace_id = raw_workspace.get("id") + if raw_workspace_id is None: + return None, None + if not isinstance(raw_workspace_id, str): + return None, context.error_response( + request_id, + invalid_params_error( + "metadata.opencode.workspace.id must be a string", + data={"type": "INVALID_FIELD", "field": "metadata.opencode.workspace.id"}, + ), + ) + workspace_id = raw_workspace_id.strip() + return workspace_id or None, None + + +def resolve_routing_context( + context: ExtensionHandlerContext, + *, + request_id: str | int | None, + params: dict[str, Any], + requested_directory: str | None = None, +) -> tuple[str | None, str | None, Response | None]: + workspace_id, workspace_error = extract_workspace_id_from_metadata( + context, + request_id=request_id, + params=params, + ) + if workspace_error is not None: + return None, None, workspace_error + if workspace_id is not None: + return None, workspace_id, None + + if requested_directory is not None: + try: + return context.directory_resolver(requested_directory), None, None + except ValueError as exc: + return ( + None, + None, + context.error_response( + request_id, + invalid_params_error( + str(exc), + data={"type": "INVALID_FIELD", "field": "directory"}, + ), + ), + ) + + directory, directory_error = resolve_directory( + context, + request_id=request_id, + params=params, + ) + if directory_error is not None: + return None, None, directory_error + return directory, None, None + + def resolve_directory( context: ExtensionHandlerContext, *, diff --git a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py index c647217..44186bd 100644 --- a/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py +++ b/src/opencode_a2a/jsonrpc/handlers/interrupt_callbacks.py @@ -9,6 +9,7 @@ from starlette.responses import Response from ...contracts.extensions import INTERRUPT_ERROR_BUSINESS_CODES +from ...invocation import call_with_supported_kwargs from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import ( @@ -24,6 +25,7 @@ build_upstream_http_error_response, build_upstream_unreachable_error_response, extract_interrupt_callback_directory_hint, + extract_workspace_id_from_metadata, ) logger = logging.getLogger(__name__) @@ -60,6 +62,13 @@ async def handle_interrupt_callback_request( ) if directory_error is not None: return directory_error + workspace_id, workspace_error = extract_workspace_id_from_metadata( + context, + request_id=base_request.id, + params=params, + ) + if workspace_error is not None: + return workspace_error expected_interrupt_type = ( "permission" if base_request.method == context.method_reply_permission else "question" @@ -136,21 +145,30 @@ async def handle_interrupt_callback_request( message = params.get("message") if message is not None and not isinstance(message, str): raise ValueError("message must be a string") - await context.upstream_client.permission_reply( + await call_with_supported_kwargs( + context.upstream_client.permission_reply, request_id, reply=reply, message=message, directory=directory, + workspace_id=workspace_id, ) elif base_request.method == context.method_reply_question: answers = _parse_question_answers(params.get("answers")) - await context.upstream_client.question_reply( + await call_with_supported_kwargs( + context.upstream_client.question_reply, request_id, answers=answers, directory=directory, + workspace_id=workspace_id, ) else: - await context.upstream_client.question_reject(request_id, directory=directory) + await call_with_supported_kwargs( + context.upstream_client.question_reject, + request_id, + directory=directory, + workspace_id=workspace_id, + ) discard_request = getattr(context.upstream_client, "discard_interrupt_request", None) if callable(discard_request): await discard_request(request_id) diff --git a/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py index 3b17dd8..ce4f508 100644 --- a/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py +++ b/src/opencode_a2a/jsonrpc/handlers/provider_discovery.py @@ -9,6 +9,7 @@ from starlette.responses import Response from ...contracts.extensions import PROVIDER_DISCOVERY_ERROR_BUSINESS_CODES +from ...invocation import call_with_supported_kwargs from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import invalid_params_error @@ -24,7 +25,7 @@ build_upstream_http_error_response, build_upstream_payload_error_response, build_upstream_unreachable_error_response, - resolve_directory, + resolve_routing_context, ) logger = logging.getLogger(__name__) @@ -71,16 +72,20 @@ async def handle_provider_discovery_request( ) provider_id = raw_provider_id.strip() - directory, directory_error = resolve_directory( + directory, workspace_id, routing_error = resolve_routing_context( context, request_id=base_request.id, params=params, ) - if directory_error is not None: - return directory_error + if routing_error is not None: + return routing_error try: - raw_result = await context.upstream_client.list_provider_catalog(directory=directory) + raw_result = await call_with_supported_kwargs( + context.upstream_client.list_provider_catalog, + directory=directory, + workspace_id=workspace_id, + ) except httpx.HTTPStatusError as exc: upstream_status = exc.response.status_code return build_upstream_http_error_response( diff --git a/src/opencode_a2a/jsonrpc/handlers/session_control.py b/src/opencode_a2a/jsonrpc/handlers/session_control.py index e81a1a6..a812ce8 100644 --- a/src/opencode_a2a/jsonrpc/handlers/session_control.py +++ b/src/opencode_a2a/jsonrpc/handlers/session_control.py @@ -9,6 +9,7 @@ from starlette.responses import Response from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES +from ...invocation import call_with_supported_kwargs from ...opencode_upstream_client import UpstreamConcurrencyLimitError, UpstreamContractError from ..dispatch import ExtensionHandlerContext from ..error_responses import invalid_params_error, session_not_found_error @@ -27,7 +28,7 @@ build_upstream_http_error_response, build_upstream_payload_error_response, build_upstream_unreachable_error_response, - resolve_directory, + resolve_routing_context, ) logger = logging.getLogger(__name__) @@ -121,13 +122,13 @@ def _log_shell_audit(outcome: str) -> None: invalid_params_error(str(exc), data={"type": "INVALID_FIELD", "field": exc.field}), ) - directory, directory_error = resolve_directory( + directory, workspace_id, routing_error = resolve_routing_context( context, request_id=base_request.id, params=params, ) - if directory_error is not None: - return directory_error + if routing_error is not None: + return routing_error pending_claim = False claim_finalized = False @@ -148,17 +149,21 @@ def _log_shell_audit(outcome: str) -> None: try: result: dict[str, Any] if base_request.method == context.method_prompt_async: - await context.upstream_client.session_prompt_async( + await call_with_supported_kwargs( + context.upstream_client.session_prompt_async, session_id, request=dict(raw_request), directory=directory, + workspace_id=workspace_id, ) result = {"ok": True, "session_id": session_id} elif base_request.method == context.method_command: - raw_result = await context.upstream_client.session_command( + raw_result = await call_with_supported_kwargs( + context.upstream_client.session_command, session_id, request=dict(raw_request), directory=directory, + workspace_id=workspace_id, ) item = _as_a2a_message(session_id, raw_result) if item is None: @@ -168,10 +173,12 @@ def _log_shell_audit(outcome: str) -> None: ) result = {"item": item} else: - raw_result = await context.upstream_client.session_shell( + raw_result = await call_with_supported_kwargs( + context.upstream_client.session_shell, session_id, request=dict(raw_request), directory=directory, + workspace_id=workspace_id, ) item = _as_a2a_message(session_id, raw_result) if item is None: diff --git a/src/opencode_a2a/jsonrpc/handlers/session_queries.py b/src/opencode_a2a/jsonrpc/handlers/session_queries.py index dad8df8..35247d4 100644 --- a/src/opencode_a2a/jsonrpc/handlers/session_queries.py +++ b/src/opencode_a2a/jsonrpc/handlers/session_queries.py @@ -9,6 +9,7 @@ from starlette.responses import Response from ...contracts.extensions import SESSION_QUERY_ERROR_BUSINESS_CODES +from ...invocation import call_with_supported_kwargs from ...opencode_upstream_client import UpstreamConcurrencyLimitError from ..dispatch import ExtensionHandlerContext from ..error_responses import invalid_params_error, session_not_found_error @@ -30,6 +31,7 @@ build_upstream_http_error_response, build_upstream_payload_error_response, build_upstream_unreachable_error_response, + resolve_routing_context, ) logger = logging.getLogger(__name__) @@ -60,6 +62,7 @@ async def handle_session_query_request( limit = int(query["limit"]) directory = None + workspace_id = None if base_request.method == context.method_list_sessions: requested_directory = query.pop("directory", None) if requested_directory is not None and not isinstance(requested_directory, str): @@ -70,25 +73,38 @@ async def handle_session_query_request( data={"type": "INVALID_FIELD", "field": "directory"}, ), ) - try: - directory = context.directory_resolver(requested_directory) - except ValueError as exc: - return context.error_response( - base_request.id, - invalid_params_error( - str(exc), - data={"type": "INVALID_FIELD", "field": "directory"}, - ), - ) + directory, workspace_id, routing_error = resolve_routing_context( + context, + request_id=base_request.id, + params=params, + requested_directory=requested_directory, + ) + if routing_error is not None: + return routing_error + else: + directory, workspace_id, routing_error = resolve_routing_context( + context, + request_id=base_request.id, + params=params, + ) + if routing_error is not None: + return routing_error try: if base_request.method == context.method_list_sessions: - raw_result = await context.upstream_client.list_sessions( + raw_result = await call_with_supported_kwargs( + context.upstream_client.list_sessions, params=query, directory=directory, + workspace_id=workspace_id, ) else: assert session_id is not None - raw_result = await context.upstream_client.list_messages(session_id, params=query) + raw_result = await call_with_supported_kwargs( + context.upstream_client.list_messages, + session_id, + params=query, + workspace_id=workspace_id, + ) except httpx.HTTPStatusError as exc: upstream_status = exc.response.status_code if upstream_status == 404 and base_request.method == context.method_get_session_messages: diff --git a/src/opencode_a2a/jsonrpc/handlers/workspace_control.py b/src/opencode_a2a/jsonrpc/handlers/workspace_control.py new file mode 100644 index 0000000..c846485 --- /dev/null +++ b/src/opencode_a2a/jsonrpc/handlers/workspace_control.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import logging +from typing import Any + +import httpx +from a2a.types import JSONRPCRequest +from starlette.requests import Request +from starlette.responses import Response + +from ...contracts.extensions import WORKSPACE_CONTROL_ERROR_BUSINESS_CODES +from ...opencode_upstream_client import UpstreamConcurrencyLimitError +from ..dispatch import ExtensionHandlerContext +from ..error_responses import invalid_params_error +from .common import ( + build_internal_error_response, + build_success_response, + build_upstream_concurrency_error_response, + build_upstream_http_error_response, + build_upstream_payload_error_response, + build_upstream_unreachable_error_response, +) + +logger = logging.getLogger(__name__) + +ERR_UPSTREAM_UNREACHABLE = WORKSPACE_CONTROL_ERROR_BUSINESS_CODES["UPSTREAM_UNREACHABLE"] +ERR_UPSTREAM_HTTP_ERROR = WORKSPACE_CONTROL_ERROR_BUSINESS_CODES["UPSTREAM_HTTP_ERROR"] +ERR_UPSTREAM_PAYLOAD_ERROR = WORKSPACE_CONTROL_ERROR_BUSINESS_CODES["UPSTREAM_PAYLOAD_ERROR"] + + +def _parse_optional_request_object( + params: dict[str, Any], + *, + required: bool, +) -> dict[str, Any] | None: + value = params.get("request") + if value is None: + if required: + raise ValueError("Missing required params.request") + return None + if not isinstance(value, dict): + raise TypeError("params.request must be an object") + return dict(value) + + +def _parse_workspace_id(params: dict[str, Any]) -> str: + raw_workspace_id = params.get("workspace_id") + if not isinstance(raw_workspace_id, str) or not raw_workspace_id.strip(): + raise ValueError("Missing required params.workspace_id") + return raw_workspace_id.strip() + + +def _validate_workspace_request(method: str, request: dict[str, Any]) -> None: + if method == "create_workspace": + allowed_fields = {"id", "type", "branch", "extra"} + if "type" not in request: + raise ValueError("Missing required params.request.type") + request_type = request.get("type") + if not isinstance(request_type, str) or not request_type.strip(): + raise TypeError("params.request.type must be a non-empty string") + elif method == "create_worktree": + allowed_fields = {"name", "startCommand"} + elif method in {"remove_worktree", "reset_worktree"}: + allowed_fields = {"directory"} + directory = request.get("directory") + if not isinstance(directory, str) or not directory.strip(): + raise TypeError("params.request.directory must be a non-empty string") + else: + allowed_fields = set() + + unknown_fields = sorted(set(request) - allowed_fields) + if unknown_fields: + raise ValueError( + "Unsupported request fields: " + + ", ".join(f"request.{field}" for field in unknown_fields) + ) + + for field in ("id", "type", "branch", "name", "startCommand", "directory"): + if field not in request: + continue + value = request[field] + if value is not None and not isinstance(value, str): + raise TypeError(f"params.request.{field} must be a string") + + +def _validate_allowed_fields( + method: str, + params: dict[str, Any], +) -> None: + allowed_fields = {"metadata"} + if method in {"create_workspace", "create_worktree", "remove_worktree", "reset_worktree"}: + allowed_fields.add("request") + if method == "remove_workspace": + allowed_fields.add("workspace_id") + + unknown_fields = sorted(set(params) - allowed_fields) + if unknown_fields: + raise ValueError("Unsupported fields: " + ", ".join(unknown_fields)) + + +def _validate_response_payload(method: str, payload: Any) -> dict[str, Any]: + if method in {"list_projects", "list_workspaces", "list_worktrees"}: + if not isinstance(payload, list): + raise ValueError("Upstream list response must be an array") + return {"items": payload} + if method in {"get_current_project", "create_workspace", "remove_workspace", "create_worktree"}: + if payload is not None and not isinstance(payload, dict): + raise ValueError("Upstream item response must be an object or null") + return {"item": payload} + if method in {"remove_worktree", "reset_worktree"}: + if not isinstance(payload, bool): + raise ValueError("Upstream boolean response must be a boolean") + return {"ok": payload} + raise ValueError(f"Unsupported workspace control method: {method}") + + +async def handle_workspace_control_request( + context: ExtensionHandlerContext, + base_request: JSONRPCRequest, + params: dict[str, Any], + request: Request, +) -> Response: + del request + + method_map = { + context.method_list_projects: "list_projects", + context.method_get_current_project: "get_current_project", + context.method_list_workspaces: "list_workspaces", + context.method_create_workspace: "create_workspace", + context.method_remove_workspace: "remove_workspace", + context.method_list_worktrees: "list_worktrees", + context.method_create_worktree: "create_worktree", + context.method_remove_worktree: "remove_worktree", + context.method_reset_worktree: "reset_worktree", + } + method_key = method_map.get(base_request.method) + if method_key is None: + return context.error_response( + base_request.id, + invalid_params_error( + f"Unsupported method: {base_request.method}", + data={"type": "INVALID_FIELD", "field": "method"}, + ), + ) + + try: + _validate_allowed_fields(method_key, params) + request_body: dict[str, Any] | None = None + workspace_id: str | None = None + if method_key == "remove_workspace": + workspace_id = _parse_workspace_id(params) + elif method_key in { + "create_workspace", + "create_worktree", + "remove_worktree", + "reset_worktree", + }: + request_body = _parse_optional_request_object( + params, + required=True, + ) + assert request_body is not None + _validate_workspace_request(method_key, request_body) + except ValueError as exc: + field = "workspace_id" if "workspace_id" in str(exc) else "request" + return context.error_response( + base_request.id, + invalid_params_error(str(exc), data={"type": "INVALID_FIELD", "field": field}), + ) + except TypeError as exc: + return context.error_response( + base_request.id, + invalid_params_error(str(exc), data={"type": "INVALID_FIELD"}), + ) + + try: + if method_key == "list_projects": + raw_result = await context.upstream_client.list_projects() + elif method_key == "get_current_project": + raw_result = await context.upstream_client.get_current_project() + elif method_key == "list_workspaces": + raw_result = await context.upstream_client.list_workspaces() + elif method_key == "create_workspace": + raw_result = await context.upstream_client.create_workspace(request_body or {}) + elif method_key == "remove_workspace": + assert workspace_id is not None + raw_result = await context.upstream_client.remove_workspace(workspace_id) + elif method_key == "list_worktrees": + raw_result = await context.upstream_client.list_worktrees() + elif method_key == "create_worktree": + raw_result = await context.upstream_client.create_worktree(request_body or {}) + elif method_key == "remove_worktree": + raw_result = await context.upstream_client.remove_worktree(request_body or {}) + else: + raw_result = await context.upstream_client.reset_worktree(request_body or {}) + except httpx.HTTPStatusError as exc: + return build_upstream_http_error_response( + context, + base_request.id, + ERR_UPSTREAM_HTTP_ERROR, + upstream_status=exc.response.status_code, + method=base_request.method, + ) + except httpx.HTTPError: + return build_upstream_unreachable_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + method=base_request.method, + ) + except UpstreamConcurrencyLimitError as exc: + return build_upstream_concurrency_error_response( + context, + base_request.id, + ERR_UPSTREAM_UNREACHABLE, + exc=exc, + method=base_request.method, + ) + except Exception as exc: + return build_internal_error_response( + context, + base_request.id, + log_message="OpenCode workspace control JSON-RPC method failed", + exc=exc, + ) + + try: + result = _validate_response_payload(method_key, raw_result) + except ValueError as exc: + logger.warning("Upstream OpenCode workspace payload mismatch: %s", exc) + return build_upstream_payload_error_response( + context, + base_request.id, + ERR_UPSTREAM_PAYLOAD_ERROR, + detail=str(exc), + method=base_request.method, + ) + + return build_success_response(context, base_request.id, result) diff --git a/src/opencode_a2a/opencode_upstream_client.py b/src/opencode_a2a/opencode_upstream_client.py index 4983de9..249b44f 100644 --- a/src/opencode_a2a/opencode_upstream_client.py +++ b/src/opencode_a2a/opencode_upstream_client.py @@ -333,32 +333,51 @@ def _normalize_model_ref(value: Mapping[str, Any] | None) -> dict[str, str] | No "modelID": model_id, } - def _query_params(self, directory: str | None = None) -> dict[str, str]: + def _query_params( + self, + directory: str | None = None, + *, + workspace_id: str | None = None, + ) -> dict[str, str]: + if isinstance(workspace_id, str): + normalized_workspace_id = workspace_id.strip() + if normalized_workspace_id: + return {"workspace": normalized_workspace_id} d = directory or self._directory if not d: return {} return {"directory": d} def _merge_params( - self, extra: dict[str, Any] | None, *, directory: str | None = None + self, + extra: dict[str, Any] | None, + *, + directory: str | None = None, + workspace_id: str | None = None, ) -> dict[str, Any]: - params: dict[str, Any] = dict(self._query_params(directory=directory)) + params: dict[str, Any] = dict( + self._query_params(directory=directory, workspace_id=workspace_id) + ) if not extra: return params for key, value in extra.items(): if value is None: continue # "directory" is server-controlled. Client overrides are handled via explicit parameter. - if key == "directory": + if key in {"directory", "workspace"}: continue # FastAPI query params are strings; keep them as-is. Coerce other primitives to str. params[key] = value if isinstance(value, str) else str(value) return params async def stream_events( - self, stop_event: asyncio.Event | None = None, *, directory: str | None = None + self, + stop_event: asyncio.Event | None = None, + *, + directory: str | None = None, + workspace_id: str | None = None, ) -> AsyncIterator[dict[str, Any]]: - params = self._query_params(directory=directory) + params = self._query_params(directory=directory, workspace_id=workspace_id) async with self._stream_budget.reserve(operation="/event"): async with self._client.stream( "GET", @@ -393,7 +412,11 @@ async def stream_events( continue async def create_session( - self, title: str | None = None, *, directory: str | None = None + self, + title: str | None = None, + *, + directory: str | None = None, + workspace_id: str | None = None, ) -> str: payload: dict[str, Any] = {} if title: @@ -401,7 +424,7 @@ async def create_session( data = await self._post_json( "/session", endpoint="/session", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), json_body=payload, ) session_id = data.get("id") @@ -409,11 +432,17 @@ async def create_session( raise RuntimeError("OpenCode session response missing id") return session_id - async def abort_session(self, session_id: str, *, directory: str | None = None) -> bool: + async def abort_session( + self, + session_id: str, + *, + directory: str | None = None, + workspace_id: str | None = None, + ) -> bool: return await self._post_boolean( f"/session/{session_id}/abort", endpoint="/session/{sessionID}/abort", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), ) async def list_sessions( @@ -421,12 +450,13 @@ async def list_sessions( *, params: dict[str, Any] | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> Any: """List sessions from OpenCode.""" return await self._get_json( "/session", endpoint="/session", - params=self._merge_params(params, directory=directory), + params=self._merge_params(params, directory=directory, workspace_id=workspace_id), ) async def list_messages( @@ -434,13 +464,14 @@ async def list_messages( session_id: str, *, params: dict[str, Any] | None = None, + workspace_id: str | None = None, ) -> OpencodeMessagePage: """List messages for a session from OpenCode.""" endpoint = "/session/{sessionID}/message" async with self._request_budget.reserve(operation=endpoint): response = await self._client.get( f"/session/{session_id}/message", - params=self._merge_params(params), + params=self._merge_params(params, workspace_id=workspace_id), ) response.raise_for_status() payload = self._decode_json_response(response, endpoint=endpoint) @@ -458,12 +489,13 @@ async def session_prompt_async( request: dict[str, Any], *, directory: str | None = None, + workspace_id: str | None = None, ) -> None: endpoint = "/session/{sessionID}/prompt_async" async with self._request_budget.reserve(operation=endpoint): response = await self._client.post( f"/session/{session_id}/prompt_async", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), json=request, ) response.raise_for_status() @@ -478,11 +510,12 @@ async def session_command( request: dict[str, Any], *, directory: str | None = None, + workspace_id: str | None = None, ) -> Any: return await self._post_json( f"/session/{session_id}/command", endpoint="/session/{sessionID}/command", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), json_body=request, ) @@ -492,19 +525,87 @@ async def session_shell( request: dict[str, Any], *, directory: str | None = None, + workspace_id: str | None = None, ) -> Any: return await self._post_json( f"/session/{session_id}/shell", endpoint="/session/{sessionID}/shell", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), json_body=request, ) - async def list_provider_catalog(self, *, directory: str | None = None) -> Any: + async def list_provider_catalog( + self, + *, + directory: str | None = None, + workspace_id: str | None = None, + ) -> Any: return await self._get_json( "/provider", endpoint="/provider", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), + ) + + async def list_projects(self) -> Any: + return await self._get_json("/project", endpoint="/project") + + async def get_current_project(self) -> Any: + return await self._get_json("/project/current", endpoint="/project/current") + + async def list_workspaces(self) -> Any: + return await self._get_json( + "/experimental/workspace", + endpoint="/experimental/workspace", + ) + + async def create_workspace(self, request: dict[str, Any]) -> Any: + return await self._post_json( + "/experimental/workspace", + endpoint="/experimental/workspace", + json_body=request, + ) + + async def remove_workspace(self, workspace_id: str) -> Any: + async with self._request_budget.reserve(operation="/experimental/workspace/{id}"): + response = await self._client.delete(f"/experimental/workspace/{workspace_id}") + response.raise_for_status() + return self._decode_json_response( + response, + endpoint="/experimental/workspace/{id}", + ) + + async def list_worktrees(self) -> Any: + return await self._get_json( + "/experimental/worktree", + endpoint="/experimental/worktree", + ) + + async def create_worktree(self, request: dict[str, Any]) -> Any: + return await self._post_json( + "/experimental/worktree", + endpoint="/experimental/worktree", + json_body=request, + ) + + async def remove_worktree(self, request: dict[str, Any]) -> bool: + async with self._request_budget.reserve(operation="/experimental/worktree"): + response = await self._client.request( + "DELETE", + "/experimental/worktree", + json=request, + ) + response.raise_for_status() + payload = self._decode_json_response(response, endpoint="/experimental/worktree") + return self._require_boolean_response( + endpoint="/experimental/worktree", + payload=payload, + ) + + async def reset_worktree(self, request: dict[str, Any]) -> bool: + return await self._post_boolean( + "/experimental/worktree/reset", + endpoint="/experimental/worktree/reset", + json_body=request, ) async def send_message( @@ -514,6 +615,7 @@ async def send_message( *, parts: Sequence[Mapping[str, Any]] | None = None, directory: str | None = None, + workspace_id: str | None = None, model_override: Mapping[str, Any] | None = None, timeout_override: float | None | object = _UNSET, ) -> OpencodeMessage: @@ -550,7 +652,7 @@ async def send_message( data = await self._post_json( f"/session/{session_id}/message", endpoint="/session/{sessionID}/message", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), json_body=payload, timeout=timeout_override, ) @@ -575,6 +677,7 @@ async def permission_reply( reply: str, message: str | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: payload: dict[str, Any] = {"reply": reply} if message: @@ -582,7 +685,7 @@ async def permission_reply( return await self._post_boolean( f"/permission/{request_id}/reply", endpoint="/permission/{requestID}/reply", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), json_body=payload, ) @@ -592,11 +695,12 @@ async def question_reply( *, answers: list[list[str]], directory: str | None = None, + workspace_id: str | None = None, ) -> bool: return await self._post_boolean( f"/question/{request_id}/reply", endpoint="/question/{requestID}/reply", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), json_body={"answers": answers}, ) @@ -605,9 +709,10 @@ async def question_reject( request_id: str, *, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: return await self._post_boolean( f"/question/{request_id}/reject", endpoint="/question/{requestID}/reject", - params=self._query_params(directory=directory), + params=self._query_params(directory=directory, workspace_id=workspace_id), ) diff --git a/src/opencode_a2a/profile/runtime.py b/src/opencode_a2a/profile/runtime.py index 86de2f8..e44863a 100644 --- a/src/opencode_a2a/profile/runtime.py +++ b/src/opencode_a2a/profile/runtime.py @@ -10,6 +10,7 @@ DEPLOYMENT_ID = "single_tenant_shared_workspace" SESSION_SHELL_TOGGLE = "A2A_ENABLE_SESSION_SHELL" DIRECTORY_OVERRIDE_METADATA_FIELD = "metadata.opencode.directory" +WORKSPACE_OVERRIDE_METADATA_FIELD = "metadata.opencode.workspace.id" @dataclass(frozen=True) @@ -42,6 +43,22 @@ def as_dict(self) -> dict[str, Any]: } +@dataclass(frozen=True) +class WorkspaceBindingProfile: + enabled: bool + metadata_field: str = WORKSPACE_OVERRIDE_METADATA_FIELD + upstream_query_param: str = "workspace" + precedence: str = "prefer_workspace_else_directory" + + def as_dict(self) -> dict[str, Any]: + return { + "enabled": self.enabled, + "metadata_field": self.metadata_field, + "upstream_query_param": self.upstream_query_param, + "precedence": self.precedence, + } + + @dataclass(frozen=True) class SessionShellProfile: enabled: bool @@ -161,6 +178,7 @@ class RuntimeProfile: profile_id: str deployment: DeploymentProfile directory_binding: DirectoryBindingProfile + workspace_binding: WorkspaceBindingProfile session_shell: SessionShellProfile execution_environment: ExecutionEnvironmentProfile service_features: ServiceFeaturesProfile @@ -169,6 +187,7 @@ class RuntimeProfile: def runtime_features_dict(self) -> dict[str, Any]: return { "directory_binding": self.directory_binding.as_dict(), + "workspace_binding": self.workspace_binding.as_dict(), "session_shell": self.session_shell.as_dict(), "execution_environment": self.execution_environment.as_dict(), "service_features": self.service_features.as_dict(), @@ -219,6 +238,9 @@ def build_runtime_profile(settings: Settings) -> RuntimeProfile: allow_override=settings.a2a_allow_directory_override, scope=directory_scope, ), + workspace_binding=WorkspaceBindingProfile( + enabled=True, + ), session_shell=SessionShellProfile( enabled=shell_enabled, availability="enabled" if shell_enabled else "disabled", diff --git a/src/opencode_a2a/server/agent_card.py b/src/opencode_a2a/server/agent_card.py index 1ad4384..bb60cd1 100644 --- a/src/opencode_a2a/server/agent_card.py +++ b/src/opencode_a2a/server/agent_card.py @@ -23,6 +23,7 @@ SESSION_QUERY_METHODS, STREAMING_EXTENSION_URI, WIRE_CONTRACT_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, JsonRpcCapabilitySnapshot, build_capability_snapshot, build_compatibility_profile_params, @@ -34,6 +35,7 @@ build_session_query_extension_params, build_streaming_extension_params, build_wire_contract_params, + build_workspace_control_extension_params, ) from ..jsonrpc.application import SESSION_CONTEXT_PREFIX from ..profile.runtime import RuntimeProfile, build_runtime_profile @@ -46,8 +48,8 @@ def _build_agent_card_description(settings: Settings, runtime_profile: RuntimePr "(message/send, message/stream), task APIs (tasks/get, tasks/cancel, " "tasks/resubscribe; REST mapping: GET /v1/tasks/{id}:subscribe), shared " "session-binding/model-selection/streaming contracts, provider-private " - "OpenCode session/provider/model/interrupt recovery extensions, and " - "shared interrupt callback extensions." + "OpenCode session/provider/model/workspace-control/interrupt recovery " + "extensions, and shared interrupt callback extensions." ) parts: list[str] = [base, summary] parts.append( @@ -101,6 +103,14 @@ def _build_interrupt_recovery_skill_examples() -> list[str]: ] +def _build_workspace_control_skill_examples() -> list[str]: + return [ + "List OpenCode projects (method opencode.projects.list).", + "List workspaces for the active project (method opencode.workspaces.list).", + "Create a worktree (method opencode.worktrees.create).", + ] + + def build_agent_card(settings: Settings) -> AgentCard: public_url = settings.a2a_public_url.rstrip("/") base_url = public_url @@ -131,6 +141,9 @@ def build_agent_card(settings: Settings) -> AgentCard: provider_discovery_extension_params = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + workspace_control_extension_params = build_workspace_control_extension_params( + runtime_profile=runtime_profile, + ) interrupt_recovery_extension_params = build_interrupt_recovery_extension_params( runtime_profile=runtime_profile, ) @@ -208,6 +221,15 @@ def build_agent_card(settings: Settings) -> AgentCard: ), params=provider_discovery_extension_params, ), + AgentExtension( + uri=WORKSPACE_CONTROL_EXTENSION_URI, + required=False, + description=( + "Expose OpenCode-specific project/workspace/worktree control-plane " + "methods through JSON-RPC extensions." + ), + params=workspace_control_extension_params, + ), AgentExtension( uri=INTERRUPT_RECOVERY_EXTENSION_URI, required=False, @@ -285,6 +307,16 @@ def build_agent_card(settings: Settings) -> AgentCard: "List available models for a provider (method opencode.models.list).", ], ), + AgentSkill( + id="opencode.workspace.control", + name="OpenCode Workspace Control", + description=( + "provider-private OpenCode project/workspace/worktree control surface " + "exposed through JSON-RPC extensions." + ), + tags=["opencode", "project", "workspace", "worktree", "provider-private"], + examples=_build_workspace_control_skill_examples(), + ), AgentSkill( id="opencode.interrupt.recovery", name="OpenCode Interrupt Recovery", diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index acd0e25..178bd75 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -59,6 +59,8 @@ SESSION_QUERY_METHODS, STREAMING_EXTENSION_URI, WIRE_CONTRACT_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, + WORKSPACE_CONTROL_METHODS, build_capability_snapshot, ) from ..execution.executor import OpencodeAgentExecutor, _emit_metric @@ -122,6 +124,8 @@ "SESSION_QUERY_METHODS", "STREAMING_EXTENSION_URI", "WIRE_CONTRACT_EXTENSION_URI", + "WORKSPACE_CONTROL_EXTENSION_URI", + "WORKSPACE_CONTROL_METHODS", "_build_agent_card_description", "_build_chat_examples", "_build_jsonrpc_extension_openapi_description", @@ -695,6 +699,7 @@ def create_app(settings: Settings) -> FastAPI: jsonrpc_methods = { **capability_snapshot.session_query_methods(), **capability_snapshot.provider_discovery_methods(), + **capability_snapshot.workspace_control_methods(), **capability_snapshot.interrupt_recovery_methods(), **capability_snapshot.interrupt_callback_methods(), } diff --git a/src/opencode_a2a/server/openapi.py b/src/opencode_a2a/server/openapi.py index c8abe49..a26c197 100644 --- a/src/opencode_a2a/server/openapi.py +++ b/src/opencode_a2a/server/openapi.py @@ -11,6 +11,7 @@ PROVIDER_DISCOVERY_METHODS, SESSION_QUERY_DEFAULT_LIMIT, SESSION_QUERY_METHODS, + WORKSPACE_CONTROL_METHODS, JsonRpcCapabilitySnapshot, build_capability_snapshot, build_compatibility_profile_params, @@ -22,6 +23,7 @@ build_session_query_extension_params, build_streaming_extension_params, build_wire_contract_params, + build_workspace_control_extension_params, ) from ..jsonrpc.application import SESSION_CONTEXT_PREFIX from ..profile.runtime import RuntimeProfile @@ -33,6 +35,7 @@ def _build_jsonrpc_extension_openapi_description( ) -> str: session_methods = list(capability_snapshot.session_query_methods().values()) provider_methods = ", ".join(sorted(PROVIDER_DISCOVERY_METHODS.values())) + workspace_methods = ", ".join(sorted(WORKSPACE_CONTROL_METHODS.values())) interrupt_recovery_methods = ", ".join(sorted(INTERRUPT_RECOVERY_METHODS.values())) interrupt_methods = ", ".join(sorted(INTERRUPT_CALLBACK_METHODS.values())) return ( @@ -42,6 +45,7 @@ def _build_jsonrpc_extension_openapi_description( "interrupt recovery extensions, and shared interrupt callback methods.\n\n" f"OpenCode session query/control methods: {', '.join(session_methods)}.\n" f"OpenCode provider/model discovery methods: {provider_methods}.\n" + f"OpenCode project/workspace/worktree control methods: {workspace_methods}.\n" f"OpenCode interrupt recovery methods: {interrupt_recovery_methods}.\n" f"Shared interrupt callback methods: {interrupt_methods}.\n\n" "Notification semantics: extension requests without JSON-RPC id return HTTP 204." @@ -213,6 +217,92 @@ def _build_jsonrpc_extension_openapi_examples( "params": {"provider_id": "openai"}, }, }, + "projects_list": { + "summary": "List OpenCode projects", + "value": { + "jsonrpc": "2.0", + "id": 28, + "method": WORKSPACE_CONTROL_METHODS["list_projects"], + "params": {}, + }, + }, + "projects_current": { + "summary": "Get the current OpenCode project", + "value": { + "jsonrpc": "2.0", + "id": 281, + "method": WORKSPACE_CONTROL_METHODS["get_current_project"], + "params": {}, + }, + }, + "workspaces_list": { + "summary": "List workspaces for the active project", + "value": { + "jsonrpc": "2.0", + "id": 29, + "method": WORKSPACE_CONTROL_METHODS["list_workspaces"], + "params": {}, + }, + }, + "workspaces_create": { + "summary": "Create a workspace for the active project", + "value": { + "jsonrpc": "2.0", + "id": 291, + "method": WORKSPACE_CONTROL_METHODS["create_workspace"], + "params": {"request": {"type": "git", "branch": "main"}}, + }, + }, + "workspaces_remove": { + "summary": "Remove a workspace", + "value": { + "jsonrpc": "2.0", + "id": 292, + "method": WORKSPACE_CONTROL_METHODS["remove_workspace"], + "params": {"workspace_id": "wrk-1"}, + }, + }, + "worktrees_list": { + "summary": "List worktrees for the active project", + "value": { + "jsonrpc": "2.0", + "id": 293, + "method": WORKSPACE_CONTROL_METHODS["list_worktrees"], + "params": {}, + }, + }, + "worktrees_create": { + "summary": "Create a new worktree", + "value": { + "jsonrpc": "2.0", + "id": 30, + "method": WORKSPACE_CONTROL_METHODS["create_worktree"], + "params": { + "request": { + "name": "feature-branch", + "startCommand": "pnpm install", + } + }, + }, + }, + "worktrees_remove": { + "summary": "Remove a worktree", + "value": { + "jsonrpc": "2.0", + "id": 301, + "method": WORKSPACE_CONTROL_METHODS["remove_worktree"], + "params": {"request": {"directory": "/tmp/worktrees/feature-branch"}}, + }, + }, + "worktrees_reset": { + "summary": "Reset a worktree branch", + "value": { + "jsonrpc": "2.0", + "id": 302, + "method": WORKSPACE_CONTROL_METHODS["reset_worktree"], + "params": {"request": {"directory": "/tmp/worktrees/feature-branch"}}, + }, + }, "permissions_list": { "summary": "List pending permission interrupts for the current caller", "value": { @@ -365,6 +455,9 @@ def _patch_jsonrpc_openapi_contract( provider_discovery = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + workspace_control = build_workspace_control_extension_params( + runtime_profile=runtime_profile, + ) interrupt_recovery = build_interrupt_recovery_extension_params( runtime_profile=runtime_profile, ) @@ -403,6 +496,7 @@ def custom_openapi() -> dict[str, Any]: "streaming": streaming, "session_query": session_query, "provider_discovery": provider_discovery, + "workspace_control": workspace_control, "interrupt_recovery": interrupt_recovery, "interrupt_callback": interrupt_callback, "compatibility_profile": compatibility_profile, diff --git a/src/opencode_a2a/server/request_parsing.py b/src/opencode_a2a/server/request_parsing.py index fe3bdb7..28f2c93 100644 --- a/src/opencode_a2a/server/request_parsing.py +++ b/src/opencode_a2a/server/request_parsing.py @@ -9,6 +9,7 @@ INTERRUPT_CALLBACK_METHODS, INTERRUPT_RECOVERY_METHODS, SESSION_QUERY_METHODS, + WORKSPACE_CONTROL_METHODS, ) logger = logging.getLogger(__name__) @@ -32,6 +33,7 @@ def _detect_sensitive_extension_method(payload: dict | None) -> str | None: set(SESSION_QUERY_METHODS.values()) | set(INTERRUPT_CALLBACK_METHODS.values()) | set(INTERRUPT_RECOVERY_METHODS.values()) + | set(WORKSPACE_CONTROL_METHODS.values()) ) if method in sensitive_methods: return method diff --git a/tests/contracts/test_extension_contract_consistency.py b/tests/contracts/test_extension_contract_consistency.py index 9a00fed..4a9ba8d 100644 --- a/tests/contracts/test_extension_contract_consistency.py +++ b/tests/contracts/test_extension_contract_consistency.py @@ -5,6 +5,7 @@ INTERRUPT_CALLBACK_METHODS, SESSION_QUERY_DEFAULT_LIMIT, SESSION_QUERY_MAX_LIMIT, + WORKSPACE_CONTROL_METHODS, build_capability_snapshot, build_compatibility_profile_params, build_interrupt_callback_extension_params, @@ -15,6 +16,7 @@ build_session_query_extension_params, build_streaming_extension_params, build_wire_contract_params, + build_workspace_control_extension_params, ) from opencode_a2a.jsonrpc.application import SESSION_CONTEXT_PREFIX from opencode_a2a.profile.runtime import build_runtime_profile @@ -28,6 +30,7 @@ SESSION_QUERY_EXTENSION_URI, STREAMING_EXTENSION_URI, WIRE_CONTRACT_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, build_agent_card, create_app, ) @@ -46,6 +49,7 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: streaming = ext_by_uri[STREAMING_EXTENSION_URI] session_query = ext_by_uri[SESSION_QUERY_EXTENSION_URI] provider_discovery = ext_by_uri[PROVIDER_DISCOVERY_EXTENSION_URI] + workspace_control = ext_by_uri[WORKSPACE_CONTROL_EXTENSION_URI] interrupt_recovery = ext_by_uri[INTERRUPT_RECOVERY_EXTENSION_URI] interrupt_callback = ext_by_uri[INTERRUPT_CALLBACK_EXTENSION_URI] compatibility_profile = ext_by_uri[COMPATIBILITY_PROFILE_EXTENSION_URI] @@ -66,6 +70,9 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: expected_provider_discovery = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + expected_workspace_control = build_workspace_control_extension_params( + runtime_profile=runtime_profile, + ) expected_interrupt_recovery = build_interrupt_recovery_extension_params( runtime_profile=runtime_profile, ) @@ -98,6 +105,9 @@ def test_extension_ssot_matches_agent_card_contracts() -> None: assert provider_discovery.params == expected_provider_discovery, ( "Provider discovery extension drifted from contracts.extensions SSOT." ) + assert workspace_control.params == expected_workspace_control, ( + "Workspace control extension drifted from contracts.extensions SSOT." + ) assert interrupt_recovery.params == expected_interrupt_recovery, ( "Interrupt recovery extension drifted from contracts.extensions SSOT." ) @@ -127,6 +137,7 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: streaming = contract["streaming"] session_query = contract["session_query"] provider_discovery = contract["provider_discovery"] + workspace_control = contract["workspace_control"] interrupt_recovery = contract["interrupt_recovery"] interrupt_callback = contract["interrupt_callback"] compatibility_profile = contract["compatibility_profile"] @@ -147,6 +158,9 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: expected_provider_discovery = build_provider_discovery_extension_params( runtime_profile=runtime_profile, ) + expected_workspace_control = build_workspace_control_extension_params( + runtime_profile=runtime_profile, + ) expected_interrupt_recovery = build_interrupt_recovery_extension_params( runtime_profile=runtime_profile, ) @@ -177,6 +191,9 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: assert provider_discovery == expected_provider_discovery, ( "OpenAPI provider discovery contract drifted from contracts.extensions SSOT." ) + assert workspace_control == expected_workspace_control, ( + "OpenAPI workspace control contract drifted from contracts.extensions SSOT." + ) assert interrupt_recovery == expected_interrupt_recovery, ( "OpenAPI interrupt recovery contract drifted from contracts.extensions SSOT." ) @@ -213,6 +230,7 @@ def test_openapi_jsonrpc_contract_extension_matches_ssot() -> None: expected_methods |= { "opencode.providers.list", "opencode.models.list", + *WORKSPACE_CONTROL_METHODS.values(), "opencode.permissions.list", "opencode.questions.list", } @@ -295,6 +313,15 @@ async def test_runtime_supported_methods_align_with_capability_snapshot( ), ("opencode.providers.list", {}, None), ("opencode.models.list", {"provider_id": "openai"}, None), + ("opencode.projects.list", {}, None), + ("opencode.projects.current", {}, None), + ("opencode.workspaces.list", {}, None), + ("opencode.workspaces.create", {"request": {"type": "git"}}, None), + ("opencode.workspaces.remove", {"workspace_id": "wrk-1"}, None), + ("opencode.worktrees.list", {}, None), + ("opencode.worktrees.create", {"request": {"name": "feature-branch"}}, None), + ("opencode.worktrees.remove", {"request": {"directory": "/tmp/worktree"}}, None), + ("opencode.worktrees.reset", {"request": {"directory": "/tmp/worktree"}}, None), ("opencode.permissions.list", {}, None), ("opencode.questions.list", {}, None), ( diff --git a/tests/execution/test_opencode_agent_session_binding.py b/tests/execution/test_opencode_agent_session_binding.py index 418ba3e..189ef05 100644 --- a/tests/execution/test_opencode_agent_session_binding.py +++ b/tests/execution/test_opencode_agent_session_binding.py @@ -109,9 +109,14 @@ async def create_session( title: str | None = None, *, directory: str | None = None, + workspace_id: str | None = None, ) -> str: await asyncio.sleep(0.05) - return await super().create_session(title=title, directory=directory) + return await super().create_session( + title=title, + directory=directory, + workspace_id=workspace_id, + ) client = SlowCreateClient() executor = OpencodeAgentExecutor(client, streaming_enabled=False) @@ -126,6 +131,49 @@ async def run_one(task_id: str) -> None: assert client.created_sessions == 1 +@pytest.mark.asyncio +async def test_agent_passes_opencode_workspace_metadata_to_upstream() -> None: + client = DummyChatOpencodeUpstreamClient() + executor = OpencodeAgentExecutor(client, streaming_enabled=False) + q = DummyEventQueue() + + ctx = make_request_context( + task_id="t-workspace", + context_id="c-workspace", + text="hello", + metadata={"opencode": {"workspace": {"id": "wrk-1"}}}, + ) + await executor.execute(ctx, q) + + assert client.created_workspace_ids == ["wrk-1"] + assert client.sent_workspace_ids == ["wrk-1"] + + +@pytest.mark.asyncio +async def test_agent_scopes_cached_sessions_by_workspace_binding() -> None: + client = DummyChatOpencodeUpstreamClient() + executor = OpencodeAgentExecutor(client, streaming_enabled=False) + + ctx1 = make_request_context( + task_id="t-workspace-1", + context_id="c-shared", + text="hello", + metadata={"opencode": {"workspace": {"id": "wrk-1"}}}, + ) + ctx2 = make_request_context( + task_id="t-workspace-2", + context_id="c-shared", + text="hello again", + metadata={"opencode": {"workspace": {"id": "wrk-2"}}}, + ) + + await executor.execute(ctx1, DummyEventQueue()) + await executor.execute(ctx2, DummyEventQueue()) + + assert client.created_sessions == 2 + assert client.created_workspace_ids == ["wrk-1", "wrk-2"] + + @pytest.mark.asyncio async def test_agent_uses_stable_fallback_message_id_when_upstream_missing_message_id() -> None: class MissingMessageIdClient(DummyChatOpencodeUpstreamClient): diff --git a/tests/jsonrpc/test_dispatch_registry.py b/tests/jsonrpc/test_dispatch_registry.py index 141c6cb..ff07422 100644 --- a/tests/jsonrpc/test_dispatch_registry.py +++ b/tests/jsonrpc/test_dispatch_registry.py @@ -26,6 +26,7 @@ async def test_extension_registry_tracks_configured_methods(monkeypatch) -> None registry_methods = _jsonrpc_app(app)._extension_method_registry.methods() # noqa: SLF001 assert "opencode.sessions.list" in registry_methods assert "opencode.providers.list" in registry_methods + assert "opencode.projects.list" in registry_methods assert "opencode.permissions.list" in registry_methods assert "a2a.interrupt.permission.reply" in registry_methods assert "opencode.sessions.shell" not in registry_methods diff --git a/tests/jsonrpc/test_opencode_session_extension_commands.py b/tests/jsonrpc/test_opencode_session_extension_commands.py index 297e4aa..796c926 100644 --- a/tests/jsonrpc/test_opencode_session_extension_commands.py +++ b/tests/jsonrpc/test_opencode_session_extension_commands.py @@ -58,6 +58,40 @@ async def test_session_command_extension_success(monkeypatch): assert dummy.command_calls[0]["directory"] == "/workspace" +@pytest.mark.asyncio +async def test_session_command_extension_prefers_workspace_metadata(monkeypatch): + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer t-1"}, + json={ + "jsonrpc": "2.0", + "id": 3202, + "method": "opencode.sessions.command", + "params": { + "session_id": "s-1", + "request": {"command": "/review", "arguments": "security"}, + "metadata": {"opencode": {"workspace": {"id": "wrk-1"}}}, + }, + }, + ) + + assert response.status_code == 200 + assert dummy.command_calls[0]["directory"] is None + assert dummy.command_calls[0]["workspace_id"] == "wrk-1" + + @pytest.mark.asyncio async def test_session_command_extension_accepts_request_model(monkeypatch): import opencode_a2a.server.application as app_module diff --git a/tests/jsonrpc/test_opencode_session_extension_interrupts.py b/tests/jsonrpc/test_opencode_session_extension_interrupts.py index 73b494c..0cabc98 100644 --- a/tests/jsonrpc/test_opencode_session_extension_interrupts.py +++ b/tests/jsonrpc/test_opencode_session_extension_interrupts.py @@ -26,6 +26,7 @@ async def permission_reply( reply: str, message: str | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: self.permission_reply_calls.append( { @@ -33,6 +34,7 @@ async def permission_reply( "reply": reply, "message": message, "directory": directory, + "workspace_id": workspace_id, } ) return True @@ -69,6 +71,7 @@ async def permission_reply( "metadata": { "opencode": { "directory": "/workspace", + "workspace": {"id": "wrk-1"}, } }, }, @@ -83,6 +86,7 @@ async def permission_reply( assert dummy.permission_reply_calls[0]["request_id"] == "perm-1" assert dummy.permission_reply_calls[0]["reply"] == "once" assert dummy.permission_reply_calls[0]["directory"] == "/workspace" + assert dummy.permission_reply_calls[0]["workspace_id"] == "wrk-1" @pytest.mark.asyncio @@ -166,9 +170,15 @@ async def question_reply( *, answers: list[list[str]], directory: str | None = None, + workspace_id: str | None = None, ) -> bool: self.question_reply_calls.append( - {"request_id": request_id, "answers": answers, "directory": directory} + { + "request_id": request_id, + "answers": answers, + "directory": directory, + "workspace_id": workspace_id, + } ) return True @@ -177,8 +187,15 @@ async def question_reject( request_id: str, *, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: - self.question_reject_calls.append({"request_id": request_id, "directory": directory}) + self.question_reject_calls.append( + { + "request_id": request_id, + "directory": directory, + "workspace_id": workspace_id, + } + ) return True dummy = InterruptClient( @@ -262,8 +279,9 @@ async def permission_reply( reply: str, message: str | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: - del request_id, reply, message, directory + del request_id, reply, message, directory, workspace_id request = httpx.Request("POST", "http://opencode/permission/x/reply") response = httpx.Response(404, request=request) raise httpx.HTTPStatusError("Not Found", request=request, response=response) @@ -346,8 +364,9 @@ async def permission_reply( reply: str, message: str | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: - del reply, message, directory + del reply, message, directory, workspace_id self.permission_reply_calls.append(request_id) return True @@ -469,8 +488,9 @@ async def permission_reply( reply: str, message: str | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: - del request_id, reply, message, directory + del request_id, reply, message, directory, workspace_id raise UpstreamConcurrencyLimitError( category="request", operation="/permission/{requestID}/reply", diff --git a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py index a3a454b..c576768 100644 --- a/tests/jsonrpc/test_opencode_session_extension_prompt_async.py +++ b/tests/jsonrpc/test_opencode_session_extension_prompt_async.py @@ -66,6 +66,45 @@ async def test_session_prompt_async_extension_success(monkeypatch): assert dummy.prompt_async_calls[0]["request"]["parts"][0]["text"] == "Continue the task" +@pytest.mark.asyncio +async def test_session_prompt_async_extension_prefers_workspace_metadata(monkeypatch): + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer t-1"}, + json={ + "jsonrpc": "2.0", + "id": 3011, + "method": "opencode.sessions.prompt_async", + "params": { + "session_id": "s-1", + "request": {"parts": [{"type": "text", "text": "Continue the task"}]}, + "metadata": { + "opencode": { + "directory": "/workspace", + "workspace": {"id": "wrk-1"}, + } + }, + }, + }, + ) + + assert response.status_code == 200 + assert dummy.prompt_async_calls[0]["directory"] is None + assert dummy.prompt_async_calls[0]["workspace_id"] == "wrk-1" + + @pytest.mark.asyncio async def test_session_prompt_async_extension_rejects_invalid_params(monkeypatch): import opencode_a2a.server.application as app_module diff --git a/tests/jsonrpc/test_opencode_session_extension_queries.py b/tests/jsonrpc/test_opencode_session_extension_queries.py index d433fab..a1a986a 100644 --- a/tests/jsonrpc/test_opencode_session_extension_queries.py +++ b/tests/jsonrpc/test_opencode_session_extension_queries.py @@ -195,6 +195,65 @@ async def test_session_query_extension_supports_session_filters_and_message_curs assert dummy.last_messages_params == {"before": "cursor-1", "limit": 5} +@pytest.mark.asyncio +async def test_session_query_extension_prefers_workspace_metadata_for_routing(monkeypatch): + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings( + a2a_bearer_token="t-1", + a2a_log_payloads=False, + opencode_workspace_root="/workspace", + **_BASE_SETTINGS, + ) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 12, + "method": "opencode.sessions.list", + "params": { + "directory": "services/api", + "limit": 2, + "metadata": {"opencode": {"workspace": {"id": "wrk-1"}}}, + }, + }, + ) + await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 13, + "method": "opencode.sessions.messages.list", + "params": { + "session_id": "s-1", + "limit": 2, + "metadata": {"opencode": {"workspace": {"id": "wrk-1"}}}, + }, + }, + ) + + assert dummy.last_sessions_workspace_id == "wrk-1" + assert dummy.last_sessions_directory is None + assert dummy.last_messages_workspace_id == "wrk-1" + + @pytest.mark.asyncio async def test_session_query_extension_rejects_directory_outside_workspace(monkeypatch): import opencode_a2a.server.application as app_module @@ -368,7 +427,7 @@ async def test_provider_discovery_extension_returns_normalized_catalog(monkeypat "jsonrpc": "2.0", "id": 11, "method": "opencode.providers.list", - "params": {}, + "params": {"metadata": {"opencode": {"workspace": {"id": "wrk-1"}}}}, }, ) assert providers_resp.status_code == 200 @@ -377,6 +436,7 @@ async def test_provider_discovery_extension_returns_normalized_catalog(monkeypat assert providers_payload["connected"] == ["openai"] assert providers_payload["items"][0]["provider_id"] == "openai" assert providers_payload["items"][0]["default_model_id"] == "gpt-5" + assert dummy.workspace_control_calls[0]["workspace_id"] == "wrk-1" models_resp = await client.post( "/", diff --git a/tests/jsonrpc/test_opencode_workspace_control_extension.py b/tests/jsonrpc/test_opencode_workspace_control_extension.py new file mode 100644 index 0000000..112531b --- /dev/null +++ b/tests/jsonrpc/test_opencode_workspace_control_extension.py @@ -0,0 +1,199 @@ +import httpx +import pytest + +from tests.support.helpers import ( + DummySessionQueryOpencodeUpstreamClient as DummyOpencodeUpstreamClient, +) +from tests.support.helpers import make_settings +from tests.support.session_extensions import _BASE_SETTINGS + + +@pytest.mark.asyncio +async def test_workspace_control_extension_supports_read_only_methods(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + projects = await client.post( + "/", + headers=headers, + json={"jsonrpc": "2.0", "id": 1, "method": "opencode.projects.list", "params": {}}, + ) + current = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 2, + "method": "opencode.projects.current", + "params": {}, + }, + ) + workspaces = await client.post( + "/", + headers=headers, + json={"jsonrpc": "2.0", "id": 3, "method": "opencode.workspaces.list", "params": {}}, + ) + worktrees = await client.post( + "/", + headers=headers, + json={"jsonrpc": "2.0", "id": 4, "method": "opencode.worktrees.list", "params": {}}, + ) + + assert projects.status_code == 200 + assert projects.json()["result"]["items"][0]["id"] == "proj-1" + assert current.status_code == 200 + assert current.json()["result"]["item"]["id"] == "proj-1" + assert workspaces.status_code == 200 + assert workspaces.json()["result"]["items"][0]["id"] == "wrk-1" + assert worktrees.status_code == 200 + assert worktrees.json()["result"]["items"] == ["/tmp/worktrees/alpha"] + + +@pytest.mark.asyncio +async def test_workspace_control_extension_supports_mutating_methods(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + dummy = DummyOpencodeUpstreamClient( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", lambda _settings: dummy) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": "Bearer t-1"} + create_workspace = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 10, + "method": "opencode.workspaces.create", + "params": {"request": {"type": "git", "branch": "main"}}, + }, + ) + remove_workspace = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 11, + "method": "opencode.workspaces.remove", + "params": {"workspace_id": "wrk-1"}, + }, + ) + create_worktree = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 12, + "method": "opencode.worktrees.create", + "params": {"request": {"name": "feature-branch"}}, + }, + ) + remove_worktree = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 13, + "method": "opencode.worktrees.remove", + "params": {"request": {"directory": "/tmp/worktrees/feature-branch"}}, + }, + ) + reset_worktree = await client.post( + "/", + headers=headers, + json={ + "jsonrpc": "2.0", + "id": 14, + "method": "opencode.worktrees.reset", + "params": {"request": {"directory": "/tmp/worktrees/feature-branch"}}, + }, + ) + + assert create_workspace.status_code == 200 + assert create_workspace.json()["result"]["item"]["type"] == "git" + assert remove_workspace.status_code == 200 + assert remove_workspace.json()["result"]["item"]["id"] == "wrk-1" + assert create_worktree.status_code == 200 + assert create_worktree.json()["result"]["item"]["directory"] == "/tmp/worktrees/feature-branch" + assert remove_worktree.status_code == 200 + assert remove_worktree.json()["result"] == {"ok": True} + assert reset_worktree.status_code == 200 + assert reset_worktree.json()["result"] == {"ok": True} + + +@pytest.mark.asyncio +async def test_workspace_control_extension_validates_request_shape(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", DummyOpencodeUpstreamClient) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer t-1"}, + json={ + "jsonrpc": "2.0", + "id": 20, + "method": "opencode.workspaces.create", + "params": {"request": {"branch": "main"}}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["error"]["code"] == -32602 + assert payload["error"]["data"]["field"] == "request" + + +@pytest.mark.asyncio +async def test_workspace_control_extension_maps_upstream_http_error(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + class UpstreamErrorClient(DummyOpencodeUpstreamClient): + async def list_workspaces(self): + request = httpx.Request("GET", "http://test/experimental/workspace") + response = httpx.Response(503, request=request) + raise httpx.HTTPStatusError("boom", request=request, response=response) + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", UpstreamErrorClient) + app = app_module.create_app( + make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) + ) + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer t-1"}, + json={ + "jsonrpc": "2.0", + "id": 21, + "method": "opencode.workspaces.list", + "params": {}, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["error"]["data"]["type"] == "UPSTREAM_HTTP_ERROR" + assert payload["error"]["data"]["upstream_status"] == 503 diff --git a/tests/profile/test_profile_runtime.py b/tests/profile/test_profile_runtime.py index 775f176..238f08a 100644 --- a/tests/profile/test_profile_runtime.py +++ b/tests/profile/test_profile_runtime.py @@ -39,6 +39,12 @@ def test_profile_runtime_splits_deployment_runtime_features_and_health_payload() "scope": "workspace_root_only", "metadata_field": "metadata.opencode.directory", }, + "workspace_binding": { + "enabled": True, + "metadata_field": "metadata.opencode.workspace.id", + "upstream_query_param": "workspace", + "precedence": "prefer_workspace_else_directory", + }, "session_shell": { "enabled": False, "availability": "disabled", diff --git a/tests/server/test_agent_card.py b/tests/server/test_agent_card.py index 32d5a46..9152e77 100644 --- a/tests/server/test_agent_card.py +++ b/tests/server/test_agent_card.py @@ -14,6 +14,7 @@ SESSION_QUERY_EXTENSION_URI, STREAMING_EXTENSION_URI, WIRE_CONTRACT_EXTENSION_URI, + WORKSPACE_CONTROL_EXTENSION_URI, build_agent_card, ) from tests.support.helpers import make_settings @@ -63,8 +64,12 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert binding.params["supported_metadata"] == [ "shared.session.id", "opencode.directory", + "opencode.workspace.id", + ] + assert binding.params["provider_private_metadata"] == [ + "opencode.directory", + "opencode.workspace.id", ] - assert binding.params["provider_private_metadata"] == ["opencode.directory"] assert profile["profile_id"] == "opencode-a2a-single-tenant-coding-v1" assert profile["deployment"] == { "id": "single_tenant_shared_workspace", @@ -83,6 +88,12 @@ def test_agent_card_injects_profile_into_extensions() -> None: "scope": "workspace_root_only", "metadata_field": "metadata.opencode.directory", } + assert profile["runtime_features"]["workspace_binding"] == { + "enabled": True, + "metadata_field": "metadata.opencode.workspace.id", + "upstream_query_param": "workspace", + "precedence": "prefer_workspace_else_directory", + } assert profile["runtime_features"]["execution_environment"] == { "sandbox": { "mode": "workspace-write", @@ -205,6 +216,7 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert list_contract["params"]["optional"] == [ "limit", "directory", + "metadata.opencode.workspace.id", "roots", "start", "search", @@ -217,6 +229,7 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert messages_contract["params"]["optional"] == [ "limit", "before", + "metadata.opencode.workspace.id", "query.limit", "query.before", ] @@ -261,6 +274,10 @@ def test_agent_card_injects_profile_into_extensions() -> None: "list_providers": "opencode.providers.list", "list_models": "opencode.models.list", } + assert provider_discovery.params["supported_metadata"] == [ + "opencode.directory", + "opencode.workspace.id", + ] assert "result_envelope" not in provider_discovery.params assert provider_discovery.params["method_contracts"]["opencode.providers.list"]["result"] == { "fields": ["items", "default_by_provider", "connected"], @@ -279,6 +296,35 @@ def test_agent_card_injects_profile_into_extensions() -> None: "UPSTREAM_PAYLOAD_ERROR": -32005, } + workspace_control = ext_by_uri[WORKSPACE_CONTROL_EXTENSION_URI] + assert workspace_control.params["profile"]["runtime_context"]["project"] == "alpha" + assert workspace_control.params["methods"] == { + "list_projects": "opencode.projects.list", + "get_current_project": "opencode.projects.current", + "list_workspaces": "opencode.workspaces.list", + "create_workspace": "opencode.workspaces.create", + "remove_workspace": "opencode.workspaces.remove", + "list_worktrees": "opencode.worktrees.list", + "create_worktree": "opencode.worktrees.create", + "remove_worktree": "opencode.worktrees.remove", + "reset_worktree": "opencode.worktrees.reset", + } + assert workspace_control.params["routing_fields"]["workspace_id"] == ( + "metadata.opencode.workspace.id" + ) + assert workspace_control.params["method_contracts"]["opencode.projects.list"]["result"] == { + "fields": ["items"], + "items_type": "Project[]", + } + assert workspace_control.params["method_contracts"]["opencode.workspaces.create"]["params"] == { + "required": ["request.type"], + "optional": ["request.id", "request.branch", "request.extra"], + } + assert workspace_control.params["method_contracts"]["opencode.worktrees.reset"]["result"] == { + "fields": ["ok"], + "items_type": "boolean", + } + interrupt_recovery = ext_by_uri[INTERRUPT_RECOVERY_EXTENSION_URI] assert interrupt_recovery.params["profile"]["runtime_context"]["project"] == "alpha" assert interrupt_recovery.params["methods"] == { @@ -299,9 +345,16 @@ def test_agent_card_injects_profile_into_extensions() -> None: interrupt = ext_by_uri[INTERRUPT_CALLBACK_EXTENSION_URI] assert interrupt.params["profile"]["runtime_context"]["project"] == "alpha" assert interrupt.params["request_id_field"] == "metadata.shared.interrupt.request_id" - assert interrupt.params["supported_metadata"] == ["opencode.directory"] - assert interrupt.params["provider_private_metadata"] == ["opencode.directory"] + assert interrupt.params["supported_metadata"] == [ + "opencode.directory", + "opencode.workspace.id", + ] + assert interrupt.params["provider_private_metadata"] == [ + "opencode.directory", + "opencode.workspace.id", + ] assert interrupt.params["context_fields"]["directory"] == "metadata.opencode.directory" + assert interrupt.params["context_fields"]["workspace_id"] == "metadata.opencode.workspace.id" assert interrupt.params["errors"]["business_codes"] == { "INTERRUPT_REQUEST_NOT_FOUND": -32004, "INTERRUPT_REQUEST_EXPIRED": -32007, @@ -381,6 +434,7 @@ def test_agent_card_injects_profile_into_extensions() -> None: assert wire_contract.params["profile"]["profile_id"] == "opencode-a2a-single-tenant-coding-v1" assert MODEL_SELECTION_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert PROVIDER_DISCOVERY_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] + assert WORKSPACE_CONTROL_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert INTERRUPT_RECOVERY_EXTENSION_URI in wire_contract.params["extensions"]["extension_uris"] assert "opencode.sessions.shell" not in wire_contract.params["all_jsonrpc_methods"] assert wire_contract.params["service_behaviors"] == expected_service_behaviors @@ -435,12 +489,16 @@ def test_agent_card_skills_hide_shell_when_disabled_by_default() -> None: session_skill = next(skill for skill in card.skills if skill.id == "opencode.sessions.query") provider_skill = next(skill for skill in card.skills if skill.id == "opencode.providers.query") + workspace_skill = next( + skill for skill in card.skills if skill.id == "opencode.workspace.control" + ) assert "provider-private" in session_skill.tags assert "provider-private" in session_skill.description assert all("opencode.sessions.shell" not in example for example in session_skill.examples) assert "provider-private" in provider_skill.tags assert any("opencode.providers.list" in example for example in provider_skill.examples) + assert any("opencode.projects.list" in example for example in workspace_skill.examples) interrupt_recovery_skill = next( skill for skill in card.skills if skill.id == "opencode.interrupt.recovery" ) diff --git a/tests/server/test_app_behaviors.py b/tests/server/test_app_behaviors.py index 06c9385..f441cfb 100644 --- a/tests/server/test_app_behaviors.py +++ b/tests/server/test_app_behaviors.py @@ -153,6 +153,12 @@ def test_agent_card_helper_builders_cover_optional_branches() -> None: "scope": "workspace_root_only", "metadata_field": "metadata.opencode.directory", }, + "workspace_binding": { + "enabled": True, + "metadata_field": "metadata.opencode.workspace.id", + "upstream_query_param": "workspace", + "precedence": "prefer_workspace_else_directory", + }, "session_shell": { "enabled": True, "availability": "enabled", @@ -265,6 +271,12 @@ async def close(self) -> None: "scope": "workspace_root_or_descendant", "metadata_field": "metadata.opencode.directory", }, + "workspace_binding": { + "enabled": True, + "metadata_field": "metadata.opencode.workspace.id", + "upstream_query_param": "workspace", + "precedence": "prefer_workspace_else_directory", + }, "session_shell": { "enabled": True, "availability": "enabled", diff --git a/tests/support/helpers.py b/tests/support/helpers.py index 3be2e19..38c0b78 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -125,6 +125,8 @@ def __init__(self, settings: Settings | None = None) -> None: self.created_sessions = 0 self.sent_session_ids: list[str] = [] self.sent_model_overrides: list[dict[str, str] | None] = [] + self.sent_workspace_ids: list[str | None] = [] + self.created_workspace_ids: list[str | None] = [] self.stream_timeout = None self.directory = None self.settings = settings or make_settings( @@ -140,9 +142,11 @@ async def create_session( title: str | None = None, *, directory: str | None = None, + workspace_id: str | None = None, ) -> str: del title, directory self.created_sessions += 1 + self.created_workspace_ids.append(workspace_id) return f"ses-created-{self.created_sessions}" async def send_message( @@ -152,12 +156,14 @@ async def send_message( *, parts: list[dict[str, Any]] | None = None, directory: str | None = None, + workspace_id: str | None = None, model_override: dict[str, str] | None = None, timeout_override=None, # noqa: ANN001 ) -> OpencodeMessage: del directory, timeout_override, parts self.sent_session_ids.append(session_id) self.sent_model_overrides.append(model_override) + self.sent_workspace_ids.append(workspace_id) return OpencodeMessage( text=f"echo:{text or ''}", session_id=session_id, @@ -165,8 +171,14 @@ async def send_message( raw={}, ) - async def stream_events(self, stop_event=None, *, directory: str | None = None): # noqa: ANN001 - del stop_event, directory + async def stream_events( # noqa: ANN001 + self, + stop_event=None, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + del stop_event, directory, workspace_id for _ in (): yield {} @@ -215,10 +227,13 @@ def __init__(self, _settings: Settings) -> None: self._messages_next_cursor: str | None = None self.last_sessions_params = None self.last_sessions_directory: str | None = None + self.last_sessions_workspace_id: str | None = None self.last_messages_params = None + self.last_messages_workspace_id: str | None = None self.prompt_async_calls: list[dict[str, Any]] = [] self.command_calls: list[dict[str, Any]] = [] self.shell_calls: list[dict[str, Any]] = [] + self.workspace_control_calls: list[dict[str, Any]] = [] self.provider_catalog_payload: dict[str, Any] = { "all": [ { @@ -268,14 +283,22 @@ def __init__(self, _settings: Settings) -> None: async def close(self) -> None: return None - async def list_sessions(self, *, params=None, directory: str | None = None): + async def list_sessions( + self, + *, + params=None, + directory: str | None = None, + workspace_id: str | None = None, + ): self.last_sessions_directory = directory + self.last_sessions_workspace_id = workspace_id self.last_sessions_params = params return self._sessions_payload - async def list_messages(self, session_id: str, *, params=None): + async def list_messages(self, session_id: str, *, params=None, workspace_id: str | None = None): assert session_id self.last_messages_params = params + self.last_messages_workspace_id = workspace_id return OpencodeMessagePage( payload=self._messages_payload, next_cursor=self._messages_next_cursor, @@ -287,12 +310,14 @@ async def session_prompt_async( request: dict[str, Any], *, directory: str | None = None, + workspace_id: str | None = None, ) -> None: self.prompt_async_calls.append( { "session_id": session_id, "request": request, "directory": directory, + "workspace_id": workspace_id, } ) @@ -302,12 +327,14 @@ async def session_command( request: dict[str, Any], *, directory: str | None = None, + workspace_id: str | None = None, ) -> dict[str, Any]: self.command_calls.append( { "session_id": session_id, "request": request, "directory": directory, + "workspace_id": workspace_id, } ) return { @@ -321,12 +348,14 @@ async def session_shell( request: dict[str, Any], *, directory: str | None = None, + workspace_id: str | None = None, ) -> dict[str, Any]: self.shell_calls.append( { "session_id": session_id, "request": request, "directory": directory, + "workspace_id": workspace_id, } ) return { @@ -335,10 +364,63 @@ async def session_shell( "parts": [{"type": "text", "text": "Shell command executed."}], } - async def list_provider_catalog(self, *, directory: str | None = None): - del directory + async def list_provider_catalog( + self, + *, + directory: str | None = None, + workspace_id: str | None = None, + ): + self.workspace_control_calls.append( + { + "method": "provider_catalog", + "directory": directory, + "workspace_id": workspace_id, + } + ) return self.provider_catalog_payload + async def list_projects(self): + self.workspace_control_calls.append({"method": "list_projects"}) + return [{"id": "proj-1", "name": "Alpha", "directory": "/workspace"}] + + async def get_current_project(self): + self.workspace_control_calls.append({"method": "get_current_project"}) + return {"id": "proj-1", "name": "Alpha", "directory": "/workspace"} + + async def list_workspaces(self): + self.workspace_control_calls.append({"method": "list_workspaces"}) + return [{"id": "wrk-1", "type": "git", "branch": "main", "directory": None}] + + async def create_workspace(self, request: dict[str, Any]): + self.workspace_control_calls.append({"method": "create_workspace", "request": request}) + return {"id": "wrk-2", **request} + + async def remove_workspace(self, workspace_id: str): + self.workspace_control_calls.append( + {"method": "remove_workspace", "workspace_id": workspace_id} + ) + return {"id": workspace_id, "type": "git", "branch": "main", "directory": None} + + async def list_worktrees(self): + self.workspace_control_calls.append({"method": "list_worktrees"}) + return ["/tmp/worktrees/alpha"] + + async def create_worktree(self, request: dict[str, Any]): + self.workspace_control_calls.append({"method": "create_worktree", "request": request}) + return { + "name": request.get("name") or "feature-branch", + "branch": "opencode/feature-branch", + "directory": "/tmp/worktrees/feature-branch", + } + + async def remove_worktree(self, request: dict[str, Any]) -> bool: + self.workspace_control_calls.append({"method": "remove_worktree", "request": request}) + return True + + async def reset_worktree(self, request: dict[str, Any]) -> bool: + self.workspace_control_calls.append({"method": "reset_worktree", "request": request}) + return True + async def remember_interrupt_request( self, *, @@ -443,8 +525,9 @@ async def permission_reply( reply: str, message: str | None = None, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: - del request_id, reply, message, directory + del request_id, reply, message, directory, workspace_id return True async def question_reply( @@ -453,8 +536,9 @@ async def question_reply( *, answers: list[list[str]], directory: str | None = None, + workspace_id: str | None = None, ) -> bool: - del request_id, answers, directory + del request_id, answers, directory, workspace_id return True async def question_reject( @@ -462,6 +546,7 @@ async def question_reject( request_id: str, *, directory: str | None = None, + workspace_id: str | None = None, ) -> bool: - del request_id, directory + del request_id, directory, workspace_id return True diff --git a/tests/upstream/test_opencode_upstream_client_params.py b/tests/upstream/test_opencode_upstream_client_params.py index c860691..29e2154 100644 --- a/tests/upstream/test_opencode_upstream_client_params.py +++ b/tests/upstream/test_opencode_upstream_client_params.py @@ -814,10 +814,16 @@ def test_merge_params_keeps_empty_directory_out_of_query() -> None: ) assert client._query_params() == {} + assert client._query_params(workspace_id="wrk-1") == {"workspace": "wrk-1"} assert client._merge_params({"limit": 5, "enabled": False}, directory=None) == { "limit": "5", "enabled": "False", } + assert client._merge_params( + {"limit": 5, "workspace": "ignored"}, + directory="/safe", + workspace_id="wrk-1", + ) == {"workspace": "wrk-1", "limit": "5"} @pytest.mark.asyncio @@ -895,6 +901,35 @@ async def fake_get(path: str, *, params=None, **_kwargs): await client.close() +@pytest.mark.asyncio +async def test_list_provider_catalog_prefers_workspace_query(monkeypatch) -> None: + client = OpencodeUpstreamClient( + make_settings( + a2a_bearer_token="t-1", + opencode_workspace_root="/safe", + opencode_timeout=1.0, + a2a_log_level="DEBUG", + a2a_log_payloads=False, + ) + ) + + seen = {} + + async def fake_get(path: str, *, params=None, **_kwargs): + seen["path"] = path + seen["params"] = params + return _DummyResponse({"all": [], "default": {}, "connected": []}) + + monkeypatch.setattr(client._client, "get", fake_get) + + await client.list_provider_catalog(directory="/safe/nested", workspace_id="wrk-1") + + assert seen["path"] == "/provider" + assert seen["params"] == {"workspace": "wrk-1"} + + await client.close() + + @pytest.mark.asyncio async def test_send_message_requires_text_or_parts() -> None: client = OpencodeUpstreamClient(