diff --git a/src/askui/models/askui/google_genai_api.py b/src/askui/models/askui/google_genai_api.py index f932bcb0..49b72d03 100644 --- a/src/askui/models/askui/google_genai_api.py +++ b/src/askui/models/askui/google_genai_api.py @@ -5,7 +5,14 @@ from google.genai import types as genai_types from google.genai.errors import APIError from pydantic import ValidationError -from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential +from tenacity import ( + RetryCallState, + retry, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) +from tenacity.wait import wait_base from typing_extensions import override from askui.logger import logger @@ -14,12 +21,52 @@ from askui.models.models import GetModel, ModelName from askui.models.shared.prompts import SYSTEM_PROMPT_GET from askui.models.types.response_schemas import ResponseSchema, to_response_schema +from askui.utils.http_utils import parse_retry_after_header from askui.utils.image_utils import ImageSource ASKUI_MODEL_CHOICE_PREFIX = "askui/" ASKUI_MODEL_CHOICE_PREFIX_LEN = len(ASKUI_MODEL_CHOICE_PREFIX) +class _wait_for_retry_after_header(wait_base): + """Wait strategy that tries to wait for the length specified by + the Retry-After header, or the underlying wait strategy if not. + See RFC 6585 ยง 4. + + Otherwise, wait according to the fallback strategy. + """ + + def __init__(self, fallback: wait_base) -> None: + """Initialize the wait strategy with a fallback strategy. + + Args: + fallback (wait_base): The fallback wait strategy to use when + Retry-After header is not available or invalid. + """ + self._fallback = fallback + + def __call__(self, retry_state: RetryCallState) -> float: + """Calculate the wait time based on Retry-After header or fallback. + + Args: + retry_state (RetryCallState): The retry state containing the + exception information. + + Returns: + float: The wait time in seconds. + """ + if outcome := retry_state.outcome: + exc = outcome.exception() + if isinstance(exc, APIError): + retry_after: str | None = exc.response.headers.get("Retry-After") + if retry_after: + try: + return parse_retry_after_header(retry_after) + except ValueError: + pass + return self._fallback(retry_state) + + def _is_retryable_error(exception: BaseException) -> bool: """Check if the exception is a retryable error (status codes 429, 502, or 529). @@ -27,7 +74,7 @@ def _is_retryable_error(exception: BaseException) -> bool: retry it. """ if isinstance(exception, APIError): - return exception.code in (429, 502, 529) + return exception.code in (408, 413, 429, 500, 502, 503, 504, 521, 522, 524) return False @@ -55,7 +102,9 @@ def __init__(self, settings: AskUiInferenceApiSettings | None = None) -> None: @retry( stop=stop_after_attempt(4), # 3 retries - wait=wait_exponential(multiplier=30, min=30, max=120), # 30s, 60s, 120s + wait=_wait_for_retry_after_header( + wait_exponential(multiplier=30, min=30, max=120) + ), # retry after or as a fallback 30s, 60s, 120s retry=retry_if_exception(_is_retryable_error), reraise=True, ) diff --git a/src/askui/utils/http_utils.py b/src/askui/utils/http_utils.py new file mode 100644 index 00000000..3969ce2f --- /dev/null +++ b/src/askui/utils/http_utils.py @@ -0,0 +1,36 @@ +from datetime import datetime, timezone + + +def parse_retry_after_header(retry_after: str) -> float: + """Parse the Retry-After header value. + + The Retry-After header value can be a number of seconds or a date in + RFC 2822 format: + - `` + - `, :: GMT` + + The date must be in GMT timezone. + + Args: + retry_after (str): The Retry-After header value. + + Returns: + float: The number of seconds to wait. + + Raises: + ValueError: If the header value cannot be parsed. + """ + try: + return float(retry_after) + except (TypeError, ValueError): + try: + dt = datetime.strptime( # noqa: DTZ007 + retry_after, + "%a, %d %b %Y %H:%M:%S GMT", + ) + dt = dt.replace(tzinfo=timezone.utc) + now = datetime.now(timezone.utc) + return (dt - now).total_seconds() + except (TypeError, ValueError) as e: + error_msg = f"Could not parse Retry-After header: {retry_after}" + raise ValueError(error_msg) from e diff --git a/tests/unit/utils/test_http_utils.py b/tests/unit/utils/test_http_utils.py new file mode 100644 index 00000000..03c895b3 --- /dev/null +++ b/tests/unit/utils/test_http_utils.py @@ -0,0 +1,145 @@ +from datetime import datetime, timezone +from unittest.mock import patch + +import pytest + +from askui.utils.http_utils import parse_retry_after_header + + +class TestParseRetryAfterHeader: + """Test cases for the `parse_retry_after_header` function.""" + + def test_parse_numeric_seconds(self) -> None: + """Test parsing numeric retry-after values.""" + assert parse_retry_after_header("30") == 30.0 + assert parse_retry_after_header("60.5") == 60.5 + assert parse_retry_after_header("0") == 0.0 + assert parse_retry_after_header("120") == 120.0 + + def test_parse_rfc2822_date_format(self) -> None: + """Test parsing RFC 2822 date format retry-after values.""" + # Test with a future date + future_date = "Mon, 15 Jan 2024 12:00:00 GMT" + with patch("askui.utils.http_utils.datetime") as mock_datetime: + # Mock current time to be before the retry-after date + mock_datetime.now.return_value = datetime( + 2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc + ) + mock_datetime.strptime = datetime.strptime + + result = parse_retry_after_header(future_date) + assert result > 0 # Should be positive seconds + + def test_parse_rfc2822_date_format_past_date(self) -> None: + """Test parsing RFC 2822 date format with past date.""" + past_date = "Mon, 15 Jan 2024 10:00:00 GMT" + with patch("askui.utils.http_utils.datetime") as mock_datetime: + # Mock current time to be after the retry-after date + mock_datetime.now.return_value = datetime( + 2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc + ) + mock_datetime.strptime = datetime.strptime + + result = parse_retry_after_header(past_date) + assert result < 0 # Should be negative seconds (past date) + + def test_parse_rfc2822_date_format_exact_time(self) -> None: + """Test parsing RFC 2822 date format with exact current time.""" + exact_date = "Mon, 15 Jan 2024 11:00:00 GMT" + with patch("askui.utils.http_utils.datetime") as mock_datetime: + # Mock current time to be exactly the retry-after date + mock_datetime.now.return_value = datetime( + 2024, 1, 15, 11, 0, 0, tzinfo=timezone.utc + ) + mock_datetime.strptime = datetime.strptime + + result = parse_retry_after_header(exact_date) + assert result == 0.0 # Should be zero seconds + + def test_parse_invalid_numeric_input(self) -> None: + """Test parsing invalid numeric inputs.""" + with pytest.raises( + ValueError, match="Could not parse Retry-After header: invalid" + ): + parse_retry_after_header("invalid") + + def test_parse_invalid_date_format(self) -> None: + """Test parsing invalid date format inputs.""" + with pytest.raises( + ValueError, match="Could not parse Retry-After header: invalid_date" + ): + parse_retry_after_header("invalid_date") + + def test_parse_empty_string(self) -> None: + """Test parsing empty string input.""" + with pytest.raises(ValueError, match="Could not parse Retry-After header: "): + parse_retry_after_header("") + + def test_parse_whitespace_string(self) -> None: + """Test parsing whitespace-only string input.""" + with pytest.raises(ValueError, match="Could not parse Retry-After header: "): + parse_retry_after_header(" ") + + def test_parse_none_input(self) -> None: + """Test parsing None input.""" + with pytest.raises( + ValueError, match="Could not parse Retry-After header: None" + ): + parse_retry_after_header(None) # type: ignore + + def test_parse_malformed_date(self) -> None: + """Test parsing malformed date string.""" + malformed_date = "Mon, 15 Jan 2024 25:00:00 GMT" # Invalid hour + with pytest.raises( + ValueError, + match="Could not parse Retry-After header: Mon, 15 Jan 2024 25:00:00 GMT", + ): + parse_retry_after_header(malformed_date) + + def test_parse_date_without_gmt(self) -> None: + """Test parsing date without GMT timezone.""" + date_without_gmt = "Mon, 15 Jan 2024 12:00:00" + with pytest.raises( + ValueError, + match="Could not parse Retry-After header: Mon, 15 Jan 2024 12:00:00", + ): + parse_retry_after_header(date_without_gmt) + + def test_parse_negative_numeric(self) -> None: + """Test parsing negative numeric values.""" + assert parse_retry_after_header("-30") == -30.0 + assert parse_retry_after_header("-60.5") == -60.5 + + def test_parse_large_numeric(self) -> None: + """Test parsing large numeric values.""" + assert parse_retry_after_header("86400") == 86400.0 # 24 hours + assert parse_retry_after_header("31536000") == 31536000.0 # 1 year + + def test_parse_scientific_notation(self) -> None: + """Test parsing scientific notation.""" + assert parse_retry_after_header("1e3") == 1000.0 + assert parse_retry_after_header("1.5e2") == 150.0 + + def test_parse_edge_case_dates(self) -> None: + """Test parsing edge case date formats.""" + # Test leap year date + leap_year_date = "Mon, 29 Feb 2024 12:00:00 GMT" + with patch("askui.utils.http_utils.datetime") as mock_datetime: + mock_datetime.now.return_value = datetime( + 2024, 2, 29, 11, 0, 0, tzinfo=timezone.utc + ) + mock_datetime.strptime = datetime.strptime + + result = parse_retry_after_header(leap_year_date) + assert result > 0 + + # Test end of year date + end_year_date = "Mon, 31 Dec 2024 23:59:59 GMT" + with patch("askui.utils.http_utils.datetime") as mock_datetime: + mock_datetime.now.return_value = datetime( + 2024, 12, 31, 23, 0, 0, tzinfo=timezone.utc + ) + mock_datetime.strptime = datetime.strptime + + result = parse_retry_after_header(end_year_date) + assert result > 0