Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion openviking/models/embedder/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, model_name: str, config: Optional[Dict[str, Any]] = None):
"""
self.model_name = model_name
self.config = config or {}
self.max_retries = self.config.get("max_retries", 3) if self.config else 3

@abstractmethod
def embed(self, text: str, is_query: bool = False) -> EmbedResult:
Expand Down Expand Up @@ -255,7 +256,7 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes

return [
EmbedResult(dense_vector=d.dense_vector, sparse_vector=s.sparse_vector)
for d, s in zip(dense_results, sparse_results)
for d, s in zip(dense_results, sparse_results, strict=False)
]

def get_dimension(self) -> int:
Expand Down
40 changes: 24 additions & 16 deletions openviking/models/embedder/gemini_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
EmbedResult,
truncate_and_normalize,
)
from openviking.models.retry import transient_retry

logger = logging.getLogger("gemini_embedders")

Expand Down Expand Up @@ -146,15 +147,13 @@ def __init__(
)
if dimension is not None and not (1 <= dimension <= 3072):
raise ValueError(f"dimension must be between 1 and 3072, got {dimension}")
# Disable SDK-level retry; we use transient_retry for unified retry logic
if _HTTP_RETRY_AVAILABLE:
self.client = genai.Client(
api_key=api_key,
http_options=HttpOptions(
retry_options=HttpRetryOptions(
attempts=3,
initial_delay=1.0,
max_delay=30.0,
exp_base=2.0,
attempts=1,
)
),
)
Expand Down Expand Up @@ -209,11 +208,16 @@ def embed(
task_type = self.document_param
# SDK accepts plain str; converts to REST Parts format internally.
try:
result = self.client.models.embed_content(
model=self.model_name,
contents=text,
config=self._build_config(task_type=task_type, title=title),
)
embed_config = self._build_config(task_type=task_type, title=title)

def _call():
return self.client.models.embed_content(
model=self.model_name,
contents=text,
config=embed_config,
)

result = transient_retry(_call, max_retries=self.max_retries)
vector = truncate_and_normalize(list(result.embeddings[0].values), self._dimension)
return EmbedResult(dense_vector=vector)
except (APIError, ClientError) as e:
Expand All @@ -233,7 +237,7 @@ def embed_batch(
if titles is not None:
return [
self.embed(text, is_query=is_query, task_type=task_type, title=title)
for text, title in zip(texts, titles)
for text, title in zip(texts, titles, strict=False)
]
# Resolve effective task_type from is_query when no explicit override
if task_type is None:
Expand All @@ -254,13 +258,17 @@ def embed_batch(

non_empty_texts = [batch[j] for j in non_empty_indices]
try:
response = self.client.models.embed_content(
model=self.model_name,
contents=non_empty_texts,
config=config,
)

def _batch_call(texts=non_empty_texts, cfg=config):
return self.client.models.embed_content(
model=self.model_name,
contents=texts,
config=cfg,
)

response = transient_retry(_batch_call, max_retries=self.max_retries)
batch_results = [None] * len(batch)
for j, emb in zip(non_empty_indices, response.embeddings):
for j, emb in zip(non_empty_indices, response.embeddings, strict=False):
batch_results[j] = EmbedResult(
dense_vector=truncate_and_normalize(list(emb.values), self._dimension)
)
Expand Down
13 changes: 11 additions & 2 deletions openviking/models/embedder/jina_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DenseEmbedderBase,
EmbedResult,
)
from openviking.models.retry import transient_retry

# Default dimensions for Jina embedding models
JINA_MODEL_DIMENSIONS = {
Expand Down Expand Up @@ -113,9 +114,11 @@ def __init__(
raise ValueError("api_key is required")

# Initialize OpenAI-compatible client with Jina base URL
# Disable SDK retry; we use transient_retry for unified retry logic
self.client = openai.OpenAI(
api_key=self.api_key,
base_url=self.api_base,
max_retries=0,
)

# Determine dimension
Expand Down Expand Up @@ -174,7 +177,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
if extra_body:
kwargs["extra_body"] = extra_body

response = self.client.embeddings.create(**kwargs)
def _call():
return self.client.embeddings.create(**kwargs)

response = transient_retry(_call, max_retries=self.max_retries)
vector = response.data[0].embedding

return EmbedResult(dense_vector=vector)
Expand Down Expand Up @@ -209,7 +215,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
if extra_body:
kwargs["extra_body"] = extra_body

response = self.client.embeddings.create(**kwargs)
def _call():
return self.client.embeddings.create(**kwargs)

response = transient_retry(_call, max_retries=self.max_retries)

return [EmbedResult(dense_vector=item.embedding) for item in response.data]
except openai.APIError as e:
Expand Down
13 changes: 11 additions & 2 deletions openviking/models/embedder/litellm_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import litellm

from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult
from openviking.models.retry import transient_retry
from openviking.telemetry import get_current_telemetry

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -157,7 +158,11 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
try:
kwargs = self._build_kwargs(is_query=is_query)
kwargs["input"] = [text]
response = litellm.embedding(**kwargs)

def _call():
return litellm.embedding(**kwargs)

response = transient_retry(_call, max_retries=self.max_retries)
self._update_telemetry_token_usage(response)
vector = response.data[0]["embedding"]
return EmbedResult(dense_vector=vector)
Expand All @@ -183,7 +188,11 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
try:
kwargs = self._build_kwargs(is_query=is_query)
kwargs["input"] = texts
response = litellm.embedding(**kwargs)

def _call():
return litellm.embedding(**kwargs)

response = transient_retry(_call, max_retries=self.max_retries)
self._update_telemetry_token_usage(response)
return [EmbedResult(dense_vector=item["embedding"]) for item in response.data]
except Exception as e:
Expand Down
21 changes: 11 additions & 10 deletions openviking/models/embedder/minimax_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from urllib3.util.retry import Retry

from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult
from openviking.models.retry import transient_retry
from openviking_cli.utils.logger import default_logger as logger


Expand Down Expand Up @@ -89,12 +90,8 @@ def __init__(
def _create_session(self) -> requests.Session:
"""Create a requests session with retry logic"""
session = requests.Session()
retry_strategy = Retry(
total=6,
backoff_factor=1, # 1s, 2s, 4s, 8s, 16s, 32s
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["POST"],
)
# Disable transport-level retry; we use transient_retry for unified retry logic
retry_strategy = Retry(total=0)
adapter = HTTPAdapter(max_retries=retry_strategy)
session.mount("https://", adapter)
session.mount("http://", adapter)
Expand Down Expand Up @@ -163,17 +160,21 @@ def _call_api(self, texts: List[str], is_query: bool = False) -> List[List[float

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
"""Perform dense embedding on text"""
vectors = self._call_api([text], is_query=is_query)
vectors = transient_retry(
lambda: self._call_api([text], is_query=is_query),
max_retries=self.max_retries,
)
return EmbedResult(dense_vector=vectors[0])

def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
"""Batch embedding"""
if not texts:
return []

# MiniMax might have batch size limits, but let's assume the caller handles batching or use safe defaults
# For now, we pass through. If needed, we can implement internal chunking.
vectors = self._call_api(texts, is_query=is_query)
vectors = transient_retry(
lambda: self._call_api(texts, is_query=is_query),
max_retries=self.max_retries,
)
return [EmbedResult(dense_vector=v) for v in vectors]

def get_dimension(self) -> int:
Expand Down
18 changes: 14 additions & 4 deletions openviking/models/embedder/openai_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

import openai

from openviking.models.vlm.registry import DEFAULT_AZURE_API_VERSION
from openviking.models.embedder.base import (
DenseEmbedderBase,
EmbedResult,
HybridEmbedderBase,
SparseEmbedderBase,
)
from openviking.models.retry import transient_retry
from openviking.models.vlm.registry import DEFAULT_AZURE_API_VERSION
from openviking.telemetry import get_current_telemetry


Expand Down Expand Up @@ -118,7 +119,10 @@ def __init__(
if not self.api_key and not self.api_base:
raise ValueError("api_key is required")

client_kwargs: Dict[str, Any] = {"api_key": self.api_key or "no-key"}
client_kwargs: Dict[str, Any] = {
"api_key": self.api_key or "no-key",
"max_retries": 0, # Disable SDK retry; we use transient_retry
}
if self._provider == "azure":
if not self.api_base:
raise ValueError("api_base (Azure endpoint) is required for Azure provider")
Expand Down Expand Up @@ -242,7 +246,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
if extra_body:
kwargs["extra_body"] = extra_body

response = self.client.embeddings.create(**kwargs)
def _call():
return self.client.embeddings.create(**kwargs)

response = transient_retry(_call, max_retries=self.max_retries)
self._update_telemetry_token_usage(response)
vector = response.data[0].embedding

Expand Down Expand Up @@ -277,7 +284,10 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
if extra_body:
kwargs["extra_body"] = extra_body

response = self.client.embeddings.create(**kwargs)
def _call():
return self.client.embeddings.create(**kwargs)

response = transient_retry(_call, max_retries=self.max_retries)
self._update_telemetry_token_usage(response)

return [EmbedResult(dense_vector=item.embedding) for item in response.data]
Expand Down
35 changes: 27 additions & 8 deletions openviking/models/embedder/vikingdb_embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
HybridEmbedderBase,
SparseEmbedderBase,
)
from openviking.models.retry import transient_retry
from openviking.storage.vectordb.collection.volcengine_clients import ClientForDataApi
from openviking_cli.utils.logger import default_logger as logger

Expand Down Expand Up @@ -124,7 +125,10 @@ def __init__(
self.dense_model = {"name": model_name, "version": model_version, "dim": dimension}

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
results = self._call_api([text], dense_model=self.dense_model)
results = transient_retry(
lambda: self._call_api([text], dense_model=self.dense_model),
max_retries=self.max_retries,
)
if not results:
return EmbedResult(dense_vector=[])

Expand All @@ -138,7 +142,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
raw_results = self._call_api(texts, dense_model=self.dense_model)
raw_results = transient_retry(
lambda: self._call_api(texts, dense_model=self.dense_model),
max_retries=self.max_retries,
)
return [
EmbedResult(
dense_vector=self._truncate_and_normalize(
Expand Down Expand Up @@ -174,7 +181,10 @@ def __init__(
}

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
results = self._call_api([text], sparse_model=self.sparse_model)
results = transient_retry(
lambda: self._call_api([text], sparse_model=self.sparse_model),
max_retries=self.max_retries,
)
if not results:
return EmbedResult(sparse_vector={})

Expand All @@ -188,7 +198,10 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
raw_results = self._call_api(texts, sparse_model=self.sparse_model)
raw_results = transient_retry(
lambda: self._call_api(texts, sparse_model=self.sparse_model),
max_retries=self.max_retries,
)
return [
EmbedResult(
sparse_vector=self._process_sparse_embedding(item.get("sparse_embedding", {}))
Expand Down Expand Up @@ -224,8 +237,11 @@ def __init__(
}

def embed(self, text: str, is_query: bool = False) -> EmbedResult:
results = self._call_api(
[text], dense_model=self.dense_model, sparse_model=self.sparse_model
results = transient_retry(
lambda: self._call_api(
[text], dense_model=self.dense_model, sparse_model=self.sparse_model
),
max_retries=self.max_retries,
)
if not results:
return EmbedResult(dense_vector=[], sparse_vector={})
Expand All @@ -244,8 +260,11 @@ def embed(self, text: str, is_query: bool = False) -> EmbedResult:
def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]:
if not texts:
return []
raw_results = self._call_api(
texts, dense_model=self.dense_model, sparse_model=self.sparse_model
raw_results = transient_retry(
lambda: self._call_api(
texts, dense_model=self.dense_model, sparse_model=self.sparse_model
),
max_retries=self.max_retries,
)
results = []
for item in raw_results:
Expand Down
Loading
Loading