Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions strix/interface/tui/live_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import json
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast


if TYPE_CHECKING:
Expand All @@ -20,34 +20,47 @@ def __init__(self) -> None:
self.events: list[dict[str, Any]] = []
self._next_event_id = 1
self._open_assistant_event_by_agent: dict[str, dict[str, Any]] = {}
self._tool_event_by_call_id: dict[str, dict[str, Any]] = {}
# Keyed by (agent_id, call_id) so identical call_ids from different
# agents never collide and overwrite each other's tool events.
self._tool_event_by_call_id: dict[tuple[str, str], dict[str, Any]] = {}

def hydrate_from_run_dir(self, run_dir: Path) -> None:
state_dir = runtime_state_dir(run_dir)
agents_path = state_dir / "agents.json"
if not agents_path.exists():
return
try:
agents_data = json.loads(agents_path.read_text(encoding="utf-8"))
raw_data: Any = json.loads(agents_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return
statuses = agents_data.get("statuses") or {}
names = agents_data.get("names") or {}
parent_of = agents_data.get("parent_of") or {}
if not isinstance(statuses, dict):

agents_data: dict[str, Any] = (
cast("dict[str, Any]", raw_data) if isinstance(raw_data, dict) else {}
)

statuses = _as_str_any_dict(agents_data.get("statuses"))
names = _as_str_any_dict(agents_data.get("names"))
parent_of = _as_str_any_dict(agents_data.get("parent_of"))

if not statuses:
return

for agent_id, status in statuses.items():
if not isinstance(agent_id, str):
continue
name_val = names.get(agent_id, agent_id)
name = name_val if isinstance(name_val, str) else agent_id

parent_id_val = parent_of.get(agent_id)
parent_id = parent_id_val if isinstance(parent_id_val, str) else None

self.upsert_agent(
agent_id,
name=names.get(agent_id, agent_id) if isinstance(names, dict) else agent_id,
parent_id=parent_of.get(agent_id) if isinstance(parent_of, dict) else None,
name=name,
parent_id=parent_id,
status=str(status),
)
self._hydrate_sdk_session_history(run_dir, statuses.keys())
self._hydrate_sdk_session_history(run_dir, list(statuses.keys()))

def _hydrate_sdk_session_history(self, run_dir: Path, agent_ids: Any) -> None:
def _hydrate_sdk_session_history(self, run_dir: Path, agent_ids: list[str]) -> None:
for agent_id, item, timestamp in load_session_history(run_dir, agent_ids):
self._ingest_session_history_item(
agent_id,
Expand Down Expand Up @@ -212,19 +225,27 @@ def _record_tool_call_data(
timestamp: str | None = None,
) -> None:
call_id = call["call_id"]
existing = self._tool_event_by_call_id.get(call_id)
tool_data = {
"tool_name": call["tool_name"],
"args": call["args"],
"status": "running",
"agent_id": agent_id,
"call_id": call_id,
}
key = (agent_id, call_id)
existing = self._tool_event_by_call_id.get(key)

if existing is None:
tool_data: dict[str, Any] = {
"tool_name": call["tool_name"],
"args": call["args"],
"status": "running",
"agent_id": agent_id,
"call_id": call_id,
}
event = self._append_event(agent_id, "tool", tool_data, timestamp=timestamp)
self._tool_event_by_call_id[call_id] = event
self._tool_event_by_call_id[key] = event
else:
existing["data"].update(tool_data)
# Refresh identifying fields only. Never clobber "status" or
# "result" here -- those are owned by _record_tool_output_data
# once the call has produced output. Overwriting them on a
# replayed/duplicate tool_call_item would regress a completed
# tool back to "running" and orphan its result.
existing["data"]["tool_name"] = call["tool_name"]
existing["data"]["args"] = call["args"]
self._bump_event(existing, timestamp=timestamp)

def _record_tool_output(self, agent_id: str, item: Any) -> None:
Expand All @@ -238,7 +259,8 @@ def _record_tool_output_data(
timestamp: str | None = None,
) -> None:
call_id = output["call_id"]
event = self._tool_event_by_call_id.get(call_id)
key = (agent_id, call_id)
event = self._tool_event_by_call_id.get(key)
if event is None:
event = self._append_event(
agent_id,
Expand All @@ -252,7 +274,7 @@ def _record_tool_output_data(
},
timestamp=timestamp,
)
self._tool_event_by_call_id[call_id] = event
self._tool_event_by_call_id[key] = event

result = _parse_json_value(output["output"])
event["data"]["result"] = result
Expand All @@ -267,7 +289,7 @@ def _append_event(
*,
timestamp: str | None = None,
) -> dict[str, Any]:
event = {
event: dict[str, Any] = {
"id": f"{event_type}_{self._next_event_id}",
"type": event_type,
"agent_id": agent_id,
Expand All @@ -285,6 +307,13 @@ def _bump_event(event: dict[str, Any], *, timestamp: str | None = None) -> None:
event["timestamp"] = timestamp or datetime.now(UTC).isoformat()


def _as_str_any_dict(value: Any) -> dict[str, Any]:
"""Narrow an Any-typed value to dict[str, Any], defaulting to {} otherwise."""
if isinstance(value, dict):
return cast("dict[str, Any]", value)
return {}


def _sdk_tool_call_data(item: Any) -> dict[str, Any]:
raw = getattr(item, "raw_item", None)
call_id = str(_raw_field(raw, "call_id") or _raw_field(raw, "id") or id(item))
Expand Down Expand Up @@ -319,7 +348,7 @@ def _session_message_text(item: dict[str, Any]) -> str:

def _message_content_text(content: Any) -> str:
parts: list[str] = []
content_items = content if isinstance(content, list) else [content]
content_items: list[Any] = content if isinstance(content, list) else [content]
for part in content_items:
if isinstance(part, str):
parts.append(part)
Expand All @@ -332,13 +361,16 @@ def _message_content_text(content: Any) -> str:

def _raw_field(raw: Any, key: str, default: Any = None) -> Any:
if isinstance(raw, dict):
return raw.get(key, default)
raw_dict = cast("dict[str, Any]", raw)
return raw_dict.get(key, default)
return getattr(raw, key, default)


def _parse_json_object(value: Any) -> dict[str, Any]:
parsed = _parse_json_value(value)
return parsed if isinstance(parsed, dict) else {}
if isinstance(parsed, dict):
return cast("dict[str, Any]", parsed)
return {}


def _parse_json_value(value: Any) -> Any:
Expand All @@ -351,6 +383,6 @@ def _parse_json_value(value: Any) -> Any:


def _tool_status_from_result(result: Any) -> str:
if isinstance(result, dict) and result.get("success") is False:
if isinstance(result, dict) and cast("dict[str, Any]", result).get("success") is False:
return "failed"
return "completed"
59 changes: 59 additions & 0 deletions tests/test_live_view.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests verifying the tool-event agent-isolation and status-regression fixes."""

from types import SimpleNamespace
from typing import Any

from strix.interface.tui.live_view import TuiLiveView


def _call_item(call_id: str, name: str = "tool") -> Any:
raw = SimpleNamespace(
call_id=call_id,
id=call_id,
name=name,
arguments="{}",
type="function_call",
)
return SimpleNamespace(type="tool_call_item", raw_item=raw, title=None)


def _output_item(call_id: str, name: str = "tool", output: str = '{"success": true}') -> Any:
raw = SimpleNamespace(call_id=call_id, id=call_id, name=name, type="function_call_output")
return SimpleNamespace(type="tool_call_output_item", raw_item=raw, output=output)


def test_same_call_id_different_agents_stay_isolated() -> None:
view = TuiLiveView()
view._record_tool_call("agent-A", _call_item("shared-id", "read_file"))
view._record_tool_call("agent-B", _call_item("shared-id", "write_file"))

assert len(view.events_for_agent("agent-A")) == 1
assert len(view.events_for_agent("agent-B")) == 1
assert view.events_for_agent("agent-A")[0]["data"]["tool_name"] == "read_file"
assert view.events_for_agent("agent-B")[0]["data"]["tool_name"] == "write_file"


def test_replayed_tool_call_does_not_regress_completed_status() -> None:
view = TuiLiveView()
view._record_tool_call("agent-A", _call_item("id-1"))
view._record_tool_output(
"agent-A",
_output_item("id-1", output='{"success": true, "x": 1}'),
)

# simulate a duplicate/replayed tool_call_item
view._record_tool_call("agent-A", _call_item("id-1"))

event = view.events_for_agent("agent-A")[0]
assert event["data"]["status"] == "completed"
assert event["data"]["result"] == {"success": True, "x": 1}


def test_failed_result_keeps_failed_status_after_replay() -> None:
view = TuiLiveView()
view._record_tool_call("agent-A", _call_item("id-2"))
view._record_tool_output("agent-A", _output_item("id-2", output='{"success": false}'))
view._record_tool_call("agent-A", _call_item("id-2")) # replay

event = view.events_for_agent("agent-A")[0]
assert event["data"]["status"] == "failed"