diff --git a/cogsol/core/api.py b/cogsol/core/api.py index f8347e8..cd1e885 100644 --- a/cogsol/core/api.py +++ b/cogsol/core/api.py @@ -403,6 +403,63 @@ def get_script(self, script_id: int) -> Any: def get_retrieval_tool(self, tool_id: int) -> Any: return self.request("GET", f"/tools/retrievals/{tool_id}/") + # ========================================================================= + # MCP Servers & Tools + # ========================================================================= + + def list_mcp_servers(self) -> Any: + """List all MCP servers.""" + return self.request("GET", "/mcp-servers/") + + def create_mcp_server(self, payload: dict[str, Any]) -> int: + """Create an MCP server and return its id.""" + data = self.request("POST", "/mcp-servers/", payload) + return self._ensure_id(data, "MCPServer") + + def upsert_mcp_server(self, *, remote_id: int | None, payload: dict[str, Any]) -> int: + """Create or update an MCP server.""" + if remote_id: + data = self.request("PUT", f"/mcp-servers/{remote_id}/", payload) + else: + data = self.request("POST", "/mcp-servers/", payload) + return self._ensure_id(data, "MCPServer") + + def delete_mcp_server(self, server_id: int) -> None: + """Delete an MCP server by id.""" + self.request("DELETE", f"/mcp-servers/{server_id}/") + + def get_mcp_server(self, server_id: int) -> Any: + """Retrieve an MCP server by id.""" + return self.request("GET", f"/mcp-servers/{server_id}/") + + def discover_mcp_oauth(self, server_id: int) -> Any: + """Discover OAuth metadata for an MCP server.""" + return self.request("POST", f"/mcp-servers/{server_id}/oauth/discover/") + + def get_mcp_oauth_authorization_url(self, server_id: int) -> Any: + """Get OAuth authorization URL for an MCP server.""" + return self.request("GET", f"/mcp-servers/{server_id}/oauth/authorize/") + + def list_mcp_server_tools(self, server_id: int) -> Any: + """List tools currently configured on an MCP server.""" + return self.request("GET", f"/mcp-servers/{server_id}/tools/") + + def sync_mcp_server_tools(self, server_id: int, selected_tools: list[str]) -> Any: + """Sync (create/update) the selected tools on an MCP server. + + The backend uses a POST with ``{"selected_tools": [...]}`` + to reconcile the tool set. + """ + return self.request( + "POST", + f"/mcp-servers/{server_id}/tools/", + {"selected_tools": selected_tools}, + ) + + def delete_mcp_tool(self, tool_id: int) -> None: + """Delete an MCP tool by id.""" + self.request("DELETE", f"/mcp-tools/{tool_id}/") + # ========================================================================= # Content API - Nodes (Topics) # ========================================================================= diff --git a/cogsol/core/loader.py b/cogsol/core/loader.py index e619aad..357bf3c 100644 --- a/cogsol/core/loader.py +++ b/cogsol/core/loader.py @@ -28,6 +28,8 @@ BaseFAQ, BaseFixedResponse, BaseLesson, + BaseMCPServer, + BaseMCPTool, BaseRetrievalTool, BaseTool, ) @@ -69,6 +71,8 @@ def serialize_value(value: Any) -> Any: return ( getattr(value, "name", None) or getattr(value, "key", None) or value.__class__.__name__ ) + if isinstance(value, type) and issubclass(value, BaseMCPServer): + return getattr(value, "name", None) or value.__name__ if isinstance(value, type): if issubclass(value, BaseRetrieval): return getattr(value, "name", None) or value.__name__ @@ -294,6 +298,8 @@ def collect_definitions( "agents": {}, "tools": {}, "retrieval_tools": {}, + "mcp_servers": {}, + "mcp_tools": {}, "faqs": {}, "fixed_responses": {}, "lessons": {}, @@ -346,6 +352,64 @@ def collect_definitions( if not _ignore_missing_module(exc, f"{app_name}.searches"): _raise_import_error("retrieval tools module", f"{app_name}.searches", exc) + # MCP servers (global) + try: + mcp_server_module = _import_module(f"{app_name}.mcp_servers", project_path) + mcp_srv_prefix = f"{mcp_server_module.__name__}." + for _, obj in inspect.getmembers(mcp_server_module, inspect.isclass): + if ( + issubclass(obj, BaseMCPServer) + and obj is not BaseMCPServer + and ( + obj.__module__ == mcp_server_module.__name__ + or obj.__module__.startswith(mcp_srv_prefix) + ) + ): + fields, meta = _extract_class_fields(obj) + name = fields.get("name") or obj.__name__ + fields["name"] = name + # Resolve env-var references so values are evaluated at collect time + if hasattr(obj, "url") and obj.url is not None: + fields["url"] = obj.url + if hasattr(obj, "headers") and obj.headers: + fields["headers"] = obj.headers + # OAuth fields + if hasattr(obj, "auth_type"): + fields["auth_type"] = obj.auth_type + if hasattr(obj, "oauth_client_id") and obj.oauth_client_id is not None: + fields["oauth_client_id"] = obj.oauth_client_id + if hasattr(obj, "oauth_scopes") and obj.oauth_scopes is not None: + fields["oauth_scopes"] = obj.oauth_scopes + definitions["mcp_servers"][name] = {"fields": fields, "meta": meta} + except ModuleNotFoundError as exc: + if not _ignore_missing_module(exc, f"{app_name}.mcp_servers"): + _raise_import_error("MCP servers module", f"{app_name}.mcp_servers", exc) + + # MCP tools (global) + try: + mcp_tool_module = _import_module(f"{app_name}.mcp_tools", project_path) + mcp_tool_prefix = f"{mcp_tool_module.__name__}." + for _, obj in inspect.getmembers(mcp_tool_module, inspect.isclass): + if ( + issubclass(obj, BaseMCPTool) + and obj is not BaseMCPTool + and ( + obj.__module__ == mcp_tool_module.__name__ + or obj.__module__.startswith(mcp_tool_prefix) + ) + ): + fields, meta = _extract_class_fields(obj) + name = fields.get("name") or obj.__name__ + fields["name"] = name + # Serialize server reference as the server class name + server_cls = getattr(obj, "server", None) + if server_cls is not None and isinstance(server_cls, type): + fields["server"] = getattr(server_cls, "name", None) or server_cls.__name__ + definitions["mcp_tools"][name] = {"fields": fields, "meta": meta} + except ModuleNotFoundError as exc: + if not _ignore_missing_module(exc, f"{app_name}.mcp_tools"): + _raise_import_error("MCP tools module", f"{app_name}.mcp_tools", exc) + # Per-agent packages (agents//agent.py) for sub in sorted(app_path.iterdir()): if not sub.is_dir(): @@ -483,6 +547,8 @@ def collect_classes(project_path: Path, app_name: str = "agents") -> dict[str, d "agents": {}, "tools": {}, "retrieval_tools": {}, + "mcp_servers": {}, + "mcp_tools": {}, } # Tools @@ -521,6 +587,44 @@ def collect_classes(project_path: Path, app_name: str = "agents") -> dict[str, d if not _ignore_missing_module(exc, f"{app_name}.searches"): _raise_import_error("retrieval tools module", f"{app_name}.searches", exc) + # MCP servers + try: + mcp_server_module = _import_module(f"{app_name}.mcp_servers", project_path) + mcp_srv_prefix = f"{mcp_server_module.__name__}." + for _, obj in inspect.getmembers(mcp_server_module, inspect.isclass): + if ( + issubclass(obj, BaseMCPServer) + and obj is not BaseMCPServer + and ( + obj.__module__ == mcp_server_module.__name__ + or obj.__module__.startswith(mcp_srv_prefix) + ) + ): + name = getattr(obj, "name", None) or obj.__name__ + classes["mcp_servers"][name] = obj + except ModuleNotFoundError as exc: + if not _ignore_missing_module(exc, f"{app_name}.mcp_servers"): + _raise_import_error("MCP servers module", f"{app_name}.mcp_servers", exc) + + # MCP tools + try: + mcp_tool_module = _import_module(f"{app_name}.mcp_tools", project_path) + mcp_tool_prefix = f"{mcp_tool_module.__name__}." + for _, obj in inspect.getmembers(mcp_tool_module, inspect.isclass): + if ( + issubclass(obj, BaseMCPTool) + and obj is not BaseMCPTool + and ( + obj.__module__ == mcp_tool_module.__name__ + or obj.__module__.startswith(mcp_tool_prefix) + ) + ): + name = getattr(obj, "name", None) or obj.__name__ + classes["mcp_tools"][name] = obj + except ModuleNotFoundError as exc: + if not _ignore_missing_module(exc, f"{app_name}.mcp_tools"): + _raise_import_error("MCP tools module", f"{app_name}.mcp_tools", exc) + # Agents per folder for sub in sorted(app_path.iterdir()): if not sub.is_dir(): diff --git a/cogsol/core/management.py b/cogsol/core/management.py index 22d35da..91a5c0b 100644 --- a/cogsol/core/management.py +++ b/cogsol/core/management.py @@ -21,6 +21,7 @@ def _command_registry() -> dict[str, str]: "importagent": "cogsol.management.commands.importagent", "makemigrations": "cogsol.management.commands.makemigrations", "migrate": "cogsol.management.commands.migrate", + "addmcptools": "cogsol.management.commands.addmcptools", "chat": "cogsol.management.commands.chat", } diff --git a/cogsol/core/mcp.py b/cogsol/core/mcp.py new file mode 100644 index 0000000..c70f1e4 --- /dev/null +++ b/cogsol/core/mcp.py @@ -0,0 +1,220 @@ +"""Lightweight MCP (Model Context Protocol) client using only stdlib. + +Used by the ``addmcptools`` CLI command to discover tools exposed by an +MCP server. No third-party dependencies are required. +""" + +from __future__ import annotations + +import json +import re +import uuid +from typing import Any, cast +from urllib import error, request + + +class MCPClientError(RuntimeError): + """Raised when an MCP protocol or connection error occurs.""" + + +class MCPClient: + """Simplified MCP client that speaks JSON-RPC 2.0 over HTTP(S). + + Parameters + ---------- + server_url: + The base URL of the MCP server (e.g. ``https://mcp.example.com/sse``). + headers: + Optional extra headers to send with every request (e.g. API keys). + auth_type: + Authentication type declared on the server (``"none"``, ``"headers"`` + or ``"oauth2"``). Used only to produce a helpful diagnostic when the + server responds with 401 during tool discovery. + """ + + def __init__( + self, + server_url: str, + headers: dict[str, str] | None = None, + auth_type: str = "headers", + ): + self.server_url = server_url.rstrip("/") + self.extra_headers: dict[str, str] = dict(headers or {}) + self.auth_type = auth_type + self.session_id: str | None = None + self.tools: list[dict[str, Any]] = [] + self.initialized = False + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def initialize(self) -> bool: + """Perform the MCP ``initialize`` + ``tools/list`` handshake. + + Returns ``True`` on success, ``False`` on failure. + """ + try: + self._make_request( + "initialize", + { + "protocolVersion": "2025-03-26", + "capabilities": {"tools": {}}, + "clientInfo": { + "name": "cognitive-mcp-client", + "version": "1.0.0", + }, + }, + ) + + result = self._make_request("tools/list") + self.tools = [ + { + "name": t["name"], + "description": t.get("description", ""), + "input_schema": t.get("inputSchema", {}), + } + for t in result.get("tools", []) + ] + self.initialized = True + return True + except MCPClientError as exc: + msg = str(exc) + if "HTTP 401" in msg and self.auth_type == "oauth2": + print( + "[MCPClient] Received 401 — this OAuth 2.1 server requires user " + "authorization.\n" + " Tool discovery without auth failed, but you can still create the " + "server definition.\n" + " Complete the OAuth authorization flow from the CogSol portal after " + "running `migrate`." + ) + else: + print(f"[MCPClient] Failed to initialize: {exc}") + return False + except Exception as exc: + print(f"[MCPClient] Failed to initialize: {exc}") + return False + + def list_tools(self) -> list[dict[str, Any]]: + """Return tools discovered during ``initialize``.""" + return list(self.tools) + + def disconnect(self) -> None: + """Send a best-effort DELETE to close the session.""" + if not self.session_id: + return + try: + headers = {"Mcp-Session-Id": self.session_id} + headers.update(self.extra_headers) + req = request.Request(self.server_url, headers=headers, method="DELETE") + with request.urlopen(req, timeout=5): + pass + except Exception: + pass + + # ------------------------------------------------------------------ + # Internals + # ------------------------------------------------------------------ + + @staticmethod + def _summarize_http_error(detail: str) -> str: + """Return a concise HTTP error detail suitable for terminal output.""" + clean = (detail or "").strip() + if not clean: + return "" + + # Cloudflare / proxy pages often return full HTML documents. + if "(.*?)", clean, re.IGNORECASE | re.DOTALL) + if title_match: + title = re.sub(r"\s+", " ", title_match.group(1)).strip() + return f"HTML error page returned ({title})" + h1_match = re.search(r"]*>(.*?)", clean, re.IGNORECASE | re.DOTALL) + if h1_match: + heading = re.sub(r"\s+", " ", h1_match.group(1)).strip() + return f"HTML error page returned ({heading})" + return "HTML error page returned by remote server" + + single_line = re.sub(r"\s+", " ", clean) + if len(single_line) > 240: + return f"{single_line[:237]}..." + return single_line + + def _make_request(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + request_id = str(uuid.uuid4()) + payload: dict[str, Any] = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + } + if params: + payload["params"] = params + + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if self.session_id: + headers["Mcp-Session-Id"] = self.session_id + headers.update(self.extra_headers) + + body = json.dumps(payload).encode("utf-8") + req = request.Request(self.server_url, data=body, headers=headers, method="POST") + + try: + with request.urlopen(req, timeout=30) as resp: + # Capture session id on initialize + if method == "initialize": + sid = resp.headers.get("Mcp-Session-Id") + if sid: + self.session_id = sid + + content_type = (resp.headers.get("Content-Type") or "").lower() + raw = resp.read().decode("utf-8") + + if "text/event-stream" in content_type: + return self._parse_sse(raw) + + # Default: treat as JSON + result_obj = json.loads(raw) + if not isinstance(result_obj, dict): + raise MCPClientError("Invalid JSON-RPC response payload") + result = cast(dict[str, Any], result_obj) + if "error" in result: + raise MCPClientError(f"MCP Error: {result['error']}") + result_payload = result.get("result", {}) + if isinstance(result_payload, dict): + return cast(dict[str, Any], result_payload) + raise MCPClientError("Invalid JSON-RPC result payload") + + except error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="ignore") + short_detail = self._summarize_http_error(detail) + if short_detail: + raise MCPClientError(f"HTTP {exc.code} {exc.reason}: {short_detail}") from exc + raise MCPClientError(f"HTTP {exc.code} {exc.reason}") from exc + except error.URLError as exc: + raise MCPClientError(f"Connection error: {exc.reason}") from exc + + @staticmethod + def _parse_sse(text: str) -> dict[str, Any]: + """Extract the first JSON-RPC result from an SSE stream.""" + for line in text.split("\n"): + line = line.strip() + if line.startswith("data: "): + try: + data_obj = json.loads(line[6:]) + if not isinstance(data_obj, dict): + continue + data = cast(dict[str, Any], data_obj) + if "result" in data: + result_payload = data["result"] + if isinstance(result_payload, dict): + return cast(dict[str, Any], result_payload) + raise MCPClientError("Invalid JSON-RPC result payload") + if "error" in data: + raise MCPClientError(f"MCP Error: {data['error']}") + except json.JSONDecodeError: + continue + raise MCPClientError("No valid JSON-RPC response found in SSE stream") diff --git a/cogsol/core/migrations.py b/cogsol/core/migrations.py index 75afa3a..14402b7 100644 --- a/cogsol/core/migrations.py +++ b/cogsol/core/migrations.py @@ -18,6 +18,8 @@ def empty_state() -> dict[str, dict[str, dict[str, Any]]]: "lessons": {}, "faqs": {}, "fixed_responses": {}, + "mcp_servers": {}, + "mcp_tools": {}, } @@ -77,6 +79,8 @@ def _diff_bucket(entity: str, prev_defs: dict[str, Any], current_defs: dict[str, "lessons": ops.CreateLesson, "faqs": ops.CreateFAQ, "fixed_responses": ops.CreateFixedResponse, + "mcp_servers": ops.CreateMCPServer, + "mcp_tools": ops.CreateMCPTool, }[entity] for name, definition in current_defs.items(): @@ -166,7 +170,16 @@ def diff_states( ) else: # Cognitive API entities (agents) - for entity in ["retrieval_tools", "tools", "agents", "lessons", "faqs", "fixed_responses"]: + for entity in [ + "mcp_servers", + "mcp_tools", + "retrieval_tools", + "tools", + "agents", + "lessons", + "faqs", + "fixed_responses", + ]: operations.extend( _diff_bucket( entity, diff --git a/cogsol/db/migrations.py b/cogsol/db/migrations.py index 75d0e50..e2f987a 100644 --- a/cogsol/db/migrations.py +++ b/cogsol/db/migrations.py @@ -129,6 +129,25 @@ def __init__(self, name: str, fields: dict[str, Any]) -> None: super().__init__(name=name, fields=fields, entity="retrievals") +# ============================================================================= +# MCP entities (agents/ folder) +# ============================================================================= + + +class CreateMCPServer(CreateDefinition): + """Creates an MCP Server definition.""" + + def __init__(self, name: str, fields: dict[str, Any]) -> None: + super().__init__(name=name, fields=fields, entity="mcp_servers") + + +class CreateMCPTool(CreateDefinition): + """Creates an MCP Tool definition linked to an MCP Server.""" + + def __init__(self, name: str, fields: dict[str, Any]) -> None: + super().__init__(name=name, fields=fields, entity="mcp_tools") + + @dataclass class AlterField: model_name: str diff --git a/cogsol/management/commands/addmcptools.py b/cogsol/management/commands/addmcptools.py new file mode 100644 index 0000000..77c1d99 --- /dev/null +++ b/cogsol/management/commands/addmcptools.py @@ -0,0 +1,812 @@ +"""Interactive command for adding MCP server + tool definitions. + +Mirrors the frontend flow: + Step 1 – prompt for server details (name, description, URL, auth type) + • auth_type="none" – no credentials needed + • auth_type="headers" – static headers (e.g. API keys) + • auth_type="oauth2" – OAuth 2.1 / PKCE flow + Step 2 – connect to the MCP server, list tools, let user select + (OAuth servers are contacted without auth; full authorization + is completed from the CogSol portal after ``migrate``) + Step 3 – generate ``agents/mcp_servers.py`` and ``agents/mcp_tools.py`` + and update ``.env`` with any sensitive values. + +OAuth 2.1 note +-------------- +``client_id`` and ``client_secret`` are **both optional** — the cognitive +backend supports Dynamic Client Registration (RFC 7591) and will obtain them +automatically if omitted. + +The ``client_secret`` is NEVER written to ``.env`` or to source files. It is +sent write-only to the CogSol API, which stores it in Azure Key Vault. +""" + +from __future__ import annotations + +import os +import re +import time +import webbrowser +from getpass import getpass +from pathlib import Path +from typing import Any + +from cogsol.core.api import CogSolAPIError, CogSolClient +from cogsol.core.env import load_dotenv +from cogsol.core.mcp import MCPClient +from cogsol.management.base import BaseCommand + +# Header keys offered by the frontend's McpToolGeneral component. +HEADER_KEYS = [ + "Authorization", + "x-api-key", + "Content-Type", + "Accept", + "User-Agent", + "x-custom-header", + "x-auth-token", + "x-client-id", +] + +AUTH_TYPES = ["none", "headers", "oauth2"] +OAUTH_POLL_SECONDS_DEFAULT = 300 +OAUTH_POLL_INTERVAL_SECONDS = 2 + + +def _ask(prompt: str, default: str = "") -> str: + """Simple input wrapper with an optional default shown in brackets.""" + suffix = f" [{default}]" if default else "" + value = input(f"{prompt}{suffix}: ").strip() + return value or default + + +def _ask_secret(prompt: str) -> str: + """Ask for a secret value using masked terminal input when possible.""" + try: + value = getpass(f"{prompt} (leave blank to skip): ").strip() + except Exception: + value = input(f"{prompt} (leave blank to skip): ").strip() + return value + + +def _ask_yes_no(prompt: str, default: bool = True) -> bool: + hint = "Y/n" if default else "y/N" + answer = input(f"{prompt} [{hint}]: ").strip().lower() + if not answer: + return default + return answer.startswith("y") + + +def _to_env_key(name: str) -> str: + """Derive an env-var name from a server/header name.""" + key = re.sub(r"[^a-zA-Z0-9]+", "_", name).strip("_").upper() + return f"MCP_{key}" + + +def _py_str(value: str) -> str: + """Return a safe Python string literal for generated source.""" + return repr(value) + + +def _oauth_config(client_id: str, scopes: str) -> dict[str, Any]: + """Build oauth_config payload for the API, omitting empty values.""" + cfg: dict[str, Any] = {} + if client_id: + cfg["client_id"] = client_id + if scopes: + cfg["scopes"] = scopes + return cfg + + +class Command(BaseCommand): + help = "Interactively configure an MCP server and select tools." + + def add_arguments(self, parser): + parser.add_argument( + "--app", + default="agents", + help="App folder (default: agents).", + ) + parser.add_argument( + "--oauth-timeout", + default=OAUTH_POLL_SECONDS_DEFAULT, + type=int, + help="Seconds to wait for OAuth completion when browser flow is triggered.", + ) + + def _extract_results(self, payload: Any) -> list[dict[str, Any]]: + if isinstance(payload, list): + return [item for item in payload if isinstance(item, dict)] + if isinstance(payload, dict): + if isinstance(payload.get("results"), list): + return [item for item in payload["results"] if isinstance(item, dict)] + if isinstance(payload.get("tools"), list): + return [item for item in payload["tools"] if isinstance(item, dict)] + return [] + + def _norm(self, value: Any) -> str: + text = str(value or "") + return re.sub(r"\s+", " ", text).strip().casefold() + + def _find_remote_server( + self, + *, + client: CogSolClient, + server_name: str, + server_url: str, + ) -> dict[str, Any] | None: + servers = self._extract_results(client.list_mcp_servers()) + server_name_n = self._norm(server_name) + server_url_n = self._norm(str(server_url).rstrip("/")) + + exact = [ + s + for s in servers + if self._norm(s.get("name")) == server_name_n + and self._norm(str(s.get("url", "")).rstrip("/")) == server_url_n + ] + if exact: + return exact[0] + + name_match = [s for s in servers if self._norm(s.get("name")) == server_name_n] + if len(name_match) == 1: + return name_match[0] + + url_match = [ + s for s in servers if self._norm(str(s.get("url", "")).rstrip("/")) == server_url_n + ] + if len(url_match) == 1: + return url_match[0] + + # If duplicates exist for the same URL, prefer oauth2 and most recently updated. + if len(url_match) > 1: + oauth_candidates = [s for s in url_match if self._norm(s.get("auth_type")) == "oauth2"] + candidates = oauth_candidates or url_match + candidates.sort( + key=lambda s: str(s.get("updated_at") or s.get("created_at") or ""), + reverse=True, + ) + return candidates[0] + + if len(name_match) > 1: + name_match.sort( + key=lambda s: str(s.get("updated_at") or s.get("created_at") or ""), + reverse=True, + ) + return name_match[0] + + return None + + def _wait_for_oauth_connected( + self, + *, + client: CogSolClient, + server_id: int, + timeout_seconds: int, + ) -> bool: + start = time.time() + while time.time() - start < timeout_seconds: + try: + data = client.get_mcp_server(server_id) or {} + status = str(data.get("oauth_status", "")).lower() + if status == "connected": + return True + except CogSolAPIError: + # Keep polling; callback might still be processing. + pass + time.sleep(OAUTH_POLL_INTERVAL_SECONDS) + return False + + def _is_oauth_reauthorization_error(self, exc: CogSolAPIError) -> bool: + return "oauth re-authorization required" in str(exc).lower() + + def _start_oauth_authorization( + self, + *, + client: CogSolClient, + server_id: int, + server_name: str, + oauth_timeout: int, + ) -> None: + auth_payload = client.get_mcp_oauth_authorization_url(server_id) or {} + authorization_url = auth_payload.get("authorization_url") + if not authorization_url: + raise CogSolAPIError( + "OAuth authorization URL could not be generated for " f"MCP server '{server_name}'." + ) + + opened = webbrowser.open(str(authorization_url), new=1, autoraise=True) + if not opened: + print(" Could not auto-open browser. Open this URL manually:") + print(f" {authorization_url}") + + connected = self._wait_for_oauth_connected( + client=client, + server_id=server_id, + timeout_seconds=max(5, int(oauth_timeout)), + ) + if not connected: + raise CogSolAPIError( + "OAuth authorization did not complete within timeout. " + "Please finish OAuth in browser and retry addmcptools." + ) + + def _oauth_assisted_discovery( + self, + *, + server_name: str, + server_description: str, + server_url: str, + oauth_client_id: str, + oauth_client_secret: str, + oauth_scopes: str, + oauth_timeout: int, + ) -> list[dict[str, Any]]: + api_base = os.environ.get("COGSOL_API_BASE") + api_key = os.environ.get("COGSOL_API_KEY") + if not api_base: + print("COGSOL_API_BASE is required to run assisted OAuth tool discovery.") + return [] + + api_client = CogSolClient(base_url=api_base, api_key=api_key) + remote = self._find_remote_server( + client=api_client, + server_name=server_name, + server_url=server_url, + ) + if not remote or "id" not in remote: + payload: dict[str, Any] = { + "name": server_name, + "description": server_description, + "url": server_url, + "headers": {}, + "protocol_version": "2025-03-26", + "client_name": "cognitive-mcp-client", + "client_version": "1.0.0", + "active": True, + "auth_type": "oauth2", + "oauth_config": _oauth_config(oauth_client_id, oauth_scopes), + } + if oauth_client_secret: + payload["oauth_client_secret"] = oauth_client_secret + try: + print("MCP server not found in API. Creating it now for OAuth discovery...") + server_id = int(api_client.upsert_mcp_server(remote_id=None, payload=payload)) + except CogSolAPIError as exc: + print( + "Could not create MCP server in CogSol API during OAuth onboarding.\n" + f"Details: {exc}" + ) + return [] + else: + server_id = int(remote["id"]) + + try: + print("Discovering OAuth metadata from CogSol API...") + api_client.discover_mcp_oauth(server_id) + auth_payload = api_client.get_mcp_oauth_authorization_url(server_id) or {} + authorization_url = auth_payload.get("authorization_url") + if not authorization_url: + print("Could not obtain OAuth authorization URL from the API.") + return [] + + print("Opening browser for OAuth authorization...") + opened = webbrowser.open(str(authorization_url), new=1, autoraise=True) + if not opened: + print("Could not auto-open browser. Open this URL manually:") + print(str(authorization_url)) + + print(f"Waiting for OAuth completion (timeout: {oauth_timeout}s)...") + connected = self._wait_for_oauth_connected( + client=api_client, + server_id=server_id, + timeout_seconds=max(5, int(oauth_timeout)), + ) + if not connected: + print("OAuth authorization timeout. You can retry addmcptools after authorizing.") + return [] + + tools_payload = api_client.list_mcp_server_tools(server_id) + tools = self._extract_results(tools_payload) + if not tools: + print("OAuth connected, but no tools were returned by the MCP server.") + return [] + + normalized: list[dict[str, Any]] = [] + for tool in tools: + if not tool.get("name"): + continue + normalized.append( + { + "name": str(tool.get("name")), + "description": str(tool.get("description") or ""), + "input_schema": tool.get("input_schema") or tool.get("inputSchema") or {}, + } + ) + return normalized + except CogSolAPIError as exc: + print(f"OAuth discovery via CogSol API failed: {exc}") + return [] + + def _publish_to_cognitive( + self, + *, + server_name: str, + server_description: str, + server_url: str, + auth_type: str, + headers: dict[str, str], + oauth_client_id: str, + oauth_client_secret: str, + oauth_scopes: str, + selected_tools: list[dict[str, Any]], + oauth_timeout: int, + ) -> None: + api_base = os.environ.get("COGSOL_API_BASE") + api_key = os.environ.get("COGSOL_API_KEY") + if not api_base: + raise CogSolAPIError( + "COGSOL_API_BASE is required. addmcptools now publishes MCP servers/tools " + "directly to Cognitive." + ) + + client = CogSolClient(base_url=api_base, api_key=api_key) + existing = self._find_remote_server( + client=client, + server_name=server_name, + server_url=server_url, + ) + remote_id = int(existing["id"]) if existing and existing.get("id") else None + + payload: dict[str, Any] = { + "name": server_name, + "description": server_description, + "url": server_url, + "headers": headers if auth_type == "headers" else {}, + "protocol_version": "2025-03-26", + "client_name": "cognitive-mcp-client", + "client_version": "1.0.0", + "active": True, + "auth_type": auth_type, + } + if auth_type == "oauth2": + payload["oauth_config"] = _oauth_config(oauth_client_id, oauth_scopes) + if oauth_client_secret: + payload["oauth_client_secret"] = oauth_client_secret + + server_id = int(client.upsert_mcp_server(remote_id=remote_id, payload=payload)) + action = "Updated" if remote_id else "Created" + print(f" {action} MCP server in Cognitive (id={server_id}).") + + if auth_type == "oauth2": + print(f" Refreshing OAuth metadata for server id={server_id}...") + force_authorization = False + try: + client.discover_mcp_oauth(server_id) + except CogSolAPIError as exc: + if self._is_oauth_reauthorization_error(exc): + print( + " OAuth re-authorization required while refreshing metadata; " + "continuing with authorization flow..." + ) + force_authorization = True + else: + raise + + server_data = client.get_mcp_server(server_id) or {} + status = str(server_data.get("oauth_status", "")).lower() + if force_authorization or status != "connected": + print(" OAuth server is not connected yet. Starting authorization flow...") + self._start_oauth_authorization( + client=client, + server_id=server_id, + server_name=server_name, + oauth_timeout=oauth_timeout, + ) + + tool_names = [str(t.get("name")) for t in selected_tools if t.get("name")] + if not tool_names: + print(" No MCP tools selected to sync in Cognitive.") + return + + try: + client.sync_mcp_server_tools(server_id, tool_names) + except CogSolAPIError as exc: + if auth_type != "oauth2" or not self._is_oauth_reauthorization_error(exc): + raise + + print( + " OAuth re-authorization required during tools sync. " + "Starting recovery flow and retrying once..." + ) + self._start_oauth_authorization( + client=client, + server_id=server_id, + server_name=server_name, + oauth_timeout=oauth_timeout, + ) + client.sync_mcp_server_tools(server_id, tool_names) + print(f" Synced {len(tool_names)} MCP tool(s) in Cognitive.") + + def handle(self, project_path: Path | None, **options: Any) -> int: # noqa: C901 + assert project_path is not None, "project_path is required" + app = str(options.get("app") or "agents") + oauth_timeout = int(options.get("oauth_timeout") or OAUTH_POLL_SECONDS_DEFAULT) + + load_dotenv(project_path / ".env") + + # ── Step 1: Server details ─────────────────────────────────── + print("\n=== Step 1: MCP Server Configuration ===\n") + server_name = _ask("Server name") + if not server_name: + print("A server name is required.") + return 1 + server_description = _ask("Description", "") + server_url = _ask("Server URL (e.g. https://mcp.example.com/mcp)") + if not server_url: + print("A server URL is required.") + return 1 + + # Auth type + print("\nAuthentication type:") + for i, auth_option in enumerate(AUTH_TYPES, 1): + suffix = " ← default" if auth_option == "headers" else "" + print(f" {i}. {auth_option}{suffix}") + auth_choice = _ask("Select auth type", "2") + try: + auth_idx = int(auth_choice) - 1 + if auth_idx < 0 or auth_idx >= len(AUTH_TYPES): + raise ValueError + except ValueError: + print("Invalid selection, defaulting to 'headers'.") + auth_idx = 1 + auth_type = AUTH_TYPES[auth_idx] + print(f" → auth_type: {auth_type}\n") + + # Credentials depending on auth type + headers: dict[str, str] = {} + oauth_client_id = "" + oauth_client_secret = "" + oauth_scopes = "" + + if auth_type == "headers": + print("Available header keys:") + for i, key in enumerate(HEADER_KEYS, 1): + print(f" {i}. {key}") + print(" 0. Skip / done adding headers") + print() + + while True: + choice = _ask("Select header key number (0 to finish)", "0") + if choice == "0": + break + try: + idx = int(choice) - 1 + if idx < 0 or idx >= len(HEADER_KEYS): + raise ValueError + except ValueError: + print("Invalid selection, try again.") + continue + hdr_key = HEADER_KEYS[idx] + hdr_value = _ask(f" Value for '{hdr_key}'") + if hdr_value: + headers[hdr_key] = hdr_value + + elif auth_type == "oauth2": + print("OAuth 2.1 Configuration") + print("(All fields are optional — the server supports Dynamic Client Registration)\n") + oauth_client_id = _ask("Client ID (leave blank for auto-registration)", "") + oauth_client_secret = _ask_secret( + "Client Secret (leave blank for auto-registration)" + "\n ⚠ This value will NOT be saved to .env or source files." + "\n It is sent securely to the CogSol API (Azure Key Vault)." + ) + oauth_scopes = _ask( + "Scopes (space-separated, e.g. 'read:jira write:confluence')", "" + ) + + # ── Step 2: Connect & list tools ───────────────────────────── + print("\n=== Step 2: Discovering Tools ===\n") + print(f"Connecting to {server_url} ...") + + # For OAuth servers, attempt discovery without auth (many servers allow + # tools/list unauthenticated; actual execution requires the portal auth flow) + discovery_headers = headers if auth_type == "headers" else {} + client = MCPClient(server_url, headers=discovery_headers, auth_type=auth_type) + connected = client.initialize() + + tools = client.list_tools() if connected else [] + client.disconnect() + + if auth_type == "oauth2" and not tools: + tools = self._oauth_assisted_discovery( + server_name=server_name, + server_description=server_description, + server_url=server_url, + oauth_client_id=oauth_client_id, + oauth_client_secret=oauth_client_secret, + oauth_scopes=oauth_scopes, + oauth_timeout=oauth_timeout, + ) + + if not connected and auth_type != "oauth2": + print("Failed to connect to the MCP server. Check URL and headers.") + return 1 + + selected_tools: list[dict[str, Any]] = [] + + if tools: + print(f"\nFound {len(tools)} tool(s):\n") + for i, tool in enumerate(tools, 1): + desc = tool.get("description", "") + print(f" {i}. {tool['name']}") + if desc: + print(f" {desc[:100]}") + print() + + print("Enter tool numbers to add (e.g. '1,3,5'), 'all' for all, or '0' to cancel:") + selection = _ask("Selection", "all") + if selection == "0": + print("Cancelled.") + return 0 + + if selection.lower() == "all": + selected_tools = tools + else: + selected_indices: list[int] = [] + for part in selection.split(","): + part = part.strip() + if part.isdigit(): + idx = int(part) - 1 + if 0 <= idx < len(tools): + selected_indices.append(idx) + selected_tools = [tools[i] for i in selected_indices] + + if not selected_tools: + print("No tools selected.") + return 1 + + print(f"\nSelected {len(selected_tools)} tool(s).\n") + else: + if auth_type == "oauth2": + print( + "Could not list tools yet (OAuth still required or unavailable).\n" + "The server definition will be created without tool entries.\n" + "Complete OAuth authorization and re-run `addmcptools`, or add tools manually.\n" + ) + else: + print("The server reported no tools.") + + # ── Step 3: Generate files ─────────────────────────────────── + print("=== Step 3: Generating Files ===\n") + + # Python-safe class name from the server name + cls_base = re.sub(r"[^a-zA-Z0-9]+", " ", server_name).title().replace(" ", "") + server_cls_name = f"{cls_base}MCPServer" + + # ── Build mcp_servers.py snippet ───────────────────────────── + # URL is hardcoded in the class — it lives in the CogSol API after migrate. + env_new_vars: dict[str, str] = {} + + if auth_type == "none": + server_body_lines = [ + f" name = {_py_str(server_name)}", + f" description = {_py_str(server_description)}", + ' auth_type = "none"', + f" url = {_py_str(server_url)}", + ] + + elif auth_type == "headers": + header_env_entries: dict[str, str] = {} + header_attr_lines: list[str] = [] + for hk, hv in headers.items(): + env_key = _to_env_key(f"{server_name}_{hk}") + header_env_entries[env_key] = hv + env_new_vars[env_key] = hv + header_attr_lines.append( + f" {_py_str(hk)}: os.environ.get({_py_str(env_key)}, '')," + ) + headers_block = ( + "{\n" + "\n".join(header_attr_lines) + "\n }" if header_attr_lines else "{}" + ) + server_body_lines = [ + f" name = {_py_str(server_name)}", + f" description = {_py_str(server_description)}", + f" url = {_py_str(server_url)}", + f" headers = {headers_block}", + ] + + else: # oauth2 + oauth_cid_env = _to_env_key(f"{server_name}_OAUTH_CLIENT_ID") + oauth_scopes_env = _to_env_key(f"{server_name}_OAUTH_SCOPES") + if oauth_client_id: + env_new_vars[oauth_cid_env] = oauth_client_id + if oauth_scopes: + env_new_vars[oauth_scopes_env] = oauth_scopes + server_body_lines = [ + f" name = {_py_str(server_name)}", + f" description = {_py_str(server_description)}", + ' auth_type = "oauth2"', + f" url = {_py_str(server_url)}", + ] + if oauth_client_id: + server_body_lines.append( + f" oauth_client_id = os.environ.get({_py_str(oauth_cid_env)}, '')" + ) + if oauth_scopes: + server_body_lines.append( + f" oauth_scopes = os.environ.get({_py_str(oauth_scopes_env)}, '')" + ) + + server_body = "\n".join(server_body_lines) + # Build server_code line-by-line to avoid f-string + textwrap.dedent + # indentation issues when server_body spans multiple lines. + # Only emit `import os` when auth type uses os.environ (headers/oauth2). + header_lines = (["import os", ""] if auth_type != "none" else []) + [ + "from cogsol.tools import BaseMCPServer", + "", + "", + ] + server_code = "\n".join( + header_lines + + [ + f"class {server_cls_name}(BaseMCPServer):", + ' """MCP server definition."""', + "", + server_body, + "", + ] + ) + + # ── Build mcp_tools.py snippet ─────────────────────────────── + tool_classes: list[str] = [] + for tool in selected_tools: + t_name = tool["name"] + t_desc = tool.get("description", "") or "" + t_cls_base = re.sub(r"[^a-zA-Z0-9]+", " ", t_name).title().replace(" ", "") + t_cls_name = f"{t_cls_base}MCPTool" + tool_classes.append( + "\n".join( + [ + "", + "", + f"class {t_cls_name}(BaseMCPTool):", + ' """MCP tool definition."""', + "", + f" name = {_py_str(t_name)}", + f" description = {_py_str(t_desc)}", + f" server = {server_cls_name}", + "", + ] + ) + ) + + tools_code = "\n".join( + [ + "from cogsol.tools import BaseMCPTool", + "", + f"from {app}.mcp_servers import {server_cls_name}", + ] + ) + "".join(tool_classes) + + # ── Write files ─────────────────────────────────────────────── + app_path = project_path / app + app_path.mkdir(parents=True, exist_ok=True) + + servers_file = app_path / "mcp_servers.py" + tools_file = app_path / "mcp_tools.py" + + if servers_file.exists(): + existing = servers_file.read_text(encoding="utf-8") + if server_cls_name in existing: + print( + f" Server class '{server_cls_name}' already exists in " + f"{servers_file.name}; skipping." + ) + else: + class_only = ( + "\n\n" + + "\n".join( + line + for line in server_code.splitlines() + if not line.startswith("import ") and not line.startswith("from ") + ).strip() + + "\n" + ) + servers_file.write_text(existing.rstrip() + class_only, encoding="utf-8") + print(f" Appended {server_cls_name} to {servers_file.name}") + else: + servers_file.write_text(server_code, encoding="utf-8") + print(f" Created {servers_file.name}") + + if tool_classes: + if tools_file.exists(): + existing = tools_file.read_text(encoding="utf-8") + import_line = f"from {app}.mcp_servers import {server_cls_name}" + if import_line not in existing: + existing = existing.rstrip() + f"\n{import_line}\n" + for cls_block in tool_classes: + cls_name_match = re.search(r"class (\w+)\(", cls_block) + if cls_name_match and cls_name_match.group(1) not in existing: + existing += cls_block + tools_file.write_text(existing, encoding="utf-8") + print(f" Updated {tools_file.name}") + else: + tools_file.write_text(tools_code, encoding="utf-8") + print(f" Created {tools_file.name}") + else: + if not tools_file.exists(): + # Create a placeholder so the module is importable + placeholder = "\n".join( + [ + "from cogsol.tools import BaseMCPTool", + "", + f"from {app}.mcp_servers import {server_cls_name}", + "# No tools selected yet.", + "# Re-run `python manage.py addmcptools` after completing OAuth", + "# authorization in the CogSol portal to select tools.", + "", + ] + ) + tools_file.write_text(placeholder, encoding="utf-8") + print(f" Created {tools_file.name} (placeholder — tools to be added later)") + + # ── Update .env ─────────────────────────────────────────────── + env_path = project_path / ".env" + env_lines: list[str] = [] + if env_path.exists(): + env_lines = env_path.read_text(encoding="utf-8").splitlines() + + existing_keys = { + line.split("=", 1)[0].strip() + for line in env_lines + if "=" in line and not line.strip().startswith("#") + } + additions: list[str] = [] + for k, v in env_new_vars.items(): + if k not in existing_keys: + additions.append(f"{k}={v}") + + if additions: + env_lines.append("") + env_lines.append(f"# MCP Server: {server_name}") + env_lines.extend(additions) + env_path.write_text("\n".join(env_lines) + "\n", encoding="utf-8") + print(f" Updated .env with {len(additions)} new variable(s).") + else: + print(" .env already up-to-date.") + + print( + "\nPublishing MCP server/tools to Cognitive now " + "(this updates what appears in the portal immediately)..." + ) + try: + self._publish_to_cognitive( + server_name=server_name, + server_description=server_description, + server_url=server_url, + auth_type=auth_type, + headers=headers, + oauth_client_id=oauth_client_id, + oauth_client_secret=oauth_client_secret, + oauth_scopes=oauth_scopes, + selected_tools=selected_tools, + oauth_timeout=oauth_timeout, + ) + except CogSolAPIError as exc: + print(f"Failed to publish MCP catalog to Cognitive: {exc}") + return 1 + + if auth_type == "oauth2" and oauth_client_secret: + print( + "\n ℹ OAuth client_secret was entered but NOT written to .env.\n" + " It was sent securely to the CogSol API and stored in Azure Key Vault." + ) + + print( + "\nDone! Run 'python manage.py makemigrations' followed by " + "'python manage.py migrate'." + ) + if auth_type == "oauth2": + print("OAuth authorization was completed (or attempted) during addmcptools.") + return 0 diff --git a/cogsol/management/commands/migrate.py b/cogsol/management/commands/migrate.py index 41353a1..464b484 100644 --- a/cogsol/management/commands/migrate.py +++ b/cogsol/management/commands/migrate.py @@ -2,12 +2,13 @@ import ast import copy +import inspect import json import os import re import textwrap from pathlib import Path -from typing import Any, cast +from typing import Any, Optional, cast from cogsol.agents import genconfigs from cogsol.content import BaseRetrieval @@ -18,8 +19,17 @@ get_content_api_base_url, ) from cogsol.core.env import load_dotenv +from cogsol.core.loader import _extract_tool_params, collect_classes, collect_content_classes from cogsol.db import migrations from cogsol.management.base import BaseCommand +from cogsol.prompts import Prompt +from cogsol.tools import BaseMCPTool, BaseTool + + +def _tool_key(obj: Any) -> str: + cls = obj if isinstance(obj, type) else obj.__class__ + cname = cls.__name__ + return cname[:-4] if cname.endswith("Tool") else cname def _normalize_code(code: Any) -> str: @@ -29,13 +39,12 @@ def _normalize_code(code: Any) -> str: return textwrap.dedent(code).rstrip() -def _name_aliases(name: str) -> set[str]: - aliases = {name} - if name.endswith("Tool") and len(name) > 4: - aliases.add(name[:-4]) - elif not name.endswith("Tool"): - aliases.add(f"{name}Tool") - return aliases +def sub_slug(cls: type | None) -> str | None: + if cls and hasattr(cls, "__module__"): + parts = cls.__module__.split(".") + if len(parts) >= 2: + return parts[1] + return None class Command(BaseCommand): @@ -119,20 +128,26 @@ def handle(self, project_path: Path | None, **options: Any) -> int: try: touched = self._touched_entities(pending_ops) if app_name == "data": + class_map = collect_content_classes(project_path, app_name) remote_ids, created = self._sync_content_with_api( api_base=content_base or api_base, api_key=api_key, state=temp_state, remote_ids=remote_ids, + class_map=class_map, + project_path=project_path, touched=touched, ) else: + class_map = collect_classes(project_path, app_name) remote_ids, created = self._sync_with_api( api_base=api_base, api_key=api_key, state=temp_state, remote_ids=remote_ids, + class_map=class_map, project_path=project_path, + app=app_name, touched=touched, ) @@ -192,6 +207,8 @@ def _empty_remote(self) -> dict[str, Any]: "agents": {}, "tools": {}, "retrieval_tools": {}, + "mcp_servers": {}, + "mcp_tools": {}, "lessons": {}, "faqs": {}, "fixed_responses": {}, @@ -206,6 +223,95 @@ def _empty_content_remote(self) -> dict[str, Any]: "metadata_configs": {}, } + def _has_state_entries(self, state: dict[str, Any]) -> bool: + for value in state.values(): + if isinstance(value, dict) and value: + return True + return False + + def _extract_results(self, payload: Any) -> list[dict[str, Any]]: + if isinstance(payload, list): + return [item for item in payload if isinstance(item, dict)] + if isinstance(payload, dict): + if isinstance(payload.get("results"), list): + return [item for item in payload["results"] if isinstance(item, dict)] + if isinstance(payload.get("tools"), list): + return [item for item in payload["tools"] if isinstance(item, dict)] + return [] + + def _norm_text(self, value: Any) -> str: + return re.sub(r"\s+", " ", str(value or "")).strip().casefold() + + def _mcp_server_hydration_targets( + self, + *, + state: dict[str, Any], + remote_ids: dict[str, Any], + ) -> tuple[set[int], set[str]]: + """Resolve server ids/names that are relevant for this project's MCP references.""" + refs: set[str] = set() + for server_key, definition in (state.get("mcp_servers", {}) or {}).items(): + refs.add(str(server_key)) + fields = definition.get("fields", {}) if definition else {} + if fields.get("name"): + refs.add(str(fields["name"])) + + for definition in (state.get("mcp_tools", {}) or {}).values(): + fields = definition.get("fields", {}) if definition else {} + server_ref = fields.get("server") + if server_ref: + refs.add(str(server_ref)) + + target_names = {self._norm_text(ref) for ref in refs if str(ref).strip()} + + target_ids: set[int] = set() + remote_servers = remote_ids.get("mcp_servers", {}) or {} + for ref in refs: + remote_id = remote_servers.get(ref) + if isinstance(remote_id, int): + target_ids.add(remote_id) + elif isinstance(remote_id, str) and remote_id.isdigit(): + target_ids.add(int(remote_id)) + + return target_ids, target_names + + def _hydrate_remote_mcp_tool_ids( + self, + *, + client: CogSolClient, + remote_ids: dict[str, Any], + target_server_ids: set[int] | None = None, + target_server_names: set[str] | None = None, + ) -> None: + """Populate ``remote_ids['mcp_tools']`` from current Cognitive MCP catalog.""" + ids_filter = set(target_server_ids or set()) + names_filter = {self._norm_text(name) for name in (target_server_names or set())} + if not ids_filter and not names_filter: + return + + servers = self._extract_results(client.list_mcp_servers()) + for server in servers: + server_id = server.get("id") + if server_id is None: + continue + try: + server_id_int = int(server_id) + except (TypeError, ValueError): + continue + + server_name_n = self._norm_text(server.get("name")) + if server_id_int not in ids_filter and server_name_n not in names_filter: + continue + try: + payload = client.list_mcp_server_tools(server_id_int) + except CogSolAPIError as exc: + print( + " Warning: could not list MCP tools for" + f" server '{server.get('name', server_id_int)}': {exc}" + ) + continue + self._update_mcp_tool_remote_ids(remote_ids, payload) + def _touched_entities(self, operations: list[Any]) -> dict[str, set[str]]: touched: dict[str, set[str]] = {} for op in operations: @@ -239,6 +345,8 @@ def _sync_content_with_api( api_key: str | None, state: dict[str, Any], remote_ids: dict[str, Any], + class_map: dict[str, dict[str, type]], + project_path: Path, touched: dict[str, set[str]] | None = None, ) -> tuple[dict[str, Any], list[tuple[str, int | None, int]]]: """Sync Content API entities (topics, formatters, retrievals) with the API.""" @@ -484,7 +592,9 @@ def _sync_with_api( api_key: str | None, state: dict[str, Any], remote_ids: dict[str, Any], + class_map: dict[str, dict[str, type]], project_path: Path, + app: str, touched: dict[str, set[str]] | None = None, ) -> tuple[dict[str, Any], list[tuple[str, int | None, int]]]: client = CogSolClient(api_base, api_key=api_key) @@ -496,55 +606,75 @@ def _sync_with_api( for tool_name, definition in state.get("tools", {}).items(): if touched is not None and tool_name not in touched.get("tools", set()): continue - payload = self._tool_payload(tool_name, definition) + cls = cast(Optional[type[BaseTool]], class_map.get("tools", {}).get(tool_name)) + payload = self._tool_payload(tool_name, definition, cls) remote_id = new_remote.get("tools", {}).get(tool_name) - if remote_id is None and payload.get("name"): - remote_id = new_remote.get("tools", {}).get(str(payload["name"])) new_id = client.upsert_script(remote_id=remote_id, payload=payload) if not remote_id: created.append(("tool", None, new_id)) + # store under multiple keys to ensure lookup (normalized, class name, explicit name) new_remote.setdefault("tools", {})[tool_name] = new_id - if payload.get("name"): - new_remote["tools"][str(payload["name"])] = new_id - for alias in _name_aliases(tool_name): - new_remote["tools"][alias] = new_id + if cls is not None: + norm = _tool_key(cls) + new_remote["tools"][norm] = new_id + new_remote["tools"][cls.__name__] = new_id + explicit_name = getattr(cls, "name", None) + if explicit_name: + new_remote["tools"][explicit_name] = new_id # Upsert retrieval tools. for tool_name, definition in state.get("retrieval_tools", {}).items(): if touched is not None and tool_name not in touched.get("retrieval_tools", set()): continue + cls = class_map.get("retrieval_tools", {}).get(tool_name) payload = self._retrieval_tool_payload( tool_name=tool_name, definition=definition, + cls=cls, project_path=project_path, ) remote_id = new_remote.get("retrieval_tools", {}).get(tool_name) - if remote_id is None and payload.get("name"): - remote_id = new_remote.get("retrieval_tools", {}).get(str(payload["name"])) new_id = client.upsert_retrieval_tool(remote_id=remote_id, payload=payload) if not remote_id: created.append(("retrieval_tool", None, new_id)) new_remote.setdefault("retrieval_tools", {})[tool_name] = new_id - if payload.get("name"): - new_remote["retrieval_tools"][str(payload["name"])] = new_id - for alias in _name_aliases(tool_name): - new_remote["retrieval_tools"][alias] = new_id + if cls is not None: + new_remote["retrieval_tools"][cls.__name__] = new_id + explicit_name = getattr(cls, "name", None) + if explicit_name: + new_remote["retrieval_tools"][explicit_name] = new_id + + if touched is None or touched.get("mcp_servers") or touched.get("mcp_tools"): + print( + " MCP server/tool migration operations are skipped in 'migrate'. " + "Use 'addmcptools' to publish MCP catalog to Cognitive." + ) - agents_with_faqs = self._agents_from_related_bucket(state.get("faqs", {})) - agents_with_fixed = self._agents_from_related_bucket(state.get("fixed_responses", {})) - agents_with_lessons = self._agents_from_related_bucket(state.get("lessons", {})) + # Migrate only needs MCP tool ids for assistant associations. + target_ids, target_names = self._mcp_server_hydration_targets( + state=state, + remote_ids=new_remote, + ) + self._hydrate_remote_mcp_tool_ids( + client=client, + remote_ids=new_remote, + target_server_ids=target_ids, + target_server_names=target_names, + ) # Upsert agents. for agent_name, definition in state.get("agents", {}).items(): if touched is not None and agent_name not in touched.get("agents", set()): continue + cls = class_map.get("agents", {}).get(agent_name) payload = self._assistant_payload( agent_name=agent_name, definition=definition, + cls=cls, remote_ids=new_remote, - faq_available=agent_name in agents_with_faqs, - fixed_available=agent_name in agents_with_fixed, - lessons_available=agent_name in agents_with_lessons, + project_path=project_path, + app=app, + slug=sub_slug(cls), ) remote_id = new_remote.get("agents", {}).get(agent_name) new_id = client.upsert_assistant(remote_id=remote_id, payload=payload) @@ -552,66 +682,95 @@ def _sync_with_api( created.append(("assistant", None, new_id)) new_remote.setdefault("agents", {})[agent_name] = new_id - # Upsert FAQs (common questions), fixed responses, lessons per agent from migration state. - for key, definition in state.get("faqs", {}).items(): - if touched is not None and key not in touched.get("faqs", set()): + # Upsert FAQs (common questions), fixed responses, lessons per agent. + sync_related = True + sync_all_related = touched is None + faq_filter: dict[str, set[str]] = {} + fixed_filter: dict[str, set[str]] = {} + lesson_filter: dict[str, set[str]] = {} + if not sync_all_related: + touched_map = touched or {} + for key in touched_map.get("faqs", set()): + agent, _, name = key.partition("::") + if agent and name: + faq_filter.setdefault(agent, set()).add(name) + for key in touched_map.get("fixed_responses", set()): + agent, _, name = key.partition("::") + if agent and name: + fixed_filter.setdefault(agent, set()).add(name) + for key in touched_map.get("lessons", set()): + agent, _, name = key.partition("::") + if agent and name: + lesson_filter.setdefault(agent, set()).add(name) + sync_related = bool(faq_filter or fixed_filter or lesson_filter) + if not sync_related: + return new_remote, created + + agents_filter = set(faq_filter) | set(fixed_filter) | set(lesson_filter) + + for agent_name, agent_cls in class_map.get("agents", {}).items(): + if agents_filter and agent_name not in agents_filter: continue - fields = definition.get("fields", {}) if definition else {} - agent_name = str(fields.get("agent") or str(key).partition("::")[0]) assistant_id = new_remote.get("agents", {}).get(agent_name) if not assistant_id: continue - payload = self._faq_payload_from_fields(str(key), fields) new_remote.setdefault("faqs", {}).setdefault(agent_name, {}) - remote_id = new_remote["faqs"][agent_name].get(payload["name"]) - new_id = client.upsert_common_question( - assistant_id=assistant_id, - remote_id=remote_id, - payload=payload, - ) - if not remote_id: - created.append(("faq", assistant_id, new_id)) - new_remote["faqs"][agent_name][payload["name"]] = new_id - - for key, definition in state.get("fixed_responses", {}).items(): - if touched is not None and key not in touched.get("fixed_responses", set()): - continue - fields = definition.get("fields", {}) if definition else {} - agent_name = str(fields.get("agent") or str(key).partition("::")[0]) - assistant_id = new_remote.get("agents", {}).get(agent_name) - if not assistant_id: - continue - payload = self._fixed_payload_from_fields(str(key), fields) new_remote.setdefault("fixed_responses", {}).setdefault(agent_name, {}) - remote_id = new_remote["fixed_responses"][agent_name].get(payload["name"]) - new_id = client.upsert_fixed_response( - assistant_id=assistant_id, - remote_id=remote_id, - payload=payload, - ) - if not remote_id: - created.append(("fixed", assistant_id, new_id)) - new_remote["fixed_responses"][agent_name][payload["name"]] = new_id - - for key, definition in state.get("lessons", {}).items(): - if touched is not None and key not in touched.get("lessons", set()): - continue - fields = definition.get("fields", {}) if definition else {} - agent_name = str(fields.get("agent") or str(key).partition("::")[0]) - assistant_id = new_remote.get("agents", {}).get(agent_name) - if not assistant_id: - continue - payload = self._lesson_payload_from_fields(str(key), fields) new_remote.setdefault("lessons", {}).setdefault(agent_name, {}) - remote_id = new_remote["lessons"][agent_name].get(payload["name"]) - new_id = client.upsert_lesson( - assistant_id=assistant_id, - remote_id=remote_id, - payload=payload, - ) - if not remote_id: - created.append(("lesson", assistant_id, new_id)) - new_remote["lessons"][agent_name][payload["name"]] = new_id + + if sync_all_related or faq_filter: + for faq_obj in getattr(agent_cls, "faqs", []) or []: + payload = self._faq_payload(faq_obj) + if ( + faq_filter.get(agent_name) + and payload["name"] not in faq_filter[agent_name] + ): + continue + remote_id = new_remote["faqs"][agent_name].get(payload["name"]) + new_id = client.upsert_common_question( + assistant_id=assistant_id, + remote_id=remote_id, + payload=payload, + ) + if not remote_id: + created.append(("faq", assistant_id, new_id)) + new_remote["faqs"][agent_name][payload["name"]] = new_id + + if sync_all_related or fixed_filter: + for fx_obj in getattr(agent_cls, "fixed_responses", []) or []: + payload = self._fixed_payload(fx_obj) + if ( + fixed_filter.get(agent_name) + and payload["name"] not in fixed_filter[agent_name] + ): + continue + remote_id = new_remote["fixed_responses"][agent_name].get(payload["name"]) + new_id = client.upsert_fixed_response( + assistant_id=assistant_id, + remote_id=remote_id, + payload=payload, + ) + if not remote_id: + created.append(("fixed", assistant_id, new_id)) + new_remote["fixed_responses"][agent_name][payload["name"]] = new_id + + if sync_all_related or lesson_filter: + for lesson_obj in getattr(agent_cls, "lessons", []) or []: + payload = self._lesson_payload(lesson_obj) + if ( + lesson_filter.get(agent_name) + and payload["name"] not in lesson_filter[agent_name] + ): + continue + remote_id = new_remote["lessons"][agent_name].get(payload["name"]) + new_id = client.upsert_lesson( + assistant_id=assistant_id, + remote_id=remote_id, + payload=payload, + ) + if not remote_id: + created.append(("lesson", assistant_id, new_id)) + new_remote["lessons"][agent_name][payload["name"]] = new_id return new_remote, created except Exception: @@ -624,6 +783,34 @@ def _sync_with_api( continue raise + def _update_mcp_tool_remote_ids(self, remote_ids: dict[str, Any], payload: Any) -> bool: + """Extract MCP tool ids from API payloads and store them in ``remote_ids``. + + Accepts either a raw list of tool dicts or dict payloads containing a + ``results``/``configured_tools``/``tools`` list. + """ + items: list[Any] = [] + if isinstance(payload, list): + items = payload + elif isinstance(payload, dict): + for key in ("results", "configured_tools", "tools"): + value = payload.get(key) + if isinstance(value, list): + items.extend(value) + + found = False + for item in items: + if not isinstance(item, dict): + continue + if "id" not in item: + continue + tname = item.get("name") + if not tname: + continue + remote_ids.setdefault("mcp_tools", {})[str(tname)] = int(item["id"]) + found = True + return found + def _delete_created_entry( self, client: CogSolClient, kind: str, parent_id: int | None, obj_id: int ) -> None: @@ -639,6 +826,10 @@ def _delete_created_entry( client.delete_script(obj_id) elif kind == "retrieval_tool": client.delete_retrieval_tool(obj_id) + elif kind == "mcp_server": + client.delete_mcp_server(obj_id) + elif kind == "mcp_tool": + client.delete_mcp_tool(obj_id) def _delete_content_created_entry( self, client: CogSolClient, kind: str, parent_id: int | None, obj_id: int @@ -684,51 +875,54 @@ def _tool_payload( self, tool_name: str, definition: dict[str, Any], + cls: type[BaseTool] | None = None, ) -> dict[str, Any]: fields = definition.get("fields", {}) if definition else {} - params: list[dict[str, Any]] = [] - raw_params = fields.get("parameters", {}) - if isinstance(raw_params, dict): - for name, meta in raw_params.items(): - meta = meta or {} - param_entry = { - "name": str(name), - "description": meta.get("description") or str(name), - "type": meta.get("type") or "string", - "required": bool(meta.get("required", True)), - } - if param_entry["type"] == "array" and "items" in meta: - param_entry["items"] = meta["items"] - params.append(param_entry) - elif isinstance(raw_params, list): - for item in raw_params: - if not isinstance(item, dict): - continue - name = str(item.get("name") or "") - if not name: - continue - param_entry = { - "name": name, - "description": item.get("description") or name, - "type": item.get("type") or "string", - "required": bool(item.get("required", True)), - } - if param_entry["type"] == "array" and "items" in item: - param_entry["items"] = item["items"] - params.append(param_entry) + def _get(attr: str, default=None): + if cls is not None and hasattr(cls, attr): + return getattr(cls, attr) + return fields.get(attr, default) - description = fields.get("description") or f"Tool {tool_name}" - code = self._tool_script_from_state(fields) + params = [] + if cls is not None: + param_def = _extract_tool_params(cls) + else: + param_def = definition.get("fields", {}).get("parameters", {}) if definition else {} + for name, meta in (param_def or {}).items(): + meta = meta or {} + param_entry = { + "name": name, + "description": meta.get("description") or name, + "type": meta.get("type") or "string", + "required": bool(meta.get("required", True)), + } + # Include 'items' for array types if specified + if param_entry["type"] == "array" and "items" in meta: + param_entry["items"] = meta["items"] + params.append(param_entry) + + description = ( + (definition.get("fields", {}) or {}).get("description") if definition else None + ) + if not description and cls is not None: + description = getattr(cls, "description", None) + description = description or f"Tool {tool_name}" + + code = ( + self._tool_script_from_class(cls) + if cls is not None + else self._tool_script_from_state(fields) + ) code = code or "# TODO: provide implementation\nresponse = None" return { - "name": fields.get("name") or tool_name, + "name": tool_name, "description": description, "parameters": params, - "show_tool_message": bool(fields.get("show_tool_message", False)), - "show_assistant_message": bool(fields.get("show_assistant_message", False)), - "edit_available": bool(fields.get("edit_available", True)), + "show_tool_message": bool(_get("show_tool_message", False)), + "show_assistant_message": bool(_get("show_assistant_message", False)), + "edit_available": bool(_get("edit_available", True)), "code": code, } @@ -737,10 +931,16 @@ def _retrieval_tool_payload( *, tool_name: str, definition: dict[str, Any], + cls: type | None, project_path: Path, ) -> dict[str, Any]: fields = definition.get("fields", {}) if definition else {} + def _get(attr: str, default=None): + if cls is not None and hasattr(cls, attr): + return getattr(cls, attr) + return fields.get(attr, default) + def _resolve_retrieval_id(value: Any) -> int: if value is None: raise CogSolAPIError(f"retrieval is required for retrieval tool '{tool_name}'.") @@ -763,7 +963,7 @@ def _resolve_retrieval_id(value: Any) -> int: ) return int(retrieval_id) - params = list(fields.get("parameters") or []) + params = list(_get("parameters") or []) if not params: params.append( { @@ -773,18 +973,18 @@ def _resolve_retrieval_id(value: Any) -> int: "required": True, } ) - description = fields.get("description") or f"Retrieval tool {tool_name}" - retrieval_id = _resolve_retrieval_id(fields.get("retrieval")) + description = _get("description") or f"Retrieval tool {tool_name}" + retrieval_id = _resolve_retrieval_id(_get("retrieval")) return { - "name": fields.get("name") or tool_name, + "name": _get("name") or tool_name, "description": description, "parameters": params, - "show_tool_message": bool(fields.get("show_tool_message", False)), - "show_assistant_message": bool(fields.get("show_assistant_message", False)), - "edit_available": bool(fields.get("edit_available", True)), + "show_tool_message": bool(_get("show_tool_message", False)), + "show_assistant_message": bool(_get("show_assistant_message", False)), + "edit_available": bool(_get("edit_available", True)), "retrieval_id": retrieval_id, - "answer": bool(fields.get("answer", True)), + "answer": bool(_get("answer", True)), } def _assistant_payload( @@ -792,15 +992,18 @@ def _assistant_payload( *, agent_name: str, definition: dict[str, Any], + cls: type | None, remote_ids: dict[str, Any], - faq_available: bool, - fixed_available: bool, - lessons_available: bool, + project_path: Path, + app: str, + slug: str | None = None, ) -> dict[str, Any]: fields = definition.get("fields", {}) if definition else {} meta = definition.get("meta", {}) if definition else {} def _get(attr: str, default=None): + if cls is not None and hasattr(cls, attr): + return getattr(cls, attr) return fields.get(attr, default) def _get_meta(attr: str, default=None): @@ -832,37 +1035,79 @@ def _first_non_none(*values: Any) -> Any: return v return None - def _resolve_tool_id(raw: Any) -> int | None: - candidates: list[str] = [] - if isinstance(raw, str): - candidates.append(raw) - elif isinstance(raw, dict) and raw.get("name"): - candidates.append(str(raw.get("name"))) - if not candidates and raw is not None: - candidates.append(str(raw)) - - for candidate in candidates: - for alias in _name_aliases(candidate): - tool_id = remote_ids.get("tools", {}).get(alias) - if isinstance(tool_id, int): - return tool_id - if isinstance(tool_id, str) and tool_id.isdigit(): - return int(tool_id) - retrieval_id = remote_ids.get("retrieval_tools", {}).get(alias) - if isinstance(retrieval_id, int): - return retrieval_id - if isinstance(retrieval_id, str) and retrieval_id.isdigit(): - return int(retrieval_id) + def _prompt_text(value: Any) -> str: + if isinstance(value, Prompt): + candidates = [] + if value.base_dir: + candidates.append(Path(value.base_dir) / "prompts" / value.path) + if slug: + candidates.append(project_path / app / slug / "prompts" / value.path) + candidates.append(project_path / app / "prompts" / value.path) + for candidate in candidates: + if candidate.exists(): + try: + return candidate.read_text(encoding="utf-8") + except FileNotFoundError: + continue + return str(value.path) + if isinstance(value, Path): + try: + return value.read_text(encoding="utf-8") + except FileNotFoundError: + return str(value) + return str(value) if value is not None else "" + + tools = getattr(cls, "tools", []) if cls else [] + pretools = getattr(cls, "pretools", []) if cls else [] + + def _resolve_tool_id(t) -> int | None: + is_mcp_tool = isinstance(t, BaseMCPTool) or ( + isinstance(t, type) and issubclass(t, BaseMCPTool) + ) + candidates = [ + getattr(t, "name", None), + _tool_key(t), + getattr(t, "__name__", None), + t.__class__.__name__, + ] + for name in candidates: + if not name: + continue + tool_id = remote_ids.get("tools", {}).get(name) + if isinstance(tool_id, int): + return tool_id + if isinstance(tool_id, str) and tool_id.isdigit(): + return int(tool_id) + + retrieval_id = remote_ids.get("retrieval_tools", {}).get(name) + if isinstance(retrieval_id, int): + return retrieval_id + if isinstance(retrieval_id, str) and retrieval_id.isdigit(): + return int(retrieval_id) + + mcp_tool_id = remote_ids.get("mcp_tools", {}).get(name) + if isinstance(mcp_tool_id, int): + return mcp_tool_id + if isinstance(mcp_tool_id, str) and mcp_tool_id.isdigit(): + return int(mcp_tool_id) + + if is_mcp_tool: + tool_name = next((str(n) for n in candidates if n), "") + raise CogSolAPIError( + "Assistant references MCP tool '" + f"{tool_name}' that is not present in Cognitive. " + "Run 'addmcptools' first to publish MCP server/tools." + ) return None tool_ids: list[int] = [] - for t in list(_get("tools", []) or []): + for t in tools: remote_id = _resolve_tool_id(t) if remote_id: tool_ids.append(remote_id) pretool_ids: list[int] = [] - for t in list(_get("pretools", []) or []): + for t in pretools: remote_id = _resolve_tool_id(t) if remote_id: pretool_ids.append(remote_id) @@ -881,7 +1126,7 @@ def _resolve_tool_id(raw: Any) -> int | None: "generation_config": _normalize_config(_get("generation_config")), "generation_config_pretools": _normalize_config(_get("pregeneration_config")), "description": _get_meta("chat_name") or f"Agent {agent_name}", - "system_prompt": str(_get("system_prompt") or ""), + "system_prompt": _prompt_text(_get("system_prompt")), "temperature": float(_get("temperature") or 0.0), "max_responses": _int_or_default( _first_non_none(_get("max_responses"), _get("max_interactions")), default=0 @@ -903,9 +1148,13 @@ def _resolve_tool_id(raw: Any) -> int | None: "matrix_mode_available": bool(_get("realtime", False)), "not_info_message": _get("no_information_message"), "strategy_to_optimize_tokens": None, - "faq_available": bool(faq_available or fields.get("faqs")), - "fixed_available": bool(fixed_available or fields.get("fixed_responses")), - "lessons_available": bool(lessons_available or fields.get("lessons")), + "faq_available": bool(getattr(cls, "faqs", []) if cls else fields.get("faqs")), + "fixed_available": bool( + getattr(cls, "fixed_responses", []) if cls else fields.get("fixed_responses") + ), + "lessons_available": bool( + getattr(cls, "lessons", []) if cls else fields.get("lessons") + ), "realtime_available": bool(_get("realtime", False)), "info": None, "colors": colors, @@ -916,49 +1165,33 @@ def _resolve_tool_id(raw: Any) -> int | None: } return payload - def _agents_from_related_bucket(self, bucket: dict[str, Any]) -> set[str]: - agents: set[str] = set() - for key, definition in (bucket or {}).items(): - fields = definition.get("fields", {}) if isinstance(definition, dict) else {} - agent_name = str(fields.get("agent") or str(key).partition("::")[0]).strip() - if agent_name: - agents.add(agent_name) - return agents - - def _faq_payload_from_fields(self, key: str, fields: dict[str, Any]) -> dict[str, Any]: - _, _, default_name = key.partition("::") - name = str(fields.get("name") or default_name or key) - content = str(fields.get("content") or "") + def _faq_payload(self, faq_obj: Any) -> dict[str, Any]: + name = ( + getattr(faq_obj, "question", None) + or getattr(faq_obj, "name", None) + or faq_obj.__class__.__name__ + ) + content = getattr(faq_obj, "answer", None) or getattr(faq_obj, "content", None) or "" return { "name": name, "content": content, "additional_metadata": {}, } - def _fixed_payload_from_fields(self, key: str, fields: dict[str, Any]) -> dict[str, Any]: - _, _, default_name = key.partition("::") - name = str(fields.get("name") or default_name or key) - content = str(fields.get("content") or "") - meta = fields.get("meta") - topic = None - if isinstance(meta, dict): - topic = meta.get("topic") - topic = topic or name + def _fixed_payload(self, obj: Any) -> dict[str, Any]: + key = getattr(obj, "key", None) or getattr(obj, "name", None) or obj.__class__.__name__ + content = getattr(obj, "response", None) or getattr(obj, "content", None) or "" return { - "topic": topic, + "topic": key, "content": content, - "name": name, + "name": key, "additional_metadata": {}, } - def _lesson_payload_from_fields(self, key: str, fields: dict[str, Any]) -> dict[str, Any]: - _, _, default_name = key.partition("::") - name = str(fields.get("name") or default_name or key) - content = str(fields.get("content") or "") - context = "general" - meta = fields.get("meta") - if isinstance(meta, dict): - context = str(meta.get("context_of_application") or "general") + def _lesson_payload(self, obj: Any) -> dict[str, Any]: + name = getattr(obj, "name", None) or obj.__class__.__name__ + content = getattr(obj, "content", None) or "" + context = getattr(obj, "context_of_application", None) or "general" return { "name": name, "content": content, @@ -966,71 +1199,198 @@ def _lesson_payload_from_fields(self, key: str, fields: dict[str, Any]) -> dict[ "additional_metadata": {}, } - def _tool_script_from_state(self, fields: dict[str, Any]) -> str: - raw_code = fields.get("__code__") - if not isinstance(raw_code, str) or not raw_code.strip(): + def _tool_helper_source(self, node: ast.AST) -> str: + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): return "" - code = _normalize_code(raw_code) + fn_node = ast.fix_missing_locations(ast.copy_location(node, node)) + fn_node.decorator_list = [] + if fn_node.args.posonlyargs and fn_node.args.posonlyargs[0].arg == "self": + fn_node.args.posonlyargs = fn_node.args.posonlyargs[1:] + if fn_node.args.args and fn_node.args.args[0].arg == "self": + fn_node.args.args = fn_node.args.args[1:] + return ast.unparse(fn_node).strip() + + def _replace_self_calls(self, code: str, helper_names: list[str]) -> str: + rewritten = code + for name in helper_names: + rewritten = re.sub(rf"\bself\.{name}\b", name, rewritten) + return rewritten + + def _strip_self_from_signature(self, source: str) -> str: + lines = source.splitlines() + cleaned: list[str] = [] + in_signature = True + for line in lines: + if in_signature: + stripped = line.strip() + if stripped in {"self", "self,"}: + continue + line = re.sub(r"\(\s*self\s*,\s*", "(", line) + line = re.sub(r"\(\s*self\s*\)", "()", line) + line = re.sub(r",\s*self\s*(?=[,)])", "", line) + if stripped.endswith(":"): + in_signature = False + cleaned.append(line) + return "\n".join(cleaned) + + def _tool_script_from_state(self, fields: dict[str, Any]) -> str: + code = _normalize_code(fields.get("__code__", "") or "") if not code: return "" - params_to_bind = self._tool_param_names_from_fields(fields.get("parameters")) - return self._tool_script_from_code(code, params_to_bind) - - def _tool_param_names_from_fields(self, raw_params: Any) -> list[str]: - names: list[str] = [] - if isinstance(raw_params, dict): - names = [str(k) for k in raw_params] - elif isinstance(raw_params, list): - for item in raw_params: - if isinstance(item, dict) and item.get("name"): - names.append(str(item["name"])) - return names - - def _tool_script_from_code(self, code: str, params_to_bind: list[str]) -> str: - normalized = _normalize_code(code) + try: - tree = ast.parse(normalized) + tree = ast.parse(code) except SyntaxError: - script = normalized.strip() - if script and "response" not in script: - script += "\n\nresponse = None" - return script - - fn_nodes = [ - node for node in tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) - ] - run_node = next((node for node in fn_nodes if node.name == "run"), None) - helper_nodes = [ - node - for node in fn_nodes - if node is not run_node - and not (node.name.startswith("__") and node.name.endswith("__")) - ] - helper_names = [node.name for node in helper_nodes] - helper_sources = [ - src - for src in (self._tool_helper_source(node, normalized) for node in helper_nodes) - if src - ] - - if run_node is None: - helper_block = self._replace_self_calls( - "\n\n".join(helper_sources), helper_names - ).strip() - if helper_block and "response" not in helper_block: - helper_block += "\n\nresponse = None" - return helper_block - - if not params_to_bind: - params_to_bind = self._tool_param_names_from_run_node(run_node) - - run_body = self._run_body_source(run_node, normalized) - run_body = self._replace_self_calls(run_body, helper_names) + return "" + + lines = code.splitlines() + helper_sources: list[str] = [] + helper_names: list[str] = [] + run_node: ast.FunctionDef | ast.AsyncFunctionDef | None = None + + for node in tree.body: + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + if node.name == "run": + run_node = node + continue + if node.name.startswith("__") and node.name.endswith("__"): + continue + helper_names.append(node.name) + helper_src = "\n".join(lines[node.lineno - 1 : node.end_lineno]) + helper_src = self._strip_self_from_signature(helper_src) + helper_src = self._replace_self_calls(helper_src, helper_names) + helper_sources.append(helper_src.strip()) + + run_script = "" + if run_node is not None and run_node.body: + body_start = run_node.body[0].lineno - 1 + body_end = run_node.end_lineno or run_node.body[-1].lineno + body = "\n".join(lines[body_start:body_end]) + dedented = textwrap.dedent(body) + dedented = self._replace_self_calls(dedented, helper_names) + + params_to_bind = list((fields.get("parameters") or {}).keys()) + result_lines: list[str] = [] + for p in params_to_bind: + result_lines.append(f"{p} = params.get('{p}') if params else None") + + for line in dedented.splitlines(): + stripped = line.lstrip() + indent = line[: len(line) - len(stripped)] + if stripped.startswith("return "): + result_lines.append(f"{indent}response = {stripped[len('return '):]}") + continue + if stripped == "return": + result_lines.append(f"{indent}response = None") + continue + result_lines.append(line) + + run_script = "\n".join(result_lines).strip() + if "response" not in run_script: + run_script += ("\n\n" if run_script else "") + "response = None" + + script_parts: list[str] = [] + if helper_sources: + script_parts.append("\n\n".join(helper_sources)) + if run_script: + script_parts.append(run_script) + return "\n\n".join(script_parts).strip() + + def _tool_script_from_class(self, cls: type[BaseTool] | None) -> str: + if cls is None: + return "" + try: + run_fn = cls.run + except AttributeError: + return getattr(cls, "__doc__", "") or "" + + try: + source = inspect.getsource(run_fn) + except (OSError, TypeError): # pragma: no cover - best effort + return getattr(cls, "__doc__", "") or "" + + helper_sources: list[str] = [] + helper_names: list[str] = [] + try: + class_source = inspect.getsource(cls) + class_source = _normalize_code(class_source) + tree = ast.parse(class_source) + class_def = next( + ( + node + for node in tree.body + if isinstance(node, ast.ClassDef) and node.name == cls.__name__ + ), + None, + ) + if class_def is not None: + for node in class_def.body: + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + name = node.name + if name == "run" or (name.startswith("__") and name.endswith("__")): + continue + helper_names.append(name) + helper_sources.append(self._tool_helper_source(node)) + except Exception: # pragma: no cover - best effort + helper_sources = [] + helper_names = [] + + source = _normalize_code(source) + lines = source.splitlines() + # Strip decorator lines if any (not expected but safe). + while lines and lines[0].lstrip().startswith("@"): + lines.pop(0) + + try: + fn_tree = ast.parse(source) + run_node = next( + ( + node + for node in fn_tree.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + ), + None, + ) + except Exception: + run_node = None + + if run_node is not None and run_node.body: + body_start = run_node.body[0].lineno - 1 + body_end = run_node.end_lineno or run_node.body[-1].lineno + body = "\n".join(lines[body_start:body_end]) + else: + # Fallback for unusual source layouts. + def_idx = None + for i, line in enumerate(lines): + if line.lstrip().startswith("def "): + def_idx = i + break + if def_idx is None: + return textwrap.dedent(source) + body = "\n".join(lines[def_idx + 1 :]) + + dedented = textwrap.dedent(body) + dedented = self._replace_self_calls(dedented, helper_names) + + # Detect parameters to bind from signature (excluding runtime args) + params_to_bind = [] + try: + sig = inspect.signature(run_fn) + for name, _param in sig.parameters.items(): + if name in {"self", "chat", "data", "secrets", "log", "params"}: + continue + params_to_bind.append(name) + except Exception: + params_to_bind = [] result_lines: list[str] = [] + # Prepend param extraction for p in params_to_bind: result_lines.append(f"{p} = params.get('{p}') if params else None") - for line in run_body.splitlines(): + + for line in dedented.splitlines(): stripped = line.lstrip() indent = line[: len(line) - len(stripped)] if stripped.startswith("return "): @@ -1047,105 +1407,10 @@ def _tool_script_from_code(self, code: str, params_to_bind: list[str]) -> str: script_parts: list[str] = [] if helper_sources: - helper_block = self._replace_self_calls("\n\n".join(helper_sources), helper_names) - script_parts.append(helper_block.strip()) + helper_block = "\n\n".join(helper_sources) + helper_block = self._replace_self_calls(helper_block, helper_names) + script_parts.append(helper_block) if run_script: script_parts.append(run_script) - return "\n\n".join(part for part in script_parts if part).strip() - - def _tool_param_names_from_run_node(self, run_node: ast.AST) -> list[str]: - if not isinstance(run_node, (ast.FunctionDef, ast.AsyncFunctionDef)): - return [] - ignore = {"self", "chat", "data", "secrets", "log", "params"} - names: list[str] = [] - arg_nodes = ( - list(run_node.args.posonlyargs) - + list(run_node.args.args) - + list(run_node.args.kwonlyargs) - ) - for arg in arg_nodes: - if arg.arg in ignore or arg.arg in names: - continue - names.append(arg.arg) - return names - - def _node_source(self, node: ast.AST, source: str) -> str: - segment = ast.get_source_segment(source, node) - if segment: - return segment - try: - return ast.unparse(node) - except Exception: - return "" - def _relative_node_source(self, node: ast.AST, source: str) -> str: - segment = self._node_source(node, source) - if not segment: - return "" - lines = segment.splitlines() - if len(lines) <= 1: - return segment - indent = max(getattr(node, "col_offset", 0), 0) - normalized = [lines[0]] - for line in lines[1:]: - if line.startswith(" " * indent): - normalized.append(line[indent:]) - else: - normalized.append(line.lstrip() if line.strip() else "") - return "\n".join(normalized) - - def _function_body_source(self, node: ast.AST, source: str) -> str: - if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - return "" - body_parts = [ - part - for part in (self._relative_node_source(stmt, source) for stmt in node.body) - if part - ] - return _normalize_code("\n".join(body_parts)) - - def _source_offset(self, source: str, node: ast.AST, target: ast.AST) -> int: - if ( - not hasattr(node, "lineno") - or not hasattr(target, "lineno") - or not hasattr(target, "col_offset") - ): - return len(source) - lines = source.splitlines(keepends=True) - line_index = max(target.lineno - node.lineno, 0) - if line_index >= len(lines): - return len(source) - offset = sum(len(line) for line in lines[:line_index]) - if line_index == 0: - col = max(target.col_offset - getattr(node, "col_offset", 0), 0) - else: - col = max(target.col_offset, 0) - return cast(int, min(offset + col, len(source))) - - def _strip_first_self_param(self, signature: str) -> str: - updated = re.sub(r"(\(\s*)self(\s*,\s*)", r"\1", signature, count=1) - return re.sub(r"(\(\s*)self(\s*\))", r"\1\2", updated, count=1) - - def _run_body_source(self, run_node: ast.AST, source: str) -> str: - return self._function_body_source(run_node, source) - - def _tool_helper_source(self, node: ast.AST, source: str) -> str: - if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - return "" - helper_src = self._node_source(node, source) - if not helper_src: - return "" - body_source = self._function_body_source(node, source) - header_source = helper_src - if node.body: - header_source = helper_src[: self._source_offset(helper_src, node, node.body[0])] - header_source = self._strip_first_self_param(header_source.rstrip()) - if not body_source: - return _normalize_code(header_source) - return _normalize_code(f"{header_source}\n{textwrap.indent(body_source, ' ')}") - - def _replace_self_calls(self, code: str, helper_names: list[str]) -> str: - rewritten = code - for name in helper_names: - rewritten = re.sub(rf"\bself\.{name}\b", name, rewritten) - return rewritten + return "\n\n".join(script_parts).strip() diff --git a/cogsol/management/commands/startproject.py b/cogsol/management/commands/startproject.py index 365ed91..cb38cd7 100644 --- a/cogsol/management/commands/startproject.py +++ b/cogsol/management/commands/startproject.py @@ -90,6 +90,52 @@ def run(self, chat=None, data=None, secrets=None, log=None, text: str = "", coun # response = "Thanks for trying CogSol!" """ +MCP_SERVERS_PY = """\ +import os + +from cogsol.tools import BaseMCPServer +# Define MCP servers here. Use os.environ for sensitive credentials. +# Run `python manage.py addmcptools` to interactively add a server. +# The server URL is hardcoded — it is sent to the CogSol API on `migrate` +# and served from there; it does not need to live in .env. +# +# --- auth_type="none" (no credentials) --- +# class ExampleMCPServer(BaseMCPServer): +# name = "example_server" +# description = "Example MCP server." +# url = "https://example.com/mcp" +# +# --- auth_type="headers" (API key) --- +# class ExampleMCPServer(BaseMCPServer): +# name = "example_server" +# description = "Example MCP server." +# url = "https://example.com/mcp" +# headers = {"x-api-key": os.environ.get("MCP_EXAMPLE_SERVER_X_API_KEY", "")} +# +# --- auth_type="oauth2" (OAuth 2.1 / PKCE) --- +# class AtlassianMCPServer(BaseMCPServer): +# name = "atlassian_server" +# description = "Atlassian MCP server via OAuth 2.1." +# auth_type = "oauth2" +# url = "https://mcp.atlassian.com/mcp" +# oauth_client_id = os.environ.get("MCP_ATLASSIAN_SERVER_OAUTH_CLIENT_ID", "") +# oauth_scopes = os.environ.get("MCP_ATLASSIAN_SERVER_OAUTH_SCOPES", "") +# # oauth_client_secret is NEVER stored here — handled by the CogSol backend +""" + +MCP_TOOLS_PY = """\ +from cogsol.tools import BaseMCPTool +# Define MCP tools here. Each tool references a BaseMCPServer subclass. +# Run `python manage.py addmcptools` to interactively select tools. +# +# from agents.mcp_servers import ExampleMCPServer +# +# class ExampleMCPTool(BaseMCPTool): +# name = "example_tool" +# description = "An example tool from the MCP server." +# server = ExampleMCPServer +""" + # Data folder templates for Content API DATA_FORMATTERS_PY = """\ from cogsol.content import BaseReferenceFormatter @@ -144,6 +190,7 @@ def run(self, chat=None, data=None, secrets=None, log=None, text: str = "", coun - Create agents with `python manage.py startagent MyAgent` (per-agent folders under `agents/`). - Define reusable tools in `agents/tools.py` and import them in each agent. - Define retrieval tools in `agents/searches.py` to query Content API retrievals. +- Add MCP servers/tools with `python manage.py addmcptools` or define in `agents/mcp_servers.py` and `agents/mcp_tools.py`. ## Data (Content API) - Create topics with `python manage.py starttopic my_topic` (nested folders under `data/`). @@ -189,6 +236,8 @@ def handle(self, project_path: Path | None, **options: Any) -> int: "agents/__init__.py": "", "agents/tools.py": TOOLS_PY, "agents/searches.py": SEARCHES_PY, + "agents/mcp_servers.py": MCP_SERVERS_PY, + "agents/mcp_tools.py": MCP_TOOLS_PY, "agents/migrations/__init__.py": "", "data/__init__.py": "", "data/formatters.py": DATA_FORMATTERS_PY, diff --git a/cogsol/tools/__init__.py b/cogsol/tools/__init__.py index 78ce3ff..2a9f115 100644 --- a/cogsol/tools/__init__.py +++ b/cogsol/tools/__init__.py @@ -11,7 +11,7 @@ class BaseTool: name: str | None = None description: str | None = None - parameters: dict[str, Any] = {} + parameters: dict[str, Any] | None = None def __init__(self, name: str | None = None, description: str | None = None): if name: @@ -22,6 +22,8 @@ def __init__(self, name: str | None = None, description: str | None = None): # Derive name from class (strip 'Tool' suffix if present) cls_name = self.__class__.__name__ self.name = cls_name[:-4] if cls_name.endswith("Tool") else cls_name + # Avoid sharing mutable metadata across subclasses/instances. + self.parameters = dict(getattr(self, "parameters", {}) or {}) def run(self, *args: Any, **kwargs: Any) -> Any: # pragma: no cover - placeholder raise NotImplementedError("Tool execution is not implemented in the CLI framework.") @@ -57,7 +59,7 @@ def __repr__(self) -> str: class BaseRetrievalTool: name: str | None = None description: str | None = None - parameters: list[dict[str, Any]] = [] + parameters: list[dict[str, Any]] | None = None retrieval: str | None = None show_tool_message: bool = False show_assistant_message: bool = False @@ -72,11 +74,89 @@ def __init__(self, name: str | None = None, description: str | None = None): if not getattr(self, "name", None): cls_name = self.__class__.__name__ self.name = cls_name[:-4] if cls_name.endswith("Tool") else cls_name + # Avoid sharing mutable metadata across subclasses/instances. + self.parameters = list(getattr(self, "parameters", []) or []) def __repr__(self) -> str: return f"" +class BaseMCPServer: + """Base class for MCP server definitions. + + Subclass this to register an MCP server in your project. + URL, header values and OAuth credentials should be read from environment + variables (stored in .env) for security. + + ``auth_type`` controls how the server authenticates: + - ``"none"`` – no authentication + - ``"headers"`` – static headers (e.g. API keys) **default** + - ``"oauth2"`` – OAuth 2.1 / PKCE flow managed by the cognitive backend + + For OAuth 2.1, ``oauth_client_id`` and ``oauth_scopes`` are optional — + cogsol supports Dynamic Client Registration (RFC 7591) when they are + not provided. The client secret is **never** stored as a class attribute; + it is supplied interactively by ``addmcptools`` and sent write-only to the + API (stored in Azure Key Vault by the backend). + """ + + name: str | None = None + description: str | None = None + url: str | None = None + headers: dict[str, str] | None = None + protocol_version: str = "2025-03-26" + client_name: str = "cognitive-mcp-client" + client_version: str = "1.0.0" + active: bool = True + + # Authentication + auth_type: str = "headers" # "none" | "headers" | "oauth2" + + # OAuth 2.1 (only meaningful when auth_type == "oauth2") + oauth_client_id: str | None = None # Optional; DCR fills it when omitted + oauth_scopes: str | None = None # Space-separated scopes, e.g. "read:jira" + # NOTE: oauth_client_secret is intentionally NOT declared here — it is + # never stored in source code; addmcptools prompts for it and sends it + # write-only to the CogSol API (Azure Key Vault). + + def __init__(self, name: str | None = None): + if name: + self.name = name + if not getattr(self, "name", None): + cls_name = self.__class__.__name__ + self.name = cls_name[:-9] if cls_name.endswith("MCPServer") else cls_name + # Avoid sharing mutable headers across subclasses/instances. + self.headers = dict(getattr(self, "headers", {}) or {}) + + def __repr__(self) -> str: + return f"" + + +class BaseMCPTool: + """Base class for MCP tool definitions. + + Subclass this to register an MCP tool selected from an MCP server. + The ``server`` attribute should reference the BaseMCPServer subclass. + """ + + name: str | None = None + description: str | None = None + server: type | None = None # Reference to a BaseMCPServer subclass + show_tool_message: bool = False + show_assistant_message: bool = False + edit_available: bool = True + + def __init__(self, name: str | None = None): + if name: + self.name = name + if not getattr(self, "name", None): + cls_name = self.__class__.__name__ + self.name = cls_name[:-7] if cls_name.endswith("MCPTool") else cls_name + + def __repr__(self) -> str: + return f"" + + def tool_params(**params: dict[str, Any]): """ Decorator to attach parameter metadata to a tool's run method. @@ -102,5 +182,7 @@ def decorator(func): "BaseFAQ", "BaseFixedResponse", "BaseRetrievalTool", + "BaseMCPServer", + "BaseMCPTool", "tool_params", ] diff --git a/tests/test_addmcptools.py b/tests/test_addmcptools.py new file mode 100644 index 0000000..4cd2dc5 --- /dev/null +++ b/tests/test_addmcptools.py @@ -0,0 +1,486 @@ +"""Tests for the addmcptools management command.""" + +import ast + +from cogsol.core.api import CogSolAPIError +from cogsol.management.commands import addmcptools + + +class TestAddMCPToolsCodegen: + def test_generates_valid_python_with_special_characters(self, monkeypatch, tmp_path): + class FakeMCPClient: + def __init__(self, *_args, **_kwargs): + pass + + def initialize(self): + return True + + def list_tools(self): + return [ + { + "name": 'tool "one"', + "description": 'desc with "quotes" and triple """ markers', + } + ] + + def disconnect(self): + return None + + class FakeCogSolClient: + def __init__(self, *_args, **_kwargs): + self.upserted = [] + self.synced = [] + + def list_mcp_servers(self): + return [] + + def upsert_mcp_server(self, *, remote_id, payload): + self.upserted.append((remote_id, payload)) + return 99 + + def sync_mcp_server_tools(self, server_id, selected_tools): + self.synced.append((server_id, selected_tools)) + return {"results": [{"id": 1, "name": n} for n in selected_tools]} + + answers = { + "Server name": 'Server "A"', + "Description": "Description with \"double\" and 'single' quotes", + "Server URL (e.g. https://mcp.example.com/mcp)": 'https://example.com/mcp?x="1"', + "Select auth type": "1", # none + "Selection": "all", + } + + def fake_ask(prompt: str, default: str = "") -> str: + return answers.get(prompt, default) + + monkeypatch.setattr(addmcptools, "MCPClient", FakeMCPClient) + monkeypatch.setattr(addmcptools, "CogSolClient", FakeCogSolClient) + monkeypatch.setattr(addmcptools, "_ask", fake_ask) + monkeypatch.setenv("COGSOL_API_BASE", "https://api.example.test") + + project_path = tmp_path + (project_path / ".env").write_text("", encoding="utf-8") + + result = addmcptools.Command().handle(project_path=project_path, app="agents") + + assert result == 0 + + servers_file = project_path / "agents" / "mcp_servers.py" + tools_file = project_path / "agents" / "mcp_tools.py" + + servers_source = servers_file.read_text(encoding="utf-8") + tools_source = tools_file.read_text(encoding="utf-8") + + ast.parse(servers_source) + ast.parse(tools_source) + + assert "name = 'Server \"A\"'" in servers_source + assert ( + "description = 'Description with \"double\" and \\'single\\' quotes'" in servers_source + ) + assert "name = 'tool \"one\"'" in tools_source + assert 'description = \'desc with "quotes" and triple """ markers\'' in tools_source + + +class TestAddMCPToolsOAuthAssisted: + def test_oauth_assisted_discovery_opens_browser_and_loads_tools(self, monkeypatch, tmp_path): + class FakeMCPClient: + def __init__(self, *_args, **_kwargs): + pass + + def initialize(self): + return False + + def list_tools(self): + return [] + + def disconnect(self): + return None + + class FakeCogSolClient: + def __init__(self, *_args, **_kwargs): + self.calls = 0 + + def list_mcp_servers(self): + return [{"id": 12, "name": "jira oauth", "url": "https://mcp.atlassian.com/v1/mcp"}] + + def discover_mcp_oauth(self, _server_id): + return {"success": True} + + def get_mcp_oauth_authorization_url(self, _server_id): + return {"authorization_url": "https://mcp.atlassian.com/v1/authorize?..."} + + def get_mcp_server(self, _server_id): + self.calls += 1 + return {"oauth_status": "connected" if self.calls >= 1 else "disconnected"} + + def list_mcp_server_tools(self, _server_id): + return { + "tools": [ + { + "name": "GETACCESSIBLEATLASSIANRESOURCES", + "description": "Get cloudId to make tool calls.", + } + ] + } + + def sync_mcp_server_tools(self, _server_id, _selected_tools): + return {"results": [{"id": 33, "name": "GETACCESSIBLEATLASSIANRESOURCES"}]} + + def upsert_mcp_server(self, *, remote_id, payload): + return remote_id or 12 + + answers = { + "Server name": "jira oauth", + "Description": "", + "Server URL (e.g. https://mcp.example.com/mcp)": "https://mcp.atlassian.com/v1/mcp", + "Select auth type": "3", + "Client ID (leave blank for auto-registration)": "", + "Scopes (space-separated, e.g. 'read:jira write:confluence')": "", + "Selection": "all", + } + + def fake_ask(prompt: str, default: str = "") -> str: + return answers.get(prompt, default) + + monkeypatch.setenv("COGSOL_API_BASE", "https://api.example.test") + monkeypatch.setattr(addmcptools, "MCPClient", FakeMCPClient) + monkeypatch.setattr(addmcptools, "CogSolClient", FakeCogSolClient) + monkeypatch.setattr(addmcptools, "_ask", fake_ask) + monkeypatch.setattr(addmcptools, "_ask_secret", lambda _prompt: "") + opened_urls = [] + monkeypatch.setattr( + addmcptools.webbrowser, + "open", + lambda url, **_kwargs: opened_urls.append(url) or True, + ) + monkeypatch.setattr(addmcptools.time, "sleep", lambda _s: None) + + project_path = tmp_path + (project_path / ".env").write_text("", encoding="utf-8") + + result = addmcptools.Command().handle(project_path=project_path, app="agents") + + assert result == 0 + assert opened_urls + + tools_file = project_path / "agents" / "mcp_tools.py" + tools_source = tools_file.read_text(encoding="utf-8") + assert "GETACCESSIBLEATLASSIANRESOURCES" in tools_source + ast.parse(tools_source) + + def test_oauth_assisted_discovery_auto_creates_server_when_not_migrated( + self, monkeypatch, tmp_path + ): + class FakeMCPClient: + def __init__(self, *_args, **_kwargs): + pass + + def initialize(self): + return False + + def list_tools(self): + return [] + + def disconnect(self): + return None + + class FakeCogSolClient: + def __init__(self, *_args, **_kwargs): + self.calls = 0 + self.created_payload = None + + def list_mcp_servers(self): + return [] + + def upsert_mcp_server(self, *, remote_id, payload): + assert remote_id is None + self.created_payload = payload + return 77 + + def discover_mcp_oauth(self, _server_id): + return {"success": True} + + def get_mcp_oauth_authorization_url(self, _server_id): + return {"authorization_url": "https://mcp.atlassian.com/v1/authorize?..."} + + def get_mcp_server(self, _server_id): + self.calls += 1 + return {"oauth_status": "connected" if self.calls >= 1 else "disconnected"} + + def list_mcp_server_tools(self, _server_id): + return { + "tools": [ + { + "name": "GETACCESSIBLEATLASSIANRESOURCES", + "description": "Get cloudId to make tool calls.", + } + ] + } + + def sync_mcp_server_tools(self, _server_id, _selected_tools): + return {"results": [{"id": 44, "name": "GETACCESSIBLEATLASSIANRESOURCES"}]} + + answers = { + "Server name": "jira oauth", + "Description": "", + "Server URL (e.g. https://mcp.example.com/mcp)": "https://mcp.atlassian.com/v1/mcp", + "Select auth type": "3", + "Client ID (leave blank for auto-registration)": "", + "Scopes (space-separated, e.g. 'read:jira write:confluence')": "", + } + + def fake_ask(prompt: str, default: str = "") -> str: + return answers.get(prompt, default) + + monkeypatch.setenv("COGSOL_API_BASE", "https://api.example.test") + monkeypatch.setattr(addmcptools, "MCPClient", FakeMCPClient) + monkeypatch.setattr(addmcptools, "CogSolClient", FakeCogSolClient) + monkeypatch.setattr(addmcptools, "_ask", fake_ask) + monkeypatch.setattr(addmcptools, "_ask_secret", lambda _prompt: "") + monkeypatch.setattr(addmcptools.webbrowser, "open", lambda *_args, **_kwargs: True) + monkeypatch.setattr(addmcptools.time, "sleep", lambda _s: None) + + project_path = tmp_path + (project_path / ".env").write_text("", encoding="utf-8") + + result = addmcptools.Command().handle(project_path=project_path, app="agents") + + assert result == 0 + tools_file = project_path / "agents" / "mcp_tools.py" + tools_source = tools_file.read_text(encoding="utf-8") + assert "GETACCESSIBLEATLASSIANRESOURCES" in tools_source + + def test_find_remote_server_normalizes_name_and_uses_latest_url_match(self): + class FakeCogSolClient: + def list_mcp_servers(self): + return [ + { + "id": 1, + "name": "Atlassian mcp server", + "url": "https://mcp.atlassian.com/v1/mcp", + "auth_type": "oauth2", + "updated_at": "2026-03-17T10:00:00Z", + }, + { + "id": 2, + "name": "attlasian mcp server oauth", + "url": "https://mcp.atlassian.com/v1/mcp/", + "auth_type": "oauth2", + "updated_at": "2026-03-17T12:00:00Z", + }, + ] + + cmd = addmcptools.Command() + found = cmd._find_remote_server( + client=FakeCogSolClient(), + server_name=" Attlasian Mcp Server Oauth ", + server_url="https://mcp.atlassian.com/v1/mcp", + ) + + assert found is not None + assert found["id"] == 2 + + def test_publish_oauth_server_runs_discover_and_waits_connected_before_sync(self, monkeypatch): + calls: list[str] = [] + + class FakeCogSolClient: + def __init__(self, *_args, **_kwargs): + self.status_calls = 0 + + def list_mcp_servers(self): + return [ + { + "id": 173, + "name": "atlassian mcp server", + "url": "https://mcp.atlassian.com/v1/mcp", + } + ] + + def upsert_mcp_server(self, *, remote_id, payload): + calls.append("upsert") + assert remote_id == 173 + assert payload["auth_type"] == "oauth2" + return 173 + + def discover_mcp_oauth(self, server_id): + calls.append("discover") + assert server_id == 173 + return {"success": True} + + def get_mcp_server(self, server_id): + calls.append("get_server") + assert server_id == 173 + self.status_calls += 1 + if self.status_calls == 1: + return {"oauth_status": "disconnected"} + return {"oauth_status": "connected"} + + def get_mcp_oauth_authorization_url(self, server_id): + calls.append("authorize_url") + assert server_id == 173 + return {"authorization_url": "https://mcp.atlassian.com/v1/authorize?..."} + + def sync_mcp_server_tools(self, server_id, selected_tools): + calls.append("sync") + assert server_id == 173 + assert selected_tools == ["ATLASSIANUSERINFO"] + return {"results": [{"id": 1, "name": "ATLASSIANUSERINFO"}]} + + monkeypatch.setenv("COGSOL_API_BASE", "https://api.example.test") + monkeypatch.setattr(addmcptools, "CogSolClient", FakeCogSolClient) + monkeypatch.setattr(addmcptools.webbrowser, "open", lambda *_args, **_kwargs: True) + monkeypatch.setattr(addmcptools.time, "sleep", lambda _s: None) + + cmd = addmcptools.Command() + cmd._publish_to_cognitive( + server_name="atlassian mcp server", + server_description="", + server_url="https://mcp.atlassian.com/v1/mcp", + auth_type="oauth2", + headers={}, + oauth_client_id="", + oauth_client_secret="", + oauth_scopes="", + selected_tools=[{"name": "ATLASSIANUSERINFO"}], + oauth_timeout=5, + ) + + assert "discover" in calls + assert "authorize_url" in calls + assert calls.index("discover") < calls.index("sync") + + def test_publish_oauth_server_retries_sync_after_reauthorization(self, monkeypatch): + calls: list[str] = [] + + class FakeCogSolClient: + def __init__(self, *_args, **_kwargs): + self.sync_calls = 0 + self.status_calls = 0 + + def list_mcp_servers(self): + return [ + { + "id": 176, + "name": "atlassian mcp server", + "url": "https://mcp.atlassian.com/v1/mcp", + } + ] + + def upsert_mcp_server(self, *, remote_id, payload): + calls.append("upsert") + assert remote_id == 176 + assert payload["auth_type"] == "oauth2" + return 176 + + def discover_mcp_oauth(self, server_id): + calls.append("discover") + assert server_id == 176 + return {"success": True} + + def get_mcp_server(self, server_id): + calls.append("get_server") + assert server_id == 176 + self.status_calls += 1 + # First check says connected; recovery path should still reauthorize after sync failure. + if self.status_calls == 1: + return {"oauth_status": "connected"} + return {"oauth_status": "connected"} + + def get_mcp_oauth_authorization_url(self, server_id): + calls.append("authorize_url") + assert server_id == 176 + return {"authorization_url": "https://mcp.atlassian.com/v1/authorize?..."} + + def sync_mcp_server_tools(self, server_id, selected_tools): + calls.append("sync") + assert server_id == 176 + assert selected_tools == ["ATLASSIANUSERINFO"] + self.sync_calls += 1 + if self.sync_calls == 1: + raise CogSolAPIError( + '500 Internal Server Error: {"error":"Internal server error: ' + "OAuth re-authorization required for MCP server 'atlassian mcp server'\"}" + ) + return {"results": [{"id": 1, "name": "ATLASSIANUSERINFO"}]} + + monkeypatch.setenv("COGSOL_API_BASE", "https://api.example.test") + monkeypatch.setattr(addmcptools, "CogSolClient", FakeCogSolClient) + monkeypatch.setattr(addmcptools.webbrowser, "open", lambda *_args, **_kwargs: True) + monkeypatch.setattr(addmcptools.time, "sleep", lambda _s: None) + + cmd = addmcptools.Command() + cmd._publish_to_cognitive( + server_name="atlassian mcp server", + server_description="", + server_url="https://mcp.atlassian.com/v1/mcp", + auth_type="oauth2", + headers={}, + oauth_client_id="", + oauth_client_secret="", + oauth_scopes="", + selected_tools=[{"name": "ATLASSIANUSERINFO"}], + oauth_timeout=5, + ) + + assert calls.count("sync") == 2 + assert "authorize_url" in calls + + def test_publish_oauth_server_does_not_retry_for_non_oauth_errors(self, monkeypatch): + calls: list[str] = [] + + class FakeCogSolClient: + def __init__(self, *_args, **_kwargs): + pass + + def list_mcp_servers(self): + return [ + { + "id": 176, + "name": "atlassian mcp server", + "url": "https://mcp.atlassian.com/v1/mcp", + } + ] + + def upsert_mcp_server(self, *, remote_id, payload): + assert remote_id == 176 + return 176 + + def discover_mcp_oauth(self, server_id): + assert server_id == 176 + return {"success": True} + + def get_mcp_server(self, server_id): + assert server_id == 176 + return {"oauth_status": "connected"} + + def get_mcp_oauth_authorization_url(self, server_id): + calls.append("authorize_url") + assert server_id == 176 + return {"authorization_url": "https://mcp.atlassian.com/v1/authorize?..."} + + def sync_mcp_server_tools(self, _server_id, _selected_tools): + raise CogSolAPIError("500 Internal Server Error: generic failure") + + monkeypatch.setenv("COGSOL_API_BASE", "https://api.example.test") + monkeypatch.setattr(addmcptools, "CogSolClient", FakeCogSolClient) + + cmd = addmcptools.Command() + try: + cmd._publish_to_cognitive( + server_name="atlassian mcp server", + server_description="", + server_url="https://mcp.atlassian.com/v1/mcp", + auth_type="oauth2", + headers={}, + oauth_client_id="", + oauth_client_secret="", + oauth_scopes="", + selected_tools=[{"name": "ATLASSIANUSERINFO"}], + oauth_timeout=5, + ) + raise AssertionError("Expected CogSolAPIError") + except CogSolAPIError: + pass + + assert "authorize_url" not in calls diff --git a/tests/test_mcp_client.py b/tests/test_mcp_client.py new file mode 100644 index 0000000..1a38f24 --- /dev/null +++ b/tests/test_mcp_client.py @@ -0,0 +1,28 @@ +"""Tests for MCP client error formatting.""" + +from cogsol.core.mcp import MCPClient + + +class TestMCPClientErrorSummary: + def test_summarizes_html_error_with_title(self): + html = """ + + + Access denied | Cloudflare +

Error 1010

+ + """ + + summary = MCPClient._summarize_http_error(html) + + assert "HTML error page returned" in summary + assert "Access denied | Cloudflare" in summary + assert " None: assert "response = helper(text=text)" in script assert "self.helper" not in script ast.parse(script) + + +class TestAssistantPayloadMCPTools: + def test_maps_mcp_tool_ids_from_remote_registry(self) -> None: + class PingMCPTool(BaseMCPTool): + name = "ping" + + class DemoAgent: + tools = [PingMCPTool()] + + payload = Command()._assistant_payload( + agent_name="DemoAgent", + definition={"fields": {}, "meta": {}}, + cls=DemoAgent, + remote_ids={ + "tools": {}, + "retrieval_tools": {}, + "mcp_tools": {"ping": 123}, + }, + project_path=Path("."), + app="agents", + ) + + assert payload["tools"] == [123] + + def test_raises_when_mcp_tool_is_not_published(self) -> None: + class MissingMCPTool(BaseMCPTool): + name = "missing_remote_tool" + + class DemoAgent: + tools = [MissingMCPTool()] + + with pytest.raises(CogSolAPIError, match="Run 'addmcptools' first"): + Command()._assistant_payload( + agent_name="DemoAgent", + definition={"fields": {}, "meta": {}}, + cls=DemoAgent, + remote_ids={ + "tools": {}, + "retrieval_tools": {}, + "mcp_tools": {}, + }, + project_path=Path("."), + app="agents", + ) + + +class TestMCPToolRemoteIdHarvest: + def test_extracts_ids_from_list_payload(self) -> None: + remote = {"mcp_tools": {}} + payload = [ + {"id": 10, "name": "read_wiki_structure"}, + {"id": 11, "name": "read_wiki_contents"}, + ] + + found = Command()._update_mcp_tool_remote_ids(remote, payload) + + assert found is True + assert remote["mcp_tools"]["read_wiki_structure"] == 10 + assert remote["mcp_tools"]["read_wiki_contents"] == 11 + + def test_extracts_ids_from_results_payload(self) -> None: + remote = {"mcp_tools": {}} + payload = { + "count": 1, + "results": [{"id": 42, "name": "ask_question"}], + } + + found = Command()._update_mcp_tool_remote_ids(remote, payload) + + assert found is True + assert remote["mcp_tools"]["ask_question"] == 42 + + def test_extracts_ids_from_configured_tools_payload(self) -> None: + remote = {"mcp_tools": {}} + payload = { + "tools": [ + {"name": "read_wiki_structure", "already_configured": True}, + {"name": "read_wiki_contents", "already_configured": True}, + ], + "configured_tools": [ + {"id": 903, "name": "ask_question", "configured": True}, + {"id": 904, "name": "read_wiki_contents", "configured": True}, + {"id": 905, "name": "read_wiki_structure", "configured": True}, + ], + } + + found = Command()._update_mcp_tool_remote_ids(remote, payload) + + assert found is True + assert remote["mcp_tools"]["ask_question"] == 903 + assert remote["mcp_tools"]["read_wiki_contents"] == 904 + assert remote["mcp_tools"]["read_wiki_structure"] == 905 + + +class TestToolScriptFromClass: + def test_handles_multiline_run_signature(self) -> None: + class WeatherTool(BaseTool): + def run( + self, + chat=None, + data=None, + latitude: float = 0.0, + longitude: float = 0.0, + ): + return {"lat": latitude, "lon": longitude} + + script = Command()._tool_script_from_class(WeatherTool) + + assert "latitude = params.get('latitude')" in script + assert "longitude = params.get('longitude')" in script + assert "chat=None" not in script + assert 'response = {"lat": latitude, "lon": longitude}' in script + ast.parse(script) + + +class TestMigrateMCPAssociationOnly: + def test_hydrates_mcp_ids_and_associates_existing_tool(self, monkeypatch) -> None: + class PingMCPTool(BaseMCPTool): + name = "ping" + + class DemoAgent: + tools = [PingMCPTool()] + + captured_payloads: list[dict] = [] + + class FakeClient: + def __init__(self, *_args, **_kwargs): + pass + + def list_mcp_servers(self): + return [{"id": 42, "name": "Demo MCP", "url": "https://mcp.demo"}] + + def list_mcp_server_tools(self, _server_id): + return {"results": [{"id": 901, "name": "ping"}]} + + def upsert_assistant(self, *, remote_id, payload): + captured_payloads.append(payload) + return remote_id or 100 + + monkeypatch.setattr("cogsol.management.commands.migrate.CogSolClient", FakeClient) + + cmd = Command() + state = { + "tools": {}, + "retrieval_tools": {}, + "mcp_servers": { + "DemoMCPServer": { + "fields": { + "name": "Demo MCP", + "auth_type": "none", + "url": "https://mcp.demo", + } + } + }, + "mcp_tools": { + "PingMCPTool": { + "fields": { + "name": "ping", + "server": "DemoMCPServer", + } + } + }, + "agents": { + "DemoAgent": { + "fields": { + "description": "Demo", + "system_prompt": "Hi", + }, + "meta": {}, + } + }, + } + + remote, _ = cmd._sync_with_api( + api_base="https://api.invalid", + api_key=None, + state=state, + remote_ids=cmd._empty_remote(), + class_map={ + "tools": {}, + "retrieval_tools": {}, + "mcp_servers": {}, + "agents": {"DemoAgent": DemoAgent}, + }, + project_path=Path("."), + app="agents", + touched=None, + ) + + assert remote["mcp_tools"]["ping"] == 901 + assert captured_payloads + assert captured_payloads[0]["tools"] == [901] + + def test_hydration_skips_unrelated_remote_servers(self, monkeypatch) -> None: + class PingMCPTool(BaseMCPTool): + name = "ping" + + class DemoAgent: + tools = [PingMCPTool()] + + class FakeClient: + def __init__(self, *_args, **_kwargs): + pass + + def list_mcp_servers(self): + return [ + {"id": 42, "name": "Demo MCP", "url": "https://mcp.demo"}, + {"id": 99, "name": "moweek", "url": "https://moweek.invalid/mcp"}, + ] + + def list_mcp_server_tools(self, server_id): + if int(server_id) == 99: + raise AssertionError("Unrelated server should not be queried") + return {"results": [{"id": 901, "name": "ping"}]} + + def upsert_assistant(self, *, remote_id, payload): + return remote_id or 100 + + monkeypatch.setattr("cogsol.management.commands.migrate.CogSolClient", FakeClient) + + cmd = Command() + state = { + "tools": {}, + "retrieval_tools": {}, + "mcp_servers": { + "DemoMCPServer": { + "fields": { + "name": "Demo MCP", + "auth_type": "none", + "url": "https://mcp.demo", + } + } + }, + "mcp_tools": { + "PingMCPTool": { + "fields": { + "name": "ping", + "server": "DemoMCPServer", + } + } + }, + "agents": { + "DemoAgent": { + "fields": { + "description": "Demo", + "system_prompt": "Hi", + }, + "meta": {}, + } + }, + } + + remote, _ = cmd._sync_with_api( + api_base="https://api.invalid", + api_key=None, + state=state, + remote_ids=cmd._empty_remote(), + class_map={ + "tools": {}, + "retrieval_tools": {}, + "mcp_servers": {}, + "agents": {"DemoAgent": DemoAgent}, + }, + project_path=Path("."), + app="agents", + touched=None, + ) + + assert remote["mcp_tools"]["ping"] == 901 + + +class TestRollbackDeleteDispatch: + def test_delete_created_entry_supports_mcp_tool(self) -> None: + class FakeClient: + def __init__(self): + self.deleted = [] + + def delete_mcp_tool(self, tool_id): + self.deleted.append(("mcp_tool", tool_id)) + + client = FakeClient() + Command()._delete_created_entry(client, "mcp_tool", None, 77) + + assert ("mcp_tool", 77) in client.deleted diff --git a/tests/test_tools.py b/tests/test_tools.py index 70db185..a75965f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -6,6 +6,8 @@ BaseFAQ, BaseFixedResponse, BaseLesson, + BaseMCPServer, + BaseRetrievalTool, BaseTool, tool_params, ) @@ -42,6 +44,46 @@ def test_repr(self): tool = BaseTool(name="test") assert "test" in repr(tool) + def test_parameters_not_shared_between_instances(self): + """BaseTool parameters should not be shared across instances.""" + + class ParamTool(BaseTool): + parameters = {"a": {"type": "string"}} + + first = ParamTool() + second = ParamTool() + + first.parameters["b"] = {"type": "integer"} + assert "b" not in second.parameters + + +class TestBaseRetrievalTool: + def test_parameters_not_shared_between_instances(self): + """BaseRetrievalTool parameters should not be shared across instances.""" + + class RetrievalTool(BaseRetrievalTool): + parameters = [{"name": "query"}] + + first = RetrievalTool() + second = RetrievalTool() + + first.parameters.append({"name": "limit"}) + assert len(second.parameters) == 1 + + +class TestBaseMCPServer: + def test_headers_not_shared_between_instances(self): + """BaseMCPServer headers should not be shared across instances.""" + + class DemoServer(BaseMCPServer): + headers = {"x-api-key": "one"} + + first = DemoServer() + second = DemoServer() + + first.headers["x-custom"] = "v" + assert "x-custom" not in second.headers + class TestBaseFAQ: """Tests for BaseFAQ class."""