From 6e296c93d2411449b7ae84e16f462d97e07abc27 Mon Sep 17 00:00:00 2001 From: GaoXiang233 <1679562189@qq.com> Date: Sat, 25 Apr 2026 11:38:48 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=A6=96=E6=AC=A1=E6=8F=90=E4=BA=A4?= =?UTF-8?q?=E6=88=91=E7=9A=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastapi_mcp/__init__.py | 37 +- fastapi_mcp/auth/api_key_manager.py | 210 +++++++ fastapi_mcp/auth/rate_limiter.py | 291 +++++++++ fastapi_mcp/auth/user_manager.py | 220 +++++++ fastapi_mcp/monitoring/__init__.py | 3 + fastapi_mcp/monitoring/call_logger.py | 300 +++++++++ fastapi_mcp/panel/__init__.py | 3 + fastapi_mcp/panel/panel.py | 852 ++++++++++++++++++++++++++ fastapi_mcp/server.py | 289 ++++++++- fastapi_mcp/types.py | 198 +++++- 10 files changed, 2382 insertions(+), 21 deletions(-) create mode 100644 fastapi_mcp/auth/api_key_manager.py create mode 100644 fastapi_mcp/auth/rate_limiter.py create mode 100644 fastapi_mcp/auth/user_manager.py create mode 100644 fastapi_mcp/monitoring/__init__.py create mode 100644 fastapi_mcp/monitoring/call_logger.py create mode 100644 fastapi_mcp/panel/__init__.py create mode 100644 fastapi_mcp/panel/panel.py diff --git a/fastapi_mcp/__init__.py b/fastapi_mcp/__init__.py index f748712..aa05e88 100644 --- a/fastapi_mcp/__init__.py +++ b/fastapi_mcp/__init__.py @@ -13,11 +13,46 @@ __version__ = "0.0.0.dev0" # pragma: no cover from .server import FastApiMCP -from .types import AuthConfig, OAuthMetadata +from .types import ( + AuthConfig, + OAuthMetadata, + ExtendedAuthConfig, + PanelConfig, + RateLimitConfig, + SecurityConfig, + UserRole, + Permission, + User, + UserCreate, + UserUpdate, + ApiKeyStatus, + ApiKey, + ApiKeyCreate, + ApiKeyUpdate, + CallStatus, + CallLog, + MonitorData, +) __all__ = [ "FastApiMCP", "AuthConfig", "OAuthMetadata", + "ExtendedAuthConfig", + "PanelConfig", + "RateLimitConfig", + "SecurityConfig", + "UserRole", + "Permission", + "User", + "UserCreate", + "UserUpdate", + "ApiKeyStatus", + "ApiKey", + "ApiKeyCreate", + "ApiKeyUpdate", + "CallStatus", + "CallLog", + "MonitorData", ] diff --git a/fastapi_mcp/auth/api_key_manager.py b/fastapi_mcp/auth/api_key_manager.py new file mode 100644 index 0000000..5f27655 --- /dev/null +++ b/fastapi_mcp/auth/api_key_manager.py @@ -0,0 +1,210 @@ +import logging +import uuid +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any +from typing_extensions import Annotated, Doc + +from fastapi_mcp.types import ( + ApiKey, + ApiKeyCreate, + ApiKeyUpdate, + ApiKeyStatus, + Permission, +) + + +logger = logging.getLogger(__name__) + + +class ApiKeyManager: + """ + API Key management system for FastAPI-MCP. + + Provides API key generation, validation, and management. + """ + + def __init__( + self, + key_prefix: Annotated[ + str, + Doc("Prefix for generated API keys"), + ] = "mcp", + default_permissions: Annotated[ + List[Permission], + Doc("Default permissions for new API keys"), + ] = None, + default_rate_limit: Annotated[ + Optional[int], + Doc("Default rate limit for new API keys (requests per window)"), + ] = None, + default_rate_limit_window: Annotated[ + int, + Doc("Default rate limit window in seconds"), + ] = 60, + ): + self._api_keys: Dict[str, ApiKey] = {} + self._api_keys_by_key: Dict[str, ApiKey] = {} + self._api_keys_by_user: Dict[str, List[str]] = {} + self._key_prefix = key_prefix + self._default_permissions = default_permissions or [Permission.READ_TOOLS, Permission.CALL_TOOLS] + self._default_rate_limit = default_rate_limit + self._default_rate_limit_window = default_rate_limit_window + + def create_api_key( + self, + user_id: str, + key_create: ApiKeyCreate, + ) -> ApiKey: + key_value = ApiKey.generate_key(self._key_prefix) + api_key_id = str(uuid.uuid4()) + + expires_at = None + if key_create.expires_in_days: + expires_at = datetime.utcnow() + timedelta(days=key_create.expires_in_days) + + api_key = ApiKey( + id=api_key_id, + key=key_value, + name=key_create.name, + description=key_create.description, + user_id=user_id, + status=ApiKeyStatus.ACTIVE, + permissions=key_create.permissions or self._default_permissions, + rate_limit=key_create.rate_limit or self._default_rate_limit, + rate_limit_window=key_create.rate_limit_window or self._default_rate_limit_window, + expires_at=expires_at, + created_at=datetime.utcnow(), + ) + + self._api_keys[api_key_id] = api_key + self._api_keys_by_key[key_value] = api_key + + if user_id not in self._api_keys_by_user: + self._api_keys_by_user[user_id] = [] + self._api_keys_by_user[user_id].append(api_key_id) + + logger.info(f"API Key created: {api_key.name} (ID: {api_key_id}) for user {user_id}") + return api_key + + def get_api_key(self, api_key_id: str) -> Optional[ApiKey]: + return self._api_keys.get(api_key_id) + + def get_api_key_by_key(self, key: str) -> Optional[ApiKey]: + return self._api_keys_by_key.get(key) + + def validate_api_key(self, key: str) -> Optional[ApiKey]: + api_key = self._api_keys_by_key.get(key) + if not api_key: + return None + if not api_key.is_valid(): + return None + return api_key + + def list_api_keys(self, user_id: Optional[str] = None) -> List[ApiKey]: + if user_id: + api_key_ids = self._api_keys_by_user.get(user_id, []) + return [self._api_keys[aid] for aid in api_key_ids if aid in self._api_keys] + return list(self._api_keys.values()) + + def list_active_api_keys(self, user_id: Optional[str] = None) -> List[ApiKey]: + keys = self.list_api_keys(user_id) + return [k for k in keys if k.is_valid()] + + def update_api_key(self, api_key_id: str, key_update: ApiKeyUpdate) -> Optional[ApiKey]: + api_key = self._api_keys.get(api_key_id) + if not api_key: + return None + + update_data = key_update.model_dump(exclude_unset=True) + + for key, value in update_data.items(): + if value is not None: + setattr(api_key, key, value) + + logger.info(f"API Key updated: {api_key.name} (ID: {api_key_id})") + return api_key + + def revoke_api_key(self, api_key_id: str) -> bool: + api_key = self._api_keys.get(api_key_id) + if not api_key: + return False + + api_key.status = ApiKeyStatus.REVOKED + logger.info(f"API Key revoked: {api_key.name} (ID: {api_key_id})") + return True + + def delete_api_key(self, api_key_id: str) -> bool: + api_key = self._api_keys.get(api_key_id) + if not api_key: + return False + + key_value = api_key.key + user_id = api_key.user_id + + del self._api_keys[api_key_id] + del self._api_keys_by_key[key_value] + + if user_id in self._api_keys_by_user: + if api_key_id in self._api_keys_by_user[user_id]: + self._api_keys_by_user[user_id].remove(api_key_id) + + logger.info(f"API Key deleted: {api_key.name} (ID: {api_key_id})") + return True + + def record_usage(self, api_key_id: str) -> bool: + api_key = self._api_keys.get(api_key_id) + if not api_key: + return False + + api_key.usage_count += 1 + api_key.last_used_at = datetime.utcnow() + return True + + def check_permission(self, api_key_id: str, permission: Permission) -> bool: + api_key = self._api_keys.get(api_key_id) + if not api_key: + return False + if not api_key.is_valid(): + return False + if Permission.ADMIN_ALL in api_key.permissions: + return True + return permission in api_key.permissions + + def get_permissions_for_key(self, api_key_id: str) -> List[Permission]: + api_key = self._api_keys.get(api_key_id) + if not api_key: + return [] + return api_key.permissions + + def rotate_api_key(self, api_key_id: str) -> Optional[ApiKey]: + api_key = self._api_keys.get(api_key_id) + if not api_key: + return None + + old_key = api_key.key + new_key = ApiKey.generate_key(self._key_prefix) + + del self._api_keys_by_key[old_key] + api_key.key = new_key + self._api_keys_by_key[new_key] = api_key + + logger.info(f"API Key rotated: {api_key.name} (ID: {api_key_id})") + return api_key + + def get_stats(self) -> Dict[str, Any]: + total_keys = len(self._api_keys) + active_keys = sum(1 for k in self._api_keys.values() if k.is_valid()) + revoked_keys = sum(1 for k in self._api_keys.values() if k.status == ApiKeyStatus.REVOKED) + inactive_keys = sum(1 for k in self._api_keys.values() if k.status == ApiKeyStatus.INACTIVE) + total_usage = sum(k.usage_count for k in self._api_keys.values()) + + users_with_keys = len(self._api_keys_by_user) + + return { + "total_keys": total_keys, + "active_keys": active_keys, + "revoked_keys": revoked_keys, + "inactive_keys": inactive_keys, + "total_usage": total_usage, + "users_with_keys": users_with_keys, + } diff --git a/fastapi_mcp/auth/rate_limiter.py b/fastapi_mcp/auth/rate_limiter.py new file mode 100644 index 0000000..7ee2212 --- /dev/null +++ b/fastapi_mcp/auth/rate_limiter.py @@ -0,0 +1,291 @@ +import logging +import time +from collections import defaultdict, deque +from typing import Dict, Optional, Any +from typing_extensions import Annotated, Doc + +from fastapi_mcp.types import ( + RateLimitConfig, + ApiKey, +) + + +logger = logging.getLogger(__name__) + + +class SlidingWindowRateLimiter: + """ + Sliding window rate limiter implementation. + + This uses a deque to track request timestamps within each window. + """ + + def __init__( + self, + limit: int, + window_seconds: int, + ): + self.limit = limit + self.window_seconds = window_seconds + self._requests: deque = deque() + + def check_and_acquire(self) -> bool: + now = time.time() + cutoff = now - self.window_seconds + + while self._requests and self._requests[0] <= cutoff: + self._requests.popleft() + + if len(self._requests) >= self.limit: + return False + + self._requests.append(now) + return True + + def get_remaining(self) -> int: + now = time.time() + cutoff = now - self.window_seconds + + while self._requests and self._requests[0] <= cutoff: + self._requests.popleft() + + return max(0, self.limit - len(self._requests)) + + def get_reset_time(self) -> float: + if not self._requests: + return 0.0 + now = time.time() + oldest = self._requests[0] + return max(0.0, oldest + self.window_seconds - now) + + +class RateLimiter: + """ + Rate limiter for FastAPI-MCP. + + Supports: + - Global rate limiting + - Per-user rate limiting + - Per-API-key rate limiting + - Custom rate limits for specific API keys + """ + + def __init__( + self, + config: Annotated[ + Optional[RateLimitConfig], + Doc("Rate limit configuration"), + ] = None, + ): + self.config = config or RateLimitConfig() + self._global_limiter: Optional[SlidingWindowRateLimiter] = None + self._user_limiters: Dict[str, SlidingWindowRateLimiter] = {} + self._api_key_limiters: Dict[str, SlidingWindowRateLimiter] = {} + self._start_time = time.time() + + if self.config.global_limit: + self._global_limiter = SlidingWindowRateLimiter( + limit=self.config.global_limit, + window_seconds=self.config.default_window, + ) + + def _get_or_create_limiter( + self, + key: str, + limit: int, + window: int, + limiter_dict: Dict[str, SlidingWindowRateLimiter], + ) -> SlidingWindowRateLimiter: + if key not in limiter_dict: + limiter_dict[key] = SlidingWindowRateLimiter(limit=limit, window_seconds=window) + return limiter_dict[key] + + def check_rate_limit( + self, + user_id: Optional[str] = None, + api_key: Optional[ApiKey] = None, + ) -> tuple[bool, Dict[str, Any]]: + if not self.config.enabled: + return True, {"allowed": True, "reason": "rate_limit_disabled"} + + if self._global_limiter: + if not self._global_limiter.check_and_acquire(): + return False, { + "allowed": False, + "reason": "global_rate_limit_exceeded", + "limit": self.config.global_limit, + "remaining": self._global_limiter.get_remaining(), + "reset_seconds": self._global_limiter.get_reset_time(), + } + + if api_key: + effective_limit = self.config.get_effective_limit(api_key) + effective_window = api_key.rate_limit_window if api_key.rate_limit else self.config.default_window + + limiter = self._get_or_create_limiter( + key=api_key.id, + limit=effective_limit, + window=effective_window, + limiter_dict=self._api_key_limiters, + ) + + if not limiter.check_and_acquire(): + return False, { + "allowed": False, + "reason": "api_key_rate_limit_exceeded", + "limit": effective_limit, + "window_seconds": effective_window, + "remaining": limiter.get_remaining(), + "reset_seconds": limiter.get_reset_time(), + "api_key_id": api_key.id, + } + + elif user_id: + effective_limit = self.config.user_limit or self.config.default_limit + effective_window = self.config.default_window + + limiter = self._get_or_create_limiter( + key=user_id, + limit=effective_limit, + window=effective_window, + limiter_dict=self._user_limiters, + ) + + if not limiter.check_and_acquire(): + return False, { + "allowed": False, + "reason": "user_rate_limit_exceeded", + "limit": effective_limit, + "window_seconds": effective_window, + "remaining": limiter.get_remaining(), + "reset_seconds": limiter.get_reset_time(), + "user_id": user_id, + } + + return True, {"allowed": True} + + def get_rate_limit_status( + self, + user_id: Optional[str] = None, + api_key: Optional[ApiKey] = None, + ) -> Dict[str, Any]: + status = { + "enabled": self.config.enabled, + "global": None, + "user": None, + "api_key": None, + } + + if self._global_limiter: + status["global"] = { + "limit": self.config.global_limit, + "remaining": self._global_limiter.get_remaining(), + "reset_seconds": self._global_limiter.get_reset_time(), + } + + if api_key and api_key.id in self._api_key_limiters: + limiter = self._api_key_limiters[api_key.id] + effective_limit = self.config.get_effective_limit(api_key) + status["api_key"] = { + "id": api_key.id, + "limit": effective_limit, + "remaining": limiter.get_remaining(), + "reset_seconds": limiter.get_reset_time(), + } + elif api_key: + effective_limit = self.config.get_effective_limit(api_key) + status["api_key"] = { + "id": api_key.id, + "limit": effective_limit, + "remaining": effective_limit, + "reset_seconds": 0.0, + } + + if user_id and user_id in self._user_limiters: + limiter = self._user_limiters[user_id] + effective_limit = self.config.user_limit or self.config.default_limit + status["user"] = { + "id": user_id, + "limit": effective_limit, + "remaining": limiter.get_remaining(), + "reset_seconds": limiter.get_reset_time(), + } + elif user_id: + effective_limit = self.config.user_limit or self.config.default_limit + status["user"] = { + "id": user_id, + "limit": effective_limit, + "remaining": effective_limit, + "reset_seconds": 0.0, + } + + return status + + def reset_limiter( + self, + user_id: Optional[str] = None, + api_key_id: Optional[str] = None, + ) -> bool: + if user_id and user_id in self._user_limiters: + del self._user_limiters[user_id] + logger.info(f"Rate limiter reset for user: {user_id}") + return True + + if api_key_id and api_key_id in self._api_key_limiters: + del self._api_key_limiters[api_key_id] + logger.info(f"Rate limiter reset for API key: {api_key_id}") + return True + + return False + + def get_stats(self) -> Dict[str, Any]: + uptime = time.time() - self._start_time + + active_user_limiters = len(self._user_limiters) + active_api_key_limiters = len(self._api_key_limiters) + + total_user_requests = sum( + len(limiter._requests) for limiter in self._user_limiters.values() + ) + total_api_key_requests = sum( + len(limiter._requests) for limiter in self._api_key_limiters.values() + ) + + return { + "enabled": self.config.enabled, + "uptime_seconds": round(uptime, 2), + "active_user_limiters": active_user_limiters, + "active_api_key_limiters": active_api_key_limiters, + "total_user_requests_tracked": total_user_requests, + "total_api_key_requests_tracked": total_api_key_requests, + "config": { + "default_limit": self.config.default_limit, + "default_window": self.config.default_window, + "user_limit": self.config.user_limit, + "api_key_limit": self.config.api_key_limit, + "global_limit": self.config.global_limit, + }, + } + + def cleanup_stale_limiters(self, max_age_seconds: int = 3600) -> int: + now = time.time() + cutoff = now - max_age_seconds + + cleaned = 0 + + for key in list(self._user_limiters.keys()): + limiter = self._user_limiters[key] + if not limiter._requests or (limiter._requests and limiter._requests[-1] <= cutoff): + del self._user_limiters[key] + cleaned += 1 + + for key in list(self._api_key_limiters.keys()): + limiter = self._api_key_limiters[key] + if not limiter._requests or (limiter._requests and limiter._requests[-1] <= cutoff): + del self._api_key_limiters[key] + cleaned += 1 + + if cleaned > 0: + logger.info(f"Cleaned up {cleaned} stale rate limiters") + + return cleaned diff --git a/fastapi_mcp/auth/user_manager.py b/fastapi_mcp/auth/user_manager.py new file mode 100644 index 0000000..629444e --- /dev/null +++ b/fastapi_mcp/auth/user_manager.py @@ -0,0 +1,220 @@ +import logging +import uuid +from datetime import datetime +from typing import Dict, List, Optional, Callable, Any +from typing_extensions import Annotated, Doc + +from fastapi_mcp.types import ( + User, + UserCreate, + UserUpdate, + UserRole, + Permission, +) + + +logger = logging.getLogger(__name__) + + +class UserManager: + """ + User management system for FastAPI-MCP. + + Provides user CRUD operations, authentication, and permission checking. + """ + + def __init__( + self, + enable_admin_user: Annotated[ + bool, + Doc("Whether to enable the default admin user"), + ] = True, + admin_username: Annotated[ + str, + Doc("Username for the default admin user"), + ] = "admin", + admin_password: Annotated[ + Optional[str], + Doc("Password for the default admin user (if None, a random one will be generated)"), + ] = None, + password_hasher: Annotated[ + Optional[Callable[[str], str]], + Doc("Optional custom password hasher function"), + ] = None, + password_verifier: Annotated[ + Optional[Callable[[str, str], bool]], + Doc("Optional custom password verifier function"), + ] = None, + ): + self._users: Dict[str, User] = {} + self._users_by_username: Dict[str, User] = {} + self._passwords: Dict[str, str] = {} + self._password_hasher = password_hasher or self._default_hash_password + self._password_verifier = password_verifier or self._default_verify_password + + if enable_admin_user: + self._create_default_admin(admin_username, admin_password) + + def _default_hash_password(self, password: str) -> str: + try: + import bcrypt + password_bytes = password.encode("utf-8") + salt = bcrypt.gensalt() + return bcrypt.hashpw(password_bytes, salt).decode("utf-8") + except ImportError: + import hashlib + return hashlib.sha256(password.encode("utf-8")).hexdigest() + + def _default_verify_password(self, password: str, hashed: str) -> bool: + try: + import bcrypt + return bcrypt.checkpw(password.encode("utf-8"), hashed.encode("utf-8")) + except ImportError: + import hashlib + return hashlib.sha256(password.encode("utf-8")).hexdigest() == hashed + + def _create_default_admin(self, username: str, password: Optional[str] = None) -> User: + if password is None: + password = uuid.uuid4().hex + logger.warning(f"Default admin user created with password: {password}") + logger.warning("Please change this password in production!") + + admin_user = User( + id=str(uuid.uuid4()), + username=username, + role=UserRole.ADMIN, + permissions=[Permission.ADMIN_ALL], + is_active=True, + ) + + self._users[admin_user.id] = admin_user + self._users_by_username[username] = admin_user + self._passwords[admin_user.id] = self._password_hasher(password) + + logger.info(f"Default admin user created: {username}") + return admin_user + + def create_user(self, user_create: UserCreate) -> User: + if user_create.username in self._users_by_username: + raise ValueError(f"Username '{user_create.username}' already exists") + + user = User( + id=str(uuid.uuid4()), + username=user_create.username, + email=user_create.email, + role=user_create.role, + permissions=user_create.permissions, + is_active=True, + ) + + self._users[user.id] = user + self._users_by_username[user.username] = user + + if user_create.password: + self._passwords[user.id] = self._password_hasher(user_create.password) + + logger.info(f"User created: {user.username} (ID: {user.id})") + return user + + def get_user(self, user_id: str) -> Optional[User]: + return self._users.get(user_id) + + def get_user_by_username(self, username: str) -> Optional[User]: + return self._users_by_username.get(username) + + def list_users(self) -> List[User]: + return list(self._users.values()) + + def update_user(self, user_id: str, user_update: UserUpdate) -> Optional[User]: + user = self._users.get(user_id) + if not user: + return None + + update_data = user_update.model_dump(exclude_unset=True) + + for key, value in update_data.items(): + if value is not None: + setattr(user, key, value) + + logger.info(f"User updated: {user.username} (ID: {user.id})") + return user + + def delete_user(self, user_id: str) -> bool: + user = self._users.get(user_id) + if not user: + return False + + del self._users[user_id] + del self._users_by_username[user.username] + + if user_id in self._passwords: + del self._passwords[user_id] + + logger.info(f"User deleted: {user.username} (ID: {user.id})") + return True + + def authenticate_user(self, username: str, password: str) -> Optional[User]: + user = self._users_by_username.get(username) + if not user: + logger.warning(f"Authentication failed: user '{username}' not found") + return None + + if not user.is_active: + logger.warning(f"Authentication failed: user '{username}' is inactive") + return None + + hashed_password = self._passwords.get(user.id) + if not hashed_password: + logger.warning(f"Authentication failed: no password set for user '{username}'") + return None + + if not self._password_verifier(password, hashed_password): + logger.warning(f"Authentication failed: invalid password for user '{username}'") + return None + + user.last_login = datetime.utcnow() + logger.info(f"User authenticated: {username} (ID: {user.id})") + return user + + def set_password(self, user_id: str, password: str) -> bool: + user = self._users.get(user_id) + if not user: + return False + + self._passwords[user_id] = self._password_hasher(password) + logger.info(f"Password set for user: {user.username} (ID: {user.id})") + return True + + def check_permission(self, user_id: str, permission: Permission) -> bool: + user = self._users.get(user_id) + if not user: + return False + return user.has_permission(permission) + + def check_role(self, user_id: str, role: UserRole) -> bool: + user = self._users.get(user_id) + if not user: + return False + return user.has_role(role) + + def get_permissions_for_user(self, user_id: str) -> List[Permission]: + user = self._users.get(user_id) + if not user: + return [] + if user.role == UserRole.ADMIN: + return list(Permission) + return user.permissions + + def get_stats(self) -> Dict[str, Any]: + total_users = len(self._users) + active_users = sum(1 for u in self._users.values() if u.is_active) + users_by_role = { + role.value: sum(1 for u in self._users.values() if u.role == role) + for role in UserRole + } + + return { + "total_users": total_users, + "active_users": active_users, + "users_by_role": users_by_role, + } diff --git a/fastapi_mcp/monitoring/__init__.py b/fastapi_mcp/monitoring/__init__.py new file mode 100644 index 0000000..f225448 --- /dev/null +++ b/fastapi_mcp/monitoring/__init__.py @@ -0,0 +1,3 @@ +from .call_logger import CallLogger + +__all__ = ["CallLogger"] diff --git a/fastapi_mcp/monitoring/call_logger.py b/fastapi_mcp/monitoring/call_logger.py new file mode 100644 index 0000000..5ceec06 --- /dev/null +++ b/fastapi_mcp/monitoring/call_logger.py @@ -0,0 +1,300 @@ +import logging +import uuid +import time +from collections import deque, defaultdict +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Any, Deque +from typing_extensions import Annotated, Doc + +from fastapi_mcp.types import ( + CallLog, + CallStatus, + MonitorData, +) + + +logger = logging.getLogger(__name__) + + +class CallLogger: + """ + Call logging and monitoring system for FastAPI-MCP. + + Provides: + - Call logging with detailed request/response information + - Real-time monitoring data aggregation + - Statistics collection + - Error tracking + """ + + def __init__( + self, + max_logs: Annotated[ + int, + Doc("Maximum number of logs to keep in memory"), + ] = 10000, + max_errors: Annotated[ + int, + Doc("Maximum number of errors to keep in memory"), + ] = 1000, + metrics_window_seconds: Annotated[ + int, + Doc("Time window for metrics calculation in seconds"), + ] = 3600, + ): + self._max_logs = max_logs + self._max_errors = max_errors + self._metrics_window_seconds = metrics_window_seconds + + self._logs: Deque[CallLog] = deque(maxlen=max_logs) + self._errors: Deque[CallLog] = deque(maxlen=max_errors) + + self._tool_call_counts: Dict[str, int] = defaultdict(int) + self._tool_durations: Dict[str, List[float]] = defaultdict(list) + self._recent_durations: Deque[float] = deque(maxlen=10000) + + self._start_time = time.time() + self._total_calls = 0 + self._success_calls = 0 + self._error_calls = 0 + self._rate_limited_calls = 0 + + self._user_activity: Dict[str, datetime] = {} + self._api_key_activity: Dict[str, datetime] = {} + + def log_call( + self, + tool_name: str, + status: CallStatus, + duration_ms: float, + request_method: str, + request_path: str, + user_id: Optional[str] = None, + api_key_id: Optional[str] = None, + request_headers: Optional[Dict[str, str]] = None, + request_body: Optional[str] = None, + response_status: int = 200, + response_body: Optional[str] = None, + error_message: Optional[str] = None, + client_ip: Optional[str] = None, + user_agent: Optional[str] = None, + ) -> CallLog: + call_log = CallLog( + id=str(uuid.uuid4()), + timestamp=datetime.utcnow(), + tool_name=tool_name, + user_id=user_id, + api_key_id=api_key_id, + status=status, + duration_ms=duration_ms, + request_method=request_method, + request_path=request_path, + request_headers=request_headers or {}, + request_body=request_body, + response_status=response_status, + response_body=response_body, + error_message=error_message, + client_ip=client_ip, + user_agent=user_agent, + ) + + self._logs.append(call_log) + + self._total_calls += 1 + self._tool_call_counts[tool_name] += 1 + self._tool_durations[tool_name].append(duration_ms) + self._recent_durations.append(duration_ms) + + if status == CallStatus.SUCCESS: + self._success_calls += 1 + elif status == CallStatus.ERROR: + self._error_calls += 1 + self._errors.append(call_log) + elif status == CallStatus.RATE_LIMITED: + self._rate_limited_calls += 1 + + if user_id: + self._user_activity[user_id] = datetime.utcnow() + if api_key_id: + self._api_key_activity[api_key_id] = datetime.utcnow() + + logger.debug( + f"Call logged: {tool_name} - {status.value} - {duration_ms}ms " + f"(user: {user_id}, api_key: {api_key_id})" + ) + + return call_log + + def get_logs( + self, + tool_name: Optional[str] = None, + user_id: Optional[str] = None, + api_key_id: Optional[str] = None, + status: Optional[CallStatus] = None, + limit: int = 100, + offset: int = 0, + ) -> List[CallLog]: + filtered = list(self._logs) + + if tool_name: + filtered = [log for log in filtered if log.tool_name == tool_name] + if user_id: + filtered = [log for log in filtered if log.user_id == user_id] + if api_key_id: + filtered = [log for log in filtered if log.api_key_id == api_key_id] + if status: + filtered = [log for log in filtered if log.status == status] + + filtered = list(reversed(filtered)) + + return filtered[offset : offset + limit] + + def get_errors( + self, + tool_name: Optional[str] = None, + limit: int = 100, + ) -> List[CallLog]: + errors = list(self._errors) + + if tool_name: + errors = [e for e in errors if e.tool_name == tool_name] + + return list(reversed(errors))[:limit] + + def get_monitor_data(self) -> MonitorData: + uptime = time.time() - self._start_time + + avg_response_time = 0.0 + if self._recent_durations: + avg_response_time = sum(self._recent_durations) / len(self._recent_durations) + + p99_response_time = 0.0 + if self._recent_durations: + sorted_durations = sorted(self._recent_durations) + p99_index = int(len(sorted_durations) * 0.99) + if p99_index < len(sorted_durations): + p99_response_time = sorted_durations[p99_index] + + window_start = time.time() - self._metrics_window_seconds + calls_in_window = sum( + 1 for log in self._logs + if log.timestamp.timestamp() >= window_start + ) + requests_per_second = calls_in_window / self._metrics_window_seconds if self._metrics_window_seconds > 0 else 0 + + top_tools = sorted( + [{"tool": t, "count": c} for t, c in self._tool_call_counts.items()], + key=lambda x: x["count"], + reverse=True, + )[:10] + + recent_errors = [ + { + "id": e.id, + "tool_name": e.tool_name, + "error_message": e.error_message, + "timestamp": e.timestamp.isoformat(), + "response_status": e.response_status, + } + for e in list(self._errors)[-20:] + ] + + active_cutoff = datetime.utcnow() - timedelta(minutes=30) + active_users = [ + {"user_id": uid, "last_active": last.isoformat()} + for uid, last in self._user_activity.items() + if last >= active_cutoff + ] + + return MonitorData( + total_calls=self._total_calls, + success_calls=self._success_calls, + error_calls=self._error_calls, + rate_limited_calls=self._rate_limited_calls, + total_users=len(self._user_activity), + active_api_keys=len(self._api_key_activity), + uptime_seconds=round(uptime, 2), + avg_response_time_ms=round(avg_response_time, 2), + p99_response_time_ms=round(p99_response_time, 2), + requests_per_second=round(requests_per_second, 4), + top_tools=top_tools, + recent_errors=recent_errors, + active_users=active_users, + ) + + def get_tool_statistics( + self, + tool_name: Optional[str] = None, + ) -> Dict[str, Any]: + if tool_name: + counts = {tool_name: self._tool_call_counts.get(tool_name, 0)} + durations = {tool_name: self._tool_durations.get(tool_name, [])} + else: + counts = dict(self._tool_call_counts) + durations = dict(self._tool_durations) + + stats = {} + for t, count in counts.items(): + tool_durations = durations.get(t, []) + if tool_durations: + avg = sum(tool_durations) / len(tool_durations) + min_d = min(tool_durations) + max_d = max(tool_durations) + sorted_d = sorted(tool_durations) + p99_idx = int(len(sorted_d) * 0.99) + p99 = sorted_d[p99_idx] if p99_idx < len(sorted_d) else max_d + else: + avg = 0.0 + min_d = 0.0 + max_d = 0.0 + p99 = 0.0 + + stats[t] = { + "count": count, + "avg_duration_ms": round(avg, 2), + "min_duration_ms": round(min_d, 2), + "max_duration_ms": round(max_d, 2), + "p99_duration_ms": round(p99, 2), + } + + return stats + + def clear_logs(self) -> int: + count = len(self._logs) + self._logs.clear() + self._errors.clear() + logger.info(f"Cleared {count} logs") + return count + + def get_stats_summary(self) -> Dict[str, Any]: + monitor_data = self.get_monitor_data() + + success_rate = 0.0 + if monitor_data.total_calls > 0: + success_rate = (monitor_data.success_calls / monitor_data.total_calls) * 100 + + error_rate = 0.0 + if monitor_data.total_calls > 0: + error_rate = (monitor_data.error_calls / monitor_data.total_calls) * 100 + + return { + "summary": { + "total_calls": monitor_data.total_calls, + "success_calls": monitor_data.success_calls, + "error_calls": monitor_data.error_calls, + "rate_limited_calls": monitor_data.rate_limited_calls, + "success_rate_pct": round(success_rate, 2), + "error_rate_pct": round(error_rate, 2), + }, + "performance": { + "avg_response_time_ms": monitor_data.avg_response_time_ms, + "p99_response_time_ms": monitor_data.p99_response_time_ms, + "requests_per_second": monitor_data.requests_per_second, + "uptime_seconds": monitor_data.uptime_seconds, + }, + "system": { + "total_users": monitor_data.total_users, + "active_api_keys": monitor_data.active_api_keys, + "active_tools": len(self._tool_call_counts), + }, + } diff --git a/fastapi_mcp/panel/__init__.py b/fastapi_mcp/panel/__init__.py new file mode 100644 index 0000000..620c46e --- /dev/null +++ b/fastapi_mcp/panel/__init__.py @@ -0,0 +1,3 @@ +from .panel import PanelRouter + +__all__ = ["PanelRouter"] diff --git a/fastapi_mcp/panel/panel.py b/fastapi_mcp/panel/panel.py new file mode 100644 index 0000000..52707aa --- /dev/null +++ b/fastapi_mcp/panel/panel.py @@ -0,0 +1,852 @@ +import logging +from typing import Optional, Dict, Any, List + +from fastapi import APIRouter, Request, HTTPException +from fastapi.responses import HTMLResponse, JSONResponse +from typing_extensions import Annotated, Doc + +from fastapi_mcp.types import ( + PanelConfig, + ApiKeyCreate, + ApiKeyUpdate, + UserCreate, + UserUpdate, + CallStatus, + Permission, +) +from fastapi_mcp.auth.user_manager import UserManager +from fastapi_mcp.auth.api_key_manager import ApiKeyManager +from fastapi_mcp.auth.rate_limiter import RateLimiter +from fastapi_mcp.monitoring.call_logger import CallLogger + + +logger = logging.getLogger(__name__) + + +class PanelRouter: + """ + Modern web panel for FastAPI-MCP. + + Provides: + - API testing interface + - Real-time monitoring dashboard + - API Key management + - User management (if enabled) + - Call logs viewer + """ + + def __init__( + self, + config: Annotated[ + PanelConfig, + Doc("Panel configuration"), + ], + user_manager: Annotated[ + Optional[UserManager], + Doc("User manager instance"), + ] = None, + api_key_manager: Annotated[ + Optional[ApiKeyManager], + Doc("API Key manager instance"), + ] = None, + rate_limiter: Annotated[ + Optional[RateLimiter], + Doc("Rate limiter instance"), + ] = None, + call_logger: Annotated[ + Optional[CallLogger], + Doc("Call logger instance"), + ] = None, + ): + self.config = config + self.user_manager = user_manager + self.api_key_manager = api_key_manager + self.rate_limiter = rate_limiter + self.call_logger = call_logger + + self.router = APIRouter() + self._setup_routes() + + def _setup_routes(self): + @self.router.get("/", response_class=HTMLResponse, include_in_schema=False) + async def get_panel(request: Request): + return HTMLResponse(content=self._get_index_html()) + + @self.router.get("/api/config", include_in_schema=False) + async def get_config(): + return JSONResponse( + content={ + "title": self.config.title, + "enable_testing": self.config.enable_testing, + "enable_monitoring": self.config.enable_monitoring, + "enable_api_keys": self.config.enable_api_keys, + "enable_users": self.config.enable_users, + "enable_logs": self.config.enable_logs, + "logo_url": self.config.logo_url, + "theme": self.config.theme, + } + ) + + @self.router.get("/api/tools", include_in_schema=False) + async def get_tools(): + return JSONResponse(content={"tools": []}) + + if self.config.enable_monitoring: + self._setup_monitoring_routes() + + if self.config.enable_logs: + self._setup_logs_routes() + + if self.config.enable_api_keys: + self._setup_api_keys_routes() + + if self.config.enable_users and self.user_manager: + self._setup_users_routes() + + def _setup_monitoring_routes(self): + @self.router.get("/api/monitoring", include_in_schema=False) + async def get_monitoring(): + if not self.call_logger: + raise HTTPException(status_code=501, detail="Monitoring not enabled") + monitor_data = self.call_logger.get_monitor_data() + return JSONResponse(content=monitor_data.model_dump()) + + @self.router.get("/api/monitoring/stats", include_in_schema=False) + async def get_stats(): + if not self.call_logger: + raise HTTPException(status_code=501, detail="Monitoring not enabled") + stats = self.call_logger.get_stats_summary() + return JSONResponse(content=stats) + + @self.router.get("/api/monitoring/tools", include_in_schema=False) + async def get_tool_stats(tool_name: Optional[str] = None): + if not self.call_logger: + raise HTTPException(status_code=501, detail="Monitoring not enabled") + stats = self.call_logger.get_tool_statistics(tool_name) + return JSONResponse(content=stats) + + def _setup_logs_routes(self): + @self.router.get("/api/logs", include_in_schema=False) + async def get_logs( + tool_name: Optional[str] = None, + user_id: Optional[str] = None, + api_key_id: Optional[str] = None, + status: Optional[str] = None, + limit: int = 100, + offset: int = 0, + ): + if not self.call_logger: + raise HTTPException(status_code=501, detail="Logging not enabled") + + call_status = None + if status: + try: + call_status = CallStatus(status) + except ValueError: + pass + + logs = self.call_logger.get_logs( + tool_name=tool_name, + user_id=user_id, + api_key_id=api_key_id, + status=call_status, + limit=limit, + offset=offset, + ) + + return JSONResponse( + content={ + "logs": [log.model_dump(mode="json") for log in logs], + "total": len(self.call_logger._logs) if self.call_logger else 0, + "limit": limit, + "offset": offset, + } + ) + + @self.router.get("/api/logs/errors", include_in_schema=False) + async def get_errors(tool_name: Optional[str] = None, limit: int = 100): + if not self.call_logger: + raise HTTPException(status_code=501, detail="Logging not enabled") + + errors = self.call_logger.get_errors(tool_name=tool_name, limit=limit) + return JSONResponse( + content={ + "errors": [e.model_dump(mode="json") for e in errors], + "total": len(self.call_logger._errors) if self.call_logger else 0, + } + ) + + def _setup_api_keys_routes(self): + @self.router.get("/api/api-keys", include_in_schema=False) + async def list_api_keys(user_id: Optional[str] = None): + if not self.api_key_manager: + raise HTTPException(status_code=501, detail="API Key management not enabled") + + keys = self.api_key_manager.list_api_keys(user_id=user_id) + return JSONResponse( + content={ + "api_keys": [ + { + "id": k.id, + "name": k.name, + "description": k.description, + "user_id": k.user_id, + "status": k.status.value, + "permissions": [p.value for p in k.permissions], + "rate_limit": k.rate_limit, + "rate_limit_window": k.rate_limit_window, + "expires_at": k.expires_at.isoformat() if k.expires_at else None, + "last_used_at": k.last_used_at.isoformat() if k.last_used_at else None, + "created_at": k.created_at.isoformat(), + "usage_count": k.usage_count, + } + for k in keys + ] + } + ) + + @self.router.post("/api/api-keys", include_in_schema=False) + async def create_api_key(data: ApiKeyCreate, user_id: str = "default"): + if not self.api_key_manager: + raise HTTPException(status_code=501, detail="API Key management not enabled") + + api_key = self.api_key_manager.create_api_key(user_id=user_id, key_create=data) + return JSONResponse( + content={ + "id": api_key.id, + "key": api_key.key, + "name": api_key.name, + "message": "Make sure to copy this key now - you won't be able to see it again!", + }, + status_code=201, + ) + + @self.router.get("/api/api-keys/{api_key_id}", include_in_schema=False) + async def get_api_key(api_key_id: str): + if not self.api_key_manager: + raise HTTPException(status_code=501, detail="API Key management not enabled") + + key = self.api_key_manager.get_api_key(api_key_id) + if not key: + raise HTTPException(status_code=404, detail="API Key not found") + + return JSONResponse( + content={ + "id": key.id, + "name": key.name, + "description": key.description, + "user_id": key.user_id, + "status": key.status.value, + "permissions": [p.value for p in key.permissions], + "rate_limit": key.rate_limit, + "rate_limit_window": key.rate_limit_window, + "expires_at": key.expires_at.isoformat() if key.expires_at else None, + "last_used_at": key.last_used_at.isoformat() if key.last_used_at else None, + "created_at": key.created_at.isoformat(), + "usage_count": key.usage_count, + } + ) + + @self.router.patch("/api/api-keys/{api_key_id}", include_in_schema=False) + async def update_api_key(api_key_id: str, data: ApiKeyUpdate): + if not self.api_key_manager: + raise HTTPException(status_code=501, detail="API Key management not enabled") + + key = self.api_key_manager.update_api_key(api_key_id, data) + if not key: + raise HTTPException(status_code=404, detail="API Key not found") + + return JSONResponse(content={"message": "API Key updated"}) + + @self.router.post("/api/api-keys/{api_key_id}/revoke", include_in_schema=False) + async def revoke_api_key(api_key_id: str): + if not self.api_key_manager: + raise HTTPException(status_code=501, detail="API Key management not enabled") + + if not self.api_key_manager.revoke_api_key(api_key_id): + raise HTTPException(status_code=404, detail="API Key not found") + + return JSONResponse(content={"message": "API Key revoked"}) + + @self.router.post("/api/api-keys/{api_key_id}/rotate", include_in_schema=False) + async def rotate_api_key(api_key_id: str): + if not self.api_key_manager: + raise HTTPException(status_code=501, detail="API Key management not enabled") + + key = self.api_key_manager.rotate_api_key(api_key_id) + if not key: + raise HTTPException(status_code=404, detail="API Key not found") + + return JSONResponse( + content={ + "id": key.id, + "key": key.key, + "message": "API Key rotated - make sure to copy the new key!", + } + ) + + @self.router.delete("/api/api-keys/{api_key_id}", include_in_schema=False) + async def delete_api_key(api_key_id: str): + if not self.api_key_manager: + raise HTTPException(status_code=501, detail="API Key management not enabled") + + if not self.api_key_manager.delete_api_key(api_key_id): + raise HTTPException(status_code=404, detail="API Key not found") + + return JSONResponse(content={"message": "API Key deleted"}, status_code=204) + + def _setup_users_routes(self): + @self.router.get("/api/users", include_in_schema=False) + async def list_users(): + users = self.user_manager.list_users() + return JSONResponse( + content={ + "users": [ + { + "id": u.id, + "username": u.username, + "email": u.email, + "role": u.role.value, + "permissions": [p.value for p in u.permissions], + "is_active": u.is_active, + "created_at": u.created_at.isoformat(), + "last_login": u.last_login.isoformat() if u.last_login else None, + } + for u in users + ] + } + ) + + @self.router.post("/api/users", include_in_schema=False) + async def create_user(data: UserCreate): + try: + user = self.user_manager.create_user(data) + return JSONResponse( + content={ + "id": user.id, + "username": user.username, + "message": "User created", + }, + status_code=201, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + @self.router.get("/api/users/{user_id}", include_in_schema=False) + async def get_user(user_id: str): + user = self.user_manager.get_user(user_id) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return JSONResponse( + content={ + "id": user.id, + "username": user.username, + "email": user.email, + "role": user.role.value, + "permissions": [p.value for p in user.permissions], + "is_active": user.is_active, + "created_at": user.created_at.isoformat(), + "last_login": user.last_login.isoformat() if user.last_login else None, + } + ) + + @self.router.patch("/api/users/{user_id}", include_in_schema=False) + async def update_user(user_id: str, data: UserUpdate): + user = self.user_manager.update_user(user_id, data) + if not user: + raise HTTPException(status_code=404, detail="User not found") + + return JSONResponse(content={"message": "User updated"}) + + @self.router.delete("/api/users/{user_id}", include_in_schema=False) + async def delete_user(user_id: str): + if not self.user_manager.delete_user(user_id): + raise HTTPException(status_code=404, detail="User not found") + + return JSONResponse(content={"message": "User deleted"}, status_code=204) + + def _get_index_html(self) -> str: + return _PANEL_HTML + + +_PANEL_HTML = """ + + + + + + FastAPI MCP Panel + + + + +
+ +
+
+
+
+

Dashboard

+

Overview of your MCP server activity

+
+
+
+
+ Total Calls +
+ +
+
+
-
+
- success rate
+
+
+
+ Success +
+ +
+
+
-
+
successful calls
+
+
+
+ Errors +
+ +
+
+
-
+
- error rate
+
+
+
+ Avg Response +
+ +
+
+
-ms
+
P99: -ms
+
+
+
+

System Info

+
+
+
Uptime
+
-
+
+
+
Req/Sec
+
-
+
+
+
Active API Keys
+
-
+
+
+
Rate Limited
+
-
+
+
+
+
+ + + +
+
+
+ +
+ + + +""" diff --git a/fastapi_mcp/server.py b/fastapi_mcp/server.py index bb75106..7ae5e70 100644 --- a/fastapi_mcp/server.py +++ b/fastapi_mcp/server.py @@ -1,9 +1,10 @@ import json +import time import httpx from typing import Dict, Optional, Any, List, Union, Literal, Sequence from typing_extensions import Annotated, Doc -from fastapi import FastAPI, Request, APIRouter, params +from fastapi import FastAPI, Request, APIRouter, params, HTTPException from fastapi.openapi.utils import get_openapi from mcp.server.lowlevel.server import Server import mcp.types as types @@ -11,7 +12,16 @@ from fastapi_mcp.openapi.convert import convert_openapi_to_mcp_tools from fastapi_mcp.transport.sse import FastApiSseTransport from fastapi_mcp.transport.http import FastApiHttpSessionManager -from fastapi_mcp.types import HTTPRequestInfo, AuthConfig +from fastapi_mcp.types import ( + HTTPRequestInfo, + AuthConfig, + ExtendedAuthConfig, + PanelConfig, + RateLimitConfig, + CallStatus, + ApiKey, + User, +) import logging @@ -72,8 +82,8 @@ def __init__( Doc("List of tags to exclude from MCP tools. Cannot be used with include_tags."), ] = None, auth_config: Annotated[ - Optional[AuthConfig], - Doc("Configuration for MCP authentication"), + Optional[Union[AuthConfig, ExtendedAuthConfig]], + Doc("Configuration for MCP authentication and extended features"), ] = None, headers: Annotated[ List[str], @@ -84,6 +94,34 @@ def __init__( """ ), ] = ["authorization"], + enable_panel: Annotated[ + bool, + Doc("Whether to enable the web panel"), + ] = True, + panel_config: Annotated[ + Optional[PanelConfig], + Doc("Configuration for the web panel"), + ] = None, + enable_rate_limit: Annotated[ + bool, + Doc("Whether to enable rate limiting"), + ] = False, + rate_limit_config: Annotated[ + Optional[RateLimitConfig], + Doc("Configuration for rate limiting"), + ] = None, + enable_call_logging: Annotated[ + bool, + Doc("Whether to enable call logging and monitoring"), + ] = True, + enable_user_management: Annotated[ + bool, + Doc("Whether to enable user management"), + ] = False, + enable_api_key_management: Annotated[ + bool, + Doc("Whether to enable API key management"), + ] = False, ): # Validate operation and tag filtering options if include_operations is not None and exclude_operations is not None: @@ -119,10 +157,58 @@ def __init__( ) self._forward_headers = {h.lower() for h in headers} - self._http_transport: FastApiHttpSessionManager | None = None # Store reference to HTTP transport for cleanup - + self._http_transport: FastApiHttpSessionManager | None = None + + self._enable_panel = enable_panel + self._panel_config = panel_config or PanelConfig() + self._enable_rate_limit = enable_rate_limit + self._rate_limit_config = rate_limit_config or RateLimitConfig() + self._enable_call_logging = enable_call_logging + self._enable_user_management = enable_user_management + self._enable_api_key_management = enable_api_key_management + + self._user_manager = None + self._api_key_manager = None + self._rate_limiter = None + self._call_logger = None + self._panel_router = None + + self._setup_extensions() self.setup_server() + def _setup_extensions(self) -> None: + if self._enable_user_management: + try: + from fastapi_mcp.auth.user_manager import UserManager + self._user_manager = UserManager() + logger.info("User management enabled") + except ImportError as e: + logger.warning(f"Failed to enable user management: {e}") + + if self._enable_api_key_management: + try: + from fastapi_mcp.auth.api_key_manager import ApiKeyManager + self._api_key_manager = ApiKeyManager() + logger.info("API key management enabled") + except ImportError as e: + logger.warning(f"Failed to enable API key management: {e}") + + if self._enable_rate_limit: + try: + from fastapi_mcp.auth.rate_limiter import RateLimiter + self._rate_limiter = RateLimiter(self._rate_limit_config) + logger.info("Rate limiting enabled") + except ImportError as e: + logger.warning(f"Failed to enable rate limiting: {e}") + + if self._enable_call_logging: + try: + from fastapi_mcp.monitoring.call_logger import CallLogger + self._call_logger = CallLogger() + logger.info("Call logging enabled") + except ImportError as e: + logger.warning(f"Failed to enable call logging: {e}") + def setup_server(self) -> None: openapi_schema = get_openapi( title=self.fastapi.title, @@ -423,6 +509,75 @@ def mount_sse( logger.info(f"MCP SSE server listening at {mount_path}") + def mount_panel( + self, + router: Annotated[ + Optional[FastAPI | APIRouter], + Doc( + """ + The FastAPI app or APIRouter to mount the panel to. If not provided, the panel + will be mounted to the FastAPI app. + """ + ), + ] = None, + mount_path: Annotated[ + Optional[str], + Doc( + """ + Path where the panel will be mounted. + If not provided, uses the path from panel_config. + """ + ), + ] = None, + ) -> None: + """ + Mount the web panel to a FastAPI app or APIRouter. + + The panel provides: + - Dashboard with real-time monitoring + - API testing interface + - API Key management + - User management + - Call logs viewer + """ + if not self._enable_panel: + logger.info("Panel is disabled, skipping mount") + return + + if not router: + router = self.fastapi + + assert isinstance(router, (FastAPI, APIRouter)), f"Invalid router type: {type(router)}" + + actual_mount_path = mount_path or self._panel_config.mount_path + + if not actual_mount_path.startswith("/"): + actual_mount_path = f"/{actual_mount_path}" + if actual_mount_path.endswith("/"): + actual_mount_path = actual_mount_path[:-1] + + try: + from fastapi_mcp.panel.panel import PanelRouter + + panel_router = PanelRouter( + config=self._panel_config, + user_manager=self._user_manager, + api_key_manager=self._api_key_manager, + rate_limiter=self._rate_limiter, + call_logger=self._call_logger, + ) + + router.include_router(panel_router.router, prefix=actual_mount_path) + + if isinstance(router, APIRouter): + self.fastapi.include_router(router) + + self._panel_router = panel_router + logger.info(f"MCP Panel mounted at {actual_mount_path}") + + except ImportError as e: + logger.warning(f"Failed to mount panel: {e}") + def mount( self, router: Annotated[ @@ -498,6 +653,17 @@ async def _execute_api_tool( Returns: The result as MCP content types """ + start_time = time.time() + user_id = None + api_key_id = None + api_key: Optional[ApiKey] = None + call_status = CallStatus.SUCCESS + response_status = 200 + response_body = None + error_message = None + client_ip = None + user_agent = None + if tool_name not in operation_map: raise Exception(f"Unknown tool: {tool_name}") @@ -505,7 +671,57 @@ async def _execute_api_tool( path: str = operation["path"] method: str = operation["method"] parameters: List[Dict[str, Any]] = operation.get("parameters", []) - arguments = arguments.copy() if arguments else {} # Deep copy arguments to avoid mutating the original + arguments = arguments.copy() if arguments else {} + + if http_request_info and http_request_info.headers: + auth_header = http_request_info.headers.get("authorization", "") + if auth_header.lower().startswith("bearer "): + api_key_value = auth_header[7:].strip() + if self._api_key_manager: + api_key = self._api_key_manager.validate_api_key(api_key_value) + if api_key: + api_key_id = api_key.id + user_id = api_key.user_id + logger.debug(f"Authenticated with API key: {api_key_id}") + + client_ip = http_request_info.headers.get("x-forwarded-for", http_request_info.headers.get("x-real-ip")) + user_agent = http_request_info.headers.get("user-agent") + + if self._enable_rate_limit and self._rate_limiter: + user = None + if self._user_manager and user_id: + user = self._user_manager.get_user(user_id) + + allowed, rate_limit_info = self._rate_limiter.check_rate_limit( + user_id=user_id, + api_key=api_key, + ) + + if not allowed: + call_status = CallStatus.RATE_LIMITED + logger.warning(f"Rate limit exceeded for tool {tool_name}: {rate_limit_info}") + + if self._call_logger: + duration_ms = (time.time() - start_time) * 1000 + self._call_logger.log_call( + tool_name=tool_name, + status=call_status, + duration_ms=duration_ms, + request_method=method, + request_path=path, + user_id=user_id, + api_key_id=api_key_id, + request_headers=http_request_info.headers if http_request_info else {}, + response_status=429, + error_message="Rate limit exceeded", + client_ip=client_ip, + user_agent=user_agent, + ) + + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded. Try again in {rate_limit_info.get('reset_seconds', 0):.1f} seconds.", + ) for param in parameters: if param.get("in") == "path" and param.get("name") in arguments: @@ -530,45 +746,80 @@ async def _execute_api_tool( raise ValueError(f"Parameter name is None for parameter: {param}") headers[param_name] = arguments.pop(param_name) - # Forward headers that are in the allowlist if http_request_info and http_request_info.headers: for name, value in http_request_info.headers.items(): - # case-insensitive check for allowed headers if name.lower() in self._forward_headers: headers[name] = value body = arguments if arguments else None + request_body_str = None + if body: + try: + request_body_str = json.dumps(body, ensure_ascii=False) + except (TypeError, ValueError): + request_body_str = str(body) + result_text = "" try: logger.debug(f"Making {method.upper()} request to {path}") response = await self._request(client, method, path, query, headers, body) + response_status = response.status_code - # TODO: Better typing for the AsyncClientProtocol. It should return a ResponseProtocol that has a json() method that returns a dict/list/etc. try: result = response.json() result_text = json.dumps(result, indent=2, ensure_ascii=False) + response_body = result_text except json.JSONDecodeError: if hasattr(response, "text"): result_text = response.text + response_body = result_text else: result_text = response.content + response_body = str(result_text) if isinstance(result_text, bytes) else result_text - # If not raising an exception, the MCP server will return the result as a regular text response, without marking it as an error. - # TODO: Use a raise_for_status() method on the response (it needs to also be implemented in the AsyncClientProtocol) if 400 <= response.status_code < 600: - raise Exception( - f"Error calling {tool_name}. Status code: {response.status_code}. Response: {response.text}" - ) + call_status = CallStatus.ERROR + error_message = f"Status code: {response.status_code}. Response: {response.text}" + raise Exception(error_message) - try: - return [types.TextContent(type="text", text=result_text)] - except ValueError: - return [types.TextContent(type="text", text=result_text)] + except HTTPException: + raise except Exception as e: + call_status = CallStatus.ERROR + error_message = str(e) logger.exception(f"Error calling {tool_name}") raise e + finally: + duration_ms = (time.time() - start_time) * 1000 + + if api_key and self._api_key_manager: + self._api_key_manager.record_usage(api_key.id) + + if self._call_logger: + self._call_logger.log_call( + tool_name=tool_name, + status=call_status, + duration_ms=duration_ms, + request_method=method, + request_path=path, + user_id=user_id, + api_key_id=api_key_id, + request_headers=http_request_info.headers if http_request_info else {}, + request_body=request_body_str, + response_status=response_status, + response_body=response_body, + error_message=error_message, + client_ip=client_ip, + user_agent=user_agent, + ) + + try: + return [types.TextContent(type="text", text=result_text)] + except ValueError: + return [types.TextContent(type="text", text=result_text)] + async def _request( self, client: httpx.AsyncClient, diff --git a/fastapi_mcp/types.py b/fastapi_mcp/types.py index 2e8cf2e..4cec241 100644 --- a/fastapi_mcp/types.py +++ b/fastapi_mcp/types.py @@ -1,5 +1,7 @@ import time -from typing import Any, Dict, Annotated, Union, Optional, Sequence, Literal, List +import secrets +from datetime import datetime, timedelta +from typing import Any, Dict, Annotated, Union, Optional, Sequence, Literal, List, Callable from typing_extensions import Doc from pydantic import ( BaseModel, @@ -7,9 +9,11 @@ HttpUrl, field_validator, model_validator, + Field, ) from pydantic.main import IncEx from fastapi import params +from enum import Enum StrHttpUrl = Annotated[Union[str, HttpUrl], HttpUrl] @@ -382,3 +386,195 @@ class ClientRegistrationResponse(BaseType): grant_types: List[str] token_endpoint_auth_method: str client_name: str + + +class UserRole(str, Enum): + ADMIN = "admin" + USER = "user" + GUEST = "guest" + API_ONLY = "api_only" + + +class Permission(str, Enum): + READ_TOOLS = "read:tools" + CALL_TOOLS = "call:tools" + MANAGE_USERS = "manage:users" + MANAGE_API_KEYS = "manage:api_keys" + VIEW_LOGS = "view:logs" + VIEW_METRICS = "view:metrics" + ADMIN_ALL = "admin:all" + + +class User(BaseType): + id: str + username: str + email: Optional[str] = None + role: UserRole = UserRole.USER + permissions: List[Permission] = [] + is_active: bool = True + created_at: datetime = Field(default_factory=datetime.utcnow) + last_login: Optional[datetime] = None + + def has_permission(self, permission: Permission) -> bool: + if self.role == UserRole.ADMIN or Permission.ADMIN_ALL in self.permissions: + return True + return permission in self.permissions + + def has_role(self, role: UserRole) -> bool: + role_hierarchy = {UserRole.GUEST: 0, UserRole.USER: 1, UserRole.API_ONLY: 1, UserRole.ADMIN: 2} + return role_hierarchy.get(self.role, 0) >= role_hierarchy.get(role, 0) + + +class UserCreate(BaseType): + username: str + email: Optional[str] = None + role: UserRole = UserRole.USER + permissions: List[Permission] = [] + password: Optional[str] = None + + +class UserUpdate(BaseType): + email: Optional[str] = None + role: Optional[UserRole] = None + permissions: Optional[List[Permission]] = None + is_active: Optional[bool] = None + + +class ApiKeyStatus(str, Enum): + ACTIVE = "active" + INACTIVE = "inactive" + REVOKED = "revoked" + + +class ApiKey(BaseType): + id: str + key: str + name: str + description: Optional[str] = None + user_id: str + status: ApiKeyStatus = ApiKeyStatus.ACTIVE + permissions: List[Permission] = [] + rate_limit: Optional[int] = None + rate_limit_window: int = 60 + expires_at: Optional[datetime] = None + last_used_at: Optional[datetime] = None + created_at: datetime = Field(default_factory=datetime.utcnow) + usage_count: int = 0 + + @classmethod + def generate_key(cls, prefix: str = "mcp") -> str: + return f"{prefix}_{secrets.token_urlsafe(32)}" + + def is_valid(self) -> bool: + if self.status != ApiKeyStatus.ACTIVE: + return False + if self.expires_at and datetime.utcnow() > self.expires_at: + return False + return True + + +class ApiKeyCreate(BaseType): + name: str + description: Optional[str] = None + permissions: List[Permission] = [] + rate_limit: Optional[int] = None + rate_limit_window: int = 60 + expires_in_days: Optional[int] = None + + +class ApiKeyUpdate(BaseType): + name: Optional[str] = None + description: Optional[str] = None + status: Optional[ApiKeyStatus] = None + permissions: Optional[List[Permission]] = None + rate_limit: Optional[int] = None + rate_limit_window: Optional[int] = None + + +class RateLimitConfig(BaseType): + enabled: bool = True + default_limit: int = 100 + default_window: int = 60 + user_limit: Optional[int] = None + api_key_limit: Optional[int] = None + global_limit: Optional[int] = None + + def get_effective_limit(self, api_key: Optional[ApiKey] = None) -> int: + if api_key and api_key.rate_limit: + return api_key.rate_limit + if api_key and self.api_key_limit: + return self.api_key_limit + if self.user_limit: + return self.user_limit + return self.default_limit + + +class CallStatus(str, Enum): + SUCCESS = "success" + ERROR = "error" + RATE_LIMITED = "rate_limited" + UNAUTHORIZED = "unauthorized" + + +class CallLog(BaseType): + id: str + timestamp: datetime = Field(default_factory=datetime.utcnow) + tool_name: str + user_id: Optional[str] = None + api_key_id: Optional[str] = None + status: CallStatus + duration_ms: float + request_method: str + request_path: str + request_headers: Dict[str, str] = {} + request_body: Optional[str] = None + response_status: int + response_body: Optional[str] = None + error_message: Optional[str] = None + client_ip: Optional[str] = None + user_agent: Optional[str] = None + + +class MonitorData(BaseType): + total_calls: int = 0 + success_calls: int = 0 + error_calls: int = 0 + rate_limited_calls: int = 0 + total_users: int = 0 + active_api_keys: int = 0 + uptime_seconds: float = 0 + avg_response_time_ms: float = 0 + p99_response_time_ms: float = 0 + requests_per_second: float = 0 + top_tools: List[Dict[str, Any]] = [] + recent_errors: List[Dict[str, Any]] = [] + active_users: List[Dict[str, Any]] = [] + + +class PanelConfig(BaseType): + enabled: bool = True + mount_path: str = "/mcp-panel" + title: str = "FastAPI MCP Panel" + enable_testing: bool = True + enable_monitoring: bool = True + enable_api_keys: bool = True + enable_users: bool = False + enable_logs: bool = True + logo_url: Optional[str] = None + theme: Literal["light", "dark", "auto"] = "auto" + + +class SecurityConfig(BaseType): + enabled: bool = True + require_auth: bool = False + require_api_key: bool = False + allow_guest_access: bool = False + session_timeout_minutes: int = 60 + jwt_secret: Optional[str] = None + password_hash_rounds: int = 12 + + +class ExtendedAuthConfig(AuthConfig): + security: Optional[SecurityConfig] = None + rate_limit: Optional[RateLimitConfig] = None + panel: Optional[PanelConfig] = None