diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py index 47e7952b5..ddde61f19 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -41,7 +41,6 @@ _TOKEN_REFRESH_THRESHOLD_SECONDS = 120 - def _run_shell_only(lease, config, command, path: str) -> int: """Run just the shell command without log streaming.""" allow = config.drivers.allow if config is not None else getattr(lease, "allow", []) @@ -259,6 +258,35 @@ async def _monitor_token_expiry(config, lease, cancel_scope, token_state=None) - return +def _handle_after_lease_result(result, monitor): + """Handle the result of waiting for afterLease hook completion. + + Raises ExporterOfflineError when the hook definitively failed. + Returns silently when the outcome is acceptable (hook completed, + connection lost while hook was still running, or timeout). + """ + if result == ExporterStatus.AVAILABLE: + if monitor.status_message and monitor.status_message.startswith(HOOK_WARNING_PREFIX): + warning_text = monitor.status_message[len(HOOK_WARNING_PREFIX) :] + click.echo(click.style(f"Warning: {warning_text}", fg="yellow", bold=True)) + logger.info("afterLease hook completed") + elif result == ExporterStatus.AFTER_LEASE_HOOK_FAILED: + reason = monitor.status_message or "afterLease hook failed" + raise ExporterOfflineError(reason) + elif monitor.connection_lost: + if monitor.current_status == ExporterStatus.AFTER_LEASE_HOOK_FAILED: + reason = monitor.status_message or "afterLease hook failed (connection lost)" + raise ExporterOfflineError(reason) + if monitor.current_status == ExporterStatus.AFTER_LEASE_HOOK: + logger.info( + "Connection lost while afterLease hook is running; exporter will continue the hook autonomously" + ) + else: + logger.info("Connection lost, skipping afterLease hook wait") + elif result is None: + logger.warning("Timeout waiting for afterLease hook to complete") + + async def _run_shell_with_lease_async(lease, exporter_logs, config, command, cancel_scope): # noqa: C901 """Run shell with lease context managers and wait for afterLease hook if logs enabled. @@ -368,41 +396,11 @@ async def _run_shell_with_lease_async(lease, exporter_logs, config, command, can with anyio.move_on_after(10): success = await client.end_session_async() if success: - # Wait for hook to complete using background monitor - # This allows afterLease logs to be displayed in real-time result = await monitor.wait_for_any_of( [ExporterStatus.AVAILABLE, ExporterStatus.AFTER_LEASE_HOOK_FAILED], timeout=300.0, ) - if result == ExporterStatus.AVAILABLE: - if monitor.status_message and monitor.status_message.startswith( - HOOK_WARNING_PREFIX - ): - warning_text = monitor.status_message[len(HOOK_WARNING_PREFIX) :] - click.echo( - click.style(f"Warning: {warning_text}", fg="yellow", bold=True) - ) - logger.info("afterLease hook completed") - elif result == ExporterStatus.AFTER_LEASE_HOOK_FAILED: - reason = monitor.status_message or "afterLease hook failed" - raise ExporterOfflineError(reason) - elif monitor.connection_lost: - # If connection lost during afterLease hook lifecycle - # (running or failed), the exporter shut down - if monitor.current_status in ( - ExporterStatus.AFTER_LEASE_HOOK, - ExporterStatus.AFTER_LEASE_HOOK_FAILED, - ): - reason = ( - monitor.status_message - or "afterLease hook failed (connection lost)" - ) - raise ExporterOfflineError(reason) - # Connection lost but hook wasn't running. This is expected when - # the lease times out — exporter handles its own cleanup. - logger.info("Connection lost, skipping afterLease hook wait") - elif result is None: - logger.warning("Timeout waiting for afterLease hook to complete") + _handle_after_lease_result(result, monitor) else: logger.debug("EndSession not implemented, skipping hook wait") except ExporterOfflineError: @@ -570,9 +568,7 @@ async def _shell_direct_async( async with create_task_group() as tg: tg.start_soon(signal_handler, tg.cancel_scope) try: - exit_code = await _run_shell_with_lease_async( - lease, exporter_logs, config, command, tg.cancel_scope - ) + exit_code = await _run_shell_with_lease_async(lease, exporter_logs, config, command, tg.cancel_scope) except grpc.aio.AioRpcError as e: if e.code() == grpc.StatusCode.UNAUTHENTICATED: raise click.ClickException("Authentication failed: invalid or missing passphrase") from None diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py b/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py index 880aceb32..0db2e170c 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py @@ -14,6 +14,7 @@ from jumpstarter_cli.shell import ( _attempt_token_recovery, + _handle_after_lease_result, _monitor_token_expiry, _resolve_lease_from_active_async, _shell_with_signal_handling, @@ -25,6 +26,8 @@ ) from jumpstarter.client.grpc import Lease, LeaseList +from jumpstarter.common import ExporterStatus +from jumpstarter.common.exceptions import ExporterOfflineError from jumpstarter.config.client import ClientConfigV1Alpha1 from jumpstarter.config.env import JMP_LEASE @@ -321,6 +324,7 @@ def test_shell_allows_env_lease_without_selector_or_name(): mock_exit.assert_called_once_with(0) + def test_resolve_lease_handles_async_list_leases(): config = Mock(spec=ClientConfigV1Alpha1) config.metadata = type("Metadata", (), {"name": "test-client"})() @@ -333,6 +337,7 @@ def test_resolve_lease_handles_async_list_leases(): def _make_expired_jwt() -> str: """Create a JWT with an exp claim in the past (no signature verification needed).""" + def b64url(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode() @@ -437,9 +442,7 @@ async def test_successful_refresh(self, _mock_issuer, mock_oidc_cls, mock_save): @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") @patch("jumpstarter_cli.shell.Config") @patch("jumpstarter_cli.shell.decode_jwt_issuer", return_value="https://issuer") - async def test_successful_refresh_without_new_refresh_token( - self, _mock_issuer, mock_oidc_cls, _mock_save - ): + async def test_successful_refresh_without_new_refresh_token(self, _mock_issuer, mock_oidc_cls, _mock_save): config = _make_config() lease = _make_mock_lease() @@ -471,9 +474,7 @@ async def test_rollback_on_failure(self, _mock_issuer): @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") @patch("jumpstarter_cli.shell.Config") @patch("jumpstarter_cli.shell.decode_jwt_issuer", return_value="https://issuer") - async def test_save_failure_does_not_fail_refresh( - self, _mock_issuer, mock_oidc_cls, mock_save, caplog - ): + async def test_save_failure_does_not_fail_refresh(self, _mock_issuer, mock_oidc_cls, mock_save, caplog): """Disk save is best-effort; refresh should still succeed.""" config = _make_config() lease = _make_mock_lease() @@ -547,9 +548,7 @@ async def test_returns_false_when_disk_token_is_same(self, mock_client_cfg): @patch("jumpstarter_cli.shell.ClientConfigV1Alpha1") @patch("jumpstarter_cli.shell.get_token_remaining_seconds", return_value=-10) - async def test_returns_false_when_disk_token_is_expired( - self, _mock_remaining, mock_client_cfg - ): + async def test_returns_false_when_disk_token_is_expired(self, _mock_remaining, mock_client_cfg): config = _make_config(token="old_tok") disk_config = Mock() disk_config.token = "disk_tok" @@ -640,9 +639,7 @@ async def test_returns_when_remaining_is_none(self, _mock_remaining, _mock_sleep @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) @patch("jumpstarter_cli.shell.get_token_remaining_seconds") - async def test_refreshes_when_below_threshold( - self, mock_remaining, mock_recovery, mock_sleep, mock_click - ): + async def test_refreshes_when_below_threshold(self, mock_remaining, mock_recovery, mock_sleep, mock_click): # First call: below threshold; second call: raise to exit mock_remaining.side_effect = [60, Exception("done")] mock_recovery.return_value = "Token refreshed automatically." @@ -659,9 +656,7 @@ async def test_refreshes_when_below_threshold( @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) @patch("jumpstarter_cli.shell.get_token_remaining_seconds") - async def test_warns_when_refresh_fails( - self, mock_remaining, mock_recovery, mock_sleep, mock_click - ): + async def test_warns_when_refresh_fails(self, mock_remaining, mock_recovery, mock_sleep, mock_click): mock_remaining.side_effect = [60, Exception("done")] mock_recovery.return_value = None # all recovery failed config = _make_config() @@ -674,9 +669,7 @@ async def test_warns_when_refresh_fails( @patch("jumpstarter_cli.shell.click") @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) @patch("jumpstarter_cli.shell.get_token_remaining_seconds") - async def test_warns_within_expiry_window( - self, mock_remaining, mock_sleep, mock_click - ): + async def test_warns_within_expiry_window(self, mock_remaining, mock_sleep, mock_click): from jumpstarter_cli_common.oidc import TOKEN_EXPIRY_WARNING_SECONDS # First iteration: within warning window but above refresh threshold @@ -721,9 +714,7 @@ class _CancelScope(Mock): @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) @patch("jumpstarter_cli.shell.get_token_remaining_seconds") - async def test_sleeps_5s_when_below_threshold( - self, mock_remaining, mock_recovery, mock_sleep, _mock_click - ): + async def test_sleeps_5s_when_below_threshold(self, mock_remaining, mock_recovery, mock_sleep, _mock_click): mock_remaining.side_effect = [60, Exception("done")] mock_recovery.return_value = None config = _make_config() @@ -737,9 +728,7 @@ async def test_sleeps_5s_when_below_threshold( @patch("jumpstarter_cli.shell.anyio.sleep", new_callable=AsyncMock) @patch("jumpstarter_cli.shell._attempt_token_recovery", new_callable=AsyncMock) @patch("jumpstarter_cli.shell.get_token_remaining_seconds") - async def test_does_not_cancel_scope_on_expiry( - self, mock_remaining, mock_recovery, mock_sleep, _mock_click - ): + async def test_does_not_cancel_scope_on_expiry(self, mock_remaining, mock_recovery, mock_sleep, _mock_click): """The monitor must never cancel the scope — the shell stays alive.""" mock_remaining.side_effect = [60, Exception("done")] mock_recovery.return_value = None @@ -776,3 +765,53 @@ async def test_warns_red_when_token_transitions_to_expired( assert len(yellow_calls) >= 1, "Expected yellow warning for near-expiry" assert len(red_calls) >= 1, "Expected red warning for actual expiry" assert token_state["expired_unrecovered"] is True + + +class TestHandleAfterLeaseResult: + def _make_monitor(self, connection_lost=False, current_status=None, status_message=""): + monitor = Mock() + monitor.connection_lost = connection_lost + monitor.current_status = current_status + monitor.status_message = status_message + return monitor + + def test_hook_completed_successfully(self): + monitor = self._make_monitor(current_status=ExporterStatus.AVAILABLE) + _handle_after_lease_result(ExporterStatus.AVAILABLE, monitor) + + def test_hook_failed_raises_error(self): + monitor = self._make_monitor( + current_status=ExporterStatus.AFTER_LEASE_HOOK_FAILED, + status_message="hook script exited with code 1", + ) + with pytest.raises(ExporterOfflineError, match="hook script exited with code 1"): + _handle_after_lease_result(ExporterStatus.AFTER_LEASE_HOOK_FAILED, monitor) + + def test_connection_lost_during_running_hook_does_not_raise(self): + """When connection is lost while hook is still running, client should + exit gracefully. The exporter continues the hook autonomously.""" + monitor = self._make_monitor( + connection_lost=True, + current_status=ExporterStatus.AFTER_LEASE_HOOK, + ) + _handle_after_lease_result(None, monitor) + + def test_connection_lost_after_hook_failed_raises_error(self): + monitor = self._make_monitor( + connection_lost=True, + current_status=ExporterStatus.AFTER_LEASE_HOOK_FAILED, + status_message="hook crashed", + ) + with pytest.raises(ExporterOfflineError, match="hook crashed"): + _handle_after_lease_result(None, monitor) + + def test_connection_lost_hook_not_running_exits_gracefully(self): + monitor = self._make_monitor( + connection_lost=True, + current_status=ExporterStatus.LEASE_READY, + ) + _handle_after_lease_result(None, monitor) + + def test_timeout_returns_none(self): + monitor = self._make_monitor(connection_lost=False) + _handle_after_lease_result(None, monitor) diff --git a/python/packages/jumpstarter/jumpstarter/client/lease.py b/python/packages/jumpstarter/jumpstarter/client/lease.py index a51c87db0..971b75db8 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease.py @@ -320,18 +320,23 @@ async def handle_async(self, stream): response = await self.controller.Dial(jumpstarter_pb2.DialRequest(lease_name=self.name)) break except AioRpcError as e: - if e.code() == grpc.StatusCode.FAILED_PRECONDITION and "not ready" in str(e.details()): + is_not_ready = ( + e.code() == grpc.StatusCode.FAILED_PRECONDITION + and "not ready" in str(e.details()) + ) + is_buffer_full = e.code() == grpc.StatusCode.RESOURCE_EXHAUSTED + if is_not_ready or is_buffer_full: remaining = deadline - time.monotonic() if remaining <= 0: logger.debug( - "Exporter not ready and dial timeout (%.1fs) exceeded after %d attempts", + "Dial retry timeout (%.1fs) exceeded after %d attempts", self.dial_timeout, attempt + 1 ) raise delay = min(base_delay * (2 ** attempt), max_delay, remaining) logger.debug( - "Exporter not ready, retrying Dial in %.1fs (attempt %d, %.1fs remaining)", - delay, attempt + 1, remaining + "Dial transient error (%s), retrying in %.1fs (attempt %d, %.1fs remaining)", + e.code().name, delay, attempt + 1, remaining ) await sleep(delay) attempt += 1 diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index 89b7ab818..68918462c 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -5,11 +5,25 @@ from unittest.mock import AsyncMock, Mock, patch import pytest +from grpc import StatusCode +from grpc.aio import AioRpcError from rich.console import Console from jumpstarter.client.lease import Lease, LeaseAcquisitionSpinner +class MockAioRpcError(AioRpcError): + def __init__(self, status_code: StatusCode, message: str = ""): + self._status_code = status_code + self._message = message + + def code(self) -> StatusCode: + return self._status_code + + def details(self) -> str: + return self._message + + class TestLeaseAcquisitionSpinner: """Test cases for LeaseAcquisitionSpinner class.""" @@ -522,3 +536,55 @@ async def get_then_fail(): callback.assert_called() _, remain_arg = callback.call_args[0] assert remain_arg == timedelta(0) + + +class TestHandleAsyncDialRetry: + def _make_lease(self): + lease = object.__new__(Lease) + lease.name = "test-lease" + lease.dial_timeout = 5.0 + lease.controller = Mock() + lease.tls_config = Mock() + lease.grpc_options = {} + lease.lease_transferred = False + return lease + + @pytest.mark.anyio + async def test_retries_on_resource_exhausted(self): + lease = self._make_lease() + call_count = 0 + + async def mock_dial(request): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise MockAioRpcError(StatusCode.RESOURCE_EXHAUSTED, "listener buffer full on lease test-lease") + response = Mock() + response.router_endpoint = "ep" + response.router_token = "tok" + return response + + lease.controller.Dial = mock_dial + + with patch("jumpstarter.client.lease.connect_router_stream") as mock_connect: + mock_connect.return_value.__aenter__ = AsyncMock() + mock_connect.return_value.__aexit__ = AsyncMock(return_value=False) + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock): + await lease.handle_async(Mock()) + + assert call_count == 2 + + @pytest.mark.anyio + async def test_resource_exhausted_respects_dial_timeout(self): + lease = self._make_lease() + lease.dial_timeout = 0.0 + + async def mock_dial(request): + raise MockAioRpcError(StatusCode.RESOURCE_EXHAUSTED, "listener buffer full on lease test-lease") + + lease.controller.Dial = mock_dial + + with pytest.raises(AioRpcError) as exc_info: + await lease.handle_async(Mock()) + + assert exc_info.value.code() == StatusCode.RESOURCE_EXHAUSTED