diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 41748453..5cd929ac 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -18,7 +18,7 @@ jobs: with: cache: true - run: pdm install - - run: pdm run typecheck:all - run: pdm run format --check - - run: pdm run lint + - run: pdm run typecheck:all + - run: pdm run lint --output-format=github - run: pdm run test:unit diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..132af3da --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + "version": "0.2.0", + "configurations": [ + + { + "name": "Debug Tests", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "args": [ + "-vv", + ], + "console": "integratedTerminal" + }, + ] +} diff --git a/pdm.lock b/pdm.lock index 090558f7..bd737d38 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "chat", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:ee418ee4c04af70fff91a2b30c650ad1d04eb5e4bb4a1cbb47c2f43a1327a1cc" +content_hash = "sha256:fc9c6d1e6e03e722cea57c2a4ede898373b6ca551e8ab76321688873b5325d2d" [[metadata.targets]] requires_python = ">=3.10" @@ -1435,6 +1435,20 @@ files = [ {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +requires_python = ">=3.7" +summary = "pytest plugin to abort hanging tests" +groups = ["test"] +dependencies = [ + "pytest>=7.0.0", +] +files = [ + {file = "pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2"}, + {file = "pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a"}, +] + [[package]] name = "pytest-xdist" version = "3.6.1" diff --git a/pyproject.toml b/pyproject.toml index 47c9fc4f..24348b44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ test = [ "types-protobuf>=4.24.0.20240311", "grpc-stubs>=1.53.0.3", "types-pyperclip>=1.8.2.20240311", + "pytest-timeout>=2.4.0", ] chat = [ "streamlit>=1.42.0", diff --git a/src/askui/agent.py b/src/askui/agent.py index ddaecc8b..eb666f11 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -135,6 +135,7 @@ def _locate( self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") return point + @telemetry.record_call(exclude={"locator", "screenshot"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def locate( self, @@ -475,13 +476,15 @@ def keyboard( self, key: PcKey | ModifierKey, modifier_keys: Optional[list[ModifierKey]] = None, + repeat: Annotated[int, Field(gt=0)] = 1, ) -> None: """ - Simulates pressing a key or key combination on the keyboard. + Simulates pressing (and releasing) a key or key combination on the keyboard. Args: key (PcKey | ModifierKey): The main key to press. This can be a letter, number, special character, or function key. modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Common modifier keys include `'ctrl'`, `'alt'`, `'shift'`. + repeat (int, optional): The number of times to press (and release) the key. Must be greater than `0`. Defaults to `1`. Example: ```python @@ -492,10 +495,18 @@ def keyboard( agent.keyboard('enter') # Press 'Enter' key agent.keyboard('v', ['control']) # Press Ctrl+V (paste) agent.keyboard('s', ['control', 'shift']) # Press Ctrl+Shift+S + agent.keyboard('a', repeat=2) # Press 'a' key twice ``` """ + msg = f"press and release key '{key}'" + if modifier_keys is not None: + modifier_keys_str = ' + '.join(f"'{key}'" for key in modifier_keys) + msg += f" with modifiers key{'s' if len(modifier_keys) > 1 else ''} {modifier_keys_str}" + if repeat > 1: + msg += f" {repeat}x times" + self._reporter.add_message("User", msg) logger.debug("VisionAgent received instruction to press '%s'", key) - self.tools.agent_os.keyboard_tap(key, modifier_keys) + self.tools.agent_os.keyboard_tap(key, modifier_keys, count=repeat) @telemetry.record_call(exclude={"command"}) @validate_call diff --git a/src/askui/models/anthropic/claude.py b/src/askui/models/anthropic/claude.py index 92a53820..be12b62a 100644 --- a/src/askui/models/anthropic/claude.py +++ b/src/askui/models/anthropic/claude.py @@ -1,5 +1,4 @@ import json -import os import anthropic from PIL import Image @@ -10,6 +9,7 @@ UnexpectedResponseToQueryError, ) from askui.logger import logger +from askui.models.anthropic.settings import ClaudeSettings from askui.utils.image_utils import ( ImageSource, image_to_base64, @@ -21,21 +21,19 @@ class ClaudeHandler: - def __init__(self) -> None: - self.model = "claude-3-5-sonnet-20241022" - self.client = anthropic.Anthropic() - self.resolution = (1280, 800) - self.authenticated = True - if os.getenv("ANTHROPIC_API_KEY") is None: - self.authenticated = False + def __init__(self, settings: ClaudeSettings) -> None: + self._settings = settings + self._client = anthropic.Anthropic( + api_key=self._settings.anthropic.api_key.get_secret_value() + ) def _inference( self, base64_image: str, prompt: str, system_prompt: str ) -> list[anthropic.types.ContentBlock]: - message = self.client.messages.create( - model=self.model, - max_tokens=1000, - temperature=0, + message = self._client.messages.create( + model=self._settings.model, + max_tokens=self._settings.max_tokens, + temperature=self._settings.temperature, system=system_prompt, messages=[ { @@ -58,7 +56,8 @@ def _inference( def locate_inference(self, image: Image.Image, locator: str) -> tuple[int, int]: prompt = f"Click on {locator}" - screen_width, screen_height = self.resolution[0], self.resolution[1] + screen_width = self._settings.resolution[0] + screen_height = self._settings.resolution[1] system_prompt = f"Use a mouse and keyboard to interact with a computer, and take screenshots.\n* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try taking another screenshot.\n* The screen's resolution is {screen_width}x{screen_height}.\n* The display number is 0\n* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges unless asked.\n" # noqa: E501 scaled_image = scale_image_with_padding(image, screen_width, screen_height) response = self._inference(image_to_base64(scaled_image), prompt, system_prompt) @@ -79,8 +78,8 @@ def locate_inference(self, image: Image.Image, locator: str) -> tuple[int, int]: def get_inference(self, image: ImageSource, query: str) -> str: scaled_image = scale_image_with_padding( image=image.root, - max_width=self.resolution[0], - max_height=self.resolution[1], + max_width=self._settings.resolution[0], + max_height=self._settings.resolution[1], ) system_prompt = "You are an agent to process screenshots and answer questions about things on the screen or extract information from it. Answer only with the response to the question and keep it short and precise." # noqa: E501 response = self._inference( diff --git a/src/askui/models/anthropic/claude_agent.py b/src/askui/models/anthropic/claude_agent.py index 564f4e69..ae1bb2b9 100644 --- a/src/askui/models/anthropic/claude_agent.py +++ b/src/askui/models/anthropic/claude_agent.py @@ -20,6 +20,9 @@ BetaToolUseBlockParam, ) +from askui.models.anthropic.settings import ( + ClaudeComputerAgentSettings, +) from askui.reporting import Reporter from askui.tools.agent_os import AgentOs @@ -27,8 +30,6 @@ from ...tools.anthropic import ComputerTool, ToolCollection, ToolResult from ...utils.str_utils import truncate_long_strings -COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" -PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" PC_KEY = [ "backspace", "delete", @@ -169,39 +170,41 @@ class ClaudeComputerAgent: - def __init__(self, agent_os: AgentOs, reporter: Reporter) -> None: + def __init__( + self, + agent_os: AgentOs, + reporter: Reporter, + settings: ClaudeComputerAgentSettings, + ) -> None: + self._settings = settings + self._client = Anthropic( + api_key=self._settings.anthropic.api_key.get_secret_value() + ) self._reporter = reporter - self.tool_collection = ToolCollection( + self._tool_collection = ToolCollection( ComputerTool(agent_os), ) - self.system = BetaTextBlockParam( + self._system = BetaTextBlockParam( type="text", text=f"{SYSTEM_PROMPT}", ) - self.enable_prompt_caching = False - self.betas = [COMPUTER_USE_BETA_FLAG] - self.image_truncation_threshold = 10 - self.only_n_most_recent_images = 3 - self.max_tokens = 4096 - self.client = Anthropic() - self.model = "claude-3-5-sonnet-20241022" def step(self, messages: list) -> list: - if self.only_n_most_recent_images: + if self._settings.only_n_most_recent_images: self._maybe_filter_to_n_most_recent_images( messages, - self.only_n_most_recent_images, - min_removal_threshold=self.image_truncation_threshold, + self._settings.only_n_most_recent_images, + min_removal_threshold=self._settings.image_truncation_threshold, ) try: - raw_response = self.client.beta.messages.with_raw_response.create( - max_tokens=self.max_tokens, + raw_response = self._client.beta.messages.with_raw_response.create( + max_tokens=self._settings.max_tokens, messages=messages, - model=self.model, - system=[self.system], - tools=self.tool_collection.to_params(), - betas=self.betas, + model=self._settings.model, + system=[self._system], + tools=self._tool_collection.to_params(), + betas=self._settings.betas, ) except (APIStatusError, APIResponseValidationError) as e: logger.error(e) @@ -224,7 +227,7 @@ def step(self, messages: list) -> list: tool_result_content: list[BetaToolResultBlockParam] = [] for content_block in response_params: if content_block["type"] == "tool_use": - result = self.tool_collection.run( + result = self._tool_collection.run( name=content_block["name"], tool_input=cast("dict[str, Any]", content_block["input"]), ) @@ -237,7 +240,7 @@ def step(self, messages: list) -> list: messages.append(another_new_message) return messages - def run(self, goal: str) -> None: + def act(self, goal: str) -> None: messages = [{"role": "user", "content": goal}] logger.debug(messages[0]) while messages[-1]["role"] == "user": diff --git a/src/askui/models/anthropic/settings.py b/src/askui/models/anthropic/settings.py new file mode 100644 index 00000000..22781f82 --- /dev/null +++ b/src/askui/models/anthropic/settings.py @@ -0,0 +1,29 @@ +from pydantic import BaseModel, Field, SecretStr +from pydantic_settings import BaseSettings + +COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" + + +class AnthropicSettings(BaseSettings): + api_key: SecretStr = Field( + min_length=1, + validation_alias="ANTHROPIC_API_KEY", + ) + + +class ClaudeSettingsBase(BaseModel): + anthropic: AnthropicSettings = Field(default_factory=AnthropicSettings) + model: str = "claude-3-5-sonnet-20241022" + + +class ClaudeSettings(ClaudeSettingsBase): + resolution: tuple[int, int] = Field(default_factory=lambda: (1280, 800)) + max_tokens: int = 1000 + temperature: float = 0.0 + + +class ClaudeComputerAgentSettings(ClaudeSettingsBase): + max_tokens: int = 4096 + only_n_most_recent_images: int = 3 + image_truncation_threshold: int = 10 + betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) diff --git a/src/askui/models/askui/api.py b/src/askui/models/askui/api.py index 4b3f066e..9da69089 100644 --- a/src/askui/models/askui/api.py +++ b/src/askui/models/askui/api.py @@ -1,12 +1,13 @@ import base64 import json as json_lib -import os import pathlib +from functools import cached_property from typing import Any, Type, Union import requests from PIL import Image -from pydantic import RootModel +from pydantic import UUID4, Field, HttpUrl, RootModel, SecretStr +from pydantic_settings import BaseSettings from askui.locators.locators import Locator from askui.locators.serializers import AskUiLocatorSerializer @@ -15,46 +16,50 @@ from askui.utils.image_utils import ImageSource, image_to_base64 from ..types.response_schemas import ResponseSchema, to_response_schema -from .exceptions import ApiResponseError, TokenNotSetError +from .exceptions import ApiResponseError -class AskUiInferenceApi: - def __init__(self, locator_serializer: AskUiLocatorSerializer): - self._locator_serializer = locator_serializer - self.inference_endpoint = os.getenv( - "ASKUI_INFERENCE_ENDPOINT", "https://inference.askui.com" - ) - self.workspace_id = os.getenv("ASKUI_WORKSPACE_ID") - self.token = os.getenv("ASKUI_TOKEN") - self.authenticated = True - if self.workspace_id is None or self.token is None: - logger.warning("ASKUI_WORKSPACE_ID or ASKUI_TOKEN missing.") - self.authenticated = False +class AskUiSettings(BaseSettings): + """Settings for AskUI API.""" - def _build_askui_token_auth_header( - self, bearer_token: str | None = None - ) -> dict[str, str]: - if bearer_token is not None: - return {"Authorization": f"Bearer {bearer_token}"} + inference_endpoint: HttpUrl = Field( + default_factory=lambda: HttpUrl("https://inference.askui.com"), + validation_alias="ASKUI_INFERENCE_ENDPOINT", + ) + workspace_id: UUID4 = Field( + validation_alias="ASKUI_WORKSPACE_ID", + ) + token: SecretStr = Field( + validation_alias="ASKUI_TOKEN", + ) - if self.token is None: - raise TokenNotSetError - token_base64 = base64.b64encode(self.token.encode("utf-8")).decode("utf-8") - return {"Authorization": f"Basic {token_base64}"} + @cached_property + def authorization_header(self) -> str: + token_str = self.token.get_secret_value() + token_base64 = base64.b64encode(token_str.encode()).decode() + return f"Basic {token_base64}" - def _build_base_url(self, endpoint: str) -> str: - return ( - f"{self.inference_endpoint}/api/v3/workspaces/" - f"{self.workspace_id}/{endpoint}" - ) + @cached_property + def base_url(self) -> str: + return f"{self.inference_endpoint}/api/v1/workspaces/{self.workspace_id}" + + +class AskUiInferenceApi: + def __init__( + self, + locator_serializer: AskUiLocatorSerializer, + settings: AskUiSettings, + ) -> None: + self._locator_serializer = locator_serializer + self._settings = settings def _request(self, endpoint: str, json: dict[str, Any] | None = None) -> Any: response = requests.post( - self._build_base_url(endpoint), + f"{self._settings.base_url}/{endpoint}", json=json, headers={ "Content-Type": "application/json", - **self._build_askui_token_auth_header(), + "Authorization": self._settings.authorization_header, }, timeout=30, ) diff --git a/src/askui/models/exceptions.py b/src/askui/models/exceptions.py new file mode 100644 index 00000000..baa3419f --- /dev/null +++ b/src/askui/models/exceptions.py @@ -0,0 +1,15 @@ +from askui.exceptions import AutomationError +from askui.models.models import ModelComposition + + +class InvalidModelError(AutomationError): + """Exception raised when an invalid model is used. + + Args: + model (str | ModelComposition): The model that was used. + """ + + def __init__(self, model: str | ModelComposition): + self.model = model + model_str = model if isinstance(model, str) else model.model_dump_json() + super().__init__(f"Invalid model: {model_str}") diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 553659b1..9949dcd1 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -8,7 +8,6 @@ class ModelName(str, Enum): ANTHROPIC__CLAUDE__3_5__SONNET__20241022 = "anthropic-claude-3-5-sonnet-20241022" - ANTHROPIC = "anthropic" ASKUI = "askui" ASKUI__AI_ELEMENT = "askui-ai-element" ASKUI__COMBO = "askui-combo" diff --git a/src/askui/models/router.py b/src/askui/models/router.py index 4d411417..c5145433 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from functools import cached_property from typing import Type from PIL import Image @@ -7,7 +8,13 @@ from askui.container import telemetry from askui.locators.locators import AiElement, Locator, Prompt, Text from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer +from askui.models.anthropic.settings import ( + AnthropicSettings, + ClaudeComputerAgentSettings, + ClaudeSettings, +) from askui.models.askui.ai_element_utils import AiElementCollection +from askui.models.exceptions import InvalidModelError from askui.models.models import ModelComposition, ModelName from askui.models.types.response_schemas import ResponseSchema from askui.reporting import CompositeReporter, Reporter @@ -18,9 +25,9 @@ from ..logger import logger from .anthropic.claude import ClaudeHandler from .anthropic.claude_agent import ClaudeComputerAgent -from .askui.api import AskUiInferenceApi +from .askui.api import AskUiInferenceApi, AskUiSettings from .huggingface.spaces_api import HFSpacesHandler -from .ui_tars_ep.ui_tars_api import UITarsAPIHandler +from .ui_tars_ep.ui_tars_api import UiTarsApiHandler, UiTarsApiHandlerSettings Point = tuple[int, int] """ @@ -52,10 +59,6 @@ def locate( def is_responsible(self, model: ModelComposition | str | None = None) -> bool: pass - @abstractmethod - def is_authenticated(self) -> bool: - pass - class AskUiModelRouter(GroundingModelRouter): def __init__(self, inference_api: AskUiInferenceApi): @@ -75,12 +78,6 @@ def locate( locator: str | Locator, model: ModelComposition | str | None = None, ) -> Point: - if not self._inference_api.authenticated: - error_msg = ( - "NoAskUIAuthenticationSet! Please set 'AskUI ASKUI_WORKSPACE_ID' or " - "'ASKUI_TOKEN' as env variables!" - ) - raise AutomationError(error_msg) if not isinstance(model, str) or model == ModelName.ASKUI: logger.debug("Routing locate prediction to askui") locator = Text(locator) if isinstance(locator, str) else locator @@ -113,16 +110,17 @@ def locate( _locator = AiElement(locator) x, y = self._inference_api.predict(screenshot, _locator) return handle_response((x, y), _locator) - error_msg = f'Invalid model: "{model}"' - raise AutomationError(error_msg) + raise InvalidModelError(model) @override def is_responsible(self, model: ModelComposition | str | None = None) -> bool: - return not isinstance(model, str) or model.startswith(ModelName.ASKUI) - - @override - def is_authenticated(self) -> bool: - return self._inference_api.authenticated + return not isinstance(model, str) or model in [ + ModelName.ASKUI, + ModelName.ASKUI__AI_ELEMENT, + ModelName.ASKUI__OCR, + ModelName.ASKUI__COMBO, + ModelName.ASKUI__PTA, + ] class ModelRouter: @@ -131,36 +129,99 @@ def __init__( tools: AgentToolbox, grounding_model_routers: list[GroundingModelRouter] | None = None, reporter: Reporter | None = None, + anthropic_settings: AnthropicSettings | None = None, + askui_inference_api: AskUiInferenceApi | None = None, + askui_settings: AskUiSettings | None = None, + claude: ClaudeHandler | None = None, + claude_computer_agent: ClaudeComputerAgent | None = None, + huggingface_spaces: HFSpacesHandler | None = None, + tars: UiTarsApiHandler | None = None, ): - _reporter = reporter or CompositeReporter() - self._askui = AskUiInferenceApi( + self._tools = tools + self._reporter = reporter or CompositeReporter() + self._grounding_model_routers_base = grounding_model_routers + self._anthropic_settings_base = anthropic_settings + self._askui_inference_api_base = askui_inference_api + self._askui_settings_base = askui_settings + self._claude_base = claude + self._claude_computer_agent_base = claude_computer_agent + self._huggingface_spaces = huggingface_spaces or HFSpacesHandler() + self._tars_base = tars + self._locator_serializer = VlmLocatorSerializer() + + @cached_property + def _anthropic_settings(self) -> AnthropicSettings: + if self._anthropic_settings_base is not None: + return self._anthropic_settings_base + return AnthropicSettings() + + @cached_property + def _askui_inference_api(self) -> AskUiInferenceApi: + if self._askui_inference_api_base is not None: + return self._askui_inference_api_base + return AskUiInferenceApi( locator_serializer=AskUiLocatorSerializer( ai_element_collection=AiElementCollection(), - reporter=_reporter, + reporter=self._reporter, ), + settings=self._askui_settings, + ) + + @cached_property + def _askui_settings(self) -> AskUiSettings: + if self._askui_settings_base is not None: + return self._askui_settings_base + return AskUiSettings() + + @cached_property + def _claude(self) -> ClaudeHandler: + if self._claude_base is not None: + return self._claude_base + claude_settings = ClaudeSettings( + anthropic=self._anthropic_settings, ) - self._grounding_model_routers = grounding_model_routers or [ - AskUiModelRouter(inference_api=self._askui) + return ClaudeHandler(settings=claude_settings) + + @cached_property + def _claude_computer_agent(self) -> ClaudeComputerAgent: + if self._claude_computer_agent_base is not None: + return self._claude_computer_agent_base + claude_computer_agent_settings = ClaudeComputerAgentSettings( + anthropic=self._anthropic_settings, + ) + return ClaudeComputerAgent( + agent_os=self._tools.agent_os, + reporter=self._reporter, + settings=claude_computer_agent_settings, + ) + + @cached_property + def _grounding_model_routers(self) -> list[GroundingModelRouter]: + return self._grounding_model_routers_base or [ + AskUiModelRouter(inference_api=self._askui_inference_api) ] - self._claude = ClaudeHandler() - self._huggingface_spaces = HFSpacesHandler() - self._tars = UITarsAPIHandler(agent_os=tools.agent_os, reporter=_reporter) - self._claude_computer_agent = ClaudeComputerAgent( - agent_os=tools.agent_os, reporter=_reporter + + @cached_property + def _tars(self) -> UiTarsApiHandler: + if self._tars_base is not None: + return self._tars_base + tars_settings = UiTarsApiHandlerSettings() + return UiTarsApiHandler( + agent_os=self._tools.agent_os, + reporter=self._reporter, + settings=tars_settings, ) - self._locator_serializer = VlmLocatorSerializer() def act(self, goal: str, model: ModelComposition | str | None = None) -> None: - if self._tars.authenticated and model == ModelName.TARS: - self._tars.act(goal) - if self._claude.authenticated and ( - model is None - or isinstance(model, str) - and model.startswith(ModelName.ANTHROPIC) - ): - self._claude_computer_agent.run(goal) - error_msg = f"Invalid model for act: {model}" - raise AutomationError(error_msg) + if model == ModelName.TARS: + logger.debug(f"Routing act prediction to {ModelName.TARS}") + return self._tars.act(goal) + if model == ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 or model is None: + logger.debug( + f"Routing act prediction to {ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022}" # noqa: E501 + ) + return self._claude_computer_agent.act(goal) + raise InvalidModelError(model) def get_inference( self, @@ -169,35 +230,31 @@ def get_inference( response_schema: Type[ResponseSchema] | None = None, model: ModelComposition | str | None = None, ) -> ResponseSchema | str: - if self._tars.authenticated and model == ModelName.TARS: - if response_schema not in [str, None]: - error_msg = ( - "(Non-String) Response schema is not yet supported for " - "UI-TARS models." - ) - raise NotImplementedError(error_msg) + if model in [ + ModelName.TARS, + ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, + ] and response_schema not in [str, None]: + error_msg = ( + "(Non-String) Response schema is not yet supported for " + f'"{model}" model.' + ) + raise NotImplementedError(error_msg) + if model == ModelName.TARS: + logger.debug(f"Routing get inference to {ModelName.TARS}") return self._tars.get_inference(image=image, query=query) - if self._claude.authenticated and ( - isinstance(model, str) and model.startswith(ModelName.ANTHROPIC) - ): - if response_schema not in [str, None]: - error_msg = ( - "(Non-String) Response schema is not yet supported for " - "Anthropic models." - ) - raise NotImplementedError(error_msg) + if model == ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022: + logger.debug( + f"Routing get inference to {ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022}" # noqa: E501 + ) return self._claude.get_inference(image=image, query=query) - if self._askui.authenticated and (model == ModelName.ASKUI or model is None): - return self._askui.get_inference( + if model == ModelName.ASKUI or model is None: + logger.debug(f"Routing get inference to {ModelName.ASKUI}") + return self._askui_inference_api.get_inference( image=image, query=query, response_schema=response_schema, ) - error_msg = ( - "Executing get commands requires to authenticate with an Automation " - f"Model Provider supporting it: {model}" - ) - raise AutomationError(error_msg) + raise InvalidModelError(model) def _serialize_locator(self, locator: str | Locator) -> str: if isinstance(locator, Locator): @@ -211,63 +268,53 @@ def locate( # noqa: C901 locator: str | Locator, model: ModelComposition | str | None = None, ) -> Point: - x: int | None = None - y: int | None = None - if ( - isinstance(model, str) - and model in self._huggingface_spaces.get_spaces_names() - ): - x, y = self._huggingface_spaces.predict( + point: tuple[int | None, int | None] | None = None + if model in self._huggingface_spaces.get_spaces_names(): + logger.debug(f"Routing locate prediction to {model}") + point = self._huggingface_spaces.predict( screenshot=screenshot, locator=self._serialize_locator(locator), - model_name=model, + model_name=model, # type: ignore ) - return handle_response((x, y), locator) - if isinstance(model, str): - if model.startswith(ModelName.ANTHROPIC) and not self._claude.authenticated: - error_msg = ( - "You need to provide Anthropic credentials to use Anthropic models." - ) - raise AutomationError(error_msg) - if model.startswith(ModelName.TARS) and not self._tars.authenticated: - error_msg = ( - "You need to provide UI-TARS HF Endpoint credentials to use " - "UI-TARS models." - ) - raise AutomationError(error_msg) - if self._tars.authenticated and model == ModelName.TARS: - x, y = self._tars.locate_prediction( + return handle_response(point, locator) + if model == ModelName.TARS: + logger.debug(f"Routing locate prediction to {ModelName.TARS}") + point = self._tars.locate_prediction( screenshot, self._serialize_locator(locator) ) - return handle_response((x, y), locator) - if ( - self._claude.authenticated - and isinstance(model, str) - and model.startswith(ModelName.ANTHROPIC) - ): - logger.debug("Routing locate prediction to Anthropic") - x, y = self._claude.locate_inference( + return handle_response(point, locator) + if model == ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022: + logger.debug( + f"Routing locate prediction to {ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022}" # noqa: E501 + ) + point = self._claude.locate_inference( screenshot, self._serialize_locator(locator) ) - return handle_response((x, y), locator) - - for grounding_model_router in self._grounding_model_routers: - if ( - grounding_model_router.is_responsible(model) - and grounding_model_router.is_authenticated() - ): - return grounding_model_router.locate(screenshot, locator, model) - - if model is None: - if self._claude.authenticated: - logger.debug("Routing locate prediction to Anthropic") - x, y = self._claude.locate_inference( - screenshot, self._serialize_locator(locator) - ) - return handle_response((x, y), locator) + return handle_response(point, locator) + point = self._try_locating_using_grounding_model(screenshot, locator, model) + if point: + return handle_response(point, locator) + if not point and model is None: + logger.debug( + f"Routing locate prediction to {ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022}" # noqa: E501 + ) + point = self._claude.locate_inference( + screenshot, self._serialize_locator(locator) + ) + return handle_response(point, locator) + raise InvalidModelError(model) - error_msg = ( - "Executing locate commands requires to authenticate with an " - "Automation Model Provider." - ) - raise AutomationError(error_msg) + def _try_locating_using_grounding_model( + self, + screenshot: Image.Image, + locator: str | Locator, + model: ModelComposition | str | None = None, + ) -> Point | None: + try: + for grounding_model_router in self._grounding_model_routers: + if grounding_model_router.is_responsible(model): + return grounding_model_router.locate(screenshot, locator, model) + except (InvalidModelError, ValueError): + if model is not None: + raise + return None diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 0bb34ce1..24018c97 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -1,4 +1,3 @@ -import os import pathlib import re import time @@ -6,6 +5,8 @@ from openai import OpenAI from PIL import Image +from pydantic import Field, HttpUrl, SecretStr +from pydantic_settings import BaseSettings from askui.exceptions import NoResponseToQueryError from askui.reporting import Reporter @@ -16,20 +17,35 @@ from .prompts import PROMPT, PROMPT_QA -class UITarsAPIHandler: - def __init__(self, agent_os: AgentOs, reporter: Reporter) -> None: +class UiTarsApiHandlerSettings(BaseSettings): + """Settings for TARS API.""" + + tars_url: HttpUrl = Field( + validation_alias="TARS_URL", + ) + tars_api_key: SecretStr = Field( + min_length=1, + validation_alias="TARS_API_KEY", + ) + + +class UiTarsApiHandler: + def __init__( + self, + agent_os: AgentOs, + reporter: Reporter, + settings: UiTarsApiHandlerSettings, + ) -> None: self._agent_os = agent_os self._reporter = reporter - if os.getenv("TARS_URL") is None or os.getenv("TARS_API_KEY") is None: - self.authenticated = False - else: - self.authenticated = True - self.client = OpenAI( - base_url=os.getenv("TARS_URL"), api_key=os.getenv("TARS_API_KEY") - ) + self._settings = settings + self._client = OpenAI( + api_key=self._settings.tars_api_key.get_secret_value(), + base_url=str(self._settings.tars_url), + ) def _predict(self, image_url: str, instruction: str, prompt: str) -> str | None: - chat_completion = self.client.chat.completions.create( + chat_completion = self._client.chat.completions.create( model="tgi", messages=[ { @@ -173,7 +189,7 @@ def filter_message_thread( def execute_act(self, message_history: list[dict[str, Any]]) -> None: message_history = self.filter_message_thread(message_history) - chat_completion = self.client.chat.completions.create( + chat_completion = self._client.chat.completions.create( model="tgi", messages=message_history, top_p=None, diff --git a/src/askui/telemetry/processors.py b/src/askui/telemetry/processors.py index 695fb733..88560bfd 100644 --- a/src/askui/telemetry/processors.py +++ b/src/askui/telemetry/processors.py @@ -3,7 +3,7 @@ from typing import Any, TypedDict import httpx -from pydantic import BaseModel, HttpUrl +from pydantic import BaseModel, Field, HttpUrl from askui.logger import logger from askui.telemetry.context import TelemetryContext @@ -30,8 +30,12 @@ class TelemetryEvent(TypedDict): class SegmentSettings(BaseModel): - api_url: HttpUrl = HttpUrl("https://tracking.askui.com/v1") + api_url: HttpUrl = Field( + default_factory=lambda: HttpUrl("https://tracking.askui.com") + ) write_key: str = "Iae4oWbOo509Acu5ZeEb2ihqSpemjnhY" + timeout: int = 10 + max_retries: int = 3 class Segment(TelemetryProcessor): @@ -42,6 +46,9 @@ def __init__(self, settings: SegmentSettings) -> None: self._analytics = analytics self._analytics.write_key = settings.write_key + self._analytics.host = settings.api_url.encoded_string() + self._analytics.timeout = settings.timeout + self._analytics.max_retries = settings.max_retries def record_event( self, @@ -80,7 +87,7 @@ def record_event( logger.debug(f'Failed to track event "{name}" using Segment: {e}') def flush(self) -> None: - self._analytics.flush() + self._analytics.shutdown() class InMemoryProcessor(TelemetryProcessor): diff --git a/src/askui/telemetry/telemetry.py b/src/askui/telemetry/telemetry.py index 62809daf..e24d297f 100644 --- a/src/askui/telemetry/telemetry.py +++ b/src/askui/telemetry/telemetry.py @@ -96,8 +96,22 @@ def __init__(self, settings: TelemetrySettings) -> None: self._call_stack = CallStack() self._context = self._init_context() + def set_processors(self, processors: list[TelemetryProcessor]) -> None: + """Set the telemetry processors that will be called in order + + *IMPORTANT*: This will replace the existing processors. + + Args: + processors (list[TelemetryProcessor]): The list of telemetry processors to set (may be empty) + """ + self._processors = processors + def add_processor(self, processor: TelemetryProcessor) -> None: - """Add a telemetry processor that will be called in order of addition""" + """Add a telemetry processor that will be called in order of addition + + Args: + processor (TelemetryProcessor): The telemetry processor to add + """ self._processors.append(processor) def _init_context(self) -> TelemetryContext: diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index 0b3b4e51..a88ddbc1 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -276,7 +276,10 @@ def keyboard_release( @abstractmethod def keyboard_tap( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + self, + key: PcKey | ModifierKey, + modifier_keys: list[ModifierKey] | None = None, + count: int = 1, ) -> None: """ Simulates pressing and immediately releasing a keyboard key. @@ -285,6 +288,7 @@ def keyboard_tap( key (PcKey | ModifierKey): The key to tap. modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Defaults to `None`. + count (int, optional): The number of times to tap the key. Defaults to `1`. """ raise NotImplementedError diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index 342ba20b..9fd1b4d1 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -648,7 +648,10 @@ def keyboard_release( @telemetry.record_call() @override def keyboard_tap( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + self, + key: PcKey | ModifierKey, + modifier_keys: list[ModifierKey] | None = None, + count: int = 1, ) -> None: """ Press and immediately release a keyboard key. @@ -657,18 +660,23 @@ def keyboard_tap( key (PcKey | ModifierKey): The key to tap. modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Defaults to `None`. + count (int, optional): The number of times to tap the key. Defaults to `1`. """ - self._reporter.add_message("AgentOS", f'keyboard_tap("{key}", {modifier_keys})') + self._reporter.add_message( + "AgentOS", + f'keyboard_tap("{key}", {modifier_keys}, {count})', + ) if modifier_keys is None: modifier_keys = [] - self._run_recorder_action( - acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, - action_parameters=controller_v1_pbs.ActionParameters( - keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( - keyName=key, modifierKeyNames=modifier_keys - ) - ), - ) + for _ in range(count): + self._run_recorder_action( + acion_class_id=controller_v1_pbs.ActionClassID_KeyboardKey_PressAndRelease, + action_parameters=controller_v1_pbs.ActionParameters( + keyboardKeyPressAndRelease=controller_v1_pbs.ActionParameters_KeyboardKey_PressAndRelease( + keyName=key, modifierKeyNames=modifier_keys + ) + ), + ) @telemetry.record_call() @override diff --git a/tests/conftest.py b/tests/conftest.py index d61c0e4e..4dd2b56c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,15 @@ def path_fixtures_images(path_fixtures: pathlib.Path) -> pathlib.Path: return path_fixtures / "images" +@pytest.fixture +def github_login_screenshot(path_fixtures: pathlib.Path) -> Image.Image: + """Fixture providing the GitHub login screenshot.""" + screenshot_path = ( + path_fixtures / "screenshots" / "macos__chrome__github_com__login.png" + ) + return Image.open(screenshot_path) + + @pytest.fixture def path_fixtures_github_com__icon(path_fixtures_images: pathlib.Path) -> pathlib.Path: """Fixture providing the path to the github com icon image.""" @@ -51,3 +60,10 @@ def model_router_mock(mocker: MockerFixture) -> ModelRouter: "Mock response" # Return fixed response for all get_inference calls ) return cast("ModelRouter", mock) + + +@pytest.fixture(autouse=True) +def disable_telemetry() -> None: + from askui.container import telemetry + + telemetry.set_processors([]) diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index 784fb9df..a2d0bdf1 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -10,7 +10,7 @@ from askui.agent import VisionAgent from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection -from askui.models.askui.api import AskUiInferenceApi +from askui.models.askui.api import AskUiInferenceApi, AskUiSettings from askui.models.router import AskUiModelRouter, ModelRouter from askui.reporting import Reporter, SimpleHtmlReporter from askui.tools.toolbox import AgentToolbox @@ -43,7 +43,10 @@ def vision_agent( serializer = AskUiLocatorSerializer( ai_element_collection=ai_element_collection, reporter=reporter ) - inference_api = AskUiInferenceApi(locator_serializer=serializer) + inference_api = AskUiInferenceApi( + locator_serializer=serializer, + settings=AskUiSettings(), + ) model_router = ModelRouter( tools=agent_toolbox_mock, reporter=reporter, @@ -53,12 +56,3 @@ def vision_agent( reporters=[reporter], model_router=model_router, tools=agent_toolbox_mock ) as agent: yield agent - - -@pytest.fixture -def github_login_screenshot(path_fixtures: pathlib.Path) -> PILImage.Image: - """Fixture providing the GitHub login screenshot.""" - screenshot_path = ( - path_fixtures / "screenshots" / "macos__chrome__github_com__login.png" - ) - return PILImage.open(screenshot_path) diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index f65fad53..ecab67b5 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -20,7 +20,9 @@ class BrowserContextResponse(ResponseSchemaBase): browser_type: Literal["chrome", "firefox", "edge", "safari"] -@pytest.mark.parametrize("model", [None, ModelName.ASKUI, ModelName.ANTHROPIC]) +@pytest.mark.parametrize( + "model", [None, ModelName.ASKUI, ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022] +) def test_get( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, @@ -91,7 +93,7 @@ def test_get_with_response_schema_with_anthropic_model_raises_not_implemented( "What is the current url shown in the url bar?", image=github_login_screenshot, response_schema=UrlResponse, - model=ModelName.ANTHROPIC, + model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) diff --git a/tests/e2e/test_telemetry.py b/tests/e2e/test_telemetry.py new file mode 100644 index 00000000..dceb13be --- /dev/null +++ b/tests/e2e/test_telemetry.py @@ -0,0 +1,30 @@ +import logging + +import pytest +from PIL import Image + +from askui import locators as loc +from askui.agent import VisionAgent +from askui.container import telemetry +from askui.telemetry.processors import Segment, SegmentSettings +from askui.tools.toolbox import AgentToolbox + + +@pytest.mark.timeout(60) +def test_telemetry_with_nonexistent_domain_should_not_block( + github_login_screenshot: Image.Image, + agent_toolbox_mock: AgentToolbox, +) -> None: + telemetry.set_processors( + [ + Segment( + SegmentSettings( + api_url="https://this-domain-does-not-exist-123456789.com", + write_key="1234567890", + ) + ) + ] + ) + with VisionAgent(tools=agent_toolbox_mock, log_level=logging.DEBUG) as agent: + agent.locate(loc.Text(), screenshot=github_login_screenshot) + assert True diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a4699d3a..944f7b69 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -3,5 +3,4 @@ @pytest.fixture(autouse=True) def set_env_variable(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("ASKUI__VA__TELEMETRY__ENABLED", "False") monkeypatch.setenv("ASKUI_WORKSPACE_ID", "test_workspace_id") diff --git a/tests/unit/models/test_router.py b/tests/unit/models/test_router.py new file mode 100644 index 00000000..d06ec0db --- /dev/null +++ b/tests/unit/models/test_router.py @@ -0,0 +1,422 @@ +"""Unit tests for the ModelRouter class.""" + +import os +import uuid +from typing import cast +from unittest.mock import MagicMock, patch + +import pytest +from PIL import Image +from pydantic import ValidationError +from pytest_mock import MockerFixture + +from askui.models.anthropic.claude import ClaudeHandler +from askui.models.anthropic.claude_agent import ClaudeComputerAgent +from askui.models.askui.api import AskUiInferenceApi +from askui.models.exceptions import InvalidModelError +from askui.models.huggingface.spaces_api import HFSpacesHandler +from askui.models.models import ModelName +from askui.models.router import ModelRouter +from askui.models.types.response_schemas import ResponseSchemaBase +from askui.models.ui_tars_ep.ui_tars_api import UiTarsApiHandler +from askui.reporting import CompositeReporter +from askui.tools.toolbox import AgentToolbox +from askui.utils.image_utils import ImageSource + +# Test UUID for workspace_id +TEST_WORKSPACE_ID = uuid.uuid4() + + +@pytest.fixture +def mock_image() -> Image.Image: + """Fixture providing a mock PIL Image.""" + return Image.new("RGB", (100, 100)) + + +@pytest.fixture +def mock_image_source(mock_image: Image.Image) -> ImageSource: + """Fixture providing a mock ImageSource.""" + return ImageSource(root=mock_image) + + +@pytest.fixture +def mock_askui_inference_api(mocker: MockerFixture) -> AskUiInferenceApi: + """Fixture providing a mock AskUI inference API.""" + mock = cast("AskUiInferenceApi", mocker.MagicMock(spec=AskUiInferenceApi)) + mock.predict.return_value = (50, 50) # type: ignore[attr-defined] + mock.get_inference.return_value = "Mock response" # type: ignore[attr-defined] + return mock + + +@pytest.fixture +def mock_claude(mocker: MockerFixture) -> ClaudeHandler: + """Fixture providing a mock Claude handler.""" + mock = cast("ClaudeHandler", mocker.MagicMock(spec=ClaudeHandler)) + mock.locate_inference.return_value = (50, 50) # type: ignore[attr-defined] + mock.get_inference.return_value = "Mock response" # type: ignore[attr-defined] + return mock + + +@pytest.fixture +def mock_claude_agent(mocker: MockerFixture) -> ClaudeComputerAgent: + """Fixture providing a mock Claude computer agent.""" + mock = cast("ClaudeComputerAgent", mocker.MagicMock(spec=ClaudeComputerAgent)) + mock.act = MagicMock(return_value=None) # type: ignore[method-assign] + return mock + + +@pytest.fixture +def mock_tars(mocker: MockerFixture) -> UiTarsApiHandler: + """Fixture providing a mock TARS API handler.""" + mock = cast("UiTarsApiHandler", mocker.MagicMock(spec=UiTarsApiHandler)) + mock.locate_prediction.return_value = (50, 50) # type: ignore[attr-defined] + mock.get_inference.return_value = "Mock response" # type: ignore[attr-defined] + mock.act = MagicMock(return_value=None) # type: ignore[method-assign] + return mock + + +@pytest.fixture +def mock_hf_spaces(mocker: MockerFixture) -> HFSpacesHandler: + """Fixture providing a mock HuggingFace spaces handler.""" + mock = cast("HFSpacesHandler", mocker.MagicMock(spec=HFSpacesHandler)) + mock.predict.return_value = (50, 50) # type: ignore[attr-defined] + mock.get_spaces_names.return_value = ["hf-space-1", "hf-space-2"] # type: ignore[attr-defined] + return mock + + +@pytest.fixture +def model_router( + agent_toolbox_mock: AgentToolbox, + mock_askui_inference_api: AskUiInferenceApi, + mock_claude: ClaudeHandler, + mock_claude_agent: ClaudeComputerAgent, + mock_tars: UiTarsApiHandler, + mock_hf_spaces: HFSpacesHandler, + mocker: MockerFixture, +) -> ModelRouter: + """Fixture providing a ModelRouter instance with mocked dependencies.""" + return ModelRouter( + tools=agent_toolbox_mock, + reporter=CompositeReporter(), + askui_inference_api=mock_askui_inference_api, + claude=mock_claude, + claude_computer_agent=mock_claude_agent, + tars=mock_tars, + huggingface_spaces=mock_hf_spaces, + askui_settings=mocker.MagicMock(workspace_id=TEST_WORKSPACE_ID), + ) + + +class TestModelRouter: + """Test class for ModelRouter.""" + + def test_locate_with_askui_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_askui_inference_api: AskUiInferenceApi, + ) -> None: + """Test locating elements using AskUI model.""" + locator = "test locator" + x, y = model_router.locate(mock_image, locator, ModelName.ASKUI) + assert x == 50 + assert y == 50 + mock_askui_inference_api.predict.assert_called_once() # type: ignore + + def test_locate_with_askui_pta_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_askui_inference_api: AskUiInferenceApi, + ) -> None: + """Test locating elements using AskUI PTA model.""" + locator = "test locator" + x, y = model_router.locate(mock_image, locator, ModelName.ASKUI__PTA) + assert x == 50 + assert y == 50 + mock_askui_inference_api.predict.assert_called_once() # type: ignore + + def test_locate_with_askui_ocr_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_askui_inference_api: AskUiInferenceApi, + ) -> None: + """Test locating elements using AskUI OCR model.""" + locator = "test locator" + x, y = model_router.locate(mock_image, locator, ModelName.ASKUI__OCR) + assert x == 50 + assert y == 50 + mock_askui_inference_api.predict.assert_called_once() # type: ignore + + def test_locate_with_askui_combo_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_askui_inference_api: AskUiInferenceApi, + ) -> None: + """Test locating elements using AskUI combo model.""" + locator = "test locator" + x, y = model_router.locate(mock_image, locator, ModelName.ASKUI__COMBO) + assert x == 50 + assert y == 50 + mock_askui_inference_api.predict.assert_called_once() # type: ignore + + def test_locate_with_askui_ai_element_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_askui_inference_api: AskUiInferenceApi, + ) -> None: + """Test locating elements using AskUI AI element model.""" + locator = "test locator" + x, y = model_router.locate(mock_image, locator, ModelName.ASKUI__AI_ELEMENT) + assert x == 50 + assert y == 50 + mock_askui_inference_api.predict.assert_called_once() # type: ignore + + def test_locate_with_tars_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_tars: UiTarsApiHandler, + ) -> None: + """Test locating elements using TARS model.""" + locator = "test locator" + x, y = model_router.locate(mock_image, locator, ModelName.TARS) + assert x == 50 + assert y == 50 + mock_tars.locate_prediction.assert_called_once() # type: ignore + + def test_locate_with_claude_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_claude: ClaudeHandler, + ) -> None: + """Test locating elements using Claude model.""" + locator = "test locator" + x, y = model_router.locate( + mock_image, locator, ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + ) + assert x == 50 + assert y == 50 + mock_claude.locate_inference.assert_called_once() # type: ignore + + def test_locate_with_hf_space_model( + self, + model_router: ModelRouter, + mock_image: Image.Image, + mock_hf_spaces: HFSpacesHandler, + ) -> None: + """Test locating elements using HuggingFace space model.""" + locator = "test locator" + x, y = model_router.locate(mock_image, locator, "hf-space-1") + assert x == 50 + assert y == 50 + mock_hf_spaces.predict.assert_called_once() # type: ignore + + def test_locate_with_invalid_model( + self, model_router: ModelRouter, mock_image: Image.Image + ) -> None: + """Test that locating with invalid model raises InvalidModelError.""" + with pytest.raises(InvalidModelError): + model_router.locate(mock_image, "test locator", "invalid-model") + + def test_get_inference_with_askui_model( + self, + model_router: ModelRouter, + mock_image_source: ImageSource, + mock_askui_inference_api: AskUiInferenceApi, + ) -> None: + """Test getting inference using AskUI model.""" + response = model_router.get_inference( + "test query", mock_image_source, model=ModelName.ASKUI + ) + assert response == "Mock response" + mock_askui_inference_api.get_inference.assert_called_once() # type: ignore + + def test_get_inference_with_tars_model( + self, + model_router: ModelRouter, + mock_image_source: ImageSource, + mock_tars: UiTarsApiHandler, + ) -> None: + """Test getting inference using TARS model.""" + response = model_router.get_inference( + "test query", mock_image_source, model=ModelName.TARS + ) + assert response == "Mock response" + mock_tars.get_inference.assert_called_once() # type: ignore + + def test_get_inference_with_claude_model( + self, + model_router: ModelRouter, + mock_image_source: ImageSource, + mock_claude: ClaudeHandler, + ) -> None: + """Test getting inference using Claude model.""" + response = model_router.get_inference( + "test query", + mock_image_source, + model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, + ) + assert response == "Mock response" + mock_claude.get_inference.assert_called_once() # type: ignore + + def test_get_inference_with_invalid_model( + self, model_router: ModelRouter, mock_image_source: ImageSource + ) -> None: + """Test that getting inference with invalid model raises InvalidModelError.""" + with pytest.raises(InvalidModelError): + model_router.get_inference( + "test query", mock_image_source, model="invalid-model" + ) + + def test_get_inference_with_response_schema_not_implemented( + self, model_router: ModelRouter, mock_image_source: ImageSource + ) -> None: + """ + Test that getting inference with response schema for non-AskUI models raises + NotImplementedError. + """ + + class TestSchema(ResponseSchemaBase): + pass + + with pytest.raises(NotImplementedError): + model_router.get_inference( + "test query", + mock_image_source, + response_schema=TestSchema, + model=ModelName.TARS, + ) + + def test_act_with_tars_model( + self, model_router: ModelRouter, mock_tars: UiTarsApiHandler + ) -> None: + """Test acting using TARS model.""" + model_router.act("test goal", ModelName.TARS) + mock_tars.act.assert_called_once_with("test goal") # type: ignore + + def test_act_with_claude_model( + self, model_router: ModelRouter, mock_claude_agent: ClaudeComputerAgent + ) -> None: + """Test acting using Claude model.""" + model_router.act( + "test goal", ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + ) + mock_claude_agent.act.assert_called_once_with("test goal") # type: ignore + + def test_act_with_invalid_model(self, model_router: ModelRouter) -> None: + """Test that acting with invalid model raises InvalidModelError.""" + with pytest.raises(InvalidModelError): + model_router.act("test goal", "invalid-model") + + def test_act_with_missing_anthropic_credentials( + self, model_router: ModelRouter + ) -> None: + """ + Test that acting with Claude model raises ValidationError when credentials are + missing. + """ + with patch.dict(os.environ, {}, clear=True): + router = ModelRouter( + tools=model_router._tools, reporter=model_router._reporter + ) + with pytest.raises(ValidationError, match="ANTHROPIC_API_KEY"): + router.act( + "test goal", ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + ) + + def test_act_with_default_missing_credentials( + self, model_router: ModelRouter + ) -> None: + """ + Test that acting with default model raises ValidationError when credentials are + missing. + """ + with patch.dict(os.environ, {}, clear=True): + router = ModelRouter( + tools=model_router._tools, reporter=model_router._reporter + ) + with pytest.raises(ValidationError, match="ANTHROPIC_API_KEY"): + router.act("test goal") + + def test_locate_with_missing_askui_credentials( + self, model_router: ModelRouter, mock_image: Image.Image + ) -> None: + with patch.dict(os.environ, {}, clear=True): + router = ModelRouter( + tools=model_router._tools, + reporter=model_router._reporter, + ) + with pytest.raises(ValueError, match="ASKUI_WORKSPACE_ID"): + router.locate(mock_image, "test locator", ModelName.ASKUI) + + def test_locate_with_missing_askui_credentials_only_token( + self, model_router: ModelRouter, mock_image: Image.Image + ) -> None: + with patch.dict( + os.environ, {"ASKUI_WORKSPACE_ID": str(uuid.uuid4())}, clear=True + ): + router = ModelRouter( + tools=model_router._tools, + reporter=model_router._reporter, + ) + with pytest.raises(ValueError, match="ASKUI_TOKEN"): + router.locate(mock_image, "test locator", ModelName.ASKUI) + + def test_get_inference_with_missing_askui_credentials( + self, + model_router: ModelRouter, + mock_image_source: ImageSource, + ) -> None: + with patch.dict(os.environ, {}, clear=True): + router = ModelRouter( + tools=model_router._tools, + reporter=model_router._reporter, + ) + with pytest.raises(ValueError, match="ASKUI_WORKSPACE_ID"): + router.get_inference( + "test query", mock_image_source, model=ModelName.ASKUI + ) + + def test_get_inference_with_default_missing_credentials( + self, + model_router: ModelRouter, + mock_image_source: ImageSource, + ) -> None: + with patch.dict(os.environ, {}, clear=True): + router = ModelRouter( + tools=model_router._tools, + reporter=model_router._reporter, + ) + with pytest.raises(ValueError, match="ASKUI_WORKSPACE_ID"): + router.get_inference("test query", mock_image_source) + + def test_locate_with_missing_anthropic_credentials( + self, model_router: ModelRouter, mock_image: Image.Image + ) -> None: + with patch.dict(os.environ, {}, clear=True): + router = ModelRouter( + tools=model_router._tools, + reporter=model_router._reporter, + ) + with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"): + router.locate( + mock_image, + "test locator", + ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, + ) + + def test_locate_with_default_missing_credentials( + self, model_router: ModelRouter, mock_image: Image.Image + ) -> None: + with patch.dict(os.environ, {}, clear=True): + router = ModelRouter( + tools=model_router._tools, + reporter=model_router._reporter, + ) + with pytest.raises(ValueError, match="ANTHROPIC_API_KEY"): + router.locate(mock_image, "test locator")