Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions src/askui/models/ui_tars_ep/ui_tars_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import pathlib
import re
import time
Expand All @@ -16,6 +17,61 @@
from .parser import UITarsEPMessage
from .prompts import PROMPT, PROMPT_QA

# Constants copied from vision_processing.py in package qwen_vl_utils
# See also github.com/bytedance/UI-TARS/blob/main/README_coordinates.md
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28 # 4 * 28 * 28 in the original vision_processing.py
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200


def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor


def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor


def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor


@staticmethod
def smart_resize(
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met (see github.com/bytedance/UI-TARS/blob/main/README_coordinates.md):

1. Both dimensions (height and width) are divisible by 'factor'.

2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].

3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
error_msg = f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
raise ValueError(error_msg)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar


class UiTarsApiHandlerSettings(BaseSettings):
"""Settings for TARS API."""
Expand Down Expand Up @@ -90,8 +146,8 @@ def locate_prediction(
if isinstance(image, pathlib.Path):
image = Image.open(image)
width, height = image.size
x = (x * width) // 1000
y = (y * height) // 1000
new_height, new_width = smart_resize(height, width)
x, y = (int(x / new_width * width), int(y / new_height * height))
return x, y
return None, None

Expand Down