@@ -43,19 +43,26 @@ class MockTokenStorage:
4343 def __init__ (self ):
4444 self ._tokens : OAuthToken | None = None
4545 self ._client_info : OAuthClientInformationFull | None = None
46+ self ._oauth_metadata : OAuthMetadata | None = None
4647
4748 async def get_tokens (self ) -> OAuthToken | None :
48- return self ._tokens # pragma: no cover
49+ return self ._tokens
4950
5051 async def set_tokens (self , tokens : OAuthToken ) -> None :
5152 self ._tokens = tokens
5253
5354 async def get_client_info (self ) -> OAuthClientInformationFull | None :
54- return self ._client_info # pragma: no cover
55+ return self ._client_info
5556
5657 async def set_client_info (self , client_info : OAuthClientInformationFull ) -> None :
5758 self ._client_info = client_info
5859
60+ async def get_oauth_metadata (self ) -> OAuthMetadata | None :
61+ return self ._oauth_metadata
62+
63+ async def set_oauth_metadata (self , metadata : OAuthMetadata ) -> None :
64+ self ._oauth_metadata = metadata
65+
5966
6067@pytest .fixture
6168def mock_storage ():
@@ -2618,3 +2625,252 @@ async def callback_handler() -> tuple[str, str | None]:
26182625 await auth_flow .asend (final_response )
26192626 except StopAsyncIteration :
26202627 pass
2628+
2629+
2630+ # --- Regression coverage for #1318: restore oauth_metadata on _initialize ---
2631+
2632+
2633+ @pytest .mark .anyio
2634+ async def test_initialize_restores_oauth_metadata (
2635+ oauth_provider : OAuthClientProvider ,
2636+ mock_storage : MockTokenStorage ,
2637+ ):
2638+ """``_initialize`` should restore ``oauth_metadata`` from storage.
2639+
2640+ Without this, ``_refresh_token`` loses the authoritative token endpoint
2641+ discovered during the prior session and falls back to ``<base_url>/token``
2642+ after every restart — a 404 for servers whose token endpoint sits on a
2643+ non-standard path.
2644+ """
2645+ stored_metadata = OAuthMetadata (
2646+ issuer = AnyHttpUrl ("https://auth.example.com" ),
2647+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
2648+ token_endpoint = AnyHttpUrl ("https://auth.example.com/oauth/v3/token" ),
2649+ )
2650+ await mock_storage .set_oauth_metadata (stored_metadata )
2651+
2652+ await oauth_provider ._initialize ()
2653+
2654+ assert oauth_provider .context .oauth_metadata is not None
2655+ assert str (oauth_provider .context .oauth_metadata .token_endpoint ) == ("https://auth.example.com/oauth/v3/token" )
2656+
2657+
2658+ @pytest .mark .anyio
2659+ async def test_refresh_token_uses_persisted_metadata_endpoint (
2660+ oauth_provider : OAuthClientProvider ,
2661+ mock_storage : MockTokenStorage ,
2662+ valid_tokens : OAuthToken ,
2663+ ):
2664+ """After a restart with persisted metadata, ``_refresh_token`` uses the
2665+ correct ``token_endpoint`` rather than the ``<base_url>/token`` fallback.
2666+ """
2667+ custom_token_endpoint = "https://auth.example.com/oauth/v3/token"
2668+ await mock_storage .set_tokens (valid_tokens )
2669+ await mock_storage .set_client_info (
2670+ OAuthClientInformationFull (
2671+ client_id = "test_client" ,
2672+ client_secret = "test_secret" ,
2673+ redirect_uris = [AnyUrl ("http://localhost:3030/callback" )],
2674+ token_endpoint_auth_method = "client_secret_post" ,
2675+ )
2676+ )
2677+ await mock_storage .set_oauth_metadata (
2678+ OAuthMetadata (
2679+ issuer = AnyHttpUrl ("https://auth.example.com" ),
2680+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
2681+ token_endpoint = AnyHttpUrl (custom_token_endpoint ),
2682+ )
2683+ )
2684+
2685+ await oauth_provider ._initialize ()
2686+ request = await oauth_provider ._refresh_token ()
2687+
2688+ assert str (request .url ) == custom_token_endpoint
2689+
2690+
2691+ @pytest .mark .anyio
2692+ async def test_initialize_backward_compat_without_metadata_methods (
2693+ client_metadata : OAuthClientMetadata ,
2694+ valid_tokens : OAuthToken ,
2695+ ):
2696+ """Storage implementations predating ``get_oauth_metadata`` keep working.
2697+
2698+ Duck-typed ``TokenStorage`` instances written before this method was
2699+ introduced must not raise ``AttributeError`` on ``_initialize``.
2700+ """
2701+
2702+ class LegacyStorage :
2703+ """Duck-typed storage matching the pre-change ``TokenStorage``."""
2704+
2705+ def __init__ (self , tokens : OAuthToken | None ):
2706+ self ._tokens = tokens
2707+ self ._client_info : OAuthClientInformationFull | None = None
2708+
2709+ async def get_tokens (self ) -> OAuthToken | None :
2710+ return self ._tokens
2711+
2712+ async def set_tokens (self , tokens : OAuthToken ) -> None :
2713+ self ._tokens = tokens # pragma: no cover
2714+
2715+ async def get_client_info (self ) -> OAuthClientInformationFull | None :
2716+ return self ._client_info
2717+
2718+ async def set_client_info (self , client_info : OAuthClientInformationFull ) -> None :
2719+ self ._client_info = client_info # pragma: no cover
2720+
2721+ legacy_storage = LegacyStorage (valid_tokens )
2722+
2723+ async def redirect_handler (url : str ) -> None :
2724+ pass # pragma: no cover
2725+
2726+ async def callback_handler () -> tuple [str , str | None ]:
2727+ return "test_auth_code" , "test_state" # pragma: no cover
2728+
2729+ provider = OAuthClientProvider (
2730+ server_url = "https://api.example.com/v1/mcp" ,
2731+ client_metadata = client_metadata ,
2732+ storage = legacy_storage , # type: ignore[arg-type]
2733+ redirect_handler = redirect_handler ,
2734+ callback_handler = callback_handler ,
2735+ )
2736+
2737+ await provider ._initialize ()
2738+
2739+ assert provider .context .current_tokens is valid_tokens
2740+ assert provider .context .oauth_metadata is None
2741+
2742+
2743+ @pytest .mark .anyio
2744+ async def test_token_storage_protocol_default_metadata_methods ():
2745+ """``TokenStorage`` provides no-op defaults for the optional metadata methods.
2746+
2747+ Storage subclasses that don't care about metadata persistence can inherit
2748+ ``TokenStorage`` without overriding ``get_oauth_metadata`` /
2749+ ``set_oauth_metadata``; the default ``get`` returns ``None`` and the
2750+ default ``set`` is a no-op (equivalent to opting out of persistence).
2751+ """
2752+ from mcp .client .auth .oauth2 import TokenStorage
2753+
2754+ class DefaultStorage (TokenStorage ):
2755+ async def get_tokens (self ) -> OAuthToken | None :
2756+ return None # pragma: no cover
2757+
2758+ async def set_tokens (self , tokens : OAuthToken ) -> None : ... # pragma: no cover
2759+
2760+ async def get_client_info (self ) -> OAuthClientInformationFull | None :
2761+ return None # pragma: no cover
2762+
2763+ async def set_client_info (self , client_info : OAuthClientInformationFull ) -> None : ... # pragma: no cover
2764+
2765+ storage = DefaultStorage ()
2766+ assert await storage .get_oauth_metadata () is None
2767+
2768+ metadata = OAuthMetadata (
2769+ issuer = AnyHttpUrl ("https://auth.example.com" ),
2770+ authorization_endpoint = AnyHttpUrl ("https://auth.example.com/authorize" ),
2771+ token_endpoint = AnyHttpUrl ("https://auth.example.com/token" ),
2772+ )
2773+ # No-op: set completes without storing
2774+ await storage .set_oauth_metadata (metadata )
2775+ assert await storage .get_oauth_metadata () is None
2776+
2777+
2778+ @pytest .mark .anyio
2779+ async def test_auth_flow_discovery_with_legacy_storage_skips_metadata_persistence (
2780+ client_metadata : OAuthClientMetadata ,
2781+ ):
2782+ """OAuth discovery succeeds when storage lacks ``set_oauth_metadata``.
2783+
2784+ Covers the ``getattr`` fallback branch in ``async_auth_flow`` that bypasses
2785+ persistence for storage implementations predating the metadata API.
2786+ """
2787+
2788+ class LegacyStorage :
2789+ """Duck-typed storage matching the pre-change ``TokenStorage``."""
2790+
2791+ def __init__ (self ) -> None :
2792+ self ._tokens : OAuthToken | None = None
2793+ self ._client_info : OAuthClientInformationFull | None = None
2794+
2795+ async def get_tokens (self ) -> OAuthToken | None :
2796+ return self ._tokens
2797+
2798+ async def set_tokens (self , tokens : OAuthToken ) -> None :
2799+ self ._tokens = tokens # pragma: no cover
2800+
2801+ async def get_client_info (self ) -> OAuthClientInformationFull | None :
2802+ return self ._client_info
2803+
2804+ async def set_client_info (self , client_info : OAuthClientInformationFull ) -> None :
2805+ self ._client_info = client_info # pragma: no cover
2806+
2807+ legacy_storage = LegacyStorage ()
2808+
2809+ async def redirect_handler (url : str ) -> None :
2810+ pass # pragma: no cover
2811+
2812+ async def callback_handler () -> tuple [str , str | None ]:
2813+ return "test_auth_code" , "test_state" # pragma: no cover
2814+
2815+ provider = OAuthClientProvider (
2816+ server_url = "https://api.example.com/v1/mcp" ,
2817+ client_metadata = client_metadata ,
2818+ storage = legacy_storage , # type: ignore[arg-type]
2819+ redirect_handler = redirect_handler ,
2820+ callback_handler = callback_handler ,
2821+ )
2822+
2823+ test_request = httpx .Request ("GET" , "https://api.example.com/mcp" )
2824+ auth_flow = provider .async_auth_flow (test_request )
2825+
2826+ # First yield: ``_initialize`` loads state from LegacyStorage (no tokens,
2827+ # no client info, no metadata fallback) and the original request goes out
2828+ # without an auth header.
2829+ request = await auth_flow .__anext__ ()
2830+ assert "Authorization" not in request .headers
2831+
2832+ # 401 → triggers full OAuth flow
2833+ unauthorized_response = httpx .Response (
2834+ 401 ,
2835+ headers = {
2836+ "WWW-Authenticate" : (
2837+ 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
2838+ )
2839+ },
2840+ request = test_request ,
2841+ )
2842+ prm_request = await auth_flow .asend (unauthorized_response )
2843+ assert "oauth-protected-resource" in str (prm_request .url )
2844+
2845+ # PRM discovery response
2846+ prm_response = httpx .Response (
2847+ 200 ,
2848+ content = (
2849+ b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}'
2850+ ),
2851+ request = prm_request ,
2852+ )
2853+ asm_request = await auth_flow .asend (prm_response )
2854+ assert str (asm_request .url ).startswith ("https://auth.example.com/" )
2855+
2856+ # OASM discovery response — this is where our set_oauth_metadata
2857+ # fallback (meta_setter is None) executes for LegacyStorage.
2858+ asm_response = httpx .Response (
2859+ 200 ,
2860+ content = (
2861+ b'{"issuer": "https://auth.example.com", '
2862+ b'"authorization_endpoint": "https://auth.example.com/authorize", '
2863+ b'"token_endpoint": "https://auth.example.com/token", '
2864+ b'"registration_endpoint": "https://auth.example.com/register"}'
2865+ ),
2866+ request = asm_request ,
2867+ )
2868+ next_request = await auth_flow .asend (asm_response )
2869+
2870+ # Discovery succeeded: flow advanced past metadata handling.
2871+ # (Legacy storage had no set_oauth_metadata, so persistence is skipped.)
2872+ assert next_request is not None
2873+ assert provider .context .oauth_metadata is not None
2874+ assert str (provider .context .oauth_metadata .token_endpoint ) == "https://auth.example.com/token"
2875+
2876+ await auth_flow .aclose ()
0 commit comments