Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions sdk/python/adrian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,11 @@ def init(
if loop is not None:
_ws_client.schedule_connect(loop)
else:
logger.debug(
"No running event loop at init(); WebSocket will connect on "
"first send from within an async context."
logger.warning(
"Adrian initialised without a running event loop. WebSocket "
"transport and BLOCK/HITL verdict handling may not be active "
"yet; sync ToolNode.invoke will fail closed until an event "
"loop connects the WebSocket and receives a policy LoginAck."
)

if auto_instrument:
Expand Down Expand Up @@ -857,15 +859,12 @@ def _should_halt(verdict: pb.Verdict) -> bool:
"M4": verdict.policy.policy_m4,
}.get(mad_prefix, False)


def _patch_tool_node() -> None:
"""Patch ToolNode for callback injection + async verdict gate.
"""Patch ``ToolNode.invoke`` / ``ainvoke``.

ToolNode dispatches tools via tool.invoke (sync) even within async
Pregel. BaseTool.invoke can't await a verdict from the event loop
thread, so we add the verdict gate here on ToolNode.ainvoke - the
entry point Pregel calls before tool dispatch begins. This is a
complementary gate to BaseTool (which covers direct callers).
ToolNode stays responsible for callback injection. The verdict gate lives
on ``BaseTool`` so async ToolNode dispatch does not consume verdict futures
before individual tools run.
"""
try:
from langgraph.prebuilt import ToolNode
Expand All @@ -885,7 +884,9 @@ def patched_invoke(
config: Any = None,
**kwargs: Any, # noqa: A002, ANN401
) -> Any: # noqa: ANN401
"""Inject Adrian callbacks into sync ToolNode invocation."""
config = _inject_callbacks(config)

return original_invoke(self, input, config=config, **kwargs)

async def patched_ainvoke(
Expand Down
163 changes: 163 additions & 0 deletions sdk/python/tests/test_block_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ def _tool_pair() -> PairedEvent:
)


def _resolved_verdict_future(verdict: pb.Verdict) -> asyncio.Future[pb.Verdict]:
"""Build a completed future for sync ToolNode.invoke tests."""
loop = asyncio.new_event_loop()
try:
fut: asyncio.Future[pb.Verdict] = loop.create_future()
fut.set_result(verdict)
return fut
finally:
loop.close()


class TestRunIdCorrelation:
async def test_llm_pair_populates_run_id_map(self) -> None:
mock_ws = AsyncMock()
Expand Down Expand Up @@ -192,6 +203,120 @@ async def _real_tool(x: str) -> str:
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

def test_sync_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None:
"""Sync MODE_BLOCK mirrors async: in-scope blocking verdict halts."""

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
_real_tool.called = True # type: ignore[attr-defined]

return x

_real_tool.called = False # type: ignore[attr-defined]

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

ws = adrian._ws_client
assert ws is not None
policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True)
ws._connected.set()
ws._tool_call_id_to_event_id["tc-1"] = "llm-evt"
ws._pending_verdicts["llm-evt"] = _resolved_verdict_future(
pb.Verdict(event_id="llm-evt", mad_code="M4_a", policy=policy),
)

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert _real_tool.called is False # type: ignore[attr-defined]
msgs = result["messages"]
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

def test_sync_missing_login_ack_halts_tool(self, tmp_path: Path) -> None:
"""Sync ToolNode.invoke fails closed until server policy is known."""

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
_real_tool.called = True # type: ignore[attr-defined]

return x

_real_tool.called = False # type: ignore[attr-defined]

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert _real_tool.called is False # type: ignore[attr-defined]
msgs = result["messages"]
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

def test_sync_unresolved_active_policy_halts_tool(self, tmp_path: Path) -> None:
"""Sync ToolNode.invoke fails closed when it cannot wait for verdicts."""

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
_real_tool.called = True # type: ignore[attr-defined]

return x

_real_tool.called = False # type: ignore[attr-defined]

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

ws = adrian._ws_client
assert ws is not None
_apply_mode(ws, pb.MODE_BLOCK, policy_m4=True)
ws._connected.set()
ws._tool_call_id_to_event_id["tc-1"] = "llm-evt"

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert _real_tool.called is False # type: ignore[attr-defined]
msgs = result["messages"]
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

async def test_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None:
"""MODE_BLOCK with policy_m2=false + mad_code='M2' → continue (out-of-scope)."""

Expand Down Expand Up @@ -233,6 +358,44 @@ async def _real_tool(x: str) -> str:

assert captured == ["hi"]

def test_sync_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None:
"""Sync MODE_BLOCK continues when the verdict family is out of scope."""
captured: list[str] = []

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
captured.append(x)

return x

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

ws = adrian._ws_client
assert ws is not None
policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) # m2 stays False
ws._connected.set()
ws._tool_call_id_to_event_id["tc-1"] = "llm-evt"
ws._pending_verdicts["llm-evt"] = _resolved_verdict_future(
pb.Verdict(event_id="llm-evt", mad_code="M2", policy=policy),
)

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert captured == ["hi"]

async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None:
"""Verdict timeout in MODE_BLOCK → fail-closed (tool does NOT run)."""
captured: list[str] = []
Expand Down
19 changes: 19 additions & 0 deletions sdk/python/tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import asyncio
import logging
import os
from collections.abc import Iterator
from pathlib import Path
Expand Down Expand Up @@ -70,6 +71,24 @@ def test_creates_jsonl_file(self, tmp_path: Path) -> None:

assert log.exists()

def test_warns_when_ws_init_has_no_running_loop(
self,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None:
"""init() should warn when WS enforcement starts without a loop."""
caplog.set_level(logging.WARNING, logger="adrian")
log = tmp_path / "events.jsonl"

adrian.init(
api_key="k",
log_file=str(log),
auto_instrument=False,
ws_url="ws://x",
)

assert "without a running event loop" in caplog.text

def test_sync_init_first_async_send_starts_connect_task(self) -> None:
"""First async send should start connect when init() ran without a loop."""
adrian.init(
Expand Down
Loading