Skip to content

Commit 1e5cc3e

Browse files
committed
refactor: tool integration into models
- models define the abstract interface for tools - simplify a lot of logic - fix pressing "Escape" key - remove code duplication and unnecessary complex code - allow tool results to include multiple images and texts
1 parent 48fbe82 commit 1e5cc3e

File tree

21 files changed

+504
-732
lines changed

21 files changed

+504
-732
lines changed

src/askui/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase
3434
from .retry import ConfigurableRetry, Retry
3535
from .tools import ModifierKey, PcKey
36-
from .tools.anthropic import ToolResult
3736
from .utils.image_utils import ImageSource, Img
3837

3938
__all__ = [
@@ -67,7 +66,6 @@
6766
"Retry",
6867
"TextBlockParam",
6968
"TextCitationParam",
70-
"ToolResult",
7169
"ToolResultBlockParam",
7270
"ToolUseBlockParam",
7371
"UrlImageSourceParam",

src/askui/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from askui.locators.locators import Locator
1111
from askui.models.shared.computer_agent_cb_param import OnMessageCb
1212
from askui.models.shared.computer_agent_message_param import MessageParam
13+
from askui.models.shared.tools import ToolCollection
14+
from askui.tools.computer import Computer20241022Tool
1315
from askui.utils.image_utils import ImageSource, Img
1416

1517
from .logger import configure_logging, logger
@@ -79,7 +81,9 @@ def __init__(
7981
),
8082
)
8183
self._model_router = ModelRouter(
82-
tools=self.tools, reporter=self._reporter, models=models
84+
tool_collection=ToolCollection(tools=[Computer20241022Tool(self.tools.os)]),
85+
reporter=self._reporter,
86+
models=models,
8387
)
8488
self.model = model
8589
self._retry = retry or ConfigurableRetry(

src/askui/models/anthropic/computer_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from askui.models.models import ANTHROPIC_MODEL_NAME_MAPPING, ModelName
88
from askui.models.shared.computer_agent import ComputerAgent
99
from askui.models.shared.computer_agent_message_param import MessageParam
10+
from askui.models.shared.tools import ToolCollection
1011
from askui.reporting import Reporter
11-
from askui.tools.agent_os import AgentOs
1212

1313
if TYPE_CHECKING:
1414
from anthropic.types.beta import BetaMessageParam
@@ -17,11 +17,11 @@
1717
class ClaudeComputerAgent(ComputerAgent[ClaudeComputerAgentSettings]):
1818
def __init__(
1919
self,
20-
agent_os: AgentOs,
20+
tool_collection: ToolCollection,
2121
reporter: Reporter,
2222
settings: ClaudeComputerAgentSettings,
2323
) -> None:
24-
super().__init__(settings, agent_os, reporter)
24+
super().__init__(settings, tool_collection, reporter)
2525
self._client = Anthropic(
2626
api_key=self._settings.anthropic.api_key.get_secret_value()
2727
)

src/askui/models/askui/computer_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from askui.models.askui.settings import AskUiComputerAgentSettings
66
from askui.models.shared.computer_agent import ComputerAgent
77
from askui.models.shared.computer_agent_message_param import MessageParam
8+
from askui.models.shared.tools import ToolCollection
89
from askui.reporting import Reporter
9-
from askui.tools.agent_os import AgentOs
1010

1111
from ...logger import logger
1212

@@ -21,11 +21,11 @@ def is_retryable_error(exception: BaseException) -> bool:
2121
class AskUiComputerAgent(ComputerAgent[AskUiComputerAgentSettings]):
2222
def __init__(
2323
self,
24-
agent_os: AgentOs,
24+
tool_collection: ToolCollection,
2525
reporter: Reporter,
2626
settings: AskUiComputerAgentSettings,
2727
) -> None:
28-
super().__init__(settings, agent_os, reporter)
28+
super().__init__(settings, tool_collection, reporter)
2929
self._client = httpx.Client(
3030
base_url=f"{self._settings.askui.base_url}",
3131
headers={

src/askui/models/askui/settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@ class AskUiSettings(BaseSettings):
1616
validation_alias="ASKUI_INFERENCE_ENDPOINT",
1717
)
1818
workspace_id: UUID4 = Field(
19+
default=...,
1920
validation_alias="ASKUI_WORKSPACE_ID",
2021
)
2122
token: SecretStr = Field(
23+
default=...,
2224
validation_alias="ASKUI_TOKEN",
2325
)
2426

src/askui/models/model_router.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,19 @@
3030
from askui.models.shared.computer_agent_cb_param import OnMessageCb
3131
from askui.models.shared.computer_agent_message_param import MessageParam
3232
from askui.models.shared.facade import ModelFacade
33+
from askui.models.shared.tools import ToolCollection
3334
from askui.models.types.response_schemas import ResponseSchema
3435
from askui.reporting import CompositeReporter, Reporter
35-
from askui.tools.toolbox import AgentToolbox
3636
from askui.utils.image_utils import ImageSource
3737

3838
from ..logger import logger
3939
from .anthropic.computer_agent import ClaudeComputerAgent
4040
from .anthropic.handler import ClaudeHandler
4141
from .askui.inference_api import AskUiInferenceApi, AskUiSettings
42-
from .ui_tars_ep.ui_tars_api import UiTarsApiHandler, UiTarsApiHandlerSettings
4342

4443

4544
def _initialize_default_model_registry( # noqa: C901
46-
tools: AgentToolbox,
45+
tool_collection: ToolCollection,
4746
reporter: Reporter,
4847
) -> ModelRegistry:
4948
@functools.cache
@@ -74,7 +73,7 @@ def vlm_locator_serializer() -> VlmLocatorSerializer:
7473
def anthropic_facade() -> ModelFacade:
7574
settings = AnthropicSettings()
7675
computer_agent = ClaudeComputerAgent(
77-
agent_os=tools.os,
76+
tool_collection=tool_collection,
7877
reporter=reporter,
7978
settings=ClaudeComputerAgentSettings(
8079
anthropic=settings,
@@ -95,7 +94,7 @@ def anthropic_facade() -> ModelFacade:
9594
@functools.cache
9695
def askui_facade() -> ModelFacade:
9796
computer_agent = AskUiComputerAgent(
98-
agent_os=tools.os,
97+
tool_collection=tool_collection,
9998
reporter=reporter,
10099
settings=AskUiComputerAgentSettings(
101100
askui=askui_settings(),
@@ -113,15 +112,6 @@ def hf_spaces_handler() -> HFSpacesHandler:
113112
locator_serializer=vlm_locator_serializer(),
114113
)
115114

116-
@functools.cache
117-
def ui_tars_api_handler() -> UiTarsApiHandler:
118-
return UiTarsApiHandler(
119-
locator_serializer=vlm_locator_serializer(),
120-
agent_os=tools.os,
121-
reporter=reporter,
122-
settings=UiTarsApiHandlerSettings(),
123-
)
124-
125115
return {
126116
ModelName.ASKUI: askui_facade,
127117
ModelName.ASKUI__AI_ELEMENT: askui_model_router,
@@ -134,20 +124,20 @@ def ui_tars_api_handler() -> UiTarsApiHandler:
134124
ModelName.HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT: hf_spaces_handler,
135125
ModelName.HF__SPACES__OS_COPILOT__OS_ATLAS_BASE_7B: hf_spaces_handler,
136126
ModelName.HF__SPACES__SHOWUI__2B: hf_spaces_handler,
137-
ModelName.TARS: ui_tars_api_handler,
138127
}
139128

140129

141130
class ModelRouter:
142131
def __init__(
143132
self,
144-
tools: AgentToolbox,
133+
tool_collection: ToolCollection,
145134
reporter: Reporter | None = None,
146135
models: ModelRegistry | None = None,
147136
):
148-
self._tools = tools
149137
self._reporter = reporter or CompositeReporter()
150-
self._models = _initialize_default_model_registry(tools, self._reporter)
138+
self._models = _initialize_default_model_registry(
139+
tool_collection, self._reporter
140+
)
151141
self._models.update(models or {})
152142

153143
@overload

src/askui/models/shared/computer_agent.py

Lines changed: 12 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,12 @@
1111
from askui.models.models import ActModel
1212
from askui.models.shared.computer_agent_cb_param import OnMessageCb, OnMessageCbParam
1313
from askui.models.shared.computer_agent_message_param import (
14-
Base64ImageSourceParam,
15-
ContentBlockParam,
1614
ImageBlockParam,
1715
MessageParam,
1816
TextBlockParam,
19-
ToolResultBlockParam,
2017
)
18+
from askui.models.shared.tools import ToolCollection
2119
from askui.reporting import Reporter
22-
from askui.tools.agent_os import AgentOs
23-
from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult
2420

2521
from ...logger import logger
2622

@@ -189,21 +185,19 @@ class ComputerAgent(ActModel, ABC, Generic[ComputerAgentSettings]):
189185
def __init__(
190186
self,
191187
settings: ComputerAgentSettings,
192-
agent_os: AgentOs,
188+
tool_collection: ToolCollection,
193189
reporter: Reporter,
194190
) -> None:
195191
"""Initialize the computer agent.
196192
197193
Args:
198194
settings (ComputerAgentSettings): The settings for the computer agent.
199-
agent_os (AgentOs): The operating system agent for executing commands.
195+
tool_collection (ToolCollection): Collection of tools to be used
200196
reporter (Reporter): The reporter for logging messages and actions.
201197
"""
202198
self._settings = settings
203199
self._reporter = reporter
204-
self._tool_collection = ToolCollection(
205-
ComputerTool(agent_os),
206-
)
200+
self._tool_collection = tool_collection
207201
self._system = BetaTextBlockParam(
208202
type="text",
209203
text=f"{SYSTEM_PROMPT}",
@@ -315,24 +309,20 @@ def _use_tools(
315309
MessageParam | None: A message containing tool results or `None`
316310
if no tools were used.
317311
"""
318-
tool_result_content: list[ContentBlockParam] = []
319312
if isinstance(message.content, str):
320313
return None
321314

322-
for content_block in message.content:
323-
if content_block.type == "tool_use":
324-
result = self._tool_collection.run(
325-
name=content_block.name,
326-
tool_input=content_block.input, # type: ignore[arg-type]
327-
)
328-
tool_result_content.append(
329-
self._make_api_tool_result(result, content_block.id)
330-
)
331-
if len(tool_result_content) == 0:
315+
tool_use_content_blocks = [
316+
content_block
317+
for content_block in message.content
318+
if content_block.type == "tool_use"
319+
]
320+
content = self._tool_collection.run(tool_use_content_blocks)
321+
if len(content) == 0:
332322
return None
333323

334324
return MessageParam(
335-
content=tool_result_content,
325+
content=content,
336326
role="user",
337327
)
338328

@@ -391,62 +381,3 @@ def _maybe_filter_to_n_most_recent_images(
391381
new_content.append(content)
392382
tool_result.content = new_content
393383
return messages
394-
395-
def _make_api_tool_result(
396-
self, result: ToolResult, tool_use_id: str
397-
) -> ToolResultBlockParam:
398-
"""Convert a tool result to an API tool result block.
399-
400-
Args:
401-
result (ToolResult): The tool result to convert.
402-
tool_use_id (str): The ID of the tool use block.
403-
404-
Returns:
405-
ToolResultBlockParam: The API tool result block.
406-
"""
407-
tool_result_content: list[TextBlockParam | ImageBlockParam] | str = []
408-
is_error = False
409-
if result.error:
410-
is_error = True
411-
tool_result_content = self._maybe_prepend_system_tool_result(
412-
result, result.error
413-
)
414-
else:
415-
assert isinstance(tool_result_content, list)
416-
if result.output:
417-
tool_result_content.append(
418-
TextBlockParam(
419-
text=self._maybe_prepend_system_tool_result(
420-
result, result.output
421-
),
422-
)
423-
)
424-
if result.base64_image:
425-
tool_result_content.append(
426-
ImageBlockParam(
427-
source=Base64ImageSourceParam(
428-
media_type="image/png",
429-
data=result.base64_image,
430-
),
431-
)
432-
)
433-
return ToolResultBlockParam(
434-
content=tool_result_content,
435-
tool_use_id=tool_use_id,
436-
is_error=is_error,
437-
)
438-
439-
@staticmethod
440-
def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str:
441-
"""Prepend system message to tool result text if available.
442-
443-
Args:
444-
result (ToolResult): The tool result.
445-
result_text (str): The result text.
446-
447-
Returns:
448-
str: The result text with optional system message prepended.
449-
"""
450-
if result.system:
451-
result_text = f"<system>{result.system}</system>\n{result_text}"
452-
return result_text

0 commit comments

Comments
 (0)