From c0077a0cd6b6987ce25a779ed904222105697baf Mon Sep 17 00:00:00 2001 From: haosenwang1018 Date: Thu, 12 Mar 2026 04:33:56 +0000 Subject: [PATCH] fix(python): honor auto_restart on process exit --- python/copilot/client.py | 82 ++++++++++++++++++++++++++++++++-------- python/test_client.py | 52 +++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 15 deletions(-) diff --git a/python/copilot/client.py b/python/copilot/client.py index df09a755b..f0be85e9b 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -75,6 +75,16 @@ def _get_bundled_cli_path() -> str | None: return None +class _AutoRestartRequestProxy: + """Stable request facade that can retry once after reconnecting.""" + + def __init__(self, owner: "CopilotClient"): + self._owner = owner + + async def request(self, method: str, params: dict | None = None, **kwargs: Any) -> Any: + return await self._owner._request_with_auto_restart(method, params, **kwargs) + + class CopilotClient: """ Main client for interacting with the Copilot CLI. @@ -216,6 +226,8 @@ def __init__(self, options: CopilotClientOptions | None = None): ] = {} self._lifecycle_handlers_lock = threading.Lock() self._rpc: ServerRpc | None = None + self._request_proxy = _AutoRestartRequestProxy(self) + self._reconnect_lock = asyncio.Lock() self._negotiated_protocol_version: int | None = None @property @@ -279,6 +291,44 @@ def _parse_cli_url(self, url: str) -> tuple[str, int]: return (host, port) + async def _request_with_auto_restart( + self, method: str, params: dict | None = None, **kwargs: Any + ) -> Any: + """Send an RPC request, reconnecting and retrying once after process exit.""" + if not self._client: + raise RuntimeError("Client not connected") + + client = self._client + try: + return await client.request(method, params, **kwargs) + except ProcessExitedError: + if not self.options.get("auto_restart", True): + raise + await self._reconnect(client) + if not self._client: + raise RuntimeError("Client not connected") + return await self._client.request(method, params, **kwargs) + + async def _reconnect(self, failed_client: JsonRpcClient | None = None) -> None: + """Reconnect the transport while preserving session objects.""" + async with self._reconnect_lock: + if ( + failed_client is not None + and self._client is not failed_client + and self._state == "connected" + ): + return + + with self._sessions_lock: + saved_sessions = dict(self._sessions) + + await self.force_stop() + + with self._sessions_lock: + self._sessions = saved_sessions + + await self.start() + async def start(self) -> None: """ Start the CLI server and establish a connection. @@ -614,7 +664,7 @@ async def create_session(self, config: SessionConfig) -> CopilotSession: # Create and register the session before issuing the RPC so that # events emitted by the CLI (e.g. session.start) are not dropped. - session = CopilotSession(session_id, self._client, None) + session = CopilotSession(session_id, self._request_proxy, None) session._register_tools(tools) session._register_permission_handler(on_permission_request) if on_user_input_request: @@ -628,7 +678,7 @@ async def create_session(self, config: SessionConfig) -> CopilotSession: self._sessions[session_id] = session try: - response = await self._client.request("session.create", payload) + response = await self._request_proxy.request("session.create", payload) session._workspace_path = response.get("workspacePath") except BaseException: with self._sessions_lock: @@ -813,7 +863,7 @@ async def resume_session(self, session_id: str, config: ResumeSessionConfig) -> # Create and register the session before issuing the RPC so that # events emitted by the CLI (e.g. session.start) are not dropped. - session = CopilotSession(session_id, self._client, None) + session = CopilotSession(session_id, self._request_proxy, None) session._register_tools(cfg.get("tools")) session._register_permission_handler(on_permission_request) if on_user_input_request: @@ -827,7 +877,7 @@ async def resume_session(self, session_id: str, config: ResumeSessionConfig) -> self._sessions[session_id] = session try: - response = await self._client.request("session.resume", payload) + response = await self._request_proxy.request("session.resume", payload) session._workspace_path = response.get("workspacePath") except BaseException: with self._sessions_lock: @@ -870,7 +920,7 @@ async def ping(self, message: str | None = None) -> "PingResponse": if not self._client: raise RuntimeError("Client not connected") - result = await self._client.request("ping", {"message": message}) + result = await self._request_proxy.request("ping", {"message": message}) return PingResponse.from_dict(result) async def get_status(self) -> "GetStatusResponse": @@ -890,7 +940,7 @@ async def get_status(self) -> "GetStatusResponse": if not self._client: raise RuntimeError("Client not connected") - result = await self._client.request("status.get", {}) + result = await self._request_proxy.request("status.get", {}) return GetStatusResponse.from_dict(result) async def get_auth_status(self) -> "GetAuthStatusResponse": @@ -911,7 +961,7 @@ async def get_auth_status(self) -> "GetAuthStatusResponse": if not self._client: raise RuntimeError("Client not connected") - result = await self._client.request("auth.getStatus", {}) + result = await self._request_proxy.request("auth.getStatus", {}) return GetAuthStatusResponse.from_dict(result) async def list_models(self) -> list["ModelInfo"]: @@ -955,7 +1005,7 @@ async def list_models(self) -> list["ModelInfo"]: raise RuntimeError("Client not connected") # Cache miss - fetch from backend while holding lock - response = await self._client.request("models.list", {}) + response = await self._request_proxy.request("models.list", {}) models_data = response.get("models", []) models = [ModelInfo.from_dict(model) for model in models_data] @@ -997,7 +1047,7 @@ async def list_sessions( if filter is not None: payload["filter"] = filter.to_dict() - response = await self._client.request("session.list", payload) + response = await self._request_proxy.request("session.list", payload) sessions_data = response.get("sessions", []) return [SessionMetadata.from_dict(session) for session in sessions_data] @@ -1022,7 +1072,7 @@ async def delete_session(self, session_id: str) -> None: if not self._client: raise RuntimeError("Client not connected") - response = await self._client.request("session.delete", {"sessionId": session_id}) + response = await self._request_proxy.request("session.delete", {"sessionId": session_id}) success = response.get("success", False) if not success: @@ -1056,7 +1106,7 @@ async def get_last_session_id(self) -> str | None: if not self._client: raise RuntimeError("Client not connected") - response = await self._client.request("session.getLastId", {}) + response = await self._request_proxy.request("session.getLastId", {}) return response.get("sessionId") async def get_foreground_session_id(self) -> str | None: @@ -1080,7 +1130,7 @@ async def get_foreground_session_id(self) -> str | None: if not self._client: raise RuntimeError("Client not connected") - response = await self._client.request("session.getForeground", {}) + response = await self._request_proxy.request("session.getForeground", {}) return response.get("sessionId") async def set_foreground_session_id(self, session_id: str) -> None: @@ -1102,7 +1152,9 @@ async def set_foreground_session_id(self, session_id: str) -> None: if not self._client: raise RuntimeError("Client not connected") - response = await self._client.request("session.setForeground", {"sessionId": session_id}) + response = await self._request_proxy.request( + "session.setForeground", {"sessionId": session_id} + ) success = response.get("success", False) if not success: @@ -1406,7 +1458,7 @@ async def _connect_via_stdio(self) -> None: # Create JSON-RPC client with the process self._client = JsonRpcClient(self._process) - self._rpc = ServerRpc(self._client) + self._rpc = ServerRpc(self._request_proxy) # Set up notification handler for session events # Note: This handler is called from the event loop (thread-safe scheduling) @@ -1493,7 +1545,7 @@ def wait(self, timeout=None): self._process = SocketWrapper(sock_file, sock) # type: ignore self._client = JsonRpcClient(self._process) - self._rpc = ServerRpc(self._client) + self._rpc = ServerRpc(self._request_proxy) # Set up notification handler for session events def handle_notification(method: str, params: dict): diff --git a/python/test_client.py b/python/test_client.py index 4a06966d4..3e7111cd9 100644 --- a/python/test_client.py +++ b/python/test_client.py @@ -4,9 +4,13 @@ This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.py instead. """ +from types import SimpleNamespace +from unittest.mock import AsyncMock + import pytest from copilot import CopilotClient, PermissionHandler, define_tool +from copilot.jsonrpc import ProcessExitedError from copilot.types import ModelCapabilities, ModelInfo, ModelLimits, ModelSupports from e2e.testharness import CLI_PATH @@ -151,6 +155,54 @@ def test_use_logged_in_user_with_cli_url_raises(self): ) +class TestAutoRestart: + @pytest.mark.asyncio + async def test_request_proxy_retries_once_after_process_exit(self): + client = CopilotClient( + {"cli_path": CLI_PATH, "auto_restart": True, "log_level": "error"} + ) + client._state = "connected" + failed_request = AsyncMock(side_effect=ProcessExitedError("boom")) + client._client = SimpleNamespace(request=failed_request) + + replacement_request = AsyncMock( + return_value={ + "message": "pong: health check", + "timestamp": 123, + "protocolVersion": 2, + } + ) + replacement_client = SimpleNamespace(request=replacement_request) + reconnect = AsyncMock( + side_effect=lambda failed_client=None: setattr(client, "_client", replacement_client) + ) + client._reconnect = reconnect + + response = await client.ping("health check") + + assert response.message == "pong: health check" + assert response.timestamp == 123 + reconnect.assert_awaited_once() + failed_request.assert_awaited_once_with("ping", {"message": "health check"}) + replacement_request.assert_awaited_once_with("ping", {"message": "health check"}) + + @pytest.mark.asyncio + async def test_request_proxy_propagates_process_exit_when_auto_restart_disabled(self): + client = CopilotClient( + {"cli_path": CLI_PATH, "auto_restart": False, "log_level": "error"} + ) + client._state = "connected" + client._client = SimpleNamespace( + request=AsyncMock(side_effect=ProcessExitedError("boom")) + ) + client._reconnect = AsyncMock() + + with pytest.raises(ProcessExitedError, match="boom"): + await client.ping() + + client._reconnect.assert_not_awaited() + + class TestOverridesBuiltInTool: @pytest.mark.asyncio async def test_overrides_built_in_tool_sent_in_tool_definition(self):