Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/askui/chat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings
from askui.chat.api.files.router import router as files_router
from askui.chat.api.health.router import router as health_router
from askui.chat.api.mcp_clients.dependencies import get_mcp_client_manager_manager
from askui.chat.api.mcp_clients.manager import McpServerConnectionError
from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service
from askui.chat.api.mcp_configs.router import router as mcp_configs_router
from askui.chat.api.mcps.computer import mcp as computer_mcp
Expand All @@ -34,6 +36,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
mcp_config_service = get_mcp_config_service(settings=settings)
mcp_config_service.seed()
yield
await get_mcp_client_manager_manager(mcp_config_service).disconnect_all(force=True)


app = FastAPI(
Expand Down Expand Up @@ -142,6 +145,17 @@ def catch_all_exception_handler(
)


@app.exception_handler(McpServerConnectionError)
def mcp_server_connection_error_handler(
request: Request, # noqa: ARG001
exc: McpServerConnectionError,
) -> JSONResponse:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": str(exc)},
)


app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions src/askui/chat/api/mcp_clients/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from fastapi import Depends

from askui.chat.api.mcp_clients.manager import McpClientManagerManager
from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep
from askui.chat.api.mcp_configs.service import McpConfigService


def get_mcp_client_manager_manager(
mcp_config_service: McpConfigService = McpConfigServiceDep,
) -> McpClientManagerManager:
return McpClientManagerManager(mcp_config_service)


McpClientManagerManagerDep = Depends(get_mcp_client_manager_manager)
156 changes: 156 additions & 0 deletions src/askui/chat/api/mcp_clients/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import types
from datetime import timedelta
from typing import Any, Type

import anyio
import mcp
from fastmcp import Client
from fastmcp.client.client import CallToolResult, ProgressHandler
from fastmcp.exceptions import ToolError
from fastmcp.mcp_config import MCPConfig

from askui.chat.api.mcp_configs.service import McpConfigService
from askui.chat.api.models import WorkspaceId

McpServerName = str


class McpServerConnectionError(Exception):
"""Exception raised when a MCP server connection fails."""

def __init__(self, mcp_server_name: McpServerName, error: Exception):
super().__init__(f"Failed to connect to MCP server: {mcp_server_name}: {error}")
self.mcp_server_name = mcp_server_name
self.error = error


class McpClientManager:
def __init__(
self, mcp_clients: dict[McpServerName, Client[Any]] | None = None
) -> None:
self._mcp_clients = mcp_clients or {}
self._tools: dict[McpServerName, list[mcp.types.Tool]] = {}

@classmethod
def from_config(cls, mcp_config: MCPConfig) -> "McpClientManager":
mcp_clients: dict[McpServerName, Client[Any]] = {
mcp_server_name: Client(mcp_server_config.to_transport())
for mcp_server_name, mcp_server_config in mcp_config.mcpServers.items()
}
return cls(mcp_clients)

async def connect(self) -> "McpClientManager":
for mcp_server_name, mcp_client in self._mcp_clients.items():
try:
await mcp_client._connect() # noqa: SLF001
except Exception as e: # noqa: PERF203
raise McpServerConnectionError(mcp_server_name, e) from e
return self

async def disconnect(self, force: bool = False) -> None:
for mcp_client in self._mcp_clients.values():
if mcp_client.is_connected():
await mcp_client._disconnect(force) # noqa: SLF001

async def list_tools(
self,
) -> list[mcp.types.Tool]:
tools: list[mcp.types.Tool] = []
for mcp_server_name, mcp_client in self._mcp_clients.items():
if mcp_server_name not in self._tools:
self._tools[mcp_server_name] = await mcp_client.list_tools()
tools.extend(self._tools[mcp_server_name])
return tools

async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
timeout: timedelta | float | None = None, # noqa: ASYNC109
progress_handler: ProgressHandler | None = None,
raise_on_error: bool = True,
) -> CallToolResult:
for mcp_server_name, tools in self._tools.items(): # Make lookup faster
for tool in tools:
if tool.name == name:
return await self._mcp_clients[mcp_server_name].call_tool(
name,
arguments,
timeout,
progress_handler,
raise_on_error,
)
error_msg = f"Unknown tool: {name}"
if raise_on_error:
raise ToolError(error_msg)
return CallToolResult(
content=[mcp.types.TextContent(type="text", text=error_msg)],
structured_content=None,
data=None,
is_error=True,
)

async def __aenter__(self) -> "McpClientManager":
return await self.connect()

async def __aexit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: types.TracebackType | None,
) -> None:
await self.disconnect()


McpClientManagerKey = str


class McpClientManagerManager:
_mcp_client_managers: dict[McpClientManagerKey, McpClientManager | None] = {}
_lock: anyio.Lock = anyio.Lock()

def __init__(self, mcp_config_service: McpConfigService) -> None:
self._mcp_config_service = mcp_config_service

async def get_mcp_client_manager(
self, workspace_id: WorkspaceId | None
) -> McpClientManager | None:
key: McpClientManagerKey = (
f"workspace_{workspace_id}" if workspace_id else "global"
)
if key in McpClientManagerManager._mcp_client_managers:
return McpClientManagerManager._mcp_client_managers[key]

fast_mcp_config = self._mcp_config_service.retrieve_fast_mcp_config(
workspace_id
)
if not fast_mcp_config:
McpClientManagerManager._mcp_client_managers[key] = None
return None

async with McpClientManagerManager._lock:
if key not in McpClientManagerManager._mcp_client_managers:
try:
mcp_client_manager = McpClientManager.from_config(fast_mcp_config)
McpClientManagerManager._mcp_client_managers[key] = (
mcp_client_manager
)
await mcp_client_manager.connect()
except Exception:
if key in McpClientManagerManager._mcp_client_managers:
if (
_mcp_client_manager
:= McpClientManagerManager._mcp_client_managers[key]
):
await _mcp_client_manager.disconnect(force=True)
del McpClientManagerManager._mcp_client_managers[key]
raise
return McpClientManagerManager._mcp_client_managers[key]

async def disconnect_all(self, force: bool = False) -> None:
async with McpClientManagerManager._lock:
for (
mcp_client_manager
) in McpClientManagerManager._mcp_client_managers.values():
if mcp_client_manager:
await mcp_client_manager.disconnect(force)
17 changes: 14 additions & 3 deletions src/askui/chat/api/mcp_configs/seeds.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastmcp.mcp_config import RemoteMCPServer
from fastmcp.mcp_config import RemoteMCPServer, StdioMCPServer

from askui.chat.api.dependencies import get_settings
from askui.chat.api.mcp_configs.models import McpConfig
Expand All @@ -10,12 +10,23 @@
ASKUI_CHAT_MCP = McpConfig(
id="mcpcnf_68ac2c4edc4b2f27faa5a252",
created_at=now(),
name="AskUI Chat MCP",
name="askui_chat",
mcp_server=RemoteMCPServer(
url=f"http://{settings.host}:{settings.port}/mcp/sse",
transport="sse",
),
)


SEEDS = [ASKUI_CHAT_MCP]
PLAYWRIGHT_MCP = McpConfig(
id="mcpcnf_68ac2c4edc4b2f27faa5a251",
created_at=now(),
name="playwright",
mcp_server=StdioMCPServer(
command="npx",
args=["@playwright/mcp@latest", "--isolated"],
),
)


SEEDS = [ASKUI_CHAT_MCP, PLAYWRIGHT_MCP]
14 changes: 14 additions & 0 deletions src/askui/chat/api/mcp_configs/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pathlib import Path

from fastmcp.mcp_config import MCPConfig

from askui.chat.api.mcp_configs.models import (
McpConfig,
McpConfigCreateParams,
Expand Down Expand Up @@ -69,6 +71,18 @@ def retrieve(
else:
return mcp_config

def retrieve_fast_mcp_config(
self, workspace_id: WorkspaceId | None
) -> MCPConfig | None:
list_response = self.list_(
workspace_id=workspace_id,
query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"),
)
mcp_servers_dict = {
mcp_config.name: mcp_config.mcp_server for mcp_config in list_response.data
}
return MCPConfig(mcpServers=mcp_servers_dict) if mcp_servers_dict else None

def _check_limit(self, workspace_id: WorkspaceId | None) -> None:
limit = LIST_LIMIT_MAX
list_result = self.list_(workspace_id, ListQuery(limit=limit))
Expand Down
8 changes: 4 additions & 4 deletions src/askui/chat/api/runs/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from askui.chat.api.assistants.dependencies import AssistantServiceDep
from askui.chat.api.assistants.service import AssistantService
from askui.chat.api.dependencies import WorkspaceDirDep
from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep
from askui.chat.api.mcp_configs.service import McpConfigService
from askui.chat.api.mcp_clients.dependencies import McpClientManagerManagerDep
from askui.chat.api.mcp_clients.manager import McpClientManagerManager
from askui.chat.api.messages.dependencies import MessageServiceDep, MessageTranslatorDep
from askui.chat.api.messages.service import MessageService
from askui.chat.api.messages.translator import MessageTranslator
Expand All @@ -17,15 +17,15 @@
def get_runs_service(
workspace_dir: Path = WorkspaceDirDep,
assistant_service: AssistantService = AssistantServiceDep,
mcp_config_service: McpConfigService = McpConfigServiceDep,
mcp_client_manager_manager: McpClientManagerManager = McpClientManagerManagerDep,
message_service: MessageService = MessageServiceDep,
message_translator: MessageTranslator = MessageTranslatorDep,
) -> RunService:
"""Get RunService instance."""
return RunService(
base_dir=workspace_dir,
assistant_service=assistant_service,
mcp_config_service=mcp_config_service,
mcp_client_manager_manager=mcp_client_manager_manager,
message_service=message_service,
message_translator=message_translator,
)
Expand Down
46 changes: 17 additions & 29 deletions src/askui/chat/api/runs/runner/runner.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Literal, Sequence
from typing import TYPE_CHECKING, Literal

import anthropic
import anyio
from anyio.abc import ObjectStream
from asyncer import asyncify, syncify
from fastmcp import Client
from fastmcp.client.transports import MCPConfigTransport
from fastmcp.mcp_config import MCPConfig

from askui.android_agent import AndroidVisionAgent
from askui.chat.api.assistants.models import Assistant
Expand All @@ -18,8 +17,7 @@
TESTING_AGENT,
WEB_AGENT,
)
from askui.chat.api.mcp_configs.models import McpConfig
from askui.chat.api.mcp_configs.service import McpConfigService
from askui.chat.api.mcp_clients.manager import McpClientManagerManager
from askui.chat.api.messages.models import MessageCreateParams
from askui.chat.api.messages.service import MessageService
from askui.chat.api.messages.translator import MessageTranslator
Expand Down Expand Up @@ -57,13 +55,6 @@
logger = logging.getLogger(__name__)


def build_fast_mcp_config(mcp_configs: Sequence[McpConfig]) -> MCPConfig:
mcp_config_dict = {
mcp_config.id: mcp_config.mcp_server for mcp_config in mcp_configs
}
return MCPConfig(mcpServers=mcp_config_dict)


McpClient = Client[MCPConfigTransport]


Expand All @@ -85,7 +76,7 @@ def __init__(
run: Run,
message_service: MessageService,
message_translator: MessageTranslator,
mcp_config_service: McpConfigService,
mcp_client_manager_manager: McpClientManagerManager,
run_service: RunnerRunService,
) -> None:
self._workspace_id = workspace_id
Expand All @@ -94,18 +85,10 @@ def __init__(
self._message_service = message_service
self._message_translator = message_translator
self._message_content_translator = message_translator.content_translator
self._mcp_config_service = mcp_config_service
self._mcp_client_manager_manager = mcp_client_manager_manager
self._run_service = run_service
self._agent_os = PynputAgentOs()

def _get_mcp_client(self) -> McpClient | None:
mcp_configs = self._mcp_config_service.list_(
workspace_id=self._workspace_id,
query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"),
)
fast_mcp_config = build_fast_mcp_config(mcp_configs.data)
return Client(fast_mcp_config) if fast_mcp_config.mcpServers else None

def _retrieve(self) -> Run:
return self._run_service.retrieve(
thread_id=self._run.thread_id,
Expand Down Expand Up @@ -342,19 +325,24 @@ def _run_agent_inner() -> None:

await asyncify(_run_agent_inner)()

async def _get_mcp_client(self) -> McpClient | None:
return await self._mcp_client_manager_manager.get_mcp_client_manager( # type: ignore
self._workspace_id
)

async def run(
self,
send_stream: ObjectStream[Events],
) -> None:
mcp_client = self._get_mcp_client()
self._mark_run_as_started()
await send_stream.send(
RunEvent(
data=self._run,
event="thread.run.in_progress",
)
)
try:
mcp_client = await self._get_mcp_client()
self._mark_run_as_started()
await send_stream.send(
RunEvent(
data=self._run,
event="thread.run.in_progress",
)
)
if self._run.assistant_id == HUMAN_DEMONSTRATION_AGENT.id:
await self._run_human_agent(send_stream)
elif self._run.assistant_id == ANDROID_AGENT.id:
Expand Down
Loading