diff --git a/README.md b/README.md index 8fcf51f6..0b53f7d1 100644 --- a/README.md +++ b/README.md @@ -159,9 +159,26 @@ You can use Vision Agent with UI-TARS if you provide your own UI-TARS API endpoi 1. Step: Host the model locally or in the cloud. More information about hosting UI-TARS can be found [here](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#deployment). -2. Step: Provide the `TARS_URL` and `TARS_API_KEY` environment variables to Vision Agent. +2. Step: Provide the `TARS_URL`, `TARS_API_KEY`, and `TARS_MODEL_NAME` environment variables to Vision Agent. -3. Step: Use the `model="tars"` parameter in your `click()`, `get()` and `act()` etc. commands or when initializing the `VisionAgent`. +3. Step: Use the `model="tars"` parameter in your `click()`, `get()` and `act()` etc. commands or when initializing the `VisionAgent`. The TARS model will be automatically registered if the environment variables are available. + +**Example Code:** +```python +# Set environment variables before running this code: +# TARS_URL=http://your-tars-endpoint.com/v1 +# TARS_API_KEY=your-tars-api-key +# TARS_MODEL_NAME=your-model-name + +from askui import VisionAgent + + +# Use TARS model directly +with VisionAgent(model="tars") as agent: + agent.click("Submit button") # Uses TARS automatically + agent.get("What's on screen?") # Uses TARS automatically + agent.act("Search for flights") # Uses TARS automatically +``` ## ▶️ Start Building diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 3429ef0a..10af8f75 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -33,6 +33,7 @@ from ..logger import logger from .askui.inference_api import AskUiInferenceApi +from .ui_tars_ep.ui_tars_api import UiTarsApiHandler, UiTarsApiHandlerSettings def initialize_default_model_registry( # noqa: C901 @@ -90,6 +91,20 @@ def hf_spaces_handler() -> HFSpacesHandler: locator_serializer=vlm_locator_serializer(), ) + @functools.cache + def tars_handler() -> UiTarsApiHandler: + try: + settings = UiTarsApiHandlerSettings() + locator_serializer = VlmLocatorSerializer() + return UiTarsApiHandler( + reporter=reporter, + settings=settings, + locator_serializer=locator_serializer, + ) + except Exception as e: + error_msg = f"Failed to initialize TARS model: {e}" + raise ValueError(error_msg) + return { ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022: anthropic_facade, ModelName.ASKUI: askui_facade, @@ -103,6 +118,7 @@ def hf_spaces_handler() -> HFSpacesHandler: ModelName.HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT: hf_spaces_handler, ModelName.HF__SPACES__OS_COPILOT__OS_ATLAS_BASE_7B: hf_spaces_handler, ModelName.HF__SPACES__SHOWUI__2B: hf_spaces_handler, + ModelName.TARS: tars_handler, } diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index b8f0bb66..aa128d07 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -18,7 +18,6 @@ from askui.models.shared.tools import Tool from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter -from askui.tools.agent_os import AgentOs from askui.utils.image_utils import ImageSource, image_to_base64 from .parser import UITarsEPMessage @@ -85,22 +84,26 @@ class UiTarsApiHandlerSettings(BaseSettings): tars_url: HttpUrl = Field( validation_alias="TARS_URL", + description="URL of the TARS API", ) tars_api_key: SecretStr = Field( min_length=1, validation_alias="TARS_API_KEY", + description="API key for authenticating with the TARS API", + ) + tars_model_name: str = Field( + validation_alias="TARS_MODEL_NAME", + description="Name of the TARS model to use for inference", ) class UiTarsApiHandler(ActModel, LocateModel, GetModel): def __init__( self, - agent_os: AgentOs, reporter: Reporter, settings: UiTarsApiHandlerSettings, locator_serializer: VlmLocatorSerializer, ) -> None: - self._agent_os = agent_os self._reporter = reporter self._settings = settings self._client = OpenAI( @@ -111,7 +114,7 @@ def __init__( def _predict(self, image_url: str, instruction: str, prompt: str) -> str | None: chat_completion = self._client.chat.completions.create( - model="tgi", + model=self._settings.tars_model_name, messages=[ { "role": "user", @@ -159,11 +162,10 @@ def locate( prompt=PROMPT, ) assert prediction is not None - pattern = r"click\(start_box='(\(\d+,\d+\))'\)" + pattern = r"click\(start_box='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)" match = re.search(pattern, prediction) if match: - x, y = match.group(1).strip("()").split(",") - x, y = int(x), int(y) + x, y = int(match.group(1)), int(match.group(2)) width, height = image.root.size new_height, new_width = smart_resize(height, width) x, y = (int(x / new_width * width), int(y / new_height * height)) @@ -213,8 +215,21 @@ def act( if not isinstance(message.content, str): error_msg = "UI-TARS only supports text messages" raise ValueError(error_msg) # noqa: TRY004 + + # Find the computer tool + computer_tool = None + if tools: + for tool in tools: + if tool.name == "computer": + computer_tool = tool + break + + if computer_tool is None: + error_msg = "Computer tool is required for UI-TARS act() method" + raise ValueError(error_msg) + goal = message.content - screenshot = self._agent_os.screenshot() + screenshot = computer_tool(action="screenshot") self.act_history = [ { "role": "user", @@ -231,10 +246,10 @@ def act( ], } ] - self.execute_act(self.act_history) + self.execute_act(self.act_history, computer_tool) - def add_screenshot_to_history(self, message_history: list[dict[str, Any]]) -> None: - screenshot = self._agent_os.screenshot() + def add_screenshot_to_history(self, message_history: list[dict[str, Any]], computer_tool: Tool) -> None: + screenshot = computer_tool(action="screenshot") message_history.append( { "role": "user", @@ -293,11 +308,11 @@ def filter_message_thread( return filtered_messages - def execute_act(self, message_history: list[dict[str, Any]]) -> None: + def execute_act(self, message_history: list[dict[str, Any]], computer_tool: Tool) -> None: message_history = self.filter_message_thread(message_history) chat_completion = self._client.chat.completions.create( - model="tgi", + model=self._settings.tars_model_name, messages=message_history, top_p=None, temperature=None, @@ -321,21 +336,20 @@ def execute_act(self, message_history: list[dict[str, Any]]) -> None: message_history.append( {"role": "user", "content": [{"type": "text", "text": str(e)}]} ) - self.execute_act(message_history) + self.execute_act(message_history, computer_tool) return action = message.parsed_action if action.action_type == "click": - self._agent_os.mouse_move(action.start_box.x, action.start_box.y) - self._agent_os.click("left") + computer_tool(action="mouse_move", coordinate=(action.start_box.x, action.start_box.y)) + computer_tool(action="left_click") time.sleep(1) if action.action_type == "type": - self._agent_os.click("left") - self._agent_os.type(action.content) + computer_tool(action="left_click") + computer_tool(action="type", text=action.content) time.sleep(0.5) if action.action_type == "hotkey": - self._agent_os.keyboard_pressed(action.key) - self._agent_os.keyboard_release(action.key) + computer_tool(action="key", text=action.key) time.sleep(0.5) if action.action_type == "call_user": time.sleep(1) @@ -344,8 +358,8 @@ def execute_act(self, message_history: list[dict[str, Any]]) -> None: if action.action_type == "finished": return - self.add_screenshot_to_history(message_history) - self.execute_act(message_history) + self.add_screenshot_to_history(message_history, computer_tool) + self.execute_act(message_history, computer_tool) def _filter_messages( self, messages: list[UITarsEPMessage], max_messages: int