Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ dependencies = [
"httpx>=0.28.1",
"fastmcp>=2.3.4",
"pure-python-adb>=0.3.0.dev0",
"transformers>=4.45.0",
"torch>=2.1.0",
]
requires-python = ">=3.10"
readme = "README.md"
Expand Down
224 changes: 224 additions & 0 deletions src/askui/models/huggingface/holo1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""Holo-1 Vision Language Model implementation for element location.

This module provides the Holo1LocateModel class that uses the Holo-1 VLM
for locating UI elements on screen based on natural language descriptions.
"""

import json

from typing_extensions import override

from askui.exceptions import AutomationError, ElementNotFoundError
from askui.locators.locators import Locator
from askui.locators.serializers import VlmLocatorSerializer
from askui.logger import logger
from askui.models.models import LocateModel, ModelComposition, Point
from askui.utils.image_utils import ImageSource


class Holo1LocateModel(LocateModel):
"""Holo-1 model implementation for locating UI elements.

This model uses the Holo-1 Vision Language Model for element detection
and supports both GPU and CPU inference.

Attributes:
_model_name: The Hugging Face model identifier
_device: The device to run inference on (cuda/cpu)
_locator_serializer: Serializer for converting locators to prompts
"""

def __init__(
self,
locator_serializer: VlmLocatorSerializer,
model_name: str = "Hcompany/Holo1-7B",
device: str | None = None,
) -> None:
"""Initialize the Holo-1 model.

Args:
locator_serializer: Serializer for converting locators to prompts
model_name: The Hugging Face model identifier
device: Device to run inference on. If None, auto-detects GPU availability
"""
self._model_name = model_name
self._locator_serializer = locator_serializer
self._model = None
self._processor = None

# Lazy import to avoid loading heavy dependencies
import torch

if device is None:
self._device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self._device = device

logger.info(f"Holo-1 model will use device: {self._device}")

def _load_model(self) -> None:
"""Lazy load the model and processor."""
if self._model is not None:
return

logger.info(f"Loading Holo-1 model from {self._model_name}")

try:
from transformers import AutoModelForImageTextToText, AutoProcessor

self._processor = AutoProcessor.from_pretrained(self._model_name)
self._model = AutoModelForImageTextToText.from_pretrained(
self._model_name,
torch_dtype="auto",
device_map=self._device if self._device != "cpu" else None,
)

# Set to evaluation mode
self._model.eval()

logger.info("Holo-1 model loaded successfully")

except Exception as e:
error_msg = f"Failed to load Holo-1 model: {e}"
logger.error(error_msg)
raise AutomationError(error_msg) from e

def _parse_model_output(
self, output: str, _image_width: int, _image_height: int
) -> Point:
"""Parse the model output to extract coordinates.

Args:
output: The model's text output
image_width: Width of the input image
image_height: Height of the input image

Returns:
A tuple of (x, y) coordinates

Raises:
ElementNotFoundError: If coordinates cannot be parsed from output
"""
try:
# Expected format: {"bbox": [x1, y1, x2, y2]} or similar
# This may need adjustment based on actual model output format
if "bbox" in output:
bbox_data = json.loads(output)
bbox = bbox_data["bbox"]
x1, y1, x2, y2 = bbox

# Return center point
x = int((x1 + x2) / 2)
y = int((y1 + y2) / 2)

return x, y

# Try to extract coordinates from text
# Format might be "Element at (x, y)" or similar
import re

coord_pattern = r"\\((\\d+),\\s*(\\d+)\\)"
match = re.search(coord_pattern, output)

if match:
x = int(match.group(1))
y = int(match.group(2))
return x, y

error_msg = f"Could not parse coordinates from model output: {output}"
raise ValueError(error_msg) # noqa: TRY301

except (json.JSONDecodeError, ValueError, KeyError) as e:
error_msg = f"Failed to parse Holo-1 output: {output}"
logger.error(error_msg)
empty_locator = ""
raise ElementNotFoundError(empty_locator, empty_locator) from e

@override
def locate(
self,
locator: str | Locator,
image: ImageSource,
model_choice: ModelComposition | str,
) -> Point:
"""Locate an element using the Holo-1 model.

Args:
locator: Element description or locator object
image: Screenshot to analyze
model_choice: Model selection (ignored for single model)

Returns:
Coordinates of the located element as (x, y) tuple

Raises:
AutomationError: If model inference fails
ElementNotFoundError: If element cannot be found
"""
if isinstance(model_choice, ModelComposition):
error_msg = "Model composition is not supported for Holo-1"
raise NotImplementedError(error_msg)

# Ensure model is loaded
self._load_model()

# Serialize locator if needed
serialized_locator = (
self._locator_serializer.serialize(locator)
if isinstance(locator, Locator)
else locator
)

# Prepare messages for chat template
messages = [
{"role": "user", "content": f"Locate the UI element: {serialized_locator}"}
]

try:
# Apply chat template and process
text = self._processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

inputs = self._processor(
text=[text], images=image.root, return_tensors="pt"
)

# Move to device if not CPU
if self._device != "cpu":
inputs = inputs.to(self._device)

# Generate response
import torch

with torch.no_grad():
generated_ids = self._model.generate(
**inputs,
max_new_tokens=128,
)

# Trim generated tokens and decode
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(
inputs.input_ids, generated_ids, strict=False
)
]

response = self._processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True
)[0]

logger.debug(f"Holo-1 response: {response}")

# Parse coordinates from response
return self._parse_model_output(
response, image.root.width, image.root.height
)

except Exception as e:
if isinstance(e, (ElementNotFoundError, NotImplementedError)):
raise
error_msg = f"Holo-1 inference failed: {e}"
logger.error(error_msg)
raise AutomationError(error_msg) from e
32 changes: 32 additions & 0 deletions src/askui/models/huggingface/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from pydantic import Field
from pydantic_settings import BaseSettings


class Holo1Settings(BaseSettings):
"""Settings for Holo-1 model configuration.

Environment variables:
HOLO1_MODEL_NAME: Hugging Face model identifier (default: Hcompany/Holo1-7B)
HOLO1_DEVICE: Device to run inference on (default: auto-detect)
HOLO1_MAX_NEW_TOKENS: Maximum tokens to generate (default: 128)
HOLO1_TEMPERATURE: Sampling temperature (default: 0.1)
"""

model_name: str = Field(
default="Hcompany/Holo1-7B",
description="Hugging Face model identifier",
)
device: str | None = Field(
default=None,
description="Device to run inference on (cuda/cpu, auto-detect if None)",
)
max_new_tokens: int = Field(
default=128,
description="Maximum number of tokens to generate",
)
temperature: float = Field(
default=0.1,
description="Sampling temperature for generation",
)

model_config = {"env_prefix": "HOLO1_"}
12 changes: 12 additions & 0 deletions src/askui/models/model_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
AskUiComputerAgentSettings,
)
from askui.models.exceptions import ModelNotFoundError, ModelTypeMismatchError
from askui.models.huggingface.holo1 import Holo1LocateModel
from askui.models.huggingface.settings import Holo1Settings
from askui.models.huggingface.spaces_api import HFSpacesHandler
from askui.models.models import (
MODEL_TYPES,
Expand Down Expand Up @@ -116,6 +118,15 @@ def hf_spaces_handler() -> HFSpacesHandler:
locator_serializer=vlm_locator_serializer(),
)

@functools.cache
def holo1_locate_model() -> Holo1LocateModel:
settings = Holo1Settings()
return Holo1LocateModel(
locator_serializer=vlm_locator_serializer(),
model_name=settings.model_name,
device=settings.device,
)

return {
ModelName.ASKUI: askui_facade,
ModelName.ASKUI__AI_ELEMENT: askui_model_router,
Expand All @@ -128,6 +139,7 @@ def hf_spaces_handler() -> HFSpacesHandler:
ModelName.HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT: hf_spaces_handler,
ModelName.HF__SPACES__OS_COPILOT__OS_ATLAS_BASE_7B: hf_spaces_handler,
ModelName.HF__SPACES__SHOWUI__2B: hf_spaces_handler,
ModelName.HF__HOLO_1: holo1_locate_model,
}


Expand Down
1 change: 1 addition & 0 deletions src/askui/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ModelName(str, Enum):
HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT = "Qwen/Qwen2-VL-7B-Instruct"
HF__SPACES__SHOWUI__2B = "showlab/ShowUI-2B"
TARS = "tars"
HF__HOLO_1 = "holo-1"


ANTHROPIC_MODEL_NAME_MAPPING = {
Expand Down
1 change: 1 addition & 0 deletions tests/integration/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for model implementations."""
Loading
Loading