diff --git a/README.md b/README.md
index 596614d1..fc3bb70f 100644
--- a/README.md
+++ b/README.md
@@ -88,7 +88,7 @@ pip install askui
| | AskUI [INFO](https://hub.askui.com/) | Anthropic [INFO](https://console.anthropic.com/settings/keys) |
|----------|----------|----------|
| ENV Variables | `ASKUI_WORKSPACE_ID`, `ASKUI_TOKEN` | `ANTHROPIC_API_KEY` |
-| Supported Commands | `click()` | `click()`, `get()`, `act()` |
+| Supported Commands | `click()`, `get()`, `locate()`, `mouse_move()` | `act()`, `click()`, `get()`, `locate()`, `mouse_move()` |
| Description | Faster Inference, European Server, Enterprise Ready | Supports complex actions |
To get started, set the environment variables required to authenticate with your chosen model provider.
@@ -130,7 +130,7 @@ You can test the Vision Agent with Huggingface models via their Spaces API. Plea
**Example Code:**
```python
-agent.click("search field", model_name="OS-Copilot/OS-Atlas-Base-7B")
+agent.click("search field", model="OS-Copilot/OS-Atlas-Base-7B")
```
### 3c. Host your own **AI Models**
@@ -143,7 +143,7 @@ You can use Vision Agent with UI-TARS if you provide your own UI-TARS API endpoi
2. Step: Provide the `TARS_URL` and `TARS_API_KEY` environment variables to Vision Agent.
-3. Step: Use the `model_name="tars"` parameter in your `click()`, `get()` and `act()` commands.
+3. Step: Use the `model="tars"` parameter in your `click()`, `get()` and `act()` etc. commands or when initializing the `VisionAgent`.
## ▶️ Start Building
@@ -171,46 +171,68 @@ with VisionAgent() as agent:
### 🎛️ Model Selection
-Instead of relying on the default model for the entire automation script, you can specify a model for each `click` command using the `model_name` parameter.
+Instead of relying on the default model for the entire automation script, you can specify a model for each `click()` (or `act()`, `get()` etc.) command using the `model` parameter or when initializing the `VisionAgent` (overridden by the `model` parameter of individual commands).
| | AskUI | Anthropic |
|----------|----------|----------|
-| `click()` | `askui-combo`, `askui-pta`, `askui-ocr` | `anthropic-claude-3-5-sonnet-20241022` |
+| `act()` | | `anthropic-claude-3-5-sonnet-20241022` |
+| `click()` | `askui`, `askui-combo`, `askui-pta`, `askui-ocr`, `askui-ai-element` | `anthropic-claude-3-5-sonnet-20241022` |
+| `get()` | | `askui`, `anthropic-claude-3-5-sonnet-20241022` |
+| `locate()` | `askui`, `askui-combo`, `askui-pta`, `askui-ocr`, `askui-ai-element` | `anthropic-claude-3-5-sonnet-20241022` |
+| `mouse_move()` | `askui`, `askui-combo`, `askui-pta`, `askui-ocr`, `askui-ai-element` | `anthropic-claude-3-5-sonnet-20241022` |
-**Example:** `agent.click("Preview", model_name="askui-combo")`
-
- Antrophic AI Models
-
-Supported commands are: `click()`, `type()`, `mouse_move()`, `get()`, `act()`
-| Model Name | Info | Execution Speed | Security | Cost | Reliability |
-|-------------|--------------------|--------------|--------------|--------------|--------------|
-| `anthropic-claude-3-5-sonnet-20241022` | The [Computer Use](https://docs.anthropic.com/en/docs/agents-and-tools/computer-use) model from Antrophic is a Large Action Model (LAM), which can autonomously achieve goals. e.g. `"Book me a flight from Berlin to Rom"` | slow, >1s per step | Model hosting by Anthropic | High, up to 1,5$ per act | Not recommended for production usage |
-> **Note:** Configure your Antrophic Model Provider [here](#3a-authenticate-with-an-ai-model-provider)
+**Example:**
+```python
+from askui import VisionAgent
-
+with VisionAgent() as agent:
+ # Uses the default model (depending on the environment variables set, see above)
+ agent.click("Next")
+
+with VisionAgent(model="askui-combo") as agent:
+ # Uses the "askui-combo" model because it was specified when initializing the agent
+ agent.click("Next")
+ # Uses the "anthropic-claude-3-5-sonnet-20241022" model
+ agent.click("Previous", model="anthropic-claude-3-5-sonnet-20241022")
+ # Uses the "askui-combo" model again as no model was specified
+ agent.click("Next")
+```
AskUI AI Models
-Supported commands are: `click()`, `type()`, `mouse_move()`
+Supported commands are: `click()`, `locate()`, `mouse_move()`
| Model Name | Info | Execution Speed | Security | Cost | Reliability |
|-------------|--------------------|--------------|--------------|--------------|--------------|
+| `askui` | `AskUI` is a combination of all the following models: `askui-pta`, `askui-ocr`, `askui-combo`, `askui-ai-element` where AskUI chooses the best model for the task depending on the input. | Fast, <500ms per step | Secure hosting by AskUI or on-premise | Low, <0,05$ per step | Recommended for production usage, can be (at least partially) retrained |
| `askui-pta` | [`PTA-1`](https://huggingface.co/AskUI/PTA-1) (Prompt-to-Automation) is a vision language model (VLM) trained by [AskUI](https://www.askui.com/) which to address all kinds of UI elements by a textual description e.g. "`Login button`", "`Text login`" | fast, <500ms per step | Secure hosting by AskUI or on-premise | Low, <0,05$ per step | Recommended for production usage, can be retrained |
| `askui-ocr` | `AskUI OCR` is an OCR model trained to address texts on UI Screens e.g. "`Login`", "`Search`" | Fast, <500ms per step | Secure hosting by AskUI or on-premise | low, <0,05$ per step | Recommended for production usage, can be retrained |
| `askui-combo` | AskUI Combo is an combination from the `askui-pta` and the `askui-ocr` model to improve the accuracy. | Fast, <500ms per step | Secure hosting by AskUI or on-premise | low, <0,05$ per step | Recommended for production usage, can be retrained |
-| `askui-ai-element`| [AskUI AI Element](https://docs.askui.com/docs/general/Element%20Selection/aielement) allows you to address visual elements like icons or images by demonstrating what you looking for. Therefore, you have to crop out the element and give it a name. | Very fast, <5ms per step | Secure hosting by AskUI or on-premise | Low, <0,05$ per step | Recommended for production usage, determinitic behaviour |
+| `askui-ai-element`| [AskUI AI Element](https://docs.askui.com/docs/general/Element%20Selection/aielement) allows you to address visual elements like icons or images by demonstrating what you looking for. Therefore, you have to crop out the element and give it a name. | Very fast, <5ms per step | Secure hosting by AskUI or on-premise | Low, <0,05$ per step | Recommended for production usage, deterministic behaviour |
> **Note:** Configure your AskUI Model Provider [here](#3a-authenticate-with-an-ai-model-provider)
+
+
+
+ Antrophic AI Models
+
+Supported commands are: `act()`, `get()`, `click()`, `locate()`, `mouse_move()`
+| Model Name | Info | Execution Speed | Security | Cost | Reliability |
+|-------------|--------------------|--------------|--------------|--------------|--------------|
+| `anthropic-claude-3-5-sonnet-20241022` | The [Computer Use](https://docs.anthropic.com/en/docs/agents-and-tools/computer-use) model from Antrophic is a Large Action Model (LAM), which can autonomously achieve goals. e.g. `"Book me a flight from Berlin to Rom"` | slow, >1s per step | Model hosting by Anthropic | High, up to 1,5$ per act | Not recommended for production usage |
+> **Note:** Configure your Antrophic Model Provider [here](#3a-authenticate-with-an-ai-model-provider)
+
+
Huggingface AI Models (Spaces API)
-Supported commands are: `click()`, `type()`, `mouse_move()`
+Supported commands are: `click()`, `locate()`, `mouse_move()`
| Model Name | Info | Execution Speed | Security | Cost | Reliability |
|-------------|--------------------|--------------|--------------|--------------|--------------|
| `AskUI/PTA-1` | [`PTA-1`](https://huggingface.co/AskUI/PTA-1) (Prompt-to-Automation) is a vision language model (VLM) trained by [AskUI](https://www.askui.com/) which to address all kinds of UI elements by a textual description e.g. "`Login button`", "`Text login`" | fast, <500ms per step | Huggingface hosted | Prices for Huggingface hosting | Not recommended for production applications |
@@ -226,7 +248,7 @@ Supported commands are: `click()`, `type()`, `mouse_move()`
Self Hosted UI Models
-Supported commands are: `click()`, `type()`, `mouse_move()`, `get()`, `act()`
+Supported commands are: `click()`, `locate()`, `mouse_move()`, `get()`, `act()`
| Model Name | Info | Execution Speed | Security | Cost | Reliability |
|-------------|--------------------|--------------|--------------|--------------|--------------|
| `tars` | [`UI-Tars`](https://github.com/bytedance/UI-TARS) is a Large Action Model (LAM) based on Qwen2 and fine-tuned by [ByteDance](https://www.bytedance.com/) on UI data e.g. "`Book me a flight to rom`" | slow, >1s per step | Self-hosted | Depening on infrastructure | Out-of-the-box not recommended for production usage |
@@ -269,26 +291,160 @@ agent.tools.clipboard.copy("...")
result = agent.tools.clipboard.paste()
```
-### 📜 Logging & Reporting
+### 📜 Logging
-You want a better understanding of what you agent is doing? Set the `log_level` to DEBUG. You can also generate a report of the automation run by setting `enable_report` to `True`.
+You want a better understanding of what you agent is doing? Set the `log_level` to DEBUG.
```python
import logging
-with VisionAgent(log_level=logging.DEBUG, enable_report=True) as agent:
+with VisionAgent(log_level=logging.DEBUG) as agent:
+ agent...
+```
+
+### 📜 Reporting
+
+You want to see a report of the actions your agent took? Register a reporter using the `reporters` parameter.
+
+```python
+from typing import Optional, Union
+from typing_extensions import override
+from askui.reporting import SimpleHtmlReporter
+from PIL import Image
+
+with VisionAgent(reporters=[SimpleHtmlReporter()]) as agent:
+ agent...
+```
+
+You can also create your own reporter by implementing the `Reporter` interface.
+
+```python
+from askui.reporting import Reporter
+
+class CustomReporter(Reporter):
+ @override
+ def add_message(
+ self,
+ role: str,
+ content: Union[str, dict, list],
+ image: Optional[Image.Image] = None,
+ ) -> None:
+ # adding message to the report (see implementation of `SimpleHtmlReporter` as an example)
+ pass
+
+ @override
+ def generate(self) -> None:
+ # generate the report if not generated live (see implementation of `SimpleHtmlReporter` as an example)
+ pass
+
+
+with VisionAgent(reporters=[CustomReporter()]) as agent:
+ agent...
+```
+
+You can also use multiple reporters at once. Their `generate()` and `add_message()` methods will be called in the order of the reporters in the list.
+
+```python
+with VisionAgent(reporters=[SimpleHtmlReporter(), CustomReporter()]) as agent:
agent...
```
### 🖥️ Multi-Monitor Support
-You have multiple monitors? Choose which one to automate by setting `display` to 1 or 2.
+You have multiple monitors? Choose which one to automate by setting `display` to `1`, `2` etc. To find the correct display or monitor, you have to play play around a bit setting it to different values. We are going to improve this soon. By default, the agent will use display 1.
```python
with VisionAgent(display=1) as agent:
agent...
```
+### 🎯 Locating elements
+
+If you have a hard time locating (clicking, moving mouse to etc.) elements by simply using text, e.g.,
+
+```python
+agent.click("Password textfield")
+agent.type("********")
+```
+
+you can build more sophisticated locators.
+
+**⚠️ Warning:** Support can vary depending on the model you are using. Currently, only, the `askui` model provides best support for locators. This model is chosen by default if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` environment variables are set and it is not overridden using the `model` parameter.
+
+Example:
+
+```python
+from askui import locators as loc
+
+password_textfield_label = loc.Text("Password")
+password_textfield = loc.Element("textfield").right_of(password_textfield_label)
+
+agent.click(password_textfield)
+agent.type("********")
+```
+
+### 📊 Extracting information
+
+The `get()` method allows you to extract information from the screen. You can use it to:
+
+- Get text or data from the screen
+- Check the state of UI elements
+- Make decisions based on screen content
+- Analyze static images
+
+#### Basic usage
+
+```python
+# Get text from screen
+url = agent.get("What is the current url shown in the url bar?")
+print(url) # e.g., "github.com/login"
+
+# Check UI state
+# Just as an example, may be flaky if used as is, better use a response schema to check for a boolean value (see below)
+is_logged_in = agent.get("Is the user logged in? Answer with 'yes' or 'no'.") == "yes"
+if is_logged_in:
+ agent.click("Logout")
+else:
+ agent.click("Login")
+```
+
+#### Using custom images
+
+Instead of taking a screenshot, you can analyze specific images:
+
+```python
+from PIL import Image
+
+# From PIL Image
+image = Image.open("screenshot.png")
+result = agent.get("What's in this image?", image)
+
+# From file path
+result = agent.get("What's in this image?", "screenshot.png")
+```
+
+#### Using response schemas
+
+For structured data extraction, use Pydantic models extending `JsonSchemaBase`:
+
+```python
+from askui import JsonSchemaBase
+
+class UserInfo(JsonSchemaBase):
+ username: str
+ is_online: bool
+
+# Get structured data
+user_info = agent.get(
+ "What is the username and online status?",
+ response_schema=UserInfo
+)
+print(f"User {user_info.username} is {'online' if user_info.is_online else 'offline'}")
+```
+
+**⚠️ Limitations:**
+- Nested Pydantic schemas are not currently supported
+- Response schema is currently only supported by "askui" model (default model if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` are set)
## What is AskUI Vision Agent?
diff --git a/pdm.lock b/pdm.lock
index aa9660da..56efc204 100644
--- a/pdm.lock
+++ b/pdm.lock
@@ -5,7 +5,7 @@
groups = ["default", "test"]
strategy = ["inherit_metadata"]
lock_version = "4.5.0"
-content_hash = "sha256:8c2ae022f9280b62be3fc98d0e14053aece0661cc6dfca089149ff784b0b2efe"
+content_hash = "sha256:797a6cf550f6ec6264f8e851a84dff73bd155ed75264cf80adf458c4a3ecb832"
[[metadata.targets]]
requires_python = ">=3.10"
@@ -240,6 +240,17 @@ files = [
{file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"},
]
+[[package]]
+name = "execnet"
+version = "2.1.1"
+requires_python = ">=3.8"
+summary = "execnet: rapid multi-Python deployment"
+groups = ["test"]
+files = [
+ {file = "execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc"},
+ {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"},
+]
+
[[package]]
name = "filelock"
version = "3.16.1"
@@ -980,6 +991,21 @@ files = [
{file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"},
]
+[[package]]
+name = "pytest-xdist"
+version = "3.6.1"
+requires_python = ">=3.8"
+summary = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs"
+groups = ["test"]
+dependencies = [
+ "execnet>=2.1",
+ "pytest>=7.0.0",
+]
+files = [
+ {file = "pytest_xdist-3.6.1-py3-none-any.whl", hash = "sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7"},
+ {file = "pytest_xdist-3.6.1.tar.gz", hash = "sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d"},
+]
+
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
diff --git a/pyproject.toml b/pyproject.toml
index 9690c759..e3cc885a 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,9 +39,10 @@ path = "src/askui/__init__.py"
distribution = true
[tool.pdm.scripts]
-test = "pytest"
-"test:unit" = "pytest tests/unit"
-"test:integration" = "pytest tests/integration"
+test = "pytest -n auto"
+"test:e2e" = "pytest -n auto tests/e2e"
+"test:integration" = "pytest -n auto tests/integration"
+"test:unit" = "pytest -n auto tests/unit"
sort = "isort ."
format = "black ."
lint = "ruff check ."
@@ -56,6 +57,7 @@ test = [
"black>=25.1.0",
"ruff>=0.9.5",
"pytest-mock>=3.14.0",
+ "pytest-xdist>=3.6.1",
]
chat = [
"streamlit>=1.42.0",
diff --git a/src/askui/__init__.py b/src/askui/__init__.py
index 5b7ab018..6cd6a904 100644
--- a/src/askui/__init__.py
+++ b/src/askui/__init__.py
@@ -3,7 +3,19 @@
__version__ = "0.2.4"
from .agent import VisionAgent
+from .models.router import ModelRouter
+from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase
+from .tools.toolbox import AgentToolbox
+from .tools.agent_os import AgentOs, ModifierKey, PcKey
+
__all__ = [
+ "AgentOs",
+ "AgentToolbox",
+ "ModelRouter",
+ "ModifierKey",
+ "PcKey",
+ "ResponseSchema",
+ "ResponseSchemaBase",
"VisionAgent",
]
diff --git a/src/askui/agent.py b/src/askui/agent.py
index 5ca927a7..2948e88e 100644
--- a/src/askui/agent.py
+++ b/src/askui/agent.py
@@ -1,282 +1,452 @@
import logging
import subprocess
-from typing import Annotated, Any, Literal, Optional, Callable
-
-from pydantic import Field, validate_call
+from typing import Annotated, Any, Literal, Optional, Type, overload
+from pydantic import ConfigDict, Field, validate_call
from askui.container import telemetry
+from askui.locators.locators import Locator
+from askui.utils.image_utils import ImageSource, Img
from .tools.askui.askui_controller import (
AskUiControllerClient,
AskUiControllerServer,
- PC_AND_MODIFIER_KEY,
- MODIFIER_KEY,
+ ModifierKey,
+ PcKey,
)
-from .models.anthropic.claude import ClaudeHandler
from .logger import logger, configure_logging
from .tools.toolbox import AgentToolbox
-from .models.router import ModelRouter
-from .reporting.report import SimpleReportGenerator
+from .models import ModelComposition
+from .models.router import ModelRouter, Point
+from .reporting import CompositeReporter, Reporter
import time
from dotenv import load_dotenv
-from PIL import Image
+from .models.types.response_schemas import ResponseSchema
+
class InvalidParameterError(Exception):
pass
class VisionAgent:
- @telemetry.record_call(exclude={"report_callback"})
+ """
+ A vision-based agent that can interact with user interfaces through computer vision and AI.
+
+ This agent can perform various UI interactions like clicking, typing, scrolling, and more.
+ It uses computer vision models to locate UI elements and execute actions on them.
+
+ Parameters:
+ log_level (int, optional):
+ The logging level to use. Defaults to logging.INFO.
+ display (int, optional):
+ The display number to use for screen interactions. Defaults to 1.
+ model_router (ModelRouter | None, optional):
+ Custom model router instance. If None, a default one will be created.
+ reporters (list[Reporter] | None, optional):
+ List of reporter instances for logging and reporting. If None, an empty list is used.
+ tools (AgentToolbox | None, optional):
+ Custom toolbox instance. If None, a default one will be created with AskUiControllerClient.
+ model (ModelComposition | str | None, optional):
+ The default composition or name of the model(s) to be used for vision tasks.
+ Can be overridden by the `model` parameter in the `click()`, `get()`, `act()` etc. methods.
+
+ Example:
+ ```python
+ with VisionAgent() as agent:
+ agent.click("Submit button")
+ agent.type("Hello World")
+ agent.act("Open settings menu")
+ ```
+ """
+ @telemetry.record_call(exclude={"model_router", "reporters", "tools"})
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def __init__(
self,
- log_level=logging.INFO,
- display: int = 1,
- enable_report: bool = False,
- enable_askui_controller: bool = True,
- report_callback: Callable[[str | dict[str, Any]], None] | None = None,
+ log_level: int | str = logging.INFO,
+ display: Annotated[int, Field(ge=1)] = 1,
+ model_router: ModelRouter | None = None,
+ reporters: list[Reporter] | None = None,
+ tools: AgentToolbox | None = None,
+ model: ModelComposition | str | None = None,
) -> None:
load_dotenv()
configure_logging(level=log_level)
- self.report = None
- if enable_report:
- self.report = SimpleReportGenerator(report_callback=report_callback)
- self.controller = None
- self.client = None
- if enable_askui_controller:
- self.controller = AskUiControllerServer()
- self.controller.start(True)
- time.sleep(0.5)
- self.client = AskUiControllerClient(display, self.report)
- self.client.connect()
- self.client.set_display(display)
- self.model_router = ModelRouter(log_level, self.report)
- self.claude = ClaudeHandler(log_level=log_level)
- self.tools = AgentToolbox(os_controller=self.client)
-
- def _check_askui_controller_enabled(self) -> None:
- if not self.client:
- raise ValueError(
- "AskUI Controller is not initialized. Please, set `enable_askui_controller` to `True` when initializing the `VisionAgent`."
- )
-
- @telemetry.record_call(exclude={"instruction"})
- def click(self, instruction: Optional[str] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model_name: Optional[str] = None) -> None:
+ self._reporter = CompositeReporter(reports=reporters)
+ self.tools = tools or AgentToolbox(
+ agent_os=AskUiControllerClient(
+ display=display,
+ reporter=self._reporter,
+ controller_server=AskUiControllerServer()
+ ),
+ )
+ self.model_router = (
+ ModelRouter(tools=self.tools, reporter=self._reporter) if model_router is None else model_router
+ )
+ self._model = model
+
+ @telemetry.record_call(exclude={"locator"})
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
+ def click(
+ self,
+ locator: Optional[str | Locator] = None,
+ button: Literal['left', 'middle', 'right'] = 'left',
+ repeat: Annotated[int, Field(gt=0)] = 1,
+ model: ModelComposition | str | None = None,
+ ) -> None:
"""
- Simulates a mouse click on the user interface element identified by the provided instruction.
+ Simulates a mouse click on the user interface element identified by the provided locator.
Parameters:
- instruction (str | None): The identifier or description of the element to click.
- button ('left' | 'middle' | 'right'): Specifies which mouse button to click. Defaults to 'left'.
- repeat (int): The number of times to click. Must be greater than 0. Defaults to 1.
- model_name (str | None): The model name to be used for element detection. Optional.
+ locator (str | Locator | None):
+ The identifier or description of the element to click. If None, clicks at current position.
+ button ('left' | 'middle' | 'right'):
+ Specifies which mouse button to click. Defaults to 'left'.
+ repeat (int):
+ The number of times to click. Must be greater than 0. Defaults to 1.
+ model (ModelComposition | str | None):
+ The composition or name of the model(s) to be used for locating the element to click on using the `locator`.
Raises:
InvalidParameterError: If the 'repeat' parameter is less than 1.
Example:
- ```python
- with VisionAgent() as agent:
- agent.click() # Left click on current position
- agent.click("Edit") # Left click on text "Edit"
- agent.click("Edit", button="right") # Right click on text "Edit"
- agent.click(repeat=2) # Double left click on current position
- agent.click("Edit", button="middle", repeat=4) # 4x middle click on text "Edit"
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.click() # Left click on current position
+ agent.click("Edit") # Left click on text "Edit"
+ agent.click("Edit", button="right") # Right click on text "Edit"
+ agent.click(repeat=2) # Double left click on current position
+ agent.click("Edit", button="middle", repeat=4) # 4x middle click on text "Edit"
+ ```
"""
if repeat < 1:
raise InvalidParameterError("InvalidParameterError! The parameter 'repeat' needs to be greater than 0.")
- self._check_askui_controller_enabled()
- if self.report is not None:
- msg = 'click'
- if button != 'left':
- msg = f'{button} ' + msg
- if repeat > 1:
- msg += f' {repeat}x times'
- if instruction is not None:
- msg += f' on "{instruction}"'
- self.report.add_message("User", msg)
- if instruction is not None:
- logger.debug("VisionAgent received instruction to click '%s'", instruction)
- self.__mouse_move(instruction, model_name)
- self.client.click(button, repeat) # type: ignore
-
- def __mouse_move(self, instruction: str, model_name: Optional[str] = None) -> None:
- self._check_askui_controller_enabled()
- screenshot = self.client.screenshot() # type: ignore
- x, y = self.model_router.locate(screenshot, instruction, model_name)
- if self.report is not None:
- self.report.add_message("ModelRouter", f"locate: ({x}, {y})")
- self.client.mouse(x, y) # type: ignore
-
- @telemetry.record_call(exclude={"instruction"})
- def mouse_move(self, instruction: str, model_name: Optional[str] = None) -> None:
+ msg = 'click'
+ if button != 'left':
+ msg = f'{button} ' + msg
+ if repeat > 1:
+ msg += f' {repeat}x times'
+ if locator is not None:
+ msg += f' on {locator}'
+ self._reporter.add_message("User", msg)
+ if locator is not None:
+ logger.debug("VisionAgent received instruction to click on %s", locator)
+ self._mouse_move(locator, model or self._model)
+ self.tools.agent_os.click(button, repeat) # type: ignore
+
+ def _locate(self, locator: str | Locator, screenshot: Optional[Img] = None, model: ModelComposition | str | None = None) -> Point:
+ _screenshot = ImageSource(self.tools.agent_os.screenshot() if screenshot is None else screenshot)
+ point = self.model_router.locate(_screenshot.root, locator, model or self._model)
+ self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})")
+ return point
+
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
+ def locate(
+ self,
+ locator: str | Locator,
+ screenshot: Optional[Img] = None,
+ model: ModelComposition | str | None = None,
+ ) -> Point:
"""
- Moves the mouse cursor to the UI element identified by the provided instruction.
+ Locates the UI element identified by the provided locator.
Parameters:
- instruction (str): The identifier or description of the element to move to.
- model_name (str | None): The model name to be used for element detection. Optional.
+ locator (str | Locator):
+ The identifier or description of the element to locate.
+ screenshot (Img | None, optional):
+ The screenshot to use for locating the element. Can be a path to an image file, a PIL Image object or a data URL.
+ If None, takes a screenshot of the currently selected display.
+ model (ModelComposition | str | None):
+ The composition or name of the model(s) to be used for locating the element using the `locator`.
+
+ Returns:
+ Point: The coordinates of the element as a tuple (x, y).
Example:
- ```python
- with VisionAgent() as agent:
- agent.mouse_move("Submit button") # Moves cursor to submit button
- agent.mouse_move("Close") # Moves cursor to close element
- agent.mouse_move("Profile picture", model_name="custom_model") # Uses specific model
- ```
+ ```python
+ with VisionAgent() as agent:
+ point = agent.locate("Submit button")
+ print(f"Element found at coordinates: {point}")
+ ```
"""
- if self.report is not None:
- self.report.add_message("User", f'mouse_move: "{instruction}"')
- logger.debug("VisionAgent received instruction to mouse_move '%s'", instruction)
- self.__mouse_move(instruction, model_name)
+ self._reporter.add_message("User", f"locate {locator}")
+ logger.debug("VisionAgent received instruction to locate %s", locator)
+ return self._locate(locator, screenshot, model or self._model)
+
+ def _mouse_move(self, locator: str | Locator, model: ModelComposition | str | None = None) -> None:
+ point = self._locate(locator=locator, model=model or self._model)
+ self.tools.agent_os.mouse(point[0], point[1]) # type: ignore
+
+ @telemetry.record_call(exclude={"locator"})
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
+ def mouse_move(
+ self,
+ locator: str | Locator,
+ model: ModelComposition | str | None = None,
+ ) -> None:
+ """
+ Moves the mouse cursor to the UI element identified by the provided locator.
+
+ Parameters:
+ locator (str | Locator):
+ The identifier or description of the element to move to.
+ model (ModelComposition | str | None):
+ The composition or name of the model(s) to be used for locating the element to move the mouse to using the `locator`.
+
+ Example:
+ ```python
+ with VisionAgent() as agent:
+ agent.mouse_move("Submit button") # Moves cursor to submit button
+ agent.mouse_move("Close") # Moves cursor to close element
+ agent.mouse_move("Profile picture", model="custom_model") # Uses specific model
+ ```
+ """
+ self._reporter.add_message("User", f'mouse_move: {locator}')
+ logger.debug("VisionAgent received instruction to mouse_move to %s", locator)
+ self._mouse_move(locator, model or self._model)
@telemetry.record_call()
- def mouse_scroll(self, x: int, y: int) -> None:
+ @validate_call
+ def mouse_scroll(
+ self,
+ x: int,
+ y: int,
+ ) -> None:
"""
Simulates scrolling the mouse wheel by the specified horizontal and vertical amounts.
Parameters:
- x (int): The horizontal scroll amount. Positive values typically scroll right, negative values scroll left.
- y (int): The vertical scroll amount. Positive values typically scroll down, negative values scroll up.
+ x (int):
+ The horizontal scroll amount. Positive values typically scroll right, negative values scroll left.
+ y (int):
+ The vertical scroll amount. Positive values typically scroll down, negative values scroll up.
Note:
- The actual `scroll direction` depends on the operating system's configuration.
+ The actual scroll direction depends on the operating system's configuration.
Some systems may have "natural scrolling" enabled, which reverses the traditional direction.
- The meaning of scroll `units` varies` acro`ss oper`ating` systems and applications.
+ The meaning of scroll units varies across operating systems and applications.
A scroll value of 10 might result in different distances depending on the application and system settings.
Example:
- ```python
- with VisionAgent() as agent:
- agent.mouse_scroll(0, 10) # Usually scrolls down 10 units
- agent.mouse_scroll(0, -5) # Usually scrolls up 5 units
- agent.mouse_scroll(3, 0) # Usually scrolls right 3 units
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.mouse_scroll(0, 10) # Usually scrolls down 10 units
+ agent.mouse_scroll(0, -5) # Usually scrolls up 5 units
+ agent.mouse_scroll(3, 0) # Usually scrolls right 3 units
+ ```
"""
- self._check_askui_controller_enabled()
- if self.report is not None:
- self.report.add_message("User", f'mouse_scroll: "{x}", "{y}"')
- self.client.mouse_scroll(x, y)
+ self._reporter.add_message("User", f'mouse_scroll: "{x}", "{y}"')
+ self.tools.agent_os.mouse_scroll(x, y)
@telemetry.record_call(exclude={"text"})
- def type(self, text: str) -> None:
+ @validate_call
+ def type(
+ self,
+ text: Annotated[str, Field(min_length=1)],
+ ) -> None:
"""
Types the specified text as if it were entered on a keyboard.
Parameters:
- text (str): The text to be typed.
+ text (str):
+ The text to be typed. Must be at least 1 character long.
Example:
- ```python
- with VisionAgent() as agent:
- agent.type("Hello, world!") # Types "Hello, world!"
- agent.type("user@example.com") # Types an email address
- agent.type("password123") # Types a password
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.type("Hello, world!") # Types "Hello, world!"
+ agent.type("user@example.com") # Types an email address
+ agent.type("password123") # Types a password
+ ```
"""
- self._check_askui_controller_enabled()
- if self.report is not None:
- self.report.add_message("User", f'type: "{text}"')
+ self._reporter.add_message("User", f'type: "{text}"')
logger.debug("VisionAgent received instruction to type '%s'", text)
- self.client.type(text) # type: ignore
+ self.tools.agent_os.type(text) # type: ignore
- @telemetry.record_call(exclude={"instruction", "screenshot"})
- def get(self, instruction: str, model_name: Optional[str] = None, screenshot: Optional[Image.Image] = None) -> str:
+
+ @overload
+ def get(
+ self,
+ query: Annotated[str, Field(min_length=1)],
+ response_schema: None = None,
+ image: Optional[Img] = None,
+ model: ModelComposition | str | None = None,
+ ) -> str: ...
+ @overload
+ def get(
+ self,
+ query: Annotated[str, Field(min_length=1)],
+ response_schema: Type[ResponseSchema],
+ image: Optional[Img] = None,
+ model: ModelComposition | str | None = None,
+ ) -> ResponseSchema: ...
+
+ @telemetry.record_call(exclude={"query", "image", "response_schema"})
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
+ def get(
+ self,
+ query: Annotated[str, Field(min_length=1)],
+ image: Optional[Img] = None,
+ response_schema: Type[ResponseSchema] | None = None,
+ model: ModelComposition | str | None = None,
+ ) -> ResponseSchema | str:
"""
- Retrieves text or information from the screen based on the provided instruction.
+ Retrieves information from an image (defaults to a screenshot of the current screen) based on the provided query.
Parameters:
- instruction (str): The instruction describing what information to retrieve.
- model_name (str | None): The model name to be used for information extraction. Optional.
+ query (str):
+ The query describing what information to retrieve.
+ image (Img | None, optional):
+ The image to extract information from. Defaults to a screenshot of the current screen.
+ Can be a path to an image file, a PIL Image object or a data URL.
+ response_schema (Type[ResponseSchema] | None, optional):
+ A Pydantic model class that defines the response schema. If not provided, returns a string.
+ model (ModelComposition | str | None, optional):
+ The composition or name of the model(s) to be used for retrieving information from the screen or image using the `query`.
+ Note: `response_schema` is not supported by all models.
Returns:
- str: The extracted text or information.
+ ResponseSchema | str:
+ The extracted information, either as an instance of ResponseSchema or string if no response_schema is provided.
+
+ Limitations:
+ - Nested Pydantic schemas are not currently supported
+ - Schema support is only available with "askui" model (default model if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` are set) at the moment
Example:
- ```python
- with VisionAgent() as agent:
- price = agent.get("What is the price displayed?")
- username = agent.get("What is the username shown in the profile?")
- error_message = agent.get("What does the error message say?")
- ```
+ ```python
+ from askui import JsonSchemaBase
+ from PIL import Image
+
+ class UrlResponse(JsonSchemaBase):
+ url: str
+
+ with VisionAgent() as agent:
+ # Get URL as string
+ url = agent.get("What is the current url shown in the url bar?")
+
+ # Get URL as Pydantic model from image at (relative) path
+ response = agent.get(
+ "What is the current url shown in the url bar?",
+ response_schema=UrlResponse,
+ image="screenshot.png",
+ )
+ print(response.url)
+
+ # Get boolean response from PIL Image
+ is_login_page = agent.get(
+ "Is this a login page?",
+ response_schema=bool,
+ image=Image.open("screenshot.png"),
+ )
+
+ # Get integer response
+ input_count = agent.get(
+ "How many input fields are visible on this page?",
+ response_schema=int,
+ )
+
+ # Get float response
+ design_rating = agent.get(
+ "Rate the page design quality from 0 to 1",
+ response_schema=float,
+ )
+ ```
"""
- self._check_askui_controller_enabled()
- if self.report is not None:
- self.report.add_message("User", f'get: "{instruction}"')
- logger.debug("VisionAgent received instruction to get '%s'", instruction)
- if screenshot is None:
- screenshot = self.client.screenshot() # type: ignore
- response = self.model_router.get_inference(screenshot, instruction, model_name)
- if self.report is not None:
- self.report.add_message("Agent", response)
+ logger.debug("VisionAgent received instruction to get '%s'", query)
+ _image = ImageSource(self.tools.agent_os.screenshot() if image is None else image) # type: ignore
+ self._reporter.add_message("User", f'get: "{query}"', image=_image.root)
+ response = self.model_router.get_inference(
+ image=_image,
+ query=query,
+ model=model or self._model,
+ response_schema=response_schema,
+ )
+ if self._reporter is not None:
+ message_content = str(response) if isinstance(response, (str, bool, int, float)) else response.model_dump()
+ self._reporter.add_message("Agent", message_content)
return response
@telemetry.record_call()
@validate_call
- def wait(self, sec: Annotated[float, Field(gt=0)]) -> None:
+ def wait(
+ self,
+ sec: Annotated[float, Field(gt=0.0)],
+ ) -> None:
"""
Pauses the execution of the program for the specified number of seconds.
Parameters:
- sec (float): The number of seconds to wait. Must be greater than 0.
+ sec (float):
+ The number of seconds to wait. Must be greater than 0.0.
Raises:
ValueError: If the provided `sec` is negative.
Example:
- ```python
- with VisionAgent() as agent:
- agent.wait(5) # Pauses execution for 5 seconds
- agent.wait(0.5) # Pauses execution for 500 milliseconds
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.wait(5) # Pauses execution for 5 seconds
+ agent.wait(0.5) # Pauses execution for 500 milliseconds
+ ```
"""
time.sleep(sec)
@telemetry.record_call()
- def key_up(self, key: PC_AND_MODIFIER_KEY) -> None:
+ @validate_call
+ def key_up(
+ self,
+ key: PcKey | ModifierKey,
+ ) -> None:
"""
Simulates the release of a key.
Parameters:
- key (PC_AND_MODIFIER_KEY): The key to be released.
+ key (PcKey | ModifierKey):
+ The key to be released.
Example:
- ```python
- with VisionAgent() as agent:
- agent.key_up('a') # Release the 'a' key
- agent.key_up('shift') # Release the 'Shift' key
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.key_up('a') # Release the 'a' key
+ agent.key_up('shift') # Release the 'Shift' key
+ ```
"""
- self._check_askui_controller_enabled()
- if self.report is not None:
- self.report.add_message("User", f'key_up "{key}"')
+ self._reporter.add_message("User", f'key_up "{key}"')
logger.debug("VisionAgent received in key_up '%s'", key)
- self.client.keyboard_release(key)
+ self.tools.agent_os.keyboard_release(key)
@telemetry.record_call()
- def key_down(self, key: PC_AND_MODIFIER_KEY) -> None:
+ @validate_call
+ def key_down(
+ self,
+ key: PcKey | ModifierKey,
+ ) -> None:
"""
Simulates the pressing of a key.
Parameters:
- key (PC_AND_MODIFIER_KEY): The key to be pressed.
+ key (PcKey | ModifierKey):
+ The key to be pressed.
Example:
- ```python
- with VisionAgent() as agent:
- agent.key_down('a') # Press the 'a' key
- agent.key_down('shift') # Press the 'Shift' key
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.key_down('a') # Press the 'a' key
+ agent.key_down('shift') # Press the 'Shift' key
+ ```
"""
- self._check_askui_controller_enabled()
- if self.report is not None:
- self.report.add_message("User", f'key_down "{key}"')
+ self._reporter.add_message("User", f'key_down "{key}"')
logger.debug("VisionAgent received in key_down '%s'", key)
- self.client.keyboard_pressed(key)
+ self.tools.agent_os.keyboard_pressed(key)
@telemetry.record_call(exclude={"goal"})
- def act(self, goal: str, model_name: Optional[str] = None) -> None:
+ @validate_call
+ def act(
+ self,
+ goal: Annotated[str, Field(min_length=1)],
+ model: ModelComposition | str | None = None,
+ ) -> None:
"""
Instructs the agent to achieve a specified goal through autonomous actions.
@@ -285,54 +455,59 @@ def act(self, goal: str, model_name: Optional[str] = None) -> None:
interface interactions.
Parameters:
- goal (str): A description of what the agent should achieve.
- model_name (str | None): The specific model to use for vision analysis.
- If None, uses the default model.
+ goal (str):
+ A description of what the agent should achieve.
+ model (ModelComposition | str | None, optional):
+ The composition or name of the model(s) to be used for achieving the `goal`.
Example:
- ```python
- with VisionAgent() as agent:
- agent.act("Open the settings menu")
- agent.act("Search for 'printer' in the search box")
- agent.act("Log in with username 'admin' and password '1234'")
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.act("Open the settings menu")
+ agent.act("Search for 'printer' in the search box")
+ agent.act("Log in with username 'admin' and password '1234'")
+ ```
"""
- self._check_askui_controller_enabled()
- if self.report is not None:
- self.report.add_message("User", f'act: "{goal}"')
+ self._reporter.add_message("User", f'act: "{goal}"')
logger.debug(
"VisionAgent received instruction to act towards the goal '%s'", goal
)
- self.model_router.act(self.client, goal, model_name)
+ self.model_router.act(goal, model or self._model)
@telemetry.record_call()
+ @validate_call
def keyboard(
- self, key: PC_AND_MODIFIER_KEY, modifier_keys: list[MODIFIER_KEY] | None = None
+ self,
+ key: PcKey | ModifierKey,
+ modifier_keys: Optional[list[ModifierKey]] = None,
) -> None:
"""
Simulates pressing a key or key combination on the keyboard.
Parameters:
- key (PC_AND_MODIFIER_KEY): The main key to press. This can be a letter, number,
- special character, or function key.
- modifier_keys (list[MODIFIER_KEY] | None): Optional list of modifier keys to press
- along with the main key. Common modifier keys include 'ctrl', 'alt', 'shift'.
+ key (PcKey | ModifierKey):
+ The main key to press. This can be a letter, number, special character, or function key.
+ modifier_keys (list[ModifierKey] | None, optional):
+ List of modifier keys to press along with the main key. Common modifier keys include 'ctrl', 'alt', 'shift'.
Example:
- ```python
- with VisionAgent() as agent:
- agent.keyboard('a') # Press 'a' key
- agent.keyboard('enter') # Press 'Enter' key
- agent.keyboard('v', ['control']) # Press Ctrl+V (paste)
- agent.keyboard('s', ['control', 'shift']) # Press Ctrl+Shift+S
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.keyboard('a') # Press 'a' key
+ agent.keyboard('enter') # Press 'Enter' key
+ agent.keyboard('v', ['control']) # Press Ctrl+V (paste)
+ agent.keyboard('s', ['control', 'shift']) # Press Ctrl+Shift+S
+ ```
"""
- self._check_askui_controller_enabled()
logger.debug("VisionAgent received instruction to press '%s'", key)
- self.client.keyboard_tap(key, modifier_keys) # type: ignore
+ self.tools.agent_os.keyboard_tap(key, modifier_keys) # type: ignore
@telemetry.record_call(exclude={"command"})
- def cli(self, command: str) -> None:
+ @validate_call
+ def cli(
+ self,
+ command: Annotated[str, Field(min_length=1)],
+ ) -> None:
"""
Executes a command on the command line interface.
@@ -340,32 +515,39 @@ def cli(self, command: str) -> None:
is split on spaces and executed as a subprocess.
Parameters:
- command (str): The command to execute on the command line.
+ command (str):
+ The command to execute on the command line.
Example:
- ```python
- with VisionAgent() as agent:
- agent.cli("echo Hello World") # Prints "Hello World"
- agent.cli("ls -la") # Lists files in current directory with details
- agent.cli("python --version") # Displays Python version
- ```
+ ```python
+ with VisionAgent() as agent:
+ agent.cli("echo Hello World") # Prints "Hello World"
+ agent.cli("ls -la") # Lists files in current directory with details
+ agent.cli("python --version") # Displays Python version
+ ```
"""
logger.debug("VisionAgent received instruction to execute '%s' on cli", command)
subprocess.run(command.split(" "))
@telemetry.record_call(flush=True)
def close(self) -> None:
- if self.client:
- self.client.disconnect()
- if self.controller:
- self.controller.stop(True)
+ self.tools.agent_os.disconnect()
+ self._reporter.generate()
+
+ @telemetry.record_call()
+ def open(self) -> None:
+ self.tools.agent_os.connect()
@telemetry.record_call()
def __enter__(self) -> "VisionAgent":
+ self.open()
return self
@telemetry.record_call(exclude={"exc_value", "traceback"})
- def __exit__(self, exc_type, exc_value, traceback) -> None:
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[Any],
+ ) -> None:
self.close()
- if self.report is not None:
- self.report.generate_report()
diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py
index f212521e..97cf18b2 100644
--- a/src/askui/chat/__main__.py
+++ b/src/askui/chat/__main__.py
@@ -1,17 +1,22 @@
from random import randint
from PIL import Image, ImageDraw
-from typing import Any, Callable, Literal
+from typing import Union
+from typing_extensions import override, TypedDict
import streamlit as st
from askui import VisionAgent
import logging
from askui.chat.click_recorder import ClickRecorder
-from askui.utils import base64_to_image, draw_point_on_image
+from askui.models import ModelName
+from askui.reporting import Reporter
+from askui.utils.image_utils import base64_to_image
import json
-from datetime import date, datetime
+from datetime import datetime
import os
import glob
import re
+from askui.utils.image_utils import draw_point_on_image
+
st.set_page_config(
page_title="Vision Agent Chat",
@@ -25,14 +30,6 @@
click_recorder = ClickRecorder()
-def json_serial(obj):
- """JSON serializer for objects not serializable by default json code"""
-
- if isinstance(obj, (datetime, date)):
- return obj.isoformat()
- raise TypeError("Type %s not serializable" % type(obj))
-
-
def setup_chat_dirs():
os.makedirs(CHAT_SESSIONS_DIR_PATH, exist_ok=True)
os.makedirs(CHAT_IMAGES_DIR_PATH, exist_ok=True)
@@ -70,19 +67,24 @@ def get_image(img_b64_str_or_path: str) -> Image.Image:
def write_message(
- role: Literal["User", "Anthropic Computer Use", "AgentOS", "User (Demonstration)"],
- content: str,
+ role: str,
+ content: str | dict | list,
timestamp: str,
- image: Image.Image |str | None = None,
+ image: Image.Image | str | list[str | Image.Image] | list[str] | list[Image.Image] | None = None,
):
_role = ROLE_MAP.get(role.lower(), UNKNOWN_ROLE)
avatar = None if _role != UNKNOWN_ROLE else "❔"
with st.chat_message(_role, avatar=avatar):
st.markdown(f"*{timestamp}* - **{role}**\n\n")
- st.markdown(content)
+ st.markdown(json.dumps(content, indent=2) if isinstance(content, (dict, list)) else content)
if image:
- img = get_image(image) if isinstance(image, str) else image
- st.image(img)
+ if isinstance(image, list):
+ for img in image:
+ img = get_image(img) if isinstance(img, str) else img
+ st.image(img)
+ else:
+ img = get_image(image) if isinstance(image, str) else image
+ st.image(img)
def save_image(image: Image.Image) -> str:
@@ -92,31 +94,44 @@ def save_image(image: Image.Image) -> str:
return image_path
-def chat_history_appender(session_id: str) -> Callable[[str | dict[str, Any]], None]:
- def append_to_chat_history(report: str | dict) -> None:
- if isinstance(report, dict):
- if report.get("image"):
- if not os.path.isfile(report["image"]):
- report["image"] = save_image(base64_to_image(report["image"]))
+class Message(TypedDict):
+ role: str
+ content: str | dict | list
+ timestamp: str
+ image: str | list[str] | None
+
+
+class ChatHistoryAppender(Reporter):
+ def __init__(self, session_id: str) -> None:
+ self._session_id = session_id
+
+ @override
+ def add_message(self, role: str, content: Union[str, dict, list], image: Image.Image | list[Image.Image] | None = None) -> None:
+ image_paths: list[str] = []
+ if image is None:
+ _images = []
+ elif isinstance(image, list):
+ _images = image
else:
- report = {
- "role": "unknown",
- "content": f"🔄 {report}",
- "timestamp": datetime.now().isoformat(),
- }
- write_message(
- report["role"],
- report["content"],
- report["timestamp"],
- report.get("image"),
+ _images = [image]
+ for img in _images:
+ image_paths.append(save_image(img))
+ message = Message(
+ role=role,
+ content=content,
+ timestamp=datetime.now().isoformat(),
+ image=image_paths,
)
+ write_message(**message)
with open(
- os.path.join(CHAT_SESSIONS_DIR_PATH, f"{session_id}.jsonl"), "a"
+ os.path.join(CHAT_SESSIONS_DIR_PATH, f"{self._session_id}.jsonl"), "a"
) as f:
- json.dump(report, f, default=json_serial)
+ json.dump(message, f)
f.write("\n")
- return append_to_chat_history
+ @override
+ def generate(self) -> None:
+ pass
def get_available_sessions():
@@ -200,9 +215,9 @@ def rerun():
screenshot, (x, y)
)
element_description = agent.get(
- prompt,
- screenshot=screenshot_with_crosshair,
- model_name="anthropic-claude-3-5-sonnet-20241022",
+ query=prompt,
+ image=screenshot_with_crosshair,
+ model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022,
)
write_message(
message["role"],
@@ -211,8 +226,8 @@ def rerun():
image=screenshot_with_crosshair,
)
agent.mouse_move(
- instruction=element_description.replace('"', ""),
- model_name="anthropic-claude-3-5-sonnet-20241022",
+ locator=element_description.replace('"', ""),
+ model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022,
)
else:
write_message(
@@ -255,7 +270,7 @@ def rerun():
st.session_state.session_id = session_id
st.rerun()
-report_callback = chat_history_appender(session_id)
+reporter = ChatHistoryAppender(session_id)
st.title(f"Vision Agent Chat - {session_id}")
st.session_state.messages = load_chat_history(session_id)
@@ -270,26 +285,16 @@ def rerun():
)
if value_to_type := st.chat_input("Simulate Typing for User (Demonstration)"):
- report_callback(
- {
- "role": "User (Demonstration)",
- "content": f'type("{value_to_type}", 50)',
- "timestamp": datetime.now().isoformat(),
- "is_json": False,
- "image": None,
- }
+ reporter.add_message(
+ role="User (Demonstration)",
+ content=f'type("{value_to_type}", 50)',
)
st.rerun()
if st.button("Simulate left click"):
- report_callback(
- {
- "role": "User (Demonstration)",
- "content": 'click("left", 1)',
- "timestamp": datetime.now().isoformat(),
- "is_json": False,
- "image": None,
- }
+ reporter.add_message(
+ role="User (Demonstration)",
+ content='click("left", 1)',
)
st.rerun()
@@ -298,35 +303,24 @@ def rerun():
"Demonstrate where to move mouse"
): # only single step, only click supported for now, independent of click always registered as click
image, coordinates = click_recorder.record()
- report_callback(
- {
- "role": "User (Demonstration)",
- "content": "screenshot()",
- "timestamp": datetime.now().isoformat(),
- "is_json": False,
- "image": save_image(image),
- }
+ reporter.add_message(
+ role="User (Demonstration)",
+ content="screenshot()",
+ image=image,
)
- report_callback(
- {
- "role": "User (Demonstration)",
- "content": f"mouse({coordinates[0]}, {coordinates[1]})",
- "timestamp": datetime.now().isoformat(),
- "is_json": False,
- "image": save_image(
- draw_point_on_image(image, coordinates[0], coordinates[1])
- ),
- }
+ reporter.add_message(
+ role="User (Demonstration)",
+ content=f"mouse({coordinates[0]}, {coordinates[1]})",
+ image=draw_point_on_image(image, coordinates[0], coordinates[1]),
)
st.rerun()
if act_prompt := st.chat_input("Ask AI"):
with VisionAgent(
log_level=logging.DEBUG,
- enable_report=True,
- report_callback=report_callback,
+ reporters=[reporter],
) as agent:
- agent.act(act_prompt, model_name="claude")
+ agent.act(act_prompt, model="claude")
st.rerun()
if st.button("Rerun"):
diff --git a/src/askui/exceptions.py b/src/askui/exceptions.py
new file mode 100644
index 00000000..467882da
--- /dev/null
+++ b/src/askui/exceptions.py
@@ -0,0 +1,8 @@
+class AutomationError(Exception):
+ """Exception raised when the automation step cannot complete."""
+ pass
+
+
+class ElementNotFoundError(AutomationError):
+ """Exception raised when an element cannot be located."""
+ pass
diff --git a/src/askui/locators/__init__.py b/src/askui/locators/__init__.py
new file mode 100644
index 00000000..23964220
--- /dev/null
+++ b/src/askui/locators/__init__.py
@@ -0,0 +1,9 @@
+from askui.locators.locators import AiElement, Element, Prompt, Image, Text
+
+__all__ = [
+ "AiElement",
+ "Element",
+ "Prompt",
+ "Image",
+ "Text",
+]
diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py
new file mode 100644
index 00000000..24bc569a
--- /dev/null
+++ b/src/askui/locators/locators.py
@@ -0,0 +1,418 @@
+from abc import ABC
+import pathlib
+from typing import Annotated, Literal, Union
+import uuid
+
+from PIL import Image as PILImage
+from pydantic import ConfigDict, Field, validate_call
+
+from askui.utils.image_utils import ImageSource
+from askui.locators.relatable import Relatable
+
+
+class Locator(Relatable, ABC):
+ """Base class for all locators."""
+
+ def _str(self) -> str:
+ return "locator"
+
+ pass
+
+
+class Prompt(Locator):
+ """Locator for finding ui elements by a textual prompt / description of a ui element, e.g., "green sign up button"."""
+
+ @validate_call
+ def __init__(
+ self,
+ prompt: Annotated[
+ str,
+ Field(
+ description="""A textual prompt / description of a ui element, e.g., "green sign up button"."""
+ ),
+ ],
+ ) -> None:
+ """Initialize a Prompt locator.
+
+ Args:
+ prompt: A textual prompt / description of a ui element, e.g., "green sign up button"
+ """
+ super().__init__()
+ self._prompt = prompt
+
+ @property
+ def prompt(self) -> str:
+ return self._prompt
+
+ def _str(self) -> str:
+ return f'element with prompt "{self.prompt}"'
+
+
+class Element(Locator):
+ """Locator for finding ui elements by a class name assigned to the ui element, e.g., by a computer vision model."""
+
+ @validate_call
+ def __init__(
+ self,
+ class_name: Annotated[
+ Literal["text", "textfield"] | None,
+ Field(
+ description="""The class name of the ui element, e.g., 'text' or 'textfield'."""
+ ),
+ ] = None,
+ ) -> None:
+ """Initialize an Element locator.
+
+ Args:
+ class_name: The class name of the ui element, e.g., 'text' or 'textfield'
+ """
+ super().__init__()
+ self._class_name = class_name
+
+ @property
+ def class_name(self) -> Literal["text", "textfield"] | None:
+ return self._class_name
+
+ def _str(self) -> str:
+ return (
+ f'element with class "{self.class_name}"' if self.class_name else "element"
+ )
+
+
+TextMatchType = Literal["similar", "exact", "contains", "regex"]
+DEFAULT_TEXT_MATCH_TYPE: TextMatchType = "similar"
+DEFAULT_SIMILARITY_THRESHOLD = 70
+
+
+class Text(Element):
+ """Locator for finding text elements by their content."""
+
+ @validate_call
+ def __init__(
+ self,
+ text: Annotated[
+ str | None,
+ Field(
+ description="""The text content of the ui element, e.g., 'Sign up'."""
+ ),
+ ] = None,
+ match_type: Annotated[
+ TextMatchType,
+ Field(
+ description="""The type of match to use. Defaults to 'similar'.
+ 'similar' uses a similarity threshold to determine if the text is a match.
+ 'exact' requires the text to be exactly the same.
+ 'contains' requires the text to contain the specified text.
+ 'regex' uses a regular expression to match the text."""
+ ),
+ ] = DEFAULT_TEXT_MATCH_TYPE,
+ similarity_threshold: Annotated[
+ int,
+ Field(
+ ge=0,
+ le=100,
+ description="""A threshold for how similar the text
+ needs to be to the text content of the ui element to be considered a match.
+ Takes values between 0 and 100 (higher is more similar). Defaults to 70.
+ Only used if match_type is 'similar'.""",
+ ),
+ ] = DEFAULT_SIMILARITY_THRESHOLD,
+ ) -> None:
+ """Initialize a Text locator.
+
+ Args:
+ text: The text content of the ui element, e.g., 'Sign up'
+ match_type: The type of match to use. Defaults to 'similar'. 'similar' uses a similarity threshold to
+ determine if the text is a match. 'exact' requires the text to be exactly the same. 'contains'
+ requires the text to contain the specified text. 'regex' uses a regular expression to match the text.
+ similarity_threshold: A threshold for how similar the text needs to be to the text content of the ui
+ element to be considered a match. Takes values between 0 and 100 (higher is more similar).
+ Defaults to 70. Only used if match_type is 'similar'.
+ """
+ super().__init__()
+ self._text = text
+ self._match_type = match_type
+ self._similarity_threshold = similarity_threshold
+
+ @property
+ def text(self) -> str | None:
+ return self._text
+
+ @property
+ def match_type(self) -> TextMatchType:
+ return self._match_type
+
+ @property
+ def similarity_threshold(self) -> int:
+ return self._similarity_threshold
+
+ def _str(self) -> str:
+ if self.text is None:
+ result = "text"
+ else:
+ result = "text "
+ match self.match_type:
+ case "similar":
+ result += f'similar to "{self.text}" (similarity >= {self.similarity_threshold}%)'
+ case "exact":
+ result += f'"{self.text}"'
+ case "contains":
+ result += f'containing text "{self.text}"'
+ case "regex":
+ result += f'matching regex "{self.text}"'
+ return result
+
+
+class ImageBase(Locator, ABC):
+ def __init__(
+ self,
+ threshold: float,
+ stop_threshold: float,
+ mask: list[tuple[float, float]] | None,
+ rotation_degree_per_step: int,
+ name: str,
+ image_compare_format: Literal["RGB", "grayscale", "edges"],
+ ) -> None:
+ super().__init__()
+ if threshold > stop_threshold:
+ raise ValueError(
+ f"threshold ({threshold}) must be less than or equal to stop_threshold ({stop_threshold})"
+ )
+ self._threshold = threshold
+ self._stop_threshold = stop_threshold
+ self._mask = mask
+ self._rotation_degree_per_step = rotation_degree_per_step
+ self._name = name
+ self._image_compare_format = image_compare_format
+
+ @property
+ def threshold(self) -> float:
+ return self._threshold
+
+ @property
+ def stop_threshold(self) -> float:
+ return self._stop_threshold
+
+ @property
+ def mask(self) -> list[tuple[float, float]] | None:
+ return self._mask
+
+ @property
+ def rotation_degree_per_step(self) -> int:
+ return self._rotation_degree_per_step
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def image_compare_format(self) -> Literal["RGB", "grayscale", "edges"]:
+ return self._image_compare_format
+
+ def _params_str(self) -> str:
+ return (
+ "("
+ + ", ".join([
+ f"threshold: {self.threshold}",
+ f"stop_threshold: {self.stop_threshold}",
+ f"rotation_degree_per_step: {self.rotation_degree_per_step}",
+ f"image_compare_format: {self.image_compare_format}",
+ f"mask: {self.mask}"
+ ])
+ + ")"
+ )
+
+ def _str(self) -> str:
+ return (
+ f'element "{self.name}" located by image '
+ + self._params_str()
+ )
+
+
+def _generate_name() -> str:
+ return f"anonymous image {uuid.uuid4()}"
+
+
+class Image(ImageBase):
+ """Locator for finding ui elements by an image."""
+
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ image: Union[PILImage.Image, pathlib.Path, str],
+ threshold: Annotated[
+ float,
+ Field(
+ ge=0,
+ le=1,
+ description="""A threshold for how similar UI elements need to be to the image to be considered a match.
+ Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly
+ like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality.""",
+ ),
+ ] = 0.5,
+ stop_threshold: Annotated[
+ float | None,
+ Field(
+ ge=0,
+ le=1,
+ description="""A threshold for when to stop searching for UI elements similar to the image. As soon
+ as UI elements have been found that are at least as similar as the stop_threshold, the search stops. Should
+ be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of `threshold` if
+ not provided. Important: The stop_threshold impacts the prediction speed.""",
+ ),
+ ] = None,
+ mask: Annotated[
+ list[tuple[float, float]] | None,
+ Field(
+ min_length=3,
+ description="A polygon to match only a certain area of the image.",
+ ),
+ ] = None,
+ rotation_degree_per_step: Annotated[
+ int,
+ Field(
+ ge=0,
+ lt=360,
+ description="""A step size in rotation degree. Rotates the image by rotation_degree_per_step until
+ 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the prediction time
+ quite a bit. So only use it when absolutely necessary.""",
+ ),
+ ] = 0,
+ name: str | None = None,
+ image_compare_format: Annotated[
+ Literal["RGB", "grayscale", "edges"],
+ Field(
+ description="""A color compare style. Defaults to 'grayscale'.
+ Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb,
+ 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For
+ quality it is most often the other way around."""
+ ),
+ ] = "grayscale",
+ ) -> None:
+ """Initialize an Image locator.
+
+ Args:
+ image: The image to match against (PIL Image, path, or string)
+ threshold: A threshold for how similar UI elements need to be to the image to be considered a match.
+ Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly
+ like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality.
+ stop_threshold: A threshold for when to stop searching for UI elements similar to the image. As soon
+ as UI elements have been found that are at least as similar as the stop_threshold, the search stops.
+ Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of
+ `threshold` if not provided. Important: The stop_threshold impacts the prediction speed.
+ mask: A polygon to match only a certain area of the image. Must have at least 3 points.
+ rotation_degree_per_step: A step size in rotation degree. Rotates the image by rotation_degree_per_step
+ until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the
+ prediction time quite a bit. So only use it when absolutely necessary.
+ name: Optional name for the image. Defaults to generated UUID.
+ image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format
+ impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster
+ than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often
+ the other way around.
+ """
+ super().__init__(
+ threshold=threshold,
+ stop_threshold=stop_threshold or threshold,
+ mask=mask,
+ rotation_degree_per_step=rotation_degree_per_step,
+ image_compare_format=image_compare_format,
+ name=_generate_name() if name is None else name,
+ ) # type: ignore
+ self._image = ImageSource(image)
+
+ @property
+ def image(self) -> ImageSource:
+ return self._image
+
+
+class AiElement(ImageBase):
+ """Locator for finding ui elements by an image and other kinds data saved on the disk."""
+
+ @validate_call(config=ConfigDict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ name: str,
+ threshold: Annotated[
+ float,
+ Field(
+ ge=0,
+ le=1,
+ description="""A threshold for how similar UI elements need to be to be considered a match.
+ Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match).
+ Defaults to 0.5. Important: The threshold impacts the prediction quality.""",
+ ),
+ ] = 0.5,
+ stop_threshold: Annotated[
+ float | None,
+ Field(
+ ge=0,
+ le=1,
+ description="""A threshold for when to stop searching for UI elements. As soon
+ as UI elements have been found that are at least as similar as the stop_threshold, the search stops.
+ Should be greater than or equal to threshold. Takes values between 0.0 and 1.0.
+ Defaults to value of `threshold` if not provided.
+ Important: The stop_threshold impacts the prediction speed.""",
+ ),
+ ] = None,
+ mask: Annotated[
+ list[tuple[float, float]] | None,
+ Field(
+ min_length=3,
+ description="A polygon to match only a certain area of the image of the element saved on disk.",
+ ),
+ ] = None,
+ rotation_degree_per_step: Annotated[
+ int,
+ Field(
+ ge=0,
+ lt=360,
+ description="""A step size in rotation degree. Rotates the image of the element saved on disk by
+ rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°.
+ Important: This increases the prediction time quite a bit. So only use it when absolutely necessary.""",
+ ),
+ ] = 0,
+ image_compare_format: Annotated[
+ Literal["RGB", "grayscale", "edges"],
+ Field(
+ description="""A color compare style. Defaults to 'grayscale'.
+ Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb,
+ 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For
+ quality it is most often the other way around."""
+ ),
+ ] = "grayscale",
+ ) -> None:
+ """Initialize an AiElement locator.
+
+ Args:
+ name: Name of the AI element
+ threshold: A threshold for how similar UI elements need to be to be considered a match. Takes values
+ between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match).
+ Defaults to 0.5. Important: The threshold impacts the prediction quality.
+ stop_threshold: A threshold for when to stop searching for UI elements. As soon as UI elements have
+ been found that are at least as similar as the stop_threshold, the search stops. Should be greater
+ than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of `threshold` if not
+ provided. Important: The stop_threshold impacts the prediction speed.
+ mask: A polygon to match only a certain area of the image of the element saved on disk. Must have at
+ least 3 points.
+ rotation_degree_per_step: A step size in rotation degree. Rotates the image of the element saved on
+ disk by rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°.
+ Important: This increases the prediction time quite a bit. So only use it when absolutely necessary.
+ image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format
+ impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster
+ than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often
+ the other way around.
+ """
+ super().__init__(
+ name=name,
+ threshold=threshold,
+ stop_threshold=stop_threshold or threshold,
+ mask=mask,
+ rotation_degree_per_step=rotation_degree_per_step,
+ image_compare_format=image_compare_format,
+ ) # type: ignore
+
+ def _str(self) -> str:
+ return (
+ f'ai element named "{self.name}" '
+ + self._params_str()
+ )
diff --git a/src/askui/locators/relatable.py b/src/askui/locators/relatable.py
new file mode 100644
index 00000000..1cb4df19
--- /dev/null
+++ b/src/askui/locators/relatable.py
@@ -0,0 +1,912 @@
+from abc import ABC
+from typing import Annotated, Literal
+from pydantic import BaseModel, ConfigDict, Field
+from typing_extensions import Self
+
+
+ReferencePoint = Literal["center", "boundary", "any"]
+
+
+RelationTypeMapping = {
+ "above_of": "above of",
+ "below_of": "below of",
+ "right_of": "right of",
+ "left_of": "left of",
+ "and": "and",
+ "or": "or",
+ "containing": "containing",
+ "inside_of": "inside of",
+ "nearest_to": "nearest to",
+}
+
+
+RelationIndex = Annotated[int, Field(ge=0)]
+
+
+class RelationBase(BaseModel):
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+ other_locator: "Relatable"
+ type: Literal[
+ "above_of",
+ "below_of",
+ "right_of",
+ "left_of",
+ "and",
+ "or",
+ "containing",
+ "inside_of",
+ "nearest_to",
+ ]
+
+ def __str__(self):
+ return f"{RelationTypeMapping[self.type]} {self.other_locator._str_with_relation()}"
+
+
+class NeighborRelation(RelationBase):
+ type: Literal["above_of", "below_of", "right_of", "left_of"]
+ index: RelationIndex
+ reference_point: ReferencePoint
+
+ def __str__(self):
+ i = self.index + 1
+ if i == 11 or i == 12 or i == 13:
+ index_str = f"{i}th"
+ else:
+ index_str = (
+ f"{i}st"
+ if i % 10 == 1
+ else f"{i}nd" if i % 10 == 2 else f"{i}rd" if i % 10 == 3 else f"{i}th"
+ )
+ reference_point_str = (
+ " center of"
+ if self.reference_point == "center"
+ else " boundary of" if self.reference_point == "boundary" else ""
+ )
+ return f"{RelationTypeMapping[self.type]}{reference_point_str} the {index_str} {self.other_locator._str_with_relation()}"
+
+
+class LogicalRelation(RelationBase):
+ type: Literal["and", "or"]
+
+
+class BoundingRelation(RelationBase):
+ type: Literal["containing", "inside_of"]
+
+
+class NearestToRelation(RelationBase):
+ type: Literal["nearest_to"]
+
+
+Relation = NeighborRelation | LogicalRelation | BoundingRelation | NearestToRelation
+
+
+class CircularDependencyError(ValueError):
+ """Exception raised for circular dependencies in locator relations."""
+
+ def __init__(
+ self,
+ message: str = (
+ "Detected circular dependency in locator relations. "
+ "This occurs when locators reference each other in a way that creates an infinite loop "
+ "(e.g., A is above B and B is above A)."
+ ),
+ ) -> None:
+ super().__init__(message)
+
+
+class Relatable(ABC):
+ """Base class for locators that can be related to other locators, e.g., spatially, logically, distance based etc.
+
+ Attributes:
+ relations: List of relations to other locators
+ """
+
+ def __init__(self) -> None:
+ self._relations: list[Relation] = []
+
+ @property
+ def relations(self) -> list[Relation]:
+ return self._relations
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation
+ def above_of(
+ self,
+ other_locator: "Relatable",
+ index: RelationIndex = 0,
+ reference_point: ReferencePoint = "boundary",
+ ) -> Self:
+ """Defines the element (located by *self*) to be **above** another element /
+ other elements (located by *other_locator*).
+
+ An element **A** is considered to be *above* another element / other elements **B**
+
+ - if most of **A** (or, more specifically, **A**'s bounding box) is *above* **B**
+ (or, more specifically, the **top border** of **B**'s bounding box) **and**
+ - if the **bottom border** of **A** (or, more specifically, **A**'s bounding box)
+ is *above* the **bottom border** of **B** (or, more specifically, **B**'s
+ bounding box).
+
+ Args:
+ other_locator:
+ Locator for an element / elements to relate to
+ index:
+ Index of the element (located by *self*) above the other element(s)
+ (located by *other_locator*), e.g., the first (index=0), second
+ (index=1), third (index=2) etc. element above the other element(s).
+ Elements' (relative) position is determined by the **bottom border**
+ (*y*-coordinate) of their bounding box.
+ We don't guarantee the order of elements with the same bottom border
+ (*y*-coordinate).
+ reference_point:
+ Defines which element (located by *self*) is considered to be above the
+ other element(s) (located by *other_locator*):
+
+ **"center"**: One point of the element (located by *self*) is above the
+ center (in a straight vertical line) of the other element(s) (located
+ by *other_locator*).
+ **"boundary"**: One point of the element (located by *self*) is above
+ any other point (in a straight vertical line) of the other element(s)
+ (located by *other_locator*).
+ **"any"**: No point of the element (located by *self*) has to be above
+ a point (in a straight vertical line) of the other element(s) (located
+ by *other_locator*).
+
+ *Default is **"boundary".***
+
+ Returns:
+ Self: The locator with the relation added
+
+ Examples:
+ ```text
+
+ ===========
+ | A |
+ ===========
+ ===========
+ | B |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element above ("center" of)
+ # text "B"
+ text = loc.Text().above_of(loc.Text("B"), reference_point="center")
+ ```
+
+ ```text
+
+ ===========
+ | A |
+ ===========
+ ===========
+ | B |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element above
+ # ("boundary" of / any point of) text "B"
+ # (reference point "center" won't work here)
+ text = loc.Text().above_of(loc.Text("B"), reference_point="boundary")
+ ```
+
+ ```text
+
+ ===========
+ | A |
+ ===========
+ ===========
+ | B |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element above text "B"
+ # (reference point "center" or "boundary" won't work here)
+ text = loc.Text().above_of(loc.Text("B"), reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | A |
+ ===========
+ ===========
+ | B |
+ ===========
+ ===========
+ | C |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element above text "C"
+ # (reference point "center" or "boundary" won't work here)
+ text = loc.Text().above_of(loc.Text("C"), index=1, reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | A |
+ ===========
+ ===========
+ =========== | B |
+ | | ===========
+ | C |
+ | |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element above text "C"
+ # (reference point "any")
+ text = loc.Text().above_of(loc.Text("C"), index=1, reference_point="any")
+ # locates also text "A" as it is the first (index 0) element above text "C"
+ # with reference point "boundary"
+ text = loc.Text().above_of(loc.Text("C"), index=0, reference_point="boundary")
+ ```
+ """
+ self._relations.append(
+ NeighborRelation(
+ type="above_of",
+ other_locator=other_locator,
+ index=index,
+ reference_point=reference_point,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation
+ def below_of(
+ self,
+ other_locator: "Relatable",
+ index: RelationIndex = 0,
+ reference_point: ReferencePoint = "boundary",
+ ) -> Self:
+ """Defines the element (located by *self*) to be **below** another element /
+ other elements (located by *other_locator*).
+
+ An element **A** is considered to be *below* another element / other elements **B**
+
+ - if most of **A** (or, more specifically, **A**'s bounding box) is *below* **B**
+ (or, more specifically, the **bottom border** of **B**'s bounding box) **and**
+ - if the **top border** of **A** (or, more specifically, **A**'s bounding box) is
+ *below* the **top border** of **B** (or, more specifically, **B**'s bounding
+ box).
+
+ Args:
+ other_locator:
+ Locator for an element / elements to relate to.
+ index:
+ Index of the element (located by *self*) **below** the other
+ element(s) (located by *other_locator*), e.g., the first (*index=0*),
+ second (*index=1*), third (*index=2*) etc. element below the other
+ element(s). Elements' (relative) position is determined by the **top
+ border** (*y*-coordinate) of their bounding box.
+ We don't guarantee the order of elements with the same top border
+ (*y*-coordinate).
+ reference_point:
+ Defines which element (located by *self*) is considered to be
+ *below* the other element(s) (located by *other_locator*):
+
+ **"center"**: One point of the element (located by *self*) is
+ **below** the *center* (in a straight vertical line) of the other
+ element(s) (located by *other_locator*).
+ **"boundary"**: One point of the element (located by *self*) is
+ **below** *any* other point (in a straight vertical line) of the
+ other element(s) (located by *other_locator*).
+ **"any"**: No point of the element (located by *self*) has to
+ be **below** a point (in a straight vertical line) of the other
+ element(s) (located by *other_locator*).
+
+ *Default is **"boundary".***
+
+ Returns:
+ Self: The locator with the relation added.
+
+ Examples:
+ ```text
+
+ ===========
+ | B |
+ ===========
+ ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element below ("center" of)
+ # text "B"
+ text = loc.Text().below_of(loc.Text("B"), reference_point="center")
+ ```
+
+ ```text
+
+ ===========
+ | B |
+ ===========
+ ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element below
+ # ("boundary" of / any point of) text "B"
+ # (reference point "center" won't work here)
+ text = loc.Text().below_of(loc.Text("B"), reference_point="boundary")
+ ```
+
+ ```text
+
+ ===========
+ | B |
+ ===========
+ ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element below text "B"
+ # (reference point "center" or "boundary won't work here)
+ text = loc.Text().below_of(loc.Text("B"), reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | C |
+ ===========
+ ===========
+ | B |
+ ===========
+ ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element below text "C"
+ # (reference point "center" or "boundary" won't work here)
+ text = loc.Text().below_of(loc.Text("C"), index=1, reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | |
+ | C |
+ | |===========
+ ===========| B |
+ ===========
+ ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element below text "C"
+ # (reference point "any")
+ text = loc.Text().below_of(loc.Text("C"), index=1, reference_point="any")
+ # locates also text "A" as it is the first (index 0) element below text "C"
+ # with reference point "boundary"
+ text = loc.Text().below_of(loc.Text("C"), index=0, reference_point="boundary")
+ ```
+ """
+ self._relations.append(
+ NeighborRelation(
+ type="below_of",
+ other_locator=other_locator,
+ index=index,
+ reference_point=reference_point,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation
+ def right_of(
+ self,
+ other_locator: "Relatable",
+ index: RelationIndex = 0,
+ reference_point: ReferencePoint = "center",
+ ) -> Self:
+ """Defines the element (located by *self*) to be **right of** another element /
+ other elements (located by *other_locator*).
+
+ An element **A** is considered to be *right of* another element / other elements **B**
+
+ - if most of **A** (or, more specifically, **A**'s bounding box) is *right of* **B**
+ (or, more specifically, the **right border** of **B**'s bounding box) **and**
+ - if the **left border** of **A** (or, more specifically, **A**'s bounding box) is
+ *right of* the **left border** of **B** (or, more specifically, **B**'s
+ bounding box).
+
+ Args:
+ other_locator:
+ Locator for an element / elements to relate to.
+ index:
+ Index of the element (located by *self*) **right of** the other
+ element(s) (located by *other_locator*), e.g., the first (*index=0*),
+ second (*index=1*), third (*index=2*) etc. element right of the other
+ element(s). Elements' (relative) position is determined by the **left
+ border** (*x*-coordinate) of their bounding box.
+ We don't guarantee the order of elements with the same left border
+ (*x*-coordinate).
+ reference_point:
+ Defines which element (located by *self*) is considered to be
+ *right of* the other element(s) (located by *other_locator*):
+
+ **"center"**: One point of the element (located by *self*) is
+ **right of** the *center* (in a straight horizontal line) of the
+ other element(s) (located by *other_locator*).
+ **"boundary"**: One point of the element (located by *self*) is
+ **right of** *any* other point (in a straight horizontal line) of
+ the other element(s) (located by *other_locator*).
+ **"any"**: No point of the element (located by *self*) has to
+ be **right of** a point (in a straight horizontal line) of the
+ other element(s) (located by *other_locator*).
+
+ *Default is **"center".***
+
+ Returns:
+ Self: The locator with the relation added.
+
+ Examples:
+ ```text
+
+ =========== ===========
+ | B | | A |
+ =========== ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element right of ("center"
+ # of) text "B"
+ text = loc.Text().right_of(loc.Text("B"), reference_point="center")
+ ```
+
+ ```text
+
+ ===========
+ | B |
+ =========== ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element right of
+ # ("boundary" of / any point of) text "B"
+ # (reference point "center" won't work here)
+ text = loc.Text().right_of(loc.Text("B"), reference_point="boundary")
+ ```
+
+ ```text
+
+ ===========
+ | B |
+ ===========
+ ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element right of text "B"
+ # (reference point "center" or "boundary" won't work here)
+ text = loc.Text().right_of(loc.Text("B"), reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | A |
+ ===========
+ =========== ===========
+ | C | | B |
+ =========== ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element right of text "C"
+ # (reference point "center" or "boundary" won't work here)
+ text = loc.Text().right_of(loc.Text("C"), index=1, reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | B |
+ =========== ===========
+ =========== | A |
+ | C | ===========
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element right of text "C"
+ # (reference point "any")
+ text = loc.Text().right_of(loc.Text("C"), index=1, reference_point="any")
+ # locates also text "A" as it is the first (index 0) element right of text
+ # "C" with reference point "boundary"
+ text = loc.Text().right_of(loc.Text("C"), index=0, reference_point="boundary")
+ ```
+ """
+ self._relations.append(
+ NeighborRelation(
+ type="right_of",
+ other_locator=other_locator,
+ index=index,
+ reference_point=reference_point,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation
+ def left_of(
+ self,
+ other_locator: "Relatable",
+ index: RelationIndex = 0,
+ reference_point: ReferencePoint = "center",
+ ) -> Self:
+ """Defines the element (located by *self*) to be **left of** another element /
+ other elements (located by *other_locator*).
+
+ An element **A** is considered to be *left of* another element / other elements **B**
+
+ - if most of **A** (or, more specifically, **A**'s bounding box) is *left of* **B**
+ (or, more specifically, the **left border** of **B**'s bounding box) **and**
+ - if the **right border** of **A** (or, more specifically, **A**'s bounding box) is
+ *left of* the **right border** of **B** (or, more specifically, **B**'s
+ bounding box).
+
+ Args:
+ other_locator:
+ Locator for an element / elements to relate to.
+ index:
+ Index of the element (located by *self*) **left of** the other
+ element(s) (located by *other_locator*), e.g., the first (*index=0*),
+ second (*index=1*), third (*index=2*) etc. element left of the other
+ element(s). Elements' (relative) position is determined by the **right
+ border** (*x*-coordinate) of their bounding box.
+ We don't guarantee the order of elements with the same right border
+ (*x*-coordinate).
+ reference_point:
+ Defines which element (located by *self*) is considered to be
+ *left of* the other element(s) (located by *other_locator*):
+
+ **"center"** : One point of the element (located by *self*) is
+ **left of** the *center* (in a straight horizontal line) of the
+ other element(s) (located by *other_locator*).
+ **"boundary"**: One point of the element (located by *self*) is
+ **left of** *any* other point (in a straight horizontal line) of
+ the other element(s) (located by *other_locator*).
+ **"any"** : No point of the element (located by *self*) has to
+ be **left of** a point (in a straight horizontal line) of the
+ other element(s) (located by *other_locator*).
+
+ *Default is **"center".***
+
+ Returns:
+ Self: The locator with the relation added.
+
+ Examples:
+ ```text
+
+ =========== ===========
+ | A | | B |
+ =========== ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element left of ("center"
+ # of) text "B"
+ text = loc.Text().left_of(loc.Text("B"), reference_point="center")
+ ```
+
+ ```text
+
+ ===========
+ =========== | B |
+ | A | ===========
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element left of ("boundary"
+ # of / any point of) text "B"
+ # (reference point "center" won't work here)
+ text = loc.Text().left_of(loc.Text("B"), reference_point="boundary")
+ ```
+
+ ```text
+
+ ===========
+ | B |
+ ===========
+ ===========
+ | A |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the first (index 0) element left of text "B"
+ # (reference point "center" or "boundary won't work here)
+ text = loc.Text().left_of(loc.Text("B"), reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | A |
+ ===========
+ =========== ===========
+ | B | | C |
+ =========== ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element left of text "C"
+ # (reference point "center" or "boundary" won't work here)
+ text = loc.Text().left_of(loc.Text("C"), index=1, reference_point="any")
+ ```
+
+ ```text
+
+ ===========
+ | B |
+ =========== ===========
+ | A | ===========
+ =========== | C |
+ ===========
+ ```
+ ```python
+ from askui import locators as loc
+ # locates text "A" as it is the second (index 1) element left of text "C"
+ # (reference point "any")
+ text = loc.Text().left_of(loc.Text("C"), index=1, reference_point="any")
+ # locates also text "A" as it is the first (index 0) element right of text
+ # "C" with reference point "boundary"
+ text = loc.Text().right_of(loc.Text("C"), index=0, reference_point="boundary")
+ ```
+ """
+ self._relations.append(
+ NeighborRelation(
+ type="left_of",
+ other_locator=other_locator,
+ index=index,
+ reference_point=reference_point,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using BoundingRelation
+ def containing(self, other_locator: "Relatable") -> Self:
+ """Defines the element (located by *self*) to contain another element (located
+ by *other_locator*).
+
+ Args:
+ other_locator: The locator to check if it's contained
+
+ Returns:
+ Self: The locator with the relation added
+
+ Examples:
+ ```text
+ ---------------------------
+ | textfield |
+ | --------------------- |
+ | | placeholder text | |
+ | --------------------- |
+ | |
+ ---------------------------
+ ```
+ ```python
+ from askui import locators as loc
+
+ # Returns the textfield because it contains the placeholder text
+ textfield = loc.Element("textfield").containing(loc.Text("placeholder"))
+ ```
+ """
+ self._relations.append(
+ BoundingRelation(
+ type="containing",
+ other_locator=other_locator,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using BoundingRelation
+ def inside_of(self, other_locator: "Relatable") -> Self:
+ """Defines the element (located by *self*) to be inside of another element
+ (located by *other_locator*).
+
+ Args:
+ other_locator: The locator to check if it contains this element
+
+ Returns:
+ Self: The locator with the relation added
+
+ Examples:
+ ```text
+ ---------------------------
+ | textfield |
+ | --------------------- |
+ | | placeholder text | |
+ | --------------------- |
+ | |
+ ---------------------------
+ ```
+ ```python
+ from askui import locators as loc
+
+ # Returns the placeholder text of the textfield
+ placeholder_text = loc.Text("placeholder").inside_of(
+ loc.Element("textfield")
+ )
+ ```
+ """
+ self._relations.append(
+ BoundingRelation(
+ type="inside_of",
+ other_locator=other_locator,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NearestToRelation
+ def nearest_to(self, other_locator: "Relatable") -> Self:
+ """Defines the element (located by *self*) to be the nearest to another element
+ (located by *other_locator*).
+
+ Args:
+ other_locator: The locator to compare distance against
+
+ Returns:
+ Self: The locator with the relation added
+
+ Examples:
+ ```text
+ --------------
+ | text |
+ --------------
+ ---------------
+ | textfield 1 |
+ ---------------
+
+
+
+
+ ---------------
+ | textfield 2 |
+ ---------------
+ ```
+ ```python
+ from askui import locators as loc
+
+ # Returns textfield 1 because it is nearer to the text than textfield 2
+ textfield = loc.Element("textfield").nearest_to(loc.Text())
+ ```
+ """
+ self._relations.append(
+ NearestToRelation(
+ type="nearest_to",
+ other_locator=other_locator,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using LogicalRelation
+ def and_(self, other_locator: "Relatable") -> Self:
+ """Logical and operator to combine multiple locators, e.g., to require an
+ element to match multiple locators.
+
+ Args:
+ other_locator: The locator to combine with
+
+ Returns:
+ Self: The locator with the relation added
+
+ Examples:
+ ```python
+ from askui import locators as loc
+
+ # Searches for an element that contains the text "Google" and is a
+ # multi-colored Google logo (instead of, e.g., simply some text that says
+ # "Google")
+ icon_user = loc.Element().containing(
+ loc.Text("Google").and_(loc.Description("Multi-colored Google logo"))
+ )
+ ```
+ """
+ self._relations.append(
+ LogicalRelation(
+ type="and",
+ other_locator=other_locator,
+ )
+ )
+ return self
+
+ # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using LogicalRelation
+ def or_(self, other_locator: "Relatable") -> Self:
+ """Logical or operator to combine multiple locators, e.g., to provide a fallback
+ if no element is found for one of the locators.
+
+ Args:
+ other_locator: The locator to combine with
+
+ Returns:
+ Self: The locator with the relation added
+
+ Examples:
+ ```python
+ from askui import locators as loc
+
+ # Searches for element using a description and if the element cannot be
+ # found, searches for it using an image
+ search_icon = loc.Description("search icon").or_(
+ loc.Image("search_icon.png")
+ )
+ ```
+ """
+ self._relations.append(
+ LogicalRelation(
+ type="or",
+ other_locator=other_locator,
+ )
+ )
+ return self
+
+ def _str(self) -> str:
+ return "relatable"
+
+ def _relations_str(self) -> str:
+ if not self._relations:
+ return ""
+
+ result = []
+ for i, relation in enumerate(self._relations):
+ [other_locator_str, *nested_relation_strs] = str(relation).split("\n")
+ result.append(f" {i + 1}. {other_locator_str}")
+ for nested_relation_str in nested_relation_strs:
+ result.append(f" {nested_relation_str}")
+ return "\n" + "\n".join(result)
+
+ def _str_with_relation(self) -> str:
+ return self._str() + self._relations_str()
+
+ def raise_if_cycle(self) -> None:
+ """Raises CircularDependencyError if the relations form a cycle (see [Cycle (graph theory)](https://en.wikipedia.org/wiki/Cycle_(graph_theory)))."""
+ if self._has_cycle():
+ raise CircularDependencyError()
+
+ def _has_cycle(self) -> bool:
+ """Check if the relations form a cycle (see [Cycle (graph theory)](https://en.wikipedia.org/wiki/Cycle_(graph_theory)))."""
+ visited_ids: set[int] = set()
+ recursion_stack_ids: set[int] = set()
+
+ def _dfs(node: Relatable) -> bool:
+ node_id = id(node)
+ if node_id in recursion_stack_ids:
+ return True
+ if node_id in visited_ids:
+ return False
+
+ visited_ids.add(node_id)
+ recursion_stack_ids.add(node_id)
+
+ for relation in node.relations:
+ if _dfs(relation.other_locator):
+ return True
+
+ recursion_stack_ids.remove(node_id)
+ return False
+
+ return _dfs(self)
+
+ def __str__(self) -> str:
+ self.raise_if_cycle()
+ return self._str_with_relation()
diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py
new file mode 100644
index 00000000..35e1f180
--- /dev/null
+++ b/src/askui/locators/serializers.py
@@ -0,0 +1,257 @@
+from typing_extensions import NotRequired, TypedDict
+
+from askui.reporting import Reporter
+from askui.utils.image_utils import ImageSource
+from askui.models.askui.ai_element_utils import AiElementCollection
+from .locators import (
+ DEFAULT_SIMILARITY_THRESHOLD,
+ DEFAULT_TEXT_MATCH_TYPE,
+ ImageBase,
+ AiElement as AiElementLocator,
+ Element,
+ Prompt,
+ Image,
+ Text,
+)
+from .relatable import (
+ BoundingRelation,
+ LogicalRelation,
+ NearestToRelation,
+ NeighborRelation,
+ ReferencePoint,
+ Relatable,
+ Relation,
+)
+
+
+class VlmLocatorSerializer:
+ def serialize(self, locator: Relatable) -> str:
+ locator.raise_if_cycle()
+ if len(locator.relations) > 0:
+ raise NotImplementedError(
+ "Serializing locators with relations is not yet supported for VLMs"
+ )
+
+ if isinstance(locator, Text):
+ return self._serialize_text(locator)
+ elif isinstance(locator, Element):
+ return self._serialize_class(locator)
+ elif isinstance(locator, Prompt):
+ return self._serialize_prompt(locator)
+ elif isinstance(locator, Image):
+ raise NotImplementedError(
+ "Serializing image locators is not yet supported for VLMs"
+ )
+ elif isinstance(locator, AiElementLocator):
+ raise NotImplementedError(
+ "Serializing AI element locators is not yet supported for VLMs"
+ )
+ else:
+ raise ValueError(f"Unsupported locator type: {type(locator)}")
+
+ def _serialize_class(self, class_: Element) -> str:
+ if class_.class_name:
+ return f"an arbitrary {class_.class_name} shown"
+ else:
+ return "an arbitrary ui element (e.g., text, button, textfield, etc.)"
+
+ def _serialize_prompt(self, prompt: Prompt) -> str:
+ return prompt.prompt
+
+ def _serialize_text(self, text: Text) -> str:
+ if text.match_type == "similar":
+ return f'text similar to "{text.text}"'
+
+ return str(text)
+
+
+class CustomElement(TypedDict):
+ threshold: NotRequired[float]
+ stopThreshold: NotRequired[float]
+ customImage: str
+ mask: NotRequired[list[tuple[float, float]]]
+ rotationDegreePerStep: NotRequired[int]
+ imageCompareFormat: NotRequired[str]
+ name: NotRequired[str]
+
+
+class AskUiSerializedLocator(TypedDict):
+ instruction: str
+ customElements: list[CustomElement]
+
+
+class AskUiLocatorSerializer:
+ _TEXT_DELIMITER = "<|string|>"
+ _RP_TO_INTERSECTION_AREA_MAPPING: dict[ReferencePoint, str] = {
+ "center": "element_center_line",
+ "boundary": "element_edge_area",
+ "any": "display_edge_area",
+ }
+ _RELATION_TYPE_MAPPING: dict[str, str] = {
+ "above_of": "above",
+ "below_of": "below",
+ "right_of": "right of",
+ "left_of": "left of",
+ "containing": "contains",
+ "inside_of": "in",
+ "nearest_to": "nearest to",
+ "and": "and",
+ "or": "or",
+ }
+
+ def __init__(self, ai_element_collection: AiElementCollection, reporter: Reporter):
+ self._ai_element_collection = ai_element_collection
+ self._reporter = reporter
+
+ def serialize(self, locator: Relatable) -> AskUiSerializedLocator:
+ locator.raise_if_cycle()
+ if len(locator.relations) > 1:
+ # If we lift this constraint, we also have to make sure that custom element references are still working + we need, e.g., some symbol or a structured format to indicate precedence
+ raise NotImplementedError(
+ "Serializing locators with multiple relations is not yet supported by AskUI"
+ )
+
+ result = AskUiSerializedLocator(instruction="", customElements=[])
+ if isinstance(locator, Text):
+ result["instruction"] = self._serialize_text(locator)
+ elif isinstance(locator, Element):
+ result["instruction"] = self._serialize_class(locator)
+ elif isinstance(locator, Prompt):
+ result["instruction"] = self._serialize_prompt(locator)
+ elif isinstance(locator, Image):
+ result = self._serialize_image(locator)
+ elif isinstance(locator, AiElementLocator):
+ result = self._serialize_ai_element(locator)
+ else:
+ raise ValueError(f'Unsupported locator type: "{type(locator)}"')
+
+ if len(locator.relations) == 0:
+ return result
+
+ serialized_relation = self._serialize_relation(locator.relations[0])
+ result["instruction"] += f" {serialized_relation['instruction']}"
+ result["customElements"] += serialized_relation["customElements"]
+ return result
+
+ def _serialize_class(self, class_: Element) -> str:
+ return class_.class_name or "element"
+
+ def _serialize_prompt(self, prompt: Prompt) -> str:
+ return f"pta {self._TEXT_DELIMITER}{prompt.prompt}{self._TEXT_DELIMITER}"
+
+ def _serialize_text(self, text: Text) -> str:
+ match text.match_type:
+ case "similar":
+ if (
+ text.similarity_threshold == DEFAULT_SIMILARITY_THRESHOLD
+ and text.match_type == DEFAULT_TEXT_MATCH_TYPE
+ ):
+ # Necessary so that we can use wordlevel ocr for these texts
+ return (
+ f"text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}"
+ )
+ return f"text with text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER} that matches to {text.similarity_threshold} %"
+ case "exact":
+ return f"text equals text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}"
+ case "contains":
+ return f"text contain text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}"
+ case "regex":
+ return f"text match regex pattern {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}"
+ case _:
+ raise ValueError(f'Unsupported text match type: "{text.match_type}"')
+
+ def _serialize_relation(self, relation: Relation) -> AskUiSerializedLocator:
+ match relation.type:
+ case "above_of" | "below_of" | "right_of" | "left_of":
+ assert isinstance(relation, NeighborRelation)
+ return self._serialize_neighbor_relation(relation)
+ case "containing" | "inside_of" | "nearest_to" | "and" | "or":
+ assert isinstance(
+ relation, LogicalRelation | BoundingRelation | NearestToRelation
+ )
+ return self._serialize_non_neighbor_relation(relation)
+ case _:
+ raise ValueError(f'Unsupported relation type: "{relation.type}"')
+
+ def _serialize_neighbor_relation(
+ self, relation: NeighborRelation
+ ) -> AskUiSerializedLocator:
+ serialized_other_locator = self.serialize(relation.other_locator)
+ return AskUiSerializedLocator(
+ instruction=f"index {relation.index} {self._RELATION_TYPE_MAPPING[relation.type]} intersection_area {self._RP_TO_INTERSECTION_AREA_MAPPING[relation.reference_point]} {serialized_other_locator['instruction']}",
+ customElements=serialized_other_locator["customElements"],
+ )
+
+ def _serialize_non_neighbor_relation(
+ self, relation: LogicalRelation | BoundingRelation | NearestToRelation
+ ) -> AskUiSerializedLocator:
+ serialized_other_locator = self.serialize(relation.other_locator)
+ return AskUiSerializedLocator(
+ instruction=f"{self._RELATION_TYPE_MAPPING[relation.type]} {serialized_other_locator['instruction']}",
+ customElements=serialized_other_locator["customElements"],
+ )
+
+ def _serialize_image_to_custom_element(
+ self,
+ image_locator: ImageBase,
+ image_source: ImageSource,
+ ) -> CustomElement:
+ custom_element: CustomElement = CustomElement(
+ customImage=image_source.to_data_url(),
+ threshold=image_locator.threshold,
+ stopThreshold=image_locator.stop_threshold,
+ rotationDegreePerStep=image_locator.rotation_degree_per_step,
+ imageCompareFormat=image_locator.image_compare_format,
+ name=image_locator.name,
+ )
+ if image_locator.mask:
+ custom_element["mask"] = image_locator.mask
+ return custom_element
+
+ def _serialize_image_base(
+ self,
+ image_locator: ImageBase,
+ image_sources: list[ImageSource],
+ ) -> AskUiSerializedLocator:
+ custom_elements: list[CustomElement] = [
+ self._serialize_image_to_custom_element(
+ image_locator=image_locator,
+ image_source=image_source,
+ )
+ for image_source in image_sources
+ ]
+ return AskUiSerializedLocator(
+ instruction=f"custom element with text {self._TEXT_DELIMITER}{image_locator.name}{self._TEXT_DELIMITER}",
+ customElements=custom_elements,
+ )
+
+ def _serialize_image(
+ self,
+ image: Image,
+ ) -> AskUiSerializedLocator:
+ self._reporter.add_message(
+ "AskUiLocatorSerializer",
+ f"Image locator: {image}",
+ image=image.image.root,
+ )
+ return self._serialize_image_base(
+ image_locator=image,
+ image_sources=[image.image],
+ )
+
+ def _serialize_ai_element(
+ self, ai_element_locator: AiElementLocator
+ ) -> AskUiSerializedLocator:
+ ai_elements = self._ai_element_collection.find(ai_element_locator.name)
+ self._reporter.add_message(
+ "AskUiLocatorSerializer",
+ f"Found {len(ai_elements)} ai elements named {ai_element_locator.name}",
+ image=[ai_element.image for ai_element in ai_elements],
+ )
+ return self._serialize_image_base(
+ image_locator=ai_element_locator,
+ image_sources=[
+ ImageSource.model_construct(root=ai_element.image)
+ for ai_element in ai_elements
+ ],
+ )
diff --git a/src/askui/logger.py b/src/askui/logger.py
index e6da1743..2038ecf9 100644
--- a/src/askui/logger.py
+++ b/src/askui/logger.py
@@ -11,7 +11,7 @@
logger.setLevel(logging.INFO)
-def configure_logging(level=logging.INFO):
+def configure_logging(level: str | int = logging.INFO):
logger.setLevel(level)
diff --git a/src/askui/models/__init__.py b/src/askui/models/__init__.py
new file mode 100644
index 00000000..efc2755c
--- /dev/null
+++ b/src/askui/models/__init__.py
@@ -0,0 +1,7 @@
+from .models import ModelName, ModelComposition, ModelDefinition
+
+__all__ = [
+ "ModelName",
+ "ModelComposition",
+ "ModelDefinition",
+]
diff --git a/src/askui/models/anthropic/claude.py b/src/askui/models/anthropic/claude.py
index ce5813be..4d54f8e8 100644
--- a/src/askui/models/anthropic/claude.py
+++ b/src/askui/models/anthropic/claude.py
@@ -2,24 +2,25 @@
import anthropic
from PIL import Image
+from askui.utils.image_utils import ImageSource, image_to_base64, scale_coordinates_back, scale_image_with_padding
+
from ...logger import logger
-from ...utils import AutomationError
-from ..utils import scale_image_with_padding, scale_coordinates_back, extract_click_coordinates, image_to_base64
+from ...exceptions import ElementNotFoundError
+from .utils import extract_click_coordinates
class ClaudeHandler:
- def __init__(self, log_level):
- self.model_name = "claude-3-5-sonnet-20241022"
+ def __init__(self):
+ self.model = "claude-3-5-sonnet-20241022"
self.client = anthropic.Anthropic()
self.resolution = (1280, 800)
- self.log_level = log_level
self.authenticated = True
if os.getenv("ANTHROPIC_API_KEY") is None:
self.authenticated = False
- def inference(self, base64_image, prompt, system_prompt) -> list[anthropic.types.ContentBlock]:
+ def _inference(self, base64_image: str, prompt: str, system_prompt: str) -> list[anthropic.types.ContentBlock]:
message = self.client.messages.create(
- model=self.model_name,
+ model=self.model,
max_tokens=1000,
temperature=0,
system=system_prompt,
@@ -32,7 +33,7 @@ def inference(self, base64_image, prompt, system_prompt) -> list[anthropic.types
"source": {
"type": "base64",
"media_type": "image/png",
- "data": base64_image
+ "data": base64_image,
}
},
{
@@ -50,19 +51,27 @@ def locate_inference(self, image: Image.Image, locator: str) -> tuple[int, int]:
screen_width, screen_height = self.resolution[0], self.resolution[1]
system_prompt = f"Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try taking another screenshot.\n* The screen's resolution is {screen_width}x{screen_height}.\n* The display number is 0\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\n"
scaled_image = scale_image_with_padding(image, screen_width, screen_height)
- response = self.inference(image_to_base64(scaled_image), prompt, system_prompt)
+ response = self._inference(image_to_base64(scaled_image), prompt, system_prompt)
response = response[0].text
logger.debug("ClaudeHandler received locator: %s", response)
try:
scaled_x, scaled_y = extract_click_coordinates(response)
except Exception as e:
- raise AutomationError(f"Couldn't locate '{locator}' on the screen.")
+ raise ElementNotFoundError(f"Element not found: {locator}")
x, y = scale_coordinates_back(scaled_x, scaled_y, image.width, image.height, screen_width, screen_height)
return int(x), int(y)
- def get_inference(self, image: Image.Image, instruction: str) -> str:
- scaled_image = scale_image_with_padding(image, self.resolution[0], self.resolution[1])
+ def get_inference(self, image: ImageSource, query: str) -> str:
+ scaled_image = scale_image_with_padding(
+ image=image.root,
+ max_width=self.resolution[0],
+ max_height=self.resolution[1],
+ )
system_prompt = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise."
- response = self.inference(image_to_base64(scaled_image), instruction, system_prompt)
+ response = self._inference(
+ base64_image=image_to_base64(scaled_image),
+ prompt=query,
+ system_prompt=system_prompt
+ )
response = response[0].text
return response
diff --git a/src/askui/models/anthropic/claude_agent.py b/src/askui/models/anthropic/claude_agent.py
index 54bc7922..05599433 100644
--- a/src/askui/models/anthropic/claude_agent.py
+++ b/src/askui/models/anthropic/claude_agent.py
@@ -20,10 +20,12 @@
BetaToolUseBlockParam,
)
+from askui.tools.agent_os import AgentOs
+
from ...tools.anthropic import ComputerTool, ToolCollection, ToolResult
from ...logger import logger
-from ...utils import truncate_long_strings
-from askui.reporting.report import SimpleReportGenerator
+from ...utils.str_utils import truncate_long_strings
+from askui.reporting import Reporter
COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
@@ -60,10 +62,10 @@
class ClaudeComputerAgent:
- def __init__(self, controller_client, report: SimpleReportGenerator | None = None) -> None:
- self.report = report
+ def __init__(self, agent_os: AgentOs, reporter: Reporter) -> None:
+ self._reporter = reporter
self.tool_collection = ToolCollection(
- ComputerTool(controller_client),
+ ComputerTool(agent_os),
)
self.system = BetaTextBlockParam(
type="text",
@@ -109,8 +111,8 @@ def step(self, messages: list):
}
logger.debug(new_message)
messages.append(new_message)
- if self.report is not None:
- self.report.add_message("Anthropic Computer Use", response_params)
+ if self._reporter is not None:
+ self._reporter.add_message("Anthropic Computer Use", response_params)
tool_result_content: list[BetaToolResultBlockParam] = []
for content_block in response_params:
diff --git a/src/askui/models/anthropic/utils.py b/src/askui/models/anthropic/utils.py
new file mode 100644
index 00000000..7b27a065
--- /dev/null
+++ b/src/askui/models/anthropic/utils.py
@@ -0,0 +1,8 @@
+import re
+
+
+def extract_click_coordinates(text: str):
+ pattern = r'(\d+),\s*(\d+)'
+ matches = re.findall(pattern, text)
+ x, y = matches[-1]
+ return int(x), int(y)
diff --git a/src/askui/models/askui/ai_element_utils.py b/src/askui/models/askui/ai_element_utils.py
index bda8fb73..c8f3ad4b 100644
--- a/src/askui/models/askui/ai_element_utils.py
+++ b/src/askui/models/askui/ai_element_utils.py
@@ -61,48 +61,49 @@ def from_json_file(cls, json_file_path: pathlib.Path) -> "AiElement":
image = Image.open(image_path))
-class AiElementNotFound(Exception):
- pass
+class AiElementNotFound(ValueError):
+ def __init__(self, name: str, locations: list[pathlib.Path]):
+ self.name = name
+ self.locations = locations
+ locations_str = ", ".join([str(location) for location in locations])
+ super().__init__(
+ f'AI element "{name}" not found in {locations_str}\n'
+ 'Solutions:\n'
+ '1. Verify the element exists in these locations and try again if you are sure it is present\n'
+ '2. Add location to ASKUI_AI_ELEMENT_LOCATIONS env var (paths, comma separated)\n'
+ '3. Create new AI element (see https://docs.askui.com/02-api-reference/02-askui-suite/02-askui-suite/AskUIRemoteDeviceSnippingTool/Public/AskUI-NewAIElement)'
+ )
class AiElementCollection:
-
def __init__(self, additional_ai_element_locations: Optional[List[pathlib.Path]] = None):
+ additional_ai_element_locations = additional_ai_element_locations or []
+
workspace_id = os.getenv("ASKUI_WORKSPACE_ID")
if workspace_id is None:
raise ValueError("ASKUI_WORKSPACE_ID is not set")
- if additional_ai_element_locations is None:
- additional_ai_element_locations = []
-
- addional_ai_element_from_env = []
- if os.getenv("ASKUI_AI_ELEMENT_LOCATIONS", "") != "":
- addional_ai_element_from_env = [pathlib.Path(ai_element_loc) for ai_element_loc in os.getenv("ASKUI_AI_ELEMENT_LOCATIONS", "").split(",")],
+ locations_from_env: list[pathlib.Path] = []
+ if locations_env := os.getenv("ASKUI_AI_ELEMENT_LOCATIONS"):
+ locations_from_env = [pathlib.Path(loc) for loc in locations_env.split(",")]
- self.ai_element_locations = [
+ self._ai_element_locations = [
pathlib.Path.home() / ".askui" / "SnippingTool" / "AIElement" / workspace_id,
- *addional_ai_element_from_env,
+ *locations_from_env,
*additional_ai_element_locations
]
- logger.debug("AI Element locations: %s", self.ai_element_locations)
+ logger.debug("AI Element locations: %s", self._ai_element_locations)
- def find(self, name: str):
- ai_elements = []
-
- for location in self.ai_element_locations:
+ def find(self, name: str) -> list[AiElement]:
+ ai_elements: list[AiElement] = []
+ for location in self._ai_element_locations:
path = pathlib.Path(location)
-
json_files = list(path.glob("*.json"))
-
- if not json_files:
- logger.warning(f"No JSON files found in: {location}")
- continue
-
for json_file in json_files:
ai_element = AiElement.from_json_file(json_file)
-
if ai_element.metadata.name == name:
ai_elements.append(ai_element)
-
- return ai_elements
\ No newline at end of file
+ if len(ai_elements) == 0:
+ raise AiElementNotFound(name=name, locations=self._ai_element_locations)
+ return ai_elements
diff --git a/src/askui/models/askui/api.py b/src/askui/models/askui/api.py
index 915bd2de..cc39cc8a 100644
--- a/src/askui/models/askui/api.py
+++ b/src/askui/models/askui/api.py
@@ -1,105 +1,95 @@
import os
import base64
import pathlib
+from pydantic import RootModel
import requests
-
+import json as json_lib
from PIL import Image
-from typing import List, Union
-from askui.models.askui.ai_element_utils import AiElement, AiElementCollection, AiElementNotFound
-from askui.utils import image_to_base64
+from typing import Any, Type, Union
+from askui.models.models import ModelComposition
+from askui.utils.image_utils import ImageSource
+from askui.locators.serializers import AskUiLocatorSerializer
+from askui.locators.locators import Locator
+from askui.utils.image_utils import image_to_base64
from askui.logger import logger
+from ..types.response_schemas import ResponseSchema, to_response_schema
-class AskUIHandler:
- def __init__(self):
+class AskUiInferenceApi:
+ def __init__(self, locator_serializer: AskUiLocatorSerializer):
+ self._locator_serializer = locator_serializer
self.inference_endpoint = os.getenv("ASKUI_INFERENCE_ENDPOINT", "https://inference.askui.com")
self.workspace_id = os.getenv("ASKUI_WORKSPACE_ID")
self.token = os.getenv("ASKUI_TOKEN")
-
self.authenticated = True
if self.workspace_id is None or self.token is None:
logger.warning("ASKUI_WORKSPACE_ID or ASKUI_TOKEN missing.")
self.authenticated = False
- self.ai_element_collection = AiElementCollection()
-
-
-
def _build_askui_token_auth_header(self, bearer_token: str | None = None) -> dict[str, str]:
if bearer_token is not None:
return {"Authorization": f"Bearer {bearer_token}"}
+
+ if self.token is None:
+ raise Exception("ASKUI_TOKEN is not set.")
token_base64 = base64.b64encode(self.token.encode("utf-8")).decode("utf-8")
return {"Authorization": f"Basic {token_base64}"}
-
- def _build_custom_elements(self, ai_elements: List[AiElement] | None):
- """
- Converts AiElements to the CustomElementDto format expected by the backend.
-
- Args:
- ai_elements (List[AiElement]): List of AI elements to convert
-
- Returns:
- dict: Custom elements in the format expected by the backend
- """
- if not ai_elements:
- return {}
-
- custom_elements = []
- for element in ai_elements:
- custom_element = {
- "customImage": "," + image_to_base64(element.image),
- "imageCompareFormat": "grayscale",
- "name": element.metadata.name
- }
- custom_elements.append(custom_element)
-
- return {
- "customElements": custom_elements
- }
- def __build_model_composition(self):
- return {}
-
- def __build_base_url(self, endpoint: str = "inference") -> str:
+
+ def _build_base_url(self, endpoint: str) -> str:
return f"{self.inference_endpoint}/api/v3/workspaces/{self.workspace_id}/{endpoint}"
- def predict(self, image: Union[pathlib.Path, Image.Image], locator: str, ai_elements: List[pathlib.Path] = None) -> tuple[int | None, int | None]:
+ def _request(self, endpoint: str, json: dict[str, Any] | None = None) -> Any:
response = requests.post(
- self.__build_base_url(),
- json={
- "image": f",{image_to_base64(image)}",
- **({"instruction": locator} if locator is not None else {}),
- **self.__build_model_composition(),
- **self._build_custom_elements(ai_elements)
- },
+ self._build_base_url(endpoint),
+ json=json,
headers={"Content-Type": "application/json", **self._build_askui_token_auth_header()},
timeout=30,
)
if response.status_code != 200:
raise Exception(f"{response.status_code}: Unknown Status Code\n", response.text)
- content = response.json()
+ return response.json()
+
+ def predict(self, image: Union[pathlib.Path, Image.Image], locator: Locator, model: ModelComposition | None = None) -> tuple[int | None, int | None]:
+ serialized_locator = self._locator_serializer.serialize(locator=locator)
+ logger.debug(f"serialized_locator:\n{json_lib.dumps(serialized_locator)}")
+ json: dict[str, Any] = {
+ "image": f",{image_to_base64(image)}",
+ "instruction": f"Click on {serialized_locator['instruction']}",
+ }
+ if "customElements" in serialized_locator:
+ json["customElements"] = serialized_locator["customElements"]
+ if model is not None:
+ json["modelComposition"] = model.model_dump(by_alias=True)
+ logger.debug(f"modelComposition:\n{json_lib.dumps(json['modelComposition'])}")
+ content = self._request(endpoint="inference", json=json)
assert content["type"] == "COMMANDS", f"Received unknown content type {content['type']}"
actions = [el for el in content["data"]["actions"] if el["inputEvent"] == "MOUSE_MOVE"]
if len(actions) == 0:
return None, None
- position = actions[0]["position"]
+ position = actions[0]["position"]
return int(position["x"]), int(position["y"])
-
- def locate_pta_prediction(self, image: Union[pathlib.Path, Image.Image], locator: str) -> tuple[int | None, int | None]:
- askui_locator = f'Click on pta "{locator}"'
- return self.predict(image, askui_locator)
-
- def locate_ocr_prediction(self, image: Union[pathlib.Path, Image.Image], locator: str) -> tuple[int | None, int | None]:
- askui_locator = f'Click on with text "{locator}"'
- return self.predict(image, askui_locator)
-
- def locate_ai_element_prediction(self, image: Union[pathlib.Path, Image.Image], name: str) -> tuple[int | None, int | None]:
- ai_elements = self.ai_element_collection.find(name)
- if len(ai_elements) == 0:
- raise AiElementNotFound(f"Could not locate AI element with name '{name}'")
-
- askui_instruction = f'Click on custom element with text "{name}"'
- return self.predict(image, askui_instruction, ai_elements=ai_elements)
+ def get_inference(
+ self,
+ image: ImageSource,
+ query: str,
+ response_schema: Type[ResponseSchema] | None = None
+ ) -> ResponseSchema | str:
+ json: dict[str, Any] = {
+ "image": image.to_data_url(),
+ "prompt": query,
+ }
+ _response_schema = to_response_schema(response_schema)
+ json["config"] = {
+ "json_schema": _response_schema.model_json_schema()
+ }
+ logger.debug(f"json_schema:\n{json_lib.dumps(json['config']['json_schema'])}")
+ content = self._request(endpoint="vqa/inference", json=json)
+ response = content["data"]["response"]
+ validated_response = _response_schema.model_validate(response)
+ if isinstance(validated_response, RootModel):
+ return validated_response.root
+ return validated_response
diff --git a/src/askui/models/huggingface/spaces_api.py b/src/askui/models/huggingface/spaces_api.py
index 5d2b5a7a..f8499206 100644
--- a/src/askui/models/huggingface/spaces_api.py
+++ b/src/askui/models/huggingface/spaces_api.py
@@ -2,7 +2,7 @@
import tempfile
from gradio_client import Client, handle_file
-from askui.utils import AutomationError
+from askui.exceptions import AutomationError
class HFSpacesHandler:
diff --git a/src/askui/models/models.py b/src/askui/models/models.py
new file mode 100644
index 00000000..71da37b2
--- /dev/null
+++ b/src/askui/models/models.py
@@ -0,0 +1,86 @@
+from collections.abc import Iterator
+from enum import Enum
+import re
+from typing import Annotated
+from pydantic import BaseModel, ConfigDict, Field, RootModel
+
+
+class ModelName(str, Enum):
+ ANTHROPIC__CLAUDE__3_5__SONNET__20241022 = "anthropic-claude-3-5-sonnet-20241022"
+ ANTHROPIC = "anthropic"
+ ASKUI = "askui"
+ ASKUI__AI_ELEMENT = "askui-ai-element"
+ ASKUI__COMBO = "askui-combo"
+ ASKUI__OCR = "askui-ocr"
+ ASKUI__PTA = "askui-pta"
+ TARS = "tars"
+
+
+MODEL_DEFINITION_PROPERTY_REGEX_PATTERN = re.compile(r"^[A-Za-z0-9_]+$")
+
+
+ModelDefinitionProperty = Annotated[
+ str, Field(pattern=MODEL_DEFINITION_PROPERTY_REGEX_PATTERN)
+]
+
+
+class ModelDefinition(BaseModel):
+ """
+ A definition of a model.
+ """
+ model_config = ConfigDict(
+ populate_by_name=True,
+ )
+ task: ModelDefinitionProperty = Field(
+ description="The task the model is trained for, e.g., end-to-end OCR (e2e_ocr) or object detection (od)",
+ examples=["e2e_ocr", "od"],
+ )
+ architecture: ModelDefinitionProperty = Field(
+ description="The architecture of the model", examples=["easy_ocr", "yolo"]
+ )
+ version: str = Field(pattern=r"^[0-9]{1,6}$")
+ interface: ModelDefinitionProperty = Field(
+ description="The interface the model is trained for",
+ examples=["online_learning", "offline_learning"],
+ )
+ use_case: ModelDefinitionProperty = Field(
+ description='The use case the model is trained for. In the case of workspace specific AskUI models, this is often the workspace id but with "-" replaced by "_"',
+ examples=[
+ "fb3b9a7b_3aea_41f7_ba02_e55fd66d1c1e",
+ "00000000_0000_0000_0000_000000000000",
+ ],
+ default="00000000_0000_0000_0000_000000000000",
+ serialization_alias="useCase",
+ )
+ tags: list[ModelDefinitionProperty] = Field(
+ default_factory=list,
+ description="Tags for identifying the model that cannot be represented by other properties",
+ examples=["trained", "word_level"],
+ )
+
+ @property
+ def model_name(self) -> str:
+ return (
+ "-".join(
+ [
+ self.task,
+ self.architecture,
+ self.interface,
+ self.use_case,
+ self.version,
+ *self.tags,
+ ]
+ )
+ )
+
+
+class ModelComposition(RootModel[list[ModelDefinition]]):
+ """
+ A composition of models.
+ """
+
+ def __iter__(self):
+ return iter(self.root)
+
+ def __getitem__(self, index: int) -> ModelDefinition:
+ return self.root[index]
diff --git a/src/askui/models/router.py b/src/askui/models/router.py
index 377486dc..7f3395cb 100644
--- a/src/askui/models/router.py
+++ b/src/askui/models/router.py
@@ -1,12 +1,22 @@
-from typing import Optional
+from typing import Type
+from typing_extensions import override
from PIL import Image
from askui.container import telemetry
-from .askui.api import AskUIHandler
+from askui.locators.locators import AiElement, Prompt, Text
+from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer
+from askui.locators.locators import Locator
+from askui.models.askui.ai_element_utils import AiElementCollection
+from askui.models.models import ModelComposition, ModelName
+from askui.models.types.response_schemas import ResponseSchema
+from askui.reporting import CompositeReporter, Reporter
+from askui.tools.toolbox import AgentToolbox
+from askui.utils.image_utils import ImageSource
+from .askui.api import AskUiInferenceApi
from .anthropic.claude import ClaudeHandler
from .huggingface.spaces_api import HFSpacesHandler
+from ..exceptions import AutomationError, ElementNotFoundError
from ..logger import logger
-from ..utils import AutomationError
from .ui_tars_ep.ui_tars_api import UITarsAPIHandler
from .anthropic.claude_agent import ClaudeComputerAgent
from abc import ABC, abstractmethod
@@ -14,115 +24,208 @@
Point = tuple[int, int]
-def handle_response(response: tuple[int | None, int | None], locator: str):
+
+def handle_response(response: tuple[int | None, int | None], locator: str | Locator):
if response[0] is None or response[1] is None:
- raise AutomationError(f'Could not locate "{locator}"')
+ raise ElementNotFoundError(f"Element not found: {locator}")
return response
-class GroundingModelRouter(ABC):
+class GroundingModelRouter(ABC):
@abstractmethod
- def locate(self, screenshot: Image.Image, locator: str, model_name: str | None = None) -> Point:
+ def locate(
+ self,
+ screenshot: Image.Image,
+ locator: str | Locator,
+ model: ModelComposition | str | None = None,
+ ) -> Point:
pass
@abstractmethod
- def is_responsible(self, model_name: Optional[str]) -> bool:
+ def is_responsible(self, model: ModelComposition | str | None = None) -> bool:
pass
-
+
@abstractmethod
def is_authenticated(self) -> bool:
pass
-class AskUIModelRouter(GroundingModelRouter):
-
- def __init__(self):
- self.askui = AskUIHandler()
-
- def locate(self, screenshot: Image.Image, locator: str, model_name: str | None = None) -> Point:
- if not self.askui.authenticated:
- raise AutomationError(f"NoAskUIAuthenticationSet! Please set 'AskUI ASKUI_WORKSPACE_ID' or 'ASKUI_TOKEN' as env variables!")
-
- if model_name == "askui-pta":
- logger.debug(f"Routing locate prediction to askui-pta")
- x, y = self.askui.locate_pta_prediction(screenshot, locator)
+class AskUiModelRouter(GroundingModelRouter):
+ def __init__(self, inference_api: AskUiInferenceApi):
+ self._inference_api = inference_api
+
+ def _locate_with_askui_ocr(self, screenshot: Image.Image, locator: str | Text) -> Point:
+ locator = Text(locator) if isinstance(locator, str) else locator
+ x, y = self._inference_api.predict(screenshot, locator)
+ return handle_response((x, y), locator)
+
+ @override
+ def locate(
+ self,
+ screenshot: Image.Image,
+ locator: str | Locator,
+ model: ModelComposition | str | None = None,
+ ) -> Point:
+ if not self._inference_api.authenticated:
+ raise AutomationError(
+ "NoAskUIAuthenticationSet! Please set 'AskUI ASKUI_WORKSPACE_ID' or 'ASKUI_TOKEN' as env variables!"
+ )
+ if not isinstance(model, str) or model == ModelName.ASKUI:
+ logger.debug("Routing locate prediction to askui")
+ locator = Text(locator) if isinstance(locator, str) else locator
+ _model = model if not isinstance(model, str) else None
+ x, y = self._inference_api.predict(screenshot, locator, _model)
return handle_response((x, y), locator)
- if model_name == "askui-ocr":
- logger.debug(f"Routing locate prediction to askui-ocr")
- x, y = self.askui.locate_ocr_prediction(screenshot, locator)
+ if not isinstance(locator, str):
+ raise AutomationError(
+ f'Locators of type `{type(locator)}` are not supported for models "askui-pta", "askui-ocr" and "askui-combo" and "askui-ai-element". Please provide a `str`.'
+ )
+ if model == ModelName.ASKUI__PTA:
+ logger.debug("Routing locate prediction to askui-pta")
+ x, y = self._inference_api.predict(screenshot, Prompt(locator))
return handle_response((x, y), locator)
- if model_name == "askui-combo" or model_name is None:
- logger.debug(f"Routing locate prediction to askui-combo")
- x, y = self.askui.locate_pta_prediction(screenshot, locator)
+ if model == ModelName.ASKUI__OCR:
+ logger.debug("Routing locate prediction to askui-ocr")
+ return self._locate_with_askui_ocr(screenshot, locator)
+ if model == ModelName.ASKUI__COMBO or model is None:
+ logger.debug("Routing locate prediction to askui-combo")
+ prompt_locator = Prompt(locator)
+ x, y = self._inference_api.predict(screenshot, prompt_locator)
if x is None or y is None:
- x, y = self.askui.locate_ocr_prediction(screenshot, locator)
- return handle_response((x, y), locator)
- if model_name == "askui-ai-element":
- logger.debug(f"Routing click prediction to askui-ai-element")
- x, y = self.askui.locate_ai_element_prediction(screenshot, locator)
- return handle_response((x, y), locator)
- raise AutomationError(f"Invalid model name {model_name} for click")
-
- def is_responsible(self, model_name: Optional[str]):
- return model_name is None or model_name.startswith("askui")
-
+ return self._locate_with_askui_ocr(screenshot, locator)
+ return handle_response((x, y), prompt_locator)
+ if model == ModelName.ASKUI__AI_ELEMENT:
+ logger.debug("Routing click prediction to askui-ai-element")
+ _locator = AiElement(locator)
+ x, y = self._inference_api.predict(screenshot, _locator)
+ return handle_response((x, y), _locator)
+ raise AutomationError(f'Invalid model: "{model}"')
+
+ @override
+ def is_responsible(self, model: ModelComposition | str | None = None) -> bool:
+ return not isinstance(model, str) or model.startswith(ModelName.ASKUI)
+
+ @override
def is_authenticated(self) -> bool:
- return self.askui.authenticated
+ return self._inference_api.authenticated
-
class ModelRouter:
- def __init__(self, log_level, report,
- grounding_model_routers: list[GroundingModelRouter] | None = None):
- self.report = report
-
- self.grounding_model_routers = grounding_model_routers or [AskUIModelRouter()]
-
- self.claude = ClaudeHandler(log_level)
- self.huggingface_spaces = HFSpacesHandler()
- self.tars = UITarsAPIHandler(self.report)
-
- def act(self, controller_client, goal: str, model_name: str | None = None):
- if self.tars.authenticated and model_name == "tars":
- return self.tars.act(controller_client, goal)
- if self.claude.authenticated and (model_name == "claude" or model_name is None):
- agent = ClaudeComputerAgent(controller_client, self.report)
- return agent.run(goal)
- raise AutomationError("Invalid model name for act")
-
- def get_inference(self, screenshot: Image.Image, locator: str, model_name: str | None = None):
- if self.tars.authenticated and model_name == "tars":
- return self.tars.get_prediction(screenshot, locator)
- if self.claude.authenticated and (model_name == "anthropic-claude-3-5-sonnet-20241022" or model_name is None):
- return self.claude.get_inference(screenshot, locator)
- raise AutomationError("Executing get commands requires to authenticate with an Automation Model Provider supporting it.")
+ def __init__(
+ self,
+ tools: AgentToolbox,
+ grounding_model_routers: list[GroundingModelRouter] | None = None,
+ reporter: Reporter | None = None,
+ ):
+ _reporter = reporter or CompositeReporter()
+ self._askui = AskUiInferenceApi(
+ locator_serializer=AskUiLocatorSerializer(
+ ai_element_collection=AiElementCollection(),
+ reporter=_reporter,
+ ),
+ )
+ self._grounding_model_routers = grounding_model_routers or [AskUiModelRouter(inference_api=self._askui)]
+ self._claude = ClaudeHandler()
+ self._huggingface_spaces = HFSpacesHandler()
+ self._tars = UITarsAPIHandler(agent_os=tools.agent_os, reporter=_reporter)
+ self._claude_computer_agent = ClaudeComputerAgent(agent_os=tools.agent_os, reporter=_reporter)
+ self._locator_serializer = VlmLocatorSerializer()
+
+ def act(self, goal: str, model: ModelComposition | str | None = None):
+ if self._tars.authenticated and model == ModelName.TARS:
+ return self._tars.act(goal)
+ if self._claude.authenticated and (model is None or isinstance(model, str) and model.startswith(ModelName.ANTHROPIC)):
+ return self._claude_computer_agent.run(goal)
+ raise AutomationError(f"Invalid model for act: {model}")
+
+ def get_inference(
+ self,
+ query: str,
+ image: ImageSource,
+ response_schema: Type[ResponseSchema] | None = None,
+ model: ModelComposition | str | None = None,
+ ) -> ResponseSchema | str:
+ if self._tars.authenticated and model == ModelName.TARS:
+ if response_schema not in [str, None]:
+ raise NotImplementedError("(Non-String) Response schema is not yet supported for UI-TARS models.")
+ return self._tars.get_inference(image=image, query=query)
+ if self._claude.authenticated and (
+ isinstance(model, str) and model.startswith(ModelName.ANTHROPIC)
+ ):
+ if response_schema not in [str, None]:
+ raise NotImplementedError("(Non-String) Response schema is not yet supported for Anthropic models.")
+ return self._claude.get_inference(image=image, query=query)
+ if self._askui.authenticated and (model == ModelName.ASKUI or model is None):
+ return self._askui.get_inference(
+ image=image,
+ query=query,
+ response_schema=response_schema,
+ )
+ raise AutomationError(
+ f"Executing get commands requires to authenticate with an Automation Model Provider supporting it: {model}"
+ )
+
+ def _serialize_locator(self, locator: str | Locator) -> str:
+ if isinstance(locator, Locator):
+ return self._locator_serializer.serialize(locator=locator)
+ return locator
@telemetry.record_call(exclude={"locator", "screenshot"})
- def locate(self, screenshot: Image.Image, locator: str, model_name: str | None = None) -> Point:
- if model_name is not None and model_name in self.huggingface_spaces.get_spaces_names():
- x, y = self.huggingface_spaces.predict(screenshot, locator, model_name)
+ def locate(
+ self,
+ screenshot: Image.Image,
+ locator: str | Locator,
+ model: ModelComposition | str | None = None,
+ ) -> Point:
+ if (
+ isinstance(model, str)
+ and model in self._huggingface_spaces.get_spaces_names()
+ ):
+ x, y = self._huggingface_spaces.predict(
+ screenshot=screenshot,
+ locator=self._serialize_locator(locator),
+ model_name=model,
+ )
return handle_response((x, y), locator)
- if model_name is not None:
- if model_name.startswith("anthropic") and not self.claude.authenticated:
- raise AutomationError("You need to provide Anthropic credentials to use Anthropic models.")
- if model_name.startswith("tars") and not self.tars.authenticated:
- raise AutomationError("You need to provide UI-TARS HF Endpoint credentials to use UI-TARS models.")
- if self.tars.authenticated and model_name == "tars":
- x, y = self.tars.locate_prediction(screenshot, locator)
+ if isinstance(model, str):
+ if model.startswith(ModelName.ANTHROPIC) and not self._claude.authenticated:
+ raise AutomationError(
+ "You need to provide Anthropic credentials to use Anthropic models."
+ )
+ if model.startswith(ModelName.TARS) and not self._tars.authenticated:
+ raise AutomationError(
+ "You need to provide UI-TARS HF Endpoint credentials to use UI-TARS models."
+ )
+ if self._tars.authenticated and model == ModelName.TARS:
+ x, y = self._tars.locate_prediction(
+ screenshot, self._serialize_locator(locator)
+ )
return handle_response((x, y), locator)
- if self.claude.authenticated and model_name == "anthropic-claude-3-5-sonnet-20241022":
+ if (
+ self._claude.authenticated
+ and isinstance(model, str) and model.startswith(ModelName.ANTHROPIC)
+ ):
logger.debug("Routing locate prediction to Anthropic")
- x, y = self.claude.locate_inference(screenshot, locator)
+ x, y = self._claude.locate_inference(
+ screenshot, self._serialize_locator(locator)
+ )
return handle_response((x, y), locator)
-
- for grounding_model_router in self.grounding_model_routers:
- if grounding_model_router.is_responsible(model_name) and grounding_model_router.is_authenticated():
- return grounding_model_router.locate(screenshot, locator, model_name)
- if model_name is None:
- if self.claude.authenticated:
+ for grounding_model_router in self._grounding_model_routers:
+ if (
+ grounding_model_router.is_responsible(model)
+ and grounding_model_router.is_authenticated()
+ ):
+ return grounding_model_router.locate(screenshot, locator, model)
+
+ if model is None:
+ if self._claude.authenticated:
logger.debug("Routing locate prediction to Anthropic")
- x, y = self.claude.locate_inference(screenshot, locator)
+ x, y = self._claude.locate_inference(
+ screenshot, self._serialize_locator(locator)
+ )
return handle_response((x, y), locator)
-
- raise AutomationError("Executing locate commands requires to authenticate with an Automation Model Provider.")
+
+ raise AutomationError(
+ "Executing locate commands requires to authenticate with an Automation Model Provider."
+ )
diff --git a/src/askui/models/types/__init__.py b/src/askui/models/types/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/askui/models/types/response_schemas.py b/src/askui/models/types/response_schemas.py
new file mode 100644
index 00000000..e9eba25c
--- /dev/null
+++ b/src/askui/models/types/response_schemas.py
@@ -0,0 +1,43 @@
+from typing import Type, TypeVar, overload
+from pydantic import BaseModel, ConfigDict, RootModel
+
+
+class ResponseSchemaBase(BaseModel):
+ model_config = ConfigDict(extra="forbid")
+
+
+String = RootModel[str]
+Boolean = RootModel[bool]
+Integer = RootModel[int]
+Float = RootModel[float]
+
+
+ResponseSchema = TypeVar('ResponseSchema', ResponseSchemaBase, str, bool, int, float)
+
+
+@overload
+def to_response_schema(response_schema: None) -> Type[String]: ...
+@overload
+def to_response_schema(response_schema: Type[str]) -> Type[String]: ...
+@overload
+def to_response_schema(response_schema: Type[bool]) -> Type[Boolean]: ...
+@overload
+def to_response_schema(response_schema: Type[int]) -> Type[Integer]: ...
+@overload
+def to_response_schema(response_schema: Type[float]) -> Type[Float]: ...
+@overload
+def to_response_schema(response_schema: Type[ResponseSchemaBase]) -> Type[ResponseSchemaBase]: ...
+def to_response_schema(response_schema: Type[ResponseSchemaBase] | Type[str] | Type[bool] | Type[int] | Type[float] | None = None) -> Type[ResponseSchemaBase] | Type[String] | Type[Boolean] | Type[Integer] | Type[Float]:
+ if response_schema is None:
+ return String
+ if response_schema is str:
+ return String
+ if response_schema is bool:
+ return Boolean
+ if response_schema is int:
+ return Integer
+ if response_schema is float:
+ return Float
+ if issubclass(response_schema, ResponseSchemaBase):
+ return response_schema
+ raise ValueError(f"Invalid response schema type: {response_schema}")
diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py
index 312e6a56..0bc97c96 100644
--- a/src/askui/models/ui_tars_ep/ui_tars_api.py
+++ b/src/askui/models/ui_tars_ep/ui_tars_api.py
@@ -1,18 +1,23 @@
import re
import os
import pathlib
-from typing import Union
+from typing import Any, Union
from openai import OpenAI
-from askui.utils import image_to_base64
+from askui.reporting import Reporter
+from askui.tools.agent_os import AgentOs
+from askui.utils.image_utils import image_to_base64
from PIL import Image
+
+from askui.utils.image_utils import ImageSource
from .prompts import PROMPT, PROMPT_QA
from .parser import UITarsEPMessage
import time
class UITarsAPIHandler:
- def __init__(self, report):
- self.report = report
+ def __init__(self, agent_os: AgentOs, reporter: Reporter):
+ self._agent_os = agent_os
+ self._reporter = reporter
if os.getenv("TARS_URL") is None or os.getenv("TARS_API_KEY") is None:
self.authenticated = False
else:
@@ -22,7 +27,7 @@ def __init__(self, report):
api_key=os.getenv("TARS_API_KEY")
)
- def predict(self, screenshot, instruction: str, prompt: str):
+ def _predict(self, image_url: str, instruction: str, prompt: str) -> Any:
chat_completion = self.client.chat.completions.create(
model="tgi",
messages=[
@@ -32,7 +37,7 @@ def predict(self, screenshot, instruction: str, prompt: str):
{
"type": "image_url",
"image_url": {
- "url": f"data:image/png;base64,{image_to_base64(screenshot)}"
+ "url": image_url,
}
},
{
@@ -55,7 +60,11 @@ def predict(self, screenshot, instruction: str, prompt: str):
def locate_prediction(self, image: Union[pathlib.Path, Image.Image], locator: str) -> tuple[int | None, int | None]:
askui_locator = f'Click on "{locator}"'
- prediction = self.predict(image, askui_locator, PROMPT)
+ prediction = self._predict(
+ image_url=f"data:image/png;base64,{image_to_base64(image)}",
+ instruction=askui_locator,
+ prompt=PROMPT,
+ )
pattern = r"click\(start_box='(\(\d+,\d+\))'\)"
match = re.search(pattern, prediction)
if match:
@@ -69,11 +78,15 @@ def locate_prediction(self, image: Union[pathlib.Path, Image.Image], locator: st
return x, y
return None, None
- def get_prediction(self, image: Image.Image, instruction: str) -> str:
- return self.predict(image, instruction, PROMPT_QA)
+ def get_inference(self, image: ImageSource, query: str) -> str:
+ return self._predict(
+ image_url=image.to_data_url(),
+ instruction=query,
+ prompt=PROMPT_QA,
+ )
- def act(self, controller_client, goal: str) -> str:
- screenshot = controller_client.screenshot()
+ def act(self, goal: str) -> None:
+ screenshot = self._agent_os.screenshot()
self.act_history = [
{
"role": "user",
@@ -91,10 +104,10 @@ def act(self, controller_client, goal: str) -> str:
]
}
]
- self.execute_act(controller_client, self.act_history)
+ self.execute_act(self.act_history)
- def add_screenshot_to_history(self, controller_client, message_history):
- screenshot = controller_client.screenshot()
+ def add_screenshot_to_history(self, message_history):
+ screenshot = self._agent_os.screenshot()
message_history.append(
{
"role": "user",
@@ -148,7 +161,7 @@ def filter_message_thread(self, message_history, max_screenshots=3):
return filtered_messages
- def execute_act(self, controller_client, message_history):
+ def execute_act(self, message_history):
message_history = self.filter_message_thread(message_history)
chat_completion = self.client.chat.completions.create(
@@ -166,8 +179,8 @@ def execute_act(self, controller_client, message_history):
raw_message = chat_completion.choices[-1].message.content
print(raw_message)
- if self.report is not None:
- self.report.add_message("UI-TARS", raw_message)
+ if self._reporter is not None:
+ self._reporter.add_message("UI-TARS", raw_message)
try:
message = UITarsEPMessage.parse_message(raw_message)
@@ -184,21 +197,21 @@ def execute_act(self, controller_client, message_history):
]
}
)
- self.execute_act(controller_client, message_history)
+ self.execute_act(message_history)
return
action = message.parsed_action
if action.action_type == "click":
- controller_client.mouse(action.start_box.x, action.start_box.y)
- controller_client.click("left")
+ self._agent_os.mouse(action.start_box.x, action.start_box.y)
+ self._agent_os.click("left")
time.sleep(1)
if action.action_type == "type":
- controller_client.click("left")
- controller_client.type(action.content)
+ self._agent_os.click("left")
+ self._agent_os.type(action.content)
time.sleep(0.5)
if action.action_type == "hotkey":
- controller_client.keyboard_pressed(action.content)
- controller_client.keyboard_release(action.content)
+ self._agent_os.keyboard_pressed(action.content)
+ self._agent_os.keyboard_release(action.content)
time.sleep(0.5)
if action.action_type == "call_user":
time.sleep(1)
@@ -207,5 +220,5 @@ def execute_act(self, controller_client, message_history):
if action.action_type == "finished":
return
- self.add_screenshot_to_history(controller_client, message_history)
- self.execute_act(controller_client, message_history)
\ No newline at end of file
+ self.add_screenshot_to_history(message_history)
+ self.execute_act(message_history)
diff --git a/src/askui/models/utils.py b/src/askui/models/utils.py
deleted file mode 100644
index a5f0cd43..00000000
--- a/src/askui/models/utils.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import re
-import base64
-
-from io import BytesIO
-from PIL import Image, ImageOps
-
-
-def scale_image_with_padding(image, max_width, max_height):
- original_width, original_height = image.size
- aspect_ratio = original_width / original_height
- if (max_width / max_height) > aspect_ratio:
- scale_factor = max_height / original_height
- else:
- scale_factor = max_width / original_width
- scaled_width = int(original_width * scale_factor)
- scaled_height = int(original_height * scale_factor)
- scaled_image = image.resize((scaled_width, scaled_height), Image.Resampling.LANCZOS)
- pad_left = (max_width - scaled_width) // 2
- pad_top = (max_height - scaled_height) // 2
- padded_image = ImageOps.expand(
- scaled_image,
- border=(pad_left, pad_top, max_width - scaled_width - pad_left, max_height - scaled_height - pad_top),
- fill=(0, 0, 0) # Black padding
- )
- return padded_image
-
-
-def scale_coordinates_back(x, y, original_width, original_height, max_width, max_height):
- aspect_ratio = original_width / original_height
- if (max_width / max_height) > aspect_ratio:
- scale_factor = max_height / original_height
- scaled_width = int(original_width * scale_factor)
- scaled_height = max_height
- else:
- scale_factor = max_width / original_width
- scaled_width = max_width
- scaled_height = int(original_height * scale_factor)
- pad_left = (max_width - scaled_width) // 2
- pad_top = (max_height - scaled_height) // 2
- adjusted_x = x - pad_left
- adjusted_y = y - pad_top
- if adjusted_x < 0 or adjusted_x > scaled_width or adjusted_y < 0 or adjusted_y > scaled_height:
- raise ValueError("Coordinates are outside the padded image area")
- original_x = adjusted_x / scale_factor
- original_y = adjusted_y / scale_factor
- return original_x, original_y
-
-
-def extract_click_coordinates(text: str):
- pattern = r'(\d+),\s*(\d+)'
- matches = re.findall(pattern, text)
- x, y = matches[-1]
- return int(x), int(y)
-
-
-def base64_to_image(base64_string):
- base64_string = base64_string.split(",")[1]
- while len(base64_string) % 4 != 0:
- base64_string += '='
- image_data = base64.b64decode(base64_string)
- image = Image.open(BytesIO(image_data))
- return image
-
-
-def image_to_base64(image):
- buffered = BytesIO()
- image.save(buffered, format="PNG")
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
- return img_str
diff --git a/src/askui/py.typed b/src/askui/py.typed
new file mode 100644
index 00000000..0519ecba
--- /dev/null
+++ b/src/askui/py.typed
@@ -0,0 +1 @@
+
\ No newline at end of file
diff --git a/src/askui/reporting/report.py b/src/askui/reporting.py
similarity index 80%
rename from src/askui/reporting/report.py
rename to src/askui/reporting.py
index accb9a76..c274fc80 100644
--- a/src/askui/reporting/report.py
+++ b/src/askui/reporting.py
@@ -1,7 +1,10 @@
+from abc import ABC, abstractmethod
from pathlib import Path
+import random
from jinja2 import Template
from datetime import datetime
-from typing import Any, List, Dict, Optional, Union, Callable
+from typing import List, Dict, Optional, Union
+from typing_extensions import override
import platform
import sys
from importlib.metadata import distributions
@@ -11,49 +14,96 @@
import json
-class SimpleReportGenerator:
- def __init__(self, report_dir: str = "reports", report_callback: Callable[[str | dict[str, Any]], None] | None = None) -> None:
+class Reporter(ABC):
+ @abstractmethod
+ def add_message(
+ self,
+ role: str,
+ content: Union[str, dict, list],
+ image: Optional[Image.Image | list[Image.Image]] = None,
+ ) -> None:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def generate(self) -> None:
+ raise NotImplementedError()
+
+
+class CompositeReporter(Reporter):
+ def __init__(self, reports: list[Reporter] | None = None) -> None:
+ self._reports = reports or []
+
+ @override
+ def add_message(
+ self,
+ role: str,
+ content: Union[str, dict, list],
+ image: Optional[Image.Image | list[Image.Image]] = None,
+ ) -> None:
+ for report in self._reports:
+ report.add_message(role, content, image)
+
+ @override
+ def generate(self) -> None:
+ for report in self._reports:
+ report.generate()
+
+
+class SimpleHtmlReporter(Reporter):
+ def __init__(self, report_dir: str = "reports") -> None:
self.report_dir = Path(report_dir)
self.report_dir.mkdir(exist_ok=True)
self.messages: List[Dict] = []
self.system_info = self._collect_system_info()
- self.report_callback = report_callback
def _collect_system_info(self) -> Dict[str, str]:
"""Collect system and Python information"""
return {
"platform": platform.platform(),
"python_version": sys.version.split()[0],
- "packages": sorted([f"{dist.metadata['Name']}=={dist.version}"
- for dist in distributions()])
+ "packages": sorted(
+ [f"{dist.metadata['Name']}=={dist.version}" for dist in distributions()]
+ ),
}
-
+
def _image_to_base64(self, image: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
-
+
def _format_content(self, content: Union[str, dict, list]) -> str:
"""Format content based on its type"""
if isinstance(content, (dict, list)):
return json.dumps(content, indent=2)
return str(content)
-
- def add_message(self, role: str, content: Union[str, dict, list], image: Optional[Image.Image] = None):
+
+ @override
+ def add_message(
+ self,
+ role: str,
+ content: Union[str, dict, list],
+ image: Optional[Image.Image | list[Image.Image]] = None,
+ ) -> None:
"""Add a message to the report, optionally with an image"""
+ if image is None:
+ _images = []
+ elif isinstance(image, list):
+ _images = image
+ else:
+ _images = [image]
+
message = {
"timestamp": datetime.now(),
"role": role,
"content": self._format_content(content),
"is_json": isinstance(content, (dict, list)),
- "image": self._image_to_base64(image) if image else None
+ "images": [self._image_to_base64(img) for img in _images],
}
self.messages.append(message)
- if self.report_callback is not None:
- self.report_callback(message)
- def generate_report(self) -> str:
+ @override
+ def generate(self) -> None:
"""Generate HTML report using a Jinja template"""
template_str = """
@@ -190,12 +240,12 @@ def generate_report(self) -> str:
{% else %}
{{ msg.content }}
{% endif %}
- {% if msg.image %}
+ {% for image in msg.images %}
-
- {% endif %}
+ {% endfor %}
{% endfor %}
@@ -203,14 +253,13 @@ def generate_report(self) -> str: