diff --git a/sdk/python/adrian/ws.py b/sdk/python/adrian/ws.py index 9eb4bc3..9b26d70 100644 --- a/sdk/python/adrian/ws.py +++ b/sdk/python/adrian/ws.py @@ -447,6 +447,12 @@ def schedule_connect(self, loop: asyncio.AbstractEventLoop) -> None: if self._connect_task is None or self._connect_task.done(): self._connect_task = loop.create_task(self.connect()) + def _ensure_connect_task(self) -> None: + """Start the initial/reconnect task if none is currently running.""" + if self._connect_task is None or self._connect_task.done(): + loop = asyncio.get_running_loop() + self._connect_task = loop.create_task(self.connect()) + async def connect(self) -> None: """Establish the WebSocket with exponential-backoff retry. @@ -581,6 +587,8 @@ async def _send_frame(self, frame: pb.ClientFrame) -> None: if not self._connected.is_set() or self._replaying: self._buffer_frame(frame_bytes) + if not self._replaying: + self._ensure_connect_task() reason = "disconnected" if not self._connected.is_set() else "replaying" logger.info( "buffered for replay (session_id=%s, kind=%s, " @@ -597,6 +605,8 @@ async def _send_frame(self, frame: pb.ClientFrame) -> None: if ws is None: self._buffer_frame(frame_bytes) + if not self._connected.is_set(): + self._ensure_connect_task() return @@ -878,10 +888,7 @@ async def _handle_disconnect(self, reason: str) -> None: if self._closing: return - loop = asyncio.get_running_loop() - - if self._connect_task is None or self._connect_task.done(): - self._connect_task = loop.create_task(self.connect()) + self._ensure_connect_task() async def _fire_on_disconnect(self, reason: str) -> None: """Invoke the on_disconnect callback, catching any exception.""" diff --git a/sdk/python/tests/test_init.py b/sdk/python/tests/test_init.py index b27c8bf..17dc6e6 100644 --- a/sdk/python/tests/test_init.py +++ b/sdk/python/tests/test_init.py @@ -1,7 +1,10 @@ """Tests for adrian.init / shutdown and auto-instrumentation.""" +# pyright: reportPrivateUsage=false + from __future__ import annotations +import asyncio import os from collections.abc import Iterator from pathlib import Path @@ -10,6 +13,7 @@ import adrian import pytest from adrian.config import AdrianConfig, get_config, is_initialized +from adrian.proto import event_pb2 as pb from langchain_core.callbacks.manager import CallbackManager from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.runnables.base import Runnable @@ -66,6 +70,42 @@ def test_creates_jsonl_file(self, tmp_path: Path) -> None: assert log.exists() + def test_sync_init_first_async_send_starts_connect_task(self) -> None: + """First async send should start connect when init() ran without a loop.""" + adrian.init( + auto_instrument=False, + api_key="k", + ws_url="ws://127.0.0.1:9999/ws", + ) + + ws = adrian._ws_client + assert ws is not None + assert ws._connect_task is None + + frame = pb.ClientFrame() + event = frame.paired_batch.events.add() + event.event_id = "evt-1" + event.invocation_id = "inv-1" + event.session_id = "sess-1" + event.pair_type = pb.PAIR_TYPE_TOOL + event.tool.tool_name = "demo" + + connect_calls: list[int] = [] + + async def _fake_connect() -> None: + connect_calls.append(1) + + async def _send_once() -> None: + with patch.object(ws, "connect", _fake_connect): + await ws._send_frame(frame) + await asyncio.sleep(0) + + asyncio.run(_send_once()) + + assert connect_calls == [1] + assert ws._connect_task is not None + assert len(ws._replay_buffer) == 1 + class TestShutdown: """Tests for adrian.shutdown()."""