From 1265032f83ae6b22d9c8bd2ea40dd2cc1a6766b1 Mon Sep 17 00:00:00 2001 From: Gao Sun Date: Fri, 2 May 2025 12:16:55 -0700 Subject: [PATCH 1/2] refactor: update usage --- mcpauth/__init__.py | 48 ++++--- mcpauth/config.py | 145 ++++++++++++++++++++- mcpauth/middleware/create_bearer_auth.py | 20 +-- mcpauth/models/__init__.py | 7 - mcpauth/models/auth_server.py | 39 ------ mcpauth/models/oauth.py | 100 -------------- mcpauth/types.py | 2 +- mcpauth/utils/__init__.py | 1 - mcpauth/utils/_create_verify_jwt.py | 9 +- mcpauth/utils/_fetch_server_config.py | 25 ++-- mcpauth/utils/_validate_server_config.py | 2 +- samples/server/starlette.py | 26 ++-- tests/__init__test.py | 46 +++---- tests/utils/fetch_server_config_test.py | 10 +- tests/utils/validate_server_config_test.py | 3 +- 15 files changed, 223 insertions(+), 260 deletions(-) delete mode 100644 mcpauth/models/__init__.py delete mode 100644 mcpauth/models/auth_server.py delete mode 100644 mcpauth/models/oauth.py diff --git a/mcpauth/__init__.py b/mcpauth/__init__.py index d9f861c..8168ae2 100644 --- a/mcpauth/__init__.py +++ b/mcpauth/__init__.py @@ -1,9 +1,9 @@ import logging -from typing import Any, Literal, Union +from typing import List, Literal, Optional, Union -from .middleware.create_bearer_auth import BaseBearerAuthConfig, BearerAuthConfig +from .middleware.create_bearer_auth import BearerAuthConfig from .types import VerifyAccessTokenFunction -from .config import MCPAuthConfig +from .config import AuthServerConfig from .exceptions import MCPAuthAuthServerException, AuthServerExceptionCode from .utils import validate_server_config from starlette.middleware.base import BaseHTTPMiddleware @@ -16,12 +16,14 @@ class MCPAuth: functions for handling OAuth 2.0-related tasks and bearer token auth. See Also: https://mcp-auth.dev for more information about the library and its usage. - - :param config: An instance of `MCPAuthConfig` containing the server configuration. """ - def __init__(self, config: MCPAuthConfig): - result = validate_server_config(config.server) + def __init__(self, server: AuthServerConfig): + """ + :param server: Configuration for the remote authorization server. + """ + + result = validate_server_config(server) if not result.is_valid: logging.error( @@ -37,13 +39,13 @@ def __init__(self, config: MCPAuthConfig): for warning in result.warnings: logging.warning(f"- {warning}") - self.config = config + self.server = server def metadata_response(self) -> JSONResponse: """ Returns a response containing the server metadata in JSON format with CORS support. """ - server_config = self.config.server + server_config = self.server response = JSONResponse( server_config.metadata.model_dump(exclude_none=True), @@ -56,22 +58,26 @@ def metadata_response(self) -> JSONResponse: def bearer_auth_middleware( self, mode_or_verify: Union[Literal["jwt"], VerifyAccessTokenFunction], - config: BaseBearerAuthConfig = BaseBearerAuthConfig(), - jwt_options: dict[str, Any] = {}, + audience: Optional[str] = None, + required_scopes: Optional[List[str]] = None, + show_error_details: bool = False, + leeway: float = 60, ) -> type[BaseHTTPMiddleware]: """ Creates a middleware that handles bearer token authentication. :param mode_or_verify: If "jwt", uses built-in JWT verification; or a custom function that takes a string token and returns an `AuthInfo` object. - :param config: Configuration for the Bearer auth handler, including audience, required - scopes, etc. - :param jwt_options: Optional dictionary of additional options for JWT verification - (`jwt.decode`). Not used if a custom function is provided. + :param audience: Optional audience to verify against the token. + :param required_scopes: Optional list of scopes that the token must contain. + :param show_error_details: Whether to include detailed error information in the response. + Defaults to `False`. + :param leeway: Optional leeway in seconds for JWT verification (`jwt.decode`). Defaults to + `60`. Not used if a custom function is provided. :return: A middleware class that can be used in a Starlette or FastAPI application. """ - metadata = self.config.server.metadata + metadata = self.server.metadata if isinstance(mode_or_verify, str) and mode_or_verify == "jwt": from .utils import create_verify_jwt @@ -82,7 +88,7 @@ def bearer_auth_middleware( verify = create_verify_jwt( metadata.jwks_uri, - options=jwt_options, + leeway=leeway, ) elif callable(mode_or_verify): verify = mode_or_verify @@ -94,5 +100,11 @@ def bearer_auth_middleware( from .middleware.create_bearer_auth import create_bearer_auth return create_bearer_auth( - verify, BearerAuthConfig(issuer=metadata.issuer, **config.model_dump()) + verify, + BearerAuthConfig( + issuer=metadata.issuer, + audience=audience, + required_scopes=required_scopes, + show_error_details=show_error_details, + ), ) diff --git a/mcpauth/config.py b/mcpauth/config.py index 17bdead..a0c2e9b 100644 --- a/mcpauth/config.py +++ b/mcpauth/config.py @@ -1,13 +1,148 @@ -from .models.auth_server import AuthServerConfig +from enum import Enum +from typing import List, Optional from pydantic import BaseModel +from pydantic import BaseModel + + +class AuthorizationServerMetadata(BaseModel): + """ + Pydantic model for OAuth 2.0 Authorization Server Metadata as defined in RFC 8414. + """ + + issuer: str + """ + The authorization server's issuer identifier, which is a URL that uses the `https` scheme and + has no query or fragment components. + """ + authorization_endpoint: str + """ + URL of the authorization server's authorization endpoint [[RFC6749](https://rfc-editor.org/rfc/rfc6749)]. + This is REQUIRED unless no grant types are supported that use the authorization endpoint. + + See: https://rfc-editor.org/rfc/rfc6749#section-3.1 + """ + + token_endpoint: str + """ + URL of the authorization server's token endpoint [[RFC6749](https://rfc-editor.org/rfc/rfc6749)]. + This is REQUIRED unless only the implicit grant type is supported. + + See: https://rfc-editor.org/rfc/rfc6749#section-3.2 + """ -class MCPAuthConfig(BaseModel): + jwks_uri: Optional[str] = None """ - Configuration for the `MCPAuth` class. + URL of the authorization server's JWK Set [[JWK](https://www.rfc-editor.org/rfc/rfc8414.html#ref-JWK)] document. + The referenced document contains the signing key(s) the client uses to validate signatures + from the authorization server. This URL MUST use the `https` scheme. """ - server: AuthServerConfig + registration_endpoint: Optional[str] = None + """ + URL of the authorization server's OAuth 2.0 Dynamic Client Registration endpoint + [[RFC7591](https://www.rfc-editor.org/rfc/rfc7591)]. + """ + + scope_supported: Optional[List[str]] = None + + response_types_supported: List[str] + """ + JSON array containing a list of the OAuth 2.0 `response_type` values that this authorization + server supports. The array values used are the same as those used with the `response_types` + parameter defined by "OAuth 2.0 Dynamic Client Registration Protocol" [[RFC7591](https://www.rfc-editor.org/rfc/rfc7591)]. + """ + + response_modes_supported: Optional[List[str]] = None + """ + JSON array containing a list of the OAuth 2.0 `response_mode` values that this + authorization server supports, as specified in "OAuth 2.0 Multiple Response Type Encoding Practices" + [[OAuth.Responses](https://datatracker.ietf.org/doc/html/rfc8414#ref-OAuth.Responses)]. + + If omitted, the default is ["query", "fragment"]. The response mode value `"form_post"` is + also defined in "OAuth 2.0 Form Post Response Mode" [[OAuth.FormPost](https://datatracker.ietf.org/doc/html/rfc8414#ref-OAuth.Post)]. + """ + + grant_types_supported: Optional[List[str]] = None + """ + JSON array containing a list of the OAuth 2.0 grant type values that this authorization server supports. + The array values used are the same as those used with the `grant_types` parameter defined by + "OAuth 2.0 Dynamic Client Registration Protocol" [[RFC7591](https://www.rfc-editor.org/rfc/rfc7591)]. + + If omitted, the default value is ["authorization_code", "implicit"]. + """ + + token_endpoint_auth_methods_supported: Optional[List[str]] = None + token_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None + service_documentation: Optional[str] = None + ui_locales_supported: Optional[List[str]] = None + op_policy_uri: Optional[str] = None + op_tos_uri: Optional[str] = None + + revocation_endpoint: Optional[str] = None """ - Config for the remote authorization server. + URL of the authorization server's OAuth 2.0 revocation endpoint [[RFC7009](https://www.rfc-editor.org/rfc/rfc7009)]. """ + + revocation_endpoint_auth_methods_supported: Optional[List[str]] = None + revocation_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None + + introspection_endpoint: Optional[str] = None + """ + URL of the authorization server's OAuth 2.0 introspection endpoint [[RFC7662](https://www.rfc-editor.org/rfc/rfc7662)]. + """ + + introspection_endpoint_auth_methods_supported: Optional[List[str]] = None + introspection_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None + + code_challenge_methods_supported: Optional[List[str]] = None + """ + JSON array containing a list of Proof Key for Code Exchange (PKCE) [[RFC7636](https://www.rfc-editor.org/rfc/rfc7636)] + code challenge methods supported by this authorization server. + """ + + +class AuthServerType(str, Enum): + """ + The type of the authorization server. This information should be provided by the server + configuration and indicates whether the server is an OAuth 2.0 or OpenID Connect (OIDC) + authorization server. + """ + + OAUTH = "oauth" + OIDC = "oidc" + + +class AuthServerConfig(BaseModel): + """ + Configuration for the remote authorization server integrated with the MCP server. + """ + + metadata: AuthorizationServerMetadata + """ + The metadata of the authorization server, which should conform to the MCP specification + (based on OAuth 2.0 Authorization Server Metadata). + + This metadata is typically fetched from the server's well-known endpoint (OAuth 2.0 + Authorization Server Metadata or OpenID Connect Discovery); it can also be provided + directly in the configuration if the server does not support such endpoints. + + See: + - OAuth 2.0 Authorization Server Metadata: https://datatracker.ietf.org/doc/html/rfc8414 + - OpenID Connect Discovery: https://openid.net/specs/openid-connect-discovery-1_0.html + """ + + type: AuthServerType + """ + The type of the authorization server. See `AuthServerType` for possible values. + """ + + +class ServerMetadataPaths(str, Enum): + """ + Enum for server metadata paths. + This is used to define the standard paths for OAuth and OIDC well-known URLs. + """ + + OAUTH = "/.well-known/oauth-authorization-server" + OIDC = "/.well-known/openid-configuration" diff --git a/mcpauth/middleware/create_bearer_auth.py b/mcpauth/middleware/create_bearer_auth.py index 6c4e575..639d6cf 100644 --- a/mcpauth/middleware/create_bearer_auth.py +++ b/mcpauth/middleware/create_bearer_auth.py @@ -18,9 +18,14 @@ from ..types import VerifyAccessTokenFunction, Record -class BaseBearerAuthConfig(BaseModel): +class BearerAuthConfig(BaseModel): """ - Base configuration for the Bearer auth handler. + Configuration for the Bearer auth handler. + """ + + issuer: str + """ + The expected issuer of the access token. This should be a valid URL. """ audience: Optional[str] = None @@ -42,17 +47,6 @@ class BaseBearerAuthConfig(BaseModel): """ -class BearerAuthConfig(BaseBearerAuthConfig): - """ - Configuration for the Bearer auth handler. - """ - - issuer: str - """ - The expected issuer of the access token. This should be a valid URL. - """ - - def get_bearer_token_from_headers(headers: Headers) -> str: """ Extract the Bearer token from the request headers. diff --git a/mcpauth/models/__init__.py b/mcpauth/models/__init__.py deleted file mode 100644 index f44c556..0000000 --- a/mcpauth/models/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .auth_server import ( - AuthServerConfig as AuthServerConfig, - AuthServerType as AuthServerType, -) -from .oauth import ( - AuthorizationServerMetadata as AuthorizationServerMetadata, -) diff --git a/mcpauth/models/auth_server.py b/mcpauth/models/auth_server.py deleted file mode 100644 index 4dea47a..0000000 --- a/mcpauth/models/auth_server.py +++ /dev/null @@ -1,39 +0,0 @@ -from enum import Enum -from pydantic import BaseModel -from .oauth import AuthorizationServerMetadata - - -class AuthServerType(str, Enum): - """ - The type of the authorization server. This information should be provided by the server - configuration and indicates whether the server is an OAuth 2.0 or OpenID Connect (OIDC) - authorization server. - """ - - OAUTH = "oauth" - OIDC = "oidc" - - -class AuthServerConfig(BaseModel): - """ - Configuration for the remote authorization server integrated with the MCP server. - """ - - metadata: AuthorizationServerMetadata - """ - The metadata of the authorization server, which should conform to the MCP specification - (based on OAuth 2.0 Authorization Server Metadata). - - This metadata is typically fetched from the server's well-known endpoint (OAuth 2.0 - Authorization Server Metadata or OpenID Connect Discovery); it can also be provided - directly in the configuration if the server does not support such endpoints. - - See: - - OAuth 2.0 Authorization Server Metadata: https://datatracker.ietf.org/doc/html/rfc8414 - - OpenID Connect Discovery: https://openid.net/specs/openid-connect-discovery-1_0.html - """ - - type: AuthServerType - """ - The type of the authorization server. See `AuthServerType` for possible values. - """ diff --git a/mcpauth/models/oauth.py b/mcpauth/models/oauth.py deleted file mode 100644 index e4b9420..0000000 --- a/mcpauth/models/oauth.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import List, Optional -from pydantic import BaseModel - - -class AuthorizationServerMetadata(BaseModel): - """ - Pydantic model for OAuth 2.0 Authorization Server Metadata as defined in RFC 8414. - """ - - issuer: str - """ - The authorization server's issuer identifier, which is a URL that uses the `https` scheme and - has no query or fragment components. - """ - - authorization_endpoint: str - """ - URL of the authorization server's authorization endpoint [[RFC6749](https://rfc-editor.org/rfc/rfc6749)]. - This is REQUIRED unless no grant types are supported that use the authorization endpoint. - - See: https://rfc-editor.org/rfc/rfc6749#section-3.1 - """ - - token_endpoint: str - """ - URL of the authorization server's token endpoint [[RFC6749](https://rfc-editor.org/rfc/rfc6749)]. - This is REQUIRED unless only the implicit grant type is supported. - - See: https://rfc-editor.org/rfc/rfc6749#section-3.2 - """ - - jwks_uri: Optional[str] = None - """ - URL of the authorization server's JWK Set [[JWK](https://www.rfc-editor.org/rfc/rfc8414.html#ref-JWK)] document. - The referenced document contains the signing key(s) the client uses to validate signatures - from the authorization server. This URL MUST use the `https` scheme. - """ - - registration_endpoint: Optional[str] = None - """ - URL of the authorization server's OAuth 2.0 Dynamic Client Registration endpoint - [[RFC7591](https://www.rfc-editor.org/rfc/rfc7591)]. - """ - - scope_supported: Optional[List[str]] = None - - response_types_supported: List[str] - """ - JSON array containing a list of the OAuth 2.0 `response_type` values that this authorization - server supports. The array values used are the same as those used with the `response_types` - parameter defined by "OAuth 2.0 Dynamic Client Registration Protocol" [[RFC7591](https://www.rfc-editor.org/rfc/rfc7591)]. - """ - - response_modes_supported: Optional[List[str]] = None - """ - JSON array containing a list of the OAuth 2.0 `response_mode` values that this - authorization server supports, as specified in "OAuth 2.0 Multiple Response Type Encoding Practices" - [[OAuth.Responses](https://datatracker.ietf.org/doc/html/rfc8414#ref-OAuth.Responses)]. - - If omitted, the default is ["query", "fragment"]. The response mode value `"form_post"` is - also defined in "OAuth 2.0 Form Post Response Mode" [[OAuth.FormPost](https://datatracker.ietf.org/doc/html/rfc8414#ref-OAuth.Post)]. - """ - - grant_types_supported: Optional[List[str]] = None - """ - JSON array containing a list of the OAuth 2.0 grant type values that this authorization server supports. - The array values used are the same as those used with the `grant_types` parameter defined by - "OAuth 2.0 Dynamic Client Registration Protocol" [[RFC7591](https://www.rfc-editor.org/rfc/rfc7591)]. - - If omitted, the default value is ["authorization_code", "implicit"]. - """ - - token_endpoint_auth_methods_supported: Optional[List[str]] = None - token_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None - service_documentation: Optional[str] = None - ui_locales_supported: Optional[List[str]] = None - op_policy_uri: Optional[str] = None - op_tos_uri: Optional[str] = None - - revocation_endpoint: Optional[str] = None - """ - URL of the authorization server's OAuth 2.0 revocation endpoint [[RFC7009](https://www.rfc-editor.org/rfc/rfc7009)]. - """ - - revocation_endpoint_auth_methods_supported: Optional[List[str]] = None - revocation_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None - - introspection_endpoint: Optional[str] = None - """ - URL of the authorization server's OAuth 2.0 introspection endpoint [[RFC7662](https://www.rfc-editor.org/rfc/rfc7662)]. - """ - - introspection_endpoint_auth_methods_supported: Optional[List[str]] = None - introspection_endpoint_auth_signing_alg_values_supported: Optional[List[str]] = None - - code_challenge_methods_supported: Optional[List[str]] = None - """ - JSON array containing a list of Proof Key for Code Exchange (PKCE) [[RFC7636](https://www.rfc-editor.org/rfc/rfc7636)] - code challenge methods supported by this authorization server. - """ diff --git a/mcpauth/types.py b/mcpauth/types.py index cf275c8..e88bc3c 100644 --- a/mcpauth/types.py +++ b/mcpauth/types.py @@ -86,7 +86,7 @@ class VerifyAccessTokenFunction(Protocol): """ Function type for verifying an access token. - This function should throw an `MCPAuthJwtVerificationError` if the token is invalid, or return an + This function should throw an `MCPAuthJwtVerificationException` if the token is invalid, or return an `AuthInfo` instance if the token is valid. For example, if you have a JWT verification function, it should at least check the token's diff --git a/mcpauth/utils/__init__.py b/mcpauth/utils/__init__.py index fd1d2c0..511e0b8 100644 --- a/mcpauth/utils/__init__.py +++ b/mcpauth/utils/__init__.py @@ -2,7 +2,6 @@ from ._fetch_server_config import ( fetch_server_config as fetch_server_config, fetch_server_config_by_well_known_url as fetch_server_config_by_well_known_url, - ServerMetadataPaths as ServerMetadataPaths, ) from ._validate_server_config import ( validate_server_config as validate_server_config, diff --git a/mcpauth/utils/_create_verify_jwt.py b/mcpauth/utils/_create_verify_jwt.py index 37d7cc0..cd6e2e8 100644 --- a/mcpauth/utils/_create_verify_jwt.py +++ b/mcpauth/utils/_create_verify_jwt.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, List, Optional, Union +from typing import Annotated, List, Optional, Union from jwt import PyJWK, PyJWKClient, PyJWTError, decode from pydantic import BaseModel, StringConstraints, ValidationError from ..types import AuthInfo, VerifyAccessTokenFunction @@ -23,8 +23,7 @@ class JwtBaseModel(BaseModel): def create_verify_jwt( input: Union[str, PyJWKClient, PyJWK], algorithms: List[str] = ["RS256", "PS256", "ES256", "ES384", "ES512"], - leeway: int = 60, - options: dict[str, Any] = {}, + leeway: float = 60, ) -> VerifyAccessTokenFunction: """ Creates a JWT verification function using the provided JWKS URI. @@ -35,7 +34,6 @@ def create_verify_jwt( - An instance of `PyJWK` that represents a single JWK. :param algorithms: A list of acceptable algorithms for verifying the JWT signature. :param leeway: The amount of leeway (in seconds) to allow when checking the expiration time of the JWT. - :param options: Additional options to pass to the JWT decode function (`jwt.decode`). :return: A function that can be used to verify JWTs. """ @@ -66,8 +64,7 @@ def verify_jwt(token: str) -> AuthInfo: options={ "verify_aud": False, "verify_iss": False, - } - | options, + }, ) base_model = JwtBaseModel(**decoded) scopes = base_model.scope or base_model.scopes diff --git a/mcpauth/utils/_fetch_server_config.py b/mcpauth/utils/_fetch_server_config.py index 42c4d87..b186312 100644 --- a/mcpauth/utils/_fetch_server_config.py +++ b/mcpauth/utils/_fetch_server_config.py @@ -1,4 +1,3 @@ -from enum import Enum from typing import Callable, Optional from urllib.parse import urlparse, urlunparse import requests @@ -6,8 +5,12 @@ from pathlib import Path from ..types import Record -from ..models.oauth import AuthorizationServerMetadata -from ..models.auth_server import AuthServerConfig, AuthServerType +from ..config import ( + AuthServerConfig, + AuthServerType, + ServerMetadataPaths, + AuthorizationServerMetadata, +) from ..exceptions import ( AuthServerExceptionCode, MCPAuthAuthServerException, @@ -15,17 +18,7 @@ ) -class ServerMetadataPaths(str, Enum): - """ - Enum for server metadata paths. - This is used to define the standard paths for OAuth and OIDC well-known URLs. - """ - - OAUTH = "/.well-known/oauth-authorization-server" - OIDC = "/.well-known/openid-configuration" - - -def smart_join(*args: str) -> str: +def _smart_join(*args: str) -> str: """ Joins multiple path components into a single path string, regardless of leading or trailing slashes. @@ -36,13 +29,13 @@ def smart_join(*args: str) -> str: def get_oauth_well_known_url(issuer: str) -> str: parsed_url = urlparse(issuer) - new_path = smart_join(ServerMetadataPaths.OAUTH.value, parsed_url.path) + new_path = _smart_join(ServerMetadataPaths.OAUTH.value, parsed_url.path) return urlunparse(parsed_url._replace(path=new_path)) def get_oidc_well_known_url(issuer: str) -> str: parsed = urlparse(issuer) - new_path = smart_join(parsed.path, ServerMetadataPaths.OIDC.value) + new_path = _smart_join(parsed.path, ServerMetadataPaths.OIDC.value) return urlunparse(parsed._replace(path=new_path)) diff --git a/mcpauth/utils/_validate_server_config.py b/mcpauth/utils/_validate_server_config.py index 81bb542..da96d42 100644 --- a/mcpauth/utils/_validate_server_config.py +++ b/mcpauth/utils/_validate_server_config.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List, Optional from pydantic import BaseModel, ValidationError -from mcpauth.models.auth_server import AuthServerConfig +from ..config import AuthServerConfig class AuthServerConfigErrorCode(str, Enum): diff --git a/samples/server/starlette.py b/samples/server/starlette.py index 340b613..ea34bf1 100644 --- a/samples/server/starlette.py +++ b/samples/server/starlette.py @@ -1,33 +1,33 @@ from mcpauth import MCPAuth -from mcpauth.config import MCPAuthConfig -from mcpauth.models import AuthServerType -from mcpauth.utils import fetch_server_config, ServerMetadataPaths +from mcpauth.config import AuthServerType, ServerMetadataPaths +from mcpauth.utils import fetch_server_config from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.responses import JSONResponse from starlette.requests import Request +from starlette.routing import Route import os MCP_AUTH_ISSUER = ( os.getenv("MCP_AUTH_ISSUER") or "https://replace-with-your-issuer-url.com" ) -mcpAuth = MCPAuth( - MCPAuthConfig(server=fetch_server_config(MCP_AUTH_ISSUER, AuthServerType.OIDC)) -) +mcp_auth = MCPAuth(server=fetch_server_config(MCP_AUTH_ISSUER, AuthServerType.OIDC)) -protected_app = Starlette( - middleware=[Middleware(mcpAuth.bearer_auth_middleware("jwt"))] -) +async def mcp_endpoint(request: Request): + return JSONResponse({"auth": request.state.auth}) -@protected_app.route("/") # type: ignore -async def secret_endpoint(_: Request): - return JSONResponse({"secret": True}) +protected_app = Starlette( + middleware=[ + Middleware(mcp_auth.bearer_auth_middleware("jwt", required_scopes=["read"])) + ], + routes=[Route("/", endpoint=mcp_endpoint)], +) app = Starlette( debug=True, ) -app.mount(ServerMetadataPaths.OAUTH.value, mcpAuth.metadata_response()) +app.mount(ServerMetadataPaths.OAUTH.value, mcp_auth.metadata_response()) app.mount("/mcp", protected_app) diff --git a/tests/__init__test.py b/tests/__init__test.py index 5edffd8..9aff35a 100644 --- a/tests/__init__test.py +++ b/tests/__init__test.py @@ -1,13 +1,7 @@ import pytest from unittest.mock import patch, MagicMock from mcpauth import MCPAuth, MCPAuthAuthServerException, AuthServerExceptionCode -from mcpauth.config import MCPAuthConfig -from mcpauth.models.auth_server import AuthServerConfig, AuthServerType -from mcpauth.models.oauth import AuthorizationServerMetadata -from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig -from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig -from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig -from mcpauth.middleware.create_bearer_auth import BaseBearerAuthConfig +from mcpauth.config import AuthServerConfig, AuthServerType, AuthorizationServerMetadata class TestMCPAuth: @@ -24,13 +18,12 @@ def test_init_with_valid_config(self): code_challenge_methods_supported=["S256"], ), ) - config = MCPAuthConfig(server=server_config) # Exercise - auth = MCPAuth(config) + auth = MCPAuth(server=server_config) # Verify - assert auth.config == config + assert auth.server == server_config def test_init_with_invalid_config(self): # Setup @@ -43,11 +36,10 @@ def test_init_with_invalid_config(self): response_types_supported=["token"], # Invalid response type ), ) - config = MCPAuthConfig(server=server_config) # Exercise & Verify with pytest.raises(MCPAuthAuthServerException) as exc_info: - MCPAuth(config) + MCPAuth(server=server_config) assert exc_info.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG @@ -66,10 +58,9 @@ def test_init_with_warnings(self, mock_warning: MagicMock): # Missing registration_endpoint will cause a warning ), ) - config = MCPAuthConfig(server=server_config) # Exercise - MCPAuth(config) + MCPAuth(server=server_config) # Verify assert mock_warning.called @@ -89,8 +80,7 @@ def test_metadata_response(self): code_challenge_methods_supported=["S256"], ), ) - config = MCPAuthConfig(server=server_config) - auth = MCPAuth(config) + auth = MCPAuth(server=server_config) # Exercise response = auth.metadata_response() @@ -116,20 +106,19 @@ def test_bearer_auth_middleware_jwt_mode(self): code_challenge_methods_supported=["S256"], ), ) - config = MCPAuthConfig(server=server_config) - auth = MCPAuth(config) + auth = MCPAuth(server=server_config) # Exercise with patch("mcpauth.utils.create_verify_jwt") as mock_create_verify_jwt: mock_create_verify_jwt.return_value = MagicMock() middleware_class = auth.bearer_auth_middleware( - "jwt", BaseBearerAuthConfig(required_scopes=["profile"]) + "jwt", required_scopes=["profile"] ) # Verify assert middleware_class is not None mock_create_verify_jwt.assert_called_once_with( - "https://example.com/.well-known/jwks.json", options={} + "https://example.com/.well-known/jwks.json", leeway=60 ) def test_bearer_auth_middleware_custom_verify(self): @@ -145,8 +134,7 @@ def test_bearer_auth_middleware_custom_verify(self): code_challenge_methods_supported=["S256"], ), ) - config = MCPAuthConfig(server=server_config) - auth = MCPAuth(config) + auth = MCPAuth(server=server_config) custom_verify = MagicMock() @@ -155,7 +143,7 @@ def test_bearer_auth_middleware_custom_verify(self): "mcpauth.middleware.create_bearer_auth.create_bearer_auth" ) as mock_create_bearer_auth: middleware_class = auth.bearer_auth_middleware( - custom_verify, BaseBearerAuthConfig(required_scopes=["profile"]) + custom_verify, required_scopes=["profile"] ) # Verify @@ -179,14 +167,11 @@ def test_bearer_auth_middleware_jwt_without_jwks_uri(self): code_challenge_methods_supported=["S256"], ), ) - config = MCPAuthConfig(server=server_config) - auth = MCPAuth(config) + auth = MCPAuth(server=server_config) # Exercise & Verify with pytest.raises(MCPAuthAuthServerException) as exc_info: - auth.bearer_auth_middleware( - "jwt", BaseBearerAuthConfig(required_scopes=["profile"]) - ) + auth.bearer_auth_middleware("jwt", required_scopes=["profile"]) assert exc_info.value.code == AuthServerExceptionCode.MISSING_JWKS_URI @@ -203,14 +188,13 @@ def test_bearer_auth_middleware_invalid_mode(self): code_challenge_methods_supported=["S256"], ), ) - config = MCPAuthConfig(server=server_config) - auth = MCPAuth(config) + auth = MCPAuth(server=server_config) # Exercise & Verify with pytest.raises(ValueError) as exc_info: auth.bearer_auth_middleware( "invalid_mode", # type: ignore - BaseBearerAuthConfig(required_scopes=["profile"]), + required_scopes=["profile"], ) assert "mode_or_verify must be 'jwt' or a callable function" in str( diff --git a/tests/utils/fetch_server_config_test.py b/tests/utils/fetch_server_config_test.py index 8720db8..ccf4de2 100644 --- a/tests/utils/fetch_server_config_test.py +++ b/tests/utils/fetch_server_config_test.py @@ -1,11 +1,10 @@ import pytest import responses -from mcpauth.models.auth_server import AuthServerType +from mcpauth.config import AuthServerType, ServerMetadataPaths from mcpauth.exceptions import MCPAuthAuthServerException, MCPAuthConfigException from mcpauth.types import Record from mcpauth.utils import ( - ServerMetadataPaths, fetch_server_config, fetch_server_config_by_well_known_url, ) @@ -68,15 +67,12 @@ def test_fetch_server_config_by_well_known_url_success_with_transpile(self): "token_endpoint": "https://example.com/oauth/token", } - def transpile(data: Record) -> Record: - return {**data, "response_types_supported": ["code"]} - responses.add(responses.GET, url=sample_well_known_url, json=sample_response) config = fetch_server_config_by_well_known_url( sample_well_known_url, - AuthServerType.OAUTH, - transpile_data=transpile, + type=AuthServerType.OAUTH, + transpile_data=lambda data: {**data, "response_types_supported": ["code"]}, ) assert config.type == AuthServerType.OAUTH diff --git a/tests/utils/validate_server_config_test.py b/tests/utils/validate_server_config_test.py index df1e727..4631152 100644 --- a/tests/utils/validate_server_config_test.py +++ b/tests/utils/validate_server_config_test.py @@ -1,5 +1,4 @@ -from mcpauth.models.auth_server import AuthServerConfig, AuthServerType -from mcpauth.models.oauth import AuthorizationServerMetadata +from mcpauth.config import AuthServerConfig, AuthServerType, AuthorizationServerMetadata from mcpauth.utils import ( validate_server_config, AuthServerConfigErrorCode, From 2081c161964cc11af048730e448f61ce4907c6cb Mon Sep 17 00:00:00 2001 From: Gao Sun Date: Fri, 2 May 2025 12:19:22 -0700 Subject: [PATCH 2/2] refactor: remove duplicate import --- mcpauth/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mcpauth/config.py b/mcpauth/config.py index a0c2e9b..f50c249 100644 --- a/mcpauth/config.py +++ b/mcpauth/config.py @@ -1,7 +1,6 @@ from enum import Enum from typing import List, Optional from pydantic import BaseModel -from pydantic import BaseModel class AuthorizationServerMetadata(BaseModel):