From b89bf033b84160cf1419718afce1d8b878ae1cf6 Mon Sep 17 00:00:00 2001 From: Mark Widman Date: Tue, 24 Feb 2026 15:41:25 -0500 Subject: [PATCH 1/3] fix(mcp): harden gateway connection stability and code quality MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Migrate _check_pieces_os_status from threading.Lock to asyncio.Lock with asyncio.to_thread() to avoid blocking the event loop - Replace private session._received_notification monkey-patch with public ClientSession(message_handler=...) API - Capture and log upstream session ID from streamable HTTP transport - Modernize type hints to Python 3.11+ (X | None, list[], tuple[]) - Add type annotations to all nested functions and closures - Clean up unused imports across source and test files - Reorder imports to follow PEP 8 (stdlib → third-party → local) - Surface all errors with actionable CLI commands (pieces open/restart) - Add comprehensive docstrings documenting two-phase health check, connection lifecycle, and message handler behavior - Add developer documentation (documentation/mcp.md) - Add 29 new unit tests covering bug fixes and enhancements (68 total) --- documentation/mcp.md | 144 ++++ src/pieces/mcp/__init__.py | 23 + src/pieces/mcp/gateway.py | 673 +++++++++++------- src/pieces/mcp/handler.py | 49 +- src/pieces/mcp/integration.py | 29 +- src/pieces/mcp/tools_cache.py | 11 +- src/pieces/mcp/utils.py | 92 ++- tests/mcps/mcp_gateway/test_bug_fixes.py | 559 +++++++++++++++ tests/mcps/mcp_gateway/test_integration.py | 2 +- .../mcp_gateway/test_validation_advanced.py | 54 +- .../mcps/mcp_gateway/test_validation_core.py | 60 +- tests/mcps/mcp_gateway/utils.py | 8 +- 12 files changed, 1383 insertions(+), 321 deletions(-) create mode 100644 documentation/mcp.md create mode 100644 tests/mcps/mcp_gateway/test_bug_fixes.py diff --git a/documentation/mcp.md b/documentation/mcp.md new file mode 100644 index 00000000..ef084459 --- /dev/null +++ b/documentation/mcp.md @@ -0,0 +1,144 @@ +# MCP Gateway — Developer Documentation + +## Overview + +The MCP (Model Context Protocol) gateway enables IDE clients like Claude Desktop, +Cursor, and VS Code to communicate with PiecesOS tools through a standardized protocol. + +## Architecture + +``` + stdio streamable HTTP + IDE Client <--------> CLI Gateway <--------------------> PiecesOS + (Claude, (gateway.py) (MCP Server) + Cursor, etc.) +``` + +### Components + +| Component | File | Role | +| -------------------- | ---------------- | -------------------------------------- | +| Gateway Server | `gateway.py` `MCPGateway` | stdio server, routes requests | +| Upstream Connection | `gateway.py` `PosMcpConnection` | manages PiecesOS connection | +| URL Resolution | `utils.py` | schema version selection, caching | +| CLI Handlers | `handler.py` | `pieces mcp` subcommands | +| IDE Integrations | `integration.py` | config file management | +| Fallback Tools | `tools_cache.py` | offline tool definitions | + +## MCP Schema Versions + +| Version | Transport | Endpoint Pattern | Status | +| ------------ | --------------- | ------------------------------------------------- | --------- | +| `2024-11-05` | SSE | `/model_context_protocol/2024-11-05/sse` | Legacy | +| `2025-03-26` | Streamable HTTP | `/model_context_protocol/2025-03-26/mcp` | Preferred | + +The gateway prefers `2025-03-26` (streamable HTTP) for upstream connections because: + +- Request-response model is more robust than long-lived SSE connections +- No connection degradation over 30–45 minute sessions +- Better error recovery and reconnection behavior + +IDE integration configs (written to config files) still use `2024-11-05` SSE URLs +because IDEs connect directly to PiecesOS, not through the CLI gateway. + +## Connection Lifecycle + +1. **Startup**: `main()` initializes WebSockets (Health, Auth, LTM Vision), + resolves the upstream URL, creates the gateway. Blocking SDK calls during + health checks are offloaded to threads via `asyncio.to_thread()`. +2. **Connect**: `PosMcpConnection.connect()` creates a background task running + `_connection_handler`, which enters the transport context manager and + establishes a `ClientSession` +3. **Tool Discovery**: `update_tools()` fetches available tools, detects changes + via SHA-256 hashing, and notifies the IDE client +4. **Request Proxying**: IDE sends `tools/call` via stdio → gateway forwards to + PiecesOS via streamable HTTP → result returned to IDE +5. **Reconnection**: On connection failure, `connect()` cleans up stale state and + creates a new connection task. URL cache is invalidated to handle PiecesOS restarts. +6. **Shutdown**: Signal handlers set `shutdown_event`, gateway cancels, WebSockets + close, upstream connection cleans up + +## URL Caching + +Schema version URLs are cached in `utils._latest_urls` to avoid repeated API calls. +The cache is invalidated when: + +- The upstream URL is `None` (PiecesOS wasn't running at startup) +- PiecesOS transitions from down to up (health check succeeds after failure) +- A connection is cleaned up (URL may be stale after disconnect) + +## Validation Pipeline + +Before every tool call, `_validate_system_status()` runs 4 checks: + +1. **PiecesOS health** (via WebSocket) +2. **Version compatibility** (CLI vs PiecesOS) +3. **User authentication** +4. **LTM status** (for LTM-specific tools only) + +Each check returns an actionable error message if it fails. + +## Error Surfacing Philosophy + +This is a CLI — errors must be surfaced to the user with actionable remediation +steps, not silently swallowed. The decision tree for every caught exception: + +1. **Can we auto-retry?** → Retry with backoff +2. **Retry succeeded?** → Continue normally +3. **Retry failed or not retryable?** → Surface actionable error to user + +Connection errors are stored in `_last_connection_error` so that the next +`call_tool` invocation includes them in the user-facing response. Tool call +errors include the specific exception type and message alongside remediation +commands like `pieces restart` or `pieces open`. + +## Troubleshooting + +| Symptom | Likely Cause | Fix | +| ------------------------------------------ | ----------------------------------------- | -------------------------------------------- | +| "PiecesOS is not running" | PiecesOS crashed or not started | `pieces open` | +| Connection degrades after 30–45 min | Using SSE instead of streamable HTTP | Ensure `PREFERRED_SCHEMA_VERSION = "2025-03-26"` | +| Stale tools after PiecesOS restart | URL cache not invalidated | Restart the CLI gateway | +| "Cannot get MCP upstream URL" | PiecesOS not reachable | Check PiecesOS is running, check port | +| Tool calls timeout | PiecesOS overloaded or network issues | `pieces restart` | +| "Timed out connecting to PiecesOS" | PiecesOS overloaded or restarting | `pieces restart` | +| "Connection to PiecesOS was lost" | PiecesOS shut down unexpectedly | `pieces open` | +| "PiecesOS sent a malformed response" | Version mismatch | `pieces update` then `pieces restart` | + +## Async Health Check + +`_check_pieces_os_status` is fully async. Blocking SDK calls (health WebSocket +start, user snapshot, LTM status, etc.) are offloaded via `asyncio.to_thread()` +so the event loop is never stalled. An `asyncio.Lock` guards the fast-path +check to prevent redundant health probes. + +## Notification Handling + +The gateway uses the SDK's public `ClientSession(message_handler=...)` API to +receive upstream notifications. Only `ToolListChangedNotification` triggers +tool re-discovery and an IDE notification; all other message types are ignored +(the SDK handles them internally). + +## Session ID Tracking + +The `get_session_id` callable returned by `streamablehttp_client` is captured +and logged at connection establishment and included in Sentry breadcrumbs. This +aids debugging without adding functional complexity. The session ID is cleared +on connection cleanup. + +## Known Limitations + +- `_upstream_session_id` is logged for observability but not yet used for + session resumption or reconnection. + +## Testing + +Tests live in `tests/mcps/mcp_gateway/`. Key test files: + +| File | Description | +| ----------------------------- | -------------------------------------------------- | +| `test_bug_fixes.py` | Unit tests for connection lifecycle bugs | +| `test_validation_core.py` | System status validation tests | +| `test_validation_advanced.py` | Concurrency, edge cases, performance | +| `test_integration.py` | Integration tests (requires PiecesOS running) | +| `test_e2e.py` | End-to-end subprocess tests (requires PiecesOS) | diff --git a/src/pieces/mcp/__init__.py b/src/pieces/mcp/__init__.py index 2edabad1..b0624605 100644 --- a/src/pieces/mcp/__init__.py +++ b/src/pieces/mcp/__init__.py @@ -1,3 +1,26 @@ +""" +Pieces MCP (Model Context Protocol) Gateway. + +This package implements a gateway between IDE clients (via stdio) and PiecesOS +(via streamable HTTP / SSE). It handles: + +- Connection lifecycle management (connect, reconnect, cleanup) +- Tool discovery and proxying (list_tools, call_tool) +- IDE integration configuration (Claude, Cursor, VS Code, etc.) +- Health monitoring and validation (PiecesOS status, version compat, LTM) + +Architecture:: + + IDE <--stdio--> MCPGateway <--streamable HTTP--> PiecesOS + +Key modules: + gateway - Core gateway server and upstream connection management + utils - URL resolution and schema version selection + handler - CLI command handlers for ``pieces mcp`` subcommands + integration - IDE-specific configuration file management + tools_cache - Fallback tool definitions when PiecesOS is offline +""" + from .handler import handle_mcp, handle_mcp_docs, handle_repair, handle_status from .list_mcp import handle_list, handle_list_headless from .gateway import main as handle_gateway diff --git a/src/pieces/mcp/gateway.py b/src/pieces/mcp/gateway.py index ea6f5222..688598c1 100644 --- a/src/pieces/mcp/gateway.py +++ b/src/pieces/mcp/gateway.py @@ -1,15 +1,34 @@ +""" +MCP Gateway -- stdio-to-PiecesOS bridge. + +Implements the gateway that sits between IDE clients (communicating via stdio) +and PiecesOS (communicating via streamable HTTP or SSE). The gateway proxies +tool discovery and tool calls, handles reconnection, and validates system +status before each operation. + +Architecture:: + + IDE (Claude, Cursor, ...) <--stdio--> MCPGateway <--streamable HTTP--> PiecesOS + +Key classes: + PosMcpConnection -- manages the upstream connection to PiecesOS + MCPGateway -- stdio server that routes IDE requests to PiecesOS +""" + import asyncio import hashlib -from pydantic import ValidationError -import sentry_sdk +import re import signal import threading -from typing import Tuple, Callable, Awaitable -import httpx +from typing import Any, Awaitable, Callable + import httpcore +import httpx +import sentry_sdk +from pydantic import ValidationError from websocket import WebSocketConnectionClosedException -from pieces.mcp.utils import get_mcp_latest_url +from pieces.mcp.utils import get_mcp_latest_url, invalidate_mcp_url_cache from pieces.mcp.tools_cache import PIECES_MCP_TOOLS_CACHE from pieces.settings import Settings from .._vendor.pieces_os_client.wrapper.version_compatibility import ( @@ -19,7 +38,7 @@ from .._vendor.pieces_os_client.wrapper.websockets.health_ws import HealthWS from .._vendor.pieces_os_client.wrapper.websockets.ltm_vision_ws import LTMVisionWS from .._vendor.pieces_os_client.wrapper.websockets.auth_ws import AuthWS -from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client from mcp import ClientSession from mcp.server import Server import mcp.server.stdio @@ -29,48 +48,70 @@ class PosMcpConnection: - """Manages connection to the Pieces MCP server.""" + """Manages the upstream connection to the PiecesOS MCP server. + + Handles the full connection lifecycle: establishing a streamable HTTP + transport, creating a ``ClientSession``, discovering tools, handling + reconnection on failure, and cleaning up resources. + + Thread-safety: + * ``connection_lock`` (asyncio.Lock) serialises connect/cleanup calls. + * ``_health_check_lock`` (asyncio.Lock) guards the async PiecesOS + health check. Blocking SDK calls are wrapped in + ``asyncio.to_thread()`` to avoid stalling the event loop. + + Error surfacing: + Connection errors are stored in ``_last_connection_error`` so that + the next ``call_tool`` invocation can include them in the user-facing + response instead of returning a generic message. + """ def __init__( self, upstream_url: str, tools_changed_callback: Callable[[], Awaitable[None]] ): - self.upstream_url = ( - upstream_url # Can be None if PiecesOS wasn't running at startup - ) - self.CONNECTION_ESTABLISH_ATTEMPTS = 100 - self.CONNECTION_CHECK_INTERVAL = 0.1 - self.session = None - self.sse_client = None - self.discovered_tools = [] - self.connection_lock = asyncio.Lock() - self._pieces_os_running = None - self._ltm_enabled = None - self.result = None - self._previous_tools_hash = None - self._tools_changed_callback = tools_changed_callback - self._health_check_lock = threading.Lock() - - # Add cleanup coordination - self._cleanup_requested = asyncio.Event() - self._connection_task = None + self.upstream_url: str | None = upstream_url + self.CONNECTION_ESTABLISH_ATTEMPTS: int = 100 + self.CONNECTION_CHECK_INTERVAL: float = 0.1 + self.session: ClientSession | None = None + self._transport_ctx: object | None = None + self.discovered_tools: list[types.Tool] = [] + self.connection_lock: asyncio.Lock = asyncio.Lock() + self._pieces_os_running: bool | None = None + self._ltm_enabled: bool | None = None + self.result: object | None = None + self._previous_tools_hash: str | None = None + self._tools_changed_callback: Callable[[], Awaitable[None]] = tools_changed_callback + self._health_check_lock: asyncio.Lock = asyncio.Lock() + self._last_connection_error: str | None = None + self._upstream_session_id: str | None = None + + self._cleanup_requested: asyncio.Event = asyncio.Event() + self._connection_task: asyncio.Task | None = None + + def _try_get_upstream_url(self) -> bool: + """Try to resolve the upstream URL if we don't have one yet. + + Invalidates the URL cache before fetching so that a PiecesOS restart + on a different port is handled correctly. - def _try_get_upstream_url(self): - """Try to get the upstream URL if we don't have it yet.""" + Returns: + True if ``self.upstream_url`` is set, False otherwise. + """ if self.upstream_url is None: if Settings.pieces_client.is_pieces_running(): try: + invalidate_mcp_url_cache() self.upstream_url = get_mcp_latest_url() return True - except: # noqa: E722 - pass + except Exception as e: + Settings.logger.warning(f"Failed to get MCP upstream URL: {e}") return False return True - def request_cleanup(self): + def request_cleanup(self) -> None: """Request cleanup from exception handler (thread-safe).""" Settings.logger.debug("Cleanup requested from exception handler") - # Use asyncio's thread-safe method to schedule cleanup loop = None try: loop = asyncio.get_running_loop() @@ -79,31 +120,27 @@ def request_cleanup(self): return if loop and not loop.is_closed(): - # Schedule cleanup in the event loop loop.call_soon_threadsafe(self._schedule_cleanup) - def _schedule_cleanup(self): + def _schedule_cleanup(self) -> None: """Internal method to schedule cleanup in the event loop.""" - # Set cleanup event self._cleanup_requested.set() - # Cancel connection task if it exists if self._connection_task and not self._connection_task.done(): self._connection_task.cancel() Settings.logger.debug("Connection task cancelled due to cleanup request") - async def _cleanup_stale_session(self): - """Clean up a stale session and its resources.""" - # Store references to avoid race conditions - session = self.session - sse_client = self.sse_client + async def _cleanup_stale_session(self) -> None: + """Clean up a stale session and its resources. - # Clear instance variables immediately - self.session = None - self.sse_client = None - self.discovered_tools = [] + Ordering guarantee: ``__aexit__`` is called on the session and + transport context *before* the instance variables are nullified. + This prevents a concurrent ``connect()`` from seeing ``None`` and + starting a new connection while old resources are still tearing down. + """ + session = self.session + transport_ctx = self._transport_ctx - # Clean up session if it exists if session: try: await session.__aexit__(None, None, None) @@ -111,20 +148,26 @@ async def _cleanup_stale_session(self): except Exception as e: Settings.logger.debug(f"Error cleaning up session: {e}") - # Clean up SSE client if it exists - if sse_client: + if transport_ctx: try: - await sse_client.__aexit__(None, None, None) - Settings.logger.debug("SSE client cleaned up successfully") + await transport_ctx.__aexit__(None, None, None) + Settings.logger.debug("Transport context cleaned up successfully") except Exception as e: - Settings.logger.debug(f"Error cleaning up SSE client: {e}") + Settings.logger.debug(f"Error cleaning up transport context: {e}") - def _check_version_compatibility(self) -> Tuple[bool, str]: - """ - Check if the PiecesOS version is compatible with the MCP server. + self.session = None + self._transport_ctx = None + self._upstream_session_id = None + self.discovered_tools = [] + + invalidate_mcp_url_cache() + + def _check_version_compatibility(self) -> tuple[bool, str]: + """Check if the PiecesOS version is compatible with the MCP server. Returns: - Tuple[bool, str]: A tuple containing a boolean indicating compatibility, str: message if it is not compatible. + ``(True, "")`` if compatible, ``(False, message)`` otherwise. + The message includes actionable remediation steps for the user. """ version = Settings.pieces_client.version if version == "debug": @@ -139,7 +182,6 @@ def _check_version_compatibility(self) -> Tuple[bool, str]: if self.result.compatible: return True, "" - # These messages are sent to the llm to update the respective tool if self.result.update == UpdateEnum.Plugin: return ( False, @@ -151,72 +193,88 @@ def _check_version_compatibility(self) -> Tuple[bool, str]: "Please update PiecesOS to a compatible version to be able to run the tool call. Run 'pieces update' to get the latest version. Then retry your request again after updating.", ) - def _check_pieces_os_status(self): - """Check if PiecesOS is running using health WebSocket""" - with self._health_check_lock: - # First check if already connected + async def _check_pieces_os_status(self) -> bool: + """Check if PiecesOS is running and initialise health state. + + Two-phase check: + 1. **Fast path** (under ``_health_check_lock``): if the health + WebSocket is already running, return immediately. + 2. **Slow path**: call blocking SDK methods (``is_pieces_running``, + ``health_ws.start``, ``user_snapshot``, etc.) via + ``asyncio.to_thread()`` so the event loop is never stalled. + + Returns: + True if PiecesOS is healthy and reachable, False otherwise. + """ + async with self._health_check_lock: if HealthWS.is_running() and Settings.pieces_client.is_pos_stream_running: return True - # Check if PiecesOS is available - if not Settings.pieces_client.is_pieces_running(2): - return False + is_running = await asyncio.to_thread(Settings.pieces_client.is_pieces_running, 2) + if not is_running: + return False - try: - health_ws = HealthWS.get_instance() - if health_ws: - health_ws.start() + try: + health_ws = HealthWS.get_instance() + if health_ws: + await asyncio.to_thread(health_ws.start) - sentry_sdk.set_user({"id": Settings.get_os_id() or "unknown"}) + os_id = await asyncio.to_thread(Settings.get_os_id) + sentry_sdk.set_user({"id": os_id or "unknown"}) - # Update the user profile cache - Settings.pieces_client.user.user_profile = ( - Settings.pieces_client.user_api.user_snapshot().user - ) + snapshot = await asyncio.to_thread( + Settings.pieces_client.user_api.user_snapshot + ) + Settings.pieces_client.user.user_profile = snapshot.user - # Update LTM status cache - Settings.pieces_client.copilot.context.ltm.ltm_status = Settings.pieces_client.work_stream_pattern_engine_api.workstream_pattern_engine_processors_vision_status() - return True - except Exception as e: - Settings.logger.debug(f"Failed to start health WebSocket: {e}") - return False + ltm_status = await asyncio.to_thread( + Settings.pieces_client.work_stream_pattern_engine_api + .workstream_pattern_engine_processors_vision_status + ) + Settings.pieces_client.copilot.context.ltm.ltm_status = ltm_status + + invalidate_mcp_url_cache() + return True + except Exception as e: + Settings.logger.warning( + f"PiecesOS appears to be running but health check failed: {e}" + ) + return False - def _check_ltm_status(self): + def _check_ltm_status(self) -> bool: """Check if LTM is enabled.""" return Settings.pieces_client.copilot.context.ltm.is_enabled - def _validate_system_status(self, tool_name: str) -> tuple[bool, str]: - """ - Perform 4-step validation before executing any command: - 1. Check health WebSocket - 2. Check compatibility - 3. Check Auth - 4. Check LTM (for LTM tools) + async def _validate_system_status(self, tool_name: str) -> tuple[bool, str]: + """Perform 4-step validation before executing any tool. + + Steps: + 1. Check PiecesOS health (via WebSocket) + 2. Check version compatibility (CLI vs PiecesOS) + 3. Check user authentication + 4. Check LTM status (for LTM-specific tools only) Returns: - tuple[bool, str]: (is_valid, error_message) + ``(is_valid, error_message)`` -- error_message includes actionable + remediation steps when ``is_valid`` is False. """ - # Step 1: Check health WebSocket / PiecesOS status - if not self._check_pieces_os_status(): + if not await self._check_pieces_os_status(): return False, ( "PiecesOS is not running. To use this tool, please run:\n\n" "`pieces open`\n\n" "This will start PiecesOS, then you can retry your request." ) - # Step 2: Check version compatibility is_compatible, compatibility_message = self._check_version_compatibility() if not is_compatible: return False, compatibility_message - # Step 3: Check Auth status if not Settings.pieces_client.user.user_profile: return False, ( "User must sign up to use this tool, please run:\n\n`pieces login`\n\n" "This will open the authentication page in your browser. After signing in, you can retry your request." ) - # Step 4: Check LTM status (only for LTM-related tools) if tool_name in ["ask_pieces_ltm", "create_pieces_memory"]: ltm_enabled = self._check_ltm_status() if not ltm_enabled: @@ -227,46 +285,45 @@ def _validate_system_status(self, tool_name: str) -> tuple[bool, str]: "This will enable LTM, then you can retry your request." ) - # All checks passed return True, "" - def _get_error_message_for_tool(self, tool_name: str) -> str: - """Get appropriate error message based on the tool and system status.""" - # Use the 3-step validation system - is_valid, error_message = self._validate_system_status(tool_name) + async def _get_error_message_for_tool(self, tool_name: str) -> str: + """Get an actionable error message based on the tool and system status. + + Checks validation first, then falls back to ``_last_connection_error`` + if available, and finally returns a generic message with remediation. + """ + is_valid, error_message = await self._validate_system_status(tool_name) if not is_valid: return error_message + tool_name = self._sanitize_tool_name(tool_name) - # If all validations pass but we still have an error, return generic message + base_msg = f"Unable to execute '{tool_name}' tool." + + if self._last_connection_error: + return f"{base_msg}\n\n{self._last_connection_error}" return ( - f"Unable to execute '{tool_name}' tool. Please ensure PiecesOS is running " + f"{base_msg} Please ensure PiecesOS is running " "and try again. If the problem persists, run:\n\n" "`pieces restart`" ) def _sanitize_tool_name(self, tool_name: str) -> str: """Sanitize tool name for safe inclusion in messages.""" - import re - - # Remove control characters and limit length - sanitized = re.sub(r"[^\w\s\-_.]", "", tool_name) - return sanitized[:100] # Limit to reasonable length + sanitized: str = re.sub(r"[^\w\s\-_.]", "", tool_name) + return sanitized[:100] - def _get_tools_hash(self, tools): - """Generate a hash of the tools list for change detection.""" + def _get_tools_hash(self, tools: list[types.Tool]) -> str | None: + """Generate a stable SHA-256 hash of the tools list for change detection.""" if not tools: return None - # Create a stable hash using SHA256 hasher = hashlib.sha256() - - # Sort tools by name for consistency sorted_tools = sorted(tools, key=lambda t: t.name) for tool in sorted_tools: - # Use truncated description to catch content changes while avoiding memory issues description = tool.description or "" truncated_desc = ( description[:200] if len(description) > 200 else description @@ -276,11 +333,10 @@ def _get_tools_hash(self, tools): return hasher.hexdigest() - def _tools_have_changed(self, new_tools): + def _tools_have_changed(self, new_tools: list[types.Tool]) -> bool: """Check if the tools have changed since last check.""" new_hash = self._get_tools_hash(new_tools) if self._previous_tools_hash is None: - # First time, consider as changed if we have tools self._previous_tools_hash = new_hash return bool(new_tools) @@ -292,7 +348,7 @@ def _tools_have_changed(self, new_tools): return True return False - async def update_tools(self, session, send_notification: bool = True): + async def update_tools(self, session: ClientSession, send_notification: bool = True) -> None: """Fetch tools from the session and handle change detection.""" try: self.tools = await session.list_tools() @@ -300,12 +356,9 @@ async def update_tools(self, session, send_notification: bool = True): tool[1] for tool in self.tools if tool[0] == "tools" ][0] - # Check if tools have changed tools_changed = self._tools_have_changed(new_discovered_tools) - # Clean up old tool data if changed if tools_changed and self.discovered_tools: - # Clear references to old tools to prevent memory buildup self.discovered_tools.clear() self.discovered_tools = new_discovered_tools @@ -314,42 +367,73 @@ async def update_tools(self, session, send_notification: bool = True): f"Discovered {len(self.discovered_tools)} tools from upstream server" ) - # If tools changed, call the callback if send_notification and tools_changed: try: Settings.logger.info("Tools have changed - sending notification") await self._tools_changed_callback() except Exception as e: - Settings.logger.error(f"Error in tools changed callback: {e}") + Settings.logger.error( + f"Failed to notify IDE of tool changes: {e}. " + "Tools will be updated on the next request." + ) except Exception as e: Settings.logger.error(f"Error fetching tools: {e}", exc_info=True) raise - async def _connection_handler(self, send_notification: bool = True): - """Handle the connection lifecycle in a single task context.""" + async def _connection_handler(self, send_notification: bool = True) -> None: + """Handle the connection lifecycle in a single task context. + + Stages: + 1. Enter the streamable HTTP transport context manager and capture + the upstream session ID for observability + 2. Create and enter a ``ClientSession`` with a public + ``message_handler`` for notification handling + 3. Discover tools and set up notification handlers + 4. Wait for cleanup signal or cancellation + 5. Clean up resources in ``finally`` (orphaned + stale) + + Connection errors are stored in ``_last_connection_error`` so that + the next ``call_tool`` can surface them to the user. + """ + transport_ctx = None + session_obj = None try: Settings.logger.info( f"Connecting to upstream MCP server at {self.upstream_url}" ) - # Enter SSE client context - self.sse_client = sse_client(self.upstream_url) - read_stream, write_stream = await self.sse_client.__aenter__() + transport_ctx = streamablehttp_client(self.upstream_url) + read_stream, write_stream, get_session_id = await transport_ctx.__aenter__() + self._transport_ctx = transport_ctx + + self._upstream_session_id = get_session_id() + Settings.logger.debug( + f"Transport session ID: {self._upstream_session_id or '(not yet assigned)'}" + ) - # Enter session context - session = ClientSession(read_stream, write_stream) + session_obj = ClientSession( + read_stream, + write_stream, + message_handler=self._make_message_handler(send_notification), + ) Settings.logger.info("Connecting to the client session") - await session.__aenter__() - self.session = session + await session_obj.__aenter__() + self.session = session_obj - # Update tools and setup notifications - await self.update_tools(session, send_notification) - await self.setup_notification_handler(session) + self._last_connection_error = None - Settings.logger.info("Connection established successfully") + await self.update_tools(session_obj, send_notification) + + upstream_sid = get_session_id() + if upstream_sid: + self._upstream_session_id = upstream_sid + + Settings.logger.info( + f"Connection established successfully " + f"(session_id={self._upstream_session_id or 'N/A'})" + ) - # Add Sentry breadcrumb for successful MCP connection sentry_sdk.add_breadcrumb( message="MCP connection established", category="mcp", @@ -357,10 +441,10 @@ async def _connection_handler(self, send_notification: bool = True): data={ "upstream_url": self.upstream_url, "tools_count": len(self.discovered_tools), + "session_id": self._upstream_session_id, }, ) - # Keep connection alive until cleanup is requested or cancelled try: await self._cleanup_requested.wait() Settings.logger.debug("Cleanup requested, shutting down connection") @@ -378,12 +462,14 @@ async def _connection_handler(self, send_notification: bool = True): httpcore.ReadTimeout, httpcore.ConnectTimeout, ) as e: - # Handle SSE timeout errors gracefully without sending to Sentry - Settings.logger.info( - f"SSE connection timeout (expected for long-running connections): {type(e).__name__}" + self._last_connection_error = ( + f"Connection to PiecesOS timed out ({type(e).__name__}). " + "This may happen if PiecesOS is overloaded or restarting. " + "Try again, or run `pieces restart` if the issue persists." ) + Settings.logger.info(f"Connection timeout: {type(e).__name__}") sentry_sdk.add_breadcrumb( - message="SSE connection timeout handled", + message="MCP connection timeout handled", category="mcp", level="info", data={ @@ -391,15 +477,18 @@ async def _connection_handler(self, send_notification: bool = True): "upstream_url": self.upstream_url, }, ) - # Don't re-raise - this is a normal part of SSE connection lifecycle return except (httpx.RemoteProtocolError, httpcore.RemoteProtocolError) as e: - # Handle protocol errors gracefully + self._last_connection_error = ( + f"Connection to PiecesOS was interrupted ({type(e).__name__}). " + "PiecesOS may have restarted. Retrying should reconnect automatically. " + "If this persists, run `pieces restart`." + ) Settings.logger.info( - f"SSE protocol error (connection interrupted): {type(e).__name__}" + f"Protocol error (connection interrupted): {type(e).__name__}" ) sentry_sdk.add_breadcrumb( - message="SSE protocol error handled", + message="MCP protocol error handled", category="mcp", level="info", data={ @@ -407,14 +496,18 @@ async def _connection_handler(self, send_notification: bool = True): "upstream_url": self.upstream_url, }, ) - # Don't re-raise - this is expected when connections are interrupted return except BrokenPipeError as e: + self._last_connection_error = ( + "Connection to PiecesOS was lost (broken pipe). " + "This usually means PiecesOS shut down unexpectedly. " + "Run `pieces open` to restart it, then retry." + ) Settings.logger.info( - "SSE stream resource broken (connection closed during send)" + "Stream resource broken (connection closed during send)" ) sentry_sdk.add_breadcrumb( - message="SSE stream resource broken handled", + message="MCP stream resource broken handled", category="mcp", level="info", data={ @@ -424,6 +517,10 @@ async def _connection_handler(self, send_notification: bool = True): ) return except ValidationError as e: + self._last_connection_error = ( + "PiecesOS sent a malformed response. This may indicate a version mismatch. " + "Run `pieces update` to ensure you have the latest version, then `pieces restart`." + ) Settings.logger.info( "MCP server sent malformed JSON-RPC message (validation failed)" ) @@ -438,63 +535,86 @@ async def _connection_handler(self, send_notification: bool = True): ) return except Exception as e: + self._last_connection_error = ( + f"Unexpected connection error: {type(e).__name__}: {e}. " + "Run `pieces restart` if this persists." + ) Settings.logger.error(f"Error in connection handler: {e}", exc_info=True) raise finally: - # Cleanup happens in the same task context where __aenter__ was called + if session_obj and self.session is None: + try: + await session_obj.__aexit__(None, None, None) + except Exception as cleanup_err: + Settings.logger.debug(f"Error cleaning up orphaned session: {cleanup_err}") + if transport_ctx and self._transport_ctx is None: + try: + await transport_ctx.__aexit__(None, None, None) + except Exception as cleanup_err: + Settings.logger.debug(f"Error cleaning up orphaned transport: {cleanup_err}") await self._cleanup_stale_session() Settings.logger.debug("Connection handler cleanup completed") - async def connect(self, send_notification: bool = True): - """Ensures a connection to the POS server exists and returns it.""" + async def connect(self, send_notification: bool = True) -> ClientSession: + """Ensure a connection to PiecesOS exists and return the session. + + Uses a polling loop to wait for ``_connection_handler`` to set + ``self.session``. The loop also checks for early task failure so + that errors propagate immediately instead of waiting for the full + 10-second timeout. + + Returns: + The active ``ClientSession``. + + Raises: + ValueError: If the upstream URL cannot be resolved. + TimeoutError: If the connection is not established within 10 s. + """ async with self.connection_lock: - # Check if we have a valid existing connection + session = self.session if ( - self.session is not None + session is not None and self._connection_task and not self._connection_task.done() ): try: - await self.session.send_ping() + await session.send_ping() Settings.logger.debug("Using existing upstream connection") - return self.session + return session except Exception as e: Settings.logger.debug( f"Existing connection is stale: {e}, creating new connection" ) - # Fall through to create new connection - # Clean up any existing connection state await self._ensure_clean_state() - # Try to get upstream URL if we don't have it if not self._try_get_upstream_url(): raise ValueError( - "Cannot get MCP upstream URL - PiecesOS may not be running" + "Cannot get MCP upstream URL - PiecesOS may not be running. " + "Run `pieces open` to start PiecesOS." ) try: Settings.logger.info("Creating new connection to upstream server") - # Reset cleanup event for new connection self._cleanup_requested.clear() - # Start connection in a dedicated task self._connection_task = asyncio.create_task( self._connection_handler(send_notification) ) - # Wait for connection to establish with longer timeout Settings.logger.debug("Waiting for connection to establish...") - for attempt in range( - self.CONNECTION_ESTABLISH_ATTEMPTS - ): # Wait up to 10 seconds + for _attempt in range(self.CONNECTION_ESTABLISH_ATTEMPTS): if self.session is not None: Settings.logger.info("Connection established successfully") return self.session + if self._connection_task.done(): + exc = self._connection_task.exception() + if exc: + raise exc + break await asyncio.sleep(self.CONNECTION_CHECK_INTERVAL) - # Timeout occurred - clean up the running task Settings.logger.error("Connection establishment timed out") if self._connection_task and not self._connection_task.done(): self._connection_task.cancel() @@ -504,25 +624,30 @@ async def connect(self, send_notification: bool = True): pass raise TimeoutError( - "Connection establishment timed out after 10 seconds" + "Connection establishment timed out after 10 seconds. " + "PiecesOS may be starting up or overloaded. " + "Try again, or run `pieces restart`." ) except Exception as e: - # Ensure clean state on any error await self._ensure_clean_state() Settings.logger.error( f"Error connecting to upstream server: {e}", exc_info=True ) raise - async def _ensure_clean_state(self): - """Ensure all connection state is properly cleaned up.""" + async def _ensure_clean_state(self) -> None: + """Ensure all connection state is properly cleaned up. + + Unlike ``_cleanup_stale_session`` (which cleans up within the + connection handler's task context), this method is called from + ``connect()`` and ``cleanup()`` to cancel the connection task + and reset state from the outside. + """ Settings.logger.debug("Ensuring clean connection state") - # Signal cleanup if needed self._cleanup_requested.set() - # Cancel and wait for existing connection task if self._connection_task and not self._connection_task.done(): self._connection_task.cancel() try: @@ -532,55 +657,74 @@ async def _ensure_clean_state(self): except Exception as e: Settings.logger.debug(f"Error cleaning up connection task: {e}") - # Reset all state self.session = None - self.sse_client = None + self._transport_ctx = None self._connection_task = None - # Note: Don't clear discovered_tools here - keep them for fallback - async def setup_notification_handler(self, session): - """Setup the notification handler for the session.""" - if not hasattr(self, "main_notification_handler"): - self.main_notification_handler = session._received_notification + def _make_message_handler( + self, send_notification: bool = True + ) -> Callable[[Any], Awaitable[None]]: + """Create a ``message_handler`` callback for ``ClientSession``. - async def received_notification_handler( - notification: types.ServerNotification, - ): - """Handle received notifications from the SSE client.""" - Settings.logger.debug(f"Received notification: {notification.root}") - if isinstance(notification.root, types.ToolListChangedNotification): - await self.update_tools(session, send_notification=False) - await self._tools_changed_callback() - await self.main_notification_handler(notification) + Uses the SDK's public ``message_handler`` constructor parameter + instead of monkey-patching the private ``_received_notification`` + attribute. The handler is called *after* the SDK's built-in + notification processing, so logging and other default behaviour + is preserved. - session._received_notification = received_notification_handler + Only ``ToolListChangedNotification`` is handled here; all other + message types (requests, exceptions) are ignored -- the SDK + handles them internally. - async def cleanup(self): - """Cleans up the upstream connection.""" + Args: + send_notification: Forwarded to ``update_tools`` when a + tool-list-changed notification arrives. + + Returns: + An async callable suitable for the ``message_handler`` kwarg + of ``ClientSession.__init__``. + """ + conn = self + + async def _handler(message: Any) -> None: + if not isinstance(message, types.ServerNotification): + return + Settings.logger.debug(f"Received notification: {message.root}") + if isinstance(message.root, types.ToolListChangedNotification): + session = conn.session + if session is not None: + await conn.update_tools(session, send_notification=False) + await conn._tools_changed_callback() + + return _handler + + async def cleanup(self) -> None: + """Clean up the upstream connection (full teardown).""" async with self.connection_lock: Settings.logger.info("Starting connection cleanup") - # Ensure clean state (this handles task cancellation and cleanup) await self._ensure_clean_state() - # Clear discovered tools on full cleanup self.discovered_tools = [] Settings.logger.info("Connection cleanup completed") - async def call_tool(self, name, arguments): - """Calls a tool on the POS MCP server.""" + async def call_tool(self, name: str, arguments: dict) -> types.CallToolResult: + """Call a tool on the PiecesOS MCP server. + + Performs 4-step validation, connects if needed, and forwards the + call. Errors are surfaced to the user with specific messages and + remediation steps rather than being swallowed. + """ Settings.logger.debug(f"Calling tool: {name}") - # Perform 3-step validation before attempting to call tool - is_valid, error_message = self._validate_system_status(name) + is_valid, error_message = await self._validate_system_status(name) if not is_valid: Settings.logger.debug(f"Tool validation failed for {name}: {error_message}") return types.CallToolResult( content=[types.TextContent(type="text", text=error_message)] ) - # All validations passed, try to call the upstream tool try: Settings.logger.debug(f"Calling upstream tool: {name}") session = await self.connect() @@ -590,27 +734,51 @@ async def call_tool(self, name, arguments): Settings.logger.debug(f"with results: {result}") return result + except TimeoutError: + error_message = ( + f"Timed out connecting to PiecesOS while executing '{self._sanitize_tool_name(name)}'. " + "PiecesOS may be starting up or overloaded. Please try again in a few seconds.\n\n" + "If this persists, run: `pieces restart`" + ) + Settings.logger.error(f"Timeout calling tool {name}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=error_message)] + ) + except (ConnectionError, OSError) as e: + error_message = ( + f"Cannot reach PiecesOS to execute '{self._sanitize_tool_name(name)}'. " + "Please ensure PiecesOS is running with `pieces open`, then retry.\n\n" + f"Error: {e}" + ) + Settings.logger.error(f"Connection error calling tool {name}: {e}") + return types.CallToolResult( + content=[types.TextContent(type="text", text=error_message)] + ) except Exception as e: Settings.logger.error(f"Error calling POS MCP {name}: {e}", exc_info=True) - # Return a helpful error message based on the tool and system status - error_message = self._get_error_message_for_tool(name) + error_message = await self._get_error_message_for_tool(name) + error_message += f"\n\nError details: {type(e).__name__}: {e}" return types.CallToolResult( content=[types.TextContent(type="text", text=error_message)] ) class MCPGateway: - """Gateway server between POS MCP server and stdio.""" + """Gateway server between IDE clients (stdio) and PiecesOS (upstream). + + Routes ``list_tools`` and ``call_tool`` requests from the IDE to PiecesOS, + handles tool change notifications, and manages the upstream connection + lifecycle. + """ - def __init__(self, server_name, upstream_url): - self.server_name = server_name + def __init__(self, server_name: str, upstream_url: str | None): + self.server_name: str = server_name self.server = Server(server_name) self.upstream = PosMcpConnection( upstream_url, self.send_tools_changed_notification ) - # Add MCP server info to Sentry context sentry_sdk.set_context( "mcp_gateway", { @@ -623,8 +791,8 @@ def __init__(self, server_name, upstream_url): self.setup_handlers() - async def send_tools_changed_notification(self): - """Send a tools/list_changed notification to the client.""" + async def send_tools_changed_notification(self) -> None: + """Send a tools/list_changed notification to the IDE client.""" try: ctx = self.server.request_context await ctx.session.send_notification( @@ -643,15 +811,15 @@ async def send_tools_changed_notification(self): "Tools have changed - clients will receive updated tools on next request" ) - def setup_handlers(self): - """Sets up the request handlers for the gateway server.""" + def setup_handlers(self) -> None: + """Set up the request handlers for the gateway server.""" Settings.logger.info("Setting up gateway request handlers") @self.server.list_tools() async def list_tools() -> list[types.Tool]: Settings.logger.debug("Received list_tools request") - if self.upstream._check_pieces_os_status(): + if await self.upstream._check_pieces_os_status(): await self.upstream.connect(send_notification=False) Settings.logger.debug( @@ -659,17 +827,16 @@ async def list_tools() -> list[types.Tool]: ) return self.upstream.discovered_tools else: - # Only use cached/fallback tools when PiecesOS is not running if self.upstream.discovered_tools: - Settings.logger.debug( - f"PiecesOS not running - returning cached tools: {len(self.upstream.discovered_tools)} tools" + Settings.logger.warning( + "PiecesOS is not running -- returning previously cached tools. " + "Results may be stale. Run `pieces open` to reconnect." ) return self.upstream.discovered_tools - Settings.logger.debug("PiecesOS not running - returning fallback tools") - # Use the hardcoded fallback tools - Settings.logger.debug( - f"Returning {len(PIECES_MCP_TOOLS_CACHE)} fallback tools" + Settings.logger.warning( + "PiecesOS is not running and no cached tools available -- " + "returning fallback tool definitions. Run `pieces open` to start PiecesOS." ) return PIECES_MCP_TOOLS_CACHE @@ -682,12 +849,11 @@ async def call_tool(name: str, arguments: dict) -> list[types.ContentBlock]: Settings.logger.debug(f"POS returnable {pos_returnable}") return pos_returnable.content - async def run(self): - """Runs the gateway server.""" + async def run(self) -> None: + """Run the gateway server (stdio transport).""" try: Settings.logger.info("Starting MCP Gateway server") - # Add Sentry breadcrumb for MCP gateway startup sentry_sdk.add_breadcrumb( message="MCP Gateway starting", category="mcp", @@ -701,7 +867,11 @@ async def run(self): try: await self.upstream.connect(send_notification=False) except Exception as e: - Settings.logger.error(f"Failed to connect to upstream server {e}") + Settings.logger.warning( + f"Could not connect to PiecesOS at startup: {e}. " + "The gateway will continue and retry on the next tool call. " + "Ensure PiecesOS is running with `pieces open`." + ) Settings.logger.info(f"Starting stdio server for {self.server.name}") async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): @@ -722,7 +892,6 @@ async def run(self): except KeyboardInterrupt: Settings.logger.info("Gateway interrupted by user") except Exception as e: - # Handle specific MCP-related errors more gracefully if ( "BrokenResourceError" in str(e) or "unhandled errors in a TaskGroup" in str(e) @@ -734,8 +903,6 @@ async def run(self): f"Error running gateway server: {e}", exc_info=True ) finally: - # Ensure we clean up the connection when the gateway exits - # But do it in a way that doesn't interfere with stdio cleanup Settings.logger.info("Gateway shutting down, cleaning up connections") try: await self.upstream.cleanup() @@ -743,21 +910,48 @@ async def run(self): Settings.logger.debug(f"Error during cleanup: {e}") -async def main(): - # Just initialize settings without starting services +async def _run_with_shutdown(gateway: MCPGateway, shutdown_event: asyncio.Event) -> None: + """Run the gateway and cancel it when the shutdown event is set. + + This wires up the signal-handler-set ``shutdown_event`` so that + SIGTERM/SIGINT actually trigger a graceful shutdown. + """ + gateway_task = asyncio.create_task(gateway.run()) + shutdown_task = asyncio.create_task(shutdown_event.wait()) + done, pending = await asyncio.wait( + [gateway_task, shutdown_task], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +async def main() -> None: + """Entry point for the MCP gateway process. + + Startup sequence: + 1. Set up the asyncio exception handler + 2. Register signal handlers for graceful shutdown + 3. Create WebSocket instances (Health, Auth, LTM Vision) + 4. Resolve the upstream MCP URL + 5. Create and run the gateway + 6. On shutdown, close WebSockets and clean up + """ Settings.logger.info("Starting MCP Gateway") is_pos_stream_running_lock = threading.Lock() upstream_connection = None - def asyncio_exception_handler(loop, context): + def asyncio_exception_handler(loop: asyncio.AbstractEventLoop, context: dict) -> None: exc = context.get("exception") - # Handle HTTP timeout and connection-related errors without sending to Sentry if isinstance(exc, (httpx.RemoteProtocolError)): with is_pos_stream_running_lock: Settings.pieces_client.is_pos_stream_running = False - # Log at info level instead of debug for better visibility Settings.logger.info( f"POS stream stopped due to HTTP timeout/connection error: {type(exc).__name__}" ) @@ -777,7 +971,6 @@ def asyncio_exception_handler(loop, context): Settings.logger.info( f"Timeout error (handled gracefully): {type(exc).__name__}" ) - # Add breadcrumb but don't send exception to Sentry sentry_sdk.add_breadcrumb( message="MCP timeout handled by async exception handler", category="mcp", @@ -788,11 +981,10 @@ def asyncio_exception_handler(loop, context): exc ): Settings.logger.info( - "SSE stream resource broken in async handler (connection closed during send)" + "Stream resource broken in async handler (connection closed during send)" ) - # Add breadcrumb but don't send exception to Sentry sentry_sdk.add_breadcrumb( - message="SSE stream resource broken handled by async exception handler", + message="Stream resource broken handled by async exception handler", category="mcp", level="info", data={ @@ -805,7 +997,6 @@ def asyncio_exception_handler(loop, context): Settings.logger.info( "MCP JSON-RPC validation error in async handler (server sent malformed message)" ) - # Add breadcrumb but don't send exception to Sentry sentry_sdk.add_breadcrumb( message="MCP JSON-RPC validation error handled by async exception handler", category="mcp", @@ -816,35 +1007,36 @@ def asyncio_exception_handler(loop, context): }, ) else: - Settings.logger.error(f"Async exception: {context}") + Settings.logger.error( + f"Unexpected async error: {context}. " + "If this causes issues, try `pieces restart`." + ) + if exc: + sentry_sdk.capture_exception(exc) loop = asyncio.get_event_loop() loop.set_exception_handler(asyncio_exception_handler) - # Set up signal handlers for graceful shutdown shutdown_event = asyncio.Event() - def signal_handler(): + def signal_handler() -> None: Settings.logger.info("Received shutdown signal") shutdown_event.set() - # Register signal handlers if hasattr(signal, "SIGTERM"): signal.signal(signal.SIGTERM, lambda s, f: signal_handler()) if hasattr(signal, "SIGINT"): signal.signal(signal.SIGINT, lambda s, f: signal_handler()) - # HealthWS starts the AuthWS, which starts the LTMVisionWS ltm_vision = LTMVisionWS(Settings.pieces_client, lambda x: None) user_ws = AuthWS( Settings.pieces_client, lambda x: None, lambda x: ltm_vision.start() ) - def on_ws_event(ws, e): + def on_ws_event(ws: Any, e: Exception) -> None: if isinstance(e, WebSocketConnectionClosedException): with is_pos_stream_running_lock: Settings.pieces_client.is_pos_stream_running = False - # Also request cleanup if we have the connection reference if upstream_connection: upstream_connection.request_cleanup() else: @@ -857,7 +1049,6 @@ def on_ws_event(ws, e): on_error=on_ws_event, ) - # Try to get the MCP URL, but continue even if it fails upstream_url = None if Settings.pieces_client.is_pieces_running(): upstream_url = get_mcp_latest_url() @@ -869,14 +1060,18 @@ def on_ws_event(ws, e): upstream_url=upstream_url, ) - # Store reference for exception handler upstream_connection = gateway.upstream try: - await gateway.run() + await _run_with_shutdown(gateway, shutdown_event) except KeyboardInterrupt: Settings.logger.info("Gateway interrupted by user") except Exception as e: Settings.logger.error(f"Unexpected error in main: {e}", exc_info=True) finally: Settings.logger.info("MCP Gateway shutting down") + for ws in [ltm_vision, user_ws, health_ws]: + try: + ws.close() + except Exception: + pass diff --git a/src/pieces/mcp/handler.py b/src/pieces/mcp/handler.py index 7c79576c..1692f33a 100644 --- a/src/pieces/mcp/handler.py +++ b/src/pieces/mcp/handler.py @@ -1,3 +1,11 @@ +""" +CLI command handlers for ``pieces mcp`` subcommands. + +Provides the entry points for MCP setup, repair, status, and documentation +commands. Each handler validates that PiecesOS is reachable before proceeding +and returns structured ``BaseResponse`` objects for headless mode. +""" + import json import time import urllib.parse @@ -17,7 +25,7 @@ create_mcp_repair_success, create_mcp_setup_success, ) -from pieces.mcp.utils import get_mcp_latest_url +from pieces.mcp.utils import get_mcp_sse_url from pieces.settings import Settings from pieces.urls import URLs @@ -55,14 +63,21 @@ def check_mcp_running(): + """Check if the PiecesOS MCP server is reachable. + + Uses a simple HTTP GET with a short timeout that works for both SSE + and streamable HTTP endpoints. + """ try: - with urllib.request.urlopen(get_mcp_latest_url(), timeout=1) as response: - for line in response: - message = line.decode("utf-8").strip() - if message: - break + url = get_mcp_sse_url() + if not url: + Settings.show_error("No MCP server URL available. Is PiecesOS running?") + return False + req = urllib.request.Request(url, method="GET") + with urllib.request.urlopen(req, timeout=3) as response: + response.read(1) except Exception as e: - Settings.show_error(f"Pieces MCP server is not running {e}") + Settings.show_error(f"Pieces MCP server is not running: {e}") return False return True @@ -72,6 +87,16 @@ def handle_mcp( stdio: bool = False, **kwargs, ) -> BaseResponse: + """Set up MCP integration for a specific IDE or show a selection menu. + + Args: + integration: Target IDE name, or ``None`` to show a selection menu. + stdio: If True, configure stdio transport instead of SSE. + **kwargs: Additional arguments (``local``, ``global`` for VS Code/Cursor). + + Returns: + A ``BaseResponse`` indicating success or failure with details. + """ if not check_mcp_running(): return ErrorResponse( "mcp setup", @@ -93,7 +118,7 @@ def handle_mcp( ) elif integration == "warp": jsn = ( - warp_stdio_json if stdio else warp_sse_json.format(url=get_mcp_latest_url()) + warp_stdio_json if stdio else warp_sse_json.format(url=get_mcp_sse_url()) ) text = warp_instructions.format(json=jsn) Settings.logger.print(Markdown(text)) @@ -187,6 +212,13 @@ def handle_mcp_docs( ide: Union[mcp_integration_types, Literal["all", "current", "raycast", "warp"]], **kwargs, ): + """Display or open documentation URLs for MCP integrations. + + Args: + ide: Target IDE, ``"all"`` for every integration, or ``"current"`` + for only those already set up. + **kwargs: Pass ``open=True`` to open the URL in a browser. + """ if ide == "all" or ide == "current": for mcp_name, mcp_integration in supported_mcps.items(): if ide == "current" and not mcp_integration.is_set_up(): @@ -286,6 +318,7 @@ def repair_single_integration( def handle_status(**kwargs): + """Check the status of all MCP integrations and offer to repair broken ones.""" if supported_mcps["vscode"].check_ltm(): Settings.logger.print("[green]LTM running[/green]") else: diff --git a/src/pieces/mcp/integration.py b/src/pieces/mcp/integration.py index 702066c6..68525bb5 100644 --- a/src/pieces/mcp/integration.py +++ b/src/pieces/mcp/integration.py @@ -1,3 +1,14 @@ +""" +IDE integration configuration management for MCP. + +Handles reading, writing, and repairing IDE-specific configuration files +(e.g. ``settings.json`` for VS Code/Cursor, ``claude_desktop_config.json`` +for Claude Desktop) so that IDEs can connect to PiecesOS via MCP. + +IDE configs use the SSE schema URL (``2024-11-05``) because IDEs connect +directly to PiecesOS, not through the CLI gateway. +""" + import json import os from typing import Callable, Dict, List, Tuple, Optional @@ -11,12 +22,17 @@ from pieces.headless.exceptions import HeadlessError from pieces.settings import Settings -from .utils import get_mcp_latest_url, get_mcp_urls +from .utils import get_mcp_sse_url, get_mcp_urls from ..utils import PiecesSelectMenu from pieces.config.schemas.mcp import IntegrationDict, mcp_types, mcp_integration_types class MCPProperties: + """Defines the JSON property structure for an MCP integration's config file. + + Holds separate property templates for stdio and SSE transports, along + with the JSON key names for URL, command, and args fields. + """ pieces_cli_bin_path: Optional[str] = None def __init__( @@ -66,7 +82,7 @@ def mcp_path(self, mcp_type: mcp_types): def mcp_modified_settings(self, mcp_type: mcp_types): mcp_settings = self.mcp_settings(mcp_type) if mcp_type == "sse": - mcp_settings[self.url_property_name] = get_mcp_latest_url() + mcp_settings[self.url_property_name] = get_mcp_sse_url() else: mcp_settings[self.command_property_name] = self.pieces_cli_bin_path mcp_settings[self.args_property_name] = [ @@ -78,6 +94,13 @@ def mcp_modified_settings(self, mcp_type: mcp_types): class Integration: + """Manages a single IDE's MCP configuration. + + Each instance represents one IDE (e.g. VS Code, Claude Desktop) and knows + how to locate, read, write, validate, and repair that IDE's MCP config + file. Supports both stdio and SSE transport modes. + """ + def __init__( self, options: List[Tuple], @@ -312,7 +335,7 @@ def check_properties(self, mcp_type: mcp_types, config: Dict) -> bool: mcp_settings = self.mcp_properties.mcp_modified_settings(mcp_type) for k, value in config.items(): if k == self.mcp_properties.url_property_name and mcp_type == "sse": - if value != get_mcp_latest_url(): + if value != get_mcp_sse_url(): return False elif k in mcp_settings: if mcp_settings[k] != value: diff --git a/src/pieces/mcp/tools_cache.py b/src/pieces/mcp/tools_cache.py index 7f78374d..3f732754 100644 --- a/src/pieces/mcp/tools_cache.py +++ b/src/pieces/mcp/tools_cache.py @@ -1,6 +1,15 @@ +""" +Fallback MCP tool definitions for offline operation. + +When PiecesOS is not running and no previously discovered tools are cached, +the gateway returns these hardcoded tool definitions so that IDE clients +still see the available tool signatures. The tools won't execute +successfully until PiecesOS is started, but the IDE can display them and +queue requests. +""" + from mcp.types import Tool -# Hardcoded fallback tools when PiecesOS isn't available PIECES_MCP_TOOLS_CACHE = [ Tool( name="ask_pieces_ltm", diff --git a/src/pieces/mcp/utils.py b/src/pieces/mcp/utils.py index e9213f52..a126ee40 100644 --- a/src/pieces/mcp/utils.py +++ b/src/pieces/mcp/utils.py @@ -1,4 +1,19 @@ -from typing import TYPE_CHECKING, List +""" +MCP URL resolution and schema version selection. + +Resolves PiecesOS MCP endpoint URLs from the schema versions API and caches +them to avoid repeated API calls. The cache is invalidated when PiecesOS +restarts or the connection is cleaned up. + +The gateway prefers the ``2025-03-26`` streamable HTTP schema for upstream +connections (more robust than SSE), while IDE integration configs continue +to use the ``2024-11-05`` SSE schema for direct IDE-to-PiecesOS connections. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + from ..settings import Settings if TYPE_CHECKING: @@ -6,10 +21,25 @@ ModelContextProtocolSchemaVersion, ) -_latest_urls = [] # cache the urls instead of sending to the api +PREFERRED_SCHEMA_VERSION = "2025-03-26" +"""Streamable HTTP schema -- preferred for gateway upstream connections.""" + +SSE_SCHEMA_VERSION = "2024-11-05" +"""SSE schema -- used for IDE integration configs that connect directly to PiecesOS.""" +_latest_urls: list[ModelContextProtocolSchemaVersion] = [] -def get_mcp_model_urls() -> List["ModelContextProtocolSchemaVersion"]: + +def get_mcp_model_urls() -> list[ModelContextProtocolSchemaVersion]: + """Fetch and cache the list of MCP schema versions from PiecesOS. + + Returns: + List of ``ModelContextProtocolSchemaVersion`` objects, each containing + ``entry_endpoint`` and ``version`` fields. + + Raises: + Exception: If the PiecesOS API call fails. + """ global _latest_urls if not _latest_urls: res = Settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions() @@ -17,9 +47,59 @@ def get_mcp_model_urls() -> List["ModelContextProtocolSchemaVersion"]: return _latest_urls -def get_mcp_latest_url(): - return get_mcp_model_urls()[0].entry_endpoint +def invalidate_mcp_url_cache() -> None: + """Clear the cached schema version URLs. + + Call this when PiecesOS restarts, the connection is cleaned up, or the + upstream URL needs to be re-resolved (e.g. PiecesOS changed ports). + """ + global _latest_urls + _latest_urls = [] + + +def get_mcp_latest_url() -> str: + """Get the preferred MCP endpoint URL for gateway upstream connections. + + Prefers the ``2025-03-26`` streamable HTTP schema. Falls back to the + first available schema if the preferred version is not found. + + Returns: + The entry endpoint URL string. + + Raises: + ValueError: If no MCP schema versions are available from PiecesOS. + """ + urls = get_mcp_model_urls() + if not urls: + raise ValueError("No MCP schema versions available from PiecesOS") + for schema in urls: + if schema.version == PREFERRED_SCHEMA_VERSION: + return schema.entry_endpoint + return urls[0].entry_endpoint + + +def get_mcp_sse_url() -> str | None: + """Get the SSE endpoint URL specifically for IDE integration configs. + + IDE integrations (Claude, Cursor, VS Code, etc.) connect directly to + PiecesOS using SSE, so they need the ``2024-11-05`` schema URL. + + Returns: + The SSE entry endpoint URL, or ``None`` if no schemas are available. + """ + urls = get_mcp_model_urls() + if not urls: + return None + for schema in urls: + if schema.version == SSE_SCHEMA_VERSION: + return schema.entry_endpoint + return urls[0].entry_endpoint + +def get_mcp_urls() -> list[str]: + """Get all known MCP endpoint URLs (all schema versions). -def get_mcp_urls(): + Returns: + List of endpoint URL strings. + """ return [mcp.entry_endpoint for mcp in get_mcp_model_urls()] diff --git a/tests/mcps/mcp_gateway/test_bug_fixes.py b/tests/mcps/mcp_gateway/test_bug_fixes.py new file mode 100644 index 00000000..f6d0bdec --- /dev/null +++ b/tests/mcps/mcp_gateway/test_bug_fixes.py @@ -0,0 +1,559 @@ +""" +Tests for MCP gateway bug fixes. + +Covers: schema version selection, URL cache invalidation, cleanup ordering, +notification handler reconnection, polling early failure detection, resource +leak prevention, error surfacing, and session ping race conditions. +""" + +import asyncio +import time +import pytest +import mcp.types as types +from unittest.mock import Mock, AsyncMock, patch + +from .utils import mock_connection + + +# --------------------------------------------------------------------------- +# Bug 1: Schema version selection +# --------------------------------------------------------------------------- + +class TestSchemaVersionSelection: + """Verify that get_mcp_latest_url prefers 2025-03-26 streamable HTTP.""" + + def _make_schema(self, version, endpoint): + s = Mock() + s.version = version + s.entry_endpoint = endpoint + return s + + @patch("pieces.mcp.utils._latest_urls", []) + @patch("pieces.mcp.utils.Settings") + def test_prefers_2025_03_26(self, mock_settings): + from pieces.mcp.utils import get_mcp_latest_url, invalidate_mcp_url_cache + + invalidate_mcp_url_cache() + schemas = [ + self._make_schema("2024-11-05", "http://localhost:39300/mcp/2024-11-05/sse"), + self._make_schema("2025-03-26", "http://localhost:39300/mcp/2025-03-26/mcp"), + ] + mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.return_value = Mock( + iterable=schemas + ) + + url = get_mcp_latest_url() + assert url == "http://localhost:39300/mcp/2025-03-26/mcp" + invalidate_mcp_url_cache() + + @patch("pieces.mcp.utils._latest_urls", []) + @patch("pieces.mcp.utils.Settings") + def test_falls_back_to_first(self, mock_settings): + from pieces.mcp.utils import get_mcp_latest_url, invalidate_mcp_url_cache + + invalidate_mcp_url_cache() + schemas = [ + self._make_schema("2024-11-05", "http://localhost:39300/mcp/2024-11-05/sse"), + ] + mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.return_value = Mock( + iterable=schemas + ) + + url = get_mcp_latest_url() + assert url == "http://localhost:39300/mcp/2024-11-05/sse" + invalidate_mcp_url_cache() + + @patch("pieces.mcp.utils._latest_urls", []) + @patch("pieces.mcp.utils.Settings") + def test_empty_list_raises(self, mock_settings): + from pieces.mcp.utils import get_mcp_latest_url, invalidate_mcp_url_cache + + invalidate_mcp_url_cache() + mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.return_value = Mock( + iterable=[] + ) + + with pytest.raises(ValueError, match="No MCP schema versions"): + get_mcp_latest_url() + invalidate_mcp_url_cache() + + @patch("pieces.mcp.utils._latest_urls", []) + @patch("pieces.mcp.utils.Settings") + def test_get_mcp_sse_url_returns_sse(self, mock_settings): + from pieces.mcp.utils import get_mcp_sse_url, invalidate_mcp_url_cache + + invalidate_mcp_url_cache() + schemas = [ + self._make_schema("2025-03-26", "http://localhost:39300/mcp/2025-03-26/mcp"), + self._make_schema("2024-11-05", "http://localhost:39300/mcp/2024-11-05/sse"), + ] + mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.return_value = Mock( + iterable=schemas + ) + + url = get_mcp_sse_url() + assert url == "http://localhost:39300/mcp/2024-11-05/sse" + invalidate_mcp_url_cache() + + +# --------------------------------------------------------------------------- +# Bug 2: URL cache invalidation +# --------------------------------------------------------------------------- + +class TestCacheInvalidation: + """Verify that invalidate_mcp_url_cache clears the cache.""" + + @patch("pieces.mcp.utils.Settings") + def test_invalidate_forces_refetch(self, mock_settings): + from pieces.mcp.utils import get_mcp_model_urls, invalidate_mcp_url_cache + + invalidate_mcp_url_cache() + schema = Mock() + schema.version = "2025-03-26" + schema.entry_endpoint = "http://test/mcp" + mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.return_value = Mock( + iterable=[schema] + ) + + get_mcp_model_urls() + assert mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.call_count == 1 + + get_mcp_model_urls() + assert mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.call_count == 1 + + invalidate_mcp_url_cache() + get_mcp_model_urls() + assert mock_settings.pieces_client.model_context_protocol_api.model_context_protocol_schema_versions.call_count == 2 + + invalidate_mcp_url_cache() + + @pytest.mark.asyncio + async def test_try_get_upstream_url_invalidates_when_none(self, mock_connection): + mock_connection.upstream_url = None + + with ( + patch("pieces.mcp.gateway.Settings") as mock_settings, + patch("pieces.mcp.gateway.get_mcp_latest_url", return_value="http://new-url/mcp"), + patch("pieces.mcp.gateway.invalidate_mcp_url_cache") as mock_invalidate, + ): + mock_settings.pieces_client.is_pieces_running.return_value = True + + result = mock_connection._try_get_upstream_url() + + assert result is True + assert mock_connection.upstream_url == "http://new-url/mcp" + mock_invalidate.assert_called_once() + + +# --------------------------------------------------------------------------- +# Bug 3: Cleanup ordering +# --------------------------------------------------------------------------- + +class TestCleanupOrdering: + """Verify __aexit__ is called before instance vars are nullified.""" + + @pytest.mark.asyncio + async def test_aexit_called_before_nullify(self, mock_connection): + session_mock = AsyncMock() + transport_mock = AsyncMock() + mock_connection.session = session_mock + mock_connection._transport_ctx = transport_mock + + call_order = [] + + async def session_aexit(*args): + assert mock_connection.session is session_mock, \ + "session should still be set when __aexit__ is called" + call_order.append("session_aexit") + + async def transport_aexit(*args): + call_order.append("transport_aexit") + + session_mock.__aexit__ = session_aexit + transport_mock.__aexit__ = transport_aexit + + await mock_connection._cleanup_stale_session() + + assert mock_connection.session is None + assert mock_connection._transport_ctx is None + assert "session_aexit" in call_order + assert "transport_aexit" in call_order + + @pytest.mark.asyncio + async def test_cleanup_continues_on_aexit_exception(self, mock_connection): + session_mock = AsyncMock() + transport_mock = AsyncMock() + mock_connection.session = session_mock + mock_connection._transport_ctx = transport_mock + + session_mock.__aexit__ = AsyncMock(side_effect=RuntimeError("boom")) + transport_mock.__aexit__ = AsyncMock() + + await mock_connection._cleanup_stale_session() + + assert mock_connection.session is None + assert mock_connection._transport_ctx is None + transport_mock.__aexit__.assert_called_once() + + +# --------------------------------------------------------------------------- +# Bug 4: Notification handler (now via public message_handler API) +# --------------------------------------------------------------------------- + +class TestMessageHandler: + """Verify _make_message_handler uses the public SDK API correctly.""" + + @pytest.mark.asyncio + async def test_each_call_returns_distinct_handler(self, mock_connection): + handler1 = mock_connection._make_message_handler(send_notification=True) + handler2 = mock_connection._make_message_handler(send_notification=True) + + assert handler1 is not handler2 + + @pytest.mark.asyncio + async def test_handles_tool_list_changed_notification(self, mock_connection): + mock_session = AsyncMock() + mock_connection.session = mock_session + mock_connection._tools_changed_callback = AsyncMock() + + handler = mock_connection._make_message_handler(send_notification=True) + + notification = types.ServerNotification( + root=types.ToolListChangedNotification( + method="notifications/tools/list_changed" + ) + ) + + with patch.object(mock_connection, "update_tools", new_callable=AsyncMock): + await handler(notification) + + mock_connection._tools_changed_callback.assert_awaited_once() + + @pytest.mark.asyncio + async def test_ignores_non_notification_messages(self, mock_connection): + handler = mock_connection._make_message_handler(send_notification=True) + mock_connection._tools_changed_callback = AsyncMock() + + await handler("not a notification") + await handler(Exception("some error")) + await handler(42) + + mock_connection._tools_changed_callback.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_stale_main_handler_attribute(self, mock_connection): + mock_connection._make_message_handler(send_notification=True) + + assert not hasattr(mock_connection, "main_notification_handler") + + +# --------------------------------------------------------------------------- +# Bug 5: Polling early failure detection +# --------------------------------------------------------------------------- + +class TestPollingEarlyExit: + """Verify connect() detects early task failure instead of timing out.""" + + @pytest.mark.asyncio + async def test_detects_early_task_failure(self, mock_connection): + mock_connection.upstream_url = "http://test/mcp" + + with patch.object(mock_connection, "_ensure_clean_state", new_callable=AsyncMock): + with patch.object( + mock_connection, + "_connection_handler", + side_effect=RuntimeError("immediate failure"), + ): + start = time.time() + with pytest.raises(RuntimeError, match="immediate failure"): + await mock_connection.connect() + elapsed = time.time() - start + + assert elapsed < 3.0, ( + f"connect() should fail fast, not wait 10s (took {elapsed:.1f}s)" + ) + + +# --------------------------------------------------------------------------- +# Bug 6: Resource leaks +# --------------------------------------------------------------------------- + +class TestResourceLeaks: + """Verify bare except is replaced and resources are cleaned up.""" + + def test_try_get_upstream_url_does_not_catch_system_exit(self, mock_connection): + mock_connection.upstream_url = None + + with patch("pieces.mcp.gateway.Settings") as mock_settings: + mock_settings.pieces_client.is_pieces_running.return_value = True + + with patch( + "pieces.mcp.gateway.get_mcp_latest_url", + side_effect=SystemExit(1), + ): + with pytest.raises(SystemExit): + mock_connection._try_get_upstream_url() + + +# --------------------------------------------------------------------------- +# Bug 9: Session ping race +# --------------------------------------------------------------------------- + +class TestSessionPingRace: + """Verify connect() handles session nullified during ping.""" + + @pytest.mark.asyncio + async def test_handles_session_nullified_during_ping(self, mock_connection): + mock_connection.upstream_url = "http://test/mcp" + + mock_session = AsyncMock() + mock_session.send_ping = AsyncMock(side_effect=Exception("session gone")) + mock_connection.session = mock_session + mock_connection._connection_task = Mock() + mock_connection._connection_task.done.return_value = False + + with ( + patch.object(mock_connection, "_ensure_clean_state", new_callable=AsyncMock), + patch.object(mock_connection, "_connection_handler", new_callable=AsyncMock) as mock_handler, + ): + async def fake_handler(send_notification=True): + mock_connection.session = AsyncMock() + await asyncio.sleep(10) + + mock_handler.side_effect = fake_handler + + session = await mock_connection.connect() + assert session is not None + + +# --------------------------------------------------------------------------- +# Bug 11: Error surfacing +# --------------------------------------------------------------------------- + +class TestErrorSurfacing: + """Verify errors are surfaced to users with actionable messages.""" + + @pytest.mark.asyncio + async def test_call_tool_timeout_includes_specific_error(self, mock_connection): + with patch.object( + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) + ): + with patch.object( + mock_connection, + "connect", + side_effect=TimeoutError("timed out"), + ): + result = await mock_connection.call_tool("test_tool", {}) + + assert isinstance(result, types.CallToolResult) + text = result.content[0].text + assert "Timed out" in text + assert "pieces restart" in text + + @pytest.mark.asyncio + async def test_call_tool_connection_error_includes_pieces_open(self, mock_connection): + with patch.object( + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) + ): + with patch.object( + mock_connection, + "connect", + side_effect=ConnectionError("refused"), + ): + result = await mock_connection.call_tool("test_tool", {}) + + assert isinstance(result, types.CallToolResult) + text = result.content[0].text + assert "pieces open" in text + assert "refused" in text + + @pytest.mark.asyncio + async def test_call_tool_generic_error_includes_details(self, mock_connection): + with patch.object( + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) + ): + with patch.object( + mock_connection, + "connect", + side_effect=RuntimeError("something broke"), + ): + result = await mock_connection.call_tool("test_tool", {}) + + assert isinstance(result, types.CallToolResult) + text = result.content[0].text + assert "RuntimeError" in text + assert "something broke" in text + + @pytest.mark.asyncio + async def test_last_connection_error_surfaced(self, mock_connection): + mock_connection._last_connection_error = ( + "Connection to PiecesOS timed out (ReadTimeout)." + ) + + with patch.object( + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) + ): + with patch.object( + mock_connection, + "connect", + side_effect=Exception("generic"), + ): + result = await mock_connection.call_tool("test_tool", {}) + + text = result.content[0].text + assert "ReadTimeout" in text + + @pytest.mark.asyncio + async def test_last_connection_error_cleared_on_success(self, mock_connection): + mock_connection._last_connection_error = "old error" + mock_connection.upstream_url = "http://test/mcp" + + mock_session = AsyncMock() + + with ( + patch.object(mock_connection, "_ensure_clean_state", new_callable=AsyncMock), + patch.object(mock_connection, "_connection_handler", new_callable=AsyncMock) as mock_handler, + ): + async def fake_handler(send_notification=True): + mock_connection.session = mock_session + mock_connection._last_connection_error = None + await asyncio.sleep(10) + + mock_handler.side_effect = fake_handler + + session = await mock_connection.connect() + assert session is mock_session + assert mock_connection._last_connection_error is None + + @pytest.mark.asyncio + async def test_list_tools_fallback_logs_warning(self): + from pieces.mcp.gateway import MCPGateway + + with ( + patch("pieces.mcp.gateway.Settings") as mock_settings, + patch("pieces.mcp.gateway.Server"), + patch("pieces.mcp.gateway.sentry_sdk"), + ): + mock_settings.logger = Mock() + gateway = MCPGateway( + server_name="test", + upstream_url="http://test/mcp", + ) + + gateway.upstream._check_pieces_os_status = AsyncMock(return_value=False) + gateway.upstream.discovered_tools = [] + + handlers = {} + original_list_tools = gateway.server.list_tools + + def capture_list_tools(): + def decorator(func): + handlers["list_tools"] = func + return func + return decorator + + gateway.server.list_tools = capture_list_tools + gateway.server.call_tool = lambda: lambda f: f + gateway.setup_handlers() + + if "list_tools" in handlers: + result = await handlers["list_tools"]() + mock_settings.logger.warning.assert_called() + warning_msg = mock_settings.logger.warning.call_args[0][0] + assert "pieces open" in warning_msg.lower() or "PiecesOS" in warning_msg + + +# --------------------------------------------------------------------------- +# Enhancement 1: asyncio.Lock + asyncio.to_thread in _check_pieces_os_status +# --------------------------------------------------------------------------- + +class TestAsyncHealthCheck: + """Verify _check_pieces_os_status is async and uses to_thread.""" + + @pytest.mark.asyncio + @patch("pieces.mcp.gateway.HealthWS") + async def test_returns_true_when_health_ws_running( + self, mock_health_ws, mock_connection + ): + mock_health_ws.is_running.return_value = True + with patch("pieces.mcp.gateway.Settings") as mock_settings: + mock_settings.pieces_client.is_pos_stream_running = True + + result = await mock_connection._check_pieces_os_status() + + assert result is True + + @pytest.mark.asyncio + async def test_returns_false_when_pieces_not_running(self, mock_connection): + with ( + patch("pieces.mcp.gateway.HealthWS") as mock_health_ws, + patch("pieces.mcp.gateway.Settings") as mock_settings, + ): + mock_health_ws.is_running.return_value = False + mock_settings.pieces_client.is_pos_stream_running = False + mock_settings.pieces_client.is_pieces_running.return_value = False + + result = await mock_connection._check_pieces_os_status() + + assert result is False + + @pytest.mark.asyncio + async def test_offloads_blocking_calls_to_thread(self, mock_connection): + with ( + patch("pieces.mcp.gateway.HealthWS") as mock_health_ws, + patch("pieces.mcp.gateway.Settings") as mock_settings, + patch("pieces.mcp.gateway.sentry_sdk"), + patch("asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread, + ): + mock_health_ws.is_running.return_value = False + mock_settings.pieces_client.is_pos_stream_running = False + + mock_to_thread.side_effect = [ + True, # is_pieces_running + None, # health_ws.start + "id1", # get_os_id + Mock(user=Mock()), # user_snapshot + Mock(), # vision_status + ] + + mock_health_ws.get_instance.return_value = Mock() + + result = await mock_connection._check_pieces_os_status() + + assert result is True + assert mock_to_thread.call_count == 5 + + @pytest.mark.asyncio + async def test_uses_asyncio_lock_not_threading_lock(self, mock_connection): + assert isinstance(mock_connection._health_check_lock, asyncio.Lock) + + @pytest.mark.asyncio + async def test_validate_system_status_is_async(self, mock_connection): + import inspect + assert inspect.iscoroutinefunction(mock_connection._validate_system_status) + + @pytest.mark.asyncio + async def test_get_error_message_for_tool_is_async(self, mock_connection): + import inspect + assert inspect.iscoroutinefunction(mock_connection._get_error_message_for_tool) + + +# --------------------------------------------------------------------------- +# Enhancement 3: get_session_id captured for observability +# --------------------------------------------------------------------------- + +class TestSessionIdTracking: + """Verify upstream session ID is captured and logged.""" + + def test_upstream_session_id_initialized_to_none(self, mock_connection): + assert mock_connection._upstream_session_id is None + + @pytest.mark.asyncio + async def test_session_id_cleared_on_cleanup(self, mock_connection): + mock_connection._upstream_session_id = "test-session-123" + mock_connection.session = None + mock_connection._transport_ctx = None + + await mock_connection._cleanup_stale_session() + + assert mock_connection._upstream_session_id is None diff --git a/tests/mcps/mcp_gateway/test_integration.py b/tests/mcps/mcp_gateway/test_integration.py index 0840b6c1..bf2a700e 100644 --- a/tests/mcps/mcp_gateway/test_integration.py +++ b/tests/mcps/mcp_gateway/test_integration.py @@ -29,7 +29,7 @@ def get_upstream_url(): return get_mcp_latest_url() except Exception: # We are mocking the settings so this will raise an exception most of the time we can hardcode the url instead - return "http://localhost:39300/model_context_protocol/2024-11-05/sse" + return "http://localhost:39300/model_context_protocol/2025-03-26/mcp" @pytest.fixture(scope="module") diff --git a/tests/mcps/mcp_gateway/test_validation_advanced.py b/tests/mcps/mcp_gateway/test_validation_advanced.py index 830ebc1a..3d249305 100644 --- a/tests/mcps/mcp_gateway/test_validation_advanced.py +++ b/tests/mcps/mcp_gateway/test_validation_advanced.py @@ -7,14 +7,12 @@ import time import pytest import mcp.types as types -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from .utils import ( - mock_tools_changed_callback, mock_connection, UpdateEnum, ) -from pieces.mcp.gateway import PosMcpConnection class TestMCPGatewayValidationAdvanced: @@ -25,7 +23,7 @@ async def test_concurrent_validation_calls(self, mock_connection): """Test that concurrent validation calls don't cause race conditions""" # Mock all components to return True with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): with patch.object( mock_connection, "_check_version_compatibility", return_value=(True, "") @@ -34,7 +32,7 @@ async def test_concurrent_validation_calls(self, mock_connection): # Run multiple validations concurrently results = [] for i in range(10): - result = mock_connection._validate_system_status(f"tool_{i}") + result = await mock_connection._validate_system_status(f"tool_{i}") results.append(result) # All should succeed @@ -53,10 +51,10 @@ async def test_malformed_tool_names(self, mock_connection): ] with patch.object( - mock_connection, "_check_pieces_os_status", return_value=False + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=False) ): for name in malicious_names: - is_valid, error_message = mock_connection._validate_system_status(name) + is_valid, error_message = await mock_connection._validate_system_status(name) assert is_valid is False # Should not contain raw tool name in error @@ -69,7 +67,7 @@ async def test_connection_timeout_handling(self, mock_connection): """Test handling of connection timeouts""" # Mock validation success with patch.object( - mock_connection, "_validate_system_status", return_value=(True, "") + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) ): # Mock connection to timeout with patch.object( @@ -87,14 +85,14 @@ async def test_partial_failure_states(self, mock_connection): """Test when some checks pass but others fail""" # PiecesOS running but incompatible version with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): mock_result = Mock() mock_result.compatible = False mock_result.update = UpdateEnum.PiecesOS mock_connection.result = mock_result - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "test_tool" ) @@ -106,14 +104,14 @@ async def test_partial_failure_states(self, mock_connection): # PiecesOS running, compatible, but LTM disabled for LTM tool with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): mock_result = Mock() mock_result.compatible = True mock_connection.result = mock_result with patch.object(mock_connection, "_check_ltm_status", return_value=False): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "ask_pieces_ltm" ) @@ -135,7 +133,7 @@ async def mock_call_tool_impl(tool_name, args): ) with patch.object( - mock_connection, "_validate_system_status", return_value=(True, "") + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) ): with patch.object(mock_connection, "connect"): # Mock the actual tool execution @@ -166,7 +164,7 @@ async def test_error_recovery_after_pos_restart(self, mock_connection): """Test gateway recovers after PiecesOS restart""" # Simulate PiecesOS down initially with patch.object( - mock_connection, "_check_pieces_os_status", return_value=False + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=False) ): result1 = await mock_connection.call_tool("test_tool", {}) assert isinstance(result1, types.CallToolResult) @@ -178,7 +176,7 @@ async def test_error_recovery_after_pos_restart(self, mock_connection): # Simulate PiecesOS back up with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): with patch.object(mock_connection, "_check_ltm_status", return_value=True): mock_result = Mock() @@ -187,7 +185,7 @@ async def test_error_recovery_after_pos_restart(self, mock_connection): with patch.object(mock_connection, "connect"): # Should work now - validation passes - is_valid, error_msg = mock_connection._validate_system_status( + is_valid, error_msg = await mock_connection._validate_system_status( "test_tool" ) assert is_valid is True @@ -198,9 +196,9 @@ async def test_error_message_content_validation(self, mock_connection): """Test that error messages provide helpful guidance to users""" # Test PiecesOS not running scenario with patch.object( - mock_connection, "_check_pieces_os_status", return_value=False + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=False) ): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "test_tool" ) @@ -211,7 +209,7 @@ async def test_error_message_content_validation(self, mock_connection): # Test CLI version incompatible scenario with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): with patch.object( mock_connection, @@ -221,7 +219,7 @@ async def test_error_message_content_validation(self, mock_connection): "Please update the CLI version to be able to run the tool call, run 'pieces manage update' to get the latest version. Then retry your request again after updating.", ), ): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "test_tool" ) @@ -232,7 +230,7 @@ async def test_error_message_content_validation(self, mock_connection): # Test PiecesOS version incompatible scenario with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): with patch.object( mock_connection, @@ -242,7 +240,7 @@ async def test_error_message_content_validation(self, mock_connection): "Please update PiecesOS to a compatible version to be able to run the tool call. run 'pieces update' to get the latest version. Then retry your request again after updating.", ), ): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "test_tool" ) @@ -253,7 +251,7 @@ async def test_error_message_content_validation(self, mock_connection): # Test LTM disabled scenario with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): with patch.object( mock_connection, "_check_version_compatibility", return_value=(True, "") @@ -261,7 +259,7 @@ async def test_error_message_content_validation(self, mock_connection): with patch.object( mock_connection, "_check_ltm_status", return_value=False ): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "ask_pieces_ltm" ) @@ -326,7 +324,7 @@ async def mock_list_tools(): async def test_performance_validation_overhead(self, mock_connection): """Test that validation doesn't add significant overhead""" with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): with patch.object(mock_connection, "_check_ltm_status", return_value=True): mock_result = Mock() @@ -335,13 +333,13 @@ async def test_performance_validation_overhead(self, mock_connection): # Warm up for _ in range(10): - mock_connection._validate_system_status("test_tool") + await mock_connection._validate_system_status("test_tool") # Measure performance start = time.time() iterations = 100 for _ in range(iterations): - is_valid, _ = mock_connection._validate_system_status("test_tool") + is_valid, _ = await mock_connection._validate_system_status("test_tool") assert is_valid is True elapsed = time.time() - start @@ -414,7 +412,7 @@ async def mock_cleanup(): # Mock successful validation with patch.object( - mock_connection, "_validate_system_status", return_value=(True, "") + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) ): # Mock the connection and cleanup methods original_connect = mock_connection.connect diff --git a/tests/mcps/mcp_gateway/test_validation_core.py b/tests/mcps/mcp_gateway/test_validation_core.py index 1f19c087..aa3a8341 100644 --- a/tests/mcps/mcp_gateway/test_validation_core.py +++ b/tests/mcps/mcp_gateway/test_validation_core.py @@ -5,14 +5,12 @@ import pytest import mcp.types as types -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from .utils import ( - mock_tools_changed_callback, mock_connection, UpdateEnum, ) -from pieces.mcp.gateway import PosMcpConnection class TestMCPGatewayValidationCore: @@ -22,9 +20,9 @@ class TestMCPGatewayValidationCore: async def test_validate_system_status_pieces_os_not_running(self, mock_connection): """Test validation when PiecesOS is not running""" with patch.object( - mock_connection, "_check_pieces_os_status", return_value=False + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=False) ): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "ask_pieces_ltm" ) @@ -39,7 +37,7 @@ async def test_validate_system_status_version_incompatible_plugin_update( """Test validation when CLI version needs updating""" # Mock PiecesOS running with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): # Mock version compatibility check to return plugin update needed mock_result = Mock() @@ -47,7 +45,7 @@ async def test_validate_system_status_version_incompatible_plugin_update( mock_result.update = UpdateEnum.Plugin mock_connection.result = mock_result - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "ask_pieces_ltm" ) @@ -63,7 +61,7 @@ async def test_validate_system_status_version_incompatible_pos_update( """Test validation when PiecesOS version needs updating""" # Mock PiecesOS running with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): # Mock version compatibility check to return POS update needed mock_result = Mock() @@ -71,7 +69,7 @@ async def test_validate_system_status_version_incompatible_pos_update( mock_result.update = UpdateEnum.PiecesOS # Or any value that's not Plugin mock_connection.result = mock_result - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "ask_pieces_ltm" ) @@ -85,7 +83,7 @@ async def test_validate_system_status_ltm_disabled(self, mock_connection): """Test validation when LTM is disabled for LTM tools""" # Mock PiecesOS running and version compatible with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): mock_result = Mock() mock_result.compatible = True @@ -93,7 +91,7 @@ async def test_validate_system_status_ltm_disabled(self, mock_connection): # Mock LTM disabled with patch.object(mock_connection, "_check_ltm_status", return_value=False): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "ask_pieces_ltm" ) @@ -108,7 +106,7 @@ async def test_validate_system_status_ltm_disabled_create_memory_tool( """Test validation when LTM is disabled for create_pieces_memory tool""" # Mock PiecesOS running and version compatible with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): mock_result = Mock() mock_result.compatible = True @@ -116,7 +114,7 @@ async def test_validate_system_status_ltm_disabled_create_memory_tool( # Mock LTM disabled with patch.object(mock_connection, "_check_ltm_status", return_value=False): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "create_pieces_memory" ) @@ -129,7 +127,7 @@ async def test_validate_system_status_non_ltm_tool_success(self, mock_connection """Test validation success for non-LTM tools when LTM is disabled""" # Mock PiecesOS running and version compatible with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): mock_result = Mock() mock_result.compatible = True @@ -137,7 +135,7 @@ async def test_validate_system_status_non_ltm_tool_success(self, mock_connection # Mock LTM disabled (shouldn't matter for non-LTM tools) with patch.object(mock_connection, "_check_ltm_status", return_value=False): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "some_other_tool" ) @@ -149,7 +147,7 @@ async def test_validate_system_status_all_checks_pass(self, mock_connection): """Test validation when all checks pass""" # Mock PiecesOS running with patch.object( - mock_connection, "_check_pieces_os_status", return_value=True + mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): # Mock version compatible mock_result = Mock() @@ -158,7 +156,7 @@ async def test_validate_system_status_all_checks_pass(self, mock_connection): # Mock LTM enabled with patch.object(mock_connection, "_check_ltm_status", return_value=True): - is_valid, error_message = mock_connection._validate_system_status( + is_valid, error_message = await mock_connection._validate_system_status( "ask_pieces_ltm" ) @@ -171,7 +169,7 @@ async def test_call_tool_with_validation_failure(self, mock_connection): # Mock validation failure error_msg = "Test validation error" with patch.object( - mock_connection, "_validate_system_status", return_value=(False, error_msg) + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(False, error_msg)) ): result = await mock_connection.call_tool("test_tool", {}) @@ -185,7 +183,7 @@ async def test_call_tool_with_connection_failure(self, mock_connection): """Test call_tool handles connection failures gracefully""" # Mock validation success with patch.object( - mock_connection, "_validate_system_status", return_value=(True, "") + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) ): # Mock connection failure with patch.object( @@ -204,9 +202,9 @@ async def test_get_error_message_for_tool_uses_validation(self, mock_connection) # Mock validation failure error_msg = "Validation failed" with patch.object( - mock_connection, "_validate_system_status", return_value=(False, error_msg) + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(False, error_msg)) ): - result = mock_connection._get_error_message_for_tool("test_tool") + result = await mock_connection._get_error_message_for_tool("test_tool") assert result == error_msg @@ -215,9 +213,9 @@ async def test_get_error_message_for_tool_validation_passes(self, mock_connectio """Test _get_error_message_for_tool when validation passes but still has error""" # Mock validation success with patch.object( - mock_connection, "_validate_system_status", return_value=(True, "") + mock_connection, "_validate_system_status", new=AsyncMock(return_value=(True, "")) ): - result = mock_connection._get_error_message_for_tool("test_tool") + result = await mock_connection._get_error_message_for_tool("test_tool") assert "Unable to execute 'test_tool' tool" in result assert "`pieces restart`" in result @@ -243,8 +241,9 @@ def test_check_version_compatibility_caches_result(self, mock_connection): # VersionChecker should only be called once due to caching mock_version_checker.assert_called_once() + @pytest.mark.asyncio @patch("pieces.mcp.gateway.HealthWS") - def test_check_pieces_os_status_health_ws_running( + async def test_check_pieces_os_status_health_ws_running( self, mock_health_ws, mock_connection ): """Test _check_pieces_os_status when health WS is already running""" @@ -255,14 +254,15 @@ def test_check_pieces_os_status_health_ws_running( with patch("pieces.mcp.gateway.Settings") as mock_settings: mock_settings.pieces_client.is_pos_stream_running = True - result = mock_connection._check_pieces_os_status() + result = await mock_connection._check_pieces_os_status() assert result is True mock_health_ws.is_running.assert_called_once() + @pytest.mark.asyncio @patch("pieces.mcp.gateway.HealthWS") @patch("pieces.mcp.gateway.Settings") - def test_check_pieces_os_status_starts_health_ws( + async def test_check_pieces_os_status_starts_health_ws( self, mock_settings, mock_health_ws, mock_connection ): """Test _check_pieces_os_status starts health WS when PiecesOS is running""" @@ -278,7 +278,7 @@ def test_check_pieces_os_status_starts_health_ws( # Mock the workstream API call mock_settings.pieces_client.work_stream_pattern_engine_api.workstream_pattern_engine_processors_vision_status.return_value = Mock() - result = mock_connection._check_pieces_os_status() + result = await mock_connection._check_pieces_os_status() assert result is True mock_health_ws_instance.start.assert_called_once() @@ -297,7 +297,7 @@ async def test_multiple_validation_calls_same_tool(self, mock_connection): """Test that multiple validation calls for the same tool work correctly""" # Mock all components with ( - patch.object(mock_connection, "_check_pieces_os_status", return_value=True), + patch.object(mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True)), patch.object(mock_connection, "_check_ltm_status", return_value=True), ): mock_result = Mock() @@ -305,8 +305,8 @@ async def test_multiple_validation_calls_same_tool(self, mock_connection): mock_connection.result = mock_result # Call validation multiple times - is_valid1, msg1 = mock_connection._validate_system_status("ask_pieces_ltm") - is_valid2, msg2 = mock_connection._validate_system_status("ask_pieces_ltm") + is_valid1, msg1 = await mock_connection._validate_system_status("ask_pieces_ltm") + is_valid2, msg2 = await mock_connection._validate_system_status("ask_pieces_ltm") assert is_valid1 == is_valid2 assert is_valid1 is True diff --git a/tests/mcps/mcp_gateway/utils.py b/tests/mcps/mcp_gateway/utils.py index a4c26d84..950c7883 100644 --- a/tests/mcps/mcp_gateway/utils.py +++ b/tests/mcps/mcp_gateway/utils.py @@ -11,8 +11,7 @@ import urllib.request import pytest import requests -import mcp.types as types -from unittest.mock import Mock, patch +from unittest.mock import Mock from pieces.mcp.gateway import MCPGateway, PosMcpConnection from pieces.mcp.utils import get_mcp_latest_url from pieces.settings import Settings @@ -25,8 +24,8 @@ TEST_SERVER_NAME = "pieces-test-mcp" """Default server name used in MCP Gateway tests.""" -DEFAULT_TEST_URL = "http://localhost:39300/model_context_protocol/2024-11-05/sse" -"""Fallback URL when Settings.startup() fails.""" +DEFAULT_TEST_URL = "http://localhost:39300/model_context_protocol/2025-03-26/mcp" +"""Fallback URL when Settings.startup() fails (uses preferred streamable HTTP schema).""" # ===== MOCK HELPERS ===== @@ -244,5 +243,4 @@ def __exit__(self, exc_type, exc_val, exc_tb): "MockPiecesOSContext", # Re-exports for convenience "UpdateEnum", - "types", ] From 493a9be55302cdb9e3f73994173877e705b306ec Mon Sep 17 00:00:00 2001 From: bishoy-at-pieces Date: Wed, 25 Feb 2026 16:34:23 +0200 Subject: [PATCH 2/3] test(mcps): update mock path for mcp latest url --- tests/mcps/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mcps/utils.py b/tests/mcps/utils.py index ee518c03..990c882b 100644 --- a/tests/mcps/utils.py +++ b/tests/mcps/utils.py @@ -205,7 +205,7 @@ def setUpClass(cls) -> None: # noqa: D401 – overrides TestCase hook cls.mcp_urls_patcher.start() cls.mcp_latest_url_patcher = patch( - "pieces.mcp.integration.get_mcp_latest_url", return_value="pieces_url" + "pieces.mcp.utils.get_mcp_latest_url", return_value="pieces_url" ) cls.mcp_latest_url_patcher.start() From 5c78be907de043f1d264aa942f3e1cc58f628848 Mon Sep 17 00:00:00 2001 From: Mark Widman Date: Thu, 5 Mar 2026 14:51:22 -0500 Subject: [PATCH 3/3] chore: pr feedback --- documentation/mcp.md | 5 +- src/pieces/mcp/gateway.py | 97 +++++++++++-------- src/pieces/mcp/integration.py | 9 +- .../mcp_gateway/test_validation_advanced.py | 13 ++- 4 files changed, 76 insertions(+), 48 deletions(-) diff --git a/documentation/mcp.md b/documentation/mcp.md index ef084459..eb38101b 100644 --- a/documentation/mcp.md +++ b/documentation/mcp.md @@ -109,8 +109,9 @@ commands like `pieces restart` or `pieces open`. `_check_pieces_os_status` is fully async. Blocking SDK calls (health WebSocket start, user snapshot, LTM status, etc.) are offloaded via `asyncio.to_thread()` -so the event loop is never stalled. An `asyncio.Lock` guards the fast-path -check to prevent redundant health probes. +so the event loop is never stalled. An `asyncio.Lock` guards the entire method +(both fast path and slow path) so that only one coroutine runs the health +probe at a time, preventing redundant health-WS starts and shared-state races. ## Notification Handling diff --git a/src/pieces/mcp/gateway.py b/src/pieces/mcp/gateway.py index 688598c1..fa2784ce 100644 --- a/src/pieces/mcp/gateway.py +++ b/src/pieces/mcp/gateway.py @@ -67,7 +67,7 @@ class PosMcpConnection: """ def __init__( - self, upstream_url: str, tools_changed_callback: Callable[[], Awaitable[None]] + self, upstream_url: str | None, tools_changed_callback: Callable[[], Awaitable[None]] ): self.upstream_url: str | None = upstream_url self.CONNECTION_ESTABLISH_ATTEMPTS: int = 100 @@ -87,6 +87,7 @@ def __init__( self._cleanup_requested: asyncio.Event = asyncio.Event() self._connection_task: asyncio.Task | None = None + self._event_loop: asyncio.AbstractEventLoop | None = None def _try_get_upstream_url(self) -> bool: """Try to resolve the upstream URL if we don't have one yet. @@ -109,17 +110,20 @@ def _try_get_upstream_url(self) -> bool: return True def request_cleanup(self) -> None: - """Request cleanup from exception handler (thread-safe).""" + """Request cleanup from exception handler (thread-safe). + + Uses the stored ``_event_loop`` reference rather than + ``asyncio.get_running_loop()`` because this method is invoked + from WebSocket callback threads where no asyncio loop is running. + """ Settings.logger.debug("Cleanup requested from exception handler") - loop = None - try: - loop = asyncio.get_running_loop() - except RuntimeError: - Settings.logger.debug("No running loop for cleanup request") + loop = self._event_loop + if loop is None: + Settings.logger.debug("No event loop stored yet for cleanup request") return - if loop and not loop.is_closed(): + if not loop.is_closed(): loop.call_soon_threadsafe(self._schedule_cleanup) def _schedule_cleanup(self) -> None: @@ -196,13 +200,17 @@ def _check_version_compatibility(self) -> tuple[bool, str]: async def _check_pieces_os_status(self) -> bool: """Check if PiecesOS is running and initialise health state. - Two-phase check: - 1. **Fast path** (under ``_health_check_lock``): if the health - WebSocket is already running, return immediately. + Two-phase check (both under ``_health_check_lock``): + 1. **Fast path**: if the health WebSocket is already running, + return immediately. 2. **Slow path**: call blocking SDK methods (``is_pieces_running``, ``health_ws.start``, ``user_snapshot``, etc.) via ``asyncio.to_thread()`` so the event loop is never stalled. + The lock covers the entire method so that only one coroutine runs + the slow-path probe at a time, preventing redundant health-WS + starts and shared-state races. + Returns: True if PiecesOS is healthy and reachable, False otherwise. """ @@ -210,36 +218,38 @@ async def _check_pieces_os_status(self) -> bool: if HealthWS.is_running() and Settings.pieces_client.is_pos_stream_running: return True - is_running = await asyncio.to_thread(Settings.pieces_client.is_pieces_running, 2) - if not is_running: - return False + is_running = await asyncio.to_thread( + Settings.pieces_client.is_pieces_running, 2 + ) + if not is_running: + return False - try: - health_ws = HealthWS.get_instance() - if health_ws: - await asyncio.to_thread(health_ws.start) + try: + health_ws = HealthWS.get_instance() + if health_ws: + await asyncio.to_thread(health_ws.start) - os_id = await asyncio.to_thread(Settings.get_os_id) - sentry_sdk.set_user({"id": os_id or "unknown"}) + os_id = await asyncio.to_thread(Settings.get_os_id) + sentry_sdk.set_user({"id": os_id or "unknown"}) - snapshot = await asyncio.to_thread( - Settings.pieces_client.user_api.user_snapshot - ) - Settings.pieces_client.user.user_profile = snapshot.user + snapshot = await asyncio.to_thread( + Settings.pieces_client.user_api.user_snapshot + ) + Settings.pieces_client.user.user_profile = snapshot.user - ltm_status = await asyncio.to_thread( - Settings.pieces_client.work_stream_pattern_engine_api - .workstream_pattern_engine_processors_vision_status - ) - Settings.pieces_client.copilot.context.ltm.ltm_status = ltm_status + ltm_status = await asyncio.to_thread( + Settings.pieces_client.work_stream_pattern_engine_api + .workstream_pattern_engine_processors_vision_status + ) + Settings.pieces_client.copilot.context.ltm.ltm_status = ltm_status - invalidate_mcp_url_cache() - return True - except Exception as e: - Settings.logger.warning( - f"PiecesOS appears to be running but health check failed: {e}" - ) - return False + invalidate_mcp_url_cache() + return True + except Exception as e: + Settings.logger.warning( + f"PiecesOS appears to be running but health check failed: {e}" + ) + return False def _check_ltm_status(self) -> bool: """Check if LTM is enabled.""" @@ -571,6 +581,9 @@ async def connect(self, send_notification: bool = True) -> ClientSession: TimeoutError: If the connection is not established within 10 s. """ async with self.connection_lock: + if self._event_loop is None: + self._event_loop = asyncio.get_running_loop() + session = self.session if ( session is not None @@ -1070,8 +1083,16 @@ def on_ws_event(ws: Any, e: Exception) -> None: Settings.logger.error(f"Unexpected error in main: {e}", exc_info=True) finally: Settings.logger.info("MCP Gateway shutting down") - for ws in [ltm_vision, user_ws, health_ws]: + + async def _close_ws(ws_instance: Any) -> None: try: - ws.close() + await asyncio.wait_for( + asyncio.to_thread(ws_instance.close), timeout=5.0 + ) except Exception: pass + + await asyncio.gather( + *(_close_ws(ws) for ws in [ltm_vision, user_ws, health_ws]), + return_exceptions=True, + ) diff --git a/src/pieces/mcp/integration.py b/src/pieces/mcp/integration.py index 68525bb5..7186ab21 100644 --- a/src/pieces/mcp/integration.py +++ b/src/pieces/mcp/integration.py @@ -82,7 +82,14 @@ def mcp_path(self, mcp_type: mcp_types): def mcp_modified_settings(self, mcp_type: mcp_types): mcp_settings = self.mcp_settings(mcp_type) if mcp_type == "sse": - mcp_settings[self.url_property_name] = get_mcp_sse_url() + sse_url = get_mcp_sse_url() + if sse_url is None: + raise HeadlessError( + "Unable to determine PiecesOS MCP SSE URL. " + "Please ensure PiecesOS is running and reachable before " + "configuring IDE integration." + ) + mcp_settings[self.url_property_name] = sse_url else: mcp_settings[self.command_property_name] = self.pieces_cli_bin_path mcp_settings[self.args_property_name] = [ diff --git a/tests/mcps/mcp_gateway/test_validation_advanced.py b/tests/mcps/mcp_gateway/test_validation_advanced.py index 3d249305..688b4ee8 100644 --- a/tests/mcps/mcp_gateway/test_validation_advanced.py +++ b/tests/mcps/mcp_gateway/test_validation_advanced.py @@ -21,7 +21,6 @@ class TestMCPGatewayValidationAdvanced: @pytest.mark.asyncio async def test_concurrent_validation_calls(self, mock_connection): """Test that concurrent validation calls don't cause race conditions""" - # Mock all components to return True with patch.object( mock_connection, "_check_pieces_os_status", new=AsyncMock(return_value=True) ): @@ -29,13 +28,13 @@ async def test_concurrent_validation_calls(self, mock_connection): mock_connection, "_check_version_compatibility", return_value=(True, "") ): with patch.object(mock_connection, "_check_ltm_status", return_value=True): - # Run multiple validations concurrently - results = [] - for i in range(10): - result = await mock_connection._validate_system_status(f"tool_{i}") - results.append(result) + results = await asyncio.gather( + *( + mock_connection._validate_system_status(f"tool_{i}") + for i in range(10) + ) + ) - # All should succeed assert all(result[0] for result in results) assert all(result[1] == "" for result in results)