From e80fc5079cb03278811cfe03a0a672d7a23f4baa Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 15 Apr 2025 17:44:52 +0200 Subject: [PATCH 01/14] refactor!(agent): rename `model_name` parameter to `model` --- README.md | 10 +- src/askui/agent.py | 51 +++++---- src/askui/chat/__main__.py | 6 +- src/askui/models/anthropic/claude.py | 4 +- src/askui/models/router.py | 66 +++++------ tests/e2e/agent/test_get.py | 24 ++-- tests/e2e/agent/test_locate.py | 58 +++++----- .../test_locate_with_different_models.py | 54 ++++----- tests/e2e/agent/test_locate_with_relations.py | 106 +++++++++--------- 9 files changed, 196 insertions(+), 183 deletions(-) diff --git a/README.md b/README.md index ee576c37..0bc22c82 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ You can test the Vision Agent with Huggingface models via their Spaces API. Plea **Example Code:** ```python -agent.click("search field", model_name="OS-Copilot/OS-Atlas-Base-7B") +agent.click("search field", model="OS-Copilot/OS-Atlas-Base-7B") ``` ### 3c. Host your own **AI Models** @@ -143,7 +143,7 @@ You can use Vision Agent with UI-TARS if you provide your own UI-TARS API endpoi 2. Step: Provide the `TARS_URL` and `TARS_API_KEY` environment variables to Vision Agent. -3. Step: Use the `model_name="tars"` parameter in your `click()`, `get()` and `act()` etc. commands. +3. Step: Use the `model="tars"` parameter in your `click()`, `get()` and `act()` etc. commands. ## ▶️ Start Building @@ -171,7 +171,7 @@ with VisionAgent() as agent: ### 🎛️ Model Selection -Instead of relying on the default model for the entire automation script, you can specify a model for each `click()` (or `act()`, `get()` etc.) command using the `model_name` parameter. +Instead of relying on the default model for the entire automation script, you can specify a model for each `click()` (or `act()`, `get()` etc.) command using the `model` parameter. | | AskUI | Anthropic | |----------|----------|----------| @@ -182,7 +182,7 @@ Instead of relying on the default model for the entire automation script, you ca | `mouse_move()` | `askui`, `askui-combo`, `askui-pta`, `askui-ocr`, `askui-ai-element` | `anthropic-claude-3-5-sonnet-20241022` | -**Example:** `agent.click("Preview", model_name="askui-combo")` +**Example:** `agent.click("Preview", model="askui-combo")`
AskUI AI Models @@ -353,7 +353,7 @@ agent.type("********") you can build more sophisticated locators. -**⚠️ Warning:** Support can vary depending on the model you are using. Currently, only, the `askui` model provides best support for locators. This model is chosen by default if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` environment variables are set and it is not overridden using the `model_name` parameter. +**⚠️ Warning:** Support can vary depending on the model you are using. Currently, only, the `askui` model provides best support for locators. This model is chosen by default if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` environment variables are set and it is not overridden using the `model` parameter. Example: diff --git a/src/askui/agent.py b/src/askui/agent.py index 3163398f..5f1a6ea8 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -51,7 +51,7 @@ def __init__( self._controller = AskUiControllerServer() @telemetry.record_call(exclude={"locator"}) - def click(self, locator: Optional[str | Locator] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model_name: Optional[str] = None) -> None: + def click(self, locator: Optional[str | Locator] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model: Optional[str] = None) -> None: """ Simulates a mouse click on the user interface element identified by the provided locator. @@ -59,7 +59,7 @@ def click(self, locator: Optional[str | Locator] = None, button: Literal['left', locator (str | Locator | None): The identifier or description of the element to click. button ('left' | 'middle' | 'right'): Specifies which mouse button to click. Defaults to 'left'. repeat (int): The number of times to click. Must be greater than 0. Defaults to 1. - model_name (str | None): The model name to be used for element detection. Optional. + model (str | None): The model name to be used for element detection. Optional. Raises: InvalidParameterError: If the 'repeat' parameter is less than 1. @@ -86,45 +86,56 @@ def click(self, locator: Optional[str | Locator] = None, button: Literal['left', self._reporter.add_message("User", msg) if locator is not None: logger.debug("VisionAgent received instruction to click on %s", locator) - self._mouse_move(locator, model_name) + self._mouse_move(locator, model) self.tools.os.click(button, repeat) # type: ignore - def _locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model_name: Optional[str] = None) -> Point: + def _locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: Optional[str] = None) -> Point: if screenshot is None: screenshot = self.tools.os.screenshot() # type: ignore - point = self.model_router.locate(screenshot, locator, model_name) + point = self.model_router.locate(screenshot, locator, model) self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") return point - def locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model_name: Optional[str] = None) -> Point: + def locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: Optional[str] = None) -> Point: + """ + Locates the UI element identified by the provided locator. + + Args: + locator (str | Locator): The identifier or description of the element to locate. + screenshot (Optional[Image.Image], optional): The screenshot to use for locating the element. Defaults to None. + model (Optional[str], optional): The model to use for locating the element. Defaults to None. + + Returns: + Point: The coordinates of the element. + """ self._reporter.add_message("User", f"locate {locator}") logger.debug("VisionAgent received instruction to locate %s", locator) - return self._locate(locator, screenshot, model_name) + return self._locate(locator, screenshot, model) - def _mouse_move(self, locator: str | Locator, model_name: Optional[str] = None) -> None: - point = self._locate(locator=locator, model_name=model_name) + def _mouse_move(self, locator: str | Locator, model: Optional[str] = None) -> None: + point = self._locate(locator=locator, model=model) self.tools.os.mouse(point[0], point[1]) # type: ignore @telemetry.record_call(exclude={"locator"}) - def mouse_move(self, locator: str | Locator, model_name: Optional[str] = None) -> None: + def mouse_move(self, locator: str | Locator, model: Optional[str] = None) -> None: """ Moves the mouse cursor to the UI element identified by the provided locator. Parameters: locator (str | Locator): The identifier or description of the element to move to. - model_name (str | None): The model name to be used for element detection. Optional. + model (str | None): The model name to be used for element detection. Optional. Example: ```python with VisionAgent() as agent: agent.mouse_move("Submit button") # Moves cursor to submit button agent.mouse_move("Close") # Moves cursor to close element - agent.mouse_move("Profile picture", model_name="custom_model") # Uses specific model + agent.mouse_move("Profile picture", model="custom_model") # Uses specific model ``` """ self._reporter.add_message("User", f'mouse_move: {locator}') logger.debug("VisionAgent received instruction to mouse_move to %s", locator) - self._mouse_move(locator, model_name) + self._mouse_move(locator, model) @telemetry.record_call() def mouse_scroll(self, x: int, y: int) -> None: @@ -179,7 +190,7 @@ def get( query: str, image: Optional[ImageSource] = None, response_schema: Type[JsonSchema] | None = None, - model_name: Optional[str] = None, + model: Optional[str] = None, ) -> JsonSchema | str: """ Retrieves information from an image (defaults to a screenshot of the current screen) based on the provided query. @@ -191,8 +202,8 @@ def get( The image to extract information from. Optional. Defaults to a screenshot of the current screen. response_schema (type[ResponseSchema] | None): A Pydantic model class that defines the response schema. Optional. If not provided, returns a string. - model_name (str | None): - The model name to be used for information extraction. Optional. + model (str | None): + The model to be used for information extraction. Optional. Note: response_schema is only supported with models that support JSON output (like the default askui model). Returns: @@ -228,7 +239,7 @@ class UrlResponse(JsonSchemaBase): response = self.model_router.get_inference( image=image, query=query, - model_name=model_name, + model=model, response_schema=response_schema, ) if self._reporter is not None: @@ -296,7 +307,7 @@ def key_down(self, key: PcKey | ModifierKey) -> None: self.tools.os.keyboard_pressed(key) @telemetry.record_call(exclude={"goal"}) - def act(self, goal: str, model_name: Optional[str] = None) -> None: + def act(self, goal: str, model: Optional[str] = None) -> None: """ Instructs the agent to achieve a specified goal through autonomous actions. @@ -306,7 +317,7 @@ def act(self, goal: str, model_name: Optional[str] = None) -> None: Parameters: goal (str): A description of what the agent should achieve. - model_name (str | None): The specific model to use for vision analysis. + model (str | None): The specific model to use for vision analysis. If None, uses the default model. Example: @@ -321,7 +332,7 @@ def act(self, goal: str, model_name: Optional[str] = None) -> None: logger.debug( "VisionAgent received instruction to act towards the goal '%s'", goal ) - self.model_router.act(self.tools.os, goal, model_name) + self.model_router.act(self.tools.os, goal, model) @telemetry.record_call() def keyboard( diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 85e5c316..add72aa8 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -203,7 +203,7 @@ def rerun(): element_description = agent.get( query=prompt, image=screenshot_with_crosshair, - model_name="anthropic-claude-3-5-sonnet-20241022", + model="anthropic-claude-3-5-sonnet-20241022", ) write_message( message["role"], @@ -213,7 +213,7 @@ def rerun(): ) agent.mouse_move( locator=element_description.replace('"', ""), - model_name="anthropic-claude-3-5-sonnet-20241022", + model="anthropic-claude-3-5-sonnet-20241022", ) else: write_message( @@ -306,7 +306,7 @@ def rerun(): log_level=logging.DEBUG, reporters=[reporter], ) as agent: - agent.act(act_prompt, model_name="claude") + agent.act(act_prompt, model="claude") st.rerun() if st.button("Rerun"): diff --git a/src/askui/models/anthropic/claude.py b/src/askui/models/anthropic/claude.py index b4229c8b..12f1cf14 100644 --- a/src/askui/models/anthropic/claude.py +++ b/src/askui/models/anthropic/claude.py @@ -11,7 +11,7 @@ class ClaudeHandler: def __init__(self, log_level): - self.model_name = "claude-3-5-sonnet-20241022" + self.model = "claude-3-5-sonnet-20241022" self.client = anthropic.Anthropic() self.resolution = (1280, 800) self.log_level = log_level @@ -21,7 +21,7 @@ def __init__(self, log_level): def _inference(self, base64_image: str, prompt: str, system_prompt: str) -> list[anthropic.types.ContentBlock]: message = self.client.messages.create( - model=self.model_name, + model=self.model, max_tokens=1000, temperature=0, system=system_prompt, diff --git a/src/askui/models/router.py b/src/askui/models/router.py index 22f0d57a..4009908f 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -36,12 +36,12 @@ def locate( self, screenshot: Image.Image, locator: str | Locator, - model_name: str | None = None, + model: str | None = None, ) -> Point: pass @abstractmethod - def is_responsible(self, model_name: str | None = None) -> bool: + def is_responsible(self, model: str | None = None) -> bool: pass @abstractmethod @@ -63,13 +63,13 @@ def locate( self, screenshot: Image.Image, locator: str | Locator, - model_name: str | None = None, + model: str | None = None, ) -> Point: if not self._inference_api.authenticated: raise AutomationError( "NoAskUIAuthenticationSet! Please set 'AskUI ASKUI_WORKSPACE_ID' or 'ASKUI_TOKEN' as env variables!" ) - if model_name == "askui" or model_name is None: + if model == "askui" or model is None: logger.debug("Routing locate prediction to askui") locator = Text(locator) if isinstance(locator, str) else locator x, y = self._inference_api.predict(screenshot, locator) @@ -78,30 +78,30 @@ def locate( raise AutomationError( f'Locators of type `{type(locator)}` are not supported for models "askui-pta", "askui-ocr" and "askui-combo" and "askui-ai-element". Please provide a `str`.' ) - if model_name == "askui-pta": + if model == "askui-pta": logger.debug("Routing locate prediction to askui-pta") x, y = self._inference_api.predict(screenshot, Description(locator)) return handle_response((x, y), locator) - if model_name == "askui-ocr": + if model == "askui-ocr": logger.debug("Routing locate prediction to askui-ocr") return self._locate_with_askui_ocr(screenshot, locator) - if model_name == "askui-combo" or model_name is None: + if model == "askui-combo" or model is None: logger.debug("Routing locate prediction to askui-combo") description_locator = Description(locator) x, y = self._inference_api.predict(screenshot, description_locator) if x is None or y is None: return self._locate_with_askui_ocr(screenshot, locator) return handle_response((x, y), description_locator) - if model_name == "askui-ai-element": + if model == "askui-ai-element": logger.debug("Routing click prediction to askui-ai-element") _locator = AiElement(locator) x, y = self._inference_api.predict(screenshot, _locator) return handle_response((x, y), _locator) - raise AutomationError(f'Invalid model name: "{model_name}"') + raise AutomationError(f'Invalid model: "{model}"') @override - def is_responsible(self, model_name: str | None = None) -> bool: - return model_name is None or model_name.startswith("askui") + def is_responsible(self, model: str | None = None) -> bool: + return model is None or model.startswith("askui") @override def is_authenticated(self) -> bool: @@ -127,39 +127,39 @@ def __init__( self.tars = UITarsAPIHandler(self._reporter) self._locator_serializer = VlmLocatorSerializer() - def act(self, controller_client, goal: str, model_name: str | None = None): - if self.tars.authenticated and model_name == "tars": + def act(self, controller_client, goal: str, model: str | None = None): + if self.tars.authenticated and model == "tars": return self.tars.act(controller_client, goal) - if self.claude.authenticated and (model_name == "claude" or model_name is None): + if self.claude.authenticated and (model == "claude" or model is None): agent = ClaudeComputerAgent(controller_client, self._reporter) return agent.run(goal) - raise AutomationError(f"Invalid model name for act: {model_name}") + raise AutomationError(f"Invalid model for act: {model}") def get_inference( self, query: str, image: ImageSource, response_schema: Type[JsonSchema] | None = None, - model_name: str | None = None, + model: str | None = None, ) -> JsonSchema | str: - if self.tars.authenticated and model_name == "tars": + if self.tars.authenticated and model == "tars": if response_schema is not None: raise NotImplementedError("Response schema is not yet supported for UI-TARS models.") return self.tars.get_inference(image=image, query=query) if self.claude.authenticated and ( - model_name == "anthropic-claude-3-5-sonnet-20241022" + model == "anthropic-claude-3-5-sonnet-20241022" ): if response_schema is not None: raise NotImplementedError("Response schema is not yet supported for Anthropic models.") return self.claude.get_inference(image=image, query=query) - if self.askui.authenticated and (model_name == "askui" or model_name is None): + if self.askui.authenticated and (model == "askui" or model is None): return self.askui.get_inference( image=image, query=query, response_schema=response_schema, ) raise AutomationError( - f"Executing get commands requires to authenticate with an Automation Model Provider supporting it: {model_name}" + f"Executing get commands requires to authenticate with an Automation Model Provider supporting it: {model}" ) def _serialize_locator(self, locator: str | Locator) -> str: @@ -172,33 +172,35 @@ def locate( self, screenshot: Image.Image, locator: str | Locator, - model_name: str | None = None, + model: str | None = None, ) -> Point: if ( - model_name is not None - and model_name in self.huggingface_spaces.get_spaces_names() + model is not None + and model in self.huggingface_spaces.get_spaces_names() ): x, y = self.huggingface_spaces.predict( - screenshot, self._serialize_locator(locator), model_name + screenshot=screenshot, + locator=self._serialize_locator(locator), + model_name=model, ) return handle_response((x, y), locator) - if model_name is not None: - if model_name.startswith("anthropic") and not self.claude.authenticated: + if model is not None: + if model.startswith("anthropic") and not self.claude.authenticated: raise AutomationError( "You need to provide Anthropic credentials to use Anthropic models." ) - if model_name.startswith("tars") and not self.tars.authenticated: + if model.startswith("tars") and not self.tars.authenticated: raise AutomationError( "You need to provide UI-TARS HF Endpoint credentials to use UI-TARS models." ) - if self.tars.authenticated and model_name == "tars": + if self.tars.authenticated and model == "tars": x, y = self.tars.locate_prediction( screenshot, self._serialize_locator(locator) ) return handle_response((x, y), locator) if ( self.claude.authenticated - and model_name == "anthropic-claude-3-5-sonnet-20241022" + and model == "anthropic-claude-3-5-sonnet-20241022" ): logger.debug("Routing locate prediction to Anthropic") x, y = self.claude.locate_inference( @@ -208,12 +210,12 @@ def locate( for grounding_model_router in self.grounding_model_routers: if ( - grounding_model_router.is_responsible(model_name) + grounding_model_router.is_responsible(model) and grounding_model_router.is_authenticated() ): - return grounding_model_router.locate(screenshot, locator, model_name) + return grounding_model_router.locate(screenshot, locator, model) - if model_name is None: + if model is None: if self.claude.authenticated: logger.debug("Routing locate prediction to Anthropic") x, y = self.claude.locate_inference( diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index c47ab838..2e9f5ef1 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -20,16 +20,16 @@ class BrowserContextResponse(JsonSchemaBase): browser_type: Literal["chrome", "firefox", "edge", "safari"] -@pytest.mark.parametrize("model_name", [None, models.ASKUI, models.ANTHROPIC]) +@pytest.mark.parametrize("model", [None, models.ASKUI, models.ANTHROPIC]) def test_get( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: url = vision_agent.get( "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), - model_name=model_name, + model=model, ) assert url == "github.com/login" @@ -44,7 +44,7 @@ def test_get_with_response_schema_without_additional_properties_with_askui_model "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), response_schema=UrlResponse, - model_name=models.ASKUI, + model=models.ASKUI, ) @@ -58,21 +58,21 @@ def test_get_with_response_schema_without_required_with_askui_model_raises( "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), response_schema=UrlResponse, - model_name=models.ASKUI, + model=models.ASKUI, ) -@pytest.mark.parametrize("model_name", [None, models.ASKUI]) +@pytest.mark.parametrize("model", [None, models.ASKUI]) def test_get_with_response_schema( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: response = vision_agent.get( "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), response_schema=UrlResponse, - model_name=model_name, + model=model, ) assert isinstance(response, UrlResponse) assert response.url in ["https://github.com/login", "github.com/login"] @@ -87,22 +87,22 @@ def test_get_with_response_schema_with_anthropic_model_raises_not_implemented( "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), response_schema=UrlResponse, - model_name=models.ANTHROPIC, + model=models.ANTHROPIC, ) -@pytest.mark.parametrize("model_name", [None, models.ASKUI]) +@pytest.mark.parametrize("model", [None, models.ASKUI]) @pytest.mark.skip("Skip as there is currently a bug on the api side not supporting definitions used for nested schemas") def test_get_with_nested_and_inherited_response_schema( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: response = vision_agent.get( "What is the current browser context?", ImageSource(github_login_screenshot), response_schema=BrowserContextResponse, - model_name=model_name, + model=model, ) assert isinstance(response, BrowserContextResponse) assert response.page_context.url in ["https://github.com/login", "github.com/login"] diff --git a/tests/e2e/agent/test_locate.py b/tests/e2e/agent/test_locate.py index ad6f7a3f..fe20fad5 100644 --- a/tests/e2e/agent/test_locate.py +++ b/tests/e2e/agent/test_locate.py @@ -17,7 +17,7 @@ @pytest.mark.skip("Skipping tests for now") @pytest.mark.parametrize( - "model_name", + "model", [ "askui", "anthropic-claude-3-5-sonnet-20241022", @@ -30,12 +30,12 @@ def test_locate_with_string_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a simple string locator.""" locator = "Forgot password?" x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -44,12 +44,12 @@ def test_locate_with_textfield_class_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a class locator.""" locator = Class("textfield") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 50 <= x <= 860 or 350 <= x <= 570 assert 0 <= y <= 80 or 160 <= y <= 280 @@ -58,12 +58,12 @@ def test_locate_with_unspecified_class_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a class locator.""" locator = Class() x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 0 <= x <= github_login_screenshot.width assert 0 <= y <= github_login_screenshot.height @@ -72,12 +72,12 @@ def test_locate_with_description_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a description locator.""" locator = Description("Username textfield") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 160 <= y <= 230 @@ -86,12 +86,12 @@ def test_locate_with_similar_text_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a text locator.""" locator = Text("Forgot password?") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -100,12 +100,12 @@ def test_locate_with_typo_text_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a text locator with a typo.""" locator = Text("Forgot pasword", similarity_threshold=90) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -114,12 +114,12 @@ def test_locate_with_exact_text_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a text locator.""" locator = Text("Forgot password?", match_type="exact") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -128,12 +128,12 @@ def test_locate_with_regex_text_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a text locator.""" locator = Text(r"F.*?", match_type="regex") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -142,12 +142,12 @@ def test_locate_with_contains_text_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using a text locator.""" locator = Text("Forgot", match_type="contains") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -156,7 +156,7 @@ def test_locate_with_image( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator.""" @@ -164,7 +164,7 @@ def test_locate_with_image( image = PILImage.open(image_path) locator = Image(image=image) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -173,7 +173,7 @@ def test_locate_with_image_and_custom_params( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator with custom parameters.""" @@ -188,7 +188,7 @@ def test_locate_with_image_and_custom_params( name="Sign in button" ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -197,7 +197,7 @@ def test_locate_with_image_should_fail_when_threshold_is_too_high( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator with custom parameters.""" @@ -209,18 +209,18 @@ def test_locate_with_image_should_fail_when_threshold_is_too_high( stop_threshold=1.0 ) with pytest.raises(ElementNotFoundError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) def test_locate_with_ai_element_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using an AI element locator.""" locator = AiElement("github_com__icon") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -229,9 +229,9 @@ def test_locate_with_ai_element_locator_should_fail_when_threshold_is_too_high( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using image locator with custom parameters.""" locator = AiElement("github_com__icon") with pytest.raises(ElementNotFoundError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) diff --git a/tests/e2e/agent/test_locate_with_different_models.py b/tests/e2e/agent/test_locate_with_different_models.py index e50cbca9..2c9ebb5b 100644 --- a/tests/e2e/agent/test_locate_with_different_models.py +++ b/tests/e2e/agent/test_locate_with_different_models.py @@ -15,125 +15,125 @@ class TestVisionAgentLocateWithDifferentModels: """Test class for VisionAgent.locate() method with different AskUI models.""" - @pytest.mark.parametrize("model_name", ["askui-pta"]) + @pytest.mark.parametrize("model", ["askui-pta"]) def test_locate_with_pta_model( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using PTA model with description locator.""" locator = "Username textfield" x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 160 <= y <= 230 - @pytest.mark.parametrize("model_name", ["askui-pta"]) + @pytest.mark.parametrize("model", ["askui-pta"]) def test_locate_with_pta_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test that PTA model fails with wrong locator type.""" locator = Text("Username textfield") with pytest.raises(AutomationError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) - @pytest.mark.parametrize("model_name", ["askui-ocr"]) + @pytest.mark.parametrize("model", ["askui-ocr"]) def test_locate_with_ocr_model( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using OCR model with text locator.""" locator = "Forgot password?" x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 - @pytest.mark.parametrize("model_name", ["askui-ocr"]) + @pytest.mark.parametrize("model", ["askui-ocr"]) def test_locate_with_ocr_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test that OCR model fails with wrong locator type.""" locator = Description("Forgot password?") with pytest.raises(AutomationError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) - @pytest.mark.parametrize("model_name", ["askui-ai-element"]) + @pytest.mark.parametrize("model", ["askui-ai-element"]) def test_locate_with_ai_element_model( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using AI element model.""" locator = "github_com__signin__button" x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 - @pytest.mark.parametrize("model_name", ["askui-ai-element"]) + @pytest.mark.parametrize("model", ["askui-ai-element"]) def test_locate_with_ai_element_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test that AI element model fails with wrong locator type.""" locator = Text("Sign in") with pytest.raises(AutomationError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) - @pytest.mark.parametrize("model_name", ["askui-combo"]) + @pytest.mark.parametrize("model", ["askui-combo"]) def test_locate_with_combo_model_description_first( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using combo model with description locator.""" locator = "Username textfield" x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 160 <= y <= 230 - @pytest.mark.parametrize("model_name", ["askui-combo"]) + @pytest.mark.parametrize("model", ["askui-combo"]) def test_locate_with_combo_model_text_fallback( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using combo model with text locator as fallback.""" locator = "Forgot password?" x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 - @pytest.mark.parametrize("model_name", ["askui-combo"]) + @pytest.mark.parametrize("model", ["askui-combo"]) def test_locate_with_combo_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test that combo model fails with wrong locator type.""" locator = AiElement("github_com__signin__button") with pytest.raises(AutomationError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) diff --git a/tests/e2e/agent/test_locate_with_relations.py b/tests/e2e/agent/test_locate_with_relations.py index ed58be62..dabbba13 100644 --- a/tests/e2e/agent/test_locate_with_relations.py +++ b/tests/e2e/agent/test_locate_with_relations.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize( - "model_name", + "model", [ "askui", ], @@ -27,12 +27,12 @@ def test_locate_with_above_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using above_of relation.""" locator = Text("Forgot password?").above_of(Class("textfield")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -41,12 +41,12 @@ def test_locate_with_below_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using below_of relation.""" locator = Text("Forgot password?").below_of(Class("textfield")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -55,12 +55,12 @@ def test_locate_with_right_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using right_of relation.""" locator = Text("Forgot password?").right_of(Text("Password")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -69,12 +69,12 @@ def test_locate_with_left_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using left_of relation.""" locator = Text("Password").left_of(Text("Forgot password?")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 450 assert 190 <= y <= 260 @@ -83,12 +83,12 @@ def test_locate_with_containing_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using containing relation.""" locator = Class("textfield").containing(Text("github.com/login")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 50 <= x <= 860 assert 0 <= y <= 80 @@ -97,12 +97,12 @@ def test_locate_with_inside_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using inside_of relation.""" locator = Text("github.com/login").inside_of(Class("textfield")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 70 <= x <= 200 assert 10 <= y <= 75 @@ -111,12 +111,12 @@ def test_locate_with_nearest_to_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using nearest_to relation.""" locator = Class("textfield").nearest_to(Text("Password")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 210 <= y <= 280 @@ -126,12 +126,12 @@ def test_locate_with_and_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using and_ relation.""" locator = Text("Forgot password?").and_(Class("text")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -140,14 +140,14 @@ def test_locate_with_or_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using or_ relation.""" locator = Class("textfield").nearest_to( Text("Password").or_(Text("Username or email address")) ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 160 <= y <= 280 @@ -156,14 +156,14 @@ def test_locate_with_relation_index( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with index.""" locator = Class("textfield").below_of( Text("Username or email address"), index=0 ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 160 <= y <= 230 @@ -172,12 +172,12 @@ def test_locate_with_relation_index_greater_0( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with index.""" locator = Class("textfield").below_of(Class("textfield"), index=1) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 210 <= y <= 280 @@ -187,12 +187,12 @@ def test_locate_with_relation_index_greater_1( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with index.""" locator = Text("Sign in").below_of(Text(), index=4, reference_point="any") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 420 <= x <= 500 assert 250 <= y <= 310 @@ -201,14 +201,14 @@ def test_locate_with_relation_reference_point_center( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with center reference point.""" locator = Text("Forgot password?").right_of( Text("Password"), reference_point="center" ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -217,25 +217,25 @@ def test_locate_with_relation_reference_point_center_raises_when_element_cannot_ self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with center reference point.""" locator = Text("Sign in").below_of(Text("Password"), reference_point="center") with pytest.raises(ElementNotFoundError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) def test_locate_with_relation_reference_point_boundary( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with boundary reference point.""" locator = Text("Forgot password?").right_of( Text("Password"), reference_point="boundary" ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 450 <= x <= 570 assert 190 <= y <= 260 @@ -244,23 +244,23 @@ def test_locate_with_relation_reference_point_boundary_raises_when_element_canno self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with boundary reference point.""" locator = Text("Sign in").below_of(Text("Password"), reference_point="boundary") with pytest.raises(ElementNotFoundError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) def test_locate_with_relation_reference_point_any( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with any reference point.""" locator = Text("Sign in").below_of(Text("Password"), reference_point="any") x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 420 <= x <= 500 assert 250 <= y <= 310 @@ -269,7 +269,7 @@ def test_locate_with_multiple_relations_with_same_locator_raises( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using multiple relations with same locator which is not supported by AskUI.""" locator = ( @@ -278,13 +278,13 @@ def test_locate_with_multiple_relations_with_same_locator_raises( .below_of(Class("textfield")) ) with pytest.raises(NotImplementedError): - vision_agent.locate(locator, github_login_screenshot, model_name=model_name) + vision_agent.locate(locator, github_login_screenshot, model=model) def test_locate_with_chained_relations( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using chained relations.""" locator = Text("Sign in").below_of( @@ -292,7 +292,7 @@ def test_locate_with_chained_relations( reference_point="any", ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 420 <= x <= 500 assert 250 <= y <= 310 @@ -301,7 +301,7 @@ def test_locate_with_relation_different_locator_types( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using relation with different locator types.""" locator = Text("Sign in").below_of( @@ -309,7 +309,7 @@ def test_locate_with_relation_different_locator_types( reference_point="center", ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 420 <= x <= 500 assert 250 <= y <= 310 @@ -318,12 +318,12 @@ def test_locate_with_description_and_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using description with relation.""" locator = Description("Sign in button").below_of(Description("Password field")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -333,14 +333,14 @@ def test_locate_with_description_and_complex_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using description with relation.""" locator = Description("Sign in button").below_of( Class("textfield").below_of(Text("Password")) ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -349,7 +349,7 @@ def test_locate_with_image_and_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator with relation.""" @@ -357,7 +357,7 @@ def test_locate_with_image_and_relation( image = PILImage.open(image_path) locator = Image(image=image).containing(Text("Sign in")) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -366,7 +366,7 @@ def test_locate_with_image_in_relation_to_other_image( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator with relation.""" @@ -377,7 +377,7 @@ def test_locate_with_image_in_relation_to_other_image( github_icon = Image(image=github_icon_image) signin_button = Image(image=signin_button_image).below_of(github_icon) x, y = vision_agent.locate( - signin_button, github_login_screenshot, model_name=model_name + signin_button, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -386,7 +386,7 @@ def test_locate_with_image_and_complex_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator with complex relation.""" @@ -396,7 +396,7 @@ def test_locate_with_image_and_complex_relation( Class("textfield").below_of(Text("Password")) ) x, y = vision_agent.locate( - locator, github_login_screenshot, model_name=model_name + locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 @@ -405,13 +405,13 @@ def test_locate_with_ai_element_locator_relation( self, vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, - model_name: str, + model: str, ) -> None: """Test locating elements using an AI element locator with relation.""" icon_locator = AiElement("github_com__icon") signin_locator = AiElement("github_com__signin__button") x, y = vision_agent.locate( - signin_locator.below_of(icon_locator), github_login_screenshot, model_name=model_name + signin_locator.below_of(icon_locator), github_login_screenshot, model=model ) assert 350 <= x <= 570 assert 240 <= y <= 320 From b6f78370b67a398af194182bca5e9ac266d27fca Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 16 Apr 2025 10:55:31 +0200 Subject: [PATCH 02/14] feat!(agent): enable selecting models using composition / for whole agent - set default model for `VisionAgent` using `model` parameter - use `ModelComposition` for askui models - possible values of model of `VisionAgent.act()` changed - only for locate (not for get or act) BREAKING CHANGE: - model value "claude" for `VisionAgent.act()` changed to "anthropic-claude-3-5-sonnet-20241022" --- README.md | 26 +++- src/askui/agent.py | 75 ++++++++---- src/askui/chat/__main__.py | 5 +- src/askui/locators/locators.py | 10 +- src/askui/locators/serializers.py | 5 + src/askui/models/__init__.py | 8 +- src/askui/models/askui/api.py | 6 +- src/askui/models/models.py | 93 +++++++++++++-- src/askui/models/router.py | 52 ++++---- src/askui/telemetry/telemetry.py | 5 + tests/e2e/agent/test_get.py | 14 +-- tests/e2e/agent/test_locate.py | 5 +- .../test_locate_with_different_models.py | 19 +-- tests/e2e/agent/test_model_composition.py | 111 ++++++++++++++++++ .../test_askui_locator_serializer.py | 24 ++-- tests/unit/{unit => models}/__init__.py | 0 tests/unit/models/test_models.py | 85 ++++++++++++++ .../unit/{unit => utils}/test_image_utils.py | 0 18 files changed, 443 insertions(+), 100 deletions(-) create mode 100644 tests/e2e/agent/test_model_composition.py rename tests/unit/{unit => models}/__init__.py (100%) create mode 100644 tests/unit/models/test_models.py rename tests/unit/{unit => utils}/test_image_utils.py (100%) diff --git a/README.md b/README.md index 0bc22c82..5183b739 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ pip install askui | | AskUI [INFO](https://hub.askui.com/) | Anthropic [INFO](https://console.anthropic.com/settings/keys) | |----------|----------|----------| | ENV Variables | `ASKUI_WORKSPACE_ID`, `ASKUI_TOKEN` | `ANTHROPIC_API_KEY` | -| Supported Commands | `click()`, `locate()`, `mouse_move()` | `act()`, `get()`, `click()`, `locate()`, `mouse_move()` | +| Supported Commands | `click()`, `get()`, `locate()`, `mouse_move()` | `act()`, `click()`, `get()`, `locate()`, `mouse_move()` | | Description | Faster Inference, European Server, Enterprise Ready | Supports complex actions | To get started, set the environment variables required to authenticate with your chosen model provider. @@ -143,7 +143,7 @@ You can use Vision Agent with UI-TARS if you provide your own UI-TARS API endpoi 2. Step: Provide the `TARS_URL` and `TARS_API_KEY` environment variables to Vision Agent. -3. Step: Use the `model="tars"` parameter in your `click()`, `get()` and `act()` etc. commands. +3. Step: Use the `model="tars"` parameter in your `click()`, `get()` and `act()` etc. commands or when initializing the `VisionAgent`. ## ▶️ Start Building @@ -171,18 +171,34 @@ with VisionAgent() as agent: ### 🎛️ Model Selection -Instead of relying on the default model for the entire automation script, you can specify a model for each `click()` (or `act()`, `get()` etc.) command using the `model` parameter. +Instead of relying on the default model for the entire automation script, you can specify a model for each `click()` (or `act()`, `get()` etc.) command using the `model` parameter or when initializing the `VisionAgent` (overridden by the `model` parameter of individual commands). | | AskUI | Anthropic | |----------|----------|----------| | `act()` | | `anthropic-claude-3-5-sonnet-20241022` | | `click()` | `askui`, `askui-combo`, `askui-pta`, `askui-ocr`, `askui-ai-element` | `anthropic-claude-3-5-sonnet-20241022` | -| `get()` | | `anthropic-claude-3-5-sonnet-20241022` | +| `get()` | | `askui`, `anthropic-claude-3-5-sonnet-20241022` | | `locate()` | `askui`, `askui-combo`, `askui-pta`, `askui-ocr`, `askui-ai-element` | `anthropic-claude-3-5-sonnet-20241022` | | `mouse_move()` | `askui`, `askui-combo`, `askui-pta`, `askui-ocr`, `askui-ai-element` | `anthropic-claude-3-5-sonnet-20241022` | -**Example:** `agent.click("Preview", model="askui-combo")` +**Example:** + +```python +from askui import VisionAgent + +with VisionAgent() as agent: + # Uses the default model (depending on the environment variables set, see above) + agent.click("Next") + +with VisionAgent(model="askui-combo") as agent: + # Uses the "askui-combo" model because it was specified when initializing the agent + agent.click("Next") + # Uses the "anthropic-claude-3-5-sonnet-20241022" model + agent.click("Previous", model="anthropic-claude-3-5-sonnet-20241022") + # Uses the "askui-combo" model again as no model was specified + agent.click("Next") +```
AskUI AI Models diff --git a/src/askui/agent.py b/src/askui/agent.py index 5f1a6ea8..03d174a8 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -16,6 +16,7 @@ from .models.anthropic.claude import ClaudeHandler from .logger import logger, configure_logging from .tools.toolbox import AgentToolbox +from .models import ModelComposition from .models.router import ModelRouter, Point from .reporting import CompositeReporter, Reporter import time @@ -29,6 +30,35 @@ class InvalidParameterError(Exception): class VisionAgent: + """ + A vision-based agent that can interact with user interfaces through computer vision and AI. + + This agent can perform various UI interactions like clicking, typing, scrolling, and more. + It uses computer vision models to locate UI elements and execute actions on them. + + Parameters: + log_level (int, optional): + The logging level to use. Defaults to logging.INFO. + display (int, optional): + The display number to use for screen interactions. Defaults to 1. + model_router (ModelRouter | None, optional): + Custom model router instance. If None, a default one will be created. + reporters (list[Reporter] | None, optional): + List of reporter instances for logging and reporting. If None, an empty list is used. + tools (AgentToolbox | None, optional): + Custom toolbox instance. If None, a default one will be created with AskUiControllerClient. + model (ModelComposition | str | None, optional): + The default composition or name of the model(s) to be used for vision tasks. + Can be overridden by the `model` parameter in the `click()`, `get()`, `act()` etc. methods. + + Example: + ```python + with VisionAgent() as agent: + agent.click("Submit button") + agent.type("Hello World") + agent.act("Open settings menu") + ``` + """ @telemetry.record_call(exclude={"model_router", "reporters", "tools"}) def __init__( self, @@ -37,6 +67,7 @@ def __init__( model_router: ModelRouter | None = None, reporters: list[Reporter] | None = None, tools: AgentToolbox | None = None, + model: ModelComposition | str | None = None, ) -> None: load_dotenv() configure_logging(level=log_level) @@ -49,9 +80,10 @@ def __init__( self.claude = ClaudeHandler(log_level=log_level) self.tools = tools or AgentToolbox(os=AskUiControllerClient(display=display, reporter=self._reporter)) self._controller = AskUiControllerServer() + self._model = model @telemetry.record_call(exclude={"locator"}) - def click(self, locator: Optional[str | Locator] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model: Optional[str] = None) -> None: + def click(self, locator: Optional[str | Locator] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model: ModelComposition | str | None = None) -> None: """ Simulates a mouse click on the user interface element identified by the provided locator. @@ -59,7 +91,7 @@ def click(self, locator: Optional[str | Locator] = None, button: Literal['left', locator (str | Locator | None): The identifier or description of the element to click. button ('left' | 'middle' | 'right'): Specifies which mouse button to click. Defaults to 'left'. repeat (int): The number of times to click. Must be greater than 0. Defaults to 1. - model (str | None): The model name to be used for element detection. Optional. + model (ModelComposition | str | None): The composition or name of the model(s) to be used for locating the element to click on using the `locator`. Raises: InvalidParameterError: If the 'repeat' parameter is less than 1. @@ -86,44 +118,44 @@ def click(self, locator: Optional[str | Locator] = None, button: Literal['left', self._reporter.add_message("User", msg) if locator is not None: logger.debug("VisionAgent received instruction to click on %s", locator) - self._mouse_move(locator, model) + self._mouse_move(locator, model or self._model) self.tools.os.click(button, repeat) # type: ignore - def _locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: Optional[str] = None) -> Point: + def _locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: ModelComposition | str | None = None) -> Point: if screenshot is None: screenshot = self.tools.os.screenshot() # type: ignore - point = self.model_router.locate(screenshot, locator, model) + point = self.model_router.locate(screenshot, locator, model or self._model) self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") return point - def locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: Optional[str] = None) -> Point: + def locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: ModelComposition | str | None = None) -> Point: """ Locates the UI element identified by the provided locator. Args: locator (str | Locator): The identifier or description of the element to locate. screenshot (Optional[Image.Image], optional): The screenshot to use for locating the element. Defaults to None. - model (Optional[str], optional): The model to use for locating the element. Defaults to None. + model (ModelComposition | str | None): The composition or name of the model(s) to be used for locating the element using the `locator`. Returns: Point: The coordinates of the element. """ self._reporter.add_message("User", f"locate {locator}") logger.debug("VisionAgent received instruction to locate %s", locator) - return self._locate(locator, screenshot, model) + return self._locate(locator, screenshot, model or self._model) - def _mouse_move(self, locator: str | Locator, model: Optional[str] = None) -> None: - point = self._locate(locator=locator, model=model) + def _mouse_move(self, locator: str | Locator, model: ModelComposition | str | None = None) -> None: + point = self._locate(locator=locator, model=model or self._model) self.tools.os.mouse(point[0], point[1]) # type: ignore @telemetry.record_call(exclude={"locator"}) - def mouse_move(self, locator: str | Locator, model: Optional[str] = None) -> None: + def mouse_move(self, locator: str | Locator, model: ModelComposition | str | None = None) -> None: """ Moves the mouse cursor to the UI element identified by the provided locator. Parameters: locator (str | Locator): The identifier or description of the element to move to. - model (str | None): The model name to be used for element detection. Optional. + model (ModelComposition | str | None): The composition or name of the model(s) to be used for locating the element to move the mouse to using the `locator`. Example: ```python @@ -135,7 +167,7 @@ def mouse_move(self, locator: str | Locator, model: Optional[str] = None) -> Non """ self._reporter.add_message("User", f'mouse_move: {locator}') logger.debug("VisionAgent received instruction to mouse_move to %s", locator) - self._mouse_move(locator, model) + self._mouse_move(locator, model or self._model) @telemetry.record_call() def mouse_scroll(self, x: int, y: int) -> None: @@ -190,7 +222,7 @@ def get( query: str, image: Optional[ImageSource] = None, response_schema: Type[JsonSchema] | None = None, - model: Optional[str] = None, + model: ModelComposition | str | None = None, ) -> JsonSchema | str: """ Retrieves information from an image (defaults to a screenshot of the current screen) based on the provided query. @@ -202,9 +234,9 @@ def get( The image to extract information from. Optional. Defaults to a screenshot of the current screen. response_schema (type[ResponseSchema] | None): A Pydantic model class that defines the response schema. Optional. If not provided, returns a string. - model (str | None): - The model to be used for information extraction. Optional. - Note: response_schema is only supported with models that support JSON output (like the default askui model). + model (ModelComposition | str | None): + The composition or name of the model(s) to be used for retrieving information from the screen or image using the `query`. + Note: `response_schema` is only supported with not supported by all models. Returns: ResponseSchema | str: The extracted information, either as a Pydantic model instance or a string. @@ -239,7 +271,7 @@ class UrlResponse(JsonSchemaBase): response = self.model_router.get_inference( image=image, query=query, - model=model, + model=model or self._model, response_schema=response_schema, ) if self._reporter is not None: @@ -307,7 +339,7 @@ def key_down(self, key: PcKey | ModifierKey) -> None: self.tools.os.keyboard_pressed(key) @telemetry.record_call(exclude={"goal"}) - def act(self, goal: str, model: Optional[str] = None) -> None: + def act(self, goal: str, model: ModelComposition | str | None = None) -> None: """ Instructs the agent to achieve a specified goal through autonomous actions. @@ -317,8 +349,7 @@ def act(self, goal: str, model: Optional[str] = None) -> None: Parameters: goal (str): A description of what the agent should achieve. - model (str | None): The specific model to use for vision analysis. - If None, uses the default model. + model (ModelComposition | str | None): The composition or name of the model(s) to be used for achieving the `goal`. Example: ```python @@ -332,7 +363,7 @@ def act(self, goal: str, model: Optional[str] = None) -> None: logger.debug( "VisionAgent received instruction to act towards the goal '%s'", goal ) - self.model_router.act(self.tools.os, goal, model) + self.model_router.act(self.tools.os, goal, model or self._model) @telemetry.record_call() def keyboard( diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index add72aa8..7eb98f7a 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -6,6 +6,7 @@ from askui import VisionAgent import logging from askui.chat.click_recorder import ClickRecorder +from askui.models import ModelName from askui.reporting import Reporter from askui.utils.image_utils import base64_to_image import json @@ -203,7 +204,7 @@ def rerun(): element_description = agent.get( query=prompt, image=screenshot_with_crosshair, - model="anthropic-claude-3-5-sonnet-20241022", + model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) write_message( message["role"], @@ -213,7 +214,7 @@ def rerun(): ) agent.mouse_move( locator=element_description.replace('"', ""), - model="anthropic-claude-3-5-sonnet-20241022", + model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) else: write_message( diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index fd64d0bf..0eb63c5e 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -57,19 +57,21 @@ def __str__(self) -> str: TextMatchType = Literal["similar", "exact", "contains", "regex"] +DEFAULT_TEXT_MATCH_TYPE = "similar" +DEFAULT_SIMILARITY_THRESHOLD = 70 class Text(Class): """Locator for finding text elements by their content.""" text: str | None = None - match_type: TextMatchType = "similar" - similarity_threshold: int = Field(default=70, ge=0, le=100) + match_type: TextMatchType = DEFAULT_TEXT_MATCH_TYPE + similarity_threshold: int = Field(default=DEFAULT_SIMILARITY_THRESHOLD, ge=0, le=100) def __init__( self, text: str | None = None, - match_type: TextMatchType = "similar", - similarity_threshold: int = 70, + match_type: TextMatchType = DEFAULT_TEXT_MATCH_TYPE, + similarity_threshold: int = DEFAULT_SIMILARITY_THRESHOLD, **kwargs, ) -> None: super().__init__( diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py index 18b077d0..c5b1bf58 100644 --- a/src/askui/locators/serializers.py +++ b/src/askui/locators/serializers.py @@ -3,6 +3,8 @@ from askui.utils.image_utils import ImageSource from askui.models.askui.ai_element_utils import AiElementCollection, AiElementNotFound from .locators import ( + DEFAULT_SIMILARITY_THRESHOLD, + DEFAULT_TEXT_MATCH_TYPE, ImageMetadata, AiElement as AiElementLocator, Class, @@ -139,6 +141,9 @@ def _serialize_description(self, description: Description) -> str: def _serialize_text(self, text: Text) -> str: match text.match_type: case "similar": + if text.similarity_threshold == DEFAULT_SIMILARITY_THRESHOLD and text.match_type == DEFAULT_TEXT_MATCH_TYPE: + # Necessary so that we can use wordlevel ocr for these texts + return f"text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" return f"text with text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER} that matches to {text.similarity_threshold} %" case "exact": return f"text equals text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" diff --git a/src/askui/models/__init__.py b/src/askui/models/__init__.py index 5ffcdcab..efc2755c 100644 --- a/src/askui/models/__init__.py +++ b/src/askui/models/__init__.py @@ -1,7 +1,7 @@ -from .models import ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ANTHROPIC, ASKUI +from .models import ModelName, ModelComposition, ModelDefinition __all__ = [ - "ANTHROPIC__CLAUDE__3_5__SONNET__20241022", - "ANTHROPIC", - "ASKUI", + "ModelName", + "ModelComposition", + "ModelDefinition", ] diff --git a/src/askui/models/askui/api.py b/src/askui/models/askui/api.py index 4d12ba06..a44a5a48 100644 --- a/src/askui/models/askui/api.py +++ b/src/askui/models/askui/api.py @@ -5,6 +5,7 @@ import json as json_lib from PIL import Image from typing import Any, Type, Union +from askui.models.models import ModelComposition from askui.utils.image_utils import ImageSource from askui.locators.serializers import AskUiLocatorSerializer from askui.locators.locators import Locator @@ -49,7 +50,7 @@ def _request(self, endpoint: str, json: dict[str, Any] | None = None) -> Any: return response.json() - def predict(self, image: Union[pathlib.Path, Image.Image], locator: Locator) -> tuple[int | None, int | None]: + def predict(self, image: Union[pathlib.Path, Image.Image], locator: Locator, model: ModelComposition | None = None) -> tuple[int | None, int | None]: serialized_locator = self._locator_serializer.serialize(locator=locator) json: dict[str, Any] = { "image": f",{image_to_base64(image)}", @@ -57,6 +58,9 @@ def predict(self, image: Union[pathlib.Path, Image.Image], locator: Locator) -> } if "customElements" in serialized_locator: json["customElements"] = serialized_locator["customElements"] + if model is not None: + json["modelComposition"] = model.model_dump(by_alias=True) + logger.debug(f"modelComposition:\n{json_lib.dumps(json['modelComposition'])}") content = self._request(endpoint="inference", json=json) assert content["type"] == "COMMANDS", f"Received unknown content type {content['type']}" actions = [el for el in content["data"]["actions"] if el["inputEvent"] == "MOUSE_MOVE"] diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 7326d901..71da37b2 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -1,7 +1,86 @@ -ANTHROPIC__CLAUDE__3_5__SONNET__20241022 = "anthropic-claude-3-5-sonnet-20241022" -ANTHROPIC = ANTHROPIC__CLAUDE__3_5__SONNET__20241022 -ASKUI = "askui" -ASKUI__AI_ELEMENT = "askui-ai-element" -ASKUI__COMBO = "askui-combo" -ASKUI__OCR = "askui-ocr" -ASKUI__PTA = "askui-pta" +from collections.abc import Iterator +from enum import Enum +import re +from typing import Annotated +from pydantic import BaseModel, ConfigDict, Field, RootModel + + +class ModelName(str, Enum): + ANTHROPIC__CLAUDE__3_5__SONNET__20241022 = "anthropic-claude-3-5-sonnet-20241022" + ANTHROPIC = "anthropic" + ASKUI = "askui" + ASKUI__AI_ELEMENT = "askui-ai-element" + ASKUI__COMBO = "askui-combo" + ASKUI__OCR = "askui-ocr" + ASKUI__PTA = "askui-pta" + TARS = "tars" + + +MODEL_DEFINITION_PROPERTY_REGEX_PATTERN = re.compile(r"^[A-Za-z0-9_]+$") + + +ModelDefinitionProperty = Annotated[ + str, Field(pattern=MODEL_DEFINITION_PROPERTY_REGEX_PATTERN) +] + + +class ModelDefinition(BaseModel): + """ + A definition of a model. + """ + model_config = ConfigDict( + populate_by_name=True, + ) + task: ModelDefinitionProperty = Field( + description="The task the model is trained for, e.g., end-to-end OCR (e2e_ocr) or object detection (od)", + examples=["e2e_ocr", "od"], + ) + architecture: ModelDefinitionProperty = Field( + description="The architecture of the model", examples=["easy_ocr", "yolo"] + ) + version: str = Field(pattern=r"^[0-9]{1,6}$") + interface: ModelDefinitionProperty = Field( + description="The interface the model is trained for", + examples=["online_learning", "offline_learning"], + ) + use_case: ModelDefinitionProperty = Field( + description='The use case the model is trained for. In the case of workspace specific AskUI models, this is often the workspace id but with "-" replaced by "_"', + examples=[ + "fb3b9a7b_3aea_41f7_ba02_e55fd66d1c1e", + "00000000_0000_0000_0000_000000000000", + ], + default="00000000_0000_0000_0000_000000000000", + serialization_alias="useCase", + ) + tags: list[ModelDefinitionProperty] = Field( + default_factory=list, + description="Tags for identifying the model that cannot be represented by other properties", + examples=["trained", "word_level"], + ) + + @property + def model_name(self) -> str: + return ( + "-".join( + [ + self.task, + self.architecture, + self.interface, + self.use_case, + self.version, + *self.tags, + ] + ) + ) + + +class ModelComposition(RootModel[list[ModelDefinition]]): + """ + A composition of models. + """ + + def __iter__(self): + return iter(self.root) + + def __getitem__(self, index: int) -> ModelDefinition: + return self.root[index] diff --git a/src/askui/models/router.py b/src/askui/models/router.py index 4009908f..42756d84 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -8,6 +8,7 @@ from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer from askui.locators.locators import Locator from askui.models.askui.ai_element_utils import AiElementCollection +from askui.models.models import ModelComposition, ModelName from askui.models.types import JsonSchema from askui.reporting import Reporter from askui.utils.image_utils import ImageSource @@ -36,12 +37,12 @@ def locate( self, screenshot: Image.Image, locator: str | Locator, - model: str | None = None, + model: ModelComposition | str | None = None, ) -> Point: pass @abstractmethod - def is_responsible(self, model: str | None = None) -> bool: + def is_responsible(self, model: ModelComposition | str | None = None) -> bool: pass @abstractmethod @@ -63,36 +64,37 @@ def locate( self, screenshot: Image.Image, locator: str | Locator, - model: str | None = None, + model: ModelComposition | str | None = None, ) -> Point: if not self._inference_api.authenticated: raise AutomationError( "NoAskUIAuthenticationSet! Please set 'AskUI ASKUI_WORKSPACE_ID' or 'ASKUI_TOKEN' as env variables!" ) - if model == "askui" or model is None: + if not isinstance(model, str) or model == ModelName.ASKUI: logger.debug("Routing locate prediction to askui") locator = Text(locator) if isinstance(locator, str) else locator - x, y = self._inference_api.predict(screenshot, locator) + _model = model if not isinstance(model, str) else None + x, y = self._inference_api.predict(screenshot, locator, _model) return handle_response((x, y), locator) if not isinstance(locator, str): raise AutomationError( f'Locators of type `{type(locator)}` are not supported for models "askui-pta", "askui-ocr" and "askui-combo" and "askui-ai-element". Please provide a `str`.' ) - if model == "askui-pta": + if model == ModelName.ASKUI__PTA: logger.debug("Routing locate prediction to askui-pta") x, y = self._inference_api.predict(screenshot, Description(locator)) return handle_response((x, y), locator) - if model == "askui-ocr": + if model == ModelName.ASKUI__OCR: logger.debug("Routing locate prediction to askui-ocr") return self._locate_with_askui_ocr(screenshot, locator) - if model == "askui-combo" or model is None: + if model == ModelName.ASKUI__COMBO or model is None: logger.debug("Routing locate prediction to askui-combo") description_locator = Description(locator) x, y = self._inference_api.predict(screenshot, description_locator) if x is None or y is None: return self._locate_with_askui_ocr(screenshot, locator) return handle_response((x, y), description_locator) - if model == "askui-ai-element": + if model == ModelName.ASKUI__AI_ELEMENT: logger.debug("Routing click prediction to askui-ai-element") _locator = AiElement(locator) x, y = self._inference_api.predict(screenshot, _locator) @@ -100,8 +102,8 @@ def locate( raise AutomationError(f'Invalid model: "{model}"') @override - def is_responsible(self, model: str | None = None) -> bool: - return model is None or model.startswith("askui") + def is_responsible(self, model: ModelComposition | str | None = None) -> bool: + return not isinstance(model, str) or model.startswith(ModelName.ASKUI) @override def is_authenticated(self) -> bool: @@ -127,10 +129,10 @@ def __init__( self.tars = UITarsAPIHandler(self._reporter) self._locator_serializer = VlmLocatorSerializer() - def act(self, controller_client, goal: str, model: str | None = None): - if self.tars.authenticated and model == "tars": + def act(self, controller_client, goal: str, model: ModelComposition | str | None = None): + if self.tars.authenticated and model == ModelName.TARS: return self.tars.act(controller_client, goal) - if self.claude.authenticated and (model == "claude" or model is None): + if self.claude.authenticated and (model is None or isinstance(model, str) and model.startswith(ModelName.ANTHROPIC)): agent = ClaudeComputerAgent(controller_client, self._reporter) return agent.run(goal) raise AutomationError(f"Invalid model for act: {model}") @@ -140,19 +142,19 @@ def get_inference( query: str, image: ImageSource, response_schema: Type[JsonSchema] | None = None, - model: str | None = None, + model: ModelComposition | str | None = None, ) -> JsonSchema | str: - if self.tars.authenticated and model == "tars": + if self.tars.authenticated and model == ModelName.TARS: if response_schema is not None: raise NotImplementedError("Response schema is not yet supported for UI-TARS models.") return self.tars.get_inference(image=image, query=query) if self.claude.authenticated and ( - model == "anthropic-claude-3-5-sonnet-20241022" + isinstance(model, str) and model.startswith(ModelName.ANTHROPIC) ): if response_schema is not None: raise NotImplementedError("Response schema is not yet supported for Anthropic models.") return self.claude.get_inference(image=image, query=query) - if self.askui.authenticated and (model == "askui" or model is None): + if self.askui.authenticated and (model == ModelName.ASKUI or model is None): return self.askui.get_inference( image=image, query=query, @@ -172,10 +174,10 @@ def locate( self, screenshot: Image.Image, locator: str | Locator, - model: str | None = None, + model: ModelComposition | str | None = None, ) -> Point: if ( - model is not None + isinstance(model, str) and model in self.huggingface_spaces.get_spaces_names() ): x, y = self.huggingface_spaces.predict( @@ -184,23 +186,23 @@ def locate( model_name=model, ) return handle_response((x, y), locator) - if model is not None: - if model.startswith("anthropic") and not self.claude.authenticated: + if isinstance(model, str): + if model.startswith(ModelName.ANTHROPIC) and not self.claude.authenticated: raise AutomationError( "You need to provide Anthropic credentials to use Anthropic models." ) - if model.startswith("tars") and not self.tars.authenticated: + if model.startswith(ModelName.TARS) and not self.tars.authenticated: raise AutomationError( "You need to provide UI-TARS HF Endpoint credentials to use UI-TARS models." ) - if self.tars.authenticated and model == "tars": + if self.tars.authenticated and model == ModelName.TARS: x, y = self.tars.locate_prediction( screenshot, self._serialize_locator(locator) ) return handle_response((x, y), locator) if ( self.claude.authenticated - and model == "anthropic-claude-3-5-sonnet-20241022" + and isinstance(model, str) and model.startswith(ModelName.ANTHROPIC) ): logger.debug("Routing locate prediction to Anthropic") x, y = self.claude.locate_inference( diff --git a/src/askui/telemetry/telemetry.py b/src/askui/telemetry/telemetry.py index 182c30f0..5ddc61c4 100644 --- a/src/askui/telemetry/telemetry.py +++ b/src/askui/telemetry/telemetry.py @@ -174,10 +174,15 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: ) if exclude_first_arg: processed_args = processed_args[1:] if processed_args else () + processed_args = tuple(arg.model_dump() if isinstance(arg, BaseModel) else arg for arg in processed_args) processed_kwargs = { k: v if k not in _exclude else self._EXCLUDE_MASK for k, v in kwargs.items() } + processed_kwargs = { + k: v.model_dump() if isinstance(v, BaseModel) else v + for k, v in processed_kwargs.items() + } attributes: dict[str, Any] = { "module": module, "fn_name": fn_name, diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index 2e9f5ef1..ca9940c5 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -1,7 +1,7 @@ from typing import Literal import pytest from PIL import Image as PILImage -from askui import models +from askui.models import ModelName from askui import VisionAgent from askui.utils.image_utils import ImageSource from askui import JsonSchemaBase @@ -20,7 +20,7 @@ class BrowserContextResponse(JsonSchemaBase): browser_type: Literal["chrome", "firefox", "edge", "safari"] -@pytest.mark.parametrize("model", [None, models.ASKUI, models.ANTHROPIC]) +@pytest.mark.parametrize("model", [None, ModelName.ASKUI, ModelName.ANTHROPIC]) def test_get( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, @@ -44,7 +44,7 @@ def test_get_with_response_schema_without_additional_properties_with_askui_model "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), response_schema=UrlResponse, - model=models.ASKUI, + model=ModelName.ASKUI, ) @@ -58,11 +58,11 @@ def test_get_with_response_schema_without_required_with_askui_model_raises( "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), response_schema=UrlResponse, - model=models.ASKUI, + model=ModelName.ASKUI, ) -@pytest.mark.parametrize("model", [None, models.ASKUI]) +@pytest.mark.parametrize("model", [None, ModelName.ASKUI]) def test_get_with_response_schema( vision_agent: VisionAgent, github_login_screenshot: PILImage.Image, @@ -87,11 +87,11 @@ def test_get_with_response_schema_with_anthropic_model_raises_not_implemented( "What is the current url shown in the url bar?", ImageSource(github_login_screenshot), response_schema=UrlResponse, - model=models.ANTHROPIC, + model=ModelName.ANTHROPIC, ) -@pytest.mark.parametrize("model", [None, models.ASKUI]) +@pytest.mark.parametrize("model", [None, ModelName.ASKUI]) @pytest.mark.skip("Skip as there is currently a bug on the api side not supporting definitions used for nested schemas") def test_get_with_nested_and_inherited_response_schema( vision_agent: VisionAgent, diff --git a/tests/e2e/agent/test_locate.py b/tests/e2e/agent/test_locate.py index fe20fad5..af061519 100644 --- a/tests/e2e/agent/test_locate.py +++ b/tests/e2e/agent/test_locate.py @@ -13,14 +13,15 @@ ) from askui.locators.locators import Image from askui.exceptions import ElementNotFoundError +from askui.models import ModelName @pytest.mark.skip("Skipping tests for now") @pytest.mark.parametrize( "model", [ - "askui", - "anthropic-claude-3-5-sonnet-20241022", + ModelName.ASKUI, + ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ], ) class TestVisionAgentLocate: diff --git a/tests/e2e/agent/test_locate_with_different_models.py b/tests/e2e/agent/test_locate_with_different_models.py index 2c9ebb5b..8b3ad9cd 100644 --- a/tests/e2e/agent/test_locate_with_different_models.py +++ b/tests/e2e/agent/test_locate_with_different_models.py @@ -10,12 +10,13 @@ AiElement, ) from askui.exceptions import AutomationError +from askui.models.models import ModelName class TestVisionAgentLocateWithDifferentModels: """Test class for VisionAgent.locate() method with different AskUI models.""" - @pytest.mark.parametrize("model", ["askui-pta"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__PTA]) def test_locate_with_pta_model( self, vision_agent: VisionAgent, @@ -30,7 +31,7 @@ def test_locate_with_pta_model( assert 350 <= x <= 570 assert 160 <= y <= 230 - @pytest.mark.parametrize("model", ["askui-pta"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__PTA]) def test_locate_with_pta_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, @@ -42,7 +43,7 @@ def test_locate_with_pta_model_fails_with_wrong_locator( with pytest.raises(AutomationError): vision_agent.locate(locator, github_login_screenshot, model=model) - @pytest.mark.parametrize("model", ["askui-ocr"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__OCR]) def test_locate_with_ocr_model( self, vision_agent: VisionAgent, @@ -57,7 +58,7 @@ def test_locate_with_ocr_model( assert 450 <= x <= 570 assert 190 <= y <= 260 - @pytest.mark.parametrize("model", ["askui-ocr"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__OCR]) def test_locate_with_ocr_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, @@ -69,7 +70,7 @@ def test_locate_with_ocr_model_fails_with_wrong_locator( with pytest.raises(AutomationError): vision_agent.locate(locator, github_login_screenshot, model=model) - @pytest.mark.parametrize("model", ["askui-ai-element"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__AI_ELEMENT]) def test_locate_with_ai_element_model( self, vision_agent: VisionAgent, @@ -84,7 +85,7 @@ def test_locate_with_ai_element_model( assert 350 <= x <= 570 assert 240 <= y <= 320 - @pytest.mark.parametrize("model", ["askui-ai-element"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__AI_ELEMENT]) def test_locate_with_ai_element_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, @@ -96,7 +97,7 @@ def test_locate_with_ai_element_model_fails_with_wrong_locator( with pytest.raises(AutomationError): vision_agent.locate(locator, github_login_screenshot, model=model) - @pytest.mark.parametrize("model", ["askui-combo"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__COMBO]) def test_locate_with_combo_model_description_first( self, vision_agent: VisionAgent, @@ -111,7 +112,7 @@ def test_locate_with_combo_model_description_first( assert 350 <= x <= 570 assert 160 <= y <= 230 - @pytest.mark.parametrize("model", ["askui-combo"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__COMBO]) def test_locate_with_combo_model_text_fallback( self, vision_agent: VisionAgent, @@ -126,7 +127,7 @@ def test_locate_with_combo_model_text_fallback( assert 450 <= x <= 570 assert 190 <= y <= 260 - @pytest.mark.parametrize("model", ["askui-combo"]) + @pytest.mark.parametrize("model", [ModelName.ASKUI__COMBO]) def test_locate_with_combo_model_fails_with_wrong_locator( self, vision_agent: VisionAgent, diff --git a/tests/e2e/agent/test_model_composition.py b/tests/e2e/agent/test_model_composition.py new file mode 100644 index 00000000..8ae8b165 --- /dev/null +++ b/tests/e2e/agent/test_model_composition.py @@ -0,0 +1,111 @@ +"""Tests for VisionAgent with different model compositions""" + +import pytest +from PIL import Image as PILImage +from askui.agent import VisionAgent +from askui.locators.locators import DEFAULT_SIMILARITY_THRESHOLD, Text +from askui.models import ModelComposition, ModelDefinition + + +@pytest.mark.parametrize( + "model", + [ + ModelComposition( + [ + ModelDefinition( + task="e2e_ocr", + architecture="easy_ocr", + version="1", + interface="online_learning", + ) + ] + ), + ModelComposition( + [ + ModelDefinition( + task="e2e_ocr", + architecture="easy_ocr", + version="1", + interface="online_learning", + use_case="fb3b9a7b_3aea_41f7_ba02_e55fd66d1c1e", + tags=["trained"], + ) + ] + ), + ], +) +class TestSimpleOcrModel: + """Test class for simple OCR model compositions.""" + + def test_locate_with_simple_ocr( + self, + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: ModelComposition, + ) -> None: + """Test locating elements using simple OCR model.""" + x, y = vision_agent.locate("Sign in", github_login_screenshot, model=model) + assert isinstance(x, int) + assert isinstance(y, int) + assert 0 <= x <= github_login_screenshot.width + assert 0 <= y <= github_login_screenshot.height + + +@pytest.mark.parametrize( + "model", + [ + ModelComposition( + [ + ModelDefinition( + task="e2e_ocr", + architecture="easy_ocr", + version="1", + interface="online_learning", + tags=["word_level"], + ) + ] + ), + ModelComposition( + [ + ModelDefinition( + task="e2e_ocr", + architecture="easy_ocr", + version="1", + interface="online_learning", + use_case="fb3b9a7b_3aea_41f7_ba02_e55fd66d1c1e", + tags=["trained", "word_level"], + ) + ] + ), + ], +) +class TestWordLevelOcrModel: + """Test class for word-level OCR model compositions.""" + + def test_locate_with_word_level_ocr( + self, + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: ModelComposition, + ) -> None: + """Test locating elements using word-level OCR model.""" + x, y = vision_agent.locate("Sign", github_login_screenshot, model=model) + assert isinstance(x, int) + assert isinstance(y, int) + assert 0 <= x <= github_login_screenshot.width + assert 0 <= y <= github_login_screenshot.height + + def test_locate_with_trained_word_level_ocr_with_non_default_text_raises( + self, + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: ModelComposition, + ) -> None: + if any("trained" not in m.tags for m in model): + pytest.skip("Skipping test for non-trained model") + with pytest.raises(Exception): + vision_agent.locate(Text("Sign in", text_type="exact"), github_login_screenshot, model=model) + vision_agent.locate(Text("Sign in", text_type="regex"), github_login_screenshot, model=model) + vision_agent.locate(Text("Sign in", text_type="contains"), github_login_screenshot, model=model) + assert DEFAULT_SIMILARITY_THRESHOLD != 80 + vision_agent.locate(Text("Sign in", similarity_threshold=80), github_login_screenshot, model=model) diff --git a/tests/unit/locators/serializers/test_askui_locator_serializer.py b/tests/unit/locators/serializers/test_askui_locator_serializer.py index fd11ea84..a2c58ad9 100644 --- a/tests/unit/locators/serializers/test_askui_locator_serializer.py +++ b/tests/unit/locators/serializers/test_askui_locator_serializer.py @@ -124,7 +124,7 @@ def test_serialize_above_relation(askui_serializer: AskUiLocatorSerializer) -> N result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % index 1 above intersection_area element_center_line text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> index 1 above intersection_area element_center_line text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -135,7 +135,7 @@ def test_serialize_below_relation(askui_serializer: AskUiLocatorSerializer) -> N result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % index 0 below intersection_area element_edge_area text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> index 0 below intersection_area element_edge_area text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -146,7 +146,7 @@ def test_serialize_right_relation(askui_serializer: AskUiLocatorSerializer) -> N result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % index 0 right of intersection_area element_edge_area text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> index 0 right of intersection_area element_edge_area text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -157,7 +157,7 @@ def test_serialize_left_relation(askui_serializer: AskUiLocatorSerializer) -> No result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % index 0 left of intersection_area element_edge_area text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> index 0 left of intersection_area element_edge_area text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -170,7 +170,7 @@ def test_serialize_containing_relation( result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % contains text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> contains text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -181,7 +181,7 @@ def test_serialize_inside_relation(askui_serializer: AskUiLocatorSerializer) -> result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % in text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> in text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -194,7 +194,7 @@ def test_serialize_nearest_to_relation( result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % nearest to text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> nearest to text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -205,7 +205,7 @@ def test_serialize_and_relation(askui_serializer: AskUiLocatorSerializer) -> Non result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % and text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> and text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -216,7 +216,7 @@ def test_serialize_or_relation(askui_serializer: AskUiLocatorSerializer) -> None result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % or text with text <|string|>world<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> or text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -240,7 +240,7 @@ def test_serialize_relations_chain(askui_serializer: AskUiLocatorSerializer) -> result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % index 0 above intersection_area element_edge_area text with text <|string|>world<|string|> that matches to 70 % index 0 below intersection_area element_edge_area text with text <|string|>earth<|string|> that matches to 70 %" + == "text <|string|>hello<|string|> index 0 above intersection_area element_edge_area text <|string|>world<|string|> index 0 below intersection_area element_edge_area text <|string|>earth<|string|>" ) assert result["customElements"] == [] @@ -335,7 +335,7 @@ def test_serialize_image_with_relation( result = askui_serializer.serialize(image) assert ( result["instruction"] - == "custom element with text <|string|>image<|string|> index 0 above intersection_area element_edge_area text with text <|string|>world<|string|> that matches to 70 %" + == "custom element with text <|string|>image<|string|> index 0 above intersection_area element_edge_area text <|string|>world<|string|>" ) assert len(result["customElements"]) == 1 custom_element = result["customElements"][0] @@ -350,7 +350,7 @@ def test_serialize_text_with_image_relation( result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text with text <|string|>hello<|string|> that matches to 70 % index 0 above intersection_area element_edge_area custom element with text <|string|>image<|string|>" + == "text <|string|>hello<|string|> index 0 above intersection_area element_edge_area custom element with text <|string|>image<|string|>" ) assert len(result["customElements"]) == 1 custom_element = result["customElements"][0] diff --git a/tests/unit/unit/__init__.py b/tests/unit/models/__init__.py similarity index 100% rename from tests/unit/unit/__init__.py rename to tests/unit/models/__init__.py diff --git a/tests/unit/models/test_models.py b/tests/unit/models/test_models.py new file mode 100644 index 00000000..c28402f9 --- /dev/null +++ b/tests/unit/models/test_models.py @@ -0,0 +1,85 @@ +import pytest +from src.askui.models.models import ModelComposition, ModelDefinition + + +MODEL_DEFINITIONS = { + "e2e_ocr": ModelDefinition( + task="e2e_ocr", + architecture="easy_ocr", + version="1", + interface="online_learning", + use_case="test_workspace", + tags=["trained"] + ), + "od": ModelDefinition( + task="od", + architecture="yolo", + version="789012", + interface="offline_learning", + use_case="test_workspace2" + ) +} + + +def test_model_composition_initialization(): + composition = ModelComposition([MODEL_DEFINITIONS["e2e_ocr"]]) + assert len(composition.root) == 1 + assert composition.root[0].model_name == "e2e_ocr-easy_ocr-online_learning-test_workspace-1-trained" + + +def test_model_composition_initialization_with_multiple_models(): + composition = ModelComposition([MODEL_DEFINITIONS["e2e_ocr"], MODEL_DEFINITIONS["od"]]) + assert len(composition.root) == 2 + assert composition.root[0].model_name == "e2e_ocr-easy_ocr-online_learning-test_workspace-1-trained" + assert composition.root[1].model_name == "od-yolo-offline_learning-test_workspace2-789012" + + +def test_model_composition_serialization(): + model_def = MODEL_DEFINITIONS["e2e_ocr"] + composition = ModelComposition([model_def]) + serialized = composition.model_dump(by_alias=True) + assert isinstance(serialized, list) + assert len(serialized) == 1 + assert serialized[0]["task"] == "e2e_ocr" + assert serialized[0]["architecture"] == "easy_ocr" + assert serialized[0]["version"] == "1" + assert serialized[0]["interface"] == "online_learning" + assert serialized[0]["useCase"] == "test_workspace" + assert serialized[0]["tags"] == ["trained"] + + +def test_model_composition_serialization_with_multiple_models(): + composition = ModelComposition([MODEL_DEFINITIONS["e2e_ocr"], MODEL_DEFINITIONS["od"]]) + serialized = composition.model_dump(by_alias=True) + assert isinstance(serialized, list) + assert len(serialized) == 2 + assert serialized[0]["task"] == "e2e_ocr" + assert serialized[1]["task"] == "od" + + +def test_model_composition_validation_with_invalid_task(): + with pytest.raises(ValueError): + ModelComposition([{ + "task": "invalid task!", + "architecture": "easy_ocr", + "version": "123456", + "interface": "online_learning", + "useCase": "test_workspace" + }]) + + +def test_model_composition_validation_with_invalid_version(): + with pytest.raises(ValueError): + ModelComposition([{ + "task": "e2e_ocr", + "architecture": "easy_ocr", + "version": "invalid", + "interface": "online_learning", + "useCase": "test_workspace" + }]) + + +def test_model_composition_with_empty_tags_and_use_case(): + model_def = ModelDefinition(**{**MODEL_DEFINITIONS["e2e_ocr"].model_dump(exclude={"tags", "use_case"}), "tags": []}) + composition = ModelComposition([model_def]) + assert composition.root[0].model_name == "e2e_ocr-easy_ocr-online_learning-00000000_0000_0000_0000_000000000000-1" diff --git a/tests/unit/unit/test_image_utils.py b/tests/unit/utils/test_image_utils.py similarity index 100% rename from tests/unit/unit/test_image_utils.py rename to tests/unit/utils/test_image_utils.py From b4a584ace97ce1dbcbb1622ffe3c45145c9c4a81 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 17 Apr 2025 09:17:25 +0200 Subject: [PATCH 03/14] refactor!(locators): rename `Class` to `Element` --- src/askui/locators/__init__.py | 4 +-- src/askui/locators/locators.py | 8 ++--- src/askui/locators/serializers.py | 10 +++---- tests/e2e/agent/test_locate.py | 6 ++-- tests/e2e/agent/test_locate_with_relations.py | 30 +++++++++---------- .../test_askui_locator_serializer.py | 4 +-- .../test_locator_string_representation.py | 16 +++++----- .../test_vlm_locator_serializer.py | 6 ++-- tests/unit/locators/test_locators.py | 16 +++++----- 9 files changed, 50 insertions(+), 50 deletions(-) diff --git a/src/askui/locators/__init__.py b/src/askui/locators/__init__.py index b830a0e1..d98f9484 100644 --- a/src/askui/locators/__init__.py +++ b/src/askui/locators/__init__.py @@ -1,8 +1,8 @@ -from askui.locators.locators import AiElement, Class, Description, Image, Text +from askui.locators.locators import AiElement, Element, Description, Image, Text __all__ = [ "AiElement", - "Class", + "Element", "Description", "Image", "Text", diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index 0eb63c5e..3ec306ae 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -32,7 +32,7 @@ def __str__(self) -> str: return self._str_with_relation() -class Class(Locator): +class Element(Locator): """Locator for finding ui elements by a class name assigned to the ui element, e.g., by a computer vision model.""" class_name: Literal["text", "textfield"] | None = None @@ -47,7 +47,7 @@ def _str_with_relation(self) -> str: result = ( f'element with class "{self.class_name}"' if self.class_name - else "element that has a class" + else "element" ) return result + super()._relations_str() @@ -57,11 +57,11 @@ def __str__(self) -> str: TextMatchType = Literal["similar", "exact", "contains", "regex"] -DEFAULT_TEXT_MATCH_TYPE = "similar" +DEFAULT_TEXT_MATCH_TYPE: TextMatchType = "similar" DEFAULT_SIMILARITY_THRESHOLD = 70 -class Text(Class): +class Text(Element): """Locator for finding text elements by their content.""" text: str | None = None match_type: TextMatchType = DEFAULT_TEXT_MATCH_TYPE diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py index c5b1bf58..5140784e 100644 --- a/src/askui/locators/serializers.py +++ b/src/askui/locators/serializers.py @@ -7,7 +7,7 @@ DEFAULT_TEXT_MATCH_TYPE, ImageMetadata, AiElement as AiElementLocator, - Class, + Element, Description, Image, Text, @@ -33,7 +33,7 @@ def serialize(self, locator: Relatable) -> str: if isinstance(locator, Text): return self._serialize_text(locator) - elif isinstance(locator, Class): + elif isinstance(locator, Element): return self._serialize_class(locator) elif isinstance(locator, Description): return self._serialize_description(locator) @@ -44,7 +44,7 @@ def serialize(self, locator: Relatable) -> str: else: raise ValueError(f"Unsupported locator type: {type(locator)}") - def _serialize_class(self, class_: Class) -> str: + def _serialize_class(self, class_: Element) -> str: if class_.class_name: return f"an arbitrary {class_.class_name} shown" else: @@ -108,7 +108,7 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: result = AskUiSerializedLocator(instruction="", customElements=[]) if isinstance(locator, Text): result["instruction"] = self._serialize_text(locator) - elif isinstance(locator, Class): + elif isinstance(locator, Element): result["instruction"] = self._serialize_class(locator) elif isinstance(locator, Description): result["instruction"] = self._serialize_description(locator) @@ -130,7 +130,7 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: result["customElements"] += serialized_relation["customElements"] return result - def _serialize_class(self, class_: Class) -> str: + def _serialize_class(self, class_: Element) -> str: return class_.class_name or "element" def _serialize_description(self, description: Description) -> str: diff --git a/tests/e2e/agent/test_locate.py b/tests/e2e/agent/test_locate.py index af061519..2edefc6a 100644 --- a/tests/e2e/agent/test_locate.py +++ b/tests/e2e/agent/test_locate.py @@ -7,7 +7,7 @@ from askui.agent import VisionAgent from askui.locators import ( Description, - Class, + Element, Text, AiElement, ) @@ -48,7 +48,7 @@ def test_locate_with_textfield_class_locator( model: str, ) -> None: """Test locating elements using a class locator.""" - locator = Class("textfield") + locator = Element("textfield") x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -62,7 +62,7 @@ def test_locate_with_unspecified_class_locator( model: str, ) -> None: """Test locating elements using a class locator.""" - locator = Class() + locator = Element() x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) diff --git a/tests/e2e/agent/test_locate_with_relations.py b/tests/e2e/agent/test_locate_with_relations.py index dabbba13..98305cc1 100644 --- a/tests/e2e/agent/test_locate_with_relations.py +++ b/tests/e2e/agent/test_locate_with_relations.py @@ -8,7 +8,7 @@ from askui.agent import VisionAgent from askui.locators import ( Description, - Class, + Element, Text, Image, ) @@ -30,7 +30,7 @@ def test_locate_with_above_relation( model: str, ) -> None: """Test locating elements using above_of relation.""" - locator = Text("Forgot password?").above_of(Class("textfield")) + locator = Text("Forgot password?").above_of(Element("textfield")) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -44,7 +44,7 @@ def test_locate_with_below_relation( model: str, ) -> None: """Test locating elements using below_of relation.""" - locator = Text("Forgot password?").below_of(Class("textfield")) + locator = Text("Forgot password?").below_of(Element("textfield")) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -86,7 +86,7 @@ def test_locate_with_containing_relation( model: str, ) -> None: """Test locating elements using containing relation.""" - locator = Class("textfield").containing(Text("github.com/login")) + locator = Element("textfield").containing(Text("github.com/login")) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -100,7 +100,7 @@ def test_locate_with_inside_relation( model: str, ) -> None: """Test locating elements using inside_of relation.""" - locator = Text("github.com/login").inside_of(Class("textfield")) + locator = Text("github.com/login").inside_of(Element("textfield")) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -114,7 +114,7 @@ def test_locate_with_nearest_to_relation( model: str, ) -> None: """Test locating elements using nearest_to relation.""" - locator = Class("textfield").nearest_to(Text("Password")) + locator = Element("textfield").nearest_to(Text("Password")) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -129,7 +129,7 @@ def test_locate_with_and_relation( model: str, ) -> None: """Test locating elements using and_ relation.""" - locator = Text("Forgot password?").and_(Class("text")) + locator = Text("Forgot password?").and_(Element("text")) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -143,7 +143,7 @@ def test_locate_with_or_relation( model: str, ) -> None: """Test locating elements using or_ relation.""" - locator = Class("textfield").nearest_to( + locator = Element("textfield").nearest_to( Text("Password").or_(Text("Username or email address")) ) x, y = vision_agent.locate( @@ -159,7 +159,7 @@ def test_locate_with_relation_index( model: str, ) -> None: """Test locating elements using relation with index.""" - locator = Class("textfield").below_of( + locator = Element("textfield").below_of( Text("Username or email address"), index=0 ) x, y = vision_agent.locate( @@ -175,7 +175,7 @@ def test_locate_with_relation_index_greater_0( model: str, ) -> None: """Test locating elements using relation with index.""" - locator = Class("textfield").below_of(Class("textfield"), index=1) + locator = Element("textfield").below_of(Element("textfield"), index=1) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -274,8 +274,8 @@ def test_locate_with_multiple_relations_with_same_locator_raises( """Test locating elements using multiple relations with same locator which is not supported by AskUI.""" locator = ( Text("Forgot password?") - .below_of(Class("textfield")) - .below_of(Class("textfield")) + .below_of(Element("textfield")) + .below_of(Element("textfield")) ) with pytest.raises(NotImplementedError): vision_agent.locate(locator, github_login_screenshot, model=model) @@ -305,7 +305,7 @@ def test_locate_with_relation_different_locator_types( ) -> None: """Test locating elements using relation with different locator types.""" locator = Text("Sign in").below_of( - Class("textfield").below_of(Text("Username or email address")), + Element("textfield").below_of(Text("Username or email address")), reference_point="center", ) x, y = vision_agent.locate( @@ -337,7 +337,7 @@ def test_locate_with_description_and_complex_relation( ) -> None: """Test locating elements using description with relation.""" locator = Description("Sign in button").below_of( - Class("textfield").below_of(Text("Password")) + Element("textfield").below_of(Text("Password")) ) x, y = vision_agent.locate( locator, github_login_screenshot, model=model @@ -393,7 +393,7 @@ def test_locate_with_image_and_complex_relation( image_path = path_fixtures / "images" / "github_com__signin__button.png" image = PILImage.open(image_path) locator = Image(image=image).below_of( - Class("textfield").below_of(Text("Password")) + Element("textfield").below_of(Text("Password")) ) x, y = vision_agent.locate( locator, github_login_screenshot, model=model diff --git a/tests/unit/locators/serializers/test_askui_locator_serializer.py b/tests/unit/locators/serializers/test_askui_locator_serializer.py index a2c58ad9..67840e9d 100644 --- a/tests/unit/locators/serializers/test_askui_locator_serializer.py +++ b/tests/unit/locators/serializers/test_askui_locator_serializer.py @@ -7,7 +7,7 @@ from pytest_mock import MockerFixture from askui.locators.locators import Locator -from askui.locators import Class, Description, Text, Image +from askui.locators import Element, Description, Text, Image from askui.locators.relatable import RelationBase from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection @@ -62,7 +62,7 @@ def test_serialize_text_regex(askui_serializer: AskUiLocatorSerializer) -> None: def test_serialize_class_no_name(askui_serializer: AskUiLocatorSerializer) -> None: - class_ = Class() + class_ = Element() result = askui_serializer.serialize(class_) assert result["instruction"] == "element" assert result["customElements"] == [] diff --git a/tests/unit/locators/serializers/test_locator_string_representation.py b/tests/unit/locators/serializers/test_locator_string_representation.py index 2271f446..6bc026f2 100644 --- a/tests/unit/locators/serializers/test_locator_string_representation.py +++ b/tests/unit/locators/serializers/test_locator_string_representation.py @@ -1,6 +1,6 @@ import re import pytest -from askui.locators import Class, Description, Text, Image +from askui.locators import Element, Description, Text, Image from askui.locators.relatable import CircularDependencyError from PIL import Image as PILImage @@ -29,13 +29,13 @@ def test_text_regex_str() -> None: def test_class_with_name_str() -> None: - class_ = Class("textfield") + class_ = Element("textfield") assert str(class_) == 'element with class "textfield"' def test_class_without_name_str() -> None: - class_ = Class() - assert str(class_) == "element that has a class" + class_ = Element() + assert str(class_) == "element" def test_description_str() -> None: @@ -145,7 +145,7 @@ def test_text_with_chained_relations_str() -> None: def test_mixed_locator_types_with_relations_str() -> None: text = Text("hello") - text.above_of(Class("textfield")) + text.above_of(Element("textfield")) assert ( str(text) == 'text similar to "hello" (similarity >= 70%)\n 1. above of boundary of the 1st element with class "textfield"' @@ -164,12 +164,12 @@ def test_description_with_relation_str() -> None: def test_complex_relation_chain_str() -> None: text = Text("hello") text.above_of( - Class("textfield") + Element("textfield") .right_of(Text("world", match_type="exact")) .and_( Description("input") .below_of(Text("earth", match_type="contains")) - .nearest_to(Class("textfield")) + .nearest_to(Element("textfield")) ) ) assert ( @@ -228,7 +228,7 @@ def test_deep_cycle_str() -> None: def test_multiple_references_no_cycle_str() -> None: heading = Text("heading") - textfield = Class("textfield") + textfield = Element("textfield") textfield.right_of(heading) textfield.below_of(heading) assert str(textfield) == 'element with class "textfield"\n 1. right of boundary of the 1st text similar to "heading" (similarity >= 70%)\n 2. below of boundary of the 1st text similar to "heading" (similarity >= 70%)' diff --git a/tests/unit/locators/serializers/test_vlm_locator_serializer.py b/tests/unit/locators/serializers/test_vlm_locator_serializer.py index 05b07013..00ec5425 100644 --- a/tests/unit/locators/serializers/test_vlm_locator_serializer.py +++ b/tests/unit/locators/serializers/test_vlm_locator_serializer.py @@ -1,6 +1,6 @@ import pytest from askui.locators.locators import Locator -from askui.locators import Class, Description, Text +from askui.locators import Element, Description, Text from askui.locators.locators import Image from askui.locators.relatable import CircularDependencyError from askui.locators.serializers import VlmLocatorSerializer @@ -41,13 +41,13 @@ def test_serialize_text_regex(vlm_serializer: VlmLocatorSerializer) -> None: def test_serialize_class(vlm_serializer: VlmLocatorSerializer) -> None: - class_ = Class("textfield") + class_ = Element("textfield") result = vlm_serializer.serialize(class_) assert result == "an arbitrary textfield shown" def test_serialize_class_no_name(vlm_serializer: VlmLocatorSerializer) -> None: - class_ = Class() + class_ = Element() result = vlm_serializer.serialize(class_) assert result == "an arbitrary ui element (e.g., text, button, textfield, etc.)" diff --git a/tests/unit/locators/test_locators.py b/tests/unit/locators/test_locators.py index 3d6d7378..2f9d2847 100644 --- a/tests/unit/locators/test_locators.py +++ b/tests/unit/locators/test_locators.py @@ -3,7 +3,7 @@ import pytest from PIL import Image as PILImage -from askui.locators import Description, Class, Text, Image, AiElement +from askui.locators import Description, Element, Text, Image, AiElement TEST_IMAGE_PATH = Path("tests/fixtures/images/github_com__icon.png") @@ -33,28 +33,28 @@ def test_initialization_with_invalid_args_raises(self) -> None: class TestClassLocator: def test_initialization_with_class_name(self) -> None: - cls = Class(class_name="text") + cls = Element(class_name="text") assert cls.class_name == "text" assert str(cls) == 'element with class "text"' def test_initialization_without_class_name(self) -> None: - cls = Class() + cls = Element() assert cls.class_name is None - assert str(cls) == "element that has a class" + assert str(cls) == "element" def test_initialization_with_positional_arg(self) -> None: - cls = Class("text") + cls = Element("text") assert cls.class_name == "text" def test_initialization_with_invalid_args_raises(self) -> None: with pytest.raises(ValueError): - Class(class_name="button") # type: ignore + Element(class_name="button") # type: ignore with pytest.raises(ValueError): - Class(class_name=123) # type: ignore + Element(class_name=123) # type: ignore with pytest.raises(ValueError): - Class(123) # type: ignore + Element(123) # type: ignore class TestTextLocator: From 0e3238c0942e25d450ff614ae94b049a159ceacd Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 17 Apr 2025 11:08:40 +0200 Subject: [PATCH 04/14] feat(agent): support primitive types as response_schema in get method --- src/askui/__init__.py | 5 +- src/askui/agent.py | 32 +++++++-- src/askui/models/askui/api.py | 24 ++++--- src/askui/models/router.py | 14 ++-- src/askui/models/types.py | 9 --- src/askui/models/types/__init__.py | 0 src/askui/models/types/response_schemas.py | 43 +++++++++++ tests/e2e/agent/test_get.py | 83 +++++++++++++++++++--- 8 files changed, 164 insertions(+), 46 deletions(-) delete mode 100644 src/askui/models/types.py create mode 100644 src/askui/models/types/__init__.py create mode 100644 src/askui/models/types/response_schemas.py diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 79633296..2f71c341 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -3,7 +3,7 @@ __version__ = "0.2.4" from .agent import VisionAgent -from .models.types import JsonSchemaBase +from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .tools.toolbox import AgentToolbox from .tools.agent_os import AgentOs, ModifierKey, PcKey @@ -11,8 +11,9 @@ __all__ = [ "AgentOs", "AgentToolbox", - "JsonSchemaBase", "ModifierKey", "PcKey", + "ResponseSchema", + "ResponseSchemaBase", "VisionAgent", ] diff --git a/src/askui/agent.py b/src/askui/agent.py index 03d174a8..775d0e10 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -1,6 +1,6 @@ import logging import subprocess -from typing import Annotated, Literal, Optional, Type +from typing import Annotated, Literal, Optional, Type, overload from pydantic import Field, validate_call from askui.container import telemetry @@ -22,7 +22,7 @@ import time from dotenv import load_dotenv from PIL import Image -from .models.types import JsonSchema +from .models.types.response_schemas import ResponseSchema class InvalidParameterError(Exception): @@ -216,14 +216,32 @@ def type(self, text: str) -> None: logger.debug("VisionAgent received instruction to type '%s'", text) self.tools.os.type(text) # type: ignore + + @overload + def get( + self, + query: str, + response_schema: None = None, + image: Optional[ImageSource] = None, + model: ModelComposition | str | None = None, + ) -> str: ... + @overload + def get( + self, + query: str, + response_schema: Type[ResponseSchema], + image: Optional[ImageSource] = None, + model: ModelComposition | str | None = None, + ) -> ResponseSchema: ... + @telemetry.record_call(exclude={"query", "image", "response_schema"}) def get( self, query: str, image: Optional[ImageSource] = None, - response_schema: Type[JsonSchema] | None = None, + response_schema: Type[ResponseSchema] | None = None, model: ModelComposition | str | None = None, - ) -> JsonSchema | str: + ) -> ResponseSchema | str: """ Retrieves information from an image (defaults to a screenshot of the current screen) based on the provided query. @@ -232,14 +250,14 @@ def get( The query describing what information to retrieve. image (ImageSource | None): The image to extract information from. Optional. Defaults to a screenshot of the current screen. - response_schema (type[ResponseSchema] | None): + response_schema (Type[ResponseSchema] | None): A Pydantic model class that defines the response schema. Optional. If not provided, returns a string. model (ModelComposition | str | None): The composition or name of the model(s) to be used for retrieving information from the screen or image using the `query`. Note: `response_schema` is only supported with not supported by all models. Returns: - ResponseSchema | str: The extracted information, either as a Pydantic model instance or a string. + ResponseSchema: The extracted information, either as an instance of ResponseSchemaBase or the primite type passed or string if no response_schema is provided. Limitations: - Nested Pydantic schemas are not currently supported @@ -275,7 +293,7 @@ class UrlResponse(JsonSchemaBase): response_schema=response_schema, ) if self._reporter is not None: - message_content = response if isinstance(response, str) else response.model_dump() + message_content = str(response) if isinstance(response, (str, bool, int, float)) else response.model_dump() self._reporter.add_message("Agent", message_content) return response diff --git a/src/askui/models/askui/api.py b/src/askui/models/askui/api.py index a44a5a48..fada2dc4 100644 --- a/src/askui/models/askui/api.py +++ b/src/askui/models/askui/api.py @@ -1,6 +1,7 @@ import os import base64 import pathlib +from pydantic import RootModel import requests import json as json_lib from PIL import Image @@ -11,7 +12,7 @@ from askui.locators.locators import Locator from askui.utils.image_utils import image_to_base64 from askui.logger import logger -from ..types import JsonSchema +from ..types.response_schemas import ResponseSchema, to_response_schema @@ -74,19 +75,20 @@ def get_inference( self, image: ImageSource, query: str, - response_schema: Type[JsonSchema] | None = None - ) -> JsonSchema | str: + response_schema: Type[ResponseSchema] | None = None + ) -> ResponseSchema | str: json: dict[str, Any] = { "image": image.to_data_url(), "prompt": query, } - if response_schema is not None: - json["config"] = { - "json_schema": response_schema.model_json_schema() - } - logger.debug(f"json_schema:\n{json_lib.dumps(json['config']['json_schema'])}") + _response_schema = to_response_schema(response_schema) + json["config"] = { + "json_schema": _response_schema.model_json_schema() + } + logger.debug(f"json_schema:\n{json_lib.dumps(json['config']['json_schema'])}") content = self._request(endpoint="vqa/inference", json=json) response = content["data"]["response"] - if response_schema is not None: - return response_schema.model_validate(response) - return response + validated_response = _response_schema.model_validate(response) + if isinstance(validated_response, RootModel): + return validated_response.root + return validated_response diff --git a/src/askui/models/router.py b/src/askui/models/router.py index 42756d84..e0f4d092 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -9,7 +9,7 @@ from askui.locators.locators import Locator from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.models import ModelComposition, ModelName -from askui.models.types import JsonSchema +from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter from askui.utils.image_utils import ImageSource from .askui.api import AskUiInferenceApi @@ -141,18 +141,18 @@ def get_inference( self, query: str, image: ImageSource, - response_schema: Type[JsonSchema] | None = None, + response_schema: Type[ResponseSchema] | None = None, model: ModelComposition | str | None = None, - ) -> JsonSchema | str: + ) -> ResponseSchema | str: if self.tars.authenticated and model == ModelName.TARS: - if response_schema is not None: - raise NotImplementedError("Response schema is not yet supported for UI-TARS models.") + if response_schema not in [str, None]: + raise NotImplementedError("(Non-String) Response schema is not yet supported for UI-TARS models.") return self.tars.get_inference(image=image, query=query) if self.claude.authenticated and ( isinstance(model, str) and model.startswith(ModelName.ANTHROPIC) ): - if response_schema is not None: - raise NotImplementedError("Response schema is not yet supported for Anthropic models.") + if response_schema not in [str, None]: + raise NotImplementedError("(Non-String) Response schema is not yet supported for Anthropic models.") return self.claude.get_inference(image=image, query=query) if self.askui.authenticated and (model == ModelName.ASKUI or model is None): return self.askui.get_inference( diff --git a/src/askui/models/types.py b/src/askui/models/types.py deleted file mode 100644 index 82a6b929..00000000 --- a/src/askui/models/types.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import TypeVar -from pydantic import BaseModel, ConfigDict - - -class JsonSchemaBase(BaseModel): - model_config = ConfigDict(extra="forbid") - - -JsonSchema = TypeVar('JsonSchema', bound=JsonSchemaBase) diff --git a/src/askui/models/types/__init__.py b/src/askui/models/types/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/models/types/response_schemas.py b/src/askui/models/types/response_schemas.py new file mode 100644 index 00000000..e9eba25c --- /dev/null +++ b/src/askui/models/types/response_schemas.py @@ -0,0 +1,43 @@ +from typing import Type, TypeVar, overload +from pydantic import BaseModel, ConfigDict, RootModel + + +class ResponseSchemaBase(BaseModel): + model_config = ConfigDict(extra="forbid") + + +String = RootModel[str] +Boolean = RootModel[bool] +Integer = RootModel[int] +Float = RootModel[float] + + +ResponseSchema = TypeVar('ResponseSchema', ResponseSchemaBase, str, bool, int, float) + + +@overload +def to_response_schema(response_schema: None) -> Type[String]: ... +@overload +def to_response_schema(response_schema: Type[str]) -> Type[String]: ... +@overload +def to_response_schema(response_schema: Type[bool]) -> Type[Boolean]: ... +@overload +def to_response_schema(response_schema: Type[int]) -> Type[Integer]: ... +@overload +def to_response_schema(response_schema: Type[float]) -> Type[Float]: ... +@overload +def to_response_schema(response_schema: Type[ResponseSchemaBase]) -> Type[ResponseSchemaBase]: ... +def to_response_schema(response_schema: Type[ResponseSchemaBase] | Type[str] | Type[bool] | Type[int] | Type[float] | None = None) -> Type[ResponseSchemaBase] | Type[String] | Type[Boolean] | Type[Integer] | Type[Float]: + if response_schema is None: + return String + if response_schema is str: + return String + if response_schema is bool: + return Boolean + if response_schema is int: + return Integer + if response_schema is float: + return Float + if issubclass(response_schema, ResponseSchemaBase): + return response_schema + raise ValueError(f"Invalid response schema type: {response_schema}") diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index ca9940c5..17391994 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -4,10 +4,10 @@ from askui.models import ModelName from askui import VisionAgent from askui.utils.image_utils import ImageSource -from askui import JsonSchemaBase +from askui.response_schemas import ResponseSchemaBase -class UrlResponse(JsonSchemaBase): +class UrlResponse(ResponseSchemaBase): url: str @@ -15,7 +15,7 @@ class PageContextResponse(UrlResponse): title: str -class BrowserContextResponse(JsonSchemaBase): +class BrowserContextResponse(ResponseSchemaBase): page_context: PageContextResponse browser_type: Literal["chrome", "firefox", "edge", "safari"] @@ -28,7 +28,7 @@ def test_get( ) -> None: url = vision_agent.get( "What is the current url shown in the url bar?", - ImageSource(github_login_screenshot), + image=ImageSource(github_login_screenshot), model=model, ) assert url == "github.com/login" @@ -42,7 +42,7 @@ def test_get_with_response_schema_without_additional_properties_with_askui_model with pytest.raises(Exception): vision_agent.get( "What is the current url shown in the url bar?", - ImageSource(github_login_screenshot), + image=ImageSource(github_login_screenshot), response_schema=UrlResponse, model=ModelName.ASKUI, ) @@ -56,7 +56,7 @@ def test_get_with_response_schema_without_required_with_askui_model_raises( with pytest.raises(Exception): vision_agent.get( "What is the current url shown in the url bar?", - ImageSource(github_login_screenshot), + image=ImageSource(github_login_screenshot), response_schema=UrlResponse, model=ModelName.ASKUI, ) @@ -70,7 +70,7 @@ def test_get_with_response_schema( ) -> None: response = vision_agent.get( "What is the current url shown in the url bar?", - ImageSource(github_login_screenshot), + image=ImageSource(github_login_screenshot), response_schema=UrlResponse, model=model, ) @@ -85,13 +85,13 @@ def test_get_with_response_schema_with_anthropic_model_raises_not_implemented( with pytest.raises(NotImplementedError): vision_agent.get( "What is the current url shown in the url bar?", - ImageSource(github_login_screenshot), + image=ImageSource(github_login_screenshot), response_schema=UrlResponse, model=ModelName.ANTHROPIC, ) -@pytest.mark.parametrize("model", [None, ModelName.ASKUI]) +@pytest.mark.parametrize("model", [ModelName.ASKUI]) @pytest.mark.skip("Skip as there is currently a bug on the api side not supporting definitions used for nested schemas") def test_get_with_nested_and_inherited_response_schema( vision_agent: VisionAgent, @@ -100,7 +100,7 @@ def test_get_with_nested_and_inherited_response_schema( ) -> None: response = vision_agent.get( "What is the current browser context?", - ImageSource(github_login_screenshot), + image=ImageSource(github_login_screenshot), response_schema=BrowserContextResponse, model=model, ) @@ -108,3 +108,66 @@ def test_get_with_nested_and_inherited_response_schema( assert response.page_context.url in ["https://github.com/login", "github.com/login"] assert "Github" in response.page_context.title assert response.browser_type in ["chrome", "firefox", "edge", "safari"] + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_with_string_schema( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "What is the current url shown in the url bar?", + image=ImageSource(github_login_screenshot), + response_schema=str, + model=model, + ) + assert response in ["https://github.com/login", "github.com/login"] + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_with_boolean_schema( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "Is this a login page?", + image=ImageSource(github_login_screenshot), + response_schema=bool, + model=model, + ) + assert isinstance(response, bool) + assert response is True + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_with_integer_schema( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "How many input fields are visible on this page?", + image=ImageSource(github_login_screenshot), + response_schema=int, + model=model, + ) + assert isinstance(response, int) + assert response > 0 + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_with_float_schema( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "Return a floating point number between 0 and 1 as a rating for how you well this page is designed (0 is the worst, 1 is the best)", + image=ImageSource(github_login_screenshot), + response_schema=float, + model=model, + ) + assert isinstance(response, float) + assert response > 0 From 469058c82ac3f9afd9f75c7d7cbf56cebace0426 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 17 Apr 2025 15:35:48 +0200 Subject: [PATCH 05/14] refactor: validate all public methods & make locators non-pydantic based - non-pydantic based locators are way easier to use because they have less methods (autocompletion) --> instead use validate_call & make properties read-only --- src/askui/__init__.py | 2 + src/askui/agent.py | 131 +++++++++++----- src/askui/locators/locators.py | 145 ++++++++++++------ src/askui/locators/relatable.py | 103 +++++++++---- src/askui/locators/serializers.py | 28 ++-- src/askui/logger.py | 2 +- src/askui/models/anthropic/claude.py | 3 +- src/askui/models/anthropic/claude_agent.py | 6 +- src/askui/models/askui/api.py | 1 + src/askui/models/router.py | 67 ++++---- src/askui/models/ui_tars_ep/ui_tars_api.py | 34 ++-- src/askui/reporting.py | 4 +- src/askui/tools/__init__.py | 3 + src/askui/tools/askui/__init__.py | 3 + src/askui/tools/toolbox.py | 4 +- tests/conftest.py | 2 +- tests/e2e/agent/conftest.py | 1 + tests/e2e/agent/test_get.py | 4 +- .../test_askui_locator_serializer.py | 17 -- tests/unit/locators/test_locators.py | 4 +- tests/unit/test_validate_call.py | 9 ++ 21 files changed, 357 insertions(+), 216 deletions(-) create mode 100644 tests/unit/test_validate_call.py diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 2f71c341..6cd6a904 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -3,6 +3,7 @@ __version__ = "0.2.4" from .agent import VisionAgent +from .models.router import ModelRouter from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .tools.toolbox import AgentToolbox from .tools.agent_os import AgentOs, ModifierKey, PcKey @@ -11,6 +12,7 @@ __all__ = [ "AgentOs", "AgentToolbox", + "ModelRouter", "ModifierKey", "PcKey", "ResponseSchema", diff --git a/src/askui/agent.py b/src/askui/agent.py index 775d0e10..78e41173 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -1,7 +1,7 @@ import logging import subprocess -from typing import Annotated, Literal, Optional, Type, overload -from pydantic import Field, validate_call +from typing import Annotated, Any, Literal, Optional, Type, overload +from pydantic import ConfigDict, Field, validate_call from askui.container import telemetry from askui.locators.locators import Locator @@ -13,7 +13,6 @@ ModifierKey, PcKey, ) -from .models.anthropic.claude import ClaudeHandler from .logger import logger, configure_logging from .tools.toolbox import AgentToolbox from .models import ModelComposition @@ -60,10 +59,11 @@ class VisionAgent: ``` """ @telemetry.record_call(exclude={"model_router", "reporters", "tools"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, - log_level=logging.INFO, - display: int = 1, + log_level: int | str = logging.INFO, + display: Annotated[int, Field(ge=1)] = 1, model_router: ModelRouter | None = None, reporters: list[Reporter] | None = None, tools: AgentToolbox | None = None, @@ -71,19 +71,23 @@ def __init__( ) -> None: load_dotenv() configure_logging(level=log_level) - self._reporter = CompositeReporter(reports=reporters or []) + self._reporter = CompositeReporter(reports=reporters) + self.tools = tools or AgentToolbox(agent_os=AskUiControllerClient(display=display, reporter=self._reporter)) self.model_router = ( - ModelRouter(log_level=log_level, reporter=self._reporter) - if model_router is None - else model_router + ModelRouter(tools=self.tools, reporter=self._reporter) if model_router is None else model_router ) - self.claude = ClaudeHandler(log_level=log_level) - self.tools = tools or AgentToolbox(os=AskUiControllerClient(display=display, reporter=self._reporter)) self._controller = AskUiControllerServer() self._model = model @telemetry.record_call(exclude={"locator"}) - def click(self, locator: Optional[str | Locator] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model: ModelComposition | str | None = None) -> None: + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def click( + self, + locator: Optional[str | Locator] = None, + button: Literal['left', 'middle', 'right'] = 'left', + repeat: Annotated[int, Field(gt=0)] = 1, + model: ModelComposition | str | None = None, + ) -> None: """ Simulates a mouse click on the user interface element identified by the provided locator. @@ -119,16 +123,22 @@ def click(self, locator: Optional[str | Locator] = None, button: Literal['left', if locator is not None: logger.debug("VisionAgent received instruction to click on %s", locator) self._mouse_move(locator, model or self._model) - self.tools.os.click(button, repeat) # type: ignore + self.tools.agent_os.click(button, repeat) # type: ignore def _locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: ModelComposition | str | None = None) -> Point: if screenshot is None: - screenshot = self.tools.os.screenshot() # type: ignore + screenshot = self.tools.agent_os.screenshot() # type: ignore point = self.model_router.locate(screenshot, locator, model or self._model) self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") return point - def locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: ModelComposition | str | None = None) -> Point: + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def locate( + self, + locator: str | Locator, + screenshot: Optional[Image.Image] = None, + model: ModelComposition | str | None = None, + ) -> Point: """ Locates the UI element identified by the provided locator. @@ -146,10 +156,15 @@ def locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = Non def _mouse_move(self, locator: str | Locator, model: ModelComposition | str | None = None) -> None: point = self._locate(locator=locator, model=model or self._model) - self.tools.os.mouse(point[0], point[1]) # type: ignore + self.tools.agent_os.mouse(point[0], point[1]) # type: ignore @telemetry.record_call(exclude={"locator"}) - def mouse_move(self, locator: str | Locator, model: ModelComposition | str | None = None) -> None: + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def mouse_move( + self, + locator: str | Locator, + model: ModelComposition | str | None = None, + ) -> None: """ Moves the mouse cursor to the UI element identified by the provided locator. @@ -170,7 +185,12 @@ def mouse_move(self, locator: str | Locator, model: ModelComposition | str | Non self._mouse_move(locator, model or self._model) @telemetry.record_call() - def mouse_scroll(self, x: int, y: int) -> None: + @validate_call + def mouse_scroll( + self, + x: int, + y: int, + ) -> None: """ Simulates scrolling the mouse wheel by the specified horizontal and vertical amounts. @@ -194,10 +214,14 @@ def mouse_scroll(self, x: int, y: int) -> None: ``` """ self._reporter.add_message("User", f'mouse_scroll: "{x}", "{y}"') - self.tools.os.mouse_scroll(x, y) + self.tools.agent_os.mouse_scroll(x, y) @telemetry.record_call(exclude={"text"}) - def type(self, text: str) -> None: + @validate_call + def type( + self, + text: Annotated[str, Field(min_length=1)], + ) -> None: """ Types the specified text as if it were entered on a keyboard. @@ -214,13 +238,13 @@ def type(self, text: str) -> None: """ self._reporter.add_message("User", f'type: "{text}"') logger.debug("VisionAgent received instruction to type '%s'", text) - self.tools.os.type(text) # type: ignore + self.tools.agent_os.type(text) # type: ignore @overload def get( self, - query: str, + query: Annotated[str, Field(min_length=1)], response_schema: None = None, image: Optional[ImageSource] = None, model: ModelComposition | str | None = None, @@ -228,16 +252,17 @@ def get( @overload def get( self, - query: str, + query: Annotated[str, Field(min_length=1)], response_schema: Type[ResponseSchema], image: Optional[ImageSource] = None, model: ModelComposition | str | None = None, ) -> ResponseSchema: ... @telemetry.record_call(exclude={"query", "image", "response_schema"}) + @validate_call def get( self, - query: str, + query: Annotated[str, Field(min_length=1)], image: Optional[ImageSource] = None, response_schema: Type[ResponseSchema] | None = None, model: ModelComposition | str | None = None, @@ -285,7 +310,7 @@ class UrlResponse(JsonSchemaBase): self._reporter.add_message("User", f'get: "{query}"') logger.debug("VisionAgent received instruction to get '%s'", query) if image is None: - image = ImageSource(self.tools.os.screenshot()) # type: ignore + image = ImageSource(self.tools.agent_os.screenshot()) # type: ignore response = self.model_router.get_inference( image=image, query=query, @@ -299,12 +324,15 @@ class UrlResponse(JsonSchemaBase): @telemetry.record_call() @validate_call - def wait(self, sec: Annotated[float, Field(gt=0)]) -> None: + def wait( + self, + sec: Annotated[float, Field(gt=0.0)], + ) -> None: """ Pauses the execution of the program for the specified number of seconds. Parameters: - sec (float): The number of seconds to wait. Must be greater than 0. + sec (float): The number of seconds to wait. Must be greater than 0.0. Raises: ValueError: If the provided `sec` is negative. @@ -319,7 +347,11 @@ def wait(self, sec: Annotated[float, Field(gt=0)]) -> None: time.sleep(sec) @telemetry.record_call() - def key_up(self, key: PcKey | ModifierKey) -> None: + @validate_call + def key_up( + self, + key: PcKey | ModifierKey, + ) -> None: """ Simulates the release of a key. @@ -335,10 +367,14 @@ def key_up(self, key: PcKey | ModifierKey) -> None: """ self._reporter.add_message("User", f'key_up "{key}"') logger.debug("VisionAgent received in key_up '%s'", key) - self.tools.os.keyboard_release(key) + self.tools.agent_os.keyboard_release(key) @telemetry.record_call() - def key_down(self, key: PcKey | ModifierKey) -> None: + @validate_call + def key_down( + self, + key: PcKey | ModifierKey, + ) -> None: """ Simulates the pressing of a key. @@ -354,10 +390,15 @@ def key_down(self, key: PcKey | ModifierKey) -> None: """ self._reporter.add_message("User", f'key_down "{key}"') logger.debug("VisionAgent received in key_down '%s'", key) - self.tools.os.keyboard_pressed(key) + self.tools.agent_os.keyboard_pressed(key) @telemetry.record_call(exclude={"goal"}) - def act(self, goal: str, model: ModelComposition | str | None = None) -> None: + @validate_call + def act( + self, + goal: Annotated[str, Field(min_length=1)], + model: ModelComposition | str | None = None, + ) -> None: """ Instructs the agent to achieve a specified goal through autonomous actions. @@ -381,11 +422,14 @@ def act(self, goal: str, model: ModelComposition | str | None = None) -> None: logger.debug( "VisionAgent received instruction to act towards the goal '%s'", goal ) - self.model_router.act(self.tools.os, goal, model or self._model) + self.model_router.act(goal, model or self._model) @telemetry.record_call() + @validate_call def keyboard( - self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + self, + key: PcKey | ModifierKey, + modifier_keys: Optional[list[ModifierKey]] = None, ) -> None: """ Simulates pressing a key or key combination on the keyboard. @@ -406,10 +450,14 @@ def keyboard( ``` """ logger.debug("VisionAgent received instruction to press '%s'", key) - self.tools.os.keyboard_tap(key, modifier_keys) # type: ignore + self.tools.agent_os.keyboard_tap(key, modifier_keys) # type: ignore @telemetry.record_call(exclude={"command"}) - def cli(self, command: str) -> None: + @validate_call + def cli( + self, + command: Annotated[str, Field(min_length=1)], + ) -> None: """ Executes a command on the command line interface. @@ -432,7 +480,7 @@ def cli(self, command: str) -> None: @telemetry.record_call(flush=True) def close(self) -> None: - self.tools.os.disconnect() + self.tools.agent_os.disconnect() if self._controller: self._controller.stop(True) self._reporter.generate() @@ -440,7 +488,7 @@ def close(self) -> None: @telemetry.record_call() def open(self) -> None: self._controller.start(True) - self.tools.os.connect() + self.tools.agent_os.connect() @telemetry.record_call() def __enter__(self) -> "VisionAgent": @@ -448,5 +496,10 @@ def __enter__(self) -> "VisionAgent": return self @telemetry.record_call(exclude={"exc_value", "traceback"}) - def __exit__(self, exc_type, exc_value, traceback) -> None: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[Any], + ) -> None: self.close() diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index 3ec306ae..93bbc04f 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -1,16 +1,16 @@ from abc import ABC import pathlib -from typing import Literal, Union +from typing import Annotated, Literal, Union import uuid from PIL import Image as PILImage -from pydantic import BaseModel, Field +from pydantic import ConfigDict, Field, validate_call from askui.utils.image_utils import ImageSource from askui.locators.relatable import Relatable -class Locator(Relatable, BaseModel, ABC): +class Locator(Relatable, ABC): """Base class for all locators.""" pass @@ -18,10 +18,14 @@ class Locator(Relatable, BaseModel, ABC): class Description(Locator): """Locator for finding ui elements by a textual description of the ui element.""" - description: str - - def __init__(self, description: str, **kwargs) -> None: - super().__init__(description=description, **kwargs) # type: ignore + @validate_call + def __init__(self, description: str) -> None: + super().__init__() + self._description = description + + @property + def description(self) -> str: + return self._description def _str_with_relation(self) -> str: result = f'element with description "{self.description}"' @@ -34,14 +38,17 @@ def __str__(self) -> str: class Element(Locator): """Locator for finding ui elements by a class name assigned to the ui element, e.g., by a computer vision model.""" - class_name: Literal["text", "textfield"] | None = None - + @validate_call def __init__( self, class_name: Literal["text", "textfield"] | None = None, - **kwargs, ) -> None: - super().__init__(class_name=class_name, **kwargs) # type: ignore + super().__init__() + self._class_name = class_name + + @property + def class_name(self) -> Literal["text", "textfield"] | None: + return self._class_name def _str_with_relation(self) -> str: result = ( @@ -63,23 +70,29 @@ def __str__(self) -> str: class Text(Element): """Locator for finding text elements by their content.""" - text: str | None = None - match_type: TextMatchType = DEFAULT_TEXT_MATCH_TYPE - similarity_threshold: int = Field(default=DEFAULT_SIMILARITY_THRESHOLD, ge=0, le=100) - + @validate_call def __init__( self, text: str | None = None, match_type: TextMatchType = DEFAULT_TEXT_MATCH_TYPE, - similarity_threshold: int = DEFAULT_SIMILARITY_THRESHOLD, - **kwargs, + similarity_threshold: Annotated[int, Field(ge=0, le=100)] = DEFAULT_SIMILARITY_THRESHOLD, ) -> None: - super().__init__( - text=text, - match_type=match_type, - similarity_threshold=similarity_threshold, - **kwargs, - ) # type: ignore + super().__init__() + self._text = text + self._match_type = match_type + self._similarity_threshold = similarity_threshold + + @property + def text(self) -> str | None: + return self._text + + @property + def match_type(self) -> TextMatchType: + return self._match_type + + @property + def similarity_threshold(self) -> int: + return self._similarity_threshold def _str_with_relation(self) -> str: if self.text is None: @@ -102,44 +115,79 @@ def __str__(self) -> str: return self._str_with_relation() -class ImageMetadata(Locator): - threshold: float = Field(default=0.5, ge=0, le=1) - stop_threshold: float = Field(default=0.9, ge=0, le=1) - mask: list[tuple[float, float]] | None = Field(default=None, min_length=3) - rotation_degree_per_step: int = Field(default=0, ge=0, lt=360) - image_compare_format: Literal["RGB", "grayscale", "edges"] = "grayscale" - name: str +class ImageBase(Locator, ABC): + def __init__( + self, + threshold: float, + stop_threshold: float, + mask: list[tuple[float, float]] | None, + rotation_degree_per_step: int, + name: str, + image_compare_format: Literal["RGB", "grayscale", "edges"], + ) -> None: + super().__init__() + self._threshold = threshold + self._stop_threshold = stop_threshold + self._mask = mask + self._rotation_degree_per_step = rotation_degree_per_step + self._name = name + self._image_compare_format = image_compare_format + + @property + def threshold(self) -> float: + return self._threshold + + @property + def stop_threshold(self) -> float: + return self._stop_threshold + + @property + def mask(self) -> list[tuple[float, float]] | None: + return self._mask + + @property + def rotation_degree_per_step(self) -> int: + return self._rotation_degree_per_step + + @property + def name(self) -> str: + return self._name + + @property + def image_compare_format(self) -> Literal["RGB", "grayscale", "edges"]: + return self._image_compare_format def _generate_name() -> str: return f"anonymous custom element {uuid.uuid4()}" -class Image(ImageMetadata): +class Image(ImageBase): """Locator for finding ui elements by an image.""" - image: ImageSource - + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, - image: Union[ImageSource, PILImage.Image, pathlib.Path, str], - threshold: float = 0.5, - stop_threshold: float = 0.9, - mask: list[tuple[float, float]] | None = None, - rotation_degree_per_step: int = 0, - image_compare_format: Literal["RGB", "grayscale", "edges"] = "grayscale", + image: Union[PILImage.Image, pathlib.Path, str], + threshold: Annotated[float, Field(ge=0, le=1)] = 0.5, + stop_threshold: Annotated[float, Field(ge=0, le=1)] = 0.9, + mask: Annotated[list[tuple[float, float]] | None, Field(min_length=3)] = None, + rotation_degree_per_step: Annotated[int, Field(ge=0, lt=360)] = 0, name: str | None = None, - **kwargs, + image_compare_format: Literal["RGB", "grayscale", "edges"] = "grayscale", ) -> None: super().__init__( - image=image, threshold=threshold, stop_threshold=stop_threshold, mask=mask, rotation_degree_per_step=rotation_degree_per_step, image_compare_format=image_compare_format, name=_generate_name() if name is None else name, - **kwargs, ) # type: ignore + self._image = ImageSource(image) + + @property + def image(self) -> ImageSource: + return self._image def _str_with_relation(self) -> str: result = f'element "{self.name}" located by image' @@ -150,17 +198,17 @@ def __str__(self) -> str: return self._str_with_relation() -class AiElement(ImageMetadata): +class AiElement(ImageBase): """Locator for finding ui elements by an image and other kinds data saved on the disk.""" + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, name: str, - threshold: float = 0.5, - stop_threshold: float = 0.9, - mask: list[tuple[float, float]] | None = None, - rotation_degree_per_step: int = 0, + threshold: Annotated[float, Field(ge=0, le=1)] = 0.5, + stop_threshold: Annotated[float, Field(ge=0, le=1)] = 0.9, + mask: Annotated[list[tuple[float, float]] | None, Field(min_length=3)] = None, + rotation_degree_per_step: Annotated[int, Field(ge=0, lt=360)] = 0, image_compare_format: Literal["RGB", "grayscale", "edges"] = "grayscale", - **kwargs, ) -> None: super().__init__( name=name, @@ -169,7 +217,6 @@ def __init__( mask=mask, rotation_degree_per_step=rotation_degree_per_step, image_compare_format=image_compare_format, - **kwargs, ) # type: ignore def _str_with_relation(self) -> str: diff --git a/src/askui/locators/relatable.py b/src/askui/locators/relatable.py index 6b77beae..69c0774a 100644 --- a/src/askui/locators/relatable.py +++ b/src/askui/locators/relatable.py @@ -1,7 +1,6 @@ from abc import ABC -from dataclasses import dataclass -from typing import Literal -from pydantic import BaseModel, Field +from typing import Annotated, Literal +from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Self @@ -21,19 +20,32 @@ } -@dataclass(kw_only=True) -class RelationBase(ABC): +RelationIndex = Annotated[int, Field(ge=0)] + + +class RelationBase(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) other_locator: "Relatable" - type: Literal["above_of", "below_of", "right_of", "left_of", "and", "or", "containing", "inside_of", "nearest_to"] + type: Literal[ + "above_of", + "below_of", + "right_of", + "left_of", + "and", + "or", + "containing", + "inside_of", + "nearest_to", + ] def __str__(self): return f"{RelationTypeMapping[self.type]} {self.other_locator._str_with_relation()}" -@dataclass(kw_only=True) + class NeighborRelation(RelationBase): type: Literal["above_of", "below_of", "right_of", "left_of"] - index: int + index: RelationIndex reference_point: ReferencePoint def __str__(self): @@ -41,21 +53,28 @@ def __str__(self): if i == 11 or i == 12 or i == 13: index_str = f"{i}th" else: - index_str = f"{i}st" if i % 10 == 1 else f"{i}nd" if i % 10 == 2 else f"{i}rd" if i % 10 == 3 else f"{i}th" - reference_point_str = " center of" if self.reference_point == "center" else " boundary of" if self.reference_point == "boundary" else "" + index_str = ( + f"{i}st" + if i % 10 == 1 + else f"{i}nd" if i % 10 == 2 else f"{i}rd" if i % 10 == 3 else f"{i}th" + ) + reference_point_str = ( + " center of" + if self.reference_point == "center" + else " boundary of" if self.reference_point == "boundary" else "" + ) return f"{RelationTypeMapping[self.type]}{reference_point_str} the {index_str} {self.other_locator._str_with_relation()}" -@dataclass(kw_only=True) + class LogicalRelation(RelationBase): type: Literal["and", "or"] -@dataclass(kw_only=True) + class BoundingRelation(RelationBase): type: Literal["containing", "inside_of"] -@dataclass(kw_only=True) class NearestToRelation(RelationBase): type: Literal["nearest_to"] @@ -65,6 +84,7 @@ class NearestToRelation(RelationBase): class CircularDependencyError(ValueError): """Exception raised for circular dependencies in locator relations.""" + def __init__( self, message: str = ( @@ -76,21 +96,28 @@ def __init__( super().__init__(message) -class Relatable(BaseModel, ABC): +class Relatable(ABC): """Base class for locators that can be related to other locators, e.g., spatially, logically, distance based etc. - + Attributes: relations: List of relations to other locators """ - relations: list[Relation] = Field(default_factory=list) + def __init__(self) -> None: + self._relations: list[Relation] = [] + + @property + def relations(self) -> list[Relation]: + return self._relations + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation def above_of( self, other_locator: "Relatable", - index: int = 0, + index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: - self.relations.append( + + self._relations.append( NeighborRelation( type="above_of", other_locator=other_locator, @@ -100,13 +127,14 @@ def above_of( ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation def below_of( self, other_locator: "Relatable", - index: int = 0, + index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: - self.relations.append( + self._relations.append( NeighborRelation( type="below_of", other_locator=other_locator, @@ -116,13 +144,14 @@ def below_of( ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation def right_of( self, other_locator: "Relatable", - index: int = 0, + index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: - self.relations.append( + self._relations.append( NeighborRelation( type="right_of", other_locator=other_locator, @@ -132,13 +161,14 @@ def right_of( ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NeighborRelation def left_of( self, other_locator: "Relatable", - index: int = 0, + index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: - self.relations.append( + self._relations.append( NeighborRelation( type="left_of", other_locator=other_locator, @@ -148,8 +178,9 @@ def left_of( ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using BoundingRelation def containing(self, other_locator: "Relatable") -> Self: - self.relations.append( + self._relations.append( BoundingRelation( type="containing", other_locator=other_locator, @@ -157,8 +188,9 @@ def containing(self, other_locator: "Relatable") -> Self: ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using BoundingRelation def inside_of(self, other_locator: "Relatable") -> Self: - self.relations.append( + self._relations.append( BoundingRelation( type="inside_of", other_locator=other_locator, @@ -166,8 +198,9 @@ def inside_of(self, other_locator: "Relatable") -> Self: ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NearestToRelation def nearest_to(self, other_locator: "Relatable") -> Self: - self.relations.append( + self._relations.append( NearestToRelation( type="nearest_to", other_locator=other_locator, @@ -175,8 +208,9 @@ def nearest_to(self, other_locator: "Relatable") -> Self: ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using LogicalRelation def and_(self, other_locator: "Relatable") -> Self: - self.relations.append( + self._relations.append( LogicalRelation( type="and", other_locator=other_locator, @@ -184,8 +218,9 @@ def and_(self, other_locator: "Relatable") -> Self: ) return self + # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using LogicalRelation def or_(self, other_locator: "Relatable") -> Self: - self.relations.append( + self._relations.append( LogicalRelation( type="or", other_locator=other_locator, @@ -194,21 +229,21 @@ def or_(self, other_locator: "Relatable") -> Self: return self def _relations_str(self) -> str: - if not self.relations: + if not self._relations: return "" - + result = [] - for i, relation in enumerate(self.relations): + for i, relation in enumerate(self._relations): [other_locator_str, *nested_relation_strs] = str(relation).split("\n") result.append(f" {i + 1}. {other_locator_str}") for nested_relation_str in nested_relation_strs: result.append(f" {nested_relation_str}") return "\n" + "\n".join(result) - + def raise_if_cycle(self) -> None: if self._has_cycle(): raise CircularDependencyError() - + def _has_cycle(self) -> bool: """Check if the relations form a cycle.""" visited_ids: set[int] = set() diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py index 5140784e..bcef4e07 100644 --- a/src/askui/locators/serializers.py +++ b/src/askui/locators/serializers.py @@ -5,7 +5,7 @@ from .locators import ( DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TEXT_MATCH_TYPE, - ImageMetadata, + ImageBase, AiElement as AiElementLocator, Element, Description, @@ -114,7 +114,7 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: result["instruction"] = self._serialize_description(locator) elif isinstance(locator, Image): result = self._serialize_image( - image_metadata=locator, + image_locator=locator, image_sources=[locator.image], ) elif isinstance(locator, AiElementLocator): @@ -187,35 +187,35 @@ def _serialize_non_neighbor_relation( def _serialize_image_to_custom_element( self, - image_metadata: ImageMetadata, + image_locator: ImageBase, image_source: ImageSource, ) -> CustomElement: custom_element: CustomElement = CustomElement( customImage=image_source.to_data_url(), - threshold=image_metadata.threshold, - stopThreshold=image_metadata.stop_threshold, - rotationDegreePerStep=image_metadata.rotation_degree_per_step, - imageCompareFormat=image_metadata.image_compare_format, - name=image_metadata.name, + threshold=image_locator.threshold, + stopThreshold=image_locator.stop_threshold, + rotationDegreePerStep=image_locator.rotation_degree_per_step, + imageCompareFormat=image_locator.image_compare_format, + name=image_locator.name, ) - if image_metadata.mask: - custom_element["mask"] = image_metadata.mask + if image_locator.mask: + custom_element["mask"] = image_locator.mask return custom_element def _serialize_image( self, - image_metadata: ImageMetadata, + image_locator: ImageBase, image_sources: list[ImageSource], ) -> AskUiSerializedLocator: custom_elements: list[CustomElement] = [ self._serialize_image_to_custom_element( - image_metadata=image_metadata, + image_locator=image_locator, image_source=image_source, ) for image_source in image_sources ] return AskUiSerializedLocator( - instruction=f"custom element with text {self._TEXT_DELIMITER}{image_metadata.name}{self._TEXT_DELIMITER}", + instruction=f"custom element with text {self._TEXT_DELIMITER}{image_locator.name}{self._TEXT_DELIMITER}", customElements=custom_elements, ) @@ -228,6 +228,6 @@ def _serialize_ai_element( f"Could not find AI element with name \"{ai_element_locator.name}\"" ) return self._serialize_image( - image_metadata=ai_element_locator, + image_locator=ai_element_locator, image_sources=[ImageSource.model_construct(root=ai_element.image) for ai_element in ai_elements], ) diff --git a/src/askui/logger.py b/src/askui/logger.py index e6da1743..2038ecf9 100644 --- a/src/askui/logger.py +++ b/src/askui/logger.py @@ -11,7 +11,7 @@ logger.setLevel(logging.INFO) -def configure_logging(level=logging.INFO): +def configure_logging(level: str | int = logging.INFO): logger.setLevel(level) diff --git a/src/askui/models/anthropic/claude.py b/src/askui/models/anthropic/claude.py index 12f1cf14..8965a5e5 100644 --- a/src/askui/models/anthropic/claude.py +++ b/src/askui/models/anthropic/claude.py @@ -10,11 +10,10 @@ class ClaudeHandler: - def __init__(self, log_level): + def __init__(self): self.model = "claude-3-5-sonnet-20241022" self.client = anthropic.Anthropic() self.resolution = (1280, 800) - self.log_level = log_level self.authenticated = True if os.getenv("ANTHROPIC_API_KEY") is None: self.authenticated = False diff --git a/src/askui/models/anthropic/claude_agent.py b/src/askui/models/anthropic/claude_agent.py index c489a2dd..05599433 100644 --- a/src/askui/models/anthropic/claude_agent.py +++ b/src/askui/models/anthropic/claude_agent.py @@ -20,6 +20,8 @@ BetaToolUseBlockParam, ) +from askui.tools.agent_os import AgentOs + from ...tools.anthropic import ComputerTool, ToolCollection, ToolResult from ...logger import logger from ...utils.str_utils import truncate_long_strings @@ -60,10 +62,10 @@ class ClaudeComputerAgent: - def __init__(self, controller_client, reporter: Reporter) -> None: + def __init__(self, agent_os: AgentOs, reporter: Reporter) -> None: self._reporter = reporter self.tool_collection = ToolCollection( - ComputerTool(controller_client), + ComputerTool(agent_os), ) self.system = BetaTextBlockParam( type="text", diff --git a/src/askui/models/askui/api.py b/src/askui/models/askui/api.py index fada2dc4..cc39cc8a 100644 --- a/src/askui/models/askui/api.py +++ b/src/askui/models/askui/api.py @@ -53,6 +53,7 @@ def _request(self, endpoint: str, json: dict[str, Any] | None = None) -> Any: def predict(self, image: Union[pathlib.Path, Image.Image], locator: Locator, model: ModelComposition | None = None) -> tuple[int | None, int | None]: serialized_locator = self._locator_serializer.serialize(locator=locator) + logger.debug(f"serialized_locator:\n{json_lib.dumps(serialized_locator)}") json: dict[str, Any] = { "image": f",{image_to_base64(image)}", "instruction": f"Click on {serialized_locator['instruction']}", diff --git a/src/askui/models/router.py b/src/askui/models/router.py index e0f4d092..abefd5df 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -1,4 +1,3 @@ -import logging from typing import Type from typing_extensions import override from PIL import Image @@ -10,7 +9,9 @@ from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.models import ModelComposition, ModelName from askui.models.types.response_schemas import ResponseSchema -from askui.reporting import Reporter +from askui.reporting import CompositeReporter, Reporter +from askui.tools.askui.askui_controller import AskUiControllerClient +from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource from .askui.api import AskUiInferenceApi from .anthropic.claude import ClaudeHandler @@ -113,28 +114,28 @@ def is_authenticated(self) -> bool: class ModelRouter: def __init__( self, - reporter: Reporter, - log_level: int = logging.INFO, + tools: AgentToolbox, grounding_model_routers: list[GroundingModelRouter] | None = None, + reporter: Reporter | None = None, ): - self._reporter = reporter - self.askui = AskUiInferenceApi( + _reporter = reporter or CompositeReporter() + self._askui = AskUiInferenceApi( locator_serializer=AskUiLocatorSerializer( ai_element_collection=AiElementCollection(), ), ) - self.grounding_model_routers = grounding_model_routers or [AskUiModelRouter(inference_api=self.askui)] - self.claude = ClaudeHandler(log_level) - self.huggingface_spaces = HFSpacesHandler() - self.tars = UITarsAPIHandler(self._reporter) + self._grounding_model_routers = grounding_model_routers or [AskUiModelRouter(inference_api=self._askui)] + self._claude = ClaudeHandler() + self._huggingface_spaces = HFSpacesHandler() + self._tars = UITarsAPIHandler(agent_os=tools.agent_os, reporter=_reporter) + self._claude_computer_agent = ClaudeComputerAgent(agent_os=tools.agent_os, reporter=_reporter) self._locator_serializer = VlmLocatorSerializer() - def act(self, controller_client, goal: str, model: ModelComposition | str | None = None): - if self.tars.authenticated and model == ModelName.TARS: - return self.tars.act(controller_client, goal) - if self.claude.authenticated and (model is None or isinstance(model, str) and model.startswith(ModelName.ANTHROPIC)): - agent = ClaudeComputerAgent(controller_client, self._reporter) - return agent.run(goal) + def act(self, goal: str, model: ModelComposition | str | None = None): + if self._tars.authenticated and model == ModelName.TARS: + return self._tars.act(goal) + if self._claude.authenticated and (model is None or isinstance(model, str) and model.startswith(ModelName.ANTHROPIC)): + return self._claude_computer_agent.run(goal) raise AutomationError(f"Invalid model for act: {model}") def get_inference( @@ -144,18 +145,18 @@ def get_inference( response_schema: Type[ResponseSchema] | None = None, model: ModelComposition | str | None = None, ) -> ResponseSchema | str: - if self.tars.authenticated and model == ModelName.TARS: + if self._tars.authenticated and model == ModelName.TARS: if response_schema not in [str, None]: raise NotImplementedError("(Non-String) Response schema is not yet supported for UI-TARS models.") - return self.tars.get_inference(image=image, query=query) - if self.claude.authenticated and ( + return self._tars.get_inference(image=image, query=query) + if self._claude.authenticated and ( isinstance(model, str) and model.startswith(ModelName.ANTHROPIC) ): if response_schema not in [str, None]: raise NotImplementedError("(Non-String) Response schema is not yet supported for Anthropic models.") - return self.claude.get_inference(image=image, query=query) - if self.askui.authenticated and (model == ModelName.ASKUI or model is None): - return self.askui.get_inference( + return self._claude.get_inference(image=image, query=query) + if self._askui.authenticated and (model == ModelName.ASKUI or model is None): + return self._askui.get_inference( image=image, query=query, response_schema=response_schema, @@ -178,39 +179,39 @@ def locate( ) -> Point: if ( isinstance(model, str) - and model in self.huggingface_spaces.get_spaces_names() + and model in self._huggingface_spaces.get_spaces_names() ): - x, y = self.huggingface_spaces.predict( + x, y = self._huggingface_spaces.predict( screenshot=screenshot, locator=self._serialize_locator(locator), model_name=model, ) return handle_response((x, y), locator) if isinstance(model, str): - if model.startswith(ModelName.ANTHROPIC) and not self.claude.authenticated: + if model.startswith(ModelName.ANTHROPIC) and not self._claude.authenticated: raise AutomationError( "You need to provide Anthropic credentials to use Anthropic models." ) - if model.startswith(ModelName.TARS) and not self.tars.authenticated: + if model.startswith(ModelName.TARS) and not self._tars.authenticated: raise AutomationError( "You need to provide UI-TARS HF Endpoint credentials to use UI-TARS models." ) - if self.tars.authenticated and model == ModelName.TARS: - x, y = self.tars.locate_prediction( + if self._tars.authenticated and model == ModelName.TARS: + x, y = self._tars.locate_prediction( screenshot, self._serialize_locator(locator) ) return handle_response((x, y), locator) if ( - self.claude.authenticated + self._claude.authenticated and isinstance(model, str) and model.startswith(ModelName.ANTHROPIC) ): logger.debug("Routing locate prediction to Anthropic") - x, y = self.claude.locate_inference( + x, y = self._claude.locate_inference( screenshot, self._serialize_locator(locator) ) return handle_response((x, y), locator) - for grounding_model_router in self.grounding_model_routers: + for grounding_model_router in self._grounding_model_routers: if ( grounding_model_router.is_responsible(model) and grounding_model_router.is_authenticated() @@ -218,9 +219,9 @@ def locate( return grounding_model_router.locate(screenshot, locator, model) if model is None: - if self.claude.authenticated: + if self._claude.authenticated: logger.debug("Routing locate prediction to Anthropic") - x, y = self.claude.locate_inference( + x, y = self._claude.locate_inference( screenshot, self._serialize_locator(locator) ) return handle_response((x, y), locator) diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 663d9fc9..0bc97c96 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -4,6 +4,7 @@ from typing import Any, Union from openai import OpenAI from askui.reporting import Reporter +from askui.tools.agent_os import AgentOs from askui.utils.image_utils import image_to_base64 from PIL import Image @@ -14,7 +15,8 @@ class UITarsAPIHandler: - def __init__(self, reporter: Reporter): + def __init__(self, agent_os: AgentOs, reporter: Reporter): + self._agent_os = agent_os self._reporter = reporter if os.getenv("TARS_URL") is None or os.getenv("TARS_API_KEY") is None: self.authenticated = False @@ -83,8 +85,8 @@ def get_inference(self, image: ImageSource, query: str) -> str: prompt=PROMPT_QA, ) - def act(self, controller_client, goal: str) -> None: - screenshot = controller_client.screenshot() + def act(self, goal: str) -> None: + screenshot = self._agent_os.screenshot() self.act_history = [ { "role": "user", @@ -102,10 +104,10 @@ def act(self, controller_client, goal: str) -> None: ] } ] - self.execute_act(controller_client, self.act_history) + self.execute_act(self.act_history) - def add_screenshot_to_history(self, controller_client, message_history): - screenshot = controller_client.screenshot() + def add_screenshot_to_history(self, message_history): + screenshot = self._agent_os.screenshot() message_history.append( { "role": "user", @@ -159,7 +161,7 @@ def filter_message_thread(self, message_history, max_screenshots=3): return filtered_messages - def execute_act(self, controller_client, message_history): + def execute_act(self, message_history): message_history = self.filter_message_thread(message_history) chat_completion = self.client.chat.completions.create( @@ -195,21 +197,21 @@ def execute_act(self, controller_client, message_history): ] } ) - self.execute_act(controller_client, message_history) + self.execute_act(message_history) return action = message.parsed_action if action.action_type == "click": - controller_client.mouse(action.start_box.x, action.start_box.y) - controller_client.click("left") + self._agent_os.mouse(action.start_box.x, action.start_box.y) + self._agent_os.click("left") time.sleep(1) if action.action_type == "type": - controller_client.click("left") - controller_client.type(action.content) + self._agent_os.click("left") + self._agent_os.type(action.content) time.sleep(0.5) if action.action_type == "hotkey": - controller_client.keyboard_pressed(action.content) - controller_client.keyboard_release(action.content) + self._agent_os.keyboard_pressed(action.content) + self._agent_os.keyboard_release(action.content) time.sleep(0.5) if action.action_type == "call_user": time.sleep(1) @@ -218,5 +220,5 @@ def execute_act(self, controller_client, message_history): if action.action_type == "finished": return - self.add_screenshot_to_history(controller_client, message_history) - self.execute_act(controller_client, message_history) \ No newline at end of file + self.add_screenshot_to_history(message_history) + self.execute_act(message_history) diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 65f21545..08973427 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -29,8 +29,8 @@ def generate(self) -> None: class CompositeReporter(Reporter): - def __init__(self, reports: list[Reporter]) -> None: - self._reports = reports + def __init__(self, reports: list[Reporter] | None = None) -> None: + self._reports = reports or [] @override def add_message( diff --git a/src/askui/tools/__init__.py b/src/askui/tools/__init__.py index e69de29b..e76623ba 100644 --- a/src/askui/tools/__init__.py +++ b/src/askui/tools/__init__.py @@ -0,0 +1,3 @@ +from .toolbox import AgentToolbox + +__all__ = ["AgentToolbox"] \ No newline at end of file diff --git a/src/askui/tools/askui/__init__.py b/src/askui/tools/askui/__init__.py index e69de29b..657f2f1f 100644 --- a/src/askui/tools/askui/__init__.py +++ b/src/askui/tools/askui/__init__.py @@ -0,0 +1,3 @@ +from .askui_controller import AskUiControllerClient + +__all__ = ["AskUiControllerClient"] diff --git a/src/askui/tools/toolbox.py b/src/askui/tools/toolbox.py index 5f5694d1..0affcec9 100644 --- a/src/askui/tools/toolbox.py +++ b/src/askui/tools/toolbox.py @@ -6,10 +6,10 @@ class AgentToolbox: - def __init__(self, os: AgentOs): + def __init__(self, agent_os: AgentOs): self.webbrowser = webbrowser self.clipboard: pyperclip = pyperclip - self.os = os + self.agent_os = agent_os self._hub = AskUIHub() self.httpx = httpx diff --git a/tests/conftest.py b/tests/conftest.py index dbadb991..ce33ac4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,7 +34,7 @@ def agent_os_mock(mocker: MockerFixture) -> AgentOs: @pytest.fixture def agent_toolbox_mock(agent_os_mock: AgentOs) -> AgentToolbox: """Fixture providing a mock agent toolbox.""" - return AgentToolbox(os=agent_os_mock) + return AgentToolbox(agent_os=agent_os_mock) @pytest.fixture def model_router_mock(mocker: MockerFixture) -> ModelRouter: diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index 6d01a416..b8cb6b1f 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -37,6 +37,7 @@ def vision_agent( inference_api = AskUiInferenceApi(locator_serializer=serializer) reporter = SimpleHtmlReporter() model_router = ModelRouter( + tools=agent_toolbox_mock, reporter=reporter, grounding_model_routers=[AskUiModelRouter(inference_api=inference_api)] ) diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index 17391994..d067ef3b 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -4,7 +4,7 @@ from askui.models import ModelName from askui import VisionAgent from askui.utils.image_utils import ImageSource -from askui.response_schemas import ResponseSchemaBase +from askui import ResponseSchemaBase class UrlResponse(ResponseSchemaBase): @@ -31,7 +31,7 @@ def test_get( image=ImageSource(github_login_screenshot), model=model, ) - assert url == "github.com/login" + assert url in ["github.com/login", "https://github.com/login"] @pytest.mark.skip("Skip for now as this pops up in our observability systems as a false positive") diff --git a/tests/unit/locators/serializers/test_askui_locator_serializer.py b/tests/unit/locators/serializers/test_askui_locator_serializer.py index 67840e9d..cc6f6f23 100644 --- a/tests/unit/locators/serializers/test_askui_locator_serializer.py +++ b/tests/unit/locators/serializers/test_askui_locator_serializer.py @@ -1,14 +1,11 @@ -from dataclasses import dataclass import pathlib import re -from typing import Literal import pytest from PIL import Image as PILImage from pytest_mock import MockerFixture from askui.locators.locators import Locator from askui.locators import Element, Description, Text, Image -from askui.locators.relatable import RelationBase from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection from askui.utils.image_utils import image_to_base64 @@ -255,20 +252,6 @@ class UnsupportedLocator(Locator): askui_serializer.serialize(UnsupportedLocator()) -def test_serialize_unsupported_relation_type( - askui_serializer: AskUiLocatorSerializer, -) -> None: - @dataclass(kw_only=True) - class UnsupportedRelation(RelationBase): - type: Literal["unsupported"] # type: ignore - - text = Text("hello") - text.relations.append(UnsupportedRelation(type="unsupported", other_locator=Text("world"))) # type: ignore - - with pytest.raises(ValueError, match='Unsupported relation type: "unsupported"'): - askui_serializer.serialize(text) - - def test_serialize_simple_cycle_raises(askui_serializer: AskUiLocatorSerializer) -> None: text1 = Text("hello") text2 = Text("world") diff --git a/tests/unit/locators/test_locators.py b/tests/unit/locators/test_locators.py index 2f9d2847..86305228 100644 --- a/tests/unit/locators/test_locators.py +++ b/tests/unit/locators/test_locators.py @@ -16,7 +16,7 @@ def test_initialization_with_description(self) -> None: assert str(desc) == 'element with description "test"' def test_initialization_without_description_raises(self) -> None: - with pytest.raises(TypeError): + with pytest.raises(ValueError): Description() # type: ignore def test_initialization_with_positional_arg(self) -> None: @@ -179,7 +179,7 @@ def test_initialization_with_name(self) -> None: assert str(locator) == 'ai element named "github_com__icon"' def test_initialization_without_name_raises(self) -> None: - with pytest.raises(TypeError): + with pytest.raises(ValueError): AiElement() # type: ignore def test_initialization_with_invalid_args_raises(self) -> None: diff --git a/tests/unit/test_validate_call.py b/tests/unit/test_validate_call.py new file mode 100644 index 00000000..c8b11b29 --- /dev/null +++ b/tests/unit/test_validate_call.py @@ -0,0 +1,9 @@ +import pytest +from askui import VisionAgent + + +def test_validate_call_with_non_pydantic_invalid_types_raises_value_error(): + class InvalidModelRouter: + pass + with pytest.raises(ValueError): + VisionAgent(model_router=InvalidModelRouter()) From 5fb40b4c57bd6e18c490f89c072e09c6b240fcfd Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 17 Apr 2025 17:22:31 +0200 Subject: [PATCH 06/14] fix(reporting): fix reports overriding each other - make the file name more unique to avoid collisions --- src/askui/reporting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 08973427..8c6e36f2 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path +import random from jinja2 import Template from datetime import datetime from typing import List, Dict, Optional, Union @@ -253,5 +254,5 @@ def generate(self) -> None: system_info=self.system_info, ) - report_path = self.report_dir / f"report_{datetime.now():%Y%m%d_%H%M%S}.html" + report_path = self.report_dir / f"report_{datetime.now():%Y%m%d%H%M%S%f}{random.randint(0, 1000):03}.html" report_path.write_text(html) From c04b4b9f9c1f4812caca368c486de5783f1e271d Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 17 Apr 2025 17:23:40 +0200 Subject: [PATCH 07/14] refactor(agent): make agent more modular / better testable - allow injecting a custom controller server - move controller server starting/stopping to client - --- src/askui/agent.py | 12 ++++--- src/askui/tools/askui/askui_controller.py | 38 +++++++++++++++++++---- tests/e2e/agent/conftest.py | 9 +++--- tests/e2e/agent/test_get.py | 34 ++++++++++++++++++++ 4 files changed, 78 insertions(+), 15 deletions(-) diff --git a/src/askui/agent.py b/src/askui/agent.py index 78e41173..1ca14e66 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -72,11 +72,16 @@ def __init__( load_dotenv() configure_logging(level=log_level) self._reporter = CompositeReporter(reports=reporters) - self.tools = tools or AgentToolbox(agent_os=AskUiControllerClient(display=display, reporter=self._reporter)) + self.tools = tools or AgentToolbox( + agent_os=AskUiControllerClient( + display=display, + reporter=self._reporter, + controller_server=AskUiControllerServer() + ), + ) self.model_router = ( ModelRouter(tools=self.tools, reporter=self._reporter) if model_router is None else model_router ) - self._controller = AskUiControllerServer() self._model = model @telemetry.record_call(exclude={"locator"}) @@ -481,13 +486,10 @@ def cli( @telemetry.record_call(flush=True) def close(self) -> None: self.tools.agent_os.disconnect() - if self._controller: - self._controller.stop(True) self._reporter.generate() @telemetry.record_call() def open(self) -> None: - self._controller.start(True) self.tools.agent_os.connect() @telemetry.record_call() diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index 89125ca9..65c0506d 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod import pathlib from typing import Literal from typing_extensions import Self, override @@ -58,9 +59,29 @@ def validate_either_component_registry_or_installation_directory_is_set(self) -> if self.component_registry_file is None and self.installation_directory is None: raise ValueError("Either ASKUI_COMPONENT_REGISTRY_FILE or ASKUI_INSTALLATION_DIRECTORY environment variable must be set") return self + + +class ControllerServer(ABC): + @abstractmethod + def start(self, clean_up: bool = False) -> None: + raise NotImplementedError() + + @abstractmethod + def stop(self, force: bool = False) -> None: + raise NotImplementedError() + + +class EmptyControllerServer(ControllerServer): + @override + def start(self, clean_up: bool = False) -> None: + pass + + @override + def stop(self, force: bool = False) -> None: + pass -class AskUiControllerServer: +class AskUiControllerServer(ControllerServer): def __init__(self) -> None: self._process = None self._settings = AskUiControllerSettings() # type: ignore @@ -97,8 +118,9 @@ def _find_remote_device_controller_by_legacy_path(self) -> pathlib.Path: def __start_process(self, path): self.process = subprocess.Popen(path) wait_for_port(23000) - - def start(self, clean_up=False): + + @override + def start(self, clean_up: bool = False) -> None: if sys.platform == 'win32' and clean_up and process_exists("AskuiRemoteDeviceController.exe"): self.clean_up() remote_device_controller_path = self._find_remote_device_controller() @@ -111,7 +133,8 @@ def clean_up(self): subprocess.run("taskkill.exe /IM AskUI*") time.sleep(0.1) - def stop(self, force=False): + @override + def stop(self, force: bool = False) -> None: if force: self.process.terminate() self.clean_up() @@ -121,7 +144,7 @@ def stop(self, force=False): class AskUiControllerClient(AgentOs): @telemetry.record_call(exclude={"report"}) - def __init__(self, reporter: Reporter, display: int = 1) -> None: + def __init__(self, reporter: Reporter, display: int = 1, controller_server: ControllerServer | None = None) -> None: self.stub = None self.channel = None self.session_info = None @@ -130,10 +153,12 @@ def __init__(self, reporter: Reporter, display: int = 1) -> None: self.max_retries = 10 self.display = display self._reporter = reporter + self._controller_server = controller_server or EmptyControllerServer() @telemetry.record_call() @override def connect(self) -> None: + self._controller_server.start() self.channel = grpc.insecure_channel('localhost:23000', options=[ ('grpc.max_send_message_length', 2**30 ), ('grpc.max_receive_message_length', 2**30 ), @@ -165,7 +190,8 @@ def disconnect(self) -> None: self._stop_execution() self._stop_session() self.channel.close() - + self._controller_server.stop() + @telemetry.record_call() def __enter__(self) -> Self: self.connect() diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index b8cb6b1f..71dd2c81 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -1,7 +1,7 @@ """Shared pytest fixtures for e2e tests.""" import pathlib -from typing import Optional, Union +from typing import Generator, Optional, Union from typing_extensions import override import pytest from PIL import Image as PILImage @@ -28,7 +28,7 @@ def generate(self) -> None: @pytest.fixture def vision_agent( path_fixtures: pathlib.Path, agent_toolbox_mock: AgentToolbox -) -> VisionAgent: +) -> Generator[VisionAgent, None, None]: """Fixture providing a VisionAgent instance.""" ai_element_collection = AiElementCollection( additional_ai_element_locations=[path_fixtures / "images"] @@ -41,9 +41,10 @@ def vision_agent( reporter=reporter, grounding_model_routers=[AskUiModelRouter(inference_api=inference_api)] ) - return VisionAgent( + with VisionAgent( reporters=[reporter], model_router=model_router, tools=agent_toolbox_mock - ) + ) as agent: + yield agent @pytest.fixture diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index d067ef3b..9229bb7e 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -171,3 +171,37 @@ def test_get_with_float_schema( ) assert isinstance(response, float) assert response > 0 + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_returns_str_when_no_schema_specified( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "What is the display showing?", + image=ImageSource(github_login_screenshot), + model=model, + ) + assert isinstance(response, str) + + +class Basis(ResponseSchemaBase): + answer: str + + +@pytest.mark.parametrize("model", [ModelName.ASKUI]) +def test_get_with_basis_schema( + vision_agent: VisionAgent, + github_login_screenshot: PILImage.Image, + model: str, +) -> None: + response = vision_agent.get( + "What is the display showing?", + image=ImageSource(github_login_screenshot), + response_schema=Basis, + model=model, + ) + assert isinstance(response, Basis) + assert response.answer != "\"What is the display showing?\"" From e4bbf11a375926a4bddde0e9f444c7c3a54d7236 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 17 Apr 2025 17:47:35 +0200 Subject: [PATCH 08/14] feat(agent): make it easier to pass image to locate() and get() - allow passing PIL Image, path or data url instead of custom type --- README.md | 5 +- src/askui/agent.py | 294 +++++++++++++++++++-------------- src/askui/utils/image_utils.py | 5 +- tests/e2e/agent/test_get.py | 25 ++- 4 files changed, 188 insertions(+), 141 deletions(-) diff --git a/README.md b/README.md index 5183b739..d1d7aabe 100644 --- a/README.md +++ b/README.md @@ -414,14 +414,13 @@ Instead of taking a screenshot, you can analyze specific images: ```python from PIL import Image -from askui.utils.image_utils import ImageSource # From PIL Image image = Image.open("screenshot.png") -result = agent.get("What's in this image?", ImageSource(image)) +result = agent.get("What's in this image?", image) # From file path -result = agent.get("What's in this image?", ImageSource("screenshot.png")) +result = agent.get("What's in this image?", "screenshot.png") ``` #### Using response schemas diff --git a/src/askui/agent.py b/src/askui/agent.py index 1ca14e66..b6a4aceb 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -5,7 +5,7 @@ from askui.container import telemetry from askui.locators.locators import Locator -from askui.utils.image_utils import ImageSource +from askui.utils.image_utils import ImageSource, Img from .tools.askui.askui_controller import ( AskUiControllerClient, @@ -20,7 +20,6 @@ from .reporting import CompositeReporter, Reporter import time from dotenv import load_dotenv -from PIL import Image from .models.types.response_schemas import ResponseSchema @@ -97,23 +96,27 @@ def click( Simulates a mouse click on the user interface element identified by the provided locator. Parameters: - locator (str | Locator | None): The identifier or description of the element to click. - button ('left' | 'middle' | 'right'): Specifies which mouse button to click. Defaults to 'left'. - repeat (int): The number of times to click. Must be greater than 0. Defaults to 1. - model (ModelComposition | str | None): The composition or name of the model(s) to be used for locating the element to click on using the `locator`. + locator (str | Locator | None): + The identifier or description of the element to click. If None, clicks at current position. + button ('left' | 'middle' | 'right'): + Specifies which mouse button to click. Defaults to 'left'. + repeat (int): + The number of times to click. Must be greater than 0. Defaults to 1. + model (ModelComposition | str | None): + The composition or name of the model(s) to be used for locating the element to click on using the `locator`. Raises: InvalidParameterError: If the 'repeat' parameter is less than 1. Example: - ```python - with VisionAgent() as agent: - agent.click() # Left click on current position - agent.click("Edit") # Left click on text "Edit" - agent.click("Edit", button="right") # Right click on text "Edit" - agent.click(repeat=2) # Double left click on current position - agent.click("Edit", button="middle", repeat=4) # 4x middle click on text "Edit" - ``` + ```python + with VisionAgent() as agent: + agent.click() # Left click on current position + agent.click("Edit") # Left click on text "Edit" + agent.click("Edit", button="right") # Right click on text "Edit" + agent.click(repeat=2) # Double left click on current position + agent.click("Edit", button="middle", repeat=4) # 4x middle click on text "Edit" + ``` """ if repeat < 1: raise InvalidParameterError("InvalidParameterError! The parameter 'repeat' needs to be greater than 0.") @@ -130,10 +133,9 @@ def click( self._mouse_move(locator, model or self._model) self.tools.agent_os.click(button, repeat) # type: ignore - def _locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model: ModelComposition | str | None = None) -> Point: - if screenshot is None: - screenshot = self.tools.agent_os.screenshot() # type: ignore - point = self.model_router.locate(screenshot, locator, model or self._model) + def _locate(self, locator: str | Locator, screenshot: Optional[Img] = None, model: ModelComposition | str | None = None) -> Point: + _screenshot = ImageSource(self.tools.agent_os.screenshot() if screenshot is None else screenshot) + point = self.model_router.locate(_screenshot.root, locator, model or self._model) self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") return point @@ -141,19 +143,30 @@ def _locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = No def locate( self, locator: str | Locator, - screenshot: Optional[Image.Image] = None, + screenshot: Optional[Img] = None, model: ModelComposition | str | None = None, ) -> Point: """ Locates the UI element identified by the provided locator. - Args: - locator (str | Locator): The identifier or description of the element to locate. - screenshot (Optional[Image.Image], optional): The screenshot to use for locating the element. Defaults to None. - model (ModelComposition | str | None): The composition or name of the model(s) to be used for locating the element using the `locator`. + Parameters: + locator (str | Locator): + The identifier or description of the element to locate. + screenshot (Img | None, optional): + The screenshot to use for locating the element. Can be a path to an image file, a PIL Image object or a data URL. + If None, takes a screenshot of the currently selected display. + model (ModelComposition | str | None): + The composition or name of the model(s) to be used for locating the element using the `locator`. Returns: - Point: The coordinates of the element. + Point: The coordinates of the element as a tuple (x, y). + + Example: + ```python + with VisionAgent() as agent: + point = agent.locate("Submit button") + print(f"Element found at coordinates: {point}") + ``` """ self._reporter.add_message("User", f"locate {locator}") logger.debug("VisionAgent received instruction to locate %s", locator) @@ -174,16 +187,18 @@ def mouse_move( Moves the mouse cursor to the UI element identified by the provided locator. Parameters: - locator (str | Locator): The identifier or description of the element to move to. - model (ModelComposition | str | None): The composition or name of the model(s) to be used for locating the element to move the mouse to using the `locator`. + locator (str | Locator): + The identifier or description of the element to move to. + model (ModelComposition | str | None): + The composition or name of the model(s) to be used for locating the element to move the mouse to using the `locator`. Example: - ```python - with VisionAgent() as agent: - agent.mouse_move("Submit button") # Moves cursor to submit button - agent.mouse_move("Close") # Moves cursor to close element - agent.mouse_move("Profile picture", model="custom_model") # Uses specific model - ``` + ```python + with VisionAgent() as agent: + agent.mouse_move("Submit button") # Moves cursor to submit button + agent.mouse_move("Close") # Moves cursor to close element + agent.mouse_move("Profile picture", model="custom_model") # Uses specific model + ``` """ self._reporter.add_message("User", f'mouse_move: {locator}') logger.debug("VisionAgent received instruction to mouse_move to %s", locator) @@ -200,23 +215,25 @@ def mouse_scroll( Simulates scrolling the mouse wheel by the specified horizontal and vertical amounts. Parameters: - x (int): The horizontal scroll amount. Positive values typically scroll right, negative values scroll left. - y (int): The vertical scroll amount. Positive values typically scroll down, negative values scroll up. + x (int): + The horizontal scroll amount. Positive values typically scroll right, negative values scroll left. + y (int): + The vertical scroll amount. Positive values typically scroll down, negative values scroll up. Note: - The actual `scroll direction` depends on the operating system's configuration. + The actual scroll direction depends on the operating system's configuration. Some systems may have "natural scrolling" enabled, which reverses the traditional direction. - The meaning of scroll `units` varies` acro`ss oper`ating` systems and applications. + The meaning of scroll units varies across operating systems and applications. A scroll value of 10 might result in different distances depending on the application and system settings. Example: - ```python - with VisionAgent() as agent: - agent.mouse_scroll(0, 10) # Usually scrolls down 10 units - agent.mouse_scroll(0, -5) # Usually scrolls up 5 units - agent.mouse_scroll(3, 0) # Usually scrolls right 3 units - ``` + ```python + with VisionAgent() as agent: + agent.mouse_scroll(0, 10) # Usually scrolls down 10 units + agent.mouse_scroll(0, -5) # Usually scrolls up 5 units + agent.mouse_scroll(3, 0) # Usually scrolls right 3 units + ``` """ self._reporter.add_message("User", f'mouse_scroll: "{x}", "{y}"') self.tools.agent_os.mouse_scroll(x, y) @@ -231,15 +248,16 @@ def type( Types the specified text as if it were entered on a keyboard. Parameters: - text (str): The text to be typed. + text (str): + The text to be typed. Must be at least 1 character long. Example: - ```python - with VisionAgent() as agent: - agent.type("Hello, world!") # Types "Hello, world!" - agent.type("user@example.com") # Types an email address - agent.type("password123") # Types a password - ``` + ```python + with VisionAgent() as agent: + agent.type("Hello, world!") # Types "Hello, world!" + agent.type("user@example.com") # Types an email address + agent.type("password123") # Types a password + ``` """ self._reporter.add_message("User", f'type: "{text}"') logger.debug("VisionAgent received instruction to type '%s'", text) @@ -251,7 +269,7 @@ def get( self, query: Annotated[str, Field(min_length=1)], response_schema: None = None, - image: Optional[ImageSource] = None, + image: Optional[Img] = None, model: ModelComposition | str | None = None, ) -> str: ... @overload @@ -259,16 +277,16 @@ def get( self, query: Annotated[str, Field(min_length=1)], response_schema: Type[ResponseSchema], - image: Optional[ImageSource] = None, + image: Optional[Img] = None, model: ModelComposition | str | None = None, ) -> ResponseSchema: ... @telemetry.record_call(exclude={"query", "image", "response_schema"}) - @validate_call + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def get( self, query: Annotated[str, Field(min_length=1)], - image: Optional[ImageSource] = None, + image: Optional[Img] = None, response_schema: Type[ResponseSchema] | None = None, model: ModelComposition | str | None = None, ) -> ResponseSchema | str: @@ -278,46 +296,68 @@ def get( Parameters: query (str): The query describing what information to retrieve. - image (ImageSource | None): - The image to extract information from. Optional. Defaults to a screenshot of the current screen. - response_schema (Type[ResponseSchema] | None): - A Pydantic model class that defines the response schema. Optional. If not provided, returns a string. - model (ModelComposition | str | None): + image (Img | None, optional): + The image to extract information from. Defaults to a screenshot of the current screen. + Can be a path to an image file, a PIL Image object or a data URL. + response_schema (Type[ResponseSchema] | None, optional): + A Pydantic model class that defines the response schema. If not provided, returns a string. + model (ModelComposition | str | None, optional): The composition or name of the model(s) to be used for retrieving information from the screen or image using the `query`. - Note: `response_schema` is only supported with not supported by all models. + Note: `response_schema` is not supported by all models. Returns: - ResponseSchema: The extracted information, either as an instance of ResponseSchemaBase or the primite type passed or string if no response_schema is provided. + ResponseSchema | str: + The extracted information, either as an instance of ResponseSchema or string if no response_schema is provided. Limitations: - Nested Pydantic schemas are not currently supported - Schema support is only available with "askui" model (default model if `ASKUI_WORKSPACE_ID` and `ASKUI_TOKEN` are set) at the moment Example: - ```python - from askui import JsonSchemaBase - - class UrlResponse(JsonSchemaBase): - url: str - - with VisionAgent() as agent: - # Get URL as string - url = agent.get("What is the current url shown in the url bar?") - - # Get URL as Pydantic model - response = agent.get( - "What is the current url shown in the url bar?", - response_schema=UrlResponse - ) - print(response.url) - ``` + ```python + from askui import JsonSchemaBase + from PIL import Image + + class UrlResponse(JsonSchemaBase): + url: str + + with VisionAgent() as agent: + # Get URL as string + url = agent.get("What is the current url shown in the url bar?") + + # Get URL as Pydantic model from image at (relative) path + response = agent.get( + "What is the current url shown in the url bar?", + response_schema=UrlResponse, + image="screenshot.png", + ) + print(response.url) + + # Get boolean response from PIL Image + is_login_page = agent.get( + "Is this a login page?", + response_schema=bool, + image=Image.open("screenshot.png"), + ) + + # Get integer response + input_count = agent.get( + "How many input fields are visible on this page?", + response_schema=int, + ) + + # Get float response + design_rating = agent.get( + "Rate the page design quality from 0 to 1", + response_schema=float, + ) + ``` """ self._reporter.add_message("User", f'get: "{query}"') logger.debug("VisionAgent received instruction to get '%s'", query) - if image is None: - image = ImageSource(self.tools.agent_os.screenshot()) # type: ignore + _image = ImageSource(self.tools.agent_os.screenshot() if image is None else image) # type: ignore response = self.model_router.get_inference( - image=image, + image=_image, query=query, model=model or self._model, response_schema=response_schema, @@ -337,17 +377,18 @@ def wait( Pauses the execution of the program for the specified number of seconds. Parameters: - sec (float): The number of seconds to wait. Must be greater than 0.0. + sec (float): + The number of seconds to wait. Must be greater than 0.0. Raises: ValueError: If the provided `sec` is negative. Example: - ```python - with VisionAgent() as agent: - agent.wait(5) # Pauses execution for 5 seconds - agent.wait(0.5) # Pauses execution for 500 milliseconds - ``` + ```python + with VisionAgent() as agent: + agent.wait(5) # Pauses execution for 5 seconds + agent.wait(0.5) # Pauses execution for 500 milliseconds + ``` """ time.sleep(sec) @@ -361,14 +402,15 @@ def key_up( Simulates the release of a key. Parameters: - key (PcKey | ModifierKey): The key to be released. + key (PcKey | ModifierKey): + The key to be released. Example: - ```python - with VisionAgent() as agent: - agent.key_up('a') # Release the 'a' key - agent.key_up('shift') # Release the 'Shift' key - ``` + ```python + with VisionAgent() as agent: + agent.key_up('a') # Release the 'a' key + agent.key_up('shift') # Release the 'Shift' key + ``` """ self._reporter.add_message("User", f'key_up "{key}"') logger.debug("VisionAgent received in key_up '%s'", key) @@ -384,14 +426,15 @@ def key_down( Simulates the pressing of a key. Parameters: - key (PcKey | ModifierKey): The key to be pressed. + key (PcKey | ModifierKey): + The key to be pressed. Example: - ```python - with VisionAgent() as agent: - agent.key_down('a') # Press the 'a' key - agent.key_down('shift') # Press the 'Shift' key - ``` + ```python + with VisionAgent() as agent: + agent.key_down('a') # Press the 'a' key + agent.key_down('shift') # Press the 'Shift' key + ``` """ self._reporter.add_message("User", f'key_down "{key}"') logger.debug("VisionAgent received in key_down '%s'", key) @@ -412,16 +455,18 @@ def act( interface interactions. Parameters: - goal (str): A description of what the agent should achieve. - model (ModelComposition | str | None): The composition or name of the model(s) to be used for achieving the `goal`. + goal (str): + A description of what the agent should achieve. + model (ModelComposition | str | None, optional): + The composition or name of the model(s) to be used for achieving the `goal`. Example: - ```python - with VisionAgent() as agent: - agent.act("Open the settings menu") - agent.act("Search for 'printer' in the search box") - agent.act("Log in with username 'admin' and password '1234'") - ``` + ```python + with VisionAgent() as agent: + agent.act("Open the settings menu") + agent.act("Search for 'printer' in the search box") + agent.act("Log in with username 'admin' and password '1234'") + ``` """ self._reporter.add_message("User", f'act: "{goal}"') logger.debug( @@ -440,19 +485,19 @@ def keyboard( Simulates pressing a key or key combination on the keyboard. Parameters: - key (PcKey | ModifierKey): The main key to press. This can be a letter, number, - special character, or function key. - modifier_keys (list[MODIFIER_KEY] | None): Optional list of modifier keys to press - along with the main key. Common modifier keys include 'ctrl', 'alt', 'shift'. + key (PcKey | ModifierKey): + The main key to press. This can be a letter, number, special character, or function key. + modifier_keys (list[ModifierKey] | None, optional): + List of modifier keys to press along with the main key. Common modifier keys include 'ctrl', 'alt', 'shift'. Example: - ```python - with VisionAgent() as agent: - agent.keyboard('a') # Press 'a' key - agent.keyboard('enter') # Press 'Enter' key - agent.keyboard('v', ['control']) # Press Ctrl+V (paste) - agent.keyboard('s', ['control', 'shift']) # Press Ctrl+Shift+S - ``` + ```python + with VisionAgent() as agent: + agent.keyboard('a') # Press 'a' key + agent.keyboard('enter') # Press 'Enter' key + agent.keyboard('v', ['control']) # Press Ctrl+V (paste) + agent.keyboard('s', ['control', 'shift']) # Press Ctrl+Shift+S + ``` """ logger.debug("VisionAgent received instruction to press '%s'", key) self.tools.agent_os.keyboard_tap(key, modifier_keys) # type: ignore @@ -470,15 +515,16 @@ def cli( is split on spaces and executed as a subprocess. Parameters: - command (str): The command to execute on the command line. + command (str): + The command to execute on the command line. Example: - ```python - with VisionAgent() as agent: - agent.cli("echo Hello World") # Prints "Hello World" - agent.cli("ls -la") # Lists files in current directory with details - agent.cli("python --version") # Displays Python version - ``` + ```python + with VisionAgent() as agent: + agent.cli("echo Hello World") # Prints "Hello World" + agent.cli("ls -la") # Lists files in current directory with details + agent.cli("python --version") # Displays Python version + ``` """ logger.debug("VisionAgent received instruction to execute '%s' on cli", command) subprocess.run(command.split(" ")) diff --git a/src/askui/utils/image_utils.py b/src/askui/utils/image_utils.py index 831e76f4..dc677540 100644 --- a/src/askui/utils/image_utils.py +++ b/src/askui/utils/image_utils.py @@ -247,6 +247,9 @@ def scale_coordinates_back( return original_x, original_y +Img = Union[str, Path, PILImage.Image] + + class ImageSource(RootModel): """ A Pydantic model that represents an image source and provides methods to convert it to different formats. @@ -260,7 +263,7 @@ class ImageSource(RootModel): model_config = ConfigDict(arbitrary_types_allowed=True) root: PILImage.Image - def __init__(self, root: Union[str, Path, PILImage.Image], **kwargs) -> None: + def __init__(self, root: Img, **kwargs) -> None: super().__init__(root=root, **kwargs) @field_validator("root", mode="before") diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index 9229bb7e..73ae576f 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -3,7 +3,6 @@ from PIL import Image as PILImage from askui.models import ModelName from askui import VisionAgent -from askui.utils.image_utils import ImageSource from askui import ResponseSchemaBase @@ -28,7 +27,7 @@ def test_get( ) -> None: url = vision_agent.get( "What is the current url shown in the url bar?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, model=model, ) assert url in ["github.com/login", "https://github.com/login"] @@ -42,7 +41,7 @@ def test_get_with_response_schema_without_additional_properties_with_askui_model with pytest.raises(Exception): vision_agent.get( "What is the current url shown in the url bar?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=UrlResponse, model=ModelName.ASKUI, ) @@ -56,7 +55,7 @@ def test_get_with_response_schema_without_required_with_askui_model_raises( with pytest.raises(Exception): vision_agent.get( "What is the current url shown in the url bar?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=UrlResponse, model=ModelName.ASKUI, ) @@ -70,7 +69,7 @@ def test_get_with_response_schema( ) -> None: response = vision_agent.get( "What is the current url shown in the url bar?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=UrlResponse, model=model, ) @@ -85,7 +84,7 @@ def test_get_with_response_schema_with_anthropic_model_raises_not_implemented( with pytest.raises(NotImplementedError): vision_agent.get( "What is the current url shown in the url bar?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=UrlResponse, model=ModelName.ANTHROPIC, ) @@ -100,7 +99,7 @@ def test_get_with_nested_and_inherited_response_schema( ) -> None: response = vision_agent.get( "What is the current browser context?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=BrowserContextResponse, model=model, ) @@ -118,7 +117,7 @@ def test_get_with_string_schema( ) -> None: response = vision_agent.get( "What is the current url shown in the url bar?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=str, model=model, ) @@ -133,7 +132,7 @@ def test_get_with_boolean_schema( ) -> None: response = vision_agent.get( "Is this a login page?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=bool, model=model, ) @@ -149,7 +148,7 @@ def test_get_with_integer_schema( ) -> None: response = vision_agent.get( "How many input fields are visible on this page?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=int, model=model, ) @@ -165,7 +164,7 @@ def test_get_with_float_schema( ) -> None: response = vision_agent.get( "Return a floating point number between 0 and 1 as a rating for how you well this page is designed (0 is the worst, 1 is the best)", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=float, model=model, ) @@ -181,7 +180,7 @@ def test_get_returns_str_when_no_schema_specified( ) -> None: response = vision_agent.get( "What is the display showing?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, model=model, ) assert isinstance(response, str) @@ -199,7 +198,7 @@ def test_get_with_basis_schema( ) -> None: response = vision_agent.get( "What is the display showing?", - image=ImageSource(github_login_screenshot), + image=github_login_screenshot, response_schema=Basis, model=model, ) From 30bac04665f4e1033af20e643b3995891f171c98 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 22 Apr 2025 16:00:49 +0200 Subject: [PATCH 09/14] docs(locators): improve docs of relations --- README.md | 2 +- src/askui/locators/relatable.py | 641 +++++++++++++++++++++++++++++++- 2 files changed, 638 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index d1d7aabe..fc3bb70f 100644 --- a/README.md +++ b/README.md @@ -377,7 +377,7 @@ Example: from askui import locators as loc password_textfield_label = loc.Text("Password") -password_textfield = loc.Class("textfield").right_of(password_textfield_label) +password_textfield = loc.Element("textfield").right_of(password_textfield_label) agent.click(password_textfield) agent.type("********") diff --git a/src/askui/locators/relatable.py b/src/askui/locators/relatable.py index 69c0774a..c3ef846e 100644 --- a/src/askui/locators/relatable.py +++ b/src/askui/locators/relatable.py @@ -42,7 +42,6 @@ def __str__(self): return f"{RelationTypeMapping[self.type]} {self.other_locator._str_with_relation()}" - class NeighborRelation(RelationBase): type: Literal["above_of", "below_of", "right_of", "left_of"] index: RelationIndex @@ -66,7 +65,6 @@ def __str__(self): return f"{RelationTypeMapping[self.type]}{reference_point_str} the {index_str} {self.other_locator._str_with_relation()}" - class LogicalRelation(RelationBase): type: Literal["and", "or"] @@ -102,6 +100,7 @@ class Relatable(ABC): Attributes: relations: List of relations to other locators """ + def __init__(self) -> None: self._relations: list[Relation] = [] @@ -116,7 +115,138 @@ def above_of( index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: - + """Defines the element (located by *self*) to be **above** another element / + other elements (located by *other_locator*). + + An element **A** is considered to be *above* another element / other elements **B** + + - if most of **A** (or, more specifically, **A**'s bounding box) is *above* **B** + (or, more specifically, the **top border** of **B**'s bounding box) **and** + - if the **bottom border** of **A** (or, more specifically, **A**'s bounding box) + is *above* the **bottom border** of **B** (or, more specifically, **B**'s + bounding box). + + Args: + other_locator: + Locator for an element / elements to relate to + index: + Index of the element (located by *self*) above the other element(s) + (located by *other_locator*), e.g., the first (index=0), second + (index=1), third (index=2) etc. element above the other element(s). + Elements' (relative) position is determined by the **bottom border** + (*y*-coordinate) of their bounding box. + We don't guarantee the order of elements with the same bottom border + (*y*-coordinate). + reference_point: + Defines which element (located by *self*) is considered to be above the + other element(s) (located by *other_locator*): + + **"center"**: One point of the element (located by *self*) is above the + center (in a straight vertical line) of the other element(s) (located + by *other_locator*). + **"boundary"**: One point of the element (located by *self*) is above + any other point (in a straight vertical line) of the other element(s) + (located by *other_locator*). + **"any"**: No point of the element (located by *self*) has to be above + a point (in a straight vertical line) of the other element(s) (located + by *other_locator*). + + *Default is **"boundary".*** + + Returns: + Self: The locator with the relation added + + Examples: + ```text + + =========== + | A | + =========== + =========== + | B | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element above ("center" of) + # text "B" + text = loc.Text().above_of(loc.Text("B"), reference_point="center") + ``` + + ```text + + =========== + | A | + =========== + =========== + | B | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element above + # ("boundary" of / any point of) text "B" + # (reference point "center" won't work here) + text = loc.Text().above_of(loc.Text("B"), reference_point="boundary") + ``` + + ```text + + =========== + | A | + =========== + =========== + | B | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element above text "B" + # (reference point "center" or "boundary" won't work here) + text = loc.Text().above_of(loc.Text("B"), reference_point="any") + ``` + + ```text + + =========== + | A | + =========== + =========== + | B | + =========== + =========== + | C | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element above text "C" + # (reference point "center" or "boundary" won't work here) + text = loc.Text().above_of(loc.Text("C"), index=1, reference_point="any") + ``` + + ```text + + =========== + | A | + =========== + =========== + =========== | B | + | | =========== + | C | + | | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element above text "C" + # (reference point "any") + text = loc.Text().above_of(loc.Text("C"), index=1, reference_point="any") + # locates also text "A" as it is the first (index 0) element above text "C" + # with reference point "boundary" + text = loc.Text().above_of(loc.Text("C"), index=0, reference_point="boundary") + ``` + """ self._relations.append( NeighborRelation( type="above_of", @@ -134,6 +264,138 @@ def below_of( index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: + """Defines the element (located by *self*) to be **below** another element / + other elements (located by *other_locator*). + + An element **A** is considered to be *below* another element / other elements **B** + + - if most of **A** (or, more specifically, **A**'s bounding box) is *below* **B** + (or, more specifically, the **bottom border** of **B**'s bounding box) **and** + - if the **top border** of **A** (or, more specifically, **A**'s bounding box) is + *below* the **top border** of **B** (or, more specifically, **B**'s bounding + box). + + Args: + other_locator: + Locator for an element / elements to relate to. + index: + Index of the element (located by *self*) **below** the other + element(s) (located by *other_locator*), e.g., the first (*index=0*), + second (*index=1*), third (*index=2*) etc. element below the other + element(s). Elements' (relative) position is determined by the **top + border** (*y*-coordinate) of their bounding box. + We don't guarantee the order of elements with the same top border + (*y*-coordinate). + reference_point: + Defines which element (located by *self*) is considered to be + *below* the other element(s) (located by *other_locator*): + + **"center"**: One point of the element (located by *self*) is + **below** the *center* (in a straight vertical line) of the other + element(s) (located by *other_locator*). + **"boundary"**: One point of the element (located by *self*) is + **below** *any* other point (in a straight vertical line) of the + other element(s) (located by *other_locator*). + **"any"**: No point of the element (located by *self*) has to + be **below** a point (in a straight vertical line) of the other + element(s) (located by *other_locator*). + + *Default is **"boundary".*** + + Returns: + Self: The locator with the relation added. + + Examples: + ```text + + =========== + | B | + =========== + =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element below ("center" of) + # text "B" + text = loc.Text().below_of(loc.Text("B"), reference_point="center") + ``` + + ```text + + =========== + | B | + =========== + =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element below + # ("boundary" of / any point of) text "B" + # (reference point "center" won't work here) + text = loc.Text().below_of(loc.Text("B"), reference_point="boundary") + ``` + + ```text + + =========== + | B | + =========== + =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element below text "B" + # (reference point "center" or "boundary won't work here) + text = loc.Text().below_of(loc.Text("B"), reference_point="any") + ``` + + ```text + + =========== + | C | + =========== + =========== + | B | + =========== + =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element below text "C" + # (reference point "center" or "boundary" won't work here) + text = loc.Text().below_of(loc.Text("C"), index=1, reference_point="any") + ``` + + ```text + + =========== + | | + | C | + | |=========== + ===========| B | + =========== + =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element below text "C" + # (reference point "any") + text = loc.Text().below_of(loc.Text("C"), index=1, reference_point="any") + # locates also text "A" as it is the first (index 0) element below text "C" + # with reference point "boundary" + text = loc.Text().below_of(loc.Text("C"), index=0, reference_point="boundary") + ``` + """ self._relations.append( NeighborRelation( type="below_of", @@ -151,6 +413,128 @@ def right_of( index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: + """Defines the element (located by *self*) to be **right of** another element / + other elements (located by *other_locator*). + + An element **A** is considered to be *right of* another element / other elements **B** + + - if most of **A** (or, more specifically, **A**'s bounding box) is *right of* **B** + (or, more specifically, the **right border** of **B**'s bounding box) **and** + - if the **left border** of **A** (or, more specifically, **A**'s bounding box) is + *right of* the **left border** of **B** (or, more specifically, **B**'s + bounding box). + + Args: + other_locator: + Locator for an element / elements to relate to. + index: + Index of the element (located by *self*) **right of** the other + element(s) (located by *other_locator*), e.g., the first (*index=0*), + second (*index=1*), third (*index=2*) etc. element right of the other + element(s). Elements' (relative) position is determined by the **left + border** (*x*-coordinate) of their bounding box. + We don't guarantee the order of elements with the same left border + (*x*-coordinate). + reference_point: + Defines which element (located by *self*) is considered to be + *right of* the other element(s) (located by *other_locator*): + + **"center"**: One point of the element (located by *self*) is + **right of** the *center* (in a straight horizontal line) of the + other element(s) (located by *other_locator*). + **"boundary"**: One point of the element (located by *self*) is + **right of** *any* other point (in a straight horizontal line) of + the other element(s) (located by *other_locator*). + **"any"**: No point of the element (located by *self*) has to + be **right of** a point (in a straight horizontal line) of the + other element(s) (located by *other_locator*). + + *Default is **"boundary".*** + + Returns: + Self: The locator with the relation added. + + Examples: + ```text + + =========== =========== + | B | | A | + =========== =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element right of ("center" + # of) text "B" + text = loc.Text().right_of(loc.Text("B"), reference_point="center") + ``` + + ```text + + =========== + | B | + =========== =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element right of + # ("boundary" of / any point of) text "B" + # (reference point "center" won't work here) + text = loc.Text().right_of(loc.Text("B"), reference_point="boundary") + ``` + + ```text + + =========== + | B | + =========== + =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element right of text "B" + # (reference point "center" or "boundary" won't work here) + text = loc.Text().right_of(loc.Text("B"), reference_point="any") + ``` + + ```text + + =========== + | A | + =========== + =========== =========== + | C | | B | + =========== =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element right of text "C" + # (reference point "center" or "boundary" won't work here) + text = loc.Text().right_of(loc.Text("C"), index=1, reference_point="any") + ``` + + ```text + + =========== + | B | + =========== =========== + =========== | A | + | C | =========== + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element right of text "C" + # (reference point "any") + text = loc.Text().right_of(loc.Text("C"), index=1, reference_point="any") + # locates also text "A" as it is the first (index 0) element right of text + # "C" with reference point "boundary" + text = loc.Text().right_of(loc.Text("C"), index=0, reference_point="boundary") + ``` + """ self._relations.append( NeighborRelation( type="right_of", @@ -168,6 +552,127 @@ def left_of( index: RelationIndex = 0, reference_point: ReferencePoint = "boundary", ) -> Self: + """Defines the element (located by *self*) to be **left of** another element / + other elements (located by *other_locator*). + + An element **A** is considered to be *left of* another element / other elements **B** + + - if most of **A** (or, more specifically, **A**'s bounding box) is *left of* **B** + (or, more specifically, the **left border** of **B**'s bounding box) **and** + - if the **right border** of **A** (or, more specifically, **A**'s bounding box) is + *left of* the **right border** of **B** (or, more specifically, **B**'s + bounding box). + + Args: + other_locator: + Locator for an element / elements to relate to. + index: + Index of the element (located by *self*) **left of** the other + element(s) (located by *other_locator*), e.g., the first (*index=0*), + second (*index=1*), third (*index=2*) etc. element left of the other + element(s). Elements' (relative) position is determined by the **right + border** (*x*-coordinate) of their bounding box. + We don't guarantee the order of elements with the same right border + (*x*-coordinate). + reference_point: + Defines which element (located by *self*) is considered to be + *left of* the other element(s) (located by *other_locator*): + + **"center"** : One point of the element (located by *self*) is + **left of** the *center* (in a straight horizontal line) of the + other element(s) (located by *other_locator*). + **"boundary"**: One point of the element (located by *self*) is + **left of** *any* other point (in a straight horizontal line) of + the other element(s) (located by *other_locator*). + **"any"** : No point of the element (located by *self*) has to + be **left of** a point (in a straight horizontal line) of the + other element(s) (located by *other_locator*). + + *Default is **"boundary".*** + + Returns: + Self: The locator with the relation added. + + Examples: + ```text + + =========== =========== + | A | | B | + =========== =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element left of ("center" + # of) text "B" + text = loc.Text().left_of(loc.Text("B"), reference_point="center") + ``` + + ```text + + =========== + =========== | B | + | A | =========== + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element left of ("boundary" + # of / any point of) text "B" + # (reference point "center" won't work here) + text = loc.Text().left_of(loc.Text("B"), reference_point="boundary") + ``` + + ```text + + =========== + | B | + =========== + =========== + | A | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the first (index 0) element left of text "B" + # (reference point "center" or "boundary won't work here) + text = loc.Text().left_of(loc.Text("B"), reference_point="any") + ``` + + ```text + + =========== + | A | + =========== + =========== =========== + | B | | C | + =========== =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element left of text "C" + # (reference point "center" or "boundary" won't work here) + text = loc.Text().left_of(loc.Text("C"), index=1, reference_point="any") + ``` + + ```text + + =========== + | B | + =========== =========== + | A | =========== + =========== | C | + =========== + ``` + ```python + from askui import locators as loc + # locates text "A" as it is the second (index 1) element left of text "C" + # (reference point "any") + text = loc.Text().left_of(loc.Text("C"), index=1, reference_point="any") + # locates also text "A" as it is the first (index 0) element right of text + # "C" with reference point "boundary" + text = loc.Text().right_of(loc.Text("C"), index=0, reference_point="boundary") + ``` + """ self._relations.append( NeighborRelation( type="left_of", @@ -180,6 +685,32 @@ def left_of( # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using BoundingRelation def containing(self, other_locator: "Relatable") -> Self: + """Defines the element (located by *self*) to contain another element (located + by *other_locator*). + + Args: + other_locator: The locator to check if it's contained + + Returns: + Self: The locator with the relation added + + Examples: + ```text + --------------------------- + | textfield | + | --------------------- | + | | placeholder text | | + | --------------------- | + | | + --------------------------- + ``` + ```python + from askui import locators as loc + + # Returns the textfield because it contains the placeholder text + textfield = loc.Element("textfield").containing(loc.Text("placeholder")) + ``` + """ self._relations.append( BoundingRelation( type="containing", @@ -190,6 +721,34 @@ def containing(self, other_locator: "Relatable") -> Self: # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using BoundingRelation def inside_of(self, other_locator: "Relatable") -> Self: + """Defines the element (located by *self*) to be inside of another element + (located by *other_locator*). + + Args: + other_locator: The locator to check if it contains this element + + Returns: + Self: The locator with the relation added + + Examples: + ```text + --------------------------- + | textfield | + | --------------------- | + | | placeholder text | | + | --------------------- | + | | + --------------------------- + ``` + ```python + from askui import locators as loc + + # Returns the placeholder text of the textfield + placeholder_text = loc.Text("placeholder").inside_of( + loc.Element("textfield") + ) + ``` + """ self._relations.append( BoundingRelation( type="inside_of", @@ -200,6 +759,38 @@ def inside_of(self, other_locator: "Relatable") -> Self: # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using NearestToRelation def nearest_to(self, other_locator: "Relatable") -> Self: + """Defines the element (located by *self*) to be the nearest to another element + (located by *other_locator*). + + Args: + other_locator: The locator to compare distance against + + Returns: + Self: The locator with the relation added + + Examples: + ```text + -------------- + | text | + -------------- + --------------- + | textfield 1 | + --------------- + + + + + --------------- + | textfield 2 | + --------------- + ``` + ```python + from askui import locators as loc + + # Returns textfield 1 because it is nearer to the text than textfield 2 + textfield = loc.Element("textfield").nearest_to(loc.Text()) + ``` + """ self._relations.append( NearestToRelation( type="nearest_to", @@ -210,6 +801,27 @@ def nearest_to(self, other_locator: "Relatable") -> Self: # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using LogicalRelation def and_(self, other_locator: "Relatable") -> Self: + """Logical and operator to combine multiple locators, e.g., to require an + element to match multiple locators. + + Args: + other_locator: The locator to combine with + + Returns: + Self: The locator with the relation added + + Examples: + ```python + from askui import locators as loc + + # Searches for an element that contains the text "Google" and is a + # multi-colored Google logo (instead of, e.g., simply some text that says + # "Google") + icon_user = loc.Element().containing( + loc.Text("Google").and_(loc.Description("Multi-colored Google logo")) + ) + ``` + """ self._relations.append( LogicalRelation( type="and", @@ -220,6 +832,26 @@ def and_(self, other_locator: "Relatable") -> Self: # cannot be validated by pydantic using @validate_call because of the recursive nature of the relations --> validate using LogicalRelation def or_(self, other_locator: "Relatable") -> Self: + """Logical or operator to combine multiple locators, e.g., to provide a fallback + if no element is found for one of the locators. + + Args: + other_locator: The locator to combine with + + Returns: + Self: The locator with the relation added + + Examples: + ```python + from askui import locators as loc + + # Searches for element using a description and if the element cannot be + # found, searches for it using an image + search_icon = loc.Description("search icon").or_( + loc.Image("search_icon.png") + ) + ``` + """ self._relations.append( LogicalRelation( type="or", @@ -241,11 +873,12 @@ def _relations_str(self) -> str: return "\n" + "\n".join(result) def raise_if_cycle(self) -> None: + """Raises CircularDependencyError if the relations form a cycle (see [Cycle (graph theory)](https://en.wikipedia.org/wiki/Cycle_(graph_theory))).""" if self._has_cycle(): raise CircularDependencyError() def _has_cycle(self) -> bool: - """Check if the relations form a cycle.""" + """Check if the relations form a cycle (see [Cycle (graph theory)](https://en.wikipedia.org/wiki/Cycle_(graph_theory))).""" visited_ids: set[int] = set() recursion_stack_ids: set[int] = set() From 5406f26b296e9862e12d81a3a1de7a5ff72beca6 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 22 Apr 2025 22:11:15 +0200 Subject: [PATCH 10/14] feat(reporting): add image for get() to report --- src/askui/agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/askui/agent.py b/src/askui/agent.py index b6a4aceb..2948e88e 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -353,9 +353,9 @@ class UrlResponse(JsonSchemaBase): ) ``` """ - self._reporter.add_message("User", f'get: "{query}"') logger.debug("VisionAgent received instruction to get '%s'", query) _image = ImageSource(self.tools.agent_os.screenshot() if image is None else image) # type: ignore + self._reporter.add_message("User", f'get: "{query}"', image=_image.root) response = self.model_router.get_inference( image=_image, query=query, From 3712cfdd514640124a0ec6f1e51cf81be1bead7e Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 22 Apr 2025 22:18:56 +0200 Subject: [PATCH 11/14] refactor(locators)!: rename Description to Prompt --- src/askui/locators/__init__.py | 4 ++-- src/askui/locators/locators.py | 14 +++++++------- src/askui/locators/serializers.py | 18 +++++++++--------- src/askui/models/router.py | 11 +++++------ tests/e2e/agent/test_locate.py | 4 ++-- .../agent/test_locate_with_different_models.py | 4 ++-- tests/e2e/agent/test_locate_with_relations.py | 6 +++--- .../test_askui_locator_serializer.py | 4 ++-- .../test_locator_string_representation.py | 10 +++++----- .../serializers/test_vlm_locator_serializer.py | 4 ++-- tests/unit/locators/test_locators.py | 16 ++++++++-------- 11 files changed, 47 insertions(+), 48 deletions(-) diff --git a/src/askui/locators/__init__.py b/src/askui/locators/__init__.py index d98f9484..23964220 100644 --- a/src/askui/locators/__init__.py +++ b/src/askui/locators/__init__.py @@ -1,9 +1,9 @@ -from askui.locators.locators import AiElement, Element, Description, Image, Text +from askui.locators.locators import AiElement, Element, Prompt, Image, Text __all__ = [ "AiElement", "Element", - "Description", + "Prompt", "Image", "Text", ] diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index 93bbc04f..57b70f10 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -15,20 +15,20 @@ class Locator(Relatable, ABC): pass -class Description(Locator): - """Locator for finding ui elements by a textual description of the ui element.""" +class Prompt(Locator): + """Locator for finding ui elements by a textual prompt / description of a ui element, e.g., "green sign up button".""" @validate_call - def __init__(self, description: str) -> None: + def __init__(self, prompt: str) -> None: super().__init__() - self._description = description + self._prompt = prompt @property - def description(self) -> str: - return self._description + def prompt(self) -> str: + return self._prompt def _str_with_relation(self) -> str: - result = f'element with description "{self.description}"' + result = f'element with prompt "{self.prompt}"' return result + super()._relations_str() def __str__(self) -> str: diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py index bcef4e07..9b0ce33a 100644 --- a/src/askui/locators/serializers.py +++ b/src/askui/locators/serializers.py @@ -8,7 +8,7 @@ ImageBase, AiElement as AiElementLocator, Element, - Description, + Prompt, Image, Text, ) @@ -35,8 +35,8 @@ def serialize(self, locator: Relatable) -> str: return self._serialize_text(locator) elif isinstance(locator, Element): return self._serialize_class(locator) - elif isinstance(locator, Description): - return self._serialize_description(locator) + elif isinstance(locator, Prompt): + return self._serialize_prompt(locator) elif isinstance(locator, Image): raise NotImplementedError( "Serializing image locators is not yet supported for VLMs" @@ -50,8 +50,8 @@ def _serialize_class(self, class_: Element) -> str: else: return "an arbitrary ui element (e.g., text, button, textfield, etc.)" - def _serialize_description(self, description: Description) -> str: - return description.description + def _serialize_prompt(self, prompt: Prompt) -> str: + return prompt.prompt def _serialize_text(self, text: Text) -> str: if text.match_type == "similar": @@ -110,8 +110,8 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: result["instruction"] = self._serialize_text(locator) elif isinstance(locator, Element): result["instruction"] = self._serialize_class(locator) - elif isinstance(locator, Description): - result["instruction"] = self._serialize_description(locator) + elif isinstance(locator, Prompt): + result["instruction"] = self._serialize_prompt(locator) elif isinstance(locator, Image): result = self._serialize_image( image_locator=locator, @@ -133,9 +133,9 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: def _serialize_class(self, class_: Element) -> str: return class_.class_name or "element" - def _serialize_description(self, description: Description) -> str: + def _serialize_prompt(self, prompt: Prompt) -> str: return ( - f"pta {self._TEXT_DELIMITER}{description.description}{self._TEXT_DELIMITER}" + f"pta {self._TEXT_DELIMITER}{prompt.prompt}{self._TEXT_DELIMITER}" ) def _serialize_text(self, text: Text) -> str: diff --git a/src/askui/models/router.py b/src/askui/models/router.py index abefd5df..9e87c22c 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -3,14 +3,13 @@ from PIL import Image from askui.container import telemetry -from askui.locators.locators import AiElement, Description, Text +from askui.locators.locators import AiElement, Prompt, Text from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer from askui.locators.locators import Locator from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.models import ModelComposition, ModelName from askui.models.types.response_schemas import ResponseSchema from askui.reporting import CompositeReporter, Reporter -from askui.tools.askui.askui_controller import AskUiControllerClient from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource from .askui.api import AskUiInferenceApi @@ -83,18 +82,18 @@ def locate( ) if model == ModelName.ASKUI__PTA: logger.debug("Routing locate prediction to askui-pta") - x, y = self._inference_api.predict(screenshot, Description(locator)) + x, y = self._inference_api.predict(screenshot, Prompt(locator)) return handle_response((x, y), locator) if model == ModelName.ASKUI__OCR: logger.debug("Routing locate prediction to askui-ocr") return self._locate_with_askui_ocr(screenshot, locator) if model == ModelName.ASKUI__COMBO or model is None: logger.debug("Routing locate prediction to askui-combo") - description_locator = Description(locator) - x, y = self._inference_api.predict(screenshot, description_locator) + prompt_locator = Prompt(locator) + x, y = self._inference_api.predict(screenshot, prompt_locator) if x is None or y is None: return self._locate_with_askui_ocr(screenshot, locator) - return handle_response((x, y), description_locator) + return handle_response((x, y), prompt_locator) if model == ModelName.ASKUI__AI_ELEMENT: logger.debug("Routing click prediction to askui-ai-element") _locator = AiElement(locator) diff --git a/tests/e2e/agent/test_locate.py b/tests/e2e/agent/test_locate.py index 2edefc6a..f7cb49e1 100644 --- a/tests/e2e/agent/test_locate.py +++ b/tests/e2e/agent/test_locate.py @@ -6,7 +6,7 @@ from askui.agent import VisionAgent from askui.locators import ( - Description, + Prompt, Element, Text, AiElement, @@ -76,7 +76,7 @@ def test_locate_with_description_locator( model: str, ) -> None: """Test locating elements using a description locator.""" - locator = Description("Username textfield") + locator = Prompt("Username textfield") x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) diff --git a/tests/e2e/agent/test_locate_with_different_models.py b/tests/e2e/agent/test_locate_with_different_models.py index 8b3ad9cd..2a3b887e 100644 --- a/tests/e2e/agent/test_locate_with_different_models.py +++ b/tests/e2e/agent/test_locate_with_different_models.py @@ -5,7 +5,7 @@ from askui.agent import VisionAgent from askui.locators import ( - Description, + Prompt, Text, AiElement, ) @@ -66,7 +66,7 @@ def test_locate_with_ocr_model_fails_with_wrong_locator( model: str, ) -> None: """Test that OCR model fails with wrong locator type.""" - locator = Description("Forgot password?") + locator = Prompt("Forgot password?") with pytest.raises(AutomationError): vision_agent.locate(locator, github_login_screenshot, model=model) diff --git a/tests/e2e/agent/test_locate_with_relations.py b/tests/e2e/agent/test_locate_with_relations.py index 98305cc1..21a5425e 100644 --- a/tests/e2e/agent/test_locate_with_relations.py +++ b/tests/e2e/agent/test_locate_with_relations.py @@ -7,7 +7,7 @@ from askui.exceptions import ElementNotFoundError from askui.agent import VisionAgent from askui.locators import ( - Description, + Prompt, Element, Text, Image, @@ -321,7 +321,7 @@ def test_locate_with_description_and_relation( model: str, ) -> None: """Test locating elements using description with relation.""" - locator = Description("Sign in button").below_of(Description("Password field")) + locator = Prompt("Sign in button").below_of(Prompt("Password field")) x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) @@ -336,7 +336,7 @@ def test_locate_with_description_and_complex_relation( model: str, ) -> None: """Test locating elements using description with relation.""" - locator = Description("Sign in button").below_of( + locator = Prompt("Sign in button").below_of( Element("textfield").below_of(Text("Password")) ) x, y = vision_agent.locate( diff --git a/tests/unit/locators/serializers/test_askui_locator_serializer.py b/tests/unit/locators/serializers/test_askui_locator_serializer.py index cc6f6f23..9541398f 100644 --- a/tests/unit/locators/serializers/test_askui_locator_serializer.py +++ b/tests/unit/locators/serializers/test_askui_locator_serializer.py @@ -5,7 +5,7 @@ from pytest_mock import MockerFixture from askui.locators.locators import Locator -from askui.locators import Element, Description, Text, Image +from askui.locators import Element, Prompt, Text, Image from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection from askui.utils.image_utils import image_to_base64 @@ -66,7 +66,7 @@ def test_serialize_class_no_name(askui_serializer: AskUiLocatorSerializer) -> No def test_serialize_description(askui_serializer: AskUiLocatorSerializer) -> None: - desc = Description("a big red button") + desc = Prompt("a big red button") result = askui_serializer.serialize(desc) assert result["instruction"] == "pta <|string|>a big red button<|string|>" assert result["customElements"] == [] diff --git a/tests/unit/locators/serializers/test_locator_string_representation.py b/tests/unit/locators/serializers/test_locator_string_representation.py index 6bc026f2..b43433aa 100644 --- a/tests/unit/locators/serializers/test_locator_string_representation.py +++ b/tests/unit/locators/serializers/test_locator_string_representation.py @@ -1,6 +1,6 @@ import re import pytest -from askui.locators import Element, Description, Text, Image +from askui.locators import Element, Prompt, Text, Image from askui.locators.relatable import CircularDependencyError from PIL import Image as PILImage @@ -39,7 +39,7 @@ def test_class_without_name_str() -> None: def test_description_str() -> None: - desc = Description("a big red button") + desc = Prompt("a big red button") assert str(desc) == 'element with description "a big red button"' @@ -153,8 +153,8 @@ def test_mixed_locator_types_with_relations_str() -> None: def test_description_with_relation_str() -> None: - desc = Description("button") - desc.above_of(Description("input")) + desc = Prompt("button") + desc.above_of(Prompt("input")) assert ( str(desc) == 'element with description "button"\n 1. above of boundary of the 1st element with description "input"' @@ -167,7 +167,7 @@ def test_complex_relation_chain_str() -> None: Element("textfield") .right_of(Text("world", match_type="exact")) .and_( - Description("input") + Prompt("input") .below_of(Text("earth", match_type="contains")) .nearest_to(Element("textfield")) ) diff --git a/tests/unit/locators/serializers/test_vlm_locator_serializer.py b/tests/unit/locators/serializers/test_vlm_locator_serializer.py index 00ec5425..86e70c1d 100644 --- a/tests/unit/locators/serializers/test_vlm_locator_serializer.py +++ b/tests/unit/locators/serializers/test_vlm_locator_serializer.py @@ -1,6 +1,6 @@ import pytest from askui.locators.locators import Locator -from askui.locators import Element, Description, Text +from askui.locators import Element, Prompt, Text from askui.locators.locators import Image from askui.locators.relatable import CircularDependencyError from askui.locators.serializers import VlmLocatorSerializer @@ -53,7 +53,7 @@ def test_serialize_class_no_name(vlm_serializer: VlmLocatorSerializer) -> None: def test_serialize_description(vlm_serializer: VlmLocatorSerializer) -> None: - desc = Description("a big red button") + desc = Prompt("a big red button") result = vlm_serializer.serialize(desc) assert result == "a big red button" diff --git a/tests/unit/locators/test_locators.py b/tests/unit/locators/test_locators.py index 86305228..60c65571 100644 --- a/tests/unit/locators/test_locators.py +++ b/tests/unit/locators/test_locators.py @@ -3,7 +3,7 @@ import pytest from PIL import Image as PILImage -from askui.locators import Description, Element, Text, Image, AiElement +from askui.locators import Prompt, Element, Text, Image, AiElement TEST_IMAGE_PATH = Path("tests/fixtures/images/github_com__icon.png") @@ -11,24 +11,24 @@ class TestDescriptionLocator: def test_initialization_with_description(self) -> None: - desc = Description(description="test") - assert desc.description == "test" + desc = Prompt(prompt="test") + assert desc.prompt == "test" assert str(desc) == 'element with description "test"' def test_initialization_without_description_raises(self) -> None: with pytest.raises(ValueError): - Description() # type: ignore + Prompt() # type: ignore def test_initialization_with_positional_arg(self) -> None: - desc = Description("test") - assert desc.description == "test" + desc = Prompt("test") + assert desc.prompt == "test" def test_initialization_with_invalid_args_raises(self) -> None: with pytest.raises(ValueError): - Description(description=123) # type: ignore + Prompt(prompt=123) # type: ignore with pytest.raises(ValueError): - Description(123) # type: ignore + Prompt(123) # type: ignore class TestClassLocator: From 5ff4774c42b4c5b988657f0194d283b197889552 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 22 Apr 2025 22:25:09 +0200 Subject: [PATCH 12/14] feat(locators): change default reference point to center for right_of and left_of relations --- src/askui/locators/relatable.py | 8 ++++---- .../serializers/test_askui_locator_serializer.py | 4 ++-- .../test_locator_string_representation.py | 12 ++++++------ tests/unit/locators/test_locators.py | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/askui/locators/relatable.py b/src/askui/locators/relatable.py index c3ef846e..ec10b1ab 100644 --- a/src/askui/locators/relatable.py +++ b/src/askui/locators/relatable.py @@ -411,7 +411,7 @@ def right_of( self, other_locator: "Relatable", index: RelationIndex = 0, - reference_point: ReferencePoint = "boundary", + reference_point: ReferencePoint = "center", ) -> Self: """Defines the element (located by *self*) to be **right of** another element / other elements (located by *other_locator*). @@ -449,7 +449,7 @@ def right_of( be **right of** a point (in a straight horizontal line) of the other element(s) (located by *other_locator*). - *Default is **"boundary".*** + *Default is **"center".*** Returns: Self: The locator with the relation added. @@ -550,7 +550,7 @@ def left_of( self, other_locator: "Relatable", index: RelationIndex = 0, - reference_point: ReferencePoint = "boundary", + reference_point: ReferencePoint = "center", ) -> Self: """Defines the element (located by *self*) to be **left of** another element / other elements (located by *other_locator*). @@ -588,7 +588,7 @@ def left_of( be **left of** a point (in a straight horizontal line) of the other element(s) (located by *other_locator*). - *Default is **"boundary".*** + *Default is **"center".*** Returns: Self: The locator with the relation added. diff --git a/tests/unit/locators/serializers/test_askui_locator_serializer.py b/tests/unit/locators/serializers/test_askui_locator_serializer.py index 9541398f..bae48adb 100644 --- a/tests/unit/locators/serializers/test_askui_locator_serializer.py +++ b/tests/unit/locators/serializers/test_askui_locator_serializer.py @@ -143,7 +143,7 @@ def test_serialize_right_relation(askui_serializer: AskUiLocatorSerializer) -> N result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text <|string|>hello<|string|> index 0 right of intersection_area element_edge_area text <|string|>world<|string|>" + == "text <|string|>hello<|string|> index 0 right of intersection_area element_center_line text <|string|>world<|string|>" ) assert result["customElements"] == [] @@ -154,7 +154,7 @@ def test_serialize_left_relation(askui_serializer: AskUiLocatorSerializer) -> No result = askui_serializer.serialize(text) assert ( result["instruction"] - == "text <|string|>hello<|string|> index 0 left of intersection_area element_edge_area text <|string|>world<|string|>" + == "text <|string|>hello<|string|> index 0 left of intersection_area element_center_line text <|string|>world<|string|>" ) assert result["customElements"] == [] diff --git a/tests/unit/locators/serializers/test_locator_string_representation.py b/tests/unit/locators/serializers/test_locator_string_representation.py index b43433aa..84ddbb28 100644 --- a/tests/unit/locators/serializers/test_locator_string_representation.py +++ b/tests/unit/locators/serializers/test_locator_string_representation.py @@ -40,7 +40,7 @@ def test_class_without_name_str() -> None: def test_description_str() -> None: desc = Prompt("a big red button") - assert str(desc) == 'element with description "a big red button"' + assert str(desc) == 'element with prompt "a big red button"' def test_text_with_above_relation_str() -> None: @@ -66,7 +66,7 @@ def test_text_with_right_relation_str() -> None: text.right_of(Text("world")) assert ( str(text) - == 'text similar to "hello" (similarity >= 70%)\n 1. right of boundary of the 1st text similar to "world" (similarity >= 70%)' + == 'text similar to "hello" (similarity >= 70%)\n 1. right of center of the 1st text similar to "world" (similarity >= 70%)' ) @@ -75,7 +75,7 @@ def test_text_with_left_relation_str() -> None: text.left_of(Text("world")) assert ( str(text) - == 'text similar to "hello" (similarity >= 70%)\n 1. left of boundary of the 1st text similar to "world" (similarity >= 70%)' + == 'text similar to "hello" (similarity >= 70%)\n 1. left of center of the 1st text similar to "world" (similarity >= 70%)' ) @@ -157,7 +157,7 @@ def test_description_with_relation_str() -> None: desc.above_of(Prompt("input")) assert ( str(desc) - == 'element with description "button"\n 1. above of boundary of the 1st element with description "input"' + == 'element with prompt "button"\n 1. above of boundary of the 1st element with prompt "input"' ) @@ -174,7 +174,7 @@ def test_complex_relation_chain_str() -> None: ) assert ( str(text) - == 'text similar to "hello" (similarity >= 70%)\n 1. above of boundary of the 1st element with class "textfield"\n 1. right of boundary of the 1st text "world"\n 2. and element with description "input"\n 1. below of boundary of the 1st text containing text "earth"\n 2. nearest to element with class "textfield"' + == 'text similar to "hello" (similarity >= 70%)\n 1. above of boundary of the 1st element with class "textfield"\n 1. right of center of the 1st text "world"\n 2. and element with prompt "input"\n 1. below of boundary of the 1st text containing text "earth"\n 2. nearest to element with class "textfield"' ) @@ -231,7 +231,7 @@ def test_multiple_references_no_cycle_str() -> None: textfield = Element("textfield") textfield.right_of(heading) textfield.below_of(heading) - assert str(textfield) == 'element with class "textfield"\n 1. right of boundary of the 1st text similar to "heading" (similarity >= 70%)\n 2. below of boundary of the 1st text similar to "heading" (similarity >= 70%)' + assert str(textfield) == 'element with class "textfield"\n 1. right of center of the 1st text similar to "heading" (similarity >= 70%)\n 2. below of boundary of the 1st text similar to "heading" (similarity >= 70%)' def test_image_cycle_str() -> None: diff --git a/tests/unit/locators/test_locators.py b/tests/unit/locators/test_locators.py index 60c65571..9a32f1ef 100644 --- a/tests/unit/locators/test_locators.py +++ b/tests/unit/locators/test_locators.py @@ -13,7 +13,7 @@ class TestDescriptionLocator: def test_initialization_with_description(self) -> None: desc = Prompt(prompt="test") assert desc.prompt == "test" - assert str(desc) == 'element with description "test"' + assert str(desc) == 'element with prompt "test"' def test_initialization_without_description_raises(self) -> None: with pytest.raises(ValueError): From a256e5e66fd7b3a64058aff56620caf1e68fc97b Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 22 Apr 2025 22:44:13 +0200 Subject: [PATCH 13/14] docs(locators): document all parameters --- src/askui/locators/locators.py | 167 +++++++++++++++++++++++++++++---- 1 file changed, 151 insertions(+), 16 deletions(-) diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index 57b70f10..f8f1397c 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -19,7 +19,14 @@ class Prompt(Locator): """Locator for finding ui elements by a textual prompt / description of a ui element, e.g., "green sign up button".""" @validate_call - def __init__(self, prompt: str) -> None: + def __init__(self, prompt: Annotated[str, Field( + description="""A textual prompt / description of a ui element, e.g., "green sign up button".""" + )]) -> None: + """Initialize a Prompt locator. + + Args: + prompt: A textual prompt / description of a ui element, e.g., "green sign up button" + """ super().__init__() self._prompt = prompt @@ -41,8 +48,15 @@ class Element(Locator): @validate_call def __init__( self, - class_name: Literal["text", "textfield"] | None = None, + class_name: Annotated[Literal["text", "textfield"] | None, Field( + description="""The class name of the ui element, e.g., 'text' or 'textfield'.""" + )] = None, ) -> None: + """Initialize an Element locator. + + Args: + class_name: The class name of the ui element, e.g., 'text' or 'textfield' + """ super().__init__() self._class_name = class_name @@ -73,10 +87,35 @@ class Text(Element): @validate_call def __init__( self, - text: str | None = None, - match_type: TextMatchType = DEFAULT_TEXT_MATCH_TYPE, - similarity_threshold: Annotated[int, Field(ge=0, le=100)] = DEFAULT_SIMILARITY_THRESHOLD, + text: Annotated[str | None, Field( + description="""The text content of the ui element, e.g., 'Sign up'.""" + )] = None, + match_type: Annotated[TextMatchType, Field( + description="""The type of match to use. Defaults to 'similar'. + 'similar' uses a similarity threshold to determine if the text is a match. + 'exact' requires the text to be exactly the same. + 'contains' requires the text to contain the specified text. + 'regex' uses a regular expression to match the text.""" + )] = DEFAULT_TEXT_MATCH_TYPE, + similarity_threshold: Annotated[int, Field( + ge=0, + le=100, + description="""A threshold for how similar the text + needs to be to the text content of the ui element to be considered a match. + Takes values between 0 and 100 (higher is more similar). Defaults to 70. + Only used if match_type is 'similar'.""")] = DEFAULT_SIMILARITY_THRESHOLD, ) -> None: + """Initialize a Text locator. + + Args: + text: The text content of the ui element, e.g., 'Sign up' + match_type: The type of match to use. Defaults to 'similar'. 'similar' uses a similarity threshold to + determine if the text is a match. 'exact' requires the text to be exactly the same. 'contains' + requires the text to contain the specified text. 'regex' uses a regular expression to match the text. + similarity_threshold: A threshold for how similar the text needs to be to the text content of the ui + element to be considered a match. Takes values between 0 and 100 (higher is more similar). + Defaults to 70. Only used if match_type is 'similar'. + """ super().__init__() self._text = text self._match_type = match_type @@ -159,7 +198,7 @@ def image_compare_format(self) -> Literal["RGB", "grayscale", "edges"]: def _generate_name() -> str: - return f"anonymous custom element {uuid.uuid4()}" + return f"anonymous image {uuid.uuid4()}" class Image(ImageBase): @@ -168,13 +207,61 @@ class Image(ImageBase): def __init__( self, image: Union[PILImage.Image, pathlib.Path, str], - threshold: Annotated[float, Field(ge=0, le=1)] = 0.5, - stop_threshold: Annotated[float, Field(ge=0, le=1)] = 0.9, - mask: Annotated[list[tuple[float, float]] | None, Field(min_length=3)] = None, - rotation_degree_per_step: Annotated[int, Field(ge=0, lt=360)] = 0, + threshold: Annotated[float, Field( + ge=0, + le=1, + description="""A threshold for how similar UI elements need to be to the image to be considered a match. + Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly + like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality.""" + )] = 0.5, + stop_threshold: Annotated[float, Field( + ge=0, + le=1, + description="""A threshold for when to stop searching for UI elements similar to the image. As soon + as UI elements have been found that are at least as similar as the stop_threshold, the search stops. Should + be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. Important: The + stop_threshold impacts the prediction speed.""" + )] = 0.9, + mask: Annotated[list[tuple[float, float]] | None, Field( + min_length=3, + description="A polygon to match only a certain area of the image." + )] = None, + rotation_degree_per_step: Annotated[int, Field( + ge=0, + lt=360, + description="""A step size in rotation degree. Rotates the image by rotation_degree_per_step until + 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the prediction time + quite a bit. So only use it when absolutely necessary.""" + )] = 0, name: str | None = None, - image_compare_format: Literal["RGB", "grayscale", "edges"] = "grayscale", + image_compare_format: Annotated[Literal["RGB", "grayscale", "edges"], Field( + description="""A color compare style. Defaults to 'grayscale'. + Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb, + 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For + quality it is most often the other way around.""" + )] = "grayscale", ) -> None: + """Initialize an Image locator. + + Args: + image: The image to match against (PIL Image, path, or string) + threshold: A threshold for how similar UI elements need to be to the image to be considered a match. + Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly + like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality. + stop_threshold: A threshold for when to stop searching for UI elements similar to the image. As soon + as UI elements have been found that are at least as similar as the stop_threshold, the search stops. + Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. + Important: The stop_threshold impacts the prediction speed. + mask: A polygon to match only a certain area of the image. Must have at least 3 points. + rotation_degree_per_step: A step size in rotation degree. Rotates the image by rotation_degree_per_step + until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the + prediction time quite a bit. So only use it when absolutely necessary. + name: Optional name for the image. Defaults to generated UUID. + image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format + impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster + than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often + the other way around. + """ super().__init__( threshold=threshold, stop_threshold=stop_threshold, @@ -204,12 +291,60 @@ class AiElement(ImageBase): def __init__( self, name: str, - threshold: Annotated[float, Field(ge=0, le=1)] = 0.5, - stop_threshold: Annotated[float, Field(ge=0, le=1)] = 0.9, - mask: Annotated[list[tuple[float, float]] | None, Field(min_length=3)] = None, - rotation_degree_per_step: Annotated[int, Field(ge=0, lt=360)] = 0, - image_compare_format: Literal["RGB", "grayscale", "edges"] = "grayscale", + threshold: Annotated[float, Field( + ge=0, + le=1, + description="""A threshold for how similar UI elements need to be to be considered a match. + Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match). + Defaults to 0.5. Important: The threshold impacts the prediction quality.""" + )] = 0.5, + stop_threshold: Annotated[float, Field( + ge=0, + le=1, + description="""A threshold for when to stop searching for UI elements. As soon + as UI elements have been found that are at least as similar as the stop_threshold, the search stops. + Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. + Important: The stop_threshold impacts the prediction speed.""" + )] = 0.9, + mask: Annotated[list[tuple[float, float]] | None, Field( + min_length=3, + description="A polygon to match only a certain area of the image of the element saved on disk." + )] = None, + rotation_degree_per_step: Annotated[int, Field( + ge=0, + lt=360, + description="""A step size in rotation degree. Rotates the image of the element saved on disk by + rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. + Important: This increases the prediction time quite a bit. So only use it when absolutely necessary.""" + )] = 0, + image_compare_format: Annotated[Literal["RGB", "grayscale", "edges"], Field( + description="""A color compare style. Defaults to 'grayscale'. + Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb, + 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For + quality it is most often the other way around.""" + )] = "grayscale", ) -> None: + """Initialize an AiElement locator. + + Args: + name: Name of the AI element + threshold: A threshold for how similar UI elements need to be to be considered a match. Takes values + between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match). + Defaults to 0.5. Important: The threshold impacts the prediction quality. + stop_threshold: A threshold for when to stop searching for UI elements. As soon as UI elements have + been found that are at least as similar as the stop_threshold, the search stops. Should be greater + than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. Important: The + stop_threshold impacts the prediction speed. + mask: A polygon to match only a certain area of the image of the element saved on disk. Must have at + least 3 points. + rotation_degree_per_step: A step size in rotation degree. Rotates the image of the element saved on + disk by rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. + Important: This increases the prediction time quite a bit. So only use it when absolutely necessary. + image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format + impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster + than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often + the other way around. + """ super().__init__( name=name, threshold=threshold, From b7e9c576623ffb7a46562e11bdb9fe636e13bb49 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 23 Apr 2025 10:24:09 +0200 Subject: [PATCH 14/14] feat(reporting): add image of Image / AIElement to report --- src/askui/chat/__main__.py | 27 +- src/askui/locators/locators.py | 357 ++++++++++-------- src/askui/locators/relatable.py | 10 + src/askui/locators/serializers.py | 60 ++- src/askui/models/anthropic/claude.py | 2 +- src/askui/models/askui/ai_element_utils.py | 35 +- src/askui/models/router.py | 3 +- src/askui/reporting.py | 21 +- tests/e2e/agent/conftest.py | 4 +- tests/e2e/agent/test_locate.py | 15 +- .../test_askui_locator_serializer.py | 4 +- .../test_locator_string_representation.py | 6 +- tests/unit/locators/test_locators.py | 12 +- 13 files changed, 344 insertions(+), 212 deletions(-) diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 7eb98f7a..97cf18b2 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -70,7 +70,7 @@ def write_message( role: str, content: str | dict | list, timestamp: str, - image: Image.Image |str | None = None, + image: Image.Image | str | list[str | Image.Image] | list[str] | list[Image.Image] | None = None, ): _role = ROLE_MAP.get(role.lower(), UNKNOWN_ROLE) avatar = None if _role != UNKNOWN_ROLE else "❔" @@ -78,8 +78,13 @@ def write_message( st.markdown(f"*{timestamp}* - **{role}**\n\n") st.markdown(json.dumps(content, indent=2) if isinstance(content, (dict, list)) else content) if image: - img = get_image(image) if isinstance(image, str) else image - st.image(img) + if isinstance(image, list): + for img in image: + img = get_image(img) if isinstance(img, str) else img + st.image(img) + else: + img = get_image(image) if isinstance(image, str) else image + st.image(img) def save_image(image: Image.Image) -> str: @@ -93,7 +98,7 @@ class Message(TypedDict): role: str content: str | dict | list timestamp: str - image: str | None + image: str | list[str] | None class ChatHistoryAppender(Reporter): @@ -101,13 +106,21 @@ def __init__(self, session_id: str) -> None: self._session_id = session_id @override - def add_message(self, role: str, content: Union[str, dict, list], image: Image.Image | None = None) -> None: - image_path = save_image(image) if image else None + def add_message(self, role: str, content: Union[str, dict, list], image: Image.Image | list[Image.Image] | None = None) -> None: + image_paths: list[str] = [] + if image is None: + _images = [] + elif isinstance(image, list): + _images = image + else: + _images = [image] + for img in _images: + image_paths.append(save_image(img)) message = Message( role=role, content=content, timestamp=datetime.now().isoformat(), - image=image_path, + image=image_paths, ) write_message(**message) with open( diff --git a/src/askui/locators/locators.py b/src/askui/locators/locators.py index f8f1397c..24bc569a 100644 --- a/src/askui/locators/locators.py +++ b/src/askui/locators/locators.py @@ -12,6 +12,10 @@ class Locator(Relatable, ABC): """Base class for all locators.""" + + def _str(self) -> str: + return "locator" + pass @@ -19,41 +23,46 @@ class Prompt(Locator): """Locator for finding ui elements by a textual prompt / description of a ui element, e.g., "green sign up button".""" @validate_call - def __init__(self, prompt: Annotated[str, Field( - description="""A textual prompt / description of a ui element, e.g., "green sign up button".""" - )]) -> None: + def __init__( + self, + prompt: Annotated[ + str, + Field( + description="""A textual prompt / description of a ui element, e.g., "green sign up button".""" + ), + ], + ) -> None: """Initialize a Prompt locator. - + Args: prompt: A textual prompt / description of a ui element, e.g., "green sign up button" """ super().__init__() self._prompt = prompt - + @property def prompt(self) -> str: return self._prompt - - def _str_with_relation(self) -> str: - result = f'element with prompt "{self.prompt}"' - return result + super()._relations_str() - - def __str__(self) -> str: - self.raise_if_cycle() - return self._str_with_relation() + + def _str(self) -> str: + return f'element with prompt "{self.prompt}"' class Element(Locator): """Locator for finding ui elements by a class name assigned to the ui element, e.g., by a computer vision model.""" + @validate_call def __init__( self, - class_name: Annotated[Literal["text", "textfield"] | None, Field( - description="""The class name of the ui element, e.g., 'text' or 'textfield'.""" - )] = None, + class_name: Annotated[ + Literal["text", "textfield"] | None, + Field( + description="""The class name of the ui element, e.g., 'text' or 'textfield'.""" + ), + ] = None, ) -> None: """Initialize an Element locator. - + Args: class_name: The class name of the ui element, e.g., 'text' or 'textfield' """ @@ -64,17 +73,10 @@ def __init__( def class_name(self) -> Literal["text", "textfield"] | None: return self._class_name - def _str_with_relation(self) -> str: - result = ( - f'element with class "{self.class_name}"' - if self.class_name - else "element" + def _str(self) -> str: + return ( + f'element with class "{self.class_name}"' if self.class_name else "element" ) - return result + super()._relations_str() - - def __str__(self) -> str: - self.raise_if_cycle() - return self._str_with_relation() TextMatchType = Literal["similar", "exact", "contains", "regex"] @@ -84,36 +86,47 @@ def __str__(self) -> str: class Text(Element): """Locator for finding text elements by their content.""" + @validate_call def __init__( self, - text: Annotated[str | None, Field( - description="""The text content of the ui element, e.g., 'Sign up'.""" - )] = None, - match_type: Annotated[TextMatchType, Field( - description="""The type of match to use. Defaults to 'similar'. + text: Annotated[ + str | None, + Field( + description="""The text content of the ui element, e.g., 'Sign up'.""" + ), + ] = None, + match_type: Annotated[ + TextMatchType, + Field( + description="""The type of match to use. Defaults to 'similar'. 'similar' uses a similarity threshold to determine if the text is a match. 'exact' requires the text to be exactly the same. 'contains' requires the text to contain the specified text. 'regex' uses a regular expression to match the text.""" - )] = DEFAULT_TEXT_MATCH_TYPE, - similarity_threshold: Annotated[int, Field( - ge=0, - le=100, - description="""A threshold for how similar the text + ), + ] = DEFAULT_TEXT_MATCH_TYPE, + similarity_threshold: Annotated[ + int, + Field( + ge=0, + le=100, + description="""A threshold for how similar the text needs to be to the text content of the ui element to be considered a match. Takes values between 0 and 100 (higher is more similar). Defaults to 70. - Only used if match_type is 'similar'.""")] = DEFAULT_SIMILARITY_THRESHOLD, + Only used if match_type is 'similar'.""", + ), + ] = DEFAULT_SIMILARITY_THRESHOLD, ) -> None: """Initialize a Text locator. - + Args: text: The text content of the ui element, e.g., 'Sign up' - match_type: The type of match to use. Defaults to 'similar'. 'similar' uses a similarity threshold to - determine if the text is a match. 'exact' requires the text to be exactly the same. 'contains' + match_type: The type of match to use. Defaults to 'similar'. 'similar' uses a similarity threshold to + determine if the text is a match. 'exact' requires the text to be exactly the same. 'contains' requires the text to contain the specified text. 'regex' uses a regular expression to match the text. - similarity_threshold: A threshold for how similar the text needs to be to the text content of the ui - element to be considered a match. Takes values between 0 and 100 (higher is more similar). + similarity_threshold: A threshold for how similar the text needs to be to the text content of the ui + element to be considered a match. Takes values between 0 and 100 (higher is more similar). Defaults to 70. Only used if match_type is 'similar'. """ super().__init__() @@ -128,12 +141,12 @@ def text(self) -> str | None: @property def match_type(self) -> TextMatchType: return self._match_type - + @property def similarity_threshold(self) -> int: return self._similarity_threshold - def _str_with_relation(self) -> str: + def _str(self) -> str: if self.text is None: result = "text" else: @@ -147,11 +160,7 @@ def _str_with_relation(self) -> str: result += f'containing text "{self.text}"' case "regex": result += f'matching regex "{self.text}"' - return result + super()._relations_str() - - def __str__(self) -> str: - self.raise_if_cycle() - return self._str_with_relation() + return result class ImageBase(Locator, ABC): @@ -165,36 +174,59 @@ def __init__( image_compare_format: Literal["RGB", "grayscale", "edges"], ) -> None: super().__init__() + if threshold > stop_threshold: + raise ValueError( + f"threshold ({threshold}) must be less than or equal to stop_threshold ({stop_threshold})" + ) self._threshold = threshold self._stop_threshold = stop_threshold self._mask = mask self._rotation_degree_per_step = rotation_degree_per_step self._name = name self._image_compare_format = image_compare_format - + @property def threshold(self) -> float: return self._threshold - + @property def stop_threshold(self) -> float: return self._stop_threshold - + @property def mask(self) -> list[tuple[float, float]] | None: return self._mask - + @property def rotation_degree_per_step(self) -> int: return self._rotation_degree_per_step - + @property def name(self) -> str: return self._name - + @property def image_compare_format(self) -> Literal["RGB", "grayscale", "edges"]: return self._image_compare_format + + def _params_str(self) -> str: + return ( + "(" + + ", ".join([ + f"threshold: {self.threshold}", + f"stop_threshold: {self.stop_threshold}", + f"rotation_degree_per_step: {self.rotation_degree_per_step}", + f"image_compare_format: {self.image_compare_format}", + f"mask: {self.mask}" + ]) + + ")" + ) + + def _str(self) -> str: + return ( + f'element "{self.name}" located by image ' + + self._params_str() + ) def _generate_name() -> str: @@ -203,161 +235,184 @@ def _generate_name() -> str: class Image(ImageBase): """Locator for finding ui elements by an image.""" + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, image: Union[PILImage.Image, pathlib.Path, str], - threshold: Annotated[float, Field( - ge=0, - le=1, - description="""A threshold for how similar UI elements need to be to the image to be considered a match. + threshold: Annotated[ + float, + Field( + ge=0, + le=1, + description="""A threshold for how similar UI elements need to be to the image to be considered a match. Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly - like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality.""" - )] = 0.5, - stop_threshold: Annotated[float, Field( - ge=0, - le=1, - description="""A threshold for when to stop searching for UI elements similar to the image. As soon + like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality.""", + ), + ] = 0.5, + stop_threshold: Annotated[ + float | None, + Field( + ge=0, + le=1, + description="""A threshold for when to stop searching for UI elements similar to the image. As soon as UI elements have been found that are at least as similar as the stop_threshold, the search stops. Should - be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. Important: The - stop_threshold impacts the prediction speed.""" - )] = 0.9, - mask: Annotated[list[tuple[float, float]] | None, Field( - min_length=3, - description="A polygon to match only a certain area of the image." - )] = None, - rotation_degree_per_step: Annotated[int, Field( - ge=0, - lt=360, - description="""A step size in rotation degree. Rotates the image by rotation_degree_per_step until + be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of `threshold` if + not provided. Important: The stop_threshold impacts the prediction speed.""", + ), + ] = None, + mask: Annotated[ + list[tuple[float, float]] | None, + Field( + min_length=3, + description="A polygon to match only a certain area of the image.", + ), + ] = None, + rotation_degree_per_step: Annotated[ + int, + Field( + ge=0, + lt=360, + description="""A step size in rotation degree. Rotates the image by rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the prediction time - quite a bit. So only use it when absolutely necessary.""" - )] = 0, + quite a bit. So only use it when absolutely necessary.""", + ), + ] = 0, name: str | None = None, - image_compare_format: Annotated[Literal["RGB", "grayscale", "edges"], Field( - description="""A color compare style. Defaults to 'grayscale'. + image_compare_format: Annotated[ + Literal["RGB", "grayscale", "edges"], + Field( + description="""A color compare style. Defaults to 'grayscale'. Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often the other way around.""" - )] = "grayscale", + ), + ] = "grayscale", ) -> None: """Initialize an Image locator. - + Args: image: The image to match against (PIL Image, path, or string) - threshold: A threshold for how similar UI elements need to be to the image to be considered a match. - Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly + threshold: A threshold for how similar UI elements need to be to the image to be considered a match. + Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to look exactly like defined). Defaults to 0.5. Important: The threshold impacts the prediction quality. - stop_threshold: A threshold for when to stop searching for UI elements similar to the image. As soon - as UI elements have been found that are at least as similar as the stop_threshold, the search stops. - Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. - Important: The stop_threshold impacts the prediction speed. + stop_threshold: A threshold for when to stop searching for UI elements similar to the image. As soon + as UI elements have been found that are at least as similar as the stop_threshold, the search stops. + Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of + `threshold` if not provided. Important: The stop_threshold impacts the prediction speed. mask: A polygon to match only a certain area of the image. Must have at least 3 points. - rotation_degree_per_step: A step size in rotation degree. Rotates the image by rotation_degree_per_step - until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the + rotation_degree_per_step: A step size in rotation degree. Rotates the image by rotation_degree_per_step + until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the prediction time quite a bit. So only use it when absolutely necessary. name: Optional name for the image. Defaults to generated UUID. - image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format - impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster - than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often + image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format + impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster + than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often the other way around. """ super().__init__( threshold=threshold, - stop_threshold=stop_threshold, + stop_threshold=stop_threshold or threshold, mask=mask, rotation_degree_per_step=rotation_degree_per_step, image_compare_format=image_compare_format, name=_generate_name() if name is None else name, ) # type: ignore self._image = ImageSource(image) - + @property def image(self) -> ImageSource: return self._image - def _str_with_relation(self) -> str: - result = f'element "{self.name}" located by image' - return result + super()._relations_str() - - def __str__(self) -> str: - self.raise_if_cycle() - return self._str_with_relation() - class AiElement(ImageBase): """Locator for finding ui elements by an image and other kinds data saved on the disk.""" + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def __init__( self, name: str, - threshold: Annotated[float, Field( - ge=0, - le=1, - description="""A threshold for how similar UI elements need to be to be considered a match. + threshold: Annotated[ + float, + Field( + ge=0, + le=1, + description="""A threshold for how similar UI elements need to be to be considered a match. Takes values between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match). - Defaults to 0.5. Important: The threshold impacts the prediction quality.""" - )] = 0.5, - stop_threshold: Annotated[float, Field( - ge=0, - le=1, - description="""A threshold for when to stop searching for UI elements. As soon + Defaults to 0.5. Important: The threshold impacts the prediction quality.""", + ), + ] = 0.5, + stop_threshold: Annotated[ + float | None, + Field( + ge=0, + le=1, + description="""A threshold for when to stop searching for UI elements. As soon as UI elements have been found that are at least as similar as the stop_threshold, the search stops. - Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. - Important: The stop_threshold impacts the prediction speed.""" - )] = 0.9, - mask: Annotated[list[tuple[float, float]] | None, Field( - min_length=3, - description="A polygon to match only a certain area of the image of the element saved on disk." - )] = None, - rotation_degree_per_step: Annotated[int, Field( - ge=0, - lt=360, - description="""A step size in rotation degree. Rotates the image of the element saved on disk by + Should be greater than or equal to threshold. Takes values between 0.0 and 1.0. + Defaults to value of `threshold` if not provided. + Important: The stop_threshold impacts the prediction speed.""", + ), + ] = None, + mask: Annotated[ + list[tuple[float, float]] | None, + Field( + min_length=3, + description="A polygon to match only a certain area of the image of the element saved on disk.", + ), + ] = None, + rotation_degree_per_step: Annotated[ + int, + Field( + ge=0, + lt=360, + description="""A step size in rotation degree. Rotates the image of the element saved on disk by rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. - Important: This increases the prediction time quite a bit. So only use it when absolutely necessary.""" - )] = 0, - image_compare_format: Annotated[Literal["RGB", "grayscale", "edges"], Field( - description="""A color compare style. Defaults to 'grayscale'. + Important: This increases the prediction time quite a bit. So only use it when absolutely necessary.""", + ), + ] = 0, + image_compare_format: Annotated[ + Literal["RGB", "grayscale", "edges"], + Field( + description="""A color compare style. Defaults to 'grayscale'. Important: The image_compare_format impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often the other way around.""" - )] = "grayscale", + ), + ] = "grayscale", ) -> None: """Initialize an AiElement locator. - + Args: name: Name of the AI element - threshold: A threshold for how similar UI elements need to be to be considered a match. Takes values - between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match). + threshold: A threshold for how similar UI elements need to be to be considered a match. Takes values + between 0.0 (= all elements are recognized) and 1.0 (= elements need to be an exact match). Defaults to 0.5. Important: The threshold impacts the prediction quality. - stop_threshold: A threshold for when to stop searching for UI elements. As soon as UI elements have - been found that are at least as similar as the stop_threshold, the search stops. Should be greater - than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to 0.9. Important: The - stop_threshold impacts the prediction speed. - mask: A polygon to match only a certain area of the image of the element saved on disk. Must have at + stop_threshold: A threshold for when to stop searching for UI elements. As soon as UI elements have + been found that are at least as similar as the stop_threshold, the search stops. Should be greater + than or equal to threshold. Takes values between 0.0 and 1.0. Defaults to value of `threshold` if not + provided. Important: The stop_threshold impacts the prediction speed. + mask: A polygon to match only a certain area of the image of the element saved on disk. Must have at least 3 points. - rotation_degree_per_step: A step size in rotation degree. Rotates the image of the element saved on - disk by rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. + rotation_degree_per_step: A step size in rotation degree. Rotates the image of the element saved on + disk by rotation_degree_per_step until 360° is exceeded. Range is between 0° - 360°. Defaults to 0°. Important: This increases the prediction time quite a bit. So only use it when absolutely necessary. - image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format - impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster - than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often + image_compare_format: A color compare style. Defaults to 'grayscale'. Important: The image_compare_format + impacts the prediction time as well as quality. As a rule of thumb, 'edges' is likely to be faster + than 'grayscale' and 'grayscale' is likely to be faster than 'RGB'. For quality it is most often the other way around. """ super().__init__( name=name, threshold=threshold, - stop_threshold=stop_threshold, + stop_threshold=stop_threshold or threshold, mask=mask, rotation_degree_per_step=rotation_degree_per_step, image_compare_format=image_compare_format, ) # type: ignore - def _str_with_relation(self) -> str: - result = f'ai element named "{self.name}"' - return result + super()._relations_str() - - def __str__(self) -> str: - self.raise_if_cycle() - return self._str_with_relation() + def _str(self) -> str: + return ( + f'ai element named "{self.name}" ' + + self._params_str() + ) diff --git a/src/askui/locators/relatable.py b/src/askui/locators/relatable.py index ec10b1ab..1cb4df19 100644 --- a/src/askui/locators/relatable.py +++ b/src/askui/locators/relatable.py @@ -860,6 +860,9 @@ def or_(self, other_locator: "Relatable") -> Self: ) return self + def _str(self) -> str: + return "relatable" + def _relations_str(self) -> str: if not self._relations: return "" @@ -871,6 +874,9 @@ def _relations_str(self) -> str: for nested_relation_str in nested_relation_strs: result.append(f" {nested_relation_str}") return "\n" + "\n".join(result) + + def _str_with_relation(self) -> str: + return self._str() + self._relations_str() def raise_if_cycle(self) -> None: """Raises CircularDependencyError if the relations form a cycle (see [Cycle (graph theory)](https://en.wikipedia.org/wiki/Cycle_(graph_theory))).""" @@ -900,3 +906,7 @@ def _dfs(node: Relatable) -> bool: return False return _dfs(self) + + def __str__(self) -> str: + self.raise_if_cycle() + return self._str_with_relation() diff --git a/src/askui/locators/serializers.py b/src/askui/locators/serializers.py index 9b0ce33a..35e1f180 100644 --- a/src/askui/locators/serializers.py +++ b/src/askui/locators/serializers.py @@ -1,7 +1,8 @@ from typing_extensions import NotRequired, TypedDict +from askui.reporting import Reporter from askui.utils.image_utils import ImageSource -from askui.models.askui.ai_element_utils import AiElementCollection, AiElementNotFound +from askui.models.askui.ai_element_utils import AiElementCollection from .locators import ( DEFAULT_SIMILARITY_THRESHOLD, DEFAULT_TEXT_MATCH_TYPE, @@ -41,6 +42,10 @@ def serialize(self, locator: Relatable) -> str: raise NotImplementedError( "Serializing image locators is not yet supported for VLMs" ) + elif isinstance(locator, AiElementLocator): + raise NotImplementedError( + "Serializing AI element locators is not yet supported for VLMs" + ) else: raise ValueError(f"Unsupported locator type: {type(locator)}") @@ -94,8 +99,9 @@ class AskUiLocatorSerializer: "or": "or", } - def __init__(self, ai_element_collection: AiElementCollection): + def __init__(self, ai_element_collection: AiElementCollection, reporter: Reporter): self._ai_element_collection = ai_element_collection + self._reporter = reporter def serialize(self, locator: Relatable) -> AskUiSerializedLocator: locator.raise_if_cycle() @@ -113,10 +119,7 @@ def serialize(self, locator: Relatable) -> AskUiSerializedLocator: elif isinstance(locator, Prompt): result["instruction"] = self._serialize_prompt(locator) elif isinstance(locator, Image): - result = self._serialize_image( - image_locator=locator, - image_sources=[locator.image], - ) + result = self._serialize_image(locator) elif isinstance(locator, AiElementLocator): result = self._serialize_ai_element(locator) else: @@ -134,16 +137,19 @@ def _serialize_class(self, class_: Element) -> str: return class_.class_name or "element" def _serialize_prompt(self, prompt: Prompt) -> str: - return ( - f"pta {self._TEXT_DELIMITER}{prompt.prompt}{self._TEXT_DELIMITER}" - ) + return f"pta {self._TEXT_DELIMITER}{prompt.prompt}{self._TEXT_DELIMITER}" def _serialize_text(self, text: Text) -> str: match text.match_type: case "similar": - if text.similarity_threshold == DEFAULT_SIMILARITY_THRESHOLD and text.match_type == DEFAULT_TEXT_MATCH_TYPE: + if ( + text.similarity_threshold == DEFAULT_SIMILARITY_THRESHOLD + and text.match_type == DEFAULT_TEXT_MATCH_TYPE + ): # Necessary so that we can use wordlevel ocr for these texts - return f"text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" + return ( + f"text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" + ) return f"text with text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER} that matches to {text.similarity_threshold} %" case "exact": return f"text equals text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}" @@ -202,7 +208,7 @@ def _serialize_image_to_custom_element( custom_element["mask"] = image_locator.mask return custom_element - def _serialize_image( + def _serialize_image_base( self, image_locator: ImageBase, image_sources: list[ImageSource], @@ -218,16 +224,34 @@ def _serialize_image( instruction=f"custom element with text {self._TEXT_DELIMITER}{image_locator.name}{self._TEXT_DELIMITER}", customElements=custom_elements, ) + + def _serialize_image( + self, + image: Image, + ) -> AskUiSerializedLocator: + self._reporter.add_message( + "AskUiLocatorSerializer", + f"Image locator: {image}", + image=image.image.root, + ) + return self._serialize_image_base( + image_locator=image, + image_sources=[image.image], + ) def _serialize_ai_element( self, ai_element_locator: AiElementLocator ) -> AskUiSerializedLocator: ai_elements = self._ai_element_collection.find(ai_element_locator.name) - if len(ai_elements) == 0: - raise AiElementNotFound( - f"Could not find AI element with name \"{ai_element_locator.name}\"" - ) - return self._serialize_image( + self._reporter.add_message( + "AskUiLocatorSerializer", + f"Found {len(ai_elements)} ai elements named {ai_element_locator.name}", + image=[ai_element.image for ai_element in ai_elements], + ) + return self._serialize_image_base( image_locator=ai_element_locator, - image_sources=[ImageSource.model_construct(root=ai_element.image) for ai_element in ai_elements], + image_sources=[ + ImageSource.model_construct(root=ai_element.image) + for ai_element in ai_elements + ], ) diff --git a/src/askui/models/anthropic/claude.py b/src/askui/models/anthropic/claude.py index 8965a5e5..4d54f8e8 100644 --- a/src/askui/models/anthropic/claude.py +++ b/src/askui/models/anthropic/claude.py @@ -57,7 +57,7 @@ def locate_inference(self, image: Image.Image, locator: str) -> tuple[int, int]: try: scaled_x, scaled_y = extract_click_coordinates(response) except Exception as e: - raise ElementNotFoundError(f"Couldn't locate {locator} on the screen.") + raise ElementNotFoundError(f"Element not found: {locator}") x, y = scale_coordinates_back(scaled_x, scaled_y, image.width, image.height, screen_width, screen_height) return int(x), int(y) diff --git a/src/askui/models/askui/ai_element_utils.py b/src/askui/models/askui/ai_element_utils.py index b977de33..c8f3ad4b 100644 --- a/src/askui/models/askui/ai_element_utils.py +++ b/src/askui/models/askui/ai_element_utils.py @@ -61,38 +61,49 @@ def from_json_file(cls, json_file_path: pathlib.Path) -> "AiElement": image = Image.open(image_path)) -class AiElementNotFound(Exception): - pass +class AiElementNotFound(ValueError): + def __init__(self, name: str, locations: list[pathlib.Path]): + self.name = name + self.locations = locations + locations_str = ", ".join([str(location) for location in locations]) + super().__init__( + f'AI element "{name}" not found in {locations_str}\n' + 'Solutions:\n' + '1. Verify the element exists in these locations and try again if you are sure it is present\n' + '2. Add location to ASKUI_AI_ELEMENT_LOCATIONS env var (paths, comma separated)\n' + '3. Create new AI element (see https://docs.askui.com/02-api-reference/02-askui-suite/02-askui-suite/AskUIRemoteDeviceSnippingTool/Public/AskUI-NewAIElement)' + ) class AiElementCollection: def __init__(self, additional_ai_element_locations: Optional[List[pathlib.Path]] = None): + additional_ai_element_locations = additional_ai_element_locations or [] + workspace_id = os.getenv("ASKUI_WORKSPACE_ID") if workspace_id is None: raise ValueError("ASKUI_WORKSPACE_ID is not set") - if additional_ai_element_locations is None: - additional_ai_element_locations = [] - - addional_ai_element_from_env = [] - if os.getenv("ASKUI_AI_ELEMENT_LOCATIONS", "") != "": - addional_ai_element_from_env = [pathlib.Path(ai_element_loc) for ai_element_loc in os.getenv("ASKUI_AI_ELEMENT_LOCATIONS", "").split(",")], + locations_from_env: list[pathlib.Path] = [] + if locations_env := os.getenv("ASKUI_AI_ELEMENT_LOCATIONS"): + locations_from_env = [pathlib.Path(loc) for loc in locations_env.split(",")] - self.ai_element_locations = [ + self._ai_element_locations = [ pathlib.Path.home() / ".askui" / "SnippingTool" / "AIElement" / workspace_id, - *addional_ai_element_from_env, + *locations_from_env, *additional_ai_element_locations ] - logger.debug("AI Element locations: %s", self.ai_element_locations) + logger.debug("AI Element locations: %s", self._ai_element_locations) def find(self, name: str) -> list[AiElement]: ai_elements: list[AiElement] = [] - for location in self.ai_element_locations: + for location in self._ai_element_locations: path = pathlib.Path(location) json_files = list(path.glob("*.json")) for json_file in json_files: ai_element = AiElement.from_json_file(json_file) if ai_element.metadata.name == name: ai_elements.append(ai_element) + if len(ai_elements) == 0: + raise AiElementNotFound(name=name, locations=self._ai_element_locations) return ai_elements diff --git a/src/askui/models/router.py b/src/askui/models/router.py index 9e87c22c..7f3395cb 100644 --- a/src/askui/models/router.py +++ b/src/askui/models/router.py @@ -27,7 +27,7 @@ def handle_response(response: tuple[int | None, int | None], locator: str | Locator): if response[0] is None or response[1] is None: - raise ElementNotFoundError(f"Could not locate\n{locator}") + raise ElementNotFoundError(f"Element not found: {locator}") return response @@ -121,6 +121,7 @@ def __init__( self._askui = AskUiInferenceApi( locator_serializer=AskUiLocatorSerializer( ai_element_collection=AiElementCollection(), + reporter=_reporter, ), ) self._grounding_model_routers = grounding_model_routers or [AskUiModelRouter(inference_api=self._askui)] diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 8c6e36f2..c274fc80 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -20,7 +20,7 @@ def add_message( self, role: str, content: Union[str, dict, list], - image: Optional[Image.Image] = None, + image: Optional[Image.Image | list[Image.Image]] = None, ) -> None: raise NotImplementedError() @@ -38,7 +38,7 @@ def add_message( self, role: str, content: Union[str, dict, list], - image: Optional[Image.Image] = None, + image: Optional[Image.Image | list[Image.Image]] = None, ) -> None: for report in self._reports: report.add_message(role, content, image) @@ -83,15 +83,22 @@ def add_message( self, role: str, content: Union[str, dict, list], - image: Optional[Image.Image] = None, + image: Optional[Image.Image | list[Image.Image]] = None, ) -> None: """Add a message to the report, optionally with an image""" + if image is None: + _images = [] + elif isinstance(image, list): + _images = image + else: + _images = [image] + message = { "timestamp": datetime.now(), "role": role, "content": self._format_content(content), "is_json": isinstance(content, (dict, list)), - "image": self._image_to_base64(image) if image else None, + "images": [self._image_to_base64(img) for img in _images], } self.messages.append(message) @@ -233,12 +240,12 @@ def generate(self) -> None: {% else %} {{ msg.content }} {% endif %} - {% if msg.image %} + {% for image in msg.images %}
- Message image - {% endif %} + {% endfor %} {% endfor %} diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index 71dd2c81..ba8b859d 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -33,9 +33,9 @@ def vision_agent( ai_element_collection = AiElementCollection( additional_ai_element_locations=[path_fixtures / "images"] ) - serializer = AskUiLocatorSerializer(ai_element_collection=ai_element_collection) - inference_api = AskUiInferenceApi(locator_serializer=serializer) reporter = SimpleHtmlReporter() + serializer = AskUiLocatorSerializer(ai_element_collection=ai_element_collection, reporter=reporter) + inference_api = AskUiInferenceApi(locator_serializer=serializer) model_router = ModelRouter( tools=agent_toolbox_mock, reporter=reporter, diff --git a/tests/e2e/agent/test_locate.py b/tests/e2e/agent/test_locate.py index f7cb49e1..0cf0d524 100644 --- a/tests/e2e/agent/test_locate.py +++ b/tests/e2e/agent/test_locate.py @@ -16,7 +16,6 @@ from askui.models import ModelName -@pytest.mark.skip("Skipping tests for now") @pytest.mark.parametrize( "model", [ @@ -161,6 +160,8 @@ def test_locate_with_image( path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator.""" + if model in [ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022]: + pytest.skip("Skipping test for Anthropic model because not supported yet") image_path = path_fixtures / "images" / "github_com__signin__button.png" image = PILImage.open(image_path) locator = Image(image=image) @@ -178,6 +179,8 @@ def test_locate_with_image_and_custom_params( path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator with custom parameters.""" + if model in [ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022]: + pytest.skip("Skipping test for Anthropic model because not supported yet") image_path = path_fixtures / "images" / "github_com__signin__button.png" image = PILImage.open(image_path) locator = Image( @@ -202,6 +205,8 @@ def test_locate_with_image_should_fail_when_threshold_is_too_high( path_fixtures: pathlib.Path, ) -> None: """Test locating elements using image locator with custom parameters.""" + if model in [ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022]: + pytest.skip("Skipping test for Anthropic model because not supported yet") image_path = path_fixtures / "images" / "github_com__icon.png" image = PILImage.open(image_path) locator = Image( @@ -219,12 +224,14 @@ def test_locate_with_ai_element_locator( model: str, ) -> None: """Test locating elements using an AI element locator.""" + if model in [ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022]: + pytest.skip("Skipping test for Anthropic model because not supported yet") locator = AiElement("github_com__icon") x, y = vision_agent.locate( locator, github_login_screenshot, model=model ) assert 350 <= x <= 570 - assert 240 <= y <= 320 + assert 50 <= y <= 130 def test_locate_with_ai_element_locator_should_fail_when_threshold_is_too_high( self, @@ -233,6 +240,8 @@ def test_locate_with_ai_element_locator_should_fail_when_threshold_is_too_high( model: str, ) -> None: """Test locating elements using image locator with custom parameters.""" - locator = AiElement("github_com__icon") + if model in [ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022]: + pytest.skip("Skipping test for Anthropic model because not supported yet") + locator = AiElement("github_com__icon", threshold=1.0) with pytest.raises(ElementNotFoundError): vision_agent.locate(locator, github_login_screenshot, model=model) diff --git a/tests/unit/locators/serializers/test_askui_locator_serializer.py b/tests/unit/locators/serializers/test_askui_locator_serializer.py index bae48adb..e79eadfb 100644 --- a/tests/unit/locators/serializers/test_askui_locator_serializer.py +++ b/tests/unit/locators/serializers/test_askui_locator_serializer.py @@ -9,6 +9,7 @@ from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection from askui.utils.image_utils import image_to_base64 +from askui.reporting import CompositeReporter from askui.locators.relatable import CircularDependencyError @@ -23,7 +24,8 @@ def askui_serializer(path_fixtures: pathlib.Path) -> AskUiLocatorSerializer: additional_ai_element_locations=[ path_fixtures / "images" ] - ) + ), + reporter=CompositeReporter() ) diff --git a/tests/unit/locators/serializers/test_locator_string_representation.py b/tests/unit/locators/serializers/test_locator_string_representation.py index 84ddbb28..6529714f 100644 --- a/tests/unit/locators/serializers/test_locator_string_representation.py +++ b/tests/unit/locators/serializers/test_locator_string_representation.py @@ -178,7 +178,7 @@ def test_complex_relation_chain_str() -> None: ) -IMAGE_STR_PATTERN = re.compile(r'^element ".*" located by image$') +IMAGE_STR_PATTERN = re.compile(r'^element ".*" located by image \(threshold: \d+\.\d+, stop_threshold: \d+\.\d+, rotation_degree_per_step: \d+, image_compare_format: \w+, mask: None\)$') def test_image_str() -> None: @@ -188,14 +188,14 @@ def test_image_str() -> None: def test_image_with_name_str() -> None: image = Image(TEST_IMAGE, name="test_image") - assert str(image) == 'element "test_image" located by image' + assert str(image) == 'element "test_image" located by image (threshold: 0.5, stop_threshold: 0.5, rotation_degree_per_step: 0, image_compare_format: grayscale, mask: None)' def test_image_with_relation_str() -> None: image = Image(TEST_IMAGE, name="image") image.above_of(Text("hello")) lines = str(image).split("\n") - assert lines[0] == 'element "image" located by image' + assert lines[0] == 'element "image" located by image (threshold: 0.5, stop_threshold: 0.5, rotation_degree_per_step: 0, image_compare_format: grayscale, mask: None)' assert lines[1] == ' 1. above of boundary of the 1st text similar to "hello" (similarity >= 70%)' diff --git a/tests/unit/locators/test_locators.py b/tests/unit/locators/test_locators.py index 9a32f1ef..1b60fd9f 100644 --- a/tests/unit/locators/test_locators.py +++ b/tests/unit/locators/test_locators.py @@ -111,13 +111,13 @@ class TestImageLocator: def test_image(self) -> PILImage.Image: return PILImage.open(TEST_IMAGE_PATH) - _STR_PATTERN = re.compile(r'^element ".*" located by image$') + _STR_PATTERN = re.compile(r'^element ".*" located by image \(threshold: \d+\.\d+, stop_threshold: \d+\.\d+, rotation_degree_per_step: \d+, image_compare_format: \w+, mask: None\)$') def test_initialization_with_basic_params(self, test_image: PILImage.Image) -> None: locator = Image(image=test_image) assert locator.image.root == test_image assert locator.threshold == 0.5 - assert locator.stop_threshold == 0.9 + assert locator.stop_threshold == 0.5 assert locator.mask is None assert locator.rotation_degree_per_step == 0 assert locator.image_compare_format == "grayscale" @@ -125,7 +125,7 @@ def test_initialization_with_basic_params(self, test_image: PILImage.Image) -> N def test_initialization_with_name(self, test_image: PILImage.Image) -> None: locator = Image(image=test_image, name="test") - assert str(locator) == 'element "test" located by image' + assert str(locator) == 'element "test" located by image (threshold: 0.5, stop_threshold: 0.5, rotation_degree_per_step: 0, image_compare_format: grayscale, mask: None)' def test_initialization_with_custom_params(self, test_image: PILImage.Image) -> None: locator = Image( @@ -141,7 +141,7 @@ def test_initialization_with_custom_params(self, test_image: PILImage.Image) -> assert locator.mask == [(0, 0), (1, 0), (1, 1)] assert locator.rotation_degree_per_step == 45 assert locator.image_compare_format == "RGB" - assert re.match(self._STR_PATTERN, str(locator)) + assert re.match(r'^element "anonymous image [a-f0-9-]+" located by image \(threshold: 0.7, stop_threshold: 0.95, rotation_degree_per_step: 45, image_compare_format: RGB, mask: \[\(0.0, 0.0\), \(1.0, 0.0\), \(1.0, 1.0\)\]\)$', str(locator)) def test_initialization_with_invalid_args(self, test_image: PILImage.Image) -> None: with pytest.raises(ValueError): @@ -176,7 +176,7 @@ class TestAiElementLocator: def test_initialization_with_name(self) -> None: locator = AiElement("github_com__icon") assert locator.name == "github_com__icon" - assert str(locator) == 'ai element named "github_com__icon"' + assert str(locator) == 'ai element named "github_com__icon" (threshold: 0.5, stop_threshold: 0.5, rotation_degree_per_step: 0, image_compare_format: grayscale, mask: None)' def test_initialization_without_name_raises(self) -> None: with pytest.raises(ValueError): @@ -201,7 +201,7 @@ def test_initialization_with_custom_params(self) -> None: assert locator.mask == [(0, 0), (1, 0), (1, 1)] assert locator.rotation_degree_per_step == 45 assert locator.image_compare_format == "RGB" - assert str(locator) == 'ai element named "test_element"' + assert str(locator) == 'ai element named "test_element" (threshold: 0.7, stop_threshold: 0.95, rotation_degree_per_step: 45, image_compare_format: RGB, mask: [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])' def test_initialization_with_invalid_threshold(self) -> None: with pytest.raises(ValueError):