From ecfa39a00539d0ce5e2c394b96bc9a2f958d744d Mon Sep 17 00:00:00 2001 From: bot-of-qin-ctx Date: Tue, 31 Mar 2026 15:00:34 +0800 Subject: [PATCH] Revert "feat: unify config-driven retry across VLM and embedding (#1049)" This reverts commit a092d64867a5a69d89cabb11b9f735476e75cf22. --- openviking/models/embedder/base.py | 3 +- .../models/embedder/gemini_embedders.py | 40 +- openviking/models/embedder/jina_embedders.py | 13 +- .../models/embedder/litellm_embedders.py | 13 +- .../models/embedder/minimax_embedders.py | 21 +- .../models/embedder/openai_embedders.py | 16 +- .../models/embedder/vikingdb_embedders.py | 35 +- .../models/embedder/volcengine_embedders.py | 53 ++- .../models/embedder/voyage_embedders.py | 13 +- openviking/models/retry.py | 287 ------------ openviking/models/vlm/backends/litellm_vlm.py | 49 +- openviking/models/vlm/backends/openai_vlm.py | 70 ++- .../models/vlm/backends/volcengine_vlm.py | 50 +- openviking/models/vlm/base.py | 4 +- openviking/models/vlm/llm.py | 32 +- .../utils/config/embedding_config.py | 8 - openviking_cli/utils/config/vlm_config.py | 29 +- tests/models/test_vlm_strip_think_tags.py | 2 +- tests/unit/test_backward_compat.py | 167 ------- .../unit/test_embedding_retry_integration.py | 230 --------- tests/unit/test_extra_headers_vlm.py | 4 +- tests/unit/test_retry.py | 441 ------------------ tests/unit/test_retry_config.py | 81 ---- tests/unit/test_stream_config_vlm.py | 4 +- tests/unit/test_vlm_retry_integration.py | 314 ------------- 25 files changed, 194 insertions(+), 1785 deletions(-) delete mode 100644 openviking/models/retry.py delete mode 100644 tests/unit/test_backward_compat.py delete mode 100644 tests/unit/test_embedding_retry_integration.py delete mode 100644 tests/unit/test_retry.py delete mode 100644 tests/unit/test_retry_config.py delete mode 100644 tests/unit/test_vlm_retry_integration.py diff --git a/openviking/models/embedder/base.py b/openviking/models/embedder/base.py index 62370c37a..9a9e5df3e 100644 --- a/openviking/models/embedder/base.py +++ b/openviking/models/embedder/base.py @@ -74,7 +74,6 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None): """ self.model_name = model_name self.config = config or {} - self.max_retries = self.config.get("max_retries", 3) if self.config else 3 @abstractmethod def embed(self, text: str, is_query: bool = False) -> EmbedResult: @@ -256,7 +255,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes return [ EmbedResult(dense_vector=d.dense_vector, sparse_vector=s.sparse_vector) - for d, s in zip(dense_results, sparse_results, strict=False) + for d, s in zip(dense_results, sparse_results) ] def get_dimension(self) -> int: diff --git a/openviking/models/embedder/gemini_embedders.py b/openviking/models/embedder/gemini_embedders.py index 6bc1e1ab6..c11d73e47 100644 --- a/openviking/models/embedder/gemini_embedders.py +++ b/openviking/models/embedder/gemini_embedders.py @@ -29,7 +29,6 @@ EmbedResult, truncate_and_normalize, ) -from openviking.models.retry import transient_retry logger = logging.getLogger("gemini_embedders") @@ -147,13 +146,15 @@ def __init__( ) if dimension is not None and not (1 <= dimension <= 3072): raise ValueError(f"dimension must be between 1 and 3072, got {dimension}") - # Disable SDK-level retry; we use transient_retry for unified retry logic if _HTTP_RETRY_AVAILABLE: self.client = genai.Client( api_key=api_key, http_options=HttpOptions( retry_options=HttpRetryOptions( - attempts=1, + attempts=3, + initial_delay=1.0, + max_delay=30.0, + exp_base=2.0, ) ), ) @@ -208,16 +209,11 @@ def embed( task_type = self.document_param # SDK accepts plain str; converts to REST Parts format internally. try: - embed_config = self._build_config(task_type=task_type, title=title) - - def _call(): - return self.client.models.embed_content( - model=self.model_name, - contents=text, - config=embed_config, - ) - - result = transient_retry(_call, max_retries=self.max_retries) + result = self.client.models.embed_content( + model=self.model_name, + contents=text, + config=self._build_config(task_type=task_type, title=title), + ) vector = truncate_and_normalize(list(result.embeddings[0].values), self._dimension) return EmbedResult(dense_vector=vector) except (APIError, ClientError) as e: @@ -237,7 +233,7 @@ def embed_batch( if titles is not None: return [ self.embed(text, is_query=is_query, task_type=task_type, title=title) - for text, title in zip(texts, titles, strict=False) + for text, title in zip(texts, titles) ] # Resolve effective task_type from is_query when no explicit override if task_type is None: @@ -258,17 +254,13 @@ def embed_batch( non_empty_texts = [batch[j] for j in non_empty_indices] try: - - def _batch_call(texts=non_empty_texts, cfg=config): - return self.client.models.embed_content( - model=self.model_name, - contents=texts, - config=cfg, - ) - - response = transient_retry(_batch_call, max_retries=self.max_retries) + response = self.client.models.embed_content( + model=self.model_name, + contents=non_empty_texts, + config=config, + ) batch_results = [None] * len(batch) - for j, emb in zip(non_empty_indices, response.embeddings, strict=False): + for j, emb in zip(non_empty_indices, response.embeddings): batch_results[j] = EmbedResult( dense_vector=truncate_and_normalize(list(emb.values), self._dimension) ) diff --git a/openviking/models/embedder/jina_embedders.py b/openviking/models/embedder/jina_embedders.py index 25159f2df..f94650fcc 100644 --- a/openviking/models/embedder/jina_embedders.py +++ b/openviking/models/embedder/jina_embedders.py @@ -10,7 +10,6 @@ DenseEmbedderBase, EmbedResult, ) -from openviking.models.retry import transient_retry # Default dimensions for Jina embedding models JINA_MODEL_DIMENSIONS = { @@ -114,11 +113,9 @@ def __init__( raise ValueError("api_key is required") # Initialize OpenAI-compatible client with Jina base URL - # Disable SDK retry; we use transient_retry for unified retry logic self.client = openai.OpenAI( api_key=self.api_key, base_url=self.api_base, - max_retries=0, ) # Determine dimension @@ -177,10 +174,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if extra_body: kwargs["extra_body"] = extra_body - def _call(): - return self.client.embeddings.create(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = self.client.embeddings.create(**kwargs) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) @@ -215,10 +209,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if extra_body: kwargs["extra_body"] = extra_body - def _call(): - return self.client.embeddings.create(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = self.client.embeddings.create(**kwargs) return [EmbedResult(dense_vector=item.embedding) for item in response.data] except openai.APIError as e: diff --git a/openviking/models/embedder/litellm_embedders.py b/openviking/models/embedder/litellm_embedders.py index 903b2fc67..ea24c8141 100644 --- a/openviking/models/embedder/litellm_embedders.py +++ b/openviking/models/embedder/litellm_embedders.py @@ -13,7 +13,6 @@ import litellm from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult -from openviking.models.retry import transient_retry from openviking.telemetry import get_current_telemetry logger = logging.getLogger(__name__) @@ -158,11 +157,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: try: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = [text] - - def _call(): - return litellm.embedding(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = litellm.embedding(**kwargs) self._update_telemetry_token_usage(response) vector = response.data[0]["embedding"] return EmbedResult(dense_vector=vector) @@ -188,11 +183,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes try: kwargs = self._build_kwargs(is_query=is_query) kwargs["input"] = texts - - def _call(): - return litellm.embedding(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = litellm.embedding(**kwargs) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item["embedding"]) for item in response.data] except Exception as e: diff --git a/openviking/models/embedder/minimax_embedders.py b/openviking/models/embedder/minimax_embedders.py index 1547b13af..aba462968 100644 --- a/openviking/models/embedder/minimax_embedders.py +++ b/openviking/models/embedder/minimax_embedders.py @@ -9,7 +9,6 @@ from urllib3.util.retry import Retry from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult -from openviking.models.retry import transient_retry from openviking_cli.utils.logger import default_logger as logger @@ -90,8 +89,12 @@ def __init__( def _create_session(self) -> requests.Session: """Create a requests session with retry logic""" session = requests.Session() - # Disable transport-level retry; we use transient_retry for unified retry logic - retry_strategy = Retry(total=0) + retry_strategy = Retry( + total=6, + backoff_factor=1, # 1s, 2s, 4s, 8s, 16s, 32s + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["POST"], + ) adapter = HTTPAdapter(max_retries=retry_strategy) session.mount("https://", adapter) session.mount("http://", adapter) @@ -160,10 +163,7 @@ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float def embed(self, text: str, is_query: bool = False) -> EmbedResult: """Perform dense embedding on text""" - vectors = transient_retry( - lambda: self._call_api([text], is_query=is_query), - max_retries=self.max_retries, - ) + vectors = self._call_api([text], is_query=is_query) return EmbedResult(dense_vector=vectors[0]) def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: @@ -171,10 +171,9 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - vectors = transient_retry( - lambda: self._call_api(texts, is_query=is_query), - max_retries=self.max_retries, - ) + # MiniMax might have batch size limits, but let's assume the caller handles batching or use safe defaults + # For now, we pass through. If needed, we can implement internal chunking. + vectors = self._call_api(texts, is_query=is_query) return [EmbedResult(dense_vector=v) for v in vectors] def get_dimension(self) -> int: diff --git a/openviking/models/embedder/openai_embedders.py b/openviking/models/embedder/openai_embedders.py index 0477270a6..0ebabbeee 100644 --- a/openviking/models/embedder/openai_embedders.py +++ b/openviking/models/embedder/openai_embedders.py @@ -12,7 +12,6 @@ HybridEmbedderBase, SparseEmbedderBase, ) -from openviking.models.retry import transient_retry from openviking.models.vlm.registry import DEFAULT_AZURE_API_VERSION from openviking.telemetry import get_current_telemetry @@ -119,10 +118,7 @@ def __init__( if not self.api_key and not self.api_base: raise ValueError("api_key is required") - client_kwargs: Dict[str, Any] = { - "api_key": self.api_key or "no-key", - "max_retries": 0, # Disable SDK retry; we use transient_retry - } + client_kwargs: Dict[str, Any] = {"api_key": self.api_key or "no-key"} if self._provider == "azure": if not self.api_base: raise ValueError("api_base (Azure endpoint) is required for Azure provider") @@ -246,10 +242,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if extra_body: kwargs["extra_body"] = extra_body - def _call(): - return self.client.embeddings.create(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = self.client.embeddings.create(**kwargs) self._update_telemetry_token_usage(response) vector = response.data[0].embedding @@ -284,10 +277,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if extra_body: kwargs["extra_body"] = extra_body - def _call(): - return self.client.embeddings.create(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = self.client.embeddings.create(**kwargs) self._update_telemetry_token_usage(response) return [EmbedResult(dense_vector=item.embedding) for item in response.data] diff --git a/openviking/models/embedder/vikingdb_embedders.py b/openviking/models/embedder/vikingdb_embedders.py index d9c5cf49b..0253af9dc 100644 --- a/openviking/models/embedder/vikingdb_embedders.py +++ b/openviking/models/embedder/vikingdb_embedders.py @@ -10,7 +10,6 @@ HybridEmbedderBase, SparseEmbedderBase, ) -from openviking.models.retry import transient_retry from openviking.storage.vectordb.collection.volcengine_clients import ClientForDataApi from openviking_cli.utils.logger import default_logger as logger @@ -125,10 +124,7 @@ def __init__( self.dense_model = {"name": model_name, "version": model_version, "dim": dimension} def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = transient_retry( - lambda: self._call_api([text], dense_model=self.dense_model), - max_retries=self.max_retries, - ) + results = self._call_api([text], dense_model=self.dense_model) if not results: return EmbedResult(dense_vector=[]) @@ -142,10 +138,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = transient_retry( - lambda: self._call_api(texts, dense_model=self.dense_model), - max_retries=self.max_retries, - ) + raw_results = self._call_api(texts, dense_model=self.dense_model) return [ EmbedResult( dense_vector=self._truncate_and_normalize( @@ -181,10 +174,7 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = transient_retry( - lambda: self._call_api([text], sparse_model=self.sparse_model), - max_retries=self.max_retries, - ) + results = self._call_api([text], sparse_model=self.sparse_model) if not results: return EmbedResult(sparse_vector={}) @@ -198,10 +188,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = transient_retry( - lambda: self._call_api(texts, sparse_model=self.sparse_model), - max_retries=self.max_retries, - ) + raw_results = self._call_api(texts, sparse_model=self.sparse_model) return [ EmbedResult( sparse_vector=self._process_sparse_embedding(item.get("sparse_embedding", {})) @@ -237,11 +224,8 @@ def __init__( } def embed(self, text: str, is_query: bool = False) -> EmbedResult: - results = transient_retry( - lambda: self._call_api( - [text], dense_model=self.dense_model, sparse_model=self.sparse_model - ), - max_retries=self.max_retries, + results = self._call_api( + [text], dense_model=self.dense_model, sparse_model=self.sparse_model ) if not results: return EmbedResult(dense_vector=[], sparse_vector={}) @@ -260,11 +244,8 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: if not texts: return [] - raw_results = transient_retry( - lambda: self._call_api( - texts, dense_model=self.dense_model, sparse_model=self.sparse_model - ), - max_retries=self.max_retries, + raw_results = self._call_api( + texts, dense_model=self.dense_model, sparse_model=self.sparse_model ) results = [] for item in raw_results: diff --git a/openviking/models/embedder/volcengine_embedders.py b/openviking/models/embedder/volcengine_embedders.py index 15ea42cca..7a3ef6f4e 100644 --- a/openviking/models/embedder/volcengine_embedders.py +++ b/openviking/models/embedder/volcengine_embedders.py @@ -11,13 +11,29 @@ EmbedResult, HybridEmbedderBase, SparseEmbedderBase, + exponential_backoff_retry, truncate_and_normalize, ) -from openviking.models.retry import transient_retry from openviking.telemetry import get_current_telemetry from openviking_cli.utils.logger import default_logger as logger +def is_429_error(exception: Exception) -> bool: + """ + 判断异常是否为 429 限流错误 + + Args: + exception: 要检查的异常 + + Returns: + 如果是 429 错误则返回 True,否则返回 False + """ + exception_str = str(exception) + return ( + "429" in exception_str or "TooManyRequests" in exception_str or "RateLimit" in exception_str + ) + + def process_sparse_embedding(sparse_data: Any) -> Dict[str, float]: """Process sparse embedding data from SDK response""" if not sparse_data: @@ -161,7 +177,15 @@ def _embed_call(): return EmbedResult(dense_vector=vector) try: - return transient_retry(_embed_call, max_retries=self.max_retries) + return exponential_backoff_retry( + _embed_call, + max_wait=10.0, + base_delay=0.5, + max_delay=2.0, + jitter=True, + is_retryable=is_429_error, + logger=logger, + ) except Exception as e: raise RuntimeError(f"Volcengine embedding failed: {str(e)}") from e @@ -181,7 +205,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if not texts: return [] - def _batch_call(): + try: if self.input_type == "multimodal": multimodal_inputs = [{"type": "text", "text": text} for text in texts] response = self.client.multimodal_embeddings.create( @@ -198,9 +222,6 @@ def _batch_call(): EmbedResult(dense_vector=truncate_and_normalize(item.embedding, self.dimension)) for item in data ] - - try: - return transient_retry(_batch_call, max_retries=self.max_retries) except Exception as e: logger.error( f"Volcengine batch embedding failed, texts length: {len(texts)}, input_type: {self.input_type}, model_name: {self.model_name}" @@ -274,7 +295,15 @@ def _embed_call(): return EmbedResult(sparse_vector=process_sparse_embedding(sparse_vector)) try: - return transient_retry(_embed_call, max_retries=self.max_retries) + return exponential_backoff_retry( + _embed_call, + max_wait=10.0, + base_delay=0.5, + max_delay=2.0, + jitter=True, + is_retryable=is_429_error, + logger=logger, + ) except Exception as e: raise RuntimeError(f"Volcengine sparse embedding failed: {str(e)}") from e @@ -371,7 +400,15 @@ def _embed_call(): ) try: - return transient_retry(_embed_call, max_retries=self.max_retries) + return exponential_backoff_retry( + _embed_call, + max_wait=10.0, + base_delay=0.5, + max_delay=2.0, + jitter=True, + is_retryable=is_429_error, + logger=logger, + ) except Exception as e: raise RuntimeError(f"Volcengine hybrid embedding failed: {str(e)}") from e diff --git a/openviking/models/embedder/voyage_embedders.py b/openviking/models/embedder/voyage_embedders.py index db8b85b3a..ed8d49f04 100644 --- a/openviking/models/embedder/voyage_embedders.py +++ b/openviking/models/embedder/voyage_embedders.py @@ -7,7 +7,6 @@ import openai from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult -from openviking.models.retry import transient_retry VOYAGE_MODEL_DIMENSIONS = { "voyage-3": 1024, @@ -75,11 +74,9 @@ def __init__( f"Supported dimensions: {supported}." ) - # Disable SDK retry; we use transient_retry for unified retry logic self.client = openai.OpenAI( api_key=self.api_key, base_url=self.api_base, - max_retries=0, ) self._dimension = dimension or get_voyage_model_default_dimension(normalized_model_name) @@ -91,10 +88,7 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult: if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} - def _call(): - return self.client.embeddings.create(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = self.client.embeddings.create(**kwargs) vector = response.data[0].embedding return EmbedResult(dense_vector=vector) except openai.APIError as e: @@ -112,10 +106,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes if self.dimension is not None: kwargs["extra_body"] = {"output_dimension": self.dimension} - def _call(): - return self.client.embeddings.create(**kwargs) - - response = transient_retry(_call, max_retries=self.max_retries) + response = self.client.embeddings.create(**kwargs) return [EmbedResult(dense_vector=item.embedding) for item in response.data] except openai.APIError as e: raise RuntimeError(f"Voyage API error: {e.message}") from e diff --git a/openviking/models/retry.py b/openviking/models/retry.py deleted file mode 100644 index bf4b4e697..000000000 --- a/openviking/models/retry.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -"""Unified retry logic for VLM backends and embedding providers. - -Provides three public helpers: - -- ``is_transient_error`` — classifies an exception as transient (retryable) - or permanent (should propagate immediately). -- ``transient_retry`` — synchronous retry loop with exponential backoff. -- ``transient_retry_async`` — asynchronous counterpart using ``asyncio.sleep``. - -Transient errors are those that may resolve on their own (rate-limits, temporary -server errors, network resets). Permanent errors indicate a caller mistake -(bad auth, invalid input) and should never be retried. - -Usage example:: - - result = transient_retry(lambda: client.chat(...), max_retries=3) - result = await transient_retry_async(lambda: client.chat_async(...), max_retries=3) -""" - -from __future__ import annotations - -import asyncio -import logging -import random -import time -from collections.abc import Callable -from typing import Optional, TypeVar - -logger = logging.getLogger("openviking.models.retry") - -T = TypeVar("T") - -# --------------------------------------------------------------------------- -# Status code helpers -# --------------------------------------------------------------------------- - -_TRANSIENT_STATUS_CODES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) -_PERMANENT_STATUS_CODES: frozenset[int] = frozenset({400, 401, 403, 404, 422}) - -# String patterns — permanent check runs first (more specific) -_PERMANENT_STR_PATTERNS: tuple[str, ...] = ( - "InvalidRequestError", - "AuthenticationError", -) -_TRANSIENT_STR_PATTERNS: tuple[str, ...] = ( - "TooManyRequests", - "RateLimit", - "RequestBurstTooFast", - "timed out", - "timeout", -) - - -def _extract_status_code(exc: Exception) -> int | None: - """Return numeric HTTP status from common status-bearing attributes. - - Checks ``.status_code``, ``.code``, and ``.http_status`` in that order. - Returns ``None`` if none of the attributes exist or hold an integer. - """ - for attr in ("status_code", "code", "http_status"): - value = getattr(exc, attr, None) - if isinstance(value, int): - return value - return None - - -# --------------------------------------------------------------------------- -# is_transient_error -# --------------------------------------------------------------------------- - - -def is_transient_error(exc: Exception) -> bool: - """Classify an exception as transient (retryable) or permanent. - - Evaluation order: - 1. Extract numeric status code from the exception attributes; check - permanent codes first, then transient codes. - 2. Check the exception type directly (built-in connection / timeout types). - 3. Scan ``str(exc)`` for known permanent string patterns, then transient - ones. - 4. Attempt to import ``openai`` and check against its error hierarchy. - 5. Default to ``False`` (conservative — unknown errors are not retried). - - Args: - exc: The exception to classify. - - Returns: - ``True`` if the error is likely transient and worth retrying. - ``False`` for permanent errors or any unrecognised exception. - """ - # ── 1. Numeric status code ──────────────────────────────────────────── - status = _extract_status_code(exc) - if status is not None: - if status in _PERMANENT_STATUS_CODES: - return False - if status in _TRANSIENT_STATUS_CODES: - return True - - # ── 2. Exception type ───────────────────────────────────────────────── - # asyncio.TimeoutError is a subclass of TimeoutError on 3.11+, but treat - # both explicitly for clarity on 3.10. - if isinstance(exc, (ConnectionError, ConnectionResetError, ConnectionRefusedError)): - return True - if isinstance(exc, (TimeoutError, asyncio.TimeoutError)): - return True - - # ── 3. String patterns ──────────────────────────────────────────────── - message = str(exc) - - for pattern in _PERMANENT_STR_PATTERNS: - if pattern in message: - return False - - for pattern in _TRANSIENT_STR_PATTERNS: - if pattern in message: - return True - - # ── 4. openai error types (optional dependency) ─────────────────────── - try: - import openai # type: ignore[import-untyped] - - # Permanent openai errors — check before transient - if isinstance(exc, openai.AuthenticationError): - return False - - # Transient openai errors - if isinstance( - exc, (openai.RateLimitError, openai.APITimeoutError, openai.APIConnectionError) - ): - return True - except ImportError: - pass - - # ── 5. Default: do not retry unknown errors ─────────────────────────── - return False - - -# --------------------------------------------------------------------------- -# transient_retry (sync) -# --------------------------------------------------------------------------- - - -def transient_retry( - func: Callable[[], T], - max_retries: int = 3, - base_delay: float = 0.5, - max_delay: float = 8.0, - jitter: bool = True, - is_retryable: Optional[Callable[[Exception], bool]] = None, -) -> T: - """Call *func* and retry on transient failures with exponential backoff. - - The delay between attempts follows the formula:: - - delay = min(base_delay * 2^attempt, max_delay) - - When ``jitter=True`` the delay is multiplied by a random factor in - ``[0.5, 1.5)`` to spread concurrent retries. - - Args: - func: Zero-argument callable to invoke. - max_retries: Maximum number of *additional* attempts after the first - failure. ``0`` disables retrying entirely. - base_delay: Initial delay in seconds before the first retry. - max_delay: Upper bound on the computed delay (seconds). - jitter: Whether to apply random jitter to the delay. - is_retryable: Optional predicate that decides whether an exception - should be retried. Defaults to ``is_transient_error``. - - Returns: - The return value of *func* on success. - - Raises: - Exception: The last exception raised by *func* after all retries are - exhausted, or immediately if the error is not retryable. - """ - _check = is_retryable if is_retryable is not None else is_transient_error - - last_exc: Exception - for attempt in range(max_retries + 1): - try: - return func() - except Exception as exc: - last_exc = exc - - if not _check(exc): - # Permanent — propagate immediately - raise - - if attempt >= max_retries: - # Retries exhausted - logger.warning( - "transient_retry: all %d retries exhausted; last error: %s", - max_retries, - exc, - ) - raise - - delay = min(base_delay * (2**attempt), max_delay) - if jitter: - delay *= 0.5 + random.random() # [0.5, 1.5) - - logger.info( - "transient_retry: attempt %d/%d failed (%s); retrying in %.2fs", - attempt + 1, - max_retries, - exc, - delay, - ) - time.sleep(delay) - - # Unreachable, but satisfies the type checker - raise last_exc # type: ignore[possibly-undefined] - - -# --------------------------------------------------------------------------- -# transient_retry_async -# --------------------------------------------------------------------------- - - -async def transient_retry_async( - coro_func: Callable[[], "asyncio.Coroutine[object, object, T]"], - max_retries: int = 3, - base_delay: float = 0.5, - max_delay: float = 8.0, - jitter: bool = True, - is_retryable: Optional[Callable[[Exception], bool]] = None, -) -> T: - """Async version of :func:`transient_retry`. - - Identical semantics to the sync variant but uses ``asyncio.sleep`` - so it does not block the event loop during backoff. - - Args: - coro_func: Zero-argument async callable (coroutine factory) to invoke. - max_retries: Maximum number of *additional* attempts after the first - failure. ``0`` disables retrying entirely. - base_delay: Initial delay in seconds before the first retry. - max_delay: Upper bound on the computed delay (seconds). - jitter: Whether to apply random jitter to the delay. - is_retryable: Optional predicate that decides whether an exception - should be retried. Defaults to ``is_transient_error``. - - Returns: - The return value of *coro_func()* on success. - - Raises: - Exception: The last exception raised by *coro_func* after all retries - are exhausted, or immediately if the error is not retryable. - """ - _check = is_retryable if is_retryable is not None else is_transient_error - - last_exc: Exception - for attempt in range(max_retries + 1): - try: - return await coro_func() - except Exception as exc: - last_exc = exc - - if not _check(exc): - raise - - if attempt >= max_retries: - logger.warning( - "transient_retry_async: all %d retries exhausted; last error: %s", - max_retries, - exc, - ) - raise - - delay = min(base_delay * (2**attempt), max_delay) - if jitter: - delay *= 0.5 + random.random() - - logger.info( - "transient_retry_async: attempt %d/%d failed (%s); retrying in %.2fs", - attempt + 1, - max_retries, - exc, - delay, - ) - await asyncio.sleep(delay) - - raise last_exc # type: ignore[possibly-undefined] diff --git a/openviking/models/vlm/backends/litellm_vlm.py b/openviking/models/vlm/backends/litellm_vlm.py index 36a2bf2d9..ca4a36aa7 100644 --- a/openviking/models/vlm/backends/litellm_vlm.py +++ b/openviking/models/vlm/backends/litellm_vlm.py @@ -8,6 +8,7 @@ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" +import asyncio import base64 import time from pathlib import Path @@ -16,8 +17,6 @@ import litellm from litellm import acompletion, completion -from openviking.models.retry import transient_retry, transient_retry_async - from ..base import ToolCall, VLMBase, VLMResponse logger = logging.getLogger(__name__) @@ -295,11 +294,8 @@ def get_completion( kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) - def _call(): - return completion(**kwargs) - t0 = time.perf_counter() - response = transient_retry(_call, max_retries=self.max_retries) + response = completion(**kwargs) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -308,6 +304,7 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, + max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -321,17 +318,25 @@ async def get_completion_async( kwargs = self._build_kwargs(model, kwargs_messages, tools, tool_choice, thinking=thinking) - async def _call(): - return await acompletion(**kwargs) - - t0 = time.perf_counter() - response = await transient_retry_async(_call, max_retries=self.max_retries) - elapsed = time.perf_counter() - t0 - self._update_token_usage_from_response( - response, - duration_seconds=elapsed, - ) - return self._build_vlm_response(response, has_tools=bool(tools)) + last_error = None + for attempt in range(max_retries + 1): + try: + t0 = time.perf_counter() + response = await acompletion(**kwargs) + elapsed = time.perf_counter() - t0 + self._update_token_usage_from_response( + response, + duration_seconds=elapsed, + ) + return self._build_vlm_response(response, has_tools=bool(tools)) + except Exception as e: + last_error = e + if attempt < max_retries: + await asyncio.sleep(2**attempt) + + if last_error: + raise last_error + raise RuntimeError("Unknown error in async completion") def get_vision_completion( self, @@ -357,11 +362,8 @@ def get_vision_completion( kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) - def _call(): - return completion(**kwargs) - t0 = time.perf_counter() - response = transient_retry(_call, max_retries=self.max_retries) + response = completion(**kwargs) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) @@ -390,11 +392,8 @@ async def get_vision_completion_async( kwargs = self._build_kwargs(model, kwargs_messages, tools, thinking=thinking) - async def _call(): - return await acompletion(**kwargs) - t0 = time.perf_counter() - response = await transient_retry_async(_call, max_retries=self.max_retries) + response = await acompletion(**kwargs) elapsed = time.perf_counter() - t0 self._update_token_usage_from_response(response, duration_seconds=elapsed) return self._build_vlm_response(response, has_tools=bool(tools)) diff --git a/openviking/models/vlm/backends/openai_vlm.py b/openviking/models/vlm/backends/openai_vlm.py index abcc35ba0..05c28c768 100644 --- a/openviking/models/vlm/backends/openai_vlm.py +++ b/openviking/models/vlm/backends/openai_vlm.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: AGPL-3.0 """OpenAI VLM backend implementation""" +import asyncio import base64 import json import logging @@ -10,8 +11,6 @@ from typing import Any, Dict, List, Optional, Union from urllib.parse import urlparse -from openviking.models.retry import transient_retry, transient_retry_async - from ..base import ToolCall, VLMBase, VLMResponse from ..registry import DEFAULT_AZURE_API_VERSION @@ -70,7 +69,6 @@ def get_client(self): self.api_version, self.extra_headers, ) - kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry if self.provider == "azure": self._sync_client = openai.AzureOpenAI(**kwargs) else: @@ -91,7 +89,6 @@ def get_async_client(self): self.api_version, self.extra_headers, ) - kwargs["max_retries"] = 0 # Disable SDK retry; we use transient_retry_async if self.provider == "azure": self._async_client = openai.AsyncAzureOpenAI(**kwargs) else: @@ -292,11 +289,8 @@ def get_completion( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" - def _call(): - return client.chat.completions.create(**kwargs) - t0 = time.perf_counter() - response = transient_retry(_call, max_retries=self.max_retries) + response = client.chat.completions.create(**kwargs) elapsed = time.perf_counter() - t0 if tools: @@ -315,6 +309,7 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, + max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -340,27 +335,36 @@ async def get_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" - async def _call(): - return await client.chat.completions.create(**kwargs) - - t0 = time.perf_counter() - response = await transient_retry_async(_call, max_retries=self.max_retries) - elapsed = time.perf_counter() - t0 - - if tools: - self._update_token_usage_from_response(response) - return self._build_vlm_response(response, has_tools=bool(tools)) - - if self.stream: - content = await self._process_streaming_response_async(response) + last_error = None + for attempt in range(max_retries + 1): + try: + t0 = time.perf_counter() + response = await client.chat.completions.create(**kwargs) + elapsed = time.perf_counter() - t0 + + if tools: + self._update_token_usage_from_response(response) + return self._build_vlm_response(response, has_tools=bool(tools)) + + if self.stream: + content = await self._process_streaming_response_async(response) + else: + self._update_token_usage_from_response( + response, + duration_seconds=elapsed, + ) + content = self._extract_content_from_response(response) + + return self._clean_response(content) + except Exception as e: + last_error = e + if attempt < max_retries: + await asyncio.sleep(2**attempt) + + if last_error: + raise last_error else: - self._update_token_usage_from_response( - response, - duration_seconds=elapsed, - ) - content = self._extract_content_from_response(response) - - return self._clean_response(content) + raise RuntimeError("Unknown error in async completion") def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. @@ -450,11 +454,8 @@ def get_vision_completion( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" - def _call(): - return client.chat.completions.create(**kwargs) - t0 = time.perf_counter() - response = transient_retry(_call, max_retries=self.max_retries) + response = client.chat.completions.create(**kwargs) elapsed = time.perf_counter() - t0 if tools: @@ -505,11 +506,8 @@ async def get_vision_completion_async( kwargs["tools"] = tools kwargs["tool_choice"] = "auto" - async def _call(): - return await client.chat.completions.create(**kwargs) - t0 = time.perf_counter() - response = await transient_retry_async(_call, max_retries=self.max_retries) + response = await client.chat.completions.create(**kwargs) elapsed = time.perf_counter() - t0 if tools: diff --git a/openviking/models/vlm/backends/volcengine_vlm.py b/openviking/models/vlm/backends/volcengine_vlm.py index 50912c600..06616def1 100644 --- a/openviking/models/vlm/backends/volcengine_vlm.py +++ b/openviking/models/vlm/backends/volcengine_vlm.py @@ -9,7 +9,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union -from openviking.models.retry import transient_retry_async +# Import run_async for sync-to-async calls from openviking_cli.utils import run_async from ..base import ToolCall, VLMResponse @@ -266,20 +266,6 @@ def _update_token_usage_from_response( ) return - def _parse_tool_calls(self, message) -> List[ToolCall]: - """Parse tool calls from VolcEngine response message.""" - tool_calls = [] - if hasattr(message, "tool_calls") and message.tool_calls: - for tc in message.tool_calls: - args = tc.function.arguments - if isinstance(args, str): - try: - args = json.loads(args) - except json.JSONDecodeError: - args = {"raw": args} - tool_calls.append(ToolCall(id=tc.id, name=tc.function.name, arguments=args)) - return tool_calls - def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMResponse]: """Build response from VolcEngine Responses API response. @@ -293,10 +279,10 @@ def _build_vlm_response(self, response, has_tools: bool) -> Union[str, VLMRespon # logger.info(f"[VolcEngineVLM] Full response: {response}") if hasattr(response, "output"): # logger.debug(f"[VolcEngineVLM] Output items: {len(response.output)}") - for _i, _item in enumerate(response.output): - # logger.debug(f"[VolcEngineVLM] Item {_i}: type={getattr(_item, 'type', 'unknown')}") + for i, item in enumerate(response.output): + # logger.debug(f"[VolcEngineVLM] Item {i}: type={getattr(item, 'type', 'unknown')}") # Print full item for debugging - # logger.info(f"[VolcEngineVLM] Item {_i} full: {_item}") + # logger.info(f"[VolcEngineVLM] Item {i} full: {item}") pass # Extract content from Responses API format @@ -450,7 +436,7 @@ def _convert_messages_to_input(self, messages: List[Dict[str, Any]]) -> List[Dic url = image_url.get("url", "") if url: image_urls.append(url) - has_images = True # noqa: F841 + has_images = True # Handle other block types else: # Try to extract text from any dict block @@ -553,6 +539,7 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, + max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -574,19 +561,32 @@ async def get_completion_async( # If we have static segments, try prefix cache response_format = None # Can be extended for structured output - async def _call(): - return await self.responseapi_prefixcache_completion( + try: + # Use prefix cache with multiple segments + response = await self.responseapi_prefixcache_completion( static_segments=static_segments, dynamic_messages=dynamic_messages, response_format=response_format, tools=tools, tool_choice=tool_choice, ) + elapsed = 0 # Timing handled in responseapi methods + self._update_token_usage_from_response(response, duration_seconds=elapsed) + return self._build_vlm_response(response, has_tools=bool(tools)) - response = await transient_retry_async(_call, max_retries=self.max_retries) - elapsed = 0 # Timing handled in responseapi methods - self._update_token_usage_from_response(response, duration_seconds=elapsed) - return self._build_vlm_response(response, has_tools=bool(tools)) + except Exception as e: + last_error = e + # Log token info from error response if available + error_response = getattr(e, "response", None) + if error_response and hasattr(error_response, "usage"): + u = error_response.usage + prompt_tokens = getattr(u, "input_tokens", 0) or 0 + completion_tokens = getattr(u, "output_tokens", 0) or 0 + logger.info( + f"[VolcEngineVLM] Error response - Input tokens: {prompt_tokens}, Output tokens: {completion_tokens}" + ) + logger.warning(f"[VolcEngineVLM] Request failed: {e}") + raise last_error def _detect_image_format(self, data: bytes) -> str: """Detect image format from magic bytes. diff --git a/openviking/models/vlm/base.py b/openviking/models/vlm/base.py index 270444283..c54285ad4 100644 --- a/openviking/models/vlm/base.py +++ b/openviking/models/vlm/base.py @@ -58,7 +58,7 @@ def __init__(self, config: Dict[str, Any]): self.api_key = config.get("api_key") self.api_base = config.get("api_base") self.temperature = config.get("temperature", 0.0) - self.max_retries = config.get("max_retries", 3) + self.max_retries = config.get("max_retries", 2) self.max_tokens = config.get("max_tokens") self.extra_headers = config.get("extra_headers") self.stream = config.get("stream", False) @@ -94,6 +94,7 @@ async def get_completion_async( self, prompt: str = "", thinking: bool = False, + max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, tool_choice: Optional[str] = None, messages: Optional[List[Dict[str, Any]]] = None, @@ -103,6 +104,7 @@ async def get_completion_async( Args: prompt: Text prompt (used if messages not provided) thinking: Whether to enable thinking mode + max_retries: Maximum number of retries tools: Optional list of tool definitions in OpenAI function format tool_choice: Optional tool choice mode ("auto", "none", or specific tool name) messages: Optional list of message dicts (takes precedence over prompt) diff --git a/openviking/models/vlm/llm.py b/openviking/models/vlm/llm.py index d28266a3a..e1cde6ccf 100644 --- a/openviking/models/vlm/llm.py +++ b/openviking/models/vlm/llm.py @@ -183,12 +183,7 @@ def complete_json( if schema and not messages: prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" - response = self._get_vlm().get_completion( - prompt=prompt, - thinking=thinking, - tools=tools, - messages=messages, - ) + response = self._get_vlm().get_completion(prompt, thinking, tools, messages) return parse_json_from_response(response) async def complete_json_async( @@ -196,6 +191,7 @@ async def complete_json_async( prompt: str = "", schema: Optional[Dict[str, Any]] = None, thinking: bool = False, + max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Optional[Dict[str, Any]]: @@ -204,10 +200,7 @@ async def complete_json_async( prompt = f"{prompt}\n\n{get_json_schema_prompt(schema)}" response = await self._get_vlm().get_completion_async( - prompt=prompt, - thinking=thinking, - tools=tools, - messages=messages, + prompt, thinking, max_retries, tools, messages ) return parse_json_from_response(response) @@ -234,13 +227,12 @@ async def complete_model_async( prompt: str, model_class: Type[T], thinking: bool = False, + max_retries: int = 0, ) -> Optional[T]: """Async version of complete_model.""" schema = model_class.model_json_schema() response = await self.complete_json_async( - prompt=prompt, - schema=schema, - thinking=thinking, + prompt, schema=schema, thinking=thinking, max_retries=max_retries ) if response is None: return None @@ -260,13 +252,7 @@ def get_vision_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get vision completion.""" - return self._get_vlm().get_vision_completion( - prompt=prompt, - images=images, - thinking=thinking, - tools=tools, - messages=messages, - ) + return self._get_vlm().get_vision_completion(prompt, images, thinking, tools, messages) async def get_vision_completion_async( self, @@ -278,9 +264,5 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Async vision completion.""" return await self._get_vlm().get_vision_completion_async( - prompt=prompt, - images=images, - thinking=thinking, - tools=tools, - messages=messages, + prompt, images, thinking, tools, messages ) diff --git a/openviking_cli/utils/config/embedding_config.py b/openviking_cli/utils/config/embedding_config.py index 1377d5f4a..8caad66a9 100644 --- a/openviking_cli/utils/config/embedding_config.py +++ b/openviking_cli/utils/config/embedding_config.py @@ -262,7 +262,6 @@ class EmbeddingConfig(BaseModel): sparse: Optional[EmbeddingModelConfig] = Field(default=None) hybrid: Optional[EmbeddingModelConfig] = Field(default=None) - max_retries: int = Field(default=3, description="Maximum retry attempts for transient errors") max_concurrent: int = Field( default=10, description="Maximum number of concurrent embedding requests" ) @@ -510,13 +509,6 @@ def _create_embedder( embedder_class, param_builder = factory_registry[key] params = param_builder(config) - - # Inject max_retries into the config dict so embedders pick it up - existing_config = params.get("config") or {} - if isinstance(existing_config, dict): - existing_config["max_retries"] = self.max_retries - params["config"] = existing_config - return embedder_class(**params) def get_embedder(self): diff --git a/openviking_cli/utils/config/vlm_config.py b/openviking_cli/utils/config/vlm_config.py index 5bff0a8e4..dd67e7c2c 100644 --- a/openviking_cli/utils/config/vlm_config.py +++ b/openviking_cli/utils/config/vlm_config.py @@ -12,7 +12,7 @@ class VLMConfig(BaseModel): api_key: Optional[str] = Field(default=None, description="API key") api_base: Optional[str] = Field(default=None, description="API base URL") temperature: float = Field(default=0.0, description="Generation temperature") - max_retries: int = Field(default=3, description="Maximum retry attempts") + max_retries: int = Field(default=2, description="Maximum retry attempts") provider: Optional[str] = Field(default=None, description="Provider type") backend: Optional[str] = Field( @@ -181,26 +181,19 @@ def get_completion( messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: """Get LLM completion.""" - return self.get_vlm_instance().get_completion( - prompt=prompt, - thinking=thinking, - tools=tools, - messages=messages, - ) + return self.get_vlm_instance().get_completion(prompt, thinking, tools, messages) async def get_completion_async( self, prompt: str = "", thinking: bool = False, + max_retries: int = 0, tools: Optional[List[Dict[str, Any]]] = None, messages: Optional[List[Dict[str, Any]]] = None, ) -> Union[str, Any]: - """Get LLM completion asynchronously.""" + """Get LLM completion asynchronously, max_retries=0 means no retry.""" return await self.get_vlm_instance().get_completion_async( - prompt=prompt, - thinking=thinking, - tools=tools, - messages=messages, + prompt, thinking, max_retries, tools, messages ) def is_available(self) -> bool: @@ -217,11 +210,7 @@ def get_vision_completion( ) -> Union[str, Any]: """Get LLM completion with images.""" return self.get_vlm_instance().get_vision_completion( - prompt=prompt, - images=images, - thinking=thinking, - tools=tools, - messages=messages, + prompt, images, thinking, tools, messages ) async def get_vision_completion_async( @@ -234,9 +223,5 @@ async def get_vision_completion_async( ) -> Union[str, Any]: """Get LLM completion with images asynchronously.""" return await self.get_vlm_instance().get_vision_completion_async( - prompt=prompt, - images=images, - thinking=thinking, - tools=tools, - messages=messages, + prompt, images, thinking, tools, messages ) diff --git a/tests/models/test_vlm_strip_think_tags.py b/tests/models/test_vlm_strip_think_tags.py index 95e45b682..c7a996cfb 100644 --- a/tests/models/test_vlm_strip_think_tags.py +++ b/tests/models/test_vlm_strip_think_tags.py @@ -18,7 +18,7 @@ class _Stub(VLMBase): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False): + async def get_completion_async(self, prompt, thinking=False, max_retries=0): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_backward_compat.py b/tests/unit/test_backward_compat.py deleted file mode 100644 index ba4d8a762..000000000 --- a/tests/unit/test_backward_compat.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -"""Backward compatibility tests for the retry migration. - -Verifies that: -- exponential_backoff_retry is still importable from the old location (base.py) -- exponential_backoff_retry signature is unchanged -- exponential_backoff_retry behaviour still works (time-based) -- transient_retry is count-based (different semantics) -""" - -from __future__ import annotations - -import inspect -from unittest.mock import patch - -import pytest - - -class _HttpError(Exception): - """Fake HTTP error carrying a numeric status code.""" - - def __init__(self, status_code: int, message: str = ""): - super().__init__(message or f"HTTP {status_code}") - self.status_code = status_code - - -class TestExponentialBackoffRetryImportable: - def test_importable_from_old_location(self): - """exponential_backoff_retry should still be importable from base.py.""" - from openviking.models.embedder.base import exponential_backoff_retry - - assert callable(exponential_backoff_retry) - - -class TestExponentialBackoffRetrySignature: - def test_signature_unchanged(self): - """exponential_backoff_retry should retain its original signature.""" - from openviking.models.embedder.base import exponential_backoff_retry - - sig = inspect.signature(exponential_backoff_retry) - param_names = list(sig.parameters.keys()) - - expected_params = [ - "func", - "max_wait", - "base_delay", - "max_delay", - "jitter", - "is_retryable", - "logger", - ] - - assert param_names == expected_params, ( - f"exponential_backoff_retry signature changed.\n" - f"Expected: {expected_params}\n" - f"Got: {param_names}" - ) - - def test_defaults_unchanged(self): - """Default parameter values should be preserved.""" - from openviking.models.embedder.base import exponential_backoff_retry - - sig = inspect.signature(exponential_backoff_retry) - params = sig.parameters - - assert params["max_wait"].default == 10.0 - assert params["base_delay"].default == 0.5 - assert params["max_delay"].default == 2.0 - assert params["jitter"].default is True - assert params["is_retryable"].default is None - assert params["logger"].default is None - - -class TestExponentialBackoffRetryBehavior: - def test_success_first_try(self): - """Function succeeds on first attempt.""" - from openviking.models.embedder.base import exponential_backoff_retry - - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - return "ok" - - result = exponential_backoff_retry(fn) - assert result == "ok" - assert call_count == 1 - - def test_retries_on_failure(self): - """Function retries on failure until success.""" - from openviking.models.embedder.base import exponential_backoff_retry - - errors = [Exception("fail"), Exception("fail")] - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return "ok" - - with patch("time.sleep"): - result = exponential_backoff_retry(fn, max_wait=10.0) - - assert result == "ok" - assert call_count == 3 - - def test_is_time_based(self): - """exponential_backoff_retry should be time-based (uses max_wait, not count).""" - from openviking.models.embedder.base import exponential_backoff_retry - - sig = inspect.signature(exponential_backoff_retry) - param_names = list(sig.parameters.keys()) - - # Time-based: has max_wait, no max_retries - assert "max_wait" in param_names - assert "max_retries" not in param_names - - def test_respects_is_retryable(self): - """exponential_backoff_retry should respect is_retryable callback.""" - from openviking.models.embedder.base import exponential_backoff_retry - - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - raise ValueError("permanent") - - # is_retryable returns False => no retry - with patch("time.sleep"): - with pytest.raises(ValueError): - exponential_backoff_retry(fn, is_retryable=lambda e: False) - - assert call_count == 1 - - -class TestTransientRetryIsCountBased: - def test_is_count_based(self): - """transient_retry should be count-based (uses max_retries, not max_wait).""" - from openviking.models.retry import transient_retry - - sig = inspect.signature(transient_retry) - param_names = list(sig.parameters.keys()) - - # Count-based: has max_retries, no max_wait - assert "max_retries" in param_names - assert "max_wait" not in param_names - - def test_different_from_backoff_retry(self): - """transient_retry and exponential_backoff_retry should have different signatures.""" - from openviking.models.embedder.base import exponential_backoff_retry - from openviking.models.retry import transient_retry - - backoff_params = set(inspect.signature(exponential_backoff_retry).parameters.keys()) - retry_params = set(inspect.signature(transient_retry).parameters.keys()) - - # They share 'func', 'base_delay', 'max_delay', 'jitter', 'is_retryable' - # but differ on time vs count control params - assert "max_wait" in backoff_params - assert "max_wait" not in retry_params - assert "max_retries" in retry_params - assert "max_retries" not in backoff_params diff --git a/tests/unit/test_embedding_retry_integration.py b/tests/unit/test_embedding_retry_integration.py deleted file mode 100644 index 02011b4cb..000000000 --- a/tests/unit/test_embedding_retry_integration.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -"""Integration tests for embedding providers with unified retry logic. - -Tests cover (using OpenAI and VikingDB as representatives): -- embed retries on transient error (mock API client) -- embed does NOT retry on permanent error -- uses config max_retries -- VikingDB now has retry -""" - -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -class _HttpError(Exception): - """Fake HTTP error carrying a numeric status code.""" - - def __init__(self, status_code: int, message: str = ""): - super().__init__(message or f"HTTP {status_code}") - self.status_code = status_code - - -def _make_fake_embedding_response(vector=None): - """Build a minimal fake OpenAI embeddings response.""" - if vector is None: - vector = [0.1] * 10 - item = SimpleNamespace(embedding=vector) - usage = SimpleNamespace(prompt_tokens=5, total_tokens=5) - return SimpleNamespace(data=[item], usage=usage) - - -# --------------------------------------------------------------------------- -# OpenAI Embedder Tests -# --------------------------------------------------------------------------- - - -class TestOpenAIEmbedderRetry: - @pytest.fixture() - def openai_embedder(self): - """Create an OpenAIDenseEmbedder with mocked client.""" - from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder - - embedder = OpenAIDenseEmbedder( - model_name="text-embedding-3-small", - api_key="sk-test", - dimension=10, - config={"max_retries": 2}, - ) - embedder.client = MagicMock() - return embedder - - def test_embed_retries_on_transient_error(self, openai_embedder): - """embed() should retry on 429 (transient) and succeed.""" - errors = [_HttpError(429)] - call_count = 0 - - def fake_create(**kwargs): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return _make_fake_embedding_response() - - openai_embedder.client.embeddings.create = fake_create - - with patch("time.sleep"): - result = openai_embedder.embed("test text") - - assert result.dense_vector == [0.1] * 10 - assert call_count == 2 # 1 failure + 1 success - - def test_embed_no_retry_on_permanent_error(self, openai_embedder): - """embed() should NOT retry on 401 (permanent).""" - call_count = 0 - - def fake_create(**kwargs): - nonlocal call_count - call_count += 1 - raise _HttpError(401, "Unauthorized") - - openai_embedder.client.embeddings.create = fake_create - - with patch("time.sleep"): - # 401 is permanent, transient_retry won't retry it. - # It will propagate and be caught by the except block, re-raised as RuntimeError. - with pytest.raises((RuntimeError, _HttpError)): - openai_embedder.embed("test text") - - assert call_count == 1 # no retries - - def test_uses_config_max_retries(self): - """Embedder should use self.max_retries from config.""" - from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder - - embedder = OpenAIDenseEmbedder( - model_name="text-embedding-3-small", - api_key="sk-test", - dimension=10, - config={"max_retries": 5}, - ) - assert embedder.max_retries == 5 - - # Default - embedder2 = OpenAIDenseEmbedder( - model_name="text-embedding-3-small", - api_key="sk-test", - dimension=10, - ) - assert embedder2.max_retries == 3 - - def test_openai_sdk_retry_disabled(self): - """OpenAI client should be created with max_retries=0.""" - from openviking.models.embedder.openai_embedders import OpenAIDenseEmbedder - - with patch("openai.OpenAI") as mock_openai: - mock_openai.return_value = MagicMock() - OpenAIDenseEmbedder( - model_name="text-embedding-3-small", - api_key="sk-test", - dimension=10, - ) - call_kwargs = mock_openai.call_args - assert call_kwargs.kwargs.get("max_retries") == 0 - - -# --------------------------------------------------------------------------- -# VikingDB Embedder Tests -# --------------------------------------------------------------------------- - - -class TestVikingDBEmbedderRetry: - @pytest.fixture() - def vikingdb_embedder(self): - """Create a VikingDBDenseEmbedder with mocked client.""" - from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder - - with patch("openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi"): - embedder = VikingDBDenseEmbedder( - model_name="test-model", - model_version="1.0", - ak="test-ak", - sk="test-sk", - region="cn-beijing", - dimension=10, - config={"max_retries": 2}, - ) - return embedder - - def test_embed_retries_on_transient_error(self, vikingdb_embedder): - """embed() should retry on transient error and succeed.""" - errors = [_HttpError(503)] - call_count = 0 - - def fake_call_api(*args, **kwargs): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return [{"dense_embedding": [0.1] * 10}] - - vikingdb_embedder._call_api = fake_call_api - - with patch("time.sleep"): - result = vikingdb_embedder.embed("test text") - - assert result.dense_vector == [0.1] * 10 - assert call_count == 2 # 1 failure + 1 success - - def test_embed_no_retry_on_permanent_error(self, vikingdb_embedder): - """embed() should NOT retry on 401 (permanent).""" - call_count = 0 - - def fake_call_api(*args, **kwargs): - nonlocal call_count - call_count += 1 - raise _HttpError(401, "Unauthorized") - - vikingdb_embedder._call_api = fake_call_api - - with patch("time.sleep"): - with pytest.raises(_HttpError): - vikingdb_embedder.embed("test text") - - assert call_count == 1 # no retries - - def test_uses_config_max_retries(self): - """VikingDB embedder should use self.max_retries from config.""" - from openviking.models.embedder.vikingdb_embedders import VikingDBDenseEmbedder - - with patch("openviking.storage.vectordb.collection.volcengine_clients.ClientForDataApi"): - embedder = VikingDBDenseEmbedder( - model_name="test-model", - model_version="1.0", - ak="test-ak", - sk="test-sk", - region="cn-beijing", - dimension=10, - config={"max_retries": 7}, - ) - assert embedder.max_retries == 7 - - def test_vikingdb_now_has_retry(self, vikingdb_embedder): - """VikingDB embed() should retry on 429 (was zero retry before unified retry).""" - errors = [_HttpError(429), _HttpError(429)] - call_count = 0 - - def fake_call_api(*args, **kwargs): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return [{"dense_embedding": [0.2] * 10}] - - vikingdb_embedder._call_api = fake_call_api - - with patch("time.sleep"): - result = vikingdb_embedder.embed("test text") - - assert result.dense_vector == [0.2] * 10 - assert call_count == 3 # 2 failures + 1 success diff --git a/tests/unit/test_extra_headers_vlm.py b/tests/unit/test_extra_headers_vlm.py index c087c2370..97f6bea00 100644 --- a/tests/unit/test_extra_headers_vlm.py +++ b/tests/unit/test_extra_headers_vlm.py @@ -210,7 +210,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False): + async def get_completion_async(self, prompt, thinking=False, max_retries=0): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -236,7 +236,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False): + async def get_completion_async(self, prompt, thinking=False, max_retries=0): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py deleted file mode 100644 index 176412d48..000000000 --- a/tests/unit/test_retry.py +++ /dev/null @@ -1,441 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -"""Comprehensive tests for the core retry module (openviking.models.retry). - -Tests cover: -- is_transient_error: ~28 parametrized cases (14 transient, 14 permanent) -- transient_retry (sync): 8 behavioral tests -- transient_retry_async (async): 8 mirrored behavioral tests -""" - -from __future__ import annotations - -import asyncio -from unittest.mock import AsyncMock, patch - -import pytest - -from openviking.models.retry import is_transient_error, transient_retry, transient_retry_async - -# --------------------------------------------------------------------------- -# Helper fake HTTP error with status_code attribute -# --------------------------------------------------------------------------- - - -class _HttpError(Exception): - """Fake HTTP error carrying a numeric status code for testing.""" - - def __init__(self, status_code: int, message: str = ""): - super().__init__(message or f"HTTP {status_code}") - self.status_code = status_code - - -# --------------------------------------------------------------------------- -# is_transient_error — parametrized cases -# --------------------------------------------------------------------------- - -_TRANSIENT_CASES = [ - # HTTP status codes via _HttpError.status_code - pytest.param(_HttpError(429), True, id="http_429"), - pytest.param(_HttpError(500), True, id="http_500"), - pytest.param(_HttpError(502), True, id="http_502"), - pytest.param(_HttpError(503), True, id="http_503"), - pytest.param(_HttpError(504), True, id="http_504"), - # Built-in connection exceptions - pytest.param(ConnectionError("connection failed"), True, id="ConnectionError"), - pytest.param(ConnectionResetError("reset"), True, id="ConnectionResetError"), - pytest.param(ConnectionRefusedError("refused"), True, id="ConnectionRefusedError"), - pytest.param(TimeoutError("timed out"), True, id="TimeoutError"), - pytest.param(asyncio.TimeoutError(), True, id="asyncio_TimeoutError"), - # String-pattern transient errors - pytest.param(Exception("TooManyRequests from server"), True, id="str_TooManyRequests"), - pytest.param(Exception("RateLimit exceeded"), True, id="str_RateLimit"), - pytest.param(Exception("RequestBurstTooFast"), True, id="str_RequestBurstTooFast"), - pytest.param(Exception("request timed out after 30s"), True, id="str_timed_out"), -] - -_PERMANENT_CASES = [ - # HTTP status codes via _HttpError.status_code - pytest.param(_HttpError(400), False, id="http_400"), - pytest.param(_HttpError(401), False, id="http_401"), - pytest.param(_HttpError(403), False, id="http_403"), - pytest.param(_HttpError(404), False, id="http_404"), - pytest.param(_HttpError(422), False, id="http_422"), - # Built-in value/type errors - pytest.param(ValueError("bad value"), False, id="ValueError"), - pytest.param(TypeError("wrong type"), False, id="TypeError"), - # String-pattern permanent errors - pytest.param( - Exception("InvalidRequestError: field missing"), False, id="str_InvalidRequestError" - ), - pytest.param( - Exception("AuthenticationError: invalid key"), False, id="str_AuthenticationError" - ), - # Unknown errors — conservative default False - pytest.param(Exception("some unknown error"), False, id="unknown_generic"), - pytest.param(RuntimeError("unexpected state"), False, id="RuntimeError_unknown"), - pytest.param(KeyError("missing key"), False, id="KeyError"), - pytest.param(AttributeError("no attr"), False, id="AttributeError"), - pytest.param( - Exception("config_value_out_of_range"), False, id="str_unknown_no_transient_keyword" - ), -] - - -@pytest.mark.parametrize("exc,expected", _TRANSIENT_CASES) -def test_is_transient_error_transient(exc, expected): - """Transient errors should be classified as retryable (True).""" - assert is_transient_error(exc) is expected - - -@pytest.mark.parametrize("exc,expected", _PERMANENT_CASES) -def test_is_transient_error_permanent(exc, expected): - """Permanent / unknown errors should not be retried (False).""" - assert is_transient_error(exc) is expected - - -# --------------------------------------------------------------------------- -# transient_retry (sync) -# --------------------------------------------------------------------------- - - -class TestTransientRetrySync: - """Sync retry behaviour tests.""" - - def test_success_first_try(self): - """Function succeeds on first attempt — call_count == 1.""" - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - return "ok" - - result = transient_retry(fn, max_retries=3) - assert result == "ok" - assert call_count == 1 - - def test_retry_then_success(self): - """Two transient failures then success — call_count == 3.""" - errors = [_HttpError(429), _HttpError(503)] - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return "ok" - - with patch("time.sleep"): - result = transient_retry(fn, max_retries=3) - - assert result == "ok" - assert call_count == 3 - - def test_permanent_error_no_retry(self): - """Permanent error (401) should not be retried — call_count == 1 and raises.""" - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - raise _HttpError(401) - - with patch("time.sleep"): - with pytest.raises(_HttpError): - transient_retry(fn, max_retries=3) - - assert call_count == 1 - - def test_max_retries_exhausted(self): - """4 consecutive 429 errors with max_retries=3 → raises after 4 calls.""" - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - raise _HttpError(429) - - with patch("time.sleep"): - with pytest.raises(_HttpError): - transient_retry(fn, max_retries=3) - - assert call_count == 4 # 1 initial + 3 retries - - def test_max_retries_zero_raises_immediately(self): - """max_retries=0 disables retrying — call_count == 1.""" - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - raise _HttpError(429) - - with patch("time.sleep"): - with pytest.raises(_HttpError): - transient_retry(fn, max_retries=0) - - assert call_count == 1 - - def test_max_retries_one(self): - """max_retries=1: one failure then success → call_count == 2.""" - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - if call_count == 1: - raise _HttpError(429) - return "done" - - with patch("time.sleep"): - result = transient_retry(fn, max_retries=1) - - assert result == "done" - assert call_count == 2 - - def test_backoff_delays_exponential(self): - """Verify exponential backoff: base_delay=1.0, jitter=False → 1.0, 2.0, 4.0.""" - call_count = 0 - sleep_calls = [] - - def fn(): - nonlocal call_count - call_count += 1 - raise _HttpError(429) - - with patch("time.sleep", side_effect=lambda d: sleep_calls.append(d)): - with pytest.raises(_HttpError): - transient_retry(fn, max_retries=3, base_delay=1.0, max_delay=100.0, jitter=False) - - assert len(sleep_calls) == 3 - assert sleep_calls[0] == pytest.approx(1.0) - assert sleep_calls[1] == pytest.approx(2.0) - assert sleep_calls[2] == pytest.approx(4.0) - - def test_delay_capped_at_max_delay(self): - """Delays must not exceed max_delay even with many retries.""" - sleep_calls = [] - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - raise _HttpError(503) - - with patch("time.sleep", side_effect=lambda d: sleep_calls.append(d)): - with pytest.raises(_HttpError): - transient_retry(fn, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False) - - assert all(d <= 8.0 for d in sleep_calls), f"Some delays exceed max_delay: {sleep_calls}" - - -# --------------------------------------------------------------------------- -# transient_retry_async (async) -# --------------------------------------------------------------------------- - - -class TestTransientRetryAsync: - """Async retry behaviour tests — mirrors sync suite.""" - - async def test_success_first_try(self): - """Async function succeeds on first attempt — call_count == 1.""" - call_count = 0 - - async def coro(): - nonlocal call_count - call_count += 1 - return "ok" - - result = await transient_retry_async(coro, max_retries=3) - assert result == "ok" - assert call_count == 1 - - async def test_retry_then_success(self): - """Two transient failures then success — call_count == 3.""" - errors = [_HttpError(429), _HttpError(503)] - call_count = 0 - - async def coro(): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return "ok" - - with patch("asyncio.sleep", new_callable=AsyncMock): - result = await transient_retry_async(coro, max_retries=3) - - assert result == "ok" - assert call_count == 3 - - async def test_permanent_error_no_retry(self): - """Permanent error (401) should not be retried — call_count == 1 and raises.""" - call_count = 0 - - async def coro(): - nonlocal call_count - call_count += 1 - raise _HttpError(401) - - with patch("asyncio.sleep", new_callable=AsyncMock): - with pytest.raises(_HttpError): - await transient_retry_async(coro, max_retries=3) - - assert call_count == 1 - - async def test_max_retries_exhausted(self): - """4 consecutive 429 errors with max_retries=3 → raises after 4 calls.""" - call_count = 0 - - async def coro(): - nonlocal call_count - call_count += 1 - raise _HttpError(429) - - with patch("asyncio.sleep", new_callable=AsyncMock): - with pytest.raises(_HttpError): - await transient_retry_async(coro, max_retries=3) - - assert call_count == 4 - - async def test_max_retries_zero_raises_immediately(self): - """max_retries=0 disables retrying — call_count == 1.""" - call_count = 0 - - async def coro(): - nonlocal call_count - call_count += 1 - raise _HttpError(429) - - with patch("asyncio.sleep", new_callable=AsyncMock): - with pytest.raises(_HttpError): - await transient_retry_async(coro, max_retries=0) - - assert call_count == 1 - - async def test_max_retries_one(self): - """max_retries=1: one failure then success → call_count == 2.""" - call_count = 0 - - async def coro(): - nonlocal call_count - call_count += 1 - if call_count == 1: - raise _HttpError(429) - return "done" - - with patch("asyncio.sleep", new_callable=AsyncMock): - result = await transient_retry_async(coro, max_retries=1) - - assert result == "done" - assert call_count == 2 - - async def test_backoff_delays_exponential(self): - """Verify exponential backoff: base_delay=1.0, jitter=False → 1.0, 2.0, 4.0.""" - call_count = 0 - sleep_calls = [] - - async def fake_sleep(d): - sleep_calls.append(d) - - async def coro(): - nonlocal call_count - call_count += 1 - raise _HttpError(429) - - with patch("asyncio.sleep", side_effect=fake_sleep): - with pytest.raises(_HttpError): - await transient_retry_async( - coro, max_retries=3, base_delay=1.0, max_delay=100.0, jitter=False - ) - - assert len(sleep_calls) == 3 - assert sleep_calls[0] == pytest.approx(1.0) - assert sleep_calls[1] == pytest.approx(2.0) - assert sleep_calls[2] == pytest.approx(4.0) - - async def test_delay_capped_at_max_delay(self): - """Async delays must not exceed max_delay even with many retries.""" - sleep_calls = [] - call_count = 0 - - async def fake_sleep(d): - sleep_calls.append(d) - - async def coro(): - nonlocal call_count - call_count += 1 - raise _HttpError(503) - - with patch("asyncio.sleep", side_effect=fake_sleep): - with pytest.raises(_HttpError): - await transient_retry_async( - coro, max_retries=10, base_delay=1.0, max_delay=8.0, jitter=False - ) - - assert all(d <= 8.0 for d in sleep_calls), f"Some delays exceed max_delay: {sleep_calls}" - - -# --------------------------------------------------------------------------- -# Additional edge-case tests -# --------------------------------------------------------------------------- - - -class TestIsTransientErrorEdgeCases: - """Edge cases for is_transient_error.""" - - def test_timeout_substring_in_message(self): - """'timeout' substring in message → transient.""" - err = Exception("connection timeout after 10s") - assert is_transient_error(err) is True - - def test_status_code_attribute_takes_priority(self): - """status_code=503 → transient, even if message says 'bad request'.""" - err = _HttpError(503, "bad request") - assert is_transient_error(err) is True - - def test_status_code_401_permanent_priority(self): - """status_code=401 → permanent, even if message contains 'timeout'.""" - err = _HttpError(401, "timeout auth failure") - assert is_transient_error(err) is False - - def test_custom_is_retryable_overrides(self): - """Custom is_retryable callback overrides default classification.""" - # 429 is normally transient but we pass a custom fn that returns False - call_count = 0 - - def fn(): - nonlocal call_count - call_count += 1 - raise _HttpError(429) - - with patch("time.sleep"): - with pytest.raises(_HttpError): - transient_retry(fn, max_retries=3, is_retryable=lambda e: False) - - assert call_count == 1 # no retries because custom fn says not retryable - - def test_http_status_attribute_variant(self): - """Objects with .http_status should be checked for transient status.""" - - class AltHttpError(Exception): - def __init__(self, http_status: int): - super().__init__(f"HTTP {http_status}") - self.http_status = http_status - - assert is_transient_error(AltHttpError(503)) is True - assert is_transient_error(AltHttpError(401)) is False - - def test_code_attribute_variant(self): - """Objects with .code should be checked for transient status.""" - - class CodeError(Exception): - def __init__(self, code: int): - super().__init__(f"Error code {code}") - self.code = code - - assert is_transient_error(CodeError(429)) is True - assert is_transient_error(CodeError(403)) is False diff --git a/tests/unit/test_retry_config.py b/tests/unit/test_retry_config.py deleted file mode 100644 index 24f28fb5e..000000000 --- a/tests/unit/test_retry_config.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -"""Tests for retry configuration fields on VLMConfig and EmbeddingConfig. - -Verifies that: -- VLMConfig default max_retries = 3 -- EmbeddingConfig has max_retries field, default = 3 -- EmbeddingConfig accepts custom max_retries -""" - -from __future__ import annotations - - -class TestVLMConfigMaxRetries: - def test_default_max_retries(self): - """VLMConfig should default max_retries to 3.""" - from openviking_cli.utils.config.vlm_config import VLMConfig - - cfg = VLMConfig( - model="gpt-4o-mini", - api_key="sk-test", - provider="openai", - ) - assert cfg.max_retries == 3 - - def test_custom_max_retries(self): - """VLMConfig should accept custom max_retries.""" - from openviking_cli.utils.config.vlm_config import VLMConfig - - cfg = VLMConfig( - model="gpt-4o-mini", - api_key="sk-test", - provider="openai", - max_retries=10, - ) - assert cfg.max_retries == 10 - - -class TestEmbeddingConfigMaxRetries: - def test_has_max_retries_field(self): - """EmbeddingConfig should have a max_retries field.""" - from openviking_cli.utils.config.embedding_config import EmbeddingConfig - - fields = EmbeddingConfig.model_fields - assert "max_retries" in fields, ( - f"EmbeddingConfig is missing 'max_retries' field. Fields: {list(fields.keys())}" - ) - - def test_default_max_retries(self): - """EmbeddingConfig should default max_retries to 3.""" - from openviking_cli.utils.config.embedding_config import ( - EmbeddingConfig, - EmbeddingModelConfig, - ) - - cfg = EmbeddingConfig( - dense=EmbeddingModelConfig( - model="text-embedding-3-small", - api_key="sk-test", - provider="openai", - ), - ) - assert cfg.max_retries == 3 - - def test_custom_max_retries(self): - """EmbeddingConfig should accept custom max_retries.""" - from openviking_cli.utils.config.embedding_config import ( - EmbeddingConfig, - EmbeddingModelConfig, - ) - - cfg = EmbeddingConfig( - dense=EmbeddingModelConfig( - model="text-embedding-3-small", - api_key="sk-test", - provider="openai", - ), - max_retries=7, - ) - assert cfg.max_retries == 7 diff --git a/tests/unit/test_stream_config_vlm.py b/tests/unit/test_stream_config_vlm.py index dea3e285b..64b2f81c2 100644 --- a/tests/unit/test_stream_config_vlm.py +++ b/tests/unit/test_stream_config_vlm.py @@ -253,7 +253,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False): + async def get_completion_async(self, prompt, thinking=False, max_retries=0): return "" def get_vision_completion(self, prompt, images, thinking=False): @@ -277,7 +277,7 @@ class StubVLM(OpenAIVLM): def get_completion(self, prompt, thinking=False): return "" - async def get_completion_async(self, prompt, thinking=False): + async def get_completion_async(self, prompt, thinking=False, max_retries=0): return "" def get_vision_completion(self, prompt, images, thinking=False): diff --git a/tests/unit/test_vlm_retry_integration.py b/tests/unit/test_vlm_retry_integration.py deleted file mode 100644 index e65f0ef37..000000000 --- a/tests/unit/test_vlm_retry_integration.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. -# SPDX-License-Identifier: Apache-2.0 - -"""Integration tests for VLM backends with unified retry logic. - -Tests cover (using OpenAI backend as representative): -- completion retries on 429 (transient) -- completion does NOT retry on 401 (permanent) -- vision completion now retries (was zero before) -- uses config max_retries -- max_retries parameter removed from get_completion_async signature -""" - -from __future__ import annotations - -import inspect -from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -class _HttpError(Exception): - """Fake HTTP error carrying a numeric status code.""" - - def __init__(self, status_code: int, message: str = ""): - super().__init__(message or f"HTTP {status_code}") - self.status_code = status_code - - -def _make_fake_response(content: str = "ok") -> SimpleNamespace: - """Build a minimal fake OpenAI ChatCompletion response.""" - message = SimpleNamespace(content=content, tool_calls=None) - choice = SimpleNamespace(message=message, finish_reason="stop") - usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) - return SimpleNamespace(choices=[choice], usage=usage) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture() -def openai_vlm(): - """Create an OpenAIVLM instance with mocked clients.""" - from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - - vlm = OpenAIVLM( - { - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - "max_retries": 2, - } - ) - - # Mock sync client - mock_sync = MagicMock() - vlm._sync_client = mock_sync - - # Mock async client - mock_async = MagicMock() - vlm._async_client = mock_async - - return vlm - - -# --------------------------------------------------------------------------- -# Tests: get_completion_async retries on 429 -# --------------------------------------------------------------------------- - - -class TestCompletionAsyncRetries: - async def test_retries_on_429(self, openai_vlm): - """get_completion_async should retry on 429 (transient) and succeed.""" - errors = [_HttpError(429), _HttpError(429)] - call_count = 0 - - async def fake_create(**kwargs): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return _make_fake_response("success") - - openai_vlm._async_client.chat.completions.create = fake_create - - with patch("asyncio.sleep", new_callable=AsyncMock): - result = await openai_vlm.get_completion_async(prompt="test") - - assert result == "success" - assert call_count == 3 # 2 failures + 1 success - - async def test_no_retry_on_401(self, openai_vlm): - """get_completion_async should NOT retry on 401 (permanent).""" - call_count = 0 - - async def fake_create(**kwargs): - nonlocal call_count - call_count += 1 - raise _HttpError(401, "Unauthorized") - - openai_vlm._async_client.chat.completions.create = fake_create - - with patch("asyncio.sleep", new_callable=AsyncMock): - with pytest.raises(_HttpError): - await openai_vlm.get_completion_async(prompt="test") - - assert call_count == 1 # no retries - - async def test_uses_config_max_retries(self): - """Backend should use self.max_retries from config, not a param.""" - from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - - vlm = OpenAIVLM( - { - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - "max_retries": 5, - } - ) - assert vlm.max_retries == 5 - - # Config default is now 3 - vlm2 = OpenAIVLM( - { - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - } - ) - assert vlm2.max_retries == 3 - - -# --------------------------------------------------------------------------- -# Tests: get_vision_completion_async now retries -# --------------------------------------------------------------------------- - - -class TestVisionCompletionAsyncRetries: - async def test_vision_retries_on_429(self, openai_vlm): - """get_vision_completion_async should retry on 429 (was zero retry before).""" - errors = [_HttpError(429)] - call_count = 0 - - async def fake_create(**kwargs): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return _make_fake_response("vision ok") - - openai_vlm._async_client.chat.completions.create = fake_create - - with patch("asyncio.sleep", new_callable=AsyncMock): - result = await openai_vlm.get_vision_completion_async( - prompt="describe", - images=["http://example.com/img.png"], - ) - - assert result == "vision ok" - assert call_count == 2 # 1 failure + 1 success - - -# --------------------------------------------------------------------------- -# Tests: sync completion retries -# --------------------------------------------------------------------------- - - -class TestCompletionSyncRetries: - def test_sync_retries_on_429(self, openai_vlm): - """get_completion should retry on 429.""" - errors = [_HttpError(429)] - call_count = 0 - - def fake_create(**kwargs): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return _make_fake_response("sync ok") - - openai_vlm._sync_client.chat.completions.create = fake_create - - with patch("time.sleep"): - result = openai_vlm.get_completion(prompt="test") - - assert result == "sync ok" - assert call_count == 2 - - def test_sync_vision_retries_on_503(self, openai_vlm): - """get_vision_completion should retry on 503.""" - errors = [_HttpError(503)] - call_count = 0 - - def fake_create(**kwargs): - nonlocal call_count - call_count += 1 - if errors: - raise errors.pop(0) - return _make_fake_response("vision sync ok") - - openai_vlm._sync_client.chat.completions.create = fake_create - - with patch("time.sleep"): - result = openai_vlm.get_vision_completion( - prompt="describe", - images=["http://example.com/img.png"], - ) - - assert result == "vision sync ok" - assert call_count == 2 - - -# --------------------------------------------------------------------------- -# Tests: signature change verification -# --------------------------------------------------------------------------- - - -class TestSignatureChange: - def test_no_max_retries_in_get_completion_async(self): - """get_completion_async should no longer accept max_retries parameter.""" - from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - - sig = inspect.signature(OpenAIVLM.get_completion_async) - param_names = list(sig.parameters.keys()) - - assert "max_retries" not in param_names, ( - f"max_retries should be removed from get_completion_async, got params: {param_names}" - ) - - def test_no_max_retries_in_base_get_completion_async(self): - """VLMBase.get_completion_async should no longer accept max_retries parameter.""" - from openviking.models.vlm.base import VLMBase - - sig = inspect.signature(VLMBase.get_completion_async) - param_names = list(sig.parameters.keys()) - - assert "max_retries" not in param_names, ( - f"max_retries should be removed from VLMBase.get_completion_async, got params: {param_names}" - ) - - def test_no_max_retries_in_litellm_get_completion_async(self): - """LiteLLMVLMProvider.get_completion_async should no longer accept max_retries.""" - from openviking.models.vlm.backends.litellm_vlm import LiteLLMVLMProvider - - sig = inspect.signature(LiteLLMVLMProvider.get_completion_async) - param_names = list(sig.parameters.keys()) - - assert "max_retries" not in param_names - - def test_no_max_retries_in_volcengine_get_completion_async(self): - """VolcEngineVLM.get_completion_async should no longer accept max_retries.""" - from openviking.models.vlm.backends.volcengine_vlm import VolcEngineVLM - - sig = inspect.signature(VolcEngineVLM.get_completion_async) - param_names = list(sig.parameters.keys()) - - assert "max_retries" not in param_names - - -# --------------------------------------------------------------------------- -# Tests: OpenAI SDK retry disabled -# --------------------------------------------------------------------------- - - -class TestOpenAISDKRetryDisabled: - def test_sync_client_max_retries_zero(self): - """OpenAI sync client should be created with max_retries=0.""" - from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - - vlm = OpenAIVLM( - { - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - } - ) - - with patch("openai.OpenAI") as mock_openai: - mock_openai.return_value = MagicMock() - vlm._sync_client = None # force re-creation - vlm.get_client() - call_kwargs = mock_openai.call_args - assert call_kwargs[1].get("max_retries") == 0 or ( - len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0 - ) - - def test_async_client_max_retries_zero(self): - """OpenAI async client should be created with max_retries=0.""" - from openviking.models.vlm.backends.openai_vlm import OpenAIVLM - - vlm = OpenAIVLM( - { - "api_key": "sk-test", - "model": "gpt-4o-mini", - "provider": "openai", - } - ) - - with patch("openai.AsyncOpenAI") as mock_async_openai: - mock_async_openai.return_value = MagicMock() - vlm._async_client = None # force re-creation - vlm.get_async_client() - call_kwargs = mock_async_openai.call_args - assert call_kwargs[1].get("max_retries") == 0 or ( - len(call_kwargs[0]) == 0 and call_kwargs.kwargs.get("max_retries") == 0 - )