diff --git a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver.py b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver.py index 23e394bb2..fe5173578 100644 --- a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver.py +++ b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver.py @@ -1,3 +1,4 @@ +from dataclasses import field from functools import reduce from pydantic.dataclasses import dataclass @@ -20,22 +21,33 @@ class Composite(CompositeInterface, Driver): @dataclass(kw_only=True) class Proxy(Driver): ref: str + _proxy_target: Driver | None = field(default=None, init=False, repr=False) @classmethod def client(cls) -> str: - return "jumpstarter.client.DriverClient" # unused + raise NotImplementedError("Proxy.client() should never be called; report() delegates to target") - def __target(self, root, name): + def _resolve_proxy_target(self, root, name): + if self._proxy_target: + return self._proxy_target try: path = self.ref.split(".") if not path: raise ConfigurationError(f"Proxy driver {name} has empty path") - return reduce(lambda instance, name: instance.children[name], path, root) + self._proxy_target = reduce(lambda instance, name: instance.children[name], path, root) + return self._proxy_target except KeyError: raise ConfigurationError(f"Proxy driver {name} references nonexistent driver {self.ref}") from None - def report(self, *, root=None, parent=None, name=None): - return self.__target(root, name).report(root=root, parent=parent, name=name) + def report(self, *, parent=None, name=None): + if not self._proxy_target: + raise RuntimeError("Proxy target not resolved. Call enumerate() before report()") + return self._proxy_target.report(parent=parent, name=name) def enumerate(self, *, root=None, parent=None, name=None): - return self.__target(root, name).enumerate(root=root, parent=parent, name=name) + return self._resolve_proxy_target(root or self, name).enumerate(root=root or self, parent=parent, name=name) + + def __getattr__(self, name): + if not self._proxy_target: + raise RuntimeError(f"Proxy target not resolved. Call enumerate() before accessing '{name}'") + return getattr(self._proxy_target, name) diff --git a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py index fa38cb85c..1378f9551 100644 --- a/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py +++ b/packages/jumpstarter-driver-composite/jumpstarter_driver_composite/driver_test.py @@ -1,7 +1,42 @@ from jumpstarter_driver_power.driver import MockPower +from pydantic.dataclasses import dataclass from .driver import Composite, Proxy from jumpstarter.common.utils import serve +from jumpstarter.driver import Driver, export + + +# Mock serial driver with a connect() method +@dataclass(kw_only=True) +class MockSerial(Driver): + connected: bool = False + + @classmethod + def client(cls) -> str: + return "jumpstarter.client.DriverClient" + + @export + def connect(self): + self.connected = True + return "connected" + + @export + def read(self): + return "data" + + +# Mock parent driver that accesses proxy child methods +@dataclass(kw_only=True) +class MockParent(Driver): + @classmethod + def client(cls) -> str: + return "jumpstarter.client.DriverClient" + + @export + def initialize(self): + # This simulates RideSX accessing self.children["serial"].connect() + result = self.children["serial"].connect() + return f"initialized with {result}" def test_drivers_composite(): @@ -23,3 +58,54 @@ def test_drivers_composite(): client.composite1.power1.on() client.proxy0.on() client.proxy1.power1.on() + + +def test_proxy_method_forwarding(): + """Test that Proxy forwards method calls to target driver""" + # Server-side test: verify __getattr__ works on Proxy + actual_serial = MockSerial() + proxy = Proxy(ref="test") + composite = Composite( + children={ + "proxy_serial": proxy, + "test": actual_serial, + } + ) + + # Simulate enumerate() being called (happens during serve()) + composite.enumerate() + + # Now test that proxy forwards method calls to target + result = proxy.connect() + assert result == "connected" + assert actual_serial.connected is True + + data = proxy.read() + assert data == "data" + + +def test_proxy_in_parent_child(): + """Test that parent driver can call methods on Proxy child (RideSX scenario)""" + # Server-side test: verify parent accessing self.children["serial"].method() + actual_serial = MockSerial() + proxy = Proxy(ref="actual_serial") + parent = MockParent( + children={ + "serial": proxy, + } + ) + composite = Composite( + children={ + "parent": parent, + "actual_serial": actual_serial, + } + ) + + # Simulate enumerate() being called (happens during serve()) + composite.enumerate() + + # Now test that parent.initialize() works, which internally calls + # self.children["serial"].connect() on the Proxy + result = parent.initialize() + assert result == "initialized with connected" + assert actual_serial.connected is True diff --git a/packages/jumpstarter/jumpstarter/driver/base.py b/packages/jumpstarter/jumpstarter/driver/base.py index 895b007c2..78273ef43 100644 --- a/packages/jumpstarter/jumpstarter/driver/base.py +++ b/packages/jumpstarter/jumpstarter/driver/base.py @@ -195,16 +195,12 @@ async def Stream(self, request, context): ) as stream: yield stream - def report(self, *, root=None, parent=None, name=None): + def report(self, *, parent=None, name=None): """ Create DriverInstanceReport :meta private: """ - - if root is None: - root = self - return jumpstarter_pb2.DriverInstanceReport( uuid=str(self.uuid), parent_uuid=str(parent.uuid) if parent else None,