Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 95 additions & 13 deletions mcpauth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,98 @@
from .exceptions import (
MCPAuthException as MCPAuthException,
MCPAuthConfigException as MCPAuthConfigException,
AuthServerExceptionCode as AuthServerExceptionCode,
MCPAuthAuthServerException as MCPAuthAuthServerException,
BearerAuthExceptionCode as BearerAuthExceptionCode,
MCPAuthBearerAuthExceptionDetails as MCPAuthBearerAuthExceptionDetails,
MCPAuthBearerAuthException as MCPAuthBearerAuthException,
MCPAuthJwtVerificationExceptionCode as MCPAuthJwtVerificationExceptionCode,
MCPAuthJwtVerificationException as MCPAuthJwtVerificationException,
)
import logging
from typing import Any, Literal, Union

from .middleware.create_bearer_auth import BaseBearerAuthConfig, BearerAuthConfig
from .types import VerifyAccessTokenFunction
from .config import MCPAuthConfig
from .exceptions import MCPAuthAuthServerException, AuthServerExceptionCode
from .utils import validate_server_config
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse


class MCPAuth:
def __init__(self):
self.config = None
"""
The main class for the mcp-auth library, which provides methods for creating middleware
functions for handling OAuth 2.0-related tasks and bearer token auth.

See Also: https://mcp-auth.dev for more information about the library and its usage.

:param config: An instance of `MCPAuthConfig` containing the server configuration.
"""

def __init__(self, config: MCPAuthConfig):
result = validate_server_config(config.server)

if not result.is_valid:
logging.error(
"The authorization server configuration is invalid:\n"
f"{result.errors}\n"
)
raise MCPAuthAuthServerException(
AuthServerExceptionCode.INVALID_SERVER_CONFIG, cause=result
)

if len(result.warnings) > 0:
logging.warning("The authorization server configuration has warnings:\n")
for warning in result.warnings:
logging.warning(f"- {warning}")

self.config = config

def metadata_response(self) -> JSONResponse:
"""
Returns a response containing the server metadata in JSON format with CORS support.
"""
server_config = self.config.server

response = JSONResponse(
server_config.metadata.model_dump(exclude_none=True),
status_code=200,
)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
return response

def bearer_auth_middleware(
self,
mode_or_verify: Union[Literal["jwt"], VerifyAccessTokenFunction],
config: BaseBearerAuthConfig = BaseBearerAuthConfig(),
jwt_options: dict[str, Any] = {},
) -> type[BaseHTTPMiddleware]:
"""
Creates a middleware that handles bearer token authentication.

:param mode_or_verify: If "jwt", uses built-in JWT verification; or a custom function that
takes a string token and returns an `AuthInfo` object.
:param config: Configuration for the Bearer auth handler, including audience, required
scopes, etc.
:param jwt_options: Optional dictionary of additional options for JWT verification
(`jwt.decode`). Not used if a custom function is provided.
:return: A middleware class that can be used in a Starlette or FastAPI application.
"""

metadata = self.config.server.metadata
if isinstance(mode_or_verify, str) and mode_or_verify == "jwt":
from .utils import create_verify_jwt

if not metadata.jwks_uri:
raise MCPAuthAuthServerException(
AuthServerExceptionCode.MISSING_JWKS_URI
)

verify = create_verify_jwt(
metadata.jwks_uri,
options=jwt_options,
)
elif callable(mode_or_verify):
verify = mode_or_verify
else:
raise ValueError(
"mode_or_verify must be 'jwt' or a callable function that verifies tokens."
)

from .middleware.create_bearer_auth import create_bearer_auth

return create_bearer_auth(
verify, BearerAuthConfig(issuer=metadata.issuer, **config.model_dump())
)
2 changes: 1 addition & 1 deletion mcpauth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def to_json(self, show_cause: bool = False) -> Record:
"error_description": self.message,
"cause": (
(
{k: v for k, v in self.cause.model_dump().items() if v is not None}
self.cause.model_dump(exclude_none=True)
if isinstance(self.cause, BaseModel)
else str(self.cause)
)
Expand Down
36 changes: 27 additions & 9 deletions mcpauth/middleware/create_bearer_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,39 @@
from ..types import VerifyAccessTokenFunction, Record


class BearerAuthConfig(BaseModel):
class BaseBearerAuthConfig(BaseModel):
"""
Configuration for the Bearer auth handler.

Attributes:
issuer: The expected issuer of the access token.
audience: The expected audience of the access token.
required_scopes: An array of required scopes that the access token must have.
show_error_details: Whether to show detailed error information in the response.
Base configuration for the Bearer auth handler.
"""

issuer: str
audience: Optional[str] = None
"""
The expected audience of the access token. If not provided, no audience check is performed.
"""

required_scopes: Optional[List[str]] = None
"""
An array of required scopes that the access token must have. If not provided, no scope check is
performed.
"""

show_error_details: bool = False
"""
Whether to show detailed error information in the response. Defaults to False.
If True, detailed error information will be included in the response body for debugging
purposes.
"""


class BearerAuthConfig(BaseBearerAuthConfig):
"""
Configuration for the Bearer auth handler.
"""

issuer: str
"""
The expected issuer of the access token. This should be a valid URL.
"""


def get_bearer_token_from_headers(headers: Headers) -> str:
Expand Down
7 changes: 7 additions & 0 deletions mcpauth/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .auth_server import (
AuthServerConfig as AuthServerConfig,
AuthServerType as AuthServerType,
)
from .oauth import (
AuthorizationServerMetadata as AuthorizationServerMetadata,
)
14 changes: 14 additions & 0 deletions mcpauth/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ._create_verify_jwt import create_verify_jwt as create_verify_jwt
from ._fetch_server_config import (
fetch_server_config as fetch_server_config,
fetch_server_config_by_well_known_url as fetch_server_config_by_well_known_url,
ServerMetadataPaths as ServerMetadataPaths,
)
from ._validate_server_config import (
validate_server_config as validate_server_config,
AuthServerConfigErrorCode as AuthServerConfigErrorCode,
AuthServerConfigError as AuthServerConfigError,
AuthServerConfigWarningCode as AuthServerConfigWarningCode,
AuthServerConfigWarning as AuthServerConfigWarning,
AuthServerConfigValidationResult as AuthServerConfigValidationResult,
)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Callable, Optional
from urllib.parse import urlparse, urlunparse
import aiohttp
import requests
import pydantic
from pathlib import Path

Expand Down Expand Up @@ -46,7 +46,7 @@ def get_oidc_well_known_url(issuer: str) -> str:
return urlunparse(parsed._replace(path=new_path))


async def fetch_server_config_by_well_known_url(
def fetch_server_config_by_well_known_url(
well_known_url: str,
type: AuthServerType,
transpile_data: Optional[Callable[[Record], Record]] = None,
Expand All @@ -69,14 +69,13 @@ async def fetch_server_config_by_well_known_url(
"""

try:
async with aiohttp.ClientSession() as session:
async with session.get(well_known_url) as response:
response.raise_for_status()
json = await response.json()
transpiled_data = transpile_data(json) if transpile_data else json
return AuthServerConfig(
metadata=AuthorizationServerMetadata(**transpiled_data), type=type
)
response = requests.get(well_known_url, timeout=10)
response.raise_for_status()
json = response.json()
transpiled_data = transpile_data(json) if transpile_data else json
return AuthServerConfig(
metadata=AuthorizationServerMetadata(**transpiled_data), type=type
)
except pydantic.ValidationError as e:
raise MCPAuthAuthServerException(
AuthServerExceptionCode.INVALID_SERVER_METADATA,
Expand All @@ -90,7 +89,7 @@ async def fetch_server_config_by_well_known_url(
) from e


async def fetch_server_config(
def fetch_server_config(
issuer: str,
type: AuthServerType,
transpile_data: Optional[Callable[[Record], Record]] = None,
Expand Down Expand Up @@ -141,6 +140,4 @@ async def fetch_server_config(
if type == AuthServerType.OAUTH
else get_oidc_well_known_url(issuer)
)
return await fetch_server_config_by_well_known_url(
well_known_url, type, transpile_data
)
return fetch_server_config_by_well_known_url(well_known_url, type, transpile_data)
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ keywords = [
"openid-connect",
]
dependencies = [
"aiohttp>=3.11.18",
"pydantic>=2.11.3",
"pyjwt[crypto]>=2.9.0",
"requests>=2.32.3",
"starlette>=0.46.2",
]

Expand All @@ -27,9 +27,14 @@ documentation = "https://mcp-auth.dev/docs"

[dependency-groups]
dev = [
"aresponses>=3.0.0",
"black>=24.8.0",
"pytest>=8.3.5",
"pytest-asyncio>=0.26.0",
"pytest-cov>=6.1.1",
"responses>=0.25.7",
"uvicorn>=0.34.2",
]

[tool.coverage.run]
branch = true
source = ["mcpauth"]
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[pytest]
pythonpath = .
5 changes: 0 additions & 5 deletions samples/server/fast_api.py

This file was deleted.

30 changes: 30 additions & 0 deletions samples/server/starlette.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from mcpauth import MCPAuth
from mcpauth.config import MCPAuthConfig
from mcpauth.models import AuthServerType
from mcpauth.utils import fetch_server_config, ServerMetadataPaths
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse
from starlette.requests import Request

mcpAuth = MCPAuth(
MCPAuthConfig(
server=fetch_server_config("https://auth.logto.io/oidc", AuthServerType.OIDC)
)
)

protected_app = Starlette(
middleware=[Middleware(mcpAuth.bearer_auth_middleware("jwt"))]
)


@protected_app.route("/") # type: ignore
async def secret_endpoint(_: Request):
return JSONResponse({"secret": True})


app = Starlette(
debug=True,
)
app.mount(ServerMetadataPaths.OAUTH.value, mcpAuth.metadata_response())
app.mount("/mcp", protected_app)
Loading