Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/askui/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""AskUI Vision Agent"""

__version__ = "0.22.2"
__version__ = "0.22.3"

import logging
import os
Expand Down
98 changes: 98 additions & 0 deletions src/askui/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from askui.models.shared.tools import Tool, ToolCollection
from askui.tools.agent_os import AgentOs
from askui.tools.android.agent_os import AndroidAgentOs
from askui.utils.annotation_writer import AnnotationWriter
from askui.utils.image_utils import ImageSource
from askui.utils.source_utils import InputSource, load_image_source

from .models import ModelComposition
from .models.exceptions import ElementNotFoundError, WaitUntilError
from .models.model_router import ModelRouter, initialize_default_model_registry
from .models.models import (
DetectedElement,
ModelChoice,
ModelName,
ModelRegistry,
Expand Down Expand Up @@ -507,6 +509,102 @@ def locate_all(
)
return self._locate(locator=locator, screenshot=screenshot, model=model)

@telemetry.record_call(exclude={"screenshot"})
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def locate_all_elements(
self,
screenshot: Optional[InputSource] = None,
model: ModelComposition | None = None,
) -> list[DetectedElement]:
"""Locate all elements in the current screen using AskUI Models.

Args:
screenshot (InputSource | None, optional): The screenshot to use for
locating the elements. Can be a path to an image file, a PIL Image
object or a data URL. If `None`, takes a screenshot of the currently
selected display.
model (ModelComposition | None, optional): The model composition
to be used for locating the elements.

Returns:
list[DetectedElement]: A list of detected elements

Example:
```python
from askui import VisionAgent

with VisionAgent() as agent:
detected_elements = agent.locate_all_elements()
print(f"Found {len(detected_elements)} elements: {detected_elements}")
```
"""
_screenshot = load_image_source(
self._agent_os.screenshot() if screenshot is None else screenshot
)
return self._model_router.locate_all_elements(
image=_screenshot, model=model or ModelName.ASKUI
)

@telemetry.record_call(exclude={"screenshot", "annotation_dir"})
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def annotate(
self,
screenshot: InputSource | None = None,
annotation_dir: str = "annotations",
model: ModelComposition | None = None,
) -> None:
"""Annotate the screenshot with the detected elements.
Creates an interactive HTML file with the detected elements
and saves it to the annotation directory.
The HTML file can be opened in a browser to see the annotated image.
The user can hover over the elements to see their names and text value
and click on the box to copy the text value to the clipboard.

Args:
screenshot (ImageSource | None, optional): The screenshot to annotate.
If `None`, takes a screenshot of the currently selected display.
annotation_dir (str): The directory to save the annotated
image. Defaults to "annotations".
model (ModelComposition | None, optional): The composition
of the model(s) to be used for annotating the image.
If `None`, uses the default model.

Example Using VisionAgent:
```python
from askui import VisionAgent

with VisionAgent() as agent:
agent.annotate()
```

Example Using AndroidVisionAgent:
```python
from askui import AndroidVisionAgent

with AndroidVisionAgent() as agent:
agent.annotate()
```

Example Using VisionAgent with custom screenshot and annotation directory:
```python
from askui import VisionAgent

with VisionAgent() as agent:
agent.annotate(screenshot="screenshot.png", annotation_dir="htmls")
```
"""
if screenshot is None:
screenshot = self._agent_os.screenshot()

detected_elements = self.locate_all_elements(
screenshot=screenshot,
model=model,
)
AnnotationWriter(
image=screenshot,
elements=detected_elements,
).save_to_dir(annotation_dir)

@telemetry.record_call(exclude={"until"})
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def wait(
Expand Down
29 changes: 29 additions & 0 deletions src/askui/models/askui/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
QueryUnexpectedResponseError,
)
from askui.models.models import (
DetectedElement,
GetModel,
LocateModel,
ModelComposition,
Expand Down Expand Up @@ -131,6 +132,34 @@ def _locate(
for element in detected_elements
]

@override
def locate_all_elements(
self,
image: ImageSource,
model: ModelComposition | str,
) -> list[DetectedElement]:
request_body: dict[str, Any] = {
"image": image.to_data_url(),
"instruction": "get all elements",
}

if isinstance(model, ModelComposition):
request_body["modelComposition"] = model.model_dump(by_alias=True)
logger.debug(
"Model composition",
extra={
"modelComposition": json_lib.dumps(request_body["modelComposition"])
},
)

response = self._inference_api.post(path="/inference", json=request_body)
content = response.json()
assert content["type"] == "DETECTED_ELEMENTS", (
f"Received unknown content type {content['type']}"
)
detected_elements = content["data"]["detected_elements"]
return [DetectedElement.from_json(element) for element in detected_elements]


class AskUiGetModel(GetModel):
"""A GetModel implementation that is supposed to be as comprehensive and
Expand Down
16 changes: 15 additions & 1 deletion src/askui/models/model_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from askui.models.models import (
MODEL_TYPES,
ActModel,
DetectedElement,
GetModel,
LocateModel,
Model,
Expand Down Expand Up @@ -145,7 +146,6 @@ def tars_handler() -> UiTarsApiHandler:
ModelName.ASKUI__COMBO: askui_locate_model,
ModelName.ASKUI__OCR: askui_locate_model,
ModelName.ASKUI__PTA: askui_locate_model,
ModelName.CLAUDE__SONNET__4__20250514: lambda: anthropic_facade("anthropic"),
ModelName.HF__SPACES__ASKUI__PTA_1: hf_spaces_handler,
ModelName.HF__SPACES__QWEN__QWEN2_VL_2B_INSTRUCT: hf_spaces_handler,
ModelName.HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT: hf_spaces_handler,
Expand Down Expand Up @@ -265,3 +265,17 @@ def locate(
extra={"model": _model},
)
return m.locate(locator, screenshot, _model_composition or _model)

def locate_all_elements(
self,
image: ImageSource,
model: ModelComposition | str,
) -> list[DetectedElement]:
_model = ModelName.ASKUI if isinstance(model, ModelComposition) else model
_model_composition = model if isinstance(model, ModelComposition) else None
m, _model = self._get_model(_model, "locate")
logger.debug(
"Routing locate_all_elements prediction to",
extra={"model": _model},
)
return m.locate_all_elements(image, model=_model_composition or _model)
91 changes: 91 additions & 0 deletions src/askui/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,80 @@ def __getitem__(self, index: int) -> ModelDefinition:
"""


class BoundingBox(BaseModel):
model_config = ConfigDict(
extra="ignore",
)

xmin: int
ymin: int
xmax: int
ymax: int

@staticmethod
def from_json(data: dict[str, float]) -> "BoundingBox":
return BoundingBox(
xmin=int(data["xmin"]),
ymin=int(data["ymin"]),
xmax=int(data["xmax"]),
ymax=int(data["ymax"]),
)

def __str__(self) -> str:
return f"[{self.xmin}, {self.ymin}, {self.xmax}, {self.ymax}]"

@property
def width(self) -> int:
"""The width of the bounding box."""
return self.xmax - self.xmin

@property
def height(self) -> int:
"""The height of the bounding box."""
return self.ymax - self.ymin

@property
def center(self) -> Point:
"""The center point of the bounding box."""
return int((self.xmin + self.xmax) / 2), int((self.ymin + self.ymax) / 2)


class DetectedElement(BaseModel):
model_config = ConfigDict(
extra="ignore",
)

name: str
text: str
bounding_box: BoundingBox

@staticmethod
def from_json(data: dict[str, str | float | dict[str, float]]) -> "DetectedElement":
return DetectedElement(
name=str(data["name"]),
text=str(data["text"]),
bounding_box=BoundingBox.from_json(data["bndbox"]), # type: ignore
)

def __str__(self) -> str:
return f"[name={self.name}, text={self.text}, bndbox={str(self.bounding_box)}]"

@property
def center(self) -> Point:
"""The center point of the detected element."""
return self.bounding_box.center

@property
def width(self) -> int:
"""The width of the detected element."""
return self.bounding_box.width

@property
def height(self) -> int:
"""The height of the detected element."""
return self.bounding_box.height


class ActModel(abc.ABC):
"""Abstract base class for models that can execute autonomous actions.

Expand Down Expand Up @@ -336,6 +410,23 @@ def locate(
"""
raise NotImplementedError

def locate_all_elements(
self,
image: ImageSource,
model: ModelComposition | str,
) -> list[DetectedElement]:
"""Locate all elements in an image.

Args:
image (ImageSource): The image to analyze (screenshot or provided image)
model (ModelComposition | str): Either a string model name or a
`ModelComposition` for models that support composition

Returns:
A list of detected elements
"""
raise NotImplementedError


Model = ActModel | GetModel | LocateModel
"""Union type of all abstract model classes.
Expand Down
9 changes: 9 additions & 0 deletions src/askui/models/shared/facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from askui.locators.locators import Locator
from askui.models.models import (
ActModel,
DetectedElement,
GetModel,
LocateModel,
ModelComposition,
Expand Down Expand Up @@ -65,3 +66,11 @@ def locate(
model: ModelComposition | str,
) -> PointList:
return self._locate_model.locate(locator, image, model)

@override
def locate_all_elements(
self,
image: ImageSource,
model: ModelComposition | str,
) -> list[DetectedElement]:
return self._locate_model.locate_all_elements(image, model)
Loading