diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 00000000..31bd6766 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,74 @@ +# Code Style +- Use `_` prefix for all private variables, constants, functions, methods, properties, etc. that don't need to be accessible from outside the module +- Mark everything as private that does not absolutely need to be accessible from outside the module +- Use `@override` (imported from `typing_extensions`) decorator for all methods that override a parent class method +- Use type hints for all variables, functions, methods, properties, return values, parameters etc. +- Omit `Any` within type hints unless absolutely necessary +- Use built-in types (e.g., `list`, `dict`, `tuple`, `set`, `str | None`) instead of types from `typing` module (e.g. `List`, `Dict`, `Tuple`, `Set`, `Optional`, `Union`) wherever possible +- Instead of `Optional` use `| None` +- Create a `__init__.py` file in each folder +- Never pass literals, e.g., `error_msg`, directly to `Exceptions`, but instead assign them to variables and pass them to the exception, e.g., `raise FileNotFoundError(error_msg)` instead of `raise FileNotFoundError(f"Thread {thread_id} not found")` + +## FastAPI +- Instead of defining `response_model` within route annotation, use the model as the response type in the function signature +- Do not assign `None` to dependencies but instead move it before arguments with default values + +# Testing +- Use `pytest-mock` for mocking in tests wherever you need to mock something and pytest-mock can do the job. + +# Documentation + +## Docstrings +- All public functions, constants, classes, types etc. should have docstrings +- Document the constructor (`__init__`) args as part of the class docstring +- Omit the `__init__` docstring +- All function parameter should be documented with their type (followed by `, optional` if there is a default value) in parenthesis and description +- In descriptions, use backticks for all code references (variables, types, etc.), including types, e.g., `str` +- When referencing a function, use the function name in backticks plus parentheses, e.g., `click()` +- When referencing a class, use the class name in backticks, e.g., `VisionAgent` +- When referencing a method, use the class name in backticks plus the method name in parentheses, e.g., `VisionAgent.click()` +- When referencing a class attribute, use the class name in backticks plus the attribute name, e.g., `VisionAgent.display` +- Use `Example` section for code examples +- Use `Returns` section for return values +- Use `Raises` section for exceptions listing all possible exceptions that can be raised by the function +- Use `Notes` section for additional notes +- Use `See Also` section for related functions +- Use `References` section for references +- Use `Examples` section for code examples +- Example of a good docstring: + ```python + def locate( + self, + locator: str | Locator, + screenshot: Img | None = None, + model: ModelComposition | str | None = None, + ) -> Point: + """ + Find the position of the UI element identified by the `locator` using the `model`. + + Args: + locator (str | Locator): The identifier or description of the element to locate. + screenshot (Img | None, optional): The screenshot to use for locating the + element. Can be a path to an image file, a PIL Image object or a data URL. + If `None`, takes a screenshot of the currently selected screen. + model (ModelComposition | str | None, optional): The composition or name of + the model(s) to be used for locating the element using the `locator`. + + Returns: + Point: The coordinates of a point on the element, usually the center of the element, as a tuple (x, y). + + Raises: + ValueError: If the arguments are not of the correct type. + ElementNotFoundError: If no element can be found. + + Example: + ```python + from askui import VisionAgent + + with VisionAgent() as agent: + point = agent.locate("Submit button") + print(f"Element found at coordinates: {point}") + ``` + """ + ... + ``` diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 00000000..2bd5a0a9 --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +22 diff --git a/.vscode/extensions.json b/.vscode/extensions.json index cbe1788d..711148cb 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -15,7 +15,6 @@ "tamasfe.even-better-toml", "visualstudioexptteam.vscodeintellicode", "dongli.python-preview", - "mintlify.document", "kaih2o.python-resource-monitor", "littlefoxteam.vscode-python-test-adapter", "almenon.arepl" diff --git a/README.md b/README.md index b7b61140..717751bd 100644 --- a/README.md +++ b/README.md @@ -284,32 +284,48 @@ You can create and use your own models by subclassing the `ActModel` (used for ` Here's how to create and use custom models: ```python +import functools from askui import ( ActModel, GetModel, LocateModel, Locator, ImageSource, + MessageParam, ModelComposition, ModelRegistry, + OnMessageCb, Point, ResponseSchema, VisionAgent, ) from typing import Type +from typing_extensions import override # Define custom models class MyActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: # Implement custom act logic, e.g.: # - Use a different AI model # - Implement custom business logic # - Call external services - print(f"Custom act model executing goal: {goal}") + if len(messages) > 0: + goal = messages[0].content + print(f"Custom act model executing goal: {goal}") + else: + error_msg = "No messages provided" + raise ValueError(error_msg) # Because Python supports multiple inheritance, we can subclass both `GetModel` and `LocateModel` (and even `ActModel`) # to create a model that can both get and locate elements. class MyGetAndLocateModel(GetModel, LocateModel): + @override def get( self, query: str, @@ -324,6 +340,7 @@ class MyGetAndLocateModel(GetModel, LocateModel): return f"Custom response to query: {query}" + @override def locate( self, locator: str | Locator, @@ -366,11 +383,15 @@ You can also use model factories if you need to create models dynamically: ```python class DynamicActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: - # Use api_key in implementation + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: pass - # going to be called each time model is chosen using `model` parameter def create_custom_model(api_key: str) -> ActModel: return DynamicActModel() @@ -410,7 +431,7 @@ The controller for the operating system. ```python agent.tools.os.click("left", 2) # clicking -agent.tools.os.mouse(100, 100) # mouse movement +agent.tools.os.mouse_move(100, 100) # mouse movement agent.tools.os.keyboard_tap("v", modifier_keys=["control"]) # Paste # and many more ``` diff --git a/act.py b/act.py deleted file mode 100644 index f8f594b8..00000000 --- a/act.py +++ /dev/null @@ -1,7 +0,0 @@ -from askui import VisionAgent - -with VisionAgent(log_level="DEBUG") as agent: - agent.act("Click on the 'X' button to cancel the current search in Google Maps") - agent.act( - "Search for 'Linienstraße 145' in Google maps to find the route there from 'Google Berlin'" - ) diff --git a/pdm.lock b/pdm.lock index 7e8b0377..7ac45d7b 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "chat", "test"] +groups = ["default", "chat", "mcp", "pynput", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:0a776bcb0858c92a70bc0221f9adafacdc5af3bc2dd86282aaeb4fce12ed32ac" +content_hash = "sha256:21d28e90c53b9d7f3439e469fe2125b83225d66dd3f6d182bedd1ae774544485" [[metadata.targets]] requires_python = ">=3.10" @@ -33,7 +33,7 @@ name = "annotated-types" version = "0.7.0" requires_python = ">=3.8" summary = "Reusable constraint types to use with typing.Annotated" -groups = ["default"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions>=4.0.0; python_version < \"3.9\"", ] @@ -67,7 +67,7 @@ name = "anyio" version = "4.9.0" requires_python = ">=3.9" summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["default"] +groups = ["default", "chat", "mcp"] dependencies = [ "exceptiongroup>=1.0.2; python_version < \"3.11\"", "idna>=2.8", @@ -128,7 +128,7 @@ name = "certifi" version = "2025.1.31" requires_python = ">=3.6" summary = "Python package for providing Mozilla's CA Bundle." -groups = ["default", "chat"] +groups = ["default", "chat", "mcp"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -202,7 +202,7 @@ name = "click" version = "8.1.8" requires_python = ">=3.7" summary = "Composable command line interface toolkit" -groups = ["chat"] +groups = ["chat", "mcp"] dependencies = [ "colorama; platform_system == \"Windows\"", "importlib-metadata; python_version < \"3.8\"", @@ -217,7 +217,7 @@ name = "colorama" version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." -groups = ["default", "chat", "test"] +groups = ["default", "chat", "mcp", "test"] marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, @@ -364,12 +364,23 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "evdev" +version = "1.9.2" +requires_python = ">=3.8" +summary = "Bindings to the Linux input handling subsystem" +groups = ["pynput"] +marker = "\"linux\" in sys_platform" +files = [ + {file = "evdev-1.9.2.tar.gz", hash = "sha256:5d3278892ce1f92a74d6bf888cc8525d9f68af85dbe336c95d1c87fb8f423069"}, +] + [[package]] name = "exceptiongroup" version = "1.2.2" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default", "test"] +groups = ["default", "chat", "mcp", "test"] marker = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, @@ -387,6 +398,22 @@ files = [ {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, ] +[[package]] +name = "fastapi" +version = "0.115.12" +requires_python = ">=3.8" +summary = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +groups = ["chat"] +dependencies = [ + "pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4", + "starlette<0.47.0,>=0.40.0", + "typing-extensions>=4.8.0", +] +files = [ + {file = "fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d"}, + {file = "fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681"}, +] + [[package]] name = "filelock" version = "3.18.0" @@ -581,7 +608,7 @@ name = "h11" version = "0.14.0" requires_python = ">=3.7" summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions; python_version < \"3.8\"", ] @@ -595,7 +622,7 @@ name = "httpcore" version = "1.0.7" requires_python = ">=3.8" summary = "A minimal low-level HTTP client." -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "certifi", "h11<0.15,>=0.13", @@ -610,7 +637,7 @@ name = "httpx" version = "0.28.1" requires_python = ">=3.8" summary = "The next generation HTTP client." -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "anyio", "certifi", @@ -622,6 +649,17 @@ files = [ {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, ] +[[package]] +name = "httpx-sse" +version = "0.4.0" +requires_python = ">=3.8" +summary = "Consume Server-Sent Event (SSE) messages with HTTPX." +groups = ["mcp"] +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" version = "0.30.1" @@ -647,7 +685,7 @@ name = "idna" version = "3.10" requires_python = ">=3.6" summary = "Internationalized Domain Names in Applications (IDNA)" -groups = ["default", "chat"] +groups = ["default", "chat", "mcp"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -777,7 +815,7 @@ name = "markdown-it-py" version = "3.0.0" requires_python = ">=3.8" summary = "Python port of markdown-it. Markdown parsing, done right!" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "mdurl~=0.1", ] @@ -846,17 +884,69 @@ files = [ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] +[[package]] +name = "mcp" +version = "1.8.1" +requires_python = ">=3.10" +summary = "Model Context Protocol SDK" +groups = ["mcp"] +dependencies = [ + "anyio>=4.5", + "httpx-sse>=0.4", + "httpx>=0.27", + "pydantic-settings>=2.5.2", + "pydantic<3.0.0,>=2.7.2", + "python-multipart>=0.0.9", + "sse-starlette>=1.6.1", + "starlette>=0.27", + "uvicorn>=0.23.1; sys_platform != \"emscripten\"", +] +files = [ + {file = "mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770"}, + {file = "mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e"}, +] + +[[package]] +name = "mcp" +version = "1.8.1" +extras = ["cli", "rich", "ws"] +requires_python = ">=3.10" +summary = "Model Context Protocol SDK" +groups = ["mcp"] +dependencies = [ + "mcp==1.8.1", + "python-dotenv>=1.0.0", + "rich>=13.9.4", + "typer>=0.12.4", + "websockets>=15.0.1", +] +files = [ + {file = "mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770"}, + {file = "mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e"}, +] + [[package]] name = "mdurl" version = "0.1.2" requires_python = ">=3.7" summary = "Markdown URL utilities" -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mss" +version = "10.0.0" +requires_python = ">=3.9" +summary = "An ultra fast cross-platform multiple screenshots module in pure python using ctypes." +groups = ["pynput"] +files = [ + {file = "mss-10.0.0-py3-none-any.whl", hash = "sha256:82cf6460a53d09e79b7b6d871163c982e6c7e9649c426e7b7591b74956d5cb64"}, + {file = "mss-10.0.0.tar.gz", hash = "sha256:d903e0d51262bf0f8782841cf16eaa6d7e3e1f12eae35ab41c2e318837c6637f"}, +] + [[package]] name = "mypy" version = "1.15.0" @@ -1225,7 +1315,7 @@ name = "pydantic" version = "2.11.2" requires_python = ">=3.9" summary = "Data validation using Python type hints" -groups = ["default"] +groups = ["default", "chat", "mcp"] dependencies = [ "annotated-types>=0.6.0", "pydantic-core==2.33.1", @@ -1242,7 +1332,7 @@ name = "pydantic-core" version = "2.33.1" requires_python = ">=3.9" summary = "Core functionality for Pydantic validation and serialization" -groups = ["default"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions!=4.7.0,>=4.6.0", ] @@ -1328,17 +1418,18 @@ files = [ [[package]] name = "pydantic-settings" -version = "2.8.1" -requires_python = ">=3.8" +version = "2.9.1" +requires_python = ">=3.9" summary = "Settings management using Pydantic" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "pydantic>=2.7.0", "python-dotenv>=0.21.0", + "typing-inspection>=0.4.0", ] files = [ - {file = "pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c"}, - {file = "pydantic_settings-2.8.1.tar.gz", hash = "sha256:d5c663dfbe9db9d5e1c646b2e161da12f0d734d422ee56f567d0ea2cee4e8585"}, + {file = "pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef"}, + {file = "pydantic_settings-2.9.1.tar.gz", hash = "sha256:c509bf79d27563add44e8446233359004ed85066cd096d8b510f715e6ef5d268"}, ] [[package]] @@ -1361,7 +1452,7 @@ name = "pygments" version = "2.19.1" requires_python = ">=3.8" summary = "Pygments is a syntax highlighting package written in Python." -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, @@ -1378,6 +1469,122 @@ files = [ {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, ] +[[package]] +name = "pynput" +version = "1.8.1" +summary = "Monitor and control user input devices" +groups = ["pynput"] +dependencies = [ + "enum34; python_version == \"2.7\"", + "evdev>=1.3; \"linux\" in sys_platform", + "pyobjc-framework-ApplicationServices>=8.0; sys_platform == \"darwin\"", + "pyobjc-framework-Quartz>=8.0; sys_platform == \"darwin\"", + "python-xlib>=0.17; \"linux\" in sys_platform", + "six", +] +files = [ + {file = "pynput-1.8.1-py2.py3-none-any.whl", hash = "sha256:42dfcf27404459ca16ca889c8fb8ffe42a9fe54f722fd1a3e130728e59e768d2"}, + {file = "pynput-1.8.1.tar.gz", hash = "sha256:70d7c8373ee98911004a7c938742242840a5628c004573d84ba849d4601df81e"}, +] + +[[package]] +name = "pyobjc-core" +version = "11.0" +requires_python = ">=3.8" +summary = "Python<->ObjC Interoperability Module" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +files = [ + {file = "pyobjc_core-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:10866b3a734d47caf48e456eea0d4815c2c9b21856157db5917b61dee06893a1"}, + {file = "pyobjc_core-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:50675c0bb8696fe960a28466f9baf6943df2928a1fd85625d678fa2f428bd0bd"}, + {file = "pyobjc_core-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a03061d4955c62ddd7754224a80cdadfdf17b6b5f60df1d9169a3b1b02923f0b"}, + {file = "pyobjc_core-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c338c1deb7ab2e9436d4175d1127da2eeed4a1b564b3d83b9f3ae4844ba97e86"}, + {file = "pyobjc_core-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b4e9dc4296110f251a4033ff3f40320b35873ea7f876bd29a1c9705bb5e08c59"}, + {file = "pyobjc_core-11.0.tar.gz", hash = "sha256:63bced211cb8a8fb5c8ff46473603da30e51112861bd02c438fbbbc8578d9a70"}, +] + +[[package]] +name = "pyobjc-framework-applicationservices" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the framework ApplicationServices on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", + "pyobjc-framework-Cocoa>=11.0", + "pyobjc-framework-CoreText>=11.0", + "pyobjc-framework-Quartz>=11.0", +] +files = [ + {file = "pyobjc_framework_ApplicationServices-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bc8f34b5b59ffd3c210ae883d794345c1197558ff3da0f5800669cf16435271e"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61a99eef23abb704257310db4f5271137707e184768f6407030c01de4731b67b"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:5fbeb425897d6129471d451ec61a29ddd5b1386eb26b1dd49cb313e34616ee21"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:59becf3cd87a4f4cedf4be02ff6cf46ed736f5c1123ce629f788aaafad91eff0"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:44b466e8745fb49e8ac20f29f2ffd7895b45e97aa63a844b2a80a97c3a34346f"}, + {file = "pyobjc_framework_applicationservices-11.0.tar.gz", hash = "sha256:d6ea18dfc7d5626a3ecf4ac72d510405c0d3a648ca38cae8db841acdebecf4d2"}, +] + +[[package]] +name = "pyobjc-framework-cocoa" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the Cocoa frameworks on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", +] +files = [ + {file = "pyobjc_framework_Cocoa-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fbc65f260d617d5463c7fb9dbaaffc23c9a4fabfe3b1a50b039b61870b8daefd"}, + {file = "pyobjc_framework_Cocoa-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3ea7be6e6dd801b297440de02d312ba3fa7fd3c322db747ae1cb237e975f5d33"}, + {file = "pyobjc_framework_Cocoa-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:280a577b83c68175a28b2b7138d1d2d3111f2b2b66c30e86f81a19c2b02eae71"}, + {file = "pyobjc_framework_Cocoa-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:15b2bd977ed340074f930f1330f03d42912d5882b697d78bd06f8ebe263ef92e"}, + {file = "pyobjc_framework_Cocoa-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:5750001db544e67f2b66f02067d8f0da96bb2ef71732bde104f01b8628f9d7ea"}, + {file = "pyobjc_framework_cocoa-11.0.tar.gz", hash = "sha256:00346a8cb81ad7b017b32ff7bf596000f9faa905807b1bd234644ebd47f692c5"}, +] + +[[package]] +name = "pyobjc-framework-coretext" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the framework CoreText on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", + "pyobjc-framework-Cocoa>=11.0", + "pyobjc-framework-Quartz>=11.0", +] +files = [ + {file = "pyobjc_framework_CoreText-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6939b4ea745b349b5c964823a2071f155f5defdc9b9fc3a13f036d859d7d0439"}, + {file = "pyobjc_framework_CoreText-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:56a4889858308b0d9f147d568b4d91c441cc0ffd332497cb4f709bb1990450c1"}, + {file = "pyobjc_framework_CoreText-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fb90e7f370b3fd7cb2fb442e3dc63fedf0b4af6908db1c18df694d10dc94669d"}, + {file = "pyobjc_framework_CoreText-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7947f755782456bd663e0b00c7905eeffd10f839f0bf2af031f68ded6a1ea360"}, + {file = "pyobjc_framework_CoreText-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:5356116bae33ec49f1f212c301378a7d08000440a2d6a7281aab351945528ab9"}, + {file = "pyobjc_framework_coretext-11.0.tar.gz", hash = "sha256:a68437153e627847e3898754dd3f13ae0cb852246b016a91f9c9cbccb9f91a43"}, +] + +[[package]] +name = "pyobjc-framework-quartz" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the Quartz frameworks on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", + "pyobjc-framework-Cocoa>=11.0", +] +files = [ + {file = "pyobjc_framework_Quartz-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:da3ab13c9f92361959b41b0ad4cdd41ae872f90a6d8c58a9ed699bc08ab1c45c"}, + {file = "pyobjc_framework_Quartz-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d251696bfd8e8ef72fbc90eb29fec95cb9d1cc409008a183d5cc3246130ae8c2"}, + {file = "pyobjc_framework_Quartz-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:cb4a9f2d9d580ea15e25e6b270f47681afb5689cafc9e25712445ce715bcd18e"}, + {file = "pyobjc_framework_Quartz-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:973b4f9b8ab844574461a038bd5269f425a7368d6e677e3cc81fcc9b27b65498"}, + {file = "pyobjc_framework_Quartz-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:66ab58d65348863b8707e63b2ec5cdc54569ee8189d1af90d52f29f5fdf6272c"}, + {file = "pyobjc_framework_quartz-11.0.tar.gz", hash = "sha256:3205bf7795fb9ae34747f701486b3db6dfac71924894d1f372977c4d70c3c619"}, +] + [[package]] name = "pyperclip" version = "1.9.0" @@ -1483,12 +1690,37 @@ name = "python-dotenv" version = "1.1.0" requires_python = ">=3.9" summary = "Read key-value pairs from a .env file and set them as environment variables" -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"}, {file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"}, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +requires_python = ">=3.8" +summary = "A streaming multipart parser for Python" +groups = ["mcp"] +files = [ + {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, + {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, +] + +[[package]] +name = "python-xlib" +version = "0.33" +summary = "Python X Library" +groups = ["pynput"] +marker = "\"linux\" in sys_platform" +dependencies = [ + "six>=1.10.0", +] +files = [ + {file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"}, + {file = "python_xlib-0.33-py2.py3-none-any.whl", hash = "sha256:c3534038d42e0df2f1392a1b30a15a4ff5fdc2b86cfa94f072bf11b10a164398"}, +] + [[package]] name = "pytz" version = "2025.2" @@ -1583,7 +1815,7 @@ name = "rich" version = "14.0.0" requires_python = ">=3.8.0" summary = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "markdown-it-py>=2.2.0", "pygments<3.0.0,>=2.13.0", @@ -1747,12 +1979,23 @@ files = [ {file = "setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54"}, ] +[[package]] +name = "shellingham" +version = "1.5.4" +requires_python = ">=3.7" +summary = "Tool to Detect Surrounding Shell" +groups = ["mcp"] +files = [ + {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, + {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, +] + [[package]] name = "six" version = "1.17.0" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" summary = "Python 2 and 3 compatibility utilities" -groups = ["default", "chat"] +groups = ["default", "chat", "pynput"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -1774,12 +2017,42 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["default"] +groups = ["default", "chat", "mcp"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "sse-starlette" +version = "2.3.5" +requires_python = ">=3.9" +summary = "SSE plugin for Starlette" +groups = ["mcp"] +dependencies = [ + "anyio>=4.7.0", + "starlette>=0.41.3", +] +files = [ + {file = "sse_starlette-2.3.5-py3-none-any.whl", hash = "sha256:251708539a335570f10eaaa21d1848a10c42ee6dc3a9cf37ef42266cdb1c52a8"}, + {file = "sse_starlette-2.3.5.tar.gz", hash = "sha256:228357b6e42dcc73a427990e2b4a03c023e2495ecee82e14f07ba15077e334b2"}, +] + +[[package]] +name = "starlette" +version = "0.46.2" +requires_python = ">=3.9" +summary = "The little ASGI library that shines." +groups = ["chat", "mcp"] +dependencies = [ + "anyio<5,>=3.6.2", + "typing-extensions>=3.10.0; python_version < \"3.10\"", +] +files = [ + {file = "starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35"}, + {file = "starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5"}, +] + [[package]] name = "streamlit" version = "1.44.1" @@ -1909,6 +2182,23 @@ files = [ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, ] +[[package]] +name = "typer" +version = "0.15.4" +requires_python = ">=3.7" +summary = "Typer, build great CLIs. Easy to code. Based on Python type hints." +groups = ["mcp"] +dependencies = [ + "click<8.2,>=8.0.0", + "rich>=10.11.0", + "shellingham>=1.3.0", + "typing-extensions>=3.7.4.3", +] +files = [ + {file = "typer-0.15.4-py3-none-any.whl", hash = "sha256:eb0651654dcdea706780c466cf06d8f174405a659ffff8f163cfbfee98c0e173"}, + {file = "typer-0.15.4.tar.gz", hash = "sha256:89507b104f9b6a0730354f27c39fae5b63ccd0c95b1ce1f1a6ba0cfd329997c3"}, +] + [[package]] name = "types-pillow" version = "10.2.0.20240822" @@ -1931,6 +2221,17 @@ files = [ {file = "types_protobuf-5.29.1.20250403.tar.gz", hash = "sha256:7ff44f15022119c9d7558ce16e78b2d485bf7040b4fadced4dd069bb5faf77a2"}, ] +[[package]] +name = "types-pynput" +version = "1.8.1.20250318" +requires_python = ">=3.9" +summary = "Typing stubs for pynput" +groups = ["test"] +files = [ + {file = "types_pynput-1.8.1.20250318-py3-none-any.whl", hash = "sha256:0c1038aa1550941633114a2728ad85e392f67dfba970aebf755e369ab57aca70"}, + {file = "types_pynput-1.8.1.20250318.tar.gz", hash = "sha256:13d4df97843a7d1e7cddccbf9987aca7f0d463b214a8a35b4f53275d2c5a3576"}, +] + [[package]] name = "types-pyperclip" version = "1.9.0.20250218" @@ -1972,7 +2273,7 @@ name = "typing-extensions" version = "4.13.1" requires_python = ">=3.8" summary = "Backported and Experimental Type Hints for Python 3.8+" -groups = ["default", "chat", "test"] +groups = ["default", "chat", "mcp", "test"] files = [ {file = "typing_extensions-4.13.1-py3-none-any.whl", hash = "sha256:4b6cf02909eb5495cfbc3f6e8fd49217e6cc7944e145cdda8caa3734777f9e69"}, {file = "typing_extensions-4.13.1.tar.gz", hash = "sha256:98795af00fb9640edec5b8e31fc647597b4691f099ad75f469a2616be1a76dff"}, @@ -1983,7 +2284,7 @@ name = "typing-inspection" version = "0.4.0" requires_python = ">=3.9" summary = "Runtime typing introspection tools" -groups = ["default"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions>=4.12.0", ] @@ -2014,6 +2315,22 @@ files = [ {file = "urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d"}, ] +[[package]] +name = "uvicorn" +version = "0.34.3" +requires_python = ">=3.9" +summary = "The lightning-fast ASGI server." +groups = ["chat", "mcp"] +dependencies = [ + "click>=7.0", + "h11>=0.8", + "typing-extensions>=4.0; python_version < \"3.11\"", +] +files = [ + {file = "uvicorn-0.34.3-py3-none-any.whl", hash = "sha256:16246631db62bdfbf069b0645177d6e8a77ba950cfedbfd093acef9444e4d885"}, + {file = "uvicorn-0.34.3.tar.gz", hash = "sha256:35919a9a979d7a59334b6b10e05d77c1d0d574c50e0fc98b8b1a0f165708b55a"}, +] + [[package]] name = "watchdog" version = "6.0.0" @@ -2054,7 +2371,7 @@ name = "websockets" version = "15.0.1" requires_python = ">=3.9" summary = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"}, {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"}, diff --git a/pyproject.toml b/pyproject.toml index f3031588..0be5b9c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ dependencies = [ "gradio-client>=1.4.3", "requests>=2.32.3", "Jinja2>=3.1.4", - "pydantic-settings>=2.7.0", "tenacity>=9.1.2", + "pydantic-settings>=2.9.1", "python-dateutil>=2.9.0.post0", "openai>=1.61.1", "segment-analytics-python>=2.3.4", @@ -40,18 +40,36 @@ distribution = true [tool.pdm.scripts] test = "pytest -n auto" +"test:cov" = "pytest -n auto --cov=src/askui --cov-report=html" +"test:cov:view" = "python -m http.server --directory htmlcov" "test:e2e" = "pytest -n auto tests/e2e" +"test:e2e:cov" = "pytest -n auto tests/e2e --cov=src/askui --cov-report=html" "test:integration" = "pytest -n auto tests/integration" +"test:integration:cov" = "pytest -n auto tests/integration --cov=src/askui --cov-report=html" "test:unit" = "pytest -n auto tests/unit" -"test:cov:view" = "python -m http.server --directory htmlcov" +"test:unit:cov" = "pytest -n auto tests/unit --cov=src/askui --cov-report=html" format = "ruff format src tests" lint = "ruff check src tests" "lint:fix" = "ruff check --fix src tests" typecheck = "mypy" "typecheck:all" = "mypy src tests" chat = "streamlit run src/askui/chat/__main__.py" +"chat:api" = "uvicorn src.askui.chat.api.app:app --reload --port 8000" +mcp = "mcp dev src/askui/mcp/__init__.py" [dependency-groups] +chat = [ + "streamlit>=1.42.0", + "fastapi>=0.115.12", + "uvicorn>=0.34.3", +] +pynput = [ + "mss>=10.0.0", + "pynput>=1.8.1", +] +mcp = [ + "mcp[cli,rich,ws]>=1.8.1", +] test = [ "pytest>=8.3.4", "ruff>=0.9.5", @@ -66,13 +84,11 @@ test = [ "grpc-stubs>=1.53.0.3", "types-pyperclip>=1.8.2.20240311", "pytest-timeout>=2.4.0", + "types-pynput>=1.8.1.20250318", ] -chat = [ - "streamlit>=1.42.0", -] + [tool.pytest.ini_options] -addopts = "--cov=src/askui --cov-report=html" python_classes = ["Test*"] python_files = ["test_*.py"] python_functions = ["test_*"] diff --git a/src/askui/__init__.py b/src/askui/__init__.py index a9ff59d5..478012f3 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -6,38 +6,70 @@ from .locators import Locator from .models import ( ActModel, + Base64ImageSourceParam, + CacheControlEphemeralParam, + CitationCharLocationParam, + CitationContentBlockLocationParam, + CitationPageLocationParam, + ContentBlockParam, GetModel, + ImageBlockParam, LocateModel, + MessageParam, Model, ModelChoice, ModelComposition, ModelDefinition, + ModelName, ModelRegistry, + OnMessageCb, Point, + TextBlockParam, + TextCitationParam, + ToolResultBlockParam, + ToolUseBlockParam, + UrlImageSourceParam, ) from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry from .tools import ModifierKey, PcKey +from .tools.anthropic import ToolResult from .utils.image_utils import ImageSource, Img __all__ = [ "ActModel", + "Base64ImageSourceParam", + "CacheControlEphemeralParam", + "CitationCharLocationParam", + "CitationContentBlockLocationParam", + "CitationPageLocationParam", + "ConfigurableRetry", + "ContentBlockParam", "GetModel", + "ImageBlockParam", "ImageSource", "Img", "LocateModel", "Locator", + "MessageParam", "Model", + "ModelChoice", "ModelComposition", "ModelDefinition", - "ModelChoice", + "ModelName", "ModelRegistry", "ModifierKey", + "OnMessageCb", "PcKey", "Point", "ResponseSchema", "ResponseSchemaBase", "Retry", - "ConfigurableRetry", + "TextBlockParam", + "TextCitationParam", + "ToolResult", + "ToolResultBlockParam", + "ToolUseBlockParam", + "UrlImageSourceParam", "VisionAgent", ] diff --git a/src/askui/agent.py b/src/askui/agent.py index 19ab02eb..07e0fb49 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -8,6 +8,8 @@ from askui.container import telemetry from askui.locators.locators import Locator +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.utils.image_utils import ImageSource, Img from .exceptions import ElementNotFoundError @@ -216,7 +218,7 @@ def _mouse_move( self, locator: str | Locator, model: ModelComposition | str | None = None ) -> None: point = self._locate(locator=locator, model=model) - self._tools.os.mouse(point[0], point[1]) + self._tools.os.mouse_move(point[0], point[1]) @telemetry.record_call(exclude={"locator"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -534,8 +536,9 @@ def mouse_down( @validate_call def act( self, - goal: Annotated[str, Field(min_length=1)], + goal: Annotated[str | list[MessageParam], Field(min_length=1)], model: str | None = None, + on_message: OnMessageCb | None = None, ) -> None: """ Instructs the agent to achieve a specified goal through autonomous actions. @@ -547,6 +550,10 @@ def act( Args: goal (str): A description of what the agent should achieve. model (str | None, optional): The composition or name of the model(s) to be used for achieving the `goal`. + on_message (OnMessageCb | None, optional): Callback for new messages. If it returns `None`, stops and does not add the message. + + Returns: + None Example: ```python @@ -558,11 +565,19 @@ def act( agent.act("Log in with username 'admin' and password '1234'") ``` """ - self._reporter.add_message("User", f'act: "{goal}"') + goal_str = ( + goal + if isinstance(goal, str) + else "\n".join(msg.model_dump_json() for msg in goal) + ) + self._reporter.add_message("User", f'act: "{goal_str}"') logger.debug( - "VisionAgent received instruction to act towards the goal '%s'", goal + "VisionAgent received instruction to act towards the goal '%s'", goal_str + ) + messages: list[MessageParam] = ( + [MessageParam(role="user", content=goal)] if isinstance(goal, str) else goal ) - self._model_router.act(goal, model or self._model_choice["act"]) + self._model_router.act(messages, model or self._model_choice["act"], on_message) @telemetry.record_call() @validate_call diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 05b39ae8..25d47365 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -1,21 +1,23 @@ +import io import json -import logging -import re -from datetime import datetime, timezone +import time from pathlib import Path -from random import randint -from typing import Union +import httpx import streamlit as st -from PIL import Image, ImageDraw -from typing_extensions import TypedDict, override +from PIL import Image -from askui import VisionAgent -from askui.chat.click_recorder import ClickRecorder -from askui.chat.exceptions import FunctionExecutionError, InvalidFunctionError -from askui.models import ModelName -from askui.reporting import Reporter -from askui.utils.image_utils import base64_to_image, draw_point_on_image +from askui.chat.api.messages.service import Message, MessageService +from askui.chat.api.runs.service import RunService +from askui.chat.api.threads.service import ThreadService + +# from askui.chat.click_recorder import ClickRecorder +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + MessageParam, + UrlImageSourceParam, +) +from askui.utils.image_utils import base64_to_image st.set_page_config( page_title="Vision Agent Chat", @@ -23,318 +25,358 @@ ) -CHAT_SESSIONS_DIR_PATH = Path("./chat/sessions") -CHAT_IMAGES_DIR_PATH = Path("./chat/images") - -click_recorder = ClickRecorder() - - -def setup_chat_dirs() -> None: - Path.mkdir(CHAT_SESSIONS_DIR_PATH, parents=True, exist_ok=True) - Path.mkdir(CHAT_IMAGES_DIR_PATH, parents=True, exist_ok=True) +BASE_DIR = Path("./chat") -def get_session_id_from_path(path: str) -> str: - """Get session ID from file path.""" - return Path(path).stem +@st.cache_resource +def get_thread_service() -> ThreadService: + return ThreadService(BASE_DIR) -def load_chat_history(session_id: str) -> list[dict]: - """Load chat history for a given session ID.""" - messages: list[dict] = [] - session_path = CHAT_SESSIONS_DIR_PATH / f"{session_id}.jsonl" - if session_path.exists(): - with session_path.open("r") as f: - messages.extend(json.loads(line) for line in f) - return messages +@st.cache_resource +def get_message_service() -> MessageService: + return MessageService(BASE_DIR) -ROLE_MAP = { - "user": "user", - "anthropic computer use": "ai", - "agentos": "assistant", - "user (demonstration)": "user", -} +@st.cache_resource +def get_run_service() -> RunService: + return RunService(BASE_DIR) -UNKNOWN_ROLE = "unknown" +thread_service = get_thread_service() +message_service = get_message_service() +run_service = get_run_service() +# click_recorder = ClickRecorder() -def get_image(img_b64_str_or_path: str) -> Image.Image: - """Get image from base64 string or file path.""" - if Path(img_b64_str_or_path).is_file(): - return Image.open(img_b64_str_or_path) - return base64_to_image(img_b64_str_or_path) - -def write_message( - role: str, - content: str | dict | list, - timestamp: str, - image: Image.Image - | str - | list[str | Image.Image] - | list[str] - | list[Image.Image] - | None = None, +def get_image( + source: Base64ImageSourceParam | UrlImageSourceParam, +) -> Image.Image: + match source.type: + case "base64": + data = source.data + if isinstance(data, str): + return base64_to_image(data) + error_msg = f"Image source data type not supported: {type(data)}" + raise NotImplementedError(error_msg) + case "url": + response = httpx.get(source.url) + return Image.open(io.BytesIO(response.content)) + + +def write_message( # noqa: C901 + message: Message, ) -> None: - _role = ROLE_MAP.get(role.lower(), UNKNOWN_ROLE) - avatar = None if _role != UNKNOWN_ROLE else "❔" - with st.chat_message(_role, avatar=avatar): - st.markdown(f"*{timestamp}* - **{role}**\n\n") - st.markdown( - json.dumps(content, indent=2) - if isinstance(content, (dict, list)) - else content - ) - if image: - if isinstance(image, list): - for img in image: - img = get_image(img) if isinstance(img, str) else img - st.image(img) + # Create a container for the message and delete button + col1, col2 = st.columns([0.95, 0.05]) + + with col1: + with st.chat_message(message.role): + st.markdown(f"*{message.created_at.isoformat()}* - **{message.role}**\n\n") + if isinstance(message.content, str): + st.markdown(message.content) else: - img = get_image(image) if isinstance(image, str) else image - st.image(img) - - -def save_image(image: Image.Image) -> str: - """Save image to disk and return path.""" - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S_%f") - image_path = CHAT_IMAGES_DIR_PATH / f"image_{timestamp}.png" - image.save(image_path) - return str(image_path) - - -class Message(TypedDict): - role: str - content: str | dict | list - timestamp: str - image: str | list[str] | None - - -class ChatHistoryAppender(Reporter): - def __init__(self, session_id: str) -> None: - self._session_id = session_id - - @override - def add_message( - self, - role: str, - content: Union[str, dict, list], - image: Image.Image | list[Image.Image] | None = None, - ) -> None: - image_paths: list[str] = [] - if image is None: - _images = [] - elif isinstance(image, list): - _images = image - else: - _images = [image] - image_paths.extend(save_image(img) for img in _images) - message = Message( - role=role, - content=content, - timestamp=datetime.now(tz=timezone.utc).isoformat(), - image=image_paths, - ) - write_message(**message) - with (CHAT_SESSIONS_DIR_PATH / f"{self._session_id}.jsonl").open("a") as f: - json.dump(message, f) - f.write("\n") - - @override - def generate(self) -> None: - pass - - -def get_available_sessions() -> list[str]: - """Get list of available session IDs.""" - session_files = list(CHAT_SESSIONS_DIR_PATH.glob("*.jsonl")) - return sorted([get_session_id_from_path(f) for f in session_files], reverse=True) - - -def create_new_session() -> str: - """Create a new chat session.""" - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S%f") - random_suffix = f"{randint(100, 999)}" - session_id = f"{timestamp}{random_suffix}" - (CHAT_SESSIONS_DIR_PATH / f"{session_id}.jsonl").touch() - return session_id - - -def paint_crosshair( - image: Image.Image, - coordinates: tuple[int, int], - size: int | None = None, - color: str = "red", - width: int = 4, -) -> Image.Image: - """ - Paints a crosshair at the given coordinates on the image. - - :param image: A PIL Image object. - :param coordinates: A tuple (x, y) representing the coordinates of the point. - :param size: Optional length of each line in the crosshair. Defaults to min(width,height)/20 - :param color: The color of the crosshair. - :param width: The width of the crosshair. - :return: A new image with the crosshair. - """ - if size is None: - size = ( - min(image.width, image.height) // 20 - ) # Makes crosshair ~5% of smallest image dimension - - image_copy = image.copy() - draw = ImageDraw.Draw(image_copy) - x, y = coordinates - # Draw horizontal and vertical lines - draw.line((x - size, y, x + size, y), fill=color, width=width) - draw.line((x, y - size, x, y + size), fill=color, width=width) - return image_copy - - -prompt = """The following image is a screenshot with a red crosshair on top of an element that the user wants to interact with. Give me a description that uniquely describes the element as concise as possible across all elements on the screen that the user most likely wants to interact with. Examples: - -- "Submit button" -- "Cell within the table about European countries in the third row and 6th column (area in km^2) in the right-hand browser window" -- "Avatar in the top right hand corner of the browser in focus that looks like a woman" -""" - - -def rerun() -> None: - st.markdown("### Re-running...") - with VisionAgent( - log_level=logging.DEBUG, - ) as agent: - screenshot: Image.Image | None = None - for message in st.session_state.messages: - try: - if ( - message.get("role") == "AgentOS" - or message.get("role") == "User (Demonstration)" - ): - if message.get("content") == "screenshot()": - screenshot = get_image(message["image"]) - continue - if message.get("content"): - if match := re.match( - r"mouse\((\d+),\s*(\d+)\)", message["content"] - ): - if not screenshot: - error_msg = "Screenshot is required to paint crosshair" - raise ValueError(error_msg) # noqa: TRY301 - x, y = map(int, match.groups()) - screenshot_with_crosshair = paint_crosshair( - screenshot, (x, y) - ) - element_description = agent.get( - query=prompt, - image=screenshot_with_crosshair, - model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, - ) - write_message( - message["role"], - f"Move mouse to {element_description}", - datetime.now(tz=timezone.utc).isoformat(), - image=screenshot_with_crosshair, + for block in message.content: + match block.type: + case "image": + st.image(get_image(block.source)) + case "text": + st.markdown(block.text) + case "tool_result": + st.markdown(f"Tool use id: {block.tool_use_id}") + st.markdown(f"Erroneous: {block.is_error}") + content = block.content + if isinstance(content, str): + st.markdown(content) + else: + for nested_block in content: + match nested_block.type: + case "image": + st.image(get_image(nested_block.source)) + case "text": + st.markdown(nested_block.text) + case _: + st.markdown( + json.dumps(block.model_dump(mode="json"), indent=2) ) - agent.mouse_move( - locator=element_description.replace('"', ""), - model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, - ) - else: - write_message( - message["role"], - message["content"], - datetime.now(tz=timezone.utc).isoformat(), - message.get("image"), - ) - func_call = f"agent.tools.os.{message['content']}" - eval(func_call) - except json.JSONDecodeError: - continue - except AttributeError: - st.write(str(InvalidFunctionError(message["content"]))) - except Exception as e: # noqa: BLE001 - We want to catch all other exceptions here - st.write(str(FunctionExecutionError(message["content"], e))) + # Add delete button in the second column if message_id is provided + with col2: + if st.button("🗑️", key=f"delete_{message.id}"): + message_service.delete(st.session_state.thread_id, message.id) + st.rerun() + + +# def paint_crosshair( +# image: Image.Image, +# coordinates: tuple[int, int], +# size: int | None = None, +# color: str = "red", +# width: int = 4, +# ) -> Image.Image: +# """ +# Paints a crosshair at the given coordinates on the image. + +# :param image: A PIL Image object. +# :param coordinates: A tuple (x, y) representing the coordinates of the point. +# :param size: Optional length of each line in the crosshair. Defaults to min(width,height)/20 +# :param color: The color of the crosshair. +# :param width: The width of the crosshair. +# :return: A new image with the crosshair. +# """ +# if size is None: +# size = ( +# min(image.width, image.height) // 20 +# ) # Makes crosshair ~5% of smallest image dimension + +# image_copy = image.copy() +# draw = ImageDraw.Draw(image_copy) +# x, y = coordinates +# # Draw horizontal and vertical lines +# draw.line((x - size, y, x + size, y), fill=color, width=width) +# draw.line((x, y - size, x, y + size), fill=color, width=width) +# return image_copy + + +# prompt = """The following image is a screenshot with a red crosshair on top of an element that the user wants to interact with. Give me a description that uniquely describes the element as concise as possible across all elements on the screen that the user most likely wants to interact with. Examples: + +# - "Submit button" +# - "Cell within the table about European countries in the third row and 6th column (area in km^2) in the right-hand browser window" +# - "Avatar in the top right hand corner of the browser in focus that looks like a woman" +# """ + + +# def rerun() -> None: +# st.markdown("### Re-running...") +# with VisionAgent( +# log_level=logging.DEBUG, +# tools=tools, +# ) as agent: +# screenshot: Image.Image | None = None +# for message in messages_service.list_(st.session_state.thread_id).data: +# try: +# if ( +# message.role == MessageRole.ASSISTANT +# or message.role == MessageRole.USER +# ): +# content = message.content[0] +# if content.text == "screenshot()": +# screenshot = ( +# get_image(content.image_paths[0]) +# if content.image_paths +# else None +# ) +# continue +# if content.text: +# if match := re.match( +# r"mouse\((\d+),\s*(\d+)\)", cast("str", content.text) +# ): +# if not screenshot: +# error_msg = "Screenshot is required to paint crosshair" +# raise ValueError(error_msg) # noqa: TRY301 +# x, y = map(int, match.groups()) +# screenshot_with_crosshair = paint_crosshair( +# screenshot, (x, y) +# ) +# element_description = agent.get( +# query=prompt, +# image=screenshot_with_crosshair, +# model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, +# ) +# messages_service.create( +# thread_id=st.session_state.thread_id, +# role=message.role.value, +# content=f"Move mouse to {element_description}", +# image=screenshot_with_crosshair, +# ) +# agent.mouse_move( +# locator=element_description.replace('"', ""), +# model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, +# ) +# else: +# messages_service.create( +# thread_id=st.session_state.thread_id, +# role=message.role.value, +# content=content.text, +# image=None, +# ) +# func_call = f"agent.tools.os.{content.text}" +# eval(func_call) +# except json.JSONDecodeError: +# continue +# except AttributeError: +# st.write(str(InvalidFunctionError(cast("str", content.text)))) +# except Exception as e: # noqa: BLE001 - We want to catch all other exceptions here +# st.write(str(FunctionExecutionError(cast("str", content.text), e))) -setup_chat_dirs() if st.sidebar.button("New Chat"): - st.session_state.session_id = create_new_session() + thread = thread_service.create() + st.session_state.thread_id = thread.id st.rerun() -available_sessions = get_available_sessions() -session_id = st.session_state.get("session_id", None) +available_threads = thread_service.list_().data +thread_id = st.session_state.get("thread_id", None) -if not session_id and not available_sessions: - session_id = create_new_session() - st.session_state.session_id = session_id - st.rerun() - -index_of_session = available_sessions.index(session_id) if session_id else 0 -session_id = st.sidebar.radio( - "Sessions", - available_sessions, - index=index_of_session, -) -if session_id != st.session_state.get("session_id"): - st.session_state.session_id = session_id +if not thread_id and not available_threads: + thread = thread_service.create() + thread_id = thread.id + st.session_state.thread_id = thread_id st.rerun() -reporter = ChatHistoryAppender(session_id) - -st.title(f"Vision Agent Chat - {session_id}") -st.session_state.messages = load_chat_history(session_id) - -# Display chat history -for message in st.session_state.messages: - write_message( - message["role"], - message["content"], - message["timestamp"], - message.get("image"), +index_of_thread = 0 +if thread_id: + for index, thread in enumerate(available_threads): + if thread.id == thread_id: + index_of_thread = index + break + +# Create columns for thread selection and delete buttons +thread_cols = st.sidebar.columns([0.8, 0.2]) +with thread_cols[0]: + thread_id = st.radio( + "Threads", + [t.id for t in available_threads], + index=index_of_thread, ) -if value_to_type := st.chat_input("Simulate Typing for User (Demonstration)"): - reporter.add_message( - role="User (Demonstration)", - content=f'type("{value_to_type}", 50)', - ) +# Add delete buttons for each thread +for t in available_threads: + with thread_cols[1]: + if st.button("🗑️", key=f"delete_thread_{t.id}"): + if t.id == thread_id: + # If deleting current thread, switch to first available thread + remaining_threads = [th for th in available_threads if th.id != t.id] + if remaining_threads: + st.session_state.thread_id = remaining_threads[0].id + else: + # Create new thread if no threads left + new_thread = thread_service.create() + st.session_state.thread_id = new_thread.id + thread_service.delete(t.id) + st.rerun() + +if thread_id != st.session_state.get("thread_id"): + st.session_state.thread_id = thread_id st.rerun() -if st.button("Simulate left click"): - reporter.add_message( - role="User (Demonstration)", - content='click("left", 1)', - ) - st.rerun() -# Chat input -if st.button( - "Demonstrate where to move mouse" -): # only single step, only click supported for now, independent of click always registered as click - image, coordinates = click_recorder.record() - reporter.add_message( - role="User (Demonstration)", - content="screenshot()", - image=image, - ) - reporter.add_message( - role="User (Demonstration)", - content=f"mouse({coordinates[0]}, {coordinates[1]})", - image=draw_point_on_image(image, coordinates[0], coordinates[1]), - ) - st.rerun() +st.title(f"Vision Agent Chat - {thread_id}") + +# Display chat history +messages = message_service.list_(thread_id).data +for message in messages: + write_message(message) + +last_message = messages[-1] if messages else None + +# if value_to_type := st.chat_input("Simulate Typing for User (Demonstration)"): +# reporter.add_message( +# role="user", +# content=f'type("{value_to_type}", 50)', +# ) +# st.rerun() + +# if st.button("Simulate left click"): +# reporter.add_message( +# role="User (Demonstration)", +# content='click("left", 1)', +# ) +# st.rerun() + +# # Chat input +# if st.button( +# "Demonstrate where to move mouse" +# ): # only single step, only click supported for now, independent of click always registered as click +# image, coordinates = click_recorder.record() +# reporter.add_message( +# role="User (Demonstration)", +# content="screenshot()", +# image=image, +# ) +# reporter.add_message( +# role="User (Demonstration)", +# content=f"mouse_move({coordinates[0]}, {coordinates[1]})", +# image=draw_point_on_image(image, coordinates[0], coordinates[1]), +# ) +# st.rerun() + +# if st.session_state.get("input_event_listening"): +# while input_event := tools.os.poll_event(): +# image = tools.os.screenshot(report=False) +# if input_event.pressed: +# reporter.add_message( +# role="User (Demonstration)", +# content=f"mouse_move({input_event.x}, {input_event.y})", +# image=draw_point_on_image(image, input_event.x, input_event.y), +# ) +# reporter.add_message( +# role="User (Demonstration)", +# content=f'click("{input_event.button}")', +# ) +# if st.button("Refresh"): +# st.rerun() +# if st.button("Stop listening to input events"): +# tools.os.stop_listening() +# st.session_state["input_event_listening"] = False +# st.rerun() +# else: +# if st.button("Listen to input events"): +# tools.os.start_listening() +# st.session_state["input_event_listening"] = True +# st.rerun() if act_prompt := st.chat_input("Ask AI"): - with VisionAgent( - log_level=logging.DEBUG, - reporters=[reporter], - ) as agent: - agent.act(act_prompt, model="claude") - st.rerun() - -if st.button("Rerun"): - rerun() + if act_prompt != "Continue": + last_message = message_service.create( + thread_id=thread_id, + message=MessageParam( + role="user", + content=act_prompt, + ), + ) + write_message(last_message) + run = run_service.create(thread_id, stream=False) + time.sleep(1) + while run := run_service.retrieve(run.id): + new_messages = message_service.list_( + thread_id, after=last_message.id if last_message else None + ).data + for message in new_messages: + write_message(message) + last_message = new_messages[-1] if new_messages else last_message + if run.status not in {"queued", "running", "in_progress"}: + break + time.sleep(1) + + +if act_prompt := st.chat_input("Ask AI (streaming)"): + if act_prompt != "Continue": + last_message = message_service.create( + thread_id=thread_id, + message=MessageParam( + role="user", + content=act_prompt, + ), + ) + write_message(last_message) + + # Use the streaming API + event_stream = run_service.create(thread_id, stream=True) + import asyncio + + async def handle_stream() -> None: + last_msg_id = last_message.id if last_message else None + async for event in event_stream: + if event.event == "message.created": + msg = event.data + if msg and (not last_msg_id or msg.id > last_msg_id): + write_message(msg) + last_msg_id = msg.id + + # Run the async handler in Streamlit (sync context) + asyncio.run(handle_stream()) + +# if st.button("Rerun"): +# rerun() diff --git a/src/askui/chat/api/__init__.py b/src/askui/chat/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py new file mode 100644 index 00000000..48b4eafe --- /dev/null +++ b/src/askui/chat/api/app.py @@ -0,0 +1,27 @@ +from fastapi import APIRouter, FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from askui.chat.api.messages.router import router as messages_router +from askui.chat.api.runs.router import router as runs_router +from askui.chat.api.threads.router import router as threads_router + +app = FastAPI( + title="AskUI Chat API", + version="0.1.0", +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +v1_router = APIRouter(prefix="/v1") +v1_router.include_router(threads_router) +v1_router.include_router(messages_router) +v1_router.include_router(runs_router) +app.include_router(v1_router) diff --git a/src/askui/chat/api/dependencies.py b/src/askui/chat/api/dependencies.py new file mode 100644 index 00000000..a9c78c2f --- /dev/null +++ b/src/askui/chat/api/dependencies.py @@ -0,0 +1,11 @@ +from fastapi import Depends + +from askui.chat.api.settings import Settings + + +def get_settings() -> Settings: + """Get ChatApiSettings instance.""" + return Settings() + + +SettingsDep = Depends(get_settings) diff --git a/src/askui/chat/api/messages/__init__.py b/src/askui/chat/api/messages/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py new file mode 100644 index 00000000..86899d14 --- /dev/null +++ b/src/askui/chat/api/messages/dependencies.py @@ -0,0 +1,13 @@ +from fastapi import Depends + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.messages.service import MessageService +from askui.chat.api.settings import Settings + + +def get_message_service(settings: Settings = SettingsDep) -> MessageService: + """Get MessageService instance.""" + return MessageService(settings.data_dir) + + +MessageServiceDep = Depends(get_message_service) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py new file mode 100644 index 00000000..114a1c03 --- /dev/null +++ b/src/askui/chat/api/messages/router.py @@ -0,0 +1,62 @@ +from fastapi import APIRouter, HTTPException, status + +from askui.chat.api.messages.dependencies import MessageServiceDep +from askui.chat.api.messages.service import Message, MessageListResponse, MessageService +from askui.models.shared.computer_agent_message_param import MessageParam + +router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) + + +@router.get("") +def list_messages( + thread_id: str, + limit: int | None = None, + message_service: MessageService = MessageServiceDep, +) -> MessageListResponse: + """List all messages in a thread.""" + try: + return message_service.list_(thread_id, limit=limit) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + +@router.post("", status_code=status.HTTP_201_CREATED) +async def create_message( + thread_id: str, + message: MessageParam, + message_service: MessageService = MessageServiceDep, +) -> Message: + """Create a new message in a thread.""" + try: + return message_service.create( + thread_id=thread_id, + message=message, + ) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + +@router.get("/{message_id}") +def retrieve_message( + thread_id: str, + message_id: str, + message_service: MessageService = MessageServiceDep, +) -> Message: + """Get a specific message from a thread.""" + try: + return message_service.retrieve(thread_id, message_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + +@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_message( + thread_id: str, + message_id: str, + message_service: MessageService = MessageServiceDep, +) -> None: + """Delete a message from a thread.""" + try: + message_service.delete(thread_id, message_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py new file mode 100644 index 00000000..8af02967 --- /dev/null +++ b/src/askui/chat/api/messages/service.py @@ -0,0 +1,172 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal + +from pydantic import AwareDatetime, BaseModel, Field + +from askui.chat.api.models import Event +from askui.chat.api.utils import generate_time_ordered_id +from askui.models.shared.computer_agent_message_param import MessageParam + + +class Message(MessageParam): + """A message in a thread.""" + + id: str = Field(default_factory=lambda: generate_time_ordered_id("msg")) + thread_id: str + created_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + object: str = "message" + + +class MessageEvent(Event): + data: Message + event: Literal["message.created"] + + +class MessageListResponse(BaseModel): + """Response model for listing messages.""" + + object: str = "list" + data: list[Message] + first_id: str | None = None + last_id: str | None = None + has_more: bool = False + + +class MessageService: + """Service for managing messages within threads.""" + + def __init__(self, base_dir: Path) -> None: + """Initialize message service. + + Args: + base_dir: Base directory to store message data + """ + self._base_dir = base_dir + self._threads_dir = base_dir / "threads" + + def list_( + self, thread_id: str, limit: int | None = None, after: str | None = None + ) -> MessageListResponse: + """List all messages in a thread. + + Args: + thread_id: ID of thread to list messages from + limit: Optional maximum number of messages to return + after: Optional message ID after which messages are returned + + Returns: + MessageListResponse containing messages sorted by creation date + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + messages: list[Message] = [] + with thread_file.open("r") as f: + for line in f: + msg = Message.model_validate_json(line) + messages.append(msg) + + # Sort by creation date + messages = sorted(messages, key=lambda m: m.created_at) + if after: + messages = [m for m in messages if m.id > after] + + # Apply limit if specified + if limit is not None: + messages = messages[:limit] + + return MessageListResponse( + data=messages, + first_id=messages[0].id if messages else None, + last_id=messages[-1].id if messages else None, + has_more=len(messages) > (limit or len(messages)), + ) + + def create( + self, + thread_id: str, + message: MessageParam, + ) -> Message: + """Create a new message in a thread. + + Args: + thread_id: ID of thread to create message in + role: Role of message sender + content: Message content + + Returns: + Created message object + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + message = Message.model_construct( + thread_id=thread_id, + role=message.role, + content=message.content, + ) + with thread_file.open("a") as f: + f.write(message.model_dump_json()) + f.write("\n") + return message + + def retrieve(self, thread_id: str, message_id: str) -> Message: + """Retrieve a specific message from a thread. + + Args: + thread_id: ID of thread containing message + message_id: ID of message to retrieve + + Returns: + Message object + + Raises: + FileNotFoundError: If thread or message doesn't exist + """ + messages = self.list_(thread_id).data + for msg in messages: + if msg.id == message_id: + return msg + error_msg = f"Message {message_id} not found in thread {thread_id}" + raise FileNotFoundError(error_msg) + + def delete(self, thread_id: str, message_id: str) -> None: + """Delete a message from a thread. + + Args: + thread_id: ID of thread containing message + message_id: ID of message to delete + + Raises: + FileNotFoundError: If thread or message doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + # Read all messages + messages: list[Message] = [] + with thread_file.open("r") as f: + for line in f: + msg = Message.model_validate_json(line) + if msg.id != message_id: + messages.append(msg) + + # Write back all messages except the deleted one + with thread_file.open("w") as f: + for msg in messages: + f.write(msg.model_dump_json()) + f.write("\n") diff --git a/src/askui/chat/api/models.py b/src/askui/chat/api/models.py new file mode 100644 index 00000000..81d23f14 --- /dev/null +++ b/src/askui/chat/api/models.py @@ -0,0 +1,7 @@ +from typing import Literal + +from pydantic import BaseModel + + +class Event(BaseModel): + object: Literal["event"] = "event" diff --git a/src/askui/chat/api/runs/__init__.py b/src/askui/chat/api/runs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py new file mode 100644 index 00000000..772c2545 --- /dev/null +++ b/src/askui/chat/api/runs/dependencies.py @@ -0,0 +1,14 @@ +from fastapi import Depends + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.settings import Settings + +from .service import RunService + + +def get_runs_service(settings: Settings = SettingsDep) -> RunService: + """Get RunService instance.""" + return RunService(settings.data_dir) + + +RunServiceDep = Depends(get_runs_service) diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py new file mode 100644 index 00000000..c3ee94b7 --- /dev/null +++ b/src/askui/chat/api/runs/router.py @@ -0,0 +1,87 @@ +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Annotated, cast + +from fastapi import APIRouter, Body, HTTPException, Path, Response, status +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +if TYPE_CHECKING: + from askui.chat.api.messages.service import MessageEvent + +from .dependencies import RunServiceDep +from .service import Run, RunEvent, RunListResponse, RunService + + +class CreateRunRequest(BaseModel): + stream: bool = False + + +router = APIRouter(prefix="/threads/{thread_id}/runs", tags=["runs"]) + + +@router.post("") +def create_run( + thread_id: Annotated[str, Path(...)], + request: Annotated[CreateRunRequest, Body(...)], + run_service: RunService = RunServiceDep, +) -> Response: + """ + Create a new run for a given thread. + """ + stream = request.stream + run_or_async_generator = run_service.create(thread_id, stream) + if stream: + async_generator = cast( + "AsyncGenerator[RunEvent | MessageEvent, None]", run_or_async_generator + ) + + async def sse_event_stream() -> AsyncGenerator[str, None]: + async for event in async_generator: + yield f"event: {event.event}\ndata: {event.model_dump_json()}\n\n" + + return StreamingResponse( + status_code=status.HTTP_201_CREATED, + content=sse_event_stream(), + media_type="text/event-stream", + ) + run = cast("Run", run_or_async_generator) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=run.model_dump()) + + +@router.get("/{run_id}") +def retrieve_run( + run_id: Annotated[str, Path(...)], + run_service: RunService = RunServiceDep, +) -> Run: + """ + Retrieve a run by its ID. + """ + try: + return run_service.retrieve(run_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + +@router.get("") +def list_runs( + thread_id: Annotated[str, Path(...)], + run_service: RunService = RunServiceDep, +) -> RunListResponse: + """ + List runs, optionally filtered by thread. + """ + return run_service.list_(thread_id) + + +@router.post("/{run_id}/cancel") +def cancel_run( + run_id: Annotated[str, Path(...)], + run_service: RunService = RunServiceDep, +) -> Run: + """ + Cancel a run by its ID. + """ + try: + return run_service.cancel(run_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py new file mode 100644 index 00000000..4dff82a9 --- /dev/null +++ b/src/askui/chat/api/runs/service.py @@ -0,0 +1,287 @@ +import asyncio +import queue +import threading +from collections.abc import AsyncGenerator +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Literal, Sequence, cast, overload + +from pydantic import AwareDatetime, BaseModel, Field, computed_field + +from askui.agent import VisionAgent +from askui.chat.api.messages.service import MessageEvent, MessageService +from askui.chat.api.models import Event +from askui.chat.api.utils import generate_time_ordered_id +from askui.models.shared.computer_agent_cb_param import OnMessageCbParam +from askui.models.shared.computer_agent_message_param import MessageParam + +RunStatus = Literal[ + "queued", + "in_progress", + "completed", + "cancelling", + "cancelled", + "failed", + "expired", +] + + +class RunError(BaseModel): + message: str + code: Literal["server_error"] + + +class Run(BaseModel): + id: str = Field(default_factory=lambda: generate_time_ordered_id("run")) + thread_id: str + created_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + started_at: AwareDatetime | None = None + completed_at: AwareDatetime | None = None + tried_cancelling_at: AwareDatetime | None = None + cancelled_at: AwareDatetime | None = None + expires_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + timedelta(minutes=10) + ) + failed_at: AwareDatetime | None = None + last_error: RunError | None = None + object: Literal["run"] = "run" + + @computed_field + @property + def status(self) -> RunStatus: + if self.cancelled_at: + return "cancelled" + if self.failed_at: + return "failed" + if self.completed_at: + return "completed" + if self.expires_at and self.expires_at < datetime.now(tz=timezone.utc): + return "expired" + if self.tried_cancelling_at: + return "cancelling" + if self.started_at: + return "in_progress" + return "queued" + + +class RunListResponse(BaseModel): + object: Literal["list"] = "list" + data: Sequence[Run] + first_id: str | None = None + last_id: str | None = None + has_more: bool = False + + +class RunEvent(Event): + data: Run + event: Literal[ + "run.created", + "run.started", + "run.completed", + "run.failed", + "run.cancelled", + "run.expired", + ] + + +class Runner: + def __init__(self, run: Run, base_dir: Path) -> None: + self._run = run + self._base_dir = base_dir + self._runs_dir = base_dir / "runs" + self._msg_service = MessageService(self._base_dir) + + def run(self, event_queue: queue.Queue[RunEvent | MessageEvent | None]) -> None: + self._mark_started() + event_queue.put( + RunEvent( + data=self._run, + event="run.started", + ) + ) + messages: list[MessageParam] = [ + cast("MessageParam", msg) + for msg in self._msg_service.list_(self._run.thread_id).data + ] + + def on_message( + on_message_cb_param: OnMessageCbParam, + ) -> MessageParam | None: + message = self._msg_service.create( + thread_id=self._run.thread_id, + message=on_message_cb_param.message, + ) + event_queue.put( + MessageEvent( + data=message, + event="message.created", + ) + ) + updated_run = self._retrieve_run() + if updated_run.status == "cancelling": + updated_run.cancelled_at = datetime.now(tz=timezone.utc) + self._update_run_file(updated_run) + event_queue.put( + RunEvent( + data=updated_run, + event="run.cancelled", + ) + ) + return None + if updated_run.status == "expired": + event_queue.put( + RunEvent( + data=updated_run, + event="run.expired", + ) + ) + return None + return on_message_cb_param.message + + try: + with VisionAgent() as agent: + agent.act(messages, on_message=on_message) + updated_run = self._retrieve_run() + if updated_run.status == "in_progress": + updated_run.completed_at = datetime.now(tz=timezone.utc) + self._update_run_file(updated_run) + event_queue.put( + RunEvent( + data=updated_run, + event="run.completed", + ) + ) + except Exception as e: # noqa: BLE001 + updated_run = self._retrieve_run() + updated_run.failed_at = datetime.now(tz=timezone.utc) + updated_run.last_error = RunError(message=str(e), code="server_error") + self._update_run_file(updated_run) + event_queue.put( + RunEvent( + data=updated_run, + event="run.failed", + ) + ) + finally: + event_queue.put(None) + + def _mark_started(self) -> None: + self._run.started_at = datetime.now(tz=timezone.utc) + self._update_run_file(self._run) + + def _should_abort(self, run: Run) -> bool: + return run.status in ("cancelled", "cancelling", "expired") + + def _update_run_file(self, run: Run) -> None: + run_file = self._runs_dir / f"{run.thread_id}__{run.id}.json" + with run_file.open("w") as f: + f.write(run.model_dump_json()) + + def _retrieve_run(self) -> Run: + run_file = self._runs_dir / f"{self._run.thread_id}__{self._run.id}.json" + with run_file.open("r") as f: + return Run.model_validate_json(f.read()) + + +class RunService: + """ + Service for managing runs. Handles creation, retrieval, listing, and cancellation of runs. + """ + + def __init__(self, base_dir: Path) -> None: + self._base_dir = base_dir + self._runs_dir = base_dir / "runs" + + def _run_path(self, thread_id: str, run_id: str) -> Path: + return self._runs_dir / f"{thread_id}__{run_id}.json" + + def _create_run(self, thread_id: str) -> Run: + run = Run(thread_id=thread_id) + self._runs_dir.mkdir(parents=True, exist_ok=True) + self._update_run_file(run) + return run + + @overload + def create(self, thread_id: str, stream: Literal[False]) -> Run: ... + + @overload + def create( + self, thread_id: str, stream: Literal[True] + ) -> AsyncGenerator[RunEvent | MessageEvent, None]: ... + + @overload + def create( + self, thread_id: str, stream: bool + ) -> Run | AsyncGenerator[RunEvent | MessageEvent, None]: ... + + def create( + self, thread_id: str, stream: bool + ) -> Run | AsyncGenerator[RunEvent | MessageEvent, None]: + run = self._create_run(thread_id) + event_queue: queue.Queue[RunEvent | MessageEvent | None] = queue.Queue() + runner = Runner(run, self._base_dir) + thread = threading.Thread(target=runner.run, args=(event_queue,)) + thread.start() + if stream: + + async def event_stream() -> AsyncGenerator[RunEvent | MessageEvent, None]: + yield RunEvent( + data=run, + event="run.created", + ) + loop = asyncio.get_event_loop() + while True: + event = await loop.run_in_executor(None, event_queue.get) + if event is None: + break + yield event + + return event_stream() + return run + + def _update_run_file(self, run: Run) -> None: + run_file = self._run_path(run.thread_id, run.id) + with run_file.open("w") as f: + f.write(run.model_dump_json()) + + def retrieve(self, run_id: str) -> Run: + # Find the file by run_id + for f in self._runs_dir.glob(f"*__{run_id}.json"): + with f.open("r") as file: + return Run.model_validate_json(file.read()) + error_msg = f"Run {run_id} not found" + raise FileNotFoundError(error_msg) + + def list_(self, thread_id: str | None = None) -> RunListResponse: + if not self._runs_dir.exists(): + return RunListResponse(data=[]) + if thread_id: + run_files = list(self._runs_dir.glob(f"{thread_id}__*.json")) + else: + run_files = list(self._runs_dir.glob("*__*.json")) + runs: list[Run] = [] + for f in run_files: + with f.open("r") as file: + runs.append(Run.model_validate_json(file.read())) + runs = sorted(runs, key=lambda r: r.created_at, reverse=True) + return RunListResponse( + data=runs, + first_id=runs[0].id if runs else None, + last_id=runs[-1].id if runs else None, + has_more=False, + ) + + def cancel(self, run_id: str) -> Run: + run = self.retrieve(run_id) + if run.status in ("cancelled", "cancelling", "completed", "failed", "expired"): + return run + run.tried_cancelling_at = datetime.now(tz=timezone.utc) + for f in self._runs_dir.glob(f"*__{run_id}.json"): + with f.open("w") as file: + file.write(run.model_dump_json()) + return run + # Find the file by run_id + error_msg = f"Run {run_id} not found" + raise FileNotFoundError(error_msg) diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py new file mode 100644 index 00000000..c091a6cf --- /dev/null +++ b/src/askui/chat/api/settings.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Settings for the chat API.""" + + model_config = SettingsConfigDict( + env_prefix="ASKUI__CHAT_API__", env_nested_delimiter="__" + ) + + data_dir: Path = Field( + default_factory=lambda: Path.cwd() / "chat", + description="Base directory for storing chat data", + ) diff --git a/src/askui/chat/api/threads/__init__.py b/src/askui/chat/api/threads/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/threads/dependencies.py b/src/askui/chat/api/threads/dependencies.py new file mode 100644 index 00000000..3d607559 --- /dev/null +++ b/src/askui/chat/api/threads/dependencies.py @@ -0,0 +1,13 @@ +from fastapi import Depends + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.settings import Settings +from askui.chat.api.threads.service import ThreadService + + +def get_thread_service(settings: Settings = SettingsDep) -> ThreadService: + """Get ThreadService instance.""" + return ThreadService(settings.data_dir) + + +ThreadServiceDep = Depends(get_thread_service) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py new file mode 100644 index 00000000..f899863e --- /dev/null +++ b/src/askui/chat/api/threads/router.py @@ -0,0 +1,47 @@ +from fastapi import APIRouter, HTTPException, status + +from askui.chat.api.threads.dependencies import ThreadServiceDep +from askui.chat.api.threads.service import Thread, ThreadListResponse, ThreadService + +router = APIRouter(prefix="/threads", tags=["threads"]) + + +@router.get("") +def list_threads( + limit: int | None = None, + thread_service: ThreadService = ThreadServiceDep, +) -> ThreadListResponse: + """List all threads.""" + return thread_service.list_(limit=limit) + + +@router.post("", status_code=status.HTTP_201_CREATED) +def create_thread( + thread_service: ThreadService = ThreadServiceDep, +) -> Thread: + """Create a new thread.""" + return thread_service.create() + + +@router.get("/{thread_id}") +def retrieve_thread( + thread_id: str, + thread_service: ThreadService = ThreadServiceDep, +) -> Thread: + """Get a thread by ID.""" + try: + return thread_service.retrieve(thread_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + +@router.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_thread( + thread_id: str, + thread_service: ThreadService = ThreadServiceDep, +) -> None: + """Delete a thread.""" + try: + thread_service.delete(thread_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py new file mode 100644 index 00000000..5e4e7b4d --- /dev/null +++ b/src/askui/chat/api/threads/service.py @@ -0,0 +1,132 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import Sequence + +from pydantic import AwareDatetime, BaseModel, Field + +from askui.chat.api.utils import generate_time_ordered_id + + +class Thread(BaseModel): + """A chat thread/session.""" + + id: str = Field(default_factory=lambda: generate_time_ordered_id("thread")) + created_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + object: str = "thread" + + +class ThreadListResponse(BaseModel): + """Response model for listing threads.""" + + object: str = "list" + data: Sequence[Thread] + first_id: str | None = None + last_id: str | None = None + has_more: bool = False + + +class ThreadService: + """Service for managing chat threads/sessions.""" + + def __init__(self, base_dir: Path) -> None: + """Initialize thread service. + + Args: + base_dir: Base directory to store thread data + """ + self._base_dir = base_dir + self._threads_dir = base_dir / "threads" + + def list_(self, limit: int | None = None) -> ThreadListResponse: + """List all available threads. + + Args: + limit: Optional maximum number of threads to return + + Returns: + ThreadListResponse containing threads sorted by creation date (newest first) + """ + if not self._threads_dir.exists(): + return ThreadListResponse(data=[]) + + thread_files = list(self._threads_dir.glob("*.jsonl")) + threads: list[Thread] = [] + for f in thread_files: + thread_id = f.stem + created_at = datetime.fromtimestamp(f.stat().st_ctime, tz=timezone.utc) + threads.append( + Thread( + id=thread_id, + created_at=created_at, + ) + ) + + # Sort by creation date, newest first + threads = sorted(threads, key=lambda t: t.created_at, reverse=True) + + # Apply limit if specified + if limit is not None: + threads = threads[:limit] + + return ThreadListResponse( + data=threads, + first_id=threads[0].id if threads else None, + last_id=threads[-1].id if threads else None, + has_more=len(thread_files) > (limit or len(thread_files)), + ) + + def create(self) -> Thread: + """Create a new thread. + + Returns: + Created thread object + """ + thread = Thread() + thread_file = self._threads_dir / f"{thread.id}.jsonl" + self._threads_dir.mkdir(parents=True, exist_ok=True) + thread_file.touch() + return thread + + def retrieve(self, thread_id: str) -> Thread: + """Retrieve a thread by ID. + + Args: + thread_id: ID of thread to retrieve + + Returns: + Thread object + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + created_at = datetime.fromtimestamp( + thread_file.stat().st_ctime, tz=timezone.utc + ) + return Thread( + id=thread_id, + created_at=created_at, + ) + + def delete(self, thread_id: str) -> None: + """Delete a thread and all its associated files. + + Args: + thread_id: ID of thread to delete + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + # Delete thread file + thread_file.unlink() diff --git a/src/askui/chat/api/utils.py b/src/askui/chat/api/utils.py new file mode 100644 index 00000000..d256b794 --- /dev/null +++ b/src/askui/chat/api/utils.py @@ -0,0 +1,24 @@ +import base64 +import os +import time + + +def generate_time_ordered_id(prefix: str) -> str: + """Generate a time-ordered ID with format: prefix_timestamp_random. + + Args: + prefix: Prefix for the ID (e.g. 'thread', 'msg') + + Returns: + Time-ordered ID string + """ + # Get current timestamp in milliseconds + timestamp = int(time.time() * 1000) + # Convert to base32 for shorter string (removing padding) + timestamp_b32 = ( + base64.b32encode(str(timestamp).encode()).decode().rstrip("=").lower() + ) + # Get 12 random bytes and convert to base32 (removing padding) + random_bytes = os.urandom(12) + random_b32 = base64.b32encode(random_bytes).decode().rstrip("=").lower() + return f"{prefix}_{timestamp_b32}{random_b32}" diff --git a/src/askui/mcp/__init__.py b/src/askui/mcp/__init__.py new file mode 100644 index 00000000..c23dd663 --- /dev/null +++ b/src/askui/mcp/__init__.py @@ -0,0 +1,22 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass + +from mcp.server.fastmcp import FastMCP + +from askui.agent import VisionAgent + + +@dataclass +class AppContext: + vision_agent: VisionAgent + + +@asynccontextmanager +async def mcp_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # noqa: ARG001 + with VisionAgent(display=2) as vision_agent: + server.add_tool(vision_agent.click) + yield AppContext(vision_agent=vision_agent) + + +mcp = FastMCP("Vision Agent MCP App", lifespan=mcp_lifespan) diff --git a/src/askui/models/__init__.py b/src/askui/models/__init__.py index fada0c87..ec531fb0 100644 --- a/src/askui/models/__init__.py +++ b/src/askui/models/__init__.py @@ -8,18 +8,48 @@ ModelDefinition, ModelName, ModelRegistry, + OnMessageCb, Point, ) +from .shared.computer_agent_message_param import ( + Base64ImageSourceParam, + CacheControlEphemeralParam, + CitationCharLocationParam, + CitationContentBlockLocationParam, + CitationPageLocationParam, + ContentBlockParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + TextCitationParam, + ToolResultBlockParam, + ToolUseBlockParam, + UrlImageSourceParam, +) __all__ = [ "ActModel", + "Base64ImageSourceParam", + "CacheControlEphemeralParam", + "CitationCharLocationParam", + "CitationContentBlockLocationParam", + "CitationPageLocationParam", + "ContentBlockParam", "GetModel", + "ImageBlockParam", "LocateModel", + "MessageParam", "Model", "ModelChoice", "ModelComposition", "ModelDefinition", "ModelName", "ModelRegistry", + "OnMessageCb", "Point", + "TextBlockParam", + "TextCitationParam", + "ToolResultBlockParam", + "ToolUseBlockParam", + "UrlImageSourceParam", ] diff --git a/src/askui/models/anthropic/computer_agent.py b/src/askui/models/anthropic/computer_agent.py index 94600691..f7dc7ea2 100644 --- a/src/askui/models/anthropic/computer_agent.py +++ b/src/askui/models/anthropic/computer_agent.py @@ -1,377 +1,45 @@ -import platform -import sys -from datetime import datetime, timezone -from typing import Any, cast +from typing import TYPE_CHECKING, cast -from anthropic import Anthropic, APIError, APIResponseValidationError, APIStatusError -from anthropic.types.beta import ( - BetaCacheControlEphemeralParam, - BetaImageBlockParam, - BetaMessage, - BetaMessageParam, - BetaTextBlock, - BetaTextBlockParam, - BetaToolResultBlockParam, - BetaToolUseBlockParam, -) +from anthropic import Anthropic from typing_extensions import override from askui.models.anthropic.settings import ClaudeComputerAgentSettings -from askui.models.models import ANTHROPIC_MODEL_NAME_MAPPING, ActModel, ModelName +from askui.models.models import ANTHROPIC_MODEL_NAME_MAPPING, ModelName +from askui.models.shared.computer_agent import ComputerAgent +from askui.models.shared.computer_agent_message_param import MessageParam from askui.reporting import Reporter from askui.tools.agent_os import AgentOs -from ...logger import logger -from ...tools.anthropic import ComputerTool, ToolCollection, ToolResult -from ...utils.str_utils import truncate_long_strings +if TYPE_CHECKING: + from anthropic.types.beta import BetaMessageParam -PC_KEY = [ - "backspace", - "delete", - "enter", - "tab", - "escape", - "up", - "down", - "right", - "left", - "home", - "end", - "pageup", - "pagedown", - "f1", - "f2", - "f3", - "f4", - "f5", - "f6", - "f7", - "f8", - "f9", - "f10", - "f11", - "f12", - "space", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - "P", - "Q", - "R", - "S", - "T", - "U", - "V", - "W", - "X", - "Y", - "Z", - "!", - '"', - "#", - "$", - "%", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "<", - "=", - ">", - "?", - "@", - "[", - "\\", - "]", - "^", - "_", - "`", - "{", - "|", - "}", - "~", -] - -SYSTEM_PROMPT = f""" -* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. -* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. -* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. -* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. -* Valid keyboard keys available are {", ".join(PC_KEY)} -* The current date is {datetime.now(tz=timezone.utc).strftime("%A, %B %d, %Y").replace(" 0", " ")}. - - - -* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. -* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. -""" - - -class ClaudeComputerAgent(ActModel): +class ClaudeComputerAgent(ComputerAgent[ClaudeComputerAgentSettings]): def __init__( self, agent_os: AgentOs, reporter: Reporter, settings: ClaudeComputerAgentSettings, ) -> None: - self._settings = settings + super().__init__(settings, agent_os, reporter) self._client = Anthropic( api_key=self._settings.anthropic.api_key.get_secret_value() ) - self._reporter = reporter - self._tool_collection = ToolCollection( - ComputerTool(agent_os), - ) - self._system = BetaTextBlockParam( - type="text", - text=f"{SYSTEM_PROMPT}", - ) - - def step( - self, messages: list[BetaMessageParam], model: str - ) -> list[BetaMessageParam]: - if self._settings.only_n_most_recent_images: - self._maybe_filter_to_n_most_recent_images( - messages, - self._settings.only_n_most_recent_images, - min_removal_threshold=self._settings.image_truncation_threshold, - ) - - try: - raw_response = self._client.beta.messages.with_raw_response.create( - max_tokens=self._settings.max_tokens, - messages=messages, - model=model, - system=[self._system], - tools=self._tool_collection.to_params(), - betas=self._settings.betas, - ) - except (APIStatusError, APIResponseValidationError) as e: - logger.error(e) - return messages - except APIError as e: - logger.error(e) - return messages - - response = raw_response.parse() - response_params = self._response_to_params(response) - new_message: BetaMessageParam = { - "role": "assistant", - "content": response_params, - } - logger.debug(new_message) - messages.append(new_message) - self._reporter.add_message("Anthropic Computer Use", response_params) - - tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in response_params: - if content_block["type"] == "tool_use": - result = self._tool_collection.run( - name=content_block["name"], - tool_input=cast("dict[str, Any]", content_block["input"]), - ) - tool_result_content.append( - self._make_api_tool_result(result, content_block["id"]) - ) - if len(tool_result_content) > 0: - another_new_message = {"content": tool_result_content, "role": "user"} - logger.debug(truncate_long_strings(another_new_message, max_length=200)) - messages.append(cast("BetaMessageParam", another_new_message)) - return messages @override - def act(self, goal: str, model_choice: str) -> None: - messages: list[BetaMessageParam] = [{"role": "user", "content": goal}] - logger.debug(messages[0]) - while messages[-1]["role"] == "user": - messages = self.step( - messages=messages, - model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)], - ) - - @staticmethod - def _maybe_filter_to_n_most_recent_images( - messages: list[BetaMessageParam], - images_to_keep: int | None, - min_removal_threshold: int, - ) -> list[BetaMessageParam] | None: - """ - With the assumption that images are screenshots that are of diminishing value as - the conversation progresses, remove all but the final `images_to_keep` tool_result - images in place, with a chunk of min_removal_threshold to reduce the amount we - break the implicit prompt cache. - """ - if images_to_keep is None: - return messages - - tool_result_blocks = cast( - "list[BetaToolResultBlockParam]", - [ - item + def _create_message( + self, messages: list[MessageParam], model_choice: str + ) -> MessageParam: + response = self._client.beta.messages.with_raw_response.create( + max_tokens=self._settings.max_tokens, + messages=[ + cast("BetaMessageParam", message.model_dump(mode="json")) for message in messages - for item in ( - message["content"] if isinstance(message["content"], list) else [] - ) - if isinstance(item, dict) and item.get("type") == "tool_result" ], + model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)], + system=[self._system], + tools=self._tool_collection.to_params(), + betas=self._settings.betas, ) - total_images = sum( - 1 - for tool_result in tool_result_blocks - for content in tool_result.get("content", []) - if isinstance(content, dict) and content.get("type") == "image" - ) - images_to_remove = total_images - images_to_keep - # for better cache behavior, we want to remove in chunks - images_to_remove -= images_to_remove % min_removal_threshold - for tool_result in tool_result_blocks: - if isinstance(tool_result.get("content"), list): - new_content = [] - for content in tool_result.get("content", []): - if isinstance(content, dict) and content.get("type") == "image": - if images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(content) - tool_result["content"] = new_content # type: ignore - return None - - @staticmethod - def _response_to_params( - response: BetaMessage, - ) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: - res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] - for block in response.content: - if isinstance(block, BetaTextBlock): - res.append({"type": "text", "text": block.text}) - else: - res.append(cast("BetaToolUseBlockParam", block.model_dump())) - return res - - @staticmethod - def _inject_prompt_caching( - messages: list[BetaMessageParam], - ) -> None: - """ - Set cache breakpoints for the 3 most recent turns - one cache breakpoint is left for tools/system prompt, to be shared - across sessions - """ - - breakpoints_remaining = 3 - for message in reversed(messages): - if message["role"] == "user" and isinstance( - content := message["content"], list - ): - if breakpoints_remaining: - breakpoints_remaining -= 1 - content[-1]["cache_control"] = BetaCacheControlEphemeralParam( - {"type": "ephemeral"} - ) - else: - content[-1].pop("cache_control", None) - # we'll only every have one extra turn per loop - break - - def _make_api_tool_result( - self, result: ToolResult, tool_use_id: str - ) -> BetaToolResultBlockParam: - """Convert an agent ToolResult to an API ToolResultBlockParam.""" - tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] - is_error = False - if result.error: - is_error = True - tool_result_content = self._maybe_prepend_system_tool_result( - result, result.error - ) - else: - assert isinstance(tool_result_content, list) - if result.output: - tool_result_content.append( - { - "type": "text", - "text": self._maybe_prepend_system_tool_result( - result, result.output - ), - } - ) - if result.base64_image: - tool_result_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": result.base64_image, - }, - } - ) - return { - "type": "tool_result", - "content": tool_result_content, - "tool_use_id": tool_use_id, - "is_error": is_error, - } - - @staticmethod - def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: - if result.system: - result_text = f"{result.system}\n{result_text}" - return result_text + parsed_response = response.parse() + return MessageParam.model_validate(parsed_response.model_dump()) diff --git a/src/askui/models/anthropic/facade.py b/src/askui/models/anthropic/facade.py deleted file mode 100644 index 7188867d..00000000 --- a/src/askui/models/anthropic/facade.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Type - -from typing_extensions import override - -from askui.locators.locators import Locator -from askui.models.anthropic.computer_agent import ClaudeComputerAgent -from askui.models.anthropic.handler import ClaudeHandler -from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point -from askui.models.types.response_schemas import ResponseSchema -from askui.utils.image_utils import ImageSource - - -class AnthropicFacade(ActModel, GetModel, LocateModel): - def __init__( - self, - computer_agent: ClaudeComputerAgent, - handler: ClaudeHandler, - ) -> None: - self._computer_agent = computer_agent - self._handler = handler - - @override - def act(self, goal: str, model_choice: str) -> None: - self._computer_agent.act(goal, model_choice) - - @override - def get( - self, - query: str, - image: ImageSource, - response_schema: Type[ResponseSchema] | None, - model_choice: str, - ) -> ResponseSchema | str: - return self._handler.get(query, image, response_schema, model_choice) - - @override - def locate( - self, - locator: str | Locator, - image: ImageSource, - model_choice: ModelComposition | str, - ) -> Point: - return self._handler.locate(locator, image, model_choice) diff --git a/src/askui/models/anthropic/settings.py b/src/askui/models/anthropic/settings.py index 2fe56d54..e804495d 100644 --- a/src/askui/models/anthropic/settings.py +++ b/src/askui/models/anthropic/settings.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, Field, SecretStr from pydantic_settings import BaseSettings +from askui.models.shared.computer_agent import ComputerAgentSettingsBase + COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" @@ -21,8 +23,5 @@ class ClaudeSettings(ClaudeSettingsBase): temperature: float = 0.0 -class ClaudeComputerAgentSettings(ClaudeSettingsBase): - max_tokens: int = 4096 - only_n_most_recent_images: int = 3 - image_truncation_threshold: int = 10 - betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) +class ClaudeComputerAgentSettings(ComputerAgentSettingsBase, ClaudeSettingsBase): + pass diff --git a/src/askui/models/askui/computer_agent.py b/src/askui/models/askui/computer_agent.py index 63707cc2..42073abf 100644 --- a/src/askui/models/askui/computer_agent.py +++ b/src/askui/models/askui/computer_agent.py @@ -1,170 +1,15 @@ -import platform -import sys -from datetime import datetime -from typing import Any, cast - import httpx -from anthropic.types.beta import ( - BetaImageBlockParam, - BetaMessage, - BetaMessageParam, - BetaTextBlock, - BetaTextBlockParam, - BetaToolResultBlockParam, - BetaToolUseBlockParam, -) from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential from typing_extensions import override from askui.models.askui.settings import AskUiComputerAgentSettings -from askui.models.models import ActModel +from askui.models.shared.computer_agent import ComputerAgent +from askui.models.shared.computer_agent_message_param import MessageParam from askui.reporting import Reporter from askui.tools.agent_os import AgentOs -from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult -from askui.utils.str_utils import truncate_long_strings from ...logger import logger -COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" -PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" - -PC_KEY = [ - "backspace", - "delete", - "enter", - "tab", - "escape", - "up", - "down", - "right", - "left", - "home", - "end", - "pageup", - "pagedown", - "f1", - "f2", - "f3", - "f4", - "f5", - "f6", - "f7", - "f8", - "f9", - "f10", - "f11", - "f12", - "space", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - "P", - "Q", - "R", - "S", - "T", - "U", - "V", - "W", - "X", - "Y", - "Z", - "!", - '"', - "#", - "$", - "%", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "<", - "=", - ">", - "?", - "@", - "[", - "\\", - "]", - "^", - "_", - "`", - "{", - "|", - "}", - "~", -] - -SYSTEM_PROMPT = f""" -* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. -* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. -* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. -* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. -* Valid keyboard keys available are {", ".join(PC_KEY)} -* The current date is {datetime.today().strftime("%A, %B %d, %Y").replace(" 0", " ")}. - - - -* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. -* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. -""" # noqa: DTZ002, E501 - def is_retryable_error(exception: BaseException) -> bool: """Check if the exception is a retryable error (status codes 429 or 529).""" @@ -173,19 +18,14 @@ def is_retryable_error(exception: BaseException) -> bool: return False -class AskUiComputerAgent(ActModel): +class AskUiComputerAgent(ComputerAgent[AskUiComputerAgentSettings]): def __init__( self, agent_os: AgentOs, reporter: Reporter, settings: AskUiComputerAgentSettings, ) -> None: - self._settings = settings - self._reporter = reporter - self._tool_collection = ToolCollection( - ComputerTool(agent_os), - ) - self._system = SYSTEM_PROMPT + super().__init__(settings, agent_os, reporter) self._client = httpx.Client( base_url=f"{self._settings.askui.base_url}", headers={ @@ -200,171 +40,28 @@ def __init__( retry=retry_if_exception(is_retryable_error), reraise=True, ) - def step(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: - if self._settings.only_n_most_recent_images: - self._maybe_filter_to_n_most_recent_images( - messages, - self._settings.only_n_most_recent_images, - min_removal_threshold=self._settings.image_truncation_threshold, - ) - + @override + def _create_message( + self, + messages: list[MessageParam], + model_choice: str, # noqa: ARG002 + ) -> MessageParam: try: request_body = { "max_tokens": self._settings.max_tokens, - "messages": messages, + "messages": [msg.model_dump(mode="json") for msg in messages], "model": self._settings.model, "tools": self._tool_collection.to_params(), "betas": self._settings.betas, - "system": [{"type": "text", "text": self._system}], + "system": [self._system], } - logger.debug(request_body) response = self._client.post( "/act/inference", json=request_body, timeout=300.0 ) response.raise_for_status() response_data = response.json() - beta_message = BetaMessage.model_validate(response_data) + return MessageParam.model_validate(response_data) except Exception as e: # noqa: BLE001 if is_retryable_error(e): logger.debug(e) raise - - response_params = self._response_to_params(beta_message) - new_message: BetaMessageParam = { - "role": "assistant", - "content": response_params, - } - logger.debug(new_message) - messages.append(new_message) - if self._reporter is not None: - self._reporter.add_message("AskUI Computer Use", response_params) - - tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in response_params: - if content_block["type"] == "tool_use": - result = self._tool_collection.run( - name=content_block["name"], - tool_input=cast("dict[str, Any]", content_block["input"]), - ) - tool_result_content.append( - self._make_api_tool_result(result, content_block["id"]) - ) - if len(tool_result_content) > 0: - another_new_message: BetaMessageParam = { - "content": tool_result_content, - "role": "user", - } - logger.debug( - truncate_long_strings(dict(another_new_message), max_length=200) - ) - messages.append(another_new_message) - return messages - - @override - def act(self, goal: str, model_choice: str) -> None: - messages: list[BetaMessageParam] = [{"role": "user", "content": goal}] - logger.debug(messages[0]) - while messages[-1]["role"] == "user": - messages = self.step(messages) - - @staticmethod - def _maybe_filter_to_n_most_recent_images( - messages: list, - images_to_keep: int | None, - min_removal_threshold: int, - ) -> list | None: - """ - With the assumption that images are screenshots that are of diminishing value as - the conversation progresses, remove all but the final `images_to_keep` tool_result - images in place, with a chunk of min_removal_threshold to reduce the amount we - break the implicit prompt cache. - """ # noqa: E501 - if images_to_keep is None: - return messages - - tool_result_blocks = [ - item - for message in messages - for item in ( - message["content"] if isinstance(message["content"], list) else [] - ) - if isinstance(item, dict) and item.get("type") == "tool_result" - ] - total_images = sum( - 1 - for tool_result in tool_result_blocks - for content in tool_result.get("content", []) - if isinstance(content, dict) and content.get("type") == "image" - ) - images_to_remove = total_images - images_to_keep - # for better cache behavior, we want to remove in chunks - images_to_remove -= images_to_remove % min_removal_threshold - for tool_result in tool_result_blocks: - if isinstance(tool_result.get("content"), list): - new_content = [] - for content in tool_result.get("content", []): - if isinstance(content, dict) and content.get("type") == "image": - if images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(content) - tool_result["content"] = new_content - return None - - @staticmethod - def _response_to_params( - response: BetaMessage, - ) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: - res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] - for block in response.content: - if isinstance(block, BetaTextBlock): - res.append({"type": "text", "text": block.text}) - else: - res.append(cast("BetaToolUseBlockParam", block.model_dump())) - return res - - def _make_api_tool_result( - self, result: ToolResult, tool_use_id: str - ) -> BetaToolResultBlockParam: - """Convert an agent ToolResult to an API ToolResultBlockParam.""" - tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] - is_error = False - if result.error: - is_error = True - tool_result_content = self._maybe_prepend_system_tool_result( - result, result.error - ) - else: - assert isinstance(tool_result_content, list) - if result.output: - tool_result_content.append( - { - "type": "text", - "text": self._maybe_prepend_system_tool_result( - result, result.output - ), - } - ) - if result.base64_image: - tool_result_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": result.base64_image, - }, - } - ) - return { - "type": "tool_result", - "content": tool_result_content, - "tool_use_id": tool_use_id, - "is_error": is_error, - } - - @staticmethod - def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: - if result.system: - result_text = f"{result.system}\n{result_text}" - return result_text diff --git a/src/askui/models/askui/facade.py b/src/askui/models/askui/facade.py deleted file mode 100644 index a682522e..00000000 --- a/src/askui/models/askui/facade.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Type - -from typing_extensions import override - -from askui.locators.locators import Locator -from askui.models.askui.computer_agent import AskUiComputerAgent -from askui.models.askui.inference_api import AskUiInferenceApi -from askui.models.askui.model_router import AskUiModelRouter -from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point -from askui.models.types.response_schemas import ResponseSchema -from askui.utils.image_utils import ImageSource - - -class AskUiFacade(ActModel, GetModel, LocateModel): - def __init__( - self, - computer_agent: AskUiComputerAgent, - inference_api: AskUiInferenceApi, - model_router: AskUiModelRouter, - ) -> None: - self._computer_agent = computer_agent - self._inference_api = inference_api - self._model_router = model_router - - @override - def act(self, goal: str, model_choice: str) -> None: - self._computer_agent.act(goal, model_choice) - - @override - def get( - self, - query: str, - image: ImageSource, - response_schema: Type[ResponseSchema] | None, - model_choice: str, - ) -> ResponseSchema | str: - return self._inference_api.get(query, image, response_schema, model_choice) - - @override - def locate( - self, - locator: str | Locator, - image: ImageSource, - model_choice: ModelComposition | str, - ) -> Point: - return self._model_router.locate(locator, image, model_choice) diff --git a/src/askui/models/askui/settings.py b/src/askui/models/askui/settings.py index 649aff4e..b018b40b 100644 --- a/src/askui/models/askui/settings.py +++ b/src/askui/models/askui/settings.py @@ -1,12 +1,11 @@ import base64 from functools import cached_property -from pydantic import UUID4, BaseModel, Field, HttpUrl, SecretStr +from pydantic import UUID4, Field, HttpUrl, SecretStr from pydantic_settings import BaseSettings from askui.models.models import ModelName - -COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" +from askui.models.shared.computer_agent import ComputerAgentSettingsBase class AskUiSettings(BaseSettings): @@ -37,13 +36,6 @@ def base_url(self) -> str: return f"{self.inference_endpoint}api/v1/workspaces/{self.workspace_id}" -class AskUiComputerAgentSettingsBase(BaseModel): +class AskUiComputerAgentSettings(ComputerAgentSettingsBase): model: str = ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 askui: AskUiSettings = Field(default_factory=AskUiSettings) - - -class AskUiComputerAgentSettings(AskUiComputerAgentSettingsBase): - max_tokens: int = 4096 - only_n_most_recent_images: int = 3 - image_truncation_threshold: int = 10 - betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 10ebd036..fbc49fe1 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -6,7 +6,6 @@ from askui.exceptions import ModelNotFoundError, ModelTypeMismatchError from askui.locators.locators import Locator from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer -from askui.models.anthropic.facade import AnthropicFacade from askui.models.anthropic.settings import ( AnthropicSettings, ClaudeComputerAgentSettings, @@ -14,7 +13,6 @@ ) from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.askui.computer_agent import AskUiComputerAgent -from askui.models.askui.facade import AskUiFacade from askui.models.askui.model_router import AskUiModelRouter from askui.models.askui.settings import AskUiComputerAgentSettings from askui.models.huggingface.spaces_api import HFSpacesHandler @@ -29,6 +27,9 @@ ModelRegistry, Point, ) +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam +from askui.models.shared.facade import ModelFacade from askui.models.types.response_schemas import ResponseSchema from askui.reporting import CompositeReporter, Reporter from askui.tools.toolbox import AgentToolbox @@ -70,7 +71,7 @@ def vlm_locator_serializer() -> VlmLocatorSerializer: return VlmLocatorSerializer() @functools.cache - def anthropic_facade() -> AnthropicFacade: + def anthropic_facade() -> ModelFacade: settings = AnthropicSettings() computer_agent = ClaudeComputerAgent( agent_os=tools.os, @@ -85,13 +86,14 @@ def anthropic_facade() -> AnthropicFacade: ), locator_serializer=vlm_locator_serializer(), ) - return AnthropicFacade( - computer_agent=computer_agent, - handler=handler, + return ModelFacade( + act_model=computer_agent, + get_model=handler, + locate_model=handler, ) @functools.cache - def askui_facade() -> AskUiFacade: + def askui_facade() -> ModelFacade: computer_agent = AskUiComputerAgent( agent_os=tools.os, reporter=reporter, @@ -99,10 +101,10 @@ def askui_facade() -> AskUiFacade: askui=askui_settings(), ), ) - return AskUiFacade( - computer_agent=computer_agent, - inference_api=askui_inference_api(), - model_router=askui_model_router(), + return ModelFacade( + act_model=computer_agent, + get_model=askui_inference_api(), + locate_model=askui_model_router(), ) @functools.cache @@ -180,10 +182,15 @@ def _get_model( return model - def act(self, goal: str, model_choice: str) -> None: + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: m = self._get_model(model_choice, "act") logger.debug(f'Routing "act" to model "{model_choice}"') - return m.act(goal, model_choice) + return m.act(messages, model_choice, on_message) def get( self, diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 80a6e1fe..1c612d8c 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -8,6 +8,8 @@ from typing_extensions import Literal, TypedDict from askui.locators.locators import Locator +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.types.response_schemas import ResponseSchema from askui.utils.image_utils import ImageSource @@ -147,34 +149,74 @@ def __getitem__(self, index: int) -> ModelDefinition: class ActModel(abc.ABC): """Abstract base class for models that can execute autonomous actions. - Models implementing this interface can be used with the `act()` method of - `VisionAgent` - to achieve goals through autonomous actions. These models analyze the screen and - determine necessary steps to accomplish a given goal. + Models implementing this interface can be used with the `VisionAgent.act()`. Example: ```python - from askui import ActModel, VisionAgent + from askui import ( + ActModel, + MessageParam, + OnMessageCb, + VisionAgent, + ) + from typing_extensions import override class MyActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: - # Implement custom act logic - pass + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + pass # implement action logic here with VisionAgent(models={"my-act": MyActModel()}) as agent: agent.act("search for flights", model="my-act") - ``` """ @abc.abstractmethod - def act(self, goal: str, model_choice: str) -> None: - """Execute autonomous actions to achieve a goal. + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + """ + Execute autonomous actions to achieve a goal, using a message history + and optional callbacks, encoded in the messages. In the simplest case, + it can be found in the first message `messages[0].content` as a `str`. + + The `messages` usually start with a `"user"` (role) message which is followed by + alternating `"assistant"` (AI agent) and `"user"` messages (which can be + automatic tool use, e.g., taking a screenshot) similar how you would + expect it from a conversation whereby the `"assistant"` determines the next + actions which are then automatically taking by the `"user"` programmatically + until it eventually returns, usually with an `"assistant"` message that either + says that the goal has been achieved or that it failed to achieve the goal. Args: - goal (str): A description of what the model should achieve - model_choice (str): The name of the model being used (useful for models that - support multiple configurations) - """ + messages (list[MessageParam]): The message history to start that + determines the actions and following messages. + model_choice (str): The name of the model being used, e.g., useful for + models registered under multiple keys, e.g., `"my-act-1"` and + `"my-act-2"` that depending on the key (passed as `model_choice`) + behave differently. + on_message (OnMessageCb | None, optional): Callback for new messages + from either an assistant/agent or a user (including + automatic/programmatic tool use, e.g., taking a screenshot). + If it returns `None`, the acting is canceled and `act()` returns + immediately. If it returns a `MessageParam`, this `MessageParma` is + added to the message history and the acting continues based on the + message. The message may be modified by the callback to allow for + directing the assistant/agent or tool use. + + Returns: + None + + Raises: + NotImplementedError: If the method is not implemented. + """ # noqa: E501 raise NotImplementedError diff --git a/src/askui/models/shared/__init__.py b/src/askui/models/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py new file mode 100644 index 00000000..0c4a74be --- /dev/null +++ b/src/askui/models/shared/computer_agent.py @@ -0,0 +1,442 @@ +import platform +import sys +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import Generic + +from anthropic.types.beta import BetaTextBlockParam +from pydantic import BaseModel, Field +from typing_extensions import TypeVar, override + +from askui.models.models import ActModel +from askui.models.shared.computer_agent_cb_param import OnMessageCb, OnMessageCbParam +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + ContentBlockParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, +) +from askui.reporting import Reporter +from askui.tools.agent_os import AgentOs +from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult + +from ...logger import logger + +COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" + +PC_KEY = [ + "backspace", + "delete", + "enter", + "tab", + "escape", + "up", + "down", + "right", + "left", + "home", + "end", + "pageup", + "pagedown", + "f1", + "f2", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "f10", + "f11", + "f12", + "space", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + "Q", + "R", + "S", + "T", + "U", + "V", + "W", + "X", + "Y", + "Z", + "!", + '"', + "#", + "$", + "%", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "<", + "=", + ">", + "?", + "@", + "[", + "\\", + "]", + "^", + "_", + "`", + "{", + "|", + "}", + "~", +] + +SYSTEM_PROMPT = f""" +* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. +* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. +* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. +* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* Valid keyboard keys available are {", ".join(PC_KEY)} +* The current date is {datetime.now(timezone.utc).strftime("%A, %B %d, %Y").replace(" 0", " ")}. + + + +* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. +* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. +""" # noqa: DTZ002, E501 + + +class ComputerAgentSettingsBase(BaseModel): + """Settings for computer agents.""" + + max_tokens: int = 4096 + only_n_most_recent_images: int = 3 + image_truncation_threshold: int = 10 + betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) + + +ComputerAgentSettings = TypeVar( + "ComputerAgentSettings", bound=ComputerAgentSettingsBase +) + + +class ComputerAgent(ActModel, ABC, Generic[ComputerAgentSettings]): + """Base class for computer agents that can execute autonomous actions. + + This class provides common functionality for both AskUI and Anthropic + computer agents, + including tool handling, message processing, and image filtering. + """ + + def __init__( + self, + settings: ComputerAgentSettings, + agent_os: AgentOs, + reporter: Reporter, + ) -> None: + """Initialize the computer agent. + + Args: + settings (ComputerAgentSettings): The settings for the computer agent. + agent_os (AgentOs): The operating system agent for executing commands. + reporter (Reporter): The reporter for logging messages and actions. + """ + self._settings = settings + self._reporter = reporter + self._tool_collection = ToolCollection( + ComputerTool(agent_os), + ) + self._system = BetaTextBlockParam( + type="text", + text=f"{SYSTEM_PROMPT}", + ) + + @abstractmethod + def _create_message( + self, messages: list[MessageParam], model_choice: str + ) -> MessageParam: + """Create a message using the agent's API. + + Args: + messages (list[MessageParam]): The message history. + model_choice (str): The model to use for message creation. + + Returns: + MessageParam: The created message. + """ + raise NotImplementedError + + def _step( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + """Execute a single step in the conversation. + + Args: + messages (list[MessageParam]): The message history. + model_choice (str): The model to use for message creation. + on_message (OnMessageCb | None, optional): Callback on new messages + + Returns: + None + """ + if self._settings.only_n_most_recent_images: + messages = self._maybe_filter_to_n_most_recent_images( + messages, + self._settings.only_n_most_recent_images, + self._settings.image_truncation_threshold, + ) + response_message = self._create_message(messages, model_choice) + message_by_assistant = self._call_on_message( + on_message, response_message, messages + ) + if message_by_assistant is None: + return + message_by_assistant_dict = message_by_assistant.model_dump(mode="json") + logger.debug(message_by_assistant_dict) + messages.append(message_by_assistant) + self._reporter.add_message(self.__class__.__name__, message_by_assistant_dict) + if tool_result_message := self._use_tools(message_by_assistant): + if tool_result_message := self._call_on_message( + on_message, tool_result_message, messages + ): + tool_result_message_dict = tool_result_message.model_dump(mode="json") + logger.debug(tool_result_message_dict) + messages.append(tool_result_message) + self._step( + messages=messages, + model_choice=model_choice, + on_message=on_message, + ) + + def _call_on_message( + self, + on_message: OnMessageCb | None, + message: MessageParam, + messages: list[MessageParam], + ) -> MessageParam | None: + if on_message is None: + return message + return on_message(OnMessageCbParam(message=message, messages=messages)) + + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + self._step( + messages=messages, + model_choice=model_choice, + on_message=on_message, + ) + + def _use_tools( + self, + message: MessageParam, + ) -> MessageParam | None: + """Process tool use blocks in a message. + + Args: + message (MessageParam): The message containing tool use blocks. + + Returns: + MessageParam | None: A message containing tool results or `None` + if no tools were used. + """ + tool_result_content: list[ContentBlockParam] = [] + if isinstance(message.content, str): + return None + + for content_block in message.content: + if content_block.type == "tool_use": + result = self._tool_collection.run( + name=content_block.name, + tool_input=content_block.input, # type: ignore[arg-type] + ) + tool_result_content.append( + self._make_api_tool_result(result, content_block.id) + ) + if len(tool_result_content) == 0: + return None + + return MessageParam( + content=tool_result_content, + role="user", + ) + + @staticmethod + def _maybe_filter_to_n_most_recent_images( + messages: list[MessageParam], + images_to_keep: int | None, + min_removal_threshold: int, + ) -> list[MessageParam]: + """ + Filter the message history in-place to keep only the most recent images, + according to the given chunking policy. + + Args: + messages (list[MessageParam]): The message history. + images_to_keep (int | None): Number of most recent images to keep. + min_removal_threshold (int): Minimum number of images to remove at once. + + Returns: + list[MessageParam]: The filtered message history. + """ + if images_to_keep is None: + return messages + + tool_result_blocks = [ + item + for message in messages + for item in (message.content if isinstance(message.content, list) else []) + if item.type == "tool_result" + ] + total_images = sum( + 1 + for tool_result in tool_result_blocks + if not isinstance(tool_result.content, str) + for content in tool_result.content + if content.type == "image" + ) + images_to_remove = total_images - images_to_keep + if images_to_remove < min_removal_threshold: + return messages + # for better cache behavior, we want to remove in chunks + images_to_remove -= images_to_remove % min_removal_threshold + if images_to_remove <= 0: + return messages + + # Remove images from the oldest tool_result blocks first + for tool_result in tool_result_blocks: + if images_to_remove <= 0: + break + if isinstance(tool_result.content, list): + new_content: list[TextBlockParam | ImageBlockParam] = [] + for content in tool_result.content: + if content.type == "image" and images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(content) + tool_result.content = new_content + return messages + + def _make_api_tool_result( + self, result: ToolResult, tool_use_id: str + ) -> ToolResultBlockParam: + """Convert a tool result to an API tool result block. + + Args: + result (ToolResult): The tool result to convert. + tool_use_id (str): The ID of the tool use block. + + Returns: + ToolResultBlockParam: The API tool result block. + """ + tool_result_content: list[TextBlockParam | ImageBlockParam] | str = [] + is_error = False + if result.error: + is_error = True + tool_result_content = self._maybe_prepend_system_tool_result( + result, result.error + ) + else: + assert isinstance(tool_result_content, list) + if result.output: + tool_result_content.append( + TextBlockParam( + text=self._maybe_prepend_system_tool_result( + result, result.output + ), + ) + ) + if result.base64_image: + tool_result_content.append( + ImageBlockParam( + source=Base64ImageSourceParam( + media_type="image/png", + data=result.base64_image, + ), + ) + ) + return ToolResultBlockParam( + content=tool_result_content, + tool_use_id=tool_use_id, + is_error=is_error, + ) + + @staticmethod + def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: + """Prepend system message to tool result text if available. + + Args: + result (ToolResult): The tool result. + result_text (str): The result text. + + Returns: + str: The result text with optional system message prepended. + """ + if result.system: + result_text = f"{result.system}\n{result_text}" + return result_text diff --git a/src/askui/models/shared/computer_agent_cb_param.py b/src/askui/models/shared/computer_agent_cb_param.py new file mode 100644 index 00000000..80d07f49 --- /dev/null +++ b/src/askui/models/shared/computer_agent_cb_param.py @@ -0,0 +1,14 @@ +from typing import Callable, Literal + +from pydantic import BaseModel + +from askui.models.shared.computer_agent_message_param import MessageParam + + +class OnMessageCbParam(BaseModel): + type: Literal["message"] = "message" + message: MessageParam + messages: list[MessageParam] + + +OnMessageCb = Callable[[OnMessageCbParam], MessageParam | None] diff --git a/src/askui/models/shared/computer_agent_message_param.py b/src/askui/models/shared/computer_agent_message_param.py new file mode 100644 index 00000000..e525f172 --- /dev/null +++ b/src/askui/models/shared/computer_agent_message_param.py @@ -0,0 +1,107 @@ +from pydantic import BaseModel +from typing_extensions import Literal + + +class CitationCharLocationParam(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_char_index: int + start_char_index: int + type: Literal["char_location"] = "char_location" + + +class CitationPageLocationParam(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_page_number: int + start_page_number: int + type: Literal["page_location"] = "page_location" + + +class CitationContentBlockLocationParam(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_block_index: int + start_block_index: int + type: Literal["content_block_location"] = "content_block_location" + + +TextCitationParam = ( + CitationCharLocationParam + | CitationPageLocationParam + | CitationContentBlockLocationParam +) + + +class UrlImageSourceParam(BaseModel): + type: Literal["url"] = "url" + url: str + + +class Base64ImageSourceParam(BaseModel): + data: str + media_type: Literal["image/jpeg", "image/png", "image/gif", "image/webp"] + type: Literal["base64"] = "base64" + + +class CacheControlEphemeralParam(BaseModel): + type: Literal["ephemeral"] = "ephemeral" + + +class ImageBlockParam(BaseModel): + source: Base64ImageSourceParam | UrlImageSourceParam + type: Literal["image"] = "image" + cache_control: CacheControlEphemeralParam | None = None + + +class TextBlockParam(BaseModel): + text: str + type: Literal["text"] = "text" + cache_control: CacheControlEphemeralParam | None = None + citations: list[TextCitationParam] | None = None + + +class ToolResultBlockParam(BaseModel): + tool_use_id: str + type: Literal["tool_result"] = "tool_result" + cache_control: CacheControlEphemeralParam | None = None + content: str | list[TextBlockParam | ImageBlockParam] + is_error: bool = False + + +class ToolUseBlockParam(BaseModel): + id: str + input: object + name: str + type: Literal["tool_use"] = "tool_use" + cache_control: CacheControlEphemeralParam | None = None + + +ContentBlockParam = ( + ImageBlockParam | TextBlockParam | ToolResultBlockParam | ToolUseBlockParam +) + + +class MessageParam(BaseModel): + role: Literal["user", "assistant"] + content: str | list[ContentBlockParam] + + +__all__ = [ + "Base64ImageSourceParam", + "CacheControlEphemeralParam", + "CitationCharLocationParam", + "CitationContentBlockLocationParam", + "CitationPageLocationParam", + "ContentBlockParam", + "ImageBlockParam", + "MessageParam", + "TextBlockParam", + "TextCitationParam", + "ToolResultBlockParam", + "ToolUseBlockParam", + "UrlImageSourceParam", +] diff --git a/src/askui/models/shared/facade.py b/src/askui/models/shared/facade.py new file mode 100644 index 00000000..e4ac7a0f --- /dev/null +++ b/src/askui/models/shared/facade.py @@ -0,0 +1,54 @@ +from typing import Type + +from typing_extensions import override + +from askui.locators.locators import Locator +from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam +from askui.models.types.response_schemas import ResponseSchema +from askui.utils.image_utils import ImageSource + + +class ModelFacade(ActModel, GetModel, LocateModel): + def __init__( + self, + act_model: ActModel, + get_model: GetModel, + locate_model: LocateModel, + ) -> None: + self._act_model = act_model + self._get_model = get_model + self._locate_model = locate_model + + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + self._act_model.act( + messages=messages, + model_choice=model_choice, + on_message=on_message, + ) + + @override + def get( + self, + query: str, + image: ImageSource, + response_schema: Type[ResponseSchema] | None, + model_choice: str, + ) -> ResponseSchema | str: + return self._get_model.get(query, image, response_schema, model_choice) + + @override + def locate( + self, + locator: str | Locator, + image: ImageSource, + model_choice: ModelComposition | str, + ) -> Point: + return self._locate_model.locate(locator, image, model_choice) diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 04312a60..e1c39306 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -12,6 +12,8 @@ from askui.locators.locators import Locator from askui.locators.serializers import VlmLocatorSerializer from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter from askui.tools.agent_os import AgentOs @@ -188,7 +190,26 @@ def get( return response @override - def act(self, goal: str, model_choice: str) -> None: + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + if on_message is not None: + error_msg = "on_message is not supported for UI-TARS" + raise NotImplementedError(error_msg) + if len(messages) != 1: + error_msg = "UI-TARS only supports one message" + raise ValueError(error_msg) + message = messages[0] + if message.role != "user": + error_msg = "UI-TARS only supports user messages" + raise ValueError(error_msg) + if not isinstance(message.content, str): + error_msg = "UI-TARS only supports text messages" + raise ValueError(error_msg) # noqa: TRY004 + goal = message.content screenshot = self._agent_os.screenshot() self.act_history = [ { @@ -301,7 +322,7 @@ def execute_act(self, message_history: list[dict[str, Any]]) -> None: action = message.parsed_action if action.action_type == "click": - self._agent_os.mouse(action.start_box.x, action.start_box.y) + self._agent_os.mouse_move(action.start_box.x, action.start_box.y) self._agent_os.click("left") time.sleep(1) if action.action_type == "type": diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 12360407..df30542d 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -304,8 +304,8 @@ def generate(self) -> None: {% for image in msg.images %}
Message image + class="message-image" + alt="Message image"> {% endfor %} diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index cb719899..24b62ee0 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -2,10 +2,13 @@ from typing import Literal from PIL import Image +from pydantic import BaseModel ModifierKey = Literal["command", "alt", "control", "shift", "right_shift"] """Modifier keys for keyboard actions.""" +ModifierKeys: list[ModifierKey] = ["command", "alt", "control", "shift", "right_shift"] + PcKey = Literal[ "backspace", "delete", @@ -130,6 +133,142 @@ ] """PC keys for keyboard actions.""" +PcKeys: list[PcKey] = [ + "backspace", + "delete", + "enter", + "tab", + "escape", + "up", + "down", + "right", + "left", + "home", + "end", + "pageup", + "pagedown", + "f1", + "f2", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "f10", + "f11", + "f12", + "space", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + "Q", + "R", + "S", + "T", + "U", + "V", + "W", + "X", + "Y", + "Z", + "!", + '"', + "#", + "$", + "%", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "<", + "=", + ">", + "?", + "@", + "[", + "\\", + "]", + "^", + "_", + "`", + "{", + "|", + "}", + "~", +] + + +class ClickEvent(BaseModel): + type: Literal["click"] = "click" + x: int + y: int + button: Literal["left", "middle", "right", "unknown"] + pressed: bool + injected: bool = False + timestamp: float + + +InputEvent = ClickEvent + class AgentOs(ABC): """ @@ -175,7 +314,7 @@ def screenshot(self, report: bool = True) -> Image.Image: raise NotImplementedError @abstractmethod - def mouse(self, x: int, y: int) -> None: + def mouse_move(self, x: int, y: int) -> None: """ Moves the mouse cursor to specified screen coordinates. @@ -293,12 +432,12 @@ def keyboard_tap( raise NotImplementedError @abstractmethod - def set_display(self, displayNumber: int = 1) -> None: + def set_display(self, display: int = 1) -> None: """ Sets the active display for screen interactions. Args: - displayNumber (int, optional): The display number to set as active. + display (int, optional): The display ID to set as active. Defaults to `1`. """ raise NotImplementedError @@ -315,3 +454,29 @@ def run_command(self, command: str, timeout_ms: int = 30000) -> None: """ raise NotImplementedError + + def start_listening(self) -> None: + """ + Start listening for mouse and keyboard events. + + IMPORTANT: This method is still experimental and may not work at all and may + change in the future. + """ + raise NotImplementedError + + def poll_event(self) -> InputEvent | None: + """ + Poll for a single input event. + + IMPORTANT: This method is still experimental and may not work at all and may + change in the future. + """ + raise NotImplementedError + + def stop_listening(self) -> None: + """Stop listening for mouse and keyboard events. + + IMPORTANT: This method is still experimental and may not work at all and may + change in the future. + """ + raise NotImplementedError diff --git a/src/askui/tools/anthropic/computer.py b/src/askui/tools/anthropic/computer.py index 33991b78..3997e282 100644 --- a/src/askui/tools/anthropic/computer.py +++ b/src/askui/tools/anthropic/computer.py @@ -208,8 +208,6 @@ class ComputerTool(BaseAnthropicTool): name: Literal["computer"] = "computer" api_type: Literal["computer_20241022"] = "computer_20241022" - width: int - height: int _screenshot_delay = 2.0 _scaling_enabled = True @@ -217,22 +215,20 @@ class ComputerTool(BaseAnthropicTool): @property def options(self) -> ComputerToolOptions: return { - "display_width_px": self.width, - "display_height_px": self.height, + "display_width_px": self._width, + "display_height_px": self._height, } def to_params(self) -> BetaToolComputerUse20241022Param: return {"name": self.name, "type": self.api_type, **self.options} - def __init__(self, controller_client: AgentOs) -> None: + def __init__(self, agent_os: AgentOs) -> None: super().__init__() - self.controller_client = controller_client - - self.width = 1280 - self.height = 800 - - self.real_screen_width = None - self.real_screen_height = None + self._agent_os = agent_os + self._width = 1280 + self._height = 800 + self._real_screen_width: int | None = None + self._real_screen_height: int | None = None def __call__( # noqa: C901 self, @@ -261,23 +257,28 @@ def __call__( # noqa: C901 error_msg = f"{coordinate} must be a tuple of non-negative ints" raise ToolError(error_msg) + if self._real_screen_width is None or self._real_screen_height is None: + screenshot = self._agent_os.screenshot() + self._real_screen_width = screenshot.width + self._real_screen_height = screenshot.height + x, y = scale_coordinates_back( coordinate[0], coordinate[1], - self.real_screen_width, - self.real_screen_height, - self.width, - self.height, + self._real_screen_width, + self._real_screen_height, + self._width, + self._height, ) x, y = int(x), int(y) if action == "mouse_move": - self.controller_client.mouse(x, y) + self._agent_os.mouse_move(x, y) return ToolResult() if action == "left_click_drag": - self.controller_client.mouse_down("left") - self.controller_client.mouse(x, y) - self.controller_client.mouse_up("left") + self._agent_os.mouse_down("left") + self._agent_os.mouse_move(x, y) + self._agent_os.mouse_up("left") return ToolResult() if action in ("key", "type"): @@ -300,11 +301,11 @@ def __call__( # noqa: C901 f"Key {text} is not a valid PC_KEY from {', '.join(PC_KEY)}" ) raise ToolError(error_msg) - self.controller_client.keyboard_pressed(text) - self.controller_client.keyboard_release(text) + self._agent_os.keyboard_pressed(text) + self._agent_os.keyboard_release(text) return ToolResult() if action == "type": - self.controller_client.type(text) + self._agent_os.type(text) return ToolResult() if action in ( @@ -328,16 +329,16 @@ def __call__( # noqa: C901 error_msg = "cursor_position is not implemented by this agent" raise ToolError(error_msg) if action == "left_click": - self.controller_client.click("left") + self._agent_os.click("left") return ToolResult() if action == "right_click": - self.controller_client.click("right") + self._agent_os.click("right") return ToolResult() if action == "middle_click": - self.controller_client.click("middle") + self._agent_os.click("middle") return ToolResult() if action == "double_click": - self.controller_client.click("left", 2) + self._agent_os.click("left", 2) return ToolResult() error_msg = f"Invalid action: {action}" @@ -348,9 +349,11 @@ def screenshot(self) -> ToolResult: Take a screenshot of the current screen, scale it and return the base64 encoded image. """ - screenshot = self.controller_client.screenshot() - self.real_screen_width = screenshot.width - self.real_screen_height = screenshot.height - scaled_screenshot = scale_image_with_padding(screenshot, 1280, 800) + screenshot = self._agent_os.screenshot() + self._real_screen_width = screenshot.width + self._real_screen_height = screenshot.height + scaled_screenshot = scale_image_with_padding( + screenshot, self._width, self._height + ) base64_image = image_to_base64(scaled_screenshot) return ToolResult(base64_image=base64_image) diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index fc9f01be..7c385bbf 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -427,7 +427,7 @@ def screenshot(self, report: bool = True) -> Image.Image: @telemetry.record_call() @override - def mouse(self, x: int, y: int) -> None: + def mouse_move(self, x: int, y: int) -> None: """ Moves the mouse cursor to specified screen coordinates. @@ -437,7 +437,7 @@ def mouse(self, x: int, y: int) -> None: """ self._reporter.add_message( "AgentOS", - f"mouse({x}, {y})", + f"mouse_move({x}, {y})", draw_point_on_image(self.screenshot(report=False), x, y, size=5), ) self._run_recorder_action( @@ -687,22 +687,22 @@ def keyboard_tap( @telemetry.record_call() @override - def set_display(self, displayNumber: int = 1) -> None: + def set_display(self, display: int = 1) -> None: """ Set the active display. Args: - displayNumber (int, optional): The display number to set as active. + display (int, optional): The display ID to set as active. Defaults to `1`. """ assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( "Stub is not initialized" ) - self._reporter.add_message("AgentOS", f"set_display({displayNumber})") + self._reporter.add_message("AgentOS", f"set_display({display})") self._stub.SetActiveDisplay( - controller_v1_pbs.Request_SetActiveDisplay(displayID=displayNumber) + controller_v1_pbs.Request_SetActiveDisplay(displayID=display) ) - self._display = displayNumber + self._display = display @telemetry.record_call(exclude={"command"}) @override diff --git a/src/askui/tools/pynput/__init__.py b/src/askui/tools/pynput/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/tools/pynput/pynput_agent_os.py b/src/askui/tools/pynput/pynput_agent_os.py new file mode 100644 index 00000000..c6a28f89 --- /dev/null +++ b/src/askui/tools/pynput/pynput_agent_os.py @@ -0,0 +1,399 @@ +import ctypes +import platform +import queue +import time +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast + +from mss import mss +from PIL import Image +from pynput.keyboard import Controller as KeyboardController +from pynput.keyboard import Key, KeyCode +from pynput.mouse import Button +from pynput.mouse import Controller as MouseController +from pynput.mouse import Listener as MouseListener +from typing_extensions import override + +from askui.logger import logger +from askui.reporting import Reporter +from askui.tools.agent_os import AgentOs, InputEvent, ModifierKey, PcKey +from askui.utils.image_utils import draw_point_on_image + +if platform.system() == "Windows": + try: + PROCESS_PER_MONITOR_DPI_AWARE = 2 + ctypes.windll.shcore.SetProcessDpiAwareness(PROCESS_PER_MONITOR_DPI_AWARE) # type: ignore[attr-defined] + except Exception as e: # noqa: BLE001 + logger.error(f"Could not set DPI awareness: {e}") + +if TYPE_CHECKING: + from mss.screenshot import ScreenShot + +_KEY_MAP: dict[PcKey | ModifierKey, Key | KeyCode] = { + "backspace": Key.backspace, + "delete": Key.delete, + "enter": Key.enter, + "tab": Key.tab, + "escape": Key.esc, + "up": Key.up, + "down": Key.down, + "right": Key.right, + "left": Key.left, + "home": Key.home, + "end": Key.end, + "pageup": Key.page_up, + "pagedown": Key.page_down, + "f1": Key.f1, + "f2": Key.f2, + "f3": Key.f3, + "f4": Key.f4, + "f5": Key.f5, + "f6": Key.f6, + "f7": Key.f7, + "f8": Key.f8, + "f9": Key.f9, + "f10": Key.f10, + "f11": Key.f11, + "f12": Key.f12, + "space": Key.space, + "command": Key.cmd, + "alt": Key.alt, + "control": Key.ctrl, + "shift": Key.shift, + "right_shift": Key.shift_r, +} + + +_BUTTON_MAP: dict[Literal["left", "middle", "right", "unknown"], Button] = { + "left": Button.left, + "middle": Button.middle, + "right": Button.right, + "unknown": Button.unknown, +} + +_BUTTON_MAP_REVERSE: dict[Button, Literal["left", "middle", "right", "unknown"]] = { + Button.left: "left", + Button.middle: "middle", + Button.right: "right", + Button.unknown: "unknown", +} + + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) +C = TypeVar("C", bound=type) + + +def await_action(pre_action_wait: float, post_action_wait: float) -> Callable[[F], F]: + def wrapper(func: F) -> F: + @wraps(func) + def _wrapper(*args: Any, **kwargs: Any) -> Any: + time.sleep(pre_action_wait) + result = func(*args, **kwargs) + time.sleep(post_action_wait) + return result + + return cast("F", _wrapper) + + return wrapper + + +def decorate_all_methods( + pre_action_wait: float, post_action_wait: float +) -> Callable[[C], C]: + def decorate(cls: C) -> C: + for attr_name, attr in cls.__dict__.items(): + if callable(attr) and not attr_name.startswith("__"): + setattr( + cls, + attr_name, + await_action(pre_action_wait, post_action_wait)( + cast("Callable[..., Any]", attr) + ), + ) + return cls + + return decorate + + +@decorate_all_methods(pre_action_wait=0.1, post_action_wait=0.1) +class PynputAgentOs(AgentOs): + """ + Implementation of AgentOs using `pynput` for mouse and keyboard control, and `mss` + for screenshots. + + Args: + reporter (Reporter): Reporter used for reporting with the `AgentOs`. + display (int, optional): Display number to use. Defaults to `1`. + """ + + def __init__( + self, + reporter: Reporter, + display: int = 1, + ) -> None: + self._mouse = MouseController() + self._keyboard = KeyboardController() + self._sct = mss() + self._display = display + self._reporter = reporter + self._mouse_listener: MouseListener | None = None + self._input_event_queue: queue.Queue[InputEvent] = queue.Queue() + + @override + def connect(self) -> None: + """No connection needed for pynput.""" + + @override + def disconnect(self) -> None: + """No disconnection needed for pynput.""" + + @override + def screenshot(self, report: bool = True) -> Image.Image: + """ + Take a screenshot of the current screen. + + Args: + report (bool, optional): Whether to include the screenshot in reporting. + Defaults to `True`. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + """ + monitor = self._sct.monitors[self._display] + screenshot: ScreenShot = self._sct.grab(monitor) + image = Image.frombytes( + "RGB", + screenshot.size, + screenshot.rgb, + ) + + scaled_size = (monitor["width"], monitor["height"]) + image = image.resize(scaled_size, Image.Resampling.LANCZOS) + + if report: + self._reporter.add_message("AgentOS", "screenshot()", image) + return image + + @override + def mouse_move(self, x: int, y: int) -> None: + """ + Move the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + """ + self._reporter.add_message( + "AgentOS", + f"mouse_move({x}, {y})", + draw_point_on_image(self.screenshot(report=False), x, y, size=5), + ) + self._mouse.position = (x, y) + + @override + def type(self, text: str, typing_speed: int = 50) -> None: + """ + Type text at current cursor position as if entered on a keyboard. + + Args: + text (str): The text to type. + typing_speed (int, optional): The speed of typing in characters per second. + Defaults to `50`. + """ + self._reporter.add_message("AgentOS", f'type("{text}", {typing_speed})') + delay = 1.0 / typing_speed + for char in text: + self._keyboard.press(char) + self._keyboard.release(char) + time.sleep(delay) + + @override + def click( + self, button: Literal["left", "middle", "right"] = "left", count: int = 1 + ) -> None: + """ + Click a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ + self._reporter.add_message("AgentOS", f'click("{button}", {count})') + pynput_button = _BUTTON_MAP[button] + for _ in range(count): + self._mouse.click(pynput_button) + + @override + def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Press and hold a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + press. Defaults to `"left"`. + """ + self._reporter.add_message("AgentOS", f'mouse_down("{button}")') + self._mouse.press(_BUTTON_MAP[button]) + + @override + def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Release a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + release. Defaults to "left". + """ + self._reporter.add_message("AgentOS", f'mouse_up("{button}")') + self._mouse.release(_BUTTON_MAP[button]) + + @override + def mouse_scroll(self, x: int, y: int) -> None: + """ + Scroll the mouse wheel. + + Args: + x (int): The horizontal scroll amount. Positive values scroll right, + negative values scroll left. + y (int): The vertical scroll amount. Positive values scroll down, + negative values scroll up. + """ + self._reporter.add_message("AgentOS", f"mouse_scroll({x}, {y})") + self._mouse.scroll(x, y) + + def _get_pynput_key(self, key: PcKey | ModifierKey) -> Key | KeyCode | str: + """Convert our key type to pynput key.""" + if key in _KEY_MAP: + return _KEY_MAP[key] + return key # For regular characters + + @override + def keyboard_pressed( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Press and hold a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to press. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + "AgentOS", f'keyboard_pressed("{key}", {modifier_keys})' + ) + if modifier_keys: + for mod in modifier_keys: + self._keyboard.press(_KEY_MAP[mod]) + self._keyboard.press(self._get_pynput_key(key)) + + @override + def keyboard_release( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to release. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + release along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + "AgentOS", f'keyboard_release("{key}", {modifier_keys})' + ) + self._keyboard.release(self._get_pynput_key(key)) + if modifier_keys: + for mod in reversed(modifier_keys): # Release in reverse order + self._keyboard.release(_KEY_MAP[mod]) + + @override + def keyboard_tap( + self, + key: PcKey | ModifierKey, + modifier_keys: list[ModifierKey] | None = None, + count: int = 1, + ) -> None: + """ + Press and immediately release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to tap. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + count (int, optional): The number of times to tap the key. Defaults to `1`. + """ + self._reporter.add_message( + "AgentOS", + f'keyboard_tap("{key}", {modifier_keys}, {count})', + ) + for _ in range(count): + self.keyboard_pressed(key, modifier_keys) + self.keyboard_release(key, modifier_keys) + + @override + def set_display(self, display: int = 1) -> None: + """ + Set the active display. + + Args: + display (int, optional): The display ID to set as active. + Defaults to `1`. + """ + self._reporter.add_message("AgentOS", f"set_display({display})") + if display < 1 or len(self._sct.monitors) <= display: + error_msg = f"Display {display} not found" + raise ValueError(error_msg) + self._display = display + + def _on_mouse_click( + self, x: float, y: float, button: Button, pressed: bool, injected: bool + ) -> None: + """Handle mouse click events.""" + self._input_event_queue.put( + InputEvent( + x=int(x), + y=int(y), + button=_BUTTON_MAP_REVERSE[button], + pressed=pressed, + injected=injected, + timestamp=time.time(), + ) + ) + + @override + def start_listening(self) -> None: + """ + Start listening for mouse and keyboard events. + + Args: + callback (InputEventCallback): Callback function that will be called for + each event. + """ + if self._mouse_listener: + self.stop_listening() + self._mouse_listener = MouseListener( + on_click=self._on_mouse_click, # type: ignore[arg-type] + name="PynputAgentOsMouseListener", + args=(self._input_event_queue,), + ) + self._mouse_listener.start() + + @override + def poll_event(self) -> InputEvent | None: + """Poll for a single input event.""" + try: + return self._input_event_queue.get(False) + except queue.Empty: + return None + + @override + def stop_listening(self) -> None: + """Stop listening for mouse and keyboard events.""" + if self._mouse_listener: + self._mouse_listener.stop() + self._mouse_listener = None + while not self._input_event_queue.empty(): + self._input_event_queue.get() diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index f6064450..dbbc30e5 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -11,11 +11,11 @@ from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.askui.computer_agent import AskUiComputerAgent -from askui.models.askui.facade import AskUiFacade from askui.models.askui.inference_api import AskUiInferenceApi, AskUiSettings from askui.models.askui.model_router import AskUiModelRouter from askui.models.askui.settings import AskUiComputerAgentSettings from askui.models.models import ModelName +from askui.models.shared.facade import ModelFacade from askui.reporting import Reporter, SimpleHtmlReporter from askui.tools.toolbox import AgentToolbox @@ -82,11 +82,11 @@ def askui_computer_agent( def askui_facade( askui_computer_agent: AskUiComputerAgent, askui_inference_api: AskUiInferenceApi, -) -> AskUiFacade: - return AskUiFacade( - computer_agent=askui_computer_agent, - inference_api=askui_inference_api, - model_router=AskUiModelRouter(inference_api=askui_inference_api), +) -> ModelFacade: + return ModelFacade( + act_model=askui_computer_agent, + get_model=askui_inference_api, + locate_model=AskUiModelRouter(inference_api=askui_inference_api), ) @@ -94,7 +94,7 @@ def askui_facade( def vision_agent( agent_toolbox_mock: AgentToolbox, simple_html_reporter: Reporter, - askui_facade: AskUiFacade, + askui_facade: ModelFacade, ) -> Generator[VisionAgent, None, None]: """Fixture providing a VisionAgent instance.""" with VisionAgent( diff --git a/tests/e2e/agent/test_act.py b/tests/e2e/agent/test_act.py index efd25e0a..e6707ca4 100644 --- a/tests/e2e/agent/test_act.py +++ b/tests/e2e/agent/test_act.py @@ -1,8 +1,8 @@ import pytest from askui.agent import VisionAgent -from askui.models.askui.facade import AskUiFacade from askui.models.models import ModelComposition, ModelDefinition, ModelName +from askui.models.shared.facade import ModelFacade from askui.reporting import Reporter from askui.tools.toolbox import AgentToolbox @@ -26,7 +26,7 @@ def test_act( def test_act_with_model_composition_should_use_default_model( agent_toolbox_mock: AgentToolbox, simple_html_reporter: Reporter, - askui_facade: AskUiFacade, + askui_facade: ModelFacade, ) -> None: with VisionAgent( reporters=[simple_html_reporter], diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index ce6e034c..be3fbb26 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -5,8 +5,8 @@ from askui import ResponseSchemaBase, VisionAgent from askui.models import ModelName -from askui.models.askui.facade import AskUiFacade from askui.models.models import ModelComposition, ModelDefinition +from askui.models.shared.facade import ModelFacade from askui.reporting import Reporter from askui.tools.toolbox import AgentToolbox @@ -42,7 +42,7 @@ def test_get( def test_get_with_model_composition_should_use_default_model( agent_toolbox_mock: AgentToolbox, - askui_facade: AskUiFacade, + askui_facade: ModelFacade, simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, ) -> None: diff --git a/tests/e2e/tools/__init__.py b/tests/e2e/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/tools/pynput/__init__.py b/tests/e2e/tools/pynput/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/e2e/tools/pynput/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/agent/test_computer_agent.py b/tests/integration/agent/test_computer_agent.py index 438b38f4..24b5c680 100644 --- a/tests/integration/agent/test_computer_agent.py +++ b/tests/integration/agent/test_computer_agent.py @@ -2,6 +2,7 @@ from askui.models.askui.computer_agent import AskUiComputerAgent from askui.models.models import ModelName +from askui.models.shared.computer_agent_message_param import MessageParam @pytest.mark.skip( @@ -10,4 +11,7 @@ def test_act( claude_computer_agent: AskUiComputerAgent, ) -> None: - claude_computer_agent.act("Go to github.com/login", model_choice=ModelName.ASKUI) + claude_computer_agent.act( + [MessageParam(role="user", content="Go to github.com/login")], + model_choice=ModelName.ASKUI, + ) diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index b5f190ed..01db14fd 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -1,8 +1,9 @@ """Integration tests for custom model registration and selection.""" -from typing import Optional, Type, Union +from typing import Any, Optional, Type, Union import pytest +from typing_extensions import override from askui import ( ActModel, @@ -16,6 +17,8 @@ ) from askui.locators.locators import Locator from askui.models import ModelComposition, ModelDefinition, ModelName +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource @@ -24,11 +27,17 @@ class SimpleActModel(ActModel): """Simple act model that records goals.""" def __init__(self) -> None: - self.goals: list[str] = [] + self.goals: list[list[dict[str, Any]]] = [] self.model_choices: list[str] = [] - def act(self, goal: str, model_choice: str) -> None: - self.goals.append(goal) + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: + self.goals.append([message.model_dump(mode="json") for message in messages]) self.model_choices.append(model_choice) @@ -137,7 +146,9 @@ def test_register_and_use_custom_act_model( with VisionAgent(models=model_registry, tools=agent_toolbox_mock) as agent: agent.act("test goal", model="custom-act") - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == ["custom-act"] def test_register_and_use_custom_get_model( @@ -182,7 +193,9 @@ def create_model() -> ActModel: with VisionAgent(models=registry, tools=agent_toolbox_mock) as agent: agent.act("test goal", model="factory-model") - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == ["factory-model"] def test_register_multiple_models_for_same_task( @@ -193,7 +206,13 @@ def test_register_multiple_models_for_same_task( """Test registering multiple models for the same task.""" class AnotherActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: + @override + def act( + self, + messages: list[MessageParam], + model_choice: str, + on_message: OnMessageCb | None = None, + ) -> None: pass registry: ModelRegistry = { @@ -205,7 +224,9 @@ def act(self, goal: str, model_choice: str) -> None: agent.act("test goal", model="act-1") agent.act("another goal", model="act-2") - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == ["act-1"] def test_use_response_schema_with_custom_get_model( @@ -240,7 +261,9 @@ def test_override_default_model( with VisionAgent(models=registry, tools=agent_toolbox_mock) as agent: agent.act("test goal") # Should use custom model since it overrides "askui" - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == [ModelName.ASKUI] def test_model_composition( @@ -319,4 +342,7 @@ def create_model() -> ActModel: agent.act("another goal", model="lazy-model") assert init_count == 2 - assert act_model.goals == ["test goal", "another goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + [{"role": "user", "content": "another goal"}], + ] diff --git a/tests/integration/tools/askui/test_askui_controller.py b/tests/integration/tools/askui/test_askui_controller.py index 52bc178b..76a51105 100644 --- a/tests/integration/tools/askui/test_askui_controller.py +++ b/tests/integration/tools/askui/test_askui_controller.py @@ -37,5 +37,5 @@ def test_find_remote_device_controller_by_component_registry( def test_actions(controller_client: AskUiControllerClient) -> None: with controller_client: controller_client.screenshot() - controller_client.mouse(0, 0) + controller_client.mouse_move(0, 0) controller_client.click() diff --git a/tests/unit/models/test_computer_agent_filter.py b/tests/unit/models/test_computer_agent_filter.py new file mode 100644 index 00000000..a57fc59e --- /dev/null +++ b/tests/unit/models/test_computer_agent_filter.py @@ -0,0 +1,130 @@ +from askui.models.shared.computer_agent import ComputerAgent +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, +) + + +def make_image_block() -> ImageBlockParam: + return ImageBlockParam( + source=Base64ImageSourceParam( + media_type="image/png", + data="abc", + ), + ) + + +def make_tool_result_block(num_images: int, num_texts: int = 0) -> ToolResultBlockParam: + content = [make_image_block() for _ in range(num_images)] + [ + TextBlockParam(text=f"text{i}") for i in range(num_texts) + ] + return ToolResultBlockParam(tool_use_id="id", content=content) + + +def make_message_with_tool_result(num_images: int, num_texts: int = 0) -> MessageParam: + return MessageParam( + role="user", content=[make_tool_result_block(num_images, num_texts)] + ) + + +def test_no_images() -> None: + messages = [make_message_with_tool_result(0, 2)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 3, 2) + assert filtered == messages + + +def test_fewer_images_than_keep() -> None: + messages = [make_message_with_tool_result(2, 1)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 3, 2) + # Only ToolResultBlockParam with list content should be checked + all_images = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + expected_images = [ + c + for b in (messages[0].content if isinstance(messages[0].content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + assert all_images == expected_images + + +def test_exactly_images_to_keep() -> None: + messages = [make_message_with_tool_result(3, 1)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 3, 2) + # Only check .content if the type is correct + first_block = ( + filtered[0].content[0] + if isinstance(filtered[0].content, list) and len(filtered[0].content) > 0 + else None + ) + if isinstance(first_block, ToolResultBlockParam) and isinstance( + first_block.content, list + ): + assert len(first_block.content) == 4 + else: + error_msg = ( + "filtered[0].content[0] is not a ToolResultBlockParam with list content" + ) + raise AssertionError(error_msg) # noqa: TRY004 + all_tool_result_contents = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + ] + assert ( + sum(1 for c in all_tool_result_contents if getattr(c, "type", None) == "image") + == 3 + ) + + +def test_more_images_than_keep_removes_oldest() -> None: + messages = [ + make_message_with_tool_result(2, 0), + make_message_with_tool_result(2, 0), + ] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 2, 2) + # Only 2 images should remain, and they should be the newest (from the last message) + all_images = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + assert len(all_images) == 2 + # They should be from the last message + assert all_images == [ + c + for b in (filtered[1].content if isinstance(filtered[1].content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content[:2] + if getattr(c, "type", None) == "image" + ] + + +def test_removal_chunking() -> None: + messages = [make_message_with_tool_result(5, 0)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 2, 2) + # Should remove 4 (chunk of 4), leaving 1 image + all_images = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + assert len(all_images) == 3 diff --git a/tests/unit/models/test_model_router.py b/tests/unit/models/test_model_router.py index 12cb439e..258d7292 100644 --- a/tests/unit/models/test_model_router.py +++ b/tests/unit/models/test_model_router.py @@ -9,11 +9,11 @@ from pytest_mock import MockerFixture from askui.exceptions import ModelNotFoundError -from askui.models.anthropic.facade import AnthropicFacade -from askui.models.askui.facade import AskUiFacade from askui.models.huggingface.spaces_api import HFSpacesHandler from askui.models.model_router import ModelRouter from askui.models.models import ModelName +from askui.models.shared.computer_agent_message_param import MessageParam +from askui.models.shared.facade import ModelFacade from askui.models.ui_tars_ep.ui_tars_api import UiTarsApiHandler from askui.reporting import CompositeReporter from askui.tools.toolbox import AgentToolbox @@ -55,9 +55,9 @@ def mock_hf_spaces(mocker: MockerFixture) -> HFSpacesHandler: @pytest.fixture -def mock_anthropic_facade(mocker: MockerFixture) -> AnthropicFacade: +def mock_anthropic_facade(mocker: MockerFixture) -> ModelFacade: """Fixture providing a mock Anthropic facade.""" - mock = cast("AnthropicFacade", mocker.MagicMock(spec=AnthropicFacade)) + mock = cast("ModelFacade", mocker.MagicMock(spec=ModelFacade)) mock.act.return_value = None # type: ignore[attr-defined] mock.get.return_value = "Mock response" # type: ignore[attr-defined] mock.locate.return_value = (50, 50) # type: ignore[attr-defined] @@ -65,9 +65,9 @@ def mock_anthropic_facade(mocker: MockerFixture) -> AnthropicFacade: @pytest.fixture -def mock_askui_facade(mocker: MockerFixture) -> AskUiFacade: +def mock_askui_facade(mocker: MockerFixture) -> ModelFacade: """Fixture providing a mock AskUI facade.""" - mock = cast("AskUiFacade", mocker.MagicMock(spec=AskUiFacade)) + mock = cast("ModelFacade", mocker.MagicMock(spec=ModelFacade)) mock.act.return_value = None # type: ignore[attr-defined] mock.get.return_value = "Mock response" # type: ignore[attr-defined] mock.locate.return_value = (50, 50) # type: ignore[attr-defined] @@ -77,8 +77,8 @@ def mock_askui_facade(mocker: MockerFixture) -> AskUiFacade: @pytest.fixture def model_router( agent_toolbox_mock: AgentToolbox, - mock_anthropic_facade: AnthropicFacade, - mock_askui_facade: AskUiFacade, + mock_anthropic_facade: ModelFacade, + mock_askui_facade: ModelFacade, mock_tars: UiTarsApiHandler, mock_hf_spaces: HFSpacesHandler, ) -> ModelRouter: @@ -110,7 +110,7 @@ def test_locate_with_askui_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI model.""" locator = "test locator" @@ -123,7 +123,7 @@ def test_locate_with_askui_pta_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI PTA model.""" locator = "test locator" @@ -138,7 +138,7 @@ def test_locate_with_askui_ocr_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI OCR model.""" locator = "test locator" @@ -153,7 +153,7 @@ def test_locate_with_askui_combo_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI combo model.""" locator = "test locator" @@ -168,7 +168,7 @@ def test_locate_with_askui_ai_element_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI AI element model.""" locator = "test locator" @@ -196,7 +196,7 @@ def test_locate_with_claude_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_anthropic_facade: AnthropicFacade, + mock_anthropic_facade: ModelFacade, ) -> None: """Test locating elements using Claude model.""" locator = "test locator" @@ -239,7 +239,7 @@ def test_get_with_askui_model( self, model_router: ModelRouter, mock_image_source: ImageSource, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test getting inference using AskUI model.""" response = model_router.get( @@ -265,7 +265,7 @@ def test_get_with_claude_model( self, model_router: ModelRouter, mock_image_source: ImageSource, - mock_anthropic_facade: AnthropicFacade, + mock_anthropic_facade: ModelFacade, ) -> None: """Test getting inference using Claude model.""" response = model_router.get( @@ -289,21 +289,31 @@ def test_act_with_tars_model( self, model_router: ModelRouter, mock_tars: UiTarsApiHandler ) -> None: """Test acting using TARS model.""" - model_router.act("test goal", ModelName.TARS) - mock_tars.act.assert_called_once_with("test goal", ModelName.TARS) # type: ignore + messages = [MessageParam(role="user", content="test goal")] + model_router.act(messages, ModelName.TARS) + mock_tars.act.assert_called_once_with( # type: ignore[attr-defined] + messages, + ModelName.TARS, + None, + ) def test_act_with_claude_model( - self, model_router: ModelRouter, mock_anthropic_facade: AnthropicFacade + self, model_router: ModelRouter, mock_anthropic_facade: ModelFacade ) -> None: """Test acting using Claude model.""" + messages = [MessageParam(role="user", content="test goal")] model_router.act( - "test goal", ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + messages, + ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) mock_anthropic_facade.act.assert_called_once_with( # type: ignore - "test goal", ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + messages, + ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, + None, ) def test_act_with_invalid_model(self, model_router: ModelRouter) -> None: """Test that acting with invalid model raises InvalidModelError.""" + messages = [MessageParam(role="user", content="test goal")] with pytest.raises(ModelNotFoundError): - model_router.act("test goal", "invalid-model") + model_router.act(messages, "invalid-model")