Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,3 @@ dmypy.json
.Trashes
ehthumbs.db
Thumbs.db

# Poetry
poetry.lock
2 changes: 2 additions & 0 deletions make_api_request/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
to_form_urlencoded,
)
from .response import AsyncStreamResponse, StreamResponse, from_encodable
from .retry import RetryStrategy

__all__ = [
"ApiError",
Expand All @@ -49,4 +50,5 @@
"AsyncStreamResponse",
"StreamResponse",
"QueryParams",
"RetryStrategy",
]
67 changes: 62 additions & 5 deletions make_api_request/base_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from asyncio import sleep as async_sleep
from time import sleep
from typing import (
Any,
Dict,
Expand All @@ -17,8 +19,13 @@
from .auth import AuthProvider
from .binary_response import BinaryResponse
from .query import QueryParams
from .request import RequestConfig, RequestOptions, default_request_options
from .request import (
RequestConfig,
RequestOptions,
default_request_options,
)
from .response import AsyncStreamResponse, StreamResponse, from_encodable
from .retry import RetryConfig, RetryStrategy
from .utils import filter_binary_response, get_response_type

NoneType = type(None)
Expand All @@ -43,6 +50,7 @@ def __init__(
self,
base_url: Union[str, Dict[str, str]],
auths: Optional[Dict[str, AuthProvider]] = None,
retries: Optional[RetryStrategy] = None,
):
"""Initialize the base client"""
self._base_url = (
Expand All @@ -51,6 +59,7 @@ def __init__(
else {_DEFAULT_SERVICE_NAME: base_url}
)
self._auths: Dict[str, AuthProvider] = auths or {}
self._retries = retries

def register_auth(self, auth_id: str, provider: AuthProvider) -> None:
"""Register an authentication provider.
Expand Down Expand Up @@ -351,15 +360,37 @@ def __init__(
base_url: Union[str, Dict[str, str]],
httpx_client: httpx.Client,
auths: Optional[Dict[str, AuthProvider]] = None,
retries: Optional[RetryStrategy] = None,
):
"""Initialize the synchronous client.

Args:
httpx_client: Synchronous HTTPX client instance
"""
super().__init__(base_url=base_url, auths=auths)
super().__init__(base_url=base_url, auths=auths, retries=retries)
self.httpx_client = httpx_client

def _request_with_retires(
self,
*,
req_cfg: RequestConfig,
request_options: Optional[RequestOptions] = None,
) -> httpx.Response:
response = self.httpx_client.request(**req_cfg)

retry_override = (request_options or {}).get("retries")
if retry_override or self._retries:
retry = RetryConfig(base=self._retries, override=retry_override)
attempt = 1
delay = float(retry.initial_delay)
while retry.should_retry(attempt=attempt, status_code=response.status_code):
sleep(delay / 1000)
response = self.httpx_client.request(**req_cfg)
delay = retry.calc_next_delay(curr_delay=delay)
attempt += 1

return response

def request(
self,
*,
Expand Down Expand Up @@ -414,7 +445,9 @@ def request(
content=content,
request_options=request_options,
)
response = self.httpx_client.request(**req_cfg)
response = self._request_with_retires(
req_cfg=req_cfg, request_options=request_options
)

if not response.is_success:
raise ApiError(response=response)
Expand Down Expand Up @@ -495,15 +528,37 @@ def __init__(
base_url: Union[str, Dict[str, str]],
httpx_client: httpx.AsyncClient,
auths: Optional[Dict[str, AuthProvider]] = None,
retries: Optional[RetryStrategy] = None,
):
"""Initialize the asynchronous client.

Args:
httpx_client: Asynchronous HTTPX client instance
"""
super().__init__(base_url=base_url, auths=auths)
super().__init__(base_url=base_url, auths=auths, retries=retries)
self.httpx_client = httpx_client

async def _request_with_retires(
self,
*,
req_cfg: RequestConfig,
request_options: Optional[RequestOptions] = None,
) -> httpx.Response:
response = await self.httpx_client.request(**req_cfg)

retry_override = (request_options or {}).get("retries")
if retry_override or self._retries:
retry = RetryConfig(base=self._retries, override=retry_override)
attempt = 1
delay = float(retry.initial_delay)
while retry.should_retry(attempt=attempt, status_code=response.status_code):
await async_sleep(delay / 1000)
response = await self.httpx_client.request(**req_cfg)
delay = retry.calc_next_delay(curr_delay=delay)
attempt += 1

return response

async def request(
self,
*,
Expand Down Expand Up @@ -558,7 +613,9 @@ async def request(
content=content,
request_options=request_options,
)
response = await self.httpx_client.request(**req_cfg)
response = await self._request_with_retires(
req_cfg=req_cfg, request_options=request_options
)

if not response.is_success:
raise ApiError(response=response)
Expand Down
2 changes: 2 additions & 0 deletions make_api_request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing_extensions import NotRequired, Required, TypedDict

from .query import QueryParams, QueryParamStyle, encode_query_param
from .retry import RetryStrategy
from .type_utils import NotGiven

"""
Expand Down Expand Up @@ -54,6 +55,7 @@ class RequestOptions(TypedDict):
timeout: NotRequired[int]
additional_headers: NotRequired[Dict[str, str]]
additional_params: NotRequired[QueryParams]
retries: NotRequired[RetryStrategy]


def default_request_options() -> RequestOptions:
Expand Down
99 changes: 99 additions & 0 deletions make_api_request/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import List, Optional

from typing_extensions import NotRequired, TypedDict


class RetryStrategy(TypedDict):
"""
Configuration for retrying HTTP requests.
"""

max_retries: NotRequired[int]
"""
maximum amount of retries allowed after first request failure. if 5,
the request could be sent a total of 6 times
"""
status_codes: NotRequired[List[int]]
"""
Response status codes that will trigger a retry. These must either be:
- exact status code (100 <= code < 600), e.g. 408, or
- unit (0 < num < 6) that represents a status code range, e.g. 5 -> 5XX
"""
initial_delay: NotRequired[int]
"""
Initial wait time (milliseconds) after first request failure before a retry is sent
"""
max_delay: NotRequired[int]
"""
Maximum wait time between retries
"""
backoff_factor: NotRequired[float]
"""
the factor applied to the current wait time to determine the next wait time
min(current_delay * backoff, max_delay)
"""


class RetryConfig:
max_retries: int
status_codes: List[int]
initial_delay: int
max_delay: int
backoff_factor: float

def __init__(
self,
*,
base: Optional[RetryStrategy] = None,
override: Optional[RetryStrategy] = None
):
_base: RetryStrategy = base or {}
_override: RetryStrategy = override or {}

self.max_retries = _override.get("max_retries", _base.get("max_retries", 5))
self.status_codes = _override.get(
"status_codes",
_base.get(
"status_codes",
[
5, # 5XX
408, # Timeout
409, # Conflict
429, # Too Many Requests
],
),
)
self.initial_delay = _override.get(
"initial_delay", _base.get("initial_delay", 500)
)
self.max_delay = _override.get("max_delay", _base.get("max_delay", 10000))
self.backoff_factor = _override.get(
"backoff_factor", _base.get("backoff_factor", 2.0)
)

def _matches_code(self, status_code: int, retry_code: int) -> bool:
"""
Custom status_code comparison to support exact match and
range matches
"""
if retry_code < 6:
# Range check (e.g., 4 means 400-499)
return retry_code * 100 <= status_code < (retry_code + 1) * 100
else:
# Exact match
return status_code == retry_code

def should_retry(self, *, attempt: int, status_code: int) -> bool:
"""
Checks if a retry is allowed according to the config
"""
return attempt <= self.max_retries and any(
self._matches_code(status_code, c) for c in self.status_codes
)

def calc_next_delay(self, *, curr_delay: float) -> float:
"""
Calculates the time (ms) the retrier should wait before the
next attempt according to the config
"""
return min(float(self.max_delay), curr_delay * self.backoff_factor)
Loading
Loading