Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 60 additions & 18 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,16 @@ def __init__(
self._connection_params = connection_params
self._errlog = errlog

# Session pool: maps session keys to (session, exit_stack, loop) tuples
# Session pool: maps session keys to
# (session, exit_stack, loop, session_context) tuples
self._sessions: Dict[
str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop]
str,
tuple[
ClientSession,
AsyncExitStack,
asyncio.AbstractEventLoop,
SessionContext,
],
] = {}

# Map of event loops to their respective locks to prevent race conditions
Expand Down Expand Up @@ -307,16 +314,49 @@ def _merge_headers(

return base_headers

def _is_session_disconnected(self, session: ClientSession) -> bool:
def _is_session_disconnected(
self,
session: ClientSession,
session_context: Optional[SessionContext] = None,
) -> bool:
"""Checks if a session is disconnected or closed.

Args:
session: The ClientSession to check.
session_context: Optional SessionContext to check if the background
task has died (e.g. due to a transport crash).

Returns:
True if the session is disconnected, False otherwise.
"""
return session._read_stream._closed or session._write_stream._closed
if session._read_stream._closed or session._write_stream._closed:
return True
if session_context is not None and not session_context.is_task_alive:
return True
return False

def get_session_context(
self, headers: Optional[Dict[str, str]] = None
) -> Optional[SessionContext]:
"""Returns the SessionContext for the session matching the given headers.

Note: This method reads from the session pool without acquiring
``_session_lock``. This is safe because it is called immediately after
``create_session()`` (which populates the entry under the lock) within
the same task, and dict reads are atomic in CPython.

Args:
headers: Optional headers used to identify the session.

Returns:
The SessionContext if a matching session exists, None otherwise.
"""
merged_headers = self._merge_headers(headers)
session_key = self._generate_session_key(merged_headers)
entry = self._sessions.get(session_key)
if entry is not None:
return entry[3]
return None

async def _cleanup_session(
self,
Expand Down Expand Up @@ -445,12 +485,14 @@ async def create_session(
async with self._session_lock:
# Check if we have an existing session
if session_key in self._sessions:
session, exit_stack, stored_loop = self._sessions[session_key]
session, exit_stack, stored_loop, session_ctx = self._sessions[
session_key
]

# Check if the existing session is still connected and bound to the current loop
current_loop = asyncio.get_running_loop()
if stored_loop is current_loop and not self._is_session_disconnected(
session
session, session_ctx
):
# Session is still good, return it
return session
Expand Down Expand Up @@ -479,25 +521,25 @@ async def create_session(
client = self._create_client(merged_headers)
is_stdio = isinstance(self._connection_params, StdioConnectionParams)

session_context = SessionContext(
client=client,
timeout=timeout_in_seconds,
sse_read_timeout=sse_read_timeout_in_seconds,
is_stdio=is_stdio,
sampling_callback=self._sampling_callback,
sampling_capabilities=self._sampling_capabilities,
)
session = await asyncio.wait_for(
exit_stack.enter_async_context(
SessionContext(
client=client,
timeout=timeout_in_seconds,
sse_read_timeout=sse_read_timeout_in_seconds,
is_stdio=is_stdio,
sampling_callback=self._sampling_callback,
sampling_capabilities=self._sampling_capabilities,
)
),
exit_stack.enter_async_context(session_context),
timeout=timeout_in_seconds,
)

# Store session, exit stack, and loop in the pool
# Store session, exit stack, loop, and context in the pool
self._sessions[session_key] = (
session,
exit_stack,
asyncio.get_running_loop(),
session_context,
)
logger.debug('Created new session: %s', session_key)
return session
Expand Down Expand Up @@ -541,7 +583,7 @@ async def close(self):
"""Closes all sessions and cleans up resources."""
async with self._session_lock:
for session_key in list(self._sessions.keys()):
_, exit_stack, stored_loop = self._sessions[session_key]
_, exit_stack, stored_loop, _ = self._sessions[session_key]
await self._cleanup_session(session_key, exit_stack, stored_loop)


Expand Down
15 changes: 14 additions & 1 deletion src/google/adk/tools/mcp_tool/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,25 @@ async def _run_async_impl(
# Resolve progress callback (may be a factory that needs runtime context)
resolved_callback = self._resolve_progress_callback(tool_context)

response = await session.call_tool(
call_coro = session.call_tool(
self._mcp_tool.name,
arguments=args,
progress_callback=resolved_callback,
meta=meta_trace_context,
)

# Race the tool call against the background session task so that
# transport crashes (e.g. non-2xx HTTP responses) surface immediately
# instead of hanging until sse_read_timeout expires.
# ConnectionError is intentionally NOT caught here so that it
# propagates to retry_on_errors, which will create a fresh session.
session_context = self._mcp_session_manager.get_session_context(
headers=final_headers
)
if session_context:
response = await session_context.run_guarded(call_coro)
else:
response = await call_coro
result = response.model_dump(exclude_none=True, mode="json")

# Push UI widget to the event actions if the tool supports it.
Expand Down
69 changes: 69 additions & 0 deletions src/google/adk/tools/mcp_tool/session_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
from contextlib import AsyncExitStack
from datetime import timedelta
import logging
from typing import Any
from typing import AsyncContextManager
from typing import Coroutine
from typing import Optional
from typing import TypeVar

T = TypeVar('T')

from mcp import ClientSession
from mcp import SamplingCapability
Expand Down Expand Up @@ -89,6 +94,15 @@ def session(self) -> Optional[ClientSession]:
"""Get the managed ClientSession, if available."""
return self._session

@property
def is_task_alive(self) -> bool:
"""Whether the background session task is currently running.

Returns True only when the task has been started and has not yet completed.
Returns False if the task has not been started or has finished.
"""
return self._task is not None and not self._task.done()

async def start(self) -> ClientSession:
"""Start the runner and wait for the session to be ready.

Expand Down Expand Up @@ -123,8 +137,63 @@ async def start(self) -> ClientSession:
f'Failed to create MCP session: {self._task.exception()}'
) from self._task.exception()

if self._session is None:
raise ConnectionError('Failed to create MCP session: unknown error')

return self._session

async def run_guarded(self, coro: Coroutine[Any, Any, T]) -> T:
"""Run a coroutine while monitoring the background session task.

Races the given coroutine against the background task. If the task
dies first (e.g. transport crash from a non-2xx HTTP response), the
coroutine is cancelled and the original error is raised immediately
instead of hanging until a read timeout expires.

Args:
coro: The coroutine to run (e.g. session.call_tool(...)).

Returns:
The result of the coroutine.

Raises:
ConnectionError: If the background task has already died or dies
during execution, wrapping the original exception.
"""
if self._task is None:
coro.close()
raise ConnectionError('MCP session task has not been started')

if self._task.done():
exc = self._task.exception() if not self._task.cancelled() else None
# Close the coroutine to avoid "was never awaited" warnings
coro.close()
raise ConnectionError(
f'MCP session task has already terminated: {exc}'
) from exc

coro_task = asyncio.ensure_future(coro)

done, _ = await asyncio.wait(
[coro_task, self._task],
return_when=asyncio.FIRST_COMPLETED,
)

if coro_task in done:
# If the coroutine itself raised, the exception propagates as-is
# (not wrapped in ConnectionError) — this is intentional.
return coro_task.result()

# Background task finished first — transport crash
coro_task.cancel()
try:
await coro_task
except (asyncio.CancelledError, Exception):
pass

exc = self._task.exception() if not self._task.cancelled() else None
raise ConnectionError(f'MCP session connection lost: {exc}') from exc

async def close(self):
"""Signal the context task to close and wait for cleanup."""
# Set the close event to signal the task to close.
Expand Down
Loading
Loading