diff --git a/README.md b/README.md index a57fc76a..c8c19368 100644 --- a/README.md +++ b/README.md @@ -367,12 +367,12 @@ class MyGetAndLocateModel(GetModel, LocateModel): locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: # Implement custom locate logic, e.g.: # - Use a different object detection model # - Implement custom element finding # - Call external vision services - return (100, 100) # Example coordinates + return [(100, 100)] # Example coordinates # Create model registry diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 9d821963..4d57743a 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -25,6 +25,7 @@ OnMessageCb, OnMessageCbParam, Point, + PointList, TextBlockParam, TextCitationParam, ToolResultBlockParam, @@ -82,6 +83,7 @@ "OnMessageCbParam", "PcKey", "Point", + "PointList", "ResponseSchema", "ResponseSchemaBase", "Retry", diff --git a/src/askui/agent.py b/src/askui/agent.py index c6bd2dac..fd53cd2e 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -182,7 +182,7 @@ def _click( def _mouse_move( self, locator: str | Locator, model: ModelComposition | str | None = None ) -> None: - point = self._locate(locator=locator, model=model) + point = self._locate(locator=locator, model=model)[0] self.tools.os.mouse_move(point[0], point[1]) @telemetry.record_call(exclude={"locator"}) diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index 88d8e517..a270b6f7 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -29,6 +29,7 @@ ModelName, ModelRegistry, Point, + PointList, TotalModelChoice, ) from .models.types.response_schemas import ResponseSchema @@ -352,13 +353,14 @@ class LinkedListNode(ResponseSchemaBase): self._reporter.add_message("Agent", message_content) return response + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def _locate( self, locator: str | Locator, screenshot: Optional[Img] = None, model: ModelComposition | str | None = None, - ) -> Point: - def locate_with_screenshot() -> Point: + ) -> PointList: + def locate_with_screenshot() -> PointList: _screenshot = load_image_source( self._agent_os.screenshot() if screenshot is None else screenshot ) @@ -368,10 +370,10 @@ def locate_with_screenshot() -> Point: model_choice=model or self._model_choice["locate"], ) - point = self._retry.attempt(locate_with_screenshot) - self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})") - logger.debug("ModelRouter locate: (%d, %d)", point[0], point[1]) - return point + points = self._retry.attempt(locate_with_screenshot) + self._reporter.add_message("ModelRouter", f"locate {len(points)} elements") + logger.debug("ModelRouter locate: %d elements", len(points)) + return points @telemetry.record_call(exclude={"locator", "screenshot"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -382,7 +384,7 @@ def locate( model: ModelComposition | str | None = None, ) -> Point: """ - Locates the UI element identified by the provided locator. + Locates the first matching UI element identified by the provided locator. Args: locator (str | Locator): The identifier or description of the element to @@ -405,8 +407,53 @@ def locate( print(f"Element found at coordinates: {point}") ``` """ - self._reporter.add_message("User", f"locate {locator}") - logger.debug("VisionAgent received instruction to locate %s", locator) + self._reporter.add_message("User", f"locate first matching element {locator}") + logger.debug( + "VisionAgent received instruction to locate first matching element %s", + locator, + ) + return self._locate(locator, screenshot, model)[0] + + @telemetry.record_call(exclude={"locator", "screenshot"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def locate_all( + self, + locator: str | Locator, + screenshot: Optional[Img] = None, + model: ModelComposition | str | None = None, + ) -> PointList: + """ + Locates all matching UI elements identified by the provided locator. + + Note: Some LocateModels can only locate a single element. In this case, the + returned list will have a length of 1. + + Args: + locator (str | Locator): The identifier or description of the element to + locate. + screenshot (Img | None, optional): The screenshot to use for locating the + element. Can be a path to an image file, a PIL Image object or a data + URL. If `None`, takes a screenshot of the currently selected display. + model (ModelComposition | str | None, optional): The composition or name + of the model(s) to be used for locating the element using the `locator`. + + Returns: + PointList: The coordinates of the elements as a list of tuples (x, y). + + Example: + ```python + from askui import VisionAgent + + with VisionAgent() as agent: + points = agent.locate_all("Submit button") + print(f"Found {len(points)} elements at coordinates: {points}") + ``` + """ + self._reporter.add_message("User", f"locate all matching UI elements {locator}") + logger.debug( + "VisionAgent received instruction to locate all matching UI elements %s", + locator, + ) return self._locate(locator, screenshot, model) @telemetry.record_call() diff --git a/src/askui/android_agent.py b/src/askui/android_agent.py index 7292338d..41b96720 100644 --- a/src/askui/android_agent.py +++ b/src/askui/android_agent.py @@ -198,7 +198,7 @@ def tap( msg += f" on {target}" self._reporter.add_message("User", msg) logger.debug("VisionAgent received instruction to click on %s", target) - point = self._locate(locator=target, model=model) + point = self._locate(locator=target, model=model)[0] self.os.tap(point[0], point[1]) @telemetry.record_call(exclude={"text"}) diff --git a/src/askui/models/__init__.py b/src/askui/models/__init__.py index 3de6de65..f496d769 100644 --- a/src/askui/models/__init__.py +++ b/src/askui/models/__init__.py @@ -9,6 +9,7 @@ ModelName, ModelRegistry, Point, + PointList, ) from .openrouter.model import OpenRouterModel from .openrouter.settings import ChatCompletionsCreateSettings, OpenRouterSettings @@ -53,6 +54,7 @@ "OpenRouterModel", "OpenRouterSettings", "Point", + "PointList", "TextBlockParam", "TextCitationParam", "ToolResultBlockParam", diff --git a/src/askui/models/anthropic/messages_api.py b/src/askui/models/anthropic/messages_api.py index c9b7a26d..b92e9f9c 100644 --- a/src/askui/models/anthropic/messages_api.py +++ b/src/askui/models/anthropic/messages_api.py @@ -27,7 +27,7 @@ LocateModel, ModelComposition, ModelName, - Point, + PointList, ) from askui.models.shared.agent_message_param import ( Base64ImageSourceParam, @@ -198,7 +198,7 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: if not isinstance(model_choice, str): error_msg = "Model composition is not supported for Claude" raise NotImplementedError(error_msg) @@ -219,12 +219,14 @@ def locate( ), model_choice=model_choice, ) - return scale_coordinates( - extract_click_coordinates(content), - image.root.size, - self._settings.resolution, - inverse=True, - ) + return [ + scale_coordinates( + extract_click_coordinates(content), + image.root.size, + self._settings.resolution, + inverse=True, + ) + ] except ( _UnexpectedResponseError, ValueError, diff --git a/src/askui/models/askui/inference_api.py b/src/askui/models/askui/inference_api.py index b30d40cb..231ae093 100644 --- a/src/askui/models/askui/inference_api.py +++ b/src/askui/models/askui/inference_api.py @@ -20,7 +20,7 @@ from askui.locators.serializers import AskUiLocatorSerializer, AskUiSerializedLocator from askui.logger import logger from askui.models.exceptions import ElementNotFoundError -from askui.models.models import GetModel, LocateModel, ModelComposition, Point +from askui.models.models import GetModel, LocateModel, ModelComposition, PointList from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.messages_api import MessagesApi from askui.models.shared.settings import MessageSettings @@ -162,7 +162,7 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: serialized_locator = ( self._locator_serializer.serialize(locator=locator) if isinstance(locator, Locator) @@ -171,7 +171,7 @@ def locate( logger.debug(f"serialized_locator:\n{json_lib.dumps(serialized_locator)}") json: dict[str, Any] = { "image": image.to_data_url(), - "instruction": f"Click on {serialized_locator['instruction']}", + "instruction": f"get element {serialized_locator['instruction']}", } if "customElements" in serialized_locator: json["customElements"] = serialized_locator["customElements"] @@ -182,17 +182,20 @@ def locate( ) response = self._post(path="/inference", json=json) content = response.json() - assert content["type"] == "COMMANDS", ( + assert content["type"] == "DETECTED_ELEMENTS", ( f"Received unknown content type {content['type']}" ) - actions = [ - el for el in content["data"]["actions"] if el["inputEvent"] == "MOUSE_MOVE" - ] - if len(actions) == 0: + detected_elements = content["data"]["detected_elements"] + if len(detected_elements) == 0: raise ElementNotFoundError(locator, serialized_locator) - position = actions[0]["position"] - return int(position["x"]), int(position["y"]) + return [ + ( + int((element["bndbox"]["xmax"] + element["bndbox"]["xmin"]) / 2), + int((element["bndbox"]["ymax"] + element["bndbox"]["ymin"]) / 2), + ) + for element in detected_elements + ] @override def get( diff --git a/src/askui/models/askui/model_router.py b/src/askui/models/askui/model_router.py index d2bf857f..0dc7076c 100644 --- a/src/askui/models/askui/model_router.py +++ b/src/askui/models/askui/model_router.py @@ -8,7 +8,7 @@ ElementNotFoundError, ModelNotFoundError, ) -from askui.models.models import LocateModel, ModelComposition, ModelName, Point +from askui.models.models import LocateModel, ModelComposition, ModelName, PointList from askui.utils.image_utils import ImageSource @@ -18,7 +18,7 @@ def __init__(self, inference_api: AskUiInferenceApi): def _locate_with_askui_ocr( self, screenshot: ImageSource, locator: str | Text - ) -> Point: + ) -> PointList: locator = Text(locator) if isinstance(locator, str) else locator return self._inference_api.locate( locator, screenshot, model_choice=ModelName.ASKUI__OCR @@ -30,7 +30,7 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: if ( isinstance(model_choice, ModelComposition) or model_choice == ModelName.ASKUI diff --git a/src/askui/models/huggingface/spaces_api.py b/src/askui/models/huggingface/spaces_api.py index a12d37bd..eedef8c1 100644 --- a/src/askui/models/huggingface/spaces_api.py +++ b/src/askui/models/huggingface/spaces_api.py @@ -10,7 +10,7 @@ from askui.exceptions import AutomationError from askui.locators.locators import Locator from askui.locators.serializers import VlmLocatorSerializer -from askui.models.models import LocateModel, ModelComposition, ModelName, Point +from askui.models.models import LocateModel, ModelComposition, ModelName, PointList from askui.utils.image_utils import ImageSource @@ -65,7 +65,7 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: """Predict element location using Hugging Face Spaces.""" if not isinstance(model_choice, str): error_msg = "Model composition is not supported for Hugging Face Spaces" @@ -76,9 +76,9 @@ def locate( if isinstance(locator, Locator) else locator ) - return self._spaces[model_choice]( - image.root, serialized_locator, model_choice - ) + return [ + self._spaces[model_choice](image.root, serialized_locator, model_choice) + ] except (ValueError, json.JSONDecodeError, httpx.HTTPError) as e: error_msg = f"Hugging Face Spaces Exception: {e}" raise AutomationError(error_msg) from e diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 394e7a3e..f203ef3c 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -21,7 +21,7 @@ ModelComposition, ModelName, ModelRegistry, - Point, + PointList, ) from askui.models.shared.agent import Agent from askui.models.shared.agent_message_param import MessageParam @@ -213,7 +213,7 @@ def locate( screenshot: ImageSource, locator: str | Locator, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: _model_choice = ( ModelName.ASKUI if isinstance(model_choice, ModelComposition) diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 82e3043e..3c2b63b2 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -148,6 +148,11 @@ def __getitem__(self, index: int) -> ModelDefinition: A tuple of two integers representing the coordinates of a point on the screen. """ +PointList = Annotated[list[Point], Field(min_length=1)] +""" +A list of points representing the coordinates of elements on the screen. +""" + class ActModel(abc.ABC): """Abstract base class for models that can execute autonomous actions. @@ -293,7 +298,7 @@ class LocateModel(abc.ABC): Example: ```python - from askui import LocateModel, VisionAgent, Locator, ImageSource, Point + from askui import LocateModel, VisionAgent, Locator, ImageSource, PointList from askui.models import ModelComposition class MyLocateModel(LocateModel): @@ -302,9 +307,9 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: # Implement custom locate logic - return (100, 100) + return [(100, 100)] with VisionAgent(models={"my-locate": MyLocateModel()}) as agent: agent.click("button", model="my-locate") @@ -317,7 +322,7 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: """Find the coordinates of a UI element in an image. Args: @@ -328,7 +333,7 @@ def locate( `ModelComposition` for models that support composition Returns: - A tuple of (x, y) coordinates where the element was found + A list of (x, y) coordinates where the element was found, minimum length 1 """ raise NotImplementedError diff --git a/src/askui/models/shared/facade.py b/src/askui/models/shared/facade.py index a26c9cfd..c919fdf6 100644 --- a/src/askui/models/shared/facade.py +++ b/src/askui/models/shared/facade.py @@ -3,7 +3,13 @@ from typing_extensions import override from askui.locators.locators import Locator -from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point +from askui.models.models import ( + ActModel, + GetModel, + LocateModel, + ModelComposition, + PointList, +) from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings @@ -57,5 +63,5 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: return self._locate_model.locate(locator, image, model_choice) diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 84cf46a1..1eec36bd 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -11,7 +11,13 @@ from askui.locators.locators import Locator from askui.locators.serializers import VlmLocatorSerializer from askui.models.exceptions import ElementNotFoundError, QueryNoResponseError -from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point +from askui.models.models import ( + ActModel, + GetModel, + LocateModel, + ModelComposition, + PointList, +) from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings @@ -148,7 +154,7 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: if not isinstance(model_choice, str): error_msg = "Model composition is not supported for UI-TARS" raise NotImplementedError(error_msg) @@ -171,7 +177,7 @@ def locate( width, height = image.root.size new_height, new_width = smart_resize(height, width) x, y = (int(x / new_width * width), int(y / new_height * height)) - return x, y + return [(x, y)] raise ElementNotFoundError(locator, locator_serialized) @override diff --git a/tests/integration/agent/test_retry.py b/tests/integration/agent/test_retry.py index 76f1dd67..44aaffb6 100644 --- a/tests/integration/agent/test_retry.py +++ b/tests/integration/agent/test_retry.py @@ -27,11 +27,11 @@ def locate( locator: Union[str, Locator], image: ImageSource, # noqa: ARG002 model_choice: Union[ModelComposition, str], # noqa: ARG002 - ) -> Tuple[int, int]: + ) -> list[Tuple[int, int]]: self.calls += 1 if self.calls <= self.fail_times: raise ElementNotFoundError(locator, locator) - return self.succeed_point + return [self.succeed_point] @pytest.fixture diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index 962def95..1f4a2449 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -12,6 +12,7 @@ LocateModel, ModelRegistry, Point, + PointList, ResponseSchema, ResponseSchemaBase, VisionAgent, @@ -95,11 +96,11 @@ def locate( locator: str | Locator, image: ImageSource, model_choice: ModelComposition | str, - ) -> Point: + ) -> PointList: self.locators.append(locator) self.images.append(image) self.model_choices.append(model_choice) - return self._point + return [self._point] class SimpleResponseSchema(ResponseSchemaBase):