diff --git a/docs/mcp.md b/docs/mcp.md new file mode 100644 index 00000000..b15823ba --- /dev/null +++ b/docs/mcp.md @@ -0,0 +1,167 @@ +# MCP + +## What is MCP? + +The Model Context Protocol (MCP) is a standardized way to provide context and tools to Large Language Models (LLMs). It acts as a universal interface - often described as "the USB-C port for AI" - that allows LLMs to connect to external resources and functionality in a secure, standardized manner. + +MCP servers can: +- **Expose data** through `Resources` (similar to GET endpoints for loading information into the LLM's context) +- **Provide functionality** through `Tools` (similar to POST endpoints for executing code or producing side effects) +- **Define interaction patterns** through `Prompts` (reusable templates for LLM interactions) + +For more information, see the [MCP specification](https://modelcontextprotocol.io/docs/getting-started/intro). + +## Our MCP Support + +We currently support the use of tools from MCP servers both through the library `AgentBase.act()` (and extending classes, e.g., `VisionAgent.act()`, `AndroidVisionAgent.act()`) and through the Chat API. + +The use of tools is comprised of listing the tools available, passing them on to the model and calling them if the model requests them to be called. + +The Chat API builds on top of the library which in turn uses [`fastmcp`](https://gofastmcp.com/getting-started/welcome). + +## How to Use MCP with AskUI + +### With the Library + +You can integrate MCP tools directly into your AskUI agents by creating an MCP client and passing it to the `ToolCollection`: + +```python +from fastmcp import Client +from fastmcp.mcp_config import MCPConfig, RemoteMCPServer + +from askui.agent import VisionAgent +from askui.models.shared.agent_message_param import MessageParam +from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.tools import ToolCollection +from askui.tools.mcp.config import StdioMCPServer + +# Create MCP configuration +mcp_config = MCPConfig( + mcpServers={ + # Make sure to use our patch of StdioMCPServer as we don't support the official one + "test_stdio_server": StdioMCPServer( + command="python", args=["-m", "askui.tools.mcp.servers.stdio"] + ), + "test_sse_server": RemoteMCPServer(url="http://127.0.0.1:8001/sse/"), + } +) + +# Create MCP client +mcp_client = Client(mcp_config) + +# Create tool collection with MCP tools +tools = ToolCollection(mcp_client=mcp_client) + + +def on_message(param: OnMessageCbParam) -> MessageParam | None: + print(param.message.model_dump_json()) + return param.message + + +# Use with VisionAgent +with VisionAgent() as agent: + agent.act( + "Use the `test_stdio_server_test_stdio_tool`", + tools=tools, + on_message=on_message, + ) +``` + +Tools are appended to the default tools of the agent, potentially, overriding them. + +For different ways to construct `Client`s see the [fastmcp documentation](https://gofastmcp.com/clients/client). + +Notice that the tool name (`test_stdio_tool`) is prefixed with the server name (`test_stdio_server`) to avoid conflicts. +This differs between Chat API and library. In the Chat API, the tool name is prefixed with the id of the MCP config. +More about that later. + + +If you would like to try out the `test_sse_server` you can run the following command before executing the code above: + +```bash +python -m askui.tools.mcp.servers.sse +``` + +**Caveats and limmitations** + +- **No Tool Selection/Filtering**: All MCP tools from connected servers are automatically available +- **Synchronous Code Requirement**: MCP tools must be run from within synchronous code contexts (no `async` or `await` allowed) +- **Limited Tool Response Content Types**: Only text and images (JPEG, PNG, GIF, WebP) are supported +- **Complexity Limits**: Tools are limited in number and complexity by the model's context window + +### With Chat + +To use MCP servers or, more specifically, the tools they provide with the Chat (API), you need to create MCP configs. All agents are going to have access to the servers specified within the configs if it possible to connect to them. + +#### Creating MCP Configs + +An MCP configuration can be created with either stdio or remote server similar with what is used when constructing the `MCPConfig` in the library example above + +```bash +curl -X 'POST' \ + 'http://localhost:9261/v1/mcp-configs' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "name": "test_stdio_server", + "mcp_server": { + "command": "python", + "args": [ + "-m", "askui.tools.mcp.servers.stdio" + ] + } +}' +curl -X 'POST' \ + 'http://localhost:9261/v1/mcp-configs' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "name": "test_sse_server", + "mcp_server": { + "url": "http://127.0.0.1:8001/sse/" + } +}' +``` + +Notice that each server needs to be created separately and is managed as a separate entity. For other endpoints to manage MCP configs, start the chat api using `python -m askui.chat.api` and see the [Chat API documentation](http://localhost:9261/docs#/mcp-configs). + +#### Caveats of Using MCP with Chat + +When using MCP through the Chat API, consider these limitations: + +- **Configuration Limit**: Maximum of 100 MCP configurations allowed +- **Universal Availability**: All servers are currently passed to all available agents +- **No Filtering**: No way to filter servers or tools for specific use cases +- **Tool Name Prefixing**: Tool names are automatically prefixed with the MCP config ID +- **Server Availability**: When a server is not available (Chat API cannot connect), it is silently ignored + +## How to Define Your Own MCP Server + +### Different Frameworks + +You can build MCP servers using various frameworks and languages: + +#### FastMCP (Python) - Recommended + +FastMCP provides the most Pythonic and straightforward way to build MCP servers: + +```python +from fastmcp import FastMCP + +mcp = FastMCP("My Server") + +@mcp.tool +def my_tool(param: str) -> str: + """My custom tool description.""" + return f"Processed: {param}" + +if __name__ == "__main__": + mcp.run(transport="stdio") # For AskUI integration + # or + mcp.run(transport="sse", port=8001) # For remote access +``` + +#### Official MCP SDKs + +The official MCP specification provides SDKs for multiple languages: +[https://modelcontextprotocol.io/docs/sdk](https://modelcontextprotocol.io/docs/sdk) diff --git a/pdm.lock b/pdm.lock index 1b361717..b7b75bde 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "all", "android", "chat", "dev", "mcp", "pynput", "test", "web"] +groups = ["default", "all", "android", "chat", "dev", "pynput", "test", "web"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:beb091cad08638d0d09be80ec10830745be0024dbe05a33bbd111a865950bba4" +content_hash = "sha256:0ca715599b1575a162ce51176bd474c59544297145f8d4ecd862bd7c14223775" [[metadata.targets]] requires_python = ">=3.10" @@ -15,7 +15,7 @@ name = "annotated-types" version = "0.7.0" requires_python = ">=3.8" summary = "Reusable constraint types to use with typing.Annotated" -groups = ["default", "all", "dev", "mcp"] +groups = ["default", "dev"] dependencies = [ "typing-extensions>=4.0.0; python_version < \"3.9\"", ] @@ -46,10 +46,10 @@ files = [ [[package]] name = "anyio" -version = "4.9.0" +version = "4.10.0" requires_python = ">=3.9" -summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["default", "all", "mcp"] +summary = "High-level concurrency and networking framework on top of asyncio or Trio" +groups = ["default", "all", "chat"] dependencies = [ "exceptiongroup>=1.0.2; python_version < \"3.11\"", "idna>=2.8", @@ -57,8 +57,8 @@ dependencies = [ "typing-extensions>=4.5; python_version < \"3.13\"", ] files = [ - {file = "anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c"}, - {file = "anyio-4.9.0.tar.gz", hash = "sha256:673c0c244e15788651a4ff38710fea9675823028a6f08a5eda409e0c9840a028"}, + {file = "anyio-4.10.0-py3-none-any.whl", hash = "sha256:60e474ac86736bbfd6f210f7a61218939c318f43f9972497381f1c5e930ed3d1"}, + {file = "anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6"}, ] [[package]] @@ -72,12 +72,27 @@ files = [ {file = "argcomplete-3.6.2.tar.gz", hash = "sha256:d0519b1bc867f5f4f4713c41ad0aba73a4a5f007449716b16f385f2166dc6adf"}, ] +[[package]] +name = "asyncer" +version = "0.0.8" +requires_python = ">=3.8" +summary = "Asyncer, async and await, focused on developer experience." +groups = ["default"] +dependencies = [ + "anyio<5.0,>=3.4.0", + "typing-extensions>=4.8.0; python_version < \"3.10\"", +] +files = [ + {file = "asyncer-0.0.8-py3-none-any.whl", hash = "sha256:5920d48fc99c8f8f0f1576e1882f5022885589c5fcbc46ce4224ec3e53776eeb"}, + {file = "asyncer-0.0.8.tar.gz", hash = "sha256:a589d980f57e20efb07ed91d0dbe67f1d2fd343e7142c66d3a099f05c620739c"}, +] + [[package]] name = "attrs" version = "25.3.0" requires_python = ">=3.8" summary = "Classes Without Boilerplate" -groups = ["all", "mcp"] +groups = ["default"] files = [ {file = "attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3"}, {file = "attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b"}, @@ -88,7 +103,7 @@ name = "authlib" version = "1.6.1" requires_python = ">=3.9" summary = "The ultimate Python library in building OAuth and OpenID Connect servers and clients." -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "cryptography", ] @@ -108,6 +123,18 @@ files = [ {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +requires_python = "<3.11,>=3.8" +summary = "Backport of asyncio.Runner, a context manager that controls event loop life cycle." +groups = ["test"] +marker = "python_version < \"3.11\"" +files = [ + {file = "backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5"}, + {file = "backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162"}, +] + [[package]] name = "beautifulsoup4" version = "4.13.4" @@ -175,7 +202,7 @@ name = "certifi" version = "2025.1.31" requires_python = ">=3.6" summary = "Python package for providing Mozilla's CA Bundle." -groups = ["default", "all", "mcp"] +groups = ["default"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -186,7 +213,7 @@ name = "cffi" version = "1.17.1" requires_python = ">=3.8" summary = "Foreign Function Interface for Python calling C code." -groups = ["all", "mcp"] +groups = ["default"] marker = "platform_python_implementation != \"PyPy\"" dependencies = [ "pycparser", @@ -309,7 +336,7 @@ name = "click" version = "8.1.8" requires_python = ">=3.7" summary = "Composable command line interface toolkit" -groups = ["default", "all", "chat", "dev", "mcp"] +groups = ["default", "all", "chat", "dev"] dependencies = [ "colorama; platform_system == \"Windows\"", "importlib-metadata; python_version < \"3.8\"", @@ -335,7 +362,7 @@ name = "colorama" version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." -groups = ["default", "all", "chat", "dev", "mcp", "test"] +groups = ["default", "all", "chat", "dev", "test"] marker = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, @@ -491,7 +518,7 @@ name = "cryptography" version = "45.0.5" requires_python = "!=3.9.0,!=3.9.1,>=3.7" summary = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "cffi>=1.14; platform_python_implementation != \"PyPy\"", ] @@ -540,7 +567,7 @@ name = "cyclopts" version = "3.22.2" requires_python = ">=3.9" summary = "Intuitive, easy CLIs based on type hints." -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "attrs>=23.1.0", "docstring-parser>=0.15; python_version < \"4.0\"", @@ -604,7 +631,7 @@ name = "dnspython" version = "2.7.0" requires_python = ">=3.9" summary = "DNS toolkit" -groups = ["all", "mcp"] +groups = ["default"] files = [ {file = "dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86"}, {file = "dnspython-2.7.0.tar.gz", hash = "sha256:ce9c432eda0dc91cf618a5cedf1a4e142651196bbcd2c80e89ed5a907e5cfaf1"}, @@ -615,7 +642,7 @@ name = "docstring-parser" version = "0.17.0" requires_python = ">=3.8" summary = "Parse Python docstrings in reST, Google and Numpydoc format" -groups = ["all", "mcp"] +groups = ["default"] marker = "python_version < \"4.0\"" files = [ {file = "docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708"}, @@ -627,7 +654,7 @@ name = "docutils" version = "0.21.2" requires_python = ">=3.9" summary = "Docutils -- Python Documentation Utilities" -groups = ["all", "mcp"] +groups = ["default"] files = [ {file = "docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2"}, {file = "docutils-0.21.2.tar.gz", hash = "sha256:3a6b18732edf182daa3cd12775bbb338cf5691468f91eeeb109deff6ebfa986f"}, @@ -638,7 +665,7 @@ name = "email-validator" version = "2.2.0" requires_python = ">=3.8" summary = "A robust email address syntax and deliverability validation library." -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "dnspython>=2.0.0", "idna>=2.0.0", @@ -675,7 +702,7 @@ name = "exceptiongroup" version = "1.2.2" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["default", "all", "mcp", "test"] +groups = ["default", "all", "chat", "test"] files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -713,7 +740,7 @@ name = "fastmcp" version = "2.10.6" requires_python = ">=3.10" summary = "The fast, Pythonic way to build MCP servers and clients." -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "authlib>=1.5.2", "cyclopts>=3.0.0", @@ -1016,7 +1043,7 @@ name = "h11" version = "0.14.0" requires_python = ">=3.7" summary = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" -groups = ["default", "all", "chat", "mcp"] +groups = ["default", "all", "chat"] dependencies = [ "typing-extensions; python_version < \"3.8\"", ] @@ -1030,7 +1057,7 @@ name = "httpcore" version = "1.0.7" requires_python = ">=3.8" summary = "A minimal low-level HTTP client." -groups = ["default", "all", "mcp"] +groups = ["default"] dependencies = [ "certifi", "h11<0.15,>=0.13", @@ -1045,7 +1072,7 @@ name = "httpx" version = "0.28.1" requires_python = ">=3.8" summary = "The next generation HTTP client." -groups = ["default", "all", "mcp"] +groups = ["default"] dependencies = [ "anyio", "certifi", @@ -1062,7 +1089,7 @@ name = "httpx-sse" version = "0.4.1" requires_python = ">=3.9" summary = "Consume Server-Sent Event (SSE) messages with HTTPX." -groups = ["all", "mcp"] +groups = ["default"] files = [ {file = "httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37"}, {file = "httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e"}, @@ -1110,7 +1137,7 @@ name = "idna" version = "3.10" requires_python = ">=3.6" summary = "Internationalized Domain Names in Applications (IDNA)" -groups = ["default", "all", "mcp"] +groups = ["default", "all", "chat"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -1245,7 +1272,7 @@ name = "jsonschema" version = "4.25.0" requires_python = ">=3.9" summary = "An implementation of JSON Schema validation for Python" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "attrs>=22.2.0", "jsonschema-specifications>=2023.03.6", @@ -1262,7 +1289,7 @@ name = "jsonschema-specifications" version = "2025.4.1" requires_python = ">=3.9" summary = "The JSON Schema meta-schemas and vocabularies, exposed as a Registry" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "referencing>=0.31.0", ] @@ -1389,7 +1416,7 @@ name = "markdown-it-py" version = "3.0.0" requires_python = ">=3.8" summary = "Python port of markdown-it. Markdown parsing, done right!" -groups = ["default", "all", "mcp"] +groups = ["default"] dependencies = [ "mdurl~=0.1", ] @@ -1517,7 +1544,7 @@ name = "mcp" version = "1.12.1" requires_python = ">=3.10" summary = "Model Context Protocol SDK" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "anyio>=4.5", "httpx-sse>=0.4", @@ -1541,7 +1568,7 @@ name = "mdurl" version = "0.1.2" requires_python = ">=3.7" summary = "Markdown URL utilities" -groups = ["default", "all", "mcp"] +groups = ["default"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -1757,7 +1784,7 @@ name = "openapi-pydantic" version = "0.5.1" requires_python = "<4.0,>=3.8" summary = "Pydantic OpenAPI schema implementation" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "pydantic>=1.8", ] @@ -2034,7 +2061,7 @@ name = "pycparser" version = "2.22" requires_python = ">=3.8" summary = "C parser in Python" -groups = ["all", "mcp"] +groups = ["default"] marker = "platform_python_implementation != \"PyPy\"" files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, @@ -2046,7 +2073,7 @@ name = "pydantic" version = "2.11.7" requires_python = ">=3.9" summary = "Data validation using Python type hints" -groups = ["default", "all", "dev", "mcp"] +groups = ["default", "dev"] dependencies = [ "annotated-types>=0.6.0", "pydantic-core==2.33.2", @@ -2063,7 +2090,7 @@ name = "pydantic-core" version = "2.33.2" requires_python = ">=3.9" summary = "Core functionality for Pydantic validation and serialization" -groups = ["default", "all", "dev", "mcp"] +groups = ["default", "dev"] dependencies = [ "typing-extensions!=4.7.0,>=4.6.0", ] @@ -2152,7 +2179,7 @@ name = "pydantic-settings" version = "2.9.1" requires_python = ">=3.9" summary = "Settings management using Pydantic" -groups = ["default", "all", "mcp"] +groups = ["default"] dependencies = [ "pydantic>=2.7.0", "python-dotenv>=0.21.0", @@ -2169,7 +2196,7 @@ version = "2.11.7" extras = ["email"] requires_python = ">=3.9" summary = "Data validation using Python type hints" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "email-validator>=2.0.0", "pydantic==2.11.7", @@ -2198,7 +2225,7 @@ name = "pygments" version = "2.19.1" requires_python = ">=3.8" summary = "Pygments is a syntax highlighting package written in Python." -groups = ["default", "all", "mcp"] +groups = ["default"] files = [ {file = "pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c"}, {file = "pygments-2.19.1.tar.gz", hash = "sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f"}, @@ -2335,7 +2362,7 @@ files = [ name = "pyperclip" version = "1.9.0" summary = "A cross-platform clipboard module for Python. (Only handles plain text for now.)" -groups = ["default", "all", "mcp"] +groups = ["default"] files = [ {file = "pyperclip-1.9.0.tar.gz", hash = "sha256:b7de0142ddc81bfc5c7507eea19da920b92252b548b96186caf94a5e2527d310"}, ] @@ -2371,6 +2398,22 @@ files = [ {file = "pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845"}, ] +[[package]] +name = "pytest-asyncio" +version = "1.1.0" +requires_python = ">=3.9" +summary = "Pytest support for asyncio" +groups = ["test"] +dependencies = [ + "backports-asyncio-runner<2,>=1.1; python_version < \"3.11\"", + "pytest<9,>=8.2", + "typing-extensions>=4.12; python_version < \"3.10\"", +] +files = [ + {file = "pytest_asyncio-1.1.0-py3-none-any.whl", hash = "sha256:5fe2d69607b0bd75c656d1211f969cadba035030156745ee09e7d71740e58ecf"}, + {file = "pytest_asyncio-1.1.0.tar.gz", hash = "sha256:796aa822981e01b68c12e4827b8697108f7205020f24b5793b3c41555dab68ea"}, +] + [[package]] name = "pytest-cov" version = "6.1.1" @@ -2448,7 +2491,7 @@ name = "python-dotenv" version = "1.1.0" requires_python = ">=3.9" summary = "Read key-value pairs from a .env file and set them as environment variables" -groups = ["default", "all", "mcp"] +groups = ["default"] files = [ {file = "python_dotenv-1.1.0-py3-none-any.whl", hash = "sha256:d7c01d9e2293916c18baf562d95698754b0dbbb5e74d457c45d4f6561fb9d55d"}, {file = "python_dotenv-1.1.0.tar.gz", hash = "sha256:41f90bc6f5f177fb41f53e87666db362025010eb28f60a01c9143bfa33a2b2d5"}, @@ -2459,7 +2502,7 @@ name = "python-multipart" version = "0.0.20" requires_python = ">=3.8" summary = "A streaming multipart parser for Python" -groups = ["all", "mcp"] +groups = ["default"] files = [ {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, @@ -2493,7 +2536,7 @@ files = [ name = "pywin32" version = "311" summary = "Python for Window Extensions" -groups = ["all", "mcp"] +groups = ["default"] marker = "sys_platform == \"win32\"" files = [ {file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"}, @@ -2564,7 +2607,7 @@ name = "referencing" version = "0.36.2" requires_python = ">=3.9" summary = "JSON Referencing + Python" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "attrs>=22.2.0", "rpds-py>=0.7.0", @@ -2597,7 +2640,7 @@ name = "rich" version = "14.0.0" requires_python = ">=3.8.0" summary = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -groups = ["default", "all", "mcp"] +groups = ["default"] dependencies = [ "markdown-it-py>=2.2.0", "pygments<3.0.0,>=2.13.0", @@ -2613,7 +2656,7 @@ name = "rich-rst" version = "1.3.1" requires_python = ">=3.6" summary = "A beautiful reStructuredText renderer for rich" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "docutils", "rich>=12.0.0", @@ -2628,7 +2671,7 @@ name = "rpds-py" version = "0.26.0" requires_python = ">=3.9" summary = "Python bindings to Rust's persistent data structures (rpds)" -groups = ["all", "mcp"] +groups = ["default"] files = [ {file = "rpds_py-0.26.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:4c70c70f9169692b36307a95f3d8c0a9fcd79f7b4a383aad5eaa0e9718b79b37"}, {file = "rpds_py-0.26.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:777c62479d12395bfb932944e61e915741e364c843afc3196b694db3d669fcd0"}, @@ -2836,7 +2879,7 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["default", "all", "mcp"] +groups = ["default", "all", "chat"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -2858,7 +2901,7 @@ name = "sse-starlette" version = "2.4.1" requires_python = ">=3.9" summary = "SSE plugin for Starlette" -groups = ["all", "mcp"] +groups = ["default"] dependencies = [ "anyio>=4.7.0", ] @@ -2872,7 +2915,7 @@ name = "starlette" version = "0.46.2" requires_python = ">=3.9" summary = "The little ASGI library that shines." -groups = ["default", "all", "mcp"] +groups = ["default"] dependencies = [ "anyio<5,>=3.6.2", "typing-extensions>=3.10.0; python_version < \"3.10\"", @@ -3053,7 +3096,7 @@ name = "typing-extensions" version = "4.14.1" requires_python = ">=3.9" summary = "Backported and Experimental Type Hints for Python 3.9+" -groups = ["default", "all", "chat", "dev", "mcp", "test", "web"] +groups = ["default", "all", "chat", "dev", "test", "web"] files = [ {file = "typing_extensions-4.14.1-py3-none-any.whl", hash = "sha256:d1e1e3b58374dc93031d6eda2420a48ea44a36c2b4766a4fdeb3710755731d76"}, {file = "typing_extensions-4.14.1.tar.gz", hash = "sha256:38b39f4aeeab64884ce9f74c94263ef78f3c22467c8724005483154c26648d36"}, @@ -3064,7 +3107,7 @@ name = "typing-inspection" version = "0.4.0" requires_python = ">=3.9" summary = "Runtime typing introspection tools" -groups = ["default", "all", "dev", "mcp"] +groups = ["default", "dev"] dependencies = [ "typing-extensions>=4.12.0", ] @@ -3100,7 +3143,7 @@ name = "uvicorn" version = "0.34.3" requires_python = ">=3.9" summary = "The lightning-fast ASGI server." -groups = ["all", "chat", "mcp"] +groups = ["default", "all", "chat"] dependencies = [ "click>=7.0", "h11>=0.8", diff --git a/pyproject.toml b/pyproject.toml index dd883241..943037bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ authors = [ dependencies = [ "anthropic>=0.54.0", "fastapi>=0.115.12", + "fastmcp>=2.3.0", "gradio-client>=1.4.3", "grpcio>=1.73.1", "httpx>=0.28.1", @@ -27,6 +28,7 @@ dependencies = [ "google-genai>=1.20.0", "filetype>=1.2.0", "markitdown[xls,xlsx,docx]>=0.1.2", + "asyncer==0.0.8", ] requires-python = ">=3.10" readme = "README.md" @@ -41,6 +43,7 @@ build-backend = "hatchling.build" path = "src/askui/__init__.py" + [tool.pdm] distribution = true @@ -60,7 +63,6 @@ lint = "ruff check src tests" typecheck = "mypy" "typecheck:all" = "mypy ." "chat:api" = "uvicorn askui.chat.api.app:app --reload --port 9261" -"mcp:dev" = "mcp dev src/askui/mcp/__init__.py" "qa:fix" = { composite = [ "typecheck:all", "format", @@ -86,6 +88,7 @@ test = [ "pytest-timeout>=2.4.0", "types-pynput>=1.8.1.20250318", "playwright>=1.41.0", + "pytest-asyncio>=1.1.0", ] dev = [ "datamodel-code-generator>=0.31.2", @@ -208,16 +211,14 @@ known-first-party = ["askui"] known-third-party = ["pytest", "mypy"] [project.optional-dependencies] -all = ["askui[android,chat,mcp,pynput,web]"] +all = ["askui[android,chat,pynput,web]"] android = [ "pure-python-adb>=0.3.0.dev0" ] chat = [ "askui[android,pynput,web]", "uvicorn>=0.34.3", -] -mcp = [ - "fastmcp>=2.3.4", + "anyio>=4.10.0", ] pynput = [ "mss>=10.0.0", @@ -225,4 +226,4 @@ pynput = [ ] web = [ "playwright>=1.41.0", -] \ No newline at end of file +] diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index cb2b15ca..43387ec9 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -13,7 +13,7 @@ from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool +from askui.models.shared.tools import Tool, ToolCollection from askui.tools.agent_os import AgentOs from askui.tools.android.agent_os import AndroidAgentOs from askui.utils.image_utils import ImageSource @@ -51,7 +51,8 @@ def __init__( configure_logging(level=log_level) self._reporter = reporter self._agent_os = agent_os - self._tools: list[Tool] = tools or [] + + self._tools = tools or [] self._model_router = self._init_model_router( reporter=self._reporter, models=models or {}, @@ -111,13 +112,13 @@ def _init_model_choice( } @telemetry.record_call(exclude={"goal", "on_message", "settings", "tools"}) - @validate_call + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def act( self, goal: Annotated[str | list[MessageParam], Field(min_length=1)], model: str | None = None, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: list[Tool] | ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: """ @@ -134,8 +135,8 @@ def act( be used for achieving the `goal`. on_message (OnMessageCb | None, optional): Callback for new messages. If it returns `None`, stops and does not add the message. - tools (list[Tool] | None, optional): The tools for the agent. - Defaults to a list of default tools depending on the selected model. + tools (list[Tool] | ToolCollection | None, optional): The tools for the + agent. Defaults to default tools depending on the selected model. settings (AgentSettings | None, optional): The settings for the agent. Defaults to a default settings depending on the selected model. @@ -171,7 +172,7 @@ def act( ) model_choice = model or self._model_choice["act"] _settings = settings or self._get_default_settings_for_act(model_choice) - _tools = tools or self._get_default_tools_for_act(model_choice) + _tools = self._build_tools(tools, model_choice) self._model_router.act( messages=messages, model_choice=model_choice, @@ -180,6 +181,16 @@ def act( tools=_tools, ) + def _build_tools( + self, tools: list[Tool] | ToolCollection | None, model_choice: str + ) -> ToolCollection: + default_tools = self._get_default_tools_for_act(model_choice) + if isinstance(tools, list): + return ToolCollection(tools=default_tools + tools) + if isinstance(tools, ToolCollection): + return ToolCollection(default_tools) + tools + return ToolCollection(tools=default_tools) + def _get_default_settings_for_act(self, model_choice: str) -> ActSettings: # noqa: ARG002 return ActSettings() diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index fe2781d4..8acfaf14 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -8,6 +8,7 @@ from askui.chat.api.assistants.router import router as assistants_router from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings from askui.chat.api.health.router import router as health_router +from askui.chat.api.mcp_configs.router import router as mcp_configs_router from askui.chat.api.messages.router import router as messages_router from askui.chat.api.runs.router import router as runs_router from askui.chat.api.threads.router import router as threads_router @@ -15,7 +16,8 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 - assistant_service = get_assistant_service(settings=get_settings()) + settings = get_settings() + assistant_service = get_assistant_service(settings=settings) assistant_service.seed() yield @@ -42,5 +44,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 v1_router.include_router(threads_router) v1_router.include_router(messages_router) v1_router.include_router(runs_router) +v1_router.include_router(mcp_configs_router) v1_router.include_router(health_router) app.include_router(v1_router) diff --git a/src/askui/chat/api/mcp_configs/__init__.py b/src/askui/chat/api/mcp_configs/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/askui/chat/api/mcp_configs/__init__.py @@ -0,0 +1 @@ + diff --git a/src/askui/chat/api/mcp_configs/dependencies.py b/src/askui/chat/api/mcp_configs/dependencies.py new file mode 100644 index 00000000..3ed0a11b --- /dev/null +++ b/src/askui/chat/api/mcp_configs/dependencies.py @@ -0,0 +1,13 @@ +from fastapi import Depends + +from askui.chat.api.dependencies import SettingsDep +from askui.chat.api.mcp_configs.service import McpConfigService +from askui.chat.api.settings import Settings + + +def get_mcp_config_service(settings: Settings = SettingsDep) -> McpConfigService: + """Get McpConfigService instance.""" + return McpConfigService(settings.data_dir) + + +McpConfigServiceDep = Depends(get_mcp_config_service) diff --git a/src/askui/chat/api/mcp_configs/models.py b/src/askui/chat/api/mcp_configs/models.py new file mode 100644 index 00000000..b74cc3f7 --- /dev/null +++ b/src/askui/chat/api/mcp_configs/models.py @@ -0,0 +1,53 @@ +from typing import Literal + +from fastmcp.mcp_config import RemoteMCPServer, StdioMCPServer +from pydantic import BaseModel, ConfigDict, Field + +from askui.chat.api.models import McpConfigId +from askui.utils.datetime_utils import UnixDatetime, now +from askui.utils.id_utils import generate_time_ordered_id +from askui.utils.not_given import NOT_GIVEN, BaseModelWithNotGiven, NotGiven + +McpServer = StdioMCPServer | RemoteMCPServer + + +class McpConfigCreateParams(BaseModel): + """Parameters for creating an MCP configuration.""" + + name: str + mcp_server: McpServer + + +class McpConfigModifyParams(BaseModelWithNotGiven): + """Parameters for modifying an MCP configuration.""" + + name: str | NotGiven = NOT_GIVEN + mcp_server: McpServer | NotGiven = Field(default=NOT_GIVEN) + + +class McpConfig(BaseModel): + """An MCP configuration that can be stored and managed.""" + + id: McpConfigId = Field( + default_factory=lambda: generate_time_ordered_id("mcp_config") + ) + created_at: UnixDatetime = Field(default_factory=now) + name: str + object: Literal["mcp_config"] = "mcp_config" + mcp_server: McpServer = Field(description="The MCP server configuration") + + @classmethod + def create(cls, params: McpConfigCreateParams) -> "McpConfig": + return cls( + id=generate_time_ordered_id("mcp_config"), + created_at=now(), + **params.model_dump(), + ) + + def modify(self, params: McpConfigModifyParams) -> "McpConfig": + return McpConfig.model_validate( + { + **self.model_dump(), + **params.model_dump(), + } + ) diff --git a/src/askui/chat/api/mcp_configs/router.py b/src/askui/chat/api/mcp_configs/router.py new file mode 100644 index 00000000..62386785 --- /dev/null +++ b/src/askui/chat/api/mcp_configs/router.py @@ -0,0 +1,73 @@ +from fastapi import APIRouter, HTTPException, status + +from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep +from askui.chat.api.mcp_configs.models import ( + McpConfig, + McpConfigCreateParams, + McpConfigModifyParams, +) +from askui.chat.api.mcp_configs.service import McpConfigService +from askui.chat.api.models import ListQueryDep, McpConfigId +from askui.utils.api_utils import LimitReachedError, ListQuery, ListResponse + +router = APIRouter(prefix="/mcp-configs", tags=["mcp-configs"]) + + +@router.get("", response_model_exclude_none=True) +def list_mcp_configs( + query: ListQuery = ListQueryDep, + mcp_config_service: McpConfigService = McpConfigServiceDep, +) -> ListResponse[McpConfig]: + """List all MCP configurations.""" + return mcp_config_service.list_(query=query) + + +@router.post("", status_code=status.HTTP_201_CREATED, response_model_exclude_none=True) +def create_mcp_config( + params: McpConfigCreateParams, + mcp_config_service: McpConfigService = McpConfigServiceDep, +) -> McpConfig: + """Create a new MCP configuration.""" + try: + return mcp_config_service.create(params) + except LimitReachedError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) + ) from e + + +@router.get("/{mcp_config_id}", response_model_exclude_none=True) +def retrieve_mcp_config( + mcp_config_id: McpConfigId, + mcp_config_service: McpConfigService = McpConfigServiceDep, +) -> McpConfig: + """Get an MCP configuration by ID.""" + try: + return mcp_config_service.retrieve(mcp_config_id) + except FileNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + + +@router.patch("/{mcp_config_id}", response_model_exclude_none=True) +def modify_mcp_config( + mcp_config_id: McpConfigId, + params: McpConfigModifyParams, + mcp_config_service: McpConfigService = McpConfigServiceDep, +) -> McpConfig: + """Update an MCP configuration.""" + try: + return mcp_config_service.modify(mcp_config_id, params) + except FileNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e + + +@router.delete("/{mcp_config_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_mcp_config( + mcp_config_id: McpConfigId, + mcp_config_service: McpConfigService = McpConfigServiceDep, +) -> None: + """Delete an MCP configuration.""" + try: + mcp_config_service.delete(mcp_config_id) + except FileNotFoundError as e: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) from e diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py new file mode 100644 index 00000000..abff9d55 --- /dev/null +++ b/src/askui/chat/api/mcp_configs/service.py @@ -0,0 +1,103 @@ +from pathlib import Path + +from pydantic import ValidationError + +from askui.utils.api_utils import ( + LIST_LIMIT_MAX, + ConflictError, + LimitReachedError, + ListQuery, + ListResponse, + NotFoundError, + list_resource_paths, +) + +from .models import McpConfig, McpConfigCreateParams, McpConfigId, McpConfigModifyParams + + +class McpConfigService: + """ + Service for managing McpConfig resources with filesystem persistence. + + Args: + base_dir (Path): Base directory for storing MCP configuration data. + """ + + def __init__(self, base_dir: Path) -> None: + self._base_dir = base_dir + self._mcp_configs_dir = base_dir / "mcp_configs" + self._mcp_configs_dir.mkdir(parents=True, exist_ok=True) + + def list_( + self, + query: ListQuery, + ) -> ListResponse[McpConfig]: + mcp_config_paths = list_resource_paths(self._mcp_configs_dir, query) + mcp_configs: list[McpConfig] = [] + for f in mcp_config_paths: + try: + mcp_config = McpConfig.model_validate_json(f.read_text()) + mcp_configs.append(mcp_config) + except ValidationError: # noqa: PERF203 + continue + has_more = len(mcp_configs) > query.limit + mcp_configs = mcp_configs[: query.limit] + return ListResponse( + data=mcp_configs, + first_id=mcp_configs[0].id if mcp_configs else None, + last_id=mcp_configs[-1].id if mcp_configs else None, + has_more=has_more, + ) + + def retrieve(self, mcp_config_id: McpConfigId) -> McpConfig: + mcp_config_file = self._mcp_configs_dir / f"{mcp_config_id}.json" + if not mcp_config_file.exists(): + error_msg = f"MCP configuration {mcp_config_id} not found" + raise NotFoundError(error_msg) + return McpConfig.model_validate_json(mcp_config_file.read_text()) + + def _check_limit(self) -> None: + limit = LIST_LIMIT_MAX + list_result = self.list_(ListQuery(limit=limit)) + if len(list_result.data) >= limit: + error_msg = ( + "MCP configuration limit reached. " + f"You may only have {limit} MCP configurations. " + "You can delete some MCP configurations to create new ones. " + ) + raise LimitReachedError(error_msg) + + def create(self, params: McpConfigCreateParams) -> McpConfig: + self._check_limit() + mcp_config = McpConfig.create(params) + self._save(mcp_config, new=True) + return mcp_config + + def modify( + self, mcp_config_id: McpConfigId, params: McpConfigModifyParams + ) -> McpConfig: + mcp_config = self.retrieve(mcp_config_id) + modified = mcp_config.modify(params) + self._save(modified) + return modified + + def delete(self, mcp_config_id: McpConfigId) -> None: + mcp_config_file = self._mcp_configs_dir / f"{mcp_config_id}.json" + if not mcp_config_file.exists(): + error_msg = f"MCP configuration {mcp_config_id} not found" + raise NotFoundError(error_msg) + mcp_config_file.unlink() + + def _save(self, mcp_config: McpConfig, new: bool = False) -> None: + """Save an MCP configuration to the file system.""" + self._mcp_configs_dir.mkdir(parents=True, exist_ok=True) + mcp_config_file = self._mcp_configs_dir / f"{mcp_config.id}.json" + if new and mcp_config_file.exists(): + error_msg = f"MCP configuration {mcp_config.id} already exists" + raise ConflictError(error_msg) + with mcp_config_file.open("w", encoding="utf-8") as f: + f.write( + mcp_config.model_dump_json( + exclude_unset=True, exclude_none=True, exclude_defaults=True + ) + ) diff --git a/src/askui/chat/api/models.py b/src/askui/chat/api/models.py index 091c0a64..2e222c2d 100644 --- a/src/askui/chat/api/models.py +++ b/src/askui/chat/api/models.py @@ -4,6 +4,7 @@ from askui.utils.api_utils import ListQuery AssistantId = str +McpConfigId = str FileId = str MessageId = str RunId = str diff --git a/src/askui/chat/api/runs/router.py b/src/askui/chat/api/runs/router.py index 915290f5..74db267e 100644 --- a/src/askui/chat/api/runs/router.py +++ b/src/askui/chat/api/runs/router.py @@ -1,7 +1,15 @@ from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Annotated, cast - -from fastapi import APIRouter, Body, HTTPException, Path, Response, status +from typing import Annotated + +from fastapi import ( + APIRouter, + BackgroundTasks, + Body, + HTTPException, + Path, + Response, + status, +) from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel @@ -13,29 +21,22 @@ from .models import Run from .service import RunService -if TYPE_CHECKING: - from .runner.events import Events - - router = APIRouter(prefix="/threads/{thread_id}/runs", tags=["runs"]) @router.post("") -def create_run( +async def create_run( thread_id: Annotated[ThreadId, Path(...)], request: Annotated[CreateRunRequest, Body(...)], + background_tasks: BackgroundTasks, run_service: RunService = RunServiceDep, ) -> Response: """ Create a new run for a given thread. """ stream = request.stream - run_or_async_generator = run_service.create(thread_id, stream, request) + run, async_generator = await run_service.create(thread_id, request) if stream: - async_generator = cast( - "AsyncGenerator[Events, None]", - run_or_async_generator, - ) async def sse_event_stream() -> AsyncGenerator[str, None]: async for event in async_generator: @@ -51,7 +52,12 @@ async def sse_event_stream() -> AsyncGenerator[str, None]: content=sse_event_stream(), media_type="text/event-stream", ) - run = cast("Run", run_or_async_generator) + + async def _run_async_generator() -> None: + async for _ in async_generator: + pass + + background_tasks.add_task(_run_async_generator) return JSONResponse(status_code=status.HTTP_201_CREATED, content=run.model_dump()) diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 78a25bc3..f8dd63b4 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -1,9 +1,15 @@ import logging -import queue import time from datetime import datetime, timezone from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Sequence + +import anyio +from anyio.abc import ObjectStream +from asyncer import asyncify, syncify +from fastmcp import Client +from fastmcp.client.transports import MCPConfigTransport +from fastmcp.mcp_config import MCPConfig from askui.agent import VisionAgent from askui.android_agent import AndroidVisionAgent @@ -14,6 +20,8 @@ ASKUI_WEB_TESTING_AGENT, HUMAN_DEMONSTRATION_AGENT, ) +from askui.chat.api.mcp_configs.models import McpConfig +from askui.chat.api.mcp_configs.service import McpConfigService from askui.chat.api.messages.service import MessageCreateRequest, MessageService from askui.chat.api.runs.models import Run, RunError from askui.chat.api.runs.runner.events.done_events import DoneEvent @@ -32,6 +40,7 @@ TextBlockParam, ) from askui.models.shared.agent_on_message_cb import OnMessageCbParam +from askui.models.shared.tools import ToolCollection from askui.tools.pynput_agent_os import PynputAgentOs from askui.utils.api_utils import LIST_LIMIT_MAX, ListQuery from askui.utils.image_utils import ImageSource @@ -44,6 +53,37 @@ logger = logging.getLogger(__name__) +def build_fast_mcp_config(mcp_configs: Sequence[McpConfig]) -> MCPConfig: + mcp_config_dict = { + mcp_config.id: mcp_config.mcp_server for mcp_config in mcp_configs + } + return MCPConfig(mcpServers=mcp_config_dict) + + +McpClient = Client[MCPConfigTransport] + + +def get_mcp_client( + base_dir: Path, +) -> McpClient: + """Get an MCP client from all available MCP configs. + + *Important*: This function can only handle up to 100 MCP server configs. Tool names + are prefixed with the `McpConfigId` (used as the FastMCP MCP server name) to avoid + conflicts. + + Args: + base_dir: The base directory of the MCP configs. + + Returns: + McpClient: A MCP client. + """ + mcp_config_service = McpConfigService(base_dir) + mcp_configs = mcp_config_service.list_(ListQuery(limit=LIST_LIMIT_MAX, order="asc")) + fast_mcp_config = build_fast_mcp_config(mcp_configs.data) + return Client(fast_mcp_config) + + class Runner: def __init__(self, run: Run, base_dir: Path) -> None: self._run = run @@ -52,7 +92,7 @@ def __init__(self, run: Run, base_dir: Path) -> None: self._msg_service = MessageService(self._base_dir) self._agent_os = PynputAgentOs() - def _run_human_agent(self, event_queue: queue.Queue[Events]) -> None: + async def _run_human_agent(self, send_stream: ObjectStream[Events]) -> None: message = self._msg_service.create( thread_id=self._run.thread_id, request=MessageCreateRequest( @@ -66,7 +106,7 @@ def _run_human_agent(self, event_queue: queue.Queue[Events]) -> None: run_id=self._run.id, ), ) - event_queue.put( + await send_stream.send( MessageEvent( data=message, event="thread.message.created", @@ -74,7 +114,7 @@ def _run_human_agent(self, event_queue: queue.Queue[Events]) -> None: ) self._agent_os.start_listening() screenshot = self._agent_os.screenshot() - time.sleep(0.1) + await anyio.sleep(0.1) recorded_events: list[InputEvent] = [] while True: updated_run = self._retrieve_run() @@ -113,14 +153,14 @@ def _run_human_agent(self, event_queue: queue.Queue[Events]) -> None: run_id=self._run.id, ), ) - event_queue.put( + await send_stream.send( MessageEvent( data=message, event="thread.message.created", ) ) screenshot = self._agent_os.screenshot() - time.sleep(0.1) + await anyio.sleep(0.1) self._agent_os.stop_listening() if len(recorded_events) == 0: text = "Nevermind, I didn't do anything." @@ -137,42 +177,56 @@ def _run_human_agent(self, event_queue: queue.Queue[Events]) -> None: run_id=self._run.id, ), ) - event_queue.put( + await send_stream.send( MessageEvent( data=message, event="thread.message.created", ) ) - def _run_askui_android_agent(self, event_queue: queue.Queue[Events]) -> None: - self._run_agent( + async def _run_askui_android_agent( + self, send_stream: ObjectStream[Events], mcp_client: McpClient + ) -> None: + await self._run_agent( agent_type="android", - event_queue=event_queue, + send_stream=send_stream, + mcp_client=mcp_client, ) - def _run_askui_vision_agent(self, event_queue: queue.Queue[Events]) -> None: - self._run_agent( + async def _run_askui_vision_agent( + self, send_stream: ObjectStream[Events], mcp_client: McpClient + ) -> None: + await self._run_agent( agent_type="vision", - event_queue=event_queue, + send_stream=send_stream, + mcp_client=mcp_client, ) - def _run_askui_web_agent(self, event_queue: queue.Queue[Events]) -> None: - self._run_agent( + async def _run_askui_web_agent( + self, send_stream: ObjectStream[Events], mcp_client: McpClient + ) -> None: + await self._run_agent( agent_type="web", - event_queue=event_queue, + send_stream=send_stream, + mcp_client=mcp_client, ) - def _run_askui_web_testing_agent(self, event_queue: queue.Queue[Events]) -> None: - self._run_agent( + async def _run_askui_web_testing_agent( + self, send_stream: ObjectStream[Events], mcp_client: McpClient + ) -> None: + await self._run_agent( agent_type="web_testing", - event_queue=event_queue, + send_stream=send_stream, + mcp_client=mcp_client, ) - def _run_agent( + async def _run_agent( self, agent_type: Literal["android", "vision", "web", "web_testing"], - event_queue: queue.Queue[Events], + send_stream: ObjectStream[Events], + mcp_client: McpClient, ) -> None: + tools = ToolCollection(mcp_client=mcp_client) messages: list[MessageParam] = [ MessageParam( role=msg.role, @@ -184,7 +238,7 @@ def _run_agent( ) ] - def on_message( + async def async_on_message( on_message_cb_param: OnMessageCbParam, ) -> MessageParam | None: message = self._msg_service.create( @@ -198,7 +252,7 @@ def on_message( run_id=self._run.id, ), ) - event_queue.put( + await send_stream.send( MessageEvent( data=message, event="thread.message.created", @@ -209,42 +263,52 @@ def on_message( return None return on_message_cb_param.message - if agent_type == "android": - with AndroidVisionAgent() as android_agent: - android_agent.act( - messages, - on_message=on_message, - ) - return + on_message = syncify(async_on_message) - if agent_type == "web": - with WebVisionAgent() as web_agent: - web_agent.act( - messages, - on_message=on_message, - ) - return + def _run_agent_inner() -> None: + if agent_type == "android": + with AndroidVisionAgent() as android_agent: + android_agent.act( + messages, + on_message=on_message, + tools=tools, + ) + return + + if agent_type == "web": + with WebVisionAgent() as web_agent: + web_agent.act( + messages, + on_message=on_message, + tools=tools, + ) + return - if agent_type == "web_testing": - with WebTestingAgent() as web_testing_agent: - web_testing_agent.act( + if agent_type == "web_testing": + with WebTestingAgent() as web_testing_agent: + web_testing_agent.act( + messages, + on_message=on_message, + tools=tools, + ) + return + + with VisionAgent() as agent: + agent.act( messages, on_message=on_message, + tools=tools, ) - return - with VisionAgent() as agent: - agent.act( - messages, - on_message=on_message, - ) + await asyncify(_run_agent_inner)() - def run( + async def run( self, - event_queue: queue.Queue[Events], + send_stream: ObjectStream[Events], ) -> None: + mcp_client = get_mcp_client(self._base_dir) self._mark_run_as_started() - event_queue.put( + await send_stream.send( RunEvent( data=self._run, event="thread.run.in_progress", @@ -252,27 +316,39 @@ def run( ) try: if self._run.assistant_id == HUMAN_DEMONSTRATION_AGENT.id: - self._run_human_agent(event_queue) + await self._run_human_agent(send_stream) elif self._run.assistant_id == ASKUI_VISION_AGENT.id: - self._run_askui_vision_agent(event_queue) + await self._run_askui_vision_agent( + send_stream, + mcp_client, + ) elif self._run.assistant_id == ANDROID_VISION_AGENT.id: - self._run_askui_android_agent(event_queue) + await self._run_askui_android_agent( + send_stream, + mcp_client, + ) elif self._run.assistant_id == ASKUI_WEB_AGENT.id: - self._run_askui_web_agent(event_queue) + await self._run_askui_web_agent( + send_stream, + mcp_client, + ) elif self._run.assistant_id == ASKUI_WEB_TESTING_AGENT.id: - self._run_askui_web_testing_agent(event_queue) + await self._run_askui_web_testing_agent( + send_stream, + mcp_client, + ) updated_run = self._retrieve_run() if updated_run.status == "in_progress": updated_run.completed_at = datetime.now(tz=timezone.utc) self._update_run_file(updated_run) - event_queue.put( + await send_stream.send( RunEvent( data=updated_run, event="thread.run.completed", ) ) if updated_run.status == "cancelling": - event_queue.put( + await send_stream.send( RunEvent( data=updated_run, event="thread.run.cancelling", @@ -280,33 +356,33 @@ def run( ) updated_run.cancelled_at = datetime.now(tz=timezone.utc) self._update_run_file(updated_run) - event_queue.put( + await send_stream.send( RunEvent( data=updated_run, event="thread.run.cancelled", ) ) if updated_run.status == "expired": - event_queue.put( + await send_stream.send( RunEvent( data=updated_run, event="thread.run.expired", ) ) - event_queue.put(DoneEvent()) + await send_stream.send(DoneEvent()) except Exception as e: # noqa: BLE001 logger.exception("Exception in runner") updated_run = self._retrieve_run() updated_run.failed_at = datetime.now(tz=timezone.utc) updated_run.last_error = RunError(message=str(e), code="server_error") self._update_run_file(updated_run) - event_queue.put( + await send_stream.send( RunEvent( data=updated_run, event="thread.run.failed", ) ) - event_queue.put( + await send_stream.send( ErrorEvent( data=ErrorEventData(error=ErrorEventDataError(message=str(e))) ) diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 9f67adb6..e5c4a647 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -1,11 +1,8 @@ -import asyncio -import queue -import threading from collections.abc import AsyncGenerator from datetime import datetime, timezone from pathlib import Path -from typing import Literal, overload +import anyio from pydantic import BaseModel from askui.chat.api.models import AssistantId, RunId, ThreadId @@ -42,32 +39,15 @@ def _create_run(self, thread_id: ThreadId, request: CreateRunRequest) -> Run: self._update_run_file(run) return run - @overload - def create( - self, thread_id: ThreadId, stream: Literal[False], request: CreateRunRequest - ) -> Run: ... - - @overload - def create( - self, thread_id: ThreadId, stream: Literal[True], request: CreateRunRequest - ) -> AsyncGenerator[Events, None]: ... - - @overload - def create( - self, thread_id: ThreadId, stream: bool, request: CreateRunRequest - ) -> Run | AsyncGenerator[Events, None]: ... - - def create( - self, thread_id: ThreadId, stream: bool, request: CreateRunRequest - ) -> Run | AsyncGenerator[Events, None]: + async def create( + self, thread_id: ThreadId, request: CreateRunRequest + ) -> tuple[Run, AsyncGenerator[Events, None]]: run = self._create_run(thread_id, request) - event_queue: queue.Queue[Events] = queue.Queue() + send_stream, receive_stream = anyio.create_memory_object_stream[Events]() runner = Runner(run, self._base_dir) - thread = threading.Thread(target=runner.run, args=(event_queue,), daemon=True) - thread.start() - if stream: - async def event_stream() -> AsyncGenerator[Events, None]: + async def event_generator() -> AsyncGenerator[Events, None]: + try: yield RunEvent( # run already in progress instead of queued which is # different from OpenAI @@ -80,15 +60,34 @@ async def event_stream() -> AsyncGenerator[Events, None]: data=run, event="thread.run.queued", ) - loop = asyncio.get_event_loop() - while True: - event = await loop.run_in_executor(None, event_queue.get) - yield event - if isinstance(event, DoneEvent) or isinstance(event, ErrorEvent): - break - - return event_stream() - return run + + # Start the runner in a background task + async def run_runner() -> None: + try: + await runner.run(send_stream) # type: ignore[arg-type] + finally: + await send_stream.aclose() + + # Create a task group to manage the runner and event processing + async with anyio.create_task_group() as tg: + # Start the runner in the background + tg.start_soon(run_runner) + + # Process events from the stream + while True: + try: + event = await receive_stream.receive() + yield event + if isinstance(event, DoneEvent) or isinstance( + event, ErrorEvent + ): + break + except anyio.EndOfStream: + break + finally: + await send_stream.aclose() + + return run, event_generator() def _update_run_file(self, run: Run) -> None: run_file = self._run_path(run.thread_id, run.id) diff --git a/src/askui/mcp/__init__.py b/src/askui/mcp/__init__.py deleted file mode 100644 index 0019dca5..00000000 --- a/src/askui/mcp/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import Any - -from fastmcp import FastMCP -from fastmcp.tools.tool import Tool - -from askui.agent import VisionAgent - - -@dataclass -class AppContext: - vision_agent: VisionAgent - - -@asynccontextmanager -async def mcp_lifespan(server: FastMCP[Any]) -> AsyncIterator[AppContext]: # noqa: ARG001 - with VisionAgent(display=2) as vision_agent: - server.add_tool(Tool.from_function(vision_agent.click)) - yield AppContext(vision_agent=vision_agent) - - -mcp = FastMCP("Vision Agent MCP App", lifespan=mcp_lifespan) diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index f203ef3c..7ff266f6 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -28,7 +28,7 @@ from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.facade import ModelFacade from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool +from askui.models.shared.tools import ToolCollection from askui.models.types.response_schemas import ResponseSchema from askui.reporting import NULL_REPORTER, CompositeReporter, Reporter from askui.utils.image_utils import ImageSource @@ -184,7 +184,7 @@ def act( messages: list[MessageParam], model_choice: str, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: m = self._get_model(model_choice, "act") diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 3c2b63b2..78407645 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -10,7 +10,7 @@ from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool +from askui.models.shared.tools import ToolCollection from askui.models.types.response_schemas import ResponseSchema from askui.utils.image_utils import ImageSource from askui.utils.source_utils import Source @@ -191,7 +191,7 @@ def act( messages: list[MessageParam], model_choice: str, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: """ @@ -222,7 +222,7 @@ def act( added to the message history and the acting continues based on the message. The message may be modified by the callback to allow for directing the assistant/agent or tool use. - tools (list[Tool] | None, optional): The tools for the agent. + tools (ToolCollection | None, optional): The tools for the agent. Defaults to `None`. settings (AgentSettings | None, optional): The settings for the agent. Defaults to `None`. diff --git a/src/askui/models/shared/agent.py b/src/askui/models/shared/agent.py index dc537d49..6a5a8e0e 100644 --- a/src/askui/models/shared/agent.py +++ b/src/askui/models/shared/agent.py @@ -14,7 +14,7 @@ ) from askui.models.shared.messages_api import MessagesApi from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool, ToolCollection +from askui.models.shared.tools import ToolCollection from askui.reporting import NULL_REPORTER, Reporter from ...logger import logger @@ -131,7 +131,7 @@ def act( messages: list[MessageParam], model_choice: str, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: _settings = settings or ActSettings() @@ -140,7 +140,7 @@ def act( model=_settings.messages.model or model_choice, on_message=on_message or NULL_ON_MESSAGE_CB, settings=_settings, - tool_collection=ToolCollection(tools), + tool_collection=tools or ToolCollection(), ) def _use_tools( diff --git a/src/askui/models/shared/facade.py b/src/askui/models/shared/facade.py index c919fdf6..72a8d777 100644 --- a/src/askui/models/shared/facade.py +++ b/src/askui/models/shared/facade.py @@ -13,7 +13,7 @@ from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool +from askui.models.shared.tools import ToolCollection from askui.models.types.response_schemas import ResponseSchema from askui.utils.image_utils import ImageSource from askui.utils.source_utils import Source @@ -36,7 +36,7 @@ def act( messages: list[MessageParam], model_choice: str, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: self._act_model.act( diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 355a9aae..c8976e6b 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -1,8 +1,13 @@ from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Literal, cast from anthropic.types.beta import BetaToolParam, BetaToolUnionParam from anthropic.types.beta.beta_tool_param import InputSchema +from asyncer import syncify +from fastmcp import Client +from fastmcp.client.client import CallToolResult +from fastmcp.client.transports import ClientTransportT +from mcp import Tool as McpTool from PIL import Image from pydantic import BaseModel, Field from typing_extensions import Self @@ -24,15 +29,44 @@ PrimitiveToolCallResult | list[PrimitiveToolCallResult] | tuple[PrimitiveToolCallResult, ...] + | CallToolResult ) +IMAGE_MEDIA_TYPES_SUPPORTED: list[ + Literal["image/jpeg", "image/png", "image/gif", "image/webp"] +] = ["image/jpeg", "image/png", "image/gif", "image/webp"] + + def _convert_to_content( result: ToolCallResult, ) -> list[TextBlockParam | ImageBlockParam]: if result is None: return [] + if isinstance(result, CallToolResult): + _result: list[TextBlockParam | ImageBlockParam] = [] + for block in result.content: + match block.type: + case "text": + _result.append(TextBlockParam(text=block.text)) # type: ignore[union-attr] + case "image": + media_type = block.mimeType # type: ignore[union-attr] + if media_type not in IMAGE_MEDIA_TYPES_SUPPORTED: + logger.error(f"Unsupported image media type: {media_type}") + continue + _result.append( + ImageBlockParam( + source=Base64ImageSourceParam( + media_type=media_type, + data=result.data, + ) + ) + ) + case _: + logger.error(f"Unsupported block type: {block.type}") + return _result + if isinstance(result, str): return [TextBlockParam(text=result)] @@ -108,14 +142,48 @@ class ToolCollection: - Could be used for raising on an exception (instead of just returning `ContentBlockParam`) within tool call or doing tool call or if tool is not found + + Args: + tools (list[Tool] | None, optional): The tools to add to the collection. + Defaults to `None`. + mcp_client (Client[ClientTransportT] | None, optional): The client to use for + the tools. Defaults to `None`. """ - def __init__(self, tools: list[Tool] | None = None) -> None: + def __init__( + self, + tools: list[Tool] | None = None, + mcp_client: Client[ClientTransportT] | None = None, + ) -> None: _tools = tools or [] self._tool_map = {tool.to_params()["name"]: tool for tool in _tools} + self._mcp_client = mcp_client def to_params(self) -> list[BetaToolUnionParam]: - return [tool.to_params() for tool in self._tool_map.values()] + tool_map = { + **self._get_mcp_tool_params(), + **{ + tool_name: tool.to_params() + for tool_name, tool in self._tool_map.items() + }, + } + return list(tool_map.values()) + + def _get_mcp_tool_params(self) -> dict[str, BetaToolUnionParam]: + if not self._mcp_client: + return {} + mcp_tools = self._get_mcp_tools() + return { + tool_name: cast( + "BetaToolUnionParam", + BetaToolParam( + name=tool_name, + description=tool.description or "", + input_schema=tool.inputSchema, + ), + ) + for tool_name, tool in mcp_tools.items() + } def append_tool(self, *tools: Tool) -> "Self": """Append a tool to the collection.""" @@ -141,12 +209,41 @@ def _run_tool( self, tool_use_block_param: ToolUseBlockParam ) -> ToolResultBlockParam: tool = self._tool_map.get(tool_use_block_param.name) - if not tool: - return ToolResultBlockParam( - content=f"Tool not found: {tool_use_block_param.name}", - is_error=True, - tool_use_id=tool_use_block_param.id, - ) + if tool: + return self._run_regular_tool(tool_use_block_param, tool) + mcp_tool = self._get_mcp_tools().get(tool_use_block_param.name) + if mcp_tool: + return self._run_mcp_tool(tool_use_block_param) + return ToolResultBlockParam( + content=f"Tool not found: {tool_use_block_param.name}", + is_error=True, + tool_use_id=tool_use_block_param.id, + ) + + async def _list_mcp_tools( + self, mcp_client: Client[ClientTransportT] + ) -> list[McpTool]: + async with mcp_client: + return await mcp_client.list_tools() + + def _get_mcp_tools(self) -> dict[str, McpTool]: + """Get cached MCP tools or fetch them if not cached.""" + try: + if not self._mcp_client: + return {} + list_mcp_tools_sync = syncify(self._list_mcp_tools, raise_sync_error=False) + tools_list = list_mcp_tools_sync(self._mcp_client) + except Exception as e: # noqa: BLE001 + logger.error(f"Failed to list MCP tools: {e}", exc_info=True) + return {} + else: + return {tool.name: tool for tool in tools_list} + + def _run_regular_tool( + self, + tool_use_block_param: ToolUseBlockParam, + tool: Tool, + ) -> ToolResultBlockParam: try: tool_result: ToolCallResult = tool(**tool_use_block_param.input) # type: ignore return ToolResultBlockParam( @@ -162,3 +259,48 @@ def _run_tool( is_error=True, tool_use_id=tool_use_block_param.id, ) + + async def _call_mcp_tool( + self, + mcp_client: Client[ClientTransportT], + tool_use_block_param: ToolUseBlockParam, + ) -> ToolCallResult: + async with mcp_client: + return await mcp_client.call_tool( + tool_use_block_param.name, + tool_use_block_param.input, # type: ignore[arg-type] + ) + + def _run_mcp_tool( + self, + tool_use_block_param: ToolUseBlockParam, + ) -> ToolResultBlockParam: + """Run an MCP tool using the client.""" + if not self._mcp_client: + return ToolResultBlockParam( + content="MCP client not available", + is_error=True, + tool_use_id=tool_use_block_param.id, + ) + try: + call_mcp_tool_sync = syncify(self._call_mcp_tool, raise_sync_error=False) + result = call_mcp_tool_sync(self._mcp_client, tool_use_block_param) + return ToolResultBlockParam( + content=_convert_to_content(result), + tool_use_id=tool_use_block_param.id, + ) + except Exception as e: # noqa: BLE001 + logger.error( + f"MCP tool {tool_use_block_param.name} failed: {e}", exc_info=True + ) + return ToolResultBlockParam( + content=f"MCP tool {tool_use_block_param.name} failed: {e}", + is_error=True, + tool_use_id=tool_use_block_param.id, + ) + + def __add__(self, other: "ToolCollection") -> "ToolCollection": + return ToolCollection( + tools=list(self._tool_map.values()) + list(other._tool_map.values()), + mcp_client=other._mcp_client or self._mcp_client, + ) diff --git a/src/askui/models/ui_tars_ep/ui_tars_api.py b/src/askui/models/ui_tars_ep/ui_tars_api.py index 1f6b2df2..511927c8 100644 --- a/src/askui/models/ui_tars_ep/ui_tars_api.py +++ b/src/askui/models/ui_tars_ep/ui_tars_api.py @@ -21,9 +21,10 @@ from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool +from askui.models.shared.tools import Tool, ToolCollection from askui.models.types.response_schemas import ResponseSchema from askui.reporting import Reporter +from askui.tools.computer import Computer20241022Tool from askui.utils.excel_utils import OfficeDocumentSource from askui.utils.image_utils import ImageSource, image_to_base64 from askui.utils.pdf_utils import PdfSource @@ -211,7 +212,7 @@ def act( messages: list[MessageParam], model_choice: str, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: if on_message is not None: @@ -229,11 +230,11 @@ def act( raise ValueError(error_msg) # noqa: TRY004 # Find the computer tool - computer_tool = None + computer_tool: Computer20241022Tool | None = None if tools: for tool in tools: if tool.name == "computer": - computer_tool = tool + computer_tool: Computer20241022Tool = tool break if computer_tool is None: @@ -261,7 +262,7 @@ def act( self.execute_act(self.act_history, computer_tool) def add_screenshot_to_history( - self, message_history: list[dict[str, Any]], computer_tool: Tool + self, message_history: list[dict[str, Any]], computer_tool: Computer20241022Tool ) -> None: screenshot = computer_tool(action="screenshot") message_history.append( diff --git a/src/askui/tools/mcp/__init__.py b/src/askui/tools/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/tools/mcp/config.py b/src/askui/tools/mcp/config.py new file mode 100644 index 00000000..21509617 --- /dev/null +++ b/src/askui/tools/mcp/config.py @@ -0,0 +1,17 @@ +from fastmcp.client.transports import StdioTransport +from fastmcp.mcp_config import StdioMCPServer as FastMCPStdioMCPServer + + +class StdioMCPServer(FastMCPStdioMCPServer): + keep_alive: bool = False + + def to_transport(self) -> StdioTransport: + from fastmcp.client.transports import StdioTransport + + return StdioTransport( + command=self.command, + args=self.args, + env=self.env, + cwd=self.cwd, + keep_alive=self.keep_alive, + ) diff --git a/src/askui/tools/mcp/servers/__init__.py b/src/askui/tools/mcp/servers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/tools/mcp/servers/sse.py b/src/askui/tools/mcp/servers/sse.py new file mode 100644 index 00000000..ad1645eb --- /dev/null +++ b/src/askui/tools/mcp/servers/sse.py @@ -0,0 +1,15 @@ +from typing import Any + +from fastmcp import FastMCP + +mcp: FastMCP[Any] = FastMCP("Test StdIO MCP App", port=8001) + + +@mcp.tool +def test_sse_tool() -> str: + print("test_sse_tool called") + return "I am a test sse tool" + + +if __name__ == "__main__": + mcp.run(transport="sse") diff --git a/src/askui/tools/mcp/servers/stdio.py b/src/askui/tools/mcp/servers/stdio.py new file mode 100644 index 00000000..7785f5be --- /dev/null +++ b/src/askui/tools/mcp/servers/stdio.py @@ -0,0 +1,15 @@ +from typing import Any + +from fastmcp import FastMCP + +mcp: FastMCP[Any] = FastMCP("Test StdIO MCP App") + + +@mcp.tool +def test_stdio_tool() -> str: + print("test_stdio_tool called") + return "I am a test stdio tool" + + +if __name__ == "__main__": + mcp.run(transport="stdio", show_banner=False) diff --git a/src/askui/utils/api_utils.py b/src/askui/utils/api_utils.py index b38e67ea..f1b994a5 100644 --- a/src/askui/utils/api_utils.py +++ b/src/askui/utils/api_utils.py @@ -40,6 +40,10 @@ class ConflictError(ApiError): pass +class LimitReachedError(ApiError): + pass + + class NotFoundError(ApiError): pass diff --git a/tests/integration/test_custom_models.py b/tests/integration/test_custom_models.py index 1f4a2449..4f98eaee 100644 --- a/tests/integration/test_custom_models.py +++ b/tests/integration/test_custom_models.py @@ -22,7 +22,7 @@ from askui.models.shared.agent_message_param import MessageParam from askui.models.shared.agent_on_message_cb import OnMessageCb from askui.models.shared.settings import ActSettings -from askui.models.shared.tools import Tool +from askui.models.shared.tools import ToolCollection from askui.tools.toolbox import AgentToolbox from askui.utils.image_utils import ImageSource from askui.utils.source_utils import Source @@ -41,7 +41,7 @@ def act( messages: list[MessageParam], model_choice: str, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: self.goals.append([message.model_dump(mode="json") for message in messages]) @@ -230,7 +230,7 @@ def act( messages: list[MessageParam], model_choice: str, on_message: OnMessageCb | None = None, - tools: list[Tool] | None = None, + tools: ToolCollection | None = None, settings: ActSettings | None = None, ) -> None: pass