diff --git a/strix/interface/tui/live_view.py b/strix/interface/tui/live_view.py index 993074d67..77388676d 100644 --- a/strix/interface/tui/live_view.py +++ b/strix/interface/tui/live_view.py @@ -4,7 +4,7 @@ import json from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: @@ -20,7 +20,9 @@ def __init__(self) -> None: self.events: list[dict[str, Any]] = [] self._next_event_id = 1 self._open_assistant_event_by_agent: dict[str, dict[str, Any]] = {} - self._tool_event_by_call_id: dict[str, dict[str, Any]] = {} + # Keyed by (agent_id, call_id) so identical call_ids from different + # agents never collide and overwrite each other's tool events. + self._tool_event_by_call_id: dict[tuple[str, str], dict[str, Any]] = {} def hydrate_from_run_dir(self, run_dir: Path) -> None: state_dir = runtime_state_dir(run_dir) @@ -28,26 +30,37 @@ def hydrate_from_run_dir(self, run_dir: Path) -> None: if not agents_path.exists(): return try: - agents_data = json.loads(agents_path.read_text(encoding="utf-8")) + raw_data: Any = json.loads(agents_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return - statuses = agents_data.get("statuses") or {} - names = agents_data.get("names") or {} - parent_of = agents_data.get("parent_of") or {} - if not isinstance(statuses, dict): + + agents_data: dict[str, Any] = ( + cast("dict[str, Any]", raw_data) if isinstance(raw_data, dict) else {} + ) + + statuses = _as_str_any_dict(agents_data.get("statuses")) + names = _as_str_any_dict(agents_data.get("names")) + parent_of = _as_str_any_dict(agents_data.get("parent_of")) + + if not statuses: return + for agent_id, status in statuses.items(): - if not isinstance(agent_id, str): - continue + name_val = names.get(agent_id, agent_id) + name = name_val if isinstance(name_val, str) else agent_id + + parent_id_val = parent_of.get(agent_id) + parent_id = parent_id_val if isinstance(parent_id_val, str) else None + self.upsert_agent( agent_id, - name=names.get(agent_id, agent_id) if isinstance(names, dict) else agent_id, - parent_id=parent_of.get(agent_id) if isinstance(parent_of, dict) else None, + name=name, + parent_id=parent_id, status=str(status), ) - self._hydrate_sdk_session_history(run_dir, statuses.keys()) + self._hydrate_sdk_session_history(run_dir, list(statuses.keys())) - def _hydrate_sdk_session_history(self, run_dir: Path, agent_ids: Any) -> None: + def _hydrate_sdk_session_history(self, run_dir: Path, agent_ids: list[str]) -> None: for agent_id, item, timestamp in load_session_history(run_dir, agent_ids): self._ingest_session_history_item( agent_id, @@ -212,19 +225,27 @@ def _record_tool_call_data( timestamp: str | None = None, ) -> None: call_id = call["call_id"] - existing = self._tool_event_by_call_id.get(call_id) - tool_data = { - "tool_name": call["tool_name"], - "args": call["args"], - "status": "running", - "agent_id": agent_id, - "call_id": call_id, - } + key = (agent_id, call_id) + existing = self._tool_event_by_call_id.get(key) + if existing is None: + tool_data: dict[str, Any] = { + "tool_name": call["tool_name"], + "args": call["args"], + "status": "running", + "agent_id": agent_id, + "call_id": call_id, + } event = self._append_event(agent_id, "tool", tool_data, timestamp=timestamp) - self._tool_event_by_call_id[call_id] = event + self._tool_event_by_call_id[key] = event else: - existing["data"].update(tool_data) + # Refresh identifying fields only. Never clobber "status" or + # "result" here -- those are owned by _record_tool_output_data + # once the call has produced output. Overwriting them on a + # replayed/duplicate tool_call_item would regress a completed + # tool back to "running" and orphan its result. + existing["data"]["tool_name"] = call["tool_name"] + existing["data"]["args"] = call["args"] self._bump_event(existing, timestamp=timestamp) def _record_tool_output(self, agent_id: str, item: Any) -> None: @@ -238,7 +259,8 @@ def _record_tool_output_data( timestamp: str | None = None, ) -> None: call_id = output["call_id"] - event = self._tool_event_by_call_id.get(call_id) + key = (agent_id, call_id) + event = self._tool_event_by_call_id.get(key) if event is None: event = self._append_event( agent_id, @@ -252,7 +274,7 @@ def _record_tool_output_data( }, timestamp=timestamp, ) - self._tool_event_by_call_id[call_id] = event + self._tool_event_by_call_id[key] = event result = _parse_json_value(output["output"]) event["data"]["result"] = result @@ -267,7 +289,7 @@ def _append_event( *, timestamp: str | None = None, ) -> dict[str, Any]: - event = { + event: dict[str, Any] = { "id": f"{event_type}_{self._next_event_id}", "type": event_type, "agent_id": agent_id, @@ -285,6 +307,13 @@ def _bump_event(event: dict[str, Any], *, timestamp: str | None = None) -> None: event["timestamp"] = timestamp or datetime.now(UTC).isoformat() +def _as_str_any_dict(value: Any) -> dict[str, Any]: + """Narrow an Any-typed value to dict[str, Any], defaulting to {} otherwise.""" + if isinstance(value, dict): + return cast("dict[str, Any]", value) + return {} + + def _sdk_tool_call_data(item: Any) -> dict[str, Any]: raw = getattr(item, "raw_item", None) call_id = str(_raw_field(raw, "call_id") or _raw_field(raw, "id") or id(item)) @@ -319,7 +348,7 @@ def _session_message_text(item: dict[str, Any]) -> str: def _message_content_text(content: Any) -> str: parts: list[str] = [] - content_items = content if isinstance(content, list) else [content] + content_items: list[Any] = content if isinstance(content, list) else [content] for part in content_items: if isinstance(part, str): parts.append(part) @@ -332,13 +361,16 @@ def _message_content_text(content: Any) -> str: def _raw_field(raw: Any, key: str, default: Any = None) -> Any: if isinstance(raw, dict): - return raw.get(key, default) + raw_dict = cast("dict[str, Any]", raw) + return raw_dict.get(key, default) return getattr(raw, key, default) def _parse_json_object(value: Any) -> dict[str, Any]: parsed = _parse_json_value(value) - return parsed if isinstance(parsed, dict) else {} + if isinstance(parsed, dict): + return cast("dict[str, Any]", parsed) + return {} def _parse_json_value(value: Any) -> Any: @@ -351,6 +383,6 @@ def _parse_json_value(value: Any) -> Any: def _tool_status_from_result(result: Any) -> str: - if isinstance(result, dict) and result.get("success") is False: + if isinstance(result, dict) and cast("dict[str, Any]", result).get("success") is False: return "failed" return "completed" diff --git a/tests/test_live_view.py b/tests/test_live_view.py new file mode 100644 index 000000000..9b0eebd96 --- /dev/null +++ b/tests/test_live_view.py @@ -0,0 +1,59 @@ +"""Tests verifying the tool-event agent-isolation and status-regression fixes.""" + +from types import SimpleNamespace +from typing import Any + +from strix.interface.tui.live_view import TuiLiveView + + +def _call_item(call_id: str, name: str = "tool") -> Any: + raw = SimpleNamespace( + call_id=call_id, + id=call_id, + name=name, + arguments="{}", + type="function_call", + ) + return SimpleNamespace(type="tool_call_item", raw_item=raw, title=None) + + +def _output_item(call_id: str, name: str = "tool", output: str = '{"success": true}') -> Any: + raw = SimpleNamespace(call_id=call_id, id=call_id, name=name, type="function_call_output") + return SimpleNamespace(type="tool_call_output_item", raw_item=raw, output=output) + + +def test_same_call_id_different_agents_stay_isolated() -> None: + view = TuiLiveView() + view._record_tool_call("agent-A", _call_item("shared-id", "read_file")) + view._record_tool_call("agent-B", _call_item("shared-id", "write_file")) + + assert len(view.events_for_agent("agent-A")) == 1 + assert len(view.events_for_agent("agent-B")) == 1 + assert view.events_for_agent("agent-A")[0]["data"]["tool_name"] == "read_file" + assert view.events_for_agent("agent-B")[0]["data"]["tool_name"] == "write_file" + + +def test_replayed_tool_call_does_not_regress_completed_status() -> None: + view = TuiLiveView() + view._record_tool_call("agent-A", _call_item("id-1")) + view._record_tool_output( + "agent-A", + _output_item("id-1", output='{"success": true, "x": 1}'), + ) + + # simulate a duplicate/replayed tool_call_item + view._record_tool_call("agent-A", _call_item("id-1")) + + event = view.events_for_agent("agent-A")[0] + assert event["data"]["status"] == "completed" + assert event["data"]["result"] == {"success": True, "x": 1} + + +def test_failed_result_keeps_failed_status_after_replay() -> None: + view = TuiLiveView() + view._record_tool_call("agent-A", _call_item("id-2")) + view._record_tool_output("agent-A", _output_item("id-2", output='{"success": false}')) + view._record_tool_call("agent-A", _call_item("id-2")) # replay + + event = view.events_for_agent("agent-A")[0] + assert event["data"]["status"] == "failed"