diff --git a/.gitignore b/.gitignore index 49db86ed..246d7df9 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ reports/ .DS_Store /chat /askui_chat.db +.cache/ diff --git a/docs/caching.md b/docs/caching.md new file mode 100644 index 00000000..3730e628 --- /dev/null +++ b/docs/caching.md @@ -0,0 +1,262 @@ +# Caching + +The caching mechanism allows you to record and replay agent action sequences (trajectories) for faster and more robust test execution. This feature is particularly useful for regression testing, where you want to replay known-good interaction sequences to verify that your application still behaves correctly. + +## Overview + +The caching system works by recording all tool use actions (mouse movements, clicks, typing, etc.) performed by the agent during an `act()` execution. These recorded sequences can then be replayed in subsequent executions, allowing the agent to skip the decision-making process and execute the actions directly. + +## Caching Strategies + +The caching mechanism supports four strategies, configured via the `caching_settings` parameter in the `act()` method: + +- **`"no"`** (default): No caching is used. The agent executes normally without recording or replaying actions. +- **`"write"`**: Records all agent actions to a cache file for future replay. +- **`"read"`**: Provides tools to the agent to list and execute previously cached trajectories. +- **`"both"`**: Combines read and write modes - the agent can use existing cached trajectories and will also record new ones. + +## Configuration + +Caching is configured using the `CachingSettings` class: + +```python +from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings + +caching_settings = CachingSettings( + strategy="write", # One of: "read", "write", "both", "no" + cache_dir=".cache", # Directory to store cache files + filename="my_test.json", # Filename for the cache file (optional for write mode) + execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( + delay_time_between_action=0.5 # Delay in seconds between each cached action + ) +) +``` + +### Parameters + +- **`strategy`**: The caching strategy to use (`"read"`, `"write"`, `"both"`, or `"no"`). +- **`cache_dir`**: Directory where cache files are stored. Defaults to `".cache"`. +- **`filename`**: Name of the cache file to write to or read from. If not specified in write mode, a timestamped filename will be generated automatically (format: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). +- **`execute_cached_trajectory_tool_settings`**: Configuration for the trajectory execution tool (optional). See [Execution Settings](#execution-settings) below. + +### Execution Settings + +The `CachedExecutionToolSettings` class allows you to configure how cached trajectories are executed: + +```python +from askui.models.shared.settings import CachedExecutionToolSettings + +execution_settings = CachedExecutionToolSettings( + delay_time_between_action=0.5 # Delay in seconds between each action (default: 0.5) +) +``` + +#### Parameters + +- **`delay_time_between_action`**: The time to wait (in seconds) between executing consecutive cached actions. This delay helps ensure UI elements have time to respond before the next action is executed. Defaults to `0.5` seconds. + +You can adjust this value based on your application's responsiveness: +- For faster applications or quick interactions, you might use a smaller delay (e.g., `0.1` or `0.2` seconds) +- For slower applications or complex UI updates, you might need a longer delay (e.g., `1.0` or `2.0` seconds) + +## Usage Examples + +### Writing a Cache (Recording) + +Record agent actions to a cache file for later replay: + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings + +with VisionAgent() as agent: + agent.act( + goal="Fill out the login form with username 'admin' and password 'secret123'", + caching_settings=CachingSettings( + strategy="write", + cache_dir=".cache", + filename="login_test.json" + ) + ) +``` + +After execution, a cache file will be created at `.cache/login_test.json` containing all the tool use actions performed by the agent. + +### Reading from Cache (Replaying) + +Provide the agent with access to previously recorded trajectories: + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings + +with VisionAgent() as agent: + agent.act( + goal="Fill out the login form", + caching_settings=CachingSettings( + strategy="read", + cache_dir=".cache" + ) + ) +``` + +When using `strategy="read"`, the agent receives two additional tools: + +1. **`retrieve_available_trajectories_tool`**: Lists all available cache files in the cache directory +2. **`execute_cached_executions_tool`**: Executes a specific cached trajectory + +The agent will automatically check if a relevant cached trajectory exists and use it if appropriate. After executing a cached trajectory, the agent will verify the results and make corrections if needed. + +### Using Custom Execution Settings + +You can customize the delay between cached actions to match your application's responsiveness: + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings + +with VisionAgent() as agent: + agent.act( + goal="Fill out the login form", + caching_settings=CachingSettings( + strategy="read", + cache_dir=".cache", + execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( + delay_time_between_action=1.0 # Wait 1 second between each action + ) + ) + ) +``` + +This is particularly useful when: +- Your application has animations or transitions that need time to complete +- UI elements take time to become interactive after appearing +- You're testing on slower hardware or environments + +### Using Both Strategies + +Enable both reading and writing simultaneously: + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings + +with VisionAgent() as agent: + agent.act( + goal="Complete the checkout process", + caching_settings=CachingSettings( + strategy="both", + cache_dir=".cache", + filename="checkout_test.json" + ) + ) +``` + +In this mode: +- The agent can use existing cached trajectories to speed up execution +- New actions will be recorded to the specified cache file +- If a cached execution is used, no new cache file will be written (to avoid duplicates) + +## Cache File Format + +Cache files are JSON files containing an array of tool use blocks. Each block represents a single tool invocation with the following structure: + +```json +[ + { + "type": "tool_use", + "id": "toolu_01AbCdEfGhIjKlMnOpQrStUv", + "name": "computer", + "input": { + "action": "mouse_move", + "coordinate": [150, 200] + } + }, + { + "type": "tool_use", + "id": "toolu_02AbCdEfGhIjKlMnOpQrStUv", + "name": "computer", + "input": { + "action": "left_click" + } + }, + { + "type": "tool_use", + "id": "toolu_03AbCdEfGhIjKlMnOpQrStUv", + "name": "computer", + "input": { + "action": "type", + "text": "admin" + } + } +] +``` + +Note: Screenshot actions are excluded from cached trajectories as they don't modify the UI state. + +## How It Works + +### Write Mode + +In write mode, the `CacheWriter` class: + +1. Intercepts all assistant messages via a callback function +2. Extracts tool use blocks from the messages +3. Stores them in memory during execution +4. Writes them to a JSON file when the agent finishes (on `stop_reason="end_turn"`) +5. Automatically skips writing if a cached execution was used (to avoid recording replays) + +### Read Mode + +In read mode: + +1. Two caching tools are added to the agent's toolbox +2. A special system prompt (`CACHE_USE_PROMPT`) is appended to instruct the agent on how to use trajectories +3. The agent can call `retrieve_available_trajectories_tool` to see available cache files +4. The agent can call `execute_cached_executions_tool` with a trajectory file path to replay it +5. During replay, each tool use block is executed sequentially with a configurable delay between actions (default: 0.5 seconds) +6. Screenshot and trajectory retrieval tools are skipped during replay +7. The agent is instructed to verify results after replay and make corrections if needed + +The delay between actions can be customized using `CachedExecutionToolSettings` to accommodate different application response times. + +## Limitations + +- **UI State Sensitivity**: Cached trajectories assume the UI is in the same state as when they were recorded. If the UI has changed, the replay may fail or produce incorrect results. +- **No on_message Callback**: When using `strategy="write"` or `strategy="both"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. +- **Verification Required**: After executing a cached trajectory, the agent should verify that the results are correct, as UI changes may cause partial failures. + +## Example: Complete Test Workflow + +Here's a complete example showing how to record and replay a test: + +```python +from askui import VisionAgent +from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings + +# Step 1: Record a successful login flow +print("Recording login flow...") +with VisionAgent() as agent: + agent.act( + goal="Navigate to the login page and log in with username 'testuser' and password 'testpass123'", + caching_settings=CachingSettings( + strategy="write", + cache_dir="test_cache", + filename="user_login.json" + ) + ) + +# Step 2: Later, replay the login flow for regression testing +print("\nReplaying login flow for regression test...") +with VisionAgent() as agent: + agent.act( + goal="Log in to the application", + caching_settings=CachingSettings( + strategy="read", + cache_dir="test_cache" + execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( + delay_time_between_action=1.0 + ) + ) + ) +``` diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index bd4fa826..6c5a7b48 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -4,6 +4,7 @@ from abc import ABC from typing import Annotated, Literal, Optional, Type, overload +from anthropic.types.beta import BetaTextBlockParam from dotenv import load_dotenv from pydantic import ConfigDict, Field, field_validator, validate_call from pydantic_settings import BaseSettings, SettingsConfigDict @@ -14,11 +15,17 @@ from askui.locators.locators import Locator from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb -from askui.models.shared.settings import ActSettings +from askui.models.shared.settings import ActSettings, CachingSettings from askui.models.shared.tools import Tool, ToolCollection +from askui.prompts.caching import CACHE_USE_PROMPT from askui.tools.agent_os import AgentOs from askui.tools.android.agent_os import AndroidAgentOs +from askui.tools.caching_tools import ( + ExecuteCachedTrajectory, + RetrieveCachedTestExecutions, +) from askui.utils.annotation_writer import AnnotationWriter +from askui.utils.cache_writer import CacheWriter from askui.utils.image_utils import ImageSource from askui.utils.source_utils import InputSource, load_image_source @@ -180,6 +187,7 @@ def act( on_message: OnMessageCb | None = None, tools: list[Tool] | ToolCollection | None = None, settings: ActSettings | None = None, + caching_settings: CachingSettings | None = None, ) -> None: """ Instructs the agent to achieve a specified goal through autonomous actions. @@ -194,11 +202,17 @@ def act( model (str | None, optional): The composition or name of the model(s) to be used for achieving the `goal`. on_message (OnMessageCb | None, optional): Callback for new messages. If - it returns `None`, stops and does not add the message. + it returns `None`, stops and does not add the message. Cannot be used + with caching_settings strategy "write" or "both". tools (list[Tool] | ToolCollection | None, optional): The tools for the agent. Defaults to default tools depending on the selected model. settings (AgentSettings | None, optional): The settings for the agent. Defaults to a default settings depending on the selected model. + caching_settings (CachingSettings | None, optional): The caching settings + for the act execution. Controls recording and replaying of action + sequences (trajectories). Available strategies: "no" (default, no + caching), "write" (record actions to cache file), "read" (replay from + cached trajectories), "both" (read and write). Defaults to no caching. Returns: None @@ -207,8 +221,11 @@ def act( MaxTokensExceededError: If the model reaches the maximum token limit defined in the agent settings. ModelRefusalError: If the model refuses to process the request. + ValueError: If on_message callback is provided with caching strategy + "write" or "both". Example: + Basic usage without caching: ```python from askui import VisionAgent @@ -217,6 +234,58 @@ def act( agent.act("Search for 'printer' in the search box") agent.act("Log in with username 'admin' and password '1234'") ``` + + Recording actions to a cache file: + ```python + from askui import VisionAgent + from askui.models.shared.settings import CachingSettings + + with VisionAgent() as agent: + agent.act( + goal=( + "Fill out the login form with " + "username 'admin' and password 'secret123'" + ), + caching_settings=CachingSettings( + strategy="write", + cache_dir=".cache", + filename="login_flow.json" + ) + ) + ``` + + Replaying cached actions: + ```python + from askui import VisionAgent + from askui.models.shared.settings import CachingSettings + + with VisionAgent() as agent: + agent.act( + goal="Log in to the application", + caching_settings=CachingSettings( + strategy="read", + cache_dir=".cache" + ) + ) + # Agent will automatically find and use "login_flow.json" + ``` + + Using both read and write modes: + ```python + from askui import VisionAgent + from askui.models.shared.settings import CachingSettings + + with VisionAgent() as agent: + agent.act( + goal="Complete the checkout process", + caching_settings=CachingSettings( + strategy="both", + cache_dir=".cache", + filename="checkout.json" + ) + ) + # Agent can use existing caches and will record new actions + ``` """ goal_str = ( goal @@ -232,7 +301,19 @@ def act( ) _model = self._get_model(model, "act") _settings = settings or self._get_default_settings_for_act(_model) + + _caching_settings: CachingSettings = ( + caching_settings or self._get_default_caching_settings_for_act(_model) + ) + + tools, on_message, cached_execution_tool = self._patch_act_with_cache( + _caching_settings, _settings, tools, on_message + ) _tools = self._build_tools(tools, _model) + + if cached_execution_tool: + cached_execution_tool.set_toolbox(_tools) + self._model_router.act( messages=messages, model=_model, @@ -251,9 +332,79 @@ def _build_tools( return ToolCollection(default_tools) + tools return ToolCollection(tools=default_tools) + def _patch_act_with_cache( + self, + caching_settings: CachingSettings, + settings: ActSettings, + tools: list[Tool] | ToolCollection | None, + on_message: OnMessageCb | None, + ) -> tuple[ + list[Tool] | ToolCollection, OnMessageCb | None, ExecuteCachedTrajectory | None + ]: + """Patch act settings and tools with caching functionality. + + Args: + caching_settings: The caching settings to apply + settings: The act settings to modify + tools: The tools list to extend with caching tools + on_message: The message callback (may be replaced for write mode) + + Returns: + A tuple of (modified_tools, modified_on_message, cached_execution_tool) + """ + caching_tools: list[Tool] = [] + cached_execution_tool: ExecuteCachedTrajectory | None = None + + # Setup read mode: add caching tools and modify system prompt + if caching_settings.strategy in ["read", "both"]: + cached_execution_tool = ExecuteCachedTrajectory( + caching_settings.execute_cached_trajectory_tool_settings + ) + caching_tools.extend( + [ + RetrieveCachedTestExecutions(caching_settings.cache_dir), + cached_execution_tool, + ] + ) + if isinstance(settings.messages.system, str): + settings.messages.system = ( + settings.messages.system + "\n" + CACHE_USE_PROMPT + ) + elif isinstance(settings.messages.system, list): + # Append as a new text block + settings.messages.system = settings.messages.system + [ + BetaTextBlockParam(type="text", text=CACHE_USE_PROMPT) + ] + else: # Omit or None + settings.messages.system = CACHE_USE_PROMPT + + # Add caching tools to the tools list + if isinstance(tools, list): + tools = caching_tools + tools + elif isinstance(tools, ToolCollection): + tools.append_tool(*caching_tools) + else: + tools = caching_tools + + # Setup write mode: create cache writer and set message callback + if caching_settings.strategy in ["write", "both"]: + cache_writer = CacheWriter( + caching_settings.cache_dir, caching_settings.filename + ) + if on_message is None: + on_message = cache_writer.add_message_cb + else: + error_message = "Cannot use on_message callback when writing Cache" + raise ValueError(error_message) + + return tools, on_message, cached_execution_tool + def _get_default_settings_for_act(self, model: str) -> ActSettings: # noqa: ARG002 return ActSettings() + def _get_default_caching_settings_for_act(self, model: str) -> CachingSettings: # noqa: ARG002 + return CachingSettings() + def _get_default_tools_for_act(self, model: str) -> list[Tool]: # noqa: ARG002 return self._tools diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index 968493e1..de2141da 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -6,9 +6,12 @@ BetaToolChoiceParam, ) from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Literal COMPUTER_USE_20250124_BETA_FLAG = "computer-use-2025-01-24" +CACHING_STRATEGY = Literal["read", "write", "both", "no"] + class MessageSettings(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -25,3 +28,16 @@ class ActSettings(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) messages: MessageSettings = Field(default_factory=MessageSettings) + + +class CachedExecutionToolSettings(BaseModel): + delay_time_between_action: float = 0.5 + + +class CachingSettings(BaseModel): + strategy: CACHING_STRATEGY = "no" + cache_dir: str = ".cache" + filename: str = "" + execute_cached_trajectory_tool_settings: CachedExecutionToolSettings = ( + CachedExecutionToolSettings() + ) diff --git a/src/askui/prompts/caching.py b/src/askui/prompts/caching.py new file mode 100644 index 00000000..a89cf224 --- /dev/null +++ b/src/askui/prompts/caching.py @@ -0,0 +1,25 @@ +CACHE_USE_PROMPT = ( + "\n" + " You can use precomputed trajectories to make the execution of the " + "task more robust and faster!\n" + " To do so, first use the RetrieveCachedTestExecutions tool to check " + "which trajectories are available for you.\n" + " The details what each trajectory that is available for you does are " + "at the end of this prompt.\n" + " A trajectory contains all necessary mouse movements, clicks, and " + "typing actions from a previously successful execution.\n" + " If there is a trajectory available for a step you need to take, " + "always use it!\n" + " You can execute a trajectory with the ExecuteCachedExecution tool.\n" + " After a trajectory was executed, make sure to verify the results! " + "While it works most of the time, occasionally, the execution can be " + "(partly) incorrect. So make sure to verify if everything is filled out " + "as expected, and make corrections where necessary!\n" + " \n" + " \n" + " There are several trajectories available to you.\n" + " Their filename is a unique testID.\n" + " If executed using the ExecuteCachedExecution tool, a trajectory will " + "automatically execute all necessary steps for the test with that id.\n" + " \n" +) diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py new file mode 100644 index 00000000..4abb3a55 --- /dev/null +++ b/src/askui/tools/caching_tools.py @@ -0,0 +1,144 @@ +import logging +import time +from pathlib import Path + +from pydantic import validate_call +from typing_extensions import override + +from ..models.shared.settings import CachedExecutionToolSettings +from ..models.shared.tools import Tool, ToolCollection +from ..utils.cache_writer import CacheWriter + +logger = logging.getLogger() + + +class RetrieveCachedTestExecutions(Tool): + """ + List all available trajectory files that can be used for fast-forward execution + """ + + def __init__(self, cache_dir: str, trajectories_format: str = ".json") -> None: + super().__init__( + name="retrieve_available_trajectories_tool", + description=( + "Use this tool to list all available pre-recorded trajectory " + "files in the trajectories directory. These trajectories " + "represent successful UI interaction sequences that can be " + "replayed using the execute_trajectory_tool. Call this tool " + "first to see which trajectories are available before " + "executing one. The tool returns a list of file paths to " + "available trajectory files." + ), + ) + self._cache_dir = Path(cache_dir) + self._trajectories_format = trajectories_format + + @override + @validate_call + def __call__(self) -> list[str]: # type: ignore + if not Path.is_dir(self._cache_dir): + error_msg = f"Trajectories directory not found: {self._cache_dir}" + logger.error(error_msg) + raise FileNotFoundError(error_msg) + + available = [ + str(f) + for f in self._cache_dir.iterdir() + if str(f).endswith(self._trajectories_format) + ] + + if not available: + warning_msg = f"Warning: No trajectory files found in {self._cache_dir}" + logger.warning(warning_msg) + + return available + + +class ExecuteCachedTrajectory(Tool): + """ + Execute a predefined trajectory to fast-forward through UI interactions + """ + + def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: + super().__init__( + name="execute_cached_executions_tool", + description=( + "Execute a pre-recorded trajectory to automatically perform a " + "sequence of UI interactions. This tool replays mouse movements, " + "clicks, and typing actions from a previously successful execution.\n\n" + "Before using this tool:\n" + "1. Use retrieve_available_trajectories_tool to see which " + "trajectory files are available\n" + "2. Select the appropriate trajectory file path from the " + "returned list\n" + "3. Pass the full file path to this tool\n\n" + "The trajectory will be executed step-by-step, and you should " + "verify the results afterward. Note: Trajectories may fail if " + "the UI state has changed since they were recorded." + ), + input_schema={ + "type": "object", + "properties": { + "trajectory_file": { + "type": "string", + "description": ( + "Full path to the trajectory file (use " + "retrieve_available_trajectories_tool to find " + "available files)" + ), + }, + }, + "required": ["trajectory_file"], + }, + ) + if not settings: + settings = CachedExecutionToolSettings() + self._settings = settings + + def set_toolbox(self, toolbox: ToolCollection) -> None: + """Set the AgentOS/AskUiControllerClient reference for executing actions.""" + self._toolbox = toolbox + + @override + @validate_call + def __call__(self, trajectory_file: str) -> str: + if not hasattr(self, "_toolbox"): + error_msg = "Toolbox not set. Call set_toolbox() first." + logger.error(error_msg) + raise RuntimeError(error_msg) + + if not Path(trajectory_file).is_file(): + error_msg = ( + f"Trajectory file not found: {trajectory_file}\n" + "Use retrieve_available_trajectories_tool to see available files." + ) + logger.error(error_msg) + raise FileNotFoundError(error_msg) + + # Load and execute trajectory + trajectory = CacheWriter.read_cache_file(Path(trajectory_file)) + info_msg = f"Executing cached trajectory from {trajectory_file}" + logger.info(info_msg) + for step in trajectory: + if ( + "screenshot" in step.name + or step.name == "retrieve_available_trajectories_tool" + ): + continue + try: + self._toolbox.run([step]) + except Exception as e: + error_msg = f"An error occured during the cached execution: {e}" + logger.exception(error_msg) + return ( + f"An error occured while executing the trajectory from " + f"{trajectory_file}. Please verify the UI state and " + "continue without cache." + ) + time.sleep(self._settings.delay_time_between_action) + + logger.info("Finished executing cached trajectory") + return ( + f"Successfully executed trajectory from {trajectory_file}. " + "Please verify the UI state." + ) diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/cache_writer.py new file mode 100644 index 00000000..36508c73 --- /dev/null +++ b/src/askui/utils/cache_writer.py @@ -0,0 +1,70 @@ +import json +import logging +from datetime import datetime, timezone +from pathlib import Path + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.agent_on_message_cb import OnMessageCbParam + +logger = logging.getLogger(__name__) + + +class CacheWriter: + def __init__(self, cache_dir: str = ".cache", file_name: str = "") -> None: + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(exist_ok=True) + self.messages: list[ToolUseBlockParam] = [] + if file_name and not file_name.endswith(".json"): + file_name += ".json" + self.file_name = file_name + self.was_cached_execution = False + + def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: + """Add a message to cache.""" + if param.message.role == "assistant": + contents = param.message.content + if isinstance(contents, list): + for content in contents: + if isinstance(content, ToolUseBlockParam): + self.messages.append(content) + if content.name == "execute_cached_executions_tool": + self.was_cached_execution = True + if param.message.stop_reason == "end_turn": + self.generate() + + return param.message + + def set_file_name(self, file_name: str) -> None: + if not file_name.endswith(".json"): + file_name += ".json" + self.file_name = file_name + + def reset(self, file_name: str = "") -> None: + self.messages = [] + if file_name and not file_name.endswith(".json"): + file_name += ".json" + self.file_name = file_name + self.was_cached_execution = False + + def generate(self) -> None: + if self.was_cached_execution: + logger.info("Will not write cache file as this was a cached execution") + return + if not self.file_name: + self.file_name = ( + f"cached_trajectory_{datetime.now(tz=timezone.utc):%Y%m%d%H%M%S%f}.json" + ) + cache_file_path = self.cache_dir / self.file_name + + messages_json = [m.model_dump() for m in self.messages] + with cache_file_path.open("w", encoding="utf-8") as f: + json.dump(messages_json, f, indent=4) + info_msg = f"Cache File written at {str(cache_file_path)}" + logger.info(info_msg) + self.reset() + + @staticmethod + def read_cache_file(cache_file_path: Path) -> list[ToolUseBlockParam]: + with cache_file_path.open("r", encoding="utf-8") as f: + raw_trajectory = json.load(f) + return [ToolUseBlockParam(**step) for step in raw_trajectory] diff --git a/tests/e2e/agent/test_act_caching.py b/tests/e2e/agent/test_act_caching.py new file mode 100644 index 00000000..70652b43 --- /dev/null +++ b/tests/e2e/agent/test_act_caching.py @@ -0,0 +1,219 @@ +"""Tests for caching functionality in the act method.""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from askui.agent import VisionAgent +from askui.models.shared.agent_message_param import MessageParam +from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.settings import CachedExecutionToolSettings, CachingSettings + + +def test_act_with_caching_strategy_read(vision_agent: VisionAgent) -> None: + """Test that caching_strategy='read' adds retrieve and execute tools.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create a dummy cache file + cache_dir = Path(temp_dir) + cache_file = cache_dir / "test_cache.json" + cache_file.write_text("[]", encoding="utf-8") + + # Act with read caching strategy + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="read", + cache_dir=str(cache_dir), + ), + ) + assert True + + +def test_act_with_caching_strategy_write(vision_agent: VisionAgent) -> None: + """Test that caching_strategy='write' writes cache file.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_filename = "test_output.json" + + # Act with write caching strategy + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="write", + cache_dir=str(cache_dir), + filename=cache_filename, + ), + ) + + # Verify cache file was created + cache_file = cache_dir / cache_filename + assert cache_file.exists() + + +def test_act_with_caching_strategy_both(vision_agent: VisionAgent) -> None: + """Test that caching_strategy='both' enables both read and write.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_filename = "test_both.json" + + # Create a dummy cache file for reading + cache_file = cache_dir / "existing_cache.json" + cache_file.write_text("[]", encoding="utf-8") + + # Act with both caching strategies + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="both", + cache_dir=str(cache_dir), + filename=cache_filename, + ), + ) + + # Verify new cache file was created + output_file = cache_dir / cache_filename + assert output_file.exists() + + +def test_act_with_caching_strategy_no(vision_agent: VisionAgent) -> None: + """Test that caching_strategy='no' doesn't create cache files.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + + # Act without caching + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="no", + cache_dir=str(cache_dir), + ), + ) + + # Verify no cache files were created + cache_files = list(cache_dir.glob("*.json")) + assert len(cache_files) == 0 + + +def test_act_with_custom_cache_dir_and_filename(vision_agent: VisionAgent) -> None: + """Test that custom cache_dir and cache_filename are used.""" + with tempfile.TemporaryDirectory() as temp_dir: + custom_cache_dir = Path(temp_dir) / "custom_cache" + custom_filename = "my_custom_cache.json" + + # Act with custom cache settings + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="write", + cache_dir=str(custom_cache_dir), + filename=custom_filename, + ), + ) + + # Verify custom cache directory and file were created + assert custom_cache_dir.exists() + cache_file = custom_cache_dir / custom_filename + assert cache_file.exists() + + +def test_act_with_on_message_and_write_caching_raises_error( + vision_agent: VisionAgent, +) -> None: + """Test that providing on_message callback with write caching raises ValueError.""" + with tempfile.TemporaryDirectory() as temp_dir: + + def dummy_callback(param: OnMessageCbParam) -> MessageParam: + return param.message + + # Should raise ValueError when on_message is provided with write strategy + with pytest.raises(ValueError, match="Cannot use on_message callback"): + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="write", + cache_dir=str(temp_dir), + ), + on_message=dummy_callback, + ) + + +def test_act_with_on_message_and_both_caching_raises_error( + vision_agent: VisionAgent, +) -> None: + """Test that providing on_message callback with both caching raises ValueError.""" + with tempfile.TemporaryDirectory() as temp_dir: + + def dummy_callback(param: OnMessageCbParam) -> MessageParam: + return param.message + + # Should raise ValueError when on_message is provided with both strategy + with pytest.raises(ValueError, match="Cannot use on_message callback"): + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="both", + cache_dir=str(temp_dir), + ), + on_message=dummy_callback, + ) + + +def test_cache_file_contains_tool_use_blocks(vision_agent: VisionAgent) -> None: + """Test that cache file contains ToolUseBlockParam entries.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_filename = "tool_blocks.json" + + # Act with caching + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="write", + cache_dir=str(cache_dir), + filename=cache_filename, + ), + ) + + # Read and verify cache file structure + cache_file = cache_dir / cache_filename + assert cache_file.exists() + + with cache_file.open("r", encoding="utf-8") as f: + cache_data: list[dict[str, str]] = json.load(f) + + # Cache should be a list + assert isinstance(cache_data, list) + # Each entry should have tool use structure (name, id, input, type) + for entry in cache_data: + assert "name" in entry + assert "id" in entry + assert "input" in entry + assert "type" in entry + + +def test_act_with_custom_cached_execution_tool_settings( + vision_agent: VisionAgent, +) -> None: + """Test that custom CachedExecutionToolSettings are applied.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + + # Create a dummy cache file for reading + cache_file = cache_dir / "test_cache.json" + cache_file.write_text("[]", encoding="utf-8") + + # Act with custom execution tool settings + custom_settings = CachedExecutionToolSettings(delay_time_between_action=2.0) + vision_agent.act( + goal="Tell me a joke", + caching_settings=CachingSettings( + strategy="read", + cache_dir=str(cache_dir), + execute_cached_trajectory_tool_settings=custom_settings, + ), + ) + + # Test passes if no exceptions are raised + assert True diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py new file mode 100644 index 00000000..a4404114 --- /dev/null +++ b/tests/unit/tools/test_caching_tools.py @@ -0,0 +1,333 @@ +"""Unit tests for caching tools.""" + +import json +import tempfile +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from askui.models.shared.settings import CachedExecutionToolSettings +from askui.models.shared.tools import ToolCollection +from askui.tools.caching_tools import ( + ExecuteCachedTrajectory, + RetrieveCachedTestExecutions, +) + + +def test_retrieve_cached_test_executions_lists_json_files() -> None: + """Test that RetrieveCachedTestExecutions lists all JSON files in cache dir.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + + # Create some cache files + (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") + (cache_dir / "cache2.json").write_text("{}", encoding="utf-8") + (cache_dir / "not_cache.txt").write_text("text", encoding="utf-8") + + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + result = tool() + + assert len(result) == 2 + assert any("cache1.json" in path for path in result) + assert any("cache2.json" in path for path in result) + assert not any("not_cache.txt" in path for path in result) + + +def test_retrieve_cached_test_executions_returns_empty_list_when_no_files() -> None: + """Test that RetrieveCachedTestExecutions returns empty list when no files exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + result = tool() + + assert result == [] + + +def test_retrieve_cached_test_executions_raises_error_when_dir_not_found() -> None: + """Test that RetrieveCachedTestExecutions raises error if directory doesn't exist""" + tool = RetrieveCachedTestExecutions(cache_dir="/non/existent/directory") + + with pytest.raises(FileNotFoundError, match="Trajectories directory not found"): + tool() + + +def test_retrieve_cached_test_executions_respects_custom_format() -> None: + """Test that RetrieveCachedTestExecutions respects custom file format.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + + # Create files with different extensions + (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") + (cache_dir / "cache2.traj").write_text("{}", encoding="utf-8") + + # Default format (.json) + tool_json = RetrieveCachedTestExecutions( + cache_dir=str(cache_dir), trajectories_format=".json" + ) + result_json = tool_json() + assert len(result_json) == 1 + assert "cache1.json" in result_json[0] + + # Custom format (.traj) + tool_traj = RetrieveCachedTestExecutions( + cache_dir=str(cache_dir), trajectories_format=".traj" + ) + result_traj = tool_traj() + assert len(result_traj) == 1 + assert "cache2.traj" in result_traj[0] + + +def test_execute_cached_execution_initializes_without_toolbox() -> None: + """Test that ExecuteCachedExecution can be initialized without toolbox.""" + tool = ExecuteCachedTrajectory() + assert tool.name == "execute_cached_executions_tool" + + +def test_execute_cached_execution_raises_error_without_toolbox() -> None: + """Test that ExecuteCachedExecution raises error when toolbox not set.""" + tool = ExecuteCachedTrajectory() + + with pytest.raises(RuntimeError, match="Toolbox not set"): + tool(trajectory_file="some_file.json") + + +def test_execute_cached_execution_raises_error_when_file_not_found() -> None: + """Test that ExecuteCachedExecution raises error if trajectory file doesn't exist""" + tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool.set_toolbox(mock_toolbox) + + with pytest.raises(FileNotFoundError, match="Trajectory file not found"): + tool(trajectory_file="/non/existent/file.json") + + +def test_execute_cached_execution_executes_trajectory() -> None: + """Test that ExecuteCachedExecution executes tools from trajectory file.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a trajectory file + trajectory: list[dict[str, Any]] = [ + { + "id": "tool1", + "name": "click_tool", + "input": {"x": 100, "y": 200}, + "type": "tool_use", + }, + { + "id": "tool2", + "name": "type_tool", + "input": {"text": "hello"}, + "type": "tool_use", + }, + ] + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(trajectory, f) + + # Execute the trajectory + tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool.set_toolbox(mock_toolbox) + + result = tool(trajectory_file=str(cache_file)) + + # Verify success message + assert "Successfully executed trajectory" in result + # Verify toolbox.run was called for each tool (2 calls) + assert mock_toolbox.run.call_count == 2 + + +def test_execute_cached_execution_skips_screenshot_tools() -> None: + """Test that ExecuteCachedExecution skips screenshot-related tools.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a trajectory with screenshot tools + trajectory: list[dict[str, Any]] = [ + { + "id": "tool1", + "name": "screenshot", + "input": {}, + "type": "tool_use", + }, + { + "id": "tool2", + "name": "click_tool", + "input": {"x": 100, "y": 200}, + "type": "tool_use", + }, + { + "id": "tool3", + "name": "retrieve_available_trajectories_tool", + "input": {}, + "type": "tool_use", + }, + ] + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(trajectory, f) + + # Execute the trajectory + tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool.set_toolbox(mock_toolbox) + + result = tool(trajectory_file=str(cache_file)) + + # Verify only click_tool was executed (screenshot and retrieve tools skipped) + assert mock_toolbox.run.call_count == 1 + assert "Successfully executed trajectory" in result + + +def test_execute_cached_execution_handles_errors_gracefully() -> None: + """Test that ExecuteCachedExecution handles errors during execution.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a trajectory + trajectory: list[dict[str, Any]] = [ + { + "id": "tool1", + "name": "failing_tool", + "input": {}, + "type": "tool_use", + }, + ] + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(trajectory, f) + + # Execute the trajectory with a failing tool + tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + mock_toolbox.run.side_effect = Exception("Tool execution failed") + tool.set_toolbox(mock_toolbox) + + result = tool(trajectory_file=str(cache_file)) + + # Verify error message + assert "error occured" in result.lower() + assert "verify the UI state" in result + + +def test_execute_cached_execution_set_toolbox() -> None: + """Test that set_toolbox properly sets the toolbox reference.""" + tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + + tool.set_toolbox(mock_toolbox) + + # After setting toolbox, should be able to access it + assert hasattr(tool, "_toolbox") + assert tool._toolbox == mock_toolbox + + +def test_execute_cached_execution_initializes_with_default_settings() -> None: + """Test that ExecuteCachedTrajectory uses default settings when none provided.""" + tool = ExecuteCachedTrajectory() + + # Should have default settings initialized + assert hasattr(tool, "_settings") + + +def test_execute_cached_execution_initializes_with_custom_settings() -> None: + """Test that ExecuteCachedTrajectory accepts custom settings.""" + custom_settings = CachedExecutionToolSettings(delay_time_between_action=1.0) + tool = ExecuteCachedTrajectory(settings=custom_settings) + + # Should have custom settings initialized + assert hasattr(tool, "_settings") + + +def test_execute_cached_execution_uses_delay_time_between_actions() -> None: + """Test that ExecuteCachedTrajectory uses the configured delay time.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a trajectory with 3 actions + trajectory: list[dict[str, Any]] = [ + { + "id": "tool1", + "name": "click_tool", + "input": {"x": 100, "y": 200}, + "type": "tool_use", + }, + { + "id": "tool2", + "name": "type_tool", + "input": {"text": "hello"}, + "type": "tool_use", + }, + { + "id": "tool3", + "name": "move_tool", + "input": {"x": 300, "y": 400}, + "type": "tool_use", + }, + ] + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(trajectory, f) + + # Execute with custom delay time + custom_settings = CachedExecutionToolSettings(delay_time_between_action=0.1) + tool = ExecuteCachedTrajectory(settings=custom_settings) + mock_toolbox = MagicMock(spec=ToolCollection) + tool.set_toolbox(mock_toolbox) + + # Mock time.sleep to verify it's called with correct delay + with patch("time.sleep") as mock_sleep: + result = tool(trajectory_file=str(cache_file)) + + # Verify success + assert "Successfully executed trajectory" in result + # Verify sleep was called 3 times (once after each action) + assert mock_sleep.call_count == 3 + # Verify it was called with the configured delay time + for call in mock_sleep.call_args_list: + assert call[0][0] == 0.1 + + +def test_execute_cached_execution_default_delay_time() -> None: + """Test that ExecuteCachedTrajectory uses default delay time of 0.5s.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_trajectory.json" + + # Create a trajectory with 2 actions + trajectory: list[dict[str, Any]] = [ + { + "id": "tool1", + "name": "click_tool", + "input": {"x": 100, "y": 200}, + "type": "tool_use", + }, + { + "id": "tool2", + "name": "type_tool", + "input": {"text": "hello"}, + "type": "tool_use", + }, + ] + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(trajectory, f) + + # Execute with default settings + tool = ExecuteCachedTrajectory() + mock_toolbox = MagicMock(spec=ToolCollection) + tool.set_toolbox(mock_toolbox) + + # Mock time.sleep to verify default delay is used + with patch("time.sleep") as mock_sleep: + result = tool(trajectory_file=str(cache_file)) + + # Verify success + assert "Successfully executed trajectory" in result + # Verify sleep was called with default delay of 0.5s + assert mock_sleep.call_count == 2 + for call in mock_sleep.call_args_list: + assert call[0][0] == 0.5 diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py new file mode 100644 index 00000000..2c875ae4 --- /dev/null +++ b/tests/unit/utils/test_cache_writer.py @@ -0,0 +1,312 @@ +"""Unit tests for CacheWriter utility.""" + +import json +import tempfile +from pathlib import Path +from typing import Any + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.utils.cache_writer import CacheWriter + + +def test_cache_writer_initialization() -> None: + """Test CacheWriter initialization.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + assert cache_writer.cache_dir == Path(temp_dir) + assert cache_writer.file_name == "test.json" + assert cache_writer.messages == [] + assert cache_writer.was_cached_execution is False + + +def test_cache_writer_creates_cache_directory() -> None: + """Test that CacheWriter creates the cache directory if it doesn't exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + non_existent_dir = Path(temp_dir) / "new_cache_dir" + assert not non_existent_dir.exists() + + CacheWriter(cache_dir=str(non_existent_dir)) + assert non_existent_dir.exists() + assert non_existent_dir.is_dir() + + +def test_cache_writer_adds_json_extension() -> None: + """Test that CacheWriter adds .json extension if not present.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test") + assert cache_writer.file_name == "test.json" + + cache_writer2 = CacheWriter(cache_dir=temp_dir, file_name="test.json") + assert cache_writer2.file_name == "test.json" + + +def test_cache_writer_add_message_cb_stores_tool_use_blocks() -> None: + """Test that add_message_cb stores ToolUseBlockParam from assistant messages.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + + tool_use_block = ToolUseBlockParam( + id="test_id", + name="test_tool", + input={"param": "value"}, + type="tool_use", + ) + + message = MessageParam( + role="assistant", + content=[tool_use_block], + stop_reason=None, + ) + + param = OnMessageCbParam( + message=message, + messages=[message], + ) + + result = cache_writer.add_message_cb(param) + assert result == param.message + assert len(cache_writer.messages) == 1 + assert cache_writer.messages[0] == tool_use_block + + +def test_cache_writer_add_message_cb_ignores_non_tool_use_content() -> None: + """Test that add_message_cb ignores non-ToolUseBlockParam content.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + + message = MessageParam( + role="assistant", + content="Just a text message", + stop_reason=None, + ) + + param = OnMessageCbParam( + message=message, + messages=[message], + ) + + cache_writer.add_message_cb(param) + assert len(cache_writer.messages) == 0 + + +def test_cache_writer_add_message_cb_ignores_user_messages() -> None: + """Test that add_message_cb ignores user messages.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + + message = MessageParam( + role="user", + content="User message", + stop_reason=None, + ) + + param = OnMessageCbParam( + message=message, + messages=[message], + ) + + cache_writer.add_message_cb(param) + assert len(cache_writer.messages) == 0 + + +def test_cache_writer_detects_cached_execution() -> None: + """Test that CacheWriter detects when execute_cached_executions_tool is used.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") + + tool_use_block = ToolUseBlockParam( + id="cached_exec_id", + name="execute_cached_executions_tool", + input={"trajectory_file": "test.json"}, + type="tool_use", + ) + + message = MessageParam( + role="assistant", + content=[tool_use_block], + stop_reason=None, + ) + + param = OnMessageCbParam( + message=message, + messages=[message], + ) + + cache_writer.add_message_cb(param) + assert cache_writer.was_cached_execution is True + + +def test_cache_writer_generate_writes_file() -> None: + """Test that generate() writes messages to a JSON file.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="output.json") + + # Add some tool use blocks + tool_use1 = ToolUseBlockParam( + id="id1", + name="tool1", + input={"param": "value1"}, + type="tool_use", + ) + tool_use2 = ToolUseBlockParam( + id="id2", + name="tool2", + input={"param": "value2"}, + type="tool_use", + ) + + cache_writer.messages = [tool_use1, tool_use2] + cache_writer.generate() + + # Verify file was created + cache_file = cache_dir / "output.json" + assert cache_file.exists() + + # Verify file content + with cache_file.open("r", encoding="utf-8") as f: + data = json.load(f) + + assert len(data) == 2 + assert data[0]["id"] == "id1" + assert data[0]["name"] == "tool1" + assert data[1]["id"] == "id2" + assert data[1]["name"] == "tool2" + + +def test_cache_writer_generate_auto_names_file() -> None: + """Test that generate() auto-generates filename if not provided.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="") + + tool_use = ToolUseBlockParam( + id="id1", + name="tool1", + input={}, + type="tool_use", + ) + cache_writer.messages = [tool_use] + cache_writer.generate() + + # Verify a file was created with auto-generated name + json_files = list(cache_dir.glob("*.json")) + assert len(json_files) == 1 + assert json_files[0].name.startswith("cached_trajectory_") + + +def test_cache_writer_generate_skips_cached_execution() -> None: + """Test that generate() doesn't write file for cached executions.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + + cache_writer.was_cached_execution = True + cache_writer.messages = [ + ToolUseBlockParam( + id="id1", + name="tool1", + input={}, + type="tool_use", + ) + ] + + cache_writer.generate() + + # Verify no file was created + json_files = list(cache_dir.glob("*.json")) + assert len(json_files) == 0 + + +def test_cache_writer_reset() -> None: + """Test that reset() clears messages and filename.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") + + # Add some data + cache_writer.messages = [ + ToolUseBlockParam( + id="id1", + name="tool1", + input={}, + type="tool_use", + ) + ] + cache_writer.was_cached_execution = True + + # Reset + cache_writer.reset(file_name="new.json") + + assert cache_writer.messages == [] + assert cache_writer.file_name == "new.json" + assert cache_writer.was_cached_execution is False + + +def test_cache_writer_read_cache_file() -> None: + """Test that read_cache_file() loads ToolUseBlockParam from JSON.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_file = Path(temp_dir) / "test_cache.json" + + # Create a cache file + trajectory: list[dict[str, Any]] = [ + { + "id": "id1", + "name": "tool1", + "input": {"param": "value1"}, + "type": "tool_use", + }, + { + "id": "id2", + "name": "tool2", + "input": {"param": "value2"}, + "type": "tool_use", + }, + ] + + with cache_file.open("w", encoding="utf-8") as f: + json.dump(trajectory, f) + + # Read cache file + result = CacheWriter.read_cache_file(cache_file) + + assert len(result) == 2 + assert isinstance(result[0], ToolUseBlockParam) + assert result[0].id == "id1" + assert result[0].name == "tool1" + assert isinstance(result[1], ToolUseBlockParam) + assert result[1].id == "id2" + assert result[1].name == "tool2" + + +def test_cache_writer_set_file_name() -> None: + """Test that set_file_name() updates the filename.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") + + cache_writer.set_file_name("new_name") + assert cache_writer.file_name == "new_name.json" + + cache_writer.set_file_name("another.json") + assert cache_writer.file_name == "another.json" + + +def test_cache_writer_generate_resets_after_writing() -> None: + """Test that generate() calls reset() after writing the file.""" + with tempfile.TemporaryDirectory() as temp_dir: + cache_dir = Path(temp_dir) + cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") + + cache_writer.messages = [ + ToolUseBlockParam( + id="id1", + name="tool1", + input={}, + type="tool_use", + ) + ] + + cache_writer.generate() + + # After generate, messages should be empty + assert cache_writer.messages == []