diff --git a/python/packages/jumpstarter/jumpstarter/client/lease.py b/python/packages/jumpstarter/jumpstarter/client/lease.py index a51c87db0..02a32b3be 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease.py @@ -89,6 +89,7 @@ class Lease(ContextManagerMixin, AsyncContextManagerMixin): controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False) tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) grpc_options: dict[str, Any] = field(default_factory=dict) + client_name: str | None = None # Name of the current client, used for ownership validation acquisition_timeout: int = field(default=7200) # Timeout in seconds for lease acquisition, polled in 5s intervals dial_timeout: float = field(default=30.0) # Timeout in seconds for Dial retry loop when exporter not ready exporter_name: str = field(default="remote", init=False) # Populated during acquisition @@ -174,6 +175,12 @@ async def request_async(self): if self.name: logger.debug("using existing lease via env or flag %s", self.name) existing_lease = await self.get() + # Verify the lease belongs to the current client + if self.client_name and existing_lease.client != self.client_name: + raise LeaseError( + f"lease {self.name} belongs to client '{existing_lease.client}', " + f"not the current client '{self.client_name}'" + ) if self.selector is not None and existing_lease.selector != self.selector: logger.warning( "Existing lease from env or flag %s has selector '%s' but requested selector is '%s'. " @@ -237,9 +244,7 @@ async def _acquire(self): # Old controllers (pre-918d6341) mark offline-but-matching # exporters as Unsatisfiable with reason "NoExporter". # This is transient — retry with a new lease. - if condition_present_and_equal( - result.conditions, "Unsatisfiable", "True", "NoExporter" - ): + if condition_present_and_equal(result.conditions, "Unsatisfiable", "True", "NoExporter"): await self._handle_no_exporter_retry(spinner, message) continue logger.debug("Lease %s cannot be satisfied: %s", self.name, message) @@ -325,13 +330,16 @@ async def handle_async(self, stream): if remaining <= 0: logger.debug( "Exporter not ready and dial timeout (%.1fs) exceeded after %d attempts", - self.dial_timeout, attempt + 1 + self.dial_timeout, + attempt + 1, ) raise - delay = min(base_delay * (2 ** attempt), max_delay, remaining) + delay = min(base_delay * (2**attempt), max_delay, remaining) logger.debug( "Exporter not ready, retrying Dial in %.1fs (attempt %d, %.1fs remaining)", - delay, attempt + 1, remaining + delay, + attempt + 1, + remaining, ) await sleep(delay) attempt += 1 @@ -340,8 +348,7 @@ async def handle_async(self, stream): if "permission denied" in str(e.details()).lower(): self.lease_transferred = True logger.warning( - "Lease %s has been transferred to another client. " - "Your session is no longer valid.", + "Lease %s has been transferred to another client. Your session is no longer valid.", self.name, ) else: diff --git a/python/packages/jumpstarter/jumpstarter/client/lease_test.py b/python/packages/jumpstarter/jumpstarter/client/lease_test.py index 89b7ab818..87a3f16be 100644 --- a/python/packages/jumpstarter/jumpstarter/client/lease_test.py +++ b/python/packages/jumpstarter/jumpstarter/client/lease_test.py @@ -7,6 +7,7 @@ import pytest from rich.console import Console +from jumpstarter.client.exceptions import LeaseError from jumpstarter.client.lease import Lease, LeaseAcquisitionSpinner @@ -336,6 +337,37 @@ def test_throttling_not_applied_when_console_available(self): assert mock_spinner.update.call_count == 4 +class TestRequestAsyncOwnership: + """Tests for lease ownership validation in request_async.""" + + def _make_lease(self, *, name="test-lease", client_name="my-client"): + lease = object.__new__(Lease) + lease.name = name + lease.client_name = client_name + lease.selector = None + lease.get = AsyncMock() + return lease + + @pytest.mark.anyio + async def test_raises_when_lease_belongs_to_different_client(self): + """request_async should raise LeaseError when the lease belongs to another client.""" + lease = self._make_lease(client_name="my-client") + lease.get.return_value = Mock(client="other-client", selector=None) + + with pytest.raises(LeaseError, match="belongs to client 'other-client'"): + await lease.request_async() + + @pytest.mark.anyio + async def test_skips_check_when_client_name_is_none(self): + """request_async should skip ownership check when client_name is not set.""" + lease = self._make_lease(client_name=None) + lease.get.return_value = Mock(client="other-client", selector=None) + lease._acquire = AsyncMock(return_value=lease) + + result = await lease.request_async() + assert result is lease + + class TestRefreshChannel: """Tests for Lease.refresh_channel.""" diff --git a/python/packages/jumpstarter/jumpstarter/config/client.py b/python/packages/jumpstarter/jumpstarter/config/client.py index 153377eb7..6491da085 100644 --- a/python/packages/jumpstarter/jumpstarter/config/client.py +++ b/python/packages/jumpstarter/jumpstarter/config/client.py @@ -56,7 +56,6 @@ def _attach_config_if_expired_token(exc: ConnectionError, config: ClientConfigV1 exc.set_config(config) - def _handle_connection_error(f): @wraps(f) async def wrapper(*args, **kwargs): @@ -330,6 +329,7 @@ async def lease_async( release=release_lease, tls_config=self.tls, grpc_options=self.grpcOptions, + client_name=self.metadata.name, acquisition_timeout=acquisition_timeout_seconds, dial_timeout=self.leases.dial_timeout, ) as lease: