diff --git a/.gitignore b/.gitignore index f2356f53..3f67f623 100644 --- a/.gitignore +++ b/.gitignore @@ -169,5 +169,6 @@ reports/ /askui_chat.db-shm /askui_chat.db-wal .cache/ +.askui_cache/* bom.json diff --git a/README.md b/README.md index 6937def6..b73c1435 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ Ready to build your first agent? Check out our documentation: 9. **[Reporting](docs/08_reporting.md)** - Obtain agent logs as execution reports and summaries as test reports 10. **[Observability](docs/09_observability_telemetry_tracing.md)** - Monitor and debug agents 11. **[Extracting Data](docs/10_extracting_data.md)** - Extracting structured data from screenshots and files +12. **[Callbacks](docs/11_callbacks.md)** - Inject custom logic into the control loop **Official documentation:** [docs.askui.com](https://docs.askui.com) diff --git a/docs/00_overview.md b/docs/00_overview.md index ca72f494..a048598a 100644 --- a/docs/00_overview.md +++ b/docs/00_overview.md @@ -90,6 +90,9 @@ Understand what data is collected and how to opt out. ### 10 - Extracting Data **Topics**: Using `get()`, file support (PDF. Excel, Word, CSV), structured data extraction, response schemas +### 11 - Callbacks +**Topics**: Inject custom logic at different positions of the control loop through callbacks + Extract information from screens and files using the `get()` method with Pydantic models. ## Additional Resources diff --git a/docs/06_caching.md b/docs/06_caching.md index 975990dd..44ded425 100644 --- a/docs/06_caching.md +++ b/docs/06_caching.md @@ -8,12 +8,12 @@ The caching system works by recording all tool use actions (mouse movements, cli ## Caching Strategies -The caching mechanism supports four strategies, configured via the `caching_settings` parameter in the `act()` method: +The caching mechanism supports three strategies, configured via the `caching_settings` parameter in the `act()` method: -- **`"no"`** (default): No caching is used. The agent executes normally without recording or replaying actions. -- **`"write"`**: Records all agent actions to a cache file for future replay. -- **`"read"`**: Provides tools to the agent to list and execute previously cached trajectories. -- **`"both"`**: Combines read and write modes - the agent can use existing cached trajectories and will also record new ones. +- **`None`** (default): No caching is used. The agent executes normally without recording or replaying actions. +- **`"record"`**: Records all agent actions to a cache file for future replay. +- **`"execute"`**: Provides tools to the agent to list and execute previously cached trajectories. +- **`"auto"`**: Combines execute and record modes - the agent can use existing cached trajectories and will also record new ones. ## Configuration @@ -23,20 +23,20 @@ Caching is configured using the `CachingSettings` class: from askui.models.shared.settings import CachingSettings, CachedExecutionToolSettings caching_settings = CachingSettings( - strategy="write", # One of: "read", "write", "both", "no" + strategy="record", # One of: "execute", "record", "auto", or None cache_dir=".cache", # Directory to store cache files - filename="my_test.json", # Filename for the cache file (optional for write mode) + filename="my_test.json", # Filename for the cache file (optional for record mode) execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( - delay_time_between_action=0.5 # Delay in seconds between each cached action + delay_time_between_actions=0.5 # Delay in seconds between each cached action ) ) ``` ### Parameters -- **`strategy`**: The caching strategy to use (`"read"`, `"write"`, `"both"`, or `"no"`). +- **`strategy`**: The caching strategy to use (`"execute"`, `"record"`, `"auto"`, or `None`). - **`cache_dir`**: Directory where cache files are stored. Defaults to `".cache"`. -- **`filename`**: Name of the cache file to write to or read from. If not specified in write mode, a timestamped filename will be generated automatically (format: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). +- **`filename`**: Name of the cache file to write to or read from. If not specified in record mode, a timestamped filename will be generated automatically (format: `cached_trajectory_YYYYMMDDHHMMSSffffff.json`). - **`execute_cached_trajectory_tool_settings`**: Configuration for the trajectory execution tool (optional). See [Execution Settings](#execution-settings) below. ### Execution Settings @@ -47,13 +47,13 @@ The `CachedExecutionToolSettings` class allows you to configure how cached traje from askui.models.shared.settings import CachedExecutionToolSettings execution_settings = CachedExecutionToolSettings( - delay_time_between_action=0.5 # Delay in seconds between each action (default: 0.5) + delay_time_between_actions=0.5 # Delay in seconds between each action (default: 0.5) ) ``` #### Parameters -- **`delay_time_between_action`**: The time to wait (in seconds) between executing consecutive cached actions. This delay helps ensure UI elements can materialize before the next action is executed. Defaults to `0.5` seconds. +- **`delay_time_between_actions`**: The time to wait (in seconds) between executing consecutive cached actions. This delay helps ensure UI elements can materialize before the next action is executed. Defaults to `0.5` seconds. You can adjust this value based on your application's responsiveness: - For faster applications or quick interactions, you might use a smaller delay (e.g., `0.1` or `0.2` seconds) @@ -61,7 +61,7 @@ You can adjust this value based on your application's responsiveness: ## Usage Examples -### Writing a Cache (Recording) +### Recording a Cache Record agent actions to a cache file for later replay: @@ -73,7 +73,7 @@ with ComputerAgent() as agent: agent.act( goal="Fill out the login form with username 'admin' and password 'secret123'", caching_settings=CachingSettings( - strategy="write", # you could also use "both" here + strategy="record", # you could also use "auto" here cache_dir=".cache", filename="login_test.json" ) @@ -82,7 +82,7 @@ with ComputerAgent() as agent: After execution, a cache file will be created at `.cache/login_test.json` containing all the tool use actions performed by the agent. -### Reading from Cache (Replaying) +### Executing from Cache (Replaying) Provide the agent with access to previously recorded trajectories: @@ -94,13 +94,13 @@ with ComputerAgent() as agent: agent.act( goal="Fill out the login form", caching_settings=CachingSettings( - strategy="read", # you could also use "both" here + strategy="execute", # you could also use "auto" here cache_dir=".cache" ) ) ``` -When using `strategy="read"`, the agent receives two additional tools: +When using `strategy="execute"`, the agent receives two additional tools: 1. **`retrieve_available_trajectories_tool`**: Lists all available cache files in the cache directory 2. **`execute_cached_executions_tool`**: Executes a specific cached trajectory @@ -109,7 +109,7 @@ The agent will automatically check if a relevant cached trajectory exists and us ### Referencing Cache Files in Goal Prompts -When using `strategy="read"` or `strategy="both"`, **you need to inform the agent about which cache files are available and when to use them**. This is done by including cache file information directly in your goal prompt. +When using `strategy="execute"` or `strategy="auto"`, **you need to inform the agent about which cache files are available and when to use them**. This is done by including cache file information directly in your goal prompt. #### Explicit Cache File References @@ -126,7 +126,7 @@ with ComputerAgent() as agent: If the cache file "open_website_in_chrome.json" is available, please use it for this execution. It will open a new window in Chrome and navigate to the website.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -149,7 +149,7 @@ with ComputerAgent() as agent: Check if a cache file named "{test_id}.json" exists. If it does, use it to replay the test actions, then verify the results.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir="test_cache" ) ) @@ -171,7 +171,7 @@ with ComputerAgent() as agent: Choose the most recent one if multiple are available, as it likely contains the most up-to-date interaction sequence.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -195,7 +195,7 @@ with ComputerAgent() as agent: After each cached execution, verify the step completed successfully before proceeding.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) @@ -219,10 +219,10 @@ with ComputerAgent() as agent: agent.act( goal="Fill out the login form", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache", execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( - delay_time_between_action=1.0 # Wait 1 second between each action + delay_time_between_actions=1.0 # Wait 1 second between each action ) ) ) @@ -233,7 +233,7 @@ This is particularly useful when: - UI elements take time to become interactive after appearing - You're testing on slower hardware or environments -### Using Both Strategies +### Using Auto Strategy Enable both reading and writing simultaneously: @@ -245,7 +245,7 @@ with ComputerAgent() as agent: agent.act( goal="Complete the checkout process", caching_settings=CachingSettings( - strategy="both", + strategy="auto", cache_dir=".cache", filename="checkout_test.json" ) @@ -323,7 +323,7 @@ The delay between actions can be customized using `CachedExecutionToolSettings` ## Limitations - **UI State Sensitivity**: Cached trajectories assume the UI is in the same state as when they were recorded. If the UI has changed, the replay may fail or produce incorrect results. -- **No on_message Callback**: When using `strategy="write"` or `strategy="both"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. +- **No on_message Callback**: When using `strategy="record"` or `strategy="auto"`, you cannot provide a custom `on_message` callback, as the caching system uses this callback to record actions. - **Verification Required**: After executing a cached trajectory, the agent should verify that the results are correct, as UI changes may cause partial failures. ## Example: Complete Test Workflow @@ -340,7 +340,7 @@ with ComputerAgent() as agent: agent.act( goal="Navigate to the login page and log in with username 'testuser' and password 'testpass123'", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir="test_cache", filename="user_login.json" ) @@ -356,10 +356,10 @@ with ComputerAgent() as agent: the login sequence. It contains the steps to navigate to the login page and authenticate with the test credentials.""", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir="test_cache", execute_cached_trajectory_tool_settings=CachedExecutionToolSettings( - delay_time_between_action=1.0 + delay_time_between_actions=1.0 ) ) ) diff --git a/docs/11_callbacks.md b/docs/11_callbacks.md new file mode 100644 index 00000000..abd98382 --- /dev/null +++ b/docs/11_callbacks.md @@ -0,0 +1,83 @@ +# Callbacks + +Callbacks provide hooks into the agent's conversation lifecycle, similar to PyTorch Lightning's callback system. Use them for logging, monitoring, custom metrics, or extending agent behavior. + +## Usage + +Subclass `ConversationCallback` and override the hooks you need: + +```python +from askui import ComputerAgent, ConversationCallback + +class MetricsCallback(ConversationCallback): + def on_step_start(self, conversation, step_index): + print(f"Step {step_index} starting...") + + def on_step_end(self, conversation, step_index, result): + print(f"Step {step_index} finished: {result.status}") + +with ComputerAgent(callbacks=[MetricsCallback()]) as agent: + agent.act("Open the settings menu") +``` + +## Available Hooks + +| Hook | When Called | Parameters | +|------|-------------|------------| +| `on_conversation_start` | After setup, before control loop | `conversation` | +| `on_conversation_end` | After control loop, before cleanup | `conversation` | +| `on_control_loop_start` | Before the iteration loop begins | `conversation` | +| `on_control_loop_end` | After the iteration loop ends | `conversation` | +| `on_step_start` | Before each step execution | `conversation`, `step_index` | +| `on_step_end` | After each step execution | `conversation`, `step_index`, `result` | +| `_on_speaker_switch` | After a speaker switch | `from_speaker`, `to_speaker` | +| `on_tool_execution_start` | Before tools are executed | `conversation`, `tool_names` | +| `on_tool_execution_end` | After tools are executed | `conversation`, `tool_names` | + +### Parameters + +- **`conversation`**: The `Conversation` instance with access to messages, settings, and state +- **`step_index`**: Zero-based index of the current step +- **`result`**: `SpeakerResult` containing `status`, `messages_to_add`, and `usage` +- **`tool_names`**: List of tool names being executed + +## Example: Timing Callback + +```python +import time +from askui import ComputerAgent, ConversationCallback + +class TimingCallback(ConversationCallback): + def __init__(self): + self.start_time = None + self.step_times = [] + + def on_conversation_start(self, conversation): + self.start_time = time.time() + + def on_step_start(self, conversation, step_index): + self._step_start = time.time() + + def on_step_end(self, conversation, step_index, result): + elapsed = time.time() - self._step_start + self.step_times.append(elapsed) + print(f"Step {step_index}: {elapsed:.2f}s") + + def on_conversation_end(self, conversation): + total = time.time() - self.start_time + print(f"Total: {total:.2f}s across {len(self.step_times)} steps") + +with ComputerAgent(callbacks=[TimingCallback()]) as agent: + agent.act("Search for documents") +``` + +## Multiple Callbacks + +Pass multiple callbacks to combine behaviors: + +```python +with ComputerAgent(callbacks=[TimingCallback(), MetricsCallback()]) as agent: + agent.act("Complete the form") +``` + +Callbacks are called in the order they are provided. diff --git a/pdm.lock b/pdm.lock index f4f14cb6..bad16b37 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "all", "android", "bedrock", "dev", "pynput", "vertex", "web"] +groups = ["default", "all", "android", "bedrock", "dev", "vertex", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:0af1559aa395b49a67778024f8c05e96065d40d8c1c8dd0f3aa11d2542ca176a" +content_hash = "sha256:860b7990d08be9d842b6e507d955fd350dad1e81c45fcf4cce4627b44506be68" [[metadata.targets]] requires_python = ">=3.10,<3.14" @@ -1010,17 +1010,6 @@ files = [ {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, ] -[[package]] -name = "evdev" -version = "1.9.2" -requires_python = ">=3.8" -summary = "Bindings to the Linux input handling subsystem" -groups = ["all", "pynput"] -marker = "\"linux\" in sys_platform" -files = [ - {file = "evdev-1.9.2.tar.gz", hash = "sha256:5d3278892ce1f92a74d6bf888cc8525d9f68af85dbe336c95d1c87fb8f423069"}, -] - [[package]] name = "exceptiongroup" version = "1.3.0" @@ -1791,6 +1780,36 @@ files = [ {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, ] +[[package]] +name = "imagehash" +version = "4.3.2" +summary = "Image Hashing library" +groups = ["default"] +dependencies = [ + "PyWavelets", + "numpy", + "pillow", + "scipy", +] +files = [ + {file = "ImageHash-4.3.2-py2.py3-none-any.whl", hash = "sha256:02b0f965f8c77cd813f61d7d39031ea27d4780e7ebcad56c6cd6a709acc06e5f"}, + {file = "ImageHash-4.3.2.tar.gz", hash = "sha256:e54a79805afb82a34acde4746a16540503a9636fd1ffb31d8e099b29bbbf8156"}, +] + +[[package]] +name = "importlib-metadata" +version = "8.7.1" +requires_python = ">=3.9" +summary = "Read metadata from Python packages" +groups = ["default"] +dependencies = [ + "zipp>=3.20", +] +files = [ + {file = "importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151"}, + {file = "importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb"}, +] + [[package]] name = "inflect" version = "7.5.0" @@ -2468,17 +2487,6 @@ files = [ {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, ] -[[package]] -name = "mss" -version = "10.1.0" -requires_python = ">=3.9" -summary = "An ultra fast cross-platform multiple screenshots module in pure python using ctypes." -groups = ["all", "pynput"] -files = [ - {file = "mss-10.1.0-py3-none-any.whl", hash = "sha256:9179c110cadfef5dc6dc4a041a0cd161c74c379218648e6640b48c6b5cfe8918"}, - {file = "mss-10.1.0.tar.gz", hash = "sha256:7182baf7ee16ca569e2804028b6ab9bcbf6be5c46fc2880840f33b513b9cb4f8"}, -] - [[package]] name = "mypy" version = "1.18.2" @@ -2744,6 +2752,52 @@ files = [ {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, ] +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +requires_python = ">=3.9" +summary = "OpenTelemetry Python API" +groups = ["default"] +dependencies = [ + "importlib-metadata<8.8.0,>=6.0", + "typing-extensions>=4.5.0", +] +files = [ + {file = "opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950"}, + {file = "opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c"}, +] + +[[package]] +name = "opentelemetry-sdk" +version = "1.39.1" +requires_python = ">=3.9" +summary = "OpenTelemetry Python SDK" +groups = ["default"] +dependencies = [ + "opentelemetry-api==1.39.1", + "opentelemetry-semantic-conventions==0.60b1", + "typing-extensions>=4.5.0", +] +files = [ + {file = "opentelemetry_sdk-1.39.1-py3-none-any.whl", hash = "sha256:4d5482c478513ecb0a5d938dcc61394e647066e0cc2676bee9f3af3f3f45f01c"}, + {file = "opentelemetry_sdk-1.39.1.tar.gz", hash = "sha256:cf4d4563caf7bff906c9f7967e2be22d0d6b349b908be0d90fb21c8e9c995cc6"}, +] + +[[package]] +name = "opentelemetry-semantic-conventions" +version = "0.60b1" +requires_python = ">=3.9" +summary = "OpenTelemetry Semantic Conventions" +groups = ["default"] +dependencies = [ + "opentelemetry-api==1.39.1", + "typing-extensions>=4.5.0", +] +files = [ + {file = "opentelemetry_semantic_conventions-0.60b1-py3-none-any.whl", hash = "sha256:9fa8c8b0c110da289809292b0591220d3a7b53c1526a23021e977d68597893fb"}, + {file = "opentelemetry_semantic_conventions-0.60b1.tar.gz", hash = "sha256:87c228b5a0669b748c76d76df6c364c369c28f1c465e50f661e39737e84bc953"}, +] + [[package]] name = "packageurl-python" version = "0.17.6" @@ -3289,132 +3343,6 @@ files = [ {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, ] -[[package]] -name = "pynput" -version = "1.8.1" -summary = "Monitor and control user input devices" -groups = ["all", "pynput"] -dependencies = [ - "enum34; python_version == \"2.7\"", - "evdev>=1.3; \"linux\" in sys_platform", - "pyobjc-framework-ApplicationServices>=8.0; sys_platform == \"darwin\"", - "pyobjc-framework-Quartz>=8.0; sys_platform == \"darwin\"", - "python-xlib>=0.17; \"linux\" in sys_platform", - "six", -] -files = [ - {file = "pynput-1.8.1-py2.py3-none-any.whl", hash = "sha256:42dfcf27404459ca16ca889c8fb8ffe42a9fe54f722fd1a3e130728e59e768d2"}, - {file = "pynput-1.8.1.tar.gz", hash = "sha256:70d7c8373ee98911004a7c938742242840a5628c004573d84ba849d4601df81e"}, -] - -[[package]] -name = "pyobjc-core" -version = "11.1" -requires_python = ">=3.8" -summary = "Python<->ObjC Interoperability Module" -groups = ["all", "pynput"] -marker = "sys_platform == \"darwin\"" -files = [ - {file = "pyobjc_core-11.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4c7536f3e94de0a3eae6bb382d75f1219280aa867cdf37beef39d9e7d580173c"}, - {file = "pyobjc_core-11.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ec36680b5c14e2f73d432b03ba7c1457dc6ca70fa59fd7daea1073f2b4157d33"}, - {file = "pyobjc_core-11.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:765b97dea6b87ec4612b3212258024d8496ea23517c95a1c5f0735f96b7fd529"}, - {file = "pyobjc_core-11.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:18986f83998fbd5d3f56d8a8428b2f3e0754fd15cef3ef786ca0d29619024f2c"}, - {file = "pyobjc_core-11.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:8849e78cfe6595c4911fbba29683decfb0bf57a350aed8a43316976ba6f659d2"}, - {file = "pyobjc_core-11.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8cb9ed17a8d84a312a6e8b665dd22393d48336ea1d8277e7ad20c19a38edf731"}, - {file = "pyobjc_core-11.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:f2455683e807f8541f0d83fbba0f5d9a46128ab0d5cc83ea208f0bec759b7f96"}, - {file = "pyobjc_core-11.1.tar.gz", hash = "sha256:b63d4d90c5df7e762f34739b39cc55bc63dbcf9fb2fb3f2671e528488c7a87fe"}, -] - -[[package]] -name = "pyobjc-framework-applicationservices" -version = "11.1" -requires_python = ">=3.9" -summary = "Wrappers for the framework ApplicationServices on macOS" -groups = ["all", "pynput"] -marker = "sys_platform == \"darwin\"" -dependencies = [ - "pyobjc-core>=11.1", - "pyobjc-framework-Cocoa>=11.1", - "pyobjc-framework-CoreText>=11.1", - "pyobjc-framework-Quartz>=11.1", -] -files = [ - {file = "pyobjc_framework_applicationservices-11.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:89aa713f16f1de66efd82f3be77c632ad1068e51e0ef0c2b0237ac7c7f580814"}, - {file = "pyobjc_framework_applicationservices-11.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:cf45d15eddae36dec2330a9992fc852476b61c8f529874b9ec2805c768a75482"}, - {file = "pyobjc_framework_applicationservices-11.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f4a85ccd78bab84f7f05ac65ff9be117839dfc09d48c39edd65c617ed73eb01c"}, - {file = "pyobjc_framework_applicationservices-11.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:385a89f4d0838c97a331e247519d9e9745aa3f7427169d18570e3c664076a63c"}, - {file = "pyobjc_framework_applicationservices-11.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:f480fab20f3005e559c9d06c9a3874a1f1c60dde52c6d28a53ab59b45e79d55f"}, - {file = "pyobjc_framework_applicationservices-11.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:e8dee91c6a14fd042f98819dc0ac4a182e0e816282565534032f0e544bfab143"}, - {file = "pyobjc_framework_applicationservices-11.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:a0ce40a57a9b993793b6f72c4fd93f80618ef54a69d76a1da97b8360a2f3ffc5"}, - {file = "pyobjc_framework_applicationservices-11.1.tar.gz", hash = "sha256:03fcd8c0c600db98fa8b85eb7b3bc31491701720c795e3f762b54e865138bbaf"}, -] - -[[package]] -name = "pyobjc-framework-cocoa" -version = "11.1" -requires_python = ">=3.9" -summary = "Wrappers for the Cocoa frameworks on macOS" -groups = ["all", "pynput"] -marker = "sys_platform == \"darwin\"" -dependencies = [ - "pyobjc-core>=11.1", -] -files = [ - {file = "pyobjc_framework_cocoa-11.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b27a5bdb3ab6cdeb998443ff3fce194ffae5f518c6a079b832dbafc4426937f9"}, - {file = "pyobjc_framework_cocoa-11.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7b9a9b8ba07f5bf84866399e3de2aa311ed1c34d5d2788a995bdbe82cc36cfa0"}, - {file = "pyobjc_framework_cocoa-11.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:806de56f06dfba8f301a244cce289d54877c36b4b19818e3b53150eb7c2424d0"}, - {file = "pyobjc_framework_cocoa-11.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:54e93e1d9b0fc41c032582a6f0834befe1d418d73893968f3f450281b11603da"}, - {file = "pyobjc_framework_cocoa-11.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:fd5245ee1997d93e78b72703be1289d75d88ff6490af94462b564892e9266350"}, - {file = "pyobjc_framework_cocoa-11.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:aede53a1afc5433e1e7d66568cc52acceeb171b0a6005407a42e8e82580b4fc0"}, - {file = "pyobjc_framework_cocoa-11.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:1b5de4e1757bb65689d6dc1f8d8717de9ec8587eb0c4831c134f13aba29f9b71"}, - {file = "pyobjc_framework_cocoa-11.1.tar.gz", hash = "sha256:87df76b9b73e7ca699a828ff112564b59251bb9bbe72e610e670a4dc9940d038"}, -] - -[[package]] -name = "pyobjc-framework-coretext" -version = "11.1" -requires_python = ">=3.9" -summary = "Wrappers for the framework CoreText on macOS" -groups = ["all", "pynput"] -marker = "sys_platform == \"darwin\"" -dependencies = [ - "pyobjc-core>=11.1", - "pyobjc-framework-Cocoa>=11.1", - "pyobjc-framework-Quartz>=11.1", -] -files = [ - {file = "pyobjc_framework_coretext-11.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:515be6beb48c084ee413c00c4e9fbd6e730c1b8a24270f4c618fc6c7ba0011ce"}, - {file = "pyobjc_framework_coretext-11.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4f4d2d2a6331fa64465247358d7aafce98e4fb654b99301a490627a073d021e"}, - {file = "pyobjc_framework_coretext-11.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1597bf7234270ee1b9963bf112e9061050d5fb8e1384b3f50c11bde2fe2b1570"}, - {file = "pyobjc_framework_coretext-11.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:37e051e8f12a0f47a81b8efc8c902156eb5bc3d8123c43e5bd4cebd24c222228"}, - {file = "pyobjc_framework_coretext-11.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:56a3a02202e0d50be3c43e781c00f9f1859ab9b73a8342ff56260b908e911e37"}, - {file = "pyobjc_framework_coretext-11.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:15650ba99692d00953e91e53118c11636056a22c90d472020f7ba31500577bf5"}, - {file = "pyobjc_framework_coretext-11.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:fb27f66a56660c31bb956191d64b85b95bac99cfb833f6e99622ca0ac4b3ba12"}, - {file = "pyobjc_framework_coretext-11.1.tar.gz", hash = "sha256:a29bbd5d85c77f46a8ee81d381b847244c88a3a5a96ac22f509027ceceaffaf6"}, -] - -[[package]] -name = "pyobjc-framework-quartz" -version = "11.1" -requires_python = ">=3.9" -summary = "Wrappers for the Quartz frameworks on macOS" -groups = ["all", "pynput"] -marker = "sys_platform == \"darwin\"" -dependencies = [ - "pyobjc-core>=11.1", - "pyobjc-framework-Cocoa>=11.1", -] -files = [ - {file = "pyobjc_framework_quartz-11.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b5ef75c416b0209e25b2eb07a27bd7eedf14a8c6b2f968711969d45ceceb0f84"}, - {file = "pyobjc_framework_quartz-11.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:2d501fe95ef15d8acf587cb7dc4ab4be3c5a84e2252017da8dbb7df1bbe7a72a"}, - {file = "pyobjc_framework_quartz-11.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9ac806067541917d6119b98d90390a6944e7d9bd737f5c0a79884202327c9204"}, - {file = "pyobjc_framework_quartz-11.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:43a1138280571bbf44df27a7eef519184b5c4183a588598ebaaeb887b9e73e76"}, - {file = "pyobjc_framework_quartz-11.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b23d81c30c564adf6336e00b357f355b35aad10075dd7e837cfd52a9912863e5"}, - {file = "pyobjc_framework_quartz-11.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:07cbda78b4a8fcf3a2d96e047a2ff01f44e3e1820f46f0f4b3b6d77ff6ece07c"}, - {file = "pyobjc_framework_quartz-11.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:39d02a3df4b5e3eee1e0da0fb150259476910d2a9aa638ab94153c24317a9561"}, - {file = "pyobjc_framework_quartz-11.1.tar.gz", hash = "sha256:a57f35ccfc22ad48c87c5932818e583777ff7276605fef6afad0ac0741169f75"}, -] - [[package]] name = "pyparsing" version = "3.3.1" @@ -3579,20 +3507,6 @@ files = [ {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, ] -[[package]] -name = "python-xlib" -version = "0.33" -summary = "Python X Library" -groups = ["all", "pynput"] -marker = "\"linux\" in sys_platform" -dependencies = [ - "six>=1.10.0", -] -files = [ - {file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"}, - {file = "python_xlib-0.33-py2.py3-none-any.whl", hash = "sha256:c3534038d42e0df2f1392a1b30a15a4ff5fdc2b86cfa94f072bf11b10a164398"}, -] - [[package]] name = "pytokens" version = "0.1.10" @@ -3614,6 +3528,56 @@ files = [ {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, ] +[[package]] +name = "pywavelets" +version = "1.8.0" +requires_python = ">=3.10" +summary = "PyWavelets, wavelet transform module" +groups = ["default"] +dependencies = [ + "numpy<3,>=1.23", +] +files = [ + {file = "pywavelets-1.8.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f5c86fcb203c8e61d1f3d4afbfc08d626c64e4e3708207315577264c724632bf"}, + {file = "pywavelets-1.8.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fafb5fa126277e1690c3d6329287122fc08e4d25a262ce126e3d81b1f5709308"}, + {file = "pywavelets-1.8.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dec23dfe6d5a3f4312b12456b8c546aa90a11c1138e425a885987505f0658ae0"}, + {file = "pywavelets-1.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:880a0197e9fa108939af50a95e97c1bf9b7d3e148e0fad92ea60a9ed8c8947c0"}, + {file = "pywavelets-1.8.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8bfa833d08b60d0bf53a7939fbbf3d98015dd34efe89cbe4e53ced880d085fc1"}, + {file = "pywavelets-1.8.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e10c3fc7f4a796e94da4bca9871be2186a7bb7a3b3536a0ca9376d84263140f0"}, + {file = "pywavelets-1.8.0-cp310-cp310-win32.whl", hash = "sha256:31baf4be6940fde72cc85663154360857ac1b93c251822deaf72bb804da95031"}, + {file = "pywavelets-1.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:560c39f1ff8cb37f8b8ea4b7b6eb8a14f6926c11f5cf8c09f013a58f895ed5bc"}, + {file = "pywavelets-1.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e8dd5be4faed994581a8a4b3c0169be20567a9346e523f0b57f903c8f6722bce"}, + {file = "pywavelets-1.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8d8abaf7c120b151ef309c9ff57e0a44ba9febf49045056dbc1577526ecec6c8"}, + {file = "pywavelets-1.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b43a4c58707b1e8d941bec7f1d83e67c482278575ff0db3189d5c0dfae23a57"}, + {file = "pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1c1aad0b97714e3079a2bfe48e4fb8ccd60778d0427e9ee5e0a9ff922e6c61e4"}, + {file = "pywavelets-1.8.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a0e1db96dcf3ce08156859df8b359e9ff66fa15061a1b90e70e020bf4cd077a0"}, + {file = "pywavelets-1.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e62c8fb52ab0e8ff212fff9acae681a8f12d68b76c36fe24cc48809d5b6825ba"}, + {file = "pywavelets-1.8.0-cp311-cp311-win32.whl", hash = "sha256:bf327528d10de471b04bb725c4e10677fac5a49e13d41bf0d0b3a1f6d7097abf"}, + {file = "pywavelets-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:3814d354dd109e244ffaac3d480d29a5202212fe24570c920268237c8d276f95"}, + {file = "pywavelets-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3f431c9e2aff1a2240765eff5e804975d0fcc24c82d6f3d4271243f228e5963b"}, + {file = "pywavelets-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e39b0e2314e928cb850ee89b9042733a10ea044176a495a54dc84d2c98407a51"}, + {file = "pywavelets-1.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cae701117f5c7244b7c8d48b9e92a0289637cdc02a9c205e8be83361f0c11fae"}, + {file = "pywavelets-1.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649936baee933e80083788e0adc4d8bc2da7cdd8b10464d3b113475be2cc5308"}, + {file = "pywavelets-1.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8c68e9d072c536bc646e8bdce443bb1826eeb9aa21b2cb2479a43954dea692a3"}, + {file = "pywavelets-1.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:63f67fa2ee1610445de64f746fb9c1df31980ad13d896ea2331fc3755f49b3ae"}, + {file = "pywavelets-1.8.0-cp312-cp312-win32.whl", hash = "sha256:4b3c2ab669c91e3474fd63294355487b7dd23f0b51d32f811327ddf3546f4f3d"}, + {file = "pywavelets-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:810a23a631da596fef7196ddec49b345b1aab13525bb58547eeebe1769edbbc1"}, + {file = "pywavelets-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:441ba45c8dff8c6916dbe706958d0d7f91da675695ca0c0d75e483f6f52d0a12"}, + {file = "pywavelets-1.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:24bb282bab09349d9d128ed0536fa50fff5c2147891971a69c2c36155dfeeeac"}, + {file = "pywavelets-1.8.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:426ff3799446cb4da1db04c2084e6e58edfe24225596805665fd39c14f53dece"}, + {file = "pywavelets-1.8.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa0607a9c085b8285bc0d04e33d461a6c80f8c325389221ffb1a45141861138e"}, + {file = "pywavelets-1.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d31c36a39110e8fcc7b1a4a11cfed7d22b610c285d3e7f4fe73ec777aa49fa39"}, + {file = "pywavelets-1.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fa7c68ed1e5bab23b1bafe60ccbcf709b878652d03de59e961baefa5210fcd0a"}, + {file = "pywavelets-1.8.0-cp313-cp313-win32.whl", hash = "sha256:2c6b359b55d713ef683e9da1529181b865a80d759881ceb9adc1c5742e4da4d8"}, + {file = "pywavelets-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:4dbebcfd55ea8a85b7fc8802d411e75337170422abf6e96019d7e46c394e80e5"}, + {file = "pywavelets-1.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2e1c79784bebeafd3715c1bea6621daa2e2e6ed37b687719322e2078fb35bb70"}, + {file = "pywavelets-1.8.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f489380c95013cc8fb3ef338f6d8c1a907125db453cc4dc739e2cca06fcd8b6"}, + {file = "pywavelets-1.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:06786201a91b5e74540f4f3c115c49a29190de2eb424823abbd3a1fd75ea3e28"}, + {file = "pywavelets-1.8.0-cp313-cp313t-win32.whl", hash = "sha256:f2877fb7b58c94211257dcf364b204d6ed259146fc87d5a90bf9d93c97af6226"}, + {file = "pywavelets-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ec5d723c3335ff8aa630fd4b14097077f12cc02893c91cafd60dd7b1730e780f"}, + {file = "pywavelets-1.8.0.tar.gz", hash = "sha256:f3800245754840adc143cbc29534a1b8fc4b8cff6e9d403326bd52b7bb5c35aa"}, +] + [[package]] name = "pywin32" version = "311" @@ -3979,6 +3943,64 @@ files = [ {file = "s3transfer-0.14.0.tar.gz", hash = "sha256:eff12264e7c8b4985074ccce27a3b38a485bb7f7422cc8046fee9be4983e4125"}, ] +[[package]] +name = "scipy" +version = "1.15.3" +requires_python = ">=3.10" +summary = "Fundamental algorithms for scientific computing in Python" +groups = ["default"] +dependencies = [ + "numpy<2.5,>=1.23.5", +] +files = [ + {file = "scipy-1.15.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a345928c86d535060c9c2b25e71e87c39ab2f22fc96e9636bd74d1dbf9de448c"}, + {file = "scipy-1.15.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ad3432cb0f9ed87477a8d97f03b763fd1d57709f1bbde3c9369b1dff5503b253"}, + {file = "scipy-1.15.3-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:aef683a9ae6eb00728a542b796f52a5477b78252edede72b8327a886ab63293f"}, + {file = "scipy-1.15.3-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:1c832e1bd78dea67d5c16f786681b28dd695a8cb1fb90af2e27580d3d0967e92"}, + {file = "scipy-1.15.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:263961f658ce2165bbd7b99fa5135195c3a12d9bef045345016b8b50c315cb82"}, + {file = "scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e2abc762b0811e09a0d3258abee2d98e0c703eee49464ce0069590846f31d40"}, + {file = "scipy-1.15.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:ed7284b21a7a0c8f1b6e5977ac05396c0d008b89e05498c8b7e8f4a1423bba0e"}, + {file = "scipy-1.15.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5380741e53df2c566f4d234b100a484b420af85deb39ea35a1cc1be84ff53a5c"}, + {file = "scipy-1.15.3-cp310-cp310-win_amd64.whl", hash = "sha256:9d61e97b186a57350f6d6fd72640f9e99d5a4a2b8fbf4b9ee9a841eab327dc13"}, + {file = "scipy-1.15.3-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:993439ce220d25e3696d1b23b233dd010169b62f6456488567e830654ee37a6b"}, + {file = "scipy-1.15.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:34716e281f181a02341ddeaad584205bd2fd3c242063bd3423d61ac259ca7eba"}, + {file = "scipy-1.15.3-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3b0334816afb8b91dab859281b1b9786934392aa3d527cd847e41bb6f45bee65"}, + {file = "scipy-1.15.3-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:6db907c7368e3092e24919b5e31c76998b0ce1684d51a90943cb0ed1b4ffd6c1"}, + {file = "scipy-1.15.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:721d6b4ef5dc82ca8968c25b111e307083d7ca9091bc38163fb89243e85e3889"}, + {file = "scipy-1.15.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39cb9c62e471b1bb3750066ecc3a3f3052b37751c7c3dfd0fd7e48900ed52982"}, + {file = "scipy-1.15.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:795c46999bae845966368a3c013e0e00947932d68e235702b5c3f6ea799aa8c9"}, + {file = "scipy-1.15.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:18aaacb735ab38b38db42cb01f6b92a2d0d4b6aabefeb07f02849e47f8fb3594"}, + {file = "scipy-1.15.3-cp311-cp311-win_amd64.whl", hash = "sha256:ae48a786a28412d744c62fd7816a4118ef97e5be0bee968ce8f0a2fba7acf3bb"}, + {file = "scipy-1.15.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6ac6310fdbfb7aa6612408bd2f07295bcbd3fda00d2d702178434751fe48e019"}, + {file = "scipy-1.15.3-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:185cd3d6d05ca4b44a8f1595af87f9c372bb6acf9c808e99aa3e9aa03bd98cf6"}, + {file = "scipy-1.15.3-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:05dc6abcd105e1a29f95eada46d4a3f251743cfd7d3ae8ddb4088047f24ea477"}, + {file = "scipy-1.15.3-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:06efcba926324df1696931a57a176c80848ccd67ce6ad020c810736bfd58eb1c"}, + {file = "scipy-1.15.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c05045d8b9bfd807ee1b9f38761993297b10b245f012b11b13b91ba8945f7e45"}, + {file = "scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:271e3713e645149ea5ea3e97b57fdab61ce61333f97cfae392c28ba786f9bb49"}, + {file = "scipy-1.15.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6cfd56fc1a8e53f6e89ba3a7a7251f7396412d655bca2aa5611c8ec9a6784a1e"}, + {file = "scipy-1.15.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0ff17c0bb1cb32952c09217d8d1eed9b53d1463e5f1dd6052c7857f83127d539"}, + {file = "scipy-1.15.3-cp312-cp312-win_amd64.whl", hash = "sha256:52092bc0472cfd17df49ff17e70624345efece4e1a12b23783a1ac59a1b728ed"}, + {file = "scipy-1.15.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2c620736bcc334782e24d173c0fdbb7590a0a436d2fdf39310a8902505008759"}, + {file = "scipy-1.15.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:7e11270a000969409d37ed399585ee530b9ef6aa99d50c019de4cb01e8e54e62"}, + {file = "scipy-1.15.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:8c9ed3ba2c8a2ce098163a9bdb26f891746d02136995df25227a20e71c396ebb"}, + {file = "scipy-1.15.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0bdd905264c0c9cfa74a4772cdb2070171790381a5c4d312c973382fc6eaf730"}, + {file = "scipy-1.15.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79167bba085c31f38603e11a267d862957cbb3ce018d8b38f79ac043bc92d825"}, + {file = "scipy-1.15.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9deabd6d547aee2c9a81dee6cc96c6d7e9a9b1953f74850c179f91fdc729cb7"}, + {file = "scipy-1.15.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:dde4fc32993071ac0c7dd2d82569e544f0bdaff66269cb475e0f369adad13f11"}, + {file = "scipy-1.15.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f77f853d584e72e874d87357ad70f44b437331507d1c311457bed8ed2b956126"}, + {file = "scipy-1.15.3-cp313-cp313-win_amd64.whl", hash = "sha256:b90ab29d0c37ec9bf55424c064312930ca5f4bde15ee8619ee44e69319aab163"}, + {file = "scipy-1.15.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:3ac07623267feb3ae308487c260ac684b32ea35fd81e12845039952f558047b8"}, + {file = "scipy-1.15.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:6487aa99c2a3d509a5227d9a5e889ff05830a06b2ce08ec30df6d79db5fcd5c5"}, + {file = "scipy-1.15.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:50f9e62461c95d933d5c5ef4a1f2ebf9a2b4e83b0db374cb3f1de104d935922e"}, + {file = "scipy-1.15.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:14ed70039d182f411ffc74789a16df3835e05dc469b898233a245cdfd7f162cb"}, + {file = "scipy-1.15.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a769105537aa07a69468a0eefcd121be52006db61cdd8cac8a0e68980bbb723"}, + {file = "scipy-1.15.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9db984639887e3dffb3928d118145ffe40eff2fa40cb241a306ec57c219ebbbb"}, + {file = "scipy-1.15.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:40e54d5c7e7ebf1aa596c374c49fa3135f04648a0caabcb66c52884b943f02b4"}, + {file = "scipy-1.15.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5e721fed53187e71d0ccf382b6bf977644c533e506c4d33c3fb24de89f5c3ed5"}, + {file = "scipy-1.15.3-cp313-cp313t-win_amd64.whl", hash = "sha256:76ad1fb5f8752eabf0fa02e4cc0336b4e8f021e2d5f061ed37d6d264db35e3ca"}, + {file = "scipy-1.15.3.tar.gz", hash = "sha256:eae3cf522bc7df64b42cad3925c876e1b0b6c35c1337c93e12c0f366f55b0eaf"}, +] + [[package]] name = "segment-analytics-python" version = "2.3.4" @@ -4081,7 +4103,7 @@ name = "six" version = "1.17.0" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" summary = "Python 2 and 3 compatibility utilities" -groups = ["default", "all", "bedrock", "dev", "pynput", "vertex"] +groups = ["default", "all", "bedrock", "dev", "vertex"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -4618,3 +4640,14 @@ files = [ {file = "xlrd-2.0.2-py2.py3-none-any.whl", hash = "sha256:ea762c3d29f4cca48d82df517b6d89fbce4db3107f9d78713e48cd321d5c9aa9"}, {file = "xlrd-2.0.2.tar.gz", hash = "sha256:08b5e25de58f21ce71dc7db3b3b8106c1fa776f3024c54e45b45b374e89234c9"}, ] + +[[package]] +name = "zipp" +version = "3.23.0" +requires_python = ">=3.9" +summary = "Backport of pathlib-compatible object wrapper for zip files" +groups = ["default"] +files = [ + {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, + {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, +] diff --git a/pyproject.toml b/pyproject.toml index bbae731b..79beffc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ "anyio==4.10.0", # We need to pin this version otherwise listing mcp tools using fastmcp within runner fails "sqlalchemy[mypy]>=2.0.44", "apscheduler==4.0.0a6", + "opentelemetry-api>=1.38.0", + "imagehash>=4.3.0", ] requires-python = ">=3.10,<3.14" readme = "README.md" diff --git a/src/askui/__init__.py b/src/askui/__init__.py index d31c18ea..e49913a2 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -30,6 +30,7 @@ ToolUseBlockParam, UrlImageSourceParam, ) +from .models.shared.conversation_callback import ConversationCallback from .models.shared.settings import ( DEFAULT_GET_RESOLUTION, DEFAULT_LOCATE_RESOLUTION, @@ -76,6 +77,7 @@ "CitationPageLocationParam", "ConfigurableRetry", "ContentBlockParam", + "ConversationCallback", "DEFAULT_GET_RESOLUTION", "DEFAULT_LOCATE_RESOLUTION", "GetSettings", diff --git a/src/askui/agent.py b/src/askui/agent.py index fab00ede..a869ca33 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -9,6 +9,7 @@ from askui.container import telemetry from askui.locators.locators import Locator from askui.models.models import Point +from askui.models.shared.conversation_callback import ConversationCallback from askui.models.shared.settings import ActSettings, LocateSettings, MessageSettings from askui.models.shared.tools import Tool from askui.prompts.act_prompts import ( @@ -67,7 +68,7 @@ class ComputerAgent(Agent): ``` """ - @telemetry.record_call(exclude={"reporters", "tools", "act_tools"}) + @telemetry.record_call(exclude={"reporters", "tools", "act_tools", "callbacks"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, @@ -77,6 +78,7 @@ def __init__( settings: AgentSettings | None = None, retry: Retry | None = None, act_tools: list[Tool] | None = None, + callbacks: list[ConversationCallback] | None = None, ) -> None: reporter = CompositeReporter(reporters=reporters) self.tools = tools or AgentToolbox( @@ -109,6 +111,7 @@ def __init__( + (act_tools or []), agent_os=self.tools.os, settings=settings, + callbacks=callbacks, ) self.act_agent_os_facade: ComputerAgentOsFacade = ComputerAgentOsFacade( self.tools.os diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 4afdd5a1..6e5e94cd 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -2,7 +2,7 @@ import time import types from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Literal, Optional, Type, overload +from typing import Annotated, Literal, Optional, Type, overload from dotenv import load_dotenv from PIL import Image as PILImage @@ -13,26 +13,29 @@ from askui.container import telemetry from askui.locators.locators import Locator from askui.models.shared.agent_message_param import MessageParam -from askui.models.shared.agent_on_message_cb import OnMessageCb +from askui.models.shared.conversation import Conversation, Speakers +from askui.models.shared.conversation_callback import ConversationCallback from askui.models.shared.settings import ( ActSettings, + CacheWritingSettings, CachingSettings, GetSettings, LocateSettings, ) from askui.models.shared.tools import Tool, ToolCollection -from askui.prompts.act_prompts import create_default_prompt -from askui.prompts.caching import CACHE_USE_PROMPT +from askui.models.shared.usage_tracking_callback import UsageTrackingCallback +from askui.prompts.act_prompts import CACHE_USE_PROMPT, create_default_prompt from askui.tools.agent_os import AgentOs from askui.tools.android.agent_os import AndroidAgentOs from askui.tools.caching_tools import ( - ExecuteCachedTrajectory, + InspectCacheMetadata, RetrieveCachedTestExecutions, + VerifyCacheExecution, ) from askui.tools.get_tool import GetTool from askui.tools.locate_tool import LocateTool from askui.utils.annotation_writer import AnnotationWriter -from askui.utils.cache_writer import CacheWriter +from askui.utils.caching.cache_manager import CacheManager from askui.utils.image_utils import ImageSource from askui.utils.source_utils import InputSource, load_image_source, load_source @@ -42,9 +45,7 @@ from .models.types.response_schemas import ResponseSchema from .reporting import CompositeReporter, Reporter from .retry import ConfigurableRetry, Retry - -if TYPE_CHECKING: - from askui.models.models import ActModel +from .speaker import CacheExecutor logger = logging.getLogger(__name__) @@ -57,6 +58,7 @@ def __init__( tools: list[Tool] | None = None, agent_os: AgentOs | AndroidAgentOs | None = None, settings: AgentSettings | None = None, + callbacks: list[ConversationCallback] | None = None, ) -> None: load_dotenv() self._reporter: Reporter = reporter or CompositeReporter(reporters=None) @@ -64,14 +66,23 @@ def __init__( self._tools = tools or [] - # Build models: use provided settings or fall back to AskUI defaults - from askui.models.shared.agent import AskUIAgent - + # Store settings and model providers _settings = settings or AgentSettings() - self._act_model: ActModel = AskUIAgent( - model_id=_settings.vlm_provider.model_id, - messages_api=_settings.to_messages_api(), + self._vlm_provider = _settings.vlm_provider + self._image_qa_provider = _settings.image_qa_provider + self._detection_provider = _settings.detection_provider + + # Create conversation with speakers and model providers + speakers = Speakers() + _callbacks = list(callbacks or []) + _callbacks.append(UsageTrackingCallback(reporter=self._reporter)) + self._conversation = Conversation( + speakers=speakers, + vlm_provider=self._vlm_provider, + image_qa_provider=self._image_qa_provider, + detection_provider=self._detection_provider, reporter=self._reporter, + callbacks=_callbacks, ) # Provider-based tools @@ -102,13 +113,12 @@ def __init__( self.locate_settings = LocateSettings() self.caching_settings = CachingSettings() - @telemetry.record_call(exclude={"goal", "on_message", "act_settings", "tools"}) + @telemetry.record_call(exclude={"goal", "act_settings", "tools"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def act( self, goal: Annotated[str | list[MessageParam], Field(min_length=1)], act_settings: ActSettings | None = None, - on_message: OnMessageCb | None = None, tools: list[Tool] | ToolCollection | None = None, caching_settings: CachingSettings | None = None, ) -> None: @@ -127,16 +137,14 @@ def act( act_model (ActModel | None, optional): Model to use for this act execution. Overrides the agent's default model if provided. - on_message (OnMessageCb | None, optional): Callback for new messages. If - it returns `None`, stops and does not add the message. Cannot be used - with caching_settings strategy "write" or "both". tools (list[Tool] | ToolCollection | None, optional): The tools for the agent. Defaults to default tools depending on the selected model. caching_settings (CachingSettings | None, optional): The caching settings for the act execution. Controls recording and replaying of action - sequences (trajectories). Available strategies: "no" (default, no - caching), "write" (record actions to cache file), "read" (replay from - cached trajectories), "both" (read and write). Defaults to no caching. + sequences (trajectories). Available strategies: None (default, no + caching), "record" (record actions to cache file), "execute" (replay + from cached trajectories), "auto" (execute and record). Defaults to + no caching. Returns: None @@ -145,8 +153,6 @@ def act( MaxTokensExceededError: If the model reaches the maximum token limit defined in the agent settings. ModelRefusalError: If the model refuses to process the request. - ValueError: If on_message callback is provided with caching strategy - "write" or "both". Example: Basic usage without caching: @@ -171,14 +177,14 @@ def act( "username 'admin' and password 'secret123'" ), caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=".cache", filename="login_flow.json" ) ) ``` - Replaying cached actions: + Executing cached actions: ```python from askui import ComputerAgent from askui.models.shared.settings import CachingSettings @@ -187,14 +193,14 @@ def act( agent.act( goal="Log in to the application", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=".cache" ) ) # Agent will automatically find and use "login_flow.json" ``` - Using both read and write modes: + Using both execute and record modes: ```python from askui import ComputerAgent from askui.models.shared.settings import CachingSettings @@ -203,7 +209,7 @@ def act( agent.act( goal="Complete the checkout process", caching_settings=CachingSettings( - strategy="both", + strategy="auto", cache_dir=".cache", filename="checkout.json" ) @@ -227,19 +233,23 @@ def act( _caching_settings: CachingSettings = caching_settings or self.caching_settings - tools, on_message, cached_execution_tool = self._patch_act_with_cache( - _caching_settings, _act_settings, tools, on_message + tools, cache_manager = self._patch_act_with_cache( + _caching_settings, _act_settings, tools, goal_str ) _tools = self._build_tools(tools) - if cached_execution_tool: - cached_execution_tool.set_toolbox(_tools) + # Set toolbox on cache_manager for non-cacheable tool detection + if cache_manager: + cache_manager.set_toolbox(_tools) + + # Set cache_manager on conversation for recording + self._conversation.cache_manager = cache_manager - self._act_model.act( + # Use conversation-based architecture for execution + self._conversation.execute_conversation( messages=messages, - act_settings=_act_settings, - on_message=on_message, tools=_tools, + settings=_act_settings, ) def _build_tools(self, tools: list[Tool] | ToolCollection | None) -> ToolCollection: @@ -255,33 +265,35 @@ def _patch_act_with_cache( caching_settings: CachingSettings, settings: ActSettings, tools: list[Tool] | ToolCollection | None, - on_message: OnMessageCb | None, - ) -> tuple[ - list[Tool] | ToolCollection, OnMessageCb | None, ExecuteCachedTrajectory | None - ]: + goal: str, + ) -> tuple[list[Tool] | ToolCollection, CacheManager | None]: """Patch act settings and tools with caching functionality. Args: caching_settings: The caching settings to apply settings: The act settings to modify tools: The tools list to extend with caching tools - on_message: The message callback (may be replaced for write mode) + goal: The goal string for cache recording Returns: - A tuple of (modified_tools, modified_on_message, cached_execution_tool) + A tuple of (modified_tools, cache_manager) """ caching_tools: list[Tool] = [] - cached_execution_tool: ExecuteCachedTrajectory | None = None + cache_manager: CacheManager | None = None - # Setup read mode: add caching tools and modify system prompt - if caching_settings.strategy in ["read", "both"]: - cached_execution_tool = ExecuteCachedTrajectory( - caching_settings.execute_cached_trajectory_tool_settings - ) + # Setup execute mode: add caching tools and modify system prompt + if caching_settings.strategy in ["execute", "auto"]: + # Create CacheExecutor with execution settings and add to speakers + cache_executor = CacheExecutor(caching_settings.execution_settings) + self._conversation.speakers.add_speaker(cache_executor) + + # Add caching tools (switch_speaker tool is added automatically + # by Conversation._setup_speaker_handoff) caching_tools.extend( [ RetrieveCachedTestExecutions(caching_settings.cache_dir), - cached_execution_tool, + VerifyCacheExecution(), + InspectCacheMetadata(), ] ) if settings.messages.system is None: @@ -296,18 +308,23 @@ def _patch_act_with_cache( else: tools = caching_tools - # Setup write mode: create cache writer and set message callback - if caching_settings.strategy in ["write", "both"]: - cache_writer = CacheWriter( - caching_settings.cache_dir, caching_settings.filename + # Setup record mode: create cache manager for recording + if caching_settings.strategy in ["record", "auto"]: + cache_writer_settings = ( + caching_settings.writing_settings or CacheWritingSettings() + ) + filename = cache_writer_settings.filename or "" + + cache_manager = CacheManager() + cache_manager.start_recording( + cache_dir=caching_settings.cache_dir, + file_name=filename, + goal=goal, + cache_writer_settings=cache_writer_settings, + vlm_provider=self._vlm_provider, ) - if on_message is None: - on_message = cache_writer.add_message_cb - else: - error_message = "Cannot use on_message callback when writing Cache" - raise ValueError(error_message) - return tools, on_message, cached_execution_tool + return tools, cache_manager @overload def get( diff --git a/src/askui/android_agent.py b/src/askui/android_agent.py index 3c4b29ad..9cd4711f 100644 --- a/src/askui/android_agent.py +++ b/src/askui/android_agent.py @@ -9,6 +9,7 @@ from askui.container import telemetry from askui.locators.locators import Locator from askui.models.models import Point +from askui.models.shared.conversation_callback import ConversationCallback from askui.models.shared.settings import ActSettings, MessageSettings from askui.models.shared.tools import Tool from askui.prompts.act_prompts import create_android_agent_prompt @@ -63,7 +64,7 @@ class AndroidAgent(Agent): ``` """ - @telemetry.record_call(exclude={"reporters", "tools", "act_tools"}) + @telemetry.record_call(exclude={"reporters", "tools", "act_tools", "callbacks"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, @@ -72,6 +73,7 @@ def __init__( settings: AgentSettings | None = None, retry: Retry | None = None, act_tools: list[Tool] | None = None, + callbacks: list[ConversationCallback] | None = None, ) -> None: reporter = CompositeReporter(reporters=reporters) self.os = PpadbAgentOs(device_identifier=device, reporter=reporter) @@ -98,6 +100,7 @@ def __init__( + (act_tools or []), agent_os=self.os, settings=settings, + callbacks=callbacks, ) self.act_tool_collection.add_agent_os(self.act_agent_os_facade) # Override default act settings with Android-specific settings diff --git a/src/askui/custom_agent.py b/src/askui/custom_agent.py deleted file mode 100644 index 3192fea2..00000000 --- a/src/askui/custom_agent.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Annotated - -from pydantic import ConfigDict, Field, validate_call - -from askui.container import telemetry -from askui.models.models import ActModel -from askui.models.shared.agent_message_param import MessageParam -from askui.models.shared.agent_on_message_cb import OnMessageCb -from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool, ToolCollection - - -class CustomAgent: - """Custom agent for headless agentic tasks without OS integration.""" - - def __init__(self, act_model: ActModel | None = None) -> None: - from askui.agent_settings import AgentSettings - from askui.models.shared.agent import AskUIAgent - - if act_model is not None: - self._act_model = act_model - else: - _settings = AgentSettings() - self._act_model = AskUIAgent( - model_id=_settings.vlm_provider.model_id, - messages_api=_settings.to_messages_api(), - ) - self.act_settings = ActSettings() - - @telemetry.record_call(exclude={"messages", "on_message", "settings", "tools"}) - @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) - def act( - self, - messages: Annotated[list[MessageParam], Field(min_length=1)], - on_message: OnMessageCb | None = None, - tools: list[Tool] | ToolCollection | None = None, - settings: ActSettings | None = None, - ) -> None: - _settings = settings or self.act_settings - _tools = self._build_tools(tools) - self._act_model.act( - messages=messages, - act_settings=_settings, - on_message=on_message, - tools=_tools, - ) - - def _build_tools(self, tools: list[Tool] | ToolCollection | None) -> ToolCollection: - if isinstance(tools, list): - return ToolCollection(tools=tools) - if isinstance(tools, ToolCollection): - return tools - return ToolCollection() diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index fc3a0986..56c237ff 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -46,6 +46,23 @@ def _is_retryable_error(exception: BaseException) -> bool: return isinstance(exception, (APIConnectionError, APITimeoutError, APIError)) +def _sanitize_message_for_api(message: MessageParam) -> dict[str, Any]: + """Remove non-API fields from a message before sending to Anthropic API. + + Fields like `usage`, `stop_reason`, and `visual_representation` are used + internally but not accepted by the API. + """ + msg_dict = message.model_dump(exclude={"stop_reason", "usage"}) + + # Remove visual_representation from tool_use blocks in content + if isinstance(msg_dict.get("content"), list): + for block in msg_dict["content"]: + if isinstance(block, dict) and block.get("type") == "tool_use": + block.pop("visual_representation", None) + + return msg_dict + + def built_messages_for_get_and_locate( scaled_image: Image, prompt: str ) -> list[MessageParam]: @@ -140,7 +157,7 @@ def create_message( provider_options: dict[str, Any] | None = None, ) -> MessageParam: _messages = [ - cast("BetaMessageParam", message.model_dump(exclude={"stop_reason"})) + cast("BetaMessageParam", _sanitize_message_for_api(message)) for message in messages ] diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 15061b0d..12c65fe9 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -6,7 +6,6 @@ from askui.locators.locators import Locator from askui.models.shared.agent_message_param import MessageParam -from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings, GetSettings, LocateSettings from askui.models.shared.tools import ToolCollection from askui.models.types.geometry import Point, PointList @@ -140,7 +139,6 @@ def act( self, messages: list[MessageParam], model: str, - on_message: OnMessageCb | None = None, tools: list[Tool] | None = None, settings: AgentSettings | None = None, ) -> None: @@ -155,7 +153,6 @@ def act( self, messages: list[MessageParam], act_settings: ActSettings, - on_message: OnMessageCb | None = None, tools: ToolCollection | None = None, ) -> None: """ @@ -174,14 +171,6 @@ def act( Args: messages (list[MessageParam]): The message history to start that determines the actions and following messages. - on_message (OnMessageCb | None, optional): Callback for new messages - from either an assistant/agent or a user (including - automatic/programmatic tool use, e.g., taking a screenshot). - If it returns `None`, the acting is canceled and `act()` returns - immediately. If it returns a `MessageParam`, this `MessageParma` is - added to the message history and the acting continues based on the - message. The message may be modified by the callback to allow for - directing the assistant/agent or tool use. tools (ToolCollection | None, optional): The tools for the agent. Defaults to `None`. act_settings (ActSettings): The settings for this act operation, diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py deleted file mode 100644 index 683ada6a..00000000 --- a/src/askui/models/shared/agent.py +++ /dev/null @@ -1,194 +0,0 @@ -"""AskUIAgent — core tool-calling act loop for autonomous agents.""" - -import logging - -from typing_extensions import override - -from askui.models.exceptions import MaxTokensExceededError, ModelRefusalError -from askui.models.models import ActModel -from askui.models.shared.agent_message_param import MessageParam -from askui.models.shared.agent_on_message_cb import ( - NULL_ON_MESSAGE_CB, - OnMessageCb, - OnMessageCbParam, -) -from askui.models.shared.messages_api import MessagesApi -from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import ToolCollection -from askui.models.shared.truncation_strategies import ( - SimpleTruncationStrategyFactory, - TruncationStrategy, - TruncationStrategyFactory, -) -from askui.reporting import NULL_REPORTER, Reporter - -logger = logging.getLogger(__name__) - - -class AskUIAgent(ActModel): - """Base class for agents that can execute autonomous actions. - - This class provides common functionality for both AskUI and Anthropic agents, - including tool handling, message processing, and image filtering. - - Args: - model_id (str): The identifier of the LLM to use. - messages_api (MessagesApi): Messages API for creating messages. - reporter (Reporter, optional): The reporter for logging messages and actions. - Defaults to `NULL_REPORTER`. - truncation_strategy (TruncationStrategyFactory, optional): The truncation - strategy factory to use. This is used to create the truncation strategy - to truncate the message history before sending it to the model. - Defaults to `SimpleTruncationStrategyFactory`. - """ - - def __init__( - self, - model_id: str, - messages_api: MessagesApi, - reporter: Reporter = NULL_REPORTER, - truncation_strategy_factory: TruncationStrategyFactory | None = None, - ) -> None: - self._model_id = model_id - self._messages_api = messages_api - self._reporter = reporter - self._truncation_strategy_factory = ( - truncation_strategy_factory or SimpleTruncationStrategyFactory() - ) - - def _step( - self, - on_message: OnMessageCb, - settings: ActSettings, - tool_collection: ToolCollection, - truncation_strategy: TruncationStrategy, - ) -> None: - """Execute a single step in the conversation. - - If the last message is an assistant's message and does not contain tool use - blocks, this method is going to return immediately, as there is nothing to act - upon. - - Args: - on_message (OnMessageCb): Callback on new messages - settings (ActSettings): The settings for the step. - tool_collection (ToolCollection): The tools to use for the step. - truncation_strategy (TruncationStrategy): The truncation strategy to use - for the step. - - Returns: - None - """ - if truncation_strategy.messages[-1].role == "user": - response_message = self._messages_api.create_message( - messages=truncation_strategy.messages, - model_id=self._model_id, - tools=tool_collection, - max_tokens=settings.messages.max_tokens, - system=settings.messages.system, - thinking=settings.messages.thinking, - tool_choice=settings.messages.tool_choice, - temperature=settings.messages.temperature, - provider_options=settings.messages.provider_options, - ) - message_by_assistant = self._call_on_message( - on_message, response_message, truncation_strategy.messages - ) - if message_by_assistant is None: - return - message_by_assistant_dict = message_by_assistant.model_dump(mode="json") - logger.debug(message_by_assistant_dict) - truncation_strategy.append_message(message_by_assistant) - self._reporter.add_message( - self.__class__.__name__, message_by_assistant_dict - ) - else: - message_by_assistant = truncation_strategy.messages[-1] - self._handle_stop_reason(message_by_assistant, settings.messages.max_tokens) - if tool_result_message := self._use_tools( - message_by_assistant, tool_collection - ): - if tool_result_message := self._call_on_message( - on_message, tool_result_message, truncation_strategy.messages - ): - tool_result_message_dict = tool_result_message.model_dump(mode="json") - logger.debug(tool_result_message_dict) - truncation_strategy.append_message(tool_result_message) - self._step( - tool_collection=tool_collection, - on_message=on_message, - settings=settings, - truncation_strategy=truncation_strategy, - ) - - def _call_on_message( - self, - on_message: OnMessageCb | None, - message: MessageParam, - messages: list[MessageParam], - ) -> MessageParam | None: - if on_message is None: - return message - return on_message(OnMessageCbParam(message=message, messages=messages)) - - @override - def act( - self, - messages: list[MessageParam], - act_settings: ActSettings, - on_message: OnMessageCb | None = None, - tools: ToolCollection | None = None, - ) -> None: - _tool_collection = tools or ToolCollection() - - truncation_strategy = ( - self._truncation_strategy_factory.create_truncation_strategy( - tools=_tool_collection.to_params(), - system=act_settings.messages.system, - messages=messages, - model=self._model_id, - ) - ) - self._step( - on_message=on_message or NULL_ON_MESSAGE_CB, - settings=act_settings, - tool_collection=_tool_collection, - truncation_strategy=truncation_strategy, - ) - - def _use_tools( - self, - message: MessageParam, - tool_collection: ToolCollection, - ) -> MessageParam | None: - """Process tool use blocks in a message. - - Args: - message (MessageParam): The message containing tool use blocks. - - Returns: - MessageParam | None: A message containing tool results or `None` - if no tools were used. - """ - if isinstance(message.content, str): - return None - - tool_use_content_blocks = [ - content_block - for content_block in message.content - if content_block.type == "tool_use" - ] - content = tool_collection.run(tool_use_content_blocks) - if len(content) == 0: - return None - - return MessageParam( - content=content, - role="user", - ) - - def _handle_stop_reason(self, message: MessageParam, max_tokens: int) -> None: - if message.stop_reason == "max_tokens": - raise MaxTokensExceededError(max_tokens) - if message.stop_reason == "refusal": - raise ModelRefusalError diff --git a/src/askui/models/shared/agent_message_param.py b/src/askui/models/shared/agent_message_param.py index 3d2f2f2b..7b82e87e 100644 --- a/src/askui/models/shared/agent_message_param.py +++ b/src/askui/models/shared/agent_message_param.py @@ -80,6 +80,7 @@ class ToolUseBlockParam(BaseModel): name: str type: Literal["tool_use"] = "tool_use" cache_control: CacheControlEphemeralParam | None = None + visual_representation: str | None = None # Visual hash for cache validation class BetaThinkingBlock(BaseModel): @@ -114,10 +115,22 @@ class BetaRedactedThinkingBlock(BaseModel): ToolParam = dict[str, Any] +class UsageParam(BaseModel): + """Token usage statistics from model API calls.""" + + input_tokens: int | None = None + output_tokens: int | None = None + cache_creation_input_tokens: int | None = None + cache_read_input_tokens: int | None = None + + class MessageParam(BaseModel): + """A message in a conversation.""" + role: Literal["user", "assistant"] content: str | list[ContentBlockParam] stop_reason: StopReason | None = None + usage: UsageParam | None = None __all__ = [ @@ -134,6 +147,7 @@ class MessageParam(BaseModel): "ToolResultBlockParam", "ToolUseBlockParam", "UrlImageSourceParam", + "UsageParam", "ThinkingConfigParam", "ToolChoiceParam", ] diff --git a/src/askui/models/shared/conversation.py b/src/askui/models/shared/conversation.py new file mode 100644 index 00000000..ee812a20 --- /dev/null +++ b/src/askui/models/shared/conversation.py @@ -0,0 +1,444 @@ +"""Conversation class for managing speaker-based agent interactions.""" + +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from opentelemetry import trace + +from askui.model_providers.detection_provider import DetectionProvider +from askui.model_providers.image_qa_provider import ImageQAProvider +from askui.model_providers.vlm_provider import VlmProvider +from askui.models.shared.agent_message_param import MessageParam +from askui.models.shared.settings import ActSettings +from askui.models.shared.tools import ToolCollection +from askui.models.shared.truncation_strategies import ( + SimpleTruncationStrategyFactory, + TruncationStrategy, + TruncationStrategyFactory, +) +from askui.reporting import NULL_REPORTER, Reporter +from askui.speaker.speaker import SpeakerResult, Speakers +from askui.tools.switch_speaker_tool import SwitchSpeakerTool + +if TYPE_CHECKING: + from askui.models.shared.conversation_callback import ConversationCallback + from askui.utils.caching.cache_manager import CacheManager + +logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) + + +class ConversationException(Exception): + """Exception raised during conversation execution.""" + + def __init__(self, msg: str) -> None: + super().__init__(msg) + self.msg = msg + + +class Conversation: + """Manages conversation state and delegates execution to speakers. + + The Conversation holds all model providers (`VlmProvider`, `ImageQAProvider`, + `DetectionProvider`), message history, truncation strategy, token usage, + and current speaker. It orchestrates the conversation by delegating each + step to the appropriate speaker. + + Speakers access the model providers via the conversation instance + (e.g., `conversation.vlm_provider`). + + Args: + speakers: Collection of speakers to use + vlm_provider: VLM provider for LLM API calls + image_qa_provider: Image Q&A provider (optional) + detection_provider: Detection provider (optional) + reporter: Reporter for logging messages and actions + cache_manager: Cache manager for recording/playback (optional) + truncation_strategy_factory: Factory for creating truncation strategies + callbacks: List of callbacks for conversation lifecycle hooks (optional) + """ + + def __init__( + self, + speakers: Speakers, + vlm_provider: VlmProvider, + image_qa_provider: ImageQAProvider | None = None, + detection_provider: DetectionProvider | None = None, + reporter: Reporter = NULL_REPORTER, + cache_manager: "CacheManager | None" = None, + truncation_strategy_factory: TruncationStrategyFactory | None = None, + callbacks: "list[ConversationCallback] | None" = None, + ) -> None: + """Initialize conversation with speakers and model providers.""" + if not speakers: + msg = "At least one speaker must be provided" + raise ValueError(msg) + + # Identity + self.conversation_id: str = str(uuid.uuid4()) + + # Speakers and current state + self.speakers = speakers + self.current_speaker = speakers[speakers.default_speaker] + + # Model providers - accessible by speakers via conversation instance + self.vlm_provider = vlm_provider + self.image_qa_provider = image_qa_provider + self.detection_provider = detection_provider + + # Infrastructure + self._reporter = reporter + self.cache_manager = cache_manager + self._truncation_strategy_factory = ( + truncation_strategy_factory or SimpleTruncationStrategyFactory() + ) + self._truncation_strategy: TruncationStrategy | None = None + self._callbacks: "list[ConversationCallback]" = callbacks or [] + + # State for current execution (set in start()) + self.settings: ActSettings = ActSettings() + self.tools: ToolCollection = ToolCollection() + self._reporters: list[Reporter] = [] + self._step_index: int = 0 + + # Track if cache execution was used (to prevent recording during playback) + self._executed_from_cache: bool = False + + def _on_conversation_start(self) -> None: + for callback in self._callbacks: + callback.on_conversation_start(self) + + def _on_conversation_end(self) -> None: + for callback in self._callbacks: + callback.on_conversation_end(self) + + def _on_control_loop_start(self) -> None: + for callback in self._callbacks: + callback.on_control_loop_start(self) + + def _on_control_loop_end(self) -> None: + for callback in self._callbacks: + callback.on_control_loop_end(self) + + def _on_step_start(self, step_index: int) -> None: + for callback in self._callbacks: + callback.on_step_start(self, step_index) + + def _on_step_end(self, step_index: int, result: SpeakerResult) -> None: + for callback in self._callbacks: + callback.on_step_end(self, step_index, result) + + def _on_speaker_switch(self, from_speaker: str, to_speaker: str) -> None: + for callback in self._callbacks: + callback.on_speaker_switch(self, from_speaker, to_speaker) + + def _on_tool_execution_start(self, tool_names: list[str]) -> None: + for callback in self._callbacks: + callback.on_tool_execution_start(self, tool_names) + + def _on_tool_execution_end(self, tool_names: list[str]) -> None: + for callback in self._callbacks: + callback.on_tool_execution_end(self, tool_names) + + @tracer.start_as_current_span("execute_conversation") + def execute_conversation( + self, + messages: list[MessageParam], + tools: ToolCollection | None = None, + settings: ActSettings | None = None, + reporters: list[Reporter] | None = None, + ) -> None: + """Setup conversation state and start control loop. + + Model providers are accessed via self.vlm_provider, etc. + Speakers can access them via conversation.vlm_provider. + + Args: + messages: Initial message history + tools: Available tools + settings: Agent settings + reporters: Optional list of additional reporters for this conversation + """ + msg = f"Starting conversation with speaker: {self.current_speaker.get_name()}" + logger.info(msg) + + self._setup_control_loop(messages, tools, settings, reporters) + + self._on_conversation_start() + self._execute_control_loop() + self._on_conversation_end() + + self._conclude_control_loop() + + @tracer.start_as_current_span("_setup_control_loop") + def _setup_control_loop( + self, + messages: list[MessageParam], + tools: ToolCollection | None = None, + settings: ActSettings | None = None, + reporters: list[Reporter] | None = None, + ) -> None: + # Reset state + self._executed_from_cache = False + self.speakers.reset_state() + + # Store execution parameters + self.settings = settings or ActSettings() + self.tools = tools or ToolCollection() + self._reporters = reporters or [] + + # Auto-populate speaker descriptions and switch_speaker tool + self._setup_speaker_handoff() + + # Initialize truncation strategy + self._truncation_strategy = ( + self._truncation_strategy_factory.create_truncation_strategy( + tools=self.tools.to_params(), + system=self.settings.messages.system, + messages=messages, + model=self.vlm_provider.model_id, + ) + ) + + @tracer.start_as_current_span("_execute_control_loop") + def _execute_control_loop(self) -> None: + self._on_control_loop_start() + self._step_index = 0 + continue_execution = True + while continue_execution: + continue_execution = self._execute_step() + self._on_control_loop_end() + + @tracer.start_as_current_span("_conclude_control_loop") + def _conclude_control_loop(self) -> None: + # Finish recording if cache_manager is active and not executing from cache + if self.cache_manager is not None and not self._executed_from_cache: + self.cache_manager.finish_recording(self.get_messages()) + + def _setup_speaker_handoff(self) -> None: + """Set up speaker handoff infrastructure. + + If there are speakers with descriptions (handoff targets), this method: + 1. Appends an ```` section to ``system_capabilities`` + 2. Adds a ``SwitchSpeakerTool`` to the tool collection + """ + speaker_descriptions = self._build_speaker_descriptions() + if not speaker_descriptions: + return + + # Append speaker descriptions to system_capabilities + if self.settings.messages.system is not None: + has_capabilities = self.settings.messages.system.system_capabilities + separator = "\n\n" if has_capabilities else "" + self.settings.messages.system.system_capabilities += ( + f"{separator}\n" + "The following specialized speakers are available in this " + "conversation. Use the switch_speaker tool to hand off to " + "them when appropriate.\n\n" + f"{speaker_descriptions}\n" + "" + ) + + # Create switch_speaker tool with valid speaker names + handoff_speakers = [ + speaker.get_name() for speaker in self.speakers if speaker.get_description() + ] + switch_tool = SwitchSpeakerTool(speaker_names=handoff_speakers) + self.tools.append_tool(switch_tool) + + def _build_speaker_descriptions(self) -> str: + """Build formatted speaker descriptions for the system prompt. + + Returns: + Formatted string with speaker names and descriptions, + or empty string if no speakers have descriptions. + """ + descriptions: list[str] = [] + for speaker in self.speakers: + description = speaker.get_description() + if description: + descriptions.append(f"### {speaker.get_name()}\n{description}") + return "\n\n".join(descriptions) + + @tracer.start_as_current_span("_execute_step") + def _execute_step(self) -> bool: + """Execute one step of the conversation loop with speakers. + + Each step includes: + 1. Infer next speaker + 2. Get message(s) from active speaker and add to history + 3. Execute tool calls if applicable and add result to history + 4. Check if conversation should continue and switch speaker if necessary + 5. Collect Statistics + + Returns: + True if loop should continue, False if done + """ + self._on_step_start(self._step_index) + + # 1. Infer next speaker + self._switch_speaker_if_needed() + + # 2. Get next message(s) from speaker and add to history + logger.debug("Executing step with speaker: %s", self.current_speaker.get_name()) + result: SpeakerResult = self.current_speaker.handle_step( + self, self.cache_manager + ) + for message in result.messages_to_add: + self._add_message(message) + + # 3. Execute tool calls if applicable + continue_loop = False + if result.messages_to_add: + last_message = result.messages_to_add[-1] + tool_result_message = self._execute_tools_if_present(last_message) + if tool_result_message: + self._add_message(tool_result_message) + continue_loop = True # we always continue after a tool was called + + # 4. Check if conversation should continue and switch speaker if necessary + # Note:_handle_continue_conversation must always be called (not short-circuited) + # because it has side effects (e.g., triggering speaker switches). + status_continue = self._handle_continue_conversation(result) + continue_loop = continue_loop or status_continue + + self._on_step_end(self._step_index, result) + self._step_index += 1 + + return continue_loop + + @tracer.start_as_current_span("_execute_tools_if_present") + def _execute_tools_if_present(self, message: MessageParam) -> MessageParam | None: + """Execute tools if the message contains tool use blocks. + + Args: + message: Message to check for tool calls + + Returns: + MessageParam with tool results, or None if no tools to execute + """ + # Only process assistant messages + if message.role != "assistant": + return None + + # Check if content is a list (could contain tool use blocks) + if isinstance(message.content, str): + return None + + # Find tool use blocks + tool_use_blocks = [ + block for block in message.content if block.type == "tool_use" + ] + + if not tool_use_blocks: + return None + + # Execute tools + tool_names = [block.name for block in tool_use_blocks] + logger.debug("Executing %d tool(s)", len(tool_use_blocks)) + self._on_tool_execution_start(tool_names) + tool_results = self.tools.run(tool_use_blocks) + self._on_tool_execution_end(tool_names) + + if not tool_results: + return None + + # Return tool results as a user message + return MessageParam(content=tool_results, role="user") + + def _add_message(self, message: MessageParam) -> None: + """Add message to conversation history. + + Args: + message: Message to add + """ + if not self._truncation_strategy: + logger.error("No truncation strategy, cannot add message") + return + + # Add to truncation strategy + self._truncation_strategy.append_message(message) + + # Report to reporter + self._reporter.add_message( + self.current_speaker.get_name(), message.model_dump(mode="json") + ) + + @tracer.start_as_current_span("_handle_continue_conversation") + def _handle_continue_conversation(self, result: SpeakerResult) -> bool: + """Handle speaker result status and determine if loop should continue. + + Args: + result: Result from speaker + + Returns: + True if loop should continue, False if done + """ + if result.status == "done": + logger.info("Conversation completed successfully") + return False + if result.status == "failed": + logger.error("Conversation failed") + return False + if result.status == "switch_speaker": + if result.next_speaker: + self.switch_speaker( + result.next_speaker, + speaker_context=result.speaker_context, + ) + return True + # status == "continue" + return True + + def _switch_speaker_if_needed(self) -> None: + """Switch to default speaker if current one cannot handle.""" + if not self.current_speaker.can_handle(self): + logger.debug( + "Speaker %s cannot handle current state, switching to default", + self.current_speaker.get_name(), + ) + self.switch_speaker(self.speakers.default_speaker) + + @tracer.start_as_current_span("switch_speaker") + def switch_speaker( + self, + speaker_name: str, + speaker_context: dict[str, Any] | None = None, + ) -> None: + """Switch to a different speaker, optionally passing activation context. + + Args: + speaker_name: Name of the speaker to switch to. + speaker_context: Optional activation context to pass to the + target speaker via ``on_activate()``. + """ + old_speaker = self.current_speaker + self.current_speaker = self.speakers[speaker_name] + logger.info( + "Switched speaker: %s => %s", + old_speaker.get_name(), + self.current_speaker.get_name(), + ) + self._on_speaker_switch( + old_speaker.get_name(), + self.current_speaker.get_name(), + ) + if speaker_context is not None: + self.current_speaker.on_activate(speaker_context) + + def get_messages(self) -> list[MessageParam]: + """Get current message history from truncation strategy. + + Returns: + List of messages in current conversation + """ + return self._truncation_strategy.messages if self._truncation_strategy else [] + + def get_truncation_strategy(self) -> TruncationStrategy | None: + """Get current truncation strategy. + + Returns: + Current truncation strategy or None if not initialized + """ + return self._truncation_strategy diff --git a/src/askui/models/shared/conversation_callback.py b/src/askui/models/shared/conversation_callback.py new file mode 100644 index 00000000..a375ca21 --- /dev/null +++ b/src/askui/models/shared/conversation_callback.py @@ -0,0 +1,125 @@ +"""Callback system for conversation execution hooks.""" + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from askui.models.shared.conversation import Conversation + from askui.speaker.speaker import SpeakerResult + + +class ConversationCallback: + """Base class for conversation callbacks. + + Subclass this and override only the methods you need. All methods + have empty default implementations. + + The callback methods are called in the following order: + 1. `on_conversation_start` - After setup, before control loop + 2. `on_control_loop_start` - Before the while loop begins + 3. For each step: + a. `on_step_start` - Before step execution + b. `on_tool_execution_start` - Before tools are executed (if any) + c. `on_tool_execution_end` - After tools are executed (if any) + d. `on_speaker_switch` - When a speaker switch occurs (if any) + e. `on_step_end` - After step execution + 4. `on_control_loop_end` - After the while loop ends + 5. `on_conversation_end` - Before cleanup + + Example: + ```python + class LoggingCallback(ConversationCallback): + def on_step_start(self, conversation, step_index): + print(f"Starting step {step_index}") + + def on_step_end(self, conversation, step_index, result): + print(f"Step {step_index} completed: {result.status}") + + + with ComputerAgent(callbacks=[LoggingCallback()]) as agent: + agent.act("Open the settings menu") + ``` + """ + + def on_conversation_start(self, conversation: "Conversation") -> None: + """Called when conversation begins (after setup, before control loop). + + Args: + conversation: The conversation instance with initialized state. + """ + + def on_conversation_end(self, conversation: "Conversation") -> None: + """Called when conversation ends (after control loop, before cleanup). + + Args: + conversation: The conversation instance. + """ + + def on_control_loop_start(self, conversation: "Conversation") -> None: + """Called before the control loop starts iterating. + + Args: + conversation: The conversation instance. + """ + + def on_control_loop_end(self, conversation: "Conversation") -> None: + """Called after the control loop finishes (success or failure). + + Args: + conversation: The conversation instance. + """ + + def on_step_start(self, conversation: "Conversation", step_index: int) -> None: + """Called before each step execution. + + Args: + conversation: The conversation instance. + step_index: Zero-based index of the current step. + """ + + def on_step_end( + self, + conversation: "Conversation", + step_index: int, + result: "SpeakerResult", + ) -> None: + """Called after each step execution. + + Args: + conversation: The conversation instance. + step_index: Zero-based index of the completed step. + result: The result from the speaker. + """ + + def on_speaker_switch( + self, + conversation: "Conversation", + from_speaker: str, + to_speaker: str, + ) -> None: + """Called when a speaker switch occurs. + + Args: + conversation: The conversation instance. + from_speaker: Name of the speaker being switched from. + to_speaker: Name of the speaker being switched to. + """ + + def on_tool_execution_start( + self, conversation: "Conversation", tool_names: list[str] + ) -> None: + """Called before tools are executed. + + Args: + conversation: The conversation instance. + tool_names: Names of tools about to be executed. + """ + + def on_tool_execution_end( + self, conversation: "Conversation", tool_names: list[str] + ) -> None: + """Called after tools are executed. + + Args: + conversation: The conversation instance. + tool_names: Names of tools that were executed. + """ diff --git a/src/askui/models/shared/settings.py b/src/askui/models/shared/settings.py index ee376c8b..3fc7624a 100644 --- a/src/askui/models/shared/settings.py +++ b/src/askui/models/shared/settings.py @@ -1,9 +1,15 @@ +from datetime import datetime from typing import Any, NamedTuple from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Literal -from askui.models.shared.agent_message_param import ThinkingConfigParam, ToolChoiceParam +from askui.models.shared.agent_message_param import ( + ThinkingConfigParam, + ToolChoiceParam, + ToolUseBlockParam, + UsageParam, +) from askui.models.shared.prompts import ( ActSystemPrompt, GetSystemPrompt, @@ -28,7 +34,9 @@ class Resolution(NamedTuple): DEFAULT_LOCATE_RESOLUTION = Resolution(1280, 800) DEFAULT_GET_RESOLUTION = Resolution(1280, 800) -CACHING_STRATEGY = Literal["read", "write", "both", "no"] +CACHING_STRATEGY = Literal["execute", "record", "auto"] +CACHE_PARAMETER_IDENTIFICATION_STRATEGY = Literal["llm", "preset"] +CACHING_VISUAL_VERIFICATION_METHOD = Literal["phash", "ahash", "none"] class MessageSettings(BaseModel): @@ -155,15 +163,99 @@ class LocateSettings(BaseModel): resolution: Resolution = DEFAULT_LOCATE_RESOLUTION -class CachedExecutionToolSettings(BaseModel): - """Settings for executing cached action trajectories. +class CacheFailure(BaseModel): + """Record of a single cache execution failure. + + Args: + timestamp: When the failure occurred + step_index: Index of the step that failed + error_message: Description of the failure + failure_count_at_step: Running count of failures at this step + """ + + timestamp: datetime + step_index: int + error_message: str + failure_count_at_step: int + + +class VisualValidationMetadata(BaseModel): + enabled: bool + method: CACHING_VISUAL_VERIFICATION_METHOD + region_size: int + + +class CacheMetadata(BaseModel): + """Metadata for a cache file including execution history and validation state. + + Args: + version: Cache format version + created_at: When the cache was created + goal: Original goal text (may be parameterized) + last_executed_at: When the cache was last executed + token_usage: Accumulated token usage from recording + execution_attempts: Total number of execution attempts + failures: List of recorded failures + is_valid: Whether cache is still valid + invalidation_reason: Why cache was invalidated (if applicable) + visual_validation: Visual validation configuration + """ + + version: str = "0.2" + created_at: datetime + goal: str | None = None + last_executed_at: datetime | None = None + token_usage: UsageParam | None = None + execution_attempts: int = 0 + failures: list[CacheFailure] = Field(default_factory=list) + is_valid: bool = True + invalidation_reason: str | None = None + visual_validation: VisualValidationMetadata | None = None + + +class CacheFile(BaseModel): + """Complete cache file structure with metadata and trajectory. Args: - delay_time_between_action (float): Delay in seconds between replaying - cached actions. Allows time for UI to respond. Default: 0.5. + metadata: Cache metadata and execution history + trajectory: List of tool use blocks to execute + cache_parameters: Dict mapping parameter names to descriptions """ - delay_time_between_action: float = 0.5 + metadata: CacheMetadata + trajectory: list[ToolUseBlockParam] + cache_parameters: dict[str, str] = Field(default_factory=dict) + + +class CacheWritingSettings(BaseModel): + """Settings for recording cache files. + + Args: + filename: Name for the cache file (auto-generated if empty) + parameter_identification_strategy: How to identify parameters("llm" or "preset") + llm_parameter_id_api_provider: API provider for LLM parameter identification + visual_verification_method: Visual hash method ("phash", "ahash", or "none") + visual_validation_region_size: Size of region to hash around coordinates + """ + + filename: str = "" + parameter_identification_strategy: CACHE_PARAMETER_IDENTIFICATION_STRATEGY = "llm" + visual_verification_method: CACHING_VISUAL_VERIFICATION_METHOD = "phash" + visual_validation_region_size: int = 100 + + +class CacheExecutionSettings(BaseModel): + """Settings for executing/replaying cached trajectories. + + Args: + delay_time_between_actions: Delay in seconds between actions + skip_visual_validation: Override to disable visual validation + visual_validation_threshold: Max Hamming distance for validation + """ + + delay_time_between_actions: float = 1.0 # keep >1s to give UI time to materialize + skip_visual_validation: bool = False + visual_validation_threshold: int = 10 class CachingSettings(BaseModel): @@ -173,21 +265,18 @@ class CachingSettings(BaseModel): performance optimization. Args: - strategy (CACHING_STRATEGY): Caching mode. Options: - - "no": Caching disabled (default) - - "read": Replay actions from cache - - "write": Record actions to cache - - "both": Read from cache if available, otherwise record + strategy (CACHING_STRATEGY | None): Caching mode. Options: + - None: Caching disabled (default) + - "execute": Replay actions from cache + - "record": Record actions to cache + - "auto": Execute from cache if available, otherwise record cache_dir (str): Directory path for storing cache files. - Default: ".cache". - filename (str): Name of the cache file. If empty, auto-generated. - execute_cached_trajectory_tool_settings (CachedExecutionToolSettings): - Settings for replaying cached actions. + Default: ".askui_cache". + writing_settings: Settings for cache recording (used in "record"/"auto" modes) + execution_settings: Settings for cache playback (used in "execute"/"auto" modes) """ - strategy: CACHING_STRATEGY = "no" - cache_dir: str = ".cache" - filename: str = "" - execute_cached_trajectory_tool_settings: CachedExecutionToolSettings = ( - CachedExecutionToolSettings() - ) + strategy: CACHING_STRATEGY | None = None + cache_dir: str = ".askui_cache" + writing_settings: CacheWritingSettings | None = None + execution_settings: CacheExecutionSettings | None = None diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 4b541dca..cbdd0856 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -199,6 +199,15 @@ def name(self, value: str) -> None: """Sets the base name of the tool.""" self.base_name = value + is_cacheable: bool = Field( + default=False, + description=( + "Whether this tool's actions can be cached and replayed. " + "False by default. Set to True for tools that produce deterministic " + "results and where side-effects during replay are unlikely or tolerable." + ), + ) + def to_params( self, ) -> ToolParam: @@ -456,12 +465,50 @@ def _run_tool( mcp_tool = self._get_mcp_tools().get(tool_use_block_param.name) if mcp_tool: return self._run_mcp_tool(tool_use_block_param) + # Fallback: try prefix matching (for cached trajectories with different UUIDs) + tool = self.find_tool_by_prefix(tool_use_block_param.name) + if tool: + return self._run_regular_tool(tool_use_block_param, tool) + msg = f"no matching tool found with name {tool_use_block_param.name}" + logger.error(msg) return ToolResultBlockParam( content=f"Tool not found: {tool_use_block_param.name}", is_error=True, tool_use_id=tool_use_block_param.id, ) + def find_tool_by_prefix(self, cached_name: str) -> Tool | None: + """Find a tool by matching name prefix (without UUID suffix). + + Tool names have format: {base_name}_tags_{tags}_{uuid} or {base_name}_{uuid} + This method strips the UUID suffix and matches by the remaining prefix. + + This is useful for cached trajectories where tool names may have different + UUIDs than the current session. + + Args: + cached_name: Tool name from cached trajectory (may have different UUID) + + Returns: + Matching Tool if found, None otherwise + """ + # Extract prefix by removing trailing UUID (pattern: _xxxxxxxx-xxxx-...) + # UUIDs start with 8 hex chars after an underscore + uuid_pattern = re.compile(r"_[0-9a-f]{8}-[0-9a-f]{4}.*$", re.IGNORECASE) + cached_prefix = uuid_pattern.sub("", cached_name) + + if not cached_prefix or cached_prefix == cached_name: + # No UUID found or name unchanged, can't match by prefix + return None + + # Find a tool whose name starts with the same prefix + for tool_name, tool in self.tool_map.items(): + tool_prefix = uuid_pattern.sub("", tool_name) + if tool_prefix == cached_prefix: + return tool + + return None + async def _list_mcp_tools(self, mcp_client: McpClientProtocol) -> list[McpTool]: async with mcp_client: return await mcp_client.list_tools() diff --git a/src/askui/models/shared/usage_tracking_callback.py b/src/askui/models/shared/usage_tracking_callback.py new file mode 100644 index 00000000..3bc58206 --- /dev/null +++ b/src/askui/models/shared/usage_tracking_callback.py @@ -0,0 +1,77 @@ +"""Callback for tracking token usage and reporting usage summaries.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from opentelemetry import trace +from typing_extensions import override + +from askui.models.shared.agent_message_param import UsageParam +from askui.models.shared.conversation_callback import ConversationCallback +from askui.reporting import NULL_REPORTER, Reporter + +if TYPE_CHECKING: + from askui.models.shared.conversation import Conversation + from askui.speaker.speaker import SpeakerResult + + +class UsageTrackingCallback(ConversationCallback): + """Tracks token usage per step and reports a summary at conversation end. + + Args: + reporter: Reporter to write the final usage summary to. + """ + + def __init__(self, reporter: Reporter = NULL_REPORTER) -> None: + self._reporter = reporter + self._accumulated_usage = UsageParam() + + @override + def on_conversation_start(self, conversation: Conversation) -> None: + self._accumulated_usage = UsageParam() + + @override + def on_step_end( + self, + conversation: Conversation, + step_index: int, + result: SpeakerResult, + ) -> None: + if result.usage: + self._accumulate(result.usage) + + @override + def on_conversation_end(self, conversation: Conversation) -> None: + self._reporter.add_usage_summary(self._accumulated_usage.model_dump()) + + @property + def accumulated_usage(self) -> UsageParam: + """Current accumulated usage statistics.""" + return self._accumulated_usage + + def _accumulate(self, step_usage: UsageParam) -> None: + self._accumulated_usage.input_tokens = ( + self._accumulated_usage.input_tokens or 0 + ) + (step_usage.input_tokens or 0) + self._accumulated_usage.output_tokens = ( + self._accumulated_usage.output_tokens or 0 + ) + (step_usage.output_tokens or 0) + self._accumulated_usage.cache_creation_input_tokens = ( + self._accumulated_usage.cache_creation_input_tokens or 0 + ) + (step_usage.cache_creation_input_tokens or 0) + self._accumulated_usage.cache_read_input_tokens = ( + self._accumulated_usage.cache_read_input_tokens or 0 + ) + (step_usage.cache_read_input_tokens or 0) + + current_span = trace.get_current_span() + current_span.set_attributes( + { + "input_tokens": step_usage.input_tokens or 0, + "output_tokens": step_usage.output_tokens or 0, + "cache_creation_input_tokens": ( + step_usage.cache_creation_input_tokens or 0 + ), + "cache_read_input_tokens": (step_usage.cache_read_input_tokens or 0), + } + ) diff --git a/src/askui/prompts/act_prompts.py b/src/askui/prompts/act_prompts.py index e2c65e87..c28ebc35 100644 --- a/src/askui/prompts/act_prompts.py +++ b/src/askui/prompts/act_prompts.py @@ -86,6 +86,66 @@ * Use appropriate gestures (tap, swipe, drag) based on context * Verify element visibility before interaction""" +MULTI_DEVICE_CAPABILITIES = """You are an autonomous AI agent that can interact +with user interfaces through computer vision and input control. + +* Your primary goal is to execute tasks efficiently and reliably while + maintaining system stability. +* Operate independently and make informed decisions without requiring + user input. +* Focus on completing the exact task given without deviation or expansion. +* Task completion includes all necessary verification and correction steps. +* Ensure actions are repeatable and maintain system stability. +* Optimize operations to minimize latency and resource usage. +* Always verify actions before execution, even with full system access. +TOOL USAGE: +* You will be able to operate 2 devices: an android device, and a computer device. +* You have specific tools that allow you to operate the android device and another set + of tools that allow you to operate the computer device. +* The tool names have a prefix of either 'computer_' or 'android_'. The + 'computer_' tools will operate the computer, the 'android_' tools will + operate the android device. For example, when taking a screenshot, + you will have to use 'computer_screenshot' for taking a screenshot from the + computer, and 'android_screenshot' for taking a screenshot from the android + device. +* Use the most direct and efficient tool for each task +* Combine tools strategically for complex operations +* Prefer built-in tools over shell commands when possible + +**Error Handling:** +* When you cannot find something (application window, ui element etc.) on + the currently selected/active display/screen/device, check the other available + displays by listing them and checking which one is currently active and + then going through the other displays one by one until you find it or + you have checked all of them. Do not forget to also check the other device! +* Assess failures systematically: check tool availability, permissions, + and device state +* Implement retry logic with exponential backoff for transient failures +* Use fallback strategies when primary approaches fail +* Provide clear, actionable error messages with diagnostic information + +**Performance Optimization:** +* On the android device, you can use one-liner shell commands with inline filtering +(grep, cut, awk, jq) for efficiency +* Minimize screen captures and coordinate calculations +* When using your function calls, they take a while to run and send back + to you. Where possible/feasible, try to chain multiple of these calls + all into one function calls request. + +**Screen Interaction:** +* Ensure all coordinates are integers and within screen bounds +* Implement smart scrolling for off-screen elements +* For android, use appropriate gestures (tap, swipe, drag) based on context +* Verify element visibility before interaction +* When asked to perform web tasks try to open the browser (firefox, + chrome, safari, ...) if not already open. Often you can find the + browser icons in the toolbars of the operating systems or in the doc of the android + device. +* On the computer device it can be helpful to zoom in/out when viewing a page +so that you can see everything on the page. Either that, or make sure you scroll + down/up to see everything before deciding something isn't available. +""" + WEB_BROWSER_CAPABILITIES = """You are an autonomous AI agent that can interact with web interfaces through computer vision and browser control. @@ -111,6 +171,18 @@ * Access Level: Full system access * Test Device: Yes, with full permissions""" +MULTI_DEVICE_INFORMATION = f"""* You will be operating two devices a computer and an +android device! Here are some details on both of them: +COMPUTER_DEVICE: +* Platform: {sys.platform} +* Architecture: {platform.machine()} +* Internet Access: Available +ANDROID_DEVICE: +* Device Type: Android device +* Connection: ADB (Android Debug Bridge) +* Access Level: Full system access +* Test Device: Yes, with full permissions""" + WEB_AGENT_DEVICE_INFORMATION = """* Environment: Web browser in full-screen mode * Visibility: Only current webpage content (single tab) * Interaction: Mouse, keyboard, and browser-specific controls""" @@ -172,121 +244,214 @@ * Provide clear, actionable feedback for all operations * Use the most efficient method for each task""" +CACHE_VERIFICATION_PROTOCOL = """ +CACHE EXECUTION VERIFICATION (MANDATORY): +After any cache execution (whether successful, failed, or paused), you MUST +perform the following verification steps before reporting completion: + +1. CAPTURE STATE: Take a screenshot from EVERY device involved in the task: + - If task involves computer: call computer_screenshot + - If task involves Android: call android_screenshot_tool + - Do this EVEN IF the cache execution reported success + +2. ANALYZE COMPLETENESS: For each device, verify against the goal: + - Check if expected UI changes occurred + - Verify expected applications/screens are visible + - Confirm expected data/text is present + +3. IDENTIFY GAPS: Determine which operations completed vs which did not: + - Cache may have completed only computer operations + - Cache may have completed only Android operations + - Cache may have partially completed on one or both devices + +4. COMPLETE REMAINING WORK: If ANY device is not in the expected state: + - DO NOT report the task as complete + - Manually execute the remaining operations + - Continue until ALL devices reach the expected state + +5. FINAL VERIFICATION: Only after completing step 4: + - Take final screenshots of ALL devices + - Confirm ALL goal requirements are met + - Then and ONLY then report task completion + +CRITICAL: The message "[CACHE EXECUTION COMPLETED]" does NOT mean the task +is complete. It only means the cache executor finished replaying cached steps. +You MUST verify actual task completion across ALL devices as described above. +""" + +MULTI_DEVICE_SUCCESS_CRITERIA = """ +A multi-device task/test is ONLY successful when ALL of the following are true: +1. ✓ Every device mentioned in the goal has been verified with a screenshot +2. ✓ Each device's screenshot shows the EXPECTED end state as defined in the goal +3. ✓ All operations specified in the goal have been confirmed complete +4. ✓ No errors or unexpected states are visible on any device +5. ✓ The EXPECTED OUTCOME has been OBSERVED (not just steps executed) + +CRITICAL DISTINCTION: +- EXECUTION SUCCESS: You performed all the steps without technical errors +- TEST SUCCESS: The expected outcome/behavior was actually observed in the UI +- These are NOT the same! A test can FAIL even if execution was successful. + +RED FLAGS - DO NOT report success if: +- ✗ You haven't taken a screenshot from every device in the goal +- ✗ Any device's screenshot doesn't match the expected state +- ✗ You only completed operations on one device but goal mentioned multiple +- ✗ The cache executor completed but you haven't verified device states +- ✗ You see error messages, wrong screens, or unexpected UI on any device +- ✗ The expected outcome was NOT observed (even if all steps were executed) +- ✗ The UI shows different content/state than what the test expected + +IF EXPECTED OUTCOME IS NOT OBSERVED: +- Report the test as FAILED +- Describe what was expected vs what was actually observed +- This likely indicates a defect in the application under test +""" + +MULTI_DEVICE_OPERATION_RULES = """ +MULTI-DEVICE OPERATION RULES: + +1. DEVICE SELECTION: + - You control TWO devices: a computer and an Android device + - Tools have prefixes: 'computer_' for computer, 'android_' for Android + - Example: computer_screenshot vs android_screenshot_tool + - Always use the correct prefix for the target device + +2. TASK EXECUTION: + - Read the goal carefully to identify which device(s) are involved + - Execute operations on the correct device as specified + - Some tasks require operations on BOTH devices + +3. STATE VERIFICATION (CRITICAL): + - After completing operations, verify EVERY device's state + - Take screenshots from ALL devices mentioned in the goal + - Confirm each device reached its expected state + - Do NOT skip verification even if no errors occurred + +4. MULTI-DEVICE SUCCESS: + - Task is complete ONLY when ALL devices are verified correct + - Partial completion (one device done, one not) is NOT success + - Report success only after confirming all devices with screenshots + +5. CACHE EXECUTION WITH MULTIPLE DEVICES: + - Cache execution may include steps for both devices + - Cache completing does NOT guarantee both devices are correct + - Follow CACHE VERIFICATION PROTOCOL after every cache execution + - Verify BOTH devices even if cache reports success +""" + +TEST_OUTCOME_EVALUATION = """ +TEST OUTCOME EVALUATION (CRITICAL): + +IMPORTANT: Completing all test steps does NOT automatically mean the test PASSED. +A test is only successful if the EXPECTED OUTCOME is actually observed. + +DISTINGUISH BETWEEN: +1. EXECUTION COMPLETION: All steps in the test were performed without errors +2. TEST SUCCESS: The expected outcome/behavior was observed in the UI + +EVALUATION PROCESS: +1. After executing all steps, verify the EXPECTED OUTCOME specified in the test +2. Compare the ACTUAL state with the EXPECTED state +3. If they match → Report test as PASSED +4. If they don't match → Report test as FAILED (even if all steps executed successfully) + +WHEN THE EXPECTED OUTCOME IS NOT OBSERVED: +- Clearly state that the test FAILED +- Describe what was EXPECTED (from the test specification/goal) +- Describe what was ACTUALLY observed on screen +- Identify the discrepancy as a potential UI defect or bug +- Include relevant details about the unexpected behavior +- DO NOT report the test as successful + +EXAMPLE OF CORRECT FAILURE REPORTING: +- Test goal: "Click submit button and verify success message appears" +- Steps executed: ✓ Clicked submit button successfully +- Expected outcome: Success message "Order placed" should be visible +- Actual outcome: Error message "Something went wrong" is displayed +- Correct report: TEST FAILED - Expected success message but error was shown. + This indicates a potential defect in the submit functionality. + +COMMON MISTAKES TO AVOID: +- ✗ Reporting success because "all steps completed" when outcome doesn't match +- ✗ Ignoring unexpected error messages or UI states +- ✗ Assuming the UI behaved correctly without verifying the expected outcome +- ✗ Conflating "I did my job" with "the test passed" + +REMEMBER: You are testing software that may have defects. Your job is to +accurately report whether the expected behavior was observed, not just whether +you were able to perform the actions. A UI defect means the TEST FAILED, even +if your execution was flawless. +""" + CACHE_USE_PROMPT = ( - "\n" - " You can use precomputed trajectories to make the execution of the " - "task more robust and faster!\n" - " To do so, first use the RetrieveCachedTestExecutions tool to check " - "which trajectories are available for you.\n" - " The details what each trajectory that is available for you does are " - "at the end of this prompt.\n" - " A trajectory contains all necessary mouse movements, clicks, and " - "typing actions from a previously successful execution.\n" - " If there is a trajectory available for a step you need to take, " - "always use it!\n" + "\n" + "CRITICAL: Before taking ANY action, you MUST first call the" + " retrieve_available_trajectories_tool to check for cached trajectories. If the" + " name of an available cached trajectory matches the one specified by the user," + " you MUST switch to the CacheExecutor speaker using the switch_speaker tool" + " before calling any other tools!\n" + "\n" + "WORKFLOW:\n" + "1. ALWAYS start by calling retrieve_available_trajectories_tool\n" + "2. If a matching cached trajectory exists, switch to CacheExecutor using" + " the switch_speaker tool with speaker_context containing the trajectory details\n" + "3. Only proceed with manual execution if no matching trajectory is available\n" "\n" - " EXECUTING TRAJECTORIES:\n" - " - Use ExecuteCachedTrajectory to execute a cached trajectory\n" - " - You will see all screenshots and results from the execution in " - "the message history\n" - " - After execution completes, verify the results are correct\n" - " - If execution fails partway, you'll see exactly where it failed " - "and can decide how to proceed\n" + "EXECUTING TRAJECTORIES:\n" + "- Use switch_speaker(speaker_name='CacheExecutor', speaker_context={" + "'trajectory_file': '', 'parameter_values': {...}}) to start execution\n" + "- Trajectories contain complete sequences of mouse movements, clicks, and typing" + " from successful executions\n" + "- You'll see all screenshots and results in message history\n" + "- Verify results after execution completes\n" "\n" - " CACHING_PARAMETERS:\n" - " - Trajectories may contain dynamic parameters like " - "{{current_date}} or {{user_name}}\n" - " - When executing a trajectory, check if it requires " - "parameter values\n" - " - Provide parameter values using the parameter_values " - "parameter as a dictionary\n" - " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " - "parameter_values={'current_date': '2025-12-11'})\n" - " - If required parameters are missing, execution will fail with " - "a clear error message\n" + "DYNAMIC PARAMETERS:\n" + "- Trajectories may require parameters like {{current_date}} or {{user_name}}\n" + "- Provide values via parameter_values in the speaker_context\n" + "- Example: switch_speaker(speaker_name='CacheExecutor', speaker_context={" + "'trajectory_file': 'test.json', 'parameter_values': {" + "'current_date': '2025-12-11'}})\n" + "- Missing required parameters will cause execution failure with an error message." + " In that case try again with providing the correct parameters\n" "\n" - " NON-CACHEABLE STEPS:\n" - " - Some tools cannot be cached and require your direct execution " - "(e.g., print_debug, contextual decisions)\n" - " - When trajectory execution reaches a non-cacheable step, it will " - "pause and return control to you\n" - " - You'll receive a NEEDS_AGENT status with the current " - "step index\n" - " - Execute the non-cacheable step manually using your " - "regular tools\n" - " - After completing the non-cacheable step, continue the trajectory " - "using ExecuteCachedTrajectory with start_from_step_index\n" + "NON-CACHEABLE STEPS:\n" + "- Some steps cannot be cached (e.g., print_debug, contextual decisions)\n" + "- Trajectory pauses at non-cacheable steps, returning NEEDS_AGENT status with" + " current step index\n" + "- Execute the non-cacheable step manually\n" + "- Resume by switching to CacheExecutor again with start_from_step_index" + " in the speaker_context\n" "\n" - " CONTINUING TRAJECTORIES:\n" - " - Use ExecuteCachedTrajectory with start_from_step_index to resume " - "execution after handling a non-cacheable step\n" - " - Provide the same trajectory file and the step index where " - "execution should continue\n" - " - Example: ExecuteCachedTrajectory(trajectory_file='test.json', " - "start_from_step_index=5, parameter_values={...})\n" - " - The tool will execute remaining steps from that index onwards\n" + "CONTINUING TRAJECTORIES:\n" + "- Resume after non-cacheable steps: switch_speaker(speaker_name='CacheExecutor'," + " speaker_context={'trajectory_file': 'test.json'," + " 'start_from_step_index': 5, 'parameter_values': {...}})\n" "\n" - " FAILURE HANDLING:\n" - " - If a trajectory fails during execution, you'll see the error " - "message and the step where it failed\n" - " - Analyze the failure: Was it due to UI changes, timing issues, " - "or incorrect state?\n" - " - Options for handling failures:\n" - " 1. Execute the remaining steps manually\n" - " 2. Fix the issue and retry from a specific step using " - "ExecuteCachedTrajectory with start_from_step_index\n" - " 3. Report that the cached trajectory is outdated and needs " - "re-recording\n" + "FAILURE HANDLING:\n" + "- On failure, you'll see the error and failed step index\n" + "- You MUST take one of these actions:\n" + " (1) Execute ALL remaining steps manually to complete the task, OR\n" + " (2) Retry from the failed step using start_from_step_index\n" + "- DO NOT report the task as complete until you have verified success\n" + "- If the cache is consistently failing, mark it as invalid using\n" + " verify_cache_execution with success=False\n" "\n" - " BEST PRACTICES:\n" - " - Always verify results after trajectory execution completes\n" - " - While trajectories work most of the time, occasionally " - "execution can be partly incorrect\n" - " - Make corrections where necessary after cached execution\n" - " - if you need to make any corrections after a trajectory " - "execution, please mark the cached execution as failed\n" - " - If a trajectory consistently fails, it may be invalid and " - "should be re-recorded\n" - " - There might be several trajectories available to you.\n" - " - Their filename is a unique testID.\n" - " - If executed using the ExecuteCachedTrajectory tool, a trajectory " - "will automatically execute all necessary steps for the test with " - "that id.\n" + "PARTIAL COMPLETION:\n" + "- Cache execution completing some steps does NOT mean task completion\n" + "- You MUST verify ALL devices are in the expected state (see CACHE\n" + " VERIFICATION PROTOCOL)\n" + "- Complete any remaining operations before reporting success\n" + "\n" + "BEST PRACTICES:\n" + "- Always verify results after execution\n" + "- Never assume cache success means task success\n" + "- Check EVERY device involved in the task\n" + "- Mark executions as failed if ANY corrections were needed\n" + "- Trajectory filenames are unique test IDs that automatically execute all steps" + " for that test\n" + "\n" ) -CAESR_CAPABILITIES = """ - You are Caesr, a(n) AI {{agent_name}} developed - by AskUI (Germany company), who democratizes automation. - - - - Confident but approachable - you handle complexity so users don't - have to - - Slightly cheeky but always helpful - use humor to make tech less - intimidating - - Direct communicator - no corporate fluff or technical jargon - - Empowering - remind users they don't need to be developers - - Results-focused - "let's make this actually work" attitude - - Anti-elitist - AI should be accessible to everyone, not just - engineers - - - - **When things don't work perfectly (which they won't at first):** - - Frame failures as part of the revolution - "We're literally - pioneering this stuff" - - Be collaborative: "Let's figure this out together" not "You did - something wrong" - - Normalize iteration: "Rome wasn't automated in a day - first attempt - rarely nails it" - - Make prompt improvement feel like skill-building: "You're learning to - speak the language of automation" - - Use inclusive language: "We're all learning how to command these - digital allies" - - Celebrate small wins: "By Jupiter, that's progress!" - - Position debugging as building something lasting: "We're constructing - your personal automation empire" - - """ # ============================================================================= # ActSystemPrompt INSTANCES (recommended usage) # ============================================================================= @@ -313,9 +478,9 @@ def create_computer_agent_prompt( Returns: ActSystemPrompt instance for computer agent """ - combined_rules = BROWSER_SPECIFIC_RULES + combined_rules = f"{BROWSER_SPECIFIC_RULES}\n\n{TEST_OUTCOME_EVALUATION}" if additional_rules: - combined_rules = f"{BROWSER_SPECIFIC_RULES}\n\n{additional_rules}" + combined_rules = f"{combined_rules}\n\n{additional_rules}" return ActSystemPrompt( system_capabilities=COMPUTER_USE_CAPABILITIES, @@ -340,9 +505,9 @@ def create_android_agent_prompt( Returns: ActSystemPrompt instance for Android agent """ - combined_rules = ANDROID_RECOVERY_RULES + combined_rules = f"{ANDROID_RECOVERY_RULES}\n\n{TEST_OUTCOME_EVALUATION}" if additional_rules: - combined_rules = f"{ANDROID_RECOVERY_RULES}\n\n{additional_rules}" + combined_rules = f"{combined_rules}\n\n{additional_rules}" return ActSystemPrompt( system_capabilities=ANDROID_CAPABILITIES, @@ -367,9 +532,9 @@ def create_web_agent_prompt( Returns: ActSystemPrompt instance for web agent """ - combined_rules = BROWSER_INSTALL_RULES + combined_rules = f"{BROWSER_INSTALL_RULES}\n\n{TEST_OUTCOME_EVALUATION}" if additional_rules: - combined_rules = f"{BROWSER_INSTALL_RULES}\n\n{additional_rules}" + combined_rules = f"{combined_rules}\n\n{additional_rules}" return ActSystemPrompt( system_capabilities=WEB_BROWSER_CAPABILITIES, @@ -380,6 +545,48 @@ def create_web_agent_prompt( ) +def create_multidevice_agent_prompt( + ui_information: str = "", + additional_rules: str = "", +) -> ActSystemPrompt: + """ + Create a multi-device agent (super agent) prompt with optional + custom UI information and rules. + + Args: + ui_information: Custom UI-specific information + additional_rules: Additional rules beyond the default browser install rules + + Returns: + ActSystemPrompt instance for multi-device agent + """ + combined_rules = f""" +{CACHE_VERIFICATION_PROTOCOL} + +{MULTI_DEVICE_SUCCESS_CRITERIA} + +{TEST_OUTCOME_EVALUATION} + +{MULTI_DEVICE_OPERATION_RULES} + +BROWSER RULES: +{BROWSER_SPECIFIC_RULES} + +ANDROID RECOVERY: +{ANDROID_RECOVERY_RULES} +""" + if additional_rules: + combined_rules = f"{combined_rules}\n{additional_rules}" + + return ActSystemPrompt( + system_capabilities=MULTI_DEVICE_CAPABILITIES, + device_information=MULTI_DEVICE_INFORMATION, + ui_information=ui_information, + report_format=NO_REPORT_FORMAT, + additional_rules=combined_rules, + ) + + # ============================================================================= # LEGACY PROMPTS (for backwards compatibility) # ============================================================================= @@ -620,20 +827,3 @@ def create_web_agent_prompt( [Synthesize all the information into a cohesive response to the original user prompt] ] """ - - -def caesr_system_prompt( - agent_name: str = "agent", - assistant_prompt: str = "", - metadata: str = "", -) -> ActSystemPrompt: - prompt = CAESR_CAPABILITIES.replace("{{agent_name}}", agent_name) - prompt += "\n" - if assistant_prompt: - prompt += assistant_prompt - prompt += "\n" - prompt += "Metadata of current conversation: " - prompt += "\n" - prompt += metadata - - return ActSystemPrompt(prompt=prompt) diff --git a/src/askui/prompts/caching.py b/src/askui/prompts/caching.py index a89cf224..26de3deb 100644 --- a/src/askui/prompts/caching.py +++ b/src/askui/prompts/caching.py @@ -1,25 +1,36 @@ -CACHE_USE_PROMPT = ( - "\n" - " You can use precomputed trajectories to make the execution of the " - "task more robust and faster!\n" - " To do so, first use the RetrieveCachedTestExecutions tool to check " - "which trajectories are available for you.\n" - " The details what each trajectory that is available for you does are " - "at the end of this prompt.\n" - " A trajectory contains all necessary mouse movements, clicks, and " - "typing actions from a previously successful execution.\n" - " If there is a trajectory available for a step you need to take, " - "always use it!\n" - " You can execute a trajectory with the ExecuteCachedExecution tool.\n" - " After a trajectory was executed, make sure to verify the results! " - "While it works most of the time, occasionally, the execution can be " - "(partly) incorrect. So make sure to verify if everything is filled out " - "as expected, and make corrections where necessary!\n" - " \n" - " \n" - " There are several trajectories available to you.\n" - " Their filename is a unique testID.\n" - " If executed using the ExecuteCachedExecution tool, a trajectory will " - "automatically execute all necessary steps for the test with that id.\n" - " \n" -) +CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT = """You are analyzing UI automation \ +trajectories to identify values that should be parameterized as parameters. + +Identify values that are likely to change between executions, such as: +- Dates and timestamps (e.g., "2025-12-11", "10:30 AM", "2025-12-11T14:30:00Z") +- Usernames, emails, names (e.g., "john.doe", "test@example.com", "John Smith") +- Session IDs, tokens, UUIDs, API keys +- Dynamic text that references current state or time-sensitive information +- File paths with user-specific or time-specific components +- Temporary or generated identifiers + +DO NOT mark as parameters: +- UI element coordinates (x, y positions) +- Fixed button labels or static UI text +- Configuration values that don't change (e.g., timeouts, retry counts) +- Generic action names like "click", "type", "scroll" +- Tool names +- Boolean values or common constants + +For each parameter, provide: +1. A descriptive name in snake_case (e.g., "current_date", "user_email") +2. The actual value found in the trajectory +3. A brief description of what it represents + +Return your analysis as a JSON object with this structure: +{ + "parameters": [ + { + "name": "current_date", + "value": "2025-12-11", + "description": "Current date in YYYY-MM-DD format" + } + ] +} + +If no parameters are found, return an empty parameters array.""" diff --git a/src/askui/reporting.py b/src/askui/reporting.py index beb20e17..cc859a5b 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -31,6 +31,29 @@ def normalize_to_pil_images( return [image] +def truncate_content( + content: Any, + max_string_length: int = 100000, +) -> Any: + """Filter out long strings (i.e. the base64 image data) to keep reports readable.""" + if isinstance(content, str): + if len(content) > max_string_length: + return f"[truncated: {len(content)} characters]" + return content + + if isinstance(content, dict): + return { + key: truncate_content(value, max_string_length) + for key, value in content.items() + } + + if isinstance(content, list): + return [truncate_content(item, max_string_length) for item in content] + + # For other types (int, float, bool, None), return as-is + return content + + class Reporter(ABC): """Abstract base class for reporters. Cannot be instantiated directly. @@ -56,6 +79,35 @@ def add_message( """ raise NotImplementedError + @abstractmethod + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Add usage statistics summary to the report. + + Called at the end of an act() execution with accumulated token usage. + + Args: + usage (dict[str, int | None]): Accumulated usage statistics containing: + - input_tokens: Total input tokens sent to API + - output_tokens: Total output tokens generated + """ + raise NotImplementedError + + @abstractmethod + def add_cache_execution_statistics( + self, original_usage: dict[str, int | None] + ) -> None: + """Add cache execution statistics showing token savings. + + Called when a cached trajectory is executed. The original_usage contains + the token usage from when the cache was originally recorded. + + Args: + original_usage (dict[str, int | None]): Token usage from cache recording: + - input_tokens: Input tokens used during original recording + - output_tokens: Output tokens used during original recording + """ + raise NotImplementedError + @abstractmethod def generate(self) -> None: """Generates the final report. @@ -81,6 +133,16 @@ def add_message( ) -> None: pass + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + pass + + @override + def add_cache_execution_statistics( + self, original_usage: dict[str, int | None] + ) -> None: + pass + @override def generate(self) -> None: pass @@ -114,6 +176,20 @@ def add_message( for reporter in self._reporters: reporter.add_message(role, content, image) + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Add usage summary to all reporters.""" + for reporter in self._reporters: + reporter.add_usage_summary(usage) + + @override + def add_cache_execution_statistics( + self, original_usage: dict[str, int | None] + ) -> None: + """Add cache execution statistics to all reporters.""" + for reporter in self._reporters: + reporter.add_cache_execution_statistics(original_usage) + @override def generate(self) -> None: """Generates the final report.""" @@ -139,6 +215,9 @@ def __init__(self, report_dir: str = "reports") -> None: self.report_dir = Path(report_dir) self.messages: list[dict[str, Any]] = [] self.system_info = self._collect_system_info() + self.usage_summary: dict[str, int | None] | None = None + self.cache_original_usage: dict[str, int | None] | None = None + self._start_time: datetime | None = None def _collect_system_info(self) -> SystemInfo: """Collect system and Python information""" @@ -168,17 +247,34 @@ def add_message( image: Optional[Image.Image | list[Image.Image] | AnnotatedImage] = None, ) -> None: """Add a message to the report.""" + # Track start time from first message + if self._start_time is None: + self._start_time = datetime.now(tz=timezone.utc) + _images = normalize_to_pil_images(image) + _content = truncate_content(content) message = { "timestamp": datetime.now(tz=timezone.utc), "role": role, - "content": self._format_content(content), - "is_json": isinstance(content, (dict, list)), + "content": self._format_content(_content), + "is_json": isinstance(_content, (dict, list)), "images": [self._image_to_base64(img) for img in _images], } self.messages.append(message) + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """Store usage summary for inclusion in the report.""" + self.usage_summary = usage + + @override + def add_cache_execution_statistics( + self, original_usage: dict[str, int | None] + ) -> None: + """Store original cache usage for calculating savings.""" + self.cache_original_usage = original_usage + @override def generate(self) -> None: """Generate an HTML report file. @@ -684,6 +780,68 @@ def generate(self) -> None: +
+

Execution Statistics

+ + {% if execution_time_seconds is not none %} + + + + + {% endif %} + {% if usage_summary is not none %} + {% if usage_summary.get('input_tokens') is not none %} + + + + + {% endif %} + {% if usage_summary.get('output_tokens') is not none %} + + + + + {% endif %} + {% endif %} + {% if cache_original_usage is not none %} + {% if cache_original_usage.get('input_tokens') is not none %} + + + + + {% endif %} + {% if cache_original_usage.get('output_tokens') is not none %} + + + + + {% endif %} + {% endif %} +
Execution Time{{ "%.2f"|format(execution_time_seconds) }} seconds
Input Tokens + {{ "{:,}".format(usage_summary.get('input_tokens')) }} + {% if cache_original_usage and cache_original_usage.get('input_tokens') %} + {% set original = cache_original_usage.get('input_tokens') %} + {% set current = usage_summary.get('input_tokens') %} + {% set saved = original - current %} + {% if saved > 0 and original > 0 %} + {% set savings_pct = (saved / original * 100) %} + ({{ "%.1f"|format(savings_pct) }}% saved via trajectory caching) + {% endif %} + {% endif %} +
Output Tokens + {{ "{:,}".format(usage_summary.get('output_tokens')) }} + {% if cache_original_usage and cache_original_usage.get('output_tokens') %} + {% set original = cache_original_usage.get('output_tokens') %} + {% set current = usage_summary.get('output_tokens') %} + {% set saved = original - current %} + {% if saved > 0 and original > 0 %} + {% set savings_pct = (saved / original * 100) %} + ({{ "%.1f"|format(savings_pct) }}% saved via trajectory caching) + {% endif %} + {% endif %} +
Original Input Tokens{{ "{:,}".format(cache_original_usage.get('input_tokens')) }}
Original Output Tokens{{ "{:,}".format(cache_original_usage.get('output_tokens')) }}
+
+

Conversation Log

@@ -725,10 +883,20 @@ def generate(self) -> None: """ template = Template(template_str) + + # Calculate execution time + end_time = datetime.now(tz=timezone.utc) + execution_time_seconds: float | None = None + if self._start_time is not None: + execution_time_seconds = (end_time - self._start_time).total_seconds() + html = template.render( - timestamp=datetime.now(tz=timezone.utc), + timestamp=end_time, messages=self.messages, system_info=self.system_info, + usage_summary=self.usage_summary, + cache_original_usage=self.cache_original_usage, + execution_time_seconds=execution_time_seconds, ) report_path = ( @@ -811,6 +979,16 @@ def add_message( attachment_type=self.allure.attachment_type.PNG, ) + @override + def add_usage_summary(self, usage: dict[str, int | None]) -> None: + """No-op for AllureReporter - usage is not tracked.""" + + @override + def add_cache_execution_statistics( + self, original_usage: dict[str, int | None] + ) -> None: + """No-op for AllureReporter - cache statistics are not tracked.""" + @override def generate(self) -> None: """No-op for AllureReporter as reports are generated in real-time.""" diff --git a/src/askui/speaker/__init__.py b/src/askui/speaker/__init__.py new file mode 100644 index 00000000..611ea3a6 --- /dev/null +++ b/src/askui/speaker/__init__.py @@ -0,0 +1,22 @@ +"""Speaker module for conversation-based agent architecture. + +This module provides the speaker pattern for managing conversation flow: +- `Speaker`: Abstract base class for conversation speakers +- `SpeakerResult`: Result of a speaker handling a conversation step +- `Speakers`: Collection and manager of speakers +- `Conversation`: Main orchestrator for conversation execution +- `AgentSpeaker`: Default speaker for LLM API calls +- `CacheExecutor`: Speaker for cached trajectory playback +""" + +from .agent_speaker import AgentSpeaker +from .cache_executor import CacheExecutor +from .speaker import Speaker, SpeakerResult, Speakers + +__all__ = [ + "AgentSpeaker", + "CacheExecutor", + "Speaker", + "SpeakerResult", + "Speakers", +] diff --git a/src/askui/speaker/agent_speaker.py b/src/askui/speaker/agent_speaker.py new file mode 100644 index 00000000..07bae697 --- /dev/null +++ b/src/askui/speaker/agent_speaker.py @@ -0,0 +1,181 @@ +"""Agent speaker for normal LLM API interactions.""" + +import logging +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from askui.models.exceptions import MaxTokensExceededError, ModelRefusalError +from askui.models.shared.agent_message_param import MessageParam + +from .speaker import Speaker, SpeakerResult + +if TYPE_CHECKING: + from askui.models.shared.conversation import Conversation + from askui.utils.caching.cache_manager import CacheManager + +logger = logging.getLogger(__name__) + + +class AgentSpeaker(Speaker): + """Speaker that handles normal agent API calls. + + This speaker generates messages from the LLM by: + 1. Making API calls to get agent responses via VlmProvider + 2. Handling stop reasons (max_tokens, refusal) + 3. Returning messages for the Conversation to process + + Tool execution is handled by the Conversation class, not by this speaker. + + The VlmProvider is accessed from the Conversation instance. + """ + + @override + def can_handle(self, conversation: "Conversation") -> bool: # noqa: ARG002 + """AgentSpeaker can always handle normal conversation flow. + + Args: + conversation: The conversation instance + + Returns: + Always True - this is the default speaker + """ + return True + + @override + def handle_step( + self, + conversation: "Conversation", + cache_manager: "CacheManager | None", # noqa: ARG002 + ) -> SpeakerResult: + """Get next message from the agent API. + + This speaker only generates messages from the LLM. Tool execution + is handled by the Conversation class. + + Args: + conversation: The conversation instance with current state + cache_manager: Optional cache manager (not used by this speaker) + + Returns: + SpeakerResult with the agent's message + """ + messages = conversation.get_messages() + truncation_strategy = conversation.get_truncation_strategy() + + if not truncation_strategy: + logger.error("No truncation strategy available") + return SpeakerResult(status="failed") + + # Only call agent if last message is from user + if not messages or messages[-1].role != "user": + logger.debug("Last message not from user, nothing to do") + return SpeakerResult(status="done") + + # Make API call to get agent response using VlmProvider + try: + response = conversation.vlm_provider.create_message( + messages=truncation_strategy.messages, + tools=conversation.tools, + max_tokens=conversation.settings.messages.max_tokens, + system=conversation.settings.messages.system, + thinking=conversation.settings.messages.thinking, + tool_choice=conversation.settings.messages.tool_choice, + temperature=conversation.settings.messages.temperature, + provider_options=conversation.settings.messages.provider_options, + ) + + # Log response + logger.debug("Agent response: %s", response.model_dump(mode="json")) + + except Exception: + logger.exception("Error calling agent API") + return SpeakerResult(status="failed") + + # Handle stop reason + try: + self._handle_stop_reason( + response, conversation.settings.messages.max_tokens + ) + except (MaxTokensExceededError, ModelRefusalError): + logger.exception("Agent stopped with error") + return SpeakerResult(status="failed", messages_to_add=[response]) + + # Check for switch_speaker tool call + switch_info = self._extract_switch_speaker(response) + if switch_info: + speaker_name, speaker_context = switch_info + return SpeakerResult( + status="switch_speaker", + next_speaker=speaker_name, + speaker_context=speaker_context, + messages_to_add=[response], + usage=response.usage, + ) + + return SpeakerResult( + status="done", + messages_to_add=[response], + usage=response.usage, + ) + + @override + def get_name(self) -> str: + """Return speaker name. + + Returns: + "AgentSpeaker" + """ + return "AgentSpeaker" + + @override + def get_description(self) -> str: + """AgentSpeaker is the default coordinator and not a handoff target. + + Returns: + Empty string. + """ + return "" + + def _extract_switch_speaker( + self, message: MessageParam + ) -> tuple[str, dict[str, Any]] | None: + """Extract switch_speaker tool call from message if present. + + Args: + message: The assistant message to inspect. + + Returns: + Tuple of (speaker_name, speaker_context) if found, None otherwise. + """ + if isinstance(message.content, str): + return None + + for block in message.content: + if block.type == "tool_use" and block.name.startswith("switch_speaker"): + input_data: dict[str, Any] = ( + dict(block.input) if isinstance(block.input, dict) else {} + ) + speaker_name = str(input_data.get("speaker_name", "")) + speaker_context: dict[str, Any] = ( + input_data.get("speaker_context", {}) or {} + ) + return speaker_name, speaker_context + + return None + + def _handle_stop_reason(self, message: MessageParam, max_tokens: int) -> None: + """Handle agent stop reasons. + + Args: + message: Message to check stop reason + max_tokens: Maximum tokens configured + + Raises: + MaxTokensExceededError: If agent stopped due to max tokens + ModelRefusalError: If agent refused the request + """ + if message.stop_reason == "max_tokens": + raise MaxTokensExceededError(max_tokens) + if message.stop_reason == "refusal": + raise ModelRefusalError diff --git a/src/askui/speaker/cache_executor.py b/src/askui/speaker/cache_executor.py new file mode 100644 index 00000000..8e91f1ec --- /dev/null +++ b/src/askui/speaker/cache_executor.py @@ -0,0 +1,695 @@ +"""Cache Executor speaker for executing cached trajectories.""" + +import logging +import time +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from PIL import Image +from pydantic import BaseModel, Field +from typing_extensions import Literal, override + +from askui.models.shared.agent_message_param import ( + MessageParam, + TextBlockParam, + ToolUseBlockParam, +) +from askui.models.shared.settings import CacheExecutionSettings +from askui.utils.caching.cache_manager import CacheManager +from askui.utils.caching.cache_parameter_handler import CacheParameterHandler +from askui.utils.visual_validation import ( + compute_ahash, + compute_hamming_distance, + compute_phash, + extract_region, + find_recent_screenshot, + get_validation_coordinate, +) + +from .speaker import Speaker, SpeakerResult + +if TYPE_CHECKING: + from askui.models.shared.conversation import Conversation + from askui.models.shared.settings import CacheFile + from askui.models.shared.tools import ToolCollection + from askui.reporting import Reporter + +logger = logging.getLogger(__name__) + + +class ExecutionResult(BaseModel): + """Result of executing a single step in a trajectory. + + Attributes: + status: Execution status (SUCCESS, FAILED, NEEDS_AGENT, COMPLETED) + step_index: Index of the step that was executed + tool_result: The tool result or tool use block for reference + error_message: Error message if execution failed + screenshots_taken: List of screenshots captured during this step + message_history: List of MessageParam representing the conversation history + """ + + status: Literal["SUCCESS", "FAILED", "NEEDS_AGENT", "COMPLETED"] + step_index: int + tool_result: Any | None = None + error_message: str | None = None + screenshots_taken: list[Any] = Field(default_factory=list) + message_history: list[MessageParam] = Field(default_factory=list) + + +class CacheExecutor(Speaker): + """Speaker that handles cached trajectory playback. + + This speaker generates messages from a cached trajectory: + 1. Get next step from cached trajectory (as tool use message) + 2. Track progress through the trajectory + 3. Pause for non-cacheable tools (switch to agent) + 4. Handle completion (switch to agent for verification) + 5. Handle failures (update metadata, switch to agent) + + Tool execution is handled by the Conversation class, not by this speaker. + """ + + def __init__( + self, execution_settings: CacheExecutionSettings | None = None + ) -> None: + """Initialize Cache Executor speaker. + + Args: + execution_settings: Settings for cache execution including delay time, + visual validation threshold, etc. If None, default settings are used. + """ + _settings = execution_settings or CacheExecutionSettings() + + # Cache execution state + self._executing_from_cache: bool = False + self._cache_verification_pending: bool = False + self._cache_file: "CacheFile | None" = None + self._cache_file_path: str | None = None + + # Cache Execution Settings + self._skip_visual_validation: bool = _settings.skip_visual_validation + self._visual_validation_threshold: int = _settings.visual_validation_threshold + self._delay_time_between_actions: float = _settings.delay_time_between_actions + + self._trajectory: list[ToolUseBlockParam] = [] + self._toolbox: "ToolCollection | None" = None + self._parameter_values: dict[str, str] = {} + self._visual_validation_enabled: bool = False + self._visual_validation_method: str = "phash" + self._visual_validation_region_size: int = 100 + + self._current_step_index: int = 0 + self._message_history: list[MessageParam] = [] + + # Activation context received via on_activate() + self._activation_context: dict[str, Any] = {} + + @override + def can_handle(self, conversation: "Conversation") -> bool: # noqa: ARG002 + """Check if cache execution is active or should be activated. + + Args: + conversation: The conversation instance. + + Returns: + True if in cache execution mode or if activation context is set. + """ + can_handle_step = self._executing_from_cache or bool(self._activation_context) + if not can_handle_step: + logger.warning("CacheExecutor can't handle the next step") + return can_handle_step + + @override + def handle_step( + self, conversation: "Conversation", cache_manager: "CacheManager | None" + ) -> SpeakerResult: + """Get next cached step message. + + This speaker only generates messages (tool use blocks from cache). + Tool execution is handled by the Conversation class. + + Args: + conversation: The conversation instance with current state + cache_manager: Cache manager for recording/playback + + Returns: + SpeakerResult with the next cached tool use message + """ + if cache_manager is None: + error_msg = "CacheManager must be provided if executing from Cache" + raise RuntimeError(error_msg) + + # Check if we need to activate cache execution from internal context + if not self._executing_from_cache and self._activation_context: + try: + # Augment context with toolbox and reporter from conversation + activation_context = { + **self._activation_context, + "toolbox": conversation.tools, + "reporter": conversation._reporter, # noqa: SLF001 + } + self._activate_from_context(activation_context, cache_manager) + self._activation_context = {} + conversation._executed_from_cache = True # noqa: SLF001 + except Exception as e: + # Validation or loading failed - report error and switch back + logger.exception("Failed to activate cache execution") + self._activation_context = {} + + error_message = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=f"Cache execution failed: {e}", + ) + ], + ) + + return SpeakerResult( + status="switch_speaker", + next_speaker="AgentSpeaker", + messages_to_add=[error_message], + ) + + messages = conversation.get_messages() + + # Check if last message was a tool result - if so, move to next step + if messages and messages[-1].role == "user": + # Last message is tool result, check if it's from our current step + if self._message_history and len(self._message_history) > 0: + # Tool was executed, move to next step + self._current_step_index += 1 + # Add delay between actions + if self._current_step_index < len(self._trajectory): + time.sleep(self._delay_time_between_actions) + + # Check if we have a trajectory + if not self._trajectory or not self._toolbox: + logger.error("Cache executor called but no trajectory or toolbox available") + return SpeakerResult( + status="switch_speaker", + next_speaker="AgentSpeaker", + ) + + # Get next step from cache (doesn't execute, just prepares the message) + logger.debug("Getting next step from cache") + result: ExecutionResult = self._get_next_step(conversation_messages=messages) + + # Handle result based on status + return self._handle_result(result, cache_manager) + + @override + def get_name(self) -> str: + """Return speaker name. + + Returns: + "CacheExecutor" + """ + return "CacheExecutor" + + @override + def get_description(self) -> str: + """Return description of CacheExecutor and expected context keys. + + Returns: + Description with expected context keys for activation. + """ + return ( + "Replays a pre-recorded UI interaction trajectory from a cache " + "file. Use this speaker to fast-forward through previously " + "recorded action sequences instead of executing each step from " + "scratch.\n" + "Expected context keys:\n" + " - trajectory_file (str, required): Full path to the " + "trajectory file\n" + " - start_from_step_index (int, optional, default=0): Step " + "index to start from\n" + " - parameter_values (dict[str, str], optional, default={}): " + "Dynamic parameter values for the trajectory" + ) + + @override + def on_activate(self, context: dict[str, Any]) -> None: + """Store activation context for use in the next `handle_step()` call. + + Args: + context: Dict with trajectory_file, start_from_step_index, + parameter_values. + """ + self._activation_context = context + + def _handle_result( + self, result: ExecutionResult, cache_manager: "CacheManager" + ) -> SpeakerResult: + """Handle execution result and return appropriate SpeakerResult.""" + if result.status == "SUCCESS": + return self._handle_success(result) + if result.status == "NEEDS_AGENT": + return self._handle_needs_agent(result) + if result.status == "COMPLETED": + return self._handle_completed(result) + # FAILED + return self._handle_failed(cache_manager, result) + + def _handle_success(self, result: ExecutionResult) -> SpeakerResult: + """Handle successful preparation of next cache step.""" + if not result.message_history: + return SpeakerResult( + status="switch_speaker", + next_speaker="AgentSpeaker", + ) + + # Get assistant message (tool use) + assistant_msg = result.message_history[-1] + + # Store this message for tracking + self._message_history.append(assistant_msg) + + # Continue with cache execution + return SpeakerResult( + status="continue", + messages_to_add=[assistant_msg], + ) + + def _handle_needs_agent(self, result: ExecutionResult) -> SpeakerResult: + """Handle cache execution pausing for non-cacheable tool.""" + logger.info( + "Paused cache execution at step %d " + "(non-cacheable tool - agent will handle this step)", + result.step_index, + ) + self._executing_from_cache = False + + tool_to_execute = result.tool_result + + if tool_to_execute: + instruction_message = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=( + f"Cache execution paused at step {result.step_index}. " + "The previous steps were executed successfully " + f"from cache. The next step requires the " + f"'{tool_to_execute.name}' tool, which cannot be " + "executed from cache. Please execute this tool with " + "the necessary parameters." + ), + ) + ], + ) + + return SpeakerResult( + status="switch_speaker", + next_speaker="AgentSpeaker", + messages_to_add=[instruction_message], + ) + + return SpeakerResult( + status="switch_speaker", + next_speaker="AgentSpeaker", + ) + + def _handle_completed( + self, + result: ExecutionResult, # noqa: ARG002 + ) -> SpeakerResult: + """Handle cache execution completion.""" + logger.info( + "Cache trajectory execution completed - requesting agent verification" + ) + self._executing_from_cache = False + self._cache_verification_pending = True + + verification_request = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=( + "[CACHE EXECUTION COMPLETED]\n\n" + "The CacheExecutor has automatically executed" + f" {len(self._trajectory)} steps from the cached trajectory" + f" '{self._cache_file_path}'. All previous tool calls in this" + " conversation were replayed from cache, not performed by the" + " agent.\n\n Please verify if the cached execution correctly" + " achieved the target system state using the" + " verify_cache_execution tool." + ), + ) + ], + ) + + return SpeakerResult( + status="switch_speaker", + next_speaker="AgentSpeaker", + messages_to_add=[verification_request], + ) + + def _handle_failed( + self, cache_manager: CacheManager, result: ExecutionResult + ) -> SpeakerResult: + """Handle cache execution failure.""" + logger.error( + "Cache execution failed at step %d: %s", + result.step_index, + result.error_message, + ) + self._executing_from_cache = False + + # Update cache metadata + if self._cache_file and self._cache_file_path: + cache_manager.update_metadata_on_failure( + cache_file=self._cache_file, + cache_file_path=self._cache_file_path, + step_index=result.step_index, + error_message=result.error_message or "Unknown error", + ) + + # Add failure message to inform the agent about what happened + failure_message = MessageParam( + role="user", + content=[ + TextBlockParam( + type="text", + text=( + "[CACHE EXECUTION FAILED]\n\n" + f"The CacheExecutor failed to execute the cached trajectory " + f"'{self._cache_file_path}' at step {result.step_index}.\n\n" + f"Error: {result.error_message}\n\n" + "The cache file is potentially invalid. " + "Please complete the remaining steps manually. After that, use " + "the verify_cache_execution tool with success=False to " + "potentially invalidate the cache file." + ), + ) + ], + ) + + return SpeakerResult( + status="switch_speaker", + next_speaker="AgentSpeaker", + messages_to_add=[failure_message], + ) + + def _activate_from_context( + self, context: dict[str, Any], cache_manager: "CacheManager" + ) -> None: + """Activate cache execution from conversation context. + + Args: + context: Dict containing cache execution parameters + cache_manager: Cache manager for loading cache files + + Raises: + FileNotFoundError: If cache file doesn't exist + ValueError: If validation fails + """ + # Extract parameters + trajectory_file: str = context["trajectory_file"] + start_from_step_index: int = context.get("start_from_step_index", 0) + parameter_values: dict[str, Any] = context.get("parameter_values", {}) + toolbox: "ToolCollection" = context["toolbox"] + + logger.debug("Activating cache execution from: %s", trajectory_file) + + # Load and validate cache file + if not self._cache_file_path or self._cache_file_path != trajectory_file: + self._cache_file_path = trajectory_file + self._cache_file = cache_manager.read_cache_file(Path(trajectory_file)) + else: + if not self._cache_file: + self._cache_file = cache_manager.read_cache_file(Path(trajectory_file)) + + # Validate step index + if start_from_step_index < 0 or start_from_step_index >= len( + self._cache_file.trajectory + ): + error_msg = ( + f"Invalid start_from_step_index: {start_from_step_index}. " + f"Trajectory has {len(self._cache_file.trajectory)} steps " + f"(valid indices: 0-{len(self._cache_file.trajectory) - 1})." + ) + raise ValueError(error_msg) + + # Validate parameters + is_valid, missing = CacheParameterHandler.validate_parameters( + self._cache_file.trajectory, parameter_values + ) + if not is_valid: + error_msg = ( + f"Missing required parameter values: {', '.join(missing)}. " + f"The trajectory contains the following parameters: " + f"{', '.join(self._cache_file.cache_parameters.keys())}. " + "Please provide values for all parameters." + ) + raise ValueError(error_msg) + + # Warn if cache is invalid + if not self._cache_file.metadata.is_valid: + logger.warning( + "Using invalid cache from %s. Reason: %s. " + "This cache may not work correctly.", + Path(trajectory_file).name, + self._cache_file.metadata.invalidation_reason, + ) + + # Set up execution state + self._trajectory = self._cache_file.trajectory + self._toolbox = toolbox + self._parameter_values = parameter_values + self._current_step_index = start_from_step_index + self._message_history = [] + self._executing_from_cache = True + + # Configure visual validation + visual_validation_config = self._cache_file.metadata.visual_validation + + if self._skip_visual_validation: + self._visual_validation_enabled = False + logger.info("Visual validation disabled by execution settings") + elif visual_validation_config and visual_validation_config.enabled: + self._visual_validation_enabled = True + self._visual_validation_method = visual_validation_config.method + self._visual_validation_region_size = visual_validation_config.region_size + logger.info( + "Visual validation enabled (method=%s, threshold=%d)", + self._visual_validation_method, + self._visual_validation_threshold, + ) + else: + self._visual_validation_enabled = False + logger.debug("Visual validation disabled or not configured") + + logger.info( + "Cache execution activated: %s (%d steps, starting from step %d)", + Path(trajectory_file).name, + len(self._cache_file.trajectory), + start_from_step_index, + ) + + # Report cache execution statistics to the reporter + reporter: Reporter | None = context.get("reporter") + if reporter and self._cache_file.metadata.token_usage: + reporter.add_cache_execution_statistics( + self._cache_file.metadata.token_usage.model_dump() + ) + + def reset_state(self) -> None: + """Reset cache execution state.""" + self._executing_from_cache = False + self._cache_verification_pending = False + self._cache_file = None + self._cache_file_path = None + self._trajectory = [] + self._toolbox = None + self._parameter_values = {} + self._current_step_index = 0 + self._message_history = [] + self._activation_context = {} + + def _get_next_step( + self, conversation_messages: list[MessageParam] | None = None + ) -> ExecutionResult: + """Get the next step message from the trajectory. + + This method does NOT execute tools - it only prepares the message. + Tool execution is handled by the Conversation class. + + Args: + conversation_messages: Optional conversation messages for visual validation + + Returns: + ExecutionResult with status and the prepared message + """ + # Check if we've completed all steps + if self._current_step_index >= len(self._trajectory): + return ExecutionResult( + status="COMPLETED", + step_index=self._current_step_index - 1, + message_history=self._message_history, + ) + + step = self._trajectory[self._current_step_index] + step_index = self._current_step_index + + # Check if step should be skipped + if self._should_skip_step(step): + logger.debug("Skipping step %d: %s", step_index, step.name) + self._current_step_index += 1 + return self._get_next_step(conversation_messages=conversation_messages) + + # Check if step needs agent intervention (non-cacheable) + if self._should_pause_for_agent(step): + logger.info( + "Pausing at step %d: %s (non-cacheable tool)", + step_index, + step.name, + ) + return ExecutionResult( + status="NEEDS_AGENT", + step_index=step_index, + message_history=self._message_history.copy(), + tool_result=step, + ) + + # Visual validation + if self._visual_validation_enabled: + current_screenshot = None + if conversation_messages: + current_screenshot = find_recent_screenshot(conversation_messages) + + is_valid, error_msg = self._validate_step_visually( + step, current_screenshot=current_screenshot + ) + if not is_valid: + return ExecutionResult( + status="FAILED", + step_index=step_index, + error_message=error_msg, + message_history=self._message_history.copy(), + ) + + # Substitute parameters + substituted_step = CacheParameterHandler.substitute_parameters( + step, self._parameter_values + ) + + # Create assistant message (tool use) - DON'T execute yet + try: + logger.debug("Preparing step %d: %s", step_index, step.name) + + assistant_message = MessageParam( + role="assistant", + content=[substituted_step], + ) + + return ExecutionResult( + status="SUCCESS", + step_index=step_index, + tool_result=None, + message_history=[assistant_message], + ) + + except Exception as e: + logger.exception("Error preparing step %d: %s", step_index, step.name) + return ExecutionResult( + status="FAILED", + step_index=step_index, + error_message=str(e), + message_history=self._message_history.copy(), + ) + + def _should_pause_for_agent(self, step: ToolUseBlockParam) -> bool: + """Check if execution should pause for agent intervention.""" + if not self._toolbox: + return False + + # Try exact match first, then prefix match (for tools with UUID suffixes) + tool = self._toolbox.tool_map.get(step.name) + if tool is None: + tool = self._toolbox.find_tool_by_prefix(step.name) + + if tool is None: + # Tool not found - should pause for agent to handle + return True + + return not tool.is_cacheable + + def _should_skip_step(self, step: ToolUseBlockParam) -> bool: + """Check if a step should be skipped during execution.""" + # Use startswith() to handle tool names with UUID suffixes + tools_to_skip: list[str] = [ + "retrieve_available_trajectories_tool", + "switch_speaker", + "execute_cached_executions_tool", # backward compat for old caches + ] + return any(step.name.startswith(prefix) for prefix in tools_to_skip) + + def _validate_step_visually( + self, step: ToolUseBlockParam, current_screenshot: Image.Image | None = None + ) -> tuple[bool, str | None]: + """Validate cached step using visual hash comparison.""" + if not self._visual_validation_enabled: + return True, None + + if not step.visual_representation: + return True, None + + if current_screenshot is None: + current_screenshot = find_recent_screenshot(self._message_history) + if not current_screenshot: + logger.warning("No screenshot found for visual validation, skipping") + return True, None + + try: + # Extract coordinate using the same logic as cache_manager + tool_input: dict[str, Any] = ( + step.input if isinstance(step.input, dict) else {} + ) + coordinate = get_validation_coordinate(tool_input) + + if coordinate is None: + # No coordinate found - skip visual validation for this step + logger.info( + "No coordinate found in step input, skipping visual validation" + ) + return True, None + + # Pass coordinate in the format extract_region expects + region = extract_region( + current_screenshot, + {"coordinate": list(coordinate)}, + region_size=self._visual_validation_region_size, + ) + + if self._visual_validation_method == "phash": + current_hash = compute_phash(region, hash_size=8) + else: + current_hash = compute_ahash(region, hash_size=8) + + distance = compute_hamming_distance( + step.visual_representation, current_hash + ) + + if distance > self._visual_validation_threshold: + error_msg = ( + f"Visual validation failed: UI has changed significantly. " + f"Hamming distance: {distance} > threshold: " + f"{self._visual_validation_threshold}." + ) + return False, error_msg + + logger.debug( + "Visual validation passed (distance=%d, threshold=%d)", + distance, + self._visual_validation_threshold, + ) + return True, None # noqa: TRY300 + + except Exception as e: + logger.exception("Failed to perform visual validation") + return True, f"Visual validation skipped due to error: {e}" diff --git a/src/askui/speaker/speaker.py b/src/askui/speaker/speaker.py new file mode 100644 index 00000000..d1f84c1b --- /dev/null +++ b/src/askui/speaker/speaker.py @@ -0,0 +1,174 @@ +"""Base speaker class and result types for conversation architecture.""" + +import logging +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Iterator + +from pydantic import BaseModel, Field +from typing_extensions import Literal + +from askui.models.shared.agent_message_param import MessageParam, UsageParam + +if TYPE_CHECKING: + from askui.models.shared.conversation import Conversation + from askui.utils.caching.cache_manager import CacheManager + +logger = logging.getLogger(__name__) + +SPEAKER_RESULT_STATUS = Literal["continue", "switch_speaker", "done", "failed"] + + +class SpeakerResult(BaseModel): + """Result of a speaker handling a conversation step. + + Attributes: + status: Execution status + - "continue": Continue with same speaker (recurse) + - "switch_speaker": Switch to a different speaker + - "done": Conversation finished successfully + - "failed": Conversation failed with error + next_speaker: Name of next speaker to switch to + (required if status="switch_speaker") + messages_to_add: Messages to add to conversation history + usage: Token usage from this step (if applicable) + """ + + status: SPEAKER_RESULT_STATUS + next_speaker: str | None = None + messages_to_add: list[MessageParam] = Field(default_factory=list) + usage: UsageParam | None = None + speaker_context: dict[str, Any] | None = None + + +class Speaker(ABC): + """Abstract base class for conversation speakers. + + A speaker handles specific types of conversation steps (e.g., normal agent + API calls, cache execution, human-in-the-loop). + + Each speaker determines whether it can handle the current conversation state + and executes one step when activated. + """ + + @abstractmethod + def can_handle(self, conversation: "Conversation") -> bool: + """Check if this speaker can handle the current conversation state. + + Args: + conversation: The conversation instance with current state + + Returns: + True if this speaker can handle the current state + """ + ... + + @abstractmethod + def handle_step( + self, conversation: "Conversation", cache_manager: "CacheManager | None" + ) -> SpeakerResult: + """Execute one conversation step. + + Args: + conversation: The conversation instance with current state + cache_manager: Optional cache manager for recording/playback + + Returns: + SpeakerResult indicating what to do next + """ + ... + + @abstractmethod + def get_name(self) -> str: + """Return the speaker's name for logging and identification. + + Returns: + Speaker name (e.g., "AgentSpeaker", "CacheExecutor") + """ + ... + + @abstractmethod + def get_description(self) -> str: + """Return a description of what this speaker does. + + This description is auto-populated into the system prompt so the + coordinating speaker knows which speakers are available and when + to hand off to them. Return an empty string for speakers that + should not be handoff targets (e.g., the default coordinator). + + Returns: + Human-readable description including expected context keys. + """ + ... + + def on_activate(self, context: dict[str, Any]) -> None: # noqa: B027 + """Called when this speaker is activated via a speaker switch with context. + + Override in subclasses that need activation context. + The default implementation does nothing. + + Args: + context: Activation context passed from the switch_speaker tool. + """ + + +class Speakers: + """Collection and manager of conversation speakers. + + Holds a dictionary of speakers and tracks the default speaker. + Provides dictionary-like access to speakers by name. + """ + + def __init__(self, speakers: dict[str, Speaker] | None = None) -> None: + # Lazy import to avoid circular dependency + from .agent_speaker import AgentSpeaker + + self.speakers: dict[str, Speaker] = speakers or {"AgentSpeaker": AgentSpeaker()} + self.default_speaker: str = ( + "AgentSpeaker" + if "AgentSpeaker" in self.speakers + else next(iter(self.speakers.keys())) + ) + + def add_speaker(self, speaker: Speaker) -> None: + """Add a speaker to the collection.""" + self.speakers[speaker.get_name()] = speaker + + def get_names(self) -> list[str]: + """Get list of all speaker names.""" + return list(self.speakers.keys()) + + def __add__(self, other: "Speakers") -> "Speakers": + """Combine two Speakers collections.""" + result = Speakers(speakers={}) + result.speakers = self.speakers | other.speakers + result.default_speaker = self.default_speaker + return result + + def __getitem__(self, name: str) -> Speaker: + """Get speaker by name, falling back to default if not found.""" + if name in self.speakers: + return self.speakers[name] + msg = ( + f"Speaker {name} is not part of Speakers. " + f"Will use default Speaker {self.default_speaker} instead" + ) + logger.warning(msg) + return self.speakers[self.default_speaker] + + def __contains__(self, name: str) -> bool: + """Check if a speaker exists in the collection.""" + return name in self.speakers + + def __iter__(self) -> "Iterator[Speaker]": + """Iterate over all speakers in the collection.""" + return iter(self.speakers.values()) + + def reset_state(self) -> None: + """Reset state for all speakers that have a reset_state method. + + This allows stateful speakers (e.g., CacheExecutor) to be reset + between conversation runs. + """ + for speaker in self.speakers.values(): + if hasattr(speaker, "reset_state") and callable(speaker.reset_state): + speaker.reset_state() diff --git a/src/askui/tools/android/tools.py b/src/askui/tools/android/tools.py index 48fab49f..c85c22ea 100644 --- a/src/askui/tools/android/tools.py +++ b/src/askui/tools/android/tools.py @@ -28,6 +28,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True @override def __call__(self) -> tuple[str, Image.Image]: @@ -83,6 +84,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True @override def __call__( @@ -134,6 +136,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self, text: str) -> str: @@ -191,6 +194,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True @override def __call__(self, x1: int, y1: int, x2: int, y2: int, duration: int = 1000) -> str: @@ -224,6 +228,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self, key_name: ANDROID_KEY) -> str: @@ -298,6 +303,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True @override def __call__(self, x1: int, y1: int, x2: int, y2: int, duration: int = 1000) -> str: @@ -354,6 +360,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self, keys: list[ANDROID_KEY], duration: int = 100) -> str: @@ -397,6 +404,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self, command: str) -> str: @@ -415,6 +423,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None): description="Can be used to get all connected devices serial numbers.", agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self) -> str: @@ -434,6 +443,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None): "current selected device.", agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self) -> str: @@ -455,6 +465,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None): """, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self) -> str: @@ -488,6 +499,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None): }, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self, device_sn: str) -> str: @@ -516,6 +528,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None): }, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self, display_unique_id: int) -> str: @@ -536,6 +549,7 @@ def __init__(self, agent_os: AndroidAgentOsFacade | None = None): """, agent_os=agent_os, ) + self.is_cacheable = True @override def __call__(self) -> str: diff --git a/src/askui/tools/caching_tools.py b/src/askui/tools/caching_tools.py index 4abb3a55..c491d946 100644 --- a/src/askui/tools/caching_tools.py +++ b/src/askui/tools/caching_tools.py @@ -1,13 +1,11 @@ import logging -import time from pathlib import Path from pydantic import validate_call from typing_extensions import override -from ..models.shared.settings import CachedExecutionToolSettings -from ..models.shared.tools import Tool, ToolCollection -from ..utils.cache_writer import CacheWriter +from ..models.shared.tools import Tool +from ..utils.caching.cache_manager import CacheManager logger = logging.getLogger() @@ -27,54 +25,198 @@ def __init__(self, cache_dir: str, trajectories_format: str = ".json") -> None: "replayed using the execute_trajectory_tool. Call this tool " "first to see which trajectories are available before " "executing one. The tool returns a list of file paths to " - "available trajectory files." + "available trajectory files.\n\n" + "By default, only valid (non-invalidated) caches are returned. " + "Set include_invalid=True to see all caches including those " + "marked as invalid due to repeated failures." ), + input_schema={ + "type": "object", + "properties": { + "include_invalid": { + "type": "boolean", + "description": ( + "Whether to include invalid/invalidated caches in " + "the results. Default is False (only show valid " + "caches)." + ), + "default": False, + }, + }, + "required": [], + }, ) self._cache_dir = Path(cache_dir) self._trajectories_format = trajectories_format + self.is_cacheable = True @override @validate_call - def __call__(self) -> list[str]: # type: ignore + def __call__(self, include_invalid: bool = False) -> list[str]: # type: ignore + """Retrieve available cached trajectories. + + Args: + include_invalid: Whether to include invalid caches + + Returns: + List of strings with filename and parameters info. + """ + logger.info( + "Retrieving cached trajectories from %s (include_invalid=%s)", + self._cache_dir, + include_invalid, + ) + if not Path.is_dir(self._cache_dir): error_msg = f"Trajectories directory not found: {self._cache_dir}" logger.error(error_msg) raise FileNotFoundError(error_msg) - available = [ - str(f) + all_files = [ + f for f in self._cache_dir.iterdir() if str(f).endswith(self._trajectories_format) ] + logger.debug("Found %d total cache files", len(all_files)) + + available: list[str] = [] + invalid_count = 0 + unreadable_count = 0 + + for f in all_files: + try: + cache_file = CacheManager.read_cache_file(f) + + # Check if we should include this cache + if not include_invalid and not cache_file.metadata.is_valid: + invalid_count += 1 + logger.debug( + "Excluding invalid cache: %s (reason: %s)", + f.name, + cache_file.metadata.invalidation_reason, + ) + continue + + # Add cache info with filename and parameters + available.append( + f"filename: {f!s} (parameters: {cache_file.cache_parameters})" + ) + + except Exception: # noqa: PERF203 + unreadable_count += 1 + logger.exception("Failed to read cache file %s", f.name) + continue + + logger.info( + "Found %d cache(s), excluded %d invalid, %d unreadable", + len(available), + invalid_count, + unreadable_count, + ) if not available: - warning_msg = f"Warning: No trajectory files found in {self._cache_dir}" + if include_invalid: + warning_msg = f"Warning: No trajectory files found in {self._cache_dir}" + else: + warning_msg = ( + f"Warning: No valid trajectory files found in " + f"{self._cache_dir}. " + "Try include_invalid=True to see all caches." + ) logger.warning(warning_msg) return available -class ExecuteCachedTrajectory(Tool): +class VerifyCacheExecution(Tool): + """Tool for agent to explicitly report cache execution verification results.""" + + def __init__(self) -> None: + super().__init__( + name="verify_cache_execution", + description=( + "IMPORTANT: Call this tool immediately after reviewing a " + "cached trajectory execution.\n\n" + "Report whether the cached execution successfully achieved " + "the target system state. You MUST call this tool to complete " + "the cache verification process.\n\n" + "Set success=True if:\n" + "- The cached execution correctly achieved the intended goal\n" + "- The final state matches what was expected\n" + "- No corrections or additional actions were needed\n\n" + "Set success=False if:\n" + "- The execution did not achieve the target state\n" + "- You had to make corrections or perform additional actions\n" + "- The final state is incorrect or incomplete" + ), + input_schema={ + "type": "object", + "properties": { + "success": { + "type": "boolean", + "description": ( + "True if cached execution correctly " + "achieved target state, " + "False if execution was incorrect or " + "corrections were needed" + ), + }, + "verification_notes": { + "type": "string", + "description": ( + "Brief explanation of what you verified. " + "If success=False, describe what was " + "wrong and what corrections you made." + ), + }, + }, + "required": ["success", "verification_notes"], + }, + ) + self.is_cacheable = False # Verification is not cacheable + + @override + @validate_call + def __call__(self, success: bool, verification_notes: str) -> str: + """Record cache verification result. + + Args: + success: Whether cache execution achieved target state + verification_notes: Explanation of verification result + + Returns: + Confirmation message + """ + message = ( + f"Cache verification reported: success={success}, " + f"notes={verification_notes}" + ) + logger.info( + "Cache verification reported: success=%s, notes=%s", + success, + verification_notes, + ) + return message + + +class InspectCacheMetadata(Tool): """ - Execute a predefined trajectory to fast-forward through UI interactions + Inspect detailed metadata for a cached trajectory file """ - def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: + def __init__(self) -> None: super().__init__( - name="execute_cached_executions_tool", + name="inspect_cache_metadata_tool", description=( - "Execute a pre-recorded trajectory to automatically perform a " - "sequence of UI interactions. This tool replays mouse movements, " - "clicks, and typing actions from a previously successful execution.\n\n" - "Before using this tool:\n" - "1. Use retrieve_available_trajectories_tool to see which " - "trajectory files are available\n" - "2. Select the appropriate trajectory file path from the " - "returned list\n" - "3. Pass the full file path to this tool\n\n" - "The trajectory will be executed step-by-step, and you should " - "verify the results afterward. Note: Trajectories may fail if " - "the UI state has changed since they were recorded." + "Inspect and display detailed metadata for a cached trajectory " + "file. This tool shows information about:\n" + "- Cache version and creation timestamp\n" + "- Execution statistics (attempts, last execution time)\n" + "- Validity status and invalidation reason (if invalid)\n" + "- Failure history with timestamps and error messages\n" + "- Parameters and trajectory step count\n\n" + "Use this tool to debug cache issues or understand why a cache " + "might be failing or invalidated." ), input_schema={ "type": "object", @@ -82,30 +224,28 @@ def __init__(self, settings: CachedExecutionToolSettings | None = None) -> None: "trajectory_file": { "type": "string", "description": ( - "Full path to the trajectory file (use " - "retrieve_available_trajectories_tool to find " - "available files)" + "Full path to the trajectory file to inspect. " + "Use retrieve_available_trajectories_tool to " + "find available files." ), }, }, "required": ["trajectory_file"], }, ) - if not settings: - settings = CachedExecutionToolSettings() - self._settings = settings - - def set_toolbox(self, toolbox: ToolCollection) -> None: - """Set the AgentOS/AskUiControllerClient reference for executing actions.""" - self._toolbox = toolbox @override @validate_call def __call__(self, trajectory_file: str) -> str: - if not hasattr(self, "_toolbox"): - error_msg = "Toolbox not set. Call set_toolbox() first." - logger.error(error_msg) - raise RuntimeError(error_msg) + """Inspect cache metadata. + + Args: + trajectory_file: Path to the trajectory file + + Returns: + Formatted metadata string + """ + logger.info("Inspecting cache metadata: %s", Path(trajectory_file).name) if not Path(trajectory_file).is_file(): error_msg = ( @@ -113,32 +253,64 @@ def __call__(self, trajectory_file: str) -> str: "Use retrieve_available_trajectories_tool to see available files." ) logger.error(error_msg) - raise FileNotFoundError(error_msg) + return error_msg - # Load and execute trajectory - trajectory = CacheWriter.read_cache_file(Path(trajectory_file)) - info_msg = f"Executing cached trajectory from {trajectory_file}" - logger.info(info_msg) - for step in trajectory: - if ( - "screenshot" in step.name - or step.name == "retrieve_available_trajectories_tool" - ): - continue - try: - self._toolbox.run([step]) - except Exception as e: - error_msg = f"An error occured during the cached execution: {e}" - logger.exception(error_msg) - return ( - f"An error occured while executing the trajectory from " - f"{trajectory_file}. Please verify the UI state and " - "continue without cache." - ) - time.sleep(self._settings.delay_time_between_action) + try: + cache_file = CacheManager.read_cache_file(Path(trajectory_file)) + except Exception: + error_msg = f"Failed to read cache file {Path(trajectory_file).name}" + logger.exception(error_msg) + return error_msg - logger.info("Finished executing cached trajectory") - return ( - f"Successfully executed trajectory from {trajectory_file}. " - "Please verify the UI state." + metadata = cache_file.metadata + logger.debug( + "Metadata loaded: version=%s, valid=%s, attempts=%d, failures=%d", + metadata.version, + metadata.is_valid, + metadata.execution_attempts, + len(metadata.failures), ) + + # Format the metadata into a readable string + lines = [ + "=== Cache Metadata ===", + f"File: {trajectory_file}", + "", + "--- Basic Info ---", + f"Version: {metadata.version}", + f"Created: {metadata.created_at}", + f"Last Executed: {metadata.last_executed_at or 'Never'}", + "", + "--- Execution Statistics ---", + f"Total Execution Attempts: {metadata.execution_attempts}", + f"Total Failures: {len(metadata.failures)}", + "", + "--- Validity Status ---", + f"Is Valid: {metadata.is_valid}", + ] + + if not metadata.is_valid: + lines.append(f"Invalidation Reason: {metadata.invalidation_reason}") + + lines.append("") + lines.append("--- Trajectory Info ---") + lines.append(f"Total Steps: {len(cache_file.trajectory)}") + lines.append(f"Parameters: {len(cache_file.cache_parameters)}") + if cache_file.cache_parameters: + lines.append( + f"Parameter Names: {', '.join(cache_file.cache_parameters.keys())}" + ) + + if metadata.failures: + lines.append("") + lines.append("--- Failure History ---") + for i, failure in enumerate(metadata.failures, 1): + lines.append(f"Failure {i}:") + lines.append(f" Timestamp: {failure.timestamp}") + lines.append(f" Step Index: {failure.step_index}") + lines.append( + f" Failure Count at Step: {failure.failure_count_at_step}" + ) + lines.append(f" Error: {failure.error_message}") + + return "\n".join(lines) diff --git a/src/askui/tools/computer/connect_tool.py b/src/askui/tools/computer/connect_tool.py index 8fddb45a..7e0e35f4 100644 --- a/src/askui/tools/computer/connect_tool.py +++ b/src/askui/tools/computer/connect_tool.py @@ -18,6 +18,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: ), agent_os=agent_os, ) + self.is_cacheable = True def __call__(self) -> str: try: diff --git a/src/askui/tools/computer/disconnect_tool.py b/src/askui/tools/computer/disconnect_tool.py index 1e8b83a4..6f3cea25 100644 --- a/src/askui/tools/computer/disconnect_tool.py +++ b/src/askui/tools/computer/disconnect_tool.py @@ -15,6 +15,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: ), agent_os=agent_os, ) + self.is_cacheable = True def __call__(self) -> str: self.agent_os.disconnect() diff --git a/src/askui/tools/computer/get_mouse_position_tool.py b/src/askui/tools/computer/get_mouse_position_tool.py index 829769a7..059822a5 100644 --- a/src/askui/tools/computer/get_mouse_position_tool.py +++ b/src/askui/tools/computer/get_mouse_position_tool.py @@ -12,6 +12,7 @@ def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True def __call__(self) -> str: cursor_position = self.agent_os.get_mouse_position() diff --git a/src/askui/tools/computer/keyboard_pressed_tool.py b/src/askui/tools/computer/keyboard_pressed_tool.py index 0a82595e..e85fad88 100644 --- a/src/askui/tools/computer/keyboard_pressed_tool.py +++ b/src/askui/tools/computer/keyboard_pressed_tool.py @@ -34,6 +34,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__( self, diff --git a/src/askui/tools/computer/keyboard_release_tool.py b/src/askui/tools/computer/keyboard_release_tool.py index 49a338f4..13603f4b 100644 --- a/src/askui/tools/computer/keyboard_release_tool.py +++ b/src/askui/tools/computer/keyboard_release_tool.py @@ -34,6 +34,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__( self, diff --git a/src/askui/tools/computer/keyboard_tap_tool.py b/src/askui/tools/computer/keyboard_tap_tool.py index de5233e2..62f48227 100644 --- a/src/askui/tools/computer/keyboard_tap_tool.py +++ b/src/askui/tools/computer/keyboard_tap_tool.py @@ -42,6 +42,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__( self, diff --git a/src/askui/tools/computer/list_displays_tool.py b/src/askui/tools/computer/list_displays_tool.py index 8ce94177..68f3c207 100644 --- a/src/askui/tools/computer/list_displays_tool.py +++ b/src/askui/tools/computer/list_displays_tool.py @@ -11,6 +11,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: """, agent_os=agent_os, ) + self.is_cacheable = True def __call__(self) -> str: return self.agent_os.list_displays().model_dump_json( diff --git a/src/askui/tools/computer/mouse_click_tool.py b/src/askui/tools/computer/mouse_click_tool.py index ee07c654..002f7902 100644 --- a/src/askui/tools/computer/mouse_click_tool.py +++ b/src/askui/tools/computer/mouse_click_tool.py @@ -32,6 +32,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__(self, mouse_button: MouseButton, number_of_clicks: int = 1) -> str: self.agent_os.click(mouse_button, number_of_clicks) diff --git a/src/askui/tools/computer/mouse_hold_down_tool.py b/src/askui/tools/computer/mouse_hold_down_tool.py index b3923cd6..9387b117 100644 --- a/src/askui/tools/computer/mouse_hold_down_tool.py +++ b/src/askui/tools/computer/mouse_hold_down_tool.py @@ -24,6 +24,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__(self, mouse_button: MouseButton) -> str: self.agent_os.mouse_down(mouse_button) diff --git a/src/askui/tools/computer/mouse_release_tool.py b/src/askui/tools/computer/mouse_release_tool.py index 2e8e5bb0..b8227d9c 100644 --- a/src/askui/tools/computer/mouse_release_tool.py +++ b/src/askui/tools/computer/mouse_release_tool.py @@ -24,6 +24,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__(self, mouse_button: MouseButton) -> str: self.agent_os.mouse_up(mouse_button) diff --git a/src/askui/tools/computer/mouse_scroll_tool.py b/src/askui/tools/computer/mouse_scroll_tool.py index a02645d5..17aeaf67 100644 --- a/src/askui/tools/computer/mouse_scroll_tool.py +++ b/src/askui/tools/computer/mouse_scroll_tool.py @@ -32,6 +32,7 @@ def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True def __call__(self, dx: int, dy: int) -> str: self.agent_os.mouse_scroll(dx, dy) diff --git a/src/askui/tools/computer/move_mouse_tool.py b/src/askui/tools/computer/move_mouse_tool.py index dabcae50..d04db9dd 100644 --- a/src/askui/tools/computer/move_mouse_tool.py +++ b/src/askui/tools/computer/move_mouse_tool.py @@ -14,11 +14,11 @@ def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: "properties": { "x": { "type": "integer", - "description": "The x coordinate of the mouse position.", + "description": "The x coordinate of the mouse position as int.", }, "y": { "type": "integer", - "description": "The y coordinate of the mouse position.", + "description": "The y coordinate of the mouse position as int.", }, }, "required": ["x", "y"], @@ -26,7 +26,12 @@ def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True def __call__(self, x: int, y: int) -> str: + # for some reason, the agent occasionally calls the tool with the coords + # encoded as strings, which will lead the tool to failing. To prevent this we + # will explicitly convert to int here + x, y = int(x), int(y) self.agent_os.mouse_move(x, y) return f"Mouse was moved to position ({x}, {y})." diff --git a/src/askui/tools/computer/retrieve_active_display_tool.py b/src/askui/tools/computer/retrieve_active_display_tool.py index 00f22977..7eef6cfd 100644 --- a/src/askui/tools/computer/retrieve_active_display_tool.py +++ b/src/askui/tools/computer/retrieve_active_display_tool.py @@ -12,6 +12,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: """, agent_os=agent_os, ) + self.is_cacheable = True def __call__(self) -> str: return str( diff --git a/src/askui/tools/computer/screenshot_tool.py b/src/askui/tools/computer/screenshot_tool.py index 30f8278b..fcf46553 100644 --- a/src/askui/tools/computer/screenshot_tool.py +++ b/src/askui/tools/computer/screenshot_tool.py @@ -14,6 +14,7 @@ def __init__(self, agent_os: ComputerAgentOsFacade | None = None) -> None: agent_os=agent_os, required_tags=[ToolTags.SCALED_AGENT_OS.value], ) + self.is_cacheable = True def __call__(self) -> tuple[str, Image.Image]: screenshot = self.agent_os.screenshot() diff --git a/src/askui/tools/computer/set_active_display_tool.py b/src/askui/tools/computer/set_active_display_tool.py index 22e3e710..94719dec 100644 --- a/src/askui/tools/computer/set_active_display_tool.py +++ b/src/askui/tools/computer/set_active_display_tool.py @@ -21,6 +21,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__(self, display_id: int) -> None: self.agent_os.set_display(display_id) diff --git a/src/askui/tools/computer/type_tool.py b/src/askui/tools/computer/type_tool.py index 9c874f41..adcfbf62 100644 --- a/src/askui/tools/computer/type_tool.py +++ b/src/askui/tools/computer/type_tool.py @@ -29,6 +29,7 @@ def __init__(self, agent_os: AgentOs | None = None) -> None: }, agent_os=agent_os, ) + self.is_cacheable = True def __call__(self, text: str, typing_speed: int = 50) -> str: self.agent_os.type(text, typing_speed) diff --git a/src/askui/tools/playwright/tools.py b/src/askui/tools/playwright/tools.py index 96aa7843..144eec41 100644 --- a/src/askui/tools/playwright/tools.py +++ b/src/askui/tools/playwright/tools.py @@ -35,6 +35,7 @@ def __init__(self, agent_os: PlaywrightAgentOs) -> None: }, ) self._agent_os = agent_os + self.is_cacheable = True @override def __call__(self, url: str) -> str: @@ -60,6 +61,7 @@ def __init__(self, agent_os: PlaywrightAgentOs) -> None: ), ) self._agent_os = agent_os + self.is_cacheable = True @override def __call__(self) -> str: @@ -85,6 +87,7 @@ def __init__(self, agent_os: PlaywrightAgentOs) -> None: ), ) self._agent_os = agent_os + self.is_cacheable = True @override def __call__(self) -> str: @@ -109,6 +112,7 @@ def __init__(self, agent_os: PlaywrightAgentOs) -> None: ), ) self._agent_os = agent_os + self.is_cacheable = True @override def __call__(self) -> str: @@ -133,6 +137,7 @@ def __init__(self, agent_os: PlaywrightAgentOs) -> None: ), ) self._agent_os = agent_os + self.is_cacheable = True @override def __call__(self) -> str: diff --git a/src/askui/tools/store/android/save_screenshot_tool.py b/src/askui/tools/store/android/save_screenshot_tool.py index 65b1464d..28e0a221 100644 --- a/src/askui/tools/store/android/save_screenshot_tool.py +++ b/src/askui/tools/store/android/save_screenshot_tool.py @@ -69,6 +69,7 @@ def __init__(self, base_dir: str) -> None: }, ) self._base_dir = base_dir + self.is_cacheable = True def __call__(self, image_path: str) -> str: """ diff --git a/src/askui/tools/store/computer/save_screenshot_tool.py b/src/askui/tools/store/computer/save_screenshot_tool.py index fd3f3f7c..411d52a2 100644 --- a/src/askui/tools/store/computer/save_screenshot_tool.py +++ b/src/askui/tools/store/computer/save_screenshot_tool.py @@ -69,6 +69,7 @@ def __init__(self, base_dir: str) -> None: }, ) self._base_dir = base_dir + self.is_cacheable = True def __call__(self, image_path: str) -> str: """ diff --git a/src/askui/tools/store/universal/get_current_time.py b/src/askui/tools/store/universal/get_current_time.py index 7b0a20f9..1e92259c 100644 --- a/src/askui/tools/store/universal/get_current_time.py +++ b/src/askui/tools/store/universal/get_current_time.py @@ -40,6 +40,7 @@ def __init__(self) -> None: "required": [], }, ) + self.is_cacheable = True def __call__(self) -> str: """ diff --git a/src/askui/tools/store/universal/list_files_tool.py b/src/askui/tools/store/universal/list_files_tool.py index e2d93aa0..65c60f7c 100644 --- a/src/askui/tools/store/universal/list_files_tool.py +++ b/src/askui/tools/store/universal/list_files_tool.py @@ -69,7 +69,8 @@ def __init__(self, base_dir: str | Path) -> None: "required": [], }, ) - self._base_dir = base_dir + self._base_dir = Path(base_dir) + self.is_cacheable = True def __call__(self, directory_path: str = "", recursive: bool = False) -> str: """ diff --git a/src/askui/tools/store/universal/print_to_console.py b/src/askui/tools/store/universal/print_to_console.py index 7d8949b4..609e2412 100644 --- a/src/askui/tools/store/universal/print_to_console.py +++ b/src/askui/tools/store/universal/print_to_console.py @@ -61,6 +61,7 @@ def __init__(self, source_name: str | None = None): }, ) self._source_name = source_name + self.is_cacheable = False def __call__(self, content: str) -> str: """ diff --git a/src/askui/tools/store/universal/read_from_file_tool.py b/src/askui/tools/store/universal/read_from_file_tool.py index f6cb4f11..b8aba05d 100644 --- a/src/askui/tools/store/universal/read_from_file_tool.py +++ b/src/askui/tools/store/universal/read_from_file_tool.py @@ -66,8 +66,9 @@ def __init__( "required": ["file_path"], }, ) - self._base_dir = base_dir self._encodings = encodings or ["utf-8", "latin-1"] + self._base_dir = Path(base_dir) + self.is_cacheable = True def __call__(self, file_path: str) -> str: """ diff --git a/src/askui/tools/store/universal/wait_tool.py b/src/askui/tools/store/universal/wait_tool.py index 5f9cd369..d2bfdc2c 100644 --- a/src/askui/tools/store/universal/wait_tool.py +++ b/src/askui/tools/store/universal/wait_tool.py @@ -58,6 +58,7 @@ def __init__(self, max_wait_time: int = 10 * 60) -> None: }, ) self._max_wait_time = max_wait_time + self.is_cacheable = True def __call__(self, wait_duration: int) -> str: """ diff --git a/src/askui/tools/store/universal/wait_until_condition_tool.py b/src/askui/tools/store/universal/wait_until_condition_tool.py index 2993cb18..89b884ce 100644 --- a/src/askui/tools/store/universal/wait_until_condition_tool.py +++ b/src/askui/tools/store/universal/wait_until_condition_tool.py @@ -87,6 +87,7 @@ def __init__( ) self._condition_check = condition_check self._max_wait_time = max_wait_time + self.is_cacheable = True def __call__(self, max_wait_time: int, check_interval: int = 1) -> str: """ diff --git a/src/askui/tools/store/universal/wait_with_progress_tool.py b/src/askui/tools/store/universal/wait_with_progress_tool.py index fc6e5dae..81323b5b 100644 --- a/src/askui/tools/store/universal/wait_with_progress_tool.py +++ b/src/askui/tools/store/universal/wait_with_progress_tool.py @@ -66,6 +66,7 @@ def __init__(self, max_wait_time: int = 10 * 60) -> None: }, ) self._max_wait_time = max_wait_time + self.is_cacheable = True def __call__(self, wait_duration: int, message: str = "Waiting") -> str: """ diff --git a/src/askui/tools/store/universal/write_to_file_tool.py b/src/askui/tools/store/universal/write_to_file_tool.py index cef8c1ea..2bac570c 100644 --- a/src/askui/tools/store/universal/write_to_file_tool.py +++ b/src/askui/tools/store/universal/write_to_file_tool.py @@ -94,7 +94,9 @@ def __init__(self, base_dir: str | Path) -> None: "required": ["file_path", "content", "append"], }, ) - self._base_dir = base_dir + self._base_dir = Path(base_dir) + self._base_dir.mkdir(parents=True, exist_ok=True) + self.is_cacheable = False def __call__(self, file_path: str, content: str, append: bool) -> str: """ diff --git a/src/askui/tools/switch_speaker_tool.py b/src/askui/tools/switch_speaker_tool.py new file mode 100644 index 00000000..ad7e2372 --- /dev/null +++ b/src/askui/tools/switch_speaker_tool.py @@ -0,0 +1,90 @@ +"""Generic tool for switching conversation speakers.""" + +import logging +from typing import Any + +from pydantic import validate_call +from typing_extensions import override + +from askui.models.shared.tools import Tool + +logger = logging.getLogger(__name__) + + +class SwitchSpeakerTool(Tool): + """Tool that allows the VLM to request a speaker handoff. + + This tool is dynamically created with the set of valid speaker names + as an enum constraint. When the VLM calls this tool, `AgentSpeaker` + detects the tool call and returns a `SpeakerResult` with + ``status="switch_speaker"``. + + The tool itself is a no-op — it serves as a signal. The actual + speaker switch is handled by `AgentSpeaker` inspecting the VLM response. + """ + + is_cacheable: bool = False + + def __init__(self, speaker_names: list[str]) -> None: + """Initialize with valid speaker names. + + Args: + speaker_names: List of speaker names that can be switched to. + Used to build the enum constraint in the input schema. + """ + super().__init__( + name="switch_speaker", + description=( + "Switch the conversation to a different specialized speaker. " + "Use this tool when the current task is better handled by a " + "different speaker. The speaker_context parameter passes " + "activation data to the target speaker. See " + "AVAILABLE_SPEAKERS in the system prompt for descriptions " + "of each speaker." + ), + input_schema={ + "type": "object", + "properties": { + "speaker_name": { + "type": "string", + "enum": speaker_names, + "description": ("Name of the speaker to switch to."), + }, + "speaker_context": { + "type": "object", + "description": ( + "Activation context to pass to the target " + "speaker. Each speaker expects specific context " + "keys — see the speaker descriptions in " + "AVAILABLE_SPEAKERS." + ), + "additionalProperties": True, + "default": {}, + }, + }, + "required": ["speaker_name"], + }, + ) + + @override + @validate_call + def __call__( + self, + speaker_name: str, + speaker_context: dict[str, Any] | None = None, + ) -> str: + """No-op execution — the tool is a signal, not an action. + + Args: + speaker_name: Target speaker name. + speaker_context: Activation context for the target speaker. + + Returns: + Acknowledgment message. + """ + logger.info( + "Speaker switch requested to '%s' with context keys: %s", + speaker_name, + list((speaker_context or {}).keys()), + ) + return f"Switching to speaker '{speaker_name}'" diff --git a/src/askui/utils/cache_writer.py b/src/askui/utils/cache_writer.py deleted file mode 100644 index 36508c73..00000000 --- a/src/askui/utils/cache_writer.py +++ /dev/null @@ -1,70 +0,0 @@ -import json -import logging -from datetime import datetime, timezone -from pathlib import Path - -from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam -from askui.models.shared.agent_on_message_cb import OnMessageCbParam - -logger = logging.getLogger(__name__) - - -class CacheWriter: - def __init__(self, cache_dir: str = ".cache", file_name: str = "") -> None: - self.cache_dir = Path(cache_dir) - self.cache_dir.mkdir(exist_ok=True) - self.messages: list[ToolUseBlockParam] = [] - if file_name and not file_name.endswith(".json"): - file_name += ".json" - self.file_name = file_name - self.was_cached_execution = False - - def add_message_cb(self, param: OnMessageCbParam) -> MessageParam: - """Add a message to cache.""" - if param.message.role == "assistant": - contents = param.message.content - if isinstance(contents, list): - for content in contents: - if isinstance(content, ToolUseBlockParam): - self.messages.append(content) - if content.name == "execute_cached_executions_tool": - self.was_cached_execution = True - if param.message.stop_reason == "end_turn": - self.generate() - - return param.message - - def set_file_name(self, file_name: str) -> None: - if not file_name.endswith(".json"): - file_name += ".json" - self.file_name = file_name - - def reset(self, file_name: str = "") -> None: - self.messages = [] - if file_name and not file_name.endswith(".json"): - file_name += ".json" - self.file_name = file_name - self.was_cached_execution = False - - def generate(self) -> None: - if self.was_cached_execution: - logger.info("Will not write cache file as this was a cached execution") - return - if not self.file_name: - self.file_name = ( - f"cached_trajectory_{datetime.now(tz=timezone.utc):%Y%m%d%H%M%S%f}.json" - ) - cache_file_path = self.cache_dir / self.file_name - - messages_json = [m.model_dump() for m in self.messages] - with cache_file_path.open("w", encoding="utf-8") as f: - json.dump(messages_json, f, indent=4) - info_msg = f"Cache File written at {str(cache_file_path)}" - logger.info(info_msg) - self.reset() - - @staticmethod - def read_cache_file(cache_file_path: Path) -> list[ToolUseBlockParam]: - with cache_file_path.open("r", encoding="utf-8") as f: - raw_trajectory = json.load(f) - return [ToolUseBlockParam(**step) for step in raw_trajectory] diff --git a/src/askui/utils/caching/__init__.py b/src/askui/utils/caching/__init__.py new file mode 100644 index 00000000..d9894a93 --- /dev/null +++ b/src/askui/utils/caching/__init__.py @@ -0,0 +1,27 @@ +"""Caching utilities for agent trajectory recording and playback. + +This module provides: +- `CacheManager`: High-level cache operations (recording, validation, playback) +- `CacheParameterHandler`: Parameter identification and substitution +- `CacheValidator`: Validation strategies for cache invalidation +""" + +from .cache_manager import CacheManager +from .cache_parameter_handler import CacheParameterHandler +from .cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StaleCacheValidator, + StepFailureCountValidator, + TotalFailureRateValidator, +) + +__all__ = [ + "CacheManager", + "CacheParameterHandler", + "CacheValidator", + "CompositeCacheValidator", + "StaleCacheValidator", + "StepFailureCountValidator", + "TotalFailureRateValidator", +] diff --git a/src/askui/utils/caching/cache_manager.py b/src/askui/utils/caching/cache_manager.py new file mode 100644 index 00000000..d59b1821 --- /dev/null +++ b/src/askui/utils/caching/cache_manager.py @@ -0,0 +1,679 @@ +"""Cache manager for handling cache metadata, validation, and recording.""" + +import json +import logging +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from PIL import Image + +from askui.models.shared.agent_message_param import ( + MessageParam, + ToolUseBlockParam, + UsageParam, +) +from askui.models.shared.settings import ( + CacheFailure, + CacheFile, + CacheMetadata, + CacheWritingSettings, + VisualValidationMetadata, +) +from askui.models.shared.tools import ToolCollection +from askui.utils.caching.cache_parameter_handler import CacheParameterHandler +from askui.utils.caching.cache_validator import ( + CacheValidator, + CompositeCacheValidator, + StaleCacheValidator, + StepFailureCountValidator, + TotalFailureRateValidator, +) +from askui.utils.visual_validation import ( + compute_ahash, + compute_phash, + extract_region, + find_recent_screenshot, + get_validation_coordinate, +) + +if TYPE_CHECKING: + from askui.model_providers.vlm_provider import VlmProvider + +logger = logging.getLogger(__name__) + + +class CacheManager: + """Manages cache metadata, validation, updates, and recording. + + This class provides high-level operations for cache management including: + - Reading cache files from disk + - Writing cache files to disk + - Recording trajectories during execution (write mode) + - Recording execution attempts and failures + - Validating caches using pluggable validation strategies + - Invalidating caches when they fail validation + - Updating metadata on disk + """ + + def __init__(self, validators: list[CacheValidator] | None = None) -> None: + """Initialize cache manager. + + Args: + validators: Optional list of cache validators. If None, uses default + validators (StepFailureCount, TotalFailureRate, StaleCache). + """ + # Validation + if validators is None: + # Use default validators + self.validators = CompositeCacheValidator( + [ + StepFailureCountValidator(max_failures_per_step=3), + TotalFailureRateValidator(min_attempts=10, max_failure_rate=0.5), + StaleCacheValidator(max_age_days=30), + ] + ) + else: + self.validators = CompositeCacheValidator(validators) + + # Recording state (for write mode) + self._recording = False + self._tool_blocks: list[ToolUseBlockParam] = [] + self._cache_dir: Path | None = None + self._file_name: str = "" + self._goal: str | None = None + self._toolbox: ToolCollection | None = None + self._accumulated_usage = UsageParam() + self._was_cached_execution = False + self._cache_writer_settings = CacheWritingSettings() + self._vlm_provider: "VlmProvider | None" = None + + def set_toolbox(self, toolbox: ToolCollection) -> None: + """Set the toolbox for checking which tools are cacheable. + + Args: + toolbox: ToolCollection to use for cacheable tool detection + """ + self._toolbox = toolbox + + def record_execution_attempt( + self, + cache_file: CacheFile, + success: bool, + failure_info: CacheFailure | None = None, + ) -> None: + """Record an execution attempt in cache metadata. + + Args: + cache_file: The cache file to update + success: Whether the execution was successful + failure_info: Optional failure information (required if success=False) + """ + cache_file.metadata.execution_attempts += 1 + cache_file.metadata.last_executed_at = datetime.now(tz=timezone.utc) + + if not success and failure_info: + cache_file.metadata.failures.append(failure_info) + logger.debug( + "Recorded failure at step %d: %s", + failure_info.step_index, + failure_info.error_message, + ) + + def record_step_failure( + self, cache_file: CacheFile, step_index: int, error_message: str + ) -> None: + """Record a step failure in cache metadata. + + Args: + cache_file: The cache file to update + step_index: The step index where failure occurred + error_message: Error message describing the failure + """ + # Count existing failures at this step + failures_at_step = sum( + 1 for f in cache_file.metadata.failures if f.step_index == step_index + ) + + failure = CacheFailure( + timestamp=datetime.now(tz=timezone.utc), + step_index=step_index, + error_message=error_message, + failure_count_at_step=failures_at_step + 1, + ) + cache_file.metadata.failures.append(failure) + logger.debug("Recorded failure at step %d", step_index) + + def should_invalidate( + self, cache_file: CacheFile, step_index: int | None = None + ) -> tuple[bool, str | None]: + """Check if cache should be invalidated. + + Args: + cache_file: The cache file to validate + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + return self.validators.should_invalidate(cache_file, step_index) + + def invalidate_cache(self, cache_file: CacheFile, reason: str) -> None: + """Invalidate a cache file. + + Args: + cache_file: The cache file to invalidate + reason: Reason for invalidation + """ + cache_file.metadata.is_valid = False + cache_file.metadata.invalidation_reason = reason + logger.warning("Cache invalidated: %s", reason) + + def mark_cache_valid(self, cache_file: CacheFile) -> None: + """Mark a cache file as valid. + + This resets the is_valid flag to True and clears the invalidation reason. + Useful for manual revalidation of caches that were previously invalidated. + + Args: + cache_file: The cache file to mark as valid + """ + cache_file.metadata.is_valid = True + cache_file.metadata.invalidation_reason = None + logger.info("Cache marked as valid") + + def get_failure_count_for_step(self, cache_file: CacheFile, step_index: int) -> int: + """Get the total number of failures for a specific step. + + Args: + cache_file: The cache file to check + step_index: The step index to get failure count for + + Returns: + Number of failures recorded for the given step index + """ + return sum( + 1 for f in cache_file.metadata.failures if f.step_index == step_index + ) + + def update_metadata_on_failure( + self, + cache_file: CacheFile, + cache_file_path: str, + step_index: int, + error_message: str, + ) -> None: + """Update cache metadata after execution failure and write to disk. + + This is a convenience method that combines recording the failure, + checking validation, potentially invalidating, and writing to disk. + + Args: + cache_file: The cache file to update + cache_file_path: Path to write the updated cache file + step_index: The step index where failure occurred + error_message: Error message describing the failure + """ + try: + # Record the attempt and failure + self.record_execution_attempt(cache_file, success=False) + self.record_step_failure( + cache_file, step_index=step_index, error_message=error_message + ) + + # Check if cache should be invalidated + should_inv, reason = self.should_invalidate( + cache_file, step_index=step_index + ) + if should_inv and reason: + self.invalidate_cache(cache_file, reason=reason) + + # Write updated metadata back to disk + self._write_cache_file(cache_file, cache_file_path) + logger.debug( + "Updated cache metadata after failure: %s", Path(cache_file_path).name + ) + except Exception: + logger.exception("Failed to update cache metadata") + + def update_metadata_on_completion( + self, + cache_file: CacheFile, + cache_file_path: str, + success: bool, + ) -> None: + """Update cache metadata after execution completion and write to disk. + + Args: + cache_file: The cache file to update + cache_file_path: Path to write the updated cache file + success: Whether the execution was successful + """ + try: + self.record_execution_attempt(cache_file, success=success) + + # Write updated metadata back to disk + self._write_cache_file(cache_file, cache_file_path) + logger.info("Updated cache metadata: %s", Path(cache_file_path).name) + except Exception: + logger.exception("Failed to update cache metadata") + + def _write_cache_file(self, cache_file: CacheFile, cache_file_path: str) -> None: + """Write cache file to disk. + + Args: + cache_file: The cache file to write + cache_file_path: Path to write the cache file + """ + cache_path = Path(cache_file_path) + with cache_path.open("w", encoding="utf-8") as f: + json.dump( + cache_file.model_dump(mode="json"), + f, + indent=2, + default=str, + ) + + @staticmethod + def read_cache_file(cache_file_path: Path) -> CacheFile: + """Read cache file with backward compatibility for legacy format. + + Supports two formats: + 1. Legacy format: Just a list of ToolUseBlockParam dicts + 2. New format: CacheFile with metadata and trajectory + + Args: + cache_file_path: Path to the cache file + + Returns: + CacheFile object with metadata and trajectory + """ + logger.debug("Reading cache file: %s", cache_file_path) + with cache_file_path.open("r", encoding="utf-8") as f: + raw_data = json.load(f) + + # Handle legacy format (just a list of tool blocks) + if isinstance(raw_data, list): + logger.info("Detected legacy cache format, converting to CacheFile") + trajectory = [ToolUseBlockParam(**step) for step in raw_data] + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.0", + created_at=datetime.now(tz=timezone.utc), + ), + trajectory=trajectory, + ) + else: + cache_file = CacheFile(**raw_data) + + logger.info( + "Successfully loaded cache: %s steps, %s parameters", + len(cache_file.trajectory), + len(cache_file.cache_parameters), + ) + if cache_file.metadata.goal: + logger.debug("Cache goal: %s", cache_file.metadata.goal) + return cache_file + + def start_recording( + self, + cache_dir: str | Path, + file_name: str = "", + goal: str | None = None, + toolbox: ToolCollection | None = None, + cache_writer_settings: CacheWritingSettings | None = None, + vlm_provider: "VlmProvider | None" = None, + ) -> None: + """Start recording a new trajectory. + + Args: + cache_dir: Directory to store cache files + file_name: Filename for cache file (auto-generated if not provided) + goal: Goal string for this execution + toolbox: ToolCollection to check which tools are cacheable + cache_writer_settings: Settings for cache recording + vlm_provider: VlmProvider instance to use for parameter identification + """ + self._recording = True + self._tool_blocks = [] + self._cache_dir = Path(cache_dir) + self._cache_dir.mkdir(exist_ok=True) + self._file_name = ( + file_name + if file_name.endswith(".json") or not file_name + else f"{file_name}.json" + ) + self._goal = goal + self._toolbox = toolbox + self._accumulated_usage = UsageParam() + self._was_cached_execution = False + self._cache_writer_settings = cache_writer_settings or CacheWritingSettings() + self._vlm_provider = vlm_provider or self._vlm_provider + + logger.info( + "Started recording trajectory to %s", + self._cache_dir / (self._file_name or "[auto-generated]"), + ) + + def finish_recording(self, messages: list[MessageParam]) -> str: + """Finish recording and write cache file to disk. + + Extracts tool blocks and usage from the message history. + + Args: + messages: Complete message history from the conversation + + Returns: + Success message with cache file path + """ + if not self._recording: + return "No recording in progress" + + # Extract tool blocks and usage from message history + self._extract_from_messages(messages) + + if self._was_cached_execution: + logger.info("Will not write cache file as this was a cached execution") + self._reset_recording_state() + return "Skipped writing cache (was cached execution)" + + # Blank non-cacheable tool inputs BEFORE parameterization + # (so they don't get sent to LLM for parameter identification) + if self._toolbox is not None: + self._tool_blocks = self._blank_non_cacheable_tool_inputs(self._tool_blocks) + else: + logger.info("No toolbox set, skipping non-cacheable tool input blanking") + + # Auto-generate filename if not provided + if not self._file_name: + self._file_name = ( + f"cached_trajectory_{datetime.now(tz=timezone.utc):%Y%m%d%H%M%S%f}.json" + ) + + assert isinstance(self._cache_dir, Path) + cache_file_path = self._cache_dir / self._file_name + + # Parameterize trajectory (this creates NEW tool blocks) + goal_to_save, trajectory_to_save, parameters_dict = ( + self._parameterize_trajectory() + ) + + # Add visual validation hashes to trajectory AFTER parameterization + # (so visual_representation fields don't get lost during parameterization) + self._add_visual_validation_to_trajectory(trajectory_to_save, messages) + + # Generate cache file + self._generate_cache_file( + goal_to_save, trajectory_to_save, parameters_dict, cache_file_path + ) + + # Reset recording state + self._reset_recording_state() + + return f"Cache file written: {cache_file_path}" + + def _reset_recording_state(self) -> None: + """Reset all recording state variables.""" + self._recording = False + self._tool_blocks = [] + self._file_name = "" + self._was_cached_execution = False + self._accumulated_usage = UsageParam() + + def _extract_from_messages(self, messages: list[MessageParam]) -> None: + """Extract tool blocks and usage from message history. + + Args: + messages: Complete message history from the conversation + """ + for message in messages: + if message.role == "assistant": + contents = message.content + if isinstance(contents, list): + for content in contents: + if isinstance(content, ToolUseBlockParam): + self._tool_blocks.append(content) + # Check if this was a cached execution + if content.name == "execute_cached_executions_tool": + self._was_cached_execution = True + + # Accumulate usage from assistant messages + if message.usage: + self._accumulate_usage(message.usage) + + def _parameterize_trajectory( + self, + ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: + """Identify parameters and return parameterized trajectory + goal.""" + return CacheParameterHandler.identify_and_parameterize( + trajectory=self._tool_blocks, + goal=self._goal, + identification_strategy=self._cache_writer_settings.parameter_identification_strategy, + vlm_provider=self._vlm_provider, + ) + + def _blank_non_cacheable_tool_inputs( + self, trajectory: list[ToolUseBlockParam] + ) -> list[ToolUseBlockParam]: + """Blank out input fields for non-cacheable tools to save space. + + For tools marked as is_cacheable=False, we replace their input with an + empty dict since we won't be executing them from cache anyway. + + Args: + trajectory: The trajectory to process + + Returns: + New trajectory with non-cacheable tool inputs blanked out + """ + if self._toolbox is None: + return trajectory + + blanked_count = 0 + result: list[ToolUseBlockParam] = [] + tools = self._toolbox.tool_map + for tool_block in trajectory: + # Check if this tool is cacheable + tool = tools.get(tool_block.name) + + # If tool is not cacheable, blank out its input + if tool is not None and not tool.is_cacheable: + logger.debug( + "Blanking input for non-cacheable tool: %s", tool_block.name + ) + blanked_count += 1 + result.append( + ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input={}, # Blank out the input + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + ) + else: + # Keep the tool block as-is + result.append(tool_block) + + if blanked_count > 0: + logger.info( + "Blanked inputs for %s non-cacheable tool(s) to save space", + blanked_count, + ) + + return result + + def _add_visual_validation_to_trajectory( # noqa: C901 + self, trajectory: list[ToolUseBlockParam], messages: list[MessageParam] + ) -> None: + """Add visual validation hashes to tool use blocks in the trajectory. + + This method processes the complete message history to find screenshots + and compute visual hashes for actions that require validation. + The hashes are stored in the visual_representation field of each + ToolUseBlockParam in the provided trajectory. + + Args: + trajectory: The parameterized trajectory to add validation to + messages: Complete message history from the conversation + """ + if self._cache_writer_settings.visual_verification_method == "none": + logger.info("Visual validation disabled, skipping hash computation") + return + + # Build a mapping from tool_use_id to tool_block in the trajectory + # This allows us to update the correct tool block in the trajectory + tool_block_map: dict[str, ToolUseBlockParam] = { + block.id: block for block in trajectory + } + + # Iterate through messages to find tool uses and their context + validated_count = 0 + for i, message in enumerate(messages): + if message.role != "assistant": + continue + + if isinstance(message.content, str): + continue + + # Process tool use blocks in this message + for block in message.content: + if block.type != "tool_use": + continue + + # Find the corresponding block in the trajectory + trajectory_block = tool_block_map.get(block.id) + if not trajectory_block: + # This tool use is not in the trajectory (might be non-cacheable) + continue + + # Check if this tool has coordinates for visual validation + tool_input: dict[str, Any] = ( + block.input if isinstance(block.input, dict) else {} + ) + coordinate = get_validation_coordinate(tool_input) + if coordinate is None: + # Tools without coordinates don't need visual validation + trajectory_block.visual_representation = None + continue + + # Find most recent screenshot BEFORE this tool use + screenshot = find_recent_screenshot(messages, from_index=i - 1) + if not screenshot: + logger.warning( + "No screenshot found before tool_id=%s, " + "skipping visual validation", + block.id, + ) + trajectory_block.visual_representation = None + continue + + # Extract region and compute hash + try: + # Pass coordinate in the format extract_region expects + region = extract_region( + screenshot, + {"coordinate": list(coordinate)}, + region_size=self._cache_writer_settings.visual_validation_region_size, + ) + visual_hash = self._compute_visual_hash( + region, self._cache_writer_settings.visual_verification_method + ) + trajectory_block.visual_representation = visual_hash + validated_count += 1 + logger.debug( + "Added visual validation hash for tool_id=%s", + block.id, + ) + except Exception: + logger.exception( + "Failed to compute visual hash for tool_id=%s", block.id + ) + trajectory_block.visual_representation = None + + if validated_count > 0: + logger.info( + "Added visual validation to %d action(s) in trajectory", + validated_count, + ) + + def _compute_visual_hash(self, image: Image.Image, method: str) -> str: + """Compute visual hash using specified method. + + Args: + image: PIL Image to hash + method: Hash method ("phash", "ahash", or "none") + + Returns: + String representation of the hash + + Raises: + ValueError: If method is not supported + """ + if method == "phash": + return compute_phash(image, hash_size=8) + if method == "ahash": + return compute_ahash(image, hash_size=8) + if method == "none": + return "" + msg = f"Unsupported visual verification method: {method}" + raise ValueError(msg) + + def _generate_cache_file( + self, + goal_to_save: str | None, + trajectory_to_save: list[ToolUseBlockParam], + parameters_dict: dict[str, str], + cache_file_path: Path, + ) -> None: + """Write cache file to disk with metadata. + + Args: + goal_to_save: Goal string (may be parameterized) + trajectory_to_save: Trajectory (parameterized and blanked) + parameters_dict: Cache parameters dictionary + cache_file_path: Path to write cache file + """ + # Prepare visual validation metadata + visual_validation_metadata: VisualValidationMetadata | None = None + if self._cache_writer_settings.visual_verification_method != "none": + visual_validation_metadata = VisualValidationMetadata( + enabled=True, + method=self._cache_writer_settings.visual_verification_method, + region_size=self._cache_writer_settings.visual_validation_region_size, + ) + + cache_file = CacheFile( + metadata=CacheMetadata( + version="0.2", + created_at=datetime.now(tz=timezone.utc), + goal=goal_to_save, + token_usage=self._accumulated_usage, + visual_validation=visual_validation_metadata, + ), + trajectory=trajectory_to_save, + cache_parameters=parameters_dict, + ) + + with cache_file_path.open("w", encoding="utf-8") as f: + json.dump(cache_file.model_dump(mode="json"), f, indent=4) + logger.info("Cache file successfully written: %s", cache_file_path) + + def _accumulate_usage(self, step_usage: UsageParam) -> None: + """Accumulate usage statistics from a single API call. + + Args: + step_usage: Usage from a single message + """ + self._accumulated_usage.input_tokens = ( + self._accumulated_usage.input_tokens or 0 + ) + (step_usage.input_tokens or 0) + self._accumulated_usage.output_tokens = ( + self._accumulated_usage.output_tokens or 0 + ) + (step_usage.output_tokens or 0) + self._accumulated_usage.cache_creation_input_tokens = ( + self._accumulated_usage.cache_creation_input_tokens or 0 + ) + (step_usage.cache_creation_input_tokens or 0) + self._accumulated_usage.cache_read_input_tokens = ( + self._accumulated_usage.cache_read_input_tokens or 0 + ) + (step_usage.cache_read_input_tokens or 0) diff --git a/src/askui/utils/caching/cache_parameter_handler.py b/src/askui/utils/caching/cache_parameter_handler.py new file mode 100644 index 00000000..249d7064 --- /dev/null +++ b/src/askui/utils/caching/cache_parameter_handler.py @@ -0,0 +1,482 @@ +"""Cache parameter handling for trajectory recording and execution. + +This module provides utilities for: +- Identifying dynamic values that should become parameters (recording phase) +- Validating and substituting parameter values (execution phase) + +Cache parameters use the {{parameter_name}} syntax and allow dynamic values +to be injected during cache execution. +""" + +import json +import logging +import re +from typing import TYPE_CHECKING, Any + +from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam +from askui.models.shared.prompts import SystemPrompt +from askui.prompts.caching import CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT + +if TYPE_CHECKING: + from askui.model_providers.vlm_provider import VlmProvider + +logger = logging.getLogger(__name__) + +# Regex pattern for matching parameters: {{parameter_name}} +# Allows alphanumeric characters and underscores, must start with letter/underscore +CACHE_PARAMETER_PATTERN = r"\{\{([a-zA-Z_][a-zA-Z0-9_]*)\}\}" + + +class CacheParameterDefinition: + """Represents a cache parameter identified in a trajectory.""" + + def __init__(self, name: str, value: Any, description: str) -> None: + self.name = name + self.value = value + self.description = description + + def __repr__(self) -> str: + return f"CacheParameterDefinition(name={self.name}, value={self.value})" + + +class CacheParameterHandler: + """Handles all cache parameter operations for trajectory recording and execution.""" + + # ======================================================================== + # RECORDING PHASE: Parameter identification and templatization + # ======================================================================== + + @staticmethod + def identify_and_parameterize( + trajectory: list[ToolUseBlockParam], + goal: str | None, + identification_strategy: str, + vlm_provider: "VlmProvider | None" = None, + ) -> tuple[str | None, list[ToolUseBlockParam], dict[str, str]]: + """Identify parameters and return parameterized trajectory + goal. + + This is the main entry point for the recording phase. It orchestrates + parameter identification and templatization of both trajectory and goal. + + Args: + trajectory: The trajectory to analyze and parameterize + goal: The goal text to parameterize (optional) + identification_strategy: "llm" for AI-based or "preset" for manual + vlm_provider: VlmProvider instance to use for LLM-based identification. + The model to use is determined by vlm_provider.model_id. + + Returns: + Tuple of: + - Parameterized goal text (or None if no goal) + - Parameterized trajectory (with {{param}} syntax) + - Dict mapping parameter names to descriptions + """ + if identification_strategy == "llm" and trajectory and vlm_provider is not None: + # Use LLM to identify parameters + parameters_dict, parameter_definitions = ( + CacheParameterHandler._identify_parameters_with_llm( + trajectory, vlm_provider + ) + ) + + if parameter_definitions: + # Replace values with {{parameter}} syntax in trajectory + parameterized_trajectory = ( + CacheParameterHandler._replace_values_with_parameters( + trajectory, parameter_definitions + ) + ) + + # Apply same replacement to goal text + parameterized_goal = goal + if goal: + parameterized_goal = ( + CacheParameterHandler._apply_parameters_to_text( + goal, parameter_definitions + ) + ) + + n_parameters = len(parameter_definitions) + logger.info("Replaced %s parameter values in trajectory", n_parameters) + return parameterized_goal, parameterized_trajectory, parameters_dict + + # No parameters identified + logger.info("No parameters identified in trajectory") + return goal, trajectory, {} + + # Manual extraction (preset strategy) + parameter_names = CacheParameterHandler.extract_parameters(trajectory) + parameters_dict = { + name: f"Parameter for {name}" + for name in parameter_names # Generic desc + } + n_parameters = len(parameter_names) + logger.info("Extracted %s manual parameters from trajectory", n_parameters) + return goal, trajectory, parameters_dict + + @staticmethod + def _identify_parameters_with_llm( + trajectory: list[ToolUseBlockParam], + vlm_provider: "VlmProvider", + ) -> tuple[dict[str, str], list[CacheParameterDefinition]]: + """Identify parameters in a trajectory using LLM analysis. + + Args: + trajectory: The trajectory to analyze (list of tool use blocks) + vlm_provider: VlmProvider instance for LLM calls. + The model used is determined by vlm_provider.model_id. + + Returns: + Tuple of: + - Dict mapping parameter names to descriptions + - List of CacheParameterDefinition objects with name, value, description + """ + if not trajectory: + logger.debug("Empty trajectory provided, skipping parameter identification") + return {}, [] + + logger.info( + "Starting parameter identification for trajectory with %s steps", + len(trajectory), + ) + + # Convert trajectory to serializable format for analysis + trajectory_data = [tool.model_dump(mode="json") for tool in trajectory] + logger.debug("Converted %s tool blocks to JSON format", len(trajectory_data)) + + user_message = ( + "Analyze this UI automation trajectory and identify all values that " + "should be parameters:\n\n" + f"```json\n{json.dumps(trajectory_data, indent=2)}\n```\n\n" + "Return only the JSON object with identified parameters. " + "Be thorough but conservative - only mark values that are clearly " + "dynamic or time-sensitive." + ) + + response_text = "" # Initialize for error logging + try: + # Make single API call using VlmProvider + logger.debug( + "Calling LLM (%s) to analyze trajectory for parameters", + vlm_provider.model_id, + ) + response = vlm_provider.create_message( + messages=[MessageParam(role="user", content=user_message)], + system=SystemPrompt(prompt=CACHING_PARAMETER_IDENTIFIER_SYSTEM_PROMPT), + max_tokens=4096, + temperature=0.0, # Deterministic for analysis + ) + logger.debug("Received response from LLM") + + # Extract text from response + if isinstance(response.content, list): + response_text = next( + ( + block.text + for block in response.content + if hasattr(block, "text") + ), + "", + ) + else: + response_text = str(response.content) + + # Parse the JSON response + logger.debug("Parsing LLM response to extract parameter definitions") + # Handle markdown code blocks if present + if "```json" in response_text: + logger.debug("Removing JSON markdown code block wrapper from response") + response_text = ( + response_text.split("```json")[1].split("```")[0].strip() + ) + elif "```" in response_text: + logger.debug("Removing code block wrapper from response") + response_text = response_text.split("```")[1].split("```")[0].strip() + + parameter_data = json.loads(response_text) + logger.debug( + "Successfully parsed JSON response with %s parameters", + len(parameter_data.get("parameters", [])), + ) + + # Convert to our data structures + parameter_definitions = [ + CacheParameterDefinition( + name=p["name"], value=p["value"], description=p["description"] + ) + for p in parameter_data.get("parameters", []) + ] + + parameters_dict = {p.name: p.description for p in parameter_definitions} + + if parameter_definitions: + logger.info( + "Successfully identified %s parameters in trajectory", + len(parameter_definitions), + ) + for p in parameter_definitions: + logger.debug(" - %s: %s (%s)", p.name, p.value, p.description) + else: + logger.info( + "No parameters identified in trajectory " + "(this is normal for trajectories with only static values)" + ) + + except json.JSONDecodeError as e: + logger.warning( + "Failed to parse LLM response as JSON: %s. " + "Falling back to empty parameter list.", + e, + extra={"response_text": response_text[:500]}, # Log first 500 chars + ) + return {}, [] + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to identify parameters with LLM: %s. " + "Falling back to empty parameter list.", + e, + exc_info=True, + ) + return {}, [] + else: + return parameters_dict, parameter_definitions + + @staticmethod + def _replace_values_with_parameters( + trajectory: list[ToolUseBlockParam], + parameter_definitions: list[CacheParameterDefinition], + ) -> list[ToolUseBlockParam]: + """Replace actual values in trajectory with {{parameter_name}} syntax. + + This is the reverse of substitute_parameters - it takes identified values + and replaces them with parameter syntax for saving to cache. + + Args: + trajectory: The trajectory to templatize + parameter_definitions: List of CacheParameterDefinition objects with + name and value attributes + + Returns: + New trajectory with values replaced by parameters + """ + # Build replacement map: value -> parameter name + replacements = { + str(p.value): f"{{{{{p.name}}}}}" for p in parameter_definitions + } + + # Apply replacements to each tool block + parameterized_trajectory = [] + for tool_block in trajectory: + parameterized_input = CacheParameterHandler._replace_values_in_value( + tool_block.input, replacements + ) + + parameterized_trajectory.append( + ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=parameterized_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + ) + + return parameterized_trajectory + + @staticmethod + def _apply_parameters_to_text( + text: str, parameter_definitions: list[CacheParameterDefinition] + ) -> str: + """Apply parameter replacement to a text string (e.g., goal). + + Args: + text: The text to parameterize + parameter_definitions: List of parameter definitions + + Returns: + Text with values replaced by {{parameter}} syntax + """ + # Build replacement map: value -> parameter syntax + replacements = { + str(p.value): f"{{{{{p.name}}}}}" for p in parameter_definitions + } + # Sort by length descending to replace longer matches first + result = text + for actual_value in sorted(replacements.keys(), key=len, reverse=True): + if actual_value in result: + result = result.replace(actual_value, replacements[actual_value]) + return result + + @staticmethod + def _replace_values_in_value(value: Any, replacements: dict[str, str]) -> Any: + """Recursively replace actual values with parameter syntax. + + Args: + value: Any value (str, dict, list, etc.) to process + replacements: Dict mapping actual values to parameter syntax + + Returns: + New value with replacements applied + """ + if isinstance(value, str): + # Replace exact matches and substring matches + result = value + # Sort by length descending to replace longer matches first + # This prevents partial replacements + for actual_value in sorted(replacements.keys(), key=len, reverse=True): + if actual_value in result: + result = result.replace(actual_value, replacements[actual_value]) + return result + if isinstance(value, dict): + # Recursively replace in dict values + return { + k: CacheParameterHandler._replace_values_in_value(v, replacements) + for k, v in value.items() + } + if isinstance(value, list): + # Recursively replace in list items + return [ + CacheParameterHandler._replace_values_in_value(item, replacements) + for item in value + ] + # For non-string types, check if the value matches exactly + str_value = str(value) + if str_value in replacements: + # Return the parameter as a string + return replacements[str_value] + return value + + # ======================================================================== + # EXECUTION PHASE: Parameter extraction, validation, and substitution + # ======================================================================== + + @staticmethod + def extract_parameters(trajectory: list[ToolUseBlockParam]) -> set[str]: + """Extract all parameter names from a trajectory. + + Scans all tool inputs for {{parameter_name}} patterns and returns + a set of unique parameter names. + + Args: + trajectory: List of tool use blocks to scan + + Returns: + Set of unique parameter names found in the trajectory + """ + parameters: set[str] = set() + + for step in trajectory: + # Recursively find parameters in the input object + parameters.update(CacheParameterHandler._extract_from_value(step.input)) + + return parameters + + @staticmethod + def _extract_from_value(value: Any) -> set[str]: + """Recursively extract parameters from a value. + + Args: + value: Any value (str, dict, list, etc.) to search for parameters + + Returns: + Set of parameter names found + """ + parameters: set[str] = set() + + if isinstance(value, str): + # Find all matches in the string + matches = re.finditer(CACHE_PARAMETER_PATTERN, value) + parameters.update(match.group(1) for match in matches) + elif isinstance(value, dict): + # Recursively search dict values + for v in value.values(): + parameters.update(CacheParameterHandler._extract_from_value(v)) + elif isinstance(value, list): + # Recursively search list items + for item in value: + parameters.update(CacheParameterHandler._extract_from_value(item)) + + return parameters + + @staticmethod + def validate_parameters( + trajectory: list[ToolUseBlockParam], provided_values: dict[str, str] + ) -> tuple[bool, list[str]]: + """Validate that all required parameters have values. + + Args: + trajectory: List of tool use blocks containing parameters + provided_values: Dict of parameter names to their values + + Returns: + Tuple of (is_valid, missing_parameters) + - is_valid: True if all parameters have values, False otherwise + - missing_parameters: List of parameter names that are missing values + """ + required_parameters = CacheParameterHandler.extract_parameters(trajectory) + missing = [name for name in required_parameters if name not in provided_values] + + return len(missing) == 0, missing + + @staticmethod + def substitute_parameters( + tool_block: ToolUseBlockParam, parameter_values: dict[str, str] + ) -> ToolUseBlockParam: + """Replace parameters in a tool block with actual values. + + Creates a new ToolUseBlockParam with all {{parameter}} occurrences + replaced with their corresponding values from parameter_values. + + Args: + tool_block: The tool use block containing parameters + parameter_values: Dict mapping parameter names to replacement values + + Returns: + New ToolUseBlockParam with parameters substituted + """ + # Deep copy the input and substitute parameters + substituted_input = CacheParameterHandler._substitute_in_value( + tool_block.input, parameter_values + ) + + # Create new ToolUseBlockParam with substituted values + return ToolUseBlockParam( + id=tool_block.id, + name=tool_block.name, + input=substituted_input, + type=tool_block.type, + cache_control=tool_block.cache_control, + ) + + @staticmethod + def _substitute_in_value(value: Any, parameter_values: dict[str, str]) -> Any: + """Recursively substitute parameters in a value. + + Args: + value: Any value (str, dict, list, etc.) containing parameters + parameter_values: Dict of parameter names to replacement values + + Returns: + New value with parameters substituted + """ + if isinstance(value, str): + # Replace all parameters in the string + result = value + for name, replacement in parameter_values.items(): + placeholder = f"{{{{{name}}}}}" # Creates "{{parameter_name}}" + result = result.replace(placeholder, replacement) + return result + if isinstance(value, dict): + # Recursively substitute in dict values + return { + k: CacheParameterHandler._substitute_in_value(v, parameter_values) + for k, v in value.items() + } + if isinstance(value, list): + # Recursively substitute in list items + return [ + CacheParameterHandler._substitute_in_value(item, parameter_values) + for item in value + ] + # Return other types as-is + return value diff --git a/src/askui/utils/caching/cache_validator.py b/src/askui/utils/caching/cache_validator.py new file mode 100644 index 00000000..5faaf4f5 --- /dev/null +++ b/src/askui/utils/caching/cache_validator.py @@ -0,0 +1,245 @@ +"""Cache validation strategies for automatic cache invalidation. + +This module provides an extensible validator pattern that allows users to +define custom cache invalidation logic. The system includes built-in validators +for common scenarios like step failure counts, overall failure rates, and stale caches. +""" + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone + +from askui.models.shared.settings import CacheFile + + +class CacheValidator(ABC): + """Abstract base class for cache validation strategies. + + Users can implement custom validators by subclassing this and implementing + the should_invalidate method. + """ + + @abstractmethod + def should_invalidate( + self, cache_file: CacheFile, step_index: int | None = None + ) -> tuple[bool, str | None]: + """Check if cache should be invalidated. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + + @abstractmethod + def get_name(self) -> str: + """Return validator name for logging/debugging.""" + + +class CompositeCacheValidator(CacheValidator): + """Composite validator that combines multiple validation strategies. + + Invalidates cache if ANY of the validators returns True. + Users can add custom validators via add_validator(). + """ + + def __init__(self, validators: list[CacheValidator] | None = None) -> None: + """Initialize composite validator. + + Args: + validators: Optional list of validators to include + """ + self.validators: list[CacheValidator] = validators or [] + + def add_validator(self, validator: CacheValidator) -> None: + """Add a validator to the composite. + + Args: + validator: The validator to add + """ + self.validators.append(validator) + + def should_invalidate( + self, cache_file: CacheFile, step_index: int | None = None + ) -> tuple[bool, str | None]: + """Check all validators, invalidate if any returns True. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Optional step index where failure occurred + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + If multiple validators trigger, reasons are combined with "; " + """ + reasons = [] + for validator in self.validators: + should_inv, reason = validator.should_invalidate(cache_file, step_index) + if should_inv and reason: + reasons.append(f"{validator.get_name()}: {reason}") + + if reasons: + return True, "; ".join(reasons) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "CompositeValidator" + + +# Built-in validators + + +class StepFailureCountValidator(CacheValidator): + """Invalidate if same step fails too many times. + + This validator counts how many times a specific step has failed + and invalidates the cache if it exceeds the threshold. + """ + + def __init__(self, max_failures_per_step: int = 3) -> None: + """Initialize validator. + + Args: + max_failures_per_step: Maximum number of failures allowed per step + """ + self.max_failures_per_step = max_failures_per_step + + def should_invalidate( + self, cache_file: CacheFile, step_index: int | None = None + ) -> tuple[bool, str | None]: + """Check if step has failed too many times. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: The step index to check (required for this validator) + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + if step_index is None: + return False, None + + # Count failures at this specific step + failures_at_step = sum( + 1 for f in cache_file.metadata.failures if f.step_index == step_index + ) + + if failures_at_step >= self.max_failures_per_step: + return ( + True, + f"Step {step_index} failed {failures_at_step} times " + f"(max: {self.max_failures_per_step})", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "StepFailureCount" + + +class TotalFailureRateValidator(CacheValidator): + """Invalidate if overall failure rate is too high. + + This validator calculates the ratio of failures to execution attempts + and invalidates if the rate exceeds the threshold after a minimum + number of attempts. + """ + + def __init__(self, min_attempts: int = 10, max_failure_rate: float = 0.5) -> None: + """Initialize validator. + + Args: + min_attempts: Minimum execution attempts before checking rate + max_failure_rate: Maximum acceptable failure rate (0.0 to 1.0) + """ + self.min_attempts = min_attempts + self.max_failure_rate = max_failure_rate + + def should_invalidate( + self, + cache_file: CacheFile, + step_index: int | None = None, # noqa: ARG002 + ) -> tuple[bool, str | None]: + """Check if overall failure rate is too high. + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Unused for this validator + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + attempts = cache_file.metadata.execution_attempts + if attempts < self.min_attempts: + return False, None + + failures = len(cache_file.metadata.failures) + failure_rate = failures / attempts if attempts > 0 else 0.0 + + if failure_rate > self.max_failure_rate: + return ( + True, + f"Failure rate {failure_rate:.1%} exceeds " + f"{self.max_failure_rate:.1%} after {attempts} attempts", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "TotalFailureRate" + + +class StaleCacheValidator(CacheValidator): + """Invalidate if cache is old and has failures. + + This validator checks if a cache hasn't been successfully executed + in a long time AND has failures. Caches without failures are not + considered stale regardless of age. + """ + + def __init__(self, max_age_days: int = 30) -> None: + """Initialize validator. + + Args: + max_age_days: Maximum age in days for cache with failures + """ + self.max_age_days = max_age_days + + def should_invalidate( + self, + cache_file: CacheFile, + step_index: int | None = None, # noqa: ARG002 + ) -> tuple[bool, str | None]: + """Check if cache is stale (old + has failures). + + Args: + cache_file: The cache file with metadata and trajectory + step_index: Unused for this validator + + Returns: + Tuple of (should_invalidate: bool, reason: Optional[str]) + """ + if not cache_file.metadata.last_executed_at: + return False, None + + if not cache_file.metadata.failures: + return False, None # No failures, age doesn't matter + + # Ensure last_executed_at is timezone-aware + last_executed = cache_file.metadata.last_executed_at + if last_executed.tzinfo is None: + last_executed = last_executed.replace(tzinfo=timezone.utc) + + age = datetime.now(tz=timezone.utc) - last_executed + if age > timedelta(days=self.max_age_days): + return ( + True, + f"Cache not successfully executed in {age.days} days and has failures", + ) + return False, None + + def get_name(self) -> str: + """Return validator name.""" + return "StaleCache" diff --git a/src/askui/utils/visual_validation.py b/src/askui/utils/visual_validation.py new file mode 100644 index 00000000..1e1fd404 --- /dev/null +++ b/src/askui/utils/visual_validation.py @@ -0,0 +1,228 @@ +"""Visual validation utilities for cache execution. + +This module provides utilities for visual validation of cached trajectories: +- Image hashing functions (perceptual hash, average hash) +- Hamming distance computation +- Region extraction from images +- Screenshot extraction from message history +""" + +import logging +from typing import TYPE_CHECKING, Any + +import imagehash +from PIL import Image + +if TYPE_CHECKING: + from askui.models.shared.agent_message_param import MessageParam + +logger = logging.getLogger(__name__) + + +def compute_phash(image: Image.Image, hash_size: int = 8) -> str: + """Compute perceptual hash (pHash) of an image. + + Uses DCT-based perceptual hashing which is robust to scaling, aspect ratio + changes, and minor modifications. + + Args: + image: PIL Image to hash + hash_size: Size of the hash (default: 8, produces 64-bit hash) + + Returns: + String representation of the hash (hex format) + """ + phash = imagehash.phash(image, hash_size=hash_size) + return str(phash) + + +def compute_ahash(image: Image.Image, hash_size: int = 8) -> str: + """Compute average hash (aHash) of an image. + + Average hash is faster but less robust than perceptual hash. + Good for detecting exact duplicates or very similar images. + + Args: + image: PIL Image to hash + hash_size: Size of the hash (default: 8, produces 64-bit hash) + + Returns: + String representation of the hash (hex format) + """ + ahash = imagehash.average_hash(image, hash_size=hash_size) + return str(ahash) + + +def compute_hamming_distance(hash1: str, hash2: str) -> int: + """Compute Hamming distance between two image hashes. + + The Hamming distance is the number of bit positions in which the two + hashes differ. A distance of 0 means the images are identical (or very + similar). Larger distances indicate more visual difference. + + Typical thresholds: + - 0-5: Nearly identical images + - 6-10: Similar images with minor differences + - 11+: Different images + + Args: + hash1: First hash string (hex format) + hash2: Second hash string (hex format) + + Returns: + Hamming distance (number of differing bits) + + Raises: + ValueError: If hashes have different lengths + """ + if len(hash1) != len(hash2): + msg = f"Hashes must have same length. Got {len(hash1)} and {len(hash2)}" + raise ValueError(msg) + + # Convert hex strings to imagehash objects + ihash1 = imagehash.hex_to_hash(hash1) + ihash2 = imagehash.hex_to_hash(hash2) + + # Compute Hamming distance + return ihash1 - ihash2 + + +def extract_region( + image: Image.Image, + action_input: dict[str, Any], + region_size: int = 50, +) -> Image.Image: + """Extract a square region around an action coordinate. + + Extracts a square region centered on the coordinate specified in the + action input. Handles edge cases where the region would extend beyond + image boundaries by clipping to valid bounds. + + Args: + image: PIL Image to extract region from + action_input: Action input dict containing 'coordinate' key with [x, y] + region_size: Size of the square region to extract (default: 50 pixels) + + Returns: + Extracted region as PIL Image (may be smaller than region_size if + near image edges) + """ + coordinate = action_input.get("coordinate") + if not coordinate: + msg = f"No coordinate found in action_input: {action_input}" + logger.warning(msg) + return image + + x, y = coordinate + width, height = image.size + + # Calculate region bounds (centered on coordinate) and clip to valid bounds + half_size = region_size // 2 + left = max(0, x - half_size) + top = max(0, y - half_size) + right = min(width, x + half_size) + bottom = min(height, y + half_size) + + # Handle edge case where coordinates are completely out of bounds + # In this case, return an empty or minimal region + if left >= right or top >= bottom: + # Return minimal 1x1 region from top-left corner + return image.crop((0, 0, min(1, width), min(1, height))) + + # Extract and return region + return image.crop((left, top, right, bottom)) + + +def find_recent_screenshot( + messages: list["MessageParam"], + from_index: int | None = None, +) -> Image.Image | None: + """Extract most recent screenshot from message history. + + Looks backwards through message history for the most recent tool result + containing an image block (screenshot). This is used during both recording + and validation to extract the "before" state screenshot. + + Args: + messages: Message history to search through + from_index: Optional index to start searching backwards from. + If None, starts from end of list. + + Returns: + PIL Image from most recent screenshot, or None if not found + """ + start_idx = from_index if from_index is not None else len(messages) - 1 + + # Look backwards from start index + for i in range(start_idx, -1, -1): + message = messages[i] + if message.role != "user": + continue + + # Check if message content is a list of blocks + if isinstance(message.content, str): + continue + + # Look for tool result blocks with images + for block in message.content: + if block.type == "tool_result": + # Check for image blocks within tool result + if isinstance(block.content, list): + for content_item in block.content: + if content_item.type == "image": + # Found screenshot - decode and return + from askui.utils.image_utils import base64_to_image + + # Only base64 images have data attribute + if hasattr(content_item.source, "data"): + return base64_to_image(content_item.source.data) + + return None + + +def get_validation_coordinate(tool_input: dict[str, Any]) -> tuple[int, int] | None: + """Extract the coordinate for visual validation from tool input. + + Args: + tool_input: Tool input dictionary + + Returns: + (x, y) coordinate tuple or None if not applicable + + For click actions, returns the click coordinate. + For type actions, returns the coordinate of the text input field. + """ + + def try_pair(x_val: Any, y_val: Any) -> tuple[int, int] | None: + x = _safe_int(x_val) + y = _safe_int(y_val) + if x is None or y is None: + return None + return (x, y) + + if "coordinate" in tool_input: + coord = tool_input["coordinate"] + if isinstance(coord, list) and len(coord) == 2: + result = try_pair(coord[0], coord[1]) + if result is not None: + return result + + if "x" in tool_input and "y" in tool_input: + result = try_pair(tool_input["x"], tool_input["y"]) + if result is not None: + return result + + if "x1" in tool_input and "y1" in tool_input: + result = try_pair(tool_input["x1"], tool_input["y1"]) + if result is not None: + return result + + return None + + +def _safe_int(value: Any) -> int | None: + """Try converting value to int, return None if not possible.""" + try: + return int(value) + except (TypeError, ValueError): + return None diff --git a/tests/e2e/agent/test_act_caching.py b/tests/e2e/agent/test_act_caching.py index 91f5ffaf..b13a2953 100644 --- a/tests/e2e/agent/test_act_caching.py +++ b/tests/e2e/agent/test_act_caching.py @@ -4,44 +4,40 @@ import tempfile from pathlib import Path -import pytest - from askui.agent import ComputerAgent -from askui.models.shared.agent_message_param import MessageParam -from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.models.shared.settings import CachedExecutionToolSettings, CachingSettings +from askui.models.shared.settings import CacheExecutionSettings, CachingSettings -def test_act_with_caching_strategy_read(vision_agent: ComputerAgent) -> None: - """Test that caching_strategy='read' adds retrieve and execute tools.""" +def test_act_with_caching_strategy_execute(vision_agent: ComputerAgent) -> None: + """Test that caching_strategy='execute' adds retrieve and execute tools.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a dummy cache file cache_dir = Path(temp_dir) cache_file = cache_dir / "test_cache.json" cache_file.write_text("[]", encoding="utf-8") - # Act with read caching strategy + # Act with execute caching strategy vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=str(cache_dir), ), ) assert True -def test_act_with_caching_strategy_write(vision_agent: ComputerAgent) -> None: - """Test that caching_strategy='write' writes cache file.""" +def test_act_with_caching_strategy_record(vision_agent: ComputerAgent) -> None: + """Test that caching_strategy='record' writes cache file.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) cache_filename = "test_output.json" - # Act with write caching strategy + # Act with record caching strategy vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(cache_dir), filename=cache_filename, ), @@ -53,12 +49,12 @@ def test_act_with_caching_strategy_write(vision_agent: ComputerAgent) -> None: def test_act_with_caching_strategy_both(vision_agent: ComputerAgent) -> None: - """Test that caching_strategy='both' enables both read and write.""" + """Test that caching_strategy='both' enables both execute and record.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) cache_filename = "test_both.json" - # Create a dummy cache file for reading + # Create a dummy cache file for executing cache_file = cache_dir / "existing_cache.json" cache_file.write_text("[]", encoding="utf-8") @@ -66,7 +62,7 @@ def test_act_with_caching_strategy_both(vision_agent: ComputerAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="both", + strategy="auto", cache_dir=str(cache_dir), filename=cache_filename, ), @@ -77,8 +73,8 @@ def test_act_with_caching_strategy_both(vision_agent: ComputerAgent) -> None: assert output_file.exists() -def test_act_with_caching_strategy_no(vision_agent: ComputerAgent) -> None: - """Test that caching_strategy='no' doesn't create cache files.""" +def test_act_with_caching_strategy_none(vision_agent: ComputerAgent) -> None: + """Test that caching_strategy=None doesn't create cache files.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) @@ -86,7 +82,7 @@ def test_act_with_caching_strategy_no(vision_agent: ComputerAgent) -> None: vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="no", + strategy=None, cache_dir=str(cache_dir), ), ) @@ -106,7 +102,7 @@ def test_act_with_custom_cache_dir_and_filename(vision_agent: ComputerAgent) -> vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(custom_cache_dir), filename=custom_filename, ), @@ -118,48 +114,6 @@ def test_act_with_custom_cache_dir_and_filename(vision_agent: ComputerAgent) -> assert cache_file.exists() -def test_act_with_on_message_and_write_caching_raises_error( - vision_agent: ComputerAgent, -) -> None: - """Test that providing on_message callback with write caching raises ValueError.""" - with tempfile.TemporaryDirectory() as temp_dir: - - def dummy_callback(param: OnMessageCbParam) -> MessageParam: - return param.message - - # Should raise ValueError when on_message is provided with write strategy - with pytest.raises(ValueError, match="Cannot use on_message callback"): - vision_agent.act( - goal="Tell me a joke", - caching_settings=CachingSettings( - strategy="write", - cache_dir=str(temp_dir), - ), - on_message=dummy_callback, - ) - - -def test_act_with_on_message_and_both_caching_raises_error( - vision_agent: ComputerAgent, -) -> None: - """Test that providing on_message callback with both caching raises ValueError.""" - with tempfile.TemporaryDirectory() as temp_dir: - - def dummy_callback(param: OnMessageCbParam) -> MessageParam: - return param.message - - # Should raise ValueError when on_message is provided with both strategy - with pytest.raises(ValueError, match="Cannot use on_message callback"): - vision_agent.act( - goal="Tell me a joke", - caching_settings=CachingSettings( - strategy="both", - cache_dir=str(temp_dir), - ), - on_message=dummy_callback, - ) - - def test_cache_file_contains_tool_use_blocks(vision_agent: ComputerAgent) -> None: """Test that cache file contains ToolUseBlockParam entries.""" with tempfile.TemporaryDirectory() as temp_dir: @@ -170,7 +124,7 @@ def test_cache_file_contains_tool_use_blocks(vision_agent: ComputerAgent) -> Non vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="write", + strategy="record", cache_dir=str(cache_dir), filename=cache_filename, ), @@ -196,22 +150,22 @@ def test_cache_file_contains_tool_use_blocks(vision_agent: ComputerAgent) -> Non def test_act_with_custom_cached_execution_tool_settings( vision_agent: ComputerAgent, ) -> None: - """Test that custom CachedExecutionToolSettings are applied.""" + """Test that custom CacheExecutionSettings are applied.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - # Create a dummy cache file for reading + # Create a dummy cache file for executing cache_file = cache_dir / "test_cache.json" cache_file.write_text("[]", encoding="utf-8") - # Act with custom execution tool settings - custom_settings = CachedExecutionToolSettings(delay_time_between_action=2.0) + # Act with custom execution settings + custom_settings = CacheExecutionSettings(delay_time_between_actions=2.0) vision_agent.act( goal="Tell me a joke", caching_settings=CachingSettings( - strategy="read", + strategy="execute", cache_dir=str(cache_dir), - execute_cached_trajectory_tool_settings=custom_settings, + execution_settings=custom_settings, ), ) diff --git a/tests/unit/tools/test_caching_tools.py b/tests/unit/tools/test_caching_tools.py index f1f622c4..4f162c86 100644 --- a/tests/unit/tools/test_caching_tools.py +++ b/tests/unit/tools/test_caching_tools.py @@ -3,27 +3,40 @@ import json import tempfile from pathlib import Path -from typing import Any -from unittest.mock import MagicMock, patch import pytest -from askui.models.shared.settings import CachedExecutionToolSettings -from askui.models.shared.tools import ToolCollection from askui.tools.caching_tools import ( - ExecuteCachedTrajectory, + InspectCacheMetadata, RetrieveCachedTestExecutions, + VerifyCacheExecution, ) +def _create_valid_cache_file(path: Path, is_valid: bool = True) -> None: + """Create a valid cache file with required metadata structure.""" + cache_data = { + "metadata": { + "version": "1.0", + "created_at": "2025-01-01T00:00:00Z", + "is_valid": is_valid, + "execution_attempts": 0, + "failures": [], + }, + "trajectory": [], + "cache_parameters": {}, + } + path.write_text(json.dumps(cache_data), encoding="utf-8") + + def test_retrieve_cached_test_executions_lists_json_files() -> None: """Test that RetrieveCachedTestExecutions lists all JSON files in cache dir.""" with tempfile.TemporaryDirectory() as temp_dir: cache_dir = Path(temp_dir) - # Create some cache files - (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") - (cache_dir / "cache2.json").write_text("{}", encoding="utf-8") + # Create valid cache files + _create_valid_cache_file(cache_dir / "cache1.json") + _create_valid_cache_file(cache_dir / "cache2.json") (cache_dir / "not_cache.txt").write_text("text", encoding="utf-8") tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) @@ -60,8 +73,8 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: cache_dir = Path(temp_dir) # Create files with different extensions - (cache_dir / "cache1.json").write_text("{}", encoding="utf-8") - (cache_dir / "cache2.traj").write_text("{}", encoding="utf-8") + _create_valid_cache_file(cache_dir / "cache1.json") + _create_valid_cache_file(cache_dir / "cache2.traj") # Default format (.json) tool_json = RetrieveCachedTestExecutions( @@ -80,254 +93,136 @@ def test_retrieve_cached_test_executions_respects_custom_format() -> None: assert "cache2.traj" in result_traj[0] -def test_execute_cached_execution_initializes_without_toolbox() -> None: - """Test that ExecuteCachedExecution can be initialized without toolbox.""" - tool = ExecuteCachedTrajectory() - assert tool.name.startswith("execute_cached_executions_tool") - - -def test_execute_cached_execution_raises_error_without_toolbox() -> None: - """Test that ExecuteCachedExecution raises error when toolbox not set.""" - tool = ExecuteCachedTrajectory() - - with pytest.raises(RuntimeError, match="Toolbox not set"): - tool(trajectory_file="some_file.json") - - -def test_execute_cached_execution_raises_error_when_file_not_found() -> None: - """Test that ExecuteCachedExecution raises error if trajectory file doesn't exist""" - tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) - - with pytest.raises(FileNotFoundError, match="Trajectory file not found"): - tool(trajectory_file="/non/existent/file.json") - - -def test_execute_cached_execution_executes_trajectory() -> None: - """Test that ExecuteCachedExecution executes tools from trajectory file.""" +def test_retrieve_cached_test_executions_filters_invalid_by_default() -> None: + """Test that invalid caches are filtered out by default.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_trajectory.json" - - # Create a trajectory file - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", - }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", - }, - ] - - with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + cache_dir = Path(temp_dir) - # Execute the trajectory - tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + # Create valid and invalid cache files + _create_valid_cache_file(cache_dir / "valid.json", is_valid=True) + _create_valid_cache_file(cache_dir / "invalid.json", is_valid=False) - result = tool(trajectory_file=str(cache_file)) + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + result = tool(include_invalid=False) - # Verify success message - assert "Successfully executed trajectory" in result - # Verify toolbox.run was called for each tool (2 calls) - assert mock_toolbox.run.call_count == 2 + assert len(result) == 1 + assert any("valid.json" in path for path in result) + assert not any("invalid.json" in path for path in result) -def test_execute_cached_execution_skips_screenshot_tools() -> None: - """Test that ExecuteCachedExecution skips screenshot-related tools.""" +def test_retrieve_cached_test_executions_includes_invalid_when_requested() -> None: + """Test that invalid caches are included when include_invalid=True.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_trajectory.json" - - # Create a trajectory with screenshot tools - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "screenshot", - "input": {}, - "type": "tool_use", - }, - { - "id": "tool2", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", - }, - { - "id": "tool3", - "name": "retrieve_available_trajectories_tool", - "input": {}, - "type": "tool_use", - }, - ] - - with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) + cache_dir = Path(temp_dir) - # Execute the trajectory - tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) + # Create valid and invalid cache files + _create_valid_cache_file(cache_dir / "valid.json", is_valid=True) + _create_valid_cache_file(cache_dir / "invalid.json", is_valid=False) - result = tool(trajectory_file=str(cache_file)) + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + result = tool(include_invalid=True) - # Verify only click_tool was executed (screenshot and retrieve tools skipped) - assert mock_toolbox.run.call_count == 1 - assert "Successfully executed trajectory" in result + assert len(result) == 2 + assert any("valid.json" in path for path in result) + assert any("invalid.json" in path for path in result) -def test_execute_cached_execution_handles_errors_gracefully() -> None: - """Test that ExecuteCachedExecution handles errors during execution.""" +def test_retrieve_cached_test_executions_returns_parameter_info() -> None: + """Test that cache parameter info is included in the result.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_trajectory.json" - - # Create a trajectory - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "failing_tool", - "input": {}, - "type": "tool_use", - }, - ] - - with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) - - # Execute the trajectory with a failing tool - tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) - mock_toolbox.run.side_effect = Exception("Tool execution failed") - tool.set_toolbox(mock_toolbox) - - result = tool(trajectory_file=str(cache_file)) + cache_dir = Path(temp_dir) - # Verify error message - assert "error occured" in result.lower() - assert "verify the UI state" in result + # Create cache file with parameters + cache_data = { + "metadata": { + "version": "1.0", + "created_at": "2025-01-01T00:00:00Z", + "is_valid": True, + "execution_attempts": 0, + "failures": [], + }, + "trajectory": [], + "cache_parameters": {"target_url": "placeholder", "user_id": "123"}, + } + cache_file = cache_dir / "with_params.json" + cache_file.write_text(json.dumps(cache_data), encoding="utf-8") + tool = RetrieveCachedTestExecutions(cache_dir=str(cache_dir)) + result = tool() -def test_execute_cached_execution_set_toolbox() -> None: - """Test that set_toolbox properly sets the toolbox reference.""" - tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) + assert len(result) == 1 + assert "parameters:" in result[0] + assert "target_url" in result[0] - tool.set_toolbox(mock_toolbox) - # After setting toolbox, should be able to access it - assert hasattr(tool, "_toolbox") - assert tool._toolbox == mock_toolbox +def test_verify_cache_execution_initializes_correctly() -> None: + """Test that VerifyCacheExecution initializes correctly.""" + tool = VerifyCacheExecution() + assert tool.name.startswith("verify_cache_execution") + assert "success" in tool.input_schema["properties"] + assert "verification_notes" in tool.input_schema["properties"] + assert tool.is_cacheable is False -def test_execute_cached_execution_initializes_with_default_settings() -> None: - """Test that ExecuteCachedTrajectory uses default settings when none provided.""" - tool = ExecuteCachedTrajectory() +def test_verify_cache_execution_reports_success() -> None: + """Test that VerifyCacheExecution reports success correctly.""" + tool = VerifyCacheExecution() + result = tool(success=True, verification_notes="UI state matches expected") - # Should have default settings initialized - assert hasattr(tool, "_settings") + assert "success=True" in result + assert "UI state matches expected" in result -def test_execute_cached_execution_initializes_with_custom_settings() -> None: - """Test that ExecuteCachedTrajectory accepts custom settings.""" - custom_settings = CachedExecutionToolSettings(delay_time_between_action=1.0) - tool = ExecuteCachedTrajectory(settings=custom_settings) +def test_verify_cache_execution_reports_failure() -> None: + """Test that VerifyCacheExecution reports failure correctly.""" + tool = VerifyCacheExecution() + result = tool(success=False, verification_notes="Button was not clicked") - # Should have custom settings initialized - assert hasattr(tool, "_settings") + assert "success=False" in result + assert "Button was not clicked" in result -def test_execute_cached_execution_uses_delay_time_between_actions() -> None: - """Test that ExecuteCachedTrajectory uses the configured delay time.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_trajectory.json" - - # Create a trajectory with 3 actions - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", - }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", - }, - { - "id": "tool3", - "name": "move_tool", - "input": {"x": 300, "y": 400}, - "type": "tool_use", - }, - ] +def test_inspect_cache_metadata_initializes_correctly() -> None: + """Test that InspectCacheMetadata initializes correctly.""" + tool = InspectCacheMetadata() + assert tool.name.startswith("inspect_cache_metadata_tool") + assert "trajectory_file" in tool.input_schema["properties"] - with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) - # Execute with custom delay time - custom_settings = CachedExecutionToolSettings(delay_time_between_action=0.1) - tool = ExecuteCachedTrajectory(settings=custom_settings) - mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) +def test_inspect_cache_metadata_returns_error_when_file_not_found() -> None: + """Test that InspectCacheMetadata returns error if file doesn't exist.""" + tool = InspectCacheMetadata() - # Mock time.sleep to verify it's called with correct delay - with patch("time.sleep") as mock_sleep: - result = tool(trajectory_file=str(cache_file)) + result = tool(trajectory_file="/non/existent/file.json") - # Verify success - assert "Successfully executed trajectory" in result - # Verify sleep was called 3 times (once after each action) - assert mock_sleep.call_count == 3 - # Verify it was called with the configured delay time - for call in mock_sleep.call_args_list: - assert call[0][0] == 0.1 + assert "Trajectory file not found" in result -def test_execute_cached_execution_default_delay_time() -> None: - """Test that ExecuteCachedTrajectory uses default delay time of 0.5s.""" +def test_inspect_cache_metadata_returns_metadata() -> None: + """Test that InspectCacheMetadata returns formatted metadata.""" with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_trajectory.json" - - # Create a trajectory with 2 actions - trajectory: list[dict[str, Any]] = [ - { - "id": "tool1", - "name": "click_tool", - "input": {"x": 100, "y": 200}, - "type": "tool_use", - }, - { - "id": "tool2", - "name": "type_tool", - "input": {"text": "hello"}, - "type": "tool_use", + cache_file = Path(temp_dir) / "test_cache.json" + cache_data = { + "metadata": { + "version": "1.0", + "created_at": "2025-01-01T00:00:00Z", + "is_valid": True, + "execution_attempts": 5, + "failures": [], }, - ] - - with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) - - # Execute with default settings - tool = ExecuteCachedTrajectory() - mock_toolbox = MagicMock(spec=ToolCollection) - tool.set_toolbox(mock_toolbox) - - # Mock time.sleep to verify default delay is used - with patch("time.sleep") as mock_sleep: - result = tool(trajectory_file=str(cache_file)) - - # Verify success - assert "Successfully executed trajectory" in result - # Verify sleep was called with default delay of 0.5s - assert mock_sleep.call_count == 2 - for call in mock_sleep.call_args_list: - assert call[0][0] == 0.5 + "trajectory": [ + {"id": "1", "name": "click", "input": {}, "type": "tool_use"} + ], + "cache_parameters": {"url": "test"}, + } + cache_file.write_text(json.dumps(cache_data), encoding="utf-8") + + tool = InspectCacheMetadata() + result = tool(trajectory_file=str(cache_file)) + + assert "=== Cache Metadata ===" in result + assert "Version: 1.0" in result + assert "Is Valid: True" in result + assert "Total Execution Attempts: 5" in result + assert "Total Steps: 1" in result + assert "url" in result diff --git a/tests/unit/tools/test_switch_speaker_tool.py b/tests/unit/tools/test_switch_speaker_tool.py new file mode 100644 index 00000000..cff61ecd --- /dev/null +++ b/tests/unit/tools/test_switch_speaker_tool.py @@ -0,0 +1,39 @@ +"""Unit tests for SwitchSpeakerTool.""" + +from askui.tools.switch_speaker_tool import SwitchSpeakerTool + + +def test_switch_speaker_tool_name() -> None: + """Test that the tool name starts with switch_speaker.""" + tool = SwitchSpeakerTool(speaker_names=["CacheExecutor"]) + assert tool.name.startswith("switch_speaker") + + +def test_switch_speaker_tool_enum_constraint() -> None: + """Test that speaker names are set as enum in input schema.""" + tool = SwitchSpeakerTool(speaker_names=["CacheExecutor", "ValidationAgent"]) + enum_values = tool.input_schema["properties"]["speaker_name"]["enum"] + assert enum_values == ["CacheExecutor", "ValidationAgent"] + + +def test_switch_speaker_tool_is_not_cacheable() -> None: + """Test that the tool is marked as not cacheable.""" + tool = SwitchSpeakerTool(speaker_names=["CacheExecutor"]) + assert tool.is_cacheable is False + + +def test_switch_speaker_tool_call_returns_acknowledgment() -> None: + """Test that calling the tool returns an acknowledgment message.""" + tool = SwitchSpeakerTool(speaker_names=["CacheExecutor"]) + result = tool(speaker_name="CacheExecutor") + assert "Switching to speaker 'CacheExecutor'" == result + + +def test_switch_speaker_tool_call_with_context() -> None: + """Test that calling the tool with context works.""" + tool = SwitchSpeakerTool(speaker_names=["CacheExecutor"]) + result = tool( + speaker_name="CacheExecutor", + speaker_context={"trajectory_file": "test.json"}, + ) + assert "CacheExecutor" in result diff --git a/tests/unit/utils/test_cache_writer.py b/tests/unit/utils/test_cache_writer.py deleted file mode 100644 index 2c875ae4..00000000 --- a/tests/unit/utils/test_cache_writer.py +++ /dev/null @@ -1,312 +0,0 @@ -"""Unit tests for CacheWriter utility.""" - -import json -import tempfile -from pathlib import Path -from typing import Any - -from askui.models.shared.agent_message_param import MessageParam, ToolUseBlockParam -from askui.models.shared.agent_on_message_cb import OnMessageCbParam -from askui.utils.cache_writer import CacheWriter - - -def test_cache_writer_initialization() -> None: - """Test CacheWriter initialization.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") - assert cache_writer.cache_dir == Path(temp_dir) - assert cache_writer.file_name == "test.json" - assert cache_writer.messages == [] - assert cache_writer.was_cached_execution is False - - -def test_cache_writer_creates_cache_directory() -> None: - """Test that CacheWriter creates the cache directory if it doesn't exist.""" - with tempfile.TemporaryDirectory() as temp_dir: - non_existent_dir = Path(temp_dir) / "new_cache_dir" - assert not non_existent_dir.exists() - - CacheWriter(cache_dir=str(non_existent_dir)) - assert non_existent_dir.exists() - assert non_existent_dir.is_dir() - - -def test_cache_writer_adds_json_extension() -> None: - """Test that CacheWriter adds .json extension if not present.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test") - assert cache_writer.file_name == "test.json" - - cache_writer2 = CacheWriter(cache_dir=temp_dir, file_name="test.json") - assert cache_writer2.file_name == "test.json" - - -def test_cache_writer_add_message_cb_stores_tool_use_blocks() -> None: - """Test that add_message_cb stores ToolUseBlockParam from assistant messages.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") - - tool_use_block = ToolUseBlockParam( - id="test_id", - name="test_tool", - input={"param": "value"}, - type="tool_use", - ) - - message = MessageParam( - role="assistant", - content=[tool_use_block], - stop_reason=None, - ) - - param = OnMessageCbParam( - message=message, - messages=[message], - ) - - result = cache_writer.add_message_cb(param) - assert result == param.message - assert len(cache_writer.messages) == 1 - assert cache_writer.messages[0] == tool_use_block - - -def test_cache_writer_add_message_cb_ignores_non_tool_use_content() -> None: - """Test that add_message_cb ignores non-ToolUseBlockParam content.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") - - message = MessageParam( - role="assistant", - content="Just a text message", - stop_reason=None, - ) - - param = OnMessageCbParam( - message=message, - messages=[message], - ) - - cache_writer.add_message_cb(param) - assert len(cache_writer.messages) == 0 - - -def test_cache_writer_add_message_cb_ignores_user_messages() -> None: - """Test that add_message_cb ignores user messages.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") - - message = MessageParam( - role="user", - content="User message", - stop_reason=None, - ) - - param = OnMessageCbParam( - message=message, - messages=[message], - ) - - cache_writer.add_message_cb(param) - assert len(cache_writer.messages) == 0 - - -def test_cache_writer_detects_cached_execution() -> None: - """Test that CacheWriter detects when execute_cached_executions_tool is used.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="test.json") - - tool_use_block = ToolUseBlockParam( - id="cached_exec_id", - name="execute_cached_executions_tool", - input={"trajectory_file": "test.json"}, - type="tool_use", - ) - - message = MessageParam( - role="assistant", - content=[tool_use_block], - stop_reason=None, - ) - - param = OnMessageCbParam( - message=message, - messages=[message], - ) - - cache_writer.add_message_cb(param) - assert cache_writer.was_cached_execution is True - - -def test_cache_writer_generate_writes_file() -> None: - """Test that generate() writes messages to a JSON file.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="output.json") - - # Add some tool use blocks - tool_use1 = ToolUseBlockParam( - id="id1", - name="tool1", - input={"param": "value1"}, - type="tool_use", - ) - tool_use2 = ToolUseBlockParam( - id="id2", - name="tool2", - input={"param": "value2"}, - type="tool_use", - ) - - cache_writer.messages = [tool_use1, tool_use2] - cache_writer.generate() - - # Verify file was created - cache_file = cache_dir / "output.json" - assert cache_file.exists() - - # Verify file content - with cache_file.open("r", encoding="utf-8") as f: - data = json.load(f) - - assert len(data) == 2 - assert data[0]["id"] == "id1" - assert data[0]["name"] == "tool1" - assert data[1]["id"] == "id2" - assert data[1]["name"] == "tool2" - - -def test_cache_writer_generate_auto_names_file() -> None: - """Test that generate() auto-generates filename if not provided.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="") - - tool_use = ToolUseBlockParam( - id="id1", - name="tool1", - input={}, - type="tool_use", - ) - cache_writer.messages = [tool_use] - cache_writer.generate() - - # Verify a file was created with auto-generated name - json_files = list(cache_dir.glob("*.json")) - assert len(json_files) == 1 - assert json_files[0].name.startswith("cached_trajectory_") - - -def test_cache_writer_generate_skips_cached_execution() -> None: - """Test that generate() doesn't write file for cached executions.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") - - cache_writer.was_cached_execution = True - cache_writer.messages = [ - ToolUseBlockParam( - id="id1", - name="tool1", - input={}, - type="tool_use", - ) - ] - - cache_writer.generate() - - # Verify no file was created - json_files = list(cache_dir.glob("*.json")) - assert len(json_files) == 0 - - -def test_cache_writer_reset() -> None: - """Test that reset() clears messages and filename.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") - - # Add some data - cache_writer.messages = [ - ToolUseBlockParam( - id="id1", - name="tool1", - input={}, - type="tool_use", - ) - ] - cache_writer.was_cached_execution = True - - # Reset - cache_writer.reset(file_name="new.json") - - assert cache_writer.messages == [] - assert cache_writer.file_name == "new.json" - assert cache_writer.was_cached_execution is False - - -def test_cache_writer_read_cache_file() -> None: - """Test that read_cache_file() loads ToolUseBlockParam from JSON.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_file = Path(temp_dir) / "test_cache.json" - - # Create a cache file - trajectory: list[dict[str, Any]] = [ - { - "id": "id1", - "name": "tool1", - "input": {"param": "value1"}, - "type": "tool_use", - }, - { - "id": "id2", - "name": "tool2", - "input": {"param": "value2"}, - "type": "tool_use", - }, - ] - - with cache_file.open("w", encoding="utf-8") as f: - json.dump(trajectory, f) - - # Read cache file - result = CacheWriter.read_cache_file(cache_file) - - assert len(result) == 2 - assert isinstance(result[0], ToolUseBlockParam) - assert result[0].id == "id1" - assert result[0].name == "tool1" - assert isinstance(result[1], ToolUseBlockParam) - assert result[1].id == "id2" - assert result[1].name == "tool2" - - -def test_cache_writer_set_file_name() -> None: - """Test that set_file_name() updates the filename.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_writer = CacheWriter(cache_dir=temp_dir, file_name="original.json") - - cache_writer.set_file_name("new_name") - assert cache_writer.file_name == "new_name.json" - - cache_writer.set_file_name("another.json") - assert cache_writer.file_name == "another.json" - - -def test_cache_writer_generate_resets_after_writing() -> None: - """Test that generate() calls reset() after writing the file.""" - with tempfile.TemporaryDirectory() as temp_dir: - cache_dir = Path(temp_dir) - cache_writer = CacheWriter(cache_dir=str(cache_dir), file_name="test.json") - - cache_writer.messages = [ - ToolUseBlockParam( - id="id1", - name="tool1", - input={}, - type="tool_use", - ) - ] - - cache_writer.generate() - - # After generate, messages should be empty - assert cache_writer.messages == []