diff --git a/src/askui/agent.py b/src/askui/agent.py index 30dbce32..3600257d 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -19,6 +19,7 @@ from askui.models.shared.tools import Tool from askui.tools.computer import Computer20241022Tool, Computer20250124Tool from askui.tools.exception_tool import ExceptionTool +from askui.tools.screen_switch_tool import ScreenSwitchTool from .logger import logger from .models import ModelComposition @@ -401,6 +402,8 @@ def _get_default_settings_for_act(self, model_choice: str) -> ActSettings: @override def _get_default_tools_for_act(self, model_choice: str) -> list[Tool]: + self._tools.append(ScreenSwitchTool(agent_os=self.tools.os)) + match model_choice: case ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022: return self._tools + [Computer20241022Tool(agent_os=self.tools.os)] diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index 00b8902a..e60353f6 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Literal from PIL import Image -from pydantic import BaseModel +from pydantic import BaseModel, Field if TYPE_CHECKING: from askui.tools.askui.askui_ui_controller_grpc.generated.AgentOS_Send_Request_2501 import ( # noqa: E501 @@ -154,6 +154,26 @@ class ClickEvent(BaseModel): timestamp: float +class SizeInPixels(BaseModel): + """Represents the size of a display in pixels.""" + + width: int + height: int + + +class DisplayInformation(BaseModel): + """Contains information about a single display.""" + + display_id: int = Field(validation_alias="displayID") + size_in_pixels: SizeInPixels = Field(validation_alias="sizeInPixels") + + +class GetDisplayInformationResponse(BaseModel): + """Response model for display information requests.""" + + displays: list[DisplayInformation] + + class Coordinate(BaseModel): x: int y: int @@ -333,6 +353,18 @@ def set_display(self, display: int = 1) -> None: """ raise NotImplementedError + def get_display_information(self) -> GetDisplayInformationResponse: + """ + Get information about all available displays and virtual screen. + """ + raise NotImplementedError + + def get_active_display(self) -> int: + """ + Get the active display. + """ + raise NotImplementedError + def run_command(self, command: str, timeout_ms: int = 30000) -> None: """ Executes a shell command. diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index 2ff1f2d8..1010a232 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -7,6 +7,7 @@ from typing import Literal, Type import grpc +from google.protobuf.json_format import MessageToDict from PIL import Image from pydantic import BaseModel, Field, model_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -15,7 +16,13 @@ from askui.container import telemetry from askui.logger import logger from askui.reporting import Reporter -from askui.tools.agent_os import AgentOs, Coordinate, ModifierKey, PcKey +from askui.tools.agent_os import ( + AgentOs, + Coordinate, + GetDisplayInformationResponse, + ModifierKey, + PcKey, +) from askui.tools.askui.askui_ui_controller_grpc.generated import ( Controller_V1_pb2 as controller_v1_pbs, ) @@ -723,6 +730,14 @@ def set_display(self, display: int = 1) -> None: ) self._display = display + @telemetry.record_call() + @override + def get_active_display(self) -> int: + """ + Get the active display. + """ + return self._display + @telemetry.record_call(exclude={"command"}) @override def run_command(self, command: str, timeout_ms: int = 30000) -> None: @@ -747,14 +762,13 @@ def run_command(self, command: str, timeout_ms: int = 30000) -> None: @telemetry.record_call() def get_display_information( self, - ) -> controller_v1_pbs.Response_GetDisplayInformation: + ) -> GetDisplayInformationResponse: """ Get information about all available displays and virtual screen. Returns: - controller_v1_pbs.Response_GetDisplayInformation: - - displays: List of DisplayInformation objects - - virtualScreenRectangle: Overall virtual screen bounds + GetDisplayInformationResponse: A Pydantic model containing information + about all available displays and the virtual screen. """ assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( "Stub is not initialized" @@ -765,8 +779,11 @@ def get_display_information( response: controller_v1_pbs.Response_GetDisplayInformation = ( self._stub.GetDisplayInformation(controller_v1_pbs.Request_Void()) ) - - return response + response_dict = MessageToDict( + response, + preserving_proto_field_name=True, + ) + return GetDisplayInformationResponse.model_validate(response_dict) @telemetry.record_call() def get_process_list( @@ -774,12 +791,10 @@ def get_process_list( ) -> controller_v1_pbs.Response_GetProcessList: """ Get a list of running processes. - Args: get_extended_info (bool, optional): Whether to include extended process information. Defaults to `False`. - Returns: controller_v1_pbs.Response_GetProcessList: Process list response containing: - processes: List of ProcessInfo objects @@ -802,10 +817,8 @@ def get_window_list( ) -> controller_v1_pbs.Response_GetWindowList: """ Get a list of windows for a specific process. - Args: process_id (int): The ID of the process to get windows for. - Returns: controller_v1_pbs.Response_GetWindowList: Window list response containing: - windows: List of WindowInfo objects with ID and name @@ -828,7 +841,6 @@ def get_automation_target_list( ) -> controller_v1_pbs.Response_GetAutomationTargetList: """ Get a list of available automation targets. - Returns: controller_v1_pbs.Response_GetAutomationTargetList: Automation target list response: @@ -850,7 +862,6 @@ def get_automation_target_list( def set_mouse_delay(self, delay_ms: int) -> None: """ Configure mouse action delay. - Args: delay_ms (int): The delay in milliseconds to set for mouse actions. """ @@ -870,7 +881,6 @@ def set_mouse_delay(self, delay_ms: int) -> None: def set_keyboard_delay(self, delay_ms: int) -> None: """ Configure keyboard action delay. - Args: delay_ms (int): The delay in milliseconds to set for keyboard actions. """ @@ -890,7 +900,6 @@ def set_keyboard_delay(self, delay_ms: int) -> None: def set_active_window(self, process_id: int, window_id: int) -> None: """ Set the active window for automation. - Args: process_id (int): The ID of the process that owns the window. window_id (int): The ID of the window to set as active. @@ -913,7 +922,6 @@ def set_active_window(self, process_id: int, window_id: int) -> None: def set_active_automation_target(self, target_id: int) -> None: """ Set the active automation target. - Args: target_id (int): The ID of the automation target to set as active. """ @@ -937,13 +945,11 @@ def schedule_batched_action( ) -> controller_v1_pbs.Response_ScheduleBatchedAction: """ Schedule an action for batch execution. - Args: action_class_id (controller_v1_pbs.ActionClassID): The class ID of the action to schedule. action_parameters (controller_v1_pbs.ActionParameters): Parameters for the action. - Returns: controller_v1_pbs.Response_ScheduleBatchedAction: Response containing the scheduled action ID. @@ -1003,7 +1009,6 @@ def stop_batch_run(self) -> None: def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: """ Get the count of recorded or batched actions. - Returns: controller_v1_pbs.Response_GetActionCount: Response containing the action count. @@ -1024,10 +1029,8 @@ def get_action_count(self) -> controller_v1_pbs.Response_GetActionCount: def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: """ Get a specific action by its index. - Args: action_index (int): The index of the action to retrieve. - Returns: controller_v1_pbs.Response_GetAction: Action information containing: - actionID: The action ID @@ -1052,7 +1055,6 @@ def get_action(self, action_index: int) -> controller_v1_pbs.Response_GetAction: def remove_action(self, action_id: int) -> None: """ Remove a specific action by its ID. - Args: action_id (int): The ID of the action to remove. """ @@ -1086,10 +1088,8 @@ def remove_all_actions(self) -> None: def _send_message(self, message: str) -> controller_v1_pbs.Response_Send: """ Send a general message to the controller. - Args: message (str): The message to send to the controller. - Returns: controller_v1_pbs.Response_Send: Response containing the message from the controller. @@ -1110,7 +1110,6 @@ def _send_message(self, message: str) -> controller_v1_pbs.Response_Send: def get_mouse_position(self) -> Coordinate: """ Get the mouse cursor position - Returns: Coordinate: Response containing the result of the mouse position change. """ @@ -1132,7 +1131,6 @@ def get_mouse_position(self) -> Coordinate: def set_mouse_position(self, x: int, y: int) -> None: """ Set the mouse cursor position to specific coordinates. - Args: x (int): The horizontal coordinate (in pixels) to set the cursor to. y (int): The vertical coordinate (in pixels) to set the cursor to. @@ -1150,10 +1148,8 @@ def set_mouse_position(self, x: int, y: int) -> None: def render_quad(self, style: RenderObjectStyle) -> int: """ Render a quad object to the display. - Args: style (RenderObjectStyle): The style properties for the quad. - Returns: int: Object ID. """ @@ -1174,11 +1170,9 @@ def render_quad(self, style: RenderObjectStyle) -> int: def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int: """ Render a line object to the display. - Args: style (RenderObjectStyle): The style properties for the line. points (list[Coordinates]): The points defining the line. - Returns: int: Object ID. """ @@ -1199,11 +1193,9 @@ def render_line(self, style: RenderObjectStyle, points: list[Coordinate]) -> int def render_image(self, style: RenderObjectStyle, image_data: str) -> int: """ Render an image object to the display. - Args: style (RenderObjectStyle): The style properties for the image. image_data (str): The base64-encoded image data. - Returns: int: Object ID. """ @@ -1225,11 +1217,9 @@ def render_image(self, style: RenderObjectStyle, image_data: str) -> int: def render_text(self, style: RenderObjectStyle, content: str) -> int: """ Render a text object to the display. - Args: style (RenderObjectStyle): The style properties for the text. content (str): The text content to display. - Returns: int: Object ID. """ @@ -1251,11 +1241,9 @@ def render_text(self, style: RenderObjectStyle, content: str) -> int: def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None: """ Update styling properties of an existing render object. - Args: object_id (float): The ID of the render object to update. style (RenderObjectStyle): The new style properties. - Returns: int: Object ID. """ @@ -1274,7 +1262,6 @@ def update_render_object(self, object_id: int, style: RenderObjectStyle) -> None def delete_render_object(self, object_id: int) -> None: """ Delete an existing render object from the display. - Args: object_id (RenderObjectId): The ID of the render object to delete. """ diff --git a/src/askui/tools/screen_switch_tool.py b/src/askui/tools/screen_switch_tool.py new file mode 100644 index 00000000..23957a64 --- /dev/null +++ b/src/askui/tools/screen_switch_tool.py @@ -0,0 +1,45 @@ +from askui.models.shared.tools import Tool +from askui.tools.agent_os import AgentOs, DisplayInformation + + +class ScreenSwitchTool(Tool): + """ + Tool to change the screen. + """ + + def __init__(self, agent_os: AgentOs) -> None: + # We need to determine the number of displays available to provide context + # to the agent indicating that screen switching can only be done this number + # of times. + displays: list[DisplayInformation] = agent_os.get_display_information().displays + + super().__init__( + name="screen_switch", + description=f""" + This tool is useful for switching between multiple displays to find + information not present on the current active screen. + If more than one display is available, this tool cycles through them. + Number of displays available: {len(displays)}. + """, + ) + self._agent_os: AgentOs = agent_os + self._displays: list[DisplayInformation] = displays + + def __call__(self) -> None: + """ + Cycles to the next display if there are multiple displays. + This tool is useful to switch between multiple displays if some information is + not found on the current display. + """ + if len(self._displays) <= 1: + return + + active_display_id: int = self._agent_os.get_active_display() + + current_display_index: int = next( + i for i, d in enumerate(self._displays) if d.display_id == active_display_id + ) + # if current_index is the last index, wrap around to the first index + next_index: int = (current_display_index + 1) % len(self._displays) + + self._agent_os.set_display(self._displays[next_index].display_id)