diff --git a/controller/internal/service/controller_service.go b/controller/internal/service/controller_service.go index f0ec66f0a..e90800ad8 100644 --- a/controller/internal/service/controller_service.go +++ b/controller/internal/service/controller_service.go @@ -372,14 +372,14 @@ func syncOnlineConditionWithStatus(exporter *jumpstarterdevv1alpha1.Exporter) { // Allowed statuses: // - LeaseReady: Normal operation, lease is active // - BeforeLeaseHook: Hook is running, allows j commands from hooks -// - AfterLeaseHook: Hook is running, allows j commands from hooks // - Unspecified/"": Backwards compatibility with old exporters that don't report status func checkExporterStatusForDriverCalls(exporterStatus string) error { switch exporterStatus { case jumpstarterdevv1alpha1.ExporterStatusLeaseReady, - jumpstarterdevv1alpha1.ExporterStatusBeforeLeaseHook, - jumpstarterdevv1alpha1.ExporterStatusAfterLeaseHook: + jumpstarterdevv1alpha1.ExporterStatusBeforeLeaseHook: return nil + case jumpstarterdevv1alpha1.ExporterStatusAfterLeaseHook: + return status.Errorf(codes.FailedPrecondition, "exporter is not ready (status: %s)", exporterStatus) case jumpstarterdevv1alpha1.ExporterStatusUnspecified, "": // Allow for backwards compatibility with old exporters that don't report status. // The exporter-side check will still validate if it's a new exporter. diff --git a/controller/internal/service/controller_service_test.go b/controller/internal/service/controller_service_test.go index e4d21f717..31ff53175 100644 --- a/controller/internal/service/controller_service_test.go +++ b/controller/internal/service/controller_service_test.go @@ -110,9 +110,11 @@ func TestCheckExporterStatusForDriverCalls(t *testing.T) { expectError: false, }, { - name: "AfterLeaseHook allows driver calls (for j commands in hooks)", - status: jumpstarterdevv1alpha1.ExporterStatusAfterLeaseHook, - expectError: false, + name: "AfterLeaseHook is rejected to prevent dial during cleanup", + status: jumpstarterdevv1alpha1.ExporterStatusAfterLeaseHook, + expectError: true, + expectedCode: codes.FailedPrecondition, + expectedSubstr: "not ready", }, { name: "Unspecified allows driver calls (backwards compatibility)", diff --git a/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver.py b/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver.py old mode 100755 new mode 100644 index 25d07affd..d859c4baa --- a/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver.py +++ b/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver.py @@ -111,21 +111,17 @@ def _ensure_client(self) -> OsipClient: def close(self): """Stop the opensomeip client.""" - if self._osip_client is not None: - try: - self._osip_client.stop() - except Exception: - logger.warning("failed to close opensomeip client", exc_info=True) + with self._osip_lock: + if self._osip_client is not None: + try: + self._osip_client.stop() + except Exception: + logger.warning("failed to close opensomeip client", exc_info=True) + self._osip_client = None super().close() # --- RPC --- - @export - @validate_call(validate_return=True) - def start(self) -> None: - """Force start the SOME/IP client.""" - self._ensure_client() - @export @validate_call(validate_return=True) def rpc_call( @@ -230,25 +226,33 @@ def receive_event(self, timeout: float = 5.0) -> SomeIpEventNotification: # --- Connection Management --- + @export + @validate_call(validate_return=True) + def start(self) -> None: + """Force start the SOME/IP client.""" + self._ensure_client() + @export @validate_call(validate_return=True) def close_connection(self) -> None: """Close the SOME/IP connection.""" - if self._osip_client is not None: - try: - self._osip_client.stop() - except Exception: - logger.warning("failed to stop opensomeip client during close_connection", exc_info=True) + with self._osip_lock: + if self._osip_client is not None: + try: + self._osip_client.stop() + except Exception: + logger.warning("failed to stop opensomeip client during close_connection", exc_info=True) + self._osip_client = None @export @validate_call(validate_return=True) def reconnect(self) -> None: """Reconnect to the SOME/IP endpoint.""" - if self._osip_client is not None: - try: - self._osip_client.stop() - except Exception: - logger.warning("failed to stop opensomeip client during reconnect", exc_info=True) - self._osip_client.start() - else: - self._ensure_client() + with self._osip_lock: + if self._osip_client is not None: + try: + self._osip_client.stop() + except Exception: + logger.warning("failed to stop opensomeip client during reconnect", exc_info=True) + self._osip_client = None + self._ensure_client() diff --git a/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver_test.py b/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver_test.py index d27dddced..705736f93 100644 --- a/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver_test.py +++ b/python/packages/jumpstarter-driver-someip/jumpstarter_driver_someip/driver_test.py @@ -203,6 +203,26 @@ def test_someip_close_connection(mock_osip_cls): mock_client.stop.assert_called() +@patch("jumpstarter_driver_someip.driver.OsipClient") +def test_someip_close_connection_resets_client_for_fresh_creation(mock_osip_cls): + """After close_connection, the next operation must create a fresh client.""" + first_client = _make_mock_osip_client() + second_client = _make_mock_osip_client() + mock_osip_cls.side_effect = [first_client, second_client] + + driver = SomeIp(host="127.0.0.1", port=30490) + with serve(driver) as client: + client.start() + mock_osip_cls.assert_called_once() + + client.close_connection() + first_client.stop.assert_called() + + client.start() + assert mock_osip_cls.call_count == 2 + second_client.start.assert_called_once() + + @patch("jumpstarter_driver_someip.driver.OsipClient") def test_someip_reconnect(mock_osip_cls): mock_client = _make_mock_osip_client() diff --git a/python/packages/jumpstarter/jumpstarter/streams/router.py b/python/packages/jumpstarter/jumpstarter/streams/router.py index b626ad6b3..09f90c753 100644 --- a/python/packages/jumpstarter/jumpstarter/streams/router.py +++ b/python/packages/jumpstarter/jumpstarter/streams/router.py @@ -58,6 +58,8 @@ async def receive(self) -> bytes: async def send_eof(self): with contextlib.suppress(grpc.aio.AioRpcError, asyncio.exceptions.InvalidStateError): + if self.context.done(): + return await self.context.write(self.cls(frame_type=router_pb2.FRAME_TYPE_GOAWAY)) if isinstance(self.context, grpc.aio.StreamStreamCall): await self.context.done_writing() diff --git a/python/packages/jumpstarter/jumpstarter/streams/router_test.py b/python/packages/jumpstarter/jumpstarter/streams/router_test.py new file mode 100644 index 000000000..ec7d07fde --- /dev/null +++ b/python/packages/jumpstarter/jumpstarter/streams/router_test.py @@ -0,0 +1,38 @@ +from unittest.mock import AsyncMock + +import grpc +import pytest +from jumpstarter_protocol import router_pb2 + +from .router import RouterStream + + +@pytest.fixture +def mock_client_context(): + ctx = AsyncMock(spec=grpc.aio.StreamStreamCall) + ctx.done = lambda: False + return ctx + + +@pytest.fixture +def client_stream(mock_client_context): + stream = RouterStream(context=mock_client_context) + return stream + + +class TestSendEofSkipsWriteWhenDone: + @pytest.mark.anyio + async def test_send_eof_writes_goaway_when_context_active(self, client_stream, mock_client_context): + await client_stream.send_eof() + + mock_client_context.write.assert_awaited_once() + frame = mock_client_context.write.call_args[0][0] + assert frame.frame_type == router_pb2.FRAME_TYPE_GOAWAY + + @pytest.mark.anyio + async def test_send_eof_skips_write_when_context_done(self, client_stream, mock_client_context): + mock_client_context.done = lambda: True + + await client_stream.send_eof() + + mock_client_context.write.assert_not_awaited()