Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,12 @@ class MyGetAndLocateModel(GetModel, LocateModel):
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
# Implement custom locate logic, e.g.:
# - Use a different object detection model
# - Implement custom element finding
# - Call external vision services
return (100, 100) # Example coordinates
return [(100, 100)] # Example coordinates


# Create model registry
Expand Down
2 changes: 2 additions & 0 deletions src/askui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
OnMessageCb,
OnMessageCbParam,
Point,
PointList,
TextBlockParam,
TextCitationParam,
ToolResultBlockParam,
Expand Down Expand Up @@ -82,6 +83,7 @@
"OnMessageCbParam",
"PcKey",
"Point",
"PointList",
"ResponseSchema",
"ResponseSchemaBase",
"Retry",
Expand Down
2 changes: 1 addition & 1 deletion src/askui/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _click(
def _mouse_move(
self, locator: str | Locator, model: ModelComposition | str | None = None
) -> None:
point = self._locate(locator=locator, model=model)
point = self._locate(locator=locator, model=model)[0]
self.tools.os.mouse_move(point[0], point[1])

@telemetry.record_call(exclude={"locator"})
Expand Down
65 changes: 56 additions & 9 deletions src/askui/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ModelName,
ModelRegistry,
Point,
PointList,
TotalModelChoice,
)
from .models.types.response_schemas import ResponseSchema
Expand Down Expand Up @@ -352,13 +353,14 @@ class LinkedListNode(ResponseSchemaBase):
self._reporter.add_message("Agent", message_content)
return response

@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def _locate(
self,
locator: str | Locator,
screenshot: Optional[Img] = None,
model: ModelComposition | str | None = None,
) -> Point:
def locate_with_screenshot() -> Point:
) -> PointList:
def locate_with_screenshot() -> PointList:
_screenshot = load_image_source(
self._agent_os.screenshot() if screenshot is None else screenshot
)
Expand All @@ -368,10 +370,10 @@ def locate_with_screenshot() -> Point:
model_choice=model or self._model_choice["locate"],
)

point = self._retry.attempt(locate_with_screenshot)
self._reporter.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})")
logger.debug("ModelRouter locate: (%d, %d)", point[0], point[1])
return point
points = self._retry.attempt(locate_with_screenshot)
self._reporter.add_message("ModelRouter", f"locate {len(points)} elements")
logger.debug("ModelRouter locate: %d elements", len(points))
return points

@telemetry.record_call(exclude={"locator", "screenshot"})
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand All @@ -382,7 +384,7 @@ def locate(
model: ModelComposition | str | None = None,
) -> Point:
"""
Locates the UI element identified by the provided locator.
Locates the first matching UI element identified by the provided locator.

Args:
locator (str | Locator): The identifier or description of the element to
Expand All @@ -405,8 +407,53 @@ def locate(
print(f"Element found at coordinates: {point}")
```
"""
self._reporter.add_message("User", f"locate {locator}")
logger.debug("VisionAgent received instruction to locate %s", locator)
self._reporter.add_message("User", f"locate first matching element {locator}")
logger.debug(
"VisionAgent received instruction to locate first matching element %s",
locator,
)
return self._locate(locator, screenshot, model)[0]

@telemetry.record_call(exclude={"locator", "screenshot"})
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def locate_all(
self,
locator: str | Locator,
screenshot: Optional[Img] = None,
model: ModelComposition | str | None = None,
) -> PointList:
"""
Locates all matching UI elements identified by the provided locator.

Note: Some LocateModels can only locate a single element. In this case, the
returned list will have a length of 1.

Args:
locator (str | Locator): The identifier or description of the element to
locate.
screenshot (Img | None, optional): The screenshot to use for locating the
element. Can be a path to an image file, a PIL Image object or a data
URL. If `None`, takes a screenshot of the currently selected display.
model (ModelComposition | str | None, optional): The composition or name
of the model(s) to be used for locating the element using the `locator`.

Returns:
PointList: The coordinates of the elements as a list of tuples (x, y).

Example:
```python
from askui import VisionAgent

with VisionAgent() as agent:
points = agent.locate_all("Submit button")
print(f"Found {len(points)} elements at coordinates: {points}")
```
"""
self._reporter.add_message("User", f"locate all matching UI elements {locator}")
logger.debug(
"VisionAgent received instruction to locate all matching UI elements %s",
locator,
)
return self._locate(locator, screenshot, model)

@telemetry.record_call()
Expand Down
2 changes: 1 addition & 1 deletion src/askui/android_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def tap(
msg += f" on {target}"
self._reporter.add_message("User", msg)
logger.debug("VisionAgent received instruction to click on %s", target)
point = self._locate(locator=target, model=model)
point = self._locate(locator=target, model=model)[0]
self.os.tap(point[0], point[1])

@telemetry.record_call(exclude={"text"})
Expand Down
2 changes: 2 additions & 0 deletions src/askui/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ModelName,
ModelRegistry,
Point,
PointList,
)
from .openrouter.model import OpenRouterModel
from .openrouter.settings import ChatCompletionsCreateSettings, OpenRouterSettings
Expand Down Expand Up @@ -53,6 +54,7 @@
"OpenRouterModel",
"OpenRouterSettings",
"Point",
"PointList",
"TextBlockParam",
"TextCitationParam",
"ToolResultBlockParam",
Expand Down
18 changes: 10 additions & 8 deletions src/askui/models/anthropic/messages_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
LocateModel,
ModelComposition,
ModelName,
Point,
PointList,
)
from askui.models.shared.agent_message_param import (
Base64ImageSourceParam,
Expand Down Expand Up @@ -198,7 +198,7 @@ def locate(
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
if not isinstance(model_choice, str):
error_msg = "Model composition is not supported for Claude"
raise NotImplementedError(error_msg)
Expand All @@ -219,12 +219,14 @@ def locate(
),
model_choice=model_choice,
)
return scale_coordinates(
extract_click_coordinates(content),
image.root.size,
self._settings.resolution,
inverse=True,
)
return [
scale_coordinates(
extract_click_coordinates(content),
image.root.size,
self._settings.resolution,
inverse=True,
)
]
except (
_UnexpectedResponseError,
ValueError,
Expand Down
23 changes: 13 additions & 10 deletions src/askui/models/askui/inference_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from askui.locators.serializers import AskUiLocatorSerializer, AskUiSerializedLocator
from askui.logger import logger
from askui.models.exceptions import ElementNotFoundError
from askui.models.models import GetModel, LocateModel, ModelComposition, Point
from askui.models.models import GetModel, LocateModel, ModelComposition, PointList
from askui.models.shared.agent_message_param import MessageParam
from askui.models.shared.messages_api import MessagesApi
from askui.models.shared.settings import MessageSettings
Expand Down Expand Up @@ -162,7 +162,7 @@ def locate(
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
serialized_locator = (
self._locator_serializer.serialize(locator=locator)
if isinstance(locator, Locator)
Expand All @@ -171,7 +171,7 @@ def locate(
logger.debug(f"serialized_locator:\n{json_lib.dumps(serialized_locator)}")
json: dict[str, Any] = {
"image": image.to_data_url(),
"instruction": f"Click on {serialized_locator['instruction']}",
"instruction": f"get element {serialized_locator['instruction']}",
}
if "customElements" in serialized_locator:
json["customElements"] = serialized_locator["customElements"]
Expand All @@ -182,17 +182,20 @@ def locate(
)
response = self._post(path="/inference", json=json)
content = response.json()
assert content["type"] == "COMMANDS", (
assert content["type"] == "DETECTED_ELEMENTS", (
f"Received unknown content type {content['type']}"
)
actions = [
el for el in content["data"]["actions"] if el["inputEvent"] == "MOUSE_MOVE"
]
if len(actions) == 0:
detected_elements = content["data"]["detected_elements"]
if len(detected_elements) == 0:
raise ElementNotFoundError(locator, serialized_locator)

position = actions[0]["position"]
return int(position["x"]), int(position["y"])
return [
(
int((element["bndbox"]["xmax"] + element["bndbox"]["xmin"]) / 2),
int((element["bndbox"]["ymax"] + element["bndbox"]["ymin"]) / 2),
)
for element in detected_elements
]

@override
def get(
Expand Down
6 changes: 3 additions & 3 deletions src/askui/models/askui/model_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ElementNotFoundError,
ModelNotFoundError,
)
from askui.models.models import LocateModel, ModelComposition, ModelName, Point
from askui.models.models import LocateModel, ModelComposition, ModelName, PointList
from askui.utils.image_utils import ImageSource


Expand All @@ -18,7 +18,7 @@ def __init__(self, inference_api: AskUiInferenceApi):

def _locate_with_askui_ocr(
self, screenshot: ImageSource, locator: str | Text
) -> Point:
) -> PointList:
locator = Text(locator) if isinstance(locator, str) else locator
return self._inference_api.locate(
locator, screenshot, model_choice=ModelName.ASKUI__OCR
Expand All @@ -30,7 +30,7 @@ def locate(
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
if (
isinstance(model_choice, ModelComposition)
or model_choice == ModelName.ASKUI
Expand Down
10 changes: 5 additions & 5 deletions src/askui/models/huggingface/spaces_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from askui.exceptions import AutomationError
from askui.locators.locators import Locator
from askui.locators.serializers import VlmLocatorSerializer
from askui.models.models import LocateModel, ModelComposition, ModelName, Point
from askui.models.models import LocateModel, ModelComposition, ModelName, PointList
from askui.utils.image_utils import ImageSource


Expand Down Expand Up @@ -65,7 +65,7 @@ def locate(
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
"""Predict element location using Hugging Face Spaces."""
if not isinstance(model_choice, str):
error_msg = "Model composition is not supported for Hugging Face Spaces"
Expand All @@ -76,9 +76,9 @@ def locate(
if isinstance(locator, Locator)
else locator
)
return self._spaces[model_choice](
image.root, serialized_locator, model_choice
)
return [
self._spaces[model_choice](image.root, serialized_locator, model_choice)
]
except (ValueError, json.JSONDecodeError, httpx.HTTPError) as e:
error_msg = f"Hugging Face Spaces Exception: {e}"
raise AutomationError(error_msg) from e
Expand Down
4 changes: 2 additions & 2 deletions src/askui/models/model_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ModelComposition,
ModelName,
ModelRegistry,
Point,
PointList,
)
from askui.models.shared.agent import Agent
from askui.models.shared.agent_message_param import MessageParam
Expand Down Expand Up @@ -213,7 +213,7 @@ def locate(
screenshot: ImageSource,
locator: str | Locator,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
_model_choice = (
ModelName.ASKUI
if isinstance(model_choice, ModelComposition)
Expand Down
15 changes: 10 additions & 5 deletions src/askui/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ def __getitem__(self, index: int) -> ModelDefinition:
A tuple of two integers representing the coordinates of a point on the screen.
"""

PointList = Annotated[list[Point], Field(min_length=1)]
"""
A list of points representing the coordinates of elements on the screen.
"""


class ActModel(abc.ABC):
"""Abstract base class for models that can execute autonomous actions.
Expand Down Expand Up @@ -293,7 +298,7 @@ class LocateModel(abc.ABC):

Example:
```python
from askui import LocateModel, VisionAgent, Locator, ImageSource, Point
from askui import LocateModel, VisionAgent, Locator, ImageSource, PointList
from askui.models import ModelComposition

class MyLocateModel(LocateModel):
Expand All @@ -302,9 +307,9 @@ def locate(
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
# Implement custom locate logic
return (100, 100)
return [(100, 100)]

with VisionAgent(models={"my-locate": MyLocateModel()}) as agent:
agent.click("button", model="my-locate")
Expand All @@ -317,7 +322,7 @@ def locate(
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
) -> PointList:
"""Find the coordinates of a UI element in an image.

Args:
Expand All @@ -328,7 +333,7 @@ def locate(
`ModelComposition` for models that support composition

Returns:
A tuple of (x, y) coordinates where the element was found
A list of (x, y) coordinates where the element was found, minimum length 1
"""
raise NotImplementedError

Expand Down
Loading