Skip to content

Commit d8b45a4

Browse files
Feat/Add annotation function
1 parent 3ee3a1c commit d8b45a4

File tree

7 files changed

+513
-2
lines changed

7 files changed

+513
-2
lines changed

src/askui/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""AskUI Vision Agent"""
22

3-
__version__ = "0.22.2"
3+
__version__ = "0.22.3"
44

55
import logging
66
import os

src/askui/agent_base.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
from askui.models.shared.tools import Tool, ToolCollection
1919
from askui.tools.agent_os import AgentOs
2020
from askui.tools.android.agent_os import AndroidAgentOs
21+
from askui.utils.annotation_writer import AnnotationWriter
2122
from askui.utils.image_utils import ImageSource
2223
from askui.utils.source_utils import InputSource, load_image_source
2324

2425
from .models import ModelComposition
2526
from .models.exceptions import ElementNotFoundError, WaitUntilError
2627
from .models.model_router import ModelRouter, initialize_default_model_registry
2728
from .models.models import (
29+
DetectedElement,
2830
ModelChoice,
2931
ModelName,
3032
ModelRegistry,
@@ -507,6 +509,101 @@ def locate_all(
507509
)
508510
return self._locate(locator=locator, screenshot=screenshot, model=model)
509511

512+
@telemetry.record_call(exclude={"locator", "screenshot"})
513+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
514+
def locate_all_elements(
515+
self,
516+
screenshot: Optional[InputSource] = None,
517+
model_composition: ModelComposition | None = None,
518+
) -> list[DetectedElement]:
519+
"""Locate all elements in the current screen using AskUI Models.
520+
521+
Args:
522+
screenshot (InputSource | None, optional): The screenshot to use for
523+
locating the elements. Can be a path to an image file, a PIL Image
524+
object or a data URL. If `None`, takes a screenshot of the currently
525+
selected display.
526+
model_composition (ModelComposition | None, optional): The model composition
527+
to be used for locating the elements.
528+
529+
Returns:
530+
list[DetectedElement]: A list of detected elements
531+
532+
Example:
533+
```python
534+
from askui import VisionAgent
535+
536+
with VisionAgent() as agent:
537+
detected_elements = agent.locate_all_elements()
538+
print(f"Found {len(detected_elements)} elements: {detected_elements}")
539+
```
540+
"""
541+
_screenshot = load_image_source(
542+
self._agent_os.screenshot() if screenshot is None else screenshot
543+
)
544+
return self._model_router.locate_all_elements(
545+
image=_screenshot, model=model_composition or ModelName.ASKUI
546+
)
547+
548+
@telemetry.record_call(exclude={"screenshot", "output_directory"})
549+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
550+
def annotate(
551+
self,
552+
screenshot: InputSource | None = None,
553+
model_composition: ModelComposition | None = None,
554+
output_directory: str = "reports",
555+
) -> None:
556+
"""Annotate the screenshot with the detected elements.
557+
Creates an interactive HTML file with the detected elements
558+
and saves it to the output directory.
559+
The HTML file can be opened in a browser to see the annotated image.
560+
The user can hover over the elements to see their names and text value
561+
and click on the box to copy the text value to the clipboard.
562+
563+
Args:
564+
screenshot (ImageSource | None, optional): The screenshot to annotate.
565+
If `None`, takes a screenshot of the currently selected display.
566+
model_composition (ModelComposition | None, optional): The composition
567+
or name of the model(s) to be used for locating the elements.
568+
output_directory (str, optional): The directory to save the annotated
569+
image. Defaults to "reports".
570+
571+
Example Using VisionAgent:
572+
```python
573+
from askui import VisionAgent
574+
575+
with VisionAgent() as agent:
576+
agent.annotate()
577+
```
578+
579+
Example Using AndroidVisionAgent:
580+
```python
581+
from askui import AndroidVisionAgent
582+
583+
with AndroidVisionAgent() as agent:
584+
agent.annotate()
585+
```
586+
587+
Example Using VisionAgent with custom screenshot and output directory:
588+
```python
589+
from askui import VisionAgent
590+
591+
with VisionAgent() as agent:
592+
agent.annotate(screenshot="screenshot.png", output_directory="htmls")
593+
```
594+
"""
595+
if screenshot is None:
596+
screenshot = self._agent_os.screenshot()
597+
598+
detected_elements = self.locate_all_elements(
599+
screenshot=screenshot,
600+
model_composition=model_composition,
601+
)
602+
AnnotationWriter(
603+
image=screenshot,
604+
elements=detected_elements,
605+
).write_to_file(output_directory)
606+
510607
@telemetry.record_call(exclude={"until"})
511608
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
512609
def wait(

src/askui/models/askui/models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
QueryUnexpectedResponseError,
1818
)
1919
from askui.models.models import (
20+
DetectedElement,
2021
GetModel,
2122
LocateModel,
2223
ModelComposition,
@@ -131,6 +132,34 @@ def _locate(
131132
for element in detected_elements
132133
]
133134

135+
@override
136+
def locate_all_elements(
137+
self,
138+
image: ImageSource,
139+
model: ModelComposition | str,
140+
) -> list[DetectedElement]:
141+
request_body: dict[str, Any] = {
142+
"image": image.to_data_url(),
143+
"instruction": "get all elements",
144+
}
145+
146+
if isinstance(model, ModelComposition):
147+
request_body["modelComposition"] = model.model_dump(by_alias=True)
148+
logger.debug(
149+
"Model composition",
150+
extra={
151+
"modelComposition": json_lib.dumps(request_body["modelComposition"])
152+
},
153+
)
154+
155+
response = self._inference_api.post(path="/inference", json=request_body)
156+
content = response.json()
157+
assert content["type"] == "DETECTED_ELEMENTS", (
158+
f"Received unknown content type {content['type']}"
159+
)
160+
detected_elements = content["data"]["detected_elements"]
161+
return [DetectedElement.from_json(element) for element in detected_elements]
162+
134163

135164
class AskUiGetModel(GetModel):
136165
"""A GetModel implementation that is supposed to be as comprehensive and

src/askui/models/model_router.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from askui.models.models import (
1919
MODEL_TYPES,
2020
ActModel,
21+
DetectedElement,
2122
GetModel,
2223
LocateModel,
2324
Model,
@@ -145,7 +146,6 @@ def tars_handler() -> UiTarsApiHandler:
145146
ModelName.ASKUI__COMBO: askui_locate_model,
146147
ModelName.ASKUI__OCR: askui_locate_model,
147148
ModelName.ASKUI__PTA: askui_locate_model,
148-
ModelName.CLAUDE__SONNET__4__20250514: lambda: anthropic_facade("anthropic"),
149149
ModelName.HF__SPACES__ASKUI__PTA_1: hf_spaces_handler,
150150
ModelName.HF__SPACES__QWEN__QWEN2_VL_2B_INSTRUCT: hf_spaces_handler,
151151
ModelName.HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT: hf_spaces_handler,
@@ -265,3 +265,17 @@ def locate(
265265
extra={"model": _model},
266266
)
267267
return m.locate(locator, screenshot, _model_composition or _model)
268+
269+
def locate_all_elements(
270+
self,
271+
image: ImageSource,
272+
model: ModelComposition | str,
273+
) -> list[DetectedElement]:
274+
_model = ModelName.ASKUI if isinstance(model, ModelComposition) else model
275+
_model_composition = model if isinstance(model, ModelComposition) else None
276+
m, _model = self._get_model(_model, "locate")
277+
logger.debug(
278+
"Routing locate_all_elements prediction to",
279+
extra={"model": _model},
280+
)
281+
return m.locate_all_elements(image, model=_model_composition or _model)

src/askui/models/models.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,80 @@ def __getitem__(self, index: int) -> ModelDefinition:
153153
"""
154154

155155

156+
class BoundingBox(BaseModel):
157+
model_config = ConfigDict(
158+
extra="ignore",
159+
)
160+
161+
xmin: int
162+
ymin: int
163+
xmax: int
164+
ymax: int
165+
166+
@staticmethod
167+
def from_json(data: dict[str, float]) -> "BoundingBox":
168+
return BoundingBox(
169+
xmin=int(data["xmin"]),
170+
ymin=int(data["ymin"]),
171+
xmax=int(data["xmax"]),
172+
ymax=int(data["ymax"]),
173+
)
174+
175+
def __str__(self) -> str:
176+
return f"[{self.xmin}, {self.ymin}, {self.xmax}, {self.ymax}]"
177+
178+
@property
179+
def width(self) -> int:
180+
"""The width of the bounding box."""
181+
return self.xmax - self.xmin
182+
183+
@property
184+
def height(self) -> int:
185+
"""The height of the bounding box."""
186+
return self.ymax - self.ymin
187+
188+
@property
189+
def center(self) -> Point:
190+
"""The center point of the bounding box."""
191+
return int((self.xmin + self.xmax) / 2), int((self.ymin + self.ymax) / 2)
192+
193+
194+
class DetectedElement(BaseModel):
195+
model_config = ConfigDict(
196+
extra="ignore",
197+
)
198+
199+
name: str
200+
text: str
201+
bounding_box: BoundingBox
202+
203+
@staticmethod
204+
def from_json(data: dict[str, str | float | dict[str, float]]) -> "DetectedElement":
205+
return DetectedElement(
206+
name=str(data["name"]),
207+
text=str(data["text"]),
208+
bounding_box=BoundingBox.from_json(data["bndbox"]), # type: ignore
209+
)
210+
211+
def __str__(self) -> str:
212+
return f"[name={self.name}, text={self.text}, bndbox={str(self.bounding_box)}]"
213+
214+
@property
215+
def center(self) -> Point:
216+
"""The center point of the detected element."""
217+
return self.bounding_box.center
218+
219+
@property
220+
def width(self) -> int:
221+
"""The width of the detected element."""
222+
return self.bounding_box.width
223+
224+
@property
225+
def height(self) -> int:
226+
"""The height of the detected element."""
227+
return self.bounding_box.height
228+
229+
156230
class ActModel(abc.ABC):
157231
"""Abstract base class for models that can execute autonomous actions.
158232
@@ -336,6 +410,23 @@ def locate(
336410
"""
337411
raise NotImplementedError
338412

413+
def locate_all_elements(
414+
self,
415+
image: ImageSource,
416+
model: ModelComposition | str,
417+
) -> list[DetectedElement]:
418+
"""Locate all elements in an image.
419+
420+
Args:
421+
image (ImageSource): The image to analyze (screenshot or provided image)
422+
model (ModelComposition | str): Either a string model name or a
423+
`ModelComposition` for models that support composition
424+
425+
Returns:
426+
A list of detected elements
427+
"""
428+
raise NotImplementedError
429+
339430

340431
Model = ActModel | GetModel | LocateModel
341432
"""Union type of all abstract model classes.

src/askui/models/shared/facade.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from askui.locators.locators import Locator
66
from askui.models.models import (
77
ActModel,
8+
DetectedElement,
89
GetModel,
910
LocateModel,
1011
ModelComposition,
@@ -65,3 +66,11 @@ def locate(
6566
model: ModelComposition | str,
6667
) -> PointList:
6768
return self._locate_model.locate(locator, image, model)
69+
70+
@override
71+
def locate_all_elements(
72+
self,
73+
image: ImageSource,
74+
model: ModelComposition | str,
75+
) -> list[DetectedElement]:
76+
return self._locate_model.locate_all_elements(image, model)

0 commit comments

Comments
 (0)