Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions src/askui/models/askui/google_genai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from google.genai import types as genai_types
from google.genai.errors import APIError
from pydantic import ValidationError
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
from tenacity import (
RetryCallState,
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
from tenacity.wait import wait_base
from typing_extensions import override

from askui.logger import logger
Expand All @@ -14,20 +21,60 @@
from askui.models.models import GetModel, ModelName
from askui.models.shared.prompts import SYSTEM_PROMPT_GET
from askui.models.types.response_schemas import ResponseSchema, to_response_schema
from askui.utils.http_utils import parse_retry_after_header
from askui.utils.image_utils import ImageSource

ASKUI_MODEL_CHOICE_PREFIX = "askui/"
ASKUI_MODEL_CHOICE_PREFIX_LEN = len(ASKUI_MODEL_CHOICE_PREFIX)


class _wait_for_retry_after_header(wait_base):
"""Wait strategy that tries to wait for the length specified by
the Retry-After header, or the underlying wait strategy if not.
See RFC 6585 § 4.

Otherwise, wait according to the fallback strategy.
"""

def __init__(self, fallback: wait_base) -> None:
"""Initialize the wait strategy with a fallback strategy.

Args:
fallback (wait_base): The fallback wait strategy to use when
Retry-After header is not available or invalid.
"""
self._fallback = fallback

def __call__(self, retry_state: RetryCallState) -> float:
"""Calculate the wait time based on Retry-After header or fallback.

Args:
retry_state (RetryCallState): The retry state containing the
exception information.

Returns:
float: The wait time in seconds.
"""
if outcome := retry_state.outcome:
exc = outcome.exception()
if isinstance(exc, APIError):
retry_after: str | None = exc.response.headers.get("Retry-After")
if retry_after:
try:
return parse_retry_after_header(retry_after)
except ValueError:
pass
return self._fallback(retry_state)


def _is_retryable_error(exception: BaseException) -> bool:
"""Check if the exception is a retryable error (status codes 429, 502, or 529).

The 502 status of the AskUI Inference API is usually temporary which is why we also
retry it.
"""
if isinstance(exception, APIError):
return exception.code in (429, 502, 529)
return exception.code in (408, 413, 429, 500, 502, 503, 504, 521, 522, 524)
return False


Expand Down Expand Up @@ -55,7 +102,9 @@ def __init__(self, settings: AskUiInferenceApiSettings | None = None) -> None:

@retry(
stop=stop_after_attempt(4), # 3 retries
wait=wait_exponential(multiplier=30, min=30, max=120), # 30s, 60s, 120s
wait=_wait_for_retry_after_header(
wait_exponential(multiplier=30, min=30, max=120)
), # retry after or as a fallback 30s, 60s, 120s
retry=retry_if_exception(_is_retryable_error),
reraise=True,
)
Expand Down
36 changes: 36 additions & 0 deletions src/askui/utils/http_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from datetime import datetime, timezone


def parse_retry_after_header(retry_after: str) -> float:
"""Parse the Retry-After header value.

The Retry-After header value can be a number of seconds or a date in
RFC 2822 format:
- `<number>`
- `<day-name>, <day> <month> <year> <hour>:<minute>:<second> GMT`

The date must be in GMT timezone.

Args:
retry_after (str): The Retry-After header value.

Returns:
float: The number of seconds to wait.

Raises:
ValueError: If the header value cannot be parsed.
"""
try:
return float(retry_after)
except (TypeError, ValueError):
try:
dt = datetime.strptime( # noqa: DTZ007
retry_after,
"%a, %d %b %Y %H:%M:%S GMT",
)
dt = dt.replace(tzinfo=timezone.utc)
now = datetime.now(timezone.utc)
return (dt - now).total_seconds()
except (TypeError, ValueError) as e:
error_msg = f"Could not parse Retry-After header: {retry_after}"
raise ValueError(error_msg) from e
145 changes: 145 additions & 0 deletions tests/unit/utils/test_http_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from datetime import datetime, timezone
from unittest.mock import patch

import pytest

from askui.utils.http_utils import parse_retry_after_header


class TestParseRetryAfterHeader:
"""Test cases for the `parse_retry_after_header` function."""

def test_parse_numeric_seconds(self) -> None:
"""Test parsing numeric retry-after values."""
assert parse_retry_after_header("30") == 30.0
assert parse_retry_after_header("60.5") == 60.5
assert parse_retry_after_header("0") == 0.0
assert parse_retry_after_header("120") == 120.0

def test_parse_rfc2822_date_format(self) -> None:
"""Test parsing RFC 2822 date format retry-after values."""
# Test with a future date
future_date = "Mon, 15 Jan 2024 12:00:00 GMT"
with patch("askui.utils.http_utils.datetime") as mock_datetime:
# Mock current time to be before the retry-after date
mock_datetime.now.return_value = datetime(
2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc
)
mock_datetime.strptime = datetime.strptime

result = parse_retry_after_header(future_date)
assert result > 0 # Should be positive seconds

def test_parse_rfc2822_date_format_past_date(self) -> None:
"""Test parsing RFC 2822 date format with past date."""
past_date = "Mon, 15 Jan 2024 10:00:00 GMT"
with patch("askui.utils.http_utils.datetime") as mock_datetime:
# Mock current time to be after the retry-after date
mock_datetime.now.return_value = datetime(
2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc
)
mock_datetime.strptime = datetime.strptime

result = parse_retry_after_header(past_date)
assert result < 0 # Should be negative seconds (past date)

def test_parse_rfc2822_date_format_exact_time(self) -> None:
"""Test parsing RFC 2822 date format with exact current time."""
exact_date = "Mon, 15 Jan 2024 11:00:00 GMT"
with patch("askui.utils.http_utils.datetime") as mock_datetime:
# Mock current time to be exactly the retry-after date
mock_datetime.now.return_value = datetime(
2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc
)
mock_datetime.strptime = datetime.strptime

result = parse_retry_after_header(exact_date)
assert result == 0.0 # Should be zero seconds

def test_parse_invalid_numeric_input(self) -> None:
"""Test parsing invalid numeric inputs."""
with pytest.raises(
ValueError, match="Could not parse Retry-After header: invalid"
):
parse_retry_after_header("invalid")

def test_parse_invalid_date_format(self) -> None:
"""Test parsing invalid date format inputs."""
with pytest.raises(
ValueError, match="Could not parse Retry-After header: invalid_date"
):
parse_retry_after_header("invalid_date")

def test_parse_empty_string(self) -> None:
"""Test parsing empty string input."""
with pytest.raises(ValueError, match="Could not parse Retry-After header: "):
parse_retry_after_header("")

def test_parse_whitespace_string(self) -> None:
"""Test parsing whitespace-only string input."""
with pytest.raises(ValueError, match="Could not parse Retry-After header: "):
parse_retry_after_header(" ")

def test_parse_none_input(self) -> None:
"""Test parsing None input."""
with pytest.raises(
ValueError, match="Could not parse Retry-After header: None"
):
parse_retry_after_header(None) # type: ignore

def test_parse_malformed_date(self) -> None:
"""Test parsing malformed date string."""
malformed_date = "Mon, 15 Jan 2024 25:00:00 GMT" # Invalid hour
with pytest.raises(
ValueError,
match="Could not parse Retry-After header: Mon, 15 Jan 2024 25:00:00 GMT",
):
parse_retry_after_header(malformed_date)

def test_parse_date_without_gmt(self) -> None:
"""Test parsing date without GMT timezone."""
date_without_gmt = "Mon, 15 Jan 2024 12:00:00"
with pytest.raises(
ValueError,
match="Could not parse Retry-After header: Mon, 15 Jan 2024 12:00:00",
):
parse_retry_after_header(date_without_gmt)

def test_parse_negative_numeric(self) -> None:
"""Test parsing negative numeric values."""
assert parse_retry_after_header("-30") == -30.0
assert parse_retry_after_header("-60.5") == -60.5

def test_parse_large_numeric(self) -> None:
"""Test parsing large numeric values."""
assert parse_retry_after_header("86400") == 86400.0 # 24 hours
assert parse_retry_after_header("31536000") == 31536000.0 # 1 year

def test_parse_scientific_notation(self) -> None:
"""Test parsing scientific notation."""
assert parse_retry_after_header("1e3") == 1000.0
assert parse_retry_after_header("1.5e2") == 150.0

def test_parse_edge_case_dates(self) -> None:
"""Test parsing edge case date formats."""
# Test leap year date
leap_year_date = "Mon, 29 Feb 2024 12:00:00 GMT"
with patch("askui.utils.http_utils.datetime") as mock_datetime:
mock_datetime.now.return_value = datetime(
2024, 2, 29, 11, 0, 0, tzinfo=timezone.utc
)
mock_datetime.strptime = datetime.strptime

result = parse_retry_after_header(leap_year_date)
assert result > 0

# Test end of year date
end_year_date = "Mon, 31 Dec 2024 23:59:59 GMT"
with patch("askui.utils.http_utils.datetime") as mock_datetime:
mock_datetime.now.return_value = datetime(
2024, 12, 31, 23, 0, 0, tzinfo=timezone.utc
)
mock_datetime.strptime = datetime.strptime

result = parse_retry_after_header(end_year_date)
assert result > 0