Skip to content

Commit 697c715

Browse files
committed
feat(auth): persist oauth_metadata via TokenStorage to fix refresh after restart
Closes part of #1318. Problem: OAuthClientProvider._initialize() restores tokens and client_info from storage but not oauth_metadata. After a restart with cached tokens, _refresh_token() falls back to <base_url>/token (via get_authorization_base_url + urljoin), which is incorrect for servers whose token endpoint is at a non-standard path (e.g. HubSpot's MCP Auth App uses /oauth/v3/token). Refresh requests return 404, cascading into a full interactive OAuth flow that cannot complete in non-interactive environments (daemons, containers, long-running services). Fix: add optional get_oauth_metadata / set_oauth_metadata methods to the TokenStorage protocol. _initialize now restores metadata alongside tokens and client_info via a getattr fallback that preserves backward compatibility with storage implementations predating this API; they return None and the refresh path falls back to the legacy behaviour as before. Discovered metadata is persisted after OASM discovery in the 401 flow so subsequent restarts can resolve the correct token endpoint without rediscovery. Coverage: added 5 tests (3 feature + 2 to cover the getattr fallback branches and the Protocol default bodies). Also removed three "# pragma: no cover" markers on lines now exercised by the new tests per AGENTS.md.
1 parent 3d7b311 commit 697c715

2 files changed

Lines changed: 294 additions & 6 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,24 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
8787
"""Store client information."""
8888
...
8989

90+
async def get_oauth_metadata(self) -> OAuthMetadata | None:
91+
"""Get stored authorization server metadata.
92+
93+
Optional: implementations may return ``None`` if metadata persistence
94+
is not desired. Implementations that persist tokens across restarts
95+
should also persist metadata so :meth:`OAuthClientProvider._refresh_token`
96+
can resolve the correct token endpoint without rediscovering metadata
97+
on every restart.
98+
"""
99+
return None
100+
101+
async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None:
102+
"""Store authorization server metadata.
103+
104+
Optional: no-op by default. See :meth:`get_oauth_metadata`.
105+
"""
106+
return
107+
90108

91109
@dataclass
92110
class OAuthContext:
@@ -473,10 +491,19 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p
473491
self.context.clear_tokens()
474492
return False
475493

476-
async def _initialize(self) -> None: # pragma: no cover
477-
"""Load stored tokens and client info."""
494+
async def _initialize(self) -> None:
495+
"""Load stored tokens, client info, and authorization server metadata."""
478496
self.context.current_tokens = await self.context.storage.get_tokens()
479497
self.context.client_info = await self.context.storage.get_client_info()
498+
# Restore authorization server metadata so ``_refresh_token`` can
499+
# resolve the correct token endpoint without rediscovering it on
500+
# every restart. ``getattr`` preserves backward compatibility with
501+
# storage implementations predating ``get_oauth_metadata``: they
502+
# return ``None`` and the refresh path falls back to the legacy
503+
# ``<base_url>/token`` behaviour as before.
504+
meta_getter = getattr(self.context.storage, "get_oauth_metadata", None)
505+
if meta_getter is not None:
506+
self.context.oauth_metadata = await meta_getter()
480507
self._initialized = True
481508

482509
def _add_auth_header(self, request: httpx.Request) -> None:
@@ -507,7 +534,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
507534
"""HTTPX auth flow integration."""
508535
async with self.context.lock:
509536
if not self._initialized:
510-
await self._initialize() # pragma: no cover
537+
await self._initialize()
511538

512539
# Capture protocol version from request headers
513540
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
@@ -572,6 +599,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
572599
break
573600
if ok and asm:
574601
self.context.oauth_metadata = asm
602+
# Persist so subsequent restarts can resolve the
603+
# correct token endpoint without rediscovery.
604+
meta_setter = getattr(self.context.storage, "set_oauth_metadata", None)
605+
if meta_setter is not None:
606+
await meta_setter(asm)
575607
break
576608
else:
577609
logger.debug(f"OAuth metadata discovery failed: {url}")
@@ -612,7 +644,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
612644
# Step 5: Perform authorization and complete token exchange
613645
token_response = yield await self._perform_authorization()
614646
await self._handle_token_response(token_response)
615-
except Exception: # pragma: no cover
647+
except Exception:
616648
logger.exception("OAuth flow error")
617649
raise
618650

tests/client/test_auth.py

Lines changed: 258 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,26 @@ class MockTokenStorage:
4343
def __init__(self):
4444
self._tokens: OAuthToken | None = None
4545
self._client_info: OAuthClientInformationFull | None = None
46+
self._oauth_metadata: OAuthMetadata | None = None
4647

4748
async def get_tokens(self) -> OAuthToken | None:
48-
return self._tokens # pragma: no cover
49+
return self._tokens
4950

5051
async def set_tokens(self, tokens: OAuthToken) -> None:
5152
self._tokens = tokens
5253

5354
async def get_client_info(self) -> OAuthClientInformationFull | None:
54-
return self._client_info # pragma: no cover
55+
return self._client_info
5556

5657
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
5758
self._client_info = client_info
5859

60+
async def get_oauth_metadata(self) -> OAuthMetadata | None:
61+
return self._oauth_metadata
62+
63+
async def set_oauth_metadata(self, metadata: OAuthMetadata) -> None:
64+
self._oauth_metadata = metadata
65+
5966

6067
@pytest.fixture
6168
def mock_storage():
@@ -2618,3 +2625,252 @@ async def callback_handler() -> tuple[str, str | None]:
26182625
await auth_flow.asend(final_response)
26192626
except StopAsyncIteration:
26202627
pass
2628+
2629+
2630+
# --- Regression coverage for #1318: restore oauth_metadata on _initialize ---
2631+
2632+
2633+
@pytest.mark.anyio
2634+
async def test_initialize_restores_oauth_metadata(
2635+
oauth_provider: OAuthClientProvider,
2636+
mock_storage: MockTokenStorage,
2637+
):
2638+
"""``_initialize`` should restore ``oauth_metadata`` from storage.
2639+
2640+
Without this, ``_refresh_token`` loses the authoritative token endpoint
2641+
discovered during the prior session and falls back to ``<base_url>/token``
2642+
after every restart — a 404 for servers whose token endpoint sits on a
2643+
non-standard path.
2644+
"""
2645+
stored_metadata = OAuthMetadata(
2646+
issuer=AnyHttpUrl("https://auth.example.com"),
2647+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
2648+
token_endpoint=AnyHttpUrl("https://auth.example.com/oauth/v3/token"),
2649+
)
2650+
await mock_storage.set_oauth_metadata(stored_metadata)
2651+
2652+
await oauth_provider._initialize()
2653+
2654+
assert oauth_provider.context.oauth_metadata is not None
2655+
assert str(oauth_provider.context.oauth_metadata.token_endpoint) == ("https://auth.example.com/oauth/v3/token")
2656+
2657+
2658+
@pytest.mark.anyio
2659+
async def test_refresh_token_uses_persisted_metadata_endpoint(
2660+
oauth_provider: OAuthClientProvider,
2661+
mock_storage: MockTokenStorage,
2662+
valid_tokens: OAuthToken,
2663+
):
2664+
"""After a restart with persisted metadata, ``_refresh_token`` uses the
2665+
correct ``token_endpoint`` rather than the ``<base_url>/token`` fallback.
2666+
"""
2667+
custom_token_endpoint = "https://auth.example.com/oauth/v3/token"
2668+
await mock_storage.set_tokens(valid_tokens)
2669+
await mock_storage.set_client_info(
2670+
OAuthClientInformationFull(
2671+
client_id="test_client",
2672+
client_secret="test_secret",
2673+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
2674+
token_endpoint_auth_method="client_secret_post",
2675+
)
2676+
)
2677+
await mock_storage.set_oauth_metadata(
2678+
OAuthMetadata(
2679+
issuer=AnyHttpUrl("https://auth.example.com"),
2680+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
2681+
token_endpoint=AnyHttpUrl(custom_token_endpoint),
2682+
)
2683+
)
2684+
2685+
await oauth_provider._initialize()
2686+
request = await oauth_provider._refresh_token()
2687+
2688+
assert str(request.url) == custom_token_endpoint
2689+
2690+
2691+
@pytest.mark.anyio
2692+
async def test_initialize_backward_compat_without_metadata_methods(
2693+
client_metadata: OAuthClientMetadata,
2694+
valid_tokens: OAuthToken,
2695+
):
2696+
"""Storage implementations predating ``get_oauth_metadata`` keep working.
2697+
2698+
Duck-typed ``TokenStorage`` instances written before this method was
2699+
introduced must not raise ``AttributeError`` on ``_initialize``.
2700+
"""
2701+
2702+
class LegacyStorage:
2703+
"""Duck-typed storage matching the pre-change ``TokenStorage``."""
2704+
2705+
def __init__(self, tokens: OAuthToken | None):
2706+
self._tokens = tokens
2707+
self._client_info: OAuthClientInformationFull | None = None
2708+
2709+
async def get_tokens(self) -> OAuthToken | None:
2710+
return self._tokens
2711+
2712+
async def set_tokens(self, tokens: OAuthToken) -> None:
2713+
self._tokens = tokens # pragma: no cover
2714+
2715+
async def get_client_info(self) -> OAuthClientInformationFull | None:
2716+
return self._client_info
2717+
2718+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
2719+
self._client_info = client_info # pragma: no cover
2720+
2721+
legacy_storage = LegacyStorage(valid_tokens)
2722+
2723+
async def redirect_handler(url: str) -> None:
2724+
pass # pragma: no cover
2725+
2726+
async def callback_handler() -> tuple[str, str | None]:
2727+
return "test_auth_code", "test_state" # pragma: no cover
2728+
2729+
provider = OAuthClientProvider(
2730+
server_url="https://api.example.com/v1/mcp",
2731+
client_metadata=client_metadata,
2732+
storage=legacy_storage, # type: ignore[arg-type]
2733+
redirect_handler=redirect_handler,
2734+
callback_handler=callback_handler,
2735+
)
2736+
2737+
await provider._initialize()
2738+
2739+
assert provider.context.current_tokens is valid_tokens
2740+
assert provider.context.oauth_metadata is None
2741+
2742+
2743+
@pytest.mark.anyio
2744+
async def test_token_storage_protocol_default_metadata_methods():
2745+
"""``TokenStorage`` provides no-op defaults for the optional metadata methods.
2746+
2747+
Storage subclasses that don't care about metadata persistence can inherit
2748+
``TokenStorage`` without overriding ``get_oauth_metadata`` /
2749+
``set_oauth_metadata``; the default ``get`` returns ``None`` and the
2750+
default ``set`` is a no-op (equivalent to opting out of persistence).
2751+
"""
2752+
from mcp.client.auth.oauth2 import TokenStorage
2753+
2754+
class DefaultStorage(TokenStorage):
2755+
async def get_tokens(self) -> OAuthToken | None:
2756+
return None # pragma: no cover
2757+
2758+
async def set_tokens(self, tokens: OAuthToken) -> None: ... # pragma: no cover
2759+
2760+
async def get_client_info(self) -> OAuthClientInformationFull | None:
2761+
return None # pragma: no cover
2762+
2763+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: ... # pragma: no cover
2764+
2765+
storage = DefaultStorage()
2766+
assert await storage.get_oauth_metadata() is None
2767+
2768+
metadata = OAuthMetadata(
2769+
issuer=AnyHttpUrl("https://auth.example.com"),
2770+
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
2771+
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
2772+
)
2773+
# No-op: set completes without storing
2774+
await storage.set_oauth_metadata(metadata)
2775+
assert await storage.get_oauth_metadata() is None
2776+
2777+
2778+
@pytest.mark.anyio
2779+
async def test_auth_flow_discovery_with_legacy_storage_skips_metadata_persistence(
2780+
client_metadata: OAuthClientMetadata,
2781+
):
2782+
"""OAuth discovery succeeds when storage lacks ``set_oauth_metadata``.
2783+
2784+
Covers the ``getattr`` fallback branch in ``async_auth_flow`` that bypasses
2785+
persistence for storage implementations predating the metadata API.
2786+
"""
2787+
2788+
class LegacyStorage:
2789+
"""Duck-typed storage matching the pre-change ``TokenStorage``."""
2790+
2791+
def __init__(self) -> None:
2792+
self._tokens: OAuthToken | None = None
2793+
self._client_info: OAuthClientInformationFull | None = None
2794+
2795+
async def get_tokens(self) -> OAuthToken | None:
2796+
return self._tokens
2797+
2798+
async def set_tokens(self, tokens: OAuthToken) -> None:
2799+
self._tokens = tokens # pragma: no cover
2800+
2801+
async def get_client_info(self) -> OAuthClientInformationFull | None:
2802+
return self._client_info
2803+
2804+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
2805+
self._client_info = client_info # pragma: no cover
2806+
2807+
legacy_storage = LegacyStorage()
2808+
2809+
async def redirect_handler(url: str) -> None:
2810+
pass # pragma: no cover
2811+
2812+
async def callback_handler() -> tuple[str, str | None]:
2813+
return "test_auth_code", "test_state" # pragma: no cover
2814+
2815+
provider = OAuthClientProvider(
2816+
server_url="https://api.example.com/v1/mcp",
2817+
client_metadata=client_metadata,
2818+
storage=legacy_storage, # type: ignore[arg-type]
2819+
redirect_handler=redirect_handler,
2820+
callback_handler=callback_handler,
2821+
)
2822+
2823+
test_request = httpx.Request("GET", "https://api.example.com/mcp")
2824+
auth_flow = provider.async_auth_flow(test_request)
2825+
2826+
# First yield: ``_initialize`` loads state from LegacyStorage (no tokens,
2827+
# no client info, no metadata fallback) and the original request goes out
2828+
# without an auth header.
2829+
request = await auth_flow.__anext__()
2830+
assert "Authorization" not in request.headers
2831+
2832+
# 401 → triggers full OAuth flow
2833+
unauthorized_response = httpx.Response(
2834+
401,
2835+
headers={
2836+
"WWW-Authenticate": (
2837+
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
2838+
)
2839+
},
2840+
request=test_request,
2841+
)
2842+
prm_request = await auth_flow.asend(unauthorized_response)
2843+
assert "oauth-protected-resource" in str(prm_request.url)
2844+
2845+
# PRM discovery response
2846+
prm_response = httpx.Response(
2847+
200,
2848+
content=(
2849+
b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}'
2850+
),
2851+
request=prm_request,
2852+
)
2853+
asm_request = await auth_flow.asend(prm_response)
2854+
assert str(asm_request.url).startswith("https://auth.example.com/")
2855+
2856+
# OASM discovery response — this is where our set_oauth_metadata
2857+
# fallback (meta_setter is None) executes for LegacyStorage.
2858+
asm_response = httpx.Response(
2859+
200,
2860+
content=(
2861+
b'{"issuer": "https://auth.example.com", '
2862+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
2863+
b'"token_endpoint": "https://auth.example.com/token", '
2864+
b'"registration_endpoint": "https://auth.example.com/register"}'
2865+
),
2866+
request=asm_request,
2867+
)
2868+
next_request = await auth_flow.asend(asm_response)
2869+
2870+
# Discovery succeeded: flow advanced past metadata handling.
2871+
# (Legacy storage had no set_oauth_metadata, so persistence is skipped.)
2872+
assert next_request is not None
2873+
assert provider.context.oauth_metadata is not None
2874+
assert str(provider.context.oauth_metadata.token_endpoint) == "https://auth.example.com/token"
2875+
2876+
await auth_flow.aclose()

0 commit comments

Comments
 (0)