diff --git a/src/omop_emb/backends/base.py b/src/omop_emb/backends/base.py index 895c440..d42cd12 100644 --- a/src/omop_emb/backends/base.py +++ b/src/omop_emb/backends/base.py @@ -69,6 +69,7 @@ class EmbeddingBackend(ABC, Generic[T]): filters. """ DEFAULT_BASE_STORAGE_DIR = Path.home() / ".omop_emb" + DEFAULT_K_NEAREST = 10 def __init__( self, storage_base_dir: Optional[str | Path] = None, @@ -383,7 +384,6 @@ def get_nearest_concepts( query_embedding: ndarray, metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, - k: int = 10, _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: """ @@ -402,9 +402,7 @@ def get_nearest_concepts( metric_type : MetricType Similarity or distance metric for nearest-neighbor search. concept_filter : Optional[EmbeddingConceptFilter], optional - Optional filter restricting which OMOP concepts are considered. - k : int, optional - Number of nearest matches to return per query. + Optional filter restricting which OMOP concepts are considered. The "limit" field of the filter determines how many nearest neighbors are returned per query vector. If not set, defaults to the global DEFAULT_K_NEAREST. _model_record : EmbeddingModelRecord Internal registered-model record injected by ``@require_registered_model``. @@ -413,6 +411,23 @@ def get_nearest_concepts( Tuple[Tuple[NearestConceptMatch, ...], ...] A tuple of tuples containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner tuple contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query. """ + + def validate_nearest_concepts_output( + self, + nearest_concepts: Tuple[Tuple[NearestConceptMatch, ...], ...], + k: int, + query_embeddings: ndarray, + ) -> None: + assert all(len(d) <= k for d in nearest_concepts), ( + f"Expected at most {k} nearest neighbors per query embedding, but found a dictionary with {max(len(d) for d in nearest_concepts)} entries." + ) + + assert len(nearest_concepts) == query_embeddings.shape[0], ( + f"Expected nearest concepts for {query_embeddings.shape[0]} query embeddings, " + f"but got {len(nearest_concepts)}." + ) + + def has_any_embeddings( self, diff --git a/src/omop_emb/backends/faiss/faiss_backend.py b/src/omop_emb/backends/faiss/faiss_backend.py index 51760e4..3bbdb36 100644 --- a/src/omop_emb/backends/faiss/faiss_backend.py +++ b/src/omop_emb/backends/faiss/faiss_backend.py @@ -214,7 +214,6 @@ def get_nearest_concepts( metric_type: MetricType, *, concept_filter: Optional[EmbeddingConceptFilter] = None, - k: int = 10, _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: @@ -237,8 +236,15 @@ def get_nearest_concepts( if concept_filter is None: # Easier to not do any filter if all are allowed permitted_concept_ids = None + logger.debug(f"No concept filter provided. Setting number of returned nearest concepts (k) to default: {self.DEFAULT_K_NEAREST}") + k = self.DEFAULT_K_NEAREST else: permitted_concept_ids = np.array(list(permitted_concept_ids_storage.keys()), dtype=np.int64) + if concept_filter.limit is not None: + k = concept_filter.limit + else: + logger.debug(f"Concept filter provided without limit. Setting number of returned nearest concepts (k) to default: {self.DEFAULT_K_NEAREST}") + k = self.DEFAULT_K_NEAREST distances, concept_ids = storage_manager.search( query_vector=query_embeddings, @@ -258,16 +264,22 @@ def get_nearest_concepts( if row is None: logger.warning(f"Concept ID {concept_id} returned by FAISS search but not found in permitted concept IDs. This indicates a mismatch between the FAISS index and the database registry. Skipping this result.") continue + + similarity = get_similarity_from_distance(distance.item(), metric_type) + assert isinstance(similarity, float), f"Expected similarity to be a float, got {type(similarity)}" matches_per_query.append(NearestConceptMatch( concept_id=int(concept_id), concept_name=row.concept_name, - similarity=get_similarity_from_distance(distance, metric_type), + similarity=similarity, is_standard=bool(row.is_standard), is_active=bool(row.is_active), )) matches.append(tuple(matches_per_query)) - return tuple(matches) - + + matches_tuple = tuple(matches) + self.validate_nearest_concepts_output(matches_tuple, k, query_embeddings=query_embeddings) + return matches_tuple + @require_registered_model def get_embeddings_by_concept_ids( self, diff --git a/src/omop_emb/backends/pgvector/pgvector_backend.py b/src/omop_emb/backends/pgvector/pgvector_backend.py index b12f8be..64dfc5e 100644 --- a/src/omop_emb/backends/pgvector/pgvector_backend.py +++ b/src/omop_emb/backends/pgvector/pgvector_backend.py @@ -3,6 +3,7 @@ from typing import Mapping, Optional, Sequence, Type, Tuple from numpy import ndarray +import logging from sqlalchemy import Engine, text from sqlalchemy.orm import Session from sqlalchemy.exc import IntegrityError @@ -22,6 +23,8 @@ ) from omop_emb.model_registry import EmbeddingModelRecord +logger = logging.getLogger(__name__) + class PGVectorEmbeddingBackend(EmbeddingBackend[PGVectorConceptIDEmbeddingTable]): """ pgvector-backed embedding backend for postgresql databases. @@ -166,7 +169,6 @@ def get_nearest_concepts( query_embeddings: ndarray, metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, - k: int = 10, _model_record: EmbeddingModelRecord, ) -> Tuple[Tuple[NearestConceptMatch, ...], ...]: """ @@ -197,13 +199,27 @@ def get_nearest_concepts( ) self.validate_embeddings(embeddings=query_embeddings, dimensions=_model_record.dimensions) + # Guarantee that concept_filter has a limit set for K nearest neighbors + if concept_filter is None or concept_filter.limit is None: + logger.debug(f"No concept filter or concept filter limit provided. Setting number of returned nearest concepts (k) to default: {self.DEFAULT_K_NEAREST}") + + if concept_filter is None: + concept_filter = EmbeddingConceptFilter(limit=self.DEFAULT_K_NEAREST) + elif concept_filter.limit is None: + concept_filter = EmbeddingConceptFilter( + concept_ids=concept_filter.concept_ids, + domains=concept_filter.domains, + vocabularies=concept_filter.vocabularies, + require_standard=concept_filter.require_standard, + limit=self.DEFAULT_K_NEAREST, + ) + query_list = query_embeddings.tolist() query = q_embedding_nearest_concepts( embedding_table=embedding_table, query_embeddings=query_list, metric_type=metric_type, concept_filter=concept_filter, - limit=k ) rows = session.execute(query).all() @@ -220,5 +236,10 @@ def get_nearest_concepts( ) ) - return tuple(tuple(matches) for matches in results) + matches_tuple = tuple(tuple(matches) for matches in results) + + k = concept_filter.limit + assert k is not None, "Internal error: concept_filter.limit should have been set to a non-None value by this point." + self.validate_nearest_concepts_output(matches_tuple, k, query_embeddings=query_embeddings) + return matches_tuple diff --git a/src/omop_emb/backends/pgvector/pgvector_sql.py b/src/omop_emb/backends/pgvector/pgvector_sql.py index 048a9a8..0804b3b 100644 --- a/src/omop_emb/backends/pgvector/pgvector_sql.py +++ b/src/omop_emb/backends/pgvector/pgvector_sql.py @@ -124,7 +124,6 @@ def q_embedding_nearest_concepts( query_embeddings: List[List[float]], metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, - limit: int = 10, ) -> Select: """Constructs a SQL query to retrieve the nearest concepts for the given query embeddings, applying the specified metric and filters. The query uses a LATERAL join to compute distances/similarities for each query vector against all candidate concept embeddings, and then applies the necessary OMOP filters before returning the top K results per query vector. @@ -148,7 +147,7 @@ def q_embedding_nearest_concepts( metric_type : MetricType The distance metric to use for nearest neighbor search (e.g., COSINE, L2, etc.). This will determine which pgvector distance function is used in the query. concept_filter : Optional[EmbeddingConceptFilter], optional - An optional filter object containing criteria to filter the concepts (e.g., by concept_id, domain, vocabulary, standard_concept flag). + An optional filter object containing criteria to filter the concepts (e.g., by concept_id, domain, vocabulary, standard_concept flag). Also is used to limit the number of nearest neighbors (K) returned per query vector. limit : int, optional The number of nearest neighbors (K) to return for each query embedding, by default 10. """ @@ -186,7 +185,7 @@ def q_embedding_nearest_concepts( if concept_filter: inner_stmt = concept_filter.apply(inner_stmt) - lateral_subq = inner_stmt.limit(limit).lateral("top_k") + lateral_subq = inner_stmt.lateral("top_k") # Joins the Q vectors to their K nearest neighbors stmt = ( diff --git a/src/omop_emb/interface.py b/src/omop_emb/interface.py index e72a144..13d22d4 100644 --- a/src/omop_emb/interface.py +++ b/src/omop_emb/interface.py @@ -156,9 +156,7 @@ def get_nearest_concepts( *, metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, - k: int = 10, ) -> Tuple[Mapping[int, float], ...]: - """ Return nearest stored concepts for the query embedding. @@ -176,18 +174,12 @@ def get_nearest_concepts( metric_type : MetricType The similarity or distance metric to use for nearest neighbor search. This must be compatible with the index type used by the database. concept_filter : Optional[EmbeddingConceptFilter], optional - A filter to specify which concepts to consider as potential nearest neighbors. - vocabularies : Optional[Tuple[str, ...]], optional - If provided, only consider concepts from these vocabularies as potential nearest neighbors. - require_standard : bool, optional - If True, only consider standard concepts as potential nearest neighbors. By default False. - k : int, optional - K nearest neighbors to return for each query vector. Default is 10. + A filter to specify which concepts to consider as potential nearest neighbors. The `limit` field of this filter determines the number of neighbors returned. Returns ------- Tuple[Mapping[int, float], ...] - A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query. + A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, limit) where q is the number of query vectors and limit is the number of nearest neighbors returned per query as determined by the `concept_filter` argument or backend default. """ if not isinstance(metric_type, MetricType): raise TypeError( @@ -199,10 +191,9 @@ def get_nearest_concepts( index_type=index_type, query_embeddings=query_embedding, concept_filter=concept_filter, - metric_type=metric_type, - k=k + metric_type=metric_type ) - return tuple({match_per_query.concept_id: match_per_query.similarity for match_per_query in match} for match in nearest_concepts) + return tuple({match.concept_id: match.similarity for match in matches_per_query} for matches_per_query in nearest_concepts) def get_nearest_concepts_by_texts( self, @@ -213,13 +204,13 @@ def get_nearest_concepts_by_texts( *, metric_type: MetricType, concept_filter: Optional[EmbeddingConceptFilter] = None, - k: int = 10, batch_size: Optional[int] = None ) -> Tuple[Mapping[int, float], ...]: - """ Return nearest stored concepts for the query embedding. Convenience wrapper that embeds the query texts before performing the nearest neighbor search. + The number of neighbors returned is determined by the `limit` field of the `concept_filter` argument. If `limit` is not set, a backend default may be used. + Parameters ---------- session : Session @@ -233,16 +224,14 @@ def get_nearest_concepts_by_texts( metric_type : MetricType The similarity or distance metric to use for nearest neighbor search. This should be compatible with the index type used by the model. concept_filter : Optional[EmbeddingConceptFilter], optional - A filter to specify which concepts to consider as potential nearest neighbors. - k : int, optional - K nearest neighbors to return for each query vector. Default is 10. + A filter to specify which concepts to consider as potential nearest neighbors. The `limit` field of this filter determines the number of neighbors returned. batch_size : Optional[int], optional If provided, this batch size will be used when embedding the query texts. If not provided, the default batch size of the embedding client will be used. Returns ------- Tuple[Mapping[int, float], ...] - A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, k) where q is the number of query vectors and k is the number of nearest neighbors returned per query. + A tuple of dictionaries containing nearest concept matches for each query vector. The outer tuple corresponds to the query vectors in order, and each inner dictionary contains the nearest matches for that query vector, sorted by similarity. Returned shape is (q, limit) where q is the number of query vectors and limit is the number of nearest neighbors returned per query. """ if isinstance(query_texts, str): query_texts = (query_texts,) @@ -258,7 +247,6 @@ def get_nearest_concepts_by_texts( query_embedding=query_embeddings, metric_type=metric_type, concept_filter=concept_filter, - k=k, ) def get_embeddings_by_concept_ids( diff --git a/src/omop_emb/utils/embedding_utils.py b/src/omop_emb/utils/embedding_utils.py index 6d22557..7c24688 100644 --- a/src/omop_emb/utils/embedding_utils.py +++ b/src/omop_emb/utils/embedding_utils.py @@ -15,12 +15,15 @@ class EmbeddingConceptFilter: This mirrors the current OMOP grounding needs without importing ``omop_graph`` or its search-constraint objects into ``omop_emb``. + + The `limit` field determines the number of nearest neighbors returned by embedding search operations. If not set, a backend default may be used. """ concept_ids: Optional[tuple[int, ...]] = None domains: Optional[tuple[str, ...]] = None vocabularies: Optional[tuple[str, ...]] = None require_standard: bool = False + limit: Optional[int] = None def apply(self, query: Select) -> Select: if self.concept_ids is not None: @@ -35,7 +38,7 @@ def apply(self, query: Select) -> Select: if self.require_standard: query = query.where(Concept.standard_concept.in_(["S", "C"])) - return query + return query.limit(self.limit) @dataclass(frozen=True) diff --git a/tests/shared_backend_tests.py b/tests/shared_backend_tests.py index f260ea9..be734fa 100644 --- a/tests/shared_backend_tests.py +++ b/tests/shared_backend_tests.py @@ -155,7 +155,7 @@ def test_nearest_neighbor_search(self, session, backend, mock_llm_client, index_ index_type=index_type, query_embeddings=query_embeddings, metric_type=MetricType.L2, - k=1, + concept_filter=EmbeddingConceptFilter(limit=1), ) assert len(results) == 1 @@ -182,15 +182,13 @@ def test_nearest_neighbor_with_domain_filter(self, session, backend, mock_llm_cl concept_ids=concept_ids, embeddings=embeddings, ) - results = backend.get_nearest_concepts( session=session, model_name=MODEL_NAME, index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, - concept_filter=EmbeddingConceptFilter(domains=("Condition",)), + concept_filter=EmbeddingConceptFilter(domains=("Condition",), limit=10), metric_type=MetricType.L2, - k=10, ) expected_ids = {CONCEPTS["Hypertension"].concept_id, CONCEPTS["Diabetes"].concept_id} @@ -223,9 +221,8 @@ def test_nearest_neighbor_with_vocabulary_filter(self, session, backend, mock_ll model_name=MODEL_NAME, index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, - concept_filter=EmbeddingConceptFilter(vocabularies=("RxNorm",)), + concept_filter=EmbeddingConceptFilter(vocabularies=("RxNorm",), limit=10), metric_type=MetricType.L2, - k=10, ) assert len(results) == 1 @@ -297,7 +294,7 @@ def test_l2_similarity_exact_values(self, session, backend, mock_llm_client, ind index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, metric_type=MetricType.L2, - k=10, + concept_filter=EmbeddingConceptFilter(limit=10), ) expected_similarities = { @@ -351,7 +348,7 @@ def test_cosine_similarity_exact_values(self, session, backend, mock_llm_client, index_type=index_type, query_embeddings=TEST_CONCEPT_EMB, metric_type=MetricType.COSINE, - k=10, + concept_filter=EmbeddingConceptFilter(limit=10), ) # For 1D unit vectors, cos_sim = (norm_a) * (norm_b) diff --git a/tests/test_interface.py b/tests/test_interface.py index 978a651..44d7d3b 100644 --- a/tests/test_interface.py +++ b/tests/test_interface.py @@ -3,9 +3,9 @@ import pytest import numpy as np from unittest.mock import Mock -from sqlalchemy.orm import Session from omop_emb.interface import EmbeddingInterface +from omop_emb.utils.embedding_utils import EmbeddingConceptFilter from omop_emb.config import IndexType, MetricType from omop_emb.backends.base import NearestConceptMatch from .conftest import CONCEPTS, MODEL_NAME, EMBEDDING_DIM @@ -162,7 +162,7 @@ def test_search_return_structure_for_backends(self, session, mock_llm_client, ba index_type=index_type, query_embedding=query_embedding, metric_type=MetricType.COSINE, - k=2, + concept_filter=EmbeddingConceptFilter(limit=2), ) assert isinstance(result, tuple)