66
77import grpc
88from azure .core .credentials import TokenCredential
9+ from azure .core .credentials_async import AsyncTokenCredential
910
10- from durabletask .azuremanaged .internal .access_token_manager import AccessTokenManager
11+ from durabletask .azuremanaged .internal .access_token_manager import (
12+ AccessTokenManager ,
13+ AsyncAccessTokenManager ,
14+ )
1115from durabletask .internal .grpc_interceptor import (
1216 DefaultAsyncClientInterceptorImpl ,
1317 DefaultClientInterceptorImpl ,
@@ -34,6 +38,7 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
3438 ("x-user-agent" , user_agent )] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
3539 super ().__init__ (self ._metadata )
3640
41+ self ._token_manager = None
3742 if token_credential is not None :
3843 self ._token_credential = token_credential
3944 self ._token_manager = AccessTokenManager (token_credential = self ._token_credential )
@@ -45,13 +50,21 @@ def _intercept_call(
4550 self , client_call_details : _ClientCallDetails ) -> grpc .ClientCallDetails :
4651 """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
4752 call details."""
48- # Refresh the auth token if it is present and needed
49- if self ._metadata is not None :
50- for i , (key , _ ) in enumerate (self ._metadata ):
51- if key .lower () == "authorization" : # Ensure case-insensitive comparison
52- new_token = self ._token_manager .get_access_token () # Get the new token
53- if new_token is not None :
54- self ._metadata [i ] = ("authorization" , f"Bearer { new_token .token } " ) # Update the token
53+ # Refresh the auth token if a credential was provided. The call to
54+ # get_access_token() is generally cheap, checking the expiry time and returning
55+ # the cached value without a network call when still valid.
56+ if self ._token_manager is not None :
57+ access_token = self ._token_manager .get_access_token ()
58+ if access_token is not None :
59+ # Update the existing authorization header
60+ found = False
61+ for i , (key , _ ) in enumerate (self ._metadata ):
62+ if key .lower () == "authorization" :
63+ self ._metadata [i ] = ("authorization" , f"Bearer { access_token .token } " )
64+ found = True
65+ break
66+ if not found :
67+ self ._metadata .append (("authorization" , f"Bearer { access_token .token } " ))
5568
5669 return super ()._intercept_call (client_call_details )
5770
@@ -62,7 +75,7 @@ class DTSAsyncDefaultClientInterceptorImpl(DefaultAsyncClientInterceptorImpl):
6275 This class implements async gRPC interceptors to add DTS-specific headers
6376 (task hub name, user agent, and authentication token) to all async calls."""
6477
65- def __init__ (self , token_credential : Optional [TokenCredential ], taskhub_name : str ):
78+ def __init__ (self , token_credential : Optional [AsyncTokenCredential ], taskhub_name : str ):
6679 try :
6780 # Get the version of the azuremanaged package
6881 sdk_version = version ('durabletask-azuremanaged' )
@@ -75,23 +88,34 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
7588 ("x-user-agent" , user_agent )]
7689 super ().__init__ (self ._metadata )
7790
91+ # Token acquisition is deferred to the first _intercept_call invocation
92+ # rather than happening in __init__, because get_token() on an
93+ # AsyncTokenCredential is async and cannot be awaited in a constructor.
94+ self ._token_manager = None
7895 if token_credential is not None :
7996 self ._token_credential = token_credential
80- self ._token_manager = AccessTokenManager (token_credential = self ._token_credential )
81- access_token = self ._token_manager .get_access_token ()
82- if access_token is not None :
83- self ._metadata .append (("authorization" , f"Bearer { access_token .token } " ))
97+ self ._token_manager = AsyncAccessTokenManager (token_credential = self ._token_credential )
8498
85- def _intercept_call (
99+ async def _intercept_call (
86100 self , client_call_details : _AsyncClientCallDetails ) -> grpc .aio .ClientCallDetails :
87101 """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
88102 call details."""
89- # Refresh the auth token if it is present and needed
90- if self ._metadata is not None :
91- for i , (key , _ ) in enumerate (self ._metadata ):
92- if key .lower () == "authorization" : # Ensure case-insensitive comparison
93- new_token = self ._token_manager .get_access_token () # Get the new token
94- if new_token is not None :
95- self ._metadata [i ] = ("authorization" , f"Bearer { new_token .token } " ) # Update the token
96-
97- return super ()._intercept_call (client_call_details )
103+ # Refresh the auth token if a credential was provided. The call to
104+ # get_access_token() is generally cheap, checking the expiry time and returning
105+ # the cached value without a network call when still valid.
106+ if self ._token_manager is not None :
107+ access_token = await self ._token_manager .get_access_token ()
108+ if access_token is not None :
109+ # Update the existing authorization header, or append one if this
110+ # is the first successful token acquisition (token is lazily
111+ # fetched on the first call since async constructors aren't possible).
112+ found = False
113+ for i , (key , _ ) in enumerate (self ._metadata ):
114+ if key .lower () == "authorization" :
115+ self ._metadata [i ] = ("authorization" , f"Bearer { access_token .token } " )
116+ found = True
117+ break
118+ if not found :
119+ self ._metadata .append (("authorization" , f"Bearer { access_token .token } " ))
120+
121+ return await super ()._intercept_call (client_call_details )
0 commit comments