From fb16e44e5c56a421b6ae339a475d3408bd5bcc47 Mon Sep 17 00:00:00 2001 From: James Duncan Date: Thu, 19 Mar 2026 18:55:54 +0000 Subject: [PATCH] feat(mcp): handle transport crashes gracefully in MCP tool calls Race tool call coroutines against the background session task so that transport crashes (e.g. non-2xx HTTP responses) surface immediately instead of hanging until sse_read_timeout expires. - Add run_guarded() method to SessionContext - Add is_task_alive property for clean cross-class API - Let ConnectionError propagate to retry_on_errors for session recovery - Enhance _is_session_disconnected to detect dead background tasks - Add 6 new test cases --- .../adk/tools/mcp_tool/mcp_session_manager.py | 78 +++++-- src/google/adk/tools/mcp_tool/mcp_tool.py | 15 +- .../adk/tools/mcp_tool/session_context.py | 69 +++++++ .../mcp_tool/test_mcp_session_manager.py | 112 +++++++++- .../unittests/tools/mcp_tool/test_mcp_tool.py | 1 + .../tools/mcp_tool/test_session_context.py | 195 ++++++++++++++++++ 6 files changed, 447 insertions(+), 23 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index e0cd1ebc89..b5b29d1769 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -232,9 +232,16 @@ def __init__( self._connection_params = connection_params self._errlog = errlog - # Session pool: maps session keys to (session, exit_stack, loop) tuples + # Session pool: maps session keys to + # (session, exit_stack, loop, session_context) tuples self._sessions: Dict[ - str, tuple[ClientSession, AsyncExitStack, asyncio.AbstractEventLoop] + str, + tuple[ + ClientSession, + AsyncExitStack, + asyncio.AbstractEventLoop, + SessionContext, + ], ] = {} # Map of event loops to their respective locks to prevent race conditions @@ -307,16 +314,49 @@ def _merge_headers( return base_headers - def _is_session_disconnected(self, session: ClientSession) -> bool: + def _is_session_disconnected( + self, + session: ClientSession, + session_context: Optional[SessionContext] = None, + ) -> bool: """Checks if a session is disconnected or closed. Args: session: The ClientSession to check. + session_context: Optional SessionContext to check if the background + task has died (e.g. due to a transport crash). Returns: True if the session is disconnected, False otherwise. """ - return session._read_stream._closed or session._write_stream._closed + if session._read_stream._closed or session._write_stream._closed: + return True + if session_context is not None and not session_context.is_task_alive: + return True + return False + + def get_session_context( + self, headers: Optional[Dict[str, str]] = None + ) -> Optional[SessionContext]: + """Returns the SessionContext for the session matching the given headers. + + Note: This method reads from the session pool without acquiring + ``_session_lock``. This is safe because it is called immediately after + ``create_session()`` (which populates the entry under the lock) within + the same task, and dict reads are atomic in CPython. + + Args: + headers: Optional headers used to identify the session. + + Returns: + The SessionContext if a matching session exists, None otherwise. + """ + merged_headers = self._merge_headers(headers) + session_key = self._generate_session_key(merged_headers) + entry = self._sessions.get(session_key) + if entry is not None: + return entry[3] + return None async def _cleanup_session( self, @@ -445,12 +485,14 @@ async def create_session( async with self._session_lock: # Check if we have an existing session if session_key in self._sessions: - session, exit_stack, stored_loop = self._sessions[session_key] + session, exit_stack, stored_loop, session_ctx = self._sessions[ + session_key + ] # Check if the existing session is still connected and bound to the current loop current_loop = asyncio.get_running_loop() if stored_loop is current_loop and not self._is_session_disconnected( - session + session, session_ctx ): # Session is still good, return it return session @@ -479,25 +521,25 @@ async def create_session( client = self._create_client(merged_headers) is_stdio = isinstance(self._connection_params, StdioConnectionParams) + session_context = SessionContext( + client=client, + timeout=timeout_in_seconds, + sse_read_timeout=sse_read_timeout_in_seconds, + is_stdio=is_stdio, + sampling_callback=self._sampling_callback, + sampling_capabilities=self._sampling_capabilities, + ) session = await asyncio.wait_for( - exit_stack.enter_async_context( - SessionContext( - client=client, - timeout=timeout_in_seconds, - sse_read_timeout=sse_read_timeout_in_seconds, - is_stdio=is_stdio, - sampling_callback=self._sampling_callback, - sampling_capabilities=self._sampling_capabilities, - ) - ), + exit_stack.enter_async_context(session_context), timeout=timeout_in_seconds, ) - # Store session, exit stack, and loop in the pool + # Store session, exit stack, loop, and context in the pool self._sessions[session_key] = ( session, exit_stack, asyncio.get_running_loop(), + session_context, ) logger.debug('Created new session: %s', session_key) return session @@ -541,7 +583,7 @@ async def close(self): """Closes all sessions and cleans up resources.""" async with self._session_lock: for session_key in list(self._sessions.keys()): - _, exit_stack, stored_loop = self._sessions[session_key] + _, exit_stack, stored_loop, _ = self._sessions[session_key] await self._cleanup_session(session_key, exit_stack, stored_loop) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 9a2fd5fcfd..eb7aa00618 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -373,12 +373,25 @@ async def _run_async_impl( # Resolve progress callback (may be a factory that needs runtime context) resolved_callback = self._resolve_progress_callback(tool_context) - response = await session.call_tool( + call_coro = session.call_tool( self._mcp_tool.name, arguments=args, progress_callback=resolved_callback, meta=meta_trace_context, ) + + # Race the tool call against the background session task so that + # transport crashes (e.g. non-2xx HTTP responses) surface immediately + # instead of hanging until sse_read_timeout expires. + # ConnectionError is intentionally NOT caught here so that it + # propagates to retry_on_errors, which will create a fresh session. + session_context = self._mcp_session_manager.get_session_context( + headers=final_headers + ) + if session_context: + response = await session_context.run_guarded(call_coro) + else: + response = await call_coro result = response.model_dump(exclude_none=True, mode="json") # Push UI widget to the event actions if the tool supports it. diff --git a/src/google/adk/tools/mcp_tool/session_context.py b/src/google/adk/tools/mcp_tool/session_context.py index 23e968fe52..1deaebbacd 100644 --- a/src/google/adk/tools/mcp_tool/session_context.py +++ b/src/google/adk/tools/mcp_tool/session_context.py @@ -18,8 +18,13 @@ from contextlib import AsyncExitStack from datetime import timedelta import logging +from typing import Any from typing import AsyncContextManager +from typing import Coroutine from typing import Optional +from typing import TypeVar + +T = TypeVar('T') from mcp import ClientSession from mcp import SamplingCapability @@ -89,6 +94,15 @@ def session(self) -> Optional[ClientSession]: """Get the managed ClientSession, if available.""" return self._session + @property + def is_task_alive(self) -> bool: + """Whether the background session task is currently running. + + Returns True only when the task has been started and has not yet completed. + Returns False if the task has not been started or has finished. + """ + return self._task is not None and not self._task.done() + async def start(self) -> ClientSession: """Start the runner and wait for the session to be ready. @@ -123,8 +137,63 @@ async def start(self) -> ClientSession: f'Failed to create MCP session: {self._task.exception()}' ) from self._task.exception() + if self._session is None: + raise ConnectionError('Failed to create MCP session: unknown error') + return self._session + async def run_guarded(self, coro: Coroutine[Any, Any, T]) -> T: + """Run a coroutine while monitoring the background session task. + + Races the given coroutine against the background task. If the task + dies first (e.g. transport crash from a non-2xx HTTP response), the + coroutine is cancelled and the original error is raised immediately + instead of hanging until a read timeout expires. + + Args: + coro: The coroutine to run (e.g. session.call_tool(...)). + + Returns: + The result of the coroutine. + + Raises: + ConnectionError: If the background task has already died or dies + during execution, wrapping the original exception. + """ + if self._task is None: + coro.close() + raise ConnectionError('MCP session task has not been started') + + if self._task.done(): + exc = self._task.exception() if not self._task.cancelled() else None + # Close the coroutine to avoid "was never awaited" warnings + coro.close() + raise ConnectionError( + f'MCP session task has already terminated: {exc}' + ) from exc + + coro_task = asyncio.ensure_future(coro) + + done, _ = await asyncio.wait( + [coro_task, self._task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if coro_task in done: + # If the coroutine itself raised, the exception propagates as-is + # (not wrapped in ConnectionError) — this is intentional. + return coro_task.result() + + # Background task finished first — transport crash + coro_task.cancel() + try: + await coro_task + except (asyncio.CancelledError, Exception): + pass + + exc = self._task.exception() if not self._task.cancelled() else None + raise ConnectionError(f'MCP session connection lost: {exc}') from exc + async def close(self): """Signal the context task to close and wait for cleanup.""" # Set the close event to signal the task to close. diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 327df114a8..ce76200f68 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -301,9 +301,10 @@ async def test_create_session_stdio_new(self): assert len(manager._sessions) == 1 assert "stdio_session" in manager._sessions session_data = manager._sessions["stdio_session"] - assert len(session_data) == 3 + assert len(session_data) == 4 assert session_data[0] == mock_session assert session_data[2] == asyncio.get_running_loop() + assert session_data[3] is not None # SessionContext stored # Verify SessionContext was created mock_session_context_class.assert_called_once() @@ -318,10 +319,13 @@ async def test_create_session_reuse_existing(self): # Create mock existing session existing_session = MockClientSession() existing_exit_stack = MockAsyncExitStack() + mock_session_ctx = Mock() + mock_session_ctx.is_task_alive = True manager._sessions["stdio_session"] = ( existing_session, existing_exit_stack, asyncio.get_running_loop(), + mock_session_ctx, ) # Session is connected @@ -391,11 +395,13 @@ async def test_close_success(self): session1, exit_stack1, asyncio.get_running_loop(), + Mock(), ) manager._sessions["session2"] = ( session2, exit_stack2, asyncio.get_running_loop(), + Mock(), ) await manager.close() @@ -423,11 +429,13 @@ async def test_close_with_errors(self, mock_logger): session1, exit_stack1, asyncio.get_running_loop(), + Mock(), ) manager._sessions["session2"] = ( session2, exit_stack2, asyncio.get_running_loop(), + Mock(), ) # Should not raise exception @@ -553,10 +561,13 @@ async def test_create_session_cleans_up_without_aclose_if_loop_is_different( # Use a dummy object as a different loop different_loop = Mock(spec=asyncio.AbstractEventLoop) + mock_session_ctx = Mock() + mock_session_ctx.is_task_alive = True manager._sessions["stdio_session"] = ( mock_session, mock_exit_stack, different_loop, + mock_session_ctx, ) # 2. Mock creation of a new session @@ -594,11 +605,21 @@ async def test_close_skips_aclose_for_different_loop_sessions(self): session1 = MockClientSession() exit_stack1 = MockAsyncExitStack() - manager._sessions["session1"] = (session1, exit_stack1, current_loop) + manager._sessions["session1"] = ( + session1, + exit_stack1, + current_loop, + Mock(), + ) session2 = MockClientSession() exit_stack2 = MockAsyncExitStack() - manager._sessions["session2"] = (session2, exit_stack2, different_loop) + manager._sessions["session2"] = ( + session2, + exit_stack2, + different_loop, + Mock(), + ) await manager.close() @@ -619,7 +640,12 @@ async def test_pickle_mcp_session_manager(self): assert isinstance(lock, asyncio.Lock) # Add a mock session to verify it's cleared on pickling - manager._sessions["test"] = (Mock(), Mock(), asyncio.get_running_loop()) + manager._sessions["test"] = ( + Mock(), + Mock(), + asyncio.get_running_loop(), + Mock(), + ) # Pickle and unpickle pickled = pickle.dumps(manager) @@ -726,3 +752,81 @@ async def mock_function(self): await mock_function(mock_self) assert call_count == 1 + + +class TestGetSessionContext: + """Tests for MCPSessionManager.get_session_context().""" + + def setup_method(self): + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_stdio_connection_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=5.0 + ) + + def test_get_session_context_returns_context(self): + """Test that get_session_context returns the stored SessionContext.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + mock_ctx = Mock() + manager._sessions["stdio_session"] = ( + Mock(), + Mock(), + Mock(), + mock_ctx, + ) + + result = manager.get_session_context() + assert result is mock_ctx + + def test_get_session_context_returns_none_when_no_session(self): + """Test that get_session_context returns None when no session exists.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + result = manager.get_session_context() + assert result is None + + +class TestIsSessionDisconnectedWithContext: + """Tests for enhanced _is_session_disconnected with SessionContext.""" + + def setup_method(self): + self.mock_stdio_params = StdioServerParameters( + command="test_command", args=[] + ) + self.mock_stdio_connection_params = StdioConnectionParams( + server_params=self.mock_stdio_params, timeout=5.0 + ) + + def test_detects_dead_task(self): + """Test that a done background task is detected as disconnected.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + session = MockClientSession() + + mock_ctx = Mock() + mock_ctx.is_task_alive = False + + assert manager._is_session_disconnected(session, mock_ctx) + + def test_alive_task_not_disconnected(self): + """Test that an alive background task is not detected as disconnected.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + session = MockClientSession() + + mock_ctx = Mock() + mock_ctx.is_task_alive = True + + assert not manager._is_session_disconnected(session, mock_ctx) + + def test_no_context_falls_back_to_stream_check(self): + """Test that without context, only stream state is checked.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + session = MockClientSession() + + assert not manager._is_session_disconnected(session) + assert not manager._is_session_disconnected(session, None) + + session._read_stream._closed = True + assert manager._is_session_disconnected(session) + assert manager._is_session_disconnected(session, None) diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index d6e39b94f3..ee2f695842 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -69,6 +69,7 @@ def setup_method(self): """Set up test fixtures.""" self.mock_mcp_tool = MockMCPTool() self.mock_session_manager = Mock(spec=MCPSessionManager) + self.mock_session_manager.get_session_context.return_value = None self.mock_session = AsyncMock() self.mock_session_manager.create_session = AsyncMock( return_value=self.mock_session diff --git a/tests/unittests/tools/mcp_tool/test_session_context.py b/tests/unittests/tools/mcp_tool/test_session_context.py index 161cd1aba3..20241f2edc 100644 --- a/tests/unittests/tools/mcp_tool/test_session_context.py +++ b/tests/unittests/tools/mcp_tool/test_session_context.py @@ -548,3 +548,198 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # Should not raise exception assert session_context._close_event.is_set() + + +class TestRunGuarded: + """Tests for SessionContext.run_guarded().""" + + @pytest.mark.asyncio + async def test_run_guarded_normal_completion(self): + """Test run_guarded returns result when coroutine completes normally.""" + mock_client = MockClient() + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + await session_context.start() + + async def my_coro(): + return 'hello' + + result = await session_context.run_guarded(my_coro()) + assert result == 'hello' + + await session_context.close() + + @pytest.mark.asyncio + async def test_run_guarded_background_task_crash(self): + """Test run_guarded raises ConnectionError when background task dies.""" + mock_client = MockClient() + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + await session_context.start() + + crash_error = RuntimeError('403 Forbidden') + + async def hanging_coro(): + await asyncio.sleep(100) + + # Kill the background task and replace with one that's already failed + session_context._task.cancel() + try: + await session_context._task + except (asyncio.CancelledError, Exception): + pass + + async def failing_task(): + raise crash_error + + session_context._task = asyncio.create_task(failing_task()) + await asyncio.sleep(0.01) + + coro = hanging_coro() + with pytest.raises(ConnectionError) as exc_info: + await session_context.run_guarded(coro) + + assert '403 Forbidden' in str(exc_info.value) + assert exc_info.value.__cause__ is crash_error + + @pytest.mark.asyncio + async def test_run_guarded_task_already_dead(self): + """Test run_guarded raises immediately when task is already done.""" + mock_client = MockClient() + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + await session_context.start() + + original_error = ValueError('transport died') + + async def failing_task(): + raise original_error + + session_context._task.cancel() + try: + await session_context._task + except (asyncio.CancelledError, Exception): + pass + + session_context._task = asyncio.create_task(failing_task()) + await asyncio.sleep(0.01) + + async def my_coro(): + return 'should not reach' + + coro = my_coro() + with pytest.raises(ConnectionError) as exc_info: + await session_context.run_guarded(coro) + + assert 'already terminated' in str(exc_info.value) + assert exc_info.value.__cause__ is original_error + + @pytest.mark.asyncio + async def test_run_guarded_cancels_coroutine_on_crash(self): + """Test that run_guarded cancels the coroutine when the task crashes.""" + mock_client = MockClient() + mock_session = MockClientSession() + + coro_was_cancelled = False + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + await session_context.start() + + async def slow_coro(): + nonlocal coro_was_cancelled + try: + await asyncio.sleep(100) + except asyncio.CancelledError: + coro_was_cancelled = True + raise + + # Replace the background task with one that will fail soon + session_context._close_event.set() + await asyncio.sleep(0.05) + + crash_error = RuntimeError('connection lost') + + async def crashing_task(): + await asyncio.sleep(0.05) + raise crash_error + + session_context._task = asyncio.create_task(crashing_task()) + + with pytest.raises(ConnectionError): + await session_context.run_guarded(slow_coro()) + + assert coro_was_cancelled + + @pytest.mark.asyncio + async def test_run_guarded_coroutine_raises(self): + """Test run_guarded propagates coroutine exceptions unwrapped.""" + mock_client = MockClient() + mock_session = MockClientSession() + + with patch( + 'google.adk.tools.mcp_tool.session_context.ClientSession' + ) as mock_session_class: + mock_session_class.return_value = mock_session + + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + await session_context.start() + + original_error = ValueError('invalid arguments') + + async def failing_coro(): + raise original_error + + with pytest.raises(ValueError) as exc_info: + await session_context.run_guarded(failing_coro()) + + assert exc_info.value is original_error + + await session_context.close() + + @pytest.mark.asyncio + async def test_run_guarded_no_task(self): + """Test run_guarded raises when task has not been started.""" + mock_client = MockClient() + session_context = SessionContext( + mock_client, timeout=5.0, sse_read_timeout=None + ) + + async def my_coro(): + return 'hello' + + coro = my_coro() + with pytest.raises(ConnectionError, match='not been started'): + await session_context.run_guarded(coro)