1313import pytest
1414
1515if TYPE_CHECKING :
16- from collections .abc import Awaitable , Callable , Coroutine
16+ from collections .abc import Awaitable , Callable
1717
1818logger = logging .getLogger (__name__ )
1919
@@ -96,23 +96,15 @@ def get_random_resource_name(label: str) -> str:
9696 return name_template .format (label , get_crypto_random_object_id (random_id_length ))
9797
9898
99- @overload
100- async def maybe_await (value : Coroutine [Any , Any , T ]) -> T : ...
101-
102-
103- @overload
104- async def maybe_await (value : T ) -> T : ...
105-
106-
107- async def maybe_await (value : T | Coroutine [Any , Any , T ]) -> T :
108- """Await coroutines, pass through other values.
99+ async def maybe_await (value : Awaitable [T ] | T ) -> T :
100+ """Await `value` if it is awaitable, otherwise return it unchanged.
109101
110102 Enables unified test code for both sync and async clients:
111103 result = await maybe_await(client.datasets().list())
112104 """
113- if hasattr (value , '__await__' ):
114- return await value # ty: ignore[invalid-await]
115- return value
105+ if inspect . isawaitable (value ):
106+ return await cast ( 'Awaitable[T]' , value )
107+ return cast ( 'T' , value )
116108
117109
118110async def maybe_sleep (seconds : float , * , is_async : bool ) -> None :
@@ -123,16 +115,6 @@ async def maybe_sleep(seconds: float, *, is_async: bool) -> None:
123115 time .sleep (seconds ) # noqa: ASYNC251
124116
125117
126- async def _maybe_await (value : Awaitable [T ] | T ) -> T :
127- """Await `value` if it is awaitable, otherwise return it unchanged.
128-
129- Lets `call_with_exp_backoff` and `poll_until_condition` accept both sync and async callables.
130- """
131- if inspect .isawaitable (value ):
132- return await cast ('Awaitable[T]' , value )
133- return cast ('T' , value )
134-
135-
136118@overload
137119async def call_with_exp_backoff (
138120 fn : Callable [[], Awaitable [T ]],
@@ -167,7 +149,7 @@ async def call_with_exp_backoff(
167149
168150 Unlike `poll_until_condition`, the delay between attempts grows exponentially rather than staying constant.
169151 """
170- result = await _maybe_await (fn ())
152+ result = await maybe_await (fn ())
171153 for attempt in range (max_retries ):
172154 if condition (result ):
173155 return result
@@ -176,7 +158,7 @@ async def call_with_exp_backoff(
176158 'Condition not met for %r, retrying in %ss (attempt %d/%d).' , result , delay , attempt + 1 , max_retries
177159 )
178160 await asyncio .sleep (delay )
179- result = await _maybe_await (fn ())
161+ result = await maybe_await (fn ())
180162 return result
181163
182164
@@ -214,13 +196,13 @@ async def poll_until_condition(
214196 `call_with_exp_backoff`, the interval between polls stays constant.
215197 """
216198 deadline = time .monotonic () + timeout
217- result = await _maybe_await (fn ())
199+ result = await maybe_await (fn ())
218200 while not condition (result ):
219201 remaining = deadline - time .monotonic ()
220202 if remaining <= 0 :
221203 break
222204 await asyncio .sleep (min (poll_interval , remaining ))
223- result = await _maybe_await (fn ())
205+ result = await maybe_await (fn ())
224206 return result
225207
226208
0 commit comments