Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mcpauth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,13 @@ def to_json(self, show_cause: bool = False) -> Dict[str, Optional[str]]:
class MCPAuthJwtVerificationExceptionCode(str, Enum):
INVALID_JWT = "invalid_jwt"
JWT_VERIFICATION_FAILED = "jwt_verification_failed"
JWT_EXPIRED = "jwt_expired"


jwt_verification_exception_description: Dict[
MCPAuthJwtVerificationExceptionCode, str
] = {
MCPAuthJwtVerificationExceptionCode.INVALID_JWT: "The provided JWT is invalid or malformed.",
MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED: "JWT verification failed. The token could not be verified.",
MCPAuthJwtVerificationExceptionCode.JWT_EXPIRED: "The provided JWT has expired.",
}


Expand Down
2 changes: 0 additions & 2 deletions mcpauth/middleware/create_bearer_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def get_bearer_token_from_headers(headers: Headers) -> str:

auth_header = headers.get("authorization") or headers.get("Authorization")

print(f"Authorization header: {auth_header}")

if not auth_header:
raise MCPAuthBearerAuthException(BearerAuthExceptionCode.MISSING_AUTH_HEADER)

Expand Down
2 changes: 1 addition & 1 deletion mcpauth/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class AuthInfo(BaseModel):
- https://datatracker.ietf.org/doc/html/rfc8707
"""

claims: Optional[Dict[str, Any]]
claims: Dict[str, Any]
"""
The raw claims from the token, which can include any additional information provided by the
token issuer.
Expand Down
95 changes: 95 additions & 0 deletions mcpauth/utils/create_verify_jwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Annotated, Any, List, Optional, Union
from jwt import PyJWK, PyJWKClient, PyJWTError, decode
from pydantic import BaseModel, StringConstraints, ValidationError
from ..types import AuthInfo, VerifyAccessTokenFunction
from ..exceptions import (
MCPAuthJwtVerificationException,
MCPAuthJwtVerificationExceptionCode,
)

NonEmptyString = Annotated[str, StringConstraints(min_length=1)]


class JwtBaseModel(BaseModel):
aud: Optional[Union[NonEmptyString, List[NonEmptyString]]] = None
iss: NonEmptyString
client_id: NonEmptyString
sub: NonEmptyString
scope: Optional[Union[str, List[str]]] = None
scopes: Optional[Union[str, List[str]]] = None
exp: Optional[int] = None


def create_verify_jwt(
input: Union[str, PyJWKClient, PyJWK],
algorithms: List[str] = ["RS256", "PS256", "ES256", "ES384", "ES512"],
leeway: int = 60,
options: dict[str, Any] = {},
) -> VerifyAccessTokenFunction:
"""
Creates a JWT verification function using the provided JWKS URI.

:param input: Supports one of the following:
- A JWKS URI (string) that points to a JSON Web Key Set.
- An instance of `PyJWKClient` that has been initialized with the JWKS URI.
- An instance of `PyJWK` that represents a single JWK.
:param algorithms: A list of acceptable algorithms for verifying the JWT signature.
:param leeway: The amount of leeway (in seconds) to allow when checking the expiration time of the JWT.
:param options: Additional options to pass to the JWT decode function (`jwt.decode`).
:return: A function that can be used to verify JWTs.
"""

jwks = (
input
if isinstance(input, PyJWKClient)
else (
PyJWKClient(
input, headers={"user-agent": "@mcp-auth/python", "accept": "*/*"}
)
if isinstance(input, str)
else input
)
)

def verify_jwt(token: str) -> AuthInfo:
try:
signing_key = (
jwks.get_signing_key_from_jwt(token)
if isinstance(jwks, PyJWKClient)
else jwks
)
decoded = decode(
token,
signing_key.key,
algorithms=algorithms,
leeway=leeway,
options={
"verify_aud": False,
"verify_iss": False,
}
| options,
)
base_model = JwtBaseModel(**decoded)
scopes = base_model.scope or base_model.scopes
return AuthInfo(
token=token,
issuer=base_model.iss,
client_id=base_model.client_id,
subject=base_model.sub,
audience=base_model.aud,
scopes=(scopes.split(" ") if isinstance(scopes, str) else scopes) or [],
expires_at=base_model.exp,
claims=decoded,
)
except (PyJWTError, ValidationError) as e:
raise MCPAuthJwtVerificationException(
MCPAuthJwtVerificationExceptionCode.INVALID_JWT,
cause=e,
)
except Exception as e:
raise MCPAuthJwtVerificationException(
MCPAuthJwtVerificationExceptionCode.JWT_VERIFICATION_FAILED,
cause=e,
)

return verify_jwt
204 changes: 204 additions & 0 deletions mcpauth/utils/create_verify_jwt_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import pytest
import time
import jwt
import base64
from typing import Dict, Any
from mcpauth.utils.create_verify_jwt import create_verify_jwt
from mcpauth.types import AuthInfo


from mcpauth.exceptions import (
MCPAuthJwtVerificationException,
MCPAuthJwtVerificationExceptionCode,
)

_secret_key = b"super-secret-key-for-testing"
_algorithm = "HS256"


def create_jwk(key: bytes = _secret_key) -> jwt.PyJWK:
"""Create a JWK for testing purposes"""
return jwt.PyJWK(
{
"kty": "oct",
"k": base64.urlsafe_b64encode(key).decode("utf-8"),
"alg": _algorithm,
}
)


def create_jwt(payload: Dict[str, Any]) -> str:
"""Create a test JWT with the given payload"""
return jwt.encode(
{
**payload,
"iat": int(time.time()),
"exp": int(time.time()) + 3600, # 1 hour
},
_secret_key,
algorithm=_algorithm,
)


verify_jwt = create_verify_jwt(create_jwk(), algorithms=[_algorithm])


class TestCreateVerifyJwtErrorHandling:
def test_should_throw_error_if_signature_verification_fails(self):
# Create JWT with correct secret
jwt_token = create_jwt({"client_id": "client12345", "sub": "user12345"})
verify_jwt = create_verify_jwt(
create_jwk(b"wrong-secret-key-for-testing"), algorithms=[_algorithm]
)

# Verify that the correct exception is raised
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
verify_jwt(jwt_token)

assert exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
assert isinstance(exc_info.value.cause, jwt.InvalidSignatureError)

def test_should_throw_error_if_jwt_payload_missing_iss(self):
# Test different invalid JWT payloads
jwt_missing_iss = create_jwt({"client_id": "client12345", "sub": "user12345"})
jwt_invalid_iss_type = create_jwt(
{"iss": 12345, "client_id": "client12345", "sub": "user12345"}
)
jwt_empty_iss = create_jwt(
{"iss": "", "client_id": "client12345", "sub": "user12345"}
)

for token in [jwt_missing_iss, jwt_invalid_iss_type, jwt_empty_iss]:
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
verify_jwt(token)
assert (
exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
)

def test_should_throw_error_if_jwt_payload_missing_client_id(self):
# Test different invalid JWT payloads
jwt_missing_client_id = create_jwt(
{"iss": "https://logto.io/", "sub": "user12345"}
)
jwt_invalid_client_id_type = create_jwt(
{"iss": "https://logto.io/", "client_id": 12345, "sub": "user12345"}
)
jwt_empty_client_id = create_jwt(
{"iss": "https://logto.io/", "client_id": "", "sub": "user12345"}
)

for token in [
jwt_missing_client_id,
jwt_invalid_client_id_type,
jwt_empty_client_id,
]:
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
verify_jwt(token)
assert (
exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
)

def test_should_throw_error_if_jwt_payload_missing_sub(self):
# Test different invalid JWT payloads
jwt_missing_sub = create_jwt(
{"iss": "https://logto.io/", "client_id": "client12345"}
)
jwt_invalid_sub_type = create_jwt(
{"iss": "https://logto.io/", "client_id": "client12345", "sub": 12345}
)
jwt_empty_sub = create_jwt(
{"iss": "https://logto.io/", "client_id": "client12345", "sub": ""}
)

for token in [jwt_missing_sub, jwt_invalid_sub_type, jwt_empty_sub]:
with pytest.raises(MCPAuthJwtVerificationException) as exc_info:
verify_jwt(token)
assert (
exc_info.value.code == MCPAuthJwtVerificationExceptionCode.INVALID_JWT
)


class TestCreateVerifyJwtNormalBehavior:
def test_should_return_verified_jwt_payload_with_string_scope(self):
# Create JWT with string scope
claims = {
"iss": "https://logto.io/",
"client_id": "client12345",
"sub": "user12345",
"scope": "read write",
"aud": "audience12345",
}
jwt_token = create_jwt(claims)

# Verify
result = verify_jwt(jwt_token)

# Assertions
assert isinstance(result, AuthInfo)
assert result.token == jwt_token
assert result.issuer == claims["iss"]
assert result.client_id == claims["client_id"]
assert result.subject == claims["sub"]
assert result.audience == claims["aud"]
assert result.scopes == ["read", "write"]
assert "exp" in result.claims
assert "iat" in result.claims
assert result.expires_at is not None

def test_should_return_verified_jwt_payload_with_array_scope(self):
# Create JWT with array scope
claims: Dict[str, Any] = {
"iss": "https://logto.io/",
"client_id": "client12345",
"sub": "user12345",
"scope": ["read", "write"],
}
jwt_token = create_jwt(claims)

# Verify
result = verify_jwt(jwt_token)

# Assertions
assert result.issuer == claims["iss"]
assert result.client_id == claims["client_id"]
assert result.subject == claims["sub"]
assert result.scopes == ["read", "write"]

def test_should_return_verified_jwt_payload_with_scopes_field(self):
# Create JWT with scopes field
claims: Dict[str, Any] = {
"iss": "https://logto.io/",
"client_id": "client12345",
"sub": "user12345",
"scopes": ["read", "write"],
}
jwt_token = create_jwt(claims)

# Verify
result = verify_jwt(jwt_token)

# Assertions
assert result.issuer == claims["iss"]
assert result.client_id == claims["client_id"]
assert result.subject == claims["sub"]
assert result.scopes == ["read", "write"]

def test_should_return_verified_jwt_payload_without_scopes(self):
# Create JWT without scope or scopes
claims = {
"iss": "https://logto.io/",
"client_id": "client12345",
"sub": "user12345",
"aud": "audience12345",
}
jwt_token = create_jwt(claims)

# Verify
result = verify_jwt(jwt_token)

# Assertions
assert result.issuer == claims["iss"]
assert result.client_id == claims["client_id"]
assert result.subject == claims["sub"]
assert result.audience == claims["aud"]
assert result.scopes == []