From 363e672d219553993d4401144bafc504051579a9 Mon Sep 17 00:00:00 2001 From: Samir mlika <105347215+mlikasam-askui@users.noreply.github.com> Date: Mon, 16 Jun 2025 10:41:32 +0200 Subject: [PATCH 1/6] Allow multiple image in tool result --- src/askui/models/shared/computer_agent.py | 17 +++++++++-------- src/askui/tools/anthropic/base.py | 6 +++--- src/askui/tools/anthropic/computer.py | 2 +- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py index 0c4a74be..8b637578 100644 --- a/src/askui/models/shared/computer_agent.py +++ b/src/askui/models/shared/computer_agent.py @@ -411,15 +411,16 @@ def _make_api_tool_result( ), ) ) - if result.base64_image: - tool_result_content.append( - ImageBlockParam( - source=Base64ImageSourceParam( - media_type="image/png", - data=result.base64_image, - ), + if result.base64_images: + for base64_image in result.base64_images: + tool_result_content.append( + ImageBlockParam( + source=Base64ImageSourceParam( + media_type="image/png", + data=base64_image, + ), + ) ) - ) return ToolResultBlockParam( content=tool_result_content, tool_use_id=tool_use_id, diff --git a/src/askui/tools/anthropic/base.py b/src/askui/tools/anthropic/base.py index bd4116ef..af98730d 100644 --- a/src/askui/tools/anthropic/base.py +++ b/src/askui/tools/anthropic/base.py @@ -27,13 +27,13 @@ class ToolResult: Args: output (str | None, optional): The output of the tool. error (str | None, optional): The error message of the tool. - base64_image (str | None, optional): The base64 image of the tool. + base64_images (list[str], optional): The base64 images of the tool. system (str | None, optional): The system message of the tool. """ output: str | None = None error: str | None = None - base64_image: str | None = None + base64_images: list[str] = [] system: str | None = None def __bool__(self) -> bool: @@ -53,7 +53,7 @@ def combine_fields( return ToolResult( output=combine_fields(self.output, other.output), error=combine_fields(self.error, other.error), - base64_image=combine_fields(self.base64_image, other.base64_image, False), + base64_images=self.base64_images + other.base64_images, system=combine_fields(self.system, other.system), ) diff --git a/src/askui/tools/anthropic/computer.py b/src/askui/tools/anthropic/computer.py index 3997e282..e0023a85 100644 --- a/src/askui/tools/anthropic/computer.py +++ b/src/askui/tools/anthropic/computer.py @@ -356,4 +356,4 @@ def screenshot(self) -> ToolResult: screenshot, self._width, self._height ) base64_image = image_to_base64(scaled_screenshot) - return ToolResult(base64_image=base64_image) + return ToolResult(base64_images=[base64_image]) From 4e72ba7a49ad2260a841ce330ae569112441902a Mon Sep 17 00:00:00 2001 From: Samir mlika <105347215+mlikasam-askui@users.noreply.github.com> Date: Mon, 16 Jun 2025 12:04:49 +0200 Subject: [PATCH 2/6] Add base agent --- src/askui/models/shared/base_agent.py | 306 ++++++++++++++++++++++ src/askui/models/shared/computer_agent.py | 280 +------------------- src/askui/tools/anthropic/__init__.py | 3 +- src/askui/tools/anthropic/base.py | 3 +- 4 files changed, 318 insertions(+), 274 deletions(-) create mode 100644 src/askui/models/shared/base_agent.py diff --git a/src/askui/models/shared/base_agent.py b/src/askui/models/shared/base_agent.py new file mode 100644 index 00000000..9a2b68ed --- /dev/null +++ b/src/askui/models/shared/base_agent.py @@ -0,0 +1,306 @@ +from abc import ABC, abstractmethod +from typing import Generic + +from anthropic.types.beta import BetaTextBlockParam +from pydantic import BaseModel +from typing_extensions import TypeVar, override + +from askui.models.models import ActModel +from askui.models.shared.computer_agent_cb_param import OnMessageCb, OnMessageCbParam +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + ContentBlockParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, +) +from askui.reporting import Reporter +from askui.tools.anthropic import ToolCollection, ToolResult +from askui.tools.anthropic.base import BaseAnthropicTool + +from ...logger import logger + + +class AgentSettingsBase(BaseModel): + """Settings for agents.""" + + max_tokens: int = 4096 + only_n_most_recent_images: int = 3 + image_truncation_threshold: int = 10 + betas: list[str] = [] + + +AgentSettings = TypeVar("AgentSettings", bound=AgentSettingsBase) + + +class BaseAgent(ActModel, ABC, Generic[AgentSettings]): + """Base class for agents that can execute autonomous actions. + + This class provides common functionality for both AskUI and Anthropic agents, + including tool handling, message processing, and image filtering. + """ + + def __init__( + self, + settings: AgentSettings, + tools: list[BaseAnthropicTool], + system_prompt: str, + reporter: Reporter, + ) -> None: + """Initialize the agent. + + Args: + settings (AgentSettings): The settings for the agent. + tools (list[BaseAnthropicTool]): The tools for the agent. + system_prompt (str): The system prompt for the agent. + reporter (Reporter): The reporter for logging messages and actions. + """ + self._settings: AgentSettings = settings + self._reporter = reporter + self._tool_collection = ToolCollection( + *tools, + ) + self._system = BetaTextBlockParam( + type="text", + text=system_prompt, + ) + + @abstractmethod + def _create_message( + self, messages: list[MessageParam], model_choice: str + ) -> MessageParam: + """Create a message using the agent's API. + + Args: + messages (list[MessageParam]): The message history. + model_choice (str): The model to use for message creation. + + Returns: + MessageParam: The created message. + """ + raise NotImplementedError + + def set_system_prompt(self, system_prompt: str) -> None: + self._system = BetaTextBlockParam(type="text", text=f"{system_prompt}") + + def set_tool_collection(self, tools: list[BaseAnthropicTool]) -> None: + self._tool_collection = ToolCollection(*tools) + + def _step( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + """Execute a single step in the conversation. + + Args: + messages (list[MessageParam]): The message history. + model_choice (str): The model to use for message creation. + on_message (OnMessageCb | None, optional): Callback on new messages + + Returns: + None + """ + if self._settings.only_n_most_recent_images: + messages = self._maybe_filter_to_n_most_recent_images( + messages, + self._settings.only_n_most_recent_images, + self._settings.image_truncation_threshold, + ) + response_message = self._create_message(messages, model_choice) + message_by_assistant = self._call_on_message( + on_message, response_message, messages + ) + if message_by_assistant is None: + return + message_by_assistant_dict = message_by_assistant.model_dump(mode="json") + logger.debug(message_by_assistant_dict) + messages.append(message_by_assistant) + self._reporter.add_message(self.__class__.__name__, message_by_assistant_dict) + if tool_result_message := self._use_tools(message_by_assistant): + if tool_result_message := self._call_on_message( + on_message, tool_result_message, messages + ): + tool_result_message_dict = tool_result_message.model_dump(mode="json") + logger.debug(tool_result_message_dict) + messages.append(tool_result_message) + self._step( + messages=messages, + model_choice=model_choice, + on_message=on_message, + ) + + def _call_on_message( + self, + on_message: OnMessageCb | None, + message: MessageParam, + messages: list[MessageParam], + ) -> MessageParam | None: + if on_message is None: + return message + return on_message(OnMessageCbParam(message=message, messages=messages)) + + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + self._step( + messages=messages, + model_choice=model_choice, + on_message=on_message, + ) + + def _use_tools( + self, + message: MessageParam, + ) -> MessageParam | None: + """Process tool use blocks in a message. + + Args: + message (MessageParam): The message containing tool use blocks. + + Returns: + MessageParam | None: A message containing tool results or `None` + if no tools were used. + """ + tool_result_content: list[ContentBlockParam] = [] + if isinstance(message.content, str): + return None + + for content_block in message.content: + if content_block.type == "tool_use": + result = self._tool_collection.run( + name=content_block.name, + tool_input=content_block.input, # type: ignore[arg-type] + ) + tool_result_content.append( + self._make_api_tool_result(result, content_block.id) + ) + if len(tool_result_content) == 0: + return None + + return MessageParam( + content=tool_result_content, + role="user", + ) + + @staticmethod + def _maybe_filter_to_n_most_recent_images( + messages: list[MessageParam], + images_to_keep: int | None, + min_removal_threshold: int, + ) -> list[MessageParam]: + """ + Filter the message history in-place to keep only the most recent images, + according to the given chunking policy. + + Args: + messages (list[MessageParam]): The message history. + images_to_keep (int | None): Number of most recent images to keep. + min_removal_threshold (int): Minimum number of images to remove at once. + + Returns: + list[MessageParam]: The filtered message history. + """ + if images_to_keep is None: + return messages + + tool_result_blocks = [ + item + for message in messages + for item in (message.content if isinstance(message.content, list) else []) + if item.type == "tool_result" + ] + total_images = sum( + 1 + for tool_result in tool_result_blocks + if not isinstance(tool_result.content, str) + for content in tool_result.content + if content.type == "image" + ) + images_to_remove = total_images - images_to_keep + if images_to_remove < min_removal_threshold: + return messages + # for better cache behavior, we want to remove in chunks + images_to_remove -= images_to_remove % min_removal_threshold + if images_to_remove <= 0: + return messages + + # Remove images from the oldest tool_result blocks first + for tool_result in tool_result_blocks: + if images_to_remove <= 0: + break + if isinstance(tool_result.content, list): + new_content: list[TextBlockParam | ImageBlockParam] = [] + for content in tool_result.content: + if content.type == "image" and images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(content) + tool_result.content = new_content + return messages + + def _make_api_tool_result( + self, result: ToolResult, tool_use_id: str + ) -> ToolResultBlockParam: + """Convert a tool result to an API tool result block. + + Args: + result (ToolResult): The tool result to convert. + tool_use_id (str): The ID of the tool use block. + + Returns: + ToolResultBlockParam: The API tool result block. + """ + tool_result_content: list[TextBlockParam | ImageBlockParam] | str = [] + is_error = False + if result.error: + is_error = True + tool_result_content = self._maybe_prepend_system_tool_result( + result, result.error + ) + else: + assert isinstance(tool_result_content, list) + if result.output: + tool_result_content.append( + TextBlockParam( + text=self._maybe_prepend_system_tool_result( + result, result.output + ), + ) + ) + if result.base64_images: + for base64_image in result.base64_images: + tool_result_content.append( + ImageBlockParam( + source=Base64ImageSourceParam( + media_type="image/png", + data=base64_image, + ), + ) + ) + return ToolResultBlockParam( + content=tool_result_content, + tool_use_id=tool_use_id, + is_error=is_error, + ) + + @staticmethod + def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: + """Prepend system message to tool result text if available. + + Args: + result (ToolResult): The tool result. + result_text (str): The result text. + + Returns: + str: The result text with optional system message prepended. + """ + if result.system: + result_text = f"{result.system}\n{result_text}" + return result_text diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py index 8b637578..c388a113 100644 --- a/src/askui/models/shared/computer_agent.py +++ b/src/askui/models/shared/computer_agent.py @@ -1,28 +1,14 @@ import platform import sys -from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import Generic -from anthropic.types.beta import BetaTextBlockParam -from pydantic import BaseModel, Field -from typing_extensions import TypeVar, override +from pydantic import Field +from typing_extensions import TypeVar -from askui.models.models import ActModel -from askui.models.shared.computer_agent_cb_param import OnMessageCb, OnMessageCbParam -from askui.models.shared.computer_agent_message_param import ( - Base64ImageSourceParam, - ContentBlockParam, - ImageBlockParam, - MessageParam, - TextBlockParam, - ToolResultBlockParam, -) +from askui.models.shared.base_agent import AgentSettingsBase, BaseAgent from askui.reporting import Reporter from askui.tools.agent_os import AgentOs -from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult - -from ...logger import logger +from askui.tools.anthropic import BaseAnthropicTool, ComputerTool COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" @@ -164,12 +150,9 @@ """ # noqa: DTZ002, E501 -class ComputerAgentSettingsBase(BaseModel): +class ComputerAgentSettingsBase(AgentSettingsBase): """Settings for computer agents.""" - max_tokens: int = 4096 - only_n_most_recent_images: int = 3 - image_truncation_threshold: int = 10 betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) @@ -178,7 +161,7 @@ class ComputerAgentSettingsBase(BaseModel): ) -class ComputerAgent(ActModel, ABC, Generic[ComputerAgentSettings]): +class ComputerAgent(BaseAgent[ComputerAgentSettings]): """Base class for computer agents that can execute autonomous actions. This class provides common functionality for both AskUI and Anthropic @@ -192,252 +175,5 @@ def __init__( agent_os: AgentOs, reporter: Reporter, ) -> None: - """Initialize the computer agent. - - Args: - settings (ComputerAgentSettings): The settings for the computer agent. - agent_os (AgentOs): The operating system agent for executing commands. - reporter (Reporter): The reporter for logging messages and actions. - """ - self._settings = settings - self._reporter = reporter - self._tool_collection = ToolCollection( - ComputerTool(agent_os), - ) - self._system = BetaTextBlockParam( - type="text", - text=f"{SYSTEM_PROMPT}", - ) - - @abstractmethod - def _create_message( - self, messages: list[MessageParam], model_choice: str - ) -> MessageParam: - """Create a message using the agent's API. - - Args: - messages (list[MessageParam]): The message history. - model_choice (str): The model to use for message creation. - - Returns: - MessageParam: The created message. - """ - raise NotImplementedError - - def _step( - self, - messages: list[MessageParam], - model_choice: str, - on_message: OnMessageCb | None = None, - ) -> None: - """Execute a single step in the conversation. - - Args: - messages (list[MessageParam]): The message history. - model_choice (str): The model to use for message creation. - on_message (OnMessageCb | None, optional): Callback on new messages - - Returns: - None - """ - if self._settings.only_n_most_recent_images: - messages = self._maybe_filter_to_n_most_recent_images( - messages, - self._settings.only_n_most_recent_images, - self._settings.image_truncation_threshold, - ) - response_message = self._create_message(messages, model_choice) - message_by_assistant = self._call_on_message( - on_message, response_message, messages - ) - if message_by_assistant is None: - return - message_by_assistant_dict = message_by_assistant.model_dump(mode="json") - logger.debug(message_by_assistant_dict) - messages.append(message_by_assistant) - self._reporter.add_message(self.__class__.__name__, message_by_assistant_dict) - if tool_result_message := self._use_tools(message_by_assistant): - if tool_result_message := self._call_on_message( - on_message, tool_result_message, messages - ): - tool_result_message_dict = tool_result_message.model_dump(mode="json") - logger.debug(tool_result_message_dict) - messages.append(tool_result_message) - self._step( - messages=messages, - model_choice=model_choice, - on_message=on_message, - ) - - def _call_on_message( - self, - on_message: OnMessageCb | None, - message: MessageParam, - messages: list[MessageParam], - ) -> MessageParam | None: - if on_message is None: - return message - return on_message(OnMessageCbParam(message=message, messages=messages)) - - @override - def act( - self, - messages: list[MessageParam], - model_choice: str, - on_message: OnMessageCb | None = None, - ) -> None: - self._step( - messages=messages, - model_choice=model_choice, - on_message=on_message, - ) - - def _use_tools( - self, - message: MessageParam, - ) -> MessageParam | None: - """Process tool use blocks in a message. - - Args: - message (MessageParam): The message containing tool use blocks. - - Returns: - MessageParam | None: A message containing tool results or `None` - if no tools were used. - """ - tool_result_content: list[ContentBlockParam] = [] - if isinstance(message.content, str): - return None - - for content_block in message.content: - if content_block.type == "tool_use": - result = self._tool_collection.run( - name=content_block.name, - tool_input=content_block.input, # type: ignore[arg-type] - ) - tool_result_content.append( - self._make_api_tool_result(result, content_block.id) - ) - if len(tool_result_content) == 0: - return None - - return MessageParam( - content=tool_result_content, - role="user", - ) - - @staticmethod - def _maybe_filter_to_n_most_recent_images( - messages: list[MessageParam], - images_to_keep: int | None, - min_removal_threshold: int, - ) -> list[MessageParam]: - """ - Filter the message history in-place to keep only the most recent images, - according to the given chunking policy. - - Args: - messages (list[MessageParam]): The message history. - images_to_keep (int | None): Number of most recent images to keep. - min_removal_threshold (int): Minimum number of images to remove at once. - - Returns: - list[MessageParam]: The filtered message history. - """ - if images_to_keep is None: - return messages - - tool_result_blocks = [ - item - for message in messages - for item in (message.content if isinstance(message.content, list) else []) - if item.type == "tool_result" - ] - total_images = sum( - 1 - for tool_result in tool_result_blocks - if not isinstance(tool_result.content, str) - for content in tool_result.content - if content.type == "image" - ) - images_to_remove = total_images - images_to_keep - if images_to_remove < min_removal_threshold: - return messages - # for better cache behavior, we want to remove in chunks - images_to_remove -= images_to_remove % min_removal_threshold - if images_to_remove <= 0: - return messages - - # Remove images from the oldest tool_result blocks first - for tool_result in tool_result_blocks: - if images_to_remove <= 0: - break - if isinstance(tool_result.content, list): - new_content: list[TextBlockParam | ImageBlockParam] = [] - for content in tool_result.content: - if content.type == "image" and images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(content) - tool_result.content = new_content - return messages - - def _make_api_tool_result( - self, result: ToolResult, tool_use_id: str - ) -> ToolResultBlockParam: - """Convert a tool result to an API tool result block. - - Args: - result (ToolResult): The tool result to convert. - tool_use_id (str): The ID of the tool use block. - - Returns: - ToolResultBlockParam: The API tool result block. - """ - tool_result_content: list[TextBlockParam | ImageBlockParam] | str = [] - is_error = False - if result.error: - is_error = True - tool_result_content = self._maybe_prepend_system_tool_result( - result, result.error - ) - else: - assert isinstance(tool_result_content, list) - if result.output: - tool_result_content.append( - TextBlockParam( - text=self._maybe_prepend_system_tool_result( - result, result.output - ), - ) - ) - if result.base64_images: - for base64_image in result.base64_images: - tool_result_content.append( - ImageBlockParam( - source=Base64ImageSourceParam( - media_type="image/png", - data=base64_image, - ), - ) - ) - return ToolResultBlockParam( - content=tool_result_content, - tool_use_id=tool_use_id, - is_error=is_error, - ) - - @staticmethod - def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: - """Prepend system message to tool result text if available. - - Args: - result (ToolResult): The tool result. - result_text (str): The result text. - - Returns: - str: The result text with optional system message prepended. - """ - if result.system: - result_text = f"{result.system}\n{result_text}" - return result_text + tool_list: list[BaseAnthropicTool] = [ComputerTool(agent_os)] + super().__init__(settings, tool_list, SYSTEM_PROMPT, reporter) diff --git a/src/askui/tools/anthropic/__init__.py b/src/askui/tools/anthropic/__init__.py index 0a058914..9bb516a5 100644 --- a/src/askui/tools/anthropic/__init__.py +++ b/src/askui/tools/anthropic/__init__.py @@ -1,4 +1,4 @@ -from .base import CLIResult, ToolResult +from .base import BaseAnthropicTool, CLIResult, ToolResult from .collection import ToolCollection from .computer import ComputerTool @@ -7,4 +7,5 @@ ComputerTool, ToolCollection, ToolResult, + BaseAnthropicTool, ] diff --git a/src/askui/tools/anthropic/base.py b/src/askui/tools/anthropic/base.py index af98730d..64f11ffc 100644 --- a/src/askui/tools/anthropic/base.py +++ b/src/askui/tools/anthropic/base.py @@ -3,6 +3,7 @@ from typing import Any, Optional from anthropic.types.beta import BetaToolUnionParam +from pydantic import Field class BaseAnthropicTool(metaclass=ABCMeta): @@ -33,7 +34,7 @@ class ToolResult: output: str | None = None error: str | None = None - base64_images: list[str] = [] + base64_images: list[str] = Field(default_factory=list) system: str | None = None def __bool__(self) -> bool: From f880f75b031328ee64fe69b8a79204af30192aa4 Mon Sep 17 00:00:00 2001 From: Samir mlika <105347215+mlikasam-askui@users.noreply.github.com> Date: Mon, 16 Jun 2025 14:43:01 +0200 Subject: [PATCH 3/6] Add BaseTool --- src/askui/tools/anthropic/__init__.py | 3 ++- src/askui/tools/anthropic/base.py | 30 ++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/askui/tools/anthropic/__init__.py b/src/askui/tools/anthropic/__init__.py index 9bb516a5..9e4ec274 100644 --- a/src/askui/tools/anthropic/__init__.py +++ b/src/askui/tools/anthropic/__init__.py @@ -1,4 +1,4 @@ -from .base import BaseAnthropicTool, CLIResult, ToolResult +from .base import BaseAnthropicTool, CLIResult, Tool, ToolResult from .collection import ToolCollection from .computer import ComputerTool @@ -8,4 +8,5 @@ ToolCollection, ToolResult, BaseAnthropicTool, + Tool, ] diff --git a/src/askui/tools/anthropic/base.py b/src/askui/tools/anthropic/base.py index 64f11ffc..bb524a5f 100644 --- a/src/askui/tools/anthropic/base.py +++ b/src/askui/tools/anthropic/base.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, fields, replace from typing import Any, Optional -from anthropic.types.beta import BetaToolUnionParam +from anthropic.types.beta import BetaToolParam, BetaToolUnionParam from pydantic import Field @@ -83,3 +83,31 @@ def __init__(self, message: str, result: Optional[ToolResult] = None): self.message = message self.result = result super().__init__(self.message) + + +class Tool(BaseAnthropicTool): + """A tool that can be used in an agent.""" + + def __init__(self, name: str, description: str, input_schema: dict[str, Any]): + if not name: + error_msg = "Tool name is required" + raise ValueError(error_msg) + if not description: + error_msg = "Tool description is required" + raise ValueError(error_msg) + if not input_schema: + input_schema = {"type": "object", "properties": {}, "required": []} + self.name = name + self.description = description + self.input_schema = input_schema + + def to_params(self) -> BetaToolParam: + return { + "name": self.name, + "description": self.description, + "input_schema": self.input_schema, + } + + def __call__(self, **kwargs: Any) -> ToolResult: + error_msg = "Tool subclasses must implement __call__ method" + raise NotImplementedError(error_msg) From e4b5b57b3ea95ce3a00f2aa210563376746ca62d Mon Sep 17 00:00:00 2001 From: Samir mlika <105347215+mlikasam-askui@users.noreply.github.com> Date: Tue, 17 Jun 2025 15:19:19 +0200 Subject: [PATCH 4/6] Add android os agent --- pdm.lock | 11 +- pyproject.toml | 1 + src/askui/tools/android_agent_os.py | 393 ++++++++++++++++++ .../tools/askui/askui_android_controller.py | 227 ++++++++++ 4 files changed, 631 insertions(+), 1 deletion(-) create mode 100644 src/askui/tools/android_agent_os.py create mode 100644 src/askui/tools/askui/askui_android_controller.py diff --git a/pdm.lock b/pdm.lock index 7ac45d7b..33576605 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "chat", "mcp", "pynput", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:21d28e90c53b9d7f3439e469fe2125b83225d66dd3f6d182bedd1ae774544485" +content_hash = "sha256:4a874d7a4f5c1b1756be67be2ddf90d521ad41b2c76c13d38404d8efd110f2cc" [[metadata.targets]] requires_python = ">=3.10" @@ -1253,6 +1253,15 @@ files = [ {file = "protobuf-5.29.4.tar.gz", hash = "sha256:4f1dfcd7997b31ef8f53ec82781ff434a28bf71d9102ddde14d076adcfc78c99"}, ] +[[package]] +name = "pure-python-adb" +version = "0.3.0.dev0" +summary = "Pure python implementation of the adb client" +groups = ["default"] +files = [ + {file = "pure-python-adb-0.3.0.dev0.tar.gz", hash = "sha256:0ecc89d780160cfe03260ba26df2c471a05263b2cad0318363573ee8043fb94d"}, +] + [[package]] name = "py-machineid" version = "0.7.0" diff --git a/pyproject.toml b/pyproject.toml index 0be5b9c3..762a3048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "segment-analytics-python>=2.3.4", "py-machineid>=0.7.0", "httpx>=0.28.1", + "pure-python-adb>=0.3.0.dev0", ] requires-python = ">=3.10" readme = "README.md" diff --git a/src/askui/tools/android_agent_os.py b/src/askui/tools/android_agent_os.py new file mode 100644 index 00000000..7b577cf1 --- /dev/null +++ b/src/askui/tools/android_agent_os.py @@ -0,0 +1,393 @@ +from abc import ABC, abstractmethod +from typing import List, Literal + +from PIL import Image + +ANDROID_KEY = Literal[ # pylint: disable=C0103 + "home", + "back", + "call", + "endcall", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "star", + "pound", + "dpad_up", + "dpad_down", + "dpad_left", + "dpad_right", + "dpad_center", + "volume_up", + "volume_down", + "power", + "camera", + "clear", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "comma", + "period", + "alt_left", + "alt_right", + "shift_left", + "shift_right", + "tab", + "space", + "sym", + "explorer", + "envelope", + "enter", + "del", + "grave", + "minus", + "equals", + "left_bracket", + "right_bracket", + "backslash", + "semicolon", + "apostrophe", + "slash", + "at", + "num", + "headsethook", + "focus", + "plus", + "menu", + "notification", + "search", + "media_play_pause", + "media_stop", + "media_next", + "media_previous", + "media_rewind", + "media_fast_forward", + "mute", + "page_up", + "page_down", + "switch_charset", + "escape", + "forward_del", + "ctrl_left", + "ctrl_right", + "caps_lock", + "scroll_lock", + "function", + "break", + "move_home", + "move_end", + "insert", + "forward", + "media_play", + "media_pause", + "media_close", + "media_eject", + "media_record", + "f1", + "f2", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "f10", + "f11", + "f12", + "num_lock", + "numpad_0", + "numpad_1", + "numpad_2", + "numpad_3", + "numpad_4", + "numpad_5", + "numpad_6", + "numpad_7", + "numpad_8", + "numpad_9", + "numpad_divide", + "numpad_multiply", + "numpad_subtract", + "numpad_add", + "numpad_dot", + "numpad_comma", + "numpad_enter", + "numpad_equals", + "numpad_left_paren", + "numpad_right_paren", + "volume_mute", + "info", + "channel_up", + "channel_down", + "zoom_in", + "zoom_out", + "window", + "guide", + "bookmark", + "captions", + "settings", + "app_switch", + "language_switch", + "contacts", + "calendar", + "music", + "calculator", + "assist", + "brightness_down", + "brightness_up", + "media_audio_track", + "sleep", + "wakeup", + "pairing", + "media_top_menu", + "last_channel", + "tv_data_service", + "voice_assist", + "help", + "navigate_previous", + "navigate_next", + "navigate_in", + "navigate_out", + "dpad_up_left", + "dpad_down_left", + "dpad_up_right", + "dpad_down_right", + "media_skip_forward", + "media_skip_backward", + "media_step_forward", + "media_step_backward", + "soft_sleep", + "cut", + "copy", + "paste", + "all_apps", + "refresh", +] + + +class AndroidDisplay: + def __init__( + self, unique_display_id: int, display_name: str, display_index: int + ) -> None: + self.unique_display_id: int = unique_display_id + self.display_name: str = display_name + self.display_index: int = display_index + + def __repr__(self) -> str: + return ( + f"AndroidDisplay(unique_display_id={self.unique_display_id}, " + f"display_name={self.display_name}, display_index={self.display_index})" + ) + + +class AndroidAgentOs(ABC): + """ + Abstract base class for Android Agent OS. Cannot be instantiated directly. + + This class defines the interface for operating system interactions including + mouse control, keyboard input, and screen capture functionality. + Implementations should provide concrete functionality for these abstract + methods. + """ + + @abstractmethod + def connect(self) -> None: + """ + Establishes a connection to the Agent OS. + + This method is called before performing any OS-level operations. + It handles any necessary setup or initialization required for the OS + interaction. + """ + raise NotImplementedError + + @abstractmethod + def disconnect(self) -> None: + """ + Terminates the connection to the Agent OS. + + This method is called after all OS-level operations are complete. + It handles any necessary cleanup or resource release. + """ + raise NotImplementedError + + @abstractmethod + def screenshot(self) -> Image.Image: + """ + Captures a screenshot of the current display. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + """ + raise NotImplementedError + + @abstractmethod + def mouse_move(self, x: int, y: int) -> None: + """ + Moves the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + """ + raise NotImplementedError + + @abstractmethod + def type(self, text: str) -> None: + """ + Simulates typing text as if entered on a keyboard. + + Args: + text (str): The text to be typed. + """ + raise NotImplementedError + + @abstractmethod + def tap(self, x: int, y: int) -> None: + """ + Simulates tapping a screen at specified coordinates. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse + button to click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ + raise NotImplementedError + + @abstractmethod + def swipe( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration_in_ms: int = 1000, + ) -> None: + """ + Simulates swiping a screen from one point to another. + + Args: + x1 (int): The horizontal coordinate of the start point. + y1 (int): The vertical coordinate of the start point. + x2 (int): The horizontal coordinate of the end point. + y2 (int): The vertical coordinate of the end point. + duration_in_ms (int, optional): The duration of the swipe in + milliseconds. Defaults to `1000`. + """ + raise NotImplementedError + + @abstractmethod + def drag_and_drop(self, x1: int, y1: int, x2: int, y2: int) -> None: + """ + Simulates dragging and dropping an object from one point to another. + + Args: + x1 (int): The horizontal coordinate of the start point. + y1 (int): The vertical coordinate of the start point. + x2 (int): The horizontal coordinate of the end point. + y2 (int): The vertical coordinate of the end point. + duration_in_ms (int, optional): The duration of the drag and drop in + milliseconds. Defaults to `1000`. + """ + raise NotImplementedError + + @abstractmethod + def shell(self, command: str) -> str: + """ + Executes a shell command on the Android device. + """ + raise NotImplementedError + + @abstractmethod + def key_event(self, key: ANDROID_KEY) -> None: + """ + Simulates a key event on the Android device. + """ + raise NotImplementedError + + @abstractmethod + def key_combination( + self, keys: List[ANDROID_KEY], duration_in_ms: int = 100 + ) -> None: + """ + Simulates a key combination on the Android device. + + Args: + keys (List[ANDROID_KEY]): The keys to be pressed. + duration_in_ms (int, optional): The duration of the key combination in + milliseconds. Defaults to `100`. + """ + raise NotImplementedError + + @abstractmethod + def set_display_by_index(self, display_index: int = 0) -> None: + """ + Sets the active display for screen interactions by index. + """ + raise NotImplementedError + + @abstractmethod + def set_display_by_id(self, display_id: int) -> None: + """ + Sets the active display for screen interactions by id. + """ + raise NotImplementedError + + @abstractmethod + def set_display_by_name(self, display_name: str) -> None: + """ + Sets the active display for screen interactions by name. + """ + raise NotImplementedError + + @abstractmethod + def set_device_by_index(self, device_index: int = 0) -> None: + """ + Sets the active device for screen interactions by index. + """ + raise NotImplementedError + + @abstractmethod + def set_device_by_name(self, device_name: str) -> None: + """ + Sets the active device for screen interactions by name. + """ + raise NotImplementedError + + @abstractmethod + def get_connected_displays(self) -> list[AndroidDisplay]: + """ + Gets the connected displays for screen interactions. + """ + raise NotImplementedError diff --git a/src/askui/tools/askui/askui_android_controller.py b/src/askui/tools/askui/askui_android_controller.py new file mode 100644 index 00000000..a2d13668 --- /dev/null +++ b/src/askui/tools/askui/askui_android_controller.py @@ -0,0 +1,227 @@ +import io +import re +import string +from typing import List, Optional, get_args + +from PIL import Image +from ppadb.client import Client as AdbClient +from ppadb.device import Device as AndroidDevice + +from askui.tools.android_agent_os import ANDROID_KEY, AndroidAgentOs, AndroidDisplay + + +class AskUiAndroidController(AndroidAgentOs): + def __init__(self, report: bool = True) -> None: + self._client: Optional[AdbClient] = None + self._device: Optional[AndroidDevice] = None + self._mouse_position: tuple[int, int] = (0, 0) + self._displays: list[AndroidDisplay] = [] + self._selected_display: Optional[AndroidDisplay] = None + self.report = report + + def connect(self) -> None: + self._client = AdbClient() + self.set_device_by_index(0) + self._device.wait_boot_complete() # type: ignore + + def disconnect(self) -> None: + self._client = None + self._device = None + + def _set_display(self, display: AndroidDisplay) -> None: + self._selected_display = display + self._mouse_position = (0, 0) + + def get_connected_displays(self) -> list[AndroidDisplay]: + if not self._device: + msg = "No device connected" + raise RuntimeError(msg) + displays: list[AndroidDisplay] = [] + output: str = self._device.shell( # type: ignore + "dumpsys SurfaceFlinger --display-id", + ) # type: ignore + + index = 0 + for line in output.splitlines(): + if line.startswith("Display"): + match = re.match( + r"Display (\d+) .* displayName=\"([^\"]+)\"", + line, + ) + if match: + unique_display_id: int = int(match.group(1)) + display_name: str = match.group(2) + displays.append( + AndroidDisplay(unique_display_id, display_name, index) + ) + index += 1 + if not displays: + return [AndroidDisplay(0, "Default", 0)] + return displays + + def set_display_by_index(self, display_index: int = 0) -> None: + self._displays = self.get_connected_displays() + if not self._displays: + self._displays = [AndroidDisplay(0, "Default", 0)] + if display_index >= len(self._displays): + msg = ( + f"Display index {display_index} out of range it must be less than " + f"{len(self._displays)}." + ) + raise RuntimeError(msg) + self._set_display(self._displays[display_index]) + + def set_display_by_id(self, display_id: int) -> None: + self._displays = self.get_connected_displays() + if not self._displays: + msg = "No displays connected" + raise RuntimeError(msg) + for display in self._displays: + if display.unique_display_id == display_id: + self._set_display(display) + return + msg = f"Display ID {display_id} not found" + raise RuntimeError(msg) + + def set_display_by_name(self, display_name: str) -> None: + self._displays = self.get_connected_displays() + if not self._displays: + msg = "No displays connected" + raise RuntimeError(msg) + for display in self._displays: + if display.display_name == display_name: + self._set_display(display) + return + msg = f"Display name {display_name} not found" + raise RuntimeError(msg) + + def set_device_by_index(self, device_index: int = 0) -> None: + devices = self._get_connected_devices() + if device_index >= len(devices): + msg = ( + f"Device index {device_index} out of range it must be less than " + f"{len(devices)}." + ) + raise RuntimeError(msg) + self._device = devices[device_index] + self.set_display_by_index(0) + + def set_device_by_name(self, device_name: str) -> None: + devices = self._get_connected_devices() + for device in devices: + if device.serial == device_name: + self._device = device + self.set_display_by_index(0) + return + msg = f"Device name {device_name} not found" + raise RuntimeError(msg) + + def screenshot(self) -> Image.Image: # type: ignore + self._check_if_device_is_connected() + connection_to_device = self._device.create_connection() # type: ignore + selected_device_id = self._selected_display.unique_display_id # type: ignore + connection_to_device.send( # type: ignore + f"shell:/system/bin/screencap -p -d {selected_device_id}" + ) + response = connection_to_device.read_all() # type: ignore + if response and len(response) > 5 and response[5] == 0x0D: # type: ignore + response = response.replace(b"\r\n", b"\n") # type: ignore + return Image.open(io.BytesIO(response)) # type: ignore + + def shell(self, command: str) -> str: + self._check_if_device_is_connected() + response: str = self._device.shell(command) # type: ignore + return response + + def tap(self, x: int, y: int) -> None: + self._check_if_device_is_connected() + display_index: int = self._selected_display.display_index # type: ignore + self.shell(f"input -d {display_index} tap {x} {y}") + self._mouse_position = (x, y) + + def swipe( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration_in_ms: int = 1000, + ) -> None: + display_index: int = self._selected_display.display_index # type: ignore + self.shell( + f"input -d {display_index} swipe {x1} {y1} {x2} {y2} {duration_in_ms}" + ) + self._mouse_position = (x2, y2) + + def drag_and_drop( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration_in_ms: int = 1000, + ) -> None: + display_index: int = self._selected_display.display_index # type: ignore + self.shell( + f"input -d {display_index} draganddrop {x1} {y1} {x2} {y2} {duration_in_ms}" + ) + self._mouse_position = (x2, y2) + + def type(self, text: str) -> None: + if any(c not in string.printable or ord(c) < 32 or ord(c) > 126 for c in text): + error_message: str = ( + f"Text contains non-printable characters: {text} " + + "or special characters which are not supported by the device" + ) + raise RuntimeError(error_message) + display_index: int = self._selected_display.display_index # type: ignore + self.shell(f"input -d {display_index} text {text}") + + def key_tap(self, key: ANDROID_KEY) -> None: + if key not in get_args(ANDROID_KEY): + error_message: str = f"Invalid key: {key}" + raise RuntimeError(error_message) + display_index: int = self._selected_display.display_index # type: ignore + self.shell(f"input -d {display_index} keyevent {key.capitalize()}") + + def key_combination( + self, keys: List[ANDROID_KEY], duration_in_ms: int = 100 + ) -> None: + if any(key not in get_args(ANDROID_KEY) for key in keys): + error_message: str = f"Invalid key: {keys}" + raise RuntimeError(error_message) + + if len(keys) < 2: + error_message: str = "Key combination must contain at least 2 keys" + raise RuntimeError(error_message) + + keys_string = " ".join(keys) + display_index: int = self._selected_display.display_index # type: ignore + self.shell( + f"input -d {display_index} keycombination -t {duration_in_ms} {keys_string}" + ) + + def _check_if_device_is_connected(self) -> None: + if not self._client or not self._device: + msg = "No device connected" + raise RuntimeError(msg) + devices: list[AndroidDevice] = self._client.devices() # type: ignore + if not devices: + msg = "No devices connected" + raise RuntimeError(msg) + + for device in devices: + if device.serial == self._device.serial: # type: ignore + return + msg = f"Device {self._device.serial} not found in connected devices" + raise RuntimeError(msg) + + def _get_connected_devices(self) -> list[AndroidDevice]: + if not self._client: + msg = "No client connected" + raise RuntimeError(msg) + devices: list[AndroidDevice] = self._client.devices() # type: ignore + if not devices: + msg = "No devices connected" + raise RuntimeError(msg) + return devices From 23971c2041918a13bb428d96149fb10032478d39 Mon Sep 17 00:00:00 2001 From: Samir mlika <105347215+mlikasam-askui@users.noreply.github.com> Date: Wed, 18 Jun 2025 15:10:00 +0200 Subject: [PATCH 5/6] Add android tools --- pyproject.toml | 3 + src/askui/__init__.py | 2 + src/askui/agent.py | 2 +- src/askui/android_agent.py | 600 ++++++++++++++++++ src/askui/models/askui/android_agent.py | 67 ++ src/askui/models/askui/settings.py | 8 + src/askui/models/model_router.py | 87 ++- src/askui/models/shared/android_agent.py | 115 ++++ src/askui/tools/android_agent_os.py | 22 +- src/askui/tools/anthropic/__init__.py | 18 + src/askui/tools/anthropic/android_tools.py | 311 +++++++++ src/askui/tools/anthropic/base.py | 13 +- .../tools/askui/askui_android_controller.py | 151 ++++- tests/unit/models/test_model_router.py | 2 +- 14 files changed, 1373 insertions(+), 28 deletions(-) create mode 100644 src/askui/android_agent.py create mode 100644 src/askui/models/askui/android_agent.py create mode 100644 src/askui/models/shared/android_agent.py create mode 100644 src/askui/tools/anthropic/android_tools.py diff --git a/pyproject.toml b/pyproject.toml index 762a3048..fa724391 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -195,6 +195,9 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [tool.ruff.lint.per-file-ignores] "src/askui/agent.py" = ["E501"] +"src/askui/android_agent.py" = ["E501"] +"src/askui/tools/anthropic/android_tools.py" = ["E501"] +"src/askui/models/shared/android_agent.py" = ["E501"] "src/askui/chat/*" = ["E501", "F401", "F403"] "src/askui/tools/askui/askui_workspaces/*" = ["ALL"] "src/askui/tools/askui/askui_ui_controller_grpc/*" = ["ALL"] diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 90cc1776..6db9f2fb 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -3,6 +3,7 @@ __version__ = "0.6.0" from .agent import VisionAgent +from .android_agent import AndroidVisionAgent from .locators import Locator from .models import ( ActModel, @@ -72,4 +73,5 @@ "ToolUseBlockParam", "UrlImageSourceParam", "VisionAgent", + "AndroidVisionAgent", ] diff --git a/src/askui/agent.py b/src/askui/agent.py index e1568778..8e112723 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -78,7 +78,7 @@ def __init__( reporter=self._reporter, ), ) - self._model_router = ModelRouter( + self._model_router = ModelRouter.build_default_computer_router( tools=self.tools, reporter=self._reporter, models=models ) self.model = model diff --git a/src/askui/android_agent.py b/src/askui/android_agent.py new file mode 100644 index 00000000..28bb44c0 --- /dev/null +++ b/src/askui/android_agent.py @@ -0,0 +1,600 @@ +import logging +import time +import types +from typing import Annotated, Optional, Type, overload + +from dotenv import load_dotenv +from pydantic import ConfigDict, Field, validate_call + +from askui.container import telemetry +from askui.locators.locators import Locator +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam +from askui.tools.android_agent_os import ANDROID_KEY +from askui.tools.askui.askui_android_controller import AskUiAndroidController +from askui.utils.image_utils import ImageSource, Img + +from .logger import configure_logging, logger +from .models import ModelComposition +from .models.exceptions import ElementNotFoundError +from .models.model_router import ModelRouter +from .models.models import ( + ModelChoice, + ModelName, + ModelRegistry, + Point, + TotalModelChoice, +) +from .models.types.response_schemas import ResponseSchema +from .reporting import CompositeReporter, Reporter +from .retry import ConfigurableRetry, Retry + + +class AndroidVisionAgent: + """ """ + + @telemetry.record_call(exclude={"model_router", "reporters", "tools"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def __init__( + self, + log_level: int | str = logging.INFO, + reporters: list[Reporter] | None = None, + model: ModelChoice | ModelComposition | str | None = None, + retry: Retry | None = None, + models: ModelRegistry | None = None, + ) -> None: + load_dotenv() + configure_logging(level=log_level) + self.os = AskUiAndroidController() + self._reporter = CompositeReporter(reporters=reporters) + self._model_router = ModelRouter.build_default_android_router( + os=self.os, + reporter=self._reporter, + models=models, + ) + self.model = model + self._retry = retry or ConfigurableRetry( + strategy="Exponential", + base_delay=1000, + retry_count=3, + on_exception_types=(ElementNotFoundError,), + ) + self._model_choice = self._initialize_model_choice(model) + + def _initialize_model_choice( + self, model_choice: ModelComposition | ModelChoice | str | None + ) -> TotalModelChoice: + """Initialize the model choice based on the provided model parameter. + + Args: + model (ModelComposition | ModelChoice | str | None): + The model to initialize from. Can be a ModelComposition, + ModelChoice dict, string, or None. + + Returns: + TotalModelChoice: A dict with keys "act", "get", and "locate" + mapping to model names (or a ModelComposition for "locate"). + """ + if isinstance(model_choice, ModelComposition): + return { + "act": ModelName.ASKUI, + "get": ModelName.ASKUI, + "locate": model_choice, + } + if isinstance(model_choice, str) or model_choice is None: + return { + "act": model_choice or ModelName.ASKUI, + "get": model_choice or ModelName.ASKUI, + "locate": model_choice or ModelName.ASKUI, + } + return { + "act": model_choice.get("act", ModelName.ASKUI), + "get": model_choice.get("get", ModelName.ASKUI), + "locate": model_choice.get("locate", ModelName.ASKUI), + } + + @overload + def tap( + self, + target: str | Locator, + model: ModelComposition | str | None = None, + ) -> None: ... + + @overload + def tap( + self, + target: Point, + model: ModelComposition | str | None = None, + ) -> None: ... + + @telemetry.record_call(exclude={"locator"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def tap( + self, + target: str | Locator | tuple[int, int], + model: ModelComposition | str | None = None, + ) -> None: + """ + Taps on the specified target. + + Args: + target (str | Locator | Point): The target to tap on. Can be a locator, a point, or a string. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for tapping on the target. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.tap("Submit button") + agent.tap((100, 100)) + """ + msg = "tap" + if isinstance(target, tuple): + msg += f" at ({target[0]}, {target[1]})" + self._reporter.add_message("User", msg) + self.os.tap(target[0], target[1]) + else: + msg += f" on {target}" + self._reporter.add_message("User", msg) + logger.debug("VisionAgent received instruction to click on %s", target) + point = self._locate(locator=target, model=model) + self.os.tap(point[0], point[1]) + + def _locate( + self, + locator: str | Locator, + screenshot: Optional[Img] = None, + model: ModelComposition | str | None = None, + ) -> Point: + def locate_with_screenshot() -> Point: + _screenshot = ImageSource( + self.os.screenshot() if screenshot is None else screenshot + ) + return self._model_router.locate( + screenshot=_screenshot, + locator=locator, + model_choice=model or self._model_choice["locate"], + ) + + point = self._retry.attempt(locate_with_screenshot) + self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") + logger.debug("ModelRouter locate: (%d, %d)", point[0], point[1]) + return point + + @telemetry.record_call(exclude={"locator", "screenshot"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def locate( + self, + locator: str | Locator, + screenshot: Optional[Img] = None, + model: ModelComposition | str | None = None, + ) -> Point: + """ + Locates the UI element identified by the provided locator. + + Args: + locator (str | Locator): The identifier or description of the element to locate. + screenshot (Img | None, optional): The screenshot to use for locating the element. Can be a path to an image file, a PIL Image object or a data URL. If `None`, takes a screenshot of the currently selected display. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for locating the element using the `locator`. + + Returns: + Point: The coordinates of the element as a tuple (x, y). + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + point = agent.locate("Submit button") + print(f"Element found at coordinates: {point}") + ``` + """ + self._reporter.add_message("User", f"locate {locator}") + logger.debug("VisionAgent received instruction to locate %s", locator) + return self._locate(locator, screenshot, model) + + @telemetry.record_call(exclude={"text"}) + @validate_call + def type( + self, + text: Annotated[str, Field(min_length=1)], + ) -> None: + """ + Types the specified text as if it were entered on a keyboard. + + Args: + text (str): The text to be typed. Must be at least `1` character long. + Only ASCII printable characters are supported. other characters will raise an error. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.type("Hello, world!") # Types "Hello, world!" + agent.type("user@example.com") # Types an email address + agent.type("password123") # Types a password + ``` + """ + self._reporter.add_message("User", f'type: "{text}"') + logger.debug("VisionAgent received instruction to type '%s'", text) + self.os.type(text) + + @overload + def get( + self, + query: Annotated[str, Field(min_length=1)], + response_schema: None = None, + model: str | None = None, + image: Optional[Img] = None, + ) -> str: ... + @overload + def get( + self, + query: Annotated[str, Field(min_length=1)], + response_schema: Type[ResponseSchema], + model: str | None = None, + image: Optional[Img] = None, + ) -> ResponseSchema: ... + + @telemetry.record_call(exclude={"query", "image", "response_schema"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def get( + self, + query: Annotated[str, Field(min_length=1)], + response_schema: Type[ResponseSchema] | None = None, + model: str | None = None, + image: Optional[Img] = None, + ) -> ResponseSchema | str: + """ + Retrieves information from an image (defaults to a screenshot of the current screen) based on the provided query. + + Args: + query (str): The query describing what information to retrieve. + image (Img | None, optional): The image to extract information from. Defaults to a screenshot of the current screen. Can be a path to an image file, a PIL Image object or a data URL. + response_schema (Type[ResponseSchema] | None, optional): A Pydantic model class that defines the response schema. If not provided, returns a string. + model (str | None, optional): The composition or name of the model(s) to be used for retrieving information from the screen or image using the `query`. Note: `response_schema` is not supported by all models. + + Returns: + ResponseSchema | str: The extracted information, `str` if no `response_schema` is provided. + + Example: + ```python + from askui import ResponseSchemaBase, VisionAgent + from PIL import Image + import json + + class UrlResponse(ResponseSchemaBase): + url: str + + class NestedResponse(ResponseSchemaBase): + nested: UrlResponse + + class LinkedListNode(ResponseSchemaBase): + value: str + next: "LinkedListNode | None" + + with AndroidVisionAgent() as agent: + # Get URL as string + url = agent.get("What is the current url shown in the url bar?") + + # Get URL as Pydantic model from image at (relative) path + response = agent.get( + "What is the current url shown in the url bar?", + response_schema=UrlResponse, + image="screenshot.png", + ) + # Dump whole model + print(response.model_dump_json(indent=2)) + # or + response_json_dict = response.model_dump(mode="json") + print(json.dumps(response_json_dict, indent=2)) + # or for regular dict + response_dict = response.model_dump() + print(response_dict["url"]) + + # Get boolean response from PIL Image + is_login_page = agent.get( + "Is this a login page?", + response_schema=bool, + image=Image.open("screenshot.png"), + ) + print(is_login_page) + + # Get integer response + input_count = agent.get( + "How many input fields are visible on this page?", + response_schema=int, + ) + print(input_count) + + # Get float response + design_rating = agent.get( + "Rate the page design quality from 0 to 1", + response_schema=float, + ) + print(design_rating) + + # Get nested response + nested = agent.get( + "Extract the URL and its metadata from the page", + response_schema=NestedResponse, + ) + print(nested.nested.url) + + # Get recursive response + linked_list = agent.get( + "Extract the breadcrumb navigation as a linked list", + response_schema=LinkedListNode, + ) + current = linked_list + while current: + print(current.value) + current = current.next + ``` + """ + logger.debug("VisionAgent received instruction to get '%s'", query) + _image = ImageSource(self.os.screenshot() if image is None else image) + self._reporter.add_message("User", f'get: "{query}"', image=_image.root) + response = self._model_router.get( + image=_image, + query=query, + response_schema=response_schema, + model_choice=model or self._model_choice["get"], + ) + message_content = ( + str(response) + if isinstance(response, (str, bool, int, float)) + else response.model_dump() + ) + self._reporter.add_message("Agent", message_content) + return response + + @telemetry.record_call() + @validate_call + def wait( + self, + sec: Annotated[float, Field(gt=0.0)], + ) -> None: + """ + Pauses the execution of the program for the specified number of seconds. + + Args: + sec (float): The number of seconds to wait. Must be greater than `0.0`. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.wait(5) # Pauses execution for 5 seconds + agent.wait(0.5) # Pauses execution for 500 milliseconds + ``` + """ + time.sleep(sec) + + @telemetry.record_call() + @validate_call + def key_tap( + self, + key: ANDROID_KEY, + ) -> None: + """ + Taps the specified key on the Android device. + + Args: + key (ANDROID_KEY): The key to tap. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.key_tap("KEYCODE_HOME") # Taps the home key + agent.key_tap("KEYCODE_BACK") # Taps the back key + ``` + """ + self.os.key_tap(key) + + @telemetry.record_call() + @validate_call + def key_combination( + self, + keys: Annotated[list[ANDROID_KEY], Field(min_length=1)], + duration_in_ms: int = 100, + ) -> None: + """ + Taps the specified keys on the Android device. + + Args: + keys (list[ANDROID_KEY]): The keys to tap. + duration_in_ms (int, optional): The duration in milliseconds to hold the key combination. Default is 100ms. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.key_combination(["KEYCODE_HOME", "KEYCODE_BACK"]) # Taps the home key and then the back key + agent.key_combination(["KEYCODE_HOME", "KEYCODE_BACK"], duration_in_ms=200) # Taps the home key and then the back key with a 200ms delay + ``` + """ + self.os.key_combination(keys, duration_in_ms) + + @telemetry.record_call() + @validate_call + def shell( + self, + command: str, + ) -> str: + """ + Executes a shell command on the Android device. + + Args: + command (str): The shell command to execute. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.shell("pm list packages") # Lists all installed packages + agent.shell("dumpsys battery") # Displays battery information + ``` + """ + return self.os.shell(command) + + @telemetry.record_call() + @validate_call + def drag_and_drop( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration_in_ms: int = 1000, + ) -> None: + """ + Drags and drops the specified target. + + Args: + x1 (int): The x-coordinate of the starting point. + y1 (int): The y-coordinate of the starting point. + x2 (int): The x-coordinate of the ending point. + y2 (int): The y-coordinate of the ending point. + duration_in_ms (int, optional): The duration in milliseconds to hold the drag and drop. Default is 1000ms. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.drag_and_drop(100, 100, 200, 200) # Drags and drops from (100, 100) to (200, 200) + agent.drag_and_drop(100, 100, 200, 200, duration_in_ms=2000) # Drags and drops from (100, 100) to (200, 200) with a 2000ms duration + """ + self.os.drag_and_drop(x1, y1, x2, y2, duration_in_ms) + + @telemetry.record_call() + @validate_call + def swipe( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration_in_ms: int = 1000, + ) -> None: + """ + Swipes the specified target. + + Args: + x1 (int): The x-coordinate of the starting point. + y1 (int): The y-coordinate of the starting point. + x2 (int): The x-coordinate of the ending point. + y2 (int): The y-coordinate of the ending point. + duration_in_ms (int, optional): The duration in milliseconds to hold the swipe. Default is 1000ms. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.swipe(100, 100, 200, 200) # Swipes from (100, 100) to (200, 200) + agent.swipe(100, 100, 200, 200, duration_in_ms=2000) # Swipes from (100, 100) to (200, 200) with a 2000ms duration + """ + self.os.swipe(x1, y1, x2, y2, duration_in_ms) + + @telemetry.record_call( + exclude={"device_name"}, + ) + @validate_call + def set_device_by_name( + self, + device_name: str, + ) -> None: + """ + Sets the active device for screen interactions by name. + + Args: + device_name (str): The name of the device to set as active. + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.set_device_by_name("Pixel 6") # Sets the active device to the Pixel 6 + """ + self.os.set_device_by_name(device_name) + + @telemetry.record_call(exclude={"goal", "on_message"}) + @validate_call + def act( + self, + goal: Annotated[str | list[MessageParam], Field(min_length=1)], + model: str | None = None, + on_message: OnMessageCb | None = None, + ) -> None: + """ + Instructs the agent to achieve a specified goal through autonomous actions. + + The agent will analyze the screen, determine necessary steps, and perform actions + to accomplish the goal. This may include clicking, typing, scrolling, and other + interface interactions. + + Args: + goal (str | list[MessageParam]): A description of what the agent should achieve. + model (str | None, optional): The composition or name of the model(s) to be used for achieving the `goal`. + on_message (OnMessageCb | None, optional): Callback for new messages. If it returns `None`, stops and does not add the message. + + Returns: + None + + Example: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.act("Open the settings menu") + agent.act("Log in with username 'admin' and password '1234'") + ``` + """ + goal_str = ( + goal + if isinstance(goal, str) + else "\n".join(msg.model_dump_json() for msg in goal) + ) + self._reporter.add_message("User", f'act: "{goal_str}"') + logger.debug( + "VisionAgent received instruction to act towards the goal '%s'", goal_str + ) + messages: list[MessageParam] = ( + [MessageParam(role="user", content=goal)] if isinstance(goal, str) else goal + ) + self._model_router.act(messages, model or self._model_choice["act"], on_message) + + @telemetry.record_call(flush=True) + def close(self) -> None: + """Disconnects from the Android device.""" + self.os.disconnect() + self._reporter.generate() + + @telemetry.record_call() + def open(self) -> None: + """Connects to the Android device.""" + self.os.connect() + + @telemetry.record_call() + def __enter__(self) -> "AndroidVisionAgent": + self.open() + return self + + @telemetry.record_call(exclude={"exc_value", "traceback"}) + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + self.close() diff --git a/src/askui/models/askui/android_agent.py b/src/askui/models/askui/android_agent.py new file mode 100644 index 00000000..a78ba3dd --- /dev/null +++ b/src/askui/models/askui/android_agent.py @@ -0,0 +1,67 @@ +import httpx +from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential +from typing_extensions import override + +from askui.models.askui.settings import AskUiAndroidAgentSettings +from askui.models.shared.android_agent import AndroidAgent +from askui.models.shared.computer_agent_message_param import MessageParam +from askui.reporting import Reporter +from askui.tools.android_agent_os import AndroidAgentOs + +from ...logger import logger + + +def is_retryable_error(exception: BaseException) -> bool: + """Check if the exception is a retryable error (status codes 429 or 529).""" + if isinstance(exception, httpx.HTTPStatusError): + return exception.response.status_code in (429, 529) + return False + + +class AskUiAndroidAgent(AndroidAgent[AskUiAndroidAgentSettings]): + def __init__( + self, + agent_os: AndroidAgentOs, + reporter: Reporter, + settings: AskUiAndroidAgentSettings, + ) -> None: + super().__init__(settings, agent_os, reporter) + self._client = httpx.Client( + base_url=f"{self._settings.askui.base_url}", + headers={ + "Content-Type": "application/json", + "Authorization": self._settings.askui.authorization_header, + }, + ) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=30, max=240), + retry=retry_if_exception(is_retryable_error), + reraise=True, + ) + @override + def _create_message( + self, + messages: list[MessageParam], + model_choice: str, # noqa: ARG002 + ) -> MessageParam: + try: + request_body = { + "max_tokens": self._settings.max_tokens, + "messages": [msg.model_dump(mode="json") for msg in messages], + "model": self._settings.model, + "tools": self._tool_collection.to_params(), + "betas": [], + "system": [self._system], + } + response = self._client.post( + "/act/inference", json=request_body, timeout=300.0 + ) + response.raise_for_status() + response_data = response.json() + return MessageParam.model_validate(response_data) + except Exception as e: # noqa: BLE001 + if is_retryable_error(e): + logger.debug(e) + raise diff --git a/src/askui/models/askui/settings.py b/src/askui/models/askui/settings.py index b018b40b..f09c3f5b 100644 --- a/src/askui/models/askui/settings.py +++ b/src/askui/models/askui/settings.py @@ -5,6 +5,7 @@ from pydantic_settings import BaseSettings from askui.models.models import ModelName +from askui.models.shared.base_agent import AgentSettingsBase from askui.models.shared.computer_agent import ComputerAgentSettingsBase @@ -39,3 +40,10 @@ def base_url(self) -> str: class AskUiComputerAgentSettings(ComputerAgentSettingsBase): model: str = ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 askui: AskUiSettings = Field(default_factory=AskUiSettings) + + +class AskUiAndroidAgentSettings(AgentSettingsBase): + """Settings for AskUI Android agent.""" + + model: str = ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + askui: AskUiSettings = Field(default_factory=AskUiSettings) diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 2f14c0a1..a8e20a8e 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -11,9 +11,13 @@ ClaudeSettings, ) from askui.models.askui.ai_element_utils import AiElementCollection +from askui.models.askui.android_agent import AskUiAndroidAgent from askui.models.askui.computer_agent import AskUiComputerAgent from askui.models.askui.model_router import AskUiModelRouter -from askui.models.askui.settings import AskUiComputerAgentSettings +from askui.models.askui.settings import ( + AskUiAndroidAgentSettings, + AskUiComputerAgentSettings, +) from askui.models.exceptions import ModelNotFoundError, ModelTypeMismatchError from askui.models.huggingface.spaces_api import HFSpacesHandler from askui.models.models import ( @@ -32,6 +36,7 @@ from askui.models.shared.facade import ModelFacade from askui.models.types.response_schemas import ResponseSchema from askui.reporting import CompositeReporter, Reporter +from askui.tools.android_agent_os import AndroidAgentOs from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource @@ -138,17 +143,87 @@ def ui_tars_api_handler() -> UiTarsApiHandler: } +def _initialize_default_android_model_registry( # noqa: C901 + os: AndroidAgentOs, + reporter: Reporter, +) -> ModelRegistry: + @functools.cache + @functools.cache + def askui_settings() -> AskUiSettings: + return AskUiSettings() + + @functools.cache + def askui_inference_api() -> AskUiInferenceApi: + return AskUiInferenceApi( + locator_serializer=AskUiLocatorSerializer( + ai_element_collection=AiElementCollection(), + reporter=reporter, + ), + settings=askui_settings(), + ) + + @functools.cache + def askui_model_router() -> AskUiModelRouter: + return AskUiModelRouter( + inference_api=askui_inference_api(), + ) + + @functools.cache + def askui_facade() -> ModelFacade: + android_agent = AskUiAndroidAgent( + agent_os=os, + reporter=reporter, + settings=AskUiAndroidAgentSettings( + askui=askui_settings(), + ), + ) + return ModelFacade( + act_model=android_agent, + get_model=askui_inference_api(), + locate_model=askui_model_router(), + ) + + return { + ModelName.ASKUI: askui_facade, + ModelName.ASKUI__AI_ELEMENT: askui_model_router, + ModelName.ASKUI__COMBO: askui_model_router, + ModelName.ASKUI__OCR: askui_model_router, + ModelName.ASKUI__PTA: askui_model_router, + } + + class ModelRouter: def __init__( self, + models: ModelRegistry, + ): + self._models = models + + @staticmethod + def build_default_computer_router( tools: AgentToolbox, reporter: Reporter | None = None, models: ModelRegistry | None = None, - ): - self._tools = tools - self._reporter = reporter or CompositeReporter() - self._models = _initialize_default_model_registry(tools, self._reporter) - self._models.update(models or {}) + ) -> "ModelRouter": + """Build a default model router for computer agents.""" + if reporter is None: + reporter = CompositeReporter() + models_registry = _initialize_default_model_registry(tools, reporter) + models_registry.update(models or {}) + return ModelRouter(models=models_registry) + + @staticmethod + def build_default_android_router( + os: AndroidAgentOs, + reporter: Reporter | None = None, + models: ModelRegistry | None = None, + ) -> "ModelRouter": + """Build a default model router for Android agents.""" + if reporter is None: + reporter = CompositeReporter() + models_registry = _initialize_default_android_model_registry(os, reporter) + models_registry.update(models or {}) + return ModelRouter(models=models_registry) @overload def _get_model(self, model_choice: str, model_type: Literal["act"]) -> ActModel: ... diff --git a/src/askui/models/shared/android_agent.py b/src/askui/models/shared/android_agent.py new file mode 100644 index 00000000..ab28d134 --- /dev/null +++ b/src/askui/models/shared/android_agent.py @@ -0,0 +1,115 @@ +from askui.models.shared.base_agent import AgentSettings, BaseAgent +from askui.reporting import Reporter +from askui.tools.android_agent_os import AndroidAgentOs +from askui.tools.anthropic import ( + AndroidDragAndDropTool, + AndroidKeyCombinationTool, + AndroidKeyTapEventTool, + AndroidScreenshotTool, + AndroidShellTool, + AndroidSwipeTool, + AndroidTapTool, + AndroidTypeTool, + BaseAnthropicTool, +) +from askui.tools.askui.askui_android_controller import AndroidAgentOSHandler + +ANDROID_SYSTEM_PROMPT = """ +You are an autonomous Android device control agent operating via ADB on a test device with full system access. +Your primary goal is to execute tasks efficiently and reliably while maintaining system stability. + + +* Autonomy: Operate independently and make informed decisions without requiring user input. +* Reliability: Ensure actions are repeatable and maintain system stability. +* Efficiency: Optimize operations to minimize latency and resource usage. +* Safety: Always verify actions before execution, even with full system access. + + + +1. Tool Usage: + * Verify tool availability before starting any operation + * Use the most direct and efficient tool for each task + * Combine tools strategically for complex operations + * Prefer built-in tools over shell commands when possible + +2. Error Handling: + * Assess failures systematically: check tool availability, permissions, and device state + * Implement retry logic with exponential backoff for transient failures + * Use fallback strategies when primary approaches fail + * Provide clear, actionable error messages with diagnostic information + +3. Performance Optimization: + * Use one-liner shell commands with inline filtering (grep, cut, awk, jq) for efficiency + * Minimize screen captures and coordinate calculations + * Cache device state information when appropriate + * Batch related operations when possible + +4. Screen Interaction: + * Ensure all coordinates are integers and within screen bounds + * Implement smart scrolling for off-screen elements + * Use appropriate gestures (tap, swipe, drag) based on context + * Verify element visibility before interaction + +5. System Access: + * Leverage full system access responsibly + * Use shell commands for system-level operations + * Monitor system state and resource usage + * Maintain system stability during operations + +6. Recovery Strategies: + * If an element is not visible, try: + - Scrolling in different directions + - Adjusting view parameters + - Using alternative interaction methods + * If a tool fails: + - Check device connection and state + - Verify tool availability and permissions + - Try alternative tools or approaches + * If stuck: + - Provide clear diagnostic information + - Suggest potential solutions + - Request user intervention only if necessary + +7. Best Practices: + * Document all significant operations + * Maintain operation logs for debugging + * Implement proper cleanup after operations + * Follow Android best practices for UI interaction + + +* This is a test device with full system access - use this capability responsibly +* Always verify the success of critical operations +* Maintain system stability as the highest priority +* Provide clear, actionable feedback for all operations +* Use the most efficient method for each task + +""" + + +class AndroidAgent(BaseAgent[AgentSettings]): + """Base class for computer agents that can execute autonomous actions. + + This class provides common functionality for both AskUI and Anthropic + computer agents, + including tool handling, message processing, and image filtering. + """ + + def __init__( + self, + settings: AgentSettings, + android_agent_os: AndroidAgentOs, + reporter: Reporter, + ) -> None: + android_os_handler = AndroidAgentOSHandler(android_agent_os, reporter) + tool_list: list[BaseAnthropicTool] = [ + AndroidScreenshotTool(android_os_handler), + AndroidTapTool(android_os_handler), + AndroidTypeTool(android_os_handler), + AndroidDragAndDropTool(android_os_handler), + AndroidKeyTapEventTool(android_os_handler), + AndroidSwipeTool(android_os_handler), + AndroidKeyCombinationTool(android_os_handler), + AndroidShellTool(android_os_handler), + ] + + super().__init__(settings, tool_list, ANDROID_SYSTEM_PROMPT, reporter) diff --git a/src/askui/tools/android_agent_os.py b/src/askui/tools/android_agent_os.py index 7b577cf1..2a369cb6 100644 --- a/src/askui/tools/android_agent_os.py +++ b/src/askui/tools/android_agent_os.py @@ -252,17 +252,6 @@ def screenshot(self) -> Image.Image: """ raise NotImplementedError - @abstractmethod - def mouse_move(self, x: int, y: int) -> None: - """ - Moves the mouse cursor to specified screen coordinates. - - Args: - x (int): The horizontal coordinate (in pixels) to move to. - y (int): The vertical coordinate (in pixels) to move to. - """ - raise NotImplementedError - @abstractmethod def type(self, text: str) -> None: """ @@ -308,7 +297,14 @@ def swipe( raise NotImplementedError @abstractmethod - def drag_and_drop(self, x1: int, y1: int, x2: int, y2: int) -> None: + def drag_and_drop( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration_in_ms: int = 1000, + ) -> None: """ Simulates dragging and dropping an object from one point to another. @@ -330,7 +326,7 @@ def shell(self, command: str) -> str: raise NotImplementedError @abstractmethod - def key_event(self, key: ANDROID_KEY) -> None: + def key_tap(self, key: ANDROID_KEY) -> None: """ Simulates a key event on the Android device. """ diff --git a/src/askui/tools/anthropic/__init__.py b/src/askui/tools/anthropic/__init__.py index 9e4ec274..19ce9bb0 100644 --- a/src/askui/tools/anthropic/__init__.py +++ b/src/askui/tools/anthropic/__init__.py @@ -1,3 +1,13 @@ +from .android_tools import ( + AndroidDragAndDropTool, + AndroidKeyCombinationTool, + AndroidKeyTapEventTool, + AndroidScreenshotTool, + AndroidShellTool, + AndroidSwipeTool, + AndroidTapTool, + AndroidTypeTool, +) from .base import BaseAnthropicTool, CLIResult, Tool, ToolResult from .collection import ToolCollection from .computer import ComputerTool @@ -9,4 +19,12 @@ ToolResult, BaseAnthropicTool, Tool, + AndroidScreenshotTool, + AndroidTapTool, + AndroidTypeTool, + AndroidDragAndDropTool, + AndroidKeyTapEventTool, + AndroidSwipeTool, + AndroidKeyCombinationTool, + AndroidShellTool, ] diff --git a/src/askui/tools/anthropic/android_tools.py b/src/askui/tools/anthropic/android_tools.py new file mode 100644 index 00000000..9555ae7d --- /dev/null +++ b/src/askui/tools/anthropic/android_tools.py @@ -0,0 +1,311 @@ +from typing import get_args + +from askui.tools.android_agent_os import ANDROID_KEY +from askui.tools.askui.askui_android_controller import AndroidAgentOSHandler +from askui.utils.image_utils import image_to_base64 + +from .base import Tool, ToolResult + + +class AndroidScreenshotTool(Tool): + """ + Takes a screenshot from the currently connected Android device. + """ + + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + super().__init__( + name="android_screenshot_tool", + description=""" + Takes a screenshot of the currently active window. + The image can be used to check the current state of the device. + It's recommended to use this tool to check the current state of the device + before and after an action. + """, + ) + self.os_agent_handler = os_agent_handler + + def __call__(self) -> ToolResult: + screenshot = self.os_agent_handler.screenshot() + base64_image = image_to_base64(screenshot) + return ToolResult(output="Screenshot was taken.", base64_images=[base64_image]) + + +class AndroidTapTool(Tool): + """ + Performs a tap (touch) gesture at the given (x, y) coordinates on the + Android device screen. + The coordinates are absolute coordinates on the screen. + The top left corner of the screen is (0, 0). + """ + + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + super().__init__( + name="android_tap_tool", + description=""" + Performs a tap (touch) gesture at the given (x, y) coordinates on the + Android device screen. + The coordinates are absolute coordinates on the screen. + The top left corner of the screen is (0, 0). + """, + input_schema={ + "type": "object", + "properties": { + "x": { + "type": "integer", + "description": "The x coordinate of the tap gesture in pixels.", + }, + "y": { + "type": "integer", + "description": "The y coordinate of the tap gesture in pixels.", + }, + }, + "required": ["x", "y"], + }, + ) + self.os_agent_handler = os_agent_handler + + def __call__(self, x: int, y: int) -> ToolResult: + self.os_agent_handler.tap(x, y) + return ToolResult(output=f"Tapped at ({x}, {y})") + + +class AndroidTypeTool(Tool): + """ + Types the given text on the Android device screen. + """ + + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + super().__init__( + name="android_type_tool", + description=""" + Types the given text on the Android device screen. + The to typed text can not contains non ASCII printable characters. + """, + input_schema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": """ + The text to type. It must be a valid ASCII printable string. + text such as "Hello, world!" is valid, + but "Hello, 世界!" is not valid and will raise an error. + """, + }, + }, + "required": ["text"], + }, + ) + self.os_agent_handler = os_agent_handler + + def __call__(self, text: str) -> ToolResult: + self.os_agent_handler.type(text) + return ToolResult(output=f"Typed: {text}") + + +class AndroidDragAndDropTool(Tool): + """ + Performs a drag and drop gesture on the Android device screen. + """ + + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + self.os_agent_handler = os_agent_handler + super().__init__( + name="android_drag_and_drop_tool", + description=""" + Performs a drag and drop gesture on the Android device screen. + Will hold the element at the requested start position and drag and drop + it in the requested end position in pixels in the given duration. + TopLeftCorner of the screen is (0, 0). + To get the coordinates of an element, take and analyze a screenshot of the screen. + """, + input_schema={ + "type": "object", + "properties": { + "x1": { + "type": "integer", + "description": "The x1 pixel coordinate of the start position", + }, + "y1": { + "type": "integer", + "description": "The y1 pixel coordinate of the start position", + }, + "x2": { + "type": "integer", + "description": "The x2 pixel coordinate of the end position", + }, + "y2": { + "type": "integer", + "description": "The y2 pixel coordinate of the end position", + }, + "duration": { + "type": "integer", + "description": "The duration of the drag and drop gesture in milliseconds", + "default": 1000, + }, + }, + "required": ["x1", "y1", "x2", "y2"], + }, + ) + + def __call__( + self, x1: int, y1: int, x2: int, y2: int, duration: int = 1000 + ) -> ToolResult: + self.os_agent_handler.drag_and_drop(x1, y1, x2, y2, duration) + return ToolResult( + output=f"Dragged and dropped from ({x1}, {y1}) to ({x2}, {y2}) in {duration}ms" + ) + + +class AndroidKeyTapEventTool(Tool): + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + super().__init__( + name="android_key_event_tool", + description=""" + Performs a key press on the android device. + e.g 'HOME' to simulate the home button press. + """, + input_schema={ + "type": "object", + "properties": { + "key_name": { + "type": "string", + "description": "The key event to perform. e.g 'HOME' to simulate the home button press.", + "enum": get_args(ANDROID_KEY), + }, + }, + "required": ["key_name"], + }, + ) + self.os_agent_handler = os_agent_handler + + def __call__(self, key_name: ANDROID_KEY) -> ToolResult: + self.os_agent_handler.key_tap(key_name) + return ToolResult(output=f"Tapped on {key_name}") + + +class AndroidSwipeTool(Tool): + """ + Performs a swipe gesture on the Android device screen. + """ + + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + super().__init__( + name="android_swipe_tool", + description=""" + Performs a swipe gesture on the Android device screen, similar to how a user would swipe their finger across the screen. + This is useful for scrolling through content, navigating between screens, or revealing hidden elements. + The gesture will start at the specified coordinates and move to the end coordinates over the given duration. + The screen coordinates are absolute, with (0,0) at the top-left corner of the screen. + For best results, ensure the coordinates are within the visible screen bounds. + """, + input_schema={ + "type": "object", + "properties": { + "x1": { + "type": "integer", + "description": "The starting x-coordinate in pixels from the left edge of the screen. Must be a positive integer.", + }, + "y1": { + "type": "integer", + "description": "The starting y-coordinate in pixels from the top edge of the screen. Must be a positive integer.", + }, + "x2": { + "type": "integer", + "description": "The ending x-coordinate in pixels from the left edge of the screen. Must be a positive integer.", + }, + "y2": { + "type": "integer", + "description": "The ending y-coordinate in pixels from the top edge of the screen. Must be a positive integer.", + }, + "duration": { + "type": "integer", + "description": "The duration of the swipe gesture in milliseconds. A longer duration creates a slower swipe. Default is 1000ms (1 second).", + "default": 1000, + }, + }, + "required": ["x1", "y1", "x2", "y2"], + }, + ) + self.os_agent_handler = os_agent_handler + + def __call__( + self, x1: int, y1: int, x2: int, y2: int, duration: int = 1000 + ) -> ToolResult: + self.os_agent_handler.swipe(x1, y1, x2, y2, duration) + return ToolResult( + output=f"Swiped from ({x1}, {y1}) to ({x2}, {y2}) in {duration}ms" + ) + + +class AndroidKeyCombinationTool(Tool): + """ + Performs a key combination on the Android device. + """ + + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + super().__init__( + name="android_key_combination_tool", + description=""" + Performs a combination of key presses on the Android device, similar to keyboard shortcuts on a computer. + This is useful for performing complex actions that require multiple keys to be pressed simultaneously. + For example, you can use this to copy text (ctrl+c), switch apps (alt+tab), or perform other system-wide shortcuts. + The keys will be pressed in the order specified, with a small delay between each press. + """, + input_schema={ + "type": "object", + "properties": { + "keys": { + "type": "array", + "items": { + "type": "string", + "enum": get_args(ANDROID_KEY), + }, + "description": "An array of keys to press in combination. Each key must be a valid Android key code. For example: ['ctrl_left', 'c'] for copy.", + }, + "duration": { + "type": "integer", + "description": "The duration in milliseconds to hold the key combination. A longer duration may be needed for some system actions. Default is 100ms.", + "default": 100, + }, + }, + "required": ["keys"], + }, + ) + self.os_agent_handler = os_agent_handler + + def __call__(self, keys: list[ANDROID_KEY], duration: int = 100) -> ToolResult: + self.os_agent_handler.key_combination(keys, duration) + return ToolResult(output=f"Performed key combination: {keys}") + + +class AndroidShellTool(Tool): + """ + Executes a shell command on the Android device. + """ + + def __init__(self, os_agent_handler: AndroidAgentOSHandler): + super().__init__( + name="android_shell_tool", + description=""" + Executes a shell command directly on the Android device through ADB. + This provides low-level access to the Android system, allowing you to run system commands, + check device status, or perform administrative tasks. + The command will be executed in the Android shell environment with the current user's permissions. + """, + input_schema={ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The shell command to execute on the Android device. This can be any valid Android shell command, such as 'pm list packages' to list installed apps or 'dumpsys battery' to check battery status.", + }, + }, + "required": ["command"], + }, + ) + self.os_agent_handler = os_agent_handler + + def __call__(self, command: str) -> ToolResult: + output = self.os_agent_handler.shell(command) + return ToolResult(output=f"Shell command executed. Output: {output}") diff --git a/src/askui/tools/anthropic/base.py b/src/askui/tools/anthropic/base.py index bb524a5f..a0647239 100644 --- a/src/askui/tools/anthropic/base.py +++ b/src/askui/tools/anthropic/base.py @@ -10,9 +10,9 @@ class BaseAnthropicTool(metaclass=ABCMeta): """Abstract base class for Anthropic-defined tools.""" @abstractmethod - def __call__(self, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Any: """Executes the tool with the given arguments.""" - ... + raise NotImplementedError @abstractmethod def to_params( @@ -88,7 +88,12 @@ def __init__(self, message: str, result: Optional[ToolResult] = None): class Tool(BaseAnthropicTool): """A tool that can be used in an agent.""" - def __init__(self, name: str, description: str, input_schema: dict[str, Any]): + def __init__( + self, + name: str, + description: str, + input_schema: dict[str, Any] | None = None, + ) -> None: if not name: error_msg = "Tool name is required" raise ValueError(error_msg) @@ -108,6 +113,6 @@ def to_params(self) -> BetaToolParam: "input_schema": self.input_schema, } - def __call__(self, **kwargs: Any) -> ToolResult: + def __call__(self, *_args: Any, **_kwargs: Any) -> ToolResult: error_msg = "Tool subclasses must implement __call__ method" raise NotImplementedError(error_msg) diff --git a/src/askui/tools/askui/askui_android_controller.py b/src/askui/tools/askui/askui_android_controller.py index a2d13668..1eec29b9 100644 --- a/src/askui/tools/askui/askui_android_controller.py +++ b/src/askui/tools/askui/askui_android_controller.py @@ -1,23 +1,28 @@ import io import re import string -from typing import List, Optional, get_args +from typing import List, Optional, Tuple, get_args from PIL import Image from ppadb.client import Client as AdbClient from ppadb.device import Device as AndroidDevice +from askui.reporting import Reporter from askui.tools.android_agent_os import ANDROID_KEY, AndroidAgentOs, AndroidDisplay +from askui.utils.image_utils import scale_coordinates_back, scale_image_with_padding class AskUiAndroidController(AndroidAgentOs): - def __init__(self, report: bool = True) -> None: + """ + This class is used to control the Android device. + """ + + def __init__(self) -> None: self._client: Optional[AdbClient] = None self._device: Optional[AndroidDevice] = None self._mouse_position: tuple[int, int] = (0, 0) self._displays: list[AndroidDisplay] = [] self._selected_display: Optional[AndroidDisplay] = None - self.report = report def connect(self) -> None: self._client = AdbClient() @@ -225,3 +230,143 @@ def _get_connected_devices(self) -> list[AndroidDevice]: msg = "No devices connected" raise RuntimeError(msg) return devices + + +class AndroidAgentOSHandler(AndroidAgentOs): + """ + This class is used to handle the AndroidAgentOs class. + It is used to scale the coordinates to the target resolution + and back to the real screen resolution. + """ + + def __init__(self, os_agent: AndroidAgentOs, reporter: Reporter) -> None: + self._os_agent: AndroidAgentOs = os_agent + self._reporter: Reporter = reporter + self._target_resolution: Tuple[int, int] = (1280, 800) + self._real_screen_resolution: Optional[Tuple[int, int]] = None + + def connect(self) -> None: + self._os_agent.connect() + self._reporter.add_message("AndroidAgentOS", "Connected to device") + self._real_screen_resolution = self._os_agent.screenshot().size + + def disconnect(self) -> None: + self._os_agent.disconnect() + self._real_screen_resolution = None + + def screenshot(self) -> Image.Image: + screenshot = self._os_agent.screenshot() + self._real_screen_resolution = screenshot.size + scaled_image = scale_image_with_padding( + screenshot, + self._target_resolution[0], + self._target_resolution[1], + ) + + self._reporter.add_message("AndroidAgentOS", "Screenshot taken", screenshot) + return scaled_image + + def _scale_coordinates_back(self, x: int, y: int) -> Tuple[int, int]: + if self._real_screen_resolution is None: + self._real_screen_resolution = self._os_agent.screenshot().size + + scaled_x, scaled_y = scale_coordinates_back( + x, + y, + self._real_screen_resolution[0], + self._real_screen_resolution[1], + self._target_resolution[0], + self._target_resolution[1], + ) + return int(scaled_x), int(scaled_y) + + def tap(self, x: int, y: int) -> None: + scaled_x, scaled_y = self._scale_coordinates_back(x, y) + self._os_agent.tap(scaled_x, scaled_y) + self._reporter.add_message("AndroidAgentOS", f"Tapped on {x}, {y}") + + def swipe( + self, x1: int, y1: int, x2: int, y2: int, duration_in_ms: int = 1000 + ) -> None: + scaled_x1, scaled_y1 = self._scale_coordinates_back(x1, y1) + scaled_x2, scaled_y2 = self._scale_coordinates_back(x2, y2) + self._os_agent.swipe(scaled_x1, scaled_y1, scaled_x2, scaled_y2, duration_in_ms) + self._reporter.add_message( + "AndroidAgentOS", f"Swiped from {x1}, {y1} to {x2}, {y2}" + ) + + def drag_and_drop( + self, x1: int, y1: int, x2: int, y2: int, duration_in_ms: int = 1000 + ) -> None: + scaled_x1, scaled_y1 = self._scale_coordinates_back(x1, y1) + scaled_x2, scaled_y2 = self._scale_coordinates_back(x2, y2) + self._os_agent.drag_and_drop( + scaled_x1, scaled_y1, scaled_x2, scaled_y2, duration_in_ms + ) + self._reporter.add_message( + "AndroidAgentOS", + f"Dragged and dropped from {x1}, {y1} to {x2}, {y2}", + ) + + def type(self, text: str) -> None: + self._os_agent.type(text) + self._reporter.add_message("AndroidAgentOS", f"Typed {text}") + + def key_tap(self, key: ANDROID_KEY) -> None: + self._os_agent.key_tap(key) + self._reporter.add_message("AndroidAgentOS", f"Tapped on {key}") + + def key_combination( + self, keys: List[ANDROID_KEY], duration_in_ms: int = 100 + ) -> None: + self._os_agent.key_combination(keys, duration_in_ms) + self._reporter.add_message( + "AndroidAgentOS", + f"Tapped on {keys}", + ) + + def shell(self, command: str) -> str: + shell_output = self._os_agent.shell(command) + self._reporter.add_message("AndroidAgentOS", f"Ran shell command: {command}") + return shell_output + + def get_connected_displays(self) -> list[AndroidDisplay]: + displays = self._os_agent.get_connected_displays() + self._reporter.add_message( + "AndroidAgentOS", + f"Retrieved connected displays, length: {len(displays)}", + ) + return displays + + def set_display_by_index(self, display_index: int = 0) -> None: + self._os_agent.set_display_by_index(display_index) + self._real_screen_resolution = None + self._reporter.add_message( + "AndroidAgentOS", f"Set display by index: {display_index}" + ) + + def set_display_by_id(self, display_id: int) -> None: + self._os_agent.set_display_by_id(display_id) + self._real_screen_resolution = None + self._reporter.add_message("AndroidAgentOS", f"Set display by id: {display_id}") + + def set_display_by_name(self, display_name: str) -> None: + self._os_agent.set_display_by_name(display_name) + self._real_screen_resolution = None + self._reporter.add_message( + "AndroidAgentOS", f"Set display by name: {display_name}" + ) + + def set_device_by_index(self, device_index: int = 0) -> None: + self._os_agent.set_device_by_index(device_index) + self._real_screen_resolution = None + self._reporter.add_message( + "AndroidAgentOS", f"Set device by index: {device_index}" + ) + + def set_device_by_name(self, device_name: str) -> None: + self._os_agent.set_device_by_name(device_name) + self._real_screen_resolution = None + self._reporter.add_message( + "AndroidAgentOS", f"Set device by name: {device_name}" + ) diff --git a/tests/unit/models/test_model_router.py b/tests/unit/models/test_model_router.py index 5cb6f192..6231ac04 100644 --- a/tests/unit/models/test_model_router.py +++ b/tests/unit/models/test_model_router.py @@ -83,7 +83,7 @@ def model_router( mock_hf_spaces: HFSpacesHandler, ) -> ModelRouter: """Fixture providing a ModelRouter instance with mocked dependencies.""" - return ModelRouter( + return ModelRouter.build_default_computer_router( tools=agent_toolbox_mock, reporter=CompositeReporter(), models={ From 9d12e2676e13285d3ff441f19e674fb0e55717ca Mon Sep 17 00:00:00 2001 From: Samir mlika <105347215+mlikasam-askui@users.noreply.github.com> Date: Wed, 18 Jun 2025 15:23:53 +0200 Subject: [PATCH 6/6] add exception tool --- src/askui/models/shared/android_agent.py | 2 ++ src/askui/models/shared/base_agent.py | 10 +++++-- src/askui/tools/anthropic/__init__.py | 2 ++ src/askui/tools/anthropic/base.py | 10 +++++++ src/askui/tools/anthropic/collection.py | 19 ++++++++----- src/askui/tools/anthropic/exception_tool.py | 31 +++++++++++++++++++++ 6 files changed, 65 insertions(+), 9 deletions(-) create mode 100644 src/askui/tools/anthropic/exception_tool.py diff --git a/src/askui/models/shared/android_agent.py b/src/askui/models/shared/android_agent.py index ab28d134..3a00580c 100644 --- a/src/askui/models/shared/android_agent.py +++ b/src/askui/models/shared/android_agent.py @@ -12,6 +12,7 @@ AndroidTypeTool, BaseAnthropicTool, ) +from askui.tools.anthropic.exception_tool import ExceptionTool from askui.tools.askui.askui_android_controller import AndroidAgentOSHandler ANDROID_SYSTEM_PROMPT = """ @@ -110,6 +111,7 @@ def __init__( AndroidSwipeTool(android_os_handler), AndroidKeyCombinationTool(android_os_handler), AndroidShellTool(android_os_handler), + ExceptionTool(), ] super().__init__(settings, tool_list, ANDROID_SYSTEM_PROMPT, reporter) diff --git a/src/askui/models/shared/base_agent.py b/src/askui/models/shared/base_agent.py index 9a2b68ed..50fa602a 100644 --- a/src/askui/models/shared/base_agent.py +++ b/src/askui/models/shared/base_agent.py @@ -59,7 +59,7 @@ def __init__( self._settings: AgentSettings = settings self._reporter = reporter self._tool_collection = ToolCollection( - *tools, + tools, ) self._system = BetaTextBlockParam( type="text", @@ -82,10 +82,16 @@ def _create_message( raise NotImplementedError def set_system_prompt(self, system_prompt: str) -> None: + """Set the system prompt for the agent.""" self._system = BetaTextBlockParam(type="text", text=f"{system_prompt}") def set_tool_collection(self, tools: list[BaseAnthropicTool]) -> None: - self._tool_collection = ToolCollection(*tools) + """Set the tool collection for the agent.""" + self._tool_collection = ToolCollection(tools) + + def add_tool(self, tool: BaseAnthropicTool) -> None: + """Add a tool to the agent.""" + self._tool_collection.add_tool(tool) def _step( self, diff --git a/src/askui/tools/anthropic/__init__.py b/src/askui/tools/anthropic/__init__.py index 19ce9bb0..455ad737 100644 --- a/src/askui/tools/anthropic/__init__.py +++ b/src/askui/tools/anthropic/__init__.py @@ -11,6 +11,7 @@ from .base import BaseAnthropicTool, CLIResult, Tool, ToolResult from .collection import ToolCollection from .computer import ComputerTool +from .exception_tool import ExceptionTool __ALL__ = [ CLIResult, @@ -27,4 +28,5 @@ AndroidSwipeTool, AndroidKeyCombinationTool, AndroidShellTool, + ExceptionTool, ] diff --git a/src/askui/tools/anthropic/base.py b/src/askui/tools/anthropic/base.py index a0647239..5e0da8fd 100644 --- a/src/askui/tools/anthropic/base.py +++ b/src/askui/tools/anthropic/base.py @@ -116,3 +116,13 @@ def to_params(self) -> BetaToolParam: def __call__(self, *_args: Any, **_kwargs: Any) -> ToolResult: error_msg = "Tool subclasses must implement __call__ method" raise NotImplementedError(error_msg) + + +class AgentException(Exception): + """ + Exception raised by the agent. + """ + + def __init__(self, message: str): + self.message = message + super().__init__(self.message) diff --git a/src/askui/tools/anthropic/collection.py b/src/askui/tools/anthropic/collection.py index 2141c8fa..5e49f996 100644 --- a/src/askui/tools/anthropic/collection.py +++ b/src/askui/tools/anthropic/collection.py @@ -4,18 +4,13 @@ from anthropic.types.beta import BetaToolUnionParam -from .base import ( - BaseAnthropicTool, - ToolError, - ToolFailure, - ToolResult, -) +from .base import AgentException, BaseAnthropicTool, ToolError, ToolFailure, ToolResult class ToolCollection: """A collection of anthropic-defined tools.""" - def __init__(self, *tools: BaseAnthropicTool): + def __init__(self, tools: list[BaseAnthropicTool]): self.tools = tools self.tool_map = {tool.to_params()["name"]: tool for tool in tools} @@ -24,11 +19,21 @@ def to_params( ) -> list[BetaToolUnionParam]: return [tool.to_params() for tool in self.tools] + def add_tool(self, tool: BaseAnthropicTool) -> None: + """Add a tool to the collection.""" + self.tools.append(tool) + self.tool_map[tool.to_params()["name"]] = tool + def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult: tool = self.tool_map.get(name) if not tool: return ToolFailure(error=f"Tool {name} is invalid") try: return cast("ToolResult", tool(**tool_input)) + except AgentException as e: + raise e # noqa: TRY201 except ToolError as e: return ToolFailure(error=e.message) + except Exception as e: # noqa: BLE001 + error_message = f"Unexpected error occurred with tool {name}: {e}" + return ToolFailure(error=error_message) diff --git a/src/askui/tools/anthropic/exception_tool.py b/src/askui/tools/anthropic/exception_tool.py new file mode 100644 index 00000000..c502b855 --- /dev/null +++ b/src/askui/tools/anthropic/exception_tool.py @@ -0,0 +1,31 @@ +from askui.tools.anthropic.base import AgentException, Tool, ToolResult + + +class ExceptionTool(Tool): + """ + Exception tool that can be used to raise an exception. + """ + + def __init__(self) -> None: + super().__init__( + name="exception_tool", + description=""" + Exception tool that can be used to raise an exception. + which will stop the execution of the agent. + """, + input_schema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": """ + The exception message to raise. this will be displayed to the user. + """, + }, + }, + "required": ["text"], + }, + ) + + def __call__(self, text: str) -> ToolResult: + raise AgentException(text)