diff --git a/python/packages/jumpstarter/jumpstarter/client/lease.py b/python/packages/jumpstarter/jumpstarter/client/lease.py index a51c87db0..fd44e5c62 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease.py @@ -336,6 +336,22 @@ async def handle_async(self, stream): await sleep(delay) attempt += 1 continue + if e.code() == grpc.StatusCode.UNAVAILABLE: + remaining = deadline - time.monotonic() + if remaining <= 0: + logger.warning( + "Exporter unavailable and dial timeout (%.1fs) exceeded after %d attempts", + self.dial_timeout, attempt + 1 + ) + raise + delay = min(base_delay * (2 ** attempt), max_delay, remaining) + logger.debug( + "Exporter unavailable, retrying Dial in %.1fs (attempt %d, %.1fs remaining)", + delay, attempt + 1, remaining + ) + await sleep(delay) + attempt += 1 + continue # Exporter went offline or lease ended - log and exit gracefully if "permission denied" in str(e.details()).lower(): self.lease_transferred = True diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index 89b7ab818..daceea08b 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -4,12 +4,28 @@ from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, Mock, patch +import grpc import pytest +from grpc.aio import AioRpcError from rich.console import Console from jumpstarter.client.lease import Lease, LeaseAcquisitionSpinner +class MockAioRpcError(AioRpcError): + """Mock gRPC error for testing that properly inherits from AioRpcError.""" + + def __init__(self, status_code, message=""): + self._status_code = status_code + self._message = message + + def code(self): + return self._status_code + + def details(self): + return self._message + + class TestLeaseAcquisitionSpinner: """Test cases for LeaseAcquisitionSpinner class.""" @@ -522,3 +538,63 @@ async def get_then_fail(): callback.assert_called() _, remain_arg = callback.call_args[0] assert remain_arg == timedelta(0) + + +class TestHandleAsyncUnavailableRetry: + """Tests for Lease.handle_async UNAVAILABLE retry behavior.""" + + def _make_lease_for_handle(self): + lease = object.__new__(Lease) + lease.name = "test-lease" + lease.dial_timeout = 5.0 + lease.lease_transferred = False + lease.tls_config = Mock() + lease.grpc_options = {} + lease.controller = Mock() + return lease + + @pytest.mark.anyio + async def test_handle_async_retries_unavailable_then_succeeds(self): + """Dial returns UNAVAILABLE once then succeeds on retry.""" + lease = self._make_lease_for_handle() + dial_call_count = 0 + + async def mock_dial(request): + nonlocal dial_call_count + dial_call_count += 1 + if dial_call_count == 1: + raise MockAioRpcError(grpc.StatusCode.UNAVAILABLE, "temporarily unavailable") + return Mock(router_endpoint="endpoint", router_token="token") + + lease.controller.Dial = mock_dial + + with patch("jumpstarter.client.lease.connect_router_stream") as mock_connect: + mock_connect.return_value.__aenter__ = AsyncMock() + mock_connect.return_value.__aexit__ = AsyncMock(return_value=False) + stream = Mock() + + await lease.handle_async(stream) + + assert dial_call_count == 2 + mock_connect.assert_called_once_with("endpoint", "token", stream, lease.tls_config, lease.grpc_options) + + @pytest.mark.anyio + async def test_handle_async_unavailable_exceeds_dial_timeout(self): + """Dial returns UNAVAILABLE until dial_timeout is exceeded, then raises.""" + lease = self._make_lease_for_handle() + lease.dial_timeout = 0.5 + dial_call_count = 0 + + async def mock_dial(request): + nonlocal dial_call_count + dial_call_count += 1 + raise MockAioRpcError(grpc.StatusCode.UNAVAILABLE, "permanently unavailable") + + lease.controller.Dial = mock_dial + stream = Mock() + + with pytest.raises(AioRpcError) as exc_info: + await lease.handle_async(stream) + + assert exc_info.value.code() == grpc.StatusCode.UNAVAILABLE + assert dial_call_count >= 2 diff --git a/python/packages/jumpstarter/jumpstarter/client/status_monitor.py b/python/packages/jumpstarter/jumpstarter/client/status_monitor.py index 5033a0634..dfd582ef3 100644 --- a/python/packages/jumpstarter/jumpstarter/client/status_monitor.py +++ b/python/packages/jumpstarter/jumpstarter/client/status_monitor.py @@ -322,6 +322,8 @@ async def _poll_loop(self): # noqa: C901 return deadline_retries = 0 + unavailable_retries = 0 + unavailable_max_retries = 10 while self._running: try: @@ -343,6 +345,7 @@ async def _poll_loop(self): # noqa: C901 logger.info("Connection recovered, resetting connection_lost flag") self._connection_lost = False deadline_retries = 0 + unavailable_retries = 0 # Detect missed transitions if self._status_version > 0 and new_version > self._status_version + 1: @@ -388,14 +391,21 @@ async def _poll_loop(self): # noqa: C901 self._signal_unsupported() break elif e.code() == StatusCode.UNAVAILABLE: - # Connection lost - exporter closed or restarted - logger.info("Connection lost (UNAVAILABLE), signaling waiters") - self._connection_lost = True - self._running = False - # Fire the change event to wake up any waiters - self._any_change_event.set() - self._any_change_event = Event() - break + unavailable_retries += 1 + if unavailable_retries >= unavailable_max_retries: + logger.warning( + "GetStatus UNAVAILABLE %d times consecutively, marking connection as lost", + unavailable_retries, + ) + self._connection_lost = True + self._running = False + self._any_change_event.set() + self._any_change_event = Event() + break + elif unavailable_retries % 5 == 0: + logger.warning("GetStatus UNAVAILABLE %d times consecutively", unavailable_retries) + else: + logger.debug("GetStatus UNAVAILABLE (attempt %d), retrying...", unavailable_retries) elif e.code() == StatusCode.DEADLINE_EXCEEDED: # DEADLINE_EXCEEDED is a transient error (RPC timed out), not a # permanent connection loss. Keep polling - the shell's own timeout diff --git a/python/packages/jumpstarter/jumpstarter/client/status_monitor_test.py b/python/packages/jumpstarter/jumpstarter/client/status_monitor_test.py index fa6ee42a3..3b79cb14e 100644 --- a/python/packages/jumpstarter/jumpstarter/client/status_monitor_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/status_monitor_test.py @@ -257,22 +257,106 @@ async def test_poll_loop_handles_unimplemented(self) -> None: assert stub._call_count == 1 # Only tried once before giving up - async def test_poll_loop_handles_unavailable(self) -> None: - """Test that poll loop sets connection_lost on UNAVAILABLE.""" + async def test_poll_loop_handles_unavailable_as_transient(self) -> None: + """Test that poll loop treats single UNAVAILABLE as transient and retries. + + A single UNAVAILABLE error (e.g., exporter briefly restarting) should NOT + immediately mark connection as lost. The monitor should continue polling + and recover when the exporter comes back online. + """ responses = [ create_status_response(ExporterStatus.AVAILABLE, version=1), create_mock_rpc_error(StatusCode.UNAVAILABLE), + create_status_response(ExporterStatus.LEASE_READY, version=2), ] stub = MockExporterStub(responses) monitor = StatusMonitor(stub, poll_interval=0.05) async with anyio.create_task_group() as tg: await monitor.start(tg) - await anyio.sleep(0.15) + await anyio.sleep(0.3) + await monitor.stop() + + assert not monitor.connection_lost + assert monitor.current_status == ExporterStatus.LEASE_READY + + async def test_poll_loop_unavailable_threshold(self) -> None: + """Test that poll loop marks connection lost after threshold UNAVAILABLE errors. + + If GetStatus returns UNAVAILABLE 10+ consecutive times, the monitor should + treat this as a permanently lost connection and set connection_lost. + """ + responses = [ + create_status_response(ExporterStatus.AVAILABLE, version=1), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(15) + ] + stub = MockExporterStub(responses, repeat_last=False) + monitor = StatusMonitor(stub, poll_interval=0.01) + + async with anyio.create_task_group() as tg: + await monitor.start(tg) + await anyio.sleep(2.0) await monitor.stop() assert monitor.connection_lost + async def test_poll_loop_unavailable_below_threshold(self) -> None: + """Test that UNAVAILABLE below threshold does not mark connection lost. + + 5 consecutive UNAVAILABLE errors is well below the threshold of 10, so the + monitor should recover when a successful response arrives. + """ + responses = [ + create_status_response(ExporterStatus.AVAILABLE, version=1), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(5) + ] + [ + create_status_response(ExporterStatus.LEASE_READY, version=2), + ] + stub = MockExporterStub(responses) + monitor = StatusMonitor(stub, poll_interval=0.01) + + async with anyio.create_task_group() as tg: + await monitor.start(tg) + await anyio.sleep(1.0) + await monitor.stop() + + assert not monitor.connection_lost + assert monitor.current_status == ExporterStatus.LEASE_READY + + async def test_poll_loop_unavailable_counter_resets_on_success(self) -> None: + """Test that the UNAVAILABLE retry counter resets after a successful poll. + + If the monitor sees some UNAVAILABLE errors, then a success, then more + UNAVAILABLE errors, each run starts counting from zero. The total across + both runs should not trigger connection_lost if each run is below threshold. + """ + responses = [ + create_status_response(ExporterStatus.AVAILABLE, version=1), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(5) + ] + [ + create_status_response(ExporterStatus.LEASE_READY, version=2), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(5) + ] + [ + create_status_response(ExporterStatus.AVAILABLE, version=3), + ] + stub = MockExporterStub(responses) + monitor = StatusMonitor(stub, poll_interval=0.01) + + async with anyio.create_task_group() as tg: + await monitor.start(tg) + await anyio.sleep(2.0) + await monitor.stop() + + assert not monitor.connection_lost + async def test_poll_loop_handles_deadline_exceeded(self) -> None: """Test that poll loop treats DEADLINE_EXCEEDED as transient. @@ -403,18 +487,19 @@ async def test_wait_for_status_timeout(self) -> None: async def test_wait_for_status_connection_lost(self) -> None: """Test wait_for_status returns False when connection is lost.""" - # Return UNAVAILABLE to simulate connection loss responses = [ create_status_response(ExporterStatus.AVAILABLE, version=1), - create_mock_rpc_error(StatusCode.UNAVAILABLE), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(15) ] - stub = MockExporterStub(responses) - monitor = StatusMonitor(stub, poll_interval=0.05) + stub = MockExporterStub(responses, repeat_last=False) + monitor = StatusMonitor(stub, poll_interval=0.01) async with anyio.create_task_group() as tg: await monitor.start(tg) - result = await monitor.wait_for_status(ExporterStatus.LEASE_READY, timeout=0.5) + result = await monitor.wait_for_status(ExporterStatus.LEASE_READY, timeout=2.0) await monitor.stop() @@ -531,16 +616,18 @@ async def test_wait_for_any_of_connection_lost(self) -> None: """Test wait_for_any_of returns None when connection is lost.""" responses = [ create_status_response(ExporterStatus.AVAILABLE, version=1), - create_mock_rpc_error(StatusCode.UNAVAILABLE), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(15) ] - stub = MockExporterStub(responses) - monitor = StatusMonitor(stub, poll_interval=0.05) + stub = MockExporterStub(responses, repeat_last=False) + monitor = StatusMonitor(stub, poll_interval=0.01) async with anyio.create_task_group() as tg: await monitor.start(tg) targets = [ExporterStatus.LEASE_READY] - result = await monitor.wait_for_any_of(targets, timeout=0.5) + result = await monitor.wait_for_any_of(targets, timeout=2.0) await monitor.stop() @@ -723,6 +810,39 @@ async def test_wait_for_any_of_updates_status_message(self) -> None: assert monitor.status_message == "hook script exited with code 1" +class TestStatusMonitorUnavailableRetryDelay: + async def test_unavailable_retries_include_inter_retry_delay(self) -> None: + """Test that UNAVAILABLE retries sleep between attempts. + + Without inter-retry delay, 10 UNAVAILABLE errors (which return + near-instantly) would be exhausted in milliseconds, providing + no time for an exporter to restart. The poll loop must sleep + between UNAVAILABLE retries so the threshold spans a meaningful + wall-clock duration. + """ + import time + + retry_count = 10 + poll_interval = 0.05 + responses = [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(retry_count) + ] + stub = MockExporterStub(responses, repeat_last=False) + monitor = StatusMonitor(stub, poll_interval=poll_interval) + + start = time.monotonic() + await monitor.start() + elapsed = time.monotonic() - start + + assert monitor.connection_lost + minimum_expected = poll_interval * (retry_count - 1) + assert elapsed >= minimum_expected, ( + f"UNAVAILABLE retries completed in {elapsed:.3f}s, " + f"expected at least {minimum_expected:.3f}s with inter-retry delay" + ) + + class TestStatusMonitorPRIssues: """Regression tests for issues reported during PR review of hooks feature.""" @@ -780,53 +900,52 @@ async def test_unimplemented_wait_for_after_hook_returns_promptly(self) -> None: async def test_lease_timeout_no_hooks_detects_connection_loss(self) -> None: """Issue C1: Lease timeout with no hooks should detect connection loss promptly. - When the exporter goes from LEASE_READY to UNAVAILABLE (lease timeout + When the exporter goes from LEASE_READY to sustained UNAVAILABLE (lease timeout with no hooks), wait_for_any_of should detect the connection loss and - return None within 2 seconds. + return None after the retry threshold is exceeded. """ responses = [ create_status_response(ExporterStatus.LEASE_READY, version=1), - create_mock_rpc_error(StatusCode.UNAVAILABLE), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(15) ] - stub = MockExporterStub(responses) - monitor = StatusMonitor(stub, poll_interval=0.05) - - import time + stub = MockExporterStub(responses, repeat_last=False) + monitor = StatusMonitor(stub, poll_interval=0.01) async with anyio.create_task_group() as tg: await monitor.start(tg) - start = time.monotonic() result = await monitor.wait_for_any_of( [ExporterStatus.AVAILABLE, ExporterStatus.AFTER_LEASE_HOOK], timeout=5.0, ) - elapsed = time.monotonic() - start await monitor.stop() assert monitor.connection_lost is True assert result is None - assert elapsed < 2.0, f"Connection loss detection took {elapsed:.1f}s, expected < 2.0s" async def test_lease_timeout_during_before_hook_detects_connection_loss(self) -> None: """Issue C2: Lease timeout during beforeLease should detect connection loss. - When the exporter transitions from BEFORE_LEASE_HOOK to UNAVAILABLE + When the exporter transitions from BEFORE_LEASE_HOOK to sustained UNAVAILABLE (lease timeout at boundary of beforeLease), wait_for_status(LEASE_READY) - should return False with connection_lost=True. + should return False with connection_lost=True after retry threshold is exceeded. """ responses = [ create_status_response(ExporterStatus.BEFORE_LEASE_HOOK, version=1), - create_mock_rpc_error(StatusCode.UNAVAILABLE), + ] + [ + create_mock_rpc_error(StatusCode.UNAVAILABLE) + for _ in range(15) ] - stub = MockExporterStub(responses) - monitor = StatusMonitor(stub, poll_interval=0.05) + stub = MockExporterStub(responses, repeat_last=False) + monitor = StatusMonitor(stub, poll_interval=0.01) async with anyio.create_task_group() as tg: await monitor.start(tg) - result = await monitor.wait_for_status(ExporterStatus.LEASE_READY, timeout=2.0) + result = await monitor.wait_for_status(ExporterStatus.LEASE_READY, timeout=5.0) await monitor.stop()