Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 67 additions & 15 deletions python/copilot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def _get_bundled_cli_path() -> str | None:
return None


class _AutoRestartRequestProxy:
"""Stable request facade that can retry once after reconnecting."""

def __init__(self, owner: "CopilotClient"):
self._owner = owner

async def request(self, method: str, params: dict | None = None, **kwargs: Any) -> Any:
return await self._owner._request_with_auto_restart(method, params, **kwargs)


class CopilotClient:
"""
Main client for interacting with the Copilot CLI.
Expand Down Expand Up @@ -216,6 +226,8 @@ def __init__(self, options: CopilotClientOptions | None = None):
] = {}
self._lifecycle_handlers_lock = threading.Lock()
self._rpc: ServerRpc | None = None
self._request_proxy = _AutoRestartRequestProxy(self)
self._reconnect_lock = asyncio.Lock()
self._negotiated_protocol_version: int | None = None

@property
Expand Down Expand Up @@ -279,6 +291,44 @@ def _parse_cli_url(self, url: str) -> tuple[str, int]:

return (host, port)

async def _request_with_auto_restart(
self, method: str, params: dict | None = None, **kwargs: Any
) -> Any:
"""Send an RPC request, reconnecting and retrying once after process exit."""
if not self._client:
raise RuntimeError("Client not connected")

client = self._client
try:
return await client.request(method, params, **kwargs)
except ProcessExitedError:
if not self.options.get("auto_restart", True):
raise
await self._reconnect(client)
if not self._client:
raise RuntimeError("Client not connected")
return await self._client.request(method, params, **kwargs)

async def _reconnect(self, failed_client: JsonRpcClient | None = None) -> None:
"""Reconnect the transport while preserving session objects."""
async with self._reconnect_lock:
if (
failed_client is not None
and self._client is not failed_client
and self._state == "connected"
):
return

with self._sessions_lock:
saved_sessions = dict(self._sessions)

await self.force_stop()

with self._sessions_lock:
self._sessions = saved_sessions

await self.start()

async def start(self) -> None:
"""
Start the CLI server and establish a connection.
Expand Down Expand Up @@ -614,7 +664,7 @@ async def create_session(self, config: SessionConfig) -> CopilotSession:

# Create and register the session before issuing the RPC so that
# events emitted by the CLI (e.g. session.start) are not dropped.
session = CopilotSession(session_id, self._client, None)
session = CopilotSession(session_id, self._request_proxy, None)
session._register_tools(tools)
session._register_permission_handler(on_permission_request)
if on_user_input_request:
Expand All @@ -628,7 +678,7 @@ async def create_session(self, config: SessionConfig) -> CopilotSession:
self._sessions[session_id] = session

try:
response = await self._client.request("session.create", payload)
response = await self._request_proxy.request("session.create", payload)
session._workspace_path = response.get("workspacePath")
except BaseException:
with self._sessions_lock:
Expand Down Expand Up @@ -813,7 +863,7 @@ async def resume_session(self, session_id: str, config: ResumeSessionConfig) ->

# Create and register the session before issuing the RPC so that
# events emitted by the CLI (e.g. session.start) are not dropped.
session = CopilotSession(session_id, self._client, None)
session = CopilotSession(session_id, self._request_proxy, None)
session._register_tools(cfg.get("tools"))
session._register_permission_handler(on_permission_request)
if on_user_input_request:
Expand All @@ -827,7 +877,7 @@ async def resume_session(self, session_id: str, config: ResumeSessionConfig) ->
self._sessions[session_id] = session

try:
response = await self._client.request("session.resume", payload)
response = await self._request_proxy.request("session.resume", payload)
session._workspace_path = response.get("workspacePath")
except BaseException:
with self._sessions_lock:
Expand Down Expand Up @@ -870,7 +920,7 @@ async def ping(self, message: str | None = None) -> "PingResponse":
if not self._client:
raise RuntimeError("Client not connected")

result = await self._client.request("ping", {"message": message})
result = await self._request_proxy.request("ping", {"message": message})
return PingResponse.from_dict(result)

async def get_status(self) -> "GetStatusResponse":
Expand All @@ -890,7 +940,7 @@ async def get_status(self) -> "GetStatusResponse":
if not self._client:
raise RuntimeError("Client not connected")

result = await self._client.request("status.get", {})
result = await self._request_proxy.request("status.get", {})
return GetStatusResponse.from_dict(result)

async def get_auth_status(self) -> "GetAuthStatusResponse":
Expand All @@ -911,7 +961,7 @@ async def get_auth_status(self) -> "GetAuthStatusResponse":
if not self._client:
raise RuntimeError("Client not connected")

result = await self._client.request("auth.getStatus", {})
result = await self._request_proxy.request("auth.getStatus", {})
return GetAuthStatusResponse.from_dict(result)

async def list_models(self) -> list["ModelInfo"]:
Expand Down Expand Up @@ -955,7 +1005,7 @@ async def list_models(self) -> list["ModelInfo"]:
raise RuntimeError("Client not connected")

# Cache miss - fetch from backend while holding lock
response = await self._client.request("models.list", {})
response = await self._request_proxy.request("models.list", {})
models_data = response.get("models", [])
models = [ModelInfo.from_dict(model) for model in models_data]

Expand Down Expand Up @@ -997,7 +1047,7 @@ async def list_sessions(
if filter is not None:
payload["filter"] = filter.to_dict()

response = await self._client.request("session.list", payload)
response = await self._request_proxy.request("session.list", payload)
sessions_data = response.get("sessions", [])
return [SessionMetadata.from_dict(session) for session in sessions_data]

Expand All @@ -1022,7 +1072,7 @@ async def delete_session(self, session_id: str) -> None:
if not self._client:
raise RuntimeError("Client not connected")

response = await self._client.request("session.delete", {"sessionId": session_id})
response = await self._request_proxy.request("session.delete", {"sessionId": session_id})

success = response.get("success", False)
if not success:
Expand Down Expand Up @@ -1056,7 +1106,7 @@ async def get_last_session_id(self) -> str | None:
if not self._client:
raise RuntimeError("Client not connected")

response = await self._client.request("session.getLastId", {})
response = await self._request_proxy.request("session.getLastId", {})
return response.get("sessionId")

async def get_foreground_session_id(self) -> str | None:
Expand All @@ -1080,7 +1130,7 @@ async def get_foreground_session_id(self) -> str | None:
if not self._client:
raise RuntimeError("Client not connected")

response = await self._client.request("session.getForeground", {})
response = await self._request_proxy.request("session.getForeground", {})
return response.get("sessionId")

async def set_foreground_session_id(self, session_id: str) -> None:
Expand All @@ -1102,7 +1152,9 @@ async def set_foreground_session_id(self, session_id: str) -> None:
if not self._client:
raise RuntimeError("Client not connected")

response = await self._client.request("session.setForeground", {"sessionId": session_id})
response = await self._request_proxy.request(
"session.setForeground", {"sessionId": session_id}
)

success = response.get("success", False)
if not success:
Expand Down Expand Up @@ -1406,7 +1458,7 @@ async def _connect_via_stdio(self) -> None:

# Create JSON-RPC client with the process
self._client = JsonRpcClient(self._process)
self._rpc = ServerRpc(self._client)
self._rpc = ServerRpc(self._request_proxy)

# Set up notification handler for session events
# Note: This handler is called from the event loop (thread-safe scheduling)
Expand Down Expand Up @@ -1493,7 +1545,7 @@ def wait(self, timeout=None):

self._process = SocketWrapper(sock_file, sock) # type: ignore
self._client = JsonRpcClient(self._process)
self._rpc = ServerRpc(self._client)
self._rpc = ServerRpc(self._request_proxy)

# Set up notification handler for session events
def handle_notification(method: str, params: dict):
Expand Down
52 changes: 52 additions & 0 deletions python/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
This file is for unit tests. Where relevant, prefer to add e2e tests in e2e/*.py instead.
"""

from types import SimpleNamespace
from unittest.mock import AsyncMock

import pytest

from copilot import CopilotClient, PermissionHandler, define_tool
from copilot.jsonrpc import ProcessExitedError
from copilot.types import ModelCapabilities, ModelInfo, ModelLimits, ModelSupports
from e2e.testharness import CLI_PATH

Expand Down Expand Up @@ -151,6 +155,54 @@ def test_use_logged_in_user_with_cli_url_raises(self):
)


class TestAutoRestart:
@pytest.mark.asyncio
async def test_request_proxy_retries_once_after_process_exit(self):
client = CopilotClient(
{"cli_path": CLI_PATH, "auto_restart": True, "log_level": "error"}
)
client._state = "connected"
failed_request = AsyncMock(side_effect=ProcessExitedError("boom"))
client._client = SimpleNamespace(request=failed_request)

replacement_request = AsyncMock(
return_value={
"message": "pong: health check",
"timestamp": 123,
"protocolVersion": 2,
}
)
replacement_client = SimpleNamespace(request=replacement_request)
reconnect = AsyncMock(
side_effect=lambda failed_client=None: setattr(client, "_client", replacement_client)
)
client._reconnect = reconnect

response = await client.ping("health check")

assert response.message == "pong: health check"
assert response.timestamp == 123
reconnect.assert_awaited_once()
failed_request.assert_awaited_once_with("ping", {"message": "health check"})
replacement_request.assert_awaited_once_with("ping", {"message": "health check"})

@pytest.mark.asyncio
async def test_request_proxy_propagates_process_exit_when_auto_restart_disabled(self):
client = CopilotClient(
{"cli_path": CLI_PATH, "auto_restart": False, "log_level": "error"}
)
client._state = "connected"
client._client = SimpleNamespace(
request=AsyncMock(side_effect=ProcessExitedError("boom"))
)
client._reconnect = AsyncMock()

with pytest.raises(ProcessExitedError, match="boom"):
await client.ping()

client._reconnect.assert_not_awaited()


class TestOverridesBuiltInTool:
@pytest.mark.asyncio
async def test_overrides_built_in_tool_sent_in_tool_definition(self):
Expand Down