diff --git a/src/enapter_mcp_server/cli/serve_command.py b/src/enapter_mcp_server/cli/serve_command.py index ea13218..09c0284 100644 --- a/src/enapter_mcp_server/cli/serve_command.py +++ b/src/enapter_mcp_server/cli/serve_command.py @@ -31,6 +31,9 @@ ENAPTER_OAUTH_PROXY_REQUIRED_SCOPES = os.getenv( "ENAPTER_OAUTH_PROXY_REQUIRED_SCOPES", "openid,public" ) +ENAPTER_OAUTH_PROXY_ALLOWED_REDIRECT_URLS = os.getenv( + "ENAPTER_OAUTH_PROXY_ALLOWED_REDIRECT_URLS", "" +) ENAPTER_OAUTH_PROXY_CLIENT_ID = os.getenv("ENAPTER_OAUTH_PROXY_CLIENT_ID") ENAPTER_OAUTH_PROXY_CLIENT_SECRET = os.getenv("ENAPTER_OAUTH_PROXY_CLIENT_SECRET") ENAPTER_OAUTH_PROXY_JWT_STORE_URL = os.getenv( @@ -99,6 +102,11 @@ def register(parent: Subparsers) -> None: default=ENAPTER_OAUTH_PROXY_REQUIRED_SCOPES, help="Comma-separated list of required scopes for OAuth proxy", ) + parser.add_argument( + "--oauth-proxy-allowed-redirect-urls", + default=ENAPTER_OAUTH_PROXY_ALLOWED_REDIRECT_URLS, + help="Comma-separated list of allowed redirect URLs for OAuth proxy", + ) parser.add_argument( "--oauth-proxy-client-id", default=ENAPTER_OAUTH_PROXY_CLIENT_ID, @@ -132,6 +140,11 @@ async def run(args: argparse.Namespace) -> None: for scope in args.oauth_proxy_required_scopes.split(",") if scope.strip() ] + allowed_redirect_urls = [ + url.strip() + for url in args.oauth_proxy_allowed_redirect_urls.split(",") + if url.strip() + ] oauth_proxy_config = mcp.OAuthProxyConfig( introspection_endpoint_url=args.oauth_proxy_introspection_url, authorization_endpoint_url=args.oauth_proxy_authorization_url, @@ -142,6 +155,9 @@ async def run(args: argparse.Namespace) -> None: required_scopes=required_scopes, client_id=args.oauth_proxy_client_id, client_secret=args.oauth_proxy_client_secret, + allowed_redirect_urls=( + allowed_redirect_urls if allowed_redirect_urls else None + ), jwt_store_url=args.oauth_proxy_jwt_store_url, jwt_signing_key=args.oauth_proxy_jwt_signing_key, ) diff --git a/src/enapter_mcp_server/mcp/oauth_proxy_config.py b/src/enapter_mcp_server/mcp/oauth_proxy_config.py index 6d35cb7..80872da 100644 --- a/src/enapter_mcp_server/mcp/oauth_proxy_config.py +++ b/src/enapter_mcp_server/mcp/oauth_proxy_config.py @@ -13,5 +13,6 @@ class OAuthProxyConfig: required_scopes: list[str] client_id: str client_secret: str + allowed_redirect_urls: list[str] | None = None jwt_store_url: str | None = None jwt_signing_key: str | None = None diff --git a/src/enapter_mcp_server/mcp/server.py b/src/enapter_mcp_server/mcp/server.py index c7b41a1..2252752 100644 --- a/src/enapter_mcp_server/mcp/server.py +++ b/src/enapter_mcp_server/mcp/server.py @@ -86,6 +86,7 @@ def _select_auth_provider(self) -> fastmcp.server.auth.AuthProvider | None: token_verifier=token_verifier, base_url=self._config.oauth_proxy.protected_resource_url, forward_pkce=self._config.oauth_proxy.forward_pkce, + allowed_client_redirect_uris=self._config.oauth_proxy.allowed_redirect_urls, client_storage=jwt_store, jwt_signing_key=self._config.oauth_proxy.jwt_signing_key, ) diff --git a/tests/unit/cli/test_serve_command.py b/tests/unit/cli/test_serve_command.py new file mode 100644 index 0000000..eeb41d3 --- /dev/null +++ b/tests/unit/cli/test_serve_command.py @@ -0,0 +1,82 @@ +import argparse +import asyncio +from unittest.mock import patch + +import pytest + +from enapter_mcp_server.cli.serve_command import ServeCommand + + +class TestServeCommand: + @pytest.mark.asyncio + async def test_run_parses_allowed_redirect_urls(self): + args = argparse.Namespace( + address="127.0.0.1:8080", + enapter_http_api_url="https://api.example.com", + logo_url="https://logo.example.com", + oauth_proxy_enabled="1", + oauth_proxy_introspection_url="https://sso.example.com/introspect", + oauth_proxy_authorization_url="https://sso.example.com/authorize", + oauth_proxy_token_url="https://sso.example.com/token", + oauth_proxy_user_info_url="https://sso.example.com/me", + oauth_proxy_protected_resource_url="https://mcp.example.com", + oauth_proxy_forward_pkce="1", + oauth_proxy_required_scopes="openid,profile", + oauth_proxy_allowed_redirect_urls="http://localhost:3000,http://localhost:3001", + oauth_proxy_client_id="client_id", + oauth_proxy_client_secret="client_secret", + oauth_proxy_jwt_store_url="memory://", + oauth_proxy_jwt_signing_key="signing_key", + verbose=False, + ) + + with patch("enapter_mcp_server.mcp.Server", autospec=True) as mock_server: + # We need to mock the async context manager + mock_server_instance = mock_server.return_value + mock_server_instance.__aenter__.return_value = mock_server_instance + + # ServeCommand.run has an infinite wait, so we need to break it + with patch("asyncio.Event.wait", side_effect=asyncio.CancelledError): + with pytest.raises(asyncio.CancelledError): + await ServeCommand.run(args) + + mock_server.assert_called_once() + config = mock_server.call_args[1]["config"] + assert config.oauth_proxy is not None + assert config.oauth_proxy.allowed_redirect_urls == [ + "http://localhost:3000", + "http://localhost:3001", + ] + + @pytest.mark.asyncio + async def test_run_handles_empty_allowed_redirect_urls(self): + args = argparse.Namespace( + address="127.0.0.1:8080", + enapter_http_api_url="https://api.example.com", + logo_url="https://logo.example.com", + oauth_proxy_enabled="1", + oauth_proxy_introspection_url="https://sso.example.com/introspect", + oauth_proxy_authorization_url="https://sso.example.com/authorize", + oauth_proxy_token_url="https://sso.example.com/token", + oauth_proxy_user_info_url="https://sso.example.com/me", + oauth_proxy_protected_resource_url="https://mcp.example.com", + oauth_proxy_forward_pkce="1", + oauth_proxy_required_scopes="openid,profile", + oauth_proxy_allowed_redirect_urls="", + oauth_proxy_client_id="client_id", + oauth_proxy_client_secret="client_secret", + oauth_proxy_jwt_store_url="memory://", + oauth_proxy_jwt_signing_key="signing_key", + verbose=False, + ) + + with patch("enapter_mcp_server.mcp.Server", autospec=True) as mock_server: + mock_server_instance = mock_server.return_value + mock_server_instance.__aenter__.return_value = mock_server_instance + + with patch("asyncio.Event.wait", side_effect=asyncio.CancelledError): + with pytest.raises(asyncio.CancelledError): + await ServeCommand.run(args) + + config = mock_server.call_args[1]["config"] + assert config.oauth_proxy.allowed_redirect_urls is None