From a1e907c510b581b93d9838022d7f8ba3191b9a3f Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Fri, 5 Sep 2025 16:04:15 +0200 Subject: [PATCH 1/2] fix(chat)!: hold mcp client sessions - sessions are not immediately closed after each call to server but instead kept open as expected by mcp protocol - sessions initiated lazily - one session per client per workspace (across all chats in workspace) - add playwwright npx to default seeds --> requiring playwright - no update of mcp tools on server update or update of mcp configs --> restart of chat api required - tools with the same name override each other --> tool names need to be unique across all servers - no limitation on workspaces, mcp servers and tools, resources, prompts or open connections/sessions --> may lead to unexpected behavior in case of too many - mcp servers with the same name override each other --> names need to be unique - can easily be extended later to support holding mcp sessions across one run if it becomes possible to continue run with new message instead of creating new run --- src/askui/chat/api/app.py | 14 ++ src/askui/chat/api/mcp_clients/__init__.py | 0 .../chat/api/mcp_clients/dependencies.py | 14 ++ src/askui/chat/api/mcp_clients/manager.py | 156 ++++++++++++++++++ src/askui/chat/api/mcp_configs/seeds.py | 17 +- src/askui/chat/api/mcp_configs/service.py | 14 ++ src/askui/chat/api/runs/dependencies.py | 8 +- src/askui/chat/api/runs/runner/runner.py | 46 ++---- src/askui/chat/api/runs/service.py | 8 +- src/askui/models/shared/tools.py | 69 ++++++-- 10 files changed, 296 insertions(+), 50 deletions(-) create mode 100644 src/askui/chat/api/mcp_clients/__init__.py create mode 100644 src/askui/chat/api/mcp_clients/dependencies.py create mode 100644 src/askui/chat/api/mcp_clients/manager.py diff --git a/src/askui/chat/api/app.py b/src/askui/chat/api/app.py index 6d21eb55..226a9c2a 100644 --- a/src/askui/chat/api/app.py +++ b/src/askui/chat/api/app.py @@ -11,6 +11,8 @@ from askui.chat.api.dependencies import SetEnvFromHeadersDep, get_settings from askui.chat.api.files.router import router as files_router from askui.chat.api.health.router import router as health_router +from askui.chat.api.mcp_clients.dependencies import get_mcp_client_manager_manager +from askui.chat.api.mcp_clients.manager import McpServerConnectionError from askui.chat.api.mcp_configs.dependencies import get_mcp_config_service from askui.chat.api.mcp_configs.router import router as mcp_configs_router from askui.chat.api.mcps.computer import mcp as computer_mcp @@ -34,6 +36,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 mcp_config_service = get_mcp_config_service(settings=settings) mcp_config_service.seed() yield + await get_mcp_client_manager_manager(mcp_config_service).disconnect_all(force=True) app = FastAPI( @@ -142,6 +145,17 @@ def catch_all_exception_handler( ) +@app.exception_handler(McpServerConnectionError) +def mcp_server_connection_error_handler( + request: Request, # noqa: ARG001 + exc: McpServerConnectionError, +) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": str(exc)}, + ) + + app.add_middleware( CORSMiddleware, allow_origins=["*"], diff --git a/src/askui/chat/api/mcp_clients/__init__.py b/src/askui/chat/api/mcp_clients/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/askui/chat/api/mcp_clients/dependencies.py b/src/askui/chat/api/mcp_clients/dependencies.py new file mode 100644 index 00000000..3a2744fd --- /dev/null +++ b/src/askui/chat/api/mcp_clients/dependencies.py @@ -0,0 +1,14 @@ +from fastapi import Depends + +from askui.chat.api.mcp_clients.manager import McpClientManagerManager +from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep +from askui.chat.api.mcp_configs.service import McpConfigService + + +def get_mcp_client_manager_manager( + mcp_config_service: McpConfigService = McpConfigServiceDep, +) -> McpClientManagerManager: + return McpClientManagerManager(mcp_config_service) + + +McpClientManagerManagerDep = Depends(get_mcp_client_manager_manager) diff --git a/src/askui/chat/api/mcp_clients/manager.py b/src/askui/chat/api/mcp_clients/manager.py new file mode 100644 index 00000000..9aa49cde --- /dev/null +++ b/src/askui/chat/api/mcp_clients/manager.py @@ -0,0 +1,156 @@ +import types +from datetime import timedelta +from typing import Any, Type + +import anyio +import mcp +from fastmcp import Client +from fastmcp.client.client import CallToolResult, ProgressHandler +from fastmcp.exceptions import ToolError +from fastmcp.mcp_config import MCPConfig + +from askui.chat.api.mcp_configs.service import McpConfigService +from askui.chat.api.models import WorkspaceId + +McpServerName = str + + +class McpServerConnectionError(Exception): + """Exception raised when a MCP server connection fails.""" + + def __init__(self, mcp_server_name: McpServerName, error: Exception): + super().__init__(f"Failed to connect to MCP server: {mcp_server_name}: {error}") + self.mcp_server_name = mcp_server_name + self.error = error + + +class McpClientManager: + def __init__( + self, mcp_clients: dict[McpServerName, Client[Any]] | None = None + ) -> None: + self._mcp_clients = mcp_clients or {} + self._tools: dict[McpServerName, list[mcp.types.Tool]] = {} + + @classmethod + def from_config(cls, mcp_config: MCPConfig) -> "McpClientManager": + mcp_clients: dict[McpServerName, Client[Any]] = { + mcp_server_name: Client(mcp_server_config.to_transport()) + for mcp_server_name, mcp_server_config in mcp_config.mcpServers.items() + } + return cls(mcp_clients) + + async def connect(self) -> "McpClientManager": + for mcp_server_name, mcp_client in self._mcp_clients.items(): + try: + await mcp_client._connect() # noqa: SLF001 + except Exception as e: # noqa: PERF203 + raise McpServerConnectionError(mcp_server_name, e) from e + return self + + async def disconnect(self, force: bool = False) -> None: + for mcp_client in self._mcp_clients.values(): + if mcp_client.is_connected(): + await mcp_client._disconnect(force) # noqa: SLF001 + + async def list_tools( + self, + ) -> list[mcp.types.Tool]: # TODO Proper cache and parallelization + tools: list[mcp.types.Tool] = [] + for mcp_server_name, mcp_client in self._mcp_clients.items(): + if mcp_server_name not in self._tools: + self._tools[mcp_server_name] = await mcp_client.list_tools() + tools.extend(self._tools[mcp_server_name]) + return tools + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + timeout: timedelta | float | None = None, + progress_handler: ProgressHandler | None = None, + raise_on_error: bool = True, + ) -> CallToolResult: + for mcp_server_name, tools in self._tools.items(): # Make lookup faster + for tool in tools: + if tool.name == name: + return await self._mcp_clients[mcp_server_name].call_tool( + name, + arguments, + timeout, + progress_handler, + raise_on_error, + ) + error_msg = f"Unknown tool: {name}" + if raise_on_error: + raise ToolError(error_msg) + return CallToolResult( + content=[mcp.types.TextContent(type="text", text=error_msg)], + structured_content=None, + data=None, + is_error=True, + ) + + async def __aenter__(self) -> "McpClientManager": + return await self.connect() + + async def __aexit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + await self.disconnect() + + +McpClientManagerKey = str + + +class McpClientManagerManager: + _mcp_client_managers: dict[McpClientManagerKey, McpClientManager | None] = {} + _lock: anyio.Lock = anyio.Lock() + + def __init__(self, mcp_config_service: McpConfigService) -> None: + self._mcp_config_service = mcp_config_service + + async def get_mcp_client_manager( + self, workspace_id: WorkspaceId | None + ) -> McpClientManager | None: + key: McpClientManagerKey = ( + f"workspace_{workspace_id}" if workspace_id else "global" + ) + if key in McpClientManagerManager._mcp_client_managers: + return McpClientManagerManager._mcp_client_managers[key] + + fast_mcp_config = self._mcp_config_service.retrieve_fast_mcp_config( + workspace_id + ) + if not fast_mcp_config: + McpClientManagerManager._mcp_client_managers[key] = None + return None + + async with McpClientManagerManager._lock: + if key not in McpClientManagerManager._mcp_client_managers: + try: + mcp_client_manager = McpClientManager.from_config(fast_mcp_config) + McpClientManagerManager._mcp_client_managers[key] = ( + mcp_client_manager + ) + await mcp_client_manager.connect() + except Exception: + if key in McpClientManagerManager._mcp_client_managers: + if ( + _mcp_client_manager + := McpClientManagerManager._mcp_client_managers[key] + ): + await _mcp_client_manager.disconnect(force=True) + del McpClientManagerManager._mcp_client_managers[key] + raise + return McpClientManagerManager._mcp_client_managers[key] + + async def disconnect_all(self, force: bool = False) -> None: + async with McpClientManagerManager._lock: + for ( + mcp_client_manager + ) in McpClientManagerManager._mcp_client_managers.values(): + if mcp_client_manager: + await mcp_client_manager.disconnect(force) diff --git a/src/askui/chat/api/mcp_configs/seeds.py b/src/askui/chat/api/mcp_configs/seeds.py index 57dd1f1d..8ac1bb2b 100644 --- a/src/askui/chat/api/mcp_configs/seeds.py +++ b/src/askui/chat/api/mcp_configs/seeds.py @@ -1,4 +1,4 @@ -from fastmcp.mcp_config import RemoteMCPServer +from fastmcp.mcp_config import RemoteMCPServer, StdioMCPServer from askui.chat.api.dependencies import get_settings from askui.chat.api.mcp_configs.models import McpConfig @@ -10,7 +10,7 @@ ASKUI_CHAT_MCP = McpConfig( id="mcpcnf_68ac2c4edc4b2f27faa5a252", created_at=now(), - name="AskUI Chat MCP", + name="askui_chat", mcp_server=RemoteMCPServer( url=f"http://{settings.host}:{settings.port}/mcp/sse", transport="sse", @@ -18,4 +18,15 @@ ) -SEEDS = [ASKUI_CHAT_MCP] +PLAYWRIGHT_MCP = McpConfig( + id="mcpcnf_68ac2c4edc4b2f27faa5a251", + created_at=now(), + name="playwright", + mcp_server=StdioMCPServer( + command="npx", + args=["@playwright/mcp@latest", "--isolated"], + ), +) + + +SEEDS = [ASKUI_CHAT_MCP, PLAYWRIGHT_MCP] diff --git a/src/askui/chat/api/mcp_configs/service.py b/src/askui/chat/api/mcp_configs/service.py index 140d03fa..4e5980c9 100644 --- a/src/askui/chat/api/mcp_configs/service.py +++ b/src/askui/chat/api/mcp_configs/service.py @@ -1,5 +1,7 @@ from pathlib import Path +from fastmcp.mcp_config import MCPConfig + from askui.chat.api.mcp_configs.models import ( McpConfig, McpConfigCreateParams, @@ -69,6 +71,18 @@ def retrieve( else: return mcp_config + def retrieve_fast_mcp_config( + self, workspace_id: WorkspaceId | None + ) -> MCPConfig | None: + list_response = self.list_( + workspace_id=workspace_id, + query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"), + ) + mcp_servers_dict = { + mcp_config.name: mcp_config.mcp_server for mcp_config in list_response.data + } + return MCPConfig(mcpServers=mcp_servers_dict) if mcp_servers_dict else None + def _check_limit(self, workspace_id: WorkspaceId | None) -> None: limit = LIST_LIMIT_MAX list_result = self.list_(workspace_id, ListQuery(limit=limit)) diff --git a/src/askui/chat/api/runs/dependencies.py b/src/askui/chat/api/runs/dependencies.py index eea848d0..fca6d6bc 100644 --- a/src/askui/chat/api/runs/dependencies.py +++ b/src/askui/chat/api/runs/dependencies.py @@ -5,8 +5,8 @@ from askui.chat.api.assistants.dependencies import AssistantServiceDep from askui.chat.api.assistants.service import AssistantService from askui.chat.api.dependencies import WorkspaceDirDep -from askui.chat.api.mcp_configs.dependencies import McpConfigServiceDep -from askui.chat.api.mcp_configs.service import McpConfigService +from askui.chat.api.mcp_clients.dependencies import McpClientManagerManagerDep +from askui.chat.api.mcp_clients.manager import McpClientManagerManager from askui.chat.api.messages.dependencies import MessageServiceDep, MessageTranslatorDep from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator @@ -17,7 +17,7 @@ def get_runs_service( workspace_dir: Path = WorkspaceDirDep, assistant_service: AssistantService = AssistantServiceDep, - mcp_config_service: McpConfigService = McpConfigServiceDep, + mcp_client_manager_manager: McpClientManagerManager = McpClientManagerManagerDep, message_service: MessageService = MessageServiceDep, message_translator: MessageTranslator = MessageTranslatorDep, ) -> RunService: @@ -25,7 +25,7 @@ def get_runs_service( return RunService( base_dir=workspace_dir, assistant_service=assistant_service, - mcp_config_service=mcp_config_service, + mcp_client_manager_manager=mcp_client_manager_manager, message_service=message_service, message_translator=message_translator, ) diff --git a/src/askui/chat/api/runs/runner/runner.py b/src/askui/chat/api/runs/runner/runner.py index 944b173c..76edb9e5 100644 --- a/src/askui/chat/api/runs/runner/runner.py +++ b/src/askui/chat/api/runs/runner/runner.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Sequence +from typing import TYPE_CHECKING, Literal import anthropic import anyio @@ -8,7 +8,6 @@ from asyncer import asyncify, syncify from fastmcp import Client from fastmcp.client.transports import MCPConfigTransport -from fastmcp.mcp_config import MCPConfig from askui.android_agent import AndroidVisionAgent from askui.chat.api.assistants.models import Assistant @@ -18,8 +17,7 @@ TESTING_AGENT, WEB_AGENT, ) -from askui.chat.api.mcp_configs.models import McpConfig -from askui.chat.api.mcp_configs.service import McpConfigService +from askui.chat.api.mcp_clients.manager import McpClientManagerManager from askui.chat.api.messages.models import MessageCreateParams from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator @@ -57,13 +55,6 @@ logger = logging.getLogger(__name__) -def build_fast_mcp_config(mcp_configs: Sequence[McpConfig]) -> MCPConfig: - mcp_config_dict = { - mcp_config.id: mcp_config.mcp_server for mcp_config in mcp_configs - } - return MCPConfig(mcpServers=mcp_config_dict) - - McpClient = Client[MCPConfigTransport] @@ -85,7 +76,7 @@ def __init__( run: Run, message_service: MessageService, message_translator: MessageTranslator, - mcp_config_service: McpConfigService, + mcp_client_manager_manager: McpClientManagerManager, run_service: RunnerRunService, ) -> None: self._workspace_id = workspace_id @@ -94,18 +85,10 @@ def __init__( self._message_service = message_service self._message_translator = message_translator self._message_content_translator = message_translator.content_translator - self._mcp_config_service = mcp_config_service + self._mcp_client_manager_manager = mcp_client_manager_manager self._run_service = run_service self._agent_os = PynputAgentOs() - def _get_mcp_client(self) -> McpClient | None: - mcp_configs = self._mcp_config_service.list_( - workspace_id=self._workspace_id, - query=ListQuery(limit=LIST_LIMIT_MAX, order="asc"), - ) - fast_mcp_config = build_fast_mcp_config(mcp_configs.data) - return Client(fast_mcp_config) if fast_mcp_config.mcpServers else None - def _retrieve(self) -> Run: return self._run_service.retrieve( thread_id=self._run.thread_id, @@ -342,19 +325,24 @@ def _run_agent_inner() -> None: await asyncify(_run_agent_inner)() + async def _get_mcp_client(self) -> McpClient | None: + return await self._mcp_client_manager_manager.get_mcp_client_manager( # type: ignore + self._workspace_id + ) + async def run( self, send_stream: ObjectStream[Events], ) -> None: - mcp_client = self._get_mcp_client() - self._mark_run_as_started() - await send_stream.send( - RunEvent( - data=self._run, - event="thread.run.in_progress", - ) - ) try: + mcp_client = await self._get_mcp_client() + self._mark_run_as_started() + await send_stream.send( + RunEvent( + data=self._run, + event="thread.run.in_progress", + ) + ) if self._run.assistant_id == HUMAN_DEMONSTRATION_AGENT.id: await self._run_human_agent(send_stream) elif self._run.assistant_id == ANDROID_AGENT.id: diff --git a/src/askui/chat/api/runs/service.py b/src/askui/chat/api/runs/service.py index 83e3cf3c..98c72afd 100644 --- a/src/askui/chat/api/runs/service.py +++ b/src/askui/chat/api/runs/service.py @@ -6,7 +6,7 @@ from typing_extensions import override from askui.chat.api.assistants.service import AssistantService -from askui.chat.api.mcp_configs.service import McpConfigService +from askui.chat.api.mcp_clients.manager import McpClientManagerManager from askui.chat.api.messages.service import MessageService from askui.chat.api.messages.translator import MessageTranslator from askui.chat.api.models import RunId, ThreadId, WorkspaceId @@ -34,13 +34,13 @@ def __init__( self, base_dir: Path, assistant_service: AssistantService, - mcp_config_service: McpConfigService, + mcp_client_manager_manager: McpClientManagerManager, message_service: MessageService, message_translator: MessageTranslator, ) -> None: self._base_dir = base_dir self._assistant_service = assistant_service - self._mcp_config_service = mcp_config_service + self._mcp_client_manager_manager = mcp_client_manager_manager self._message_service = message_service self._message_translator = message_translator @@ -82,7 +82,7 @@ async def create( run=run, message_service=self._message_service, message_translator=self._message_translator, - mcp_config_service=self._mcp_config_service, + mcp_client_manager_manager=self._mcp_client_manager_manager, run_service=self, ) diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 344a3281..5868e423 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -1,12 +1,13 @@ +import types from abc import ABC, abstractmethod -from typing import Any, Literal, cast +from datetime import timedelta +from typing import Any, Literal, Protocol, Type, cast +import mcp from anthropic.types.beta import BetaToolParam, BetaToolUnionParam from anthropic.types.beta.beta_tool_param import InputSchema from asyncer import syncify -from fastmcp import Client -from fastmcp.client.client import CallToolResult -from fastmcp.client.transports import ClientTransportT +from fastmcp.client.client import CallToolResult, ProgressHandler from mcp import Tool as McpTool from PIL import Image from pydantic import BaseModel, Field @@ -128,6 +129,56 @@ def __init__(self, message: str): super().__init__(self.message) +class McpClientProtocol(Protocol): + """ + Protocol defining the interface for MCP client managers. + + This protocol captures the essential methods for managing MCP tools: + listing available tools and calling them with appropriate parameters. + """ + + async def list_tools(self) -> list[mcp.types.Tool]: + """ + Retrieve all available tools from all connected MCP servers. + + Returns: + list[mcp.types.Tool]: A list of all available tools across all servers. + """ + ... + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + timeout: timedelta | float | None = None, # noqa: ASYNC109 + progress_handler: ProgressHandler | None = None, + raise_on_error: bool = True, + ) -> CallToolResult: + """ + Call a tool by name with the provided arguments. + + Args: + name (str): The name of the tool to call. + arguments (dict[str, Any] | None, optional): Arguments to pass to the tool. + timeout (timedelta | float | None, optional): Timeout for the tool call. + progress_handler (ProgressHandler | None, optional): Handler for progress updates. + raise_on_error (bool, optional): Whether to raise an exception on error. + + Returns: + CallToolResult: The result of the tool call. + """ + ... + + async def __aenter__(self) -> Self: ... + + async def __aexit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: ... + + class ToolCollection: """A collection of tools. @@ -146,14 +197,14 @@ class ToolCollection: Args: tools (list[Tool] | None, optional): The tools to add to the collection. Defaults to `None`. - mcp_client (Client[ClientTransportT] | None, optional): The client to use for + mcp_client (McpClientProtocol | None, optional): The client to use for the tools. Defaults to `None`. """ def __init__( self, tools: list[Tool] | None = None, - mcp_client: Client[ClientTransportT] | None = None, + mcp_client: McpClientProtocol | None = None, include: set[str] | None = None, ) -> None: _tools = tools or [] @@ -227,9 +278,7 @@ def _run_tool( tool_use_id=tool_use_block_param.id, ) - async def _list_mcp_tools( - self, mcp_client: Client[ClientTransportT] - ) -> list[McpTool]: + async def _list_mcp_tools(self, mcp_client: McpClientProtocol) -> list[McpTool]: async with mcp_client: return await mcp_client.list_tools() @@ -269,7 +318,7 @@ def _run_regular_tool( async def _call_mcp_tool( self, - mcp_client: Client[ClientTransportT], + mcp_client: McpClientProtocol, tool_use_block_param: ToolUseBlockParam, ) -> ToolCallResult: async with mcp_client: From b246bafc4a023900bb0cb0d8c2f8345cecf4efff Mon Sep 17 00:00:00 2001 From: Adrian Stritzinger Date: Fri, 5 Sep 2025 17:10:39 +0200 Subject: [PATCH 2/2] chore: fix tests & linting --- src/askui/chat/api/mcp_clients/manager.py | 4 +- src/askui/models/shared/tools.py | 32 +-------- .../integration/chat/api/test_mcp_configs.py | 12 ---- tests/integration/chat/api/test_runs.py | 68 +++++++++---------- 4 files changed, 38 insertions(+), 78 deletions(-) diff --git a/src/askui/chat/api/mcp_clients/manager.py b/src/askui/chat/api/mcp_clients/manager.py index 9aa49cde..825fd3a4 100644 --- a/src/askui/chat/api/mcp_clients/manager.py +++ b/src/askui/chat/api/mcp_clients/manager.py @@ -54,7 +54,7 @@ async def disconnect(self, force: bool = False) -> None: async def list_tools( self, - ) -> list[mcp.types.Tool]: # TODO Proper cache and parallelization + ) -> list[mcp.types.Tool]: tools: list[mcp.types.Tool] = [] for mcp_server_name, mcp_client in self._mcp_clients.items(): if mcp_server_name not in self._tools: @@ -66,7 +66,7 @@ async def call_tool( self, name: str, arguments: dict[str, Any] | None = None, - timeout: timedelta | float | None = None, + timeout: timedelta | float | None = None, # noqa: ASYNC109 progress_handler: ProgressHandler | None = None, raise_on_error: bool = True, ) -> CallToolResult: diff --git a/src/askui/models/shared/tools.py b/src/askui/models/shared/tools.py index 5868e423..4e45f1e7 100644 --- a/src/askui/models/shared/tools.py +++ b/src/askui/models/shared/tools.py @@ -130,21 +130,7 @@ def __init__(self, message: str): class McpClientProtocol(Protocol): - """ - Protocol defining the interface for MCP client managers. - - This protocol captures the essential methods for managing MCP tools: - listing available tools and calling them with appropriate parameters. - """ - - async def list_tools(self) -> list[mcp.types.Tool]: - """ - Retrieve all available tools from all connected MCP servers. - - Returns: - list[mcp.types.Tool]: A list of all available tools across all servers. - """ - ... + async def list_tools(self) -> list[mcp.types.Tool]: ... async def call_tool( self, @@ -153,21 +139,7 @@ async def call_tool( timeout: timedelta | float | None = None, # noqa: ASYNC109 progress_handler: ProgressHandler | None = None, raise_on_error: bool = True, - ) -> CallToolResult: - """ - Call a tool by name with the provided arguments. - - Args: - name (str): The name of the tool to call. - arguments (dict[str, Any] | None, optional): Arguments to pass to the tool. - timeout (timedelta | float | None, optional): Timeout for the tool call. - progress_handler (ProgressHandler | None, optional): Handler for progress updates. - raise_on_error (bool, optional): Whether to raise an exception on error. - - Returns: - CallToolResult: The result of the tool call. - """ - ... + ) -> CallToolResult: ... async def __aenter__(self) -> Self: ... diff --git a/tests/integration/chat/api/test_mcp_configs.py b/tests/integration/chat/api/test_mcp_configs.py index f29c5a2c..aa3ddfab 100644 --- a/tests/integration/chat/api/test_mcp_configs.py +++ b/tests/integration/chat/api/test_mcp_configs.py @@ -13,18 +13,6 @@ class TestMcpConfigsAPI: """Test suite for the MCP configs API endpoints.""" - def test_list_mcp_configs_empty( - self, test_client: TestClient, test_headers: dict[str, str] - ) -> None: - """Test listing MCP configs when no configs exist.""" - response = test_client.get("/v1/mcp-configs", headers=test_headers) - - assert response.status_code == status.HTTP_200_OK - data = response.json() - assert data["object"] == "list" - assert data["data"] == [] - assert data["has_more"] is False - def test_list_mcp_configs_with_configs(self, test_headers: dict[str, str]) -> None: """Test listing MCP configs when configs exist.""" temp_dir = tempfile.mkdtemp() diff --git a/tests/integration/chat/api/test_runs.py b/tests/integration/chat/api/test_runs.py index 0338180b..21ffa310 100644 --- a/tests/integration/chat/api/test_runs.py +++ b/tests/integration/chat/api/test_runs.py @@ -14,11 +14,11 @@ from askui.chat.api.threads.service import ThreadService -def create_mock_mcp_config_service() -> Mock: +def create_mock_mcp_client_manager_manager() -> Mock: """Create a properly configured mock MCP config service.""" mock_service = Mock() # Configure mock to return proper data structure - mock_service.list_.return_value.data = [] + mock_service.get_mcp_client_manager.return_value = None return mock_service @@ -54,13 +54,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -126,13 +126,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -199,13 +199,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -255,13 +255,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -321,13 +321,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -381,13 +381,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -430,13 +430,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -491,13 +491,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -543,13 +543,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -598,13 +598,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -666,13 +666,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -716,13 +716,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -803,13 +803,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -889,13 +889,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_assistant_service = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() mock_message_service = Mock() mock_message_translator = Mock() return RunService( base_dir=workspace_path, assistant_service=mock_assistant_service, - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -980,13 +980,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_message_service = Mock() mock_message_translator = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() from askui.chat.api.assistants.service import AssistantService return RunService( base_dir=workspace_path, assistant_service=AssistantService(workspace_path), - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, ) @@ -1068,13 +1068,13 @@ def override_thread_service() -> ThreadService: def override_runs_service() -> RunService: mock_message_service = Mock() mock_message_translator = Mock() - mock_mcp_config_service = create_mock_mcp_config_service() + mock_mcp_client_manager_manager = create_mock_mcp_client_manager_manager() from askui.chat.api.assistants.service import AssistantService return RunService( base_dir=workspace_path, assistant_service=AssistantService(workspace_path), - mcp_config_service=mock_mcp_config_service, + mcp_client_manager_manager=mock_mcp_client_manager_manager, message_service=mock_message_service, message_translator=mock_message_translator, )