Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/enapter_mcp_server/cli/serve_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
ENAPTER_OAUTH_PROXY_REQUIRED_SCOPES = os.getenv(
"ENAPTER_OAUTH_PROXY_REQUIRED_SCOPES", "openid,public"
)
ENAPTER_OAUTH_PROXY_ALLOWED_REDIRECT_URLS = os.getenv(
"ENAPTER_OAUTH_PROXY_ALLOWED_REDIRECT_URLS", ""
)
ENAPTER_OAUTH_PROXY_CLIENT_ID = os.getenv("ENAPTER_OAUTH_PROXY_CLIENT_ID")
ENAPTER_OAUTH_PROXY_CLIENT_SECRET = os.getenv("ENAPTER_OAUTH_PROXY_CLIENT_SECRET")
ENAPTER_OAUTH_PROXY_JWT_STORE_URL = os.getenv(
Expand Down Expand Up @@ -99,6 +102,11 @@ def register(parent: Subparsers) -> None:
default=ENAPTER_OAUTH_PROXY_REQUIRED_SCOPES,
help="Comma-separated list of required scopes for OAuth proxy",
)
parser.add_argument(
"--oauth-proxy-allowed-redirect-urls",
default=ENAPTER_OAUTH_PROXY_ALLOWED_REDIRECT_URLS,
help="Comma-separated list of allowed redirect URLs for OAuth proxy",
)
parser.add_argument(
"--oauth-proxy-client-id",
default=ENAPTER_OAUTH_PROXY_CLIENT_ID,
Expand Down Expand Up @@ -132,6 +140,11 @@ async def run(args: argparse.Namespace) -> None:
for scope in args.oauth_proxy_required_scopes.split(",")
if scope.strip()
]
allowed_redirect_urls = [
url.strip()
for url in args.oauth_proxy_allowed_redirect_urls.split(",")
if url.strip()
]
oauth_proxy_config = mcp.OAuthProxyConfig(
introspection_endpoint_url=args.oauth_proxy_introspection_url,
authorization_endpoint_url=args.oauth_proxy_authorization_url,
Expand All @@ -142,6 +155,9 @@ async def run(args: argparse.Namespace) -> None:
required_scopes=required_scopes,
client_id=args.oauth_proxy_client_id,
client_secret=args.oauth_proxy_client_secret,
allowed_redirect_urls=(
allowed_redirect_urls if allowed_redirect_urls else None
),
jwt_store_url=args.oauth_proxy_jwt_store_url,
jwt_signing_key=args.oauth_proxy_jwt_signing_key,
)
Expand Down
1 change: 1 addition & 0 deletions src/enapter_mcp_server/mcp/oauth_proxy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ class OAuthProxyConfig:
required_scopes: list[str]
client_id: str
client_secret: str
allowed_redirect_urls: list[str] | None = None
jwt_store_url: str | None = None
jwt_signing_key: str | None = None
1 change: 1 addition & 0 deletions src/enapter_mcp_server/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _select_auth_provider(self) -> fastmcp.server.auth.AuthProvider | None:
token_verifier=token_verifier,
base_url=self._config.oauth_proxy.protected_resource_url,
forward_pkce=self._config.oauth_proxy.forward_pkce,
allowed_client_redirect_uris=self._config.oauth_proxy.allowed_redirect_urls,
client_storage=jwt_store,
jwt_signing_key=self._config.oauth_proxy.jwt_signing_key,
)
Expand Down
82 changes: 82 additions & 0 deletions tests/unit/cli/test_serve_command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import argparse
import asyncio
from unittest.mock import patch

import pytest

from enapter_mcp_server.cli.serve_command import ServeCommand


class TestServeCommand:
@pytest.mark.asyncio
async def test_run_parses_allowed_redirect_urls(self):
args = argparse.Namespace(
address="127.0.0.1:8080",
enapter_http_api_url="https://api.example.com",
logo_url="https://logo.example.com",
oauth_proxy_enabled="1",
oauth_proxy_introspection_url="https://sso.example.com/introspect",
oauth_proxy_authorization_url="https://sso.example.com/authorize",
oauth_proxy_token_url="https://sso.example.com/token",
oauth_proxy_user_info_url="https://sso.example.com/me",
oauth_proxy_protected_resource_url="https://mcp.example.com",
oauth_proxy_forward_pkce="1",
oauth_proxy_required_scopes="openid,profile",
oauth_proxy_allowed_redirect_urls="http://localhost:3000,http://localhost:3001",
oauth_proxy_client_id="client_id",
oauth_proxy_client_secret="client_secret",
oauth_proxy_jwt_store_url="memory://",
oauth_proxy_jwt_signing_key="signing_key",
verbose=False,
)

with patch("enapter_mcp_server.mcp.Server", autospec=True) as mock_server:
# We need to mock the async context manager
mock_server_instance = mock_server.return_value
mock_server_instance.__aenter__.return_value = mock_server_instance

# ServeCommand.run has an infinite wait, so we need to break it
with patch("asyncio.Event.wait", side_effect=asyncio.CancelledError):
with pytest.raises(asyncio.CancelledError):
await ServeCommand.run(args)

mock_server.assert_called_once()
config = mock_server.call_args[1]["config"]
assert config.oauth_proxy is not None
assert config.oauth_proxy.allowed_redirect_urls == [
"http://localhost:3000",
"http://localhost:3001",
]

@pytest.mark.asyncio
async def test_run_handles_empty_allowed_redirect_urls(self):
args = argparse.Namespace(
address="127.0.0.1:8080",
enapter_http_api_url="https://api.example.com",
logo_url="https://logo.example.com",
oauth_proxy_enabled="1",
oauth_proxy_introspection_url="https://sso.example.com/introspect",
oauth_proxy_authorization_url="https://sso.example.com/authorize",
oauth_proxy_token_url="https://sso.example.com/token",
oauth_proxy_user_info_url="https://sso.example.com/me",
oauth_proxy_protected_resource_url="https://mcp.example.com",
oauth_proxy_forward_pkce="1",
oauth_proxy_required_scopes="openid,profile",
oauth_proxy_allowed_redirect_urls="",
oauth_proxy_client_id="client_id",
oauth_proxy_client_secret="client_secret",
oauth_proxy_jwt_store_url="memory://",
oauth_proxy_jwt_signing_key="signing_key",
verbose=False,
)

with patch("enapter_mcp_server.mcp.Server", autospec=True) as mock_server:
mock_server_instance = mock_server.return_value
mock_server_instance.__aenter__.return_value = mock_server_instance

with patch("asyncio.Event.wait", side_effect=asyncio.CancelledError):
with pytest.raises(asyncio.CancelledError):
await ServeCommand.run(args)

config = mock_server.call_args[1]["config"]
assert config.oauth_proxy.allowed_redirect_urls is None