diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9be656d..39db668 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,6 +33,9 @@ jobs: - name: Check formatting run: uv run --group dev black --check src tests + - name: Type check + run: uv run --group dev pyright + unit: name: Unit runs-on: ubuntu-latest diff --git a/AGENTS.md b/AGENTS.md index 4298c4a..85c682a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -93,20 +93,27 @@ Run these locally before opening a PR. They mirror the `CI` workflow ``` Run `uv run --group dev black src tests` (without `--check`) to auto-fix. -3. **Unit tests** (`unit` job) +3. **Type check** (`lint` job) + ```bash + uv run --group dev pyright + ``` + Pyright config lives in `pyproject.toml` (`[tool.pyright]`): basic mode over + `src/`, Python 3.12, resolving against the project `.venv`. + +4. **Unit tests** (`unit` job) ```bash uv sync uv run python -m unittest discover -s tests/unit -v ``` Fast, no external services or credentials required. -4. **Docker build** (`docker` job) +5. **Docker build** (`docker` job) ```bash docker build -t appwrite-mcp:ci . ``` The hosted HTTP image must build cleanly. -5. **Integration tests** (`integration` job) — *CI runs these only for pushes and +6. **Integration tests** (`integration` job) — *CI runs these only for pushes and for PRs from branches on the same repo (not forks).* They create and delete **real** Appwrite resources, so they need live credentials and are skipped when absent: diff --git a/pyproject.toml b/pyproject.toml index 41d644b..fb898df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ integration = [ dev = [ "black>=25.1.0", "ruff>=0.10.0", + "pyright>=1.1.390", # Only needed by scripts/build_docs_index.py to (re)build the docs index. "pyyaml>=6.0", ] @@ -56,6 +57,13 @@ line-length = 88 select = ["E", "F", "W", "I"] ignore = ["E501"] +[tool.pyright] +pythonVersion = "3.12" +include = ["src"] +venvPath = "." +venv = ".venv" +typeCheckingMode = "basic" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/mcp_server_appwrite/auth.py b/src/mcp_server_appwrite/auth.py index 1f78c21..c77c2c9 100644 --- a/src/mcp_server_appwrite/auth.py +++ b/src/mcp_server_appwrite/auth.py @@ -19,16 +19,19 @@ import time from urllib.parse import urlsplit, urlunsplit -import anyio import httpx import jwt +from anyio import to_thread from jwt import PyJWKClient from mcp.server.auth.provider import AccessToken, TokenVerifier from . import telemetry - -DEFAULT_ENDPOINT = "https://cloud.appwrite.io/v1" -DEFAULT_PROJECT_ID = "console" +from .constants import ( + DEFAULT_ENDPOINT, + DEFAULT_PROJECT_ID, + DISCOVERY_TTL_SECONDS, + PREFERRED_SCOPES, +) def _log(message: str) -> None: @@ -69,10 +72,30 @@ def resource_metadata_url() -> str: return urlunsplit((parts.scheme, parts.netloc, path, "", "")) -# Cache of scopes_supported, keyed by served project id (process lifetime; the -# project OAuth config is effectively static). Failed lookups raise and are not -# cached, so they retry. -_discovery_cache: dict[str, dict] = {} +def preferred_scopes() -> list[str]: + override = os.getenv("MCP_OAUTH_SCOPES", "").split() + return override or list(PREFERRED_SCOPES) + + +# Discovery cache keyed by served project id: (monotonic fetch time, document). +# Entries are refreshed after a TTL so authorization-server changes (issuer host, +# scope model) propagate without a redeploy; if a refresh fails, the stale copy +# keeps serving so an authorization-server blip doesn't take the MCP down. +_discovery_cache: dict[str, tuple[float, dict]] = {} + + +def _cached_discovery(project_id: str, *, allow_stale: bool = False) -> dict | None: + entry = _discovery_cache.get(project_id) + if entry is None: + return None + fetched_at, document = entry + if allow_stale or time.monotonic() - fetched_at < DISCOVERY_TTL_SECONDS: + return document + return None + + +def _store_discovery(project_id: str, document: dict) -> None: + _discovery_cache[project_id] = (time.monotonic(), document) def discovery_url() -> str: @@ -91,47 +114,68 @@ def _validate_discovery(doc: dict, url: str) -> dict: async def authorization_server_metadata() -> dict: project_id = configured_project_id() - cached = _discovery_cache.get(project_id) + cached = _cached_discovery(project_id) if cached is not None: return cached url = discovery_url() - async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client: - resp = await client.get(url) - resp.raise_for_status() - metadata = _validate_discovery(resp.json(), url) - - _discovery_cache[project_id] = metadata + try: + async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client: + resp = await client.get(url) + resp.raise_for_status() + metadata = _validate_discovery(resp.json(), url) + except Exception as exc: + stale = _cached_discovery(project_id, allow_stale=True) + if stale is not None: + _log(f"Discovery refresh failed ({exc}); serving stale metadata.") + return stale + raise + + _store_discovery(project_id, metadata) return metadata def authorization_server_metadata_sync() -> dict: project_id = configured_project_id() - cached = _discovery_cache.get(project_id) + cached = _cached_discovery(project_id) if cached is not None: return cached url = discovery_url() - resp = httpx.get(url, timeout=10.0, follow_redirects=True) - resp.raise_for_status() - metadata = _validate_discovery(resp.json(), url) - _discovery_cache[project_id] = metadata + try: + resp = httpx.get(url, timeout=10.0, follow_redirects=True) + resp.raise_for_status() + metadata = _validate_discovery(resp.json(), url) + except Exception as exc: + stale = _cached_discovery(project_id, allow_stale=True) + if stale is not None: + _log(f"Discovery refresh failed ({exc}); serving stale metadata.") + return stale + raise + + _store_discovery(project_id, metadata) return metadata -async def supported_scopes() -> list[str]: - """Scopes advertised in the protected-resource metadata, sourced live from the - served project's authorization-server discovery (`scopes_supported`). This is - exactly the set the project's OAuth server will grant, so it never drifts from - the tool surface. Raises if discovery is unreachable or malformed (the - authorization server is the same Appwrite deployment this MCP depends on).""" - metadata = await authorization_server_metadata() - scopes = metadata.get("scopes_supported") - if not isinstance(scopes, list): +def _advertised_scopes(metadata: dict) -> list[str]: + """The scope set to advertise: the preferred scopes intersected with the + authorization server's live ``scopes_supported`` (so a renamed/removed scope + is never advertised). Falls back to mirroring the full discovery list when + none of the preferred scopes exist — e.g. a self-hosted project with a + custom, compact scope catalog.""" + discovered = metadata.get("scopes_supported") + if not isinstance(discovered, list): raise ValueError( f"authorization server discovery missing scopes_supported: {discovery_url()}" ) - return scopes + scopes = [scope for scope in preferred_scopes() if scope in discovered] + if scopes: + return scopes + _log( + "None of the preferred scopes are in the authorization server's " + "scopes_supported; advertising the full discovered list." + ) + return discovered def build_resource_metadata(scopes: list[str], authorization_servers=None) -> dict: @@ -145,14 +189,10 @@ def build_resource_metadata(scopes: list[str], authorization_servers=None) -> di async def protected_resource_metadata() -> dict: - """RFC 9728 Protected Resource Metadata, with scopes sourced from AS discovery.""" + """RFC 9728 Protected Resource Metadata, with scopes validated against AS + discovery.""" metadata = await authorization_server_metadata() - scopes = metadata.get("scopes_supported") - if not isinstance(scopes, list): - raise ValueError( - f"authorization server discovery missing scopes_supported: {discovery_url()}" - ) - return build_resource_metadata(scopes, [metadata["issuer"]]) + return build_resource_metadata(_advertised_scopes(metadata), [metadata["issuer"]]) def project_id_from_issuer(iss: str | None) -> str | None: @@ -286,7 +326,7 @@ def _audience_ok(self, aud, expected_resource: str) -> bool: async def verify_token(self, token: str) -> AccessToken | None: start = time.monotonic() - access_token = await anyio.to_thread.run_sync(self._verify_sync, token) + access_token = await to_thread.run_sync(self._verify_sync, token) duration = time.monotonic() - start if access_token is None: # The specific rejection reason was already counted in _verify_sync; diff --git a/src/mcp_server_appwrite/constants.py b/src/mcp_server_appwrite/constants.py new file mode 100644 index 0000000..3700dbd --- /dev/null +++ b/src/mcp_server_appwrite/constants.py @@ -0,0 +1,144 @@ +"""Single home for the package's constants, grouped by the module that uses them.""" + +from __future__ import annotations + +from pathlib import Path + +from appwrite.models.bucket import Bucket +from appwrite.models.database import Database +from appwrite.models.function import Function +from appwrite.models.message import Message +from appwrite.models.site import Site +from appwrite.models.team import Team +from appwrite.models.user import User + +# --- server --------------------------------------------------------------- + +SERVER_VERSION = "0.8.1" + +DEFAULT_ENDPOINT = "https://cloud.appwrite.io/v1" +DEFAULT_TRANSPORT = "stdio" +TRANSPORTS = {"stdio", "http"} +VALIDATION_SERVICE_ORDER = ( + "tables_db", + "users", + "teams", + "functions", + "sites", + "storage", + "messaging", + "locale", + "avatars", +) + +# Service modules in the Appwrite SDK to skip (none by default — every service the +# installed SDK ships is exposed). Add a module name here to hide a service. +EXCLUDED_SERVICES: frozenset[str] = frozenset() + +MAX_FETCH_BYTES = 25 * 1024 * 1024 # 25 MB cap on server-fetched files +MAX_INLINE_BYTES = 256 * 1024 # 256 KB cap on decoded inline content +FETCH_TIMEOUT_SECONDS = 30.0 +FETCH_MAX_REDIRECTS = 5 + +HOSTED_PATH_GUIDANCE = ( + "The hosted Appwrite MCP server cannot read local file paths. For '{param}', pass a " + 'public URL as {{"url": "https://..."}} (preferred), or a small file inline as ' + '{{"filename": "...", "content": "", "encoding": "base64"}}.' +) + +# --- auth ----------------------------------------------------------------- + +DEFAULT_PROJECT_ID = "console" + +PREFERRED_SCOPES = [ + "openid", + "profile", + "email", + "all", + "project:all", + "organization:all", +] + +DISCOVERY_TTL_SECONDS = 300.0 + +# --- http_app ------------------------------------------------------------- + +CORS_HEADERS = { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS", + "Access-Control-Allow-Headers": "Authorization, Content-Type, Mcp-Session-Id, Mcp-Protocol-Version", + "Access-Control-Expose-Headers": "Mcp-Session-Id, WWW-Authenticate", +} + +# --- operator ------------------------------------------------------------- + +SEARCH_LIMIT = 8 +PREVIEW_THRESHOLD = 800 +RESULT_STORE_SIZE = 50 +CATALOG_URI = "appwrite://operator/catalog" +RESULT_URI_TEMPLATE = "appwrite://operator/results/{result_id}" +VERBS = {"list", "get", "create", "update", "delete"} +READ_VERBS = {"list", "get"} +CREATE_HINTS = {"add", "build", "create", "insert", "make", "new", "provision"} +UPDATE_HINTS = {"change", "edit", "modify", "rename", "set", "update"} +DELETE_HINTS = {"delete", "destroy", "drop", "remove"} +READ_HINTS = {"fetch", "find", "get", "list", "read", "search", "show", "view"} + +# --- docs_search ---------------------------------------------------------- + +DOCS_TOOL_NAME = "appwrite_search_docs" +EMBED_MODEL = "text-embedding-3-small" +DOCS_DEFAULT_LIMIT = 5 +DOCS_MAX_LIMIT = 10 +DOCS_DEFAULT_MIN_SCORE = 0.25 +DOCS_MIN_QUERY_LENGTH = 3 + +DATA_DIR = Path(__file__).parent / "data" +VECTORS_FILE = "docs_index.npz" +META_FILE = "docs_index_meta.json" + +# --- context -------------------------------------------------------------- + +SERVICE_PROBES = { + "tablesdb": { + "path": "/tablesdb", + "items_key": "databases", + "model": Database, + }, + "users": { + "path": "/users", + "items_key": "users", + "model": User, + }, + "storage": { + "path": "/storage/buckets", + "items_key": "buckets", + "model": Bucket, + }, + "functions": { + "path": "/functions", + "items_key": "functions", + "model": Function, + }, + "sites": { + "path": "/sites", + "items_key": "sites", + "model": Site, + }, + "messaging": { + "path": "/messaging/messages", + "items_key": "messages", + "model": Message, + }, + "teams": { + "path": "/teams", + "items_key": "teams", + "model": Team, + }, +} + +REDACTED_KEYS = {"password", "secret", "key", "token", "otp", "cookie", "session"} + +# --- telemetry ------------------------------------------------------------ + +ACTIVE_WINDOW_SECONDS = 300.0 # rolling window for "active users/clients" gauges diff --git a/src/mcp_server_appwrite/context.py b/src/mcp_server_appwrite/context.py index b48afec..f021803 100644 --- a/src/mcp_server_appwrite/context.py +++ b/src/mcp_server_appwrite/context.py @@ -5,57 +5,14 @@ from appwrite.client import Client from appwrite.exception import AppwriteException -from appwrite.models.bucket import Bucket -from appwrite.models.database import Database -from appwrite.models.function import Function -from appwrite.models.message import Message from appwrite.models.project import Project -from appwrite.models.site import Site from appwrite.models.team import Team from appwrite.models.user import User from appwrite.query import Query -ContextClientFactory = Callable[[str | None, str | None], Client] +from .constants import REDACTED_KEYS, SERVICE_PROBES -SERVICE_PROBES = { - "tablesdb": { - "path": "/tablesdb", - "items_key": "databases", - "model": Database, - }, - "users": { - "path": "/users", - "items_key": "users", - "model": User, - }, - "storage": { - "path": "/storage/buckets", - "items_key": "buckets", - "model": Bucket, - }, - "functions": { - "path": "/functions", - "items_key": "functions", - "model": Function, - }, - "sites": { - "path": "/sites", - "items_key": "sites", - "model": Site, - }, - "messaging": { - "path": "/messaging/messages", - "items_key": "messages", - "model": Message, - }, - "teams": { - "path": "/teams", - "items_key": "teams", - "model": Team, - }, -} - -REDACTED_KEYS = {"password", "secret", "key", "token", "otp", "cookie", "session"} +ContextClientFactory = Callable[[str | None, str | None], Client] def get_appwrite_context( diff --git a/src/mcp_server_appwrite/docs_search.py b/src/mcp_server_appwrite/docs_search.py index d66cfc4..dade65b 100644 --- a/src/mcp_server_appwrite/docs_search.py +++ b/src/mcp_server_appwrite/docs_search.py @@ -22,20 +22,20 @@ import mcp.types as types from . import telemetry +from .constants import ( + DATA_DIR, + DOCS_DEFAULT_LIMIT, + DOCS_DEFAULT_MIN_SCORE, + DOCS_MAX_LIMIT, + DOCS_MIN_QUERY_LENGTH, + DOCS_TOOL_NAME, + EMBED_MODEL, + META_FILE, + VECTORS_FILE, +) ToolContent = types.TextContent | types.ImageContent | types.EmbeddedResource -TOOL_NAME = "appwrite_search_docs" -EMBED_MODEL = "text-embedding-3-small" -DEFAULT_LIMIT = 5 -MAX_LIMIT = 10 -DEFAULT_MIN_SCORE = 0.25 -MIN_QUERY_LENGTH = 3 - -DATA_DIR = Path(__file__).parent / "data" -VECTORS_FILE = "docs_index.npz" -META_FILE = "docs_index_meta.json" - # An embedder maps a query string to its embedding vector. Embedder = Callable[[str], list[float]] @@ -63,7 +63,7 @@ def _clamp_limit(value: Any, default: int) -> int: limit = int(value) if limit < 1: raise ValueError("limit must be at least 1.") - return min(limit, MAX_LIMIT) + return min(limit, DOCS_MAX_LIMIT) class DocsSearch: @@ -80,14 +80,14 @@ def __init__( data_dir: Path | None = None, embedder: Embedder | None = None, min_score: float | None = None, - default_limit: int = DEFAULT_LIMIT, + default_limit: int = DOCS_DEFAULT_LIMIT, ): self._data_dir = data_dir or DATA_DIR self._embedder = embedder if embedder is not None else _default_embedder() self._min_score = ( min_score if min_score is not None - else float(os.getenv("DOCS_SEARCH_MIN_SCORE", DEFAULT_MIN_SCORE)) + else float(os.getenv("DOCS_SEARCH_MIN_SCORE", DOCS_DEFAULT_MIN_SCORE)) ) self._default_limit = int(os.getenv("DOCS_SEARCH_LIMIT", default_limit)) self._vectors = None # np.ndarray [N, D], L2-normalized @@ -118,7 +118,7 @@ def _load_index(self) -> bool: def get_tool(self) -> types.Tool: return types.Tool( - name=TOOL_NAME, + name=DOCS_TOOL_NAME, description=( "Search the Appwrite documentation with a natural-language query and " "return the most relevant documentation pages with their full content. " @@ -136,7 +136,7 @@ def get_tool(self) -> types.Tool: "limit": { "type": "integer", "minimum": 1, - "maximum": MAX_LIMIT, + "maximum": DOCS_MAX_LIMIT, "description": f"Maximum number of pages to return. Defaults to {self._default_limit}.", }, }, @@ -148,9 +148,9 @@ def get_tool(self) -> types.Tool: def search(self, arguments: dict[str, Any] | None) -> list[ToolContent]: arguments = arguments or {} query = str(arguments.get("query", "")).strip() - if len(query) < MIN_QUERY_LENGTH: + if len(query) < DOCS_MIN_QUERY_LENGTH: raise ValueError( - f"query must be at least {MIN_QUERY_LENGTH} characters long." + f"query must be at least {DOCS_MIN_QUERY_LENGTH} characters long." ) if not self.available: raise RuntimeError( @@ -185,6 +185,12 @@ def _rank(self, query: str, limit: int) -> tuple[list[dict[str, Any]], float]: """Return the ranked pages and the embedding call's duration in seconds.""" import numpy as np + # _rank is only reachable when the index is loaded and an embedder exists + # (guarded by `available`); narrow the optionals for the type checker. + assert self._embedder is not None + assert self._vectors is not None + assert self._chunk_page is not None + embed_start = time.monotonic() embedding = np.asarray(self._embedder(query), dtype=np.float32) embedding_duration_s = time.monotonic() - embed_start diff --git a/src/mcp_server_appwrite/http_app.py b/src/mcp_server_appwrite/http_app.py index c01e12a..0c40453 100644 --- a/src/mcp_server_appwrite/http_app.py +++ b/src/mcp_server_appwrite/http_app.py @@ -29,7 +29,7 @@ from starlette.requests import Request from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route -from starlette.types import Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send from . import telemetry from .auth import ( @@ -37,20 +37,13 @@ protected_resource_metadata, resource_metadata_url, ) +from .constants import CORS_HEADERS, SERVER_VERSION from .server import ( - SERVER_VERSION, build_catalog_tools_manager, build_mcp_server, build_operator, ) -_CORS_HEADERS = { - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS", - "Access-Control-Allow-Headers": "Authorization, Content-Type, Mcp-Session-Id, Mcp-Protocol-Version", - "Access-Control-Expose-Headers": "Mcp-Session-Id, WWW-Authenticate", -} - class HealthzAccessLogFilter(logging.Filter): """Drop noisy load-balancer health probes from uvicorn access logs.""" @@ -98,7 +91,7 @@ class RequireBearer: against the token's granted scopes), so the gate only requires a valid token. """ - def __init__(self, app: object) -> None: + def __init__(self, app: ASGIApp) -> None: self.app = app async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -123,7 +116,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) async def _preflight(self, send: Send) -> None: - headers = [(k.lower().encode(), v.encode()) for k, v in _CORS_HEADERS.items()] + headers = [(k.lower().encode(), v.encode()) for k, v in CORS_HEADERS.items()] await send({"type": "http.response.start", "status": 204, "headers": headers}) await send({"type": "http.response.body", "body": b""}) @@ -137,7 +130,7 @@ def _has_authorization_header(scope: Scope) -> bool: async def protected_resource_metadata_endpoint(request: Request) -> JSONResponse: metadata = await protected_resource_metadata() - return JSONResponse(metadata, headers=_CORS_HEADERS) + return JSONResponse(metadata, headers=CORS_HEADERS) async def health_endpoint(request: Request) -> PlainTextResponse: diff --git a/src/mcp_server_appwrite/operator.py b/src/mcp_server_appwrite/operator.py index 4cc355e..681c059 100644 --- a/src/mcp_server_appwrite/operator.py +++ b/src/mcp_server_appwrite/operator.py @@ -11,23 +11,25 @@ import mcp.types as types from mcp.server.lowlevel.helper_types import ReadResourceContents +from pydantic import AnyUrl from . import telemetry +from .constants import ( + CATALOG_URI, + CREATE_HINTS, + DELETE_HINTS, + PREVIEW_THRESHOLD, + READ_HINTS, + READ_VERBS, + RESULT_STORE_SIZE, + RESULT_URI_TEMPLATE, + SEARCH_LIMIT, + UPDATE_HINTS, + VERBS, +) from .docs_search import DocsSearch from .tool_manager import ToolManager -SEARCH_LIMIT = 8 -PREVIEW_THRESHOLD = 800 -RESULT_STORE_SIZE = 50 -CATALOG_URI = "appwrite://operator/catalog" -RESULT_URI_TEMPLATE = "appwrite://operator/results/{result_id}" -VERBS = {"list", "get", "create", "update", "delete"} -READ_VERBS = {"list", "get"} -CREATE_HINTS = {"add", "build", "create", "insert", "make", "new", "provision"} -UPDATE_HINTS = {"change", "edit", "modify", "rename", "set", "update"} -DELETE_HINTS = {"delete", "destroy", "drop", "remove"} -READ_HINTS = {"fetch", "find", "get", "list", "read", "search", "show", "view"} - ToolContent = types.TextContent | types.ImageContent | types.EmbeddedResource # (tool_name, arguments, project_id, organization_id) -> content ToolExecutor = Callable[ @@ -300,7 +302,7 @@ def _get_context(self, arguments: dict[str, Any]) -> list[ToolContent]: def list_resources(self) -> list[types.Resource]: resources = [ types.Resource( - uri=CATALOG_URI, + uri=AnyUrl(CATALOG_URI), name="Appwrite Hidden Tool Catalog", description="Full internal Appwrite tool catalog used by the Appwrite operator surface.", mimeType="application/json", @@ -311,7 +313,7 @@ def list_resources(self) -> list[types.Resource]: for stored_result in self._result_store.list(): resources.append( types.Resource( - uri=stored_result.uri, + uri=AnyUrl(stored_result.uri), name=f"{stored_result.tool_name} result", description="Stored Appwrite tool result. Read this resource to inspect the full output.", mimeType="application/json", diff --git a/src/mcp_server_appwrite/server.py b/src/mcp_server_appwrite/server.py index 9949bb2..255766a 100644 --- a/src/mcp_server_appwrite/server.py +++ b/src/mcp_server_appwrite/server.py @@ -37,35 +37,29 @@ from mcp.server.auth.middleware.auth_context import get_access_token from mcp.server.lowlevel.helper_types import ReadResourceContents from mcp.server.models import InitializationOptions +from pydantic import AnyUrl from . import telemetry +from .constants import ( + CATALOG_URI, + DEFAULT_ENDPOINT, + DEFAULT_TRANSPORT, + EXCLUDED_SERVICES, + FETCH_MAX_REDIRECTS, + FETCH_TIMEOUT_SECONDS, + HOSTED_PATH_GUIDANCE, + MAX_FETCH_BYTES, + MAX_INLINE_BYTES, + SERVER_VERSION, + TRANSPORTS, + VALIDATION_SERVICE_ORDER, +) from .context import _normalize_sample_limit, get_appwrite_context from .docs_search import DocsSearch -from .operator import CATALOG_URI, Operator, _parse_tool_name +from .operator import Operator, _parse_tool_name from .service import Service from .tool_manager import ToolManager -SERVER_VERSION = "0.8.1" - -DEFAULT_ENDPOINT = "https://cloud.appwrite.io/v1" -DEFAULT_TRANSPORT = "stdio" -TRANSPORTS = {"stdio", "http"} -VALIDATION_SERVICE_ORDER = ( - "tables_db", - "users", - "teams", - "functions", - "sites", - "storage", - "messaging", - "locale", - "avatars", -) - -# Service modules in the Appwrite SDK to skip (none by default — every service the -# installed SDK ships is exposed). Add a module name here to hide a service. -EXCLUDED_SERVICES: frozenset[str] = frozenset() - def _discover_service_classes() -> dict[str, type]: """Discover every Appwrite SDK service class, keyed by its module name @@ -348,17 +342,6 @@ def _coerce_enum(enum_type: type[Enum], value: Any, param_name: str) -> Any: # paths are rejected with guidance; uploads come via URL fetch or inline bytes. _UPLOAD_TRANSPORT: str = "stdio" -_MAX_FETCH_BYTES = 25 * 1024 * 1024 # 25 MB cap on server-fetched files -_MAX_INLINE_BYTES = 256 * 1024 # 256 KB cap on decoded inline content -_FETCH_TIMEOUT_SECONDS = 30.0 -_FETCH_MAX_REDIRECTS = 5 - -_HOSTED_PATH_GUIDANCE = ( - "The hosted Appwrite MCP server cannot read local file paths. For '{param}', pass a " - 'public URL as {{"url": "https://..."}} (preferred), or a small file inline as ' - '{{"filename": "...", "content": "", "encoding": "base64"}}.' -) - def _configure_uploads(transport: str) -> None: """Set the upload mode for this server process. Called once from build_mcp_server.""" @@ -434,9 +417,9 @@ def _fetch_input_file(url: str, param_name: str) -> InputFile: _validate_fetch_url(url) try: with httpx.Client( - timeout=_FETCH_TIMEOUT_SECONDS, + timeout=FETCH_TIMEOUT_SECONDS, follow_redirects=True, - max_redirects=_FETCH_MAX_REDIRECTS, + max_redirects=FETCH_MAX_REDIRECTS, limits=httpx.Limits(max_connections=1), ) as client: with client.stream("GET", url) as resp: @@ -446,22 +429,22 @@ def _fetch_input_file(url: str, param_name: str) -> InputFile: declared = resp.headers.get("content-length") if declared is not None and declared.isdigit(): - if int(declared) > _MAX_FETCH_BYTES: + if int(declared) > MAX_FETCH_BYTES: telemetry.record_upload_error("too_large") raise ValueError( f"File at URL for '{param_name}' is too large " - f"({declared} bytes); max is {_MAX_FETCH_BYTES} bytes." + f"({declared} bytes); max is {MAX_FETCH_BYTES} bytes." ) chunks: list[bytes] = [] total = 0 for chunk in resp.iter_bytes(): total += len(chunk) - if total > _MAX_FETCH_BYTES: + if total > MAX_FETCH_BYTES: telemetry.record_upload_error("too_large") raise ValueError( f"File at URL for '{param_name}' exceeds the max of " - f"{_MAX_FETCH_BYTES} bytes." + f"{MAX_FETCH_BYTES} bytes." ) chunks.append(chunk) @@ -484,6 +467,9 @@ def _fetch_input_file(url: str, param_name: str) -> InputFile: def _coerce_inline_content(value: Mapping, param_name: str) -> InputFile: filename = value.get("filename") content = value.get("content") + if content is None: + telemetry.record_upload_error("decode") + raise ValueError(f"Missing inline 'content' for '{param_name}'.") encoding = str(value.get("encoding", "utf-8")).lower() if encoding == "base64": try: @@ -499,11 +485,11 @@ def _coerce_inline_content(value: Mapping, param_name: str) -> InputFile: f"Invalid encoding for '{param_name}'. Expected 'utf-8' or 'base64'." ) - if len(data) > _MAX_INLINE_BYTES: + if len(data) > MAX_INLINE_BYTES: telemetry.record_upload_error("too_large") raise ValueError( f"Inline content for '{param_name}' is too large " - f"({len(data)} bytes, max {_MAX_INLINE_BYTES}). For larger files pass " + f"({len(data)} bytes, max {MAX_INLINE_BYTES}). For larger files pass " '{"url": "https://..."} so the server can download it directly.' ) @@ -514,7 +500,7 @@ def _coerce_inline_content(value: Mapping, param_name: str) -> InputFile: def _coerce_path(path: str, param_name: str) -> InputFile: if _UPLOAD_TRANSPORT != "stdio": telemetry.record_upload_error("path_unsupported") - raise ValueError(_HOSTED_PATH_GUIDANCE.format(param=param_name)) + raise ValueError(HOSTED_PATH_GUIDANCE.format(param=param_name)) return InputFile.from_path(path) @@ -594,9 +580,8 @@ def _expected_argument_names(tool_info: dict) -> set[str]: if parameter_names: return parameter_names - input_schema = ( - tool_info.get("definition").inputSchema if tool_info.get("definition") else None - ) + definition = tool_info.get("definition") + input_schema = definition.inputSchema if definition is not None else None properties = ( input_schema.get("properties", {}) if isinstance(input_schema, dict) else {} ) @@ -688,9 +673,8 @@ def _validate_argument_keys( def _prepare_arguments(tool_info: dict, arguments: dict[str, Any]) -> dict[str, Any]: prepared_arguments = _normalize_argument_keys(tool_info, arguments) - tool_name = ( - tool_info.get("definition").name if tool_info.get("definition") else "tool" - ) + definition = tool_info.get("definition") + tool_name = definition.name if definition is not None else "tool" _validate_argument_keys(tool_name, tool_info, prepared_arguments) for param_name, param_type in tool_info.get("parameter_types", {}).items(): if param_name not in prepared_arguments: @@ -821,7 +805,7 @@ def _guess_mime_type(data: bytes, tool_name: str, arguments: dict[str, Any]) -> def _format_binary_result( tool_name: str, data: bytes, arguments: dict[str, Any] -) -> list[types.ImageContent | types.EmbeddedResource]: +) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: mime_type = _guess_mime_type(data, tool_name, arguments) encoded = base64.b64encode(data).decode("ascii") if mime_type.startswith("image/"): @@ -831,7 +815,7 @@ def _format_binary_result( types.EmbeddedResource( type="resource", resource=types.BlobResourceContents( - uri=f"appwrite://tool/{tool_name}", + uri=AnyUrl(f"appwrite://tool/{tool_name}"), blob=encoded, mimeType=mime_type, ), diff --git a/src/mcp_server_appwrite/service.py b/src/mcp_server_appwrite/service.py index d1de30d..cfcf8b1 100644 --- a/src/mcp_server_appwrite/service.py +++ b/src/mcp_server_appwrite/service.py @@ -115,7 +115,7 @@ def python_type_to_json_schema(self, py_type: Any) -> dict: if inspect.isclass(py_type) and issubclass(py_type, Enum): enum_values = [member.value for member in py_type] value_types = {type(value) for value in enum_values} - schema = {"enum": enum_values} + schema: dict[str, Any] = {"enum": enum_values} if len(value_types) == 1 and next(iter(value_types)) in type_mapping: schema["type"] = type_mapping[next(iter(value_types))] return schema @@ -181,7 +181,7 @@ def list_tools(self) -> Dict[str, Dict]: for doc_param in docstring.params: if doc_param.arg_name == param_name: properties[param_name]["description"] = self._clean_description( - doc_param.description + doc_param.description or "" ) if param.default is param.empty: diff --git a/src/mcp_server_appwrite/telemetry.py b/src/mcp_server_appwrite/telemetry.py index f8037ff..928dd85 100644 --- a/src/mcp_server_appwrite/telemetry.py +++ b/src/mcp_server_appwrite/telemetry.py @@ -37,7 +37,7 @@ import time from typing import Any, Iterable -_ACTIVE_WINDOW_SECONDS = 300.0 # rolling window for "active users/clients" gauges +from .constants import ACTIVE_WINDOW_SECONDS _enabled = False _lock = threading.Lock() @@ -46,7 +46,7 @@ _instruments: dict[str, Any] = {} # Rolling TTL sets for the active-user/active-client observable gauges. Keys expire -# after _ACTIVE_WINDOW_SECONDS so the gauges reflect a recent window, not all time. +# after ACTIVE_WINDOW_SECONDS so the gauges reflect a recent window, not all time. _active_users: dict[str, float] = {} _active_clients: dict[str, float] = {} # key: client name -> last-seen monotonic-ish ts _active_lock = threading.Lock() @@ -302,7 +302,7 @@ def _observe_active_clients(_options: Any) -> Iterable[Any]: def _touch_user(subject: str | None) -> None: if not subject: return - expiry = time.monotonic() + _ACTIVE_WINDOW_SECONDS + expiry = time.monotonic() + ACTIVE_WINDOW_SECONDS with _active_lock: _active_users[subject] = expiry @@ -313,7 +313,7 @@ def _touch_client(client_name: str | None, subject: str | None) -> None: if not client_name: return key = f"{client_name}\x00{subject or ''}" - expiry = time.monotonic() + _ACTIVE_WINDOW_SECONDS + expiry = time.monotonic() + ACTIVE_WINDOW_SECONDS with _active_lock: _active_clients[key] = expiry diff --git a/src/mcp_server_appwrite/tool_manager.py b/src/mcp_server_appwrite/tool_manager.py index 9b6cbfe..c9d48a5 100644 --- a/src/mcp_server_appwrite/tool_manager.py +++ b/src/mcp_server_appwrite/tool_manager.py @@ -1,4 +1,4 @@ -from typing import Dict, List +from __future__ import annotations from mcp.types import Tool @@ -7,18 +7,18 @@ class ToolManager: def __init__(self): - self.services: List[Service] = [] - self.tools_registry = {} + self.services: list[Service] = [] + self.tools_registry: dict[str, dict] = {} def register_service(self, service: Service): """Register a new service and its tools""" self.services.append(service) self.tools_registry.update(service.list_tools()) - def get_all_tools(self) -> List[Tool]: + def get_all_tools(self) -> list[Tool]: """Get all tool definitions""" return [tool_info["definition"] for tool_info in self.tools_registry.values()] - def get_tool(self, name: str) -> Dict: - """Get a specific tool by name""" + def get_tool(self, name: str) -> dict | None: + """Get a specific tool by name, or None if unregistered""" return self.tools_registry.get(name) diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index df52d1c..439ce33 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -13,6 +13,20 @@ "APPWRITE_PROJECT_ID": "console", } +# The issuer the authorization server *discovers to* is the regional host, which +# deliberately differs from the configured APPWRITE_ENDPOINT host — production +# Cloud discovery returns e.g. fra.cloud.appwrite.io while the MCP is configured +# with cloud.appwrite.io. +DISCOVERED_ISSUER = "https://fra.cloud.appwrite.io/v1/oauth2/console" + + +def discovery_doc(scopes: list[str], issuer: str = DISCOVERED_ISSUER) -> dict: + return { + "issuer": issuer, + "jwks_uri": f"{issuer}/.well-known/jwks.json", + "scopes_supported": scopes, + } + class AuthHelperTests(unittest.TestCase): def setUp(self): @@ -40,20 +54,112 @@ def test_build_resource_metadata_shape(self): self.assertEqual(meta["bearer_methods_supported"], ["header"]) self.assertEqual(meta["scopes_supported"], ["users.read", "teams.read"]) - def test_supported_scopes_uses_cache_without_network(self): + def test_advertised_scopes_use_cache_without_network(self): pid = auth.configured_project_id() - auth._discovery_cache[pid] = { - "issuer": "https://fra.cloud.appwrite.io/v1/oauth2/console", - "jwks_uri": "https://fra.cloud.appwrite.io/v1/oauth2/console/.well-known/jwks.json", - "scopes_supported": ["rows.read", "rows.write"], - } + auth._store_discovery(pid, discovery_doc(["rows.read", "rows.write"])) try: - scopes = asyncio.run(auth.supported_scopes()) + scopes = asyncio.run(auth.protected_resource_metadata())["scopes_supported"] finally: auth._discovery_cache.pop(pid, None) + # None of the preferred scopes exist, so the full discovered list is + # mirrored (custom scope catalogs on self-hosted projects). self.assertEqual(scopes, ["rows.read", "rows.write"]) - def test_supported_scopes_reads_dcr_enabled_discovery(self): + def test_advertised_scopes_prefer_curated_subset(self): + # The Cloud console authorization server advertises a very large + # fine-grained scope catalog; the MCP must advertise only the compact + # preferred set (clients request every advertised scope, and the + # authorize endpoint caps the scope parameter length). + pid = auth.configured_project_id() + auth._store_discovery( + pid, + discovery_doc( + [ + "openid", + "profile", + "email", + "phone", + "all", + "project:all", + "organization:all", + "project:users.read", + "project:users.write", + "organization:projects.read", + ] + ), + ) + try: + scopes = asyncio.run(auth.protected_resource_metadata())["scopes_supported"] + finally: + auth._discovery_cache.pop(pid, None) + self.assertEqual( + scopes, + ["openid", "profile", "email", "all", "project:all", "organization:all"], + ) + + def test_advertised_scopes_drop_preferred_scopes_missing_from_discovery(self): + pid = auth.configured_project_id() + auth._store_discovery(pid, discovery_doc(["openid", "email", "all"])) + try: + scopes = asyncio.run(auth.protected_resource_metadata())["scopes_supported"] + finally: + auth._discovery_cache.pop(pid, None) + self.assertEqual(scopes, ["openid", "email", "all"]) + + def test_advertised_scopes_env_override(self): + pid = auth.configured_project_id() + auth._store_discovery( + pid, discovery_doc(["openid", "email", "all", "project:all"]) + ) + try: + with mock.patch.dict( + os.environ, {"MCP_OAUTH_SCOPES": "openid project:all"} + ): + scopes = asyncio.run(auth.protected_resource_metadata())[ + "scopes_supported" + ] + finally: + auth._discovery_cache.pop(pid, None) + self.assertEqual(scopes, ["openid", "project:all"]) + + def test_discovery_cache_expires_after_ttl(self): + pid = auth.configured_project_id() + auth._store_discovery(pid, {"issuer": "x", "jwks_uri": "y"}) + try: + self.assertIsNotNone(auth._cached_discovery(pid)) + # Age the entry past the TTL. + fetched_at, doc = auth._discovery_cache[pid] + auth._discovery_cache[pid] = ( + fetched_at - auth.DISCOVERY_TTL_SECONDS - 1, + doc, + ) + self.assertIsNone(auth._cached_discovery(pid)) + # A stale entry is still reachable as a fallback for failed refreshes. + self.assertIsNotNone(auth._cached_discovery(pid, allow_stale=True)) + finally: + auth._discovery_cache.pop(pid, None) + + def test_stale_discovery_served_when_refresh_fails(self): + pid = auth.configured_project_id() + stale_doc = discovery_doc(["openid"]) + auth._store_discovery(pid, stale_doc) + fetched_at, doc = auth._discovery_cache[pid] + auth._discovery_cache[pid] = ( + fetched_at - auth.DISCOVERY_TTL_SECONDS - 1, + doc, + ) + + def _boom(*args, **kwargs): + raise RuntimeError("network down") + + try: + with mock.patch.object(auth.httpx, "get", _boom): + metadata = auth.authorization_server_metadata_sync() + finally: + auth._discovery_cache.pop(pid, None) + self.assertEqual(metadata, stale_doc) + + def test_metadata_reads_dcr_enabled_discovery(self): # The authorization server's discovery document now also advertises # `registration_endpoint` (RFC 7591). Sourcing scopes must keep working # against that document — the MCP points clients at this same AS, and @@ -90,7 +196,9 @@ async def get(self, url): with mock.patch.object(auth.httpx, "AsyncClient", _FakeAsyncClient): try: - scopes = asyncio.run(auth.supported_scopes()) + scopes = asyncio.run(auth.protected_resource_metadata())[ + "scopes_supported" + ] finally: auth._discovery_cache.pop("console", None) @@ -108,7 +216,7 @@ async def get(self, url): f"{discovery['issuer']}/register", ) - def test_supported_scopes_raises_when_discovery_unreachable(self): + def test_metadata_raises_when_discovery_unreachable(self): # Point discovery at an unroutable address so the fetch fails fast. with mock.patch.dict( os.environ, @@ -118,7 +226,7 @@ def test_supported_scopes_raises_when_discovery_unreachable(self): }, ): with self.assertRaises(Exception): - asyncio.run(auth.supported_scopes()) + asyncio.run(auth.protected_resource_metadata()) self.assertNotIn("unreachableproj", auth._discovery_cache) def test_project_id_from_issuer_accepts_matching_issuer(self): @@ -148,11 +256,7 @@ def test_project_id_from_issuer_rejects_foreign_issuer(self): def test_protected_resource_metadata_uses_discovered_issuer(self): pid = auth.configured_project_id() - auth._discovery_cache[pid] = { - "issuer": "https://fra.cloud.appwrite.io/v1/oauth2/console", - "jwks_uri": "https://fra.cloud.appwrite.io/v1/oauth2/console/.well-known/jwks.json", - "scopes_supported": ["users.read"], - } + auth._store_discovery(pid, discovery_doc(["users.read"])) try: meta = asyncio.run(auth.protected_resource_metadata()) finally: @@ -160,7 +264,7 @@ def test_protected_resource_metadata_uses_discovered_issuer(self): self.assertEqual( meta["authorization_servers"], - ["https://fra.cloud.appwrite.io/v1/oauth2/console"], + [DISCOVERED_ISSUER], ) @@ -200,11 +304,7 @@ def test_verify_rejects_token_for_other_project(self): def test_verify_rejects_token_with_undiscovered_issuer(self): pid = auth.configured_project_id() - auth._discovery_cache[pid] = { - "issuer": "https://fra.cloud.appwrite.io/v1/oauth2/console", - "jwks_uri": "https://fra.cloud.appwrite.io/v1/oauth2/console/.well-known/jwks.json", - "scopes_supported": ["users.read"], - } + auth._store_discovery(pid, discovery_doc(["users.read"])) token = jwt.encode( {"iss": "https://cloud.appwrite.io/v1/oauth2/console"}, "x" * 32, diff --git a/tests/unit/test_docs_search.py b/tests/unit/test_docs_search.py index 21497ae..45e0e30 100644 --- a/tests/unit/test_docs_search.py +++ b/tests/unit/test_docs_search.py @@ -5,7 +5,7 @@ import numpy as np -from mcp_server_appwrite.docs_search import MAX_LIMIT, DocsSearch, _clamp_limit +from mcp_server_appwrite.docs_search import DOCS_MAX_LIMIT, DocsSearch, _clamp_limit def write_index(data_dir: Path) -> None: @@ -117,7 +117,7 @@ def test_no_match_returns_message(self): def test_clamp_limit(self): self.assertEqual(_clamp_limit(None, 5), 5) - self.assertEqual(_clamp_limit(50, 5), MAX_LIMIT) + self.assertEqual(_clamp_limit(50, 5), DOCS_MAX_LIMIT) with self.assertRaises(ValueError): _clamp_limit(0, 5) diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 52a3bd5..d4149a1 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -576,13 +576,13 @@ def test_url_fetch_rejects_non_http_scheme(self): def test_url_fetch_size_cap_via_stream(self): response = _FakeResponse(data=b"0123456789") # 10 bytes, no content-length addr, client = self._patch_fetch(response) - with addr, client, patch.object(server_module, "_MAX_FETCH_BYTES", 4): + with addr, client, patch.object(server_module, "MAX_FETCH_BYTES", 4): with self.assertRaises(ValueError) as ctx: _coerce_argument("file", {"url": "https://example.com/x"}, InputFile) self.assertIn("max", str(ctx.exception).lower()) def test_inline_content_size_cap(self): - with patch.object(server_module, "_MAX_INLINE_BYTES", 4): + with patch.object(server_module, "MAX_INLINE_BYTES", 4): with self.assertRaises(ValueError) as ctx: _coerce_argument( "file", diff --git a/uv.lock b/uv.lock index 5c88af1..d028059 100644 --- a/uv.lock +++ b/uv.lock @@ -669,6 +669,7 @@ integration = [ [package.dev-dependencies] dev = [ { name = "black" }, + { name = "pyright" }, { name = "pyyaml" }, { name = "ruff" }, ] @@ -698,6 +699,7 @@ provides-extras = ["integration"] [package.metadata.requires-dev] dev = [ { name = "black", specifier = ">=25.1.0" }, + { name = "pyright", specifier = ">=1.1.390" }, { name = "pyyaml", specifier = ">=6.0" }, { name = "ruff", specifier = ">=0.10.0" }, ] @@ -720,6 +722,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, +] + [[package]] name = "numpy" version = "2.5.0" @@ -1087,6 +1098,19 @@ crypto = [ { name = "cryptography" }, ] +[[package]] +name = "pyright" +version = "1.1.411" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/ab/265f7dc69d28113ebba19092e57b075f41543b2ed048429c5f56e2b88eac/pyright-1.1.411.tar.gz", hash = "sha256:d885a0551f2e763b089a02702174e7f4ba77548cddabc972ab86d1f7f1b0f998", size = 4112861, upload-time = "2026-06-25T02:14:06.37Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/49/385be530a6a5b78d1cbcd5c2e38debc8959a2fc6bdb716f4e581002979fc/pyright-1.1.411-py3-none-any.whl", hash = "sha256:dc7c72a8e2700c55baa127554040e067041ea53ccfd50bf96308cc4291c7d5d9", size = 6181526, upload-time = "2026-06-25T02:14:04.691Z" }, +] + [[package]] name = "python-dotenv" version = "1.0.1"