diff --git a/README.md b/README.md index 49fada3e..d6a91552 100644 --- a/README.md +++ b/README.md @@ -425,12 +425,12 @@ result = agent.get("What's in this image?", "screenshot.png") #### Using response schemas -For structured data extraction, use Pydantic models extending `JsonSchemaBase`: +For structured data extraction, use Pydantic models extending `ResponseSchemaBase`: ```python -from askui import JsonSchemaBase +from askui import ResponseSchemaBase -class UserInfo(JsonSchemaBase): +class UserInfo(ResponseSchemaBase): username: str is_online: bool diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 58d5b144..30fa7a4f 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -3,18 +3,19 @@ __version__ = "0.3.0" from .agent import VisionAgent -from .models.router import ModelRouter +from .models import ModelComposition, ModelDefinition +from .models.router import Point from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase -from .tools.toolbox import AgentToolbox -from .tools.agent_os import AgentOs, ModifierKey, PcKey - +from .tools import ModifierKey, PcKey +from .utils.image_utils import Img __all__ = [ - "AgentOs", - "AgentToolbox", - "ModelRouter", + "Img", + "ModelComposition", + "ModelDefinition", "ModifierKey", "PcKey", + "Point", "ResponseSchema", "ResponseSchemaBase", "VisionAgent", diff --git a/src/askui/agent.py b/src/askui/agent.py index 2948e88e..3d6d76ac 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -7,14 +7,12 @@ from askui.locators.locators import Locator from askui.utils.image_utils import ImageSource, Img -from .tools.askui.askui_controller import ( +from .tools.askui import ( AskUiControllerClient, AskUiControllerServer, - ModifierKey, - PcKey, ) from .logger import logger, configure_logging -from .tools.toolbox import AgentToolbox +from .tools import AgentToolbox, ModifierKey, PcKey from .models import ModelComposition from .models.router import ModelRouter, Point from .reporting import CompositeReporter, Reporter @@ -34,23 +32,18 @@ class VisionAgent: This agent can perform various UI interactions like clicking, typing, scrolling, and more. It uses computer vision models to locate UI elements and execute actions on them. - Parameters: - log_level (int, optional): - The logging level to use. Defaults to logging.INFO. - display (int, optional): - The display number to use for screen interactions. Defaults to 1. - model_router (ModelRouter | None, optional): - Custom model router instance. If None, a default one will be created. - reporters (list[Reporter] | None, optional): - List of reporter instances for logging and reporting. If None, an empty list is used. - tools (AgentToolbox | None, optional): - Custom toolbox instance. If None, a default one will be created with AskUiControllerClient. - model (ModelComposition | str | None, optional): - The default composition or name of the model(s) to be used for vision tasks. - Can be overridden by the `model` parameter in the `click()`, `get()`, `act()` etc. methods. + Args: + log_level (int | str, optional): The logging level to use. Defaults to `logging.INFO`. + display (int, optional): The display number to use for screen interactions. Defaults to `1`. + model_router (ModelRouter | None, optional): Custom model router instance. If `None`, a default one will be created. + reporters (list[Reporter] | None, optional): List of reporter instances for logging and reporting. If `None`, an empty list is used. + tools (AgentToolbox | None, optional): Custom toolbox instance. If `None`, a default one will be created with `AskUiControllerClient`. + model (ModelComposition | str | None, optional): The default composition or name of the model(s) to be used for vision tasks. Can be overridden by the `model` parameter in the `click()`, `get()`, `act()` etc. methods. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.click("Submit button") agent.type("Hello World") @@ -70,7 +63,7 @@ def __init__( ) -> None: load_dotenv() configure_logging(level=log_level) - self._reporter = CompositeReporter(reports=reporters) + self._reporter = CompositeReporter(reporters=reporters) self.tools = tools or AgentToolbox( agent_os=AskUiControllerClient( display=display, @@ -95,21 +88,16 @@ def click( """ Simulates a mouse click on the user interface element identified by the provided locator. - Parameters: - locator (str | Locator | None): - The identifier or description of the element to click. If None, clicks at current position. - button ('left' | 'middle' | 'right'): - Specifies which mouse button to click. Defaults to 'left'. - repeat (int): - The number of times to click. Must be greater than 0. Defaults to 1. - model (ModelComposition | str | None): - The composition or name of the model(s) to be used for locating the element to click on using the `locator`. - - Raises: - InvalidParameterError: If the 'repeat' parameter is less than 1. + Args: + locator (str | Locator | None, optional): The identifier or description of the element to click. If `None`, clicks at current position. + button ('left' | 'middle' | 'right', optional): Specifies which mouse button to click. Defaults to `'left'`. + repeat (int, optional): The number of times to click. Must be greater than `0`. Defaults to `1`. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for locating the element to click on using the `locator`. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.click() # Left click on current position agent.click("Edit") # Left click on text "Edit" @@ -149,20 +137,18 @@ def locate( """ Locates the UI element identified by the provided locator. - Parameters: - locator (str | Locator): - The identifier or description of the element to locate. - screenshot (Img | None, optional): - The screenshot to use for locating the element. Can be a path to an image file, a PIL Image object or a data URL. - If None, takes a screenshot of the currently selected display. - model (ModelComposition | str | None): - The composition or name of the model(s) to be used for locating the element using the `locator`. + Args: + locator (str | Locator): The identifier or description of the element to locate. + screenshot (Img | None, optional): The screenshot to use for locating the element. Can be a path to an image file, a PIL Image object or a data URL. If `None`, takes a screenshot of the currently selected display. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for locating the element using the `locator`. Returns: Point: The coordinates of the element as a tuple (x, y). Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: point = agent.locate("Submit button") print(f"Element found at coordinates: {point}") @@ -186,14 +172,14 @@ def mouse_move( """ Moves the mouse cursor to the UI element identified by the provided locator. - Parameters: - locator (str | Locator): - The identifier or description of the element to move to. - model (ModelComposition | str | None): - The composition or name of the model(s) to be used for locating the element to move the mouse to using the `locator`. + Args: + locator (str | Locator): The identifier or description of the element to move to. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for locating the element to move the mouse to using the `locator`. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.mouse_move("Submit button") # Moves cursor to submit button agent.mouse_move("Close") # Moves cursor to close element @@ -214,21 +200,21 @@ def mouse_scroll( """ Simulates scrolling the mouse wheel by the specified horizontal and vertical amounts. - Parameters: - x (int): - The horizontal scroll amount. Positive values typically scroll right, negative values scroll left. - y (int): - The vertical scroll amount. Positive values typically scroll down, negative values scroll up. + Args: + x (int): The horizontal scroll amount. Positive values typically scroll right, negative values scroll left. + y (int): The vertical scroll amount. Positive values typically scroll down, negative values scroll up. Note: The actual scroll direction depends on the operating system's configuration. Some systems may have "natural scrolling" enabled, which reverses the traditional direction. The meaning of scroll units varies across operating systems and applications. - A scroll value of 10 might result in different distances depending on the application and system settings. + A scroll value of `10` might result in different distances depending on the application and system settings. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.mouse_scroll(0, 10) # Usually scrolls down 10 units agent.mouse_scroll(0, -5) # Usually scrolls up 5 units @@ -247,12 +233,13 @@ def type( """ Types the specified text as if it were entered on a keyboard. - Parameters: - text (str): - The text to be typed. Must be at least 1 character long. + Args: + text (str): The text to be typed. Must be at least `1` character long. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.type("Hello, world!") # Types "Hello, world!" agent.type("user@example.com") # Types an email address @@ -293,21 +280,14 @@ def get( """ Retrieves information from an image (defaults to a screenshot of the current screen) based on the provided query. - Parameters: - query (str): - The query describing what information to retrieve. - image (Img | None, optional): - The image to extract information from. Defaults to a screenshot of the current screen. - Can be a path to an image file, a PIL Image object or a data URL. - response_schema (Type[ResponseSchema] | None, optional): - A Pydantic model class that defines the response schema. If not provided, returns a string. - model (ModelComposition | str | None, optional): - The composition or name of the model(s) to be used for retrieving information from the screen or image using the `query`. - Note: `response_schema` is not supported by all models. + Args: + query (str): The query describing what information to retrieve. + image (Img | None, optional): The image to extract information from. Defaults to a screenshot of the current screen. Can be a path to an image file, a PIL Image object or a data URL. + response_schema (Type[ResponseSchema] | None, optional): A Pydantic model class that defines the response schema. If not provided, returns a string. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for retrieving information from the screen or image using the `query`. Note: `response_schema` is not supported by all models. Returns: - ResponseSchema | str: - The extracted information, either as an instance of ResponseSchema or string if no response_schema is provided. + ResponseSchema | str: The extracted information, `str` if no `response_schema` is provided. Limitations: - Nested Pydantic schemas are not currently supported @@ -315,7 +295,7 @@ def get( Example: ```python - from askui import JsonSchemaBase + from askui import JsonSchemaBase, VisionAgent from PIL import Image class UrlResponse(JsonSchemaBase): @@ -376,15 +356,13 @@ def wait( """ Pauses the execution of the program for the specified number of seconds. - Parameters: - sec (float): - The number of seconds to wait. Must be greater than 0.0. - - Raises: - ValueError: If the provided `sec` is negative. + Args: + sec (float): The number of seconds to wait. Must be greater than `0.0`. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.wait(5) # Pauses execution for 5 seconds agent.wait(0.5) # Pauses execution for 500 milliseconds @@ -401,12 +379,13 @@ def key_up( """ Simulates the release of a key. - Parameters: - key (PcKey | ModifierKey): - The key to be released. + Args: + key (PcKey | ModifierKey): The key to be released. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.key_up('a') # Release the 'a' key agent.key_up('shift') # Release the 'Shift' key @@ -425,12 +404,13 @@ def key_down( """ Simulates the pressing of a key. - Parameters: - key (PcKey | ModifierKey): - The key to be pressed. + Args: + key (PcKey | ModifierKey): The key to be pressed. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.key_down('a') # Press the 'a' key agent.key_down('shift') # Press the 'Shift' key @@ -454,14 +434,14 @@ def act( to accomplish the goal. This may include clicking, typing, scrolling, and other interface interactions. - Parameters: - goal (str): - A description of what the agent should achieve. - model (ModelComposition | str | None, optional): - The composition or name of the model(s) to be used for achieving the `goal`. + Args: + goal (str): A description of what the agent should achieve. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for achieving the `goal`. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.act("Open the settings menu") agent.act("Search for 'printer' in the search box") @@ -484,14 +464,14 @@ def keyboard( """ Simulates pressing a key or key combination on the keyboard. - Parameters: - key (PcKey | ModifierKey): - The main key to press. This can be a letter, number, special character, or function key. - modifier_keys (list[ModifierKey] | None, optional): - List of modifier keys to press along with the main key. Common modifier keys include 'ctrl', 'alt', 'shift'. + Args: + key (PcKey | ModifierKey): The main key to press. This can be a letter, number, special character, or function key. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Common modifier keys include `'ctrl'`, `'alt'`, `'shift'`. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.keyboard('a') # Press 'a' key agent.keyboard('enter') # Press 'Enter' key @@ -514,12 +494,13 @@ def cli( This method allows running shell commands directly from the agent. The command is split on spaces and executed as a subprocess. - Parameters: - command (str): - The command to execute on the command line. + Args: + command (str): The command to execute on the command line. Example: ```python + from askui import VisionAgent + with VisionAgent() as agent: agent.cli("echo Hello World") # Prints "Hello World" agent.cli("ls -la") # Lists files in current directory with details diff --git a/src/askui/locators/__init__.py b/src/askui/locators/__init__.py index 23964220..2dd749b2 100644 --- a/src/askui/locators/__init__.py +++ b/src/askui/locators/__init__.py @@ -1,9 +1,16 @@ -from askui.locators.locators import AiElement, Element, Prompt, Image, Text +from .locators import AiElement, Element, Prompt, Image, Text, TextMatchType, Locator +from .relatable import CircularDependencyError, ReferencePoint, RelationIndex, Relatable __all__ = [ "AiElement", + "CircularDependencyError", "Element", - "Prompt", "Image", + "Locator", + "Prompt", + "ReferencePoint", + "Relatable", + "RelationIndex", "Text", + "TextMatchType", ] diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index 24bc569a..5a5b518e 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -10,8 +10,24 @@ from askui.locators.relatable import Relatable +TextMatchType = Literal["similar", "exact", "contains", "regex"] +"""The type of match to use. + +- `"similar"` uses a similarity threshold to determine if the text is a match. +- `"exact"` requires the text to be exactly the same (this is not the same as `"similar"` + with a `similarity_threshold` of `100` as a `similarity_threshold` of `100` can still + allow for small differences in very long texts). +- `"contains"` requires the text to contain (exactly) the specified text. +- `"regex"` uses a regular expression to match the text. +""" + + +DEFAULT_TEXT_MATCH_TYPE: TextMatchType = "similar" +DEFAULT_SIMILARITY_THRESHOLD = 70 + + class Locator(Relatable, ABC): - """Base class for all locators.""" + """Abstract base class for all locators. Cannot be instantiated directly.""" def _str(self) -> str: return "locator" @@ -20,150 +36,147 @@ def _str(self) -> str: class Prompt(Locator): - """Locator for finding ui elements by a textual prompt / description of a ui element, e.g., "green sign up button".""" + """Locator for finding ui elements by a textual prompt / description of a ui element, e.g., "green sign up button". + + Args: + prompt (str): A textual prompt / description of a ui element, e.g., `"green sign up button"` + + Examples: + ```python + from askui import locators as loc + # locates a green sign up button + button = loc.Prompt("green sign up button") + # locates an email text field, e.g., with label "Email" or a placeholder "john.doe@example.com" + textfield = loc.Prompt("email text field") + # locates the avatar in the right hand corner of the application + avatar_top_right_corner = loc.Prompt("avatar in the top right corner of the application") + ``` + """ @validate_call def __init__( self, prompt: Annotated[ str, - Field( - description="""A textual prompt / description of a ui element, e.g., "green sign up button".""" - ), + Field(), ], ) -> None: - """Initialize a Prompt locator. - - Args: - prompt: A textual prompt / description of a ui element, e.g., "green sign up button" - """ super().__init__() self._prompt = prompt - - @property - def prompt(self) -> str: - return self._prompt def _str(self) -> str: - return f'element with prompt "{self.prompt}"' + return f'element with prompt "{self._prompt}"' class Element(Locator): - """Locator for finding ui elements by a class name assigned to the ui element, e.g., by a computer vision model.""" + """Locator for finding ui elements by their class. + + Args: + class_name (Literal["text", "textfield"] | None, optional): The class of the ui element, e.g., `'text'` or `'textfield'`. Defaults to `None`. + + Examples: + ```python + from askui import locators as loc + # locates a text elementAdd + text = loc.Element(class_name="text") + # locates a textfield element + textfield = loc.Element(class_name="textfield") + # locates any ui element detected + element = loc.Element() + ``` + """ @validate_call def __init__( self, class_name: Annotated[ Literal["text", "textfield"] | None, - Field( - description="""The class name of the ui element, e.g., 'text' or 'textfield'.""" - ), + Field(), ] = None, ) -> None: - """Initialize an Element locator. - - Args: - class_name: The class name of the ui element, e.g., 'text' or 'textfield' - """ super().__init__() self._class_name = class_name - @property - def class_name(self) -> Literal["text", "textfield"] | None: - return self._class_name - def _str(self) -> str: return ( - f'element with class "{self.class_name}"' if self.class_name else "element" + f'element with class "{self._class_name}"' if self._class_name else "element" ) -TextMatchType = Literal["similar", "exact", "contains", "regex"] -DEFAULT_TEXT_MATCH_TYPE: TextMatchType = "similar" -DEFAULT_SIMILARITY_THRESHOLD = 70 - - class Text(Element): - """Locator for finding text elements by their content.""" + """Locator for finding text elements by their textual content. + + Args: + text (str | None, optional): The text content of the ui element, e.g., `'Sign up'`. Defaults to `None`. + If `None`, the locator will match any text element. + match_type (TextMatchType, optional): The type of match to use. Defaults to `"similar"`. + similarity_threshold (int, optional): A threshold for how similar the actual text content of the ui element + needs to be to the specified text to be considered a match when `match_type` is `"similar"`. + Takes values between `0` and `100` (inclusive, higher is more similar). + Defaults to `70`. + + Examples: + ```python + from askui import locators as loc + # locates a text element with text similar to "Sign up", e.g., "Sign up" or "Sign Up" or "Sign-Up" + text = loc.Text("Sign up") + # if it does not find an element, you can try decreasing the similarity threshold (default is `70`) + text = loc.Text("Sign up", match_type="similar", similarity_threshold=50) + # if it also locates "Sign In", you can try increasing the similarity threshold (default is `70`) + text = loc.Text("Sign up", match_type="similar", similarity_threshold=80) + # or use `match_type="exact"` to require an exact match (does not match other variations of "Sign up", e.g., "Sign Up" or "Sign-Up") + text = loc.Text("Sign up", match_type="exact") + # locates a text element starting with "Sign" or "sign" using a regular expression + text = loc.Text("^[Ss]ign.*", match_type="regex") + # locates a text element containing "Sign" (exact match) + text = loc.Text("Sign", match_type="contains") + ``` + """ @validate_call def __init__( self, text: Annotated[ str | None, - Field( - description="""The text content of the ui element, e.g., 'Sign up'.""" - ), + Field(), ] = None, match_type: Annotated[ TextMatchType, - Field( - description="""The type of match to use. Defaults to 'similar'. - 'similar' uses a similarity threshold to determine if the text is a match. - 'exact' requires the text to be exactly the same. - 'contains' requires the text to contain the specified text. - 'regex' uses a regular expression to match the text.""" - ), + Field(), ] = DEFAULT_TEXT_MATCH_TYPE, similarity_threshold: Annotated[ int, Field( ge=0, le=100, - description="""A threshold for how similar the text - needs to be to the text content of the ui element to be considered a match. - Takes values between 0 and 100 (higher is more similar). Defaults to 70. - Only used if match_type is 'similar'.""", ), ] = DEFAULT_SIMILARITY_THRESHOLD, ) -> None: - """Initialize a Text locator. - - Args: - text: The text content of the ui element, e.g., 'Sign up' - match_type: The type of match to use. Defaults to 'similar'. 'similar' uses a similarity threshold to - determine if the text is a match. 'exact' requires the text to be exactly the same. 'contains' - requires the text to contain the specified text. 'regex' uses a regular expression to match the text. - similarity_threshold: A threshold for how similar the text needs to be to the text content of the ui - element to be considered a match. Takes values between 0 and 100 (higher is more similar). - Defaults to 70. Only used if match_type is 'similar'. - """ super().__init__() self._text = text self._match_type = match_type self._similarity_threshold = similarity_threshold - @property - def text(self) -> str | None: - return self._text - - @property - def match_type(self) -> TextMatchType: - return self._match_type - - @property - def similarity_threshold(self) -> int: - return self._similarity_threshold - def _str(self) -> str: - if self.text is None: + if self._text is None: result = "text" else: result = "text " - match self.match_type: + match self._match_type: case "similar": - result += f'similar to "{self.text}" (similarity >= {self.similarity_threshold}%)' + result += f'similar to "{self._text}" (similarity >= {self._similarity_threshold}%)' case "exact": - result += f'"{self.text}"' + result += f'"{self._text}"' case "contains": - result += f'containing text "{self.text}"' + result += f'containing text "{self._text}"' case "regex": - result += f'matching regex "{self.text}"' + result += f'matching regex "{self._text}"' return result class ImageBase(Locator, ABC): + """Abstract base class for image locators. Cannot be instantiated directly.""" + def __init__( self, threshold: float, @@ -184,47 +197,23 @@ def __init__( self._rotation_degree_per_step = rotation_degree_per_step self._name = name self._image_compare_format = image_compare_format - - @property - def threshold(self) -> float: - return self._threshold - - @property - def stop_threshold(self) -> float: - return self._stop_threshold - - @property - def mask(self) -> list[tuple[float, float]] | None: - return self._mask - - @property - def rotation_degree_per_step(self) -> int: - return self._rotation_degree_per_step - - @property - def name(self) -> str: - return self._name - - @property - def image_compare_format(self) -> Literal["RGB", "grayscale", "edges"]: - return self._image_compare_format def _params_str(self) -> str: return ( "(" + ", ".join([ - f"threshold: {self.threshold}", - f"stop_threshold: {self.stop_threshold}", - f"rotation_degree_per_step: {self.rotation_degree_per_step}", - f"image_compare_format: {self.image_compare_format}", - f"mask: {self.mask}" + f"threshold: {self._threshold}", + f"stop_threshold: {self._stop_threshold}", + f"rotation_degree_per_step: {self._rotation_degree_per_step}", + f"image_compare_format: {self._image_compare_format}", + f"mask: {self._mask}" ]) + ")" ) def _str(self) -> str: return ( - f'element "{self.name}" located by image ' + f'element "{self._name}" located by image ' + self._params_str() ) @@ -234,7 +223,35 @@ def _generate_name() -> str: class Image(ImageBase): - """Locator for finding ui elements by an image.""" + """Locator for finding ui elements by an image. + + Args: + image (Union[PIL.Image.Image, pathlib.Path, str]): The image to match against (PIL Image, path, or string) + threshold (float, optional): A threshold for how similar UI elements need to be to the image to be considered a match. + Takes values between `0.0` (= all elements are recognized) and `1.0` (= elements need to look exactly + like defined). Defaults to `0.5`. Important: The threshold impacts the prediction quality. + stop_threshold (float | None, optional): A threshold for when to stop searching for UI elements similar to the image. As soon + as UI elements have been found that are at least as similar as the `stop_threshold`, the search stops. + Should be greater than or equal to `threshold`. Takes values between `0.0` and `1.0`. Defaults to value of + `threshold` if not provided. Important: The `stop_threshold` impacts the prediction speed. + mask (list[tuple[float, float]] | None, optional): A polygon to match only a certain area of the image. Must have at least 3 points. + Defaults to `None`. + rotation_degree_per_step (int, optional): A step size in rotation degree. Rotates the image by `rotation_degree_per_step` + until 360° is exceeded. Range is between `0°` - `360°`. Defaults to `0°`. Important: This increases the + prediction time quite a bit. So only use it when absolutely necessary. + name (str | None, optional): Name for the image. Defaults to random name. + image_compare_format (Literal["RGB", "grayscale", "edges"], optional): A color compare style. Defaults to `'grayscale'`. **Important**: + The `image_compare_format` impacts the prediction time as well as quality. As a rule of thumb, + `'edges'` is likely to be faster than `'grayscale'` and `'grayscale'` is likely to be faster than `'RGB'`. + For quality it is most often the other way around. + + Examples: + ```python + from askui import locators as loc + # locates an image element with an image similar to "sign up button" + image = loc.Image("path/to/image.png") + ``` + """ @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( @@ -243,29 +260,21 @@ def __init__( threshold: Annotated[ float, Field( - ge=0, - le=1, - description="""A threshold for how similar UI elements need to be to the image to be considered a match. - Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly - like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality.""", + ge=0.0, + le=1.0, ), ] = 0.5, stop_threshold: Annotated[ float | None, Field( - ge=0, - le=1, - description="""A threshold for when to stop searching for UI elements similar to the image. As soon - as UI elements have been found that are at least as similar as the stop_threshold, the search stops. Should - be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of `threshold` if - not provided. Important: The stop_threshold impacts the prediction speed.""", + ge=0.0, + le=1.0, ), ] = None, mask: Annotated[ list[tuple[float, float]] | None, Field( min_length=3, - description="A polygon to match only a certain area of the image.", ), ] = None, rotation_degree_per_step: Annotated[ @@ -273,43 +282,14 @@ def __init__( Field( ge=0, lt=360, - description="""A step size in rotation degree. Rotates the image by rotation_degree_per_step until - 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the prediction time - quite a bit. So only use it when absolutely necessary.""", ), ] = 0, name: str | None = None, image_compare_format: Annotated[ Literal["RGB", "grayscale", "edges"], - Field( - description="""A color compare style. Defaults to 'grayscale'. - Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb, - 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For - quality it is most often the other way around.""" - ), + Field(), ] = "grayscale", ) -> None: - """Initialize an Image locator. - - Args: - image: The image to match against (PIL Image, path, or string) - threshold: A threshold for how similar UI elements need to be to the image to be considered a match. - Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly - like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality. - stop_threshold: A threshold for when to stop searching for UI elements similar to the image. As soon - as UI elements have been found that are at least as similar as the stop_threshold, the search stops. - Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of - `threshold` if not provided. Important: The stop_threshold impacts the prediction speed. - mask: A polygon to match only a certain area of the image. Must have at least 3 points. - rotation_degree_per_step: A step size in rotation degree. Rotates the image by rotation_degree_per_step - until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the - prediction time quite a bit. So only use it when absolutely necessary. - name: Optional name for the image. Defaults to generated UUID. - image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format - impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster - than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often - the other way around. - """ super().__init__( threshold=threshold, stop_threshold=stop_threshold or threshold, @@ -320,13 +300,31 @@ def __init__( ) # type: ignore self._image = ImageSource(image) - @property - def image(self) -> ImageSource: - return self._image - class AiElement(ImageBase): - """Locator for finding ui elements by an image and other kinds data saved on the disk.""" + """ + Locator for finding ui elements by data (e.g., image) collected with the [AskUIRemoteDeviceSnippingTool](http://localhost:3000/02-api-reference/02-askui-suite/02-askui-suite/AskUIRemoteDeviceSnippingTool/Public/AskUI-NewAIElement) using the `name` assigned to the AI element during *snipping* to retrieve the data used for locating the ui element(s). + + Args: + name (str): Name of the AI element + threshold (float, optional): A threshold for how similar UI elements need to be to the image to be considered a match. + Takes values between `0.0` (= all elements are recognized) and `1.0` (= elements need to look exactly + like defined). Defaults to `0.5`. Important: The threshold impacts the prediction quality. + stop_threshold (float | None, optional): A threshold for when to stop searching for UI elements similar to the image. As soon + as UI elements have been found that are at least as similar as the `stop_threshold`, the search stops. + Should be greater than or equal to `threshold`. Takes values between `0.0` and `1.0`. Defaults to value of + `threshold` if not provided. Important: The `stop_threshold` impacts the prediction speed. + mask (list[tuple[float, float]] | None, optional): A polygon to match only a certain area of the image. Must have at least 3 points. + Defaults to `None`. + rotation_degree_per_step (int, optional): A step size in rotation degree. Rotates the image by rotation_degree_per_step + until 360° is exceeded. Range is between `0°` - `360°`. Defaults to `0°`. Important: This increases the + prediction time quite a bit. So only use it when absolutely necessary. + name (str | None, optional): Name for the image. Defaults to random name. + image_compare_format (Literal["RGB", "grayscale", "edges"], optional): A color compare style. Defaults to `'grayscale'`. **Important**: + The `image_compare_format` impacts the prediction time as well as quality. As a rule of thumb, + `'edges'` is likely to be faster than `'grayscale'` and `'grayscale'` is likely to be faster than `'RGB'`. + For quality it is most often the other way around. + """ @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( @@ -335,30 +333,21 @@ def __init__( threshold: Annotated[ float, Field( - ge=0, - le=1, - description="""A threshold for how similar UI elements need to be to be considered a match. - Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match). - Defaults to 0.5. Important: The threshold impacts the prediction quality.""", + ge=0.0, + le=1.0, ), ] = 0.5, stop_threshold: Annotated[ float | None, Field( - ge=0, - le=1, - description="""A threshold for when to stop searching for UI elements. As soon - as UI elements have been found that are at least as similar as the stop_threshold, the search stops. - Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. - Defaults to value of `threshold` if not provided. - Important: The stop_threshold impacts the prediction speed.""", + ge=0.0, + le=1.0, ), ] = None, mask: Annotated[ list[tuple[float, float]] | None, Field( min_length=3, - description="A polygon to match only a certain area of the image of the element saved on disk.", ), ] = None, rotation_degree_per_step: Annotated[ @@ -366,42 +355,13 @@ def __init__( Field( ge=0, lt=360, - description="""A step size in rotation degree. Rotates the image of the element saved on disk by - rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. - Important: This increases the prediction time quite a bit. So only use it when absolutely necessary.""", ), ] = 0, image_compare_format: Annotated[ Literal["RGB", "grayscale", "edges"], - Field( - description="""A color compare style. Defaults to 'grayscale'. - Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb, - 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For - quality it is most often the other way around.""" - ), + Field(), ] = "grayscale", ) -> None: - """Initialize an AiElement locator. - - Args: - name: Name of the AI element - threshold: A threshold for how similar UI elements need to be to be considered a match. Takes values - between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match). - Defaults to 0.5. Important: The threshold impacts the prediction quality. - stop_threshold: A threshold for when to stop searching for UI elements. As soon as UI elements have - been found that are at least as similar as the stop_threshold, the search stops. Should be greater - than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of `threshold` if not - provided. Important: The stop_threshold impacts the prediction speed. - mask: A polygon to match only a certain area of the image of the element saved on disk. Must have at - least 3 points. - rotation_degree_per_step: A step size in rotation degree. Rotates the image of the element saved on - disk by rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. - Important: This increases the prediction time quite a bit. So only use it when absolutely necessary. - image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format - impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster - than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often - the other way around. - """ super().__init__( name=name, threshold=threshold, @@ -413,6 +373,6 @@ def __init__( def _str(self) -> str: return ( - f'ai element named "{self.name}" ' + f'ai element named "{self._name}" ' + self._params_str() ) diff --git a/src/askui/locators/relatable.py b/src/askui/locators/relatable.py index 1cb4df19..9a0a8bd8 100644 --- a/src/askui/locators/relatable.py +++ b/src/askui/locators/relatable.py @@ -5,6 +5,153 @@ ReferencePoint = Literal["center", "boundary", "any"] +""" +Defines under which conditions an element *A* is considered to be above, below, right or left of another element *B*. + +- `"center"`: *A* is considered to be above, below, right or left of *B* if it is above, below, right or left of *A*'s center (in a straight vertical or horizontal line). + +Examples: + + *A* being above *B* (imaginary straight vertical line also shown): + + ```text + =========== + | A | + =========== + | + =========== + | B | + =========== + ``` + + ```text + =========== + | A | + =========== + | + =========== + | B | + =========== + ``` + + *A* **NOT** being above *B* (imaginary straight vertical line also shown): + + ```text + |=========== + || A | + |=========== + | + =========== + | B | + =========== + ``` + + ```text + | =========== + | | A | + | =========== + | + =========== + | B | + =========== + ``` + + ```text + | + | + =========== + | B | + =========== + + =========== + | A | + =========== + ``` + + +- `"boundary"`: *A* is considered to be above, below, right or left of *B* if it is above, below, right or left of (any point of the bounding box of) *A* (in a straight vertical or horizontal line). + +Examples: + + *A* being above *B* (imaginary straight vertical line also shown): + + ```text + | =========== + | | A | + | =========== + | | + =========== + | B | + =========== + ``` + + *A* **NOT** being above *B* (imaginary straight vertical line also shown): + + ```text + | | =========== + | | | A | + | | =========== + | | + =========== + | B | + =========== + ``` + + ```text + | | + | | + =========== + | B | + =========== + + =========== + | A | + =========== + ``` + + +- `"any"`: *A* is considered to be above, below, right or left of *B* if it is above, below, right or left of *B* no matter if it can be reached in a straight vertical or horizontal line from (a point of the bounding box of) *A*. + +Examples: + + *A* being above *B*: + + ```text + =========== + | A | + =========== + + =========== + | B | + =========== + ``` + + ```text + =========== + =========== | A | + | B | =========== + =========== + ``` + + + *A* **NOT** being above *B*: + + ```text + =========== + | B | + =========== + + =========== + | A | + =========== + ``` + + ```text + =========== =========== + | B | | A | + =========== =========== + ``` +""" RelationTypeMapping = { @@ -21,6 +168,37 @@ RelationIndex = Annotated[int, Field(ge=0)] +""" +Index of the element *A* above, below, right or left of the other element *B*, +e.g., the first (`0`), second (`1`), third (`2`) etc. element +above, below, right or left of the other element *B*. *A*'s position relative +to other elements above, below, right or left of *B* +(which determines its index) is determined by the relative position of its +lowest (above), highest (below), leftmost (right) or rightmost (left) point(s) +(edge of its bounding box). + +**Important**: Which elements are counted ("indexed") depends on the locator used, e.g., +when using `Text` only text matched is counted, and the `reference_point`. + +Examples: + +```text +=========== +| A | =========== +=========== | B | + =========== + =========== + | C | + =========== =========== + | D | + =========== +``` + +For `reference_point` +- `"center"`, *A* is the first (`index=0`) element above *B*. +- `"boundary"`, *A* is the second (`index=1`) element above *B*. +- `"any"`, *A* is the third (`index=2`) element above *B*. +""" class RelationBase(BaseModel): @@ -95,19 +273,11 @@ def __init__( class Relatable(ABC): - """Base class for locators that can be related to other locators, e.g., spatially, logically, distance based etc. - - Attributes: - relations: List of relations to other locators - """ + """Abstract base class for locators that can be related to other locators, e.g., spatially, logically etc. Cannot be instantiated directly.""" def __init__(self) -> None: self._relations: list[Relation] = [] - @property - def relations(self) -> list[Relation]: - return self._relations - # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation def above_of( self, @@ -127,31 +297,15 @@ def above_of( bounding box). Args: - other_locator: - Locator for an element / elements to relate to - index: - Index of the element (located by *self*) above the other element(s) - (located by *other_locator*), e.g., the first (index=0), second - (index=1), third (index=2) etc. element above the other element(s). + other_locator (Relatable): Locator for an element / elements to relate to + index (RelationIndex, optional): Index of the element (located by *self*) above the other element(s) + (located by *other_locator*), e.g., the first (`0`), second (`1`), third (`2`) etc. element above the other element(s). Elements' (relative) position is determined by the **bottom border** (*y*-coordinate) of their bounding box. We don't guarantee the order of elements with the same bottom border - (*y*-coordinate). - reference_point: - Defines which element (located by *self*) is considered to be above the - other element(s) (located by *other_locator*): - - **"center"**: One point of the element (located by *self*) is above the - center (in a straight vertical line) of the other element(s) (located - by *other_locator*). - **"boundary"**: One point of the element (located by *self*) is above - any other point (in a straight vertical line) of the other element(s) - (located by *other_locator*). - **"any"**: No point of the element (located by *self*) has to be above - a point (in a straight vertical line) of the other element(s) (located - by *other_locator*). - - *Default is **"boundary".*** + (*y*-coordinate). Defaults to `0`. + reference_point (ReferencePoint, optional): Defines which element (located by *self*) is considered to be above the + other element(s) (located by *other_locator*). Defaults to `"boundary"`. Returns: Self: The locator with the relation added @@ -240,7 +394,7 @@ def above_of( ```python from askui import locators as loc # locates text "A" as it is the second (index 1) element above text "C" - # (reference point "any") + # (reference point "center" or "boundary" won't work here) text = loc.Text().above_of(loc.Text("C"), index=1, reference_point="any") # locates also text "A" as it is the first (index 0) element above text "C" # with reference point "boundary" @@ -276,34 +430,18 @@ def below_of( box). Args: - other_locator: - Locator for an element / elements to relate to. - index: - Index of the element (located by *self*) **below** the other - element(s) (located by *other_locator*), e.g., the first (*index=0*), - second (*index=1*), third (*index=2*) etc. element below the other - element(s). Elements' (relative) position is determined by the **top + other_locator (Relatable): Locator for an element / elements to relate to + index (RelationIndex, optional): Index of the element (located by *self*) **below** the other + element(s) (located by *other_locator*), e.g., the first (`0`), second (`1`), third (`2`) etc. element below the other + element(s). Elements' (relative) position is determined by the **top border** (*y*-coordinate) of their bounding box. We don't guarantee the order of elements with the same top border - (*y*-coordinate). - reference_point: - Defines which element (located by *self*) is considered to be - *below* the other element(s) (located by *other_locator*): - - **"center"**: One point of the element (located by *self*) is - **below** the *center* (in a straight vertical line) of the other - element(s) (located by *other_locator*). - **"boundary"**: One point of the element (located by *self*) is - **below** *any* other point (in a straight vertical line) of the - other element(s) (located by *other_locator*). - **"any"**: No point of the element (located by *self*) has to - be **below** a point (in a straight vertical line) of the other - element(s) (located by *other_locator*). - - *Default is **"boundary".*** + (*y*-coordinate). Defaults to `0`. + reference_point (ReferencePoint, optional): Defines which element (located by *self*) is considered to be + *below* the other element(s) (located by *other_locator*). Defaults to `"boundary"`. Returns: - Self: The locator with the relation added. + Self: The locator with the relation added Examples: ```text @@ -425,34 +563,18 @@ def right_of( bounding box). Args: - other_locator: - Locator for an element / elements to relate to. - index: - Index of the element (located by *self*) **right of** the other - element(s) (located by *other_locator*), e.g., the first (*index=0*), - second (*index=1*), third (*index=2*) etc. element right of the other - element(s). Elements' (relative) position is determined by the **left + other_locator (Relatable): Locator for an element / elements to relate to + index (RelationIndex, optional): Index of the element (located by *self*) **right of** the other + element(s) (located by *other_locator*), e.g., the first (`0`), second (`1`), third (`2`) etc. element right of the other + element(s). Elements' (relative) position is determined by the **left border** (*x*-coordinate) of their bounding box. We don't guarantee the order of elements with the same left border - (*x*-coordinate). - reference_point: - Defines which element (located by *self*) is considered to be - *right of* the other element(s) (located by *other_locator*): - - **"center"**: One point of the element (located by *self*) is - **right of** the *center* (in a straight horizontal line) of the - other element(s) (located by *other_locator*). - **"boundary"**: One point of the element (located by *self*) is - **right of** *any* other point (in a straight horizontal line) of - the other element(s) (located by *other_locator*). - **"any"**: No point of the element (located by *self*) has to - be **right of** a point (in a straight horizontal line) of the - other element(s) (located by *other_locator*). - - *Default is **"center".*** + (*x*-coordinate). Defaults to `0`. + reference_point (ReferencePoint, optional): Defines which element (located by *self*) is considered to be + *right of* the other element(s) (located by *other_locator*). Defaults to `"center"`. Returns: - Self: The locator with the relation added. + Self: The locator with the relation added Examples: ```text @@ -564,34 +686,18 @@ def left_of( bounding box). Args: - other_locator: - Locator for an element / elements to relate to. - index: - Index of the element (located by *self*) **left of** the other - element(s) (located by *other_locator*), e.g., the first (*index=0*), - second (*index=1*), third (*index=2*) etc. element left of the other - element(s). Elements' (relative) position is determined by the **right + other_locator (Relatable): Locator for an element / elements to relate to + index (RelationIndex, optional): Index of the element (located by *self*) **left of** the other + element(s) (located by *other_locator*), e.g., the first (`0`), second (`1`), third (`2`) etc. element left of the other + element(s). Elements' (relative) position is determined by the **right border** (*x*-coordinate) of their bounding box. We don't guarantee the order of elements with the same right border - (*x*-coordinate). - reference_point: - Defines which element (located by *self*) is considered to be - *left of* the other element(s) (located by *other_locator*): - - **"center"** : One point of the element (located by *self*) is - **left of** the *center* (in a straight horizontal line) of the - other element(s) (located by *other_locator*). - **"boundary"**: One point of the element (located by *self*) is - **left of** *any* other point (in a straight horizontal line) of - the other element(s) (located by *other_locator*). - **"any"** : No point of the element (located by *self*) has to - be **left of** a point (in a straight horizontal line) of the - other element(s) (located by *other_locator*). - - *Default is **"center".*** + (*x*-coordinate). Defaults to `0`. + reference_point (ReferencePoint, optional): Defines which element (located by *self*) is considered to be + *left of* the other element(s) (located by *other_locator*). Defaults to `"center"`. Returns: - Self: The locator with the relation added. + Self: The locator with the relation added Examples: ```text @@ -689,7 +795,7 @@ def containing(self, other_locator: "Relatable") -> Self: by *other_locator*). Args: - other_locator: The locator to check if it's contained + other_locator (Relatable): The locator to check if it's contained Returns: Self: The locator with the relation added @@ -725,7 +831,7 @@ def inside_of(self, other_locator: "Relatable") -> Self: (located by *other_locator*). Args: - other_locator: The locator to check if it contains this element + other_locator (Relatable): The locator to check if it contains this element Returns: Self: The locator with the relation added @@ -763,7 +869,7 @@ def nearest_to(self, other_locator: "Relatable") -> Self: (located by *other_locator*). Args: - other_locator: The locator to compare distance against + other_locator (Relatable): The locator to compare distance against Returns: Self: The locator with the relation added @@ -805,7 +911,7 @@ def and_(self, other_locator: "Relatable") -> Self: element to match multiple locators. Args: - other_locator: The locator to combine with + other_locator (Relatable): The locator to combine with Returns: Self: The locator with the relation added @@ -836,7 +942,7 @@ def or_(self, other_locator: "Relatable") -> Self: if no element is found for one of the locators. Args: - other_locator: The locator to combine with + other_locator (Relatable): The locator to combine with Returns: Self: The locator with the relation added @@ -898,7 +1004,7 @@ def _dfs(node: Relatable) -> bool: visited_ids.add(node_id) recursion_stack_ids.add(node_id) - for relation in node.relations: + for relation in node._relations: if _dfs(relation.other_locator): return True diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py index 35e1f180..3d2063ed 100644 --- a/src/askui/locators/serializers.py +++ b/src/askui/locators/serializers.py @@ -27,7 +27,7 @@ class VlmLocatorSerializer: def serialize(self, locator: Relatable) -> str: locator.raise_if_cycle() - if len(locator.relations) > 0: + if len(locator._relations) > 0: raise NotImplementedError( "Serializing locators with relations is not yet supported for VLMs" ) @@ -50,17 +50,17 @@ def serialize(self, locator: Relatable) -> str: raise ValueError(f"Unsupported locator type: {type(locator)}") def _serialize_class(self, class_: Element) -> str: - if class_.class_name: - return f"an arbitrary {class_.class_name} shown" + if class_._class_name: + return f"an arbitrary {class_._class_name} shown" else: return "an arbitrary ui element (e.g., text, button, textfield, etc.)" def _serialize_prompt(self, prompt: Prompt) -> str: - return prompt.prompt + return prompt._prompt def _serialize_text(self, text: Text) -> str: - if text.match_type == "similar": - return f'text similar to "{text.text}"' + if text._match_type == "similar": + return f'text similar to "{text._text}"' return str(text) @@ -105,7 +105,7 @@ def __init__(self, ai_element_collection: AiElementCollection, reporter: Reporte def serialize(self, locator: Relatable) -> AskUiSerializedLocator: locator.raise_if_cycle() - if len(locator.relations) > 1: + if len(locator._relations) > 1: # If we lift this constraint, we also have to make sure that custom element references are still working + we need, e.g., some symbol or a structured format to indicate precedence raise NotImplementedError( "Serializing locators with multiple relations is not yet supported by AskUI" @@ -125,38 +125,38 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: else: raise ValueError(f'Unsupported locator type: "{type(locator)}"') - if len(locator.relations) == 0: + if len(locator._relations) == 0: return result - serialized_relation = self._serialize_relation(locator.relations[0]) + serialized_relation = self._serialize_relation(locator._relations[0]) result["instruction"] += f" {serialized_relation['instruction']}" result["customElements"] += serialized_relation["customElements"] return result def _serialize_class(self, class_: Element) -> str: - return class_.class_name or "element" + return class_._class_name or "element" def _serialize_prompt(self, prompt: Prompt) -> str: - return f"pta {self._TEXT_DELIMITER}{prompt.prompt}{self._TEXT_DELIMITER}" + return f"pta {self._TEXT_DELIMITER}{prompt._prompt}{self._TEXT_DELIMITER}" def _serialize_text(self, text: Text) -> str: - match text.match_type: + match text._match_type: case "similar": if ( - text.similarity_threshold == DEFAULT_SIMILARITY_THRESHOLD - and text.match_type == DEFAULT_TEXT_MATCH_TYPE + text._similarity_threshold == DEFAULT_SIMILARITY_THRESHOLD + and text._match_type == DEFAULT_TEXT_MATCH_TYPE ): # Necessary so that we can use wordlevel ocr for these texts return ( - f"text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" + f"text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" ) - return f"text with text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER} that matches to {text.similarity_threshold} %" + return f"text with text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER} that matches to {text._similarity_threshold} %" case "exact": - return f"text equals text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" + return f"text equals text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" case "contains": - return f"text contain text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" + return f"text contain text {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" case "regex": - return f"text match regex pattern {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" + return f"text match regex pattern {self._TEXT_DELIMITER}{text._text}{self._TEXT_DELIMITER}" case _: raise ValueError(f'Unsupported text match type: "{text.match_type}"') @@ -198,14 +198,14 @@ def _serialize_image_to_custom_element( ) -> CustomElement: custom_element: CustomElement = CustomElement( customImage=image_source.to_data_url(), - threshold=image_locator.threshold, - stopThreshold=image_locator.stop_threshold, - rotationDegreePerStep=image_locator.rotation_degree_per_step, - imageCompareFormat=image_locator.image_compare_format, - name=image_locator.name, + threshold=image_locator._threshold, + stopThreshold=image_locator._stop_threshold, + rotationDegreePerStep=image_locator._rotation_degree_per_step, + imageCompareFormat=image_locator._image_compare_format, + name=image_locator._name, ) - if image_locator.mask: - custom_element["mask"] = image_locator.mask + if image_locator._mask: + custom_element["mask"] = image_locator._mask return custom_element def _serialize_image_base( @@ -221,7 +221,7 @@ def _serialize_image_base( for image_source in image_sources ] return AskUiSerializedLocator( - instruction=f"custom element with text {self._TEXT_DELIMITER}{image_locator.name}{self._TEXT_DELIMITER}", + instruction=f"custom element with text {self._TEXT_DELIMITER}{image_locator._name}{self._TEXT_DELIMITER}", customElements=custom_elements, ) @@ -232,20 +232,20 @@ def _serialize_image( self._reporter.add_message( "AskUiLocatorSerializer", f"Image locator: {image}", - image=image.image.root, + image=image._image.root, ) return self._serialize_image_base( image_locator=image, - image_sources=[image.image], + image_sources=[image._image], ) def _serialize_ai_element( self, ai_element_locator: AiElementLocator ) -> AskUiSerializedLocator: - ai_elements = self._ai_element_collection.find(ai_element_locator.name) + ai_elements = self._ai_element_collection.find(ai_element_locator._name) self._reporter.add_message( "AskUiLocatorSerializer", - f"Found {len(ai_elements)} ai elements named {ai_element_locator.name}", + f"Found {len(ai_elements)} ai elements named {ai_element_locator._name}", image=[ai_element.image for ai_element in ai_elements], ) return self._serialize_image_base( diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 71da37b2..f28dc667 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -27,7 +27,16 @@ class ModelName(str, Enum): class ModelDefinition(BaseModel): """ A definition of a model. + + Args: + task (str): The task the model is trained for, e.g., end-to-end OCR (`"e2e_ocr"`) or object detection (`"od"`) + architecture (str): The architecture of the model, e.g., `"easy_ocr"` or `"yolo"` + version (str): The version of the model + interface (str): The interface the model is trained for, e.g., `"online_learning"` + use_case (str, optional): The use case the model is trained for. In the case of workspace specific AskUI models, this is often the workspace id but with "-" replaced by "_". Defaults to `"00000000_0000_0000_0000_000000000000"` (custom null value). + tags (list[str], optional): Tags for identifying the model that cannot be represented by other properties, e.g., `["trained", "word_level"]` """ + model_config = ConfigDict( populate_by_name=True, ) @@ -41,7 +50,7 @@ class ModelDefinition(BaseModel): version: str = Field(pattern=r"^[0-9]{1,6}$") interface: ModelDefinitionProperty = Field( description="The interface the model is trained for", - examples=["online_learning", "offline_learning"], + examples=["online_learning"], ) use_case: ModelDefinitionProperty = Field( description='The use case the model is trained for. In the case of workspace specific AskUI models, this is often the workspace id but with "-" replaced by "_"', @@ -60,27 +69,28 @@ class ModelDefinition(BaseModel): @property def model_name(self) -> str: - return ( - "-".join( - [ - self.task, - self.architecture, - self.interface, - self.use_case, - self.version, - *self.tags, - ] - ) + """ + The name of the model. + """ + return "-".join( + [ + self.task, + self.architecture, + self.interface, + self.use_case, + self.version, + *self.tags, + ] ) class ModelComposition(RootModel[list[ModelDefinition]]): """ - A composition of models. + A composition of models (list of `ModelDefinition`) to be used for a task, e.g., locating an element on the screen to be able to click on it or extracting text from an image. """ def __iter__(self): return iter(self.root) - + def __getitem__(self, index: int) -> ModelDefinition: return self.root[index] diff --git a/src/askui/models/router.py b/src/askui/models/router.py index 7f3395cb..d9aaadcc 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -23,6 +23,9 @@ Point = tuple[int, int] +""" +A tuple of two integers representing the coordinates of a point on the screen. +""" def handle_response(response: tuple[int | None, int | None], locator: str | Locator): diff --git a/src/askui/models/types/response_schemas.py b/src/askui/models/types/response_schemas.py index e9eba25c..c75472af 100644 --- a/src/askui/models/types/response_schemas.py +++ b/src/askui/models/types/response_schemas.py @@ -3,6 +3,24 @@ class ResponseSchemaBase(BaseModel): + """Base class for response schemas to be used for defining the response of data extraction, e.g., using `askui.VisionAgent.get()`. + + This class extends Pydantic's BaseModel and adds constraints and configuration on top so that it can be used with models to define the schema (type) of the data to be extracted. + + Example: + ```python + class UrlResponse(ResponseSchemaBase): + url: str + + # nested models should also extend ResponseSchemaBase + class NestedResponse(ResponseSchemaBase): + nested: UrlResponse + + # metadata, e.g., `examples` or `description` of `Field`, is generally also passed to and considered by the models + class UrlResponse(ResponseSchemaBase): + url: str = Field(description="The URL of the response. Should used `\"https\"` scheme.", examples=["https://www.example.com"]) + ``` + """ model_config = ConfigDict(extra="forbid") @@ -13,6 +31,18 @@ class ResponseSchemaBase(BaseModel): ResponseSchema = TypeVar('ResponseSchema', ResponseSchemaBase, str, bool, int, float) +"""Type of the responses of data extracted, e.g., using `askui.VisionAgent.get()`. + +The following types are allowed: +- `ResponseSchemaBase`: Custom Pydantic models that extend `ResponseSchemaBase` +- `str`: String responses +- `bool`: Boolean responses +- `int`: Integer responses +- `float`: Floating point responses + +Usually, serialized as a JSON schema, e.g., `str` as `{"type": "string"}`, to be passed to model(s). +Also used for validating the responses of the model(s) used for data extraction. +""" @overload diff --git a/src/askui/reporting.py b/src/askui/reporting.py index c274fc80..d3f78ce2 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -15,6 +15,11 @@ class Reporter(ABC): + """Abstract base class for reporters. Cannot be instantiated directly. + + Defines the interface that all reporters must implement to be used with `askui.VisionAgent`. + """ + @abstractmethod def add_message( self, @@ -22,16 +27,37 @@ def add_message( content: Union[str, dict, list], image: Optional[Image.Image | list[Image.Image]] = None, ) -> None: + """Add a message to the report. + + Args: + role (str): The role of the message sender (e.g., `"User"`, `"Assistant"`, `"System"`) + content (Union[str, dict, list]): The message content, which can be a string, dictionary, or list, e.g. `'click 2x times on text "Edit"'` + image (Optional[PIL.Image.Image | list[PIL.Image.Image]], optional): PIL Image or list of PIL Images to include with the message + """ raise NotImplementedError() @abstractmethod def generate(self) -> None: - raise NotImplementedError() + """Generates the final report. + + Implementing this method is only required if the report is not generated in "real-time", e.g., on calls of `add_message()`, but must be generated at the end of the execution. + + This method is called when the `askui.VisionAgent` context is exited or `askui.VisionAgent.close()` is called. + """ + pass class CompositeReporter(Reporter): - def __init__(self, reports: list[Reporter] | None = None) -> None: - self._reports = reports or [] + """A reporter that combines multiple reporters. + + Allows generating different reports simultaneously. Each message added will be forwarded to all reporters passed to the constructor. The reporters are called (`add_message()`, `generate()`) in the order they are ordered in the `reporters` list. + + Args: + reporters (list[Reporter] | None, optional): List of reporters to combine + """ + + def __init__(self, reporters: list[Reporter] | None = None) -> None: + self._reporters = reporters or [] @override def add_message( @@ -40,16 +66,24 @@ def add_message( content: Union[str, dict, list], image: Optional[Image.Image | list[Image.Image]] = None, ) -> None: - for report in self._reports: + """Add a message to the report.""" + for report in self._reporters: report.add_message(role, content, image) @override def generate(self) -> None: - for report in self._reports: + """Generates the final report.""" + for report in self._reporters: report.generate() class SimpleHtmlReporter(Reporter): + """A reporter that generates HTML reports with conversation logs and system information. + + Args: + report_dir (str, optional): Directory where reports will be saved. Defaults to `reports`. + """ + def __init__(self, report_dir: str = "reports") -> None: self.report_dir = Path(report_dir) self.report_dir.mkdir(exist_ok=True) @@ -67,13 +101,11 @@ def _collect_system_info(self) -> Dict[str, str]: } def _image_to_base64(self, image: Image.Image) -> str: - """Convert PIL Image to base64 string""" buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() def _format_content(self, content: Union[str, dict, list]) -> str: - """Format content based on its type""" if isinstance(content, (dict, list)): return json.dumps(content, indent=2) return str(content) @@ -85,7 +117,7 @@ def add_message( content: Union[str, dict, list], image: Optional[Image.Image | list[Image.Image]] = None, ) -> None: - """Add a message to the report, optionally with an image""" + """Add a message to the report.""" if image is None: _images = [] elif isinstance(image, list): @@ -104,7 +136,13 @@ def add_message( @override def generate(self) -> None: - """Generate HTML report using a Jinja template""" + """Generate an HTML report file. + + Creates a timestamped HTML file in the `report_dir` containing: + - System information + - All collected messages with their content and images + - Syntax-highlighted JSON content + """ template_str = """ diff --git a/src/askui/tools/__init__.py b/src/askui/tools/__init__.py index e76623ba..3b1a761f 100644 --- a/src/askui/tools/__init__.py +++ b/src/askui/tools/__init__.py @@ -1,3 +1,9 @@ +from .agent_os import AgentOs, ModifierKey, PcKey from .toolbox import AgentToolbox -__all__ = ["AgentToolbox"] \ No newline at end of file +__all__ = [ + "AgentOs", + "AgentToolbox", + "ModifierKey", + "PcKey", +] \ No newline at end of file diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index e7bc437e..e6b2614d 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -3,6 +3,8 @@ from PIL import Image ModifierKey = Literal["command", "alt", "control", "shift", "right_shift"] +"""Modifier keys for keyboard actions.""" + PcKey = Literal[ "backspace", "delete", @@ -125,78 +127,162 @@ "}", "~", ] +"""PC keys for keyboard actions.""" class AgentOs(ABC): + """ + Abstract base class for Agent OS. Cannot be instantiated directly. + + This class defines the interface for operating system interactions including mouse control, + keyboard input, and screen capture functionality. Implementations should provide concrete + functionality for these abstract methods. + """ + @abstractmethod def connect(self) -> None: - """Connect to the Agent OS.""" + """ + Establishes a connection to the Agent OS. + + This method is called before performing any OS-level operations. + It handles any necessary setup or initialization required for the OS interaction. + """ pass @abstractmethod def disconnect(self) -> None: - """Disconnect from the Agent OS.""" + """ + Terminates the connection to the Agent OS. + + This method is called after all OS-level operations are complete. + It handles any necessary cleanup or resource release. + """ pass @abstractmethod def screenshot(self, report: bool = True) -> Image.Image: - """Take a screenshot of the current display.""" + """ + Captures a screenshot of the current display. + + Args: + report (bool, optional): Whether to include the screenshot in reporting. Defaults to `True`. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + """ raise NotImplementedError() @abstractmethod def mouse(self, x: int, y: int) -> None: - """Move mouse to specified coordinates.""" + """ + Moves the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + """ raise NotImplementedError() @abstractmethod def type(self, text: str, typing_speed: int = 50) -> None: - """Type text.""" + """ + Simulates typing text as if entered on a keyboard. + + Args: + text (str): The text to be typed. + typing_speed (int, optional): The speed of typing in characters per minute. Defaults to `50`. + """ raise NotImplementedError() @abstractmethod def click( self, button: Literal["left", "middle", "right"] = "left", count: int = 1 ) -> None: - """Click mouse button (repeatedly).""" + """ + Simulates clicking a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ raise NotImplementedError() @abstractmethod def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: - """Press and hold mouse button.""" + """ + Simulates pressing and holding a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to press. Defaults to `"left"`. + """ raise NotImplementedError() @abstractmethod def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: - """Release mouse button.""" + """ + Simulates releasing a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to release. Defaults to `"left"`. + """ raise NotImplementedError() @abstractmethod def mouse_scroll(self, x: int, y: int) -> None: - """Scroll mouse wheel horizontally and vertically.""" + """ + Simulates scrolling the mouse wheel. + + Args: + x (int): The horizontal scroll amount. Positive values scroll right, negative values scroll left. + y (int): The vertical scroll amount. Positive values scroll down, negative values scroll up. + """ raise NotImplementedError() @abstractmethod def keyboard_pressed( self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None ) -> None: - """Press and hold keyboard key.""" + """ + Simulates pressing and holding a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to press. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Defaults to `None`. + """ raise NotImplementedError() @abstractmethod def keyboard_release( self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None ) -> None: - """Release keyboard key.""" + """ + Simulates releasing a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to release. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to release along with the main key. Defaults to `None`. + """ raise NotImplementedError() @abstractmethod def keyboard_tap( self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None ) -> None: - """Press and release keyboard key.""" + """ + Simulates pressing and immediately releasing a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to tap. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Defaults to `None`. + """ raise NotImplementedError() @abstractmethod def set_display(self, displayNumber: int = 1) -> None: - """Set active display, e.g., when using multiple displays.""" + """ + Sets the active display for screen interactions. + + Args: + displayNumber (int, optional): The display number to set as active. Defaults to `1`. + """ raise NotImplementedError() diff --git a/src/askui/tools/askui/__init__.py b/src/askui/tools/askui/__init__.py index 657f2f1f..6f862b7d 100644 --- a/src/askui/tools/askui/__init__.py +++ b/src/askui/tools/askui/__init__.py @@ -1,3 +1,3 @@ -from .askui_controller import AskUiControllerClient +from .askui_controller import AskUiControllerClient, AskUiControllerServer -__all__ = ["AskUiControllerClient"] +__all__ = ["AskUiControllerClient", "AskUiControllerServer"] diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index 65c0506d..ffa81187 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod import pathlib from typing import Literal from typing_extensions import Self, override @@ -59,29 +58,13 @@ def validate_either_component_registry_or_installation_directory_is_set(self) -> if self.component_registry_file is None and self.installation_directory is None: raise ValueError("Either ASKUI_COMPONENT_REGISTRY_FILE or ASKUI_INSTALLATION_DIRECTORY environment variable must be set") return self - - -class ControllerServer(ABC): - @abstractmethod - def start(self, clean_up: bool = False) -> None: - raise NotImplementedError() - - @abstractmethod - def stop(self, force: bool = False) -> None: - raise NotImplementedError() - - -class EmptyControllerServer(ControllerServer): - @override - def start(self, clean_up: bool = False) -> None: - pass - - @override - def stop(self, force: bool = False) -> None: - pass -class AskUiControllerServer(ControllerServer): +class AskUiControllerServer: + """ + Concrete implementation of `ControllerServer` for managing the AskUI Remote Device Controller process. + Handles process discovery, startup, and shutdown for the native controller binary. + """ def __init__(self) -> None: self._process = None self._settings = AskUiControllerSettings() # type: ignore @@ -118,9 +101,14 @@ def _find_remote_device_controller_by_legacy_path(self) -> pathlib.Path: def __start_process(self, path): self.process = subprocess.Popen(path) wait_for_port(23000) - - @override + def start(self, clean_up: bool = False) -> None: + """ + Start the controller process. + + Args: + clean_up (bool, optional): Whether to clean up existing processes (only on Windows) before starting. Defaults to `False`. + """ if sys.platform == 'win32' and clean_up and process_exists("AskuiRemoteDeviceController.exe"): self.clean_up() remote_device_controller_path = self._find_remote_device_controller() @@ -129,22 +117,44 @@ def start(self, clean_up: bool = False) -> None: time.sleep(0.5) # TODO Find better way to do this, e.g., waiting for something to be logged or port to be opened def clean_up(self): - if sys.platform == 'win32': - subprocess.run("taskkill.exe /IM AskUI*") - time.sleep(0.1) + subprocess.run("taskkill.exe /IM AskUI*") + time.sleep(0.1) - @override def stop(self, force: bool = False) -> None: - if force: - self.process.terminate() - self.clean_up() - return - self.process.kill() + """ + Stop the controller process. + + Args: + force (bool, optional): Whether to forcefully terminate the process. Defaults to `False`. + """ + if not hasattr(self, "process") or self.process is None: + return # Nothing to stop + + try: + if force: + self.process.kill() + if sys.platform == "win32": + self.clean_up() + else: + self.process.terminate() + except Exception as e: + logger.error("Failed to stop AskUI Remote Device Controller: %s", e) + pass + finally: + self.process = None class AskUiControllerClient(AgentOs): + """ + Implementation of `AgentOs` that communicates with the AskUI Remote Device Controller via gRPC. + + Args: + reporter (Reporter): Reporter used for reporting with the `"AgentOs"`. + display (int, optional): Display number to use. Defaults to `1`. + controller_server (AskUiControllerServer | None, optional): Custom controller server. Defaults to `ControllerServer`. + """ @telemetry.record_call(exclude={"report"}) - def __init__(self, reporter: Reporter, display: int = 1, controller_server: ControllerServer | None = None) -> None: + def __init__(self, reporter: Reporter, display: int = 1, controller_server: AskUiControllerServer | None = None) -> None: self.stub = None self.channel = None self.session_info = None @@ -153,11 +163,17 @@ def __init__(self, reporter: Reporter, display: int = 1, controller_server: Cont self.max_retries = 10 self.display = display self._reporter = reporter - self._controller_server = controller_server or EmptyControllerServer() + self._controller_server = controller_server or AskUiControllerServer() @telemetry.record_call() @override def connect(self) -> None: + """ + Establishes a connection to the AskUI Remote Device Controller. + + This method starts the controller server, establishes a gRPC channel, + creates a session, and sets up the initial display. + """ self._controller_server.start() self.channel = grpc.insecure_channel('localhost:23000', options=[ ('grpc.max_send_message_length', 2**30 ), @@ -187,6 +203,12 @@ def _run_recorder_action(self, acion_class_id: controller_v1_pbs.ActionClassID, @telemetry.record_call() @override def disconnect(self) -> None: + """ + Terminates the connection to the AskUI Remote Device Controller. + + This method stops the execution, ends the session, closes the gRPC channel, + and stops the controller server. + """ self._stop_execution() self._stop_session() self.channel.close() @@ -194,11 +216,25 @@ def disconnect(self) -> None: @telemetry.record_call() def __enter__(self) -> Self: + """ + Context manager entry point that establishes the connection. + + Returns: + Self: The instance of AskUiControllerClient. + """ self.connect() return self @telemetry.record_call(exclude={"exc_value", "traceback"}) def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit point that disconnects the client. + + Args: + exc_type: The exception type if an exception was raised. + exc_value: The exception value if an exception was raised. + traceback: The traceback if an exception was raised. + """ self.disconnect() def _start_session(self): @@ -217,6 +253,15 @@ def _stop_execution(self): @telemetry.record_call() @override def screenshot(self, report: bool = True) -> Image.Image: + """ + Captures a screenshot of the current display. + + Args: + report (bool, optional): Whether to include the screenshot in reporting. Defaults to `True`. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + """ assert isinstance(self.stub, controller_v1.ControllerAPIStub), "Stub is not initialized" screenResponse = self.stub.CaptureScreen(controller_v1_pbs.Request_CaptureScreen(sessionInfo=self.session_info, captureParameters=controller_v1_pbs.CaptureParameters(displayID=self.display))) r, g, b, _ = Image.frombytes('RGBA', (screenResponse.bitmap.width, screenResponse.bitmap.height), screenResponse.bitmap.data).split() @@ -227,19 +272,39 @@ def screenshot(self, report: bool = True) -> Image.Image: @telemetry.record_call() @override def mouse(self, x: int, y: int) -> None: + """ + Moves the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + """ self._reporter.add_message("AgentOS", f"mouse({x}, {y})", draw_point_on_image(self.screenshot(report=False), x, y, size=5)) self._run_recorder_action(acion_class_id=controller_v1_pbs.ActionClassID_MouseMove, action_parameters=controller_v1_pbs.ActionParameters(mouseMove=controller_v1_pbs.ActionParameters_MouseMove(position=controller_v1_pbs.Coordinate2(x=x, y=y)))) - @telemetry.record_call(exclude={"text"}) @override def type(self, text: str, typing_speed: int = 50) -> None: + """ + Simulates typing text as if entered on a keyboard. + + Args: + text (str): The text to be typed. + typing_speed (int, optional): The speed of typing in characters per second. Defaults to `50`. + """ self._reporter.add_message("AgentOS", f"type(\"{text}\", {typing_speed})") self._run_recorder_action(acion_class_id=controller_v1_pbs.ActionClassID_KeyboardType_UnicodeText, action_parameters=controller_v1_pbs.ActionParameters(keyboardTypeUnicodeText=controller_v1_pbs.ActionParameters_KeyboardType_UnicodeText(text=text.encode('utf-16-le'), typingSpeed=typing_speed, typingSpeedValue=controller_v1_pbs.TypingSpeedValue.TypingSpeedValue_CharactersPerSecond))) @telemetry.record_call() @override def click(self, button: Literal['left', 'middle', 'right'] = 'left', count: int = 1) -> None: + """ + Simulates clicking a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ self._reporter.add_message("AgentOS", f"click(\"{button}\", {count})") mouse_button = None match button: @@ -254,6 +319,12 @@ def click(self, button: Literal['left', 'middle', 'right'] = 'left', count: int @telemetry.record_call() @override def mouse_down(self, button: Literal['left', 'middle', 'right'] = 'left') -> None: + """ + Simulates pressing and holding a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to press. Defaults to `"left"`. + """ self._reporter.add_message("AgentOS", f"mouse_down(\"{button}\")") mouse_button = None match button: @@ -268,6 +339,12 @@ def mouse_down(self, button: Literal['left', 'middle', 'right'] = 'left') -> Non @telemetry.record_call() @override def mouse_up(self, button: Literal['left', 'middle', 'right'] = 'left') -> None: + """ + Simulates releasing a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to release. Defaults to `"left"`. + """ self._reporter.add_message("AgentOS", f"mouse_up(\"{button}\")") mouse_button = None match button: @@ -282,6 +359,13 @@ def mouse_up(self, button: Literal['left', 'middle', 'right'] = 'left') -> None: @telemetry.record_call() @override def mouse_scroll(self, x: int, y: int) -> None: + """ + Simulates scrolling the mouse wheel. + + Args: + x (int): The horizontal scroll amount. Positive values scroll right, negative values scroll left. + y (int): The vertical scroll amount. Positive values scroll down, negative values scroll up. + """ self._reporter.add_message("AgentOS", f"mouse_scroll({x}, {y})") if x != 0: self._run_recorder_action(acion_class_id=controller_v1_pbs.ActionClassID_MouseWheelScroll, action_parameters=controller_v1_pbs.ActionParameters(mouseWheelScroll=controller_v1_pbs.ActionParameters_MouseWheelScroll( @@ -298,10 +382,16 @@ def mouse_scroll(self, x: int, y: int) -> None: milliseconds = 50 ))) - @telemetry.record_call() @override def keyboard_pressed(self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None) -> None: + """ + Simulates pressing and holding a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to press. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Defaults to `None`. + """ self._reporter.add_message("AgentOS", f"keyboard_pressed(\"{key}\", {modifier_keys})") if modifier_keys is None: modifier_keys = [] @@ -310,6 +400,13 @@ def keyboard_pressed(self, key: PcKey | ModifierKey, modifier_keys: list[Modifi @telemetry.record_call() @override def keyboard_release(self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None) -> None: + """ + Simulates releasing a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to release. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to release along with the main key. Defaults to `None`. + """ self._reporter.add_message("AgentOS", f"keyboard_release(\"{key}\", {modifier_keys})") if modifier_keys is None: modifier_keys = [] @@ -318,6 +415,13 @@ def keyboard_release(self, key: PcKey | ModifierKey, modifier_keys: list[Modifi @telemetry.record_call() @override def keyboard_tap(self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None) -> None: + """ + Simulates pressing and immediately releasing a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to tap. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to press along with the main key. Defaults to `None`. + """ self._reporter.add_message("AgentOS", f"keyboard_tap(\"{key}\", {modifier_keys})") if modifier_keys is None: modifier_keys = [] @@ -326,6 +430,12 @@ def keyboard_tap(self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKe @telemetry.record_call() @override def set_display(self, displayNumber: int = 1) -> None: + """ + Sets the active display for screen interactions. + + Args: + displayNumber (int, optional): The display number to set as active. Defaults to `1`. + """ assert isinstance(self.stub, controller_v1.ControllerAPIStub), "Stub is not initialized" self._reporter.add_message("AgentOS", f"set_display({displayNumber})") self.stub.SetActiveDisplay(controller_v1_pbs.Request_SetActiveDisplay(displayID=displayNumber)) diff --git a/src/askui/tools/toolbox.py b/src/askui/tools/toolbox.py index 0affcec9..137a8843 100644 --- a/src/askui/tools/toolbox.py +++ b/src/askui/tools/toolbox.py @@ -6,6 +6,21 @@ class AgentToolbox: + """ + Toolbox for agent. + + Provides access to OS-level actions, clipboard, web browser, HTTP client etc. + + Args: + agent_os (AgentOs): The OS interface implementation to use for agent actions. + + Attributes: + webbrowser: Python's built-in `webbrowser` module for opening URLs. + clipboard: `pyperclip` module for clipboard access. + agent_os (AgentOs): The OS interface for mouse, keyboard, and screen actions. + httpx: HTTPX client for HTTP requests. + hub (AskUIHub): Internal AskUI Hub instance. + """ def __init__(self, agent_os: AgentOs): self.webbrowser = webbrowser self.clipboard: pyperclip = pyperclip diff --git a/src/askui/utils/image_utils.py b/src/askui/utils/image_utils.py index dc677540..8de5b7db 100644 --- a/src/askui/utils/image_utils.py +++ b/src/askui/utils/image_utils.py @@ -248,6 +248,13 @@ def scale_coordinates_back( Img = Union[str, Path, PILImage.Image] +"""Type of the input images for `askui.VisionAgent.get()`, `askui.VisionAgent.locate()`, etc. + +Accepts: +- `PIL.Image.Image` +- Relative or absolute file path (`str` or `pathlib.Path`) +- Data URL (e.g., `"data:image/png;base64,..."`) +""" class ImageSource(RootModel): diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index ba8b859d..6b7151e1 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -17,7 +17,7 @@ class ReporterMock(Reporter): @override - def add_message(self, role: str, content: Union[str, dict, list], image: Optional[PILImage.Image] = None) -> None: + def add_message(self, role: str, content: Union[str, dict, list], image: Optional[PILImage.Image | list[PILImage.Image]] = None) -> None: pass @override diff --git a/tests/integration/tools/askui/test_askui_controller.py b/tests/integration/tools/askui/test_askui_controller.py index f477a73b..41cd1fd2 100644 --- a/tests/integration/tools/askui/test_askui_controller.py +++ b/tests/integration/tools/askui/test_askui_controller.py @@ -1,8 +1,37 @@ -from askui.tools.askui.askui_controller import AskUiControllerServer +import pytest +from askui.reporting import CompositeReporter +from askui.tools.askui.askui_controller import ( + AskUiControllerClient, + AskUiControllerServer, +) from pathlib import Path -def test_find_remote_device_controller_by_component_registry(): - controller = AskUiControllerServer() - remote_device_controller_path = Path(controller._find_remote_device_controller_by_component_registry()) +@pytest.fixture +def controller_server(): + return AskUiControllerServer() + + +@pytest.fixture +def controller_client(controller_server: AskUiControllerServer): + return AskUiControllerClient( + reporter=CompositeReporter(), + display=1, + controller_server=controller_server, + ) + + +def test_find_remote_device_controller_by_component_registry( + controller_server: AskUiControllerServer, +): + remote_device_controller_path = Path( + controller_server._find_remote_device_controller_by_component_registry() + ) assert "AskuiRemoteDeviceController" == remote_device_controller_path.stem + + +def test_actions(controller_client: AskUiControllerClient): + with controller_client: + controller_client.screenshot() + controller_client.mouse(0, 0) + controller_client.click() diff --git a/tests/unit/locators/serializers/test_askui_locator_serializer.py b/tests/unit/locators/serializers/test_askui_locator_serializer.py index e79eadfb..01afed76 100644 --- a/tests/unit/locators/serializers/test_askui_locator_serializer.py +++ b/tests/unit/locators/serializers/test_askui_locator_serializer.py @@ -84,12 +84,12 @@ def test_serialize_image(askui_serializer: AskUiLocatorSerializer) -> None: assert len(result["customElements"]) == 1 custom_element = result["customElements"][0] assert custom_element["customImage"] == f"data:image/png;base64,{TEST_IMAGE_BASE64}" - assert custom_element["threshold"] == image.threshold - assert custom_element["stopThreshold"] == image.stop_threshold + assert custom_element["threshold"] == image._threshold + assert custom_element["stopThreshold"] == image._stop_threshold assert "mask" not in custom_element - assert custom_element["rotationDegreePerStep"] == image.rotation_degree_per_step - assert custom_element["imageCompareFormat"] == image.image_compare_format - assert custom_element["name"] == image.name + assert custom_element["rotationDegreePerStep"] == image._rotation_degree_per_step + assert custom_element["imageCompareFormat"] == image._image_compare_format + assert custom_element["name"] == image._name def test_serialize_image_with_all_options( diff --git a/tests/unit/locators/test_locators.py b/tests/unit/locators/test_locators.py index 1b60fd9f..14eeffff 100644 --- a/tests/unit/locators/test_locators.py +++ b/tests/unit/locators/test_locators.py @@ -12,7 +12,7 @@ class TestDescriptionLocator: def test_initialization_with_description(self) -> None: desc = Prompt(prompt="test") - assert desc.prompt == "test" + assert desc._prompt == "test" assert str(desc) == 'element with prompt "test"' def test_initialization_without_description_raises(self) -> None: @@ -21,7 +21,7 @@ def test_initialization_without_description_raises(self) -> None: def test_initialization_with_positional_arg(self) -> None: desc = Prompt("test") - assert desc.prompt == "test" + assert desc._prompt == "test" def test_initialization_with_invalid_args_raises(self) -> None: with pytest.raises(ValueError): @@ -34,17 +34,17 @@ def test_initialization_with_invalid_args_raises(self) -> None: class TestClassLocator: def test_initialization_with_class_name(self) -> None: cls = Element(class_name="text") - assert cls.class_name == "text" + assert cls._class_name == "text" assert str(cls) == 'element with class "text"' def test_initialization_without_class_name(self) -> None: cls = Element() - assert cls.class_name is None + assert cls._class_name is None assert str(cls) == "element" def test_initialization_with_positional_arg(self) -> None: cls = Element("text") - assert cls.class_name == "text" + assert cls._class_name == "text" def test_initialization_with_invalid_args_raises(self) -> None: with pytest.raises(ValueError): @@ -60,20 +60,20 @@ def test_initialization_with_invalid_args_raises(self) -> None: class TestTextLocator: def test_initialization_with_positional_text(self) -> None: text = Text("Hello") - assert text.text == "Hello" - assert text.match_type == "similar" - assert text.similarity_threshold == 70 + assert text._text == "Hello" + assert text._match_type == "similar" + assert text._similarity_threshold == 70 assert str(text) == 'text similar to "Hello" (similarity >= 70%)' def test_initialization_with_named_text(self) -> None: text = Text(text="hello", match_type="exact") - assert text.text == "hello" - assert text.match_type == "exact" + assert text._text == "hello" + assert text._match_type == "exact" assert str(text) == 'text "hello"' def test_initialization_with_similarity(self) -> None: text = Text(text="hello", match_type="similar", similarity_threshold=80) - assert text.similarity_threshold == 80 + assert text._similarity_threshold == 80 assert str(text) == 'text similar to "hello" (similarity >= 80%)' def test_initialization_with_contains(self) -> None: @@ -86,7 +86,7 @@ def test_initialization_with_regex(self) -> None: def test_initialization_without_text(self) -> None: text = Text() - assert text.text is None + assert text._text is None assert str(text) == "text" def test_initialization_with_invalid_args(self) -> None: @@ -115,12 +115,12 @@ def test_image(self) -> PILImage.Image: def test_initialization_with_basic_params(self, test_image: PILImage.Image) -> None: locator = Image(image=test_image) - assert locator.image.root == test_image - assert locator.threshold == 0.5 - assert locator.stop_threshold == 0.5 - assert locator.mask is None - assert locator.rotation_degree_per_step == 0 - assert locator.image_compare_format == "grayscale" + assert locator._image.root == test_image + assert locator._threshold == 0.5 + assert locator._stop_threshold == 0.5 + assert locator._mask is None + assert locator._rotation_degree_per_step == 0 + assert locator._image_compare_format == "grayscale" assert re.match(self._STR_PATTERN, str(locator)) def test_initialization_with_name(self, test_image: PILImage.Image) -> None: @@ -136,11 +136,11 @@ def test_initialization_with_custom_params(self, test_image: PILImage.Image) -> rotation_degree_per_step=45, image_compare_format="RGB" ) - assert locator.threshold == 0.7 - assert locator.stop_threshold == 0.95 - assert locator.mask == [(0, 0), (1, 0), (1, 1)] - assert locator.rotation_degree_per_step == 45 - assert locator.image_compare_format == "RGB" + assert locator._threshold == 0.7 + assert locator._stop_threshold == 0.95 + assert locator._mask == [(0, 0), (1, 0), (1, 1)] + assert locator._rotation_degree_per_step == 45 + assert locator._image_compare_format == "RGB" assert re.match(r'^element "anonymous image [a-f0-9-]+" located by image \(threshold: 0.7, stop_threshold: 0.95, rotation_degree_per_step: 45, image_compare_format: RGB, mask: \[\(0.0, 0.0\), \(1.0, 0.0\), \(1.0, 1.0\)\]\)$', str(locator)) def test_initialization_with_invalid_args(self, test_image: PILImage.Image) -> None: @@ -175,7 +175,7 @@ def test_initialization_with_invalid_args(self, test_image: PILImage.Image) -> N class TestAiElementLocator: def test_initialization_with_name(self) -> None: locator = AiElement("github_com__icon") - assert locator.name == "github_com__icon" + assert locator._name == "github_com__icon" assert str(locator) == 'ai element named "github_com__icon" (threshold: 0.5, stop_threshold: 0.5, rotation_degree_per_step: 0, image_compare_format: grayscale, mask: None)' def test_initialization_without_name_raises(self) -> None: @@ -195,12 +195,12 @@ def test_initialization_with_custom_params(self) -> None: rotation_degree_per_step=45, image_compare_format="RGB" ) - assert locator.name == "test_element" - assert locator.threshold == 0.7 - assert locator.stop_threshold == 0.95 - assert locator.mask == [(0, 0), (1, 0), (1, 1)] - assert locator.rotation_degree_per_step == 45 - assert locator.image_compare_format == "RGB" + assert locator._name == "test_element" + assert locator._threshold == 0.7 + assert locator._stop_threshold == 0.95 + assert locator._mask == [(0, 0), (1, 0), (1, 1)] + assert locator._rotation_degree_per_step == 45 + assert locator._image_compare_format == "RGB" assert str(locator) == 'ai element named "test_element" (threshold: 0.7, stop_threshold: 0.95, rotation_degree_per_step: 45, image_compare_format: RGB, mask: [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])' def test_initialization_with_invalid_threshold(self) -> None: