Skip to content

Commit b76b429

Browse files
committed
PR feedback
1 parent e425c35 commit b76b429

File tree

7 files changed

+283
-213
lines changed

7 files changed

+283
-213
lines changed

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional
77

88
from azure.core.credentials import TokenCredential
9+
from azure.core.credentials_async import AsyncTokenCredential
910

1011
from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import (
1112
DTSAsyncDefaultClientInterceptorImpl,
@@ -81,7 +82,7 @@ class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient):
8182
def __init__(self, *,
8283
host_address: str,
8384
taskhub: str,
84-
token_credential: Optional[TokenCredential],
85+
token_credential: Optional[AsyncTokenCredential],
8586
secure_channel: bool = True,
8687
default_version: Optional[str] = None,
8788
log_handler: Optional[logging.Handler] = None,

durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55

66
from azure.core.credentials import AccessToken, TokenCredential
7+
from azure.core.credentials_async import AsyncTokenCredential
78

89
import durabletask.internal.shared as shared
910

@@ -47,3 +48,40 @@ def refresh_token(self):
4748
# Convert UNIX timestamp to timezone-aware datetime
4849
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
4950
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")
51+
52+
53+
class AsyncAccessTokenManager:
54+
"""Async version of AccessTokenManager that uses AsyncTokenCredential.
55+
56+
This avoids blocking the event loop when acquiring or refreshing tokens."""
57+
58+
_token: Optional[AccessToken]
59+
60+
def __init__(self, token_credential: Optional[AsyncTokenCredential],
61+
refresh_interval_seconds: int = 600):
62+
self._scope = "https://durabletask.io/.default"
63+
self._refresh_interval_seconds = refresh_interval_seconds
64+
self._logger = shared.get_logger("async_token_manager")
65+
66+
self._credential = token_credential
67+
self._token = None
68+
self.expiry_time = None
69+
70+
async def get_access_token(self) -> Optional[AccessToken]:
71+
if self._token is None or self.is_token_expired():
72+
await self.refresh_token()
73+
return self._token
74+
75+
def is_token_expired(self) -> bool:
76+
if self.expiry_time is None:
77+
return True
78+
return datetime.now(timezone.utc) >= (
79+
self.expiry_time - timedelta(seconds=self._refresh_interval_seconds))
80+
81+
async def refresh_token(self):
82+
if self._credential is not None:
83+
self._token = await self._credential.get_token(self._scope)
84+
85+
# Convert UNIX timestamp to timezone-aware datetime
86+
self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc)
87+
self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}")

durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77
import grpc
88
from azure.core.credentials import TokenCredential
9+
from azure.core.credentials_async import AsyncTokenCredential
910

10-
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
11+
from durabletask.azuremanaged.internal.access_token_manager import (
12+
AccessTokenManager,
13+
AsyncAccessTokenManager,
14+
)
1115
from durabletask.internal.grpc_interceptor import (
1216
DefaultAsyncClientInterceptorImpl,
1317
DefaultClientInterceptorImpl,
@@ -34,6 +38,7 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
3438
("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
3539
super().__init__(self._metadata)
3640

41+
self._token_manager = None
3742
if token_credential is not None:
3843
self._token_credential = token_credential
3944
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
@@ -45,13 +50,21 @@ def _intercept_call(
4550
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
4651
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
4752
call details."""
48-
# Refresh the auth token if it is present and needed
49-
if self._metadata is not None:
50-
for i, (key, _) in enumerate(self._metadata):
51-
if key.lower() == "authorization": # Ensure case-insensitive comparison
52-
new_token = self._token_manager.get_access_token() # Get the new token
53-
if new_token is not None:
54-
self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token
53+
# Refresh the auth token if a credential was provided. The call to
54+
# get_access_token() is generally cheap, checking the expiry time and returning
55+
# the cached value without a network call when still valid.
56+
if self._token_manager is not None:
57+
access_token = self._token_manager.get_access_token()
58+
if access_token is not None:
59+
# Update the existing authorization header
60+
found = False
61+
for i, (key, _) in enumerate(self._metadata):
62+
if key.lower() == "authorization":
63+
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
64+
found = True
65+
break
66+
if not found:
67+
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
5568

5669
return super()._intercept_call(client_call_details)
5770

@@ -62,7 +75,7 @@ class DTSAsyncDefaultClientInterceptorImpl(DefaultAsyncClientInterceptorImpl):
6275
This class implements async gRPC interceptors to add DTS-specific headers
6376
(task hub name, user agent, and authentication token) to all async calls."""
6477

65-
def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: str):
78+
def __init__(self, token_credential: Optional[AsyncTokenCredential], taskhub_name: str):
6679
try:
6780
# Get the version of the azuremanaged package
6881
sdk_version = version('durabletask-azuremanaged')
@@ -75,23 +88,34 @@ def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: st
7588
("x-user-agent", user_agent)]
7689
super().__init__(self._metadata)
7790

91+
# Token acquisition is deferred to the first _intercept_call invocation
92+
# rather than happening in __init__, because get_token() on an
93+
# AsyncTokenCredential is async and cannot be awaited in a constructor.
94+
self._token_manager = None
7895
if token_credential is not None:
7996
self._token_credential = token_credential
80-
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
81-
access_token = self._token_manager.get_access_token()
82-
if access_token is not None:
83-
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
97+
self._token_manager = AsyncAccessTokenManager(token_credential=self._token_credential)
8498

85-
def _intercept_call(
99+
async def _intercept_call(
86100
self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails:
87101
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
88102
call details."""
89-
# Refresh the auth token if it is present and needed
90-
if self._metadata is not None:
91-
for i, (key, _) in enumerate(self._metadata):
92-
if key.lower() == "authorization": # Ensure case-insensitive comparison
93-
new_token = self._token_manager.get_access_token() # Get the new token
94-
if new_token is not None:
95-
self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token
96-
97-
return super()._intercept_call(client_call_details)
103+
# Refresh the auth token if a credential was provided. The call to
104+
# get_access_token() is generally cheap, checking the expiry time and returning
105+
# the cached value without a network call when still valid.
106+
if self._token_manager is not None:
107+
access_token = await self._token_manager.get_access_token()
108+
if access_token is not None:
109+
# Update the existing authorization header, or append one if this
110+
# is the first successful token acquisition (token is lazily
111+
# fetched on the first call since async constructors aren't possible).
112+
found = False
113+
for i, (key, _) in enumerate(self._metadata):
114+
if key.lower() == "authorization":
115+
self._metadata[i] = ("authorization", f"Bearer {access_token.token}")
116+
found = True
117+
break
118+
if not found:
119+
self._metadata.append(("authorization", f"Bearer {access_token.token}"))
120+
121+
return await super()._intercept_call(client_call_details)

durabletask/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,12 @@ async def close(self) -> None:
404404
"""Closes the underlying gRPC channel."""
405405
await self._channel.close()
406406

407+
async def __aenter__(self):
408+
return self
409+
410+
async def __aexit__(self, exc_type, exc_val, exc_tb):
411+
await self.close()
412+
407413
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
408414
input: Optional[TInput] = None,
409415
instance_id: Optional[str] = None,

durabletask/internal/grpc_interceptor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,11 @@ class DefaultAsyncClientInterceptorImpl(
9393
def __init__(self, metadata: list[tuple[str, str]]):
9494
self._metadata = metadata
9595

96-
def _intercept_call(
96+
async def _intercept_call(
9797
self, client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails:
9898
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
99-
call details."""
99+
call details. This method is async to allow subclasses to perform async operations
100+
(e.g., refreshing auth tokens) during interception."""
100101
new_metadata = _apply_metadata(client_call_details, self._metadata)
101102
if new_metadata is client_call_details.metadata:
102103
return client_call_details
@@ -110,17 +111,17 @@ def _intercept_call(
110111
)
111112

112113
async def intercept_unary_unary(self, continuation, client_call_details, request):
113-
new_client_call_details = self._intercept_call(client_call_details)
114+
new_client_call_details = await self._intercept_call(client_call_details)
114115
return await continuation(new_client_call_details, request)
115116

116117
async def intercept_unary_stream(self, continuation, client_call_details, request):
117-
new_client_call_details = self._intercept_call(client_call_details)
118+
new_client_call_details = await self._intercept_call(client_call_details)
118119
return await continuation(new_client_call_details, request)
119120

120121
async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
121-
new_client_call_details = self._intercept_call(client_call_details)
122+
new_client_call_details = await self._intercept_call(client_call_details)
122123
return await continuation(new_client_call_details, request_iterator)
123124

124125
async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
125-
new_client_call_details = self._intercept_call(client_call_details)
126+
new_client_call_details = await self._intercept_call(client_call_details)
126127
return await continuation(new_client_call_details, request_iterator)

0 commit comments

Comments
 (0)