diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index 39c767c..5842a18 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -8,6 +8,7 @@ import threading from abc import ABC, abstractmethod from contextlib import AsyncExitStack, asynccontextmanager +from datetime import timedelta from functools import partial from typing import Any, AsyncGenerator, Callable, Coroutine @@ -71,6 +72,7 @@ def async_adapt( @asynccontextmanager async def mcptools( serverparams: StdioServerParameters | dict[str, Any], + client_session_timeout_seconds: float | timedelta | None = 5, ) -> AsyncGenerator[tuple[ClientSession, list[mcp.types.Tool]], None]: """Async context manager that yields tools from an MCP server. @@ -81,6 +83,7 @@ async def mcptools( serverparams: Parameters passed to either the stdio client or sse client. * if StdioServerParameters, run the MCP server using the stdio protocol. * if dict, assume the dict corresponds to parameters to an sse MCP server. + client_session_timeout_seconds: Timeout for MCP ClientSession calls Yields: A tuple of (MCP Client Session, list of MCP tools) available on the MCP server. @@ -98,8 +101,18 @@ async def mcptools( f"Invalid serverparams, expected StdioServerParameters or dict found `{type(serverparams)}`" ) + timeout = None + if isinstance(client_session_timeout_seconds, float): + timeout = timedelta(seconds=client_session_timeout_seconds) + elif isinstance(client_session_timeout_seconds, timedelta): + timeout = client_session_timeout_seconds + async with client as (read, write): - async with ClientSession(read, write) as session: + async with ClientSession( + read, + write, + timeout, + ) as session: # Initialize the connection and get the tools from the mcp server await session.initialize() tools = await session.list_tools()