diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index 6f6c024..69da468 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -339,9 +339,11 @@ def init( if loop is not None: _ws_client.schedule_connect(loop) else: - logger.debug( - "No running event loop at init(); WebSocket will connect on " - "first send from within an async context." + logger.warning( + "Adrian initialised without a running event loop. WebSocket " + "transport and BLOCK/HITL verdict handling may not be active " + "yet; sync ToolNode.invoke will fail closed until an event " + "loop connects the WebSocket and receives a policy LoginAck." ) if auto_instrument: @@ -857,15 +859,12 @@ def _should_halt(verdict: pb.Verdict) -> bool: "M4": verdict.policy.policy_m4, }.get(mad_prefix, False) - def _patch_tool_node() -> None: - """Patch ToolNode for callback injection + async verdict gate. + """Patch ``ToolNode.invoke`` / ``ainvoke``. - ToolNode dispatches tools via tool.invoke (sync) even within async - Pregel. BaseTool.invoke can't await a verdict from the event loop - thread, so we add the verdict gate here on ToolNode.ainvoke - the - entry point Pregel calls before tool dispatch begins. This is a - complementary gate to BaseTool (which covers direct callers). + ToolNode stays responsible for callback injection. The verdict gate lives + on ``BaseTool`` so async ToolNode dispatch does not consume verdict futures + before individual tools run. """ try: from langgraph.prebuilt import ToolNode @@ -885,7 +884,9 @@ def patched_invoke( config: Any = None, **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 + """Inject Adrian callbacks into sync ToolNode invocation.""" config = _inject_callbacks(config) + return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( diff --git a/sdk/python/tests/test_block_mode.py b/sdk/python/tests/test_block_mode.py index 742249b..deda212 100644 --- a/sdk/python/tests/test_block_mode.py +++ b/sdk/python/tests/test_block_mode.py @@ -95,6 +95,17 @@ def _tool_pair() -> PairedEvent: ) +def _resolved_verdict_future(verdict: pb.Verdict) -> asyncio.Future[pb.Verdict]: + """Build a completed future for sync ToolNode.invoke tests.""" + loop = asyncio.new_event_loop() + try: + fut: asyncio.Future[pb.Verdict] = loop.create_future() + fut.set_result(verdict) + return fut + finally: + loop.close() + + class TestRunIdCorrelation: async def test_llm_pair_populates_run_id_map(self) -> None: mock_ws = AsyncMock() @@ -192,6 +203,120 @@ async def _real_tool(x: str) -> str: assert len(msgs) == 1 assert "BLOCKED" in msgs[0].content + def test_sync_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None: + """Sync MODE_BLOCK mirrors async: in-scope blocking verdict halts.""" + + def _real_tool(x: str) -> str: + """Real tool stub for sync block-mode tests.""" + _real_tool.called = True # type: ignore[attr-defined] + + return x + + _real_tool.called = False # type: ignore[attr-defined] + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + ws._pending_verdicts["llm-evt"] = _resolved_verdict_future( + pb.Verdict(event_id="llm-evt", mad_code="M4_a", policy=policy), + ) + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert _real_tool.called is False # type: ignore[attr-defined] + msgs = result["messages"] + assert len(msgs) == 1 + assert "BLOCKED" in msgs[0].content + + def test_sync_missing_login_ack_halts_tool(self, tmp_path: Path) -> None: + """Sync ToolNode.invoke fails closed until server policy is known.""" + + def _real_tool(x: str) -> str: + """Real tool stub for sync block-mode tests.""" + _real_tool.called = True # type: ignore[attr-defined] + + return x + + _real_tool.called = False # type: ignore[attr-defined] + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert _real_tool.called is False # type: ignore[attr-defined] + msgs = result["messages"] + assert len(msgs) == 1 + assert "BLOCKED" in msgs[0].content + + def test_sync_unresolved_active_policy_halts_tool(self, tmp_path: Path) -> None: + """Sync ToolNode.invoke fails closed when it cannot wait for verdicts.""" + + def _real_tool(x: str) -> str: + """Real tool stub for sync block-mode tests.""" + _real_tool.called = True # type: ignore[attr-defined] + + return x + + _real_tool.called = False # type: ignore[attr-defined] + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert _real_tool.called is False # type: ignore[attr-defined] + msgs = result["messages"] + assert len(msgs) == 1 + assert "BLOCKED" in msgs[0].content + async def test_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None: """MODE_BLOCK with policy_m2=false + mad_code='M2' → continue (out-of-scope).""" @@ -233,6 +358,44 @@ async def _real_tool(x: str) -> str: assert captured == ["hi"] + def test_sync_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None: + """Sync MODE_BLOCK continues when the verdict family is out of scope.""" + captured: list[str] = [] + + def _real_tool(x: str) -> str: + """Real tool stub for sync block-mode tests.""" + captured.append(x) + + return x + + adrian.init( + api_key="k", + log_file=str(tmp_path / "events.jsonl"), + auto_instrument=True, + ws_url="ws://x", + block_timeout=1.0, + ) + + ws = adrian._ws_client + assert ws is not None + policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) # m2 stays False + ws._connected.set() + ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" + ws._pending_verdicts["llm-evt"] = _resolved_verdict_future( + pb.Verdict(event_id="llm-evt", mad_code="M2", policy=policy), + ) + + tool_node = ToolNode([_real_tool]) + ai = AIMessage( + content="", + tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}], + ) + state: dict[str, Any] = {"messages": [ai]} + + tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + + assert captured == ["hi"] + async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None: """Verdict timeout in MODE_BLOCK → fail-closed (tool does NOT run).""" captured: list[str] = [] diff --git a/sdk/python/tests/test_init.py b/sdk/python/tests/test_init.py index 17dc6e6..9f5f197 100644 --- a/sdk/python/tests/test_init.py +++ b/sdk/python/tests/test_init.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio +import logging import os from collections.abc import Iterator from pathlib import Path @@ -70,6 +71,24 @@ def test_creates_jsonl_file(self, tmp_path: Path) -> None: assert log.exists() + def test_warns_when_ws_init_has_no_running_loop( + self, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, + ) -> None: + """init() should warn when WS enforcement starts without a loop.""" + caplog.set_level(logging.WARNING, logger="adrian") + log = tmp_path / "events.jsonl" + + adrian.init( + api_key="k", + log_file=str(log), + auto_instrument=False, + ws_url="ws://x", + ) + + assert "without a running event loop" in caplog.text + def test_sync_init_first_async_send_starts_connect_task(self) -> None: """First async send should start connect when init() ran without a loop.""" adrian.init(