diff --git a/openviking/storage/vectordb/collection/local_collection.py b/openviking/storage/vectordb/collection/local_collection.py index dc84d39f1..74647a811 100644 --- a/openviking/storage/vectordb/collection/local_collection.py +++ b/openviking/storage/vectordb/collection/local_collection.py @@ -56,6 +56,7 @@ def get_or_create_local_collection( path: str = "", vectorizer: Optional[BaseVectorizer] = None, config: Optional[Dict[str, Any]] = None, + cache_config: Optional[Dict[str, Any]] = None, ): """Create or retrieve a local Collection. @@ -67,6 +68,10 @@ def get_or_create_local_collection( - "ttl_cleanup_seconds": Interval (in seconds) for TTL expiration data cleanup - "index_maintenance_seconds": Interval (in seconds) for index maintenance tasks If not provided, values will be obtained from environment variables or defaults + cache_config: Cache configuration for query result caching, optional settings include: + - "max_size": Maximum number of cache entries (default: 1000) + - "ttl_seconds": Time-to-live for cache entries in seconds (default: 300) + - "enabled": Whether caching is enabled (default: True) Returns: Collection: Collection instance @@ -81,6 +86,11 @@ def get_or_create_local_collection( ... config={ ... "ttl_cleanup_seconds": 5, ... "index_maintenance_seconds": 60 + ... }, + ... cache_config={ + ... "max_size": 2000, + ... "ttl_seconds": 600, + ... "enabled": True ... } ... ) @@ -103,7 +113,7 @@ def get_or_create_local_collection( ) store_mgr = create_store_manager("local") collection = VolatileCollection( - meta=meta, store=store_mgr, vectorizer=vectorizer, config=config + meta=meta, store=store_mgr, vectorizer=vectorizer, config=config, cache_config=cache_config ) return Collection(collection) else: @@ -118,7 +128,7 @@ def get_or_create_local_collection( storage_path = os.path.join(path, STORAGE_DIR_NAME) store_mgr = create_store_manager("local", storage_path) collection = PersistCollection( - path=path, meta=meta, store=store_mgr, vectorizer=vectorizer, config=config + path=path, meta=meta, store=store_mgr, vectorizer=vectorizer, config=config, cache_config=cache_config ) return Collection(collection) @@ -130,6 +140,7 @@ def __init__( store_mgr: StoreManager, vectorizer: Optional[BaseVectorizer] = None, config: Optional[Dict[str, Any]] = None, + cache_config: Optional[Dict[str, Any]] = None, ): self.indexes = ThreadSafeDictManager[IIndex]() self.meta: CollectionMeta = meta @@ -160,6 +171,9 @@ def __init__( executors={"default": {"type": "threadpool", "max_workers": 1}} ) self.scheduler.start() + + # Cache configuration for all indexes + self.cache_config = cache_config or {} def update(self, fields: Optional[Dict[str, Any]] = None, description: Optional[str] = None): meta_data: Dict[str, Any] = {} @@ -326,6 +340,162 @@ def search_by_vector( ] return search_result + def batch_search_by_vector( + self, + index_name: str, + dense_vectors: List[List[float]], + limit: int = 10, + offset: int = 0, + filters: Optional[Dict[str, Any]] = None, + sparse_vectors: Optional[List[Dict[str, float]]] = None, + output_fields: Optional[List[str]] = None, + num_threads: Optional[int] = None, + ) -> List[SearchResult]: + """Perform batch vector similarity search with multiple query vectors. + + This method searches with multiple query vectors in a single call, + providing significant performance improvements for batch workloads + through parallel processing and query result caching. + + Args: + index_name: Name of the index to search + dense_vectors: List of dense query vectors + limit: Maximum number of results to return per query. Defaults to 10. + offset: Number of results to skip per query. Defaults to 0. + filters: Query DSL for filtering results by scalar fields. + sparse_vectors: List of sparse vectors (dictionaries) for hybrid search. + output_fields: List of fields to include in results. + num_threads: Number of threads for parallel search. Defaults to 4. + + Returns: + List of SearchResult objects, one per query vector, in the same order + as the input dense_vectors. + + Example: + >>> query_vectors = [[0.1, 0.2, ...], [0.3, 0.4, ...], [0.5, 0.6, ...]] + >>> results = collection.batch_search_by_vector( + ... index_name="my_index", + ... dense_vectors=query_vectors, + ... limit=10 + ... ) + >>> for i, result in enumerate(results): + ... print(f"Query {i}: {len(result.data)} results") + """ + if not dense_vectors: + return [] + + index = self.indexes.get(index_name) + if not index: + return [SearchResult() for _ in dense_vectors] + + # Prepare sparse vectors if provided + sparse_raw_terms_list = None + sparse_values_list = None + if sparse_vectors: + sparse_raw_terms_list = [] + sparse_values_list = [] + for sv in sparse_vectors: + if sv and isinstance(sv, dict): + sparse_raw_terms_list.append(list(sv.keys())) + sparse_values_list.append(list(sv.values())) + else: + sparse_raw_terms_list.append([]) + sparse_values_list.append([]) + + # Perform batch search with parallel processing + actual_limit = limit + offset + batch_results = index.batch_search( + dense_vectors, actual_limit, filters, sparse_raw_terms_list, sparse_values_list, num_threads + ) + + # Process results for each query + search_results: List[SearchResult] = [] + if not output_fields: + output_fields = list(self.meta.fields_dict.keys()) + + for label_list, scores_list in batch_results: + search_result = SearchResult() + + # Apply offset by slicing the results + if offset > 0: + label_list = label_list[offset:] + scores_list = scores_list[offset:] + + # Limit to requested size + if len(label_list) > limit: + label_list = label_list[:limit] + scores_list = scores_list[:limit] + + pk_list = label_list + fields_list = [] + + if self.meta.primary_key or output_fields: + if not self.store_mgr: + raise RuntimeError("Store manager is not initialized") + + # Fetch candidate data for labels + if label_list: + cands_list = self.store_mgr.fetch_cands_data(label_list) + + valid_indices = [] + for i, cand in enumerate(cands_list): + if cand is not None: + valid_indices.append(i) + + if len(valid_indices) < len(cands_list): + cands_list = [cands_list[i] for i in valid_indices] + pk_list = [pk_list[i] for i in valid_indices] + scores_list = [scores_list[i] for i in valid_indices] + + if cands_list: + cands_fields = [json.loads(cand.fields) for cand in cands_list] + + if self.meta.primary_key: + pk_list = [ + cands_field.get(self.meta.primary_key, "") + for cands_field in cands_fields + ] + fields_list = [ + {field: cands_field.get(field, None) for field in output_fields} + for cands_field in cands_fields + ] + if self.meta.vector_key: + for i, cands in enumerate(cands_list): + fields_list[i][self.meta.vector_key] = cands.vector + + search_result.data = [ + SearchItemResult(id=pk, fields=fields, score=score) + for pk, score, fields in zip_longest(pk_list, scores_list, fields_list) + ] + search_results.append(search_result) + + return search_results + + def get_index_cache_stats(self, index_name: str) -> Optional[Dict[str, Any]]: + """Get cache statistics for a specific index. + + Args: + index_name: Name of the index + + Returns: + Dictionary containing cache statistics if the index exists, + None otherwise. + """ + index = self.indexes.get(index_name) + if not index: + return None + return index.get_cache_stats() + + def invalidate_index_cache(self, index_name: str) -> None: + """Invalidate the query cache for a specific index. + + Args: + index_name: Name of the index + """ + index = self.indexes.get(index_name) + if index: + index.invalidate_cache() + def search_by_id( self, index_name: str, @@ -895,8 +1065,9 @@ def __init__( store: StoreManager, vectorizer: Optional[BaseVectorizer] = None, config: Optional[Dict[str, Any]] = None, + cache_config: Optional[Dict[str, Any]] = None, ): - super().__init__(meta, store, vectorizer, config) + super().__init__(meta, store, vectorizer, config, cache_config) LocalCollection._register_scheduler_job(self) def _new_index( @@ -911,6 +1082,7 @@ def _new_index( name=index_name, meta=meta, cands_list=cands_list, + cache_config=self.cache_config, ) return index @@ -926,12 +1098,13 @@ def __init__( store: StoreManager, vectorizer: Optional[BaseVectorizer] = None, config: Optional[Dict[str, Any]] = None, + cache_config: Optional[Dict[str, Any]] = None, ): self.collection_dir = path os.makedirs(self.collection_dir, exist_ok=True) self.index_dir = os.path.join(self.collection_dir, "index") os.makedirs(self.index_dir, exist_ok=True) - super().__init__(meta, store, vectorizer, config) + super().__init__(meta, store, vectorizer, config, cache_config) self._recover() LocalCollection._register_scheduler_job(self) # TTL expiration data cleanup @@ -1031,6 +1204,7 @@ def _new_index( meta=meta, cands_list=cands_list, force_rebuild=force_rebuild, + cache_config=self.cache_config, ) return index diff --git a/openviking/storage/vectordb/index/index.py b/openviking/storage/vectordb/index/index.py index 81b6b1b85..084934d6f 100644 --- a/openviking/storage/vectordb/index/index.py +++ b/openviking/storage/vectordb/index/index.py @@ -273,6 +273,65 @@ def need_rebuild(self) -> bool: """ return True + def batch_search( + self, + query_vectors: List[List[float]], + limit: int = 10, + filters: Optional[Dict[str, Any]] = None, + sparse_raw_terms_list: Optional[List[List[str]]] = None, + sparse_values_list: Optional[List[List[float]]] = None, + ) -> List[Tuple[List[int], List[float]]]: + """Perform batch vector similarity search with multiple query vectors. + + This method allows searching with multiple query vectors in a single call, + which can be more efficient than multiple individual searches due to + better cache utilization and reduced overhead. + + Args: + query_vectors: List of dense query vectors for similarity matching. + Each vector should have the same dimensionality as indexed vectors. + limit: Maximum number of results to return per query. Defaults to 10. + filters: Query DSL for filtering results by scalar fields. + Applied to all queries in the batch. + sparse_raw_terms_list: List of term token lists for sparse vector search. + Each inner list corresponds to a query vector. + sparse_values_list: List of weight lists for sparse vector search. + Each inner list corresponds to a query vector. + + Returns: + List of tuples, one per query vector, each containing: + - List of labels (record identifiers) sorted by similarity + - List of similarity scores corresponding to each label + + Note: + Default implementation calls search() for each query vector sequentially. + Subclasses may override this method to provide optimized batch processing. + """ + # Default implementation: sequential search + results = [] + for i, query_vector in enumerate(query_vectors): + sparse_terms = sparse_raw_terms_list[i] if sparse_raw_terms_list else None + sparse_values = sparse_values_list[i] if sparse_values_list else None + result = self.search(query_vector, limit, filters, sparse_terms, sparse_values) + results.append(result) + return results + + def get_cache_stats(self) -> Optional[Dict[str, Any]]: + """Get cache statistics for this index. + + Returns: + Dictionary containing cache statistics if caching is enabled, + None otherwise. + """ + return None + + def invalidate_cache(self) -> None: + """Invalidate the query cache for this index. + + Should be called when the underlying data is modified. + """ + pass + class Index: """ @@ -480,3 +539,57 @@ def aggregate( if self.__index is None: raise RuntimeError("Index is not initialized") return self.__index.aggregate(filters) + + def batch_search( + self, + query_vectors: List[List[float]], + limit: int = 10, + filters: Optional[Dict[str, Any]] = None, + sparse_raw_terms_list: Optional[List[List[str]]] = None, + sparse_values_list: Optional[List[List[float]]] = None, + ) -> List[Tuple[List[int], List[float]]]: + """Perform batch vector similarity search with multiple query vectors. + + This method allows searching with multiple query vectors in a single call, + which can be more efficient than multiple individual searches due to + better cache utilization and reduced overhead. + + Args: + query_vectors: List of dense query vectors for similarity matching. + limit: Maximum number of results to return per query. Defaults to 10. + filters: Query DSL for filtering results by scalar fields. + sparse_raw_terms_list: List of term token lists for sparse vector search. + sparse_values_list: List of weight lists for sparse vector search. + + Returns: + List of tuples, one per query vector, each containing: + - List of labels (record identifiers) sorted by similarity + - List of similarity scores corresponding to each label + + Raises: + RuntimeError: If the underlying index is not initialized. + """ + if self.__index is None: + raise RuntimeError("Index is not initialized") + return self.__index.batch_search( + query_vectors, limit, filters, sparse_raw_terms_list, sparse_values_list + ) + + def get_cache_stats(self) -> Optional[Dict[str, Any]]: + """Get cache statistics for this index. + + Returns: + Dictionary containing cache statistics if caching is enabled, + None otherwise. + """ + if self.__index is None: + return None + return self.__index.get_cache_stats() + + def invalidate_cache(self) -> None: + """Invalidate the query cache for this index. + + Should be called when the underlying data is modified. + """ + if self.__index is not None: + self.__index.invalidate_cache() diff --git a/openviking/storage/vectordb/index/local_index.py b/openviking/storage/vectordb/index/local_index.py index eadd17c5e..1a2b83ef4 100644 --- a/openviking/storage/vectordb/index/local_index.py +++ b/openviking/storage/vectordb/index/local_index.py @@ -5,6 +5,7 @@ import os import shutil import time +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -13,6 +14,7 @@ from openviking.storage.vectordb.store.data import CandidateData, DeltaRecord from openviking.storage.vectordb.utils.constants import IndexFileMarkers from openviking.storage.vectordb.utils.data_processor import DataProcessor +from openviking.storage.vectordb.utils.query_cache import QueryCache from openviking_cli.utils.logger import default_logger as logger @@ -190,6 +192,8 @@ class LocalIndex(IIndex): - Metadata management and updates - Search operations with filtering and aggregation - Data lifecycle (upsert, delete, close, drop) + - Query result caching for improved performance + - Batch search support for multiple queries This class serves as the base for both VolatileIndex (in-memory) and PersistentIndex (disk-backed with versioning). @@ -197,14 +201,24 @@ class LocalIndex(IIndex): Attributes: engine_proxy (IndexEngineProxy): Proxy to the underlying index engine meta: Index metadata including configuration and schema + query_cache: Optional LRU cache for query results """ - def __init__(self, index_path_or_json: str, meta: Any): + # Default cache configuration + DEFAULT_CACHE_MAX_SIZE = 1000 + DEFAULT_CACHE_TTL_SECONDS = 300.0 # 5 minutes + DEFAULT_BATCH_SEARCH_THREADS = 4 + + def __init__(self, index_path_or_json: str, meta: Any, cache_config: Optional[Dict[str, Any]] = None): """Initialize a local index instance. Args: index_path_or_json (str): Path to index files or JSON configuration meta: Index metadata object containing configuration + cache_config: Optional cache configuration with keys: + - max_size: Maximum number of cache entries (default: 1000) + - ttl_seconds: Time-to-live for cache entries (default: 300) + - enabled: Whether caching is enabled (default: True) """ # Get the vector normalization flag from meta normalize_vector_flag = meta.inner_meta.get("VectorIndex", {}).get("NormalizeVector", False) @@ -213,7 +227,14 @@ def __init__(self, index_path_or_json: str, meta: Any): ) self.meta = meta self.field_type_converter = DataProcessor(self.meta.collection_meta.fields_dict) - pass + + # Initialize query cache + cache_config = cache_config or {} + self.query_cache = QueryCache( + max_size=cache_config.get("max_size", self.DEFAULT_CACHE_MAX_SIZE), + ttl_seconds=cache_config.get("ttl_seconds", self.DEFAULT_CACHE_TTL_SECONDS), + enabled=cache_config.get("enabled", True), + ) def update( self, @@ -235,10 +256,14 @@ def get_meta_data(self): def upsert_data(self, delta_list: List[DeltaRecord]): if self.engine_proxy: self.engine_proxy.upsert_data(self._convert_delta_list_for_index(delta_list)) + # Invalidate cache when data is modified + self.invalidate_cache() def delete_data(self, delta_list: List[DeltaRecord]): if self.engine_proxy: self.engine_proxy.delete_data(self._convert_delta_list_for_index(delta_list)) + # Invalidate cache when data is modified + self.invalidate_cache() def search( self, @@ -257,13 +282,155 @@ def search( if sparse_values is None: sparse_values = [] + # Try to get from cache first + cached_result = self.query_cache.get( + query_vector, limit, filters, sparse_raw_terms, sparse_values + ) + if cached_result is not None: + return cached_result + + # Convert filters for index if self.field_type_converter and filters is not None: filters = self.field_type_converter.convert_filter_for_index(filters) - return self.engine_proxy.search( + + result = self.engine_proxy.search( query_vector, limit, filters, sparse_raw_terms, sparse_values ) + + # Cache the result + self.query_cache.put( + query_vector, limit, filters, sparse_raw_terms, sparse_values, + result[0], result[1] + ) + + return result return [], [] + def batch_search( + self, + query_vectors: List[List[float]], + limit: int = 10, + filters: Optional[Dict[str, Any]] = None, + sparse_raw_terms_list: Optional[List[List[str]]] = None, + sparse_values_list: Optional[List[List[float]]] = None, + num_threads: Optional[int] = None, + ) -> List[Tuple[List[int], List[float]]]: + """Perform batch vector similarity search with parallel processing. + + This method processes multiple query vectors in parallel using a thread pool, + providing significant performance improvements when searching with many queries. + + Args: + query_vectors: List of dense query vectors for similarity matching. + limit: Maximum number of results to return per query. Defaults to 10. + filters: Query DSL for filtering results by scalar fields. + sparse_raw_terms_list: List of term token lists for sparse vector search. + sparse_values_list: List of weight lists for sparse vector search. + num_threads: Number of threads for parallel search. Defaults to 4. + + Returns: + List of tuples, one per query vector, each containing: + - List of labels (record identifiers) sorted by similarity + - List of similarity scores corresponding to each label + + Note: + Results are returned in the same order as input query_vectors. + Queries with cache hits are served from cache without threading. + """ + if not query_vectors: + return [] + + if not self.engine_proxy: + return [([], []) for _ in query_vectors] + + # Handle defaults + if filters is None: + filters = {} + if sparse_raw_terms_list is None: + sparse_raw_terms_list = [None] * len(query_vectors) + if sparse_values_list is None: + sparse_values_list = [None] * len(query_vectors) + + if num_threads is None: + num_threads = self.DEFAULT_BATCH_SEARCH_THREADS + + results: List[Optional[Tuple[List[int], List[float]]]] = [None] * len(query_vectors) + uncached_indices: List[int] = [] + uncached_queries: List[Tuple[int, List[float], Optional[List[str]], Optional[List[float]]]] = [] + + # Check cache for all queries + for i, query_vector in enumerate(query_vectors): + sparse_terms = sparse_raw_terms_list[i] + sparse_values = sparse_values_list[i] + + cached_result = self.query_cache.get( + query_vector, limit, filters, sparse_terms, sparse_values + ) + if cached_result is not None: + results[i] = cached_result + else: + uncached_indices.append(i) + uncached_queries.append((i, query_vector, sparse_terms, sparse_values)) + + # If all results are from cache, return early + if not uncached_queries: + return [r if r is not None else ([], []) for r in results] + + # Convert filters once for all queries + converted_filters = filters + if self.field_type_converter and filters is not None: + converted_filters = self.field_type_converter.convert_filter_for_index(filters) + + # Execute uncached queries in parallel + def search_single(args: Tuple[int, List[float], Optional[List[str]], Optional[List[float]]]) -> Tuple[int, Tuple[List[int], List[float]]]: + idx, query_vector, sparse_terms, sparse_values = args + if sparse_terms is None: + sparse_terms = [] + if sparse_values is None: + sparse_values = [] + result = self.engine_proxy.search( + query_vector, limit, converted_filters, sparse_terms, sparse_values + ) + return idx, result + + # Use thread pool for parallel execution + with ThreadPoolExecutor(max_workers=min(num_threads, len(uncached_queries))) as executor: + futures = [executor.submit(search_single, args) for args in uncached_queries] + + for future in as_completed(futures): + try: + idx, result = future.result() + results[idx] = result + + # Cache the result + query_vector = query_vectors[idx] + sparse_terms = sparse_raw_terms_list[idx] + sparse_values = sparse_values_list[idx] + self.query_cache.put( + query_vector, limit, filters, sparse_terms, sparse_values, + result[0], result[1] + ) + except Exception as e: + logger.error(f"Batch search error for query: {e}") + + # Fill in any remaining None results with empty tuples + return [r if r is not None else ([], []) for r in results] + + def get_cache_stats(self) -> Optional[Dict[str, Any]]: + """Get cache statistics for this index. + + Returns: + Dictionary containing cache statistics if caching is enabled. + """ + return self.query_cache.get_stats() + + def invalidate_cache(self) -> None: + """Invalidate the query cache for this index. + + Should be called when the underlying data is modified. + """ + self.query_cache.invalidate() + def aggregate( self, filters: Optional[Dict[str, Any]] = None, @@ -306,12 +473,14 @@ def aggregate( return agg_data def close(self): + self.invalidate_cache() pass def drop(self): if self.engine_proxy: self.engine_proxy.drop() self.meta = None + self.invalidate_cache() def get_newest_version(self) -> Union[int, str, Any]: return 0 @@ -389,6 +558,7 @@ class VolatileIndex(LocalIndex): - Data lost on process restart - Always requires rebuild from scratch on startup - Suitable for temporary indexes, testing, or when persistence is handled externally + - Supports query result caching for frequently repeated queries The index is created from an initial dataset and can be updated incrementally, but all changes exist only in memory. @@ -396,9 +566,16 @@ class VolatileIndex(LocalIndex): Attributes: engine_proxy (IndexEngineProxy): Proxy to the in-memory index engine meta: Index metadata and configuration + query_cache: LRU cache for query results """ - def __init__(self, name: str, meta: Any, cands_list: Optional[List[CandidateData]] = None): + def __init__( + self, + name: str, + meta: Any, + cands_list: Optional[List[CandidateData]] = None, + cache_config: Optional[Dict[str, Any]] = None, + ): """Initialize a volatile (in-memory) index. Creates a new in-memory index and populates it with the initial dataset. @@ -408,6 +585,10 @@ def __init__(self, name: str, meta: Any, cands_list: Optional[List[CandidateData meta: Index metadata containing configuration (dimensions, distance metric, etc.) cands_list (list): Initial list of CandidateData records to populate the index. Defaults to None (empty index). + cache_config: Optional cache configuration with keys: + - max_size: Maximum number of cache entries (default: 1000) + - ttl_seconds: Time-to-live for cache entries (default: 300) + - enabled: Whether caching is enabled (default: True) Note: The index is immediately built in memory with the provided data. @@ -431,6 +612,14 @@ def __init__(self, name: str, meta: Any, cands_list: Optional[List[CandidateData self.meta = meta self.field_type_converter = DataProcessor(self.meta.collection_meta.fields_dict) self.engine_proxy.add_data(self._convert_candidate_list_for_index(cands_list)) + + # Initialize query cache + cache_config = cache_config or {} + self.query_cache = QueryCache( + max_size=cache_config.get("max_size", self.DEFAULT_CACHE_MAX_SIZE), + ttl_seconds=cache_config.get("ttl_seconds", self.DEFAULT_CACHE_TTL_SECONDS), + enabled=cache_config.get("enabled", True), + ) def need_rebuild(self) -> bool: """Determine if rebuild is needed. @@ -466,6 +655,7 @@ class PersistentIndex(LocalIndex): - Crash recovery through versioned checkpoints - Background persistence without blocking operations - Old version cleanup to manage disk space + - Query result caching for improved performance The index maintains multiple versions on disk, each identified by a timestamp. New versions are created during persist() operations when the index has been modified. @@ -485,6 +675,7 @@ class PersistentIndex(LocalIndex): now_version (str): Current active version identifier engine_proxy (IndexEngineProxy): Proxy to the persistent index engine meta: Index metadata and configuration + query_cache: LRU cache for query results """ def __init__( @@ -495,6 +686,7 @@ def __init__( cands_list: Optional[List[CandidateData]] = None, force_rebuild: bool = False, initial_timestamp: Optional[int] = None, + cache_config: Optional[Dict[str, Any]] = None, ): """Initialize a persistent index with versioning support. @@ -510,6 +702,10 @@ def __init__( Defaults to False. initial_timestamp (Optional[int]): Timestamp to use if creating a new index from scratch. If None, uses current time. Useful for recovery scenarios. + cache_config: Optional cache configuration with keys: + - max_size: Maximum number of cache entries (default: 1000) + - ttl_seconds: Time-to-live for cache entries (default: 300) + - enabled: Whether caching is enabled (default: True) Process: 1. Create directory structure if not exists @@ -538,7 +734,7 @@ def __init__( self.now_version = str(newest_version) index_path = os.path.join(self.version_dir, self.now_version) - super().__init__(index_path, meta) + super().__init__(index_path, meta, cache_config) # Remove scheduling logic, unified scheduling by collection layer def _create_new_index( @@ -582,6 +778,7 @@ def close(self): 1. Persists any uncommitted changes to disk 2. Releases the index engine resources 3. Cleans up old version files, keeping only the latest + 4. Invalidates the query cache This ensures data durability and proper resource cleanup. After close(), the index cannot be used for further operations. @@ -602,6 +799,9 @@ def close(self): except Exception as e: logger.error(f"Failed to clean index files during close: {e}") + # 4. Invalidate cache + self.invalidate_cache() + super().close() def persist(self) -> int: @@ -625,6 +825,7 @@ def persist(self) -> int: - Dump index to new timestamped directory - Mark snapshot as complete with .write_done file - Clean up old versions (keeps current and new) + - Invalidate cache to ensure fresh results 3. If not modified, return 0 (no-op) Note: @@ -647,6 +848,8 @@ def persist(self) -> int: shutil.move(index_path, dump_index_path) Path(dump_index_path + ".write_done").touch() self._clean_index([self.now_version, str(dump_version)]) + # Invalidate cache after persist to ensure fresh results + self.invalidate_cache() return dump_version return 0 diff --git a/openviking/storage/vectordb/utils/query_cache.py b/openviking/storage/vectordb/utils/query_cache.py new file mode 100644 index 000000000..3872ec3df --- /dev/null +++ b/openviking/storage/vectordb/utils/query_cache.py @@ -0,0 +1,391 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Query result caching module for vector search optimization. + +This module provides an LRU (Least Recently Used) cache implementation +for storing and retrieving vector search results, reducing redundant +computations for frequently repeated queries. +""" + +import hashlib +import json +import threading +import time +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class CacheEntry: + """A single cache entry storing search results and metadata. + + Attributes: + labels: List of result labels (record identifiers) + scores: List of similarity scores + created_at: Timestamp when the entry was created + access_count: Number of times this entry has been accessed + """ + labels: List[int] + scores: List[float] + created_at: float = field(default_factory=time.time) + access_count: int = 0 + + +class QueryCache: + """Thread-safe LRU cache for vector search results. + + This cache stores search results keyed by a hash of the query parameters, + including the query vector, filters, and other search parameters. + + Features: + - Thread-safe operations using a reentrant lock + - LRU eviction when capacity is reached + - TTL-based expiration of stale entries + - Cache statistics tracking (hits, misses, evictions) + + Attributes: + max_size: Maximum number of entries in the cache + ttl_seconds: Time-to-live for cache entries in seconds (0 = no TTL) + enabled: Whether caching is enabled + """ + + def __init__( + self, + max_size: int = 1000, + ttl_seconds: float = 300.0, + enabled: bool = True, + ): + """Initialize the query cache. + + Args: + max_size: Maximum number of entries to store. Defaults to 1000. + ttl_seconds: Time-to-live for entries in seconds. + Set to 0 to disable TTL-based expiration. Defaults to 300 (5 minutes). + enabled: Whether caching is enabled. Defaults to True. + """ + self.max_size = max_size + self.ttl_seconds = ttl_seconds + self.enabled = enabled + self._cache: OrderedDict[str, CacheEntry] = OrderedDict() + self._lock = threading.RLock() + + # Statistics + self._hits = 0 + self._misses = 0 + self._evictions = 0 + + def _compute_key( + self, + query_vector: Optional[List[float]], + limit: int, + filters: Optional[Dict[str, Any]], + sparse_raw_terms: Optional[List[str]], + sparse_values: Optional[List[float]], + ) -> str: + """Compute a cache key from query parameters. + + Args: + query_vector: Dense query vector + limit: Maximum number of results + filters: Query filters + sparse_raw_terms: Sparse vector terms + sparse_values: Sparse vector values + + Returns: + A unique string key for the query + """ + # Convert query parameters to a hashable representation + key_parts = [] + + # Handle query vector - convert to tuple for hashing + if query_vector is not None: + # Round to 6 decimal places to handle floating point variations + rounded_vector = tuple(round(v, 6) for v in query_vector) + key_parts.append(("vector", rounded_vector)) + + key_parts.append(("limit", limit)) + + # Handle filters - convert to JSON string for consistent hashing + if filters: + filter_str = json.dumps(filters, sort_keys=True) + key_parts.append(("filters", filter_str)) + + # Handle sparse vector + if sparse_raw_terms and sparse_values: + sparse_tuple = tuple(zip(sparse_raw_terms, + [round(v, 6) for v in sparse_values])) + key_parts.append(("sparse", sparse_tuple)) + + # Create hash of the key parts + key_str = str(key_parts) + return hashlib.sha256(key_str.encode()).hexdigest() + + def get( + self, + query_vector: Optional[List[float]], + limit: int, + filters: Optional[Dict[str, Any]], + sparse_raw_terms: Optional[List[str]], + sparse_values: Optional[List[float]], + ) -> Optional[Tuple[List[int], List[float]]]: + """Retrieve cached search results if available. + + Args: + query_vector: Dense query vector + limit: Maximum number of results + filters: Query filters + sparse_raw_terms: Sparse vector terms + sparse_values: Sparse vector values + + Returns: + Tuple of (labels, scores) if found in cache, None otherwise + """ + if not self.enabled: + return None + + key = self._compute_key( + query_vector, limit, filters, sparse_raw_terms, sparse_values + ) + + with self._lock: + if key not in self._cache: + self._misses += 1 + return None + + entry = self._cache[key] + + # Check TTL expiration + if self.ttl_seconds > 0: + age = time.time() - entry.created_at + if age > self.ttl_seconds: + del self._cache[key] + self._misses += 1 + return None + + # Move to end (most recently used) + self._cache.move_to_end(key) + entry.access_count += 1 + self._hits += 1 + + return (entry.labels.copy(), entry.scores.copy()) + + def put( + self, + query_vector: Optional[List[float]], + limit: int, + filters: Optional[Dict[str, Any]], + sparse_raw_terms: Optional[List[str]], + sparse_values: Optional[List[float]], + labels: List[int], + scores: List[float], + ) -> None: + """Store search results in the cache. + + Args: + query_vector: Dense query vector + limit: Maximum number of results + filters: Query filters + sparse_raw_terms: Sparse vector terms + sparse_values: Sparse vector values + labels: Result labels from search + scores: Result scores from search + """ + if not self.enabled: + return + + key = self._compute_key( + query_vector, limit, filters, sparse_raw_terms, sparse_values + ) + + with self._lock: + # Remove if already exists (will be re-added at end) + if key in self._cache: + del self._cache[key] + + # Evict oldest entry if at capacity + while len(self._cache) >= self.max_size: + self._cache.popitem(last=False) + self._evictions += 1 + + # Add new entry + self._cache[key] = CacheEntry( + labels=labels.copy(), + scores=scores.copy(), + ) + + def invalidate(self) -> None: + """Clear all entries from the cache. + + Should be called when the underlying index is modified. + """ + with self._lock: + self._cache.clear() + + def get_stats(self) -> Dict[str, Any]: + """Get cache statistics. + + Returns: + Dictionary containing cache statistics: + - size: Current number of entries + - max_size: Maximum capacity + - hits: Number of cache hits + - misses: Number of cache misses + - evictions: Number of entries evicted + - hit_rate: Cache hit rate (0-1) + """ + with self._lock: + total = self._hits + self._misses + hit_rate = self._hits / total if total > 0 else 0.0 + + return { + "size": len(self._cache), + "max_size": self.max_size, + "hits": self._hits, + "misses": self._misses, + "evictions": self._evictions, + "hit_rate": hit_rate, + "enabled": self.enabled, + "ttl_seconds": self.ttl_seconds, + } + + def resize(self, new_max_size: int) -> None: + """Resize the cache capacity. + + Args: + new_max_size: New maximum number of entries + """ + with self._lock: + self.max_size = new_max_size + # Evict entries if new size is smaller + while len(self._cache) > new_max_size: + self._cache.popitem(last=False) + self._evictions += 1 + + def set_enabled(self, enabled: bool) -> None: + """Enable or disable caching. + + Args: + enabled: Whether to enable caching + """ + with self._lock: + self.enabled = enabled + if not enabled: + self._cache.clear() + + +class CacheManager: + """Manages multiple query caches for different indexes. + + This class provides a central point for managing caches across + multiple indexes in a collection. + + Attributes: + default_max_size: Default maximum cache size for new caches + default_ttl_seconds: Default TTL for new caches + default_enabled: Default enabled state for new caches + """ + + def __init__( + self, + default_max_size: int = 1000, + default_ttl_seconds: float = 300.0, + default_enabled: bool = True, + ): + """Initialize the cache manager. + + Args: + default_max_size: Default max size for new caches + default_ttl_seconds: Default TTL for new caches + default_enabled: Default enabled state for new caches + """ + self.default_max_size = default_max_size + self.default_ttl_seconds = default_ttl_seconds + self.default_enabled = default_enabled + self._caches: Dict[str, QueryCache] = {} + self._lock = threading.RLock() + + def get_cache(self, index_name: str) -> QueryCache: + """Get or create a cache for the specified index. + + Args: + index_name: Name of the index + + Returns: + QueryCache instance for the index + """ + with self._lock: + if index_name not in self._caches: + self._caches[index_name] = QueryCache( + max_size=self.default_max_size, + ttl_seconds=self.default_ttl_seconds, + enabled=self.default_enabled, + ) + return self._caches[index_name] + + def invalidate_index(self, index_name: str) -> None: + """Invalidate cache for a specific index. + + Args: + index_name: Name of the index to invalidate + """ + with self._lock: + if index_name in self._caches: + self._caches[index_name].invalidate() + + def invalidate_all(self) -> None: + """Invalidate all caches.""" + with self._lock: + for cache in self._caches.values(): + cache.invalidate() + + def get_all_stats(self) -> Dict[str, Dict[str, Any]]: + """Get statistics for all caches. + + Returns: + Dictionary mapping index names to their cache statistics + """ + with self._lock: + return { + name: cache.get_stats() + for name, cache in self._caches.items() + } + + def set_enabled_all(self, enabled: bool) -> None: + """Enable or disable all caches. + + Args: + enabled: Whether to enable caching + """ + with self._lock: + for cache in self._caches.values(): + cache.set_enabled(enabled) + + +# Global cache manager instance (can be configured per collection) +_global_cache_manager: Optional[CacheManager] = None +_global_cache_lock = threading.Lock() + + +def get_global_cache_manager() -> CacheManager: + """Get the global cache manager instance. + + Returns: + The global CacheManager instance, creating it if necessary + """ + global _global_cache_manager + with _global_cache_lock: + if _global_cache_manager is None: + _global_cache_manager = CacheManager() + return _global_cache_manager + + +def set_global_cache_manager(manager: CacheManager) -> None: + """Set the global cache manager instance. + + Args: + manager: The CacheManager instance to use globally + """ + global _global_cache_manager + with _global_cache_lock: + _global_cache_manager = manager \ No newline at end of file diff --git a/tests/vectordb/test_query_optimization.py b/tests/vectordb/test_query_optimization.py new file mode 100644 index 000000000..a173b95f7 --- /dev/null +++ b/tests/vectordb/test_query_optimization.py @@ -0,0 +1,441 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: Apache-2.0 +"""Tests for query caching and batch search optimization. + +This module provides tests and benchmarks for the vector retrieval +optimizations including: +- Query result caching (LRU cache) +- Batch search with parallel processing +""" + +import random +import time +from typing import Dict, List + +import pytest + +from openviking.storage.vectordb.collection.local_collection import get_or_create_local_collection + + +def create_test_collection( + collection_name: str = "test_collection", + dim: int = 128, + num_docs: int = 1000, + cache_config: Dict = None, +): + """Create a test collection with random data.""" + meta_data = { + "CollectionName": collection_name, + "Fields": [ + {"FieldName": "id", "FieldType": "int64", "IsPrimaryKey": True}, + {"FieldName": "embedding", "FieldType": "vector", "Dim": dim}, + {"FieldName": "text", "FieldType": "text"}, + {"FieldName": "category", "FieldType": "text"}, + ], + } + + collection = get_or_create_local_collection(meta_data=meta_data, cache_config=cache_config) + + # Insert test data + categories = ["tech", "science", "art", "sports", "music"] + data_list = [] + for i in range(num_docs): + data_list.append({ + "id": i, + "embedding": [random.random() for _ in range(dim)], + "text": f"Document {i}", + "category": categories[i % 5], + }) + + collection.upsert_data(data_list) + + # Create index + index_meta_data = { + "IndexName": "test_index", + "VectorIndex": { + "IndexType": "flat", + "Distance": "ip", + }, + "ScalarIndex": ["category"], + } + collection.create_index("test_index", index_meta_data) + + return collection + + +class TestQueryCache: + """Tests for query result caching.""" + + def test_cache_disabled(self): + """Test that caching can be disabled.""" + collection = create_test_collection( + collection_name="test_cache_disabled", + cache_config={"enabled": False}, + ) + + # Get cache stats + stats = collection.get_index_cache_stats("test_index") + assert stats is not None + assert stats["enabled"] is False + + # Perform searches + query = [random.random() for _ in range(128)] + result1 = collection.search_by_vector("test_index", query, limit=5) + result2 = collection.search_by_vector("test_index", query, limit=5) + + # Cache should have 0 hits since it's disabled + stats = collection.get_index_cache_stats("test_index") + assert stats["hits"] == 0 + + collection.close() + + def test_cache_enabled(self): + """Test that caching works when enabled.""" + collection = create_test_collection( + collection_name="test_cache_enabled", + cache_config={"enabled": True, "max_size": 100, "ttl_seconds": 60}, + ) + + # Get cache stats + stats = collection.get_index_cache_stats("test_index") + assert stats is not None + assert stats["enabled"] is True + assert stats["max_size"] == 100 + assert stats["ttl_seconds"] == 60 + + # Perform same search multiple times + query = [random.random() for _ in range(128)] + result1 = collection.search_by_vector("test_index", query, limit=5) + + # Check cache miss + stats = collection.get_index_cache_stats("test_index") + assert stats["misses"] == 1 + assert stats["hits"] == 0 + + # Same query should hit cache + result2 = collection.search_by_vector("test_index", query, limit=5) + + stats = collection.get_index_cache_stats("test_index") + assert stats["hits"] == 1 + + # Results should be identical + assert len(result1.data) == len(result2.data) + for i in range(len(result1.data)): + assert result1.data[i].id == result2.data[i].id + assert abs(result1.data[i].score - result2.data[i].score) < 1e-6 + + collection.close() + + def test_cache_invalidation_on_upsert(self): + """Test that cache is invalidated when data is modified.""" + collection = create_test_collection( + collection_name="test_cache_invalidation", + cache_config={"enabled": True}, + ) + + # Perform search to populate cache + query = [random.random() for _ in range(128)] + result1 = collection.search_by_vector("test_index", query, limit=5) + + stats = collection.get_index_cache_stats("test_index") + assert stats["misses"] == 1 + assert stats["hits"] == 0 + + # Insert new data - should invalidate cache + collection.upsert_data([{ + "id": 10000, + "embedding": [random.random() for _ in range(128)], + "text": "New document", + "category": "tech", + }]) + + # Same query should miss cache (it was invalidated) + result2 = collection.search_by_vector("test_index", query, limit=5) + + stats = collection.get_index_cache_stats("test_index") + # After upsert, cache was invalidated, so another miss + assert stats["misses"] == 2 + + collection.close() + + def test_cache_stats(self): + """Test cache statistics tracking.""" + collection = create_test_collection( + collection_name="test_cache_stats", + cache_config={"enabled": True, "max_size": 10}, + ) + + # Perform multiple searches + queries = [[random.random() for _ in range(128)] for _ in range(5)] + + # First round - all misses + for query in queries: + collection.search_by_vector("test_index", query, limit=5) + + stats = collection.get_index_cache_stats("test_index") + assert stats["misses"] == 5 + assert stats["hits"] == 0 + + # Second round - all hits (same queries) + for query in queries: + collection.search_by_vector("test_index", query, limit=5) + + stats = collection.get_index_cache_stats("test_index") + assert stats["hits"] == 5 + + # Test hit rate calculation + assert stats["hit_rate"] == 0.5 # 5 hits / 10 total requests + + collection.close() + + +class TestBatchSearch: + """Tests for batch search functionality.""" + + def test_batch_search_basic(self): + """Test basic batch search functionality.""" + collection = create_test_collection( + collection_name="test_batch_search_basic", + cache_config={"enabled": True}, + ) + + # Perform batch search + num_queries = 10 + queries = [[random.random() for _ in range(128)] for _ in range(num_queries)] + + results = collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + limit=5, + ) + + assert len(results) == num_queries + for result in results: + assert len(result.data) <= 5 + for item in result.data: + assert item.id is not None + assert item.score is not None + + collection.close() + + def test_batch_search_with_filters(self): + """Test batch search with filters.""" + collection = create_test_collection( + collection_name="test_batch_search_filters", + cache_config={"enabled": True}, + ) + + num_queries = 5 + queries = [[random.random() for _ in range(128)] for _ in range(num_queries)] + + results = collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + limit=10, + filters={"op": "must", "field": "category", "conds": ["tech"]}, + ) + + assert len(results) == num_queries + for result in results: + for item in result.data: + assert item.fields.get("category") == "tech" + + collection.close() + + def test_batch_search_with_sparse_vectors(self): + """Test batch search with sparse vectors.""" + collection = create_test_collection( + collection_name="test_batch_search_sparse", + cache_config={"enabled": True}, + ) + + num_queries = 3 + queries = [[random.random() for _ in range(128)] for _ in range(num_queries)] + sparse_vectors = [{"term1": 0.5, "term2": 0.3} for _ in range(num_queries)] + + results = collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + sparse_vectors=sparse_vectors, + limit=5, + ) + + assert len(results) == num_queries + + collection.close() + + def test_batch_search_with_offset(self): + """Test batch search with offset.""" + collection = create_test_collection( + collection_name="test_batch_search_offset", + cache_config={"enabled": True}, + ) + + queries = [[random.random() for _ in range(128)] for _ in range(3)] + + # Search with offset=0 + results_no_offset = collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + limit=5, + offset=0, + ) + + # Search with offset=2 + results_with_offset = collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + limit=5, + offset=2, + ) + + # With offset, we should skip the first 2 results + for i in range(len(queries)): + # If there were enough results, the first result with offset + # should be different from the first result without offset + if len(results_no_offset[i].data) > 2: + assert results_with_offset[i].data[0].id != results_no_offset[i].data[0].id + + collection.close() + + def test_batch_search_cache_interaction(self): + """Test that batch search populates and uses cache.""" + collection = create_test_collection( + collection_name="test_batch_search_cache", + cache_config={"enabled": True}, + ) + + # Perform batch search + queries = [[random.random() for _ in range(128)] for _ in range(5)] + results1 = collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + limit=5, + ) + + # Check cache stats - should have 5 misses + stats = collection.get_index_cache_stats("test_index") + assert stats["misses"] == 5 + + # Same batch search - should hit cache + results2 = collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + limit=5, + ) + + stats = collection.get_index_cache_stats("test_index") + assert stats["hits"] == 5 + + # Results should be identical + for i in range(len(queries)): + assert len(results1[i].data) == len(results2[i].data) + for j in range(len(results1[i].data)): + assert results1[i].data[j].id == results2[i].data[j].id + + collection.close() + + +class TestPerformanceBenchmark: + """Performance benchmarks for caching and batch search.""" + + @pytest.mark.skip(reason="Benchmark test - run manually") + def test_cache_performance_benchmark(self): + """Benchmark cache performance improvement.""" + collection = create_test_collection( + collection_name="benchmark_cache", + num_docs=5000, + cache_config={"enabled": True, "max_size": 1000}, + ) + + # Create a set of query vectors (some repeated) + all_queries = [[random.random() for _ in range(128)] for _ in range(100)] + # Repeat some queries to simulate cache hits + repeated_queries = all_queries[:20] * 5 + all_queries[20:] + + # Warm up cache + for query in repeated_queries[:20]: + collection.search_by_vector("test_index", query, limit=10) + + # Benchmark with cache + start_time = time.time() + for query in repeated_queries: + collection.search_by_vector("test_index", query, limit=10) + cached_time = time.time() - start_time + + # Get stats + stats = collection.get_index_cache_stats("test_index") + print(f"\nCache Performance:") + print(f" Total queries: {len(repeated_queries)}") + print(f" Cache hits: {stats['hits']}") + print(f" Cache misses: {stats['misses']}") + print(f" Hit rate: {stats['hit_rate']:.2%}") + print(f" Total time: {cached_time:.3f}s") + + collection.close() + + @pytest.mark.skip(reason="Benchmark test - run manually") + def test_batch_search_performance_benchmark(self): + """Benchmark batch search performance improvement.""" + collection = create_test_collection( + collection_name="benchmark_batch", + num_docs=5000, + cache_config={"enabled": False}, # Disable cache to measure batch effect + ) + + num_queries = 50 + queries = [[random.random() for _ in range(128)] for _ in range(num_queries)] + + # Benchmark individual searches + start_time = time.time() + for query in queries: + collection.search_by_vector("test_index", query, limit=10) + individual_time = time.time() - start_time + + # Clear cache (though it's disabled) + collection.invalidate_index_cache("test_index") + + # Benchmark batch search + start_time = time.time() + collection.batch_search_by_vector( + index_name="test_index", + dense_vectors=queries, + limit=10, + num_threads=4, + ) + batch_time = time.time() - start_time + + print(f"\nBatch Search Performance:") + print(f" Number of queries: {num_queries}") + print(f" Individual search time: {individual_time:.3f}s") + print(f" Batch search time: {batch_time:.3f}s") + print(f" Speedup: {individual_time / batch_time:.2f}x") + + collection.close() + + +if __name__ == "__main__": + # Run basic tests + print("Running query cache tests...") + test_cache = TestQueryCache() + test_cache.test_cache_disabled() + print(" ✓ test_cache_disabled") + test_cache.test_cache_enabled() + print(" ✓ test_cache_enabled") + test_cache.test_cache_invalidation_on_upsert() + print(" ✓ test_cache_invalidation_on_upsert") + test_cache.test_cache_stats() + print(" ✓ test_cache_stats") + + print("\nRunning batch search tests...") + test_batch = TestBatchSearch() + test_batch.test_batch_search_basic() + print(" ✓ test_batch_search_basic") + test_batch.test_batch_search_with_filters() + print(" ✓ test_batch_search_with_filters") + test_batch.test_batch_search_with_offset() + print(" ✓ test_batch_search_with_offset") + test_batch.test_batch_search_cache_interaction() + print(" ✓ test_batch_search_cache_interaction") + + print("\nAll tests passed! ✓") \ No newline at end of file