From a648f11694e211fb573c0166fd39670c8ddf5410 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 15 May 2025 21:45:25 +0200 Subject: [PATCH 01/20] feat(mcp): setup mcp app for vision agent --- .nvmrc | 1 + pdm.lock | 329 +++++++++++++++++++++++++++++++++++--- pyproject.toml | 14 +- src/askui/mcp/__init__.py | 22 +++ 4 files changed, 338 insertions(+), 28 deletions(-) create mode 100644 .nvmrc create mode 100644 src/askui/mcp/__init__.py diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 00000000..2bd5a0a9 --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +22 diff --git a/pdm.lock b/pdm.lock index 7e8b0377..fd4c5ca9 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "chat", "test"] +groups = ["default", "chat", "mcp", "pynput", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:0a776bcb0858c92a70bc0221f9adafacdc5af3bc2dd86282aaeb4fce12ed32ac" +content_hash = "sha256:0afb61ae5075adaa78afb4d2995e9714fad46da7eeafc8f4f61ae16b9ec88b18" [[metadata.targets]] requires_python = ">=3.10" @@ -33,7 +33,7 @@ name = "annotated-types" version = "0.7.0" requires_python = ">=3.8" summary = "Reusable constraint types to use with typing.Annotated" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "typing-extensions>=4.0.0; python_version < \"3.9\"", ] @@ -67,7 +67,7 @@ name = "anyio" version = "4.9.0" requires_python = ">=3.9" summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "exceptiongroup>=1.0.2; python_version < \"3.11\"", "idna>=2.8", @@ -128,7 +128,7 @@ name = "certifi" version = "2025.1.31" requires_python = ">=3.6" summary = "Python package for providing Mozilla's CA Bundle." -groups = ["default", "chat"] +groups = ["default", "chat", "mcp"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -202,7 +202,7 @@ name = "click" version = "8.1.8" requires_python = ">=3.7" summary = "Composable command line interface toolkit" -groups = ["chat"] +groups = ["chat", "mcp"] dependencies = [ "colorama; platform_system == \"Windows\"", "importlib-metadata; python_version < \"3.8\"", @@ -217,7 +217,7 @@ name = "colorama" version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." -groups = ["default", "chat", "test"] +groups = ["default", "chat", "mcp", "test"] marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, @@ -364,12 +364,23 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "evdev" +version = "1.9.2" +requires_python = ">=3.8" +summary = "Bindings to the Linux input handling subsystem" +groups = ["pynput"] +marker = "\"linux\" in sys_platform" +files = [ + {file = "evdev-1.9.2.tar.gz", hash = "sha256:5d3278892ce1f92a74d6bf888cc8525d9f68af85dbe336c95d1c87fb8f423069"}, +] + [[package]] name = "exceptiongroup" version = "1.2.2" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default", "test"] +groups = ["default", "mcp", "test"] marker = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, @@ -581,7 +592,7 @@ name = "h11" version = "0.14.0" requires_python = ">=3.7" summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "typing-extensions; python_version < \"3.8\"", ] @@ -595,7 +606,7 @@ name = "httpcore" version = "1.0.7" requires_python = ">=3.8" summary = "A minimal low-level HTTP client." -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "certifi", "h11<0.15,>=0.13", @@ -610,7 +621,7 @@ name = "httpx" version = "0.28.1" requires_python = ">=3.8" summary = "The next generation HTTP client." -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "anyio", "certifi", @@ -622,6 +633,17 @@ files = [ {file = "httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc"}, ] +[[package]] +name = "httpx-sse" +version = "0.4.0" +requires_python = ">=3.8" +summary = "Consume Server-Sent Event (SSE) messages with HTTPX." +groups = ["mcp"] +files = [ + {file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"}, + {file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"}, +] + [[package]] name = "huggingface-hub" version = "0.30.1" @@ -647,7 +669,7 @@ name = "idna" version = "3.10" requires_python = ">=3.6" summary = "Internationalized Domain Names in Applications (IDNA)" -groups = ["default", "chat"] +groups = ["default", "chat", "mcp"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -777,7 +799,7 @@ name = "markdown-it-py" version = "3.0.0" requires_python = ">=3.8" summary = "Python port of markdown-it. Markdown parsing, done right!" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "mdurl~=0.1", ] @@ -846,12 +868,53 @@ files = [ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] +[[package]] +name = "mcp" +version = "1.8.1" +requires_python = ">=3.10" +summary = "Model Context Protocol SDK" +groups = ["mcp"] +dependencies = [ + "anyio>=4.5", + "httpx-sse>=0.4", + "httpx>=0.27", + "pydantic-settings>=2.5.2", + "pydantic<3.0.0,>=2.7.2", + "python-multipart>=0.0.9", + "sse-starlette>=1.6.1", + "starlette>=0.27", + "uvicorn>=0.23.1; sys_platform != \"emscripten\"", +] +files = [ + {file = "mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770"}, + {file = "mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e"}, +] + +[[package]] +name = "mcp" +version = "1.8.1" +extras = ["cli", "rich", "ws"] +requires_python = ">=3.10" +summary = "Model Context Protocol SDK" +groups = ["mcp"] +dependencies = [ + "mcp==1.8.1", + "python-dotenv>=1.0.0", + "rich>=13.9.4", + "typer>=0.12.4", + "websockets>=15.0.1", +] +files = [ + {file = "mcp-1.8.1-py3-none-any.whl", hash = "sha256:948e03783859fa35abe05b9b6c0a1d5519be452fc079dc8d7f682549591c1770"}, + {file = "mcp-1.8.1.tar.gz", hash = "sha256:ec0646271d93749f784d2316fb5fe6102fb0d1be788ec70a9e2517e8f2722c0e"}, +] + [[package]] name = "mdurl" version = "0.1.2" requires_python = ">=3.7" summary = "Markdown URL utilities" -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -1225,7 +1288,7 @@ name = "pydantic" version = "2.11.2" requires_python = ">=3.9" summary = "Data validation using Python type hints" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "annotated-types>=0.6.0", "pydantic-core==2.33.1", @@ -1242,7 +1305,7 @@ name = "pydantic-core" version = "2.33.1" requires_python = ">=3.9" summary = "Core functionality for Pydantic validation and serialization" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "typing-extensions!=4.7.0,>=4.6.0", ] @@ -1331,7 +1394,7 @@ name = "pydantic-settings" version = "2.8.1" requires_python = ">=3.8" summary = "Settings management using Pydantic" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "pydantic>=2.7.0", "python-dotenv>=0.21.0", @@ -1361,7 +1424,7 @@ name = "pygments" version = "2.19.1" requires_python = ">=3.8" summary = "Pygments is a syntax highlighting package written in Python." -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, @@ -1378,6 +1441,122 @@ files = [ {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, ] +[[package]] +name = "pynput" +version = "1.8.1" +summary = "Monitor and control user input devices" +groups = ["pynput"] +dependencies = [ + "enum34; python_version == \"2.7\"", + "evdev>=1.3; \"linux\" in sys_platform", + "pyobjc-framework-ApplicationServices>=8.0; sys_platform == \"darwin\"", + "pyobjc-framework-Quartz>=8.0; sys_platform == \"darwin\"", + "python-xlib>=0.17; \"linux\" in sys_platform", + "six", +] +files = [ + {file = "pynput-1.8.1-py2.py3-none-any.whl", hash = "sha256:42dfcf27404459ca16ca889c8fb8ffe42a9fe54f722fd1a3e130728e59e768d2"}, + {file = "pynput-1.8.1.tar.gz", hash = "sha256:70d7c8373ee98911004a7c938742242840a5628c004573d84ba849d4601df81e"}, +] + +[[package]] +name = "pyobjc-core" +version = "11.0" +requires_python = ">=3.8" +summary = "Python<->ObjC Interoperability Module" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +files = [ + {file = "pyobjc_core-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:10866b3a734d47caf48e456eea0d4815c2c9b21856157db5917b61dee06893a1"}, + {file = "pyobjc_core-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:50675c0bb8696fe960a28466f9baf6943df2928a1fd85625d678fa2f428bd0bd"}, + {file = "pyobjc_core-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a03061d4955c62ddd7754224a80cdadfdf17b6b5f60df1d9169a3b1b02923f0b"}, + {file = "pyobjc_core-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c338c1deb7ab2e9436d4175d1127da2eeed4a1b564b3d83b9f3ae4844ba97e86"}, + {file = "pyobjc_core-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b4e9dc4296110f251a4033ff3f40320b35873ea7f876bd29a1c9705bb5e08c59"}, + {file = "pyobjc_core-11.0.tar.gz", hash = "sha256:63bced211cb8a8fb5c8ff46473603da30e51112861bd02c438fbbbc8578d9a70"}, +] + +[[package]] +name = "pyobjc-framework-applicationservices" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the framework ApplicationServices on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", + "pyobjc-framework-Cocoa>=11.0", + "pyobjc-framework-CoreText>=11.0", + "pyobjc-framework-Quartz>=11.0", +] +files = [ + {file = "pyobjc_framework_ApplicationServices-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bc8f34b5b59ffd3c210ae883d794345c1197558ff3da0f5800669cf16435271e"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:61a99eef23abb704257310db4f5271137707e184768f6407030c01de4731b67b"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:5fbeb425897d6129471d451ec61a29ddd5b1386eb26b1dd49cb313e34616ee21"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:59becf3cd87a4f4cedf4be02ff6cf46ed736f5c1123ce629f788aaafad91eff0"}, + {file = "pyobjc_framework_ApplicationServices-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:44b466e8745fb49e8ac20f29f2ffd7895b45e97aa63a844b2a80a97c3a34346f"}, + {file = "pyobjc_framework_applicationservices-11.0.tar.gz", hash = "sha256:d6ea18dfc7d5626a3ecf4ac72d510405c0d3a648ca38cae8db841acdebecf4d2"}, +] + +[[package]] +name = "pyobjc-framework-cocoa" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the Cocoa frameworks on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", +] +files = [ + {file = "pyobjc_framework_Cocoa-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:fbc65f260d617d5463c7fb9dbaaffc23c9a4fabfe3b1a50b039b61870b8daefd"}, + {file = "pyobjc_framework_Cocoa-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3ea7be6e6dd801b297440de02d312ba3fa7fd3c322db747ae1cb237e975f5d33"}, + {file = "pyobjc_framework_Cocoa-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:280a577b83c68175a28b2b7138d1d2d3111f2b2b66c30e86f81a19c2b02eae71"}, + {file = "pyobjc_framework_Cocoa-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:15b2bd977ed340074f930f1330f03d42912d5882b697d78bd06f8ebe263ef92e"}, + {file = "pyobjc_framework_Cocoa-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:5750001db544e67f2b66f02067d8f0da96bb2ef71732bde104f01b8628f9d7ea"}, + {file = "pyobjc_framework_cocoa-11.0.tar.gz", hash = "sha256:00346a8cb81ad7b017b32ff7bf596000f9faa905807b1bd234644ebd47f692c5"}, +] + +[[package]] +name = "pyobjc-framework-coretext" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the framework CoreText on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", + "pyobjc-framework-Cocoa>=11.0", + "pyobjc-framework-Quartz>=11.0", +] +files = [ + {file = "pyobjc_framework_CoreText-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6939b4ea745b349b5c964823a2071f155f5defdc9b9fc3a13f036d859d7d0439"}, + {file = "pyobjc_framework_CoreText-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:56a4889858308b0d9f147d568b4d91c441cc0ffd332497cb4f709bb1990450c1"}, + {file = "pyobjc_framework_CoreText-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fb90e7f370b3fd7cb2fb442e3dc63fedf0b4af6908db1c18df694d10dc94669d"}, + {file = "pyobjc_framework_CoreText-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7947f755782456bd663e0b00c7905eeffd10f839f0bf2af031f68ded6a1ea360"}, + {file = "pyobjc_framework_CoreText-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:5356116bae33ec49f1f212c301378a7d08000440a2d6a7281aab351945528ab9"}, + {file = "pyobjc_framework_coretext-11.0.tar.gz", hash = "sha256:a68437153e627847e3898754dd3f13ae0cb852246b016a91f9c9cbccb9f91a43"}, +] + +[[package]] +name = "pyobjc-framework-quartz" +version = "11.0" +requires_python = ">=3.9" +summary = "Wrappers for the Quartz frameworks on macOS" +groups = ["pynput"] +marker = "sys_platform == \"darwin\"" +dependencies = [ + "pyobjc-core>=11.0", + "pyobjc-framework-Cocoa>=11.0", +] +files = [ + {file = "pyobjc_framework_Quartz-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:da3ab13c9f92361959b41b0ad4cdd41ae872f90a6d8c58a9ed699bc08ab1c45c"}, + {file = "pyobjc_framework_Quartz-11.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d251696bfd8e8ef72fbc90eb29fec95cb9d1cc409008a183d5cc3246130ae8c2"}, + {file = "pyobjc_framework_Quartz-11.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:cb4a9f2d9d580ea15e25e6b270f47681afb5689cafc9e25712445ce715bcd18e"}, + {file = "pyobjc_framework_Quartz-11.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:973b4f9b8ab844574461a038bd5269f425a7368d6e677e3cc81fcc9b27b65498"}, + {file = "pyobjc_framework_Quartz-11.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:66ab58d65348863b8707e63b2ec5cdc54569ee8189d1af90d52f29f5fdf6272c"}, + {file = "pyobjc_framework_quartz-11.0.tar.gz", hash = "sha256:3205bf7795fb9ae34747f701486b3db6dfac71924894d1f372977c4d70c3c619"}, +] + [[package]] name = "pyperclip" version = "1.9.0" @@ -1483,12 +1662,37 @@ name = "python-dotenv" version = "1.1.0" requires_python = ">=3.9" summary = "Read key-value pairs from a .env file and set them as environment variables" -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"}, {file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"}, ] +[[package]] +name = "python-multipart" +version = "0.0.20" +requires_python = ">=3.8" +summary = "A streaming multipart parser for Python" +groups = ["mcp"] +files = [ + {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, + {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, +] + +[[package]] +name = "python-xlib" +version = "0.33" +summary = "Python X Library" +groups = ["pynput"] +marker = "\"linux\" in sys_platform" +dependencies = [ + "six>=1.10.0", +] +files = [ + {file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"}, + {file = "python_xlib-0.33-py2.py3-none-any.whl", hash = "sha256:c3534038d42e0df2f1392a1b30a15a4ff5fdc2b86cfa94f072bf11b10a164398"}, +] + [[package]] name = "pytz" version = "2025.2" @@ -1583,7 +1787,7 @@ name = "rich" version = "14.0.0" requires_python = ">=3.8.0" summary = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "markdown-it-py>=2.2.0", "pygments<3.0.0,>=2.13.0", @@ -1747,12 +1951,23 @@ files = [ {file = "setuptools-78.1.0.tar.gz", hash = "sha256:18fd474d4a82a5f83dac888df697af65afa82dec7323d09c3e37d1f14288da54"}, ] +[[package]] +name = "shellingham" +version = "1.5.4" +requires_python = ">=3.7" +summary = "Tool to Detect Surrounding Shell" +groups = ["mcp"] +files = [ + {file = "shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686"}, + {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, +] + [[package]] name = "six" version = "1.17.0" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" summary = "Python 2 and 3 compatibility utilities" -groups = ["default", "chat"] +groups = ["default", "chat", "pynput"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -1774,12 +1989,42 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "sse-starlette" +version = "2.3.5" +requires_python = ">=3.9" +summary = "SSE plugin for Starlette" +groups = ["mcp"] +dependencies = [ + "anyio>=4.7.0", + "starlette>=0.41.3", +] +files = [ + {file = "sse_starlette-2.3.5-py3-none-any.whl", hash = "sha256:251708539a335570f10eaaa21d1848a10c42ee6dc3a9cf37ef42266cdb1c52a8"}, + {file = "sse_starlette-2.3.5.tar.gz", hash = "sha256:228357b6e42dcc73a427990e2b4a03c023e2495ecee82e14f07ba15077e334b2"}, +] + +[[package]] +name = "starlette" +version = "0.46.2" +requires_python = ">=3.9" +summary = "The little ASGI library that shines." +groups = ["mcp"] +dependencies = [ + "anyio<5,>=3.6.2", + "typing-extensions>=3.10.0; python_version < \"3.10\"", +] +files = [ + {file = "starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35"}, + {file = "starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5"}, +] + [[package]] name = "streamlit" version = "1.44.1" @@ -1909,6 +2154,23 @@ files = [ {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, ] +[[package]] +name = "typer" +version = "0.15.4" +requires_python = ">=3.7" +summary = "Typer, build great CLIs. Easy to code. Based on Python type hints." +groups = ["mcp"] +dependencies = [ + "click<8.2,>=8.0.0", + "rich>=10.11.0", + "shellingham>=1.3.0", + "typing-extensions>=3.7.4.3", +] +files = [ + {file = "typer-0.15.4-py3-none-any.whl", hash = "sha256:eb0651654dcdea706780c466cf06d8f174405a659ffff8f163cfbfee98c0e173"}, + {file = "typer-0.15.4.tar.gz", hash = "sha256:89507b104f9b6a0730354f27c39fae5b63ccd0c95b1ce1f1a6ba0cfd329997c3"}, +] + [[package]] name = "types-pillow" version = "10.2.0.20240822" @@ -1972,7 +2234,7 @@ name = "typing-extensions" version = "4.13.1" requires_python = ">=3.8" summary = "Backported and Experimental Type Hints for Python 3.8+" -groups = ["default", "chat", "test"] +groups = ["default", "chat", "mcp", "test"] files = [ {file = "typing_extensions-4.13.1-py3-none-any.whl", hash = "sha256:4b6cf02909eb5495cfbc3f6e8fd49217e6cc7944e145cdda8caa3734777f9e69"}, {file = "typing_extensions-4.13.1.tar.gz", hash = "sha256:98795af00fb9640edec5b8e31fc647597b4691f099ad75f469a2616be1a76dff"}, @@ -1983,7 +2245,7 @@ name = "typing-inspection" version = "0.4.0" requires_python = ">=3.9" summary = "Runtime typing introspection tools" -groups = ["default"] +groups = ["default", "mcp"] dependencies = [ "typing-extensions>=4.12.0", ] @@ -2014,6 +2276,23 @@ files = [ {file = "urllib3-2.3.0.tar.gz", hash = "sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d"}, ] +[[package]] +name = "uvicorn" +version = "0.34.2" +requires_python = ">=3.9" +summary = "The lightning-fast ASGI server." +groups = ["mcp"] +marker = "sys_platform != \"emscripten\"" +dependencies = [ + "click>=7.0", + "h11>=0.8", + "typing-extensions>=4.0; python_version < \"3.11\"", +] +files = [ + {file = "uvicorn-0.34.2-py3-none-any.whl", hash = "sha256:deb49af569084536d269fe0a6d67e3754f104cf03aba7c11c40f01aadf33c403"}, + {file = "uvicorn-0.34.2.tar.gz", hash = "sha256:0e929828f6186353a80b58ea719861d2629d766293b6d19baf086ba31d4f3328"}, +] + [[package]] name = "watchdog" version = "6.0.0" @@ -2054,7 +2333,7 @@ name = "websockets" version = "15.0.1" requires_python = ">=3.9" summary = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" -groups = ["default"] +groups = ["default", "mcp"] files = [ {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b"}, {file = "websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205"}, diff --git a/pyproject.toml b/pyproject.toml index f3031588..eaec5221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,8 +50,18 @@ lint = "ruff check src tests" typecheck = "mypy" "typecheck:all" = "mypy src tests" chat = "streamlit run src/askui/chat/__main__.py" +mcp = "mcp dev src/askui/mcp/__init__.py" [dependency-groups] +chat = [ + "streamlit>=1.42.0", +] +mcp = [ + "mcp[cli,rich,ws]>=1.8.1", +] +pynput = [ + "pynput>=1.8.1", +] test = [ "pytest>=8.3.4", "ruff>=0.9.5", @@ -67,9 +77,7 @@ test = [ "types-pyperclip>=1.8.2.20240311", "pytest-timeout>=2.4.0", ] -chat = [ - "streamlit>=1.42.0", -] + [tool.pytest.ini_options] addopts = "--cov=src/askui --cov-report=html" diff --git a/src/askui/mcp/__init__.py b/src/askui/mcp/__init__.py new file mode 100644 index 00000000..c23dd663 --- /dev/null +++ b/src/askui/mcp/__init__.py @@ -0,0 +1,22 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass + +from mcp.server.fastmcp import FastMCP + +from askui.agent import VisionAgent + + +@dataclass +class AppContext: + vision_agent: VisionAgent + + +@asynccontextmanager +async def mcp_lifespan(server: FastMCP) -> AsyncIterator[AppContext]: # noqa: ARG001 + with VisionAgent(display=2) as vision_agent: + server.add_tool(vision_agent.click) + yield AppContext(vision_agent=vision_agent) + + +mcp = FastMCP("Vision Agent MCP App", lifespan=mcp_lifespan) From 67c4f47f90f63c83892697ecba46d687d2dae03c Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Fri, 16 May 2025 15:56:26 +0200 Subject: [PATCH 02/20] refactor(chat): extract threads and messages api --- src/askui/chat/__main__.py | 252 ++++++++++++++------------------- src/askui/chat/api/__init__.py | 0 src/askui/chat/api/messages.py | 234 ++++++++++++++++++++++++++++++ src/askui/chat/api/threads.py | 115 +++++++++++++++ src/askui/chat/api/utils.py | 22 +++ 5 files changed, 476 insertions(+), 147 deletions(-) create mode 100644 src/askui/chat/api/__init__.py create mode 100644 src/askui/chat/api/messages.py create mode 100644 src/askui/chat/api/threads.py create mode 100644 src/askui/chat/api/utils.py diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 05b39ae8..03de1bf5 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -3,71 +3,43 @@ import re from datetime import datetime, timezone from pathlib import Path -from random import randint -from typing import Union +from typing import Union, cast import streamlit as st from PIL import Image, ImageDraw -from typing_extensions import TypedDict, override +from typing_extensions import override from askui import VisionAgent +from askui.chat.api.messages import MessageRole, MessagesApi +from askui.chat.api.threads import ThreadsApi from askui.chat.click_recorder import ClickRecorder from askui.chat.exceptions import FunctionExecutionError, InvalidFunctionError from askui.models import ModelName from askui.reporting import Reporter from askui.utils.image_utils import base64_to_image, draw_point_on_image +# TODO Start backend server + st.set_page_config( page_title="Vision Agent Chat", page_icon="💬", ) +BASE_DIR = Path("./chat") +threads_api = ThreadsApi(BASE_DIR) +messages_api = MessagesApi(BASE_DIR) -CHAT_SESSIONS_DIR_PATH = Path("./chat/sessions") -CHAT_IMAGES_DIR_PATH = Path("./chat/images") - -click_recorder = ClickRecorder() - - -def setup_chat_dirs() -> None: - Path.mkdir(CHAT_SESSIONS_DIR_PATH, parents=True, exist_ok=True) - Path.mkdir(CHAT_IMAGES_DIR_PATH, parents=True, exist_ok=True) - - -def get_session_id_from_path(path: str) -> str: - """Get session ID from file path.""" - return Path(path).stem - - -def load_chat_history(session_id: str) -> list[dict]: - """Load chat history for a given session ID.""" - messages: list[dict] = [] - session_path = CHAT_SESSIONS_DIR_PATH / f"{session_id}.jsonl" - if session_path.exists(): - with session_path.open("r") as f: - messages.extend(json.loads(line) for line in f) - return messages - - -ROLE_MAP = { - "user": "user", - "anthropic computer use": "ai", - "agentos": "assistant", - "user (demonstration)": "user", -} - +click_recorder = ClickRecorder() # TODO Tool, pynput alternatively -UNKNOWN_ROLE = "unknown" - -def get_image(img_b64_str_or_path: str) -> Image.Image: +def get_image(img_b64_str_or_path: str) -> Image.Image: # TODO Image utils """Get image from base64 string or file path.""" if Path(img_b64_str_or_path).is_file(): return Image.open(img_b64_str_or_path) return base64_to_image(img_b64_str_or_path) -def write_message( +def write_message( # TODO updating frontend role: str, content: str | dict | list, timestamp: str, @@ -77,44 +49,42 @@ def write_message( | list[str] | list[Image.Image] | None = None, + message_id: str | None = None, ) -> None: - _role = ROLE_MAP.get(role.lower(), UNKNOWN_ROLE) - avatar = None if _role != UNKNOWN_ROLE else "❔" - with st.chat_message(_role, avatar=avatar): - st.markdown(f"*{timestamp}* - **{role}**\n\n") - st.markdown( - json.dumps(content, indent=2) - if isinstance(content, (dict, list)) - else content - ) - if image: - if isinstance(image, list): - for img in image: - img = get_image(img) if isinstance(img, str) else img + _role = messages_api.ROLE_MAP.get(role.lower(), MessageRole.UNKNOWN) + avatar = None if _role != MessageRole.UNKNOWN else "❔" + + # Create a container for the message and delete button + col1, col2 = st.columns([0.95, 0.05]) + + with col1: + with st.chat_message(_role.value, avatar=avatar): + st.markdown(f"*{timestamp}* - **{role}**\n\n") + st.markdown( + json.dumps(content, indent=2) + if isinstance(content, (dict, list)) + else content + ) + if image: + if isinstance(image, list): + for img in image: + img = get_image(img) if isinstance(img, str) else img + st.image(img) + else: + img = get_image(image) if isinstance(image, str) else image st.image(img) - else: - img = get_image(image) if isinstance(image, str) else image - st.image(img) - -def save_image(image: Image.Image) -> str: - """Save image to disk and return path.""" - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S_%f") - image_path = CHAT_IMAGES_DIR_PATH / f"image_{timestamp}.png" - image.save(image_path) - return str(image_path) - - -class Message(TypedDict): - role: str - content: str | dict | list - timestamp: str - image: str | list[str] | None + # Add delete button in the second column if message_id is provided + if message_id: + with col2: + if st.button("🗑️", key=f"delete_{message_id}"): + messages_api.delete(st.session_state.thread_id, message_id) + st.rerun() class ChatHistoryAppender(Reporter): - def __init__(self, session_id: str) -> None: - self._session_id = session_id + def __init__(self, thread_id: str) -> None: + self._thread_id = thread_id @override def add_message( @@ -123,45 +93,22 @@ def add_message( content: Union[str, dict, list], image: Image.Image | list[Image.Image] | None = None, ) -> None: - image_paths: list[str] = [] - if image is None: - _images = [] - elif isinstance(image, list): - _images = image - else: - _images = [image] - image_paths.extend(save_image(img) for img in _images) - message = Message( - role=role, - content=content, - timestamp=datetime.now(tz=timezone.utc).isoformat(), - image=image_paths, + message = messages_api.create( + thread_id=self._thread_id, role=role, content=content, image=image + ) + write_message( + role=message.role.value, + content=message.content[0].text or "", + timestamp=message.created_at.isoformat(), + image=message.content[0].image_paths, + message_id=message.id, ) - write_message(**message) - with (CHAT_SESSIONS_DIR_PATH / f"{self._session_id}.jsonl").open("a") as f: - json.dump(message, f) - f.write("\n") @override def generate(self) -> None: pass -def get_available_sessions() -> list[str]: - """Get list of available session IDs.""" - session_files = list(CHAT_SESSIONS_DIR_PATH.glob("*.jsonl")) - return sorted([get_session_id_from_path(f) for f in session_files], reverse=True) - - -def create_new_session() -> str: - """Create a new chat session.""" - timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S%f") - random_suffix = f"{randint(100, 999)}" - session_id = f"{timestamp}{random_suffix}" - (CHAT_SESSIONS_DIR_PATH / f"{session_id}.jsonl").touch() - return session_id - - def paint_crosshair( image: Image.Image, coordinates: tuple[int, int], @@ -207,18 +154,23 @@ def rerun() -> None: log_level=logging.DEBUG, ) as agent: screenshot: Image.Image | None = None - for message in st.session_state.messages: + for message in messages_api.list_(st.session_state.thread_id).data: try: if ( - message.get("role") == "AgentOS" - or message.get("role") == "User (Demonstration)" + message.role == MessageRole.ASSISTANT + or message.role == MessageRole.USER ): - if message.get("content") == "screenshot()": - screenshot = get_image(message["image"]) + content = message.content[0] + if content.text == "screenshot()": + screenshot = ( + get_image(content.image_paths[0]) + if content.image_paths + else None + ) continue - if message.get("content"): + if content.text: if match := re.match( - r"mouse\((\d+),\s*(\d+)\)", message["content"] + r"mouse\((\d+),\s*(\d+)\)", cast("str", content.text) ): if not screenshot: error_msg = "Screenshot is required to paint crosshair" @@ -232,10 +184,10 @@ def rerun() -> None: image=screenshot_with_crosshair, model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) - write_message( - message["role"], - f"Move mouse to {element_description}", - datetime.now(tz=timezone.utc).isoformat(), + messages_api.create( + thread_id=st.session_state.thread_id, + role=message.role.value, + content=f"Move mouse to {element_description}", image=screenshot_with_crosshair, ) agent.mouse_move( @@ -243,58 +195,64 @@ def rerun() -> None: model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) else: - write_message( - message["role"], - message["content"], - datetime.now(tz=timezone.utc).isoformat(), - message.get("image"), + messages_api.create( + thread_id=st.session_state.thread_id, + role=message.role.value, + content=content.text, + image=None, ) - func_call = f"agent.tools.os.{message['content']}" + func_call = f"agent.tools.os.{content.text}" eval(func_call) except json.JSONDecodeError: continue except AttributeError: - st.write(str(InvalidFunctionError(message["content"]))) + st.write(str(InvalidFunctionError(cast("str", content.text)))) except Exception as e: # noqa: BLE001 - We want to catch all other exceptions here - st.write(str(FunctionExecutionError(message["content"], e))) - + st.write(str(FunctionExecutionError(cast("str", content.text), e))) -setup_chat_dirs() if st.sidebar.button("New Chat"): - st.session_state.session_id = create_new_session() + thread = threads_api.create() + st.session_state.thread_id = thread.id st.rerun() -available_sessions = get_available_sessions() -session_id = st.session_state.get("session_id", None) +available_threads = threads_api.list_().data +thread_id = st.session_state.get("thread_id", None) -if not session_id and not available_sessions: - session_id = create_new_session() - st.session_state.session_id = session_id +if not thread_id and not available_threads: + thread = threads_api.create() + thread_id = thread.id + st.session_state.thread_id = thread_id st.rerun() -index_of_session = available_sessions.index(session_id) if session_id else 0 -session_id = st.sidebar.radio( - "Sessions", - available_sessions, - index=index_of_session, +index_of_thread = 0 +if thread_id: + for index, thread in enumerate(available_threads): + if thread.id == thread_id: + index_of_thread = index + break + +thread_id = st.sidebar.radio( + "Threads", + [t.id for t in available_threads], + index=index_of_thread, ) -if session_id != st.session_state.get("session_id"): - st.session_state.session_id = session_id +if thread_id != st.session_state.get("thread_id"): + st.session_state.thread_id = thread_id st.rerun() -reporter = ChatHistoryAppender(session_id) +reporter = ChatHistoryAppender(thread_id) -st.title(f"Vision Agent Chat - {session_id}") -st.session_state.messages = load_chat_history(session_id) +st.title(f"Vision Agent Chat - {thread_id}") # Display chat history -for message in st.session_state.messages: +for message in messages_api.list_(thread_id).data: write_message( - message["role"], - message["content"], - message["timestamp"], - message.get("image"), + message.role.value, + message.content[0].text or "", + message.created_at.isoformat(), + message.content[0].image_paths, + message.id, # Pass the message ID to enable deletion ) if value_to_type := st.chat_input("Simulate Typing for User (Demonstration)"): @@ -333,7 +291,7 @@ def rerun() -> None: log_level=logging.DEBUG, reporters=[reporter], ) as agent: - agent.act(act_prompt, model="claude") + agent.act(act_prompt, model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022) st.rerun() if st.button("Rerun"): diff --git a/src/askui/chat/api/__init__.py b/src/askui/chat/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/messages.py b/src/askui/chat/api/messages.py new file mode 100644 index 00000000..0ee4af38 --- /dev/null +++ b/src/askui/chat/api/messages.py @@ -0,0 +1,234 @@ +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Sequence, Union + +from PIL import Image +from pydantic import AwareDatetime, BaseModel, Field + +from askui.chat.api.utils import generate_time_ordered_id + + +class MessageRole(str, Enum): + """Valid message roles.""" + + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + AI = "ai" + UNKNOWN = "unknown" + + +class MessageContent(BaseModel): + """Message content with optional image paths.""" + + text: str | None = None + image_paths: list[str] | None = None + + +class Message(BaseModel): + """A message in a thread.""" + + id: str = Field(default_factory=lambda: generate_time_ordered_id("msg")) + thread_id: str + role: MessageRole + content: Sequence[MessageContent] + created_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + object: str = "message" + + +class MessageListResponse(BaseModel): + """Response model for listing messages.""" + + object: str = "list" + data: Sequence[Message] + first_id: str | None = None + last_id: str | None = None + has_more: bool = False + + +class MessagesApi: + """API for managing messages within threads.""" + + ROLE_MAP = { + "user": MessageRole.USER, + "anthropic computer use": MessageRole.AI, + "agentos": MessageRole.ASSISTANT, + "user (demonstration)": MessageRole.USER, + } + + def __init__(self, base_dir: Path) -> None: + """Initialize messages API. + + Args: + base_dir: Base directory to store message data + """ + self._base_dir = base_dir + self._threads_dir = base_dir / "threads" + self._images_dir = base_dir / "images" + + def list_(self, thread_id: str, limit: int | None = None) -> MessageListResponse: + """List all messages in a thread. + + Args: + thread_id: ID of thread to list messages from + limit: Optional maximum number of messages to return + + Returns: + MessageListResponse containing messages sorted by creation date + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + messages = [] + with thread_file.open("r") as f: + for line in f: + msg = Message.model_validate_json(line) + messages.append(msg) + + # Sort by creation date + messages = sorted(messages, key=lambda m: m.created_at) + + # Apply limit if specified + if limit is not None: + messages = messages[:limit] + + return MessageListResponse( + data=messages, + first_id=messages[0].id if messages else None, + last_id=messages[-1].id if messages else None, + has_more=len(messages) > (limit or len(messages)), + ) + + def create( + self, + thread_id: str, + role: str, + content: Union[str, dict, list], + image: Image.Image | list[Image.Image] | None = None, + ) -> Message: + """Create a new message in a thread. + + Args: + thread_id: ID of thread to create message in + role: Role of message sender + content: Message content + image: Optional image(s) to attach + + Returns: + Created message object + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + # Save images if provided + image_paths = [] + if image is not None: + if isinstance(image, list): + images = image + else: + images = [image] + + self._images_dir.mkdir(parents=True, exist_ok=True) + for img in images: + # Generate unique image ID using same format as thread/message IDs + image_id = generate_time_ordered_id("img") + image_path = self._images_dir / f"{image_id}.png" + img.save(image_path) + image_paths.append(str(image_path)) + + # Create message content + message_content = [ + MessageContent( + text=str(content), image_paths=image_paths if image_paths else None + ) + ] + + # Create message + message = Message( + thread_id=thread_id, + role=self.ROLE_MAP.get(role.lower(), MessageRole.UNKNOWN), + content=message_content, + ) + + # Save message + with thread_file.open("a") as f: + f.write(message.model_dump_json()) + f.write("\n") + + return message + + def retrieve(self, thread_id: str, message_id: str) -> Message: + """Retrieve a specific message from a thread. + + Args: + thread_id: ID of thread containing message + message_id: ID of message to retrieve + + Returns: + Message object + + Raises: + FileNotFoundError: If thread or message doesn't exist + """ + messages = self.list_(thread_id).data + for msg in messages: + if msg.id == message_id: + return msg + error_msg = f"Message {message_id} not found in thread {thread_id}" + raise FileNotFoundError(error_msg) + + def delete(self, thread_id: str, message_id: str) -> None: + """Delete a message from a thread. + + Args: + thread_id: ID of thread containing message + message_id: ID of message to delete + + Raises: + FileNotFoundError: If thread or message doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + # Get message and image paths before deletion + msg_to_delete = self.retrieve(thread_id, message_id) + image_paths = ( + msg_to_delete.content[0].image_paths if msg_to_delete.content else None + ) + + # Read all messages + messages = [] + with thread_file.open("r") as f: + for line in f: + msg = Message.model_validate_json(line) + if msg.id != message_id: + messages.append(msg) + + # Write back all messages except the deleted one + with thread_file.open("w") as f: + for msg in messages: + f.write(msg.model_dump_json()) + f.write("\n") + + # Delete associated images if any + if image_paths: + for img_path in image_paths: + try: + Path(img_path).unlink() + except FileNotFoundError: + pass # Image might have been deleted already diff --git a/src/askui/chat/api/threads.py b/src/askui/chat/api/threads.py new file mode 100644 index 00000000..6540a46e --- /dev/null +++ b/src/askui/chat/api/threads.py @@ -0,0 +1,115 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import Sequence + +from pydantic import AwareDatetime, BaseModel, Field + +from askui.chat.api.utils import generate_time_ordered_id + + +class Thread(BaseModel): + """A chat thread/session.""" + + id: str = Field(default_factory=lambda: generate_time_ordered_id("thread")) + created_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + object: str = "thread" + + +class ThreadListResponse(BaseModel): + """Response model for listing threads.""" + + object: str = "list" + data: Sequence[Thread] + first_id: str | None = None + last_id: str | None = None + has_more: bool = False + + +class ThreadsApi: + """API for managing chat threads/sessions.""" + + def __init__(self, base_dir: Path) -> None: + """Initialize threads API. + + Args: + base_dir: Base directory to store thread data + """ + self._base_dir = base_dir + self._threads_dir = base_dir / "threads" + + def list_(self, limit: int | None = None) -> ThreadListResponse: + """List all available threads. + + Args: + limit: Optional maximum number of threads to return + + Returns: + ThreadListResponse containing threads sorted by creation date (newest first) + """ + if not self._threads_dir.exists(): + return ThreadListResponse(data=[]) + + thread_files = list(self._threads_dir.glob("*.jsonl")) + threads = [] + for f in thread_files: + thread_id = f.stem + created_at = datetime.fromtimestamp(f.stat().st_ctime, tz=timezone.utc) + threads.append( + Thread( + id=thread_id, + created_at=created_at, + ) + ) + + # Sort by creation date, newest first + threads = sorted(threads, key=lambda t: t.created_at, reverse=True) + + # Apply limit if specified + if limit is not None: + threads = threads[:limit] + + return ThreadListResponse( + data=threads, + first_id=threads[0].id if threads else None, + last_id=threads[-1].id if threads else None, + has_more=len(thread_files) > (limit or len(thread_files)), + ) + + def create(self) -> Thread: + """Create a new thread. + + Returns: + Created thread object + """ + thread = Thread() + thread_file = self._threads_dir / f"{thread.id}.jsonl" + self._threads_dir.mkdir(parents=True, exist_ok=True) + thread_file.touch() + return thread + + def retrieve(self, thread_id: str) -> Thread: + """Retrieve a thread by ID. + + Args: + thread_id: ID of thread to retrieve + + Returns: + Thread object + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + created_at = datetime.fromtimestamp( + thread_file.stat().st_ctime, tz=timezone.utc + ) + return Thread( + id=thread_id, + created_at=created_at, + ) diff --git a/src/askui/chat/api/utils.py b/src/askui/chat/api/utils.py new file mode 100644 index 00000000..590727c2 --- /dev/null +++ b/src/askui/chat/api/utils.py @@ -0,0 +1,22 @@ +import base64 +import os +import time + + +def generate_time_ordered_id(prefix: str) -> str: + """Generate a time-ordered ID with format: prefix_timestamp_random. + + Args: + prefix: Prefix for the ID (e.g. 'thread', 'msg') + + Returns: + Time-ordered ID string + """ + # Get current timestamp in milliseconds + timestamp = int(time.time() * 1000) + # Convert to base32 for shorter string (removing padding) + timestamp_b32 = base64.b32encode(str(timestamp).encode()).decode().rstrip("=").lower() + # Get 12 random bytes and convert to base32 (removing padding) + random_bytes = os.urandom(12) + random_b32 = base64.b32encode(random_bytes).decode().rstrip("=").lower() + return f"{prefix}_{timestamp_b32}{random_b32}" From 2907198b0c18cba309b7c493c9521a6f11814578 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Fri, 16 May 2025 16:18:56 +0200 Subject: [PATCH 03/20] feat(chat): add thread deletion --- src/askui/chat/__main__.py | 31 +++++++++++++++++++++++++------ src/askui/chat/api/threads.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 03de1bf5..9f5c0ecc 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -1,7 +1,6 @@ import json import logging import re -from datetime import datetime, timezone from pathlib import Path from typing import Union, cast @@ -232,11 +231,31 @@ def rerun() -> None: index_of_thread = index break -thread_id = st.sidebar.radio( - "Threads", - [t.id for t in available_threads], - index=index_of_thread, -) +# Create columns for thread selection and delete buttons +thread_cols = st.sidebar.columns([0.8, 0.2]) +with thread_cols[0]: + thread_id = st.radio( + "Threads", + [t.id for t in available_threads], + index=index_of_thread, + ) + +# Add delete buttons for each thread +for t in available_threads: + with thread_cols[1]: + if st.button("🗑️", key=f"delete_thread_{t.id}"): + if t.id == thread_id: + # If deleting current thread, switch to first available thread + remaining_threads = [th for th in available_threads if th.id != t.id] + if remaining_threads: + st.session_state.thread_id = remaining_threads[0].id + else: + # Create new thread if no threads left + new_thread = threads_api.create() + st.session_state.thread_id = new_thread.id + threads_api.delete(t.id) + st.rerun() + if thread_id != st.session_state.get("thread_id"): st.session_state.thread_id = thread_id st.rerun() diff --git a/src/askui/chat/api/threads.py b/src/askui/chat/api/threads.py index 6540a46e..24aa1663 100644 --- a/src/askui/chat/api/threads.py +++ b/src/askui/chat/api/threads.py @@ -113,3 +113,36 @@ def retrieve(self, thread_id: str) -> Thread: id=thread_id, created_at=created_at, ) + + def delete(self, thread_id: str) -> None: + """Delete a thread and all its associated files. + + Args: + thread_id: ID of thread to delete + + Raises: + FileNotFoundError: If thread doesn't exist + """ + thread_file = self._threads_dir / f"{thread_id}.jsonl" + if not thread_file.exists(): + error_msg = f"Thread {thread_id} not found" + raise FileNotFoundError(error_msg) + + # Get all image paths from messages before deleting thread + from askui.chat.api.messages import MessagesApi + + messages_api = MessagesApi(self._base_dir) + try: + messages = messages_api.list_(thread_id).data + for msg in messages: + if msg.content and msg.content[0].image_paths: + for img_path in msg.content[0].image_paths: + try: + Path(img_path).unlink() + except FileNotFoundError: + pass # Image might have been deleted already + except FileNotFoundError: + pass # Thread might have been deleted already + + # Delete thread file + thread_file.unlink() From 15c948c10c436a0e8d6f6faf14a7df668a8bc8fe Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Mon, 19 May 2025 17:30:11 +0200 Subject: [PATCH 04/20] feat(tools)!: add `PynputAgentOs` also: - add different scripts for different test runs (with vs. without coverage) - remove general coverage reporting as it prevents break points in debugger - fix `askui.tools.anthropic.computer.ComputerTool`, e.g., hard coding of height and width of screen - rm nonworking mintlify VSCodeextensions from recommendations BREAKING CHANGE: - rename `AgentOs.set_display`'s `displayNumber` parameter to `display` (snake_case, consistent with `AskUiControllerClient` constructor parameter) - rename `AskUiControllerClient.set_display`'s `displayNumber` parameter to `display` (snake_case, consistent with `AskUiControllerClient` constructor parameter) - rename `AgentOs.mouse` to `AgentOs.mouse_move` (including implementations `PynputAgentOs` and `AskUiControllerClient`) --- .vscode/extensions.json | 1 - README.md | 2 +- pdm.lock | 44 ++- pyproject.toml | 14 +- src/askui/chat/__main__.py | 16 +- src/askui/models/ui_tars_ep/ui_tars_api.py | 2 +- src/askui/reporting.py | 4 +- src/askui/tools/agent_os.py | 6 +- src/askui/tools/anthropic/computer.py | 60 ++-- src/askui/tools/askui/askui_controller.py | 14 +- src/askui/tools/pynput/__init__.py | 0 src/askui/tools/pynput/pynput_agent_os.py | 328 ++++++++++++++++++ tests/e2e/tools/__init__.py | 0 tests/e2e/tools/pynput/__init__.py | 1 + .../tools/askui/test_askui_controller.py | 2 +- 15 files changed, 427 insertions(+), 67 deletions(-) create mode 100644 src/askui/tools/pynput/__init__.py create mode 100644 src/askui/tools/pynput/pynput_agent_os.py create mode 100644 tests/e2e/tools/__init__.py create mode 100644 tests/e2e/tools/pynput/__init__.py diff --git a/.vscode/extensions.json b/.vscode/extensions.json index cbe1788d..711148cb 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -15,7 +15,6 @@ "tamasfe.even-better-toml", "visualstudioexptteam.vscodeintellicode", "dongli.python-preview", - "mintlify.document", "kaih2o.python-resource-monitor", "littlefoxteam.vscode-python-test-adapter", "almenon.arepl" diff --git a/README.md b/README.md index b7b61140..7089109a 100644 --- a/README.md +++ b/README.md @@ -410,7 +410,7 @@ The controller for the operating system. ```python agent.tools.os.click("left", 2) # clicking -agent.tools.os.mouse(100, 100) # mouse movement +agent.tools.os.mouse_move(100, 100) # mouse movement agent.tools.os.keyboard_tap("v", modifier_keys=["control"]) # Paste # and many more ``` diff --git a/pdm.lock b/pdm.lock index fd4c5ca9..5bc26590 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "chat", "mcp", "pynput", "test"] +groups = ["default", "chat", "mcp", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:0afb61ae5075adaa78afb4d2995e9714fad46da7eeafc8f4f61ae16b9ec88b18" +content_hash = "sha256:5d81bb8ca8a8b98d18f4632c5fc1dd409c44d385724e4f321922ae71ec152f33" [[metadata.targets]] requires_python = ">=3.10" @@ -369,7 +369,7 @@ name = "evdev" version = "1.9.2" requires_python = ">=3.8" summary = "Bindings to the Linux input handling subsystem" -groups = ["pynput"] +groups = ["default"] marker = "\"linux\" in sys_platform" files = [ {file = "evdev-1.9.2.tar.gz", hash = "sha256:5d3278892ce1f92a74d6bf888cc8525d9f68af85dbe336c95d1c87fb8f423069"}, @@ -920,6 +920,17 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mss" +version = "10.0.0" +requires_python = ">=3.9" +summary = "An ultra fast cross-platform multiple screenshots module in pure python using ctypes." +groups = ["default"] +files = [ + {file = "mss-10.0.0-py3-none-any.whl", hash = "sha256:82cf6460a53d09e79b7b6d871163c982e6c7e9649c426e7b7591b74956d5cb64"}, + {file = "mss-10.0.0.tar.gz", hash = "sha256:d903e0d51262bf0f8782841cf16eaa6d7e3e1f12eae35ab41c2e318837c6637f"}, +] + [[package]] name = "mypy" version = "1.15.0" @@ -1445,7 +1456,7 @@ files = [ name = "pynput" version = "1.8.1" summary = "Monitor and control user input devices" -groups = ["pynput"] +groups = ["default"] dependencies = [ "enum34; python_version == \"2.7\"", "evdev>=1.3; \"linux\" in sys_platform", @@ -1464,7 +1475,7 @@ name = "pyobjc-core" version = "11.0" requires_python = ">=3.8" summary = "Python<->ObjC Interoperability Module" -groups = ["pynput"] +groups = ["default"] marker = "sys_platform == \"darwin\"" files = [ {file = "pyobjc_core-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:10866b3a734d47caf48e456eea0d4815c2c9b21856157db5917b61dee06893a1"}, @@ -1480,7 +1491,7 @@ name = "pyobjc-framework-applicationservices" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the framework ApplicationServices on macOS" -groups = ["pynput"] +groups = ["default"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1502,7 +1513,7 @@ name = "pyobjc-framework-cocoa" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the Cocoa frameworks on macOS" -groups = ["pynput"] +groups = ["default"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1521,7 +1532,7 @@ name = "pyobjc-framework-coretext" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the framework CoreText on macOS" -groups = ["pynput"] +groups = ["default"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1542,7 +1553,7 @@ name = "pyobjc-framework-quartz" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the Quartz frameworks on macOS" -groups = ["pynput"] +groups = ["default"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1683,7 +1694,7 @@ files = [ name = "python-xlib" version = "0.33" summary = "Python X Library" -groups = ["pynput"] +groups = ["default"] marker = "\"linux\" in sys_platform" dependencies = [ "six>=1.10.0", @@ -1967,7 +1978,7 @@ name = "six" version = "1.17.0" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" summary = "Python 2 and 3 compatibility utilities" -groups = ["default", "chat", "pynput"] +groups = ["default", "chat"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -2193,6 +2204,17 @@ files = [ {file = "types_protobuf-5.29.1.20250403.tar.gz", hash = "sha256:7ff44f15022119c9d7558ce16e78b2d485bf7040b4fadced4dd069bb5faf77a2"}, ] +[[package]] +name = "types-pynput" +version = "1.8.1.20250318" +requires_python = ">=3.9" +summary = "Typing stubs for pynput" +groups = ["test"] +files = [ + {file = "types_pynput-1.8.1.20250318-py3-none-any.whl", hash = "sha256:0c1038aa1550941633114a2728ad85e392f67dfba970aebf755e369ab57aca70"}, + {file = "types_pynput-1.8.1.20250318.tar.gz", hash = "sha256:13d4df97843a7d1e7cddccbf9987aca7f0d463b214a8a35b4f53275d2c5a3576"}, +] + [[package]] name = "types-pyperclip" version = "1.9.0.20250218" diff --git a/pyproject.toml b/pyproject.toml index eaec5221..bb60740e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ dependencies = [ "segment-analytics-python>=2.3.4", "py-machineid>=0.7.0", "httpx>=0.28.1", + "pynput>=1.8.1", + "mss>=10.0.0", ] requires-python = ">=3.10" readme = "README.md" @@ -35,15 +37,20 @@ build-backend = "hatchling.build" [tool.hatch.version] path = "src/askui/__init__.py" + [tool.pdm] distribution = true [tool.pdm.scripts] test = "pytest -n auto" +"test:cov" = "pytest -n auto --cov=src/askui --cov-report=html" +"test:cov:view" = "python -m http.server --directory htmlcov" "test:e2e" = "pytest -n auto tests/e2e" +"test:e2e:cov" = "pytest -n auto tests/e2e --cov=src/askui --cov-report=html" "test:integration" = "pytest -n auto tests/integration" +"test:integration:cov" = "pytest -n auto tests/integration --cov=src/askui --cov-report=html" "test:unit" = "pytest -n auto tests/unit" -"test:cov:view" = "python -m http.server --directory htmlcov" +"test:unit:cov" = "pytest -n auto tests/unit --cov=src/askui --cov-report=html" format = "ruff format src tests" lint = "ruff check src tests" "lint:fix" = "ruff check --fix src tests" @@ -59,9 +66,6 @@ chat = [ mcp = [ "mcp[cli,rich,ws]>=1.8.1", ] -pynput = [ - "pynput>=1.8.1", -] test = [ "pytest>=8.3.4", "ruff>=0.9.5", @@ -76,11 +80,11 @@ test = [ "grpc-stubs>=1.53.0.3", "types-pyperclip>=1.8.2.20240311", "pytest-timeout>=2.4.0", + "types-pynput>=1.8.1.20250318", ] [tool.pytest.ini_options] -addopts = "--cov=src/askui --cov-report=html" python_classes = ["Test*"] python_files = ["test_*.py"] python_functions = ["test_*"] diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 9f5c0ecc..e072b35a 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -15,6 +15,8 @@ from askui.chat.exceptions import FunctionExecutionError, InvalidFunctionError from askui.models import ModelName from askui.reporting import Reporter +from askui.tools.pynput.pynput_agent_os import PynputAgentOs +from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import base64_to_image, draw_point_on_image # TODO Start backend server @@ -24,11 +26,12 @@ page_icon="💬", ) + +# TODO Tool, pynput alternatively BASE_DIR = Path("./chat") threads_api = ThreadsApi(BASE_DIR) messages_api = MessagesApi(BASE_DIR) - -click_recorder = ClickRecorder() # TODO Tool, pynput alternatively +click_recorder = ClickRecorder() def get_image(img_b64_str_or_path: str) -> Image.Image: # TODO Image utils @@ -151,6 +154,7 @@ def rerun() -> None: st.markdown("### Re-running...") with VisionAgent( log_level=logging.DEBUG, + tools=tools, ) as agent: screenshot: Image.Image | None = None for message in messages_api.list_(st.session_state.thread_id).data: @@ -262,6 +266,9 @@ def rerun() -> None: reporter = ChatHistoryAppender(thread_id) +tools = AgentToolbox(agent_os=PynputAgentOs(reporter=reporter)) + + st.title(f"Vision Agent Chat - {thread_id}") # Display chat history @@ -300,15 +307,16 @@ def rerun() -> None: ) reporter.add_message( role="User (Demonstration)", - content=f"mouse({coordinates[0]}, {coordinates[1]})", + content=f"mouse_move({coordinates[0]}, {coordinates[1]})", image=draw_point_on_image(image, coordinates[0], coordinates[1]), ) st.rerun() if act_prompt := st.chat_input("Ask AI"): - with VisionAgent( + with VisionAgent( # we need the vision agent log_level=logging.DEBUG, reporters=[reporter], + tools=tools, ) as agent: agent.act(act_prompt, model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022) st.rerun() diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 04312a60..ffe60186 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -301,7 +301,7 @@ def execute_act(self, message_history: list[dict[str, Any]]) -> None: action = message.parsed_action if action.action_type == "click": - self._agent_os.mouse(action.start_box.x, action.start_box.y) + self._agent_os.mouse_move(action.start_box.x, action.start_box.y) self._agent_os.click("left") time.sleep(1) if action.action_type == "type": diff --git a/src/askui/reporting.py b/src/askui/reporting.py index 12360407..df30542d 100644 --- a/src/askui/reporting.py +++ b/src/askui/reporting.py @@ -304,8 +304,8 @@ def generate(self) -> None: {% for image in msg.images %}
Message image + class="message-image" + alt="Message image"> {% endfor %} diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index cb719899..5623a4c0 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -175,7 +175,7 @@ def screenshot(self, report: bool = True) -> Image.Image: raise NotImplementedError @abstractmethod - def mouse(self, x: int, y: int) -> None: + def mouse_move(self, x: int, y: int) -> None: """ Moves the mouse cursor to specified screen coordinates. @@ -293,12 +293,12 @@ def keyboard_tap( raise NotImplementedError @abstractmethod - def set_display(self, displayNumber: int = 1) -> None: + def set_display(self, display: int = 1) -> None: """ Sets the active display for screen interactions. Args: - displayNumber (int, optional): The display number to set as active. + display (int, optional): The display ID to set as active. Defaults to `1`. """ raise NotImplementedError diff --git a/src/askui/tools/anthropic/computer.py b/src/askui/tools/anthropic/computer.py index 33991b78..1dbc12fe 100644 --- a/src/askui/tools/anthropic/computer.py +++ b/src/askui/tools/anthropic/computer.py @@ -208,8 +208,6 @@ class ComputerTool(BaseAnthropicTool): name: Literal["computer"] = "computer" api_type: Literal["computer_20241022"] = "computer_20241022" - width: int - height: int _screenshot_delay = 2.0 _scaling_enabled = True @@ -217,22 +215,20 @@ class ComputerTool(BaseAnthropicTool): @property def options(self) -> ComputerToolOptions: return { - "display_width_px": self.width, - "display_height_px": self.height, + "display_width_px": self._width, + "display_height_px": self._height, } def to_params(self) -> BetaToolComputerUse20241022Param: return {"name": self.name, "type": self.api_type, **self.options} - def __init__(self, controller_client: AgentOs) -> None: + def __init__(self, agent_os: AgentOs) -> None: super().__init__() - self.controller_client = controller_client - - self.width = 1280 - self.height = 800 - - self.real_screen_width = None - self.real_screen_height = None + self._agent_os = agent_os + self._width = 1280 + self._height = 800 + self._real_screen_width: int | None = None + self._real_screen_height: int | None = None def __call__( # noqa: C901 self, @@ -264,20 +260,20 @@ def __call__( # noqa: C901 x, y = scale_coordinates_back( coordinate[0], coordinate[1], - self.real_screen_width, - self.real_screen_height, - self.width, - self.height, + self._real_screen_width, + self._real_screen_height, + self._width, + self._height, ) x, y = int(x), int(y) if action == "mouse_move": - self.controller_client.mouse(x, y) + self._agent_os.mouse_move(x, y) return ToolResult() if action == "left_click_drag": - self.controller_client.mouse_down("left") - self.controller_client.mouse(x, y) - self.controller_client.mouse_up("left") + self._agent_os.mouse_down("left") + self._agent_os.mouse_move(x, y) + self._agent_os.mouse_up("left") return ToolResult() if action in ("key", "type"): @@ -300,11 +296,11 @@ def __call__( # noqa: C901 f"Key {text} is not a valid PC_KEY from {', '.join(PC_KEY)}" ) raise ToolError(error_msg) - self.controller_client.keyboard_pressed(text) - self.controller_client.keyboard_release(text) + self._agent_os.keyboard_pressed(text) + self._agent_os.keyboard_release(text) return ToolResult() if action == "type": - self.controller_client.type(text) + self._agent_os.type(text) return ToolResult() if action in ( @@ -328,16 +324,16 @@ def __call__( # noqa: C901 error_msg = "cursor_position is not implemented by this agent" raise ToolError(error_msg) if action == "left_click": - self.controller_client.click("left") + self._agent_os.click("left") return ToolResult() if action == "right_click": - self.controller_client.click("right") + self._agent_os.click("right") return ToolResult() if action == "middle_click": - self.controller_client.click("middle") + self._agent_os.click("middle") return ToolResult() if action == "double_click": - self.controller_client.click("left", 2) + self._agent_os.click("left", 2) return ToolResult() error_msg = f"Invalid action: {action}" @@ -348,9 +344,11 @@ def screenshot(self) -> ToolResult: Take a screenshot of the current screen, scale it and return the base64 encoded image. """ - screenshot = self.controller_client.screenshot() - self.real_screen_width = screenshot.width - self.real_screen_height = screenshot.height - scaled_screenshot = scale_image_with_padding(screenshot, 1280, 800) + screenshot = self._agent_os.screenshot() + self._real_screen_width = screenshot.width + self._real_screen_height = screenshot.height + scaled_screenshot = scale_image_with_padding( + screenshot, self._width, self._height + ) base64_image = image_to_base64(scaled_screenshot) return ToolResult(base64_image=base64_image) diff --git a/src/askui/tools/askui/askui_controller.py b/src/askui/tools/askui/askui_controller.py index fc9f01be..7c385bbf 100644 --- a/src/askui/tools/askui/askui_controller.py +++ b/src/askui/tools/askui/askui_controller.py @@ -427,7 +427,7 @@ def screenshot(self, report: bool = True) -> Image.Image: @telemetry.record_call() @override - def mouse(self, x: int, y: int) -> None: + def mouse_move(self, x: int, y: int) -> None: """ Moves the mouse cursor to specified screen coordinates. @@ -437,7 +437,7 @@ def mouse(self, x: int, y: int) -> None: """ self._reporter.add_message( "AgentOS", - f"mouse({x}, {y})", + f"mouse_move({x}, {y})", draw_point_on_image(self.screenshot(report=False), x, y, size=5), ) self._run_recorder_action( @@ -687,22 +687,22 @@ def keyboard_tap( @telemetry.record_call() @override - def set_display(self, displayNumber: int = 1) -> None: + def set_display(self, display: int = 1) -> None: """ Set the active display. Args: - displayNumber (int, optional): The display number to set as active. + display (int, optional): The display ID to set as active. Defaults to `1`. """ assert isinstance(self._stub, controller_v1.ControllerAPIStub), ( "Stub is not initialized" ) - self._reporter.add_message("AgentOS", f"set_display({displayNumber})") + self._reporter.add_message("AgentOS", f"set_display({display})") self._stub.SetActiveDisplay( - controller_v1_pbs.Request_SetActiveDisplay(displayID=displayNumber) + controller_v1_pbs.Request_SetActiveDisplay(displayID=display) ) - self._display = displayNumber + self._display = display @telemetry.record_call(exclude={"command"}) @override diff --git a/src/askui/tools/pynput/__init__.py b/src/askui/tools/pynput/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/tools/pynput/pynput_agent_os.py b/src/askui/tools/pynput/pynput_agent_os.py new file mode 100644 index 00000000..cc00d846 --- /dev/null +++ b/src/askui/tools/pynput/pynput_agent_os.py @@ -0,0 +1,328 @@ +import time +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast + +from mss import mss +from PIL import Image +from pynput.keyboard import Controller as KeyboardController +from pynput.keyboard import Key, KeyCode +from pynput.mouse import Button +from pynput.mouse import Controller as MouseController +from typing_extensions import override + +from askui.reporting import Reporter +from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.utils.image_utils import draw_point_on_image + +if TYPE_CHECKING: + from mss.screenshot import ScreenShot + + +_KEY_MAP: dict[PcKey | ModifierKey, Key | KeyCode] = { + "backspace": Key.backspace, + "delete": Key.delete, + "enter": Key.enter, + "tab": Key.tab, + "escape": Key.esc, + "up": Key.up, + "down": Key.down, + "right": Key.right, + "left": Key.left, + "home": Key.home, + "end": Key.end, + "pageup": Key.page_up, + "pagedown": Key.page_down, + "f1": Key.f1, + "f2": Key.f2, + "f3": Key.f3, + "f4": Key.f4, + "f5": Key.f5, + "f6": Key.f6, + "f7": Key.f7, + "f8": Key.f8, + "f9": Key.f9, + "f10": Key.f10, + "f11": Key.f11, + "f12": Key.f12, + "space": Key.space, + "command": Key.cmd, + "alt": Key.alt, + "control": Key.ctrl, + "shift": Key.shift, + "right_shift": Key.shift_r, +} + + +_BUTTON_MAP: dict[Literal["left", "middle", "right"], Button] = { + "left": Button.left, + "middle": Button.middle, + "right": Button.right, +} + + +T = TypeVar("T") +F = TypeVar("F", bound=Callable[..., Any]) +C = TypeVar("C", bound=type) + + +def await_action(pre_action_wait: float, post_action_wait: float) -> Callable[[F], F]: + def wrapper(func: F) -> F: + @wraps(func) + def _wrapper(*args: Any, **kwargs: Any) -> Any: + time.sleep(pre_action_wait) + result = func(*args, **kwargs) + time.sleep(post_action_wait) + return result + + return cast("F", _wrapper) + + return wrapper + + +def decorate_all_methods( + pre_action_wait: float, post_action_wait: float +) -> Callable[[C], C]: + def decorate(cls: C) -> C: + for attr_name, attr in cls.__dict__.items(): + if callable(attr) and not attr_name.startswith("__"): + setattr( + cls, + attr_name, + await_action(pre_action_wait, post_action_wait)( + cast("Callable[..., Any]", attr) + ), + ) + return cls + + return decorate + + +@decorate_all_methods(pre_action_wait=0.1, post_action_wait=0.1) +class PynputAgentOs(AgentOs): + """ + Implementation of AgentOs using `pynput` for mouse and keyboard control, and `mss` + for screenshots. + + Args: + reporter (Reporter): Reporter used for reporting with the `AgentOs`. + display (int, optional): Display number to use. Defaults to `1`. + """ + + def __init__( + self, + reporter: Reporter, + display: int = 1, + ) -> None: + self._mouse = MouseController() + self._keyboard = KeyboardController() + self._sct = mss() + self._display = display + self._reporter = reporter + + @override + def connect(self) -> None: + """No connection needed for pynput.""" + + @override + def disconnect(self) -> None: + """No disconnection needed for pynput.""" + + @override + def screenshot(self, report: bool = True) -> Image.Image: + """ + Take a screenshot of the current screen. + + Args: + report (bool, optional): Whether to include the screenshot in reporting. + Defaults to `True`. + + Returns: + Image.Image: A PIL Image object containing the screenshot. + """ + monitor = self._sct.monitors[self._display] + screenshot: ScreenShot = self._sct.grab(monitor) + image = Image.frombytes( # type: ignore[arg-type] + "RGB", + screenshot.size, + screenshot.rgb, + ) + + scaled_size = (monitor["width"], monitor["height"]) + image = image.resize(scaled_size, Image.Resampling.LANCZOS) + + if report: + self._reporter.add_message("AgentOS", "screenshot()", image) + return image + + @override + def mouse_move(self, x: int, y: int) -> None: + """ + Move the mouse cursor to specified screen coordinates. + + Args: + x (int): The horizontal coordinate (in pixels) to move to. + y (int): The vertical coordinate (in pixels) to move to. + """ + self._reporter.add_message( + "AgentOS", + f"mouse_move({x}, {y})", + draw_point_on_image(self.screenshot(report=False), x, y, size=5), + ) + self._mouse.position = (x, y) + + @override + def type(self, text: str, typing_speed: int = 50) -> None: + """ + Type text at current cursor position as if entered on a keyboard. + + Args: + text (str): The text to type. + typing_speed (int, optional): The speed of typing in characters per second. + Defaults to `50`. + """ + self._reporter.add_message("AgentOS", f'type("{text}", {typing_speed})') + delay = 1.0 / typing_speed + for char in text: + self._keyboard.press(char) + self._keyboard.release(char) + time.sleep(delay) + + @override + def click( + self, button: Literal["left", "middle", "right"] = "left", count: int = 1 + ) -> None: + """ + Click a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + click. Defaults to `"left"`. + count (int, optional): Number of times to click. Defaults to `1`. + """ + self._reporter.add_message("AgentOS", f'click("{button}", {count})') + pynput_button = _BUTTON_MAP[button] + for _ in range(count): + self._mouse.click(pynput_button) + + @override + def mouse_down(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Press and hold a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + press. Defaults to `"left"`. + """ + self._reporter.add_message("AgentOS", f'mouse_down("{button}")') + self._mouse.press(_BUTTON_MAP[button]) + + @override + def mouse_up(self, button: Literal["left", "middle", "right"] = "left") -> None: + """ + Release a mouse button. + + Args: + button (Literal["left", "middle", "right"], optional): The mouse button to + release. Defaults to "left". + """ + self._reporter.add_message("AgentOS", f'mouse_up("{button}")') + self._mouse.release(_BUTTON_MAP[button]) + + @override + def mouse_scroll(self, x: int, y: int) -> None: + """ + Scroll the mouse wheel. + + Args: + x (int): The horizontal scroll amount. Positive values scroll right, + negative values scroll left. + y (int): The vertical scroll amount. Positive values scroll down, + negative values scroll up. + """ + self._reporter.add_message("AgentOS", f"mouse_scroll({x}, {y})") + self._mouse.scroll(x, y) + + def _get_pynput_key(self, key: PcKey | ModifierKey) -> Key | KeyCode | str: + """Convert our key type to pynput key.""" + if key in _KEY_MAP: + return _KEY_MAP[key] + return key # For regular characters + + @override + def keyboard_pressed( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Press and hold a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to press. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + "AgentOS", f'keyboard_pressed("{key}", {modifier_keys})' + ) + if modifier_keys: + for mod in modifier_keys: + self._keyboard.press(_KEY_MAP[mod]) + self._keyboard.press(self._get_pynput_key(key)) + + @override + def keyboard_release( + self, key: PcKey | ModifierKey, modifier_keys: list[ModifierKey] | None = None + ) -> None: + """ + Release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to release. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + release along with the main key. Defaults to `None`. + """ + self._reporter.add_message( + "AgentOS", f'keyboard_release("{key}", {modifier_keys})' + ) + self._keyboard.release(self._get_pynput_key(key)) + if modifier_keys: + for mod in reversed(modifier_keys): # Release in reverse order + self._keyboard.release(_KEY_MAP[mod]) + + @override + def keyboard_tap( + self, + key: PcKey | ModifierKey, + modifier_keys: list[ModifierKey] | None = None, + count: int = 1, + ) -> None: + """ + Press and immediately release a keyboard key. + + Args: + key (PcKey | ModifierKey): The key to tap. + modifier_keys (list[ModifierKey] | None, optional): List of modifier keys to + press along with the main key. Defaults to `None`. + count (int, optional): The number of times to tap the key. Defaults to `1`. + """ + self._reporter.add_message( + "AgentOS", + f'keyboard_tap("{key}", {modifier_keys}, {count})', + ) + for _ in range(count): + self.keyboard_pressed(key, modifier_keys) + self.keyboard_release(key, modifier_keys) + + @override + def set_display(self, display: int = 1) -> None: + """ + Set the active display. + + Args: + display (int, optional): The display ID to set as active. + Defaults to `1`. + """ + self._reporter.add_message("AgentOS", f"set_display({display})") + if display < 1 or len(self._sct.monitors) <= display: + error_msg = f"Display {display} not found" + raise ValueError(error_msg) + self._display = display diff --git a/tests/e2e/tools/__init__.py b/tests/e2e/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/e2e/tools/pynput/__init__.py b/tests/e2e/tools/pynput/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/e2e/tools/pynput/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/integration/tools/askui/test_askui_controller.py b/tests/integration/tools/askui/test_askui_controller.py index 52bc178b..76a51105 100644 --- a/tests/integration/tools/askui/test_askui_controller.py +++ b/tests/integration/tools/askui/test_askui_controller.py @@ -37,5 +37,5 @@ def test_find_remote_device_controller_by_component_registry( def test_actions(controller_client: AskUiControllerClient) -> None: with controller_client: controller_client.screenshot() - controller_client.mouse(0, 0) + controller_client.mouse_move(0, 0) controller_client.click() From 185f920c219e5c9b4971e4784c02d0cc5bb1e627 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 20 May 2025 13:00:08 +0200 Subject: [PATCH 05/20] feat(tools,chat): add click event listening/recording --- src/askui/chat/__main__.py | 39 +++++- src/askui/chat/api/messages.py | 4 +- src/askui/tools/agent_os.py | 157 ++++++++++++++++++++++ src/askui/tools/pynput/pynput_agent_os.py | 77 ++++++++++- 4 files changed, 268 insertions(+), 9 deletions(-) diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index e072b35a..0d513b06 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -2,7 +2,7 @@ import logging import re from pathlib import Path -from typing import Union, cast +from typing import Any, Union, cast import streamlit as st from PIL import Image, ImageDraw @@ -43,7 +43,7 @@ def get_image(img_b64_str_or_path: str) -> Image.Image: # TODO Image utils def write_message( # TODO updating frontend role: str, - content: str | dict | list, + content: str | dict[str, Any] | list[Any], timestamp: str, image: Image.Image | str @@ -92,7 +92,7 @@ def __init__(self, thread_id: str) -> None: def add_message( self, role: str, - content: Union[str, dict, list], + content: Union[str, dict[str, Any], list[Any]], image: Image.Image | list[Image.Image] | None = None, ) -> None: message = messages_api.create( @@ -266,7 +266,13 @@ def rerun() -> None: reporter = ChatHistoryAppender(thread_id) -tools = AgentToolbox(agent_os=PynputAgentOs(reporter=reporter)) + +@st.cache_resource +def get_tools() -> AgentToolbox: + return AgentToolbox(agent_os=PynputAgentOs(reporter=reporter)) + + +tools = get_tools() st.title(f"Vision Agent Chat - {thread_id}") @@ -312,6 +318,31 @@ def rerun() -> None: ) st.rerun() +if st.session_state.get("input_event_listening"): + while input_event := tools.os.poll_event(): + image = tools.os.screenshot(report=False) + if input_event.pressed: + reporter.add_message( + role="User (Demonstration)", + content=f"mouse_move({input_event.x}, {input_event.y})", + image=draw_point_on_image(image, input_event.x, input_event.y), + ) + reporter.add_message( + role="User (Demonstration)", + content=f'click("{input_event.button}")', + ) + if st.button("Refresh"): + st.rerun() + if st.button("Stop listening to input events"): + tools.os.stop_listening() + st.session_state["input_event_listening"] = False + st.rerun() +else: + if st.button("Listen to input events"): + tools.os.start_listening() + st.session_state["input_event_listening"] = True + st.rerun() + if act_prompt := st.chat_input("Ask AI"): with VisionAgent( # we need the vision agent log_level=logging.DEBUG, diff --git a/src/askui/chat/api/messages.py b/src/askui/chat/api/messages.py index 0ee4af38..7140176f 100644 --- a/src/askui/chat/api/messages.py +++ b/src/askui/chat/api/messages.py @@ -1,7 +1,7 @@ from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Sequence, Union +from typing import Any, Sequence, Union from PIL import Image from pydantic import AwareDatetime, BaseModel, Field @@ -111,7 +111,7 @@ def create( self, thread_id: str, role: str, - content: Union[str, dict, list], + content: Union[str, dict[str, Any], list[Any]], image: Image.Image | list[Image.Image] | None = None, ) -> Message: """Create a new message in a thread. diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index 5623a4c0..8f65e5e5 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -2,10 +2,13 @@ from typing import Literal from PIL import Image +from pydantic import BaseModel ModifierKey = Literal["command", "alt", "control", "shift", "right_shift"] """Modifier keys for keyboard actions.""" +ModifierKeys: list[ModifierKey] = ["command", "alt", "control", "shift", "right_shift"] + PcKey = Literal[ "backspace", "delete", @@ -130,6 +133,142 @@ ] """PC keys for keyboard actions.""" +PcKeys: list[PcKey] = [ + "backspace", + "delete", + "enter", + "tab", + "escape", + "up", + "down", + "right", + "left", + "home", + "end", + "pageup", + "pagedown", + "f1", + "f2", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "f10", + "f11", + "f12", + "space", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + "Q", + "R", + "S", + "T", + "U", + "V", + "W", + "X", + "Y", + "Z", + "!", + '"', + "#", + "$", + "%", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "<", + "=", + ">", + "?", + "@", + "[", + "\\", + "]", + "^", + "_", + "`", + "{", + "|", + "}", + "~", +] + + +class ClickEvent(BaseModel): + type: Literal["click"] = "click" + x: int + y: int + button: Literal["left", "middle", "right", "unknown"] + pressed: bool + injected: bool = False + timestamp: float + + +InputEvent = ClickEvent + class AgentOs(ABC): """ @@ -315,3 +454,21 @@ def run_command(self, command: str, timeout_ms: int = 30000) -> None: """ raise NotImplementedError + + def start_listening(self) -> None: + """ + Start listening for mouse and keyboard events. + """ + raise NotImplementedError + + @abstractmethod + def poll_event(self) -> InputEvent | None: + """ + Poll for a single input event. + """ + raise NotImplementedError + + @abstractmethod + def stop_listening(self) -> None: + """Stop listening for mouse and keyboard events.""" + raise NotImplementedError diff --git a/src/askui/tools/pynput/pynput_agent_os.py b/src/askui/tools/pynput/pynput_agent_os.py index cc00d846..861ca868 100644 --- a/src/askui/tools/pynput/pynput_agent_os.py +++ b/src/askui/tools/pynput/pynput_agent_os.py @@ -1,3 +1,6 @@ +import ctypes +import platform +import queue import time from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast @@ -8,16 +11,24 @@ from pynput.keyboard import Key, KeyCode from pynput.mouse import Button from pynput.mouse import Controller as MouseController +from pynput.mouse import Listener as MouseListener from typing_extensions import override +from askui.logger import logger from askui.reporting import Reporter -from askui.tools.agent_os import AgentOs, ModifierKey, PcKey +from askui.tools.agent_os import AgentOs, InputEvent, ModifierKey, PcKey from askui.utils.image_utils import draw_point_on_image +if platform.system() == "Windows": + try: + PROCESS_PER_MONITOR_DPI_AWARE = 2 + ctypes.windll.shcore.SetProcessDpiAwareness(PROCESS_PER_MONITOR_DPI_AWARE) # type: ignore[attr-defined] + except Exception as e: # noqa: BLE001 + logger.error(f"Could not set DPI awareness: {e}") + if TYPE_CHECKING: from mss.screenshot import ScreenShot - _KEY_MAP: dict[PcKey | ModifierKey, Key | KeyCode] = { "backspace": Key.backspace, "delete": Key.delete, @@ -53,10 +64,18 @@ } -_BUTTON_MAP: dict[Literal["left", "middle", "right"], Button] = { +_BUTTON_MAP: dict[Literal["left", "middle", "right", "unknown"], Button] = { "left": Button.left, "middle": Button.middle, "right": Button.right, + "unknown": Button.unknown, +} + +_BUTTON_MAP_REVERSE: dict[Button, Literal["left", "middle", "right", "unknown"]] = { + Button.left: "left", + Button.middle: "middle", + Button.right: "right", + Button.unknown: "unknown", } @@ -118,6 +137,8 @@ def __init__( self._sct = mss() self._display = display self._reporter = reporter + self._mouse_listener: MouseListener | None = None + self._input_event_queue: queue.Queue[InputEvent] = queue.Queue() @override def connect(self) -> None: @@ -326,3 +347,53 @@ def set_display(self, display: int = 1) -> None: error_msg = f"Display {display} not found" raise ValueError(error_msg) self._display = display + + def _on_mouse_click( + self, x: float, y: float, button: Button, pressed: bool, injected: bool + ) -> None: + """Handle mouse click events.""" + self._input_event_queue.put( + InputEvent( + x=int(x), + y=int(y), + button=_BUTTON_MAP_REVERSE[button], + pressed=pressed, + injected=injected, + timestamp=time.time(), + ) + ) + + @override + def start_listening(self) -> None: + """ + Start listening for mouse and keyboard events. + + Args: + callback (InputEventCallback): Callback function that will be called for + each event. + """ + if self._mouse_listener: + self.stop_listening() + self._mouse_listener = MouseListener( + on_click=self._on_mouse_click, # type: ignore[arg-type] + name="PynputAgentOsMouseListener", + args=(self._input_event_queue,), + ) + self._mouse_listener.start() + + @override + def poll_event(self) -> InputEvent | None: + """Poll for a single input event.""" + try: + return self._input_event_queue.get(False) + except queue.Empty: + return None + + @override + def stop_listening(self) -> None: + """Stop listening for mouse and keyboard events.""" + if self._mouse_listener: + self._mouse_listener.stop() + self._mouse_listener = None + while not self._input_event_queue.empty(): + self._input_event_queue.get() From b100ebe0ae1682ddc05d993d88b7be000252c130 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 3 Jun 2025 06:42:33 +0200 Subject: [PATCH 06/20] chore: add cursor rules --- .cursorrules | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 .cursorrules diff --git a/.cursorrules b/.cursorrules new file mode 100644 index 00000000..bb1615b3 --- /dev/null +++ b/.cursorrules @@ -0,0 +1,67 @@ +# Code Style +- Use `_` prefix for all private variables, constants, functions, methods, properties, etc. that don't need to be accessible from outside the module +- Mark everything as private that does not absolutely need to be accessible from outside the module +- Use `@override` (imported from `typing_extensions`) decorator for all methods that override a parent class method +- Use type hints for all variables, functions, methods, properties, return values, parameters etc. +- Omit `Any` within type hints unless absolutely necessary +- Use built-in types (e.g., `list`, `dict`, `tuple`, `set`, `str | None`) instead of types from `typing` module (e.g. `List`, `Dict`, `Tuple`, `Set`, `Optional`, `Union`) wherever possible + +# Testing +- Use `pytest-mock` for mocking in tests wherever you need to mock something and pytest-mock can do the job. + +# Documentation + +## Docstrings +- All public functions, constants, classes, types etc. should have docstrings +- Document the constructor (`__init__`) args as part of the class docstring +- Omit the `__init__` docstring +- All function parameter should be documented with their type (followed by `, optional` if there is a default value) in parenthesis and description +- In descriptions, use backticks for all code references (variables, types, etc.), including types, e.g., `str` +- When referencing a function, use the function name in backticks plus parentheses, e.g., `click()` +- When referencing a class, use the class name in backticks, e.g., `VisionAgent` +- When referencing a method, use the class name in backticks plus the method name in parentheses, e.g., `VisionAgent.click()` +- When referencing a class attribute, use the class name in backticks plus the attribute name, e.g., `VisionAgent.display` +- Use `Example` section for code examples +- Use `Returns` section for return values +- Use `Raises` section for exceptions listing all possible exceptions that can be raised by the function +- Use `Notes` section for additional notes +- Use `See Also` section for related functions +- Use `References` section for references +- Use `Examples` section for code examples +- Example of a good docstring: + ```python + def locate( + self, + locator: str | Locator, + screenshot: Img | None = None, + model: ModelComposition | str | None = None, + ) -> Point: + """ + Find the position of the UI element identified by the `locator` using the `model`. + + Args: + locator (str | Locator): The identifier or description of the element to locate. + screenshot (Img | None, optional): The screenshot to use for locating the + element. Can be a path to an image file, a PIL Image object or a data URL. + If `None`, takes a screenshot of the currently selected screen. + model (ModelComposition | str | None, optional): The composition or name of + the model(s) to be used for locating the element using the `locator`. + + Returns: + Point: The coordinates of a point on the element, usually the center of the element, as a tuple (x, y). + + Raises: + ValueError: If the arguments are not of the correct type. + ElementNotFoundError: If no element can be found. + + Example: + ```python + from askui import VisionAgent + + with VisionAgent() as agent: + point = agent.locate("Submit button") + print(f"Element found at coordinates: {point}") + ``` + """ + ... + ``` From 1fc7e83a6ac305c31115d7436104a7f36dbfe369 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 3 Jun 2025 07:43:42 +0200 Subject: [PATCH 07/20] feat(chat): wrap api into fastapi app --- .cursorrules | 5 ++ pdm.lock | 40 ++++++--- pyproject.toml | 5 +- src/askui/chat/api/dependencies.py | 11 +++ src/askui/chat/api/fastapi.py | 28 +++++++ src/askui/chat/api/images/__init__.py | 0 src/askui/chat/api/images/dependencies.py | 0 src/askui/chat/api/images/router.py | 19 +++++ src/askui/chat/api/messages/__init__.py | 0 src/askui/chat/api/messages/dependencies.py | 13 +++ src/askui/chat/api/messages/router.py | 83 +++++++++++++++++++ .../api/{messages.py => messages/service.py} | 6 +- src/askui/chat/api/settings.py | 17 ++++ src/askui/chat/api/threads/__init__.py | 0 src/askui/chat/api/threads/dependencies.py | 13 +++ src/askui/chat/api/threads/router.py | 47 +++++++++++ .../api/{threads.py => threads/service.py} | 12 +-- 17 files changed, 277 insertions(+), 22 deletions(-) create mode 100644 src/askui/chat/api/dependencies.py create mode 100644 src/askui/chat/api/fastapi.py create mode 100644 src/askui/chat/api/images/__init__.py create mode 100644 src/askui/chat/api/images/dependencies.py create mode 100644 src/askui/chat/api/images/router.py create mode 100644 src/askui/chat/api/messages/__init__.py create mode 100644 src/askui/chat/api/messages/dependencies.py create mode 100644 src/askui/chat/api/messages/router.py rename src/askui/chat/api/{messages.py => messages/service.py} (98%) create mode 100644 src/askui/chat/api/settings.py create mode 100644 src/askui/chat/api/threads/__init__.py create mode 100644 src/askui/chat/api/threads/dependencies.py create mode 100644 src/askui/chat/api/threads/router.py rename src/askui/chat/api/{threads.py => threads/service.py} (93%) diff --git a/.cursorrules b/.cursorrules index bb1615b3..de6d9034 100644 --- a/.cursorrules +++ b/.cursorrules @@ -5,6 +5,11 @@ - Use type hints for all variables, functions, methods, properties, return values, parameters etc. - Omit `Any` within type hints unless absolutely necessary - Use built-in types (e.g., `list`, `dict`, `tuple`, `set`, `str | None`) instead of types from `typing` module (e.g. `List`, `Dict`, `Tuple`, `Set`, `Optional`, `Union`) wherever possible +- Create a `__init__.py` file in each folder + +## FastAPI +- Instead of defining `response_model` within route annotation, use the model as the response type in the function signature +- Do not assign `None` to dependencies but instead move it before arguments with default values # Testing - Use `pytest-mock` for mocking in tests wherever you need to mock something and pytest-mock can do the job. diff --git a/pdm.lock b/pdm.lock index 5bc26590..6cd1425f 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "chat", "mcp", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:5d81bb8ca8a8b98d18f4632c5fc1dd409c44d385724e4f321922ae71ec152f33" +content_hash = "sha256:dedebe5b971455dfc718635d1180587b366c6d287b404138c030c4aec77452bc" [[metadata.targets]] requires_python = ">=3.10" @@ -202,7 +202,7 @@ name = "click" version = "8.1.8" requires_python = ">=3.7" summary = "Composable command line interface toolkit" -groups = ["chat", "mcp"] +groups = ["default", "chat", "mcp"] dependencies = [ "colorama; platform_system == \"Windows\"", "importlib-metadata; python_version < \"3.8\"", @@ -398,6 +398,22 @@ files = [ {file = "execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3"}, ] +[[package]] +name = "fastapi" +version = "0.115.12" +requires_python = ">=3.8" +summary = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +groups = ["default"] +dependencies = [ + "pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4", + "starlette<0.47.0,>=0.40.0", + "typing-extensions>=4.8.0", +] +files = [ + {file = "fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d"}, + {file = "fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681"}, +] + [[package]] name = "filelock" version = "3.18.0" @@ -1402,17 +1418,18 @@ files = [ [[package]] name = "pydantic-settings" -version = "2.8.1" -requires_python = ">=3.8" +version = "2.9.1" +requires_python = ">=3.9" summary = "Settings management using Pydantic" groups = ["default", "mcp"] dependencies = [ "pydantic>=2.7.0", "python-dotenv>=0.21.0", + "typing-inspection>=0.4.0", ] files = [ - {file = "pydantic_settings-2.8.1-py3-none-any.whl", hash = "sha256:81942d5ac3d905f7f3ee1a70df5dfb62d5569c12f51a5a647defc1c3d9ee2e9c"}, - {file = "pydantic_settings-2.8.1.tar.gz", hash = "sha256:d5c663dfbe9db9d5e1c646b2e161da12f0d734d422ee56f567d0ea2cee4e8585"}, + {file = "pydantic_settings-2.9.1-py3-none-any.whl", hash = "sha256:59b4f431b1defb26fe620c71a7d3968a710d719f5f4cdbbdb7926edeb770f6ef"}, + {file = "pydantic_settings-2.9.1.tar.gz", hash = "sha256:c509bf79d27563add44e8446233359004ed85066cd096d8b510f715e6ef5d268"}, ] [[package]] @@ -2026,7 +2043,7 @@ name = "starlette" version = "0.46.2" requires_python = ">=3.9" summary = "The little ASGI library that shines." -groups = ["mcp"] +groups = ["default", "mcp"] dependencies = [ "anyio<5,>=3.6.2", "typing-extensions>=3.10.0; python_version < \"3.10\"", @@ -2300,19 +2317,18 @@ files = [ [[package]] name = "uvicorn" -version = "0.34.2" +version = "0.34.3" requires_python = ">=3.9" summary = "The lightning-fast ASGI server." -groups = ["mcp"] -marker = "sys_platform != \"emscripten\"" +groups = ["default", "mcp"] dependencies = [ "click>=7.0", "h11>=0.8", "typing-extensions>=4.0; python_version < \"3.11\"", ] files = [ - {file = "uvicorn-0.34.2-py3-none-any.whl", hash = "sha256:deb49af569084536d269fe0a6d67e3754f104cf03aba7c11c40f01aadf33c403"}, - {file = "uvicorn-0.34.2.tar.gz", hash = "sha256:0e929828f6186353a80b58ea719861d2629d766293b6d19baf086ba31d4f3328"}, + {file = "uvicorn-0.34.3-py3-none-any.whl", hash = "sha256:16246631db62bdfbf069b0645177d6e8a77ba950cfedbfd093acef9444e4d885"}, + {file = "uvicorn-0.34.3.tar.gz", hash = "sha256:35919a9a979d7a59334b6b10e05d77c1d0d574c50e0fc98b8b1a0f165708b55a"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index bb60740e..496501cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,8 +15,8 @@ dependencies = [ "gradio-client>=1.4.3", "requests>=2.32.3", "Jinja2>=3.1.4", - "pydantic-settings>=2.7.0", "tenacity>=9.1.2", + "pydantic-settings>=2.9.1", "python-dateutil>=2.9.0.post0", "openai>=1.61.1", "segment-analytics-python>=2.3.4", @@ -24,6 +24,8 @@ dependencies = [ "httpx>=0.28.1", "pynput>=1.8.1", "mss>=10.0.0", + "fastapi>=0.115.12", + "uvicorn>=0.34.3", ] requires-python = ">=3.10" readme = "README.md" @@ -58,6 +60,7 @@ typecheck = "mypy" "typecheck:all" = "mypy src tests" chat = "streamlit run src/askui/chat/__main__.py" mcp = "mcp dev src/askui/mcp/__init__.py" +api = "uvicorn src.askui.chat.api.fastapi:app --reload --port 8000" [dependency-groups] chat = [ diff --git a/src/askui/chat/api/dependencies.py b/src/askui/chat/api/dependencies.py new file mode 100644 index 00000000..a9c78c2f --- /dev/null +++ b/src/askui/chat/api/dependencies.py @@ -0,0 +1,11 @@ +from fastapi import Depends + +from askui.chat.api.settings import Settings + + +def get_settings() -> Settings: + """Get ChatApiSettings instance.""" + return Settings() + + +SettingsDep = Depends(get_settings) diff --git a/src/askui/chat/api/fastapi.py b/src/askui/chat/api/fastapi.py new file mode 100644 index 00000000..e9d8cf00 --- /dev/null +++ b/src/askui/chat/api/fastapi.py @@ -0,0 +1,28 @@ +from fastapi import APIRouter, FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from askui.chat.api.images.router import router as images_router +from askui.chat.api.messages.router import router as messages_router +from askui.chat.api.threads.router import router as threads_router + +app = FastAPI( + title="AskUI Chat API", + description="REST API for managing chat threads and messages", + version="0.1.0", +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Include routers +v1_router = APIRouter(prefix="/v1") +v1_router.include_router(threads_router) +v1_router.include_router(messages_router) +v1_router.include_router(images_router) +app.include_router(v1_router) diff --git a/src/askui/chat/api/images/__init__.py b/src/askui/chat/api/images/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/images/dependencies.py b/src/askui/chat/api/images/dependencies.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/images/router.py b/src/askui/chat/api/images/router.py new file mode 100644 index 00000000..2377386d --- /dev/null +++ b/src/askui/chat/api/images/router.py @@ -0,0 +1,19 @@ +from fastapi import APIRouter, HTTPException +from fastapi.responses import FileResponse + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.settings import Settings + +router = APIRouter(prefix="/images", tags=["images"]) + + +@router.get("/{image_path:path}") +def get_image( + image_path: str, + settings: Settings = SettingsDep, +) -> FileResponse: + """Get an image by path.""" + full_path = settings.data_dir / "images" / image_path + if not full_path.exists(): + raise HTTPException(status_code=404, detail="Image not found") + return FileResponse(full_path) diff --git a/src/askui/chat/api/messages/__init__.py b/src/askui/chat/api/messages/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/messages/dependencies.py b/src/askui/chat/api/messages/dependencies.py new file mode 100644 index 00000000..86899d14 --- /dev/null +++ b/src/askui/chat/api/messages/dependencies.py @@ -0,0 +1,13 @@ +from fastapi import Depends + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.messages.service import MessageService +from askui.chat.api.settings import Settings + + +def get_message_service(settings: Settings = SettingsDep) -> MessageService: + """Get MessageService instance.""" + return MessageService(settings.data_dir) + + +MessageServiceDep = Depends(get_message_service) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py new file mode 100644 index 00000000..328e5cc2 --- /dev/null +++ b/src/askui/chat/api/messages/router.py @@ -0,0 +1,83 @@ +from io import BytesIO +from typing import Any + +from fastapi import APIRouter, File, HTTPException, UploadFile +from PIL import Image +from pydantic import BaseModel + +from askui.chat.api.messages.dependencies import MessageServiceDep +from askui.chat.api.messages.service import Message, MessageListResponse, MessageService + + +class CreateMessageRequest(BaseModel): + """Request model for creating a message.""" + + role: str + content: str | dict[str, Any] | list[Any] + + +router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) + + +@router.get("", response_model=MessageListResponse) +def list_messages( + thread_id: str, + limit: int | None = None, + message_service: MessageService = MessageServiceDep, +) -> MessageListResponse: + """List all messages in a thread.""" + try: + return message_service.list_(thread_id, limit=limit) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.post("", response_model=Message) +async def create_message( + thread_id: str, + request: CreateMessageRequest, + image: UploadFile | None = File(None), + message_service: MessageService = MessageServiceDep, +) -> Message: + """Create a new message in a thread.""" + try: + # Handle image upload if provided + pil_image = None + if image: + img_data = await image.read() + pil_image = Image.open(BytesIO(img_data)) + + return message_service.create( + thread_id=thread_id, + role=request.role, + content=request.content, + image=pil_image, + ) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.get("/{message_id}", response_model=Message) +def get_message( + thread_id: str, + message_id: str, + message_service: MessageService = MessageServiceDep, +) -> Message: + """Get a specific message from a thread.""" + try: + return message_service.retrieve(thread_id, message_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.delete("/{message_id}") +def delete_message( + thread_id: str, + message_id: str, + message_service: MessageService = MessageServiceDep, +) -> None: + """Delete a message from a thread.""" + try: + message_service.delete(thread_id, message_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) diff --git a/src/askui/chat/api/messages.py b/src/askui/chat/api/messages/service.py similarity index 98% rename from src/askui/chat/api/messages.py rename to src/askui/chat/api/messages/service.py index 7140176f..edf80d0b 100644 --- a/src/askui/chat/api/messages.py +++ b/src/askui/chat/api/messages/service.py @@ -49,8 +49,8 @@ class MessageListResponse(BaseModel): has_more: bool = False -class MessagesApi: - """API for managing messages within threads.""" +class MessageService: + """Service for managing messages within threads.""" ROLE_MAP = { "user": MessageRole.USER, @@ -60,7 +60,7 @@ class MessagesApi: } def __init__(self, base_dir: Path) -> None: - """Initialize messages API. + """Initialize message service. Args: base_dir: Base directory to store message data diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py new file mode 100644 index 00000000..52e3b50c --- /dev/null +++ b/src/askui/chat/api/settings.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + """Settings for the chat API.""" + + model_config = SettingsConfigDict( + env_prefix="ASKUI__CHAT_API__", env_nested_delimiter="__" + ) + + data_dir: Path = Field( + default_factory=lambda: Path.home() / ".askui" / "chat", + description="Base directory for storing chat data", + ) diff --git a/src/askui/chat/api/threads/__init__.py b/src/askui/chat/api/threads/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/threads/dependencies.py b/src/askui/chat/api/threads/dependencies.py new file mode 100644 index 00000000..3d607559 --- /dev/null +++ b/src/askui/chat/api/threads/dependencies.py @@ -0,0 +1,13 @@ +from fastapi import Depends + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.settings import Settings +from askui.chat.api.threads.service import ThreadService + + +def get_thread_service(settings: Settings = SettingsDep) -> ThreadService: + """Get ThreadService instance.""" + return ThreadService(settings.data_dir) + + +ThreadServiceDep = Depends(get_thread_service) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py new file mode 100644 index 00000000..1fae51cb --- /dev/null +++ b/src/askui/chat/api/threads/router.py @@ -0,0 +1,47 @@ +from fastapi import APIRouter, HTTPException + +from askui.chat.api.threads.dependencies import ThreadServiceDep +from askui.chat.api.threads.service import Thread, ThreadListResponse, ThreadService + +router = APIRouter(prefix="/threads", tags=["threads"]) + + +@router.get("", response_model=ThreadListResponse) +def list_threads( + limit: int | None = None, + thread_service: ThreadService = ThreadServiceDep, +) -> ThreadListResponse: + """List all threads.""" + return thread_service.list_(limit=limit) + + +@router.post("", response_model=Thread) +def create_thread( + thread_service: ThreadService = ThreadServiceDep, +) -> Thread: + """Create a new thread.""" + return thread_service.create() + + +@router.get("/{thread_id}", response_model=Thread) +def get_thread( + thread_id: str, + thread_service: ThreadService = ThreadServiceDep, +) -> Thread: + """Get a thread by ID.""" + try: + return thread_service.retrieve(thread_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) + + +@router.delete("/{thread_id}") +def delete_thread( + thread_id: str, + thread_service: ThreadService = ThreadServiceDep, +) -> None: + """Delete a thread.""" + try: + thread_service.delete(thread_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) diff --git a/src/askui/chat/api/threads.py b/src/askui/chat/api/threads/service.py similarity index 93% rename from src/askui/chat/api/threads.py rename to src/askui/chat/api/threads/service.py index 24aa1663..d1799db7 100644 --- a/src/askui/chat/api/threads.py +++ b/src/askui/chat/api/threads/service.py @@ -27,11 +27,11 @@ class ThreadListResponse(BaseModel): has_more: bool = False -class ThreadsApi: - """API for managing chat threads/sessions.""" +class ThreadService: + """Service for managing chat threads/sessions.""" def __init__(self, base_dir: Path) -> None: - """Initialize threads API. + """Initialize thread service. Args: base_dir: Base directory to store thread data @@ -129,11 +129,11 @@ def delete(self, thread_id: str) -> None: raise FileNotFoundError(error_msg) # Get all image paths from messages before deleting thread - from askui.chat.api.messages import MessagesApi + from askui.chat.api.messages.service import MessageService - messages_api = MessagesApi(self._base_dir) + message_service = MessageService(self._base_dir) try: - messages = messages_api.list_(thread_id).data + messages = message_service.list_(thread_id).data for msg in messages: if msg.content and msg.content[0].image_paths: for img_path in msg.content[0].image_paths: From 73ee5c2c3cc3e325083f4d057b4ee5e5b210f38e Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 3 Jun 2025 11:31:48 +0200 Subject: [PATCH 08/20] refactor(models): extract methods etc. in computer agents --- src/askui/models/anthropic/computer_agent.py | 94 ++++++++++---------- src/askui/models/askui/computer_agent.py | 3 +- 2 files changed, 47 insertions(+), 50 deletions(-) diff --git a/src/askui/models/anthropic/computer_agent.py b/src/askui/models/anthropic/computer_agent.py index 94600691..0bffb2a6 100644 --- a/src/askui/models/anthropic/computer_agent.py +++ b/src/askui/models/anthropic/computer_agent.py @@ -1,18 +1,17 @@ import platform import sys from datetime import datetime, timezone -from typing import Any, cast +from typing import cast -from anthropic import Anthropic, APIError, APIResponseValidationError, APIStatusError +from anthropic import Anthropic, APIError +from anthropic._legacy_response import LegacyAPIResponse from anthropic.types.beta import ( BetaCacheControlEphemeralParam, BetaImageBlockParam, BetaMessage, BetaMessageParam, - BetaTextBlock, BetaTextBlockParam, BetaToolResultBlockParam, - BetaToolUseBlockParam, ) from typing_extensions import override @@ -23,7 +22,6 @@ from ...logger import logger from ...tools.anthropic import ComputerTool, ToolCollection, ToolResult -from ...utils.str_utils import truncate_long_strings PC_KEY = [ "backspace", @@ -184,18 +182,11 @@ def __init__( text=f"{SYSTEM_PROMPT}", ) - def step( + def _create_raw_response( self, messages: list[BetaMessageParam], model: str - ) -> list[BetaMessageParam]: - if self._settings.only_n_most_recent_images: - self._maybe_filter_to_n_most_recent_images( - messages, - self._settings.only_n_most_recent_images, - min_removal_threshold=self._settings.image_truncation_threshold, - ) - + ) -> LegacyAPIResponse[BetaMessage] | None: try: - raw_response = self._client.beta.messages.with_raw_response.create( + return self._client.beta.messages.with_raw_response.create( max_tokens=self._settings.max_tokens, messages=messages, model=model, @@ -203,38 +194,57 @@ def step( tools=self._tool_collection.to_params(), betas=self._settings.betas, ) - except (APIStatusError, APIResponseValidationError) as e: - logger.error(e) - return messages except APIError as e: logger.error(e) + return None + + def step( + self, messages: list[BetaMessageParam], model: str + ) -> list[BetaMessageParam]: + if self._settings.only_n_most_recent_images: + self._maybe_filter_to_n_most_recent_images( + messages, + self._settings.only_n_most_recent_images, + min_removal_threshold=self._settings.image_truncation_threshold, + ) + raw_response = self._create_raw_response(messages, model) + if raw_response is None: return messages - response = raw_response.parse() - response_params = self._response_to_params(response) - new_message: BetaMessageParam = { - "role": "assistant", - "content": response_params, - } - logger.debug(new_message) - messages.append(new_message) - self._reporter.add_message("Anthropic Computer Use", response_params) + response_message = raw_response.parse() + response_message_param: BetaMessageParam = cast( + "BetaMessageParam", + response_message.model_dump(include={"content", "role"}), + ) + logger.debug(response_message_param) + messages.append(response_message_param) + self._reporter.add_message( + "Anthropic Computer Use", dict(response_message_param) + ) + if tool_result_message := self._use_tools(response_message): + messages.append(tool_result_message) + return messages + def _use_tools(self, message: BetaMessage) -> BetaMessageParam | None: tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in response_params: - if content_block["type"] == "tool_use": + for content_block in message.content: + if content_block.type == "tool_use": + tool_input = content_block.input result = self._tool_collection.run( - name=content_block["name"], - tool_input=cast("dict[str, Any]", content_block["input"]), + name=content_block.name, + tool_input=tool_input, # type: ignore[arg-type] ) tool_result_content.append( - self._make_api_tool_result(result, content_block["id"]) + self._make_api_tool_result(result, content_block.id) ) if len(tool_result_content) > 0: - another_new_message = {"content": tool_result_content, "role": "user"} - logger.debug(truncate_long_strings(another_new_message, max_length=200)) - messages.append(cast("BetaMessageParam", another_new_message)) - return messages + tool_result_message: BetaMessageParam = { + "content": tool_result_content, + "role": "user", + } + logger.debug(tool_result_message) + return tool_result_message + return None @override def act(self, goal: str, model_choice: str) -> None: @@ -293,18 +303,6 @@ def _maybe_filter_to_n_most_recent_images( tool_result["content"] = new_content # type: ignore return None - @staticmethod - def _response_to_params( - response: BetaMessage, - ) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: - res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] - for block in response.content: - if isinstance(block, BetaTextBlock): - res.append({"type": "text", "text": block.text}) - else: - res.append(cast("BetaToolUseBlockParam", block.model_dump())) - return res - @staticmethod def _inject_prompt_caching( messages: list[BetaMessageParam], diff --git a/src/askui/models/askui/computer_agent.py b/src/askui/models/askui/computer_agent.py index 63707cc2..36c176f8 100644 --- a/src/askui/models/askui/computer_agent.py +++ b/src/askui/models/askui/computer_agent.py @@ -236,8 +236,7 @@ def step(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: } logger.debug(new_message) messages.append(new_message) - if self._reporter is not None: - self._reporter.add_message("AskUI Computer Use", response_params) + self._reporter.add_message("AskUI Computer Use", response_params) tool_result_content: list[BetaToolResultBlockParam] = [] for content_block in response_params: From d180edc3467cae607833896e0a5261cca331d178 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 3 Jun 2025 16:03:19 +0200 Subject: [PATCH 09/20] feat(agent)!: allow passing messages and callbacks to `VisionAgent.act()` - callback for new messages and tool results - allow cancelling acting by returning `None` from callback - allow modifying message or tool result by returning a modified message or tool result from callback also refactor: - introduce shared `ComputerAgent` base class - refactor `AskuiComputerAgent` and `ClaudeComputerAgent` to use `ComputerAgent` base class - replace `AskUiFacade` and `ClaudeFacade` with generic `ModelFacade` facade BREAKING CHANGE: - `ActModel.act()` signature changed, original `goal` parameter is now in `messages[0]["content"]` --- .cursorrules | 1 + README.md | 44 +- src/askui/__init__.py | 6 + src/askui/agent.py | 30 +- src/askui/models/__init__.py | 4 + src/askui/models/anthropic/computer_agent.py | 373 +------------- src/askui/models/anthropic/facade.py | 43 -- src/askui/models/anthropic/settings.py | 9 +- src/askui/models/askui/computer_agent.py | 325 +----------- src/askui/models/askui/facade.py | 46 -- src/askui/models/askui/settings.py | 14 +- src/askui/models/model_router.py | 43 +- src/askui/models/models.py | 68 ++- src/askui/models/shared/__init__.py | 0 src/askui/models/shared/computer_agent.py | 467 ++++++++++++++++++ src/askui/models/shared/facade.py | 63 +++ src/askui/models/ui_tars_ep/ui_tars_api.py | 36 +- src/askui/tools/agent_os.py | 2 - src/askui/tools/pynput/pynput_agent_os.py | 2 +- tests/e2e/agent/conftest.py | 14 +- tests/e2e/agent/test_act.py | 4 +- tests/e2e/agent/test_get.py | 4 +- .../integration/agent/test_computer_agent.py | 5 +- tests/integration/test_custom_models.py | 62 ++- tests/unit/models/test_model_router.py | 51 +- 25 files changed, 858 insertions(+), 858 deletions(-) delete mode 100644 src/askui/models/anthropic/facade.py delete mode 100644 src/askui/models/askui/facade.py create mode 100644 src/askui/models/shared/__init__.py create mode 100644 src/askui/models/shared/computer_agent.py create mode 100644 src/askui/models/shared/facade.py diff --git a/.cursorrules b/.cursorrules index de6d9034..3889764c 100644 --- a/.cursorrules +++ b/.cursorrules @@ -5,6 +5,7 @@ - Use type hints for all variables, functions, methods, properties, return values, parameters etc. - Omit `Any` within type hints unless absolutely necessary - Use built-in types (e.g., `list`, `dict`, `tuple`, `set`, `str | None`) instead of types from `typing` module (e.g. `List`, `Dict`, `Tuple`, `Set`, `Optional`, `Union`) wherever possible +- Instead of `Optional` use `| None` - Create a `__init__.py` file in each folder ## FastAPI diff --git a/README.md b/README.md index 7089109a..a600f2f8 100644 --- a/README.md +++ b/README.md @@ -284,8 +284,11 @@ You can create and use your own models by subclassing the `ActModel` (used for ` Here's how to create and use custom models: ```python +import functools from askui import ( ActModel, + BetaMessageParam, + BetaToolUseBlockParam, GetModel, LocateModel, Locator, @@ -294,22 +297,40 @@ from askui import ( ModelRegistry, Point, ResponseSchema, + ToolResult, VisionAgent, ) -from typing import Type +from typing import Callable, Type +from typing_extensions import override # Define custom models class MyActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: + @override + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: # Implement custom act logic, e.g.: # - Use a different AI model # - Implement custom business logic # - Call external services + goal = messages[0]["content"] print(f"Custom act model executing goal: {goal}") # Because Python supports multiple inheritance, we can subclass both `GetModel` and `LocateModel` (and even `ActModel`) # to create a model that can both get and locate elements. class MyGetAndLocateModel(GetModel, LocateModel): + @override def get( self, query: str, @@ -324,6 +345,7 @@ class MyGetAndLocateModel(GetModel, LocateModel): return f"Custom response to query: {query}" + @override def locate( self, locator: str | Locator, @@ -366,11 +388,23 @@ You can also use model factories if you need to create models dynamically: ```python class DynamicActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: - # Use api_key in implementation + @override + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: pass - # going to be called each time model is chosen using `model` parameter def create_custom_model(api_key: str) -> ActModel: return DynamicActModel() diff --git a/src/askui/__init__.py b/src/askui/__init__.py index a9ff59d5..e3f321ec 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -6,6 +6,8 @@ from .locators import Locator from .models import ( ActModel, + BetaMessageParam, + BetaToolUseBlockParam, GetModel, LocateModel, Model, @@ -18,10 +20,13 @@ from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry from .tools import ModifierKey, PcKey +from .tools.anthropic import ToolResult from .utils.image_utils import ImageSource, Img __all__ = [ "ActModel", + "BetaMessageParam", + "BetaToolUseBlockParam", "GetModel", "ImageSource", "Img", @@ -38,6 +43,7 @@ "ResponseSchema", "ResponseSchemaBase", "Retry", + "ToolResult", "ConfigurableRetry", "VisionAgent", ] diff --git a/src/askui/agent.py b/src/askui/agent.py index 19ab02eb..95d22c3e 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -1,13 +1,15 @@ import logging import time import types -from typing import Annotated, Literal, Optional, Type, overload +from typing import Annotated, Callable, Literal, Optional, Type, overload +from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from dotenv import load_dotenv from pydantic import ConfigDict, Field, validate_call from askui.container import telemetry from askui.locators.locators import Locator +from askui.tools.anthropic import ToolResult from askui.utils.image_utils import ImageSource, Img from .exceptions import ElementNotFoundError @@ -216,7 +218,7 @@ def _mouse_move( self, locator: str | Locator, model: ModelComposition | str | None = None ) -> None: point = self._locate(locator=locator, model=model) - self._tools.os.mouse(point[0], point[1]) + self._tools.os.mouse_move(point[0], point[1]) @telemetry.record_call(exclude={"locator"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -534,8 +536,17 @@ def mouse_down( @validate_call def act( self, - goal: Annotated[str, Field(min_length=1)], + goal: Annotated[str | list[BetaMessageParam], Field(min_length=1)], model: str | None = None, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, ) -> None: """ Instructs the agent to achieve a specified goal through autonomous actions. @@ -547,6 +558,12 @@ def act( Args: goal (str): A description of what the agent should achieve. model (str | None, optional): The composition or name of the model(s) to be used for achieving the `goal`. + messages (list[BetaMessageParam] | None, optional): The message history to start from. If None, starts with a new message containing the goal. + on_message (Callable[[BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None], optional): Callback for new messages. If it returns `None`, stops and does not add the message. + on_tool_result (Callable[[ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], ToolResult | None], optional): Callback for tool results. If it returns `None`, stops and does not add the tool result. + + Returns: + None Example: ```python @@ -562,7 +579,12 @@ def act( logger.debug( "VisionAgent received instruction to act towards the goal '%s'", goal ) - self._model_router.act(goal, model or self._model_choice["act"]) + messages: list[BetaMessageParam] = ( + [{"role": "user", "content": goal}] if isinstance(goal, str) else goal + ) + self._model_router.act( + messages, model or self._model_choice["act"], on_message, on_tool_result + ) @telemetry.record_call() @validate_call diff --git a/src/askui/models/__init__.py b/src/askui/models/__init__.py index fada0c87..f1765661 100644 --- a/src/askui/models/__init__.py +++ b/src/askui/models/__init__.py @@ -1,3 +1,5 @@ +from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam + from .models import ( ActModel, GetModel, @@ -13,6 +15,8 @@ __all__ = [ "ActModel", + "BetaMessageParam", + "BetaToolUseBlockParam", "GetModel", "LocateModel", "Model", diff --git a/src/askui/models/anthropic/computer_agent.py b/src/askui/models/anthropic/computer_agent.py index 0bffb2a6..476e7ecd 100644 --- a/src/askui/models/anthropic/computer_agent.py +++ b/src/askui/models/anthropic/computer_agent.py @@ -1,375 +1,36 @@ -import platform -import sys -from datetime import datetime, timezone -from typing import cast - -from anthropic import Anthropic, APIError -from anthropic._legacy_response import LegacyAPIResponse -from anthropic.types.beta import ( - BetaCacheControlEphemeralParam, - BetaImageBlockParam, - BetaMessage, - BetaMessageParam, - BetaTextBlockParam, - BetaToolResultBlockParam, -) +from anthropic import Anthropic +from anthropic.types.beta import BetaMessage, BetaMessageParam from typing_extensions import override from askui.models.anthropic.settings import ClaudeComputerAgentSettings -from askui.models.models import ANTHROPIC_MODEL_NAME_MAPPING, ActModel, ModelName +from askui.models.models import ANTHROPIC_MODEL_NAME_MAPPING, ModelName +from askui.models.shared.computer_agent import ComputerAgent from askui.reporting import Reporter from askui.tools.agent_os import AgentOs -from ...logger import logger -from ...tools.anthropic import ComputerTool, ToolCollection, ToolResult - -PC_KEY = [ - "backspace", - "delete", - "enter", - "tab", - "escape", - "up", - "down", - "right", - "left", - "home", - "end", - "pageup", - "pagedown", - "f1", - "f2", - "f3", - "f4", - "f5", - "f6", - "f7", - "f8", - "f9", - "f10", - "f11", - "f12", - "space", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - "P", - "Q", - "R", - "S", - "T", - "U", - "V", - "W", - "X", - "Y", - "Z", - "!", - '"', - "#", - "$", - "%", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "<", - "=", - ">", - "?", - "@", - "[", - "\\", - "]", - "^", - "_", - "`", - "{", - "|", - "}", - "~", -] - - -SYSTEM_PROMPT = f""" -* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. -* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. -* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. -* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. -* Valid keyboard keys available are {", ".join(PC_KEY)} -* The current date is {datetime.now(tz=timezone.utc).strftime("%A, %B %d, %Y").replace(" 0", " ")}. - - - -* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. -* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. -""" - -class ClaudeComputerAgent(ActModel): +class ClaudeComputerAgent(ComputerAgent[ClaudeComputerAgentSettings]): def __init__( self, agent_os: AgentOs, reporter: Reporter, settings: ClaudeComputerAgentSettings, ) -> None: - self._settings = settings + super().__init__(settings, agent_os, reporter) self._client = Anthropic( api_key=self._settings.anthropic.api_key.get_secret_value() ) - self._reporter = reporter - self._tool_collection = ToolCollection( - ComputerTool(agent_os), - ) - self._system = BetaTextBlockParam( - type="text", - text=f"{SYSTEM_PROMPT}", - ) - - def _create_raw_response( - self, messages: list[BetaMessageParam], model: str - ) -> LegacyAPIResponse[BetaMessage] | None: - try: - return self._client.beta.messages.with_raw_response.create( - max_tokens=self._settings.max_tokens, - messages=messages, - model=model, - system=[self._system], - tools=self._tool_collection.to_params(), - betas=self._settings.betas, - ) - except APIError as e: - logger.error(e) - return None - - def step( - self, messages: list[BetaMessageParam], model: str - ) -> list[BetaMessageParam]: - if self._settings.only_n_most_recent_images: - self._maybe_filter_to_n_most_recent_images( - messages, - self._settings.only_n_most_recent_images, - min_removal_threshold=self._settings.image_truncation_threshold, - ) - raw_response = self._create_raw_response(messages, model) - if raw_response is None: - return messages - - response_message = raw_response.parse() - response_message_param: BetaMessageParam = cast( - "BetaMessageParam", - response_message.model_dump(include={"content", "role"}), - ) - logger.debug(response_message_param) - messages.append(response_message_param) - self._reporter.add_message( - "Anthropic Computer Use", dict(response_message_param) - ) - if tool_result_message := self._use_tools(response_message): - messages.append(tool_result_message) - return messages - - def _use_tools(self, message: BetaMessage) -> BetaMessageParam | None: - tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in message.content: - if content_block.type == "tool_use": - tool_input = content_block.input - result = self._tool_collection.run( - name=content_block.name, - tool_input=tool_input, # type: ignore[arg-type] - ) - tool_result_content.append( - self._make_api_tool_result(result, content_block.id) - ) - if len(tool_result_content) > 0: - tool_result_message: BetaMessageParam = { - "content": tool_result_content, - "role": "user", - } - logger.debug(tool_result_message) - return tool_result_message - return None @override - def act(self, goal: str, model_choice: str) -> None: - messages: list[BetaMessageParam] = [{"role": "user", "content": goal}] - logger.debug(messages[0]) - while messages[-1]["role"] == "user": - messages = self.step( - messages=messages, - model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)], - ) - - @staticmethod - def _maybe_filter_to_n_most_recent_images( - messages: list[BetaMessageParam], - images_to_keep: int | None, - min_removal_threshold: int, - ) -> list[BetaMessageParam] | None: - """ - With the assumption that images are screenshots that are of diminishing value as - the conversation progresses, remove all but the final `images_to_keep` tool_result - images in place, with a chunk of min_removal_threshold to reduce the amount we - break the implicit prompt cache. - """ - if images_to_keep is None: - return messages - - tool_result_blocks = cast( - "list[BetaToolResultBlockParam]", - [ - item - for message in messages - for item in ( - message["content"] if isinstance(message["content"], list) else [] - ) - if isinstance(item, dict) and item.get("type") == "tool_result" - ], - ) - total_images = sum( - 1 - for tool_result in tool_result_blocks - for content in tool_result.get("content", []) - if isinstance(content, dict) and content.get("type") == "image" + def _create_message( + self, messages: list[BetaMessageParam], model_choice: str + ) -> BetaMessage: + response = self._client.beta.messages.with_raw_response.create( + max_tokens=self._settings.max_tokens, + messages=messages, + model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)], + system=[self._system], + tools=self._tool_collection.to_params(), + betas=self._settings.betas, ) - images_to_remove = total_images - images_to_keep - # for better cache behavior, we want to remove in chunks - images_to_remove -= images_to_remove % min_removal_threshold - for tool_result in tool_result_blocks: - if isinstance(tool_result.get("content"), list): - new_content = [] - for content in tool_result.get("content", []): - if isinstance(content, dict) and content.get("type") == "image": - if images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(content) - tool_result["content"] = new_content # type: ignore - return None - - @staticmethod - def _inject_prompt_caching( - messages: list[BetaMessageParam], - ) -> None: - """ - Set cache breakpoints for the 3 most recent turns - one cache breakpoint is left for tools/system prompt, to be shared - across sessions - """ - - breakpoints_remaining = 3 - for message in reversed(messages): - if message["role"] == "user" and isinstance( - content := message["content"], list - ): - if breakpoints_remaining: - breakpoints_remaining -= 1 - content[-1]["cache_control"] = BetaCacheControlEphemeralParam( - {"type": "ephemeral"} - ) - else: - content[-1].pop("cache_control", None) - # we'll only every have one extra turn per loop - break - - def _make_api_tool_result( - self, result: ToolResult, tool_use_id: str - ) -> BetaToolResultBlockParam: - """Convert an agent ToolResult to an API ToolResultBlockParam.""" - tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] - is_error = False - if result.error: - is_error = True - tool_result_content = self._maybe_prepend_system_tool_result( - result, result.error - ) - else: - assert isinstance(tool_result_content, list) - if result.output: - tool_result_content.append( - { - "type": "text", - "text": self._maybe_prepend_system_tool_result( - result, result.output - ), - } - ) - if result.base64_image: - tool_result_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": result.base64_image, - }, - } - ) - return { - "type": "tool_result", - "content": tool_result_content, - "tool_use_id": tool_use_id, - "is_error": is_error, - } - - @staticmethod - def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: - if result.system: - result_text = f"{result.system}\n{result_text}" - return result_text + return response.parse() diff --git a/src/askui/models/anthropic/facade.py b/src/askui/models/anthropic/facade.py deleted file mode 100644 index 7188867d..00000000 --- a/src/askui/models/anthropic/facade.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Type - -from typing_extensions import override - -from askui.locators.locators import Locator -from askui.models.anthropic.computer_agent import ClaudeComputerAgent -from askui.models.anthropic.handler import ClaudeHandler -from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point -from askui.models.types.response_schemas import ResponseSchema -from askui.utils.image_utils import ImageSource - - -class AnthropicFacade(ActModel, GetModel, LocateModel): - def __init__( - self, - computer_agent: ClaudeComputerAgent, - handler: ClaudeHandler, - ) -> None: - self._computer_agent = computer_agent - self._handler = handler - - @override - def act(self, goal: str, model_choice: str) -> None: - self._computer_agent.act(goal, model_choice) - - @override - def get( - self, - query: str, - image: ImageSource, - response_schema: Type[ResponseSchema] | None, - model_choice: str, - ) -> ResponseSchema | str: - return self._handler.get(query, image, response_schema, model_choice) - - @override - def locate( - self, - locator: str | Locator, - image: ImageSource, - model_choice: ModelComposition | str, - ) -> Point: - return self._handler.locate(locator, image, model_choice) diff --git a/src/askui/models/anthropic/settings.py b/src/askui/models/anthropic/settings.py index 2fe56d54..e804495d 100644 --- a/src/askui/models/anthropic/settings.py +++ b/src/askui/models/anthropic/settings.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, Field, SecretStr from pydantic_settings import BaseSettings +from askui.models.shared.computer_agent import ComputerAgentSettingsBase + COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" @@ -21,8 +23,5 @@ class ClaudeSettings(ClaudeSettingsBase): temperature: float = 0.0 -class ClaudeComputerAgentSettings(ClaudeSettingsBase): - max_tokens: int = 4096 - only_n_most_recent_images: int = 3 - image_truncation_threshold: int = 10 - betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) +class ClaudeComputerAgentSettings(ComputerAgentSettingsBase, ClaudeSettingsBase): + pass diff --git a/src/askui/models/askui/computer_agent.py b/src/askui/models/askui/computer_agent.py index 36c176f8..6538c5ba 100644 --- a/src/askui/models/askui/computer_agent.py +++ b/src/askui/models/askui/computer_agent.py @@ -1,170 +1,14 @@ -import platform -import sys -from datetime import datetime -from typing import Any, cast - import httpx -from anthropic.types.beta import ( - BetaImageBlockParam, - BetaMessage, - BetaMessageParam, - BetaTextBlock, - BetaTextBlockParam, - BetaToolResultBlockParam, - BetaToolUseBlockParam, -) +from anthropic.types.beta import BetaMessage, BetaMessageParam from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential -from typing_extensions import override from askui.models.askui.settings import AskUiComputerAgentSettings -from askui.models.models import ActModel +from askui.models.shared.computer_agent import ComputerAgent from askui.reporting import Reporter from askui.tools.agent_os import AgentOs -from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult -from askui.utils.str_utils import truncate_long_strings from ...logger import logger -COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" -PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" - -PC_KEY = [ - "backspace", - "delete", - "enter", - "tab", - "escape", - "up", - "down", - "right", - "left", - "home", - "end", - "pageup", - "pagedown", - "f1", - "f2", - "f3", - "f4", - "f5", - "f6", - "f7", - "f8", - "f9", - "f10", - "f11", - "f12", - "space", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - "P", - "Q", - "R", - "S", - "T", - "U", - "V", - "W", - "X", - "Y", - "Z", - "!", - '"', - "#", - "$", - "%", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "<", - "=", - ">", - "?", - "@", - "[", - "\\", - "]", - "^", - "_", - "`", - "{", - "|", - "}", - "~", -] - -SYSTEM_PROMPT = f""" -* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. -* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. -* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. -* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. -* Valid keyboard keys available are {", ".join(PC_KEY)} -* The current date is {datetime.today().strftime("%A, %B %d, %Y").replace(" 0", " ")}. - - - -* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. -* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. -""" # noqa: DTZ002, E501 - def is_retryable_error(exception: BaseException) -> bool: """Check if the exception is a retryable error (status codes 429 or 529).""" @@ -173,19 +17,14 @@ def is_retryable_error(exception: BaseException) -> bool: return False -class AskUiComputerAgent(ActModel): +class AskUiComputerAgent(ComputerAgent[AskUiComputerAgentSettings]): def __init__( self, agent_os: AgentOs, reporter: Reporter, settings: AskUiComputerAgentSettings, ) -> None: - self._settings = settings - self._reporter = reporter - self._tool_collection = ToolCollection( - ComputerTool(agent_os), - ) - self._system = SYSTEM_PROMPT + super().__init__(settings, agent_os, reporter) self._client = httpx.Client( base_url=f"{self._settings.askui.base_url}", headers={ @@ -200,14 +39,11 @@ def __init__( retry=retry_if_exception(is_retryable_error), reraise=True, ) - def step(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: - if self._settings.only_n_most_recent_images: - self._maybe_filter_to_n_most_recent_images( - messages, - self._settings.only_n_most_recent_images, - min_removal_threshold=self._settings.image_truncation_threshold, - ) - + def _create_message( + self, + messages: list[BetaMessageParam], + model_choice: str, # noqa: ARG002 + ) -> BetaMessage: try: request_body = { "max_tokens": self._settings.max_tokens, @@ -215,7 +51,7 @@ def step(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: "model": self._settings.model, "tools": self._tool_collection.to_params(), "betas": self._settings.betas, - "system": [{"type": "text", "text": self._system}], + "system": [self._system], } logger.debug(request_body) response = self._client.post( @@ -223,147 +59,8 @@ def step(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: ) response.raise_for_status() response_data = response.json() - beta_message = BetaMessage.model_validate(response_data) + return BetaMessage.model_validate(response_data) except Exception as e: # noqa: BLE001 if is_retryable_error(e): logger.debug(e) raise - - response_params = self._response_to_params(beta_message) - new_message: BetaMessageParam = { - "role": "assistant", - "content": response_params, - } - logger.debug(new_message) - messages.append(new_message) - self._reporter.add_message("AskUI Computer Use", response_params) - - tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in response_params: - if content_block["type"] == "tool_use": - result = self._tool_collection.run( - name=content_block["name"], - tool_input=cast("dict[str, Any]", content_block["input"]), - ) - tool_result_content.append( - self._make_api_tool_result(result, content_block["id"]) - ) - if len(tool_result_content) > 0: - another_new_message: BetaMessageParam = { - "content": tool_result_content, - "role": "user", - } - logger.debug( - truncate_long_strings(dict(another_new_message), max_length=200) - ) - messages.append(another_new_message) - return messages - - @override - def act(self, goal: str, model_choice: str) -> None: - messages: list[BetaMessageParam] = [{"role": "user", "content": goal}] - logger.debug(messages[0]) - while messages[-1]["role"] == "user": - messages = self.step(messages) - - @staticmethod - def _maybe_filter_to_n_most_recent_images( - messages: list, - images_to_keep: int | None, - min_removal_threshold: int, - ) -> list | None: - """ - With the assumption that images are screenshots that are of diminishing value as - the conversation progresses, remove all but the final `images_to_keep` tool_result - images in place, with a chunk of min_removal_threshold to reduce the amount we - break the implicit prompt cache. - """ # noqa: E501 - if images_to_keep is None: - return messages - - tool_result_blocks = [ - item - for message in messages - for item in ( - message["content"] if isinstance(message["content"], list) else [] - ) - if isinstance(item, dict) and item.get("type") == "tool_result" - ] - total_images = sum( - 1 - for tool_result in tool_result_blocks - for content in tool_result.get("content", []) - if isinstance(content, dict) and content.get("type") == "image" - ) - images_to_remove = total_images - images_to_keep - # for better cache behavior, we want to remove in chunks - images_to_remove -= images_to_remove % min_removal_threshold - for tool_result in tool_result_blocks: - if isinstance(tool_result.get("content"), list): - new_content = [] - for content in tool_result.get("content", []): - if isinstance(content, dict) and content.get("type") == "image": - if images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(content) - tool_result["content"] = new_content - return None - - @staticmethod - def _response_to_params( - response: BetaMessage, - ) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: - res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] - for block in response.content: - if isinstance(block, BetaTextBlock): - res.append({"type": "text", "text": block.text}) - else: - res.append(cast("BetaToolUseBlockParam", block.model_dump())) - return res - - def _make_api_tool_result( - self, result: ToolResult, tool_use_id: str - ) -> BetaToolResultBlockParam: - """Convert an agent ToolResult to an API ToolResultBlockParam.""" - tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] - is_error = False - if result.error: - is_error = True - tool_result_content = self._maybe_prepend_system_tool_result( - result, result.error - ) - else: - assert isinstance(tool_result_content, list) - if result.output: - tool_result_content.append( - { - "type": "text", - "text": self._maybe_prepend_system_tool_result( - result, result.output - ), - } - ) - if result.base64_image: - tool_result_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": result.base64_image, - }, - } - ) - return { - "type": "tool_result", - "content": tool_result_content, - "tool_use_id": tool_use_id, - "is_error": is_error, - } - - @staticmethod - def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: - if result.system: - result_text = f"{result.system}\n{result_text}" - return result_text diff --git a/src/askui/models/askui/facade.py b/src/askui/models/askui/facade.py deleted file mode 100644 index a682522e..00000000 --- a/src/askui/models/askui/facade.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import Type - -from typing_extensions import override - -from askui.locators.locators import Locator -from askui.models.askui.computer_agent import AskUiComputerAgent -from askui.models.askui.inference_api import AskUiInferenceApi -from askui.models.askui.model_router import AskUiModelRouter -from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point -from askui.models.types.response_schemas import ResponseSchema -from askui.utils.image_utils import ImageSource - - -class AskUiFacade(ActModel, GetModel, LocateModel): - def __init__( - self, - computer_agent: AskUiComputerAgent, - inference_api: AskUiInferenceApi, - model_router: AskUiModelRouter, - ) -> None: - self._computer_agent = computer_agent - self._inference_api = inference_api - self._model_router = model_router - - @override - def act(self, goal: str, model_choice: str) -> None: - self._computer_agent.act(goal, model_choice) - - @override - def get( - self, - query: str, - image: ImageSource, - response_schema: Type[ResponseSchema] | None, - model_choice: str, - ) -> ResponseSchema | str: - return self._inference_api.get(query, image, response_schema, model_choice) - - @override - def locate( - self, - locator: str | Locator, - image: ImageSource, - model_choice: ModelComposition | str, - ) -> Point: - return self._model_router.locate(locator, image, model_choice) diff --git a/src/askui/models/askui/settings.py b/src/askui/models/askui/settings.py index 649aff4e..b018b40b 100644 --- a/src/askui/models/askui/settings.py +++ b/src/askui/models/askui/settings.py @@ -1,12 +1,11 @@ import base64 from functools import cached_property -from pydantic import UUID4, BaseModel, Field, HttpUrl, SecretStr +from pydantic import UUID4, Field, HttpUrl, SecretStr from pydantic_settings import BaseSettings from askui.models.models import ModelName - -COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" +from askui.models.shared.computer_agent import ComputerAgentSettingsBase class AskUiSettings(BaseSettings): @@ -37,13 +36,6 @@ def base_url(self) -> str: return f"{self.inference_endpoint}api/v1/workspaces/{self.workspace_id}" -class AskUiComputerAgentSettingsBase(BaseModel): +class AskUiComputerAgentSettings(ComputerAgentSettingsBase): model: str = ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 askui: AskUiSettings = Field(default_factory=AskUiSettings) - - -class AskUiComputerAgentSettings(AskUiComputerAgentSettingsBase): - max_tokens: int = 4096 - only_n_most_recent_images: int = 3 - image_truncation_threshold: int = 10 - betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 10ebd036..efdcde68 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -1,12 +1,12 @@ import functools -from typing import Type, overload +from typing import Callable, Type, overload +from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from typing_extensions import Literal from askui.exceptions import ModelNotFoundError, ModelTypeMismatchError from askui.locators.locators import Locator from askui.locators.serializers import AskUiLocatorSerializer, VlmLocatorSerializer -from askui.models.anthropic.facade import AnthropicFacade from askui.models.anthropic.settings import ( AnthropicSettings, ClaudeComputerAgentSettings, @@ -14,7 +14,6 @@ ) from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.askui.computer_agent import AskUiComputerAgent -from askui.models.askui.facade import AskUiFacade from askui.models.askui.model_router import AskUiModelRouter from askui.models.askui.settings import AskUiComputerAgentSettings from askui.models.huggingface.spaces_api import HFSpacesHandler @@ -29,8 +28,10 @@ ModelRegistry, Point, ) +from askui.models.shared.facade import ModelFacade from askui.models.types.response_schemas import ResponseSchema from askui.reporting import CompositeReporter, Reporter +from askui.tools.anthropic import ToolResult from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource @@ -70,7 +71,7 @@ def vlm_locator_serializer() -> VlmLocatorSerializer: return VlmLocatorSerializer() @functools.cache - def anthropic_facade() -> AnthropicFacade: + def anthropic_facade() -> ModelFacade: settings = AnthropicSettings() computer_agent = ClaudeComputerAgent( agent_os=tools.os, @@ -85,13 +86,14 @@ def anthropic_facade() -> AnthropicFacade: ), locator_serializer=vlm_locator_serializer(), ) - return AnthropicFacade( - computer_agent=computer_agent, - handler=handler, + return ModelFacade( + act_model=computer_agent, + get_model=handler, + locate_model=handler, ) @functools.cache - def askui_facade() -> AskUiFacade: + def askui_facade() -> ModelFacade: computer_agent = AskUiComputerAgent( agent_os=tools.os, reporter=reporter, @@ -99,10 +101,10 @@ def askui_facade() -> AskUiFacade: askui=askui_settings(), ), ) - return AskUiFacade( - computer_agent=computer_agent, - inference_api=askui_inference_api(), - model_router=askui_model_router(), + return ModelFacade( + act_model=computer_agent, + get_model=askui_inference_api(), + locate_model=askui_model_router(), ) @functools.cache @@ -180,10 +182,23 @@ def _get_model( return model - def act(self, goal: str, model_choice: str) -> None: + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: m = self._get_model(model_choice, "act") logger.debug(f'Routing "act" to model "{model_choice}"') - return m.act(goal, model_choice) + return m.act(messages, model_choice, on_message, on_tool_result) def get( self, diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 80a6e1fe..d038716a 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -4,11 +4,13 @@ from enum import Enum from typing import Annotated, Callable, Type +from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from pydantic import BaseModel, ConfigDict, Field, RootModel from typing_extensions import Literal, TypedDict from askui.locators.locators import Locator from askui.models.types.response_schemas import ResponseSchema +from askui.tools.anthropic.base import ToolResult from askui.utils.image_utils import ImageSource @@ -154,27 +156,71 @@ class ActModel(abc.ABC): Example: ```python - from askui import ActModel, VisionAgent + from askui import ( + ActModel, + BetaMessageParam, + BetaToolUseBlockParam, + VisionAgent, + ToolResult, + ) + from typing_extensions import override class MyActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: - # Implement custom act logic - pass + @override + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: + print(messages) # implement custom logic here with VisionAgent(models={"my-act": MyActModel()}) as agent: agent.act("search for flights", model="my-act") - ``` """ @abc.abstractmethod - def act(self, goal: str, model_choice: str) -> None: - """Execute autonomous actions to achieve a goal. + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: + """ + Execute autonomous actions to achieve a goal, using a message history + and optional callbacks. Args: - goal (str): A description of what the model should achieve - model_choice (str): The name of the model being used (useful for models that - support multiple configurations) - """ + messages (list[BetaMessageParam]): The message history to start from. + model_choice (str): The name of the model being used (useful for models + that support multiple configurations) + on_message (Callable[[BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None], optional): Callback for new messages. + If it returns `None`, stops and does not add the message. + on_tool_result (Callable[[ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], ToolResult | None], optional): Callback for tool results. + If it returns `None`, stops and does not add the tool result. + + Returns: + None + + Raises: + NotImplementedError: If the method is not implemented. + """ # noqa: E501 raise NotImplementedError diff --git a/src/askui/models/shared/__init__.py b/src/askui/models/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py new file mode 100644 index 00000000..49fd34c5 --- /dev/null +++ b/src/askui/models/shared/computer_agent.py @@ -0,0 +1,467 @@ +import platform +import sys +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from typing import Callable, Generic, cast + +from anthropic.types.beta import ( + BetaContentBlockParam, + BetaImageBlockParam, + BetaMessage, + BetaMessageParam, + BetaTextBlockParam, + BetaToolResultBlockParam, + BetaToolUseBlockParam, +) +from pydantic import BaseModel, Field +from typing_extensions import TypeVar + +from askui.models.models import ActModel +from askui.reporting import Reporter +from askui.tools.agent_os import AgentOs +from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult + +from ...logger import logger + +COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" + +PC_KEY = [ + "backspace", + "delete", + "enter", + "tab", + "escape", + "up", + "down", + "right", + "left", + "home", + "end", + "pageup", + "pagedown", + "f1", + "f2", + "f3", + "f4", + "f5", + "f6", + "f7", + "f8", + "f9", + "f10", + "f11", + "f12", + "space", + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "A", + "B", + "C", + "D", + "E", + "F", + "G", + "H", + "I", + "J", + "K", + "L", + "M", + "N", + "O", + "P", + "Q", + "R", + "S", + "T", + "U", + "V", + "W", + "X", + "Y", + "Z", + "!", + '"', + "#", + "$", + "%", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "<", + "=", + ">", + "?", + "@", + "[", + "\\", + "]", + "^", + "_", + "`", + "{", + "|", + "}", + "~", +] + +SYSTEM_PROMPT = f""" +* You are utilising a {sys.platform} machine using {platform.machine()} architecture with internet access. +* When asked to perform web tasks try to open the browser (firefox, chrome, safari, ...) if not already open. Often you can find the browser icons in the toolbars of the operating systems. +* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available. +* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request. +* Valid keyboard keys available are {", ".join(PC_KEY)} +* The current date is {datetime.now(timezone.utc).strftime("%A, %B %d, %Y").replace(" 0", " ")}. + + + +* When using Firefox, if a startup wizard appears, IGNORE IT. Do not even click "skip this step". Instead, click on the address bar where it says "Search or enter address", and enter the appropriate search term or URL there. +* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool. +""" # noqa: DTZ002, E501 + + +class ComputerAgentSettingsBase(BaseModel): + """Settings for computer agents.""" + + max_tokens: int = 4096 + only_n_most_recent_images: int = 3 + image_truncation_threshold: int = 10 + betas: list[str] = Field(default_factory=lambda: [COMPUTER_USE_BETA_FLAG]) + + +ComputerAgentSettings = TypeVar( + "ComputerAgentSettings", bound=ComputerAgentSettingsBase +) + + +class ComputerAgent(ActModel, ABC, Generic[ComputerAgentSettings]): + """Base class for computer agents that can execute autonomous actions. + + This class provides common functionality for both AskUI and Anthropic + computer agents, + including tool handling, message processing, and image filtering. + """ + + def __init__( + self, + settings: ComputerAgentSettings, + agent_os: AgentOs, + reporter: Reporter, + ) -> None: + """Initialize the computer agent. + + Args: + settings (ComputerAgentSettings): The settings for the computer agent. + agent_os (AgentOs): The operating system agent for executing commands. + reporter (Reporter): The reporter for logging messages and actions. + """ + self._settings = settings + self._reporter = reporter + self._tool_collection = ToolCollection( + ComputerTool(agent_os), + ) + self._system = BetaTextBlockParam( + type="text", + text=f"{SYSTEM_PROMPT}", + ) + + @abstractmethod + def _create_message( + self, messages: list[BetaMessageParam], model_choice: str + ) -> BetaMessage: + """Create a message using the agent's API. + + Args: + messages (list[BetaMessageParam]): The message history. + model_choice (str): The model to use for message creation. + + Returns: + BetaMessage: The created message. + """ + raise NotImplementedError + + def _step( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: + """Execute a single step in the conversation. + + Args: + messages (list[BetaMessageParam]): The message history. + model_choice (str): The model to use for message creation. + on_message (Callable, optional): Callback for message processing. + on_tool_result (Callable, optional): Callback for tool result processing. + + Returns: + None + """ + if self._settings.only_n_most_recent_images: + self._maybe_filter_to_n_most_recent_images( + messages, + self._settings.only_n_most_recent_images, + min_removal_threshold=self._settings.image_truncation_threshold, + ) + response_message = self._create_message(messages, model_choice) + response_message_param: BetaMessageParam = cast( + "BetaMessageParam", + response_message.model_dump(include={"content", "role"}), + ) + if on_message is not None: + response_message_param_cb = on_message(response_message_param, messages) + if response_message_param_cb is None: + return + + response_message_param = response_message_param_cb + logger.debug(response_message_param) + messages.append(response_message_param) + self._reporter.add_message( + self.__class__.__name__, dict(response_message_param) + ) + if tool_result_message := self._use_tools( + response_message, messages, on_tool_result + ): + messages.append(tool_result_message) + self._step( + messages=messages, + model_choice=model_choice, + on_message=on_message, + on_tool_result=on_tool_result, + ) + + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: + logger.debug(messages[0]) + self._step( + messages=messages, + model_choice=model_choice, + on_message=on_message, + on_tool_result=on_tool_result, + ) + + def _use_tools( + self, + message: BetaMessage, + messages: list[BetaMessageParam], + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> BetaMessageParam | None: + """Process tool use blocks in a message. + + Args: + message (BetaMessage): The message containing tool use blocks. + messages (list[BetaMessageParam]): The message history. + on_tool_result (Callable, optional): Callback for tool result processing. + + Returns: + BetaMessageParam | None: A message containing tool results or `None` + if no tools were used. + """ + tool_result_content: list[BetaToolResultBlockParam] = [] + for content_block in message.content: + if content_block.type == "tool_use": + result = self._tool_collection.run( + name=content_block.name, + tool_input=content_block.input, # type: ignore[arg-type] + ) + if on_tool_result is not None: + content_block_param = cast( + "BetaToolUseBlockParam", content_block.model_dump() + ) + result_cb = on_tool_result(result, content_block_param, messages) + if result_cb is None: + return None + result = result_cb + tool_result_content.append( + self._make_api_tool_result(result, content_block.id) + ) + if len(tool_result_content) > 0: + tool_result_message: BetaMessageParam = { + "content": tool_result_content, + "role": "user", + } + logger.debug(tool_result_message) + return tool_result_message + return None + + @staticmethod + def _maybe_filter_to_n_most_recent_images( + messages: list[BetaMessageParam], + images_to_keep: int | None, + min_removal_threshold: int, + ) -> list[BetaMessageParam] | None: + """Filter messages to keep only the most recent images. + + Args: + messages (list[BetaMessageParam]): The message history. + images_to_keep (int | None): Number of most recent images to keep. + min_removal_threshold (int): Minimum number of images to remove at once. + + Returns: + list[BetaMessageParam] | None: The filtered message history or `None` if no + filtering was done. + """ + if images_to_keep is None: + return messages + + tool_result_blocks = cast( + "list[BetaToolResultBlockParam]", + [ + item + for message in messages + for item in ( + message["content"] if isinstance(message["content"], list) else [] + ) + if item.get("type") == "tool_result" + ], + ) + total_images = sum( + 1 + for tool_result in tool_result_blocks + for content in tool_result.get("content", []) + if isinstance(content, dict) and content.get("type") == "image" + ) + images_to_remove = total_images - images_to_keep + # for better cache behavior, we want to remove in chunks + images_to_remove -= images_to_remove % min_removal_threshold + for tool_result in tool_result_blocks: + if isinstance(tool_result.get("content"), list): + new_content: list[BetaContentBlockParam] = [] + for content in tool_result.get("content", []): + if isinstance(content, dict) and content.get("type") == "image": + if images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(cast("BetaContentBlockParam", content)) + tool_result["content"] = new_content # type: ignore[typeddict-item] + return None + + def _make_api_tool_result( + self, result: ToolResult, tool_use_id: str + ) -> BetaToolResultBlockParam: + """Convert a tool result to an API tool result block. + + Args: + result (ToolResult): The tool result to convert. + tool_use_id (str): The ID of the tool use block. + + Returns: + BetaToolResultBlockParam: The API tool result block. + """ + tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] + is_error = False + if result.error: + is_error = True + tool_result_content = self._maybe_prepend_system_tool_result( + result, result.error + ) + else: + assert isinstance(tool_result_content, list) + if result.output: + tool_result_content.append( + { + "type": "text", + "text": self._maybe_prepend_system_tool_result( + result, result.output + ), + } + ) + if result.base64_image: + tool_result_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": result.base64_image, + }, + } + ) + return { + "type": "tool_result", + "content": tool_result_content, + "tool_use_id": tool_use_id, + "is_error": is_error, + } + + @staticmethod + def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: + """Prepend system message to tool result text if available. + + Args: + result (ToolResult): The tool result. + result_text (str): The result text. + + Returns: + str: The result text with optional system message prepended. + """ + if result.system: + result_text = f"{result.system}\n{result_text}" + return result_text diff --git a/src/askui/models/shared/facade.py b/src/askui/models/shared/facade.py new file mode 100644 index 00000000..34491b49 --- /dev/null +++ b/src/askui/models/shared/facade.py @@ -0,0 +1,63 @@ +from typing import Callable, Type + +from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam +from typing_extensions import override + +from askui.locators.locators import Locator +from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point +from askui.models.types.response_schemas import ResponseSchema +from askui.tools.anthropic.base import ToolResult +from askui.utils.image_utils import ImageSource + + +class ModelFacade(ActModel, GetModel, LocateModel): + def __init__( + self, + act_model: ActModel, + get_model: GetModel, + locate_model: LocateModel, + ) -> None: + self._act_model = act_model + self._get_model = get_model + self._locate_model = locate_model + + @override + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: + self._act_model.act( + messages=messages, + model_choice=model_choice, + on_message=on_message, + on_tool_result=on_tool_result, + ) + + @override + def get( + self, + query: str, + image: ImageSource, + response_schema: Type[ResponseSchema] | None, + model_choice: str, + ) -> ResponseSchema | str: + return self._get_model.get(query, image, response_schema, model_choice) + + @override + def locate( + self, + locator: str | Locator, + image: ImageSource, + model_choice: ModelComposition | str, + ) -> Point: + return self._locate_model.locate(locator, image, model_choice) diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index ffe60186..6db29e64 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -1,8 +1,9 @@ import math import re import time -from typing import Any, Type +from typing import Any, Callable, Type +from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from openai import OpenAI from pydantic import Field, HttpUrl, SecretStr from pydantic_settings import BaseSettings @@ -15,6 +16,7 @@ from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter from askui.tools.agent_os import AgentOs +from askui.tools.anthropic.base import ToolResult from askui.utils.image_utils import ImageSource, image_to_base64 from .parser import UITarsEPMessage @@ -188,7 +190,37 @@ def get( return response @override - def act(self, goal: str, model_choice: str) -> None: + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: + if on_message is not None: + error_msg = "on_message is not supported for UI-TARS" + raise NotImplementedError(error_msg) + if on_tool_result is not None: + error_msg = "on_tool_result is not supported for UI-TARS" + raise NotImplementedError(error_msg) + if len(messages) != 1: + error_msg = "UI-TARS only supports one message" + raise ValueError(error_msg) + message = messages[0] + if message["role"] != "user": + error_msg = "UI-TARS only supports user messages" + raise ValueError(error_msg) + if not isinstance(message["content"], str): + error_msg = "UI-TARS only supports text messages" + raise ValueError(error_msg) # noqa: TRY004 + goal = message["content"] screenshot = self._agent_os.screenshot() self.act_history = [ { diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index 8f65e5e5..dd7ad3df 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -461,14 +461,12 @@ def start_listening(self) -> None: """ raise NotImplementedError - @abstractmethod def poll_event(self) -> InputEvent | None: """ Poll for a single input event. """ raise NotImplementedError - @abstractmethod def stop_listening(self) -> None: """Stop listening for mouse and keyboard events.""" raise NotImplementedError diff --git a/src/askui/tools/pynput/pynput_agent_os.py b/src/askui/tools/pynput/pynput_agent_os.py index 861ca868..c6a28f89 100644 --- a/src/askui/tools/pynput/pynput_agent_os.py +++ b/src/askui/tools/pynput/pynput_agent_os.py @@ -162,7 +162,7 @@ def screenshot(self, report: bool = True) -> Image.Image: """ monitor = self._sct.monitors[self._display] screenshot: ScreenShot = self._sct.grab(monitor) - image = Image.frombytes( # type: ignore[arg-type] + image = Image.frombytes( "RGB", screenshot.size, screenshot.rgb, diff --git a/tests/e2e/agent/conftest.py b/tests/e2e/agent/conftest.py index f6064450..dbbc30e5 100644 --- a/tests/e2e/agent/conftest.py +++ b/tests/e2e/agent/conftest.py @@ -11,11 +11,11 @@ from askui.locators.serializers import AskUiLocatorSerializer from askui.models.askui.ai_element_utils import AiElementCollection from askui.models.askui.computer_agent import AskUiComputerAgent -from askui.models.askui.facade import AskUiFacade from askui.models.askui.inference_api import AskUiInferenceApi, AskUiSettings from askui.models.askui.model_router import AskUiModelRouter from askui.models.askui.settings import AskUiComputerAgentSettings from askui.models.models import ModelName +from askui.models.shared.facade import ModelFacade from askui.reporting import Reporter, SimpleHtmlReporter from askui.tools.toolbox import AgentToolbox @@ -82,11 +82,11 @@ def askui_computer_agent( def askui_facade( askui_computer_agent: AskUiComputerAgent, askui_inference_api: AskUiInferenceApi, -) -> AskUiFacade: - return AskUiFacade( - computer_agent=askui_computer_agent, - inference_api=askui_inference_api, - model_router=AskUiModelRouter(inference_api=askui_inference_api), +) -> ModelFacade: + return ModelFacade( + act_model=askui_computer_agent, + get_model=askui_inference_api, + locate_model=AskUiModelRouter(inference_api=askui_inference_api), ) @@ -94,7 +94,7 @@ def askui_facade( def vision_agent( agent_toolbox_mock: AgentToolbox, simple_html_reporter: Reporter, - askui_facade: AskUiFacade, + askui_facade: ModelFacade, ) -> Generator[VisionAgent, None, None]: """Fixture providing a VisionAgent instance.""" with VisionAgent( diff --git a/tests/e2e/agent/test_act.py b/tests/e2e/agent/test_act.py index efd25e0a..e6707ca4 100644 --- a/tests/e2e/agent/test_act.py +++ b/tests/e2e/agent/test_act.py @@ -1,8 +1,8 @@ import pytest from askui.agent import VisionAgent -from askui.models.askui.facade import AskUiFacade from askui.models.models import ModelComposition, ModelDefinition, ModelName +from askui.models.shared.facade import ModelFacade from askui.reporting import Reporter from askui.tools.toolbox import AgentToolbox @@ -26,7 +26,7 @@ def test_act( def test_act_with_model_composition_should_use_default_model( agent_toolbox_mock: AgentToolbox, simple_html_reporter: Reporter, - askui_facade: AskUiFacade, + askui_facade: ModelFacade, ) -> None: with VisionAgent( reporters=[simple_html_reporter], diff --git a/tests/e2e/agent/test_get.py b/tests/e2e/agent/test_get.py index ce6e034c..be3fbb26 100644 --- a/tests/e2e/agent/test_get.py +++ b/tests/e2e/agent/test_get.py @@ -5,8 +5,8 @@ from askui import ResponseSchemaBase, VisionAgent from askui.models import ModelName -from askui.models.askui.facade import AskUiFacade from askui.models.models import ModelComposition, ModelDefinition +from askui.models.shared.facade import ModelFacade from askui.reporting import Reporter from askui.tools.toolbox import AgentToolbox @@ -42,7 +42,7 @@ def test_get( def test_get_with_model_composition_should_use_default_model( agent_toolbox_mock: AgentToolbox, - askui_facade: AskUiFacade, + askui_facade: ModelFacade, simple_html_reporter: Reporter, github_login_screenshot: PILImage.Image, ) -> None: diff --git a/tests/integration/agent/test_computer_agent.py b/tests/integration/agent/test_computer_agent.py index 438b38f4..2154339d 100644 --- a/tests/integration/agent/test_computer_agent.py +++ b/tests/integration/agent/test_computer_agent.py @@ -10,4 +10,7 @@ def test_act( claude_computer_agent: AskUiComputerAgent, ) -> None: - claude_computer_agent.act("Go to github.com/login", model_choice=ModelName.ASKUI) + claude_computer_agent.act( + [{"role": "user", "content": "Go to github.com/login"}], + model_choice=ModelName.ASKUI, + ) diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index b5f190ed..1bdfd521 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -1,8 +1,10 @@ """Integration tests for custom model registration and selection.""" -from typing import Optional, Type, Union +from typing import Callable, Optional, Type, Union import pytest +from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam +from typing_extensions import override from askui import ( ActModel, @@ -16,6 +18,7 @@ ) from askui.locators.locators import Locator from askui.models import ModelComposition, ModelDefinition, ModelName +from askui.tools.anthropic.base import ToolResult from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource @@ -24,11 +27,25 @@ class SimpleActModel(ActModel): """Simple act model that records goals.""" def __init__(self) -> None: - self.goals: list[str] = [] + self.goals: list[list[BetaMessageParam]] = [] self.model_choices: list[str] = [] - def act(self, goal: str, model_choice: str) -> None: - self.goals.append(goal) + @override + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: + self.goals.append(messages) self.model_choices.append(model_choice) @@ -137,7 +154,9 @@ def test_register_and_use_custom_act_model( with VisionAgent(models=model_registry, tools=agent_toolbox_mock) as agent: agent.act("test goal", model="custom-act") - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == ["custom-act"] def test_register_and_use_custom_get_model( @@ -182,7 +201,9 @@ def create_model() -> ActModel: with VisionAgent(models=registry, tools=agent_toolbox_mock) as agent: agent.act("test goal", model="factory-model") - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == ["factory-model"] def test_register_multiple_models_for_same_task( @@ -193,7 +214,21 @@ def test_register_multiple_models_for_same_task( """Test registering multiple models for the same task.""" class AnotherActModel(ActModel): - def act(self, goal: str, model_choice: str) -> None: + @override + def act( + self, + messages: list[BetaMessageParam], + model_choice: str, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None = None, + on_tool_result: Callable[ + [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], + ToolResult | None, + ] + | None = None, + ) -> None: pass registry: ModelRegistry = { @@ -205,7 +240,9 @@ def act(self, goal: str, model_choice: str) -> None: agent.act("test goal", model="act-1") agent.act("another goal", model="act-2") - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == ["act-1"] def test_use_response_schema_with_custom_get_model( @@ -240,7 +277,9 @@ def test_override_default_model( with VisionAgent(models=registry, tools=agent_toolbox_mock) as agent: agent.act("test goal") # Should use custom model since it overrides "askui" - assert act_model.goals == ["test goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + ] assert act_model.model_choices == [ModelName.ASKUI] def test_model_composition( @@ -319,4 +358,7 @@ def create_model() -> ActModel: agent.act("another goal", model="lazy-model") assert init_count == 2 - assert act_model.goals == ["test goal", "another goal"] + assert act_model.goals == [ + [{"role": "user", "content": "test goal"}], + [{"role": "user", "content": "another goal"}], + ] diff --git a/tests/unit/models/test_model_router.py b/tests/unit/models/test_model_router.py index 12cb439e..f35b7257 100644 --- a/tests/unit/models/test_model_router.py +++ b/tests/unit/models/test_model_router.py @@ -9,11 +9,10 @@ from pytest_mock import MockerFixture from askui.exceptions import ModelNotFoundError -from askui.models.anthropic.facade import AnthropicFacade -from askui.models.askui.facade import AskUiFacade from askui.models.huggingface.spaces_api import HFSpacesHandler from askui.models.model_router import ModelRouter from askui.models.models import ModelName +from askui.models.shared.facade import ModelFacade from askui.models.ui_tars_ep.ui_tars_api import UiTarsApiHandler from askui.reporting import CompositeReporter from askui.tools.toolbox import AgentToolbox @@ -55,9 +54,9 @@ def mock_hf_spaces(mocker: MockerFixture) -> HFSpacesHandler: @pytest.fixture -def mock_anthropic_facade(mocker: MockerFixture) -> AnthropicFacade: +def mock_anthropic_facade(mocker: MockerFixture) -> ModelFacade: """Fixture providing a mock Anthropic facade.""" - mock = cast("AnthropicFacade", mocker.MagicMock(spec=AnthropicFacade)) + mock = cast("ModelFacade", mocker.MagicMock(spec=ModelFacade)) mock.act.return_value = None # type: ignore[attr-defined] mock.get.return_value = "Mock response" # type: ignore[attr-defined] mock.locate.return_value = (50, 50) # type: ignore[attr-defined] @@ -65,9 +64,9 @@ def mock_anthropic_facade(mocker: MockerFixture) -> AnthropicFacade: @pytest.fixture -def mock_askui_facade(mocker: MockerFixture) -> AskUiFacade: +def mock_askui_facade(mocker: MockerFixture) -> ModelFacade: """Fixture providing a mock AskUI facade.""" - mock = cast("AskUiFacade", mocker.MagicMock(spec=AskUiFacade)) + mock = cast("ModelFacade", mocker.MagicMock(spec=ModelFacade)) mock.act.return_value = None # type: ignore[attr-defined] mock.get.return_value = "Mock response" # type: ignore[attr-defined] mock.locate.return_value = (50, 50) # type: ignore[attr-defined] @@ -77,8 +76,8 @@ def mock_askui_facade(mocker: MockerFixture) -> AskUiFacade: @pytest.fixture def model_router( agent_toolbox_mock: AgentToolbox, - mock_anthropic_facade: AnthropicFacade, - mock_askui_facade: AskUiFacade, + mock_anthropic_facade: ModelFacade, + mock_askui_facade: ModelFacade, mock_tars: UiTarsApiHandler, mock_hf_spaces: HFSpacesHandler, ) -> ModelRouter: @@ -110,7 +109,7 @@ def test_locate_with_askui_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI model.""" locator = "test locator" @@ -123,7 +122,7 @@ def test_locate_with_askui_pta_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI PTA model.""" locator = "test locator" @@ -138,7 +137,7 @@ def test_locate_with_askui_ocr_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI OCR model.""" locator = "test locator" @@ -153,7 +152,7 @@ def test_locate_with_askui_combo_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI combo model.""" locator = "test locator" @@ -168,7 +167,7 @@ def test_locate_with_askui_ai_element_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test locating elements using AskUI AI element model.""" locator = "test locator" @@ -196,7 +195,7 @@ def test_locate_with_claude_model( self, model_router: ModelRouter, mock_image: Image.Image, - mock_anthropic_facade: AnthropicFacade, + mock_anthropic_facade: ModelFacade, ) -> None: """Test locating elements using Claude model.""" locator = "test locator" @@ -239,7 +238,7 @@ def test_get_with_askui_model( self, model_router: ModelRouter, mock_image_source: ImageSource, - mock_askui_facade: AskUiFacade, + mock_askui_facade: ModelFacade, ) -> None: """Test getting inference using AskUI model.""" response = model_router.get( @@ -265,7 +264,7 @@ def test_get_with_claude_model( self, model_router: ModelRouter, mock_image_source: ImageSource, - mock_anthropic_facade: AnthropicFacade, + mock_anthropic_facade: ModelFacade, ) -> None: """Test getting inference using Claude model.""" response = model_router.get( @@ -289,21 +288,29 @@ def test_act_with_tars_model( self, model_router: ModelRouter, mock_tars: UiTarsApiHandler ) -> None: """Test acting using TARS model.""" - model_router.act("test goal", ModelName.TARS) - mock_tars.act.assert_called_once_with("test goal", ModelName.TARS) # type: ignore + model_router.act([{"role": "user", "content": "test goal"}], ModelName.TARS) + mock_tars.act.assert_called_once_with( # type: ignore[attr-defined] + [{"role": "user", "content": "test goal"}], ModelName.TARS, None, None + ) def test_act_with_claude_model( - self, model_router: ModelRouter, mock_anthropic_facade: AnthropicFacade + self, model_router: ModelRouter, mock_anthropic_facade: ModelFacade ) -> None: """Test acting using Claude model.""" model_router.act( - "test goal", ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + [{"role": "user", "content": "test goal"}], + ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) mock_anthropic_facade.act.assert_called_once_with( # type: ignore - "test goal", ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022 + [{"role": "user", "content": "test goal"}], + ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, + None, + None, ) def test_act_with_invalid_model(self, model_router: ModelRouter) -> None: """Test that acting with invalid model raises InvalidModelError.""" with pytest.raises(ModelNotFoundError): - model_router.act("test goal", "invalid-model") + model_router.act( + [{"role": "user", "content": "test goal"}], "invalid-model" + ) From 7326a9870a6582a8e70faa9d9dbe151ca8228558 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 3 Jun 2025 16:05:47 +0200 Subject: [PATCH 10/20] style(chat): fix linting issues --- src/askui/chat/__main__.py | 15 ++++++--------- src/askui/chat/api/messages/router.py | 16 ++++++++-------- src/askui/chat/api/messages/service.py | 2 +- src/askui/chat/api/threads/router.py | 10 +++++----- src/askui/chat/api/threads/service.py | 2 +- src/askui/chat/api/utils.py | 4 +++- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 0d513b06..82036606 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -9,8 +9,8 @@ from typing_extensions import override from askui import VisionAgent -from askui.chat.api.messages import MessageRole, MessagesApi -from askui.chat.api.threads import ThreadsApi +from askui.chat.api.messages.service import MessageRole, MessageService +from askui.chat.api.threads.service import ThreadService from askui.chat.click_recorder import ClickRecorder from askui.chat.exceptions import FunctionExecutionError, InvalidFunctionError from askui.models import ModelName @@ -19,29 +19,26 @@ from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import base64_to_image, draw_point_on_image -# TODO Start backend server - st.set_page_config( page_title="Vision Agent Chat", page_icon="💬", ) -# TODO Tool, pynput alternatively BASE_DIR = Path("./chat") -threads_api = ThreadsApi(BASE_DIR) -messages_api = MessagesApi(BASE_DIR) +threads_api = ThreadService(BASE_DIR) +messages_api = MessageService(BASE_DIR) click_recorder = ClickRecorder() -def get_image(img_b64_str_or_path: str) -> Image.Image: # TODO Image utils +def get_image(img_b64_str_or_path: str) -> Image.Image: """Get image from base64 string or file path.""" if Path(img_b64_str_or_path).is_file(): return Image.open(img_b64_str_or_path) return base64_to_image(img_b64_str_or_path) -def write_message( # TODO updating frontend +def write_message( role: str, content: str | dict[str, Any] | list[Any], timestamp: str, diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 328e5cc2..98507b47 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -19,7 +19,7 @@ class CreateMessageRequest(BaseModel): router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) -@router.get("", response_model=MessageListResponse) +@router.get("") def list_messages( thread_id: str, limit: int | None = None, @@ -29,14 +29,14 @@ def list_messages( try: return message_service.list_(thread_id, limit=limit) except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e -@router.post("", response_model=Message) +@router.post("") async def create_message( thread_id: str, request: CreateMessageRequest, - image: UploadFile | None = File(None), + image: UploadFile | None = None, message_service: MessageService = MessageServiceDep, ) -> Message: """Create a new message in a thread.""" @@ -54,10 +54,10 @@ async def create_message( image=pil_image, ) except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e -@router.get("/{message_id}", response_model=Message) +@router.get("/{message_id}") def get_message( thread_id: str, message_id: str, @@ -67,7 +67,7 @@ def get_message( try: return message_service.retrieve(thread_id, message_id) except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e @router.delete("/{message_id}") @@ -80,4 +80,4 @@ def delete_message( try: message_service.delete(thread_id, message_id) except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index edf80d0b..0a7a150a 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -230,5 +230,5 @@ def delete(self, thread_id: str, message_id: str) -> None: for img_path in image_paths: try: Path(img_path).unlink() - except FileNotFoundError: + except FileNotFoundError: # noqa: PERF203 pass # Image might have been deleted already diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py index 1fae51cb..ffd9b49f 100644 --- a/src/askui/chat/api/threads/router.py +++ b/src/askui/chat/api/threads/router.py @@ -6,7 +6,7 @@ router = APIRouter(prefix="/threads", tags=["threads"]) -@router.get("", response_model=ThreadListResponse) +@router.get("") def list_threads( limit: int | None = None, thread_service: ThreadService = ThreadServiceDep, @@ -15,7 +15,7 @@ def list_threads( return thread_service.list_(limit=limit) -@router.post("", response_model=Thread) +@router.post("") def create_thread( thread_service: ThreadService = ThreadServiceDep, ) -> Thread: @@ -23,7 +23,7 @@ def create_thread( return thread_service.create() -@router.get("/{thread_id}", response_model=Thread) +@router.get("/{thread_id}") def get_thread( thread_id: str, thread_service: ThreadService = ThreadServiceDep, @@ -32,7 +32,7 @@ def get_thread( try: return thread_service.retrieve(thread_id) except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e @router.delete("/{thread_id}") @@ -44,4 +44,4 @@ def delete_thread( try: thread_service.delete(thread_id) except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py index d1799db7..10a716e1 100644 --- a/src/askui/chat/api/threads/service.py +++ b/src/askui/chat/api/threads/service.py @@ -139,7 +139,7 @@ def delete(self, thread_id: str) -> None: for img_path in msg.content[0].image_paths: try: Path(img_path).unlink() - except FileNotFoundError: + except FileNotFoundError: # noqa: PERF203 pass # Image might have been deleted already except FileNotFoundError: pass # Thread might have been deleted already diff --git a/src/askui/chat/api/utils.py b/src/askui/chat/api/utils.py index 590727c2..d256b794 100644 --- a/src/askui/chat/api/utils.py +++ b/src/askui/chat/api/utils.py @@ -15,7 +15,9 @@ def generate_time_ordered_id(prefix: str) -> str: # Get current timestamp in milliseconds timestamp = int(time.time() * 1000) # Convert to base32 for shorter string (removing padding) - timestamp_b32 = base64.b32encode(str(timestamp).encode()).decode().rstrip("=").lower() + timestamp_b32 = ( + base64.b32encode(str(timestamp).encode()).decode().rstrip("=").lower() + ) # Get 12 random bytes and convert to base32 (removing padding) random_bytes = os.urandom(12) random_b32 = base64.b32encode(random_bytes).decode().rstrip("=").lower() From f5891353320aab045483938e6dce5f1f0bee3d19 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Tue, 3 Jun 2025 23:39:21 +0200 Subject: [PATCH 11/20] feat: add runs endpoints to chat api --- src/askui/chat/api/fastapi.py | 2 + src/askui/chat/api/images/router.py | 2 +- src/askui/chat/api/messages/router.py | 2 +- src/askui/chat/api/runs/__init__.py | 0 src/askui/chat/api/runs/dependencies.py | 14 +++ src/askui/chat/api/runs/router.py | 65 +++++++++++++ src/askui/chat/api/runs/service.py | 124 ++++++++++++++++++++++++ src/askui/chat/api/threads/router.py | 2 +- 8 files changed, 208 insertions(+), 3 deletions(-) create mode 100644 src/askui/chat/api/runs/__init__.py create mode 100644 src/askui/chat/api/runs/dependencies.py create mode 100644 src/askui/chat/api/runs/router.py create mode 100644 src/askui/chat/api/runs/service.py diff --git a/src/askui/chat/api/fastapi.py b/src/askui/chat/api/fastapi.py index e9d8cf00..93e0d5c2 100644 --- a/src/askui/chat/api/fastapi.py +++ b/src/askui/chat/api/fastapi.py @@ -3,6 +3,7 @@ from askui.chat.api.images.router import router as images_router from askui.chat.api.messages.router import router as messages_router +from askui.chat.api.runs.router import router as runs_router from askui.chat.api.threads.router import router as threads_router app = FastAPI( @@ -25,4 +26,5 @@ v1_router.include_router(threads_router) v1_router.include_router(messages_router) v1_router.include_router(images_router) +v1_router.include_router(runs_router) app.include_router(v1_router) diff --git a/src/askui/chat/api/images/router.py b/src/askui/chat/api/images/router.py index 2377386d..7181489f 100644 --- a/src/askui/chat/api/images/router.py +++ b/src/askui/chat/api/images/router.py @@ -8,7 +8,7 @@ @router.get("/{image_path:path}") -def get_image( +def retrieve_image( image_path: str, settings: Settings = SettingsDep, ) -> FileResponse: diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 98507b47..568cc90d 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -58,7 +58,7 @@ async def create_message( @router.get("/{message_id}") -def get_message( +def retrieve_message( thread_id: str, message_id: str, message_service: MessageService = MessageServiceDep, diff --git a/src/askui/chat/api/runs/__init__.py b/src/askui/chat/api/runs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py new file mode 100644 index 00000000..772c2545 --- /dev/null +++ b/src/askui/chat/api/runs/dependencies.py @@ -0,0 +1,14 @@ +from fastapi import Depends + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.settings import Settings + +from .service import RunService + + +def get_runs_service(settings: Settings = SettingsDep) -> RunService: + """Get RunService instance.""" + return RunService(settings.data_dir) + + +RunServiceDep = Depends(get_runs_service) diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py new file mode 100644 index 00000000..aa509fb6 --- /dev/null +++ b/src/askui/chat/api/runs/router.py @@ -0,0 +1,65 @@ +from typing import Annotated + +from fastapi import APIRouter, Body, HTTPException, Path +from pydantic import BaseModel + +from .dependencies import RunServiceDep +from .service import Run, RunListResponse, RunService + + +class CreateRunRequest(BaseModel): + stream: bool = False + + +router = APIRouter(prefix="/threads/{thread_id}/runs", tags=["runs"]) + + +@router.post("") +def create_run( + thread_id: Annotated[str, Path(...)], + request: Annotated[CreateRunRequest, Body(...)], + run_service: RunService = RunServiceDep, +) -> Run: + """ + Create a new run for a given thread. + """ + return run_service.create(thread_id, request.stream) + + +@router.get("/{run_id}") +def retrieve_run( + run_id: Annotated[str, Path(...)], + run_service: RunService = RunServiceDep, +) -> Run: + """ + Retrieve a run by its ID. + """ + try: + return run_service.retrieve(run_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + + +@router.get("") +def list_runs( + thread_id: Annotated[str, Path(...)], + run_service: RunService = RunServiceDep, +) -> RunListResponse: + """ + List runs, optionally filtered by thread. + """ + return run_service.list_(thread_id) + + +@router.post("/{run_id}/cancel") +def cancel_run( + run_id: Annotated[str, Path(...)], + run_service: RunService = RunServiceDep, +) -> Run: + """ + Cancel a run by its ID. + """ + try: + return run_service.cancel(run_id) + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py new file mode 100644 index 00000000..022410e0 --- /dev/null +++ b/src/askui/chat/api/runs/service.py @@ -0,0 +1,124 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal, Sequence + +from pydantic import AwareDatetime, BaseModel, Field, computed_field + +from askui.chat.api.utils import generate_time_ordered_id + +RunStatus = Literal[ + "queued", + "in_progress", + "completed", + "cancelling", + "cancelled", + "failed", + "expired", +] + + +class RunError(BaseModel): + message: str + code: Literal["server_error"] + + +class Run(BaseModel): + id: str = Field(default_factory=lambda: generate_time_ordered_id("run")) + thread_id: str + created_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + ) + started_at: AwareDatetime | None = None + completed_at: AwareDatetime | None = None + tried_cancelling_at: AwareDatetime | None = None + cancelled_at: AwareDatetime | None = None + expires_at: AwareDatetime | None = None + failed_at: AwareDatetime | None = None + last_error: RunError | None = None + object: Literal["run"] = "run" + + @computed_field + @property + def status(self) -> RunStatus: + if self.cancelled_at: + return "cancelled" + if self.failed_at: + return "failed" + if self.completed_at: + return "completed" + if self.expires_at and self.expires_at < datetime.now(tz=timezone.utc): + return "expired" + if self.tried_cancelling_at: + return "cancelling" + if self.started_at: + return "in_progress" + return "queued" + + +class RunListResponse(BaseModel): + object: Literal["list"] = "list" + data: Sequence[Run] + first_id: str | None = None + last_id: str | None = None + has_more: bool = False + + +class RunService: + """ + Service for managing runs. Handles creation, retrieval, listing, and cancellation of runs. + """ + + def __init__(self, base_dir: Path) -> None: + self._base_dir = base_dir + self._runs_dir = base_dir / "runs" + + def _run_path(self, thread_id: str, run_id: str) -> Path: + return self._runs_dir / f"{thread_id}__{run_id}.json" + + def create(self, thread_id: str, stream: bool) -> Run: + run = Run(thread_id=thread_id) + self._runs_dir.mkdir(parents=True, exist_ok=True) + run_file = self._run_path(thread_id, run.id) + with run_file.open("w") as f: + f.write(run.model_dump_json()) + return run + + def retrieve(self, run_id: str) -> Run: + # Find the file by run_id + for f in self._runs_dir.glob(f"*__{run_id}.json"): + with f.open("r") as file: + return Run.model_validate_json(file.read()) + error_msg = f"Run {run_id} not found" + raise FileNotFoundError(error_msg) + + def list_(self, thread_id: str | None = None) -> RunListResponse: + if not self._runs_dir.exists(): + return RunListResponse(data=[]) + if thread_id: + run_files = list(self._runs_dir.glob(f"{thread_id}__*.json")) + else: + run_files = list(self._runs_dir.glob("*__*.json")) + runs: list[Run] = [] + for f in run_files: + with f.open("r") as file: + runs.append(Run.model_validate_json(file.read())) + runs = sorted(runs, key=lambda r: r.created_at, reverse=True) + return RunListResponse( + data=runs, + first_id=runs[0].id if runs else None, + last_id=runs[-1].id if runs else None, + has_more=False, + ) + + def cancel(self, run_id: str) -> Run: + run = self.retrieve(run_id) + if run.status in ("cancelled", "cancelling", "completed", "failed", "expired"): + return run + run.tried_cancelling_at = datetime.now(tz=timezone.utc) + for f in self._runs_dir.glob(f"*__{run_id}.json"): + with f.open("w") as file: + file.write(run.model_dump_json()) + return run + # Find the file by run_id + error_msg = f"Run {run_id} not found" + raise FileNotFoundError(error_msg) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py index ffd9b49f..00e4d16a 100644 --- a/src/askui/chat/api/threads/router.py +++ b/src/askui/chat/api/threads/router.py @@ -24,7 +24,7 @@ def create_thread( @router.get("/{thread_id}") -def get_thread( +def retrieve_thread( thread_id: str, thread_service: ThreadService = ThreadServiceDep, ) -> Thread: From c842dbf7b2fa75b5f6326e7a668124979b8a580d Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Wed, 4 Jun 2025 09:24:11 +0200 Subject: [PATCH 12/20] fix: call on_message on tool result message --- src/askui/models/shared/computer_agent.py | 47 +++++++++++++++-------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py index 49fd34c5..60be1aa0 100644 --- a/src/askui/models/shared/computer_agent.py +++ b/src/askui/models/shared/computer_agent.py @@ -259,27 +259,40 @@ def _step( "BetaMessageParam", response_message.model_dump(include={"content", "role"}), ) - if on_message is not None: - response_message_param_cb = on_message(response_message_param, messages) - if response_message_param_cb is None: - return - - response_message_param = response_message_param_cb - logger.debug(response_message_param) - messages.append(response_message_param) - self._reporter.add_message( - self.__class__.__name__, dict(response_message_param) + message_by_assistant = self._call_on_message( + on_message, response_message_param, messages ) + if message_by_assistant is None: + return + logger.debug(message_by_assistant) + messages.append(message_by_assistant) + self._reporter.add_message(self.__class__.__name__, dict(message_by_assistant)) if tool_result_message := self._use_tools( response_message, messages, on_tool_result ): - messages.append(tool_result_message) - self._step( - messages=messages, - model_choice=model_choice, - on_message=on_message, - on_tool_result=on_tool_result, - ) + if tool_result_message := self._call_on_message( + on_message, tool_result_message, messages + ): + messages.append(tool_result_message) + self._step( + messages=messages, + model_choice=model_choice, + on_message=on_message, + on_tool_result=on_tool_result, + ) + + def _call_on_message( + self, + on_message: Callable[ + [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None + ] + | None, + message: BetaMessageParam, + messages: list[BetaMessageParam], + ) -> BetaMessageParam | None: + if on_message is None: + return message + return on_message(message, messages) def act( self, From d6fcf843e332ea774ca6d1b38753dace84f9d58d Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 11:57:15 +0200 Subject: [PATCH 13/20] feat!: make agents api work --- .cursorrules | 1 + README.md | 39 +- act.py | 7 - src/askui/__init__.py | 38 +- src/askui/agent.py | 30 +- src/askui/chat/__main__.py | 545 +++++++++--------- src/askui/chat/api/messages/router.py | 29 +- src/askui/chat/api/messages/service.py | 98 +--- src/askui/chat/api/runs/service.py | 88 ++- src/askui/chat/api/threads/service.py | 18 +- src/askui/models/__init__.py | 34 +- src/askui/models/anthropic/computer_agent.py | 19 +- src/askui/models/askui/computer_agent.py | 13 +- src/askui/models/model_router.py | 20 +- src/askui/models/models.py | 39 +- src/askui/models/shared/computer_agent.py | 242 ++++---- .../models/shared/computer_agent_cb_param.py | 14 + .../shared/computer_agent_message_param.py | 107 ++++ src/askui/models/shared/facade.py | 19 +- src/askui/models/ui_tars_ep/ui_tars_api.py | 27 +- src/askui/tools/anthropic/computer.py | 5 + .../integration/agent/test_computer_agent.py | 3 +- tests/integration/test_custom_models.py | 34 +- .../unit/models/test_computer_agent_filter.py | 129 +++++ tests/unit/models/test_model_router.py | 19 +- 25 files changed, 890 insertions(+), 727 deletions(-) delete mode 100644 act.py create mode 100644 src/askui/models/shared/computer_agent_cb_param.py create mode 100644 src/askui/models/shared/computer_agent_message_param.py create mode 100644 tests/unit/models/test_computer_agent_filter.py diff --git a/.cursorrules b/.cursorrules index 3889764c..31bd6766 100644 --- a/.cursorrules +++ b/.cursorrules @@ -7,6 +7,7 @@ - Use built-in types (e.g., `list`, `dict`, `tuple`, `set`, `str | None`) instead of types from `typing` module (e.g. `List`, `Dict`, `Tuple`, `Set`, `Optional`, `Union`) wherever possible - Instead of `Optional` use `| None` - Create a `__init__.py` file in each folder +- Never pass literals, e.g., `error_msg`, directly to `Exceptions`, but instead assign them to variables and pass them to the exception, e.g., `raise FileNotFoundError(error_msg)` instead of `raise FileNotFoundError(f"Thread {thread_id} not found")` ## FastAPI - Instead of defining `response_model` within route annotation, use the model as the response type in the function signature diff --git a/README.md b/README.md index a600f2f8..717751bd 100644 --- a/README.md +++ b/README.md @@ -287,20 +287,19 @@ Here's how to create and use custom models: import functools from askui import ( ActModel, - BetaMessageParam, - BetaToolUseBlockParam, GetModel, LocateModel, Locator, ImageSource, + MessageParam, ModelComposition, ModelRegistry, + OnMessageCb, Point, ResponseSchema, - ToolResult, VisionAgent, ) -from typing import Callable, Type +from typing import Type from typing_extensions import override # Define custom models @@ -308,24 +307,20 @@ class MyActModel(ActModel): @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: # Implement custom act logic, e.g.: # - Use a different AI model # - Implement custom business logic # - Call external services - goal = messages[0]["content"] - print(f"Custom act model executing goal: {goal}") + if len(messages) > 0: + goal = messages[0].content + print(f"Custom act model executing goal: {goal}") + else: + error_msg = "No messages provided" + raise ValueError(error_msg) # Because Python supports multiple inheritance, we can subclass both `GetModel` and `LocateModel` (and even `ActModel`) # to create a model that can both get and locate elements. @@ -391,17 +386,9 @@ class DynamicActModel(ActModel): @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: pass diff --git a/act.py b/act.py deleted file mode 100644 index f8f594b8..00000000 --- a/act.py +++ /dev/null @@ -1,7 +0,0 @@ -from askui import VisionAgent - -with VisionAgent(log_level="DEBUG") as agent: - agent.act("Click on the 'X' button to cancel the current search in Google Maps") - agent.act( - "Search for 'Linienstraße 145' in Google maps to find the route there from 'Google Berlin'" - ) diff --git a/src/askui/__init__.py b/src/askui/__init__.py index e3f321ec..478012f3 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -6,16 +6,29 @@ from .locators import Locator from .models import ( ActModel, - BetaMessageParam, - BetaToolUseBlockParam, + Base64ImageSourceParam, + CacheControlEphemeralParam, + CitationCharLocationParam, + CitationContentBlockLocationParam, + CitationPageLocationParam, + ContentBlockParam, GetModel, + ImageBlockParam, LocateModel, + MessageParam, Model, ModelChoice, ModelComposition, ModelDefinition, + ModelName, ModelRegistry, + OnMessageCb, Point, + TextBlockParam, + TextCitationParam, + ToolResultBlockParam, + ToolUseBlockParam, + UrlImageSourceParam, ) from .models.types.response_schemas import ResponseSchema, ResponseSchemaBase from .retry import ConfigurableRetry, Retry @@ -25,25 +38,38 @@ __all__ = [ "ActModel", - "BetaMessageParam", - "BetaToolUseBlockParam", + "Base64ImageSourceParam", + "CacheControlEphemeralParam", + "CitationCharLocationParam", + "CitationContentBlockLocationParam", + "CitationPageLocationParam", + "ConfigurableRetry", + "ContentBlockParam", "GetModel", + "ImageBlockParam", "ImageSource", "Img", "LocateModel", "Locator", + "MessageParam", "Model", + "ModelChoice", "ModelComposition", "ModelDefinition", - "ModelChoice", + "ModelName", "ModelRegistry", "ModifierKey", + "OnMessageCb", "PcKey", "Point", "ResponseSchema", "ResponseSchemaBase", "Retry", + "TextBlockParam", + "TextCitationParam", "ToolResult", - "ConfigurableRetry", + "ToolResultBlockParam", + "ToolUseBlockParam", + "UrlImageSourceParam", "VisionAgent", ] diff --git a/src/askui/agent.py b/src/askui/agent.py index 95d22c3e..0e1023a9 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -1,15 +1,15 @@ import logging import time import types -from typing import Annotated, Callable, Literal, Optional, Type, overload +from typing import Annotated, Literal, Optional, Type, overload -from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from dotenv import load_dotenv from pydantic import ConfigDict, Field, validate_call from askui.container import telemetry from askui.locators.locators import Locator -from askui.tools.anthropic import ToolResult +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.utils.image_utils import ImageSource, Img from .exceptions import ElementNotFoundError @@ -536,17 +536,9 @@ def mouse_down( @validate_call def act( self, - goal: Annotated[str | list[BetaMessageParam], Field(min_length=1)], + goal: Annotated[str | list[MessageParam], Field(min_length=1)], model: str | None = None, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: """ Instructs the agent to achieve a specified goal through autonomous actions. @@ -558,9 +550,7 @@ def act( Args: goal (str): A description of what the agent should achieve. model (str | None, optional): The composition or name of the model(s) to be used for achieving the `goal`. - messages (list[BetaMessageParam] | None, optional): The message history to start from. If None, starts with a new message containing the goal. - on_message (Callable[[BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None], optional): Callback for new messages. If it returns `None`, stops and does not add the message. - on_tool_result (Callable[[ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], ToolResult | None], optional): Callback for tool results. If it returns `None`, stops and does not add the tool result. + on_message (OnMessageCb | None, optional): Callback for new messages. If it returns `None`, stops and does not add the message. Returns: None @@ -579,12 +569,10 @@ def act( logger.debug( "VisionAgent received instruction to act towards the goal '%s'", goal ) - messages: list[BetaMessageParam] = ( - [{"role": "user", "content": goal}] if isinstance(goal, str) else goal - ) - self._model_router.act( - messages, model or self._model_choice["act"], on_message, on_tool_result + messages: list[MessageParam] = ( + [MessageParam(role="user", content=goal)] if isinstance(goal, str) else goal ) + self._model_router.act(messages, model or self._model_choice["act"], on_message) @telemetry.record_call() @validate_call diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 82036606..5bf6aeb7 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -1,23 +1,23 @@ +import io import json -import logging -import re +import time from pathlib import Path -from typing import Any, Union, cast +import httpx import streamlit as st -from PIL import Image, ImageDraw -from typing_extensions import override +from PIL import Image -from askui import VisionAgent -from askui.chat.api.messages.service import MessageRole, MessageService +from askui.chat.api.messages.service import Message, MessageService +from askui.chat.api.runs.service import RunService from askui.chat.api.threads.service import ThreadService -from askui.chat.click_recorder import ClickRecorder -from askui.chat.exceptions import FunctionExecutionError, InvalidFunctionError -from askui.models import ModelName -from askui.reporting import Reporter -from askui.tools.pynput.pynput_agent_os import PynputAgentOs -from askui.tools.toolbox import AgentToolbox -from askui.utils.image_utils import base64_to_image, draw_point_on_image + +# from askui.chat.click_recorder import ClickRecorder +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + MessageParam, + UrlImageSourceParam, +) +from askui.utils.image_utils import base64_to_image st.set_page_config( page_title="Vision Agent Chat", @@ -26,201 +26,201 @@ BASE_DIR = Path("./chat") -threads_api = ThreadService(BASE_DIR) -messages_api = MessageService(BASE_DIR) -click_recorder = ClickRecorder() -def get_image(img_b64_str_or_path: str) -> Image.Image: - """Get image from base64 string or file path.""" - if Path(img_b64_str_or_path).is_file(): - return Image.open(img_b64_str_or_path) - return base64_to_image(img_b64_str_or_path) +@st.cache_resource +def get_thread_service() -> ThreadService: + return ThreadService(BASE_DIR) + + +@st.cache_resource +def get_message_service() -> MessageService: + return MessageService(BASE_DIR) + + +@st.cache_resource +def get_run_service() -> RunService: + return RunService(BASE_DIR) + + +thread_service = get_thread_service() +message_service = get_message_service() +run_service = get_run_service() + +# click_recorder = ClickRecorder() + + +def get_image( + source: Base64ImageSourceParam | UrlImageSourceParam, +) -> Image.Image: + match source.type: + case "base64": + data = source.data + if isinstance(data, str): + return base64_to_image(data) + error_msg = f"Image source data type not supported: {type(data)}" + raise NotImplementedError(error_msg) + case "url": + response = httpx.get(source.url) + return Image.open(io.BytesIO(response.content)) def write_message( - role: str, - content: str | dict[str, Any] | list[Any], - timestamp: str, - image: Image.Image - | str - | list[str | Image.Image] - | list[str] - | list[Image.Image] - | None = None, - message_id: str | None = None, + message: Message, ) -> None: - _role = messages_api.ROLE_MAP.get(role.lower(), MessageRole.UNKNOWN) - avatar = None if _role != MessageRole.UNKNOWN else "❔" - # Create a container for the message and delete button col1, col2 = st.columns([0.95, 0.05]) with col1: - with st.chat_message(_role.value, avatar=avatar): - st.markdown(f"*{timestamp}* - **{role}**\n\n") - st.markdown( - json.dumps(content, indent=2) - if isinstance(content, (dict, list)) - else content - ) - if image: - if isinstance(image, list): - for img in image: - img = get_image(img) if isinstance(img, str) else img - st.image(img) - else: - img = get_image(image) if isinstance(image, str) else image - st.image(img) + with st.chat_message(message.role): + st.markdown(f"*{message.created_at.isoformat()}* - **{message.role}**\n\n") + if isinstance(message.content, str): + st.markdown(message.content) + else: + for block in message.content: + match block.type: + case "image": + st.image(get_image(block.source)) + case "text": + st.markdown(block.text) + case "tool_result": + st.markdown(f"Tool use id: {block.tool_use_id}") + st.markdown(f"Erroneous: {block.is_error}") + content = block.content + if isinstance(content, str): + st.markdown(content) + else: + for nested_block in content: + match nested_block.type: + case "image": + st.image(get_image(nested_block.source)) + case "text": + st.markdown(nested_block.text) + case _: + st.markdown( + json.dumps(block.model_dump(mode="json"), indent=2) + ) # Add delete button in the second column if message_id is provided - if message_id: - with col2: - if st.button("🗑️", key=f"delete_{message_id}"): - messages_api.delete(st.session_state.thread_id, message_id) - st.rerun() - - -class ChatHistoryAppender(Reporter): - def __init__(self, thread_id: str) -> None: - self._thread_id = thread_id - - @override - def add_message( - self, - role: str, - content: Union[str, dict[str, Any], list[Any]], - image: Image.Image | list[Image.Image] | None = None, - ) -> None: - message = messages_api.create( - thread_id=self._thread_id, role=role, content=content, image=image - ) - write_message( - role=message.role.value, - content=message.content[0].text or "", - timestamp=message.created_at.isoformat(), - image=message.content[0].image_paths, - message_id=message.id, - ) - - @override - def generate(self) -> None: - pass + with col2: + if st.button("🗑️", key=f"delete_{message.id}"): + message_service.delete(st.session_state.thread_id, message.id) + st.rerun() -def paint_crosshair( - image: Image.Image, - coordinates: tuple[int, int], - size: int | None = None, - color: str = "red", - width: int = 4, -) -> Image.Image: - """ - Paints a crosshair at the given coordinates on the image. - - :param image: A PIL Image object. - :param coordinates: A tuple (x, y) representing the coordinates of the point. - :param size: Optional length of each line in the crosshair. Defaults to min(width,height)/20 - :param color: The color of the crosshair. - :param width: The width of the crosshair. - :return: A new image with the crosshair. - """ - if size is None: - size = ( - min(image.width, image.height) // 20 - ) # Makes crosshair ~5% of smallest image dimension - - image_copy = image.copy() - draw = ImageDraw.Draw(image_copy) - x, y = coordinates - # Draw horizontal and vertical lines - draw.line((x - size, y, x + size, y), fill=color, width=width) - draw.line((x, y - size, x, y + size), fill=color, width=width) - return image_copy - - -prompt = """The following image is a screenshot with a red crosshair on top of an element that the user wants to interact with. Give me a description that uniquely describes the element as concise as possible across all elements on the screen that the user most likely wants to interact with. Examples: - -- "Submit button" -- "Cell within the table about European countries in the third row and 6th column (area in km^2) in the right-hand browser window" -- "Avatar in the top right hand corner of the browser in focus that looks like a woman" -""" - - -def rerun() -> None: - st.markdown("### Re-running...") - with VisionAgent( - log_level=logging.DEBUG, - tools=tools, - ) as agent: - screenshot: Image.Image | None = None - for message in messages_api.list_(st.session_state.thread_id).data: - try: - if ( - message.role == MessageRole.ASSISTANT - or message.role == MessageRole.USER - ): - content = message.content[0] - if content.text == "screenshot()": - screenshot = ( - get_image(content.image_paths[0]) - if content.image_paths - else None - ) - continue - if content.text: - if match := re.match( - r"mouse\((\d+),\s*(\d+)\)", cast("str", content.text) - ): - if not screenshot: - error_msg = "Screenshot is required to paint crosshair" - raise ValueError(error_msg) # noqa: TRY301 - x, y = map(int, match.groups()) - screenshot_with_crosshair = paint_crosshair( - screenshot, (x, y) - ) - element_description = agent.get( - query=prompt, - image=screenshot_with_crosshair, - model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, - ) - messages_api.create( - thread_id=st.session_state.thread_id, - role=message.role.value, - content=f"Move mouse to {element_description}", - image=screenshot_with_crosshair, - ) - agent.mouse_move( - locator=element_description.replace('"', ""), - model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, - ) - else: - messages_api.create( - thread_id=st.session_state.thread_id, - role=message.role.value, - content=content.text, - image=None, - ) - func_call = f"agent.tools.os.{content.text}" - eval(func_call) - except json.JSONDecodeError: - continue - except AttributeError: - st.write(str(InvalidFunctionError(cast("str", content.text)))) - except Exception as e: # noqa: BLE001 - We want to catch all other exceptions here - st.write(str(FunctionExecutionError(cast("str", content.text), e))) +# def paint_crosshair( +# image: Image.Image, +# coordinates: tuple[int, int], +# size: int | None = None, +# color: str = "red", +# width: int = 4, +# ) -> Image.Image: +# """ +# Paints a crosshair at the given coordinates on the image. + +# :param image: A PIL Image object. +# :param coordinates: A tuple (x, y) representing the coordinates of the point. +# :param size: Optional length of each line in the crosshair. Defaults to min(width,height)/20 +# :param color: The color of the crosshair. +# :param width: The width of the crosshair. +# :return: A new image with the crosshair. +# """ +# if size is None: +# size = ( +# min(image.width, image.height) // 20 +# ) # Makes crosshair ~5% of smallest image dimension + +# image_copy = image.copy() +# draw = ImageDraw.Draw(image_copy) +# x, y = coordinates +# # Draw horizontal and vertical lines +# draw.line((x - size, y, x + size, y), fill=color, width=width) +# draw.line((x, y - size, x, y + size), fill=color, width=width) +# return image_copy + + +# prompt = """The following image is a screenshot with a red crosshair on top of an element that the user wants to interact with. Give me a description that uniquely describes the element as concise as possible across all elements on the screen that the user most likely wants to interact with. Examples: + +# - "Submit button" +# - "Cell within the table about European countries in the third row and 6th column (area in km^2) in the right-hand browser window" +# - "Avatar in the top right hand corner of the browser in focus that looks like a woman" +# """ + + +# def rerun() -> None: +# st.markdown("### Re-running...") +# with VisionAgent( +# log_level=logging.DEBUG, +# tools=tools, +# ) as agent: +# screenshot: Image.Image | None = None +# for message in messages_service.list_(st.session_state.thread_id).data: +# try: +# if ( +# message.role == MessageRole.ASSISTANT +# or message.role == MessageRole.USER +# ): +# content = message.content[0] +# if content.text == "screenshot()": +# screenshot = ( +# get_image(content.image_paths[0]) +# if content.image_paths +# else None +# ) +# continue +# if content.text: +# if match := re.match( +# r"mouse\((\d+),\s*(\d+)\)", cast("str", content.text) +# ): +# if not screenshot: +# error_msg = "Screenshot is required to paint crosshair" +# raise ValueError(error_msg) # noqa: TRY301 +# x, y = map(int, match.groups()) +# screenshot_with_crosshair = paint_crosshair( +# screenshot, (x, y) +# ) +# element_description = agent.get( +# query=prompt, +# image=screenshot_with_crosshair, +# model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, +# ) +# messages_service.create( +# thread_id=st.session_state.thread_id, +# role=message.role.value, +# content=f"Move mouse to {element_description}", +# image=screenshot_with_crosshair, +# ) +# agent.mouse_move( +# locator=element_description.replace('"', ""), +# model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, +# ) +# else: +# messages_service.create( +# thread_id=st.session_state.thread_id, +# role=message.role.value, +# content=content.text, +# image=None, +# ) +# func_call = f"agent.tools.os.{content.text}" +# eval(func_call) +# except json.JSONDecodeError: +# continue +# except AttributeError: +# st.write(str(InvalidFunctionError(cast("str", content.text)))) +# except Exception as e: # noqa: BLE001 - We want to catch all other exceptions here +# st.write(str(FunctionExecutionError(cast("str", content.text), e))) if st.sidebar.button("New Chat"): - thread = threads_api.create() + thread = thread_service.create() st.session_state.thread_id = thread.id st.rerun() -available_threads = threads_api.list_().data +available_threads = thread_service.list_().data thread_id = st.session_state.get("thread_id", None) if not thread_id and not available_threads: - thread = threads_api.create() + thread = thread_service.create() thread_id = thread.id st.session_state.thread_id = thread_id st.rerun() @@ -252,102 +252,105 @@ def rerun() -> None: st.session_state.thread_id = remaining_threads[0].id else: # Create new thread if no threads left - new_thread = threads_api.create() + new_thread = thread_service.create() st.session_state.thread_id = new_thread.id - threads_api.delete(t.id) + thread_service.delete(t.id) st.rerun() if thread_id != st.session_state.get("thread_id"): st.session_state.thread_id = thread_id st.rerun() -reporter = ChatHistoryAppender(thread_id) - - -@st.cache_resource -def get_tools() -> AgentToolbox: - return AgentToolbox(agent_os=PynputAgentOs(reporter=reporter)) - - -tools = get_tools() - st.title(f"Vision Agent Chat - {thread_id}") # Display chat history -for message in messages_api.list_(thread_id).data: - write_message( - message.role.value, - message.content[0].text or "", - message.created_at.isoformat(), - message.content[0].image_paths, - message.id, # Pass the message ID to enable deletion - ) - -if value_to_type := st.chat_input("Simulate Typing for User (Demonstration)"): - reporter.add_message( - role="User (Demonstration)", - content=f'type("{value_to_type}", 50)', - ) - st.rerun() - -if st.button("Simulate left click"): - reporter.add_message( - role="User (Demonstration)", - content='click("left", 1)', - ) - st.rerun() +messages = message_service.list_(thread_id).data +for message in messages: + write_message(message) + +last_message = messages[-1] if messages else None + +# if value_to_type := st.chat_input("Simulate Typing for User (Demonstration)"): +# reporter.add_message( +# role="user", +# content=f'type("{value_to_type}", 50)', +# ) +# st.rerun() + +# if st.button("Simulate left click"): +# reporter.add_message( +# role="User (Demonstration)", +# content='click("left", 1)', +# ) +# st.rerun() + +# # Chat input +# if st.button( +# "Demonstrate where to move mouse" +# ): # only single step, only click supported for now, independent of click always registered as click +# image, coordinates = click_recorder.record() +# reporter.add_message( +# role="User (Demonstration)", +# content="screenshot()", +# image=image, +# ) +# reporter.add_message( +# role="User (Demonstration)", +# content=f"mouse_move({coordinates[0]}, {coordinates[1]})", +# image=draw_point_on_image(image, coordinates[0], coordinates[1]), +# ) +# st.rerun() + +# if st.session_state.get("input_event_listening"): +# while input_event := tools.os.poll_event(): +# image = tools.os.screenshot(report=False) +# if input_event.pressed: +# reporter.add_message( +# role="User (Demonstration)", +# content=f"mouse_move({input_event.x}, {input_event.y})", +# image=draw_point_on_image(image, input_event.x, input_event.y), +# ) +# reporter.add_message( +# role="User (Demonstration)", +# content=f'click("{input_event.button}")', +# ) +# if st.button("Refresh"): +# st.rerun() +# if st.button("Stop listening to input events"): +# tools.os.stop_listening() +# st.session_state["input_event_listening"] = False +# st.rerun() +# else: +# if st.button("Listen to input events"): +# tools.os.start_listening() +# st.session_state["input_event_listening"] = True +# st.rerun() -# Chat input -if st.button( - "Demonstrate where to move mouse" -): # only single step, only click supported for now, independent of click always registered as click - image, coordinates = click_recorder.record() - reporter.add_message( - role="User (Demonstration)", - content="screenshot()", - image=image, - ) - reporter.add_message( - role="User (Demonstration)", - content=f"mouse_move({coordinates[0]}, {coordinates[1]})", - image=draw_point_on_image(image, coordinates[0], coordinates[1]), - ) - st.rerun() +if act_prompt := st.chat_input("Ask AI"): + if act_prompt != "Continue": + last_message = message_service.create( + thread_id=thread_id, + message=MessageParam( + role="user", + content=act_prompt, + ), + ) + write_message(last_message) + run = run_service.create(thread_id, stream=False) + print(run) + time.sleep(1) + while run := run_service.retrieve(run.id): + new_messages = message_service.list_( + thread_id, after=last_message.id if last_message else None + ).data + for message in new_messages: + write_message(message) + last_message = new_messages[-1] if new_messages else last_message + if run.status not in {"queued", "running", "in_progress"}: + break + time.sleep(1) -if st.session_state.get("input_event_listening"): - while input_event := tools.os.poll_event(): - image = tools.os.screenshot(report=False) - if input_event.pressed: - reporter.add_message( - role="User (Demonstration)", - content=f"mouse_move({input_event.x}, {input_event.y})", - image=draw_point_on_image(image, input_event.x, input_event.y), - ) - reporter.add_message( - role="User (Demonstration)", - content=f'click("{input_event.button}")', - ) - if st.button("Refresh"): - st.rerun() - if st.button("Stop listening to input events"): - tools.os.stop_listening() - st.session_state["input_event_listening"] = False - st.rerun() -else: - if st.button("Listen to input events"): - tools.os.start_listening() - st.session_state["input_event_listening"] = True - st.rerun() -if act_prompt := st.chat_input("Ask AI"): - with VisionAgent( # we need the vision agent - log_level=logging.DEBUG, - reporters=[reporter], - tools=tools, - ) as agent: - agent.act(act_prompt, model=ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022) - st.rerun() - -if st.button("Rerun"): - rerun() +# if st.button("Rerun"): +# rerun() diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 568cc90d..2544862b 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -1,20 +1,8 @@ -from io import BytesIO -from typing import Any - -from fastapi import APIRouter, File, HTTPException, UploadFile -from PIL import Image -from pydantic import BaseModel +from fastapi import APIRouter, HTTPException from askui.chat.api.messages.dependencies import MessageServiceDep from askui.chat.api.messages.service import Message, MessageListResponse, MessageService - - -class CreateMessageRequest(BaseModel): - """Request model for creating a message.""" - - role: str - content: str | dict[str, Any] | list[Any] - +from askui.models.shared.computer_agent_message_param import MessageParam router = APIRouter(prefix="/threads/{thread_id}/messages", tags=["messages"]) @@ -35,23 +23,14 @@ def list_messages( @router.post("") async def create_message( thread_id: str, - request: CreateMessageRequest, - image: UploadFile | None = None, + message: MessageParam, message_service: MessageService = MessageServiceDep, ) -> Message: """Create a new message in a thread.""" try: - # Handle image upload if provided - pil_image = None - if image: - img_data = await image.read() - pil_image = Image.open(BytesIO(img_data)) - return message_service.create( thread_id=thread_id, - role=request.role, - content=request.content, - image=pil_image, + message=message, ) except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 0a7a150a..e352e7d1 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -1,38 +1,17 @@ from datetime import datetime, timezone -from enum import Enum from pathlib import Path -from typing import Any, Sequence, Union -from PIL import Image from pydantic import AwareDatetime, BaseModel, Field from askui.chat.api.utils import generate_time_ordered_id +from askui.models.shared.computer_agent_message_param import MessageParam -class MessageRole(str, Enum): - """Valid message roles.""" - - USER = "user" - ASSISTANT = "assistant" - SYSTEM = "system" - AI = "ai" - UNKNOWN = "unknown" - - -class MessageContent(BaseModel): - """Message content with optional image paths.""" - - text: str | None = None - image_paths: list[str] | None = None - - -class Message(BaseModel): +class Message(MessageParam): """A message in a thread.""" id: str = Field(default_factory=lambda: generate_time_ordered_id("msg")) thread_id: str - role: MessageRole - content: Sequence[MessageContent] created_at: AwareDatetime = Field( default_factory=lambda: datetime.now(tz=timezone.utc) ) @@ -43,7 +22,7 @@ class MessageListResponse(BaseModel): """Response model for listing messages.""" object: str = "list" - data: Sequence[Message] + data: list[Message] first_id: str | None = None last_id: str | None = None has_more: bool = False @@ -52,13 +31,6 @@ class MessageListResponse(BaseModel): class MessageService: """Service for managing messages within threads.""" - ROLE_MAP = { - "user": MessageRole.USER, - "anthropic computer use": MessageRole.AI, - "agentos": MessageRole.ASSISTANT, - "user (demonstration)": MessageRole.USER, - } - def __init__(self, base_dir: Path) -> None: """Initialize message service. @@ -69,12 +41,15 @@ def __init__(self, base_dir: Path) -> None: self._threads_dir = base_dir / "threads" self._images_dir = base_dir / "images" - def list_(self, thread_id: str, limit: int | None = None) -> MessageListResponse: + def list_( + self, thread_id: str, limit: int | None = None, after: str | None = None + ) -> MessageListResponse: """List all messages in a thread. Args: thread_id: ID of thread to list messages from limit: Optional maximum number of messages to return + after: Optional message ID after which messages are returned Returns: MessageListResponse containing messages sorted by creation date @@ -87,7 +62,7 @@ def list_(self, thread_id: str, limit: int | None = None) -> MessageListResponse error_msg = f"Thread {thread_id} not found" raise FileNotFoundError(error_msg) - messages = [] + messages: list[Message] = [] with thread_file.open("r") as f: for line in f: msg = Message.model_validate_json(line) @@ -95,6 +70,8 @@ def list_(self, thread_id: str, limit: int | None = None) -> MessageListResponse # Sort by creation date messages = sorted(messages, key=lambda m: m.created_at) + if after: + messages = [m for m in messages if m.id > after] # Apply limit if specified if limit is not None: @@ -110,9 +87,7 @@ def list_(self, thread_id: str, limit: int | None = None) -> MessageListResponse def create( self, thread_id: str, - role: str, - content: Union[str, dict[str, Any], list[Any]], - image: Image.Image | list[Image.Image] | None = None, + message: MessageParam, ) -> Message: """Create a new message in a thread. @@ -120,7 +95,6 @@ def create( thread_id: ID of thread to create message in role: Role of message sender content: Message content - image: Optional image(s) to attach Returns: Created message object @@ -132,42 +106,14 @@ def create( if not thread_file.exists(): error_msg = f"Thread {thread_id} not found" raise FileNotFoundError(error_msg) - - # Save images if provided - image_paths = [] - if image is not None: - if isinstance(image, list): - images = image - else: - images = [image] - - self._images_dir.mkdir(parents=True, exist_ok=True) - for img in images: - # Generate unique image ID using same format as thread/message IDs - image_id = generate_time_ordered_id("img") - image_path = self._images_dir / f"{image_id}.png" - img.save(image_path) - image_paths.append(str(image_path)) - - # Create message content - message_content = [ - MessageContent( - text=str(content), image_paths=image_paths if image_paths else None - ) - ] - - # Create message - message = Message( + message = Message.model_construct( thread_id=thread_id, - role=self.ROLE_MAP.get(role.lower(), MessageRole.UNKNOWN), - content=message_content, + role=message.role, + content=message.content, ) - - # Save message with thread_file.open("a") as f: f.write(message.model_dump_json()) f.write("\n") - return message def retrieve(self, thread_id: str, message_id: str) -> Message: @@ -205,14 +151,8 @@ def delete(self, thread_id: str, message_id: str) -> None: error_msg = f"Thread {thread_id} not found" raise FileNotFoundError(error_msg) - # Get message and image paths before deletion - msg_to_delete = self.retrieve(thread_id, message_id) - image_paths = ( - msg_to_delete.content[0].image_paths if msg_to_delete.content else None - ) - # Read all messages - messages = [] + messages: list[Message] = [] with thread_file.open("r") as f: for line in f: msg = Message.model_validate_json(line) @@ -224,11 +164,3 @@ def delete(self, thread_id: str, message_id: str) -> None: for msg in messages: f.write(msg.model_dump_json()) f.write("\n") - - # Delete associated images if any - if image_paths: - for img_path in image_paths: - try: - Path(img_path).unlink() - except FileNotFoundError: # noqa: PERF203 - pass # Image might have been deleted already diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 022410e0..24c0ee84 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -1,10 +1,15 @@ -from datetime import datetime, timezone +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Literal, Sequence +from typing import Literal, Sequence, cast from pydantic import AwareDatetime, BaseModel, Field, computed_field +from askui.agent import VisionAgent +from askui.chat.api.messages.service import MessageService from askui.chat.api.utils import generate_time_ordered_id +from askui.models.shared.computer_agent_cb_param import OnMessageCbParam +from askui.models.shared.computer_agent_message_param import MessageParam RunStatus = Literal[ "queued", @@ -32,7 +37,9 @@ class Run(BaseModel): completed_at: AwareDatetime | None = None tried_cancelling_at: AwareDatetime | None = None cancelled_at: AwareDatetime | None = None - expires_at: AwareDatetime | None = None + expires_at: AwareDatetime = Field( + default_factory=lambda: datetime.now(tz=timezone.utc) + timedelta(minutes=10) + ) failed_at: AwareDatetime | None = None last_error: RunError | None = None object: Literal["run"] = "run" @@ -63,11 +70,70 @@ class RunListResponse(BaseModel): has_more: bool = False +class Runner: + def __init__(self, run: Run, base_dir: Path) -> None: + self._run = run + self._base_dir = base_dir + self._runs_dir = base_dir / "runs" + self._msg_service = MessageService(self._base_dir) + + def run_task(self) -> None: + self._mark_started() + messages: list[MessageParam] = [ + cast("MessageParam", msg) + for msg in self._msg_service.list_(self._run.thread_id).data + ] + + def on_message( + on_message_cb_param: OnMessageCbParam, + ) -> MessageParam | None: + self._msg_service.create( + thread_id=self._run.thread_id, + message=on_message_cb_param.message, + ) + updated_run = self._retrieve_run() + if self._should_abort(updated_run): + updated_run.cancelled_at = datetime.now(tz=timezone.utc) + self._update_run_file(updated_run) + return None + return on_message_cb_param.message + + try: + with VisionAgent() as agent: + agent.act(messages, on_message=on_message) + self._run.completed_at = datetime.now(tz=timezone.utc) + self._update_run_file(self._run) + except Exception as e: # noqa: BLE001 + self._run.failed_at = datetime.now(tz=timezone.utc) + self._run.last_error = RunError(message=str(e), code="server_error") + self._update_run_file(self._run) + raise + + def _mark_started(self) -> None: + self._run.started_at = datetime.now(tz=timezone.utc) + self._update_run_file(self._run) + + def _should_abort(self, run: Run) -> bool: + return run.status in ("cancelled", "cancelling", "expired") + + def _update_run_file(self, run: Run) -> None: + run_file = self._runs_dir / f"{run.thread_id}__{run.id}.json" + with run_file.open("w") as f: + f.write(run.model_dump_json()) + + def _retrieve_run(self) -> Run: + run_file = self._runs_dir / f"{self._run.thread_id}__{self._run.id}.json" + with run_file.open("r") as f: + return Run.model_validate_json(f.read()) + + class RunService: """ Service for managing runs. Handles creation, retrieval, listing, and cancellation of runs. """ + _executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=4) + def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir self._runs_dir = base_dir / "runs" @@ -78,10 +144,22 @@ def _run_path(self, thread_id: str, run_id: str) -> Path: def create(self, thread_id: str, stream: bool) -> Run: run = Run(thread_id=thread_id) self._runs_dir.mkdir(parents=True, exist_ok=True) - run_file = self._run_path(thread_id, run.id) + self._update_run_file(run) + runner = Runner(run, self._base_dir) + # TODO(adi-wan-askui): Run differently depending on `stream` parameter + runner.run_task() + # if not stream: + # self._start_run_background(run) + return run + + def _start_run_background(self, run: Run) -> None: + runner = Runner(run, self._base_dir) + self._executor.submit(runner.run_task) + + def _update_run_file(self, run: Run) -> None: + run_file = self._run_path(run.thread_id, run.id) with run_file.open("w") as f: f.write(run.model_dump_json()) - return run def retrieve(self, run_id: str) -> Run: # Find the file by run_id diff --git a/src/askui/chat/api/threads/service.py b/src/askui/chat/api/threads/service.py index 10a716e1..5e4e7b4d 100644 --- a/src/askui/chat/api/threads/service.py +++ b/src/askui/chat/api/threads/service.py @@ -52,7 +52,7 @@ def list_(self, limit: int | None = None) -> ThreadListResponse: return ThreadListResponse(data=[]) thread_files = list(self._threads_dir.glob("*.jsonl")) - threads = [] + threads: list[Thread] = [] for f in thread_files: thread_id = f.stem created_at = datetime.fromtimestamp(f.stat().st_ctime, tz=timezone.utc) @@ -128,21 +128,5 @@ def delete(self, thread_id: str) -> None: error_msg = f"Thread {thread_id} not found" raise FileNotFoundError(error_msg) - # Get all image paths from messages before deleting thread - from askui.chat.api.messages.service import MessageService - - message_service = MessageService(self._base_dir) - try: - messages = message_service.list_(thread_id).data - for msg in messages: - if msg.content and msg.content[0].image_paths: - for img_path in msg.content[0].image_paths: - try: - Path(img_path).unlink() - except FileNotFoundError: # noqa: PERF203 - pass # Image might have been deleted already - except FileNotFoundError: - pass # Thread might have been deleted already - # Delete thread file thread_file.unlink() diff --git a/src/askui/models/__init__.py b/src/askui/models/__init__.py index f1765661..ec531fb0 100644 --- a/src/askui/models/__init__.py +++ b/src/askui/models/__init__.py @@ -1,5 +1,3 @@ -from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam - from .models import ( ActModel, GetModel, @@ -10,20 +8,48 @@ ModelDefinition, ModelName, ModelRegistry, + OnMessageCb, Point, ) +from .shared.computer_agent_message_param import ( + Base64ImageSourceParam, + CacheControlEphemeralParam, + CitationCharLocationParam, + CitationContentBlockLocationParam, + CitationPageLocationParam, + ContentBlockParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + TextCitationParam, + ToolResultBlockParam, + ToolUseBlockParam, + UrlImageSourceParam, +) __all__ = [ "ActModel", - "BetaMessageParam", - "BetaToolUseBlockParam", + "Base64ImageSourceParam", + "CacheControlEphemeralParam", + "CitationCharLocationParam", + "CitationContentBlockLocationParam", + "CitationPageLocationParam", + "ContentBlockParam", "GetModel", + "ImageBlockParam", "LocateModel", + "MessageParam", "Model", "ModelChoice", "ModelComposition", "ModelDefinition", "ModelName", "ModelRegistry", + "OnMessageCb", "Point", + "TextBlockParam", + "TextCitationParam", + "ToolResultBlockParam", + "ToolUseBlockParam", + "UrlImageSourceParam", ] diff --git a/src/askui/models/anthropic/computer_agent.py b/src/askui/models/anthropic/computer_agent.py index 476e7ecd..f7dc7ea2 100644 --- a/src/askui/models/anthropic/computer_agent.py +++ b/src/askui/models/anthropic/computer_agent.py @@ -1,13 +1,18 @@ +from typing import TYPE_CHECKING, cast + from anthropic import Anthropic -from anthropic.types.beta import BetaMessage, BetaMessageParam from typing_extensions import override from askui.models.anthropic.settings import ClaudeComputerAgentSettings from askui.models.models import ANTHROPIC_MODEL_NAME_MAPPING, ModelName from askui.models.shared.computer_agent import ComputerAgent +from askui.models.shared.computer_agent_message_param import MessageParam from askui.reporting import Reporter from askui.tools.agent_os import AgentOs +if TYPE_CHECKING: + from anthropic.types.beta import BetaMessageParam + class ClaudeComputerAgent(ComputerAgent[ClaudeComputerAgentSettings]): def __init__( @@ -23,14 +28,18 @@ def __init__( @override def _create_message( - self, messages: list[BetaMessageParam], model_choice: str - ) -> BetaMessage: + self, messages: list[MessageParam], model_choice: str + ) -> MessageParam: response = self._client.beta.messages.with_raw_response.create( max_tokens=self._settings.max_tokens, - messages=messages, + messages=[ + cast("BetaMessageParam", message.model_dump(mode="json")) + for message in messages + ], model=ANTHROPIC_MODEL_NAME_MAPPING[ModelName(model_choice)], system=[self._system], tools=self._tool_collection.to_params(), betas=self._settings.betas, ) - return response.parse() + parsed_response = response.parse() + return MessageParam.model_validate(parsed_response.model_dump()) diff --git a/src/askui/models/askui/computer_agent.py b/src/askui/models/askui/computer_agent.py index 6538c5ba..42073abf 100644 --- a/src/askui/models/askui/computer_agent.py +++ b/src/askui/models/askui/computer_agent.py @@ -1,9 +1,10 @@ import httpx -from anthropic.types.beta import BetaMessage, BetaMessageParam from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential +from typing_extensions import override from askui.models.askui.settings import AskUiComputerAgentSettings from askui.models.shared.computer_agent import ComputerAgent +from askui.models.shared.computer_agent_message_param import MessageParam from askui.reporting import Reporter from askui.tools.agent_os import AgentOs @@ -39,27 +40,27 @@ def __init__( retry=retry_if_exception(is_retryable_error), reraise=True, ) + @override def _create_message( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, # noqa: ARG002 - ) -> BetaMessage: + ) -> MessageParam: try: request_body = { "max_tokens": self._settings.max_tokens, - "messages": messages, + "messages": [msg.model_dump(mode="json") for msg in messages], "model": self._settings.model, "tools": self._tool_collection.to_params(), "betas": self._settings.betas, "system": [self._system], } - logger.debug(request_body) response = self._client.post( "/act/inference", json=request_body, timeout=300.0 ) response.raise_for_status() response_data = response.json() - return BetaMessage.model_validate(response_data) + return MessageParam.model_validate(response_data) except Exception as e: # noqa: BLE001 if is_retryable_error(e): logger.debug(e) diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index efdcde68..fbc49fe1 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -1,7 +1,6 @@ import functools -from typing import Callable, Type, overload +from typing import Type, overload -from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from typing_extensions import Literal from askui.exceptions import ModelNotFoundError, ModelTypeMismatchError @@ -28,10 +27,11 @@ ModelRegistry, Point, ) +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.shared.facade import ModelFacade from askui.models.types.response_schemas import ResponseSchema from askui.reporting import CompositeReporter, Reporter -from askui.tools.anthropic import ToolResult from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource @@ -184,21 +184,13 @@ def _get_model( def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: m = self._get_model(model_choice, "act") logger.debug(f'Routing "act" to model "{model_choice}"') - return m.act(messages, model_choice, on_message, on_tool_result) + return m.act(messages, model_choice, on_message) def get( self, diff --git a/src/askui/models/models.py b/src/askui/models/models.py index d038716a..71099796 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -4,13 +4,13 @@ from enum import Enum from typing import Annotated, Callable, Type -from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from pydantic import BaseModel, ConfigDict, Field, RootModel from typing_extensions import Literal, TypedDict from askui.locators.locators import Locator +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.types.response_schemas import ResponseSchema -from askui.tools.anthropic.base import ToolResult from askui.utils.image_utils import ImageSource @@ -158,10 +158,9 @@ class ActModel(abc.ABC): ```python from askui import ( ActModel, - BetaMessageParam, - BetaToolUseBlockParam, + MessageParam, + OnMessageCb, VisionAgent, - ToolResult, ) from typing_extensions import override @@ -169,17 +168,9 @@ class MyActModel(ActModel): @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: print(messages) # implement custom logic here @@ -190,30 +181,20 @@ def act( @abc.abstractmethod def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: """ Execute autonomous actions to achieve a goal, using a message history and optional callbacks. Args: - messages (list[BetaMessageParam]): The message history to start from. + messages (list[MessageParam]): The message history to start from. model_choice (str): The name of the model being used (useful for models that support multiple configurations) - on_message (Callable[[BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None], optional): Callback for new messages. + on_message (OnMessageCb | None, optional): Callback for new messages. If it returns `None`, stops and does not add the message. - on_tool_result (Callable[[ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], ToolResult | None], optional): Callback for tool results. - If it returns `None`, stops and does not add the tool result. Returns: None diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py index 60be1aa0..ebcc6d2f 100644 --- a/src/askui/models/shared/computer_agent.py +++ b/src/askui/models/shared/computer_agent.py @@ -2,21 +2,22 @@ import sys from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import Callable, Generic, cast - -from anthropic.types.beta import ( - BetaContentBlockParam, - BetaImageBlockParam, - BetaMessage, - BetaMessageParam, - BetaTextBlockParam, - BetaToolResultBlockParam, - BetaToolUseBlockParam, -) +from typing import Generic + +from anthropic.types.beta import BetaTextBlockParam from pydantic import BaseModel, Field -from typing_extensions import TypeVar +from typing_extensions import TypeVar, override from askui.models.models import ActModel +from askui.models.shared.computer_agent_cb_param import OnMessageCb, OnMessageCbParam +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + ContentBlockParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, +) from askui.reporting import Reporter from askui.tools.agent_os import AgentOs from askui.tools.anthropic import ComputerTool, ToolCollection, ToolResult @@ -210,66 +211,52 @@ def __init__( @abstractmethod def _create_message( - self, messages: list[BetaMessageParam], model_choice: str - ) -> BetaMessage: + self, messages: list[MessageParam], model_choice: str + ) -> MessageParam: """Create a message using the agent's API. Args: - messages (list[BetaMessageParam]): The message history. + messages (list[MessageParam]): The message history. model_choice (str): The model to use for message creation. Returns: - BetaMessage: The created message. + MessageParam: The created message. """ raise NotImplementedError def _step( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: """Execute a single step in the conversation. Args: - messages (list[BetaMessageParam]): The message history. + messages (list[MessageParam]): The message history. model_choice (str): The model to use for message creation. - on_message (Callable, optional): Callback for message processing. - on_tool_result (Callable, optional): Callback for tool result processing. + on_message (OnMessageCb | None, optional): Callback on new messages Returns: None """ if self._settings.only_n_most_recent_images: - self._maybe_filter_to_n_most_recent_images( + messages = self._maybe_filter_to_n_most_recent_images( messages, self._settings.only_n_most_recent_images, - min_removal_threshold=self._settings.image_truncation_threshold, + self._settings.image_truncation_threshold, ) response_message = self._create_message(messages, model_choice) - response_message_param: BetaMessageParam = cast( - "BetaMessageParam", - response_message.model_dump(include={"content", "role"}), - ) message_by_assistant = self._call_on_message( - on_message, response_message_param, messages + on_message, response_message, messages ) if message_by_assistant is None: return - logger.debug(message_by_assistant) + message_by_assistant_dict = message_by_assistant.model_dump(mode="json") + logger.debug(message_by_assistant_dict) messages.append(message_by_assistant) - self._reporter.add_message(self.__class__.__name__, dict(message_by_assistant)) - if tool_result_message := self._use_tools( - response_message, messages, on_tool_result - ): + self._reporter.add_message(self.__class__.__name__, message_by_assistant_dict) + if tool_result_message := self._use_tools(message_by_assistant): if tool_result_message := self._call_on_message( on_message, tool_result_message, messages ): @@ -278,147 +265,124 @@ def _step( messages=messages, model_choice=model_choice, on_message=on_message, - on_tool_result=on_tool_result, ) def _call_on_message( self, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None, - message: BetaMessageParam, - messages: list[BetaMessageParam], - ) -> BetaMessageParam | None: + on_message: OnMessageCb | None, + message: MessageParam, + messages: list[MessageParam], + ) -> MessageParam | None: if on_message is None: return message - return on_message(message, messages) + return on_message(OnMessageCbParam(message=message, messages=messages)) + @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: - logger.debug(messages[0]) self._step( messages=messages, model_choice=model_choice, on_message=on_message, - on_tool_result=on_tool_result, ) def _use_tools( self, - message: BetaMessage, - messages: list[BetaMessageParam], - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, - ) -> BetaMessageParam | None: + message: MessageParam, + ) -> MessageParam | None: """Process tool use blocks in a message. Args: - message (BetaMessage): The message containing tool use blocks. - messages (list[BetaMessageParam]): The message history. - on_tool_result (Callable, optional): Callback for tool result processing. + message (MessageParam): The message containing tool use blocks. Returns: - BetaMessageParam | None: A message containing tool results or `None` + MessageParam | None: A message containing tool results or `None` if no tools were used. """ - tool_result_content: list[BetaToolResultBlockParam] = [] + tool_result_content: list[ContentBlockParam] = [] + if isinstance(message.content, str): + return None + for content_block in message.content: if content_block.type == "tool_use": result = self._tool_collection.run( name=content_block.name, tool_input=content_block.input, # type: ignore[arg-type] ) - if on_tool_result is not None: - content_block_param = cast( - "BetaToolUseBlockParam", content_block.model_dump() - ) - result_cb = on_tool_result(result, content_block_param, messages) - if result_cb is None: - return None - result = result_cb tool_result_content.append( self._make_api_tool_result(result, content_block.id) ) - if len(tool_result_content) > 0: - tool_result_message: BetaMessageParam = { - "content": tool_result_content, - "role": "user", - } - logger.debug(tool_result_message) - return tool_result_message - return None + if len(tool_result_content) == 0: + return None + + return MessageParam( + content=tool_result_content, + role="user", + ) @staticmethod def _maybe_filter_to_n_most_recent_images( - messages: list[BetaMessageParam], + messages: list[MessageParam], images_to_keep: int | None, min_removal_threshold: int, - ) -> list[BetaMessageParam] | None: - """Filter messages to keep only the most recent images. + ) -> list[MessageParam]: + """ + Filter the message history in-place to keep only the most recent images, + according to the given chunking policy. Args: - messages (list[BetaMessageParam]): The message history. + messages (list[MessageParam]): The message history. images_to_keep (int | None): Number of most recent images to keep. min_removal_threshold (int): Minimum number of images to remove at once. Returns: - list[BetaMessageParam] | None: The filtered message history or `None` if no - filtering was done. + list[MessageParam]: The filtered message history. """ if images_to_keep is None: return messages - tool_result_blocks = cast( - "list[BetaToolResultBlockParam]", - [ - item - for message in messages - for item in ( - message["content"] if isinstance(message["content"], list) else [] - ) - if item.get("type") == "tool_result" - ], - ) + tool_result_blocks = [ + item + for message in messages + for item in (message.content if isinstance(message.content, list) else []) + if item.type == "tool_result" + ] total_images = sum( 1 for tool_result in tool_result_blocks - for content in tool_result.get("content", []) - if isinstance(content, dict) and content.get("type") == "image" + if not isinstance(tool_result.content, str) + for content in tool_result.content + if content.type == "image" ) images_to_remove = total_images - images_to_keep + if images_to_remove < min_removal_threshold: + return messages # for better cache behavior, we want to remove in chunks images_to_remove -= images_to_remove % min_removal_threshold + if images_to_remove <= 0: + return messages + + # Remove images from the oldest tool_result blocks first for tool_result in tool_result_blocks: - if isinstance(tool_result.get("content"), list): - new_content: list[BetaContentBlockParam] = [] - for content in tool_result.get("content", []): - if isinstance(content, dict) and content.get("type") == "image": - if images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(cast("BetaContentBlockParam", content)) - tool_result["content"] = new_content # type: ignore[typeddict-item] - return None + if images_to_remove <= 0: + break + if isinstance(tool_result.content, list): + new_content: list[TextBlockParam | ImageBlockParam] = [] + for content in tool_result.content: + if content.type == "image" and images_to_remove > 0: + images_to_remove -= 1 + continue + new_content.append(content) + tool_result.content = new_content + return messages def _make_api_tool_result( self, result: ToolResult, tool_use_id: str - ) -> BetaToolResultBlockParam: + ) -> ToolResultBlockParam: """Convert a tool result to an API tool result block. Args: @@ -426,9 +390,9 @@ def _make_api_tool_result( tool_use_id (str): The ID of the tool use block. Returns: - BetaToolResultBlockParam: The API tool result block. + ToolResultBlockParam: The API tool result block. """ - tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = [] + tool_result_content: list[TextBlockParam | ImageBlockParam] | str = [] is_error = False if result.error: is_error = True @@ -439,30 +403,26 @@ def _make_api_tool_result( assert isinstance(tool_result_content, list) if result.output: tool_result_content.append( - { - "type": "text", - "text": self._maybe_prepend_system_tool_result( + TextBlockParam( + text=self._maybe_prepend_system_tool_result( result, result.output ), - } + ) ) if result.base64_image: tool_result_content.append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": result.base64_image, - }, - } + ImageBlockParam( + source=Base64ImageSourceParam( + media_type="image/png", + data=result.base64_image, + ), + ) ) - return { - "type": "tool_result", - "content": tool_result_content, - "tool_use_id": tool_use_id, - "is_error": is_error, - } + return ToolResultBlockParam( + content=tool_result_content, + tool_use_id=tool_use_id, + is_error=is_error, + ) @staticmethod def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str) -> str: diff --git a/src/askui/models/shared/computer_agent_cb_param.py b/src/askui/models/shared/computer_agent_cb_param.py new file mode 100644 index 00000000..80d07f49 --- /dev/null +++ b/src/askui/models/shared/computer_agent_cb_param.py @@ -0,0 +1,14 @@ +from typing import Callable, Literal + +from pydantic import BaseModel + +from askui.models.shared.computer_agent_message_param import MessageParam + + +class OnMessageCbParam(BaseModel): + type: Literal["message"] = "message" + message: MessageParam + messages: list[MessageParam] + + +OnMessageCb = Callable[[OnMessageCbParam], MessageParam | None] diff --git a/src/askui/models/shared/computer_agent_message_param.py b/src/askui/models/shared/computer_agent_message_param.py new file mode 100644 index 00000000..e525f172 --- /dev/null +++ b/src/askui/models/shared/computer_agent_message_param.py @@ -0,0 +1,107 @@ +from pydantic import BaseModel +from typing_extensions import Literal + + +class CitationCharLocationParam(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_char_index: int + start_char_index: int + type: Literal["char_location"] = "char_location" + + +class CitationPageLocationParam(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_page_number: int + start_page_number: int + type: Literal["page_location"] = "page_location" + + +class CitationContentBlockLocationParam(BaseModel): + cited_text: str + document_index: int + document_title: str | None = None + end_block_index: int + start_block_index: int + type: Literal["content_block_location"] = "content_block_location" + + +TextCitationParam = ( + CitationCharLocationParam + | CitationPageLocationParam + | CitationContentBlockLocationParam +) + + +class UrlImageSourceParam(BaseModel): + type: Literal["url"] = "url" + url: str + + +class Base64ImageSourceParam(BaseModel): + data: str + media_type: Literal["image/jpeg", "image/png", "image/gif", "image/webp"] + type: Literal["base64"] = "base64" + + +class CacheControlEphemeralParam(BaseModel): + type: Literal["ephemeral"] = "ephemeral" + + +class ImageBlockParam(BaseModel): + source: Base64ImageSourceParam | UrlImageSourceParam + type: Literal["image"] = "image" + cache_control: CacheControlEphemeralParam | None = None + + +class TextBlockParam(BaseModel): + text: str + type: Literal["text"] = "text" + cache_control: CacheControlEphemeralParam | None = None + citations: list[TextCitationParam] | None = None + + +class ToolResultBlockParam(BaseModel): + tool_use_id: str + type: Literal["tool_result"] = "tool_result" + cache_control: CacheControlEphemeralParam | None = None + content: str | list[TextBlockParam | ImageBlockParam] + is_error: bool = False + + +class ToolUseBlockParam(BaseModel): + id: str + input: object + name: str + type: Literal["tool_use"] = "tool_use" + cache_control: CacheControlEphemeralParam | None = None + + +ContentBlockParam = ( + ImageBlockParam | TextBlockParam | ToolResultBlockParam | ToolUseBlockParam +) + + +class MessageParam(BaseModel): + role: Literal["user", "assistant"] + content: str | list[ContentBlockParam] + + +__all__ = [ + "Base64ImageSourceParam", + "CacheControlEphemeralParam", + "CitationCharLocationParam", + "CitationContentBlockLocationParam", + "CitationPageLocationParam", + "ContentBlockParam", + "ImageBlockParam", + "MessageParam", + "TextBlockParam", + "TextCitationParam", + "ToolResultBlockParam", + "ToolUseBlockParam", + "UrlImageSourceParam", +] diff --git a/src/askui/models/shared/facade.py b/src/askui/models/shared/facade.py index 34491b49..e4ac7a0f 100644 --- a/src/askui/models/shared/facade.py +++ b/src/askui/models/shared/facade.py @@ -1,12 +1,12 @@ -from typing import Callable, Type +from typing import Type -from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from typing_extensions import override from askui.locators.locators import Locator from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.types.response_schemas import ResponseSchema -from askui.tools.anthropic.base import ToolResult from askui.utils.image_utils import ImageSource @@ -24,23 +24,14 @@ def __init__( @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: self._act_model.act( messages=messages, model_choice=model_choice, on_message=on_message, - on_tool_result=on_tool_result, ) @override diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 6db29e64..e1c39306 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -1,9 +1,8 @@ import math import re import time -from typing import Any, Callable, Type +from typing import Any, Type -from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from openai import OpenAI from pydantic import Field, HttpUrl, SecretStr from pydantic_settings import BaseSettings @@ -13,10 +12,11 @@ from askui.locators.locators import Locator from askui.locators.serializers import VlmLocatorSerializer from askui.models.models import ActModel, GetModel, LocateModel, ModelComposition, Point +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter from askui.tools.agent_os import AgentOs -from askui.tools.anthropic.base import ToolResult from askui.utils.image_utils import ImageSource, image_to_base64 from .parser import UITarsEPMessage @@ -192,35 +192,24 @@ def get( @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: if on_message is not None: error_msg = "on_message is not supported for UI-TARS" raise NotImplementedError(error_msg) - if on_tool_result is not None: - error_msg = "on_tool_result is not supported for UI-TARS" - raise NotImplementedError(error_msg) if len(messages) != 1: error_msg = "UI-TARS only supports one message" raise ValueError(error_msg) message = messages[0] - if message["role"] != "user": + if message.role != "user": error_msg = "UI-TARS only supports user messages" raise ValueError(error_msg) - if not isinstance(message["content"], str): + if not isinstance(message.content, str): error_msg = "UI-TARS only supports text messages" raise ValueError(error_msg) # noqa: TRY004 - goal = message["content"] + goal = message.content screenshot = self._agent_os.screenshot() self.act_history = [ { diff --git a/src/askui/tools/anthropic/computer.py b/src/askui/tools/anthropic/computer.py index 1dbc12fe..3997e282 100644 --- a/src/askui/tools/anthropic/computer.py +++ b/src/askui/tools/anthropic/computer.py @@ -257,6 +257,11 @@ def __call__( # noqa: C901 error_msg = f"{coordinate} must be a tuple of non-negative ints" raise ToolError(error_msg) + if self._real_screen_width is None or self._real_screen_height is None: + screenshot = self._agent_os.screenshot() + self._real_screen_width = screenshot.width + self._real_screen_height = screenshot.height + x, y = scale_coordinates_back( coordinate[0], coordinate[1], diff --git a/tests/integration/agent/test_computer_agent.py b/tests/integration/agent/test_computer_agent.py index 2154339d..24b5c680 100644 --- a/tests/integration/agent/test_computer_agent.py +++ b/tests/integration/agent/test_computer_agent.py @@ -2,6 +2,7 @@ from askui.models.askui.computer_agent import AskUiComputerAgent from askui.models.models import ModelName +from askui.models.shared.computer_agent_message_param import MessageParam @pytest.mark.skip( @@ -11,6 +12,6 @@ def test_act( claude_computer_agent: AskUiComputerAgent, ) -> None: claude_computer_agent.act( - [{"role": "user", "content": "Go to github.com/login"}], + [MessageParam(role="user", content="Go to github.com/login")], model_choice=ModelName.ASKUI, ) diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index 1bdfd521..01db14fd 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -1,9 +1,8 @@ """Integration tests for custom model registration and selection.""" -from typing import Callable, Optional, Type, Union +from typing import Any, Optional, Type, Union import pytest -from anthropic.types.beta import BetaMessageParam, BetaToolUseBlockParam from typing_extensions import override from askui import ( @@ -18,7 +17,8 @@ ) from askui.locators.locators import Locator from askui.models import ModelComposition, ModelDefinition, ModelName -from askui.tools.anthropic.base import ToolResult +from askui.models.shared.computer_agent_cb_param import OnMessageCb +from askui.models.shared.computer_agent_message_param import MessageParam from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource @@ -27,25 +27,17 @@ class SimpleActModel(ActModel): """Simple act model that records goals.""" def __init__(self) -> None: - self.goals: list[list[BetaMessageParam]] = [] + self.goals: list[list[dict[str, Any]]] = [] self.model_choices: list[str] = [] @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: - self.goals.append(messages) + self.goals.append([message.model_dump(mode="json") for message in messages]) self.model_choices.append(model_choice) @@ -217,17 +209,9 @@ class AnotherActModel(ActModel): @override def act( self, - messages: list[BetaMessageParam], + messages: list[MessageParam], model_choice: str, - on_message: Callable[ - [BetaMessageParam, list[BetaMessageParam]], BetaMessageParam | None - ] - | None = None, - on_tool_result: Callable[ - [ToolResult, BetaToolUseBlockParam, list[BetaMessageParam]], - ToolResult | None, - ] - | None = None, + on_message: OnMessageCb | None = None, ) -> None: pass diff --git a/tests/unit/models/test_computer_agent_filter.py b/tests/unit/models/test_computer_agent_filter.py new file mode 100644 index 00000000..4199e8b4 --- /dev/null +++ b/tests/unit/models/test_computer_agent_filter.py @@ -0,0 +1,129 @@ +from askui.models.shared.computer_agent import ComputerAgent +from askui.models.shared.computer_agent_message_param import ( + Base64ImageSourceParam, + ImageBlockParam, + MessageParam, + TextBlockParam, + ToolResultBlockParam, +) + + +def make_image_block() -> ImageBlockParam: + return ImageBlockParam( + source=Base64ImageSourceParam( + media_type="image/png", + data="abc", + ), + ) + + +def make_tool_result_block(num_images: int, num_texts: int = 0) -> ToolResultBlockParam: + content = [make_image_block() for _ in range(num_images)] + [ + TextBlockParam(text=f"text{i}") for i in range(num_texts) + ] + return ToolResultBlockParam(tool_use_id="id", content=content) + + +def make_message_with_tool_result(num_images: int, num_texts: int = 0) -> MessageParam: + return MessageParam( + role="user", content=[make_tool_result_block(num_images, num_texts)] + ) + + +def test_no_images() -> None: + messages = [make_message_with_tool_result(0, 2)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 3, 2) + assert filtered == messages + + +def test_fewer_images_than_keep() -> None: + messages = [make_message_with_tool_result(2, 1)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 3, 2) + # Only ToolResultBlockParam with list content should be checked + all_images = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + expected_images = [ + c + for b in (messages[0].content if isinstance(messages[0].content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + assert all_images == expected_images + + +def test_exactly_images_to_keep() -> None: + messages = [make_message_with_tool_result(3, 1)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 3, 2) + # Only check .content if the type is correct + first_block = ( + filtered[0].content[0] + if isinstance(filtered[0].content, list) and len(filtered[0].content) > 0 + else None + ) + if isinstance(first_block, ToolResultBlockParam) and isinstance( + first_block.content, list + ): + assert len(first_block.content) == 4 + else: + raise AssertionError( + "filtered[0].content[0] is not a ToolResultBlockParam with list content" + ) + all_tool_result_contents = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + ] + assert ( + sum(1 for c in all_tool_result_contents if getattr(c, "type", None) == "image") + == 3 + ) + + +def test_more_images_than_keep_removes_oldest() -> None: + messages = [ + make_message_with_tool_result(2, 0), + make_message_with_tool_result(2, 0), + ] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 2, 2) + # Only 2 images should remain, and they should be the newest (from the last message) + all_images = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + assert len(all_images) == 2 + # They should be from the last message + assert all_images == [ + c + for b in (filtered[1].content if isinstance(filtered[1].content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content[:2] + if getattr(c, "type", None) == "image" + ] + + +def test_removal_chunking() -> None: + messages = [make_message_with_tool_result(5, 0)] + filtered = ComputerAgent._maybe_filter_to_n_most_recent_images(messages, 2, 2) + # Should remove 4 (chunk of 4), leaving 1 image + all_images = [ + c + for m in filtered + for b in (m.content if isinstance(m.content, list) else []) + if isinstance(b, ToolResultBlockParam) and isinstance(b.content, list) + for c in b.content + if getattr(c, "type", None) == "image" + ] + assert len(all_images) == 3 diff --git a/tests/unit/models/test_model_router.py b/tests/unit/models/test_model_router.py index f35b7257..258d7292 100644 --- a/tests/unit/models/test_model_router.py +++ b/tests/unit/models/test_model_router.py @@ -12,6 +12,7 @@ from askui.models.huggingface.spaces_api import HFSpacesHandler from askui.models.model_router import ModelRouter from askui.models.models import ModelName +from askui.models.shared.computer_agent_message_param import MessageParam from askui.models.shared.facade import ModelFacade from askui.models.ui_tars_ep.ui_tars_api import UiTarsApiHandler from askui.reporting import CompositeReporter @@ -288,29 +289,31 @@ def test_act_with_tars_model( self, model_router: ModelRouter, mock_tars: UiTarsApiHandler ) -> None: """Test acting using TARS model.""" - model_router.act([{"role": "user", "content": "test goal"}], ModelName.TARS) + messages = [MessageParam(role="user", content="test goal")] + model_router.act(messages, ModelName.TARS) mock_tars.act.assert_called_once_with( # type: ignore[attr-defined] - [{"role": "user", "content": "test goal"}], ModelName.TARS, None, None + messages, + ModelName.TARS, + None, ) def test_act_with_claude_model( self, model_router: ModelRouter, mock_anthropic_facade: ModelFacade ) -> None: """Test acting using Claude model.""" + messages = [MessageParam(role="user", content="test goal")] model_router.act( - [{"role": "user", "content": "test goal"}], + messages, ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, ) mock_anthropic_facade.act.assert_called_once_with( # type: ignore - [{"role": "user", "content": "test goal"}], + messages, ModelName.ANTHROPIC__CLAUDE__3_5__SONNET__20241022, None, - None, ) def test_act_with_invalid_model(self, model_router: ModelRouter) -> None: """Test that acting with invalid model raises InvalidModelError.""" + messages = [MessageParam(role="user", content="test goal")] with pytest.raises(ModelNotFoundError): - model_router.act( - [{"role": "user", "content": "test goal"}], "invalid-model" - ) + model_router.act(messages, "invalid-model") From 00cbbd34679b6e12c19f68b37f3e59e0b5194d16 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 15:37:00 +0200 Subject: [PATCH 14/20] feat(chat): add streaming support to creating runs --- src/askui/chat/__main__.py | 28 +++++- src/askui/chat/api/messages/service.py | 7 ++ src/askui/chat/api/models.py | 7 ++ src/askui/chat/api/runs/router.py | 25 ++++- src/askui/chat/api/runs/service.py | 129 ++++++++++++++++++++----- 5 files changed, 169 insertions(+), 27 deletions(-) create mode 100644 src/askui/chat/api/models.py diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index 5bf6aeb7..cc3e8c37 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -338,7 +338,6 @@ def write_message( ) write_message(last_message) run = run_service.create(thread_id, stream=False) - print(run) time.sleep(1) while run := run_service.retrieve(run.id): new_messages = message_service.list_( @@ -352,5 +351,32 @@ def write_message( time.sleep(1) +if act_prompt := st.chat_input("Ask AI (streaming)"): + if act_prompt != "Continue": + last_message = message_service.create( + thread_id=thread_id, + message=MessageParam( + role="user", + content=act_prompt, + ), + ) + write_message(last_message) + + # Use the streaming API + event_stream = run_service.create(thread_id, stream=True) + import asyncio + + async def handle_stream() -> None: + last_msg_id = last_message.id if last_message else None + async for event in event_stream: + if event.event == "message.created": + msg = event.data + if msg and (not last_msg_id or msg.id > last_msg_id): + write_message(msg) + last_msg_id = msg.id + + # Run the async handler in Streamlit (sync context) + asyncio.run(handle_stream()) + # if st.button("Rerun"): # rerun() diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index e352e7d1..9cf84801 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -1,8 +1,10 @@ from datetime import datetime, timezone from pathlib import Path +from typing import Literal from pydantic import AwareDatetime, BaseModel, Field +from askui.chat.api.models import Event from askui.chat.api.utils import generate_time_ordered_id from askui.models.shared.computer_agent_message_param import MessageParam @@ -18,6 +20,11 @@ class Message(MessageParam): object: str = "message" +class MessageEvent(Event): + data: Message + event: Literal["message.created"] + + class MessageListResponse(BaseModel): """Response model for listing messages.""" diff --git a/src/askui/chat/api/models.py b/src/askui/chat/api/models.py new file mode 100644 index 00000000..81d23f14 --- /dev/null +++ b/src/askui/chat/api/models.py @@ -0,0 +1,7 @@ +from typing import Literal + +from pydantic import BaseModel + + +class Event(BaseModel): + object: Literal["event"] = "event" diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py index aa509fb6..120b40b8 100644 --- a/src/askui/chat/api/runs/router.py +++ b/src/askui/chat/api/runs/router.py @@ -1,10 +1,15 @@ -from typing import Annotated +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING, Annotated, cast from fastapi import APIRouter, Body, HTTPException, Path +from fastapi.responses import StreamingResponse from pydantic import BaseModel +if TYPE_CHECKING: + from askui.chat.api.messages.service import MessageEvent + from .dependencies import RunServiceDep -from .service import Run, RunListResponse, RunService +from .service import Run, RunEvent, RunListResponse, RunService class CreateRunRequest(BaseModel): @@ -19,11 +24,23 @@ def create_run( thread_id: Annotated[str, Path(...)], request: Annotated[CreateRunRequest, Body(...)], run_service: RunService = RunServiceDep, -) -> Run: +) -> Run | StreamingResponse: """ Create a new run for a given thread. """ - return run_service.create(thread_id, request.stream) + stream = request.stream + run_or_async_generator = run_service.create(thread_id, stream) + if stream: + async_generator = cast( + "AsyncGenerator[RunEvent | MessageEvent, None]", run_or_async_generator + ) + + async def sse_event_stream() -> AsyncGenerator[str, None]: + async for event in async_generator: + yield f"event: {event.event}\ndata: {event.model_dump_json()}\n\n" + + return StreamingResponse(sse_event_stream(), media_type="text/event-stream") + return cast("Run", run_or_async_generator) @router.get("/{run_id}") diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 24c0ee84..4dff82a9 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -1,12 +1,16 @@ -from concurrent.futures import ThreadPoolExecutor +import asyncio +import queue +import threading +from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone from pathlib import Path -from typing import Literal, Sequence, cast +from typing import Literal, Sequence, cast, overload from pydantic import AwareDatetime, BaseModel, Field, computed_field from askui.agent import VisionAgent -from askui.chat.api.messages.service import MessageService +from askui.chat.api.messages.service import MessageEvent, MessageService +from askui.chat.api.models import Event from askui.chat.api.utils import generate_time_ordered_id from askui.models.shared.computer_agent_cb_param import OnMessageCbParam from askui.models.shared.computer_agent_message_param import MessageParam @@ -70,6 +74,18 @@ class RunListResponse(BaseModel): has_more: bool = False +class RunEvent(Event): + data: Run + event: Literal[ + "run.created", + "run.started", + "run.completed", + "run.failed", + "run.cancelled", + "run.expired", + ] + + class Runner: def __init__(self, run: Run, base_dir: Path) -> None: self._run = run @@ -77,8 +93,14 @@ def __init__(self, run: Run, base_dir: Path) -> None: self._runs_dir = base_dir / "runs" self._msg_service = MessageService(self._base_dir) - def run_task(self) -> None: + def run(self, event_queue: queue.Queue[RunEvent | MessageEvent | None]) -> None: self._mark_started() + event_queue.put( + RunEvent( + data=self._run, + event="run.started", + ) + ) messages: list[MessageParam] = [ cast("MessageParam", msg) for msg in self._msg_service.list_(self._run.thread_id).data @@ -87,27 +109,63 @@ def run_task(self) -> None: def on_message( on_message_cb_param: OnMessageCbParam, ) -> MessageParam | None: - self._msg_service.create( + message = self._msg_service.create( thread_id=self._run.thread_id, message=on_message_cb_param.message, ) + event_queue.put( + MessageEvent( + data=message, + event="message.created", + ) + ) updated_run = self._retrieve_run() - if self._should_abort(updated_run): + if updated_run.status == "cancelling": updated_run.cancelled_at = datetime.now(tz=timezone.utc) self._update_run_file(updated_run) + event_queue.put( + RunEvent( + data=updated_run, + event="run.cancelled", + ) + ) + return None + if updated_run.status == "expired": + event_queue.put( + RunEvent( + data=updated_run, + event="run.expired", + ) + ) return None return on_message_cb_param.message try: with VisionAgent() as agent: agent.act(messages, on_message=on_message) - self._run.completed_at = datetime.now(tz=timezone.utc) - self._update_run_file(self._run) + updated_run = self._retrieve_run() + if updated_run.status == "in_progress": + updated_run.completed_at = datetime.now(tz=timezone.utc) + self._update_run_file(updated_run) + event_queue.put( + RunEvent( + data=updated_run, + event="run.completed", + ) + ) except Exception as e: # noqa: BLE001 - self._run.failed_at = datetime.now(tz=timezone.utc) - self._run.last_error = RunError(message=str(e), code="server_error") - self._update_run_file(self._run) - raise + updated_run = self._retrieve_run() + updated_run.failed_at = datetime.now(tz=timezone.utc) + updated_run.last_error = RunError(message=str(e), code="server_error") + self._update_run_file(updated_run) + event_queue.put( + RunEvent( + data=updated_run, + event="run.failed", + ) + ) + finally: + event_queue.put(None) def _mark_started(self) -> None: self._run.started_at = datetime.now(tz=timezone.utc) @@ -132,8 +190,6 @@ class RunService: Service for managing runs. Handles creation, retrieval, listing, and cancellation of runs. """ - _executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=4) - def __init__(self, base_dir: Path) -> None: self._base_dir = base_dir self._runs_dir = base_dir / "runs" @@ -141,20 +197,49 @@ def __init__(self, base_dir: Path) -> None: def _run_path(self, thread_id: str, run_id: str) -> Path: return self._runs_dir / f"{thread_id}__{run_id}.json" - def create(self, thread_id: str, stream: bool) -> Run: + def _create_run(self, thread_id: str) -> Run: run = Run(thread_id=thread_id) self._runs_dir.mkdir(parents=True, exist_ok=True) self._update_run_file(run) - runner = Runner(run, self._base_dir) - # TODO(adi-wan-askui): Run differently depending on `stream` parameter - runner.run_task() - # if not stream: - # self._start_run_background(run) return run - def _start_run_background(self, run: Run) -> None: + @overload + def create(self, thread_id: str, stream: Literal[False]) -> Run: ... + + @overload + def create( + self, thread_id: str, stream: Literal[True] + ) -> AsyncGenerator[RunEvent | MessageEvent, None]: ... + + @overload + def create( + self, thread_id: str, stream: bool + ) -> Run | AsyncGenerator[RunEvent | MessageEvent, None]: ... + + def create( + self, thread_id: str, stream: bool + ) -> Run | AsyncGenerator[RunEvent | MessageEvent, None]: + run = self._create_run(thread_id) + event_queue: queue.Queue[RunEvent | MessageEvent | None] = queue.Queue() runner = Runner(run, self._base_dir) - self._executor.submit(runner.run_task) + thread = threading.Thread(target=runner.run, args=(event_queue,)) + thread.start() + if stream: + + async def event_stream() -> AsyncGenerator[RunEvent | MessageEvent, None]: + yield RunEvent( + data=run, + event="run.created", + ) + loop = asyncio.get_event_loop() + while True: + event = await loop.run_in_executor(None, event_queue.get) + if event is None: + break + yield event + + return event_stream() + return run def _update_run_file(self, run: Run) -> None: run_file = self._run_path(run.thread_id, run.id) From 752e612509999670d39c5f4f13acd7b8fb4b6acb Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 16:05:57 +0200 Subject: [PATCH 15/20] chore: structure deps better, e.g., making pynput optional --- pdm.lock | 48 ++++++++++++++++++++++++------------------------ pyproject.toml | 11 ++++++----- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/pdm.lock b/pdm.lock index 6cd1425f..7ac45d7b 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "chat", "mcp", "test"] +groups = ["default", "chat", "mcp", "pynput", "test"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:dedebe5b971455dfc718635d1180587b366c6d287b404138c030c4aec77452bc" +content_hash = "sha256:21d28e90c53b9d7f3439e469fe2125b83225d66dd3f6d182bedd1ae774544485" [[metadata.targets]] requires_python = ">=3.10" @@ -33,7 +33,7 @@ name = "annotated-types" version = "0.7.0" requires_python = ">=3.8" summary = "Reusable constraint types to use with typing.Annotated" -groups = ["default", "mcp"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions>=4.0.0; python_version < \"3.9\"", ] @@ -67,7 +67,7 @@ name = "anyio" version = "4.9.0" requires_python = ">=3.9" summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["default", "mcp"] +groups = ["default", "chat", "mcp"] dependencies = [ "exceptiongroup>=1.0.2; python_version < \"3.11\"", "idna>=2.8", @@ -202,7 +202,7 @@ name = "click" version = "8.1.8" requires_python = ">=3.7" summary = "Composable command line interface toolkit" -groups = ["default", "chat", "mcp"] +groups = ["chat", "mcp"] dependencies = [ "colorama; platform_system == \"Windows\"", "importlib-metadata; python_version < \"3.8\"", @@ -369,7 +369,7 @@ name = "evdev" version = "1.9.2" requires_python = ">=3.8" summary = "Bindings to the Linux input handling subsystem" -groups = ["default"] +groups = ["pynput"] marker = "\"linux\" in sys_platform" files = [ {file = "evdev-1.9.2.tar.gz", hash = "sha256:5d3278892ce1f92a74d6bf888cc8525d9f68af85dbe336c95d1c87fb8f423069"}, @@ -380,7 +380,7 @@ name = "exceptiongroup" version = "1.2.2" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default", "mcp", "test"] +groups = ["default", "chat", "mcp", "test"] marker = "python_version < \"3.11\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, @@ -403,7 +403,7 @@ name = "fastapi" version = "0.115.12" requires_python = ">=3.8" summary = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" -groups = ["default"] +groups = ["chat"] dependencies = [ "pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4", "starlette<0.47.0,>=0.40.0", @@ -608,7 +608,7 @@ name = "h11" version = "0.14.0" requires_python = ">=3.7" summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default", "mcp"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions; python_version < \"3.8\"", ] @@ -941,7 +941,7 @@ name = "mss" version = "10.0.0" requires_python = ">=3.9" summary = "An ultra fast cross-platform multiple screenshots module in pure python using ctypes." -groups = ["default"] +groups = ["pynput"] files = [ {file = "mss-10.0.0-py3-none-any.whl", hash = "sha256:82cf6460a53d09e79b7b6d871163c982e6c7e9649c426e7b7591b74956d5cb64"}, {file = "mss-10.0.0.tar.gz", hash = "sha256:d903e0d51262bf0f8782841cf16eaa6d7e3e1f12eae35ab41c2e318837c6637f"}, @@ -1315,7 +1315,7 @@ name = "pydantic" version = "2.11.2" requires_python = ">=3.9" summary = "Data validation using Python type hints" -groups = ["default", "mcp"] +groups = ["default", "chat", "mcp"] dependencies = [ "annotated-types>=0.6.0", "pydantic-core==2.33.1", @@ -1332,7 +1332,7 @@ name = "pydantic-core" version = "2.33.1" requires_python = ">=3.9" summary = "Core functionality for Pydantic validation and serialization" -groups = ["default", "mcp"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions!=4.7.0,>=4.6.0", ] @@ -1473,7 +1473,7 @@ files = [ name = "pynput" version = "1.8.1" summary = "Monitor and control user input devices" -groups = ["default"] +groups = ["pynput"] dependencies = [ "enum34; python_version == \"2.7\"", "evdev>=1.3; \"linux\" in sys_platform", @@ -1492,7 +1492,7 @@ name = "pyobjc-core" version = "11.0" requires_python = ">=3.8" summary = "Python<->ObjC Interoperability Module" -groups = ["default"] +groups = ["pynput"] marker = "sys_platform == \"darwin\"" files = [ {file = "pyobjc_core-11.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:10866b3a734d47caf48e456eea0d4815c2c9b21856157db5917b61dee06893a1"}, @@ -1508,7 +1508,7 @@ name = "pyobjc-framework-applicationservices" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the framework ApplicationServices on macOS" -groups = ["default"] +groups = ["pynput"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1530,7 +1530,7 @@ name = "pyobjc-framework-cocoa" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the Cocoa frameworks on macOS" -groups = ["default"] +groups = ["pynput"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1549,7 +1549,7 @@ name = "pyobjc-framework-coretext" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the framework CoreText on macOS" -groups = ["default"] +groups = ["pynput"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1570,7 +1570,7 @@ name = "pyobjc-framework-quartz" version = "11.0" requires_python = ">=3.9" summary = "Wrappers for the Quartz frameworks on macOS" -groups = ["default"] +groups = ["pynput"] marker = "sys_platform == \"darwin\"" dependencies = [ "pyobjc-core>=11.0", @@ -1711,7 +1711,7 @@ files = [ name = "python-xlib" version = "0.33" summary = "Python X Library" -groups = ["default"] +groups = ["pynput"] marker = "\"linux\" in sys_platform" dependencies = [ "six>=1.10.0", @@ -1995,7 +1995,7 @@ name = "six" version = "1.17.0" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" summary = "Python 2 and 3 compatibility utilities" -groups = ["default", "chat"] +groups = ["default", "chat", "pynput"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -2017,7 +2017,7 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["default", "mcp"] +groups = ["default", "chat", "mcp"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -2043,7 +2043,7 @@ name = "starlette" version = "0.46.2" requires_python = ">=3.9" summary = "The little ASGI library that shines." -groups = ["default", "mcp"] +groups = ["chat", "mcp"] dependencies = [ "anyio<5,>=3.6.2", "typing-extensions>=3.10.0; python_version < \"3.10\"", @@ -2284,7 +2284,7 @@ name = "typing-inspection" version = "0.4.0" requires_python = ">=3.9" summary = "Runtime typing introspection tools" -groups = ["default", "mcp"] +groups = ["default", "chat", "mcp"] dependencies = [ "typing-extensions>=4.12.0", ] @@ -2320,7 +2320,7 @@ name = "uvicorn" version = "0.34.3" requires_python = ">=3.9" summary = "The lightning-fast ASGI server." -groups = ["default", "mcp"] +groups = ["chat", "mcp"] dependencies = [ "click>=7.0", "h11>=0.8", diff --git a/pyproject.toml b/pyproject.toml index 496501cc..515b6dae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,6 @@ dependencies = [ "segment-analytics-python>=2.3.4", "py-machineid>=0.7.0", "httpx>=0.28.1", - "pynput>=1.8.1", - "mss>=10.0.0", - "fastapi>=0.115.12", - "uvicorn>=0.34.3", ] requires-python = ">=3.10" readme = "README.md" @@ -39,7 +35,6 @@ build-backend = "hatchling.build" [tool.hatch.version] path = "src/askui/__init__.py" - [tool.pdm] distribution = true @@ -65,6 +60,12 @@ api = "uvicorn src.askui.chat.api.fastapi:app --reload --port 8000" [dependency-groups] chat = [ "streamlit>=1.42.0", + "fastapi>=0.115.12", + "uvicorn>=0.34.3", +] +pynput = [ + "mss>=10.0.0", + "pynput>=1.8.1", ] mcp = [ "mcp[cli,rich,ws]>=1.8.1", From 7ed587867659f33c603a2cce4d9f40a6df14c472 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 17:01:52 +0200 Subject: [PATCH 16/20] feat!(chat): remove unused image API --- src/askui/chat/api/images/__init__.py | 0 src/askui/chat/api/images/dependencies.py | 0 src/askui/chat/api/images/router.py | 19 ------------------- src/askui/chat/api/messages/service.py | 1 - 4 files changed, 20 deletions(-) delete mode 100644 src/askui/chat/api/images/__init__.py delete mode 100644 src/askui/chat/api/images/dependencies.py delete mode 100644 src/askui/chat/api/images/router.py diff --git a/src/askui/chat/api/images/__init__.py b/src/askui/chat/api/images/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/images/dependencies.py b/src/askui/chat/api/images/dependencies.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/askui/chat/api/images/router.py b/src/askui/chat/api/images/router.py deleted file mode 100644 index 7181489f..00000000 --- a/src/askui/chat/api/images/router.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter, HTTPException -from fastapi.responses import FileResponse - -from askui.chat.api.dependencies import SettingsDep -from askui.chat.api.settings import Settings - -router = APIRouter(prefix="/images", tags=["images"]) - - -@router.get("/{image_path:path}") -def retrieve_image( - image_path: str, - settings: Settings = SettingsDep, -) -> FileResponse: - """Get an image by path.""" - full_path = settings.data_dir / "images" / image_path - if not full_path.exists(): - raise HTTPException(status_code=404, detail="Image not found") - return FileResponse(full_path) diff --git a/src/askui/chat/api/messages/service.py b/src/askui/chat/api/messages/service.py index 9cf84801..8af02967 100644 --- a/src/askui/chat/api/messages/service.py +++ b/src/askui/chat/api/messages/service.py @@ -46,7 +46,6 @@ def __init__(self, base_dir: Path) -> None: """ self._base_dir = base_dir self._threads_dir = base_dir / "threads" - self._images_dir = base_dir / "images" def list_( self, thread_id: str, limit: int | None = None, after: str | None = None From 4b4e01752dcd000c5105e72d458fdd223b70fcba Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 17:14:28 +0200 Subject: [PATCH 17/20] fix(chat): fix linting, status codes, return types etc. --- pyproject.toml | 2 +- src/askui/chat/__main__.py | 2 +- src/askui/chat/api/{fastapi.py => app.py} | 3 --- src/askui/chat/api/messages/router.py | 6 +++--- src/askui/chat/api/runs/router.py | 15 ++++++++++----- src/askui/chat/api/settings.py | 2 +- src/askui/chat/api/threads/router.py | 6 +++--- 7 files changed, 19 insertions(+), 17 deletions(-) rename src/askui/chat/api/{fastapi.py => app.py} (80%) diff --git a/pyproject.toml b/pyproject.toml index 515b6dae..0be5b9c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,8 +54,8 @@ lint = "ruff check src tests" typecheck = "mypy" "typecheck:all" = "mypy src tests" chat = "streamlit run src/askui/chat/__main__.py" +"chat:api" = "uvicorn src.askui.chat.api.app:app --reload --port 8000" mcp = "mcp dev src/askui/mcp/__init__.py" -api = "uvicorn src.askui.chat.api.fastapi:app --reload --port 8000" [dependency-groups] chat = [ diff --git a/src/askui/chat/__main__.py b/src/askui/chat/__main__.py index cc3e8c37..25d47365 100644 --- a/src/askui/chat/__main__.py +++ b/src/askui/chat/__main__.py @@ -65,7 +65,7 @@ def get_image( return Image.open(io.BytesIO(response.content)) -def write_message( +def write_message( # noqa: C901 message: Message, ) -> None: # Create a container for the message and delete button diff --git a/src/askui/chat/api/fastapi.py b/src/askui/chat/api/app.py similarity index 80% rename from src/askui/chat/api/fastapi.py rename to src/askui/chat/api/app.py index 93e0d5c2..48b4eafe 100644 --- a/src/askui/chat/api/fastapi.py +++ b/src/askui/chat/api/app.py @@ -1,14 +1,12 @@ from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware -from askui.chat.api.images.router import router as images_router from askui.chat.api.messages.router import router as messages_router from askui.chat.api.runs.router import router as runs_router from askui.chat.api.threads.router import router as threads_router app = FastAPI( title="AskUI Chat API", - description="REST API for managing chat threads and messages", version="0.1.0", ) @@ -25,6 +23,5 @@ v1_router = APIRouter(prefix="/v1") v1_router.include_router(threads_router) v1_router.include_router(messages_router) -v1_router.include_router(images_router) v1_router.include_router(runs_router) app.include_router(v1_router) diff --git a/src/askui/chat/api/messages/router.py b/src/askui/chat/api/messages/router.py index 2544862b..114a1c03 100644 --- a/src/askui/chat/api/messages/router.py +++ b/src/askui/chat/api/messages/router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, status from askui.chat.api.messages.dependencies import MessageServiceDep from askui.chat.api.messages.service import Message, MessageListResponse, MessageService @@ -20,7 +20,7 @@ def list_messages( raise HTTPException(status_code=404, detail=str(e)) from e -@router.post("") +@router.post("", status_code=status.HTTP_201_CREATED) async def create_message( thread_id: str, message: MessageParam, @@ -49,7 +49,7 @@ def retrieve_message( raise HTTPException(status_code=404, detail=str(e)) from e -@router.delete("/{message_id}") +@router.delete("/{message_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_message( thread_id: str, message_id: str, diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py index 120b40b8..c3ee94b7 100644 --- a/src/askui/chat/api/runs/router.py +++ b/src/askui/chat/api/runs/router.py @@ -1,8 +1,8 @@ from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Annotated, cast -from fastapi import APIRouter, Body, HTTPException, Path -from fastapi.responses import StreamingResponse +from fastapi import APIRouter, Body, HTTPException, Path, Response, status +from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel if TYPE_CHECKING: @@ -24,7 +24,7 @@ def create_run( thread_id: Annotated[str, Path(...)], request: Annotated[CreateRunRequest, Body(...)], run_service: RunService = RunServiceDep, -) -> Run | StreamingResponse: +) -> Response: """ Create a new run for a given thread. """ @@ -39,8 +39,13 @@ async def sse_event_stream() -> AsyncGenerator[str, None]: async for event in async_generator: yield f"event: {event.event}\ndata: {event.model_dump_json()}\n\n" - return StreamingResponse(sse_event_stream(), media_type="text/event-stream") - return cast("Run", run_or_async_generator) + return StreamingResponse( + status_code=status.HTTP_201_CREATED, + content=sse_event_stream(), + media_type="text/event-stream", + ) + run = cast("Run", run_or_async_generator) + return JSONResponse(status_code=status.HTTP_201_CREATED, content=run.model_dump()) @router.get("/{run_id}") diff --git a/src/askui/chat/api/settings.py b/src/askui/chat/api/settings.py index 52e3b50c..c091a6cf 100644 --- a/src/askui/chat/api/settings.py +++ b/src/askui/chat/api/settings.py @@ -12,6 +12,6 @@ class Settings(BaseSettings): ) data_dir: Path = Field( - default_factory=lambda: Path.home() / ".askui" / "chat", + default_factory=lambda: Path.cwd() / "chat", description="Base directory for storing chat data", ) diff --git a/src/askui/chat/api/threads/router.py b/src/askui/chat/api/threads/router.py index 00e4d16a..f899863e 100644 --- a/src/askui/chat/api/threads/router.py +++ b/src/askui/chat/api/threads/router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, status from askui.chat.api.threads.dependencies import ThreadServiceDep from askui.chat.api.threads.service import Thread, ThreadListResponse, ThreadService @@ -15,7 +15,7 @@ def list_threads( return thread_service.list_(limit=limit) -@router.post("") +@router.post("", status_code=status.HTTP_201_CREATED) def create_thread( thread_service: ThreadService = ThreadServiceDep, ) -> Thread: @@ -35,7 +35,7 @@ def retrieve_thread( raise HTTPException(status_code=404, detail=str(e)) from e -@router.delete("/{thread_id}") +@router.delete("/{thread_id}", status_code=status.HTTP_204_NO_CONTENT) def delete_thread( thread_id: str, thread_service: ThreadService = ThreadServiceDep, From 09993c08d21144749fd00bba3bcdab903197cb87 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 17:15:00 +0200 Subject: [PATCH 18/20] fix(models): add logging for tool result messages removed by accident --- src/askui/models/shared/computer_agent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/askui/models/shared/computer_agent.py b/src/askui/models/shared/computer_agent.py index ebcc6d2f..0c4a74be 100644 --- a/src/askui/models/shared/computer_agent.py +++ b/src/askui/models/shared/computer_agent.py @@ -260,6 +260,8 @@ def _step( if tool_result_message := self._call_on_message( on_message, tool_result_message, messages ): + tool_result_message_dict = tool_result_message.model_dump(mode="json") + logger.debug(tool_result_message_dict) messages.append(tool_result_message) self._step( messages=messages, From 87893c2390e2f5d62f81937f16836f0d75564c33 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 17:15:54 +0200 Subject: [PATCH 19/20] docs(models,tools): improve docstrings for new features --- src/askui/models/models.py | 37 ++++++++++++++++++++++++++----------- src/askui/tools/agent_os.py | 12 +++++++++++- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 71099796..1c612d8c 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -149,10 +149,7 @@ def __getitem__(self, index: int) -> ModelDefinition: class ActModel(abc.ABC): """Abstract base class for models that can execute autonomous actions. - Models implementing this interface can be used with the `act()` method of - `VisionAgent` - to achieve goals through autonomous actions. These models analyze the screen and - determine necessary steps to accomplish a given goal. + Models implementing this interface can be used with the `VisionAgent.act()`. Example: ```python @@ -172,7 +169,7 @@ def act( model_choice: str, on_message: OnMessageCb | None = None, ) -> None: - print(messages) # implement custom logic here + pass # implement action logic here with VisionAgent(models={"my-act": MyActModel()}) as agent: agent.act("search for flights", model="my-act") @@ -187,14 +184,32 @@ def act( ) -> None: """ Execute autonomous actions to achieve a goal, using a message history - and optional callbacks. + and optional callbacks, encoded in the messages. In the simplest case, + it can be found in the first message `messages[0].content` as a `str`. + + The `messages` usually start with a `"user"` (role) message which is followed by + alternating `"assistant"` (AI agent) and `"user"` messages (which can be + automatic tool use, e.g., taking a screenshot) similar how you would + expect it from a conversation whereby the `"assistant"` determines the next + actions which are then automatically taking by the `"user"` programmatically + until it eventually returns, usually with an `"assistant"` message that either + says that the goal has been achieved or that it failed to achieve the goal. Args: - messages (list[MessageParam]): The message history to start from. - model_choice (str): The name of the model being used (useful for models - that support multiple configurations) - on_message (OnMessageCb | None, optional): Callback for new messages. - If it returns `None`, stops and does not add the message. + messages (list[MessageParam]): The message history to start that + determines the actions and following messages. + model_choice (str): The name of the model being used, e.g., useful for + models registered under multiple keys, e.g., `"my-act-1"` and + `"my-act-2"` that depending on the key (passed as `model_choice`) + behave differently. + on_message (OnMessageCb | None, optional): Callback for new messages + from either an assistant/agent or a user (including + automatic/programmatic tool use, e.g., taking a screenshot). + If it returns `None`, the acting is canceled and `act()` returns + immediately. If it returns a `MessageParam`, this `MessageParma` is + added to the message history and the acting continues based on the + message. The message may be modified by the callback to allow for + directing the assistant/agent or tool use. Returns: None diff --git a/src/askui/tools/agent_os.py b/src/askui/tools/agent_os.py index dd7ad3df..24b62ee0 100644 --- a/src/askui/tools/agent_os.py +++ b/src/askui/tools/agent_os.py @@ -458,15 +458,25 @@ def run_command(self, command: str, timeout_ms: int = 30000) -> None: def start_listening(self) -> None: """ Start listening for mouse and keyboard events. + + IMPORTANT: This method is still experimental and may not work at all and may + change in the future. """ raise NotImplementedError def poll_event(self) -> InputEvent | None: """ Poll for a single input event. + + IMPORTANT: This method is still experimental and may not work at all and may + change in the future. """ raise NotImplementedError def stop_listening(self) -> None: - """Stop listening for mouse and keyboard events.""" + """Stop listening for mouse and keyboard events. + + IMPORTANT: This method is still experimental and may not work at all and may + change in the future. + """ raise NotImplementedError From c6fbeeee89b8c90c70facf89fefcf8702a067582 Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Thu, 5 Jun 2025 17:16:21 +0200 Subject: [PATCH 20/20] feat(agent): fix logging and reporting of act() --- src/askui/agent.py | 9 +++++++-- tests/unit/models/test_computer_agent_filter.py | 3 ++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/askui/agent.py b/src/askui/agent.py index 0e1023a9..07e0fb49 100644 --- a/src/askui/agent.py +++ b/src/askui/agent.py @@ -565,9 +565,14 @@ def act( agent.act("Log in with username 'admin' and password '1234'") ``` """ - self._reporter.add_message("User", f'act: "{goal}"') + goal_str = ( + goal + if isinstance(goal, str) + else "\n".join(msg.model_dump_json() for msg in goal) + ) + self._reporter.add_message("User", f'act: "{goal_str}"') logger.debug( - "VisionAgent received instruction to act towards the goal '%s'", goal + "VisionAgent received instruction to act towards the goal '%s'", goal_str ) messages: list[MessageParam] = ( [MessageParam(role="user", content=goal)] if isinstance(goal, str) else goal diff --git a/tests/unit/models/test_computer_agent_filter.py b/tests/unit/models/test_computer_agent_filter.py index 4199e8b4..a57fc59e 100644 --- a/tests/unit/models/test_computer_agent_filter.py +++ b/tests/unit/models/test_computer_agent_filter.py @@ -72,9 +72,10 @@ def test_exactly_images_to_keep() -> None: ): assert len(first_block.content) == 4 else: - raise AssertionError( + error_msg = ( "filtered[0].content[0] is not a ToolResultBlockParam with list content" ) + raise AssertionError(error_msg) # noqa: TRY004 all_tool_result_contents = [ c for m in filtered