diff --git a/agents/python/src/lib/mcp_integration.py b/agents/python/src/lib/mcp_integration.py index 3c8524e..6faf85d 100644 --- a/agents/python/src/lib/mcp_integration.py +++ b/agents/python/src/lib/mcp_integration.py @@ -128,8 +128,11 @@ async def reconnect(self): await self.initialize() logger.info(f"Reconnected successfully (session: {self.session_id[:8]}...)") - async def _send(self, method: str, params: dict = None) -> dict: - """Send JSON-RPC message to server and wait for response via SSE.""" + async def _send(self, method: str, params: dict = None, _retry: bool = True) -> dict: + """Send JSON-RPC message to server and wait for response via SSE. + + Automatically reconnects and retries once on session expiration. + """ if not self.session_id: raise RuntimeError("Not connected. Call connect() first.") @@ -153,25 +156,38 @@ async def _send(self, method: str, params: dict = None) -> dict: try: error_data = resp.json() error_msg = error_data.get("error", error_text) - - if resp.status_code == 410: - logger.warning("Session expired (410), needs reconnection") - raise SessionExpiredException( - "Session expired or not found. Reconnection required." - ) except json.JSONDecodeError: error_msg = error_text - if resp.status_code == 410: + + # Handle session expiration: 410 (Gone) or 404 with session-related message + # The MCP SDK returns 404 for expired sessions, not 410 + is_session_error = resp.status_code == 410 or ( + resp.status_code == 404 and "session" in error_msg.lower() + ) + + if is_session_error: + self._responses.pop(msg_id, None) + if _retry: + logger.warning(f"Session expired ({resp.status_code}), reconnecting...") + await self.reconnect() + return await self._send(method, params, _retry=False) + else: raise SessionExpiredException( - "Session expired or not found. Reconnection required." + "Session expired or not found. Reconnection failed." ) raise RuntimeError( f"HTTP {resp.status_code} from server: {error_msg}" ) except httpx.HTTPStatusError as e: - if e.response.status_code == 410: - raise SessionExpiredException("Session expired. Reconnection required.") + self._responses.pop(msg_id, None) + is_session_error = e.response.status_code in (404, 410) + if is_session_error and _retry: + logger.warning(f"Session expired ({e.response.status_code}), reconnecting...") + await self.reconnect() + return await self._send(method, params, _retry=False) + elif is_session_error: + raise SessionExpiredException("Session expired. Reconnection failed.") raise RuntimeError(f"HTTP error {e.response.status_code}: {e.response.text}") try: @@ -212,60 +228,51 @@ async def _get_mcp_client() -> SimpleMCPClient: return _mcp_client -async def _call_mcp_tool(tool_name: str, arguments: Dict[str, Any], retry_count: int = 1) -> Any: +async def _call_mcp_tool(tool_name: str, arguments: Dict[str, Any]) -> Any: """ - Call MCP server tool with session management and automatic retry. + Call MCP server tool with session management. + + Session reconnection is handled automatically by the client's _send method. Args: tool_name: Name of the MCP tool to call (e.g., "knowledge_search") arguments: Tool arguments as dict - retry_count: Number of retries on session expiry (default: 1) Returns: Tool result from MCP server """ logger.info(f"Calling MCP tool: {tool_name}") - for attempt in range(retry_count + 1): - try: - client = await _get_mcp_client() - result = await client.call_tool(tool_name, arguments) - - logger.info(f"MCP tool call succeeded: {tool_name}") - - # MCP protocol returns results in result.content array - if "result" in result and "content" in result["result"]: - content = result["result"]["content"] - if isinstance(content, list) and len(content) > 0: - first_content = content[0] - if isinstance(first_content, dict) and "text" in first_content: - text = first_content["text"] - if text and text.strip(): - try: - return json.loads(text) - except json.JSONDecodeError: - return text - return first_content - return content - - return result.get("result", {}) - - except SessionExpiredException: - if attempt < retry_count: - logger.warning(f"Session expired, retrying ({attempt + 1}/{retry_count})...") - global _mcp_client - if _mcp_client: - await _mcp_client.reconnect() - continue - else: - logger.error(f"Session expired after {retry_count} retries") - raise - - except Exception as e: - logger.error(f"Failed to call MCP tool {tool_name}: {e}") - return None - - return None + try: + client = await _get_mcp_client() + result = await client.call_tool(tool_name, arguments) + + logger.info(f"MCP tool call succeeded: {tool_name}") + + # MCP protocol returns results in result.content array + if "result" in result and "content" in result["result"]: + content = result["result"]["content"] + if isinstance(content, list) and len(content) > 0: + first_content = content[0] + if isinstance(first_content, dict) and "text" in first_content: + text = first_content["text"] + if text and text.strip(): + try: + return json.loads(text) + except json.JSONDecodeError: + return text + return first_content + return content + + return result.get("result", {}) + + except SessionExpiredException: + logger.error(f"Session expired and reconnection failed for tool: {tool_name}") + raise + + except Exception as e: + logger.error(f"Failed to call MCP tool {tool_name}: {e}") + return None async def search_knowledge_base(