From 48fbe829e3af066b9e6de41060c615a78fa6531a Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 18 Jun 2025 08:47:21 +0200 Subject: [PATCH 1/4] chore(mcp): replace mcp with fastmcp - has more capabilities - better documentation --- pdm.lock | 133 +++++++++++++++++++++----------------- pyproject.toml | 9 ++- src/askui/mcp/__init__.py | 5 +- 3 files changed, 81 insertions(+), 66 deletions(-) diff --git a/pdm.lock b/pdm.lock index e70dcd9b..df52c8f7 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "chat", "mcp", "pynput", "test"] +groups = ["default", "chat", "pynput", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:1265b0e0daca5f17ed1ec690f9b11c9c0265ce255976a32163348594e60cce92" +content_hash = "sha256:a520273efe45333a1e8449f9c90dec81be9f9a526656398207266b8d91f83134" [[metadata.targets]] requires_python = ">=3.10" @@ -15,7 +15,7 @@ name = "annotated-types" version = "0.7.0" requires_python = ">=3.8" summary = "Reusable constraint types to use with typing.Annotated" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "typing-extensions>=4.0.0; python_version < \"3.9\"", ] @@ -26,22 +26,22 @@ files = [ [[package]] name = "anthropic" -version = "0.49.0" +version = "0.54.0" requires_python = ">=3.8" summary = "The official Python library for the anthropic API" groups = ["default"] dependencies = [ "anyio<5,>=3.5.0", "distro<2,>=1.7.0", - "httpx<1,>=0.23.0", + "httpx<1,>=0.25.0", "jiter<1,>=0.4.0", "pydantic<3,>=1.9.0", "sniffio", "typing-extensions<5,>=4.10", ] files = [ - {file = "anthropic-0.49.0-py3-none-any.whl", hash = "sha256:bbc17ad4e7094988d2fa86b87753ded8dce12498f4b85fe5810f208f454a8375"}, - {file = "anthropic-0.49.0.tar.gz", hash = "sha256:c09e885b0f674b9119b4f296d8508907f6cff0009bc20d5cf6b35936c40b4398"}, + {file = "anthropic-0.54.0-py3-none-any.whl", hash = "sha256:c1062a0a905daeec17ca9c06c401e4b3f24cb0495841d29d752568a1d4018d56"}, + {file = "anthropic-0.54.0.tar.gz", hash = "sha256:5e6f997d97ce8e70eac603c3ec2e7f23addeff953fbbb76b19430562bb6ba815"}, ] [[package]] @@ -49,7 +49,7 @@ name = "anyio" version = "4.9.0" requires_python = ">=3.9" summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "exceptiongroup>=1.0.2; python_version < \"3.11\"", "idna>=2.8", @@ -77,7 +77,7 @@ name = "certifi" version = "2025.1.31" requires_python = ">=3.6" summary = "Python package for providing Mozilla's CA Bundle." -groups = ["default", "mcp"] +groups = ["default"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -151,7 +151,7 @@ name = "click" version = "8.1.8" requires_python = ">=3.7" summary = "Composable command line interface toolkit" -groups = ["chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "colorama; platform_system == \"Windows\"", "importlib-metadata; python_version < \"3.8\"", @@ -166,7 +166,7 @@ name = "colorama" version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." -groups = ["default", "chat", "mcp", "test"] +groups = ["default", "chat", "test"] marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, @@ -329,8 +329,7 @@ name = "exceptiongroup" version = "1.2.2" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default", "chat", "mcp", "test"] -marker = "python_version < \"3.11\"" +groups = ["default", "chat", "test"] files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -363,6 +362,27 @@ files = [ {file = "fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681"}, ] +[[package]] +name = "fastmcp" +version = "2.3.4" +requires_python = ">=3.10" +summary = "The fast, Pythonic way to build MCP servers." +groups = ["default"] +dependencies = [ + "exceptiongroup>=1.2.2", + "httpx>=0.28.1", + "mcp<2.0.0,>=1.8.1", + "openapi-pydantic>=0.5.1", + "python-dotenv>=1.1.0", + "rich>=13.9.4", + "typer>=0.15.2", + "websockets>=14.0", +] +files = [ + {file = "fastmcp-2.3.4-py3-none-any.whl", hash = "sha256:12a45f72dd95aeaa1a6a56281fff96ca46929def3ccd9f9eb125cb97b722fbab"}, + {file = "fastmcp-2.3.4.tar.gz", hash = "sha256:f3fe004b8735b365a65ec2547eeb47db8352d5613697254854bc7c9c3c360eea"}, +] + [[package]] name = "filelock" version = "3.18.0" @@ -528,7 +548,7 @@ name = "h11" version = "0.14.0" requires_python = ">=3.7" summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "typing-extensions; python_version < \"3.8\"", ] @@ -542,7 +562,7 @@ name = "httpcore" version = "1.0.7" requires_python = ">=3.8" summary = "A minimal low-level HTTP client." -groups = ["default", "mcp"] +groups = ["default"] dependencies = [ "certifi", "h11<0.15,>=0.13", @@ -557,7 +577,7 @@ name = "httpx" version = "0.28.1" requires_python = ">=3.8" summary = "The next generation HTTP client." -groups = ["default", "mcp"] +groups = ["default"] dependencies = [ "anyio", "certifi", @@ -574,7 +594,7 @@ name = "httpx-sse" version = "0.4.0" requires_python = ">=3.8" summary = "Consume Server-Sent Event (SSE) messages with HTTPX." -groups = ["mcp"] +groups = ["default"] files = [ {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, @@ -605,7 +625,7 @@ name = "idna" version = "3.10" requires_python = ">=3.6" summary = "Internationalized Domain Names in Applications (IDNA)" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -702,7 +722,7 @@ name = "markdown-it-py" version = "3.0.0" requires_python = ">=3.8" summary = "Python port of markdown-it. Markdown parsing, done right!" -groups = ["default", "mcp"] +groups = ["default"] dependencies = [ "mdurl~=0.1", ] @@ -773,10 +793,10 @@ files = [ [[package]] name = "mcp" -version = "1.8.1" +version = "1.9.4" requires_python = ">=3.10" summary = "Model Context Protocol SDK" -groups = ["mcp"] +groups = ["default"] dependencies = [ "anyio>=4.5", "httpx-sse>=0.4", @@ -789,27 +809,8 @@ dependencies = [ "uvicorn>=0.23.1; sys_platform != \"emscripten\"", ] files = [ - {file = "mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770"}, - {file = "mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e"}, -] - -[[package]] -name = "mcp" -version = "1.8.1" -extras = ["cli", "rich", "ws"] -requires_python = ">=3.10" -summary = "Model Context Protocol SDK" -groups = ["mcp"] -dependencies = [ - "mcp==1.8.1", - "python-dotenv>=1.0.0", - "rich>=13.9.4", - "typer>=0.12.4", - "websockets>=15.0.1", -] -files = [ - {file = "mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770"}, - {file = "mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e"}, + {file = "mcp-1.9.4-py3-none-any.whl", hash = "sha256:7fcf36b62936adb8e63f89346bccca1268eeca9bf6dfb562ee10b1dfbda9dac0"}, + {file = "mcp-1.9.4.tar.gz", hash = "sha256:cfb0bcd1a9535b42edaef89947b9e18a8feb49362e1cc059d6e7fc636f2cb09f"}, ] [[package]] @@ -817,7 +818,7 @@ name = "mdurl" version = "0.1.2" requires_python = ">=3.7" summary = "Markdown URL utilities" -groups = ["default", "mcp"] +groups = ["default"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -906,6 +907,20 @@ files = [ {file = "openai-1.85.0.tar.gz", hash = "sha256:6ba76e4ebc5725f71f2f6126c7cb5169ca8de60dd5aa61f350f9448ad162c913"}, ] +[[package]] +name = "openapi-pydantic" +version = "0.5.1" +requires_python = "<4.0,>=3.8" +summary = "Pydantic OpenAPI schema implementation" +groups = ["default"] +dependencies = [ + "pydantic>=1.8", +] +files = [ + {file = "openapi_pydantic-0.5.1-py3-none-any.whl", hash = "sha256:a3a09ef4586f5bd760a8df7f43028b60cafb6d9f61de2acba9574766255ab146"}, + {file = "openapi_pydantic-0.5.1.tar.gz", hash = "sha256:ff6835af6bde7a459fb93eb93bb92b8749b754fc6e51b2f1590a19dc3005ee0d"}, +] + [[package]] name = "packaging" version = "24.2" @@ -1031,7 +1046,7 @@ name = "pydantic" version = "2.11.2" requires_python = ">=3.9" summary = "Data validation using Python type hints" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "annotated-types>=0.6.0", "pydantic-core==2.33.1", @@ -1048,7 +1063,7 @@ name = "pydantic-core" version = "2.33.1" requires_python = ">=3.9" summary = "Core functionality for Pydantic validation and serialization" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "typing-extensions!=4.7.0,>=4.6.0", ] @@ -1137,7 +1152,7 @@ name = "pydantic-settings" version = "2.9.1" requires_python = ">=3.9" summary = "Settings management using Pydantic" -groups = ["default", "mcp"] +groups = ["default"] dependencies = [ "pydantic>=2.7.0", "python-dotenv>=0.21.0", @@ -1153,7 +1168,7 @@ name = "pygments" version = "2.19.1" requires_python = ">=3.8" summary = "Pygments is a syntax highlighting package written in Python." -groups = ["default", "mcp"] +groups = ["default"] files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, @@ -1391,7 +1406,7 @@ name = "python-dotenv" version = "1.1.0" requires_python = ">=3.9" summary = "Read key-value pairs from a .env file and set them as environment variables" -groups = ["default", "mcp"] +groups = ["default"] files = [ {file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"}, {file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"}, @@ -1402,7 +1417,7 @@ name = "python-multipart" version = "0.0.20" requires_python = ">=3.8" summary = "A streaming multipart parser for Python" -groups = ["mcp"] +groups = ["default"] files = [ {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, @@ -1490,7 +1505,7 @@ name = "rich" version = "14.0.0" requires_python = ">=3.8.0" summary = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -groups = ["default", "mcp"] +groups = ["default"] dependencies = [ "markdown-it-py>=2.2.0", "pygments<3.0.0,>=2.13.0", @@ -1561,7 +1576,7 @@ name = "shellingham" version = "1.5.4" requires_python = ">=3.7" summary = "Tool to Detect Surrounding Shell" -groups = ["mcp"] +groups = ["default"] files = [ {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, @@ -1583,7 +1598,7 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -1594,7 +1609,7 @@ name = "sse-starlette" version = "2.3.5" requires_python = ">=3.9" summary = "SSE plugin for Starlette" -groups = ["mcp"] +groups = ["default"] dependencies = [ "anyio>=4.7.0", "starlette>=0.41.3", @@ -1609,7 +1624,7 @@ name = "starlette" version = "0.46.2" requires_python = ">=3.9" summary = "The little ASGI library that shines." -groups = ["chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "anyio<5,>=3.6.2", "typing-extensions>=3.10.0; python_version < \"3.10\"", @@ -1691,7 +1706,7 @@ name = "typer" version = "0.15.4" requires_python = ">=3.7" summary = "Typer, build great CLIs. Easy to code. Based on Python type hints." -groups = ["mcp"] +groups = ["default"] dependencies = [ "click<8.2,>=8.0.0", "rich>=10.11.0", @@ -1777,7 +1792,7 @@ name = "typing-extensions" version = "4.13.1" requires_python = ">=3.8" summary = "Backported and Experimental Type Hints for Python 3.8+" -groups = ["default", "chat", "mcp", "test"] +groups = ["default", "chat", "test"] files = [ {file = "typing_extensions-4.13.1-py3-none-any.whl", hash = "sha256:4b6cf02909eb5495cfbc3f6e8fd49217e6cc7944e145cdda8caa3734777f9e69"}, {file = "typing_extensions-4.13.1.tar.gz", hash = "sha256:98795af00fb9640edec5b8e31fc647597b4691f099ad75f469a2616be1a76dff"}, @@ -1788,7 +1803,7 @@ name = "typing-inspection" version = "0.4.0" requires_python = ">=3.9" summary = "Runtime typing introspection tools" -groups = ["default", "chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "typing-extensions>=4.12.0", ] @@ -1813,7 +1828,7 @@ name = "uvicorn" version = "0.34.3" requires_python = ">=3.9" summary = "The lightning-fast ASGI server." -groups = ["chat", "mcp"] +groups = ["default", "chat"] dependencies = [ "click>=7.0", "h11>=0.8", @@ -1829,7 +1844,7 @@ name = "websockets" version = "15.0.1" requires_python = ">=3.9" summary = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -groups = ["default", "mcp"] +groups = ["default"] files = [ {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"}, {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"}, diff --git a/pyproject.toml b/pyproject.toml index 97efe2e5..f85c8f31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "grpcio-tools>=1.67.0", "pillow>=11.0.0", "pydantic>=2.11.0", - "anthropic>=0.49.0", + "anthropic>=0.54.0", "rich>=13.9.4", "pyperclip>=1.9.0", "gradio-client>=1.4.3", @@ -22,6 +22,7 @@ dependencies = [ "segment-analytics-python>=2.3.4", "py-machineid>=0.7.0", "httpx>=0.28.1", + "fastmcp>=2.3.4", ] requires-python = ">=3.10" readme = "README.md" @@ -35,6 +36,7 @@ build-backend = "hatchling.build" [tool.hatch.version] path = "src/askui/__init__.py" + [tool.pdm] distribution = true @@ -56,7 +58,7 @@ typecheck = "mypy" "chat:api" = "uvicorn chat.api.app:app --reload --port 8000" "chat:ui:install" = {shell = "cd src/chat/ui && npm ci"} "chat:ui" = {shell = "cd src/chat/ui && npm run dev"} -mcp = "mcp dev src/askui/mcp/__init__.py" +"mcp:dev" = "mcp dev src/askui/mcp/__init__.py" [dependency-groups] chat = [ @@ -67,9 +69,6 @@ pynput = [ "mss>=10.0.0", "pynput>=1.8.1", ] -mcp = [ - "mcp[cli,rich,ws]>=1.8.1", -] test = [ "pytest>=8.3.4", "ruff>=0.9.5", diff --git a/src/askui/mcp/__init__.py b/src/askui/mcp/__init__.py index c23dd663..f6878350 100644 --- a/src/askui/mcp/__init__.py +++ b/src/askui/mcp/__init__.py @@ -1,8 +1,9 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import Any -from mcp.server.fastmcp import FastMCP +from fastmcp import FastMCP from askui.agent import VisionAgent @@ -13,7 +14,7 @@ class AppContext: @asynccontextmanager -async def mcp_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # noqa: ARG001 +async def mcp_lifespan(server: FastMCP[Any]) -> AsyncIterator[AppContext]: # noqa: ARG001 with VisionAgent(display=2) as vision_agent: server.add_tool(vision_agent.click) yield AppContext(vision_agent=vision_agent) From 1e5cc3eb5ac432d5aab02c30512a9ecaaccb46ea Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 18 Jun 2025 12:54:23 +0200 Subject: [PATCH 2/4] refactor: tool integration into models - models define the abstract interface for tools - simplify a lot of logic - fix pressing "Escape" key - remove code duplication and unnecessary complex code - allow tool results to include multiple images and texts --- src/askui/__init__.py | 2 - src/askui/agent.py | 6 +- src/askui/models/anthropic/computer_agent.py | 6 +- src/askui/models/askui/computer_agent.py | 6 +- src/askui/models/askui/settings.py | 2 + src/askui/models/model_router.py | 26 +- src/askui/models/shared/computer_agent.py | 93 +---- src/askui/models/shared/tools.py | 104 ++++++ src/askui/tools/agent_os.py | 123 ------- src/askui/tools/anthropic/__init__.py | 10 - src/askui/tools/anthropic/base.py | 84 ----- src/askui/tools/anthropic/collection.py | 34 -- src/askui/tools/anthropic/computer.py | 359 ------------------- src/askui/tools/computer.py | 205 +++++++++++ src/askui/tools/exceptions.py | 6 - src/askui/utils/dict_utils.py | 36 ++ test.py | 107 ++++++ tests/conftest.py | 8 + tests/e2e/agent/conftest.py | 5 +- tests/integration/agent/conftest.py | 8 +- tests/unit/models/test_model_router.py | 6 +- 21 files changed, 504 insertions(+), 732 deletions(-) create mode 100644 src/askui/models/shared/tools.py delete mode 100644 src/askui/tools/anthropic/__init__.py delete mode 100644 src/askui/tools/anthropic/base.py delete mode 100644 src/askui/tools/anthropic/collection.py delete mode 100644 src/askui/tools/anthropic/computer.py create mode 100644 src/askui/tools/computer.py delete mode 100644 src/askui/tools/exceptions.py create mode 100644 src/askui/utils/dict_utils.py create mode 100644 test.py diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 90cc1776..ff756cd8 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -33,7 +33,6 @@ from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry from .tools import ModifierKey, PcKey -from .tools.anthropic import ToolResult from .utils.image_utils import ImageSource, Img __all__ = [ @@ -67,7 +66,6 @@ "Retry", "TextBlockParam", "TextCitationParam", - "ToolResult", "ToolResultBlockParam", "ToolUseBlockParam", "UrlImageSourceParam", diff --git a/src/askui/agent.py b/src/askui/agent.py index 1dbd33fc..009fae1c 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -10,6 +10,8 @@ from askui.locators.locators import Locator from askui.models.shared.computer_agent_cb_param import OnMessageCb from askui.models.shared.computer_agent_message_param import MessageParam +from askui.models.shared.tools import ToolCollection +from askui.tools.computer import Computer20241022Tool from askui.utils.image_utils import ImageSource, Img from .logger import configure_logging, logger @@ -79,7 +81,9 @@ def __init__( ), ) self._model_router = ModelRouter( - tools=self.tools, reporter=self._reporter, models=models + tool_collection=ToolCollection(tools=[Computer20241022Tool(self.tools.os)]), + reporter=self._reporter, + models=models, ) self.model = model self._retry = retry or ConfigurableRetry( diff --git a/src/askui/models/anthropic/computer_agent.py b/src/askui/models/anthropic/computer_agent.py index f7dc7ea2..054279a3 100644 --- a/src/askui/models/anthropic/computer_agent.py +++ b/src/askui/models/anthropic/computer_agent.py @@ -7,8 +7,8 @@ from askui.models.models import ANTHROPIC_MODEL_NAME_MAPPING, ModelName from askui.models.shared.computer_agent import ComputerAgent from askui.models.shared.computer_agent_message_param import MessageParam +from askui.models.shared.tools import ToolCollection from askui.reporting import Reporter -from askui.tools.agent_os import AgentOs if TYPE_CHECKING: from anthropic.types.beta import BetaMessageParam @@ -17,11 +17,11 @@ class ClaudeComputerAgent(ComputerAgent[ClaudeComputerAgentSettings]): def __init__( self, - agent_os: AgentOs, + tool_collection: ToolCollection, reporter: Reporter, settings: ClaudeComputerAgentSettings, ) -> None: - super().__init__(settings, agent_os, reporter) + super().__init__(settings, tool_collection, reporter) self._client = Anthropic( api_key=self._settings.anthropic.api_key.get_secret_value() ) diff --git a/src/askui/models/askui/computer_agent.py b/src/askui/models/askui/computer_agent.py index 42073abf..9792d9fa 100644 --- a/src/askui/models/askui/computer_agent.py +++ b/src/askui/models/askui/computer_agent.py @@ -5,8 +5,8 @@ from askui.models.askui.settings import AskUiComputerAgentSettings from askui.models.shared.computer_agent import ComputerAgent from askui.models.shared.computer_agent_message_param import MessageParam +from askui.models.shared.tools import ToolCollection from askui.reporting import Reporter -from askui.tools.agent_os import AgentOs from ...logger import logger @@ -21,11 +21,11 @@ def is_retryable_error(exception: BaseException) -> bool: class AskUiComputerAgent(ComputerAgent[AskUiComputerAgentSettings]): def __init__( self, - agent_os: AgentOs, + tool_collection: ToolCollection, reporter: Reporter, settings: AskUiComputerAgentSettings, ) -> None: - super().__init__(settings, agent_os, reporter) + super().__init__(settings, tool_collection, reporter) self._client = httpx.Client( base_url=f"{self._settings.askui.base_url}", headers={ diff --git a/src/askui/models/askui/settings.py b/src/askui/models/askui/settings.py index b018b40b..3ad709d1 100644 --- a/src/askui/models/askui/settings.py +++ b/src/askui/models/askui/settings.py @@ -16,9 +16,11 @@ class AskUiSettings(BaseSettings): validation_alias="ASKUI_INFERENCE_ENDPOINT", ) workspace_id: UUID4 = Field( + default=..., validation_alias="ASKUI_WORKSPACE_ID", ) token: SecretStr = Field( + default=..., validation_alias="ASKUI_TOKEN", ) diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 2f14c0a1..83329f86 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -30,20 +30,19 @@ from askui.models.shared.computer_agent_cb_param import OnMessageCb from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.shared.facade import ModelFacade +from askui.models.shared.tools import ToolCollection from askui.models.types.response_schemas import ResponseSchema from askui.reporting import CompositeReporter, Reporter -from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource from ..logger import logger from .anthropic.computer_agent import ClaudeComputerAgent from .anthropic.handler import ClaudeHandler from .askui.inference_api import AskUiInferenceApi, AskUiSettings -from .ui_tars_ep.ui_tars_api import UiTarsApiHandler, UiTarsApiHandlerSettings def _initialize_default_model_registry( # noqa: C901 - tools: AgentToolbox, + tool_collection: ToolCollection, reporter: Reporter, ) -> ModelRegistry: @functools.cache @@ -74,7 +73,7 @@ def vlm_locator_serializer() -> VlmLocatorSerializer: def anthropic_facade() -> ModelFacade: settings = AnthropicSettings() computer_agent = ClaudeComputerAgent( - agent_os=tools.os, + tool_collection=tool_collection, reporter=reporter, settings=ClaudeComputerAgentSettings( anthropic=settings, @@ -95,7 +94,7 @@ def anthropic_facade() -> ModelFacade: @functools.cache def askui_facade() -> ModelFacade: computer_agent = AskUiComputerAgent( - agent_os=tools.os, + tool_collection=tool_collection, reporter=reporter, settings=AskUiComputerAgentSettings( askui=askui_settings(), @@ -113,15 +112,6 @@ def hf_spaces_handler() -> HFSpacesHandler: locator_serializer=vlm_locator_serializer(), ) - @functools.cache - def ui_tars_api_handler() -> UiTarsApiHandler: - return UiTarsApiHandler( - locator_serializer=vlm_locator_serializer(), - agent_os=tools.os, - reporter=reporter, - settings=UiTarsApiHandlerSettings(), - ) - return { ModelName.ASKUI: askui_facade, ModelName.ASKUI__AI_ELEMENT: askui_model_router, @@ -134,20 +124,20 @@ def ui_tars_api_handler() -> UiTarsApiHandler: ModelName.HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT: hf_spaces_handler, ModelName.HF__SPACES__OS_COPILOT__OS_ATLAS_BASE_7B: hf_spaces_handler, ModelName.HF__SPACES__SHOWUI__2B: hf_spaces_handler, - ModelName.TARS: ui_tars_api_handler, } class ModelRouter: def __init__( self, - tools: AgentToolbox, + tool_collection: ToolCollection, reporter: Reporter | None = None, models: ModelRegistry | None = None, ): - self._tools = tools self._reporter = reporter or CompositeReporter() - self._models = _initialize_default_model_registry(tools, self._reporter) + self._models = _initialize_default_model_registry( + tool_collection, self._reporter + ) self._models.update(models or {}) @overload diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py index eda81819..b4fb6a75 100644 --- a/src/askui/models/shared/computer_agent.py +++ b/src/askui/models/shared/computer_agent.py @@ -11,16 +11,12 @@ from askui.models.models import ActModel from askui.models.shared.computer_agent_cb_param import OnMessageCb, OnMessageCbParam from askui.models.shared.computer_agent_message_param import ( - Base64ImageSourceParam, - ContentBlockParam, ImageBlockParam, MessageParam, TextBlockParam, - ToolResultBlockParam, ) +from askui.models.shared.tools import ToolCollection from askui.reporting import Reporter -from askui.tools.agent_os import AgentOs -from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult from ...logger import logger @@ -189,21 +185,19 @@ class ComputerAgent(ActModel, ABC, Generic[ComputerAgentSettings]): def __init__( self, settings: ComputerAgentSettings, - agent_os: AgentOs, + tool_collection: ToolCollection, reporter: Reporter, ) -> None: """Initialize the computer agent. Args: settings (ComputerAgentSettings): The settings for the computer agent. - agent_os (AgentOs): The operating system agent for executing commands. + tool_collection (ToolCollection): Collection of tools to be used reporter (Reporter): The reporter for logging messages and actions. """ self._settings = settings self._reporter = reporter - self._tool_collection = ToolCollection( - ComputerTool(agent_os), - ) + self._tool_collection = tool_collection self._system = BetaTextBlockParam( type="text", text=f"{SYSTEM_PROMPT}", @@ -315,24 +309,20 @@ def _use_tools( MessageParam | None: A message containing tool results or `None` if no tools were used. """ - tool_result_content: list[ContentBlockParam] = [] if isinstance(message.content, str): return None - for content_block in message.content: - if content_block.type == "tool_use": - result = self._tool_collection.run( - name=content_block.name, - tool_input=content_block.input, # type: ignore[arg-type] - ) - tool_result_content.append( - self._make_api_tool_result(result, content_block.id) - ) - if len(tool_result_content) == 0: + tool_use_content_blocks = [ + content_block + for content_block in message.content + if content_block.type == "tool_use" + ] + content = self._tool_collection.run(tool_use_content_blocks) + if len(content) == 0: return None return MessageParam( - content=tool_result_content, + content=content, role="user", ) @@ -391,62 +381,3 @@ def _maybe_filter_to_n_most_recent_images( new_content.append(content) tool_result.content = new_content return messages - - def _make_api_tool_result( - self, result: ToolResult, tool_use_id: str - ) -> ToolResultBlockParam: - """Convert a tool result to an API tool result block. - - Args: - result (ToolResult): The tool result to convert. - tool_use_id (str): The ID of the tool use block. - - Returns: - ToolResultBlockParam: The API tool result block. - """ - tool_result_content: list[TextBlockParam | ImageBlockParam] | str = [] - is_error = False - if result.error: - is_error = True - tool_result_content = self._maybe_prepend_system_tool_result( - result, result.error - ) - else: - assert isinstance(tool_result_content, list) - if result.output: - tool_result_content.append( - TextBlockParam( - text=self._maybe_prepend_system_tool_result( - result, result.output - ), - ) - ) - if result.base64_image: - tool_result_content.append( - ImageBlockParam( - source=Base64ImageSourceParam( - media_type="image/png", - data=result.base64_image, - ), - ) - ) - return ToolResultBlockParam( - content=tool_result_content, - tool_use_id=tool_use_id, - is_error=is_error, - ) - - @staticmethod - def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: - """Prepend system message to tool result text if available. - - Args: - result (ToolResult): The tool result. - result_text (str): The result text. - - Returns: - str: The result text with optional system message prepended. - """ - if result.system: - result_text = f"{result.system}\n{result_text}" - return result_text diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py new file mode 100644 index 00000000..49854218 --- /dev/null +++ b/src/askui/models/shared/tools.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from typing import Any, cast + +from anthropic.types.beta import BetaToolUnionParam +from PIL import Image + +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + ContentBlockParam, + ImageBlockParam, + TextBlockParam, + ToolResultBlockParam, + ToolUseBlockParam, +) +from askui.utils.image_utils import ImageSource + +ToolCallResult = Image.Image | None + + +def _convert_to_content( + result: ToolCallResult, +) -> list[TextBlockParam | ImageBlockParam]: + if result is None: + return [] + + return [ + ImageBlockParam( + source=Base64ImageSourceParam( + media_type="image/png", + data=ImageSource(result).to_base64(), + ) + ) + ] + + +class Tool(ABC): + """Abstract base class for tools.""" + + @abstractmethod + def __call__(self, *args: Any, **kwargs: Any) -> ToolCallResult: + """Executes the tool with the given arguments.""" + raise NotImplementedError + + @abstractmethod + def to_params( + self, + ) -> BetaToolUnionParam: + raise NotImplementedError + + +class ToolCollection: + """A collection of tools. + + Use for dispatching tool calls + + Vision: + - Could be used for parallelizing tool calls configurable through init arg + - Could be used for raising on an exception + (instead of just returning `ContentBlockParam`) + within tool call or doing tool call or if tool is not found + """ + + def __init__(self, tools: list[Tool]) -> None: + self._tools = tools + self._tool_map = {tool.to_params()["name"]: tool for tool in tools} + + def to_params( + self, + ) -> list[BetaToolUnionParam]: + return [tool.to_params() for tool in self._tools] + + def run( + self, tool_use_block_params: list[ToolUseBlockParam] + ) -> list[ContentBlockParam]: + return [ + self._run_tool(tool_use_block_param) + for tool_use_block_param in tool_use_block_params + ] + + def _run_tool( + self, tool_use_block_param: ToolUseBlockParam + ) -> ToolResultBlockParam: + tool = self._tool_map.get(tool_use_block_param.name) + if not tool: + return ToolResultBlockParam( + content=f"Tool not found: {tool_use_block_param.name}", + is_error=True, + tool_use_id=tool_use_block_param.id, + ) + try: + tool_result: ToolCallResult = cast( + "ToolCallResult", + tool(**tool_use_block_param.input), # type: ignore + ) + return ToolResultBlockParam( + content=_convert_to_content(tool_result), + tool_use_id=tool_use_block_param.id, + ) + except Exception as e: # noqa: BLE001 + return ToolResultBlockParam( + content=f"Tool {tool_use_block_param.name} failed: {e}", + is_error=True, + tool_use_id=tool_use_block_param.id, + ) diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index 24b62ee0..d0e36af9 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -133,129 +133,6 @@ ] """PC keys for keyboard actions.""" -PcKeys: list[PcKey] = [ - "backspace", - "delete", - "enter", - "tab", - "escape", - "up", - "down", - "right", - "left", - "home", - "end", - "pageup", - "pagedown", - "f1", - "f2", - "f3", - "f4", - "f5", - "f6", - "f7", - "f8", - "f9", - "f10", - "f11", - "f12", - "space", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - "P", - "Q", - "R", - "S", - "T", - "U", - "V", - "W", - "X", - "Y", - "Z", - "!", - '"', - "#", - "$", - "%", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "<", - "=", - ">", - "?", - "@", - "[", - "\\", - "]", - "^", - "_", - "`", - "{", - "|", - "}", - "~", -] - class ClickEvent(BaseModel): type: Literal["click"] = "click" diff --git a/src/askui/tools/anthropic/__init__.py b/src/askui/tools/anthropic/__init__.py deleted file mode 100644 index 0a058914..00000000 --- a/src/askui/tools/anthropic/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base import CLIResult, ToolResult -from .collection import ToolCollection -from .computer import ComputerTool - -__ALL__ = [ - CLIResult, - ComputerTool, - ToolCollection, - ToolResult, -] diff --git a/src/askui/tools/anthropic/base.py b/src/askui/tools/anthropic/base.py deleted file mode 100644 index bd4116ef..00000000 --- a/src/askui/tools/anthropic/base.py +++ /dev/null @@ -1,84 +0,0 @@ -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass, fields, replace -from typing import Any, Optional - -from anthropic.types.beta import BetaToolUnionParam - - -class BaseAnthropicTool(metaclass=ABCMeta): - """Abstract base class for Anthropic-defined tools.""" - - @abstractmethod - def __call__(self, **kwargs: Any) -> Any: - """Executes the tool with the given arguments.""" - ... - - @abstractmethod - def to_params( - self, - ) -> BetaToolUnionParam: - raise NotImplementedError - - -@dataclass(kw_only=True, frozen=True) -class ToolResult: - """Represents the result of a tool execution. - - Args: - output (str | None, optional): The output of the tool. - error (str | None, optional): The error message of the tool. - base64_image (str | None, optional): The base64 image of the tool. - system (str | None, optional): The system message of the tool. - """ - - output: str | None = None - error: str | None = None - base64_image: str | None = None - system: str | None = None - - def __bool__(self) -> bool: - return any(getattr(self, field.name) for field in fields(self)) - - def __add__(self, other: "ToolResult") -> "ToolResult": - def combine_fields( - field: str | None, other_field: str | None, concatenate: bool = True - ) -> str | None: - if field and other_field: - if concatenate: - return field + other_field - error_msg = "Cannot combine tool results" - raise ValueError(error_msg) - return field or other_field - - return ToolResult( - output=combine_fields(self.output, other.output), - error=combine_fields(self.error, other.error), - base64_image=combine_fields(self.base64_image, other.base64_image, False), - system=combine_fields(self.system, other.system), - ) - - def replace(self, **kwargs: Any) -> "ToolResult": - """Returns a new ToolResult with the given fields replaced.""" - return replace(self, **kwargs) - - -class CLIResult(ToolResult): - """A ToolResult that can be rendered as a CLI output.""" - - -class ToolFailure(ToolResult): - """A ToolResult that represents a failure.""" - - -class ToolError(Exception): - """Raised when a tool encounters an error. - - Args: - message (str): The error message. - result (ToolResult, optional): The ToolResult that caused the error. - """ - - def __init__(self, message: str, result: Optional[ToolResult] = None): - self.message = message - self.result = result - super().__init__(self.message) diff --git a/src/askui/tools/anthropic/collection.py b/src/askui/tools/anthropic/collection.py deleted file mode 100644 index 2141c8fa..00000000 --- a/src/askui/tools/anthropic/collection.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Collection classes for managing multiple tools.""" - -from typing import Any, cast - -from anthropic.types.beta import BetaToolUnionParam - -from .base import ( - BaseAnthropicTool, - ToolError, - ToolFailure, - ToolResult, -) - - -class ToolCollection: - """A collection of anthropic-defined tools.""" - - def __init__(self, *tools: BaseAnthropicTool): - self.tools = tools - self.tool_map = {tool.to_params()["name"]: tool for tool in tools} - - def to_params( - self, - ) -> list[BetaToolUnionParam]: - return [tool.to_params() for tool in self.tools] - - def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult: - tool = self.tool_map.get(name) - if not tool: - return ToolFailure(error=f"Tool {name} is invalid") - try: - return cast("ToolResult", tool(**tool_input)) - except ToolError as e: - return ToolFailure(error=e.message) diff --git a/src/askui/tools/anthropic/computer.py b/src/askui/tools/anthropic/computer.py deleted file mode 100644 index 3997e282..00000000 --- a/src/askui/tools/anthropic/computer.py +++ /dev/null @@ -1,359 +0,0 @@ -from typing import Any, Literal, TypedDict - -from anthropic.types.beta import BetaToolComputerUse20241022Param - -from askui.tools.agent_os import AgentOs -from askui.utils.image_utils import ( - image_to_base64, - scale_coordinates_back, - scale_image_with_padding, -) - -from .base import BaseAnthropicTool, ToolError, ToolResult - -Action = Literal[ - "key", - "type", - "mouse_move", - "left_click", - "left_click_drag", - "right_click", - "middle_click", - "double_click", - "screenshot", - "cursor_position", -] - - -PC_KEY = [ - "backspace", - "delete", - "enter", - "tab", - "escape", - "up", - "down", - "right", - "left", - "home", - "end", - "pageup", - "pagedown", - "f1", - "f2", - "f3", - "f4", - "f5", - "f6", - "f7", - "f8", - "f9", - "f10", - "f11", - "f12", - "space", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - "P", - "Q", - "R", - "S", - "T", - "U", - "V", - "W", - "X", - "Y", - "Z", - "!", - '"', - "#", - "$", - "%", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "<", - "=", - ">", - "?", - "@", - "[", - "\\", - "]", - "^", - "_", - "`", - "{", - "|", - "}", - "~", -] - - -KEYSYM_MAP = { - "BackSpace": "backspace", - "Delete": "delete", - "Return": "enter", - "Enter": "enter", - "Tab": "tab", - "Escpage": "escape", - "Up": "up", - "Down": "down", - "Right": "right", - "Left": "left", - "Home": "home", - "End": "end", - "Page_Up": "pageup", - "Page_Down": "pagedown", - "F1": "f1", - "F2": "f2", - "F3": "f3", - "F4": "f4", - "F5": "f5", - "F6": "f6", - "F7": "f7", - "F8": "f8", - "F9": "f9", - "F10": "f10", - "F11": "f11", - "F12": "f12", -} - - -class Resolution(TypedDict): - width: int - height: int - - -# sizes above XGA/WXGA are not recommended (see README.md) -# scale down to one of these targets if ComputerTool._scaling_enabled is set -MAX_SCALING_TARGETS: dict[str, Resolution] = { - "XGA": Resolution(width=1024, height=768), # 4:3 - "WXGA": Resolution(width=1280, height=800), # 16:10 - "FWXGA": Resolution(width=1366, height=768), # ~16:9 -} - - -class ComputerToolOptions(TypedDict): - display_height_px: int - display_width_px: int - display_number: int | None - - -class ComputerTool(BaseAnthropicTool): - """ - A tool that allows the agent to interact with the screen, keyboard, and mouse of - the current computer. - The tool parameters are defined by Anthropic and are not editable. - """ - - name: Literal["computer"] = "computer" - api_type: Literal["computer_20241022"] = "computer_20241022" - - _screenshot_delay = 2.0 - _scaling_enabled = True - - @property - def options(self) -> ComputerToolOptions: - return { - "display_width_px": self._width, - "display_height_px": self._height, - } - - def to_params(self) -> BetaToolComputerUse20241022Param: - return {"name": self.name, "type": self.api_type, **self.options} - - def __init__(self, agent_os: AgentOs) -> None: - super().__init__() - self._agent_os = agent_os - self._width = 1280 - self._height = 800 - self._real_screen_width: int | None = None - self._real_screen_height: int | None = None - - def __call__( # noqa: C901 - self, - *, - action: Action | None = None, - text: str | None = None, - coordinate: tuple[int, int] | None = None, - **kwargs: Any, # noqa: ARG002 - ) -> ToolResult: - """Execute computer action.""" - if action is None: - error_msg = "Action is missing" - raise ToolError(error_msg) - - if action in ("mouse_move", "left_click_drag"): - if coordinate is None: - error_msg = f"coordinate is required for {action}" - raise ToolError(error_msg) - if text is not None: - error_msg = f"text is not accepted for {action}" - raise ToolError(error_msg) - if not isinstance(coordinate, list) or len(coordinate) != 2: - error_msg = f"{coordinate} must be a tuple of length 2" - raise ToolError(error_msg) - if not all(isinstance(i, int) and i >= 0 for i in coordinate): - error_msg = f"{coordinate} must be a tuple of non-negative ints" - raise ToolError(error_msg) - - if self._real_screen_width is None or self._real_screen_height is None: - screenshot = self._agent_os.screenshot() - self._real_screen_width = screenshot.width - self._real_screen_height = screenshot.height - - x, y = scale_coordinates_back( - coordinate[0], - coordinate[1], - self._real_screen_width, - self._real_screen_height, - self._width, - self._height, - ) - x, y = int(x), int(y) - - if action == "mouse_move": - self._agent_os.mouse_move(x, y) - return ToolResult() - if action == "left_click_drag": - self._agent_os.mouse_down("left") - self._agent_os.mouse_move(x, y) - self._agent_os.mouse_up("left") - return ToolResult() - - if action in ("key", "type"): - if text is None: - error_msg = f"text is required for {action}" - raise ToolError(error_msg) - if coordinate is not None: - error_msg = f"coordinate is not accepted for {action}" - raise ToolError(error_msg) - if not isinstance(text, str): - error_msg = f"{text} must be a string" - raise ToolError(error_msg) - - if action == "key": - if text in KEYSYM_MAP.keys(): - text = KEYSYM_MAP[text] - - if text not in PC_KEY: - error_msg = ( - f"Key {text} is not a valid PC_KEY from {', '.join(PC_KEY)}" - ) - raise ToolError(error_msg) - self._agent_os.keyboard_pressed(text) - self._agent_os.keyboard_release(text) - return ToolResult() - if action == "type": - self._agent_os.type(text) - return ToolResult() - - if action in ( - "left_click", - "right_click", - "double_click", - "middle_click", - "screenshot", - "cursor_position", - ): - if text is not None: - error_msg = f"text is not accepted for {action}" - raise ToolError(error_msg) - if coordinate is not None: - error_msg = f"coordinate is not accepted for {action}" - raise ToolError(error_msg) - - if action == "screenshot": - return self.screenshot() - if action == "cursor_position": - error_msg = "cursor_position is not implemented by this agent" - raise ToolError(error_msg) - if action == "left_click": - self._agent_os.click("left") - return ToolResult() - if action == "right_click": - self._agent_os.click("right") - return ToolResult() - if action == "middle_click": - self._agent_os.click("middle") - return ToolResult() - if action == "double_click": - self._agent_os.click("left", 2) - return ToolResult() - - error_msg = f"Invalid action: {action}" - raise ToolError(error_msg) - - def screenshot(self) -> ToolResult: - """ - Take a screenshot of the current screen, scale it and return the base64 - encoded image. - """ - screenshot = self._agent_os.screenshot() - self._real_screen_width = screenshot.width - self._real_screen_height = screenshot.height - scaled_screenshot = scale_image_with_padding( - screenshot, self._width, self._height - ) - base64_image = image_to_base64(scaled_screenshot) - return ToolResult(base64_image=base64_image) diff --git a/src/askui/tools/computer.py b/src/askui/tools/computer.py new file mode 100644 index 00000000..c41c5b4d --- /dev/null +++ b/src/askui/tools/computer.py @@ -0,0 +1,205 @@ +from typing import Annotated, Literal + +from anthropic.types.beta import BetaToolComputerUse20241022Param +from PIL import Image +from pydantic import Field, validate_call + +from askui.tools.agent_os import AgentOs, PcKey +from askui.utils.dict_utils import IdentityDefaultDict +from askui.utils.image_utils import scale_coordinates_back, scale_image_with_padding + +from ..models.shared.tools import Tool + +Action20241022 = Literal[ + "key", + "type", + "mouse_move", + "left_click", + "left_click_drag", + "right_click", + "middle_click", + "double_click", + "screenshot", + "cursor_position", +] + +KeysToMap = Literal[ + "BackSpace", + "Delete", + "Return", + "Enter", + "Tab", + "Escape", + "Up", + "Down", + "Right", + "Left", + "Home", + "End", + "Page_Up", + "Page_Down", + "F1", + "F2", + "F3", + "F4", + "F5", + "F6", + "F7", + "F8", + "F9", + "F10", + "F11", + "F12", +] + +Key = PcKey | KeysToMap + +KEYS_MAPPING: IdentityDefaultDict[Key, PcKey] = IdentityDefaultDict( + { + "BackSpace": "backspace", + "Delete": "delete", + "Return": "enter", + "Enter": "enter", + "Tab": "tab", + "Escape": "escape", + "Up": "up", + "Down": "down", + "Right": "right", + "Left": "left", + "Home": "home", + "End": "end", + "Page_Up": "pageup", + "Page_Down": "pagedown", + "F1": "f1", + "F2": "f2", + "F3": "f3", + "F4": "f4", + "F5": "f5", + "F6": "f6", + "F7": "f7", + "F8": "f8", + "F9": "f9", + "F10": "f10", + "F11": "f11", + "F12": "f12", + } +) + + +class ActionNotImplementedError(NotImplementedError): + def __init__(self, action: Action20241022, tool_name: str) -> None: + self.action = action + self.tool_name = tool_name + super().__init__( + f'Action "{action}" has not been implemented by tool "{tool_name}"' + ) + + +class Computer20241022Tool(Tool): + name: Literal["computer"] = "computer" + api_type: Literal["computer_20241022"] = "computer_20241022" + + def to_params(self) -> BetaToolComputerUse20241022Param: + return { + "name": self.name, + "type": self.api_type, + "display_width_px": self._width, + "display_height_px": self._height, + } + + def __init__(self, agent_os: AgentOs) -> None: + self._agent_os = agent_os + self._width = 1280 + self._height = 800 + self._real_screen_width: int | None = None + self._real_screen_height: int | None = None + + @validate_call + def __call__( + self, + action: Action20241022, + text: str | None = None, + coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]] + | None = None, + ) -> Image.Image | None: + match action: + case "mouse_move": + self._mouse_move(coordinate) # type: ignore[arg-type] + case "left_click_drag": + # does not seem to work + self._left_click_drag(coordinate) # type: ignore[arg-type] + case "screenshot": + return self._screenshot() + case "left_click": + self._agent_os.click("left") + case "right_click": + self._agent_os.click("right") + case "middle_click": + self._agent_os.click("middle") + case "double_click": + self._agent_os.click("left", 2) + case "type": + self._type(text) # type: ignore[arg-type] + case "key": + # we do not seem to support all kinds of key nor modifier keys + key combinations + self._key(text) # type: ignore[arg-type] + case _: + raise ActionNotImplementedError(action, self.name) + return None + + @validate_call + def _type(self, text: str) -> None: + self._agent_os.type(text) + + @validate_call + def _key(self, key: Key) -> None: + _key = KEYS_MAPPING[key] + self._agent_os.keyboard_pressed(_key) + self._agent_os.keyboard_release(_key) + + def _scale_coordinates_back( + self, + coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]], + ) -> tuple[int, int]: + if self._real_screen_width is None or self._real_screen_height is None: + screenshot = self._agent_os.screenshot() + self._real_screen_width = screenshot.width + self._real_screen_height = screenshot.height + x, y = scale_coordinates_back( + coordinate[0], + coordinate[1], + self._real_screen_width, # + self._real_screen_height, + self._width, + self._height, + ) + x, y = int(x), int(y) + return x, y + + @validate_call + def _mouse_move( + self, + coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]], + ) -> None: + x, y = self._scale_coordinates_back(coordinate) + self._agent_os.mouse_move(x, y) + + @validate_call + def _left_click_drag( + self, + coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]], + ) -> None: + x, y = self._scale_coordinates_back(coordinate) + # holding key pressed does not seem to work + self._agent_os.mouse_down("left") + self._agent_os.mouse_move(x, y) + self._agent_os.mouse_up("left") + + def _screenshot(self) -> Image.Image: + """ + Take a screenshot of the current screen, scale it and return it + """ + screenshot = self._agent_os.screenshot() + self._real_screen_width = screenshot.width + self._real_screen_height = screenshot.height + return scale_image_with_padding(screenshot, self._width, self._height) diff --git a/src/askui/tools/exceptions.py b/src/askui/tools/exceptions.py deleted file mode 100644 index dd052b8e..00000000 --- a/src/askui/tools/exceptions.py +++ /dev/null @@ -1,6 +0,0 @@ -from .anthropic.base import ToolError, ToolResult - -__all__ = [ - "ToolError", - "ToolResult", -] diff --git a/src/askui/utils/dict_utils.py b/src/askui/utils/dict_utils.py new file mode 100644 index 00000000..ecc48bdf --- /dev/null +++ b/src/askui/utils/dict_utils.py @@ -0,0 +1,36 @@ +from collections import defaultdict +from typing import Any, TypeVar + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +class IdentityDefaultDict(defaultdict[_KT, _VT]): + """ + A `defaultdict` variant that returns the key itself if the key is not found. + + Args: + d (dict[_KT, _VT] | None, optional): Initial dictionary to populate the mapping. + If `None`, an empty dict is used. + + Example: + ```python + d = IdentityDefaultDict({'a': 1}) + print(d['a']) # 1 + print(d['b']) # 'b' + ``` + + Returns: + IdentityDefaultDict: An instance of the mapping. + + Notes: + This is useful for mapping lookups where missing keys should fall back to the + key itself (e.g., identity mapping). + """ + + def __init__(self, d: dict[_KT, _VT] | None = None) -> None: + _d = d or {} + super().__init__(None, _d) + + def __missing__(self, key: Any) -> Any: + return key diff --git a/test.py b/test.py new file mode 100644 index 00000000..bc4f5a9b --- /dev/null +++ b/test.py @@ -0,0 +1,107 @@ +import math +from collections.abc import Callable +from typing import Any + +# Tool = TypeVar("Tool", bound=Callable[..., Any]) + + +# class Agent(Generic[Tool]): +# def __init__(self, tools: list[Tool]) -> None: +# for tool in tools: +# setattr(self, tool.__name__, tool) + + +def add(a: int, b: int) -> int: + return a + b + + +def subtract(a: int, b: int) -> int: + return a - b + + +# class MathAgent(Protocol): +# add: type[add] + + +# def make_agent(tools: list[Tool]) -> Agent[Tool] & Protocol: + + +# agent = Agent([add, subtract]) + +# print(agent.add(1, 2)) +# print(agent.subtract(1, 2)) + + +# class X: +# pass + +# X.add = staticmethod(add) + + +# def greet(self): +# print(f"Hi, I'm {self.name}!") + +# def set_name(self, name): +# self.name = name + +# def make_person_class(): +# return type('Person', (object,), { +# 'greet': greet, +# 'set_name': set_name, +# }) + +# Person = make_person_class() + + +# person = Person() + +# person.greet() + + +class ToolsBase: + def __init__(self) -> None: + self._tools: dict[str, Callable[..., Any]] = {} + for attr in dir(self): + if attr.startswith("_"): + continue + val = getattr(self, attr) + if callable(val): + self._tools[attr] = val + + +class Tools(ToolsBase): + add = staticmethod(add) + subtract = staticmethod(subtract) + floor = staticmethod(math.floor) + + +# T = TypeVar("T") + + +# class Agent(Generic[T]): +# def __init__(self, tools: T) -> None: +# self.tools = tools + + +class MathAgent(Tools): + pass + + +math_agent = MathAgent() + +math_agent.add + + +# agent = Agent({ +# 'add': add, +# 'subtract': subtract, +# 'floor': math.floor, +# }) + +# print(agent.tools['add'](1, 2)) +# print(agent.tools['subtract'](1, 2)) + + + +class Agent: + diff --git a/tests/conftest.py b/tests/conftest.py index 83e27bce..45a6a637 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,9 @@ from pytest_mock import MockerFixture from askui.models.model_router import ModelRouter +from askui.models.shared.tools import ToolCollection from askui.tools.agent_os import AgentOs +from askui.tools.computer import Computer20241022Tool from askui.tools.toolbox import AgentToolbox @@ -51,6 +53,12 @@ def agent_toolbox_mock(agent_os_mock: AgentOs) -> AgentToolbox: return AgentToolbox(agent_os=agent_os_mock) +@pytest.fixture +def tool_collection_mock(agent_os_mock: AgentOs) -> ToolCollection: + """Fixture providing a mock tool collection.""" + return ToolCollection(tools=[Computer20241022Tool(agent_os_mock)]) + + @pytest.fixture def model_router_mock(mocker: MockerFixture) -> ModelRouter: """Fixture providing a mock model router.""" diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index dbbc30e5..98564def 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -16,6 +16,7 @@ from askui.models.askui.settings import AskUiComputerAgentSettings from askui.models.models import ModelName from askui.models.shared.facade import ModelFacade +from askui.models.shared.tools import ToolCollection from askui.reporting import Reporter, SimpleHtmlReporter from askui.tools.toolbox import AgentToolbox @@ -65,12 +66,12 @@ def askui_inference_api( @pytest.fixture def askui_computer_agent( - agent_toolbox_mock: AgentToolbox, + tool_collection_mock: ToolCollection, askui_settings: AskUiSettings, simple_html_reporter: Reporter, ) -> AskUiComputerAgent: return AskUiComputerAgent( - agent_os=agent_toolbox_mock.os, + tool_collection=tool_collection_mock, reporter=simple_html_reporter, settings=AskUiComputerAgentSettings( askui=askui_settings, diff --git a/tests/integration/agent/conftest.py b/tests/integration/agent/conftest.py index b4d5857f..acb9df65 100644 --- a/tests/integration/agent/conftest.py +++ b/tests/integration/agent/conftest.py @@ -1,4 +1,4 @@ -from typing import Generator, Optional, Union +from typing import Any, Generator, Optional, Union import pytest from PIL import Image as PILImage @@ -6,8 +6,10 @@ from askui.models.askui.computer_agent import AskUiComputerAgent from askui.models.askui.settings import AskUiComputerAgentSettings, AskUiSettings +from askui.models.shared.tools import ToolCollection from askui.reporting import Reporter from askui.tools.agent_os import AgentOs +from askui.tools.computer import Computer20241022Tool class ReporterMock(Reporter): @@ -15,7 +17,7 @@ class ReporterMock(Reporter): def add_message( self, role: str, - content: Union[str, dict, list], + content: Union[str, dict[str, Any], list[Any]], image: Optional[PILImage.Image | list[PILImage.Image]] = None, ) -> None: pass @@ -31,7 +33,7 @@ def claude_computer_agent( ) -> Generator[AskUiComputerAgent, None, None]: """Fixture providing a AskUiClaudeComputerAgent instance.""" agent = AskUiComputerAgent( - agent_os=agent_os_mock, + tool_collection=ToolCollection(tools=[Computer20241022Tool(agent_os_mock)]), reporter=ReporterMock(), settings=AskUiComputerAgentSettings(askui=AskUiSettings()), ) diff --git a/tests/unit/models/test_model_router.py b/tests/unit/models/test_model_router.py index 5cb6f192..0aa9e97b 100644 --- a/tests/unit/models/test_model_router.py +++ b/tests/unit/models/test_model_router.py @@ -14,9 +14,9 @@ from askui.models.models import ModelName from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.shared.facade import ModelFacade +from askui.models.shared.tools import ToolCollection from askui.models.ui_tars_ep.ui_tars_api import UiTarsApiHandler from askui.reporting import CompositeReporter -from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource # Test UUID for workspace_id @@ -76,7 +76,7 @@ def mock_askui_facade(mocker: MockerFixture) -> ModelFacade: @pytest.fixture def model_router( - agent_toolbox_mock: AgentToolbox, + tool_collection_mock: ToolCollection, mock_anthropic_facade: ModelFacade, mock_askui_facade: ModelFacade, mock_tars: UiTarsApiHandler, @@ -84,7 +84,7 @@ def model_router( ) -> ModelRouter: """Fixture providing a ModelRouter instance with mocked dependencies.""" return ModelRouter( - tools=agent_toolbox_mock, + tool_collection=tool_collection_mock, reporter=CompositeReporter(), models={ ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022: mock_anthropic_facade, From a46066c1972eadec41edd5240cc48c0184e5c5a9 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 18 Jun 2025 13:55:47 +0200 Subject: [PATCH 3/4] feat(tools): add basic support for Anthropic's Computer20250124Tool --- src/askui/tools/computer.py | 146 ++++++++++++++++++++++++++++++++---- 1 file changed, 130 insertions(+), 16 deletions(-) diff --git a/src/askui/tools/computer.py b/src/askui/tools/computer.py index c41c5b4d..86f03f6b 100644 --- a/src/askui/tools/computer.py +++ b/src/askui/tools/computer.py @@ -1,8 +1,12 @@ -from typing import Annotated, Literal +from typing import Annotated, Literal, TypedDict -from anthropic.types.beta import BetaToolComputerUse20241022Param +from anthropic.types.beta import ( + BetaToolComputerUse20241022Param, + BetaToolComputerUse20250124Param, +) from PIL import Image from pydantic import Field, validate_call +from typing_extensions import override from askui.tools.agent_os import AgentOs, PcKey from askui.utils.dict_utils import IdentityDefaultDict @@ -23,6 +27,20 @@ "cursor_position", ] +Action20250124 = ( + Action20241022 + | Literal[ + "left_mouse_down", + "left_mouse_up", + "scroll", + "hold_key", + "wait", + "triple_click", + ] +) + +ScrollDirection = Literal["up", "down", "left", "right"] + KeysToMap = Literal[ "BackSpace", "Delete", @@ -87,7 +105,7 @@ class ActionNotImplementedError(NotImplementedError): - def __init__(self, action: Action20241022, tool_name: str) -> None: + def __init__(self, action: Action20250124, tool_name: str) -> None: self.action = action self.tool_name = tool_name super().__init__( @@ -95,29 +113,40 @@ def __init__(self, action: Action20241022, tool_name: str) -> None: ) -class Computer20241022Tool(Tool): - name: Literal["computer"] = "computer" - api_type: Literal["computer_20241022"] = "computer_20241022" +class BetaToolComputerUseParamBase(TypedDict): + name: Literal["computer"] + display_width_px: int + display_height_px: int - def to_params(self) -> BetaToolComputerUse20241022Param: - return { - "name": self.name, - "type": self.api_type, - "display_width_px": self._width, - "display_height_px": self._height, - } - def __init__(self, agent_os: AgentOs) -> None: +class ComputerToolBase(Tool): + name: Literal["computer"] = "computer" + + def __init__( + self, + agent_os: AgentOs, + ) -> None: self._agent_os = agent_os self._width = 1280 self._height = 800 self._real_screen_width: int | None = None self._real_screen_height: int | None = None + @property + def params_base( + self, + ) -> BetaToolComputerUseParamBase: + return { + "name": self.name, + "display_width_px": self._width, + "display_height_px": self._height, + } + + @override @validate_call def __call__( self, - action: Action20241022, + action: Action20250124, text: str | None = None, coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]] | None = None, @@ -141,7 +170,8 @@ def __call__( case "type": self._type(text) # type: ignore[arg-type] case "key": - # we do not seem to support all kinds of key nor modifier keys + key combinations + # we do not seem to support all kinds of key nor modifier keys + # + key combinations self._key(text) # type: ignore[arg-type] case _: raise ActionNotImplementedError(action, self.name) @@ -157,6 +187,16 @@ def _key(self, key: Key) -> None: self._agent_os.keyboard_pressed(_key) self._agent_os.keyboard_release(_key) + @validate_call + def _keyboard_pressed(self, key: Key) -> None: + _key = KEYS_MAPPING[key] + self._agent_os.keyboard_pressed(_key) + + @validate_call + def _keyboard_released(self, key: Key) -> None: + _key = KEYS_MAPPING[key] + self._agent_os.keyboard_release(_key) + def _scale_coordinates_back( self, coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]], @@ -203,3 +243,77 @@ def _screenshot(self) -> Image.Image: self._real_screen_width = screenshot.width self._real_screen_height = screenshot.height return scale_image_with_padding(screenshot, self._width, self._height) + + +class Computer20241022Tool(ComputerToolBase): + type: Literal["computer_20241022"] = "computer_20241022" + + @override + def to_params( + self, + ) -> BetaToolComputerUse20241022Param: + return { + **self.params_base, + "type": self.type, + } + + +class Computer20250124Tool(ComputerToolBase): + type: Literal["computer_20250124"] = "computer_20250124" + + @override + def to_params( + self, + ) -> BetaToolComputerUse20250124Param: + return { + **self.params_base, + "type": self.type, + } + + @override + @validate_call + def __call__( + self, + action: Action20250124, + text: str | None = None, + coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]] + | None = None, + scroll_direction: ScrollDirection | None = None, + scroll_amount: int | None = None, + duration: float | None = None, + key: str | None = None, # maybe not all keys supported + ) -> Image.Image | None: + match action: + case "left_mouse_down": + self._agent_os.mouse_down("left") + case "left_mouse_up": + self._agent_os.mouse_up("left") + case "left_click": + self._click("left", coordinate=coordinate, key=key) + case "right_click": + self._click("right", coordinate=coordinate, key=key) + case "middle_click": + self._click("middle", coordinate=coordinate, key=key) + case "double_click": + self._click("left", count=2, coordinate=coordinate, key=key) + case "triple_click": + self._click("left", count=3, coordinate=coordinate, key=key) + case _: + return super().__call__(action, text, coordinate) + return None + + def _click( + self, + button: Literal["left", "right", "middle"], + count: int = 1, + coordinate: tuple[Annotated[int, Field(ge=0)], Annotated[int, Field(ge=0)]] + | None = None, + key: str | None = None, + ) -> None: + if coordinate is not None: + self._mouse_move(coordinate) + if key is not None: + self._keyboard_pressed(key) # type: ignore[arg-type] + self._agent_os.click(button, count) + if key is not None: + self._keyboard_released(key) # type: ignore[arg-type] From 9fe820e650c015d0fe2e11bd05189a39720ef3f7 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 18 Jun 2025 15:03:00 +0200 Subject: [PATCH 4/4] chore: remove scratch file --- test.py | 107 -------------------------------------------------------- 1 file changed, 107 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index bc4f5a9b..00000000 --- a/test.py +++ /dev/null @@ -1,107 +0,0 @@ -import math -from collections.abc import Callable -from typing import Any - -# Tool = TypeVar("Tool", bound=Callable[..., Any]) - - -# class Agent(Generic[Tool]): -# def __init__(self, tools: list[Tool]) -> None: -# for tool in tools: -# setattr(self, tool.__name__, tool) - - -def add(a: int, b: int) -> int: - return a + b - - -def subtract(a: int, b: int) -> int: - return a - b - - -# class MathAgent(Protocol): -# add: type[add] - - -# def make_agent(tools: list[Tool]) -> Agent[Tool] & Protocol: - - -# agent = Agent([add, subtract]) - -# print(agent.add(1, 2)) -# print(agent.subtract(1, 2)) - - -# class X: -# pass - -# X.add = staticmethod(add) - - -# def greet(self): -# print(f"Hi, I'm {self.name}!") - -# def set_name(self, name): -# self.name = name - -# def make_person_class(): -# return type('Person', (object,), { -# 'greet': greet, -# 'set_name': set_name, -# }) - -# Person = make_person_class() - - -# person = Person() - -# person.greet() - - -class ToolsBase: - def __init__(self) -> None: - self._tools: dict[str, Callable[..., Any]] = {} - for attr in dir(self): - if attr.startswith("_"): - continue - val = getattr(self, attr) - if callable(val): - self._tools[attr] = val - - -class Tools(ToolsBase): - add = staticmethod(add) - subtract = staticmethod(subtract) - floor = staticmethod(math.floor) - - -# T = TypeVar("T") - - -# class Agent(Generic[T]): -# def __init__(self, tools: T) -> None: -# self.tools = tools - - -class MathAgent(Tools): - pass - - -math_agent = MathAgent() - -math_agent.add - - -# agent = Agent({ -# 'add': add, -# 'subtract': subtract, -# 'floor': math.floor, -# }) - -# print(agent.tools['add'](1, 2)) -# print(agent.tools['subtract'](1, 2)) - - - -class Agent: -