diff --git a/mcpauth/config.py b/mcpauth/config.py index 7c726da..a49bc1a 100644 --- a/mcpauth/config.py +++ b/mcpauth/config.py @@ -117,6 +117,11 @@ class AuthServerType(str, Enum): OIDC = "oidc" +class AuthorizationServerMetadataDefaults(Enum): + grant_types_supported = ["authorization_code", "implicit"] + response_modes_supported = ["query", "fragment"] + + class AuthServerConfig(BaseModel): """ Configuration for the remote authorization server integrated with the MCP server. diff --git a/mcpauth/utils/_validate_server_config.py b/mcpauth/utils/_validate_server_config.py index f0f7f57..ce742a4 100644 --- a/mcpauth/utils/_validate_server_config.py +++ b/mcpauth/utils/_validate_server_config.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Any, Dict, List, Optional from pydantic import BaseModel -from ..config import AuthServerConfig +from ..config import AuthServerConfig, AuthorizationServerMetadataDefaults class AuthServerConfigErrorCode(str, Enum): @@ -97,6 +97,7 @@ def validate_server_config( invalid (`{ is_valid: False }`), along with any errors or warnings encountered during validation. """ + MetadataDefaults = AuthorizationServerMetadataDefaults errors: List[AuthServerConfigError] = [] warnings: List[AuthServerConfigWarning] = [] metadata = config.metadata @@ -112,9 +113,10 @@ def validate_server_config( ) # Check if 'authorization_code' grant type is supported - if ( - not metadata.grant_types_supported - or "authorization_code" not in metadata.grant_types_supported + if "authorization_code" not in ( + metadata.grant_types_supported + if metadata.grant_types_supported is not None + else MetadataDefaults.grant_types_supported.value ): errors.append( _create_error( diff --git a/tests/utils/validate_server_config_test.py b/tests/utils/validate_server_config_test.py index 4631152..b39cc5b 100644 --- a/tests/utils/validate_server_config_test.py +++ b/tests/utils/validate_server_config_test.py @@ -26,6 +26,24 @@ def test_valid_server_config(self): assert not hasattr(result, "errors") or len(result.errors) == 0 assert result.warnings == [] + def test_valid_server_config_no_grant_types(self): + config = AuthServerConfig( + type=AuthServerType.OAUTH, + metadata=AuthorizationServerMetadata( + issuer="https://example.com", + authorization_endpoint="https://example.com/oauth/authorize", + token_endpoint="https://example.com/oauth/token", + response_types_supported=["code"], + code_challenge_methods_supported=["S256"], + registration_endpoint="https://example.com/register", + ), + ) + + result = validate_server_config(config) + assert result.is_valid is True + assert not hasattr(result, "errors") or len(result.errors) == 0 + assert result.warnings == [] + def test_invalid_server_config(self): config = AuthServerConfig( type=AuthServerType.OAUTH, @@ -42,10 +60,6 @@ def test_invalid_server_config(self): error_codes = [error.code for error in result.errors] assert AuthServerConfigErrorCode.CODE_RESPONSE_TYPE_NOT_SUPPORTED in error_codes - assert ( - AuthServerConfigErrorCode.AUTHORIZATION_CODE_GRANT_NOT_SUPPORTED - in error_codes - ) assert AuthServerConfigErrorCode.PKCE_NOT_SUPPORTED in error_codes warning_codes = [warning.code for warning in result.warnings] @@ -78,7 +92,7 @@ def test_warning_for_missing_dynamic_registration(self): ) assert len(result.warnings) == 1 - def test_code_challenge_methods(self): + def test_invalid_code_challenge_methods(self): config = AuthServerConfig( type=AuthServerType.OAUTH, metadata=AuthorizationServerMetadata( @@ -99,3 +113,25 @@ def test_code_challenge_methods(self): AuthServerConfigErrorCode.S256_CODE_CHALLENGE_METHOD_NOT_SUPPORTED in error_codes ) + + def test_invalid_grant_type(self): + config = AuthServerConfig( + type=AuthServerType.OAUTH, + metadata=AuthorizationServerMetadata( + issuer="https://example.com", + authorization_endpoint="https://example.com/oauth/authorize", + token_endpoint="https://example.com/oauth/token", + response_types_supported=["code"], + grant_types_supported=[], # Use empty list on purpose to ensure it should be treated correctly + code_challenge_methods_supported=["S256"], + ), + ) + + result = validate_server_config(config) + assert result.is_valid is False + + error_codes = [error.code for error in result.errors] + assert ( + AuthServerConfigErrorCode.AUTHORIZATION_CODE_GRANT_NOT_SUPPORTED + in error_codes + )