Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
os: [ubuntu-latest]
runs-on: ${{ matrix.os }}

Expand Down
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.9
3.10
107 changes: 92 additions & 15 deletions mcpauth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from contextvars import ContextVar
import logging
from typing import List, Literal, Optional, Union
from typing import Any, Callable, List, Literal, Optional, Union

from .middleware.create_bearer_auth import BearerAuthConfig
from .types import VerifyAccessTokenFunction
from .config import AuthServerConfig
from .types import AuthInfo, VerifyAccessTokenFunction
from .config import AuthServerConfig, ServerMetadataPaths
from .exceptions import MCPAuthAuthServerException, AuthServerExceptionCode
from .utils import validate_server_config
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.responses import Response, JSONResponse
from starlette.requests import Request
from starlette.routing import Route

_context_var_name = "mcp_auth_context"


class MCPAuth:
Expand All @@ -18,9 +23,22 @@ class MCPAuth:
See Also: https://mcp-auth.dev for more information about the library and its usage.
"""

def __init__(self, server: AuthServerConfig):
server: AuthServerConfig
"""
The configuration for the remote authorization server.
"""

def __init__(
self,
server: AuthServerConfig,
context_var: ContextVar[Optional[AuthInfo]] = ContextVar(
_context_var_name, default=None
),
):
"""
:param server: Configuration for the remote authorization server.
:param context_var: Context variable to store the `AuthInfo` object for the current request.
By default, it will be created with the name "mcp_auth_context".
"""

result = validate_server_config(server)
Expand All @@ -40,20 +58,78 @@ def __init__(self, server: AuthServerConfig):
logging.warning(f"- {warning}")

self.server = server
self._context_var = context_var

@property
def auth_info(self) -> Optional[AuthInfo]:
"""
The current `AuthInfo` object from the context variable.

This is useful for accessing the authenticated user's information in later middleware or
route handlers.
:return: The current `AuthInfo` object, or `None` if not set.
"""

return self._context_var.get()

def metadata_response(self) -> JSONResponse:
def metadata_endpoint(self) -> Callable[[Request], Any]:
"""
Returns a response containing the server metadata in JSON format with CORS support.
Returns a Starlette endpoint function that handles the OAuth 2.0 Authorization Metadata
endpoint (`/.well-known/oauth-authorization-server`) with CORS support.

Example:
```python
from starlette.applications import Starlette
from mcpauth import MCPAuth
from mcpauth.config import ServerMetadataPaths

mcp_auth = MCPAuth(server=your_server_config)
app = Starlette(routes=[
Route(
ServerMetadataPaths.OAUTH.value,
mcp_auth.metadata_endpoint(),
methods=["GET", "OPTIONS"] # Ensure to handle both GET and OPTIONS methods
)
])
```
"""

async def endpoint(request: Request) -> Response:
if request.method == "OPTIONS":
response = Response(status_code=204)
else:
server_config = self.server
response = JSONResponse(
server_config.metadata.model_dump(exclude_none=True),
status_code=200,
)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "*"
return response

return endpoint

def metadata_route(self) -> Route:
"""
Returns a Starlette route that handles the OAuth 2.0 Authorization Metadata endpoint
(`/.well-known/oauth-authorization-server`) with CORS support.

Example:
```python
from starlette.applications import Starlette
from mcpauth import MCPAuth

mcp_auth = MCPAuth(server=your_server_config)
app = Starlette(routes=[mcp_auth.metadata_route()])
```
"""
server_config = self.server

response = JSONResponse(
server_config.metadata.model_dump(exclude_none=True),
status_code=200,
return Route(
ServerMetadataPaths.OAUTH.value,
self.metadata_endpoint(),
methods=["GET", "OPTIONS"],
)
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS"
return response

def bearer_auth_middleware(
self,
Expand Down Expand Up @@ -101,10 +177,11 @@ def bearer_auth_middleware(

return create_bearer_auth(
verify,
BearerAuthConfig(
config=BearerAuthConfig(
issuer=metadata.issuer,
audience=audience,
required_scopes=required_scopes,
show_error_details=show_error_details,
),
context_var=self._context_var,
)
5 changes: 5 additions & 0 deletions mcpauth/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class AuthorizationServerMetadata(BaseModel):
code challenge methods supported by this authorization server.
"""

userinfo_endpoint: Optional[str] = None
"""
URL of the authorization server's UserInfo endpoint [[OpenID Connect](https://openid.net/specs/openid-connect-core-1_0.html#UserInfo)].
"""


class AuthServerType(str, Enum):
"""
Expand Down
24 changes: 12 additions & 12 deletions mcpauth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,31 +143,31 @@ def to_json(self, show_cause: bool = False) -> Dict[str, Optional[str]]:
return {k: v for k, v in data.items() if v is not None}


class MCPAuthJwtVerificationExceptionCode(str, Enum):
INVALID_JWT = "invalid_jwt"
JWT_VERIFICATION_FAILED = "jwt_verification_failed"
class MCPAuthTokenVerificationExceptionCode(str, Enum):
INVALID_TOKEN = "invalid_token"
TOKEN_VERIFICATION_FAILED = "token_verification_failed"


jwt_verification_exception_description: Dict[
MCPAuthJwtVerificationExceptionCode, str
token_verification_exception_description: Dict[
MCPAuthTokenVerificationExceptionCode, str
] = {
MCPAuthJwtVerificationExceptionCode.INVALID_JWT: "The provided JWT is invalid or malformed.",
MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED: "JWT verification failed. The token could not be verified.",
MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN: "The provided token is invalid or malformed.",
MCPAuthTokenVerificationExceptionCode.TOKEN_VERIFICATION_FAILED: "The token verification failed due to an error in the verification process.",
}


class MCPAuthJwtVerificationException(MCPAuthException):
class MCPAuthTokenVerificationException(MCPAuthException):
"""
Exception thrown when there is an issue when verifying JWT tokens.
Exception thrown when there is an issue when verifying access tokens.
"""

def __init__(
self, code: MCPAuthJwtVerificationExceptionCode, cause: ExceptionCause = None
self, code: MCPAuthTokenVerificationExceptionCode, cause: ExceptionCause = None
):
super().__init__(
code.value,
jwt_verification_exception_description.get(
code, "An exception occurred while verifying the JWT."
token_verification_exception_description.get(
code, "An exception occurred while verifying the token."
),
)
self.code = code
Expand Down
29 changes: 18 additions & 11 deletions mcpauth/middleware/create_bearer_auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextvars import ContextVar
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import logging
Expand All @@ -9,13 +10,13 @@

from ..exceptions import (
MCPAuthBearerAuthException,
MCPAuthJwtVerificationException,
MCPAuthTokenVerificationException,
MCPAuthAuthServerException,
MCPAuthConfigException,
BearerAuthExceptionCode,
MCPAuthBearerAuthExceptionDetails,
)
from ..types import VerifyAccessTokenFunction, Record
from ..types import AuthInfo, VerifyAccessTokenFunction, Record


class BearerAuthConfig(BaseModel):
Expand Down Expand Up @@ -92,7 +93,7 @@ def _handle_error(
Returns:
A tuple of (status_code, response_body).
"""
if isinstance(error, MCPAuthJwtVerificationException):
if isinstance(error, MCPAuthTokenVerificationException):
return 401, error.to_json(show_error_details)

if isinstance(error, MCPAuthBearerAuthException):
Expand All @@ -114,20 +115,22 @@ def _handle_error(


def create_bearer_auth(
verify_access_token: VerifyAccessTokenFunction, config: BearerAuthConfig
verify_access_token: VerifyAccessTokenFunction,
config: BearerAuthConfig,
context_var: ContextVar[Optional[AuthInfo]],
) -> type[BaseHTTPMiddleware]:
"""
Creates a middleware function for handling Bearer auth.

This middleware extracts the Bearer token from the `Authorization` header, verifies it using the
provided `verify_access_token` function, and checks the issuer, audience, and required scopes.

Args:
verify_access_token: A function that takes a Bearer token and returns an `AuthInfo` object.
config: Configuration for the Bearer auth handler.
:param verify_access_token: A function that takes a Bearer token and returns an `AuthInfo` object.
:param config: Configuration for the Bearer auth handler.
:param context_var: Context variable to store the `AuthInfo` object for the current request.
This allows access to the authenticated user's information in later middleware or route handlers.

Returns:
A middleware class that handles Bearer auth.
:return: A middleware class that handles Bearer auth.
"""

if not callable(verify_access_token):
Expand Down Expand Up @@ -206,8 +209,12 @@ async def dispatch(
cause=details,
)

# Attach auth info to the request
request.state.auth = auth_info
if context_var.get() is not None:
logging.warning(
"Overwriting existing auth info in context variable."
)

context_var.set(auth_info)

# Call the next middleware or route handler
response = await call_next(request)
Expand Down
Loading