Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 31 additions & 35 deletions python/packages/jumpstarter-cli/jumpstarter_cli/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
_TOKEN_REFRESH_THRESHOLD_SECONDS = 120



def _run_shell_only(lease, config, command, path: str) -> int:
"""Run just the shell command without log streaming."""
allow = config.drivers.allow if config is not None else getattr(lease, "allow", [])
Expand Down Expand Up @@ -259,6 +258,35 @@ async def _monitor_token_expiry(config, lease, cancel_scope, token_state=None) -
return


def _handle_after_lease_result(result, monitor):
"""Handle the result of waiting for afterLease hook completion.

Raises ExporterOfflineError when the hook definitively failed.
Returns silently when the outcome is acceptable (hook completed,
connection lost while hook was still running, or timeout).
"""
if result == ExporterStatus.AVAILABLE:
if monitor.status_message and monitor.status_message.startswith(HOOK_WARNING_PREFIX):
warning_text = monitor.status_message[len(HOOK_WARNING_PREFIX) :]
click.echo(click.style(f"Warning: {warning_text}", fg="yellow", bold=True))
logger.info("afterLease hook completed")
elif result == ExporterStatus.AFTER_LEASE_HOOK_FAILED:
reason = monitor.status_message or "afterLease hook failed"
raise ExporterOfflineError(reason)
elif monitor.connection_lost:
if monitor.current_status == ExporterStatus.AFTER_LEASE_HOOK_FAILED:
reason = monitor.status_message or "afterLease hook failed (connection lost)"
raise ExporterOfflineError(reason)
if monitor.current_status == ExporterStatus.AFTER_LEASE_HOOK:
logger.info(
"Connection lost while afterLease hook is running; exporter will continue the hook autonomously"
)
else:
logger.info("Connection lost, skipping afterLease hook wait")
elif result is None:
logger.warning("Timeout waiting for afterLease hook to complete")


async def _run_shell_with_lease_async(lease, exporter_logs, config, command, cancel_scope): # noqa: C901
"""Run shell with lease context managers and wait for afterLease hook if logs enabled.

Expand Down Expand Up @@ -368,41 +396,11 @@ async def _run_shell_with_lease_async(lease, exporter_logs, config, command, can
with anyio.move_on_after(10):
success = await client.end_session_async()
if success:
# Wait for hook to complete using background monitor
# This allows afterLease logs to be displayed in real-time
result = await monitor.wait_for_any_of(
[ExporterStatus.AVAILABLE, ExporterStatus.AFTER_LEASE_HOOK_FAILED],
timeout=300.0,
)
if result == ExporterStatus.AVAILABLE:
if monitor.status_message and monitor.status_message.startswith(
HOOK_WARNING_PREFIX
):
warning_text = monitor.status_message[len(HOOK_WARNING_PREFIX) :]
click.echo(
click.style(f"Warning: {warning_text}", fg="yellow", bold=True)
)
logger.info("afterLease hook completed")
elif result == ExporterStatus.AFTER_LEASE_HOOK_FAILED:
reason = monitor.status_message or "afterLease hook failed"
raise ExporterOfflineError(reason)
elif monitor.connection_lost:
# If connection lost during afterLease hook lifecycle
# (running or failed), the exporter shut down
if monitor.current_status in (
ExporterStatus.AFTER_LEASE_HOOK,
ExporterStatus.AFTER_LEASE_HOOK_FAILED,
):
reason = (
monitor.status_message
or "afterLease hook failed (connection lost)"
)
raise ExporterOfflineError(reason)
# Connection lost but hook wasn't running. This is expected when
# the lease times out — exporter handles its own cleanup.
logger.info("Connection lost, skipping afterLease hook wait")
elif result is None:
logger.warning("Timeout waiting for afterLease hook to complete")
_handle_after_lease_result(result, monitor)
else:
logger.debug("EndSession not implemented, skipping hook wait")
except ExporterOfflineError:
Expand Down Expand Up @@ -570,9 +568,7 @@ async def _shell_direct_async(
async with create_task_group() as tg:
tg.start_soon(signal_handler, tg.cancel_scope)
try:
exit_code = await _run_shell_with_lease_async(
lease, exporter_logs, config, command, tg.cancel_scope
)
exit_code = await _run_shell_with_lease_async(lease, exporter_logs, config, command, tg.cancel_scope)
except grpc.aio.AioRpcError as e:
if e.code() == grpc.StatusCode.UNAUTHENTICATED:
raise click.ClickException("Authentication failed: invalid or missing passphrase") from None
Expand Down
87 changes: 63 additions & 24 deletions python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from jumpstarter_cli.shell import (
_attempt_token_recovery,
_handle_after_lease_result,
_monitor_token_expiry,
_resolve_lease_from_active_async,
_shell_with_signal_handling,
Expand All @@ -25,6 +26,8 @@
)

from jumpstarter.client.grpc import Lease, LeaseList
from jumpstarter.common import ExporterStatus
from jumpstarter.common.exceptions import ExporterOfflineError
from jumpstarter.config.client import ClientConfigV1Alpha1
from jumpstarter.config.env import JMP_LEASE

Expand Down Expand Up @@ -321,6 +324,7 @@ def test_shell_allows_env_lease_without_selector_or_name():

mock_exit.assert_called_once_with(0)


def test_resolve_lease_handles_async_list_leases():
config = Mock(spec=ClientConfigV1Alpha1)
config.metadata = type("Metadata", (), {"name": "test-client"})()
Expand All @@ -333,6 +337,7 @@ def test_resolve_lease_handles_async_list_leases():

def _make_expired_jwt() -> str:
"""Create a JWT with an exp claim in the past (no signature verification needed)."""

def b64url(data: bytes) -> str:
return base64.urlsafe_b64encode(data).rstrip(b"=").decode()

Expand Down Expand Up @@ -437,9 +442,7 @@ async def test_successful_refresh(self, _mock_issuer, mock_oidc_cls, mock_save):
@patch("jumpstarter_cli.shell.ClientConfigV1Alpha1")
@patch("jumpstarter_cli.shell.Config")
@patch("jumpstarter_cli.shell.decode_jwt_issuer", return_value="https://issuer")
async def test_successful_refresh_without_new_refresh_token(
self, _mock_issuer, mock_oidc_cls, _mock_save
):
async def test_successful_refresh_without_new_refresh_token(self, _mock_issuer, mock_oidc_cls, _mock_save):
config = _make_config()
lease = _make_mock_lease()

Expand Down Expand Up @@ -471,9 +474,7 @@ async def test_rollback_on_failure(self, _mock_issuer):
@patch("jumpstarter_cli.shell.ClientConfigV1Alpha1")
@patch("jumpstarter_cli.shell.Config")
@patch("jumpstarter_cli.shell.decode_jwt_issuer", return_value="https://issuer")
async def test_save_failure_does_not_fail_refresh(
self, _mock_issuer, mock_oidc_cls, mock_save, caplog
):
async def test_save_failure_does_not_fail_refresh(self, _mock_issuer, mock_oidc_cls, mock_save, caplog):
"""Disk save is best-effort; refresh should still succeed."""
config = _make_config()
lease = _make_mock_lease()
Expand Down Expand Up @@ -547,9 +548,7 @@ async def test_returns_false_when_disk_token_is_same(self, mock_client_cfg):

@patch("jumpstarter_cli.shell.ClientConfigV1Alpha1")
@patch("jumpstarter_cli.shell.get_token_remaining_seconds", return_value=-10)
async def test_returns_false_when_disk_token_is_expired(
self, _mock_remaining, mock_client_cfg
):
async def test_returns_false_when_disk_token_is_expired(self, _mock_remaining, mock_client_cfg):
config = _make_config(token="old_tok")
disk_config = Mock()
disk_config.token = "disk_tok"
Expand Down Expand Up @@ -640,9 +639,7 @@ async def test_returns_when_remaining_is_none(self, _mock_remaining, _mock_sleep
@patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell.get_token_remaining_seconds")
async def test_refreshes_when_below_threshold(
self, mock_remaining, mock_recovery, mock_sleep, mock_click
):
async def test_refreshes_when_below_threshold(self, mock_remaining, mock_recovery, mock_sleep, mock_click):
# First call: below threshold; second call: raise to exit
mock_remaining.side_effect = [60, Exception("done")]
mock_recovery.return_value = "Token refreshed automatically."
Expand All @@ -659,9 +656,7 @@ async def test_refreshes_when_below_threshold(
@patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell.get_token_remaining_seconds")
async def test_warns_when_refresh_fails(
self, mock_remaining, mock_recovery, mock_sleep, mock_click
):
async def test_warns_when_refresh_fails(self, mock_remaining, mock_recovery, mock_sleep, mock_click):
mock_remaining.side_effect = [60, Exception("done")]
mock_recovery.return_value = None # all recovery failed
config = _make_config()
Expand All @@ -674,9 +669,7 @@ async def test_warns_when_refresh_fails(
@patch("jumpstarter_cli.shell.click")
@patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell.get_token_remaining_seconds")
async def test_warns_within_expiry_window(
self, mock_remaining, mock_sleep, mock_click
):
async def test_warns_within_expiry_window(self, mock_remaining, mock_sleep, mock_click):
from jumpstarter_cli_common.oidc import TOKEN_EXPIRY_WARNING_SECONDS

# First iteration: within warning window but above refresh threshold
Expand Down Expand Up @@ -721,9 +714,7 @@ class _CancelScope(Mock):
@patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell.get_token_remaining_seconds")
async def test_sleeps_5s_when_below_threshold(
self, mock_remaining, mock_recovery, mock_sleep, _mock_click
):
async def test_sleeps_5s_when_below_threshold(self, mock_remaining, mock_recovery, mock_sleep, _mock_click):
mock_remaining.side_effect = [60, Exception("done")]
mock_recovery.return_value = None
config = _make_config()
Expand All @@ -737,9 +728,7 @@ async def test_sleeps_5s_when_below_threshold(
@patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock)
@patch("jumpstarter_cli.shell.get_token_remaining_seconds")
async def test_does_not_cancel_scope_on_expiry(
self, mock_remaining, mock_recovery, mock_sleep, _mock_click
):
async def test_does_not_cancel_scope_on_expiry(self, mock_remaining, mock_recovery, mock_sleep, _mock_click):
"""The monitor must never cancel the scope — the shell stays alive."""
mock_remaining.side_effect = [60, Exception("done")]
mock_recovery.return_value = None
Expand Down Expand Up @@ -776,3 +765,53 @@ async def test_warns_red_when_token_transitions_to_expired(
assert len(yellow_calls) >= 1, "Expected yellow warning for near-expiry"
assert len(red_calls) >= 1, "Expected red warning for actual expiry"
assert token_state["expired_unrecovered"] is True


class TestHandleAfterLeaseResult:
def _make_monitor(self, connection_lost=False, current_status=None, status_message=""):
monitor = Mock()
monitor.connection_lost = connection_lost
monitor.current_status = current_status
monitor.status_message = status_message
return monitor

def test_hook_completed_successfully(self):
monitor = self._make_monitor(current_status=ExporterStatus.AVAILABLE)
_handle_after_lease_result(ExporterStatus.AVAILABLE, monitor)

def test_hook_failed_raises_error(self):
monitor = self._make_monitor(
current_status=ExporterStatus.AFTER_LEASE_HOOK_FAILED,
status_message="hook script exited with code 1",
)
with pytest.raises(ExporterOfflineError, match="hook script exited with code 1"):
_handle_after_lease_result(ExporterStatus.AFTER_LEASE_HOOK_FAILED, monitor)

def test_connection_lost_during_running_hook_does_not_raise(self):
"""When connection is lost while hook is still running, client should
exit gracefully. The exporter continues the hook autonomously."""
monitor = self._make_monitor(
connection_lost=True,
current_status=ExporterStatus.AFTER_LEASE_HOOK,
)
_handle_after_lease_result(None, monitor)

def test_connection_lost_after_hook_failed_raises_error(self):
monitor = self._make_monitor(
connection_lost=True,
current_status=ExporterStatus.AFTER_LEASE_HOOK_FAILED,
status_message="hook crashed",
)
with pytest.raises(ExporterOfflineError, match="hook crashed"):
_handle_after_lease_result(None, monitor)

def test_connection_lost_hook_not_running_exits_gracefully(self):
monitor = self._make_monitor(
connection_lost=True,
current_status=ExporterStatus.LEASE_READY,
)
_handle_after_lease_result(None, monitor)

def test_timeout_returns_none(self):
monitor = self._make_monitor(connection_lost=False)
_handle_after_lease_result(None, monitor)
13 changes: 9 additions & 4 deletions python/packages/jumpstarter/jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,18 +320,23 @@ async def handle_async(self, stream):
response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name))
break
except AioRpcError as e:
if e.code() == grpc.StatusCode.FAILED_PRECONDITION and "not ready" in str(e.details()):
is_not_ready = (
e.code() == grpc.StatusCode.FAILED_PRECONDITION
and "not ready" in str(e.details())
)
is_buffer_full = e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED
if is_not_ready or is_buffer_full:
remaining = deadline - time.monotonic()
if remaining <= 0:
logger.debug(
"Exporter not ready and dial timeout (%.1fs) exceeded after %d attempts",
"Dial retry timeout (%.1fs) exceeded after %d attempts",
self.dial_timeout, attempt + 1
)
raise
delay = min(base_delay * (2 ** attempt), max_delay, remaining)
logger.debug(
"Exporter not ready, retrying Dial in %.1fs (attempt %d, %.1fs remaining)",
delay, attempt + 1, remaining
"Dial transient error (%s), retrying in %.1fs (attempt %d, %.1fs remaining)",
e.code().name, delay, attempt + 1, remaining
)
await sleep(delay)
attempt += 1
Expand Down
66 changes: 66 additions & 0 deletions python/packages/jumpstarter/jumpstarter/client/lease_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,25 @@
from unittest.mock import AsyncMock, Mock, patch

import pytest
from grpc import StatusCode
from grpc.aio import AioRpcError
from rich.console import Console

from jumpstarter.client.lease import Lease, LeaseAcquisitionSpinner


class MockAioRpcError(AioRpcError):
def __init__(self, status_code: StatusCode, message: str = ""):
self._status_code = status_code
self._message = message

def code(self) -> StatusCode:
return self._status_code

def details(self) -> str:
return self._message


class TestLeaseAcquisitionSpinner:
"""Test cases for LeaseAcquisitionSpinner class."""

Expand Down Expand Up @@ -522,3 +536,55 @@ async def get_then_fail():
callback.assert_called()
_, remain_arg = callback.call_args[0]
assert remain_arg == timedelta(0)


class TestHandleAsyncDialRetry:
def _make_lease(self):
lease = object.__new__(Lease)
lease.name = "test-lease"
lease.dial_timeout = 5.0
lease.controller = Mock()
lease.tls_config = Mock()
lease.grpc_options = {}
lease.lease_transferred = False
return lease

@pytest.mark.anyio
async def test_retries_on_resource_exhausted(self):
lease = self._make_lease()
call_count = 0

async def mock_dial(request):
nonlocal call_count
call_count += 1
if call_count == 1:
raise MockAioRpcError(StatusCode.RESOURCE_EXHAUSTED, "listener buffer full on lease test-lease")
response = Mock()
response.router_endpoint = "ep"
response.router_token = "tok"
return response

lease.controller.Dial = mock_dial

with patch("jumpstarter.client.lease.connect_router_stream") as mock_connect:
mock_connect.return_value.__aenter__ = AsyncMock()
mock_connect.return_value.__aexit__ = AsyncMock(return_value=False)
with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock):
await lease.handle_async(Mock())

assert call_count == 2

@pytest.mark.anyio
async def test_resource_exhausted_respects_dial_timeout(self):
lease = self._make_lease()
lease.dial_timeout = 0.0

async def mock_dial(request):
raise MockAioRpcError(StatusCode.RESOURCE_EXHAUSTED, "listener buffer full on lease test-lease")

lease.controller.Dial = mock_dial

with pytest.raises(AioRpcError) as exc_info:
await lease.handle_async(Mock())

assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED
Loading