Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/askui/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ def __init__(
load_dotenv()
configure_logging(level=log_level)
self._reporter = CompositeReporter(reporters=reporters)
self._tools = tools or AgentToolbox(
self.tools = tools or AgentToolbox(
agent_os=AskUiControllerClient(
display=display,
reporter=self._reporter,
),
)
self._model_router = ModelRouter(
tools=self._tools, reporter=self._reporter, models=models
tools=self.tools, reporter=self._reporter, models=models
)
self.model = model
self._retry = retry or ConfigurableRetry(
Expand Down Expand Up @@ -160,7 +160,7 @@ def click(
if locator is not None:
logger.debug("VisionAgent received instruction to click on %s", locator)
self._mouse_move(locator, model)
self._tools.os.click(button, repeat)
self.tools.os.click(button, repeat)

def _locate(
self,
Expand All @@ -170,7 +170,7 @@ def _locate(
) -> Point:
def locate_with_screenshot() -> Point:
_screenshot = ImageSource(
self._tools.os.screenshot() if screenshot is None else screenshot
self.tools.os.screenshot() if screenshot is None else screenshot
)
return self._model_router.locate(
screenshot=_screenshot,
Expand Down Expand Up @@ -219,7 +219,7 @@ def _mouse_move(
self, locator: str | Locator, model: ModelComposition | str | None = None
) -> None:
point = self._locate(locator=locator, model=model)
self._tools.os.mouse_move(point[0], point[1])
self.tools.os.mouse_move(point[0], point[1])

@telemetry.record_call(exclude={"locator"})
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down Expand Up @@ -281,7 +281,7 @@ def mouse_scroll(
```
"""
self._reporter.add_message("User", f'mouse_scroll: "{x}", "{y}"')
self._tools.os.mouse_scroll(x, y)
self.tools.os.mouse_scroll(x, y)

@telemetry.record_call(exclude={"text"})
@validate_call
Expand All @@ -307,7 +307,7 @@ def type(
"""
self._reporter.add_message("User", f'type: "{text}"')
logger.debug("VisionAgent received instruction to type '%s'", text)
self._tools.os.type(text)
self.tools.os.type(text)

@overload
def get(
Expand Down Expand Up @@ -392,7 +392,7 @@ class UrlResponse(ResponseSchemaBase):
```
"""
logger.debug("VisionAgent received instruction to get '%s'", query)
_image = ImageSource(self._tools.os.screenshot() if image is None else image)
_image = ImageSource(self.tools.os.screenshot() if image is None else image)
self._reporter.add_message("User", f'get: "{query}"', image=_image.root)
response = self._model_router.get(
image=_image,
Expand Down Expand Up @@ -454,7 +454,7 @@ def key_up(
"""
self._reporter.add_message("User", f'key_up "{key}"')
logger.debug("VisionAgent received in key_up '%s'", key)
self._tools.os.keyboard_release(key)
self.tools.os.keyboard_release(key)

@telemetry.record_call()
@validate_call
Expand All @@ -479,7 +479,7 @@ def key_down(
"""
self._reporter.add_message("User", f'key_down "{key}"')
logger.debug("VisionAgent received in key_down '%s'", key)
self._tools.os.keyboard_pressed(key)
self.tools.os.keyboard_pressed(key)

@telemetry.record_call()
@validate_call
Expand All @@ -505,7 +505,7 @@ def mouse_up(
"""
self._reporter.add_message("User", f'mouse_up "{button}"')
logger.debug("VisionAgent received instruction to mouse_up '%s'", button)
self._tools.os.mouse_up(button)
self.tools.os.mouse_up(button)

@telemetry.record_call()
@validate_call
Expand All @@ -531,7 +531,7 @@ def mouse_down(
"""
self._reporter.add_message("User", f'mouse_down "{button}"')
logger.debug("VisionAgent received instruction to mouse_down '%s'", button)
self._tools.os.mouse_down(button)
self.tools.os.mouse_down(button)

@telemetry.record_call(exclude={"goal", "on_message"})
@validate_call
Expand Down Expand Up @@ -616,7 +616,7 @@ def keyboard(
msg += f" {repeat}x times"
self._reporter.add_message("User", msg)
logger.debug("VisionAgent received instruction to press '%s'", key)
self._tools.os.keyboard_tap(key, modifier_keys, count=repeat)
self.tools.os.keyboard_tap(key, modifier_keys, count=repeat)

@telemetry.record_call(exclude={"command"})
@validate_call
Expand Down Expand Up @@ -644,16 +644,16 @@ def cli(
```
"""
logger.debug("VisionAgent received instruction to execute '%s' on cli", command)
self._tools.os.run_command(command)
self.tools.os.run_command(command)

@telemetry.record_call(flush=True)
def close(self) -> None:
self._tools.os.disconnect()
self.tools.os.disconnect()
self._reporter.generate()

@telemetry.record_call()
def open(self) -> None:
self._tools.os.connect()
self.tools.os.connect()

@telemetry.record_call()
def __enter__(self) -> "VisionAgent":
Expand Down