From aa5a0709d974e6cd50b420018f45fa8fb2ebea7a Mon Sep 17 00:00:00 2001 From: Gao Sun Date: Wed, 30 Apr 2025 07:05:09 -0700 Subject: [PATCH] feat: create verify jwt --- mcpauth/exceptions.py | 2 - mcpauth/middleware/create_bearer_auth.py | 2 - mcpauth/types.py | 2 +- mcpauth/utils/create_verify_jwt.py | 95 +++++++++++ mcpauth/utils/create_verify_jwt_test.py | 204 +++++++++++++++++++++++ 5 files changed, 300 insertions(+), 5 deletions(-) create mode 100644 mcpauth/utils/create_verify_jwt.py create mode 100644 mcpauth/utils/create_verify_jwt_test.py diff --git a/mcpauth/exceptions.py b/mcpauth/exceptions.py index e949e65..c10ebb5 100644 --- a/mcpauth/exceptions.py +++ b/mcpauth/exceptions.py @@ -146,7 +146,6 @@ def to_json(self, show_cause: bool = False) -> Dict[str, Optional[str]]: class MCPAuthJwtVerificationExceptionCode(str, Enum): INVALID_JWT = "invalid_jwt" JWT_VERIFICATION_FAILED = "jwt_verification_failed" - JWT_EXPIRED = "jwt_expired" jwt_verification_exception_description: Dict[ @@ -154,7 +153,6 @@ class MCPAuthJwtVerificationExceptionCode(str, Enum): ] = { MCPAuthJwtVerificationExceptionCode.INVALID_JWT: "The provided JWT is invalid or malformed.", MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED: "JWT verification failed. The token could not be verified.", - MCPAuthJwtVerificationExceptionCode.JWT_EXPIRED: "The provided JWT has expired.", } diff --git a/mcpauth/middleware/create_bearer_auth.py b/mcpauth/middleware/create_bearer_auth.py index 152a410..bbfa9b2 100644 --- a/mcpauth/middleware/create_bearer_auth.py +++ b/mcpauth/middleware/create_bearer_auth.py @@ -51,8 +51,6 @@ def get_bearer_token_from_headers(headers: Headers) -> str: auth_header = headers.get("authorization") or headers.get("Authorization") - print(f"Authorization header: {auth_header}") - if not auth_header: raise MCPAuthBearerAuthException(BearerAuthExceptionCode.MISSING_AUTH_HEADER) diff --git a/mcpauth/types.py b/mcpauth/types.py index e768b9c..cf275c8 100644 --- a/mcpauth/types.py +++ b/mcpauth/types.py @@ -75,7 +75,7 @@ class AuthInfo(BaseModel): - https://datatracker.ietf.org/doc/html/rfc8707 """ - claims: Optional[Dict[str, Any]] + claims: Dict[str, Any] """ The raw claims from the token, which can include any additional information provided by the token issuer. diff --git a/mcpauth/utils/create_verify_jwt.py b/mcpauth/utils/create_verify_jwt.py new file mode 100644 index 0000000..37d7cc0 --- /dev/null +++ b/mcpauth/utils/create_verify_jwt.py @@ -0,0 +1,95 @@ +from typing import Annotated, Any, List, Optional, Union +from jwt import PyJWK, PyJWKClient, PyJWTError, decode +from pydantic import BaseModel, StringConstraints, ValidationError +from ..types import AuthInfo, VerifyAccessTokenFunction +from ..exceptions import ( + MCPAuthJwtVerificationException, + MCPAuthJwtVerificationExceptionCode, +) + +NonEmptyString = Annotated[str, StringConstraints(min_length=1)] + + +class JwtBaseModel(BaseModel): + aud: Optional[Union[NonEmptyString, List[NonEmptyString]]] = None + iss: NonEmptyString + client_id: NonEmptyString + sub: NonEmptyString + scope: Optional[Union[str, List[str]]] = None + scopes: Optional[Union[str, List[str]]] = None + exp: Optional[int] = None + + +def create_verify_jwt( + input: Union[str, PyJWKClient, PyJWK], + algorithms: List[str] = ["RS256", "PS256", "ES256", "ES384", "ES512"], + leeway: int = 60, + options: dict[str, Any] = {}, +) -> VerifyAccessTokenFunction: + """ + Creates a JWT verification function using the provided JWKS URI. + + :param input: Supports one of the following: + - A JWKS URI (string) that points to a JSON Web Key Set. + - An instance of `PyJWKClient` that has been initialized with the JWKS URI. + - An instance of `PyJWK` that represents a single JWK. + :param algorithms: A list of acceptable algorithms for verifying the JWT signature. + :param leeway: The amount of leeway (in seconds) to allow when checking the expiration time of the JWT. + :param options: Additional options to pass to the JWT decode function (`jwt.decode`). + :return: A function that can be used to verify JWTs. + """ + + jwks = ( + input + if isinstance(input, PyJWKClient) + else ( + PyJWKClient( + input, headers={"user-agent": "@mcp-auth/python", "accept": "*/*"} + ) + if isinstance(input, str) + else input + ) + ) + + def verify_jwt(token: str) -> AuthInfo: + try: + signing_key = ( + jwks.get_signing_key_from_jwt(token) + if isinstance(jwks, PyJWKClient) + else jwks + ) + decoded = decode( + token, + signing_key.key, + algorithms=algorithms, + leeway=leeway, + options={ + "verify_aud": False, + "verify_iss": False, + } + | options, + ) + base_model = JwtBaseModel(**decoded) + scopes = base_model.scope or base_model.scopes + return AuthInfo( + token=token, + issuer=base_model.iss, + client_id=base_model.client_id, + subject=base_model.sub, + audience=base_model.aud, + scopes=(scopes.split(" ") if isinstance(scopes, str) else scopes) or [], + expires_at=base_model.exp, + claims=decoded, + ) + except (PyJWTError, ValidationError) as e: + raise MCPAuthJwtVerificationException( + MCPAuthJwtVerificationExceptionCode.INVALID_JWT, + cause=e, + ) + except Exception as e: + raise MCPAuthJwtVerificationException( + MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED, + cause=e, + ) + + return verify_jwt diff --git a/mcpauth/utils/create_verify_jwt_test.py b/mcpauth/utils/create_verify_jwt_test.py new file mode 100644 index 0000000..c79bfa7 --- /dev/null +++ b/mcpauth/utils/create_verify_jwt_test.py @@ -0,0 +1,204 @@ +import pytest +import time +import jwt +import base64 +from typing import Dict, Any +from mcpauth.utils.create_verify_jwt import create_verify_jwt +from mcpauth.types import AuthInfo + + +from mcpauth.exceptions import ( + MCPAuthJwtVerificationException, + MCPAuthJwtVerificationExceptionCode, +) + +_secret_key = b"super-secret-key-for-testing" +_algorithm = "HS256" + + +def create_jwk(key: bytes = _secret_key) -> jwt.PyJWK: + """Create a JWK for testing purposes""" + return jwt.PyJWK( + { + "kty": "oct", + "k": base64.urlsafe_b64encode(key).decode("utf-8"), + "alg": _algorithm, + } + ) + + +def create_jwt(payload: Dict[str, Any]) -> str: + """Create a test JWT with the given payload""" + return jwt.encode( + { + **payload, + "iat": int(time.time()), + "exp": int(time.time()) + 3600, # 1 hour + }, + _secret_key, + algorithm=_algorithm, + ) + + +verify_jwt = create_verify_jwt(create_jwk(), algorithms=[_algorithm]) + + +class TestCreateVerifyJwtErrorHandling: + def test_should_throw_error_if_signature_verification_fails(self): + # Create JWT with correct secret + jwt_token = create_jwt({"client_id": "client12345", "sub": "user12345"}) + verify_jwt = create_verify_jwt( + create_jwk(b"wrong-secret-key-for-testing"), algorithms=[_algorithm] + ) + + # Verify that the correct exception is raised + with pytest.raises(MCPAuthJwtVerificationException) as exc_info: + verify_jwt(jwt_token) + + assert exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT + assert isinstance(exc_info.value.cause, jwt.InvalidSignatureError) + + def test_should_throw_error_if_jwt_payload_missing_iss(self): + # Test different invalid JWT payloads + jwt_missing_iss = create_jwt({"client_id": "client12345", "sub": "user12345"}) + jwt_invalid_iss_type = create_jwt( + {"iss": 12345, "client_id": "client12345", "sub": "user12345"} + ) + jwt_empty_iss = create_jwt( + {"iss": "", "client_id": "client12345", "sub": "user12345"} + ) + + for token in [jwt_missing_iss, jwt_invalid_iss_type, jwt_empty_iss]: + with pytest.raises(MCPAuthJwtVerificationException) as exc_info: + verify_jwt(token) + assert ( + exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT + ) + + def test_should_throw_error_if_jwt_payload_missing_client_id(self): + # Test different invalid JWT payloads + jwt_missing_client_id = create_jwt( + {"iss": "https://logto.io/", "sub": "user12345"} + ) + jwt_invalid_client_id_type = create_jwt( + {"iss": "https://logto.io/", "client_id": 12345, "sub": "user12345"} + ) + jwt_empty_client_id = create_jwt( + {"iss": "https://logto.io/", "client_id": "", "sub": "user12345"} + ) + + for token in [ + jwt_missing_client_id, + jwt_invalid_client_id_type, + jwt_empty_client_id, + ]: + with pytest.raises(MCPAuthJwtVerificationException) as exc_info: + verify_jwt(token) + assert ( + exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT + ) + + def test_should_throw_error_if_jwt_payload_missing_sub(self): + # Test different invalid JWT payloads + jwt_missing_sub = create_jwt( + {"iss": "https://logto.io/", "client_id": "client12345"} + ) + jwt_invalid_sub_type = create_jwt( + {"iss": "https://logto.io/", "client_id": "client12345", "sub": 12345} + ) + jwt_empty_sub = create_jwt( + {"iss": "https://logto.io/", "client_id": "client12345", "sub": ""} + ) + + for token in [jwt_missing_sub, jwt_invalid_sub_type, jwt_empty_sub]: + with pytest.raises(MCPAuthJwtVerificationException) as exc_info: + verify_jwt(token) + assert ( + exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT + ) + + +class TestCreateVerifyJwtNormalBehavior: + def test_should_return_verified_jwt_payload_with_string_scope(self): + # Create JWT with string scope + claims = { + "iss": "https://logto.io/", + "client_id": "client12345", + "sub": "user12345", + "scope": "read write", + "aud": "audience12345", + } + jwt_token = create_jwt(claims) + + # Verify + result = verify_jwt(jwt_token) + + # Assertions + assert isinstance(result, AuthInfo) + assert result.token == jwt_token + assert result.issuer == claims["iss"] + assert result.client_id == claims["client_id"] + assert result.subject == claims["sub"] + assert result.audience == claims["aud"] + assert result.scopes == ["read", "write"] + assert "exp" in result.claims + assert "iat" in result.claims + assert result.expires_at is not None + + def test_should_return_verified_jwt_payload_with_array_scope(self): + # Create JWT with array scope + claims: Dict[str, Any] = { + "iss": "https://logto.io/", + "client_id": "client12345", + "sub": "user12345", + "scope": ["read", "write"], + } + jwt_token = create_jwt(claims) + + # Verify + result = verify_jwt(jwt_token) + + # Assertions + assert result.issuer == claims["iss"] + assert result.client_id == claims["client_id"] + assert result.subject == claims["sub"] + assert result.scopes == ["read", "write"] + + def test_should_return_verified_jwt_payload_with_scopes_field(self): + # Create JWT with scopes field + claims: Dict[str, Any] = { + "iss": "https://logto.io/", + "client_id": "client12345", + "sub": "user12345", + "scopes": ["read", "write"], + } + jwt_token = create_jwt(claims) + + # Verify + result = verify_jwt(jwt_token) + + # Assertions + assert result.issuer == claims["iss"] + assert result.client_id == claims["client_id"] + assert result.subject == claims["sub"] + assert result.scopes == ["read", "write"] + + def test_should_return_verified_jwt_payload_without_scopes(self): + # Create JWT without scope or scopes + claims = { + "iss": "https://logto.io/", + "client_id": "client12345", + "sub": "user12345", + "aud": "audience12345", + } + jwt_token = create_jwt(claims) + + # Verify + result = verify_jwt(jwt_token) + + # Assertions + assert result.issuer == claims["iss"] + assert result.client_id == claims["client_id"] + assert result.subject == claims["sub"] + assert result.audience == claims["aud"] + assert result.scopes == []