diff --git a/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py b/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py index 880aceb32..41acfd679 100644 --- a/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py +++ b/python/packages/jumpstarter-cli/jumpstarter_cli/shell_test.py @@ -25,6 +25,7 @@ ) from jumpstarter.client.grpc import Lease, LeaseList +from jumpstarter.common.exceptions import ExporterOfflineError from jumpstarter.config.client import ClientConfigV1Alpha1 from jumpstarter.config.env import JMP_LEASE @@ -776,3 +777,125 @@ async def test_warns_red_when_token_transitions_to_expired( assert len(yellow_calls) >= 1, "Expected yellow warning for near-expiry" assert len(red_calls) >= 1, "Expected red warning for actual expiry" assert token_state["expired_unrecovered"] is True + + +class TestLeaseExpiryDuringHook: + """Tests for issue #235: graceful exit when lease expires during beforeLease hook.""" + + async def test_lease_ended_during_hook_exits_gracefully(self): + """When BaseExceptionGroup is raised and lease_ended is True, + the client should exit with code 0 instead of raising + ExporterOfflineError('Connection to exporter lost'). + """ + lease = Mock() + lease.release = True + lease.name = "test-lease" + lease.lease_ended = True + lease.lease_transferred = False + + config = _DummyConfig() + + @asynccontextmanager + async def lease_async(selector, exporter_name, lease_name, duration, portal, acquisition_timeout): + yield lease + + config.lease_async = lease_async + + async def fake_run_shell(*_args): + raise BaseExceptionGroup( + "connection errors", + [ConnectionError("stream broke")], + ) + + with ( + patch("jumpstarter_cli.shell._monitor_token_expiry", new_callable=AsyncMock), + patch("jumpstarter_cli.shell._run_shell_with_lease_async", side_effect=fake_run_shell), + ): + exit_code = await _shell_with_signal_handling( + config, None, None, None, timedelta(minutes=1), False, (), None, + ) + + assert exit_code == 0 + + async def test_genuine_connection_loss_raises_error(self): + """When BaseExceptionGroup is raised and lease_ended is False and + lease_transferred is False, ExporterOfflineError('Connection to + exporter lost') must be raised (wrapped in an ExceptionGroup by + the outer task group). + """ + lease = Mock() + lease.release = True + lease.name = "test-lease" + lease.lease_ended = False + lease.lease_transferred = False + + config = _DummyConfig() + + @asynccontextmanager + async def lease_async(selector, exporter_name, lease_name, duration, portal, acquisition_timeout): + yield lease + + config.lease_async = lease_async + + async def fake_run_shell(*_args): + raise BaseExceptionGroup( + "connection errors", + [ConnectionError("stream broke")], + ) + + with ( + patch("jumpstarter_cli.shell._monitor_token_expiry", new_callable=AsyncMock), + patch("jumpstarter_cli.shell._run_shell_with_lease_async", side_effect=fake_run_shell), + pytest.raises(BaseExceptionGroup) as exc_info, + ): + await _shell_with_signal_handling( + config, None, None, None, timedelta(minutes=1), False, (), None, + ) + + offline_exceptions = [ + e for e in exc_info.value.exceptions # ty: ignore[unresolved-attribute] + if isinstance(e, ExporterOfflineError) + ] + assert len(offline_exceptions) == 1 + assert "Connection to exporter lost" in str(offline_exceptions[0]) + + async def test_lease_transferred_raises_transfer_error(self): + """When BaseExceptionGroup is raised and lease_transferred is True, + the appropriate transfer error must be raised (wrapped in an + ExceptionGroup by the outer task group). + """ + lease = Mock() + lease.release = True + lease.name = "test-lease" + lease.lease_ended = False + lease.lease_transferred = True + + config = _DummyConfig() + + @asynccontextmanager + async def lease_async(selector, exporter_name, lease_name, duration, portal, acquisition_timeout): + yield lease + + config.lease_async = lease_async + + async def fake_run_shell(*_args): + raise BaseExceptionGroup( + "connection errors", + [ConnectionError("stream broke")], + ) + + with ( + patch("jumpstarter_cli.shell._monitor_token_expiry", new_callable=AsyncMock), + patch("jumpstarter_cli.shell._run_shell_with_lease_async", side_effect=fake_run_shell), + pytest.raises(BaseExceptionGroup) as exc_info, + ): + await _shell_with_signal_handling( + config, None, None, None, timedelta(minutes=1), False, (), None, + ) + + offline_exceptions = [ + e for e in exc_info.value.exceptions # ty: ignore[unresolved-attribute] + if isinstance(e, ExporterOfflineError) + ] + assert len(offline_exceptions) == 1 + assert "transferred" in str(offline_exceptions[0]) diff --git a/python/packages/jumpstarter/jumpstarter/client/lease.py b/python/packages/jumpstarter/jumpstarter/client/lease.py index a51c87db0..998a41932 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease.py @@ -404,48 +404,61 @@ async def monitor_async(self, threshold: timedelta = timedelta(minutes=5)): async def _monitor(): check_interval = 30 # seconds - check periodically for external lease changes last_known_end_time = None - while True: - try: - lease = await self.get() - except Exception as e: - logger.warning("Failed to check lease %s status: %s", self.name, e) - # If we know when the lease should end, use it to bound the sleep - if last_known_end_time is not None: - remain = (last_known_end_time - datetime.now().astimezone()).total_seconds() - if remain <= 0: - logger.info( - "Lease %s estimated to have ended at %s (unable to confirm with server)", - self.name, - last_known_end_time, + try: + while True: + try: + lease = await self.get() + except Exception as e: + logger.warning("Failed to check lease %s status: %s", self.name, e) + # If we know when the lease should end, use it to bound the sleep + if last_known_end_time is not None: + remain = (last_known_end_time - datetime.now().astimezone()).total_seconds() + if remain <= 0: + logger.info( + "Lease %s estimated to have ended at %s (unable to confirm with server)", + self.name, + last_known_end_time, + ) + self._notify_lease_ending(timedelta(0)) + break + await sleep(min(check_interval, remain)) + else: + await sleep(check_interval) + continue + + end_time = self._get_lease_end_time(lease) + if end_time is None: + await sleep(1) + continue + + last_known_end_time = end_time + remain = end_time - datetime.now().astimezone() + if remain < timedelta(0): + logger.info("Lease {} ended at {}".format(self.name, end_time)) + self._notify_lease_ending(timedelta(0)) + break + + # Log once when entering the threshold window + if threshold - timedelta(seconds=check_interval) <= remain < threshold: + logger.info( + "Lease {} ending in {} minutes at {}".format( + self.name, int((remain.total_seconds() + 30) // 60), end_time ) - self._notify_lease_ending(timedelta(0)) - break - await sleep(min(check_interval, remain)) - else: - await sleep(check_interval) - continue - - end_time = self._get_lease_end_time(lease) - if end_time is None: - await sleep(1) - continue - - last_known_end_time = end_time - remain = end_time - datetime.now().astimezone() - if remain < timedelta(0): - logger.info("Lease {} ended at {}".format(self.name, end_time)) - self._notify_lease_ending(timedelta(0)) - break - - # Log once when entering the threshold window - if threshold - timedelta(seconds=check_interval) <= remain < threshold: - logger.info( - "Lease {} ending in {} minutes at {}".format( - self.name, int((remain.total_seconds() + 30) // 60), end_time ) + self._notify_lease_ending(remain) + await sleep(min(remain.total_seconds(), check_interval)) + finally: + if ( + not self.lease_ended + and last_known_end_time is not None + and last_known_end_time <= datetime.now().astimezone() + ): + logger.info( + "Lease %s expired at %s (detected on monitor shutdown)", + self.name, + last_known_end_time, ) - self._notify_lease_ending(remain) - await sleep(min(remain.total_seconds(), check_interval)) + self._notify_lease_ending(timedelta(0)) async with create_task_group() as tg: tg.start_soon(_monitor) diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index 89b7ab818..80e25ea41 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -522,3 +522,53 @@ async def get_then_fail(): callback.assert_called() _, remain_arg = callback.call_args[0] assert remain_arg == timedelta(0) + + @pytest.mark.anyio + async def test_sets_lease_ended_on_cancellation_when_end_time_passed(self): + """When _monitor is cancelled while sleeping and the lease has expired + based on last_known_end_time, lease_ended must be set to True. + + This reproduces issue #235: when a lease expires during the beforeLease + hook and the monitor is cancelled (by monitor_async.__aexit__) before + detecting the expiry, lease_ended stays False, causing the client to + report 'Connection to exporter lost' instead of exiting gracefully. + + We simulate the boundary timing by: + 1. Returning an end_time slightly in the future (100ms) so _monitor + caches it and starts sleeping for ~100ms + 2. Sleeping 200ms in the body so the end_time passes during the monitor + sleep, then exiting the context (cancelling _monitor) + 3. The finally block sees that last_known_end_time has passed and sets + lease_ended = True + """ + lease = self._make_lease_for_monitor() + lease.lease_ended = False + + end_time = datetime.now(tz=timezone.utc) + timedelta(milliseconds=100) + + call_count = 0 + + async def get_with_end_time(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return Mock( + effective_begin_time=end_time - timedelta(hours=1), + effective_duration=timedelta(hours=1), + effective_end_time=end_time, + ) + raise Exception("connection lost") + + lease.get = get_with_end_time + + with patch("jumpstarter.client.lease.sleep", new_callable=AsyncMock) as mock_sleep: + async def controlled_sleep(duration): + await asyncio.sleep(min(duration, 0.5)) + + mock_sleep.side_effect = controlled_sleep + + async with lease.monitor_async(): + await asyncio.sleep(0.2) + + assert call_count >= 1, "get() should have been called at least once to cache end_time" + assert lease.lease_ended is True