diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index f56ed7e9..2cfafdd0 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -1,3 +1,4 @@ +import math import pathlib import re import time @@ -16,6 +17,61 @@ from .parser import UITarsEPMessage from .prompts import PROMPT, PROMPT_QA +# Constants copied from vision_processing.py in package qwen_vl_utils +# See also github.com/bytedance/UI-TARS/blob/main/README_coordinates.md +IMAGE_FACTOR = 28 +MIN_PIXELS = 100 * 28 * 28 # 4 * 28 * 28 in the original vision_processing.py +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: float, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: float, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +@staticmethod +def smart_resize( + height: int, + width: int, + factor: int = IMAGE_FACTOR, + min_pixels: int = MIN_PIXELS, + max_pixels: int = MAX_PIXELS, +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met (see github.com/bytedance/UI-TARS/blob/main/README_coordinates.md): + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + error_msg = f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + raise ValueError(error_msg) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + class UiTarsApiHandlerSettings(BaseSettings): """Settings for TARS API.""" @@ -90,8 +146,8 @@ def locate_prediction( if isinstance(image, pathlib.Path): image = Image.open(image) width, height = image.size - x = (x * width) // 1000 - y = (y * height) // 1000 + new_height, new_width = smart_resize(height, width) + x, y = (int(x / new_width * width), int(y / new_height * height)) return x, y return None, None